MicroFish/backend/app/services/zep_entity_reader.py

386 lines
13 KiB
Python

"""
Zep entity read and filter service
Reads nodes from the Zep graph and filters out nodes that match predefined entity types
"""
import time
from typing import Dict, Any, List, Optional, Set, Callable, TypeVar
from dataclasses import dataclass, field
from ..config import Config
from ..graph import get_graph_backend
from ..utils.logger import get_logger
logger = get_logger('mirofish.zep_entity_reader')
# Generic return type
T = TypeVar('T')
@dataclass
class EntityNode:
"""Entity node data structure"""
uuid: str
name: str
labels: List[str]
summary: str
attributes: Dict[str, Any]
# Related edge info
related_edges: List[Dict[str, Any]] = field(default_factory=list)
# Related node info
related_nodes: List[Dict[str, Any]] = field(default_factory=list)
def to_dict(self) -> Dict[str, Any]:
return {
"uuid": self.uuid,
"name": self.name,
"labels": self.labels,
"summary": self.summary,
"attributes": self.attributes,
"related_edges": self.related_edges,
"related_nodes": self.related_nodes,
}
def get_entity_type(self) -> Optional[str]:
"""Get entity type (excluding the default Entity label)"""
for label in self.labels:
if label not in ["Entity", "Node"]:
return label
return None
@dataclass
class FilteredEntities:
"""Filtered entity collection"""
entities: List[EntityNode]
entity_types: Set[str]
total_count: int
filtered_count: int
def to_dict(self) -> Dict[str, Any]:
return {
"entities": [e.to_dict() for e in self.entities],
"entity_types": list(self.entity_types),
"total_count": self.total_count,
"filtered_count": self.filtered_count,
}
class ZepEntityReader:
"""
Zep entity read and filter service
Main features:
1. Read all nodes from the Zep graph
2. Filter out nodes matching predefined entity types (nodes with labels beyond just "Entity")
3. Fetch related edges and associated node info for each entity
"""
def __init__(self, api_key: Optional[str] = None):
self._graph = get_graph_backend()
def _call_with_retry(
self,
func: Callable[[], T],
operation_name: str,
max_retries: int = 3,
initial_delay: float = 2.0
) -> T:
"""
Zep API call with retry logic
Args:
func: function to execute (a lambda or callable with no arguments)
operation_name: operation name for logging
max_retries: maximum number of retries (default 3, meaning up to 3 attempts total)
initial_delay: initial delay in seconds
Returns:
API call result
"""
last_exception = None
delay = initial_delay
for attempt in range(max_retries):
try:
return func()
except Exception as e:
last_exception = e
if attempt < max_retries - 1:
logger.warning(
f"Zep {operation_name} attempt {attempt + 1} failed: {str(e)[:100]}, "
f"retrying in {delay:.1f}s..."
)
time.sleep(delay)
delay *= 2 # Exponential backoff
else:
logger.error(f"Zep {operation_name} still failing after {max_retries} attempts: {str(e)}")
raise last_exception
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
"""
Get all nodes in the graph (paginated)
Args:
graph_id: graph ID
Returns:
Node list
"""
logger.info(f"Fetching all nodes for graph {graph_id}...")
nodes_data = self._graph.get_all_nodes(graph_id)
logger.info(f"Fetched {len(nodes_data)} nodes")
return nodes_data
def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]:
"""
Get all edges in the graph (paginated)
Args:
graph_id: graph ID
Returns:
Edge list
"""
logger.info(f"Fetching all edges for graph {graph_id}...")
edges_data = self._graph.get_all_edges(graph_id)
logger.info(f"Fetched {len(edges_data)} edges")
return edges_data
def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]:
"""
Get all edges related to the specified node (with retry logic)
Args:
node_uuid: node UUID
Returns:
Edge list
"""
try:
return self._graph.get_node_edges(node_uuid)
except Exception as e:
logger.warning(f"Failed to get edges for node {node_uuid}: {str(e)}")
return []
def filter_defined_entities(
self,
graph_id: str,
defined_entity_types: Optional[List[str]] = None,
enrich_with_edges: bool = True
) -> FilteredEntities:
"""
Filter out nodes that match predefined entity types
Filter logic:
- If a node's Labels contain only "Entity", it does not match our predefined types; skip it
- If a node's Labels contain labels other than "Entity" and "Node", it matches a predefined type; keep it
Args:
graph_id: graph ID
defined_entity_types: list of predefined entity types (optional; if provided, only these types are kept)
enrich_with_edges: whether to fetch related edge info for each entity
Returns:
FilteredEntities: filtered entity collection
"""
logger.info(f"Starting entity filtering for graph {graph_id}...")
# Get all nodes
all_nodes = self.get_all_nodes(graph_id)
total_count = len(all_nodes)
# Get all edges (for relation lookup)
all_edges = self.get_all_edges(graph_id) if enrich_with_edges else []
# Build UUID-to-node mapping
node_map = {n["uuid"]: n for n in all_nodes}
# Filter matching entities
filtered_entities = []
entity_types_found = set()
for node in all_nodes:
labels = node.get("labels", [])
# Filter logic: Labels must contain at least one label other than "Entity" and "Node"
custom_labels = [l for l in labels if l not in ["Entity", "Node"]]
if not custom_labels:
# Only default labels; skip
continue
# If predefined types are specified, check for a match
if defined_entity_types:
matching_labels = [l for l in custom_labels if l in defined_entity_types]
if not matching_labels:
continue
entity_type = matching_labels[0]
else:
entity_type = custom_labels[0]
entity_types_found.add(entity_type)
# Create entity node object
entity = EntityNode(
uuid=node["uuid"],
name=node["name"],
labels=labels,
summary=node["summary"],
attributes=node["attributes"],
)
# Fetch related edges and nodes
if enrich_with_edges:
related_edges = []
related_node_uuids = set()
for edge in all_edges:
if edge["source_node_uuid"] == node["uuid"]:
related_edges.append({
"direction": "outgoing",
"edge_name": edge["name"],
"fact": edge["fact"],
"target_node_uuid": edge["target_node_uuid"],
})
related_node_uuids.add(edge["target_node_uuid"])
elif edge["target_node_uuid"] == node["uuid"]:
related_edges.append({
"direction": "incoming",
"edge_name": edge["name"],
"fact": edge["fact"],
"source_node_uuid": edge["source_node_uuid"],
})
related_node_uuids.add(edge["source_node_uuid"])
entity.related_edges = related_edges
# Fetch basic info for related nodes
related_nodes = []
for related_uuid in related_node_uuids:
if related_uuid in node_map:
related_node = node_map[related_uuid]
related_nodes.append({
"uuid": related_node["uuid"],
"name": related_node["name"],
"labels": related_node["labels"],
"summary": related_node.get("summary", ""),
})
entity.related_nodes = related_nodes
filtered_entities.append(entity)
logger.info(f"Filtering complete: total nodes {total_count}, matching {len(filtered_entities)}, "
f"entity types: {entity_types_found}")
return FilteredEntities(
entities=filtered_entities,
entity_types=entity_types_found,
total_count=total_count,
filtered_count=len(filtered_entities),
)
def get_entity_with_context(
self,
graph_id: str,
entity_uuid: str
) -> Optional[EntityNode]:
"""
Get a single entity and its full context (edges and related nodes, with retry)
Args:
graph_id: graph ID
entity_uuid: entity UUID
Returns:
EntityNode or None
"""
try:
node = self._graph.get_node(entity_uuid)
if not node:
return None
# Get the node's edges
edges = self.get_node_edges(entity_uuid)
# Get all nodes for relation lookup
all_nodes = self.get_all_nodes(graph_id)
node_map = {n["uuid"]: n for n in all_nodes}
# Process related edges and nodes
related_edges = []
related_node_uuids = set()
for edge in edges:
if edge["source_node_uuid"] == entity_uuid:
related_edges.append({
"direction": "outgoing",
"edge_name": edge["name"],
"fact": edge["fact"],
"target_node_uuid": edge["target_node_uuid"],
})
related_node_uuids.add(edge["target_node_uuid"])
else:
related_edges.append({
"direction": "incoming",
"edge_name": edge["name"],
"fact": edge["fact"],
"source_node_uuid": edge["source_node_uuid"],
})
related_node_uuids.add(edge["source_node_uuid"])
# Fetch related node info
related_nodes = []
for related_uuid in related_node_uuids:
if related_uuid in node_map:
related_node = node_map[related_uuid]
related_nodes.append({
"uuid": related_node["uuid"],
"name": related_node["name"],
"labels": related_node["labels"],
"summary": related_node.get("summary", ""),
})
return EntityNode(
uuid=node.get("uuid", ""),
name=node.get("name", ""),
labels=node.get("labels", []),
summary=node.get("summary", ""),
attributes=node.get("attributes", {}),
related_edges=related_edges,
related_nodes=related_nodes,
)
except Exception as e:
logger.error(f"Failed to get entity {entity_uuid}: {str(e)}")
return None
def get_entities_by_type(
self,
graph_id: str,
entity_type: str,
enrich_with_edges: bool = True
) -> List[EntityNode]:
"""
Get all entities of a specified type
Args:
graph_id: graph ID
entity_type: entity type (e.g. "Student", "PublicFigure")
enrich_with_edges: whether to fetch related edge info
Returns:
Entity list
"""
result = self.filter_defined_entities(
graph_id=graph_id,
defined_entity_types=[entity_type],
enrich_with_edges=enrich_with_edges
)
return result.entities