MicroFish/backend/app/services/zep_entity_reader.py

449 lines
16 KiB
Python

"""Zep entity reader and filter service.
Reads nodes from a Zep graph and filters down to those that match a
predefined ontology of entity types.
"""
import time
from typing import Dict, Any, List, Optional, Set, Callable, TypeVar
from dataclasses import dataclass, field
from .graphiti_adapter import GraphitiAdapter
from ..config import Config
from ..utils.logger import get_logger
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
from ..utils.locale import t
logger = get_logger('mirofish.zep_entity_reader')
# Generic return-type variable.
T = TypeVar('T')
@dataclass
class EntityNode:
"""In-memory representation of an entity node from the graph."""
uuid: str
name: str
labels: List[str]
summary: str
attributes: Dict[str, Any]
# Edges connected to this entity.
related_edges: List[Dict[str, Any]] = field(default_factory=list)
# Other nodes connected through related edges.
related_nodes: List[Dict[str, Any]] = field(default_factory=list)
def to_dict(self) -> Dict[str, Any]:
return {
"uuid": self.uuid,
"name": self.name,
"labels": self.labels,
"summary": self.summary,
"attributes": self.attributes,
"related_edges": self.related_edges,
"related_nodes": self.related_nodes,
}
def get_entity_type(self) -> Optional[str]:
"""Return the first non-default label, or ``None`` if only defaults are present."""
for label in self.labels:
if label not in ["Entity", "Node"]:
return label
return None
@dataclass
class FilteredEntities:
"""Result of a filter pass over the graph: matching entities + counts."""
entities: List[EntityNode]
entity_types: Set[str]
total_count: int
filtered_count: int
def to_dict(self) -> Dict[str, Any]:
return {
"entities": [e.to_dict() for e in self.entities],
"entity_types": list(self.entity_types),
"total_count": self.total_count,
"filtered_count": self.filtered_count,
}
class ZepEntityReader:
"""Read entities from a Zep graph and filter to ontology-defined types.
Capabilities:
1. Read all nodes from the graph.
2. Keep nodes whose labels include something other than the default ``Entity``.
3. Optionally enrich each entity with its connected edges and neighboring nodes.
"""
def __init__(self, api_key: Optional[str] = None):
self.client = GraphitiAdapter()
def _call_with_retry(
self,
func: Callable[[], T],
operation_name: str,
max_retries: int = 3,
initial_delay: float = 2.0
) -> T:
"""Call a Zep API function with retry on failure.
Args:
func: A zero-argument callable performing the request.
operation_name: Operation label used in log output.
max_retries: Maximum number of attempts (default 3 — i.e. up to 3 tries total).
initial_delay: Initial delay between retries in seconds.
Returns:
The return value of ``func``.
"""
last_exception = None
delay = initial_delay
for attempt in range(max_retries):
try:
return func()
except Exception as e:
last_exception = e
if attempt < max_retries - 1:
logger.warning(
t("log.zep_entity_reader.m001", operation_name=operation_name, attempt=attempt + 1, str=str(e)[:100], delay=delay)
)
time.sleep(delay)
delay *= 2 # exponential backoff
else:
logger.error(t("log.zep_entity_reader.m002", operation_name=operation_name, max_retries=max_retries, str=str(e)))
raise last_exception
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
"""Return every node in the graph (paginated under the hood).
Args:
graph_id: Graph identifier.
Returns:
A list of node dicts.
"""
logger.info(t("log.zep_entity_reader.m003", graph_id=graph_id))
nodes = fetch_all_nodes(self.client, graph_id)
nodes_data = []
for node in nodes:
nodes_data.append({
"uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
"name": node.name or "",
"labels": node.labels or [],
"summary": node.summary or "",
"attributes": node.attributes or {},
})
logger.info(t("log.zep_entity_reader.m004", len=len(nodes_data)))
return nodes_data
def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]:
"""Return every edge in the graph (paginated under the hood).
Args:
graph_id: Graph identifier.
Returns:
A list of edge dicts.
"""
logger.info(t("log.zep_entity_reader.m005", graph_id=graph_id))
edges = fetch_all_edges(self.client, graph_id)
edges_data = []
for edge in edges:
edges_data.append({
"uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
"name": edge.name or "",
"fact": edge.fact or "",
"source_node_uuid": edge.source_node_uuid,
"target_node_uuid": edge.target_node_uuid,
"attributes": edge.attributes or {},
})
logger.info(t("log.zep_entity_reader.m006", len=len(edges_data)))
return edges_data
def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]:
"""Return every edge connected to the given node (with retry).
Args:
node_uuid: Node UUID.
Returns:
A list of edge dicts.
"""
try:
# Wrap the API call in retry logic.
edges = self._call_with_retry(
func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid),
operation_name=f"获取节点边(node={node_uuid[:8]}...)"
)
edges_data = []
for edge in edges:
edges_data.append({
"uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
"name": edge.name or "",
"fact": edge.fact or "",
"source_node_uuid": edge.source_node_uuid,
"target_node_uuid": edge.target_node_uuid,
"attributes": edge.attributes or {},
})
return edges_data
except Exception as e:
logger.warning(t("log.zep_entity_reader.m007", node_uuid=node_uuid, str=str(e)))
return []
def filter_defined_entities(
self,
graph_id: str,
defined_entity_types: Optional[List[str]] = None,
enrich_with_edges: bool = True
) -> FilteredEntities:
"""Filter nodes down to entities matching the predefined ontology types.
Filtering rules:
- Skip nodes whose only label is ``Entity`` (uncategorized).
- Keep nodes whose labels include anything other than ``Entity`` and ``Node``.
Args:
graph_id: Graph identifier.
defined_entity_types: Optional allow-list; when provided, only matching types are kept.
enrich_with_edges: When ``True``, populate related_edges and related_nodes.
Returns:
A ``FilteredEntities`` summary.
"""
logger.info(t("log.zep_entity_reader.m008", graph_id=graph_id))
# Look up ontology from project to classify entities
ontology = None
try:
from ..models.project import ProjectManager
from .graph_builder import _classify_entity_type
for p in ProjectManager.list_projects():
if p.graph_id == graph_id and p.ontology:
ontology = p.ontology
break
except Exception:
pass
# Read every node from the graph.
all_nodes = self.get_all_nodes(graph_id)
total_count = len(all_nodes)
# Apply ontology-based classification so all nodes get proper type labels
if ontology:
for node in all_nodes:
labels = node.get("labels", [])
custom = [l for l in labels if l not in ("Entity", "Node")]
if not custom:
entity_type = _classify_entity_type(
node.get("name", ""), node.get("summary", ""), ontology
)
if entity_type != "Entity":
node["labels"] = [entity_type] + labels
# Read every edge so we can enrich entities later.
all_edges = self.get_all_edges(graph_id) if enrich_with_edges else []
# uuid -> node-data map for fast lookup.
node_map = {n["uuid"]: n for n in all_nodes}
# Filter to entities that match the criteria.
filtered_entities = []
entity_types_found = set()
for node in all_nodes:
labels = node.get("labels", [])
# Filtering rule: labels must contain something other than the defaults.
custom_labels = [l for l in labels if l not in ["Entity", "Node"]]
if not custom_labels:
# Only default labels — skip.
continue
# When a predefined-type list is supplied, require a match against it.
if defined_entity_types:
matching_labels = [l for l in custom_labels if l in defined_entity_types]
if not matching_labels:
continue
entity_type = matching_labels[0]
else:
entity_type = custom_labels[0]
entity_types_found.add(entity_type)
entity = EntityNode(
uuid=node["uuid"],
name=node["name"],
labels=labels,
summary=node["summary"],
attributes=node["attributes"],
)
# Enrich with related edges and neighboring nodes.
if enrich_with_edges:
related_edges = []
related_node_uuids = set()
for edge in all_edges:
if edge["source_node_uuid"] == node["uuid"]:
related_edges.append({
"direction": "outgoing",
"edge_name": edge["name"],
"fact": edge["fact"],
"target_node_uuid": edge["target_node_uuid"],
})
related_node_uuids.add(edge["target_node_uuid"])
elif edge["target_node_uuid"] == node["uuid"]:
related_edges.append({
"direction": "incoming",
"edge_name": edge["name"],
"fact": edge["fact"],
"source_node_uuid": edge["source_node_uuid"],
})
related_node_uuids.add(edge["source_node_uuid"])
entity.related_edges = related_edges
# Populate basic info for each neighboring node.
related_nodes = []
for related_uuid in related_node_uuids:
if related_uuid in node_map:
related_node = node_map[related_uuid]
related_nodes.append({
"uuid": related_node["uuid"],
"name": related_node["name"],
"labels": related_node["labels"],
"summary": related_node.get("summary", ""),
})
entity.related_nodes = related_nodes
filtered_entities.append(entity)
logger.info(t("log.zep_entity_reader.m009", total_count=total_count, len=len(filtered_entities), entity_types_found=entity_types_found))
return FilteredEntities(
entities=filtered_entities,
entity_types=entity_types_found,
total_count=total_count,
filtered_count=len(filtered_entities),
)
def get_entity_with_context(
self,
graph_id: str,
entity_uuid: str
) -> Optional[EntityNode]:
"""Fetch a single entity with its full context (edges + neighbors), with retry.
Args:
graph_id: Graph identifier.
entity_uuid: Entity UUID.
Returns:
``EntityNode`` or ``None`` if not found.
"""
try:
# Fetch the node with retry.
node = self._call_with_retry(
func=lambda: self.client.graph.node.get(uuid_=entity_uuid),
operation_name=f"获取节点详情(uuid={entity_uuid[:8]}...)"
)
if not node:
return None
# Edges connected to this node.
edges = self.get_node_edges(entity_uuid)
# All graph nodes, used for neighbor lookup.
all_nodes = self.get_all_nodes(graph_id)
node_map = {n["uuid"]: n for n in all_nodes}
# Collect related edges and neighboring uuids.
related_edges = []
related_node_uuids = set()
for edge in edges:
if edge["source_node_uuid"] == entity_uuid:
related_edges.append({
"direction": "outgoing",
"edge_name": edge["name"],
"fact": edge["fact"],
"target_node_uuid": edge["target_node_uuid"],
})
related_node_uuids.add(edge["target_node_uuid"])
else:
related_edges.append({
"direction": "incoming",
"edge_name": edge["name"],
"fact": edge["fact"],
"source_node_uuid": edge["source_node_uuid"],
})
related_node_uuids.add(edge["source_node_uuid"])
# Populate basic info for each neighboring node.
related_nodes = []
for related_uuid in related_node_uuids:
if related_uuid in node_map:
related_node = node_map[related_uuid]
related_nodes.append({
"uuid": related_node["uuid"],
"name": related_node["name"],
"labels": related_node["labels"],
"summary": related_node.get("summary", ""),
})
return EntityNode(
uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
name=node.name or "",
labels=node.labels or [],
summary=node.summary or "",
attributes=node.attributes or {},
related_edges=related_edges,
related_nodes=related_nodes,
)
except Exception as e:
logger.error(t("log.zep_entity_reader.m010", entity_uuid=entity_uuid, str=str(e)))
return None
def get_entities_by_type(
self,
graph_id: str,
entity_type: str,
enrich_with_edges: bool = True
) -> List[EntityNode]:
"""Return every entity matching the given type.
Args:
graph_id: Graph identifier.
entity_type: Entity type label (e.g. ``Student``, ``PublicFigure``).
enrich_with_edges: When ``True``, populate related edges/nodes.
Returns:
A list of matching ``EntityNode`` instances.
"""
result = self.filter_defined_entities(
graph_id=graph_id,
defined_entity_types=[entity_type],
enrich_with_edges=enrich_with_edges
)
return result.entities