386 lines
13 KiB
Python
386 lines
13 KiB
Python
"""
|
|
Zep entity read and filter service
|
|
Reads nodes from the Zep graph and filters out nodes that match predefined entity types
|
|
"""
|
|
|
|
import time
|
|
from typing import Dict, Any, List, Optional, Set, Callable, TypeVar
|
|
from dataclasses import dataclass, field
|
|
|
|
from ..config import Config
|
|
from ..graph import get_graph_backend
|
|
from ..utils.logger import get_logger
|
|
|
|
logger = get_logger('mirofish.zep_entity_reader')
|
|
|
|
# Generic return type
|
|
T = TypeVar('T')
|
|
|
|
|
|
@dataclass
|
|
class EntityNode:
|
|
"""Entity node data structure"""
|
|
uuid: str
|
|
name: str
|
|
labels: List[str]
|
|
summary: str
|
|
attributes: Dict[str, Any]
|
|
# Related edge info
|
|
related_edges: List[Dict[str, Any]] = field(default_factory=list)
|
|
# Related node info
|
|
related_nodes: List[Dict[str, Any]] = field(default_factory=list)
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"uuid": self.uuid,
|
|
"name": self.name,
|
|
"labels": self.labels,
|
|
"summary": self.summary,
|
|
"attributes": self.attributes,
|
|
"related_edges": self.related_edges,
|
|
"related_nodes": self.related_nodes,
|
|
}
|
|
|
|
def get_entity_type(self) -> Optional[str]:
|
|
"""Get entity type (excluding the default Entity label)"""
|
|
for label in self.labels:
|
|
if label not in ["Entity", "Node"]:
|
|
return label
|
|
return None
|
|
|
|
|
|
@dataclass
|
|
class FilteredEntities:
|
|
"""Filtered entity collection"""
|
|
entities: List[EntityNode]
|
|
entity_types: Set[str]
|
|
total_count: int
|
|
filtered_count: int
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
return {
|
|
"entities": [e.to_dict() for e in self.entities],
|
|
"entity_types": list(self.entity_types),
|
|
"total_count": self.total_count,
|
|
"filtered_count": self.filtered_count,
|
|
}
|
|
|
|
|
|
class ZepEntityReader:
|
|
"""
|
|
Zep entity read and filter service
|
|
|
|
Main features:
|
|
1. Read all nodes from the Zep graph
|
|
2. Filter out nodes matching predefined entity types (nodes with labels beyond just "Entity")
|
|
3. Fetch related edges and associated node info for each entity
|
|
"""
|
|
|
|
def __init__(self, api_key: Optional[str] = None):
|
|
self._graph = get_graph_backend()
|
|
|
|
def _call_with_retry(
|
|
self,
|
|
func: Callable[[], T],
|
|
operation_name: str,
|
|
max_retries: int = 3,
|
|
initial_delay: float = 2.0
|
|
) -> T:
|
|
"""
|
|
Zep API call with retry logic
|
|
|
|
Args:
|
|
func: function to execute (a lambda or callable with no arguments)
|
|
operation_name: operation name for logging
|
|
max_retries: maximum number of retries (default 3, meaning up to 3 attempts total)
|
|
initial_delay: initial delay in seconds
|
|
|
|
Returns:
|
|
API call result
|
|
"""
|
|
last_exception = None
|
|
delay = initial_delay
|
|
|
|
for attempt in range(max_retries):
|
|
try:
|
|
return func()
|
|
except Exception as e:
|
|
last_exception = e
|
|
if attempt < max_retries - 1:
|
|
logger.warning(
|
|
f"Zep {operation_name} attempt {attempt + 1} failed: {str(e)[:100]}, "
|
|
f"retrying in {delay:.1f}s..."
|
|
)
|
|
time.sleep(delay)
|
|
delay *= 2 # Exponential backoff
|
|
else:
|
|
logger.error(f"Zep {operation_name} still failing after {max_retries} attempts: {str(e)}")
|
|
|
|
raise last_exception
|
|
|
|
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
|
|
"""
|
|
Get all nodes in the graph (paginated)
|
|
|
|
Args:
|
|
graph_id: graph ID
|
|
|
|
Returns:
|
|
Node list
|
|
"""
|
|
logger.info(f"Fetching all nodes for graph {graph_id}...")
|
|
|
|
nodes_data = self._graph.get_all_nodes(graph_id)
|
|
logger.info(f"Fetched {len(nodes_data)} nodes")
|
|
return nodes_data
|
|
|
|
def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]:
|
|
"""
|
|
Get all edges in the graph (paginated)
|
|
|
|
Args:
|
|
graph_id: graph ID
|
|
|
|
Returns:
|
|
Edge list
|
|
"""
|
|
logger.info(f"Fetching all edges for graph {graph_id}...")
|
|
|
|
edges_data = self._graph.get_all_edges(graph_id)
|
|
logger.info(f"Fetched {len(edges_data)} edges")
|
|
return edges_data
|
|
|
|
def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]:
|
|
"""
|
|
Get all edges related to the specified node (with retry logic)
|
|
|
|
Args:
|
|
node_uuid: node UUID
|
|
|
|
Returns:
|
|
Edge list
|
|
"""
|
|
try:
|
|
return self._graph.get_node_edges(node_uuid)
|
|
except Exception as e:
|
|
logger.warning(f"Failed to get edges for node {node_uuid}: {str(e)}")
|
|
return []
|
|
|
|
def filter_defined_entities(
|
|
self,
|
|
graph_id: str,
|
|
defined_entity_types: Optional[List[str]] = None,
|
|
enrich_with_edges: bool = True
|
|
) -> FilteredEntities:
|
|
"""
|
|
Filter out nodes that match predefined entity types
|
|
|
|
Filter logic:
|
|
- If a node's Labels contain only "Entity", it does not match our predefined types; skip it
|
|
- If a node's Labels contain labels other than "Entity" and "Node", it matches a predefined type; keep it
|
|
|
|
Args:
|
|
graph_id: graph ID
|
|
defined_entity_types: list of predefined entity types (optional; if provided, only these types are kept)
|
|
enrich_with_edges: whether to fetch related edge info for each entity
|
|
|
|
Returns:
|
|
FilteredEntities: filtered entity collection
|
|
"""
|
|
logger.info(f"Starting entity filtering for graph {graph_id}...")
|
|
|
|
# Get all nodes
|
|
all_nodes = self.get_all_nodes(graph_id)
|
|
total_count = len(all_nodes)
|
|
|
|
# Get all edges (for relation lookup)
|
|
all_edges = self.get_all_edges(graph_id) if enrich_with_edges else []
|
|
|
|
# Build UUID-to-node mapping
|
|
node_map = {n["uuid"]: n for n in all_nodes}
|
|
|
|
# Filter matching entities
|
|
filtered_entities = []
|
|
entity_types_found = set()
|
|
|
|
for node in all_nodes:
|
|
labels = node.get("labels", [])
|
|
|
|
# Filter logic: Labels must contain at least one label other than "Entity" and "Node"
|
|
custom_labels = [l for l in labels if l not in ["Entity", "Node"]]
|
|
|
|
if not custom_labels:
|
|
# Only default labels; skip
|
|
continue
|
|
|
|
# If predefined types are specified, check for a match
|
|
if defined_entity_types:
|
|
matching_labels = [l for l in custom_labels if l in defined_entity_types]
|
|
if not matching_labels:
|
|
continue
|
|
entity_type = matching_labels[0]
|
|
else:
|
|
entity_type = custom_labels[0]
|
|
|
|
entity_types_found.add(entity_type)
|
|
|
|
# Create entity node object
|
|
entity = EntityNode(
|
|
uuid=node["uuid"],
|
|
name=node["name"],
|
|
labels=labels,
|
|
summary=node["summary"],
|
|
attributes=node["attributes"],
|
|
)
|
|
|
|
# Fetch related edges and nodes
|
|
if enrich_with_edges:
|
|
related_edges = []
|
|
related_node_uuids = set()
|
|
|
|
for edge in all_edges:
|
|
if edge["source_node_uuid"] == node["uuid"]:
|
|
related_edges.append({
|
|
"direction": "outgoing",
|
|
"edge_name": edge["name"],
|
|
"fact": edge["fact"],
|
|
"target_node_uuid": edge["target_node_uuid"],
|
|
})
|
|
related_node_uuids.add(edge["target_node_uuid"])
|
|
elif edge["target_node_uuid"] == node["uuid"]:
|
|
related_edges.append({
|
|
"direction": "incoming",
|
|
"edge_name": edge["name"],
|
|
"fact": edge["fact"],
|
|
"source_node_uuid": edge["source_node_uuid"],
|
|
})
|
|
related_node_uuids.add(edge["source_node_uuid"])
|
|
|
|
entity.related_edges = related_edges
|
|
|
|
# Fetch basic info for related nodes
|
|
related_nodes = []
|
|
for related_uuid in related_node_uuids:
|
|
if related_uuid in node_map:
|
|
related_node = node_map[related_uuid]
|
|
related_nodes.append({
|
|
"uuid": related_node["uuid"],
|
|
"name": related_node["name"],
|
|
"labels": related_node["labels"],
|
|
"summary": related_node.get("summary", ""),
|
|
})
|
|
|
|
entity.related_nodes = related_nodes
|
|
|
|
filtered_entities.append(entity)
|
|
|
|
logger.info(f"Filtering complete: total nodes {total_count}, matching {len(filtered_entities)}, "
|
|
f"entity types: {entity_types_found}")
|
|
|
|
return FilteredEntities(
|
|
entities=filtered_entities,
|
|
entity_types=entity_types_found,
|
|
total_count=total_count,
|
|
filtered_count=len(filtered_entities),
|
|
)
|
|
|
|
def get_entity_with_context(
|
|
self,
|
|
graph_id: str,
|
|
entity_uuid: str
|
|
) -> Optional[EntityNode]:
|
|
"""
|
|
Get a single entity and its full context (edges and related nodes, with retry)
|
|
|
|
Args:
|
|
graph_id: graph ID
|
|
entity_uuid: entity UUID
|
|
|
|
Returns:
|
|
EntityNode or None
|
|
"""
|
|
try:
|
|
node = self._graph.get_node(entity_uuid)
|
|
|
|
if not node:
|
|
return None
|
|
|
|
# Get the node's edges
|
|
edges = self.get_node_edges(entity_uuid)
|
|
|
|
# Get all nodes for relation lookup
|
|
all_nodes = self.get_all_nodes(graph_id)
|
|
node_map = {n["uuid"]: n for n in all_nodes}
|
|
|
|
# Process related edges and nodes
|
|
related_edges = []
|
|
related_node_uuids = set()
|
|
|
|
for edge in edges:
|
|
if edge["source_node_uuid"] == entity_uuid:
|
|
related_edges.append({
|
|
"direction": "outgoing",
|
|
"edge_name": edge["name"],
|
|
"fact": edge["fact"],
|
|
"target_node_uuid": edge["target_node_uuid"],
|
|
})
|
|
related_node_uuids.add(edge["target_node_uuid"])
|
|
else:
|
|
related_edges.append({
|
|
"direction": "incoming",
|
|
"edge_name": edge["name"],
|
|
"fact": edge["fact"],
|
|
"source_node_uuid": edge["source_node_uuid"],
|
|
})
|
|
related_node_uuids.add(edge["source_node_uuid"])
|
|
|
|
# Fetch related node info
|
|
related_nodes = []
|
|
for related_uuid in related_node_uuids:
|
|
if related_uuid in node_map:
|
|
related_node = node_map[related_uuid]
|
|
related_nodes.append({
|
|
"uuid": related_node["uuid"],
|
|
"name": related_node["name"],
|
|
"labels": related_node["labels"],
|
|
"summary": related_node.get("summary", ""),
|
|
})
|
|
|
|
return EntityNode(
|
|
uuid=node.get("uuid", ""),
|
|
name=node.get("name", ""),
|
|
labels=node.get("labels", []),
|
|
summary=node.get("summary", ""),
|
|
attributes=node.get("attributes", {}),
|
|
related_edges=related_edges,
|
|
related_nodes=related_nodes,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get entity {entity_uuid}: {str(e)}")
|
|
return None
|
|
|
|
def get_entities_by_type(
|
|
self,
|
|
graph_id: str,
|
|
entity_type: str,
|
|
enrich_with_edges: bool = True
|
|
) -> List[EntityNode]:
|
|
"""
|
|
Get all entities of a specified type
|
|
|
|
Args:
|
|
graph_id: graph ID
|
|
entity_type: entity type (e.g. "Student", "PublicFigure")
|
|
enrich_with_edges: whether to fetch related edge info
|
|
|
|
Returns:
|
|
Entity list
|
|
"""
|
|
result = self.filter_defined_entities(
|
|
graph_id=graph_id,
|
|
defined_entity_types=[entity_type],
|
|
enrich_with_edges=enrich_with_edges
|
|
)
|
|
return result.entities
|