449 lines
16 KiB
Python
449 lines
16 KiB
Python
"""Zep entity reader and filter service.
|
|
|
|
Reads nodes from a Zep graph and filters down to those that match a
|
|
predefined ontology of entity types.
|
|
"""
|
|
|
|
import time
|
|
from typing import Dict, Any, List, Optional, Set, Callable, TypeVar
|
|
from dataclasses import dataclass, field
|
|
|
|
from .graphiti_adapter import GraphitiAdapter
|
|
|
|
from ..config import Config
|
|
from ..utils.logger import get_logger
|
|
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
|
|
from ..utils.locale import t
|
|
|
|
logger = get_logger('mirofish.zep_entity_reader')
|
|
|
|
# Generic return-type variable.
|
|
T = TypeVar('T')
|
|
|
|
|
|
@dataclass
|
|
class EntityNode:
|
|
"""In-memory representation of an entity node from the graph."""
|
|
uuid: str
|
|
name: str
|
|
labels: List[str]
|
|
summary: str
|
|
attributes: Dict[str, Any]
|
|
# Edges connected to this entity.
|
|
related_edges: List[Dict[str, Any]] = field(default_factory=list)
|
|
# Other nodes connected through related edges.
|
|
related_nodes: List[Dict[str, Any]] = field(default_factory=list)
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"uuid": self.uuid,
|
|
"name": self.name,
|
|
"labels": self.labels,
|
|
"summary": self.summary,
|
|
"attributes": self.attributes,
|
|
"related_edges": self.related_edges,
|
|
"related_nodes": self.related_nodes,
|
|
}
|
|
|
|
def get_entity_type(self) -> Optional[str]:
|
|
"""Return the first non-default label, or ``None`` if only defaults are present."""
|
|
for label in self.labels:
|
|
if label not in ["Entity", "Node"]:
|
|
return label
|
|
return None
|
|
|
|
|
|
@dataclass
|
|
class FilteredEntities:
|
|
"""Result of a filter pass over the graph: matching entities + counts."""
|
|
entities: List[EntityNode]
|
|
entity_types: Set[str]
|
|
total_count: int
|
|
filtered_count: int
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"entities": [e.to_dict() for e in self.entities],
|
|
"entity_types": list(self.entity_types),
|
|
"total_count": self.total_count,
|
|
"filtered_count": self.filtered_count,
|
|
}
|
|
|
|
|
|
class ZepEntityReader:
|
|
"""Read entities from a Zep graph and filter to ontology-defined types.
|
|
|
|
Capabilities:
|
|
1. Read all nodes from the graph.
|
|
2. Keep nodes whose labels include something other than the default ``Entity``.
|
|
3. Optionally enrich each entity with its connected edges and neighboring nodes.
|
|
"""
|
|
|
|
def __init__(self, api_key: Optional[str] = None):
|
|
self.client = GraphitiAdapter()
|
|
|
|
def _call_with_retry(
|
|
self,
|
|
func: Callable[[], T],
|
|
operation_name: str,
|
|
max_retries: int = 3,
|
|
initial_delay: float = 2.0
|
|
) -> T:
|
|
"""Call a Zep API function with retry on failure.
|
|
|
|
Args:
|
|
func: A zero-argument callable performing the request.
|
|
operation_name: Operation label used in log output.
|
|
max_retries: Maximum number of attempts (default 3 — i.e. up to 3 tries total).
|
|
initial_delay: Initial delay between retries in seconds.
|
|
|
|
Returns:
|
|
The return value of ``func``.
|
|
"""
|
|
last_exception = None
|
|
delay = initial_delay
|
|
|
|
for attempt in range(max_retries):
|
|
try:
|
|
return func()
|
|
except Exception as e:
|
|
last_exception = e
|
|
if attempt < max_retries - 1:
|
|
logger.warning(
|
|
t("log.zep_entity_reader.m001", operation_name=operation_name, attempt=attempt + 1, str=str(e)[:100], delay=delay)
|
|
)
|
|
time.sleep(delay)
|
|
delay *= 2 # exponential backoff
|
|
else:
|
|
logger.error(t("log.zep_entity_reader.m002", operation_name=operation_name, max_retries=max_retries, str=str(e)))
|
|
|
|
raise last_exception
|
|
|
|
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
|
|
"""Return every node in the graph (paginated under the hood).
|
|
|
|
Args:
|
|
graph_id: Graph identifier.
|
|
|
|
Returns:
|
|
A list of node dicts.
|
|
"""
|
|
logger.info(t("log.zep_entity_reader.m003", graph_id=graph_id))
|
|
|
|
nodes = fetch_all_nodes(self.client, graph_id)
|
|
|
|
nodes_data = []
|
|
for node in nodes:
|
|
nodes_data.append({
|
|
"uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
|
|
"name": node.name or "",
|
|
"labels": node.labels or [],
|
|
"summary": node.summary or "",
|
|
"attributes": node.attributes or {},
|
|
})
|
|
|
|
logger.info(t("log.zep_entity_reader.m004", len=len(nodes_data)))
|
|
return nodes_data
|
|
|
|
def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]:
|
|
"""Return every edge in the graph (paginated under the hood).
|
|
|
|
Args:
|
|
graph_id: Graph identifier.
|
|
|
|
Returns:
|
|
A list of edge dicts.
|
|
"""
|
|
logger.info(t("log.zep_entity_reader.m005", graph_id=graph_id))
|
|
|
|
edges = fetch_all_edges(self.client, graph_id)
|
|
|
|
edges_data = []
|
|
for edge in edges:
|
|
edges_data.append({
|
|
"uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
|
|
"name": edge.name or "",
|
|
"fact": edge.fact or "",
|
|
"source_node_uuid": edge.source_node_uuid,
|
|
"target_node_uuid": edge.target_node_uuid,
|
|
"attributes": edge.attributes or {},
|
|
})
|
|
|
|
logger.info(t("log.zep_entity_reader.m006", len=len(edges_data)))
|
|
return edges_data
|
|
|
|
def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]:
|
|
"""Return every edge connected to the given node (with retry).
|
|
|
|
Args:
|
|
node_uuid: Node UUID.
|
|
|
|
Returns:
|
|
A list of edge dicts.
|
|
"""
|
|
try:
|
|
# Wrap the API call in retry logic.
|
|
edges = self._call_with_retry(
|
|
func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid),
|
|
operation_name=f"获取节点边(node={node_uuid[:8]}...)"
|
|
)
|
|
|
|
edges_data = []
|
|
for edge in edges:
|
|
edges_data.append({
|
|
"uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
|
|
"name": edge.name or "",
|
|
"fact": edge.fact or "",
|
|
"source_node_uuid": edge.source_node_uuid,
|
|
"target_node_uuid": edge.target_node_uuid,
|
|
"attributes": edge.attributes or {},
|
|
})
|
|
|
|
return edges_data
|
|
except Exception as e:
|
|
logger.warning(t("log.zep_entity_reader.m007", node_uuid=node_uuid, str=str(e)))
|
|
return []
|
|
|
|
def filter_defined_entities(
|
|
self,
|
|
graph_id: str,
|
|
defined_entity_types: Optional[List[str]] = None,
|
|
enrich_with_edges: bool = True
|
|
) -> FilteredEntities:
|
|
"""Filter nodes down to entities matching the predefined ontology types.
|
|
|
|
Filtering rules:
|
|
- Skip nodes whose only label is ``Entity`` (uncategorized).
|
|
- Keep nodes whose labels include anything other than ``Entity`` and ``Node``.
|
|
|
|
Args:
|
|
graph_id: Graph identifier.
|
|
defined_entity_types: Optional allow-list; when provided, only matching types are kept.
|
|
enrich_with_edges: When ``True``, populate related_edges and related_nodes.
|
|
|
|
Returns:
|
|
A ``FilteredEntities`` summary.
|
|
"""
|
|
logger.info(t("log.zep_entity_reader.m008", graph_id=graph_id))
|
|
|
|
# Look up ontology from project to classify entities
|
|
ontology = None
|
|
try:
|
|
from ..models.project import ProjectManager
|
|
from .graph_builder import _classify_entity_type
|
|
for p in ProjectManager.list_projects():
|
|
if p.graph_id == graph_id and p.ontology:
|
|
ontology = p.ontology
|
|
break
|
|
except Exception:
|
|
pass
|
|
|
|
# Read every node from the graph.
|
|
all_nodes = self.get_all_nodes(graph_id)
|
|
total_count = len(all_nodes)
|
|
|
|
# Apply ontology-based classification so all nodes get proper type labels
|
|
if ontology:
|
|
for node in all_nodes:
|
|
labels = node.get("labels", [])
|
|
custom = [l for l in labels if l not in ("Entity", "Node")]
|
|
if not custom:
|
|
entity_type = _classify_entity_type(
|
|
node.get("name", ""), node.get("summary", ""), ontology
|
|
)
|
|
if entity_type != "Entity":
|
|
node["labels"] = [entity_type] + labels
|
|
|
|
# Read every edge so we can enrich entities later.
|
|
all_edges = self.get_all_edges(graph_id) if enrich_with_edges else []
|
|
|
|
# uuid -> node-data map for fast lookup.
|
|
node_map = {n["uuid"]: n for n in all_nodes}
|
|
|
|
# Filter to entities that match the criteria.
|
|
filtered_entities = []
|
|
entity_types_found = set()
|
|
|
|
for node in all_nodes:
|
|
labels = node.get("labels", [])
|
|
|
|
# Filtering rule: labels must contain something other than the defaults.
|
|
custom_labels = [l for l in labels if l not in ["Entity", "Node"]]
|
|
|
|
if not custom_labels:
|
|
# Only default labels — skip.
|
|
continue
|
|
|
|
# When a predefined-type list is supplied, require a match against it.
|
|
if defined_entity_types:
|
|
matching_labels = [l for l in custom_labels if l in defined_entity_types]
|
|
if not matching_labels:
|
|
continue
|
|
entity_type = matching_labels[0]
|
|
else:
|
|
entity_type = custom_labels[0]
|
|
|
|
entity_types_found.add(entity_type)
|
|
|
|
entity = EntityNode(
|
|
uuid=node["uuid"],
|
|
name=node["name"],
|
|
labels=labels,
|
|
summary=node["summary"],
|
|
attributes=node["attributes"],
|
|
)
|
|
|
|
# Enrich with related edges and neighboring nodes.
|
|
if enrich_with_edges:
|
|
related_edges = []
|
|
related_node_uuids = set()
|
|
|
|
for edge in all_edges:
|
|
if edge["source_node_uuid"] == node["uuid"]:
|
|
related_edges.append({
|
|
"direction": "outgoing",
|
|
"edge_name": edge["name"],
|
|
"fact": edge["fact"],
|
|
"target_node_uuid": edge["target_node_uuid"],
|
|
})
|
|
related_node_uuids.add(edge["target_node_uuid"])
|
|
elif edge["target_node_uuid"] == node["uuid"]:
|
|
related_edges.append({
|
|
"direction": "incoming",
|
|
"edge_name": edge["name"],
|
|
"fact": edge["fact"],
|
|
"source_node_uuid": edge["source_node_uuid"],
|
|
})
|
|
related_node_uuids.add(edge["source_node_uuid"])
|
|
|
|
entity.related_edges = related_edges
|
|
|
|
# Populate basic info for each neighboring node.
|
|
related_nodes = []
|
|
for related_uuid in related_node_uuids:
|
|
if related_uuid in node_map:
|
|
related_node = node_map[related_uuid]
|
|
related_nodes.append({
|
|
"uuid": related_node["uuid"],
|
|
"name": related_node["name"],
|
|
"labels": related_node["labels"],
|
|
"summary": related_node.get("summary", ""),
|
|
})
|
|
|
|
entity.related_nodes = related_nodes
|
|
|
|
filtered_entities.append(entity)
|
|
|
|
logger.info(t("log.zep_entity_reader.m009", total_count=total_count, len=len(filtered_entities), entity_types_found=entity_types_found))
|
|
|
|
return FilteredEntities(
|
|
entities=filtered_entities,
|
|
entity_types=entity_types_found,
|
|
total_count=total_count,
|
|
filtered_count=len(filtered_entities),
|
|
)
|
|
|
|
def get_entity_with_context(
|
|
self,
|
|
graph_id: str,
|
|
entity_uuid: str
|
|
) -> Optional[EntityNode]:
|
|
"""Fetch a single entity with its full context (edges + neighbors), with retry.
|
|
|
|
Args:
|
|
graph_id: Graph identifier.
|
|
entity_uuid: Entity UUID.
|
|
|
|
Returns:
|
|
``EntityNode`` or ``None`` if not found.
|
|
"""
|
|
try:
|
|
# Fetch the node with retry.
|
|
node = self._call_with_retry(
|
|
func=lambda: self.client.graph.node.get(uuid_=entity_uuid),
|
|
operation_name=f"获取节点详情(uuid={entity_uuid[:8]}...)"
|
|
)
|
|
|
|
if not node:
|
|
return None
|
|
|
|
# Edges connected to this node.
|
|
edges = self.get_node_edges(entity_uuid)
|
|
|
|
# All graph nodes, used for neighbor lookup.
|
|
all_nodes = self.get_all_nodes(graph_id)
|
|
node_map = {n["uuid"]: n for n in all_nodes}
|
|
|
|
# Collect related edges and neighboring uuids.
|
|
related_edges = []
|
|
related_node_uuids = set()
|
|
|
|
for edge in edges:
|
|
if edge["source_node_uuid"] == entity_uuid:
|
|
related_edges.append({
|
|
"direction": "outgoing",
|
|
"edge_name": edge["name"],
|
|
"fact": edge["fact"],
|
|
"target_node_uuid": edge["target_node_uuid"],
|
|
})
|
|
related_node_uuids.add(edge["target_node_uuid"])
|
|
else:
|
|
related_edges.append({
|
|
"direction": "incoming",
|
|
"edge_name": edge["name"],
|
|
"fact": edge["fact"],
|
|
"source_node_uuid": edge["source_node_uuid"],
|
|
})
|
|
related_node_uuids.add(edge["source_node_uuid"])
|
|
|
|
# Populate basic info for each neighboring node.
|
|
related_nodes = []
|
|
for related_uuid in related_node_uuids:
|
|
if related_uuid in node_map:
|
|
related_node = node_map[related_uuid]
|
|
related_nodes.append({
|
|
"uuid": related_node["uuid"],
|
|
"name": related_node["name"],
|
|
"labels": related_node["labels"],
|
|
"summary": related_node.get("summary", ""),
|
|
})
|
|
|
|
return EntityNode(
|
|
uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
|
|
name=node.name or "",
|
|
labels=node.labels or [],
|
|
summary=node.summary or "",
|
|
attributes=node.attributes or {},
|
|
related_edges=related_edges,
|
|
related_nodes=related_nodes,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(t("log.zep_entity_reader.m010", entity_uuid=entity_uuid, str=str(e)))
|
|
return None
|
|
|
|
def get_entities_by_type(
|
|
self,
|
|
graph_id: str,
|
|
entity_type: str,
|
|
enrich_with_edges: bool = True
|
|
) -> List[EntityNode]:
|
|
"""Return every entity matching the given type.
|
|
|
|
Args:
|
|
graph_id: Graph identifier.
|
|
entity_type: Entity type label (e.g. ``Student``, ``PublicFigure``).
|
|
enrich_with_edges: When ``True``, populate related edges/nodes.
|
|
|
|
Returns:
|
|
A list of matching ``EntityNode`` instances.
|
|
"""
|
|
result = self.filter_defined_entities(
|
|
graph_id=graph_id,
|
|
defined_entity_types=[entity_type],
|
|
enrich_with_edges=enrich_with_edges
|
|
)
|
|
return result.entities
|
|
|
|
|