""" Zep entity read and filter service Reads nodes from the Zep graph and filters out nodes that match predefined entity types """ import time from typing import Dict, Any, List, Optional, Set, Callable, TypeVar from dataclasses import dataclass, field from ..config import Config from ..graph import get_graph_backend from ..utils.logger import get_logger logger = get_logger('mirofish.zep_entity_reader') # Generic return type T = TypeVar('T') @dataclass class EntityNode: """Entity node data structure""" uuid: str name: str labels: List[str] summary: str attributes: Dict[str, Any] # Related edge info related_edges: List[Dict[str, Any]] = field(default_factory=list) # Related node info related_nodes: List[Dict[str, Any]] = field(default_factory=list) def to_dict(self) -> Dict[str, Any]: return { "uuid": self.uuid, "name": self.name, "labels": self.labels, "summary": self.summary, "attributes": self.attributes, "related_edges": self.related_edges, "related_nodes": self.related_nodes, } def get_entity_type(self) -> Optional[str]: """Get entity type (excluding the default Entity label)""" for label in self.labels: if label not in ["Entity", "Node"]: return label return None @dataclass class FilteredEntities: """Filtered entity collection""" entities: List[EntityNode] entity_types: Set[str] total_count: int filtered_count: int def to_dict(self) -> Dict[str, Any]: return { "entities": [e.to_dict() for e in self.entities], "entity_types": list(self.entity_types), "total_count": self.total_count, "filtered_count": self.filtered_count, } class ZepEntityReader: """ Zep entity read and filter service Main features: 1. Read all nodes from the Zep graph 2. Filter out nodes matching predefined entity types (nodes with labels beyond just "Entity") 3. Fetch related edges and associated node info for each entity """ def __init__(self, api_key: Optional[str] = None): self._graph = get_graph_backend() def _call_with_retry( self, func: Callable[[], T], operation_name: str, max_retries: int = 3, initial_delay: float = 2.0 ) -> T: """ Zep API call with retry logic Args: func: function to execute (a lambda or callable with no arguments) operation_name: operation name for logging max_retries: maximum number of retries (default 3, meaning up to 3 attempts total) initial_delay: initial delay in seconds Returns: API call result """ last_exception = None delay = initial_delay for attempt in range(max_retries): try: return func() except Exception as e: last_exception = e if attempt < max_retries - 1: logger.warning( f"Zep {operation_name} attempt {attempt + 1} failed: {str(e)[:100]}, " f"retrying in {delay:.1f}s..." ) time.sleep(delay) delay *= 2 # Exponential backoff else: logger.error(f"Zep {operation_name} still failing after {max_retries} attempts: {str(e)}") raise last_exception def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]: """ Get all nodes in the graph (paginated) Args: graph_id: graph ID Returns: Node list """ logger.info(f"Fetching all nodes for graph {graph_id}...") nodes_data = self._graph.get_all_nodes(graph_id) logger.info(f"Fetched {len(nodes_data)} nodes") return nodes_data def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]: """ Get all edges in the graph (paginated) Args: graph_id: graph ID Returns: Edge list """ logger.info(f"Fetching all edges for graph {graph_id}...") edges_data = self._graph.get_all_edges(graph_id) logger.info(f"Fetched {len(edges_data)} edges") return edges_data def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]: """ Get all edges related to the specified node (with retry logic) Args: node_uuid: node UUID Returns: Edge list """ try: return self._graph.get_node_edges(node_uuid) except Exception as e: logger.warning(f"Failed to get edges for node {node_uuid}: {str(e)}") return [] def filter_defined_entities( self, graph_id: str, defined_entity_types: Optional[List[str]] = None, enrich_with_edges: bool = True ) -> FilteredEntities: """ Filter out nodes that match predefined entity types Filter logic: - If a node's Labels contain only "Entity", it does not match our predefined types; skip it - If a node's Labels contain labels other than "Entity" and "Node", it matches a predefined type; keep it Args: graph_id: graph ID defined_entity_types: list of predefined entity types (optional; if provided, only these types are kept) enrich_with_edges: whether to fetch related edge info for each entity Returns: FilteredEntities: filtered entity collection """ logger.info(f"Starting entity filtering for graph {graph_id}...") # Get all nodes all_nodes = self.get_all_nodes(graph_id) total_count = len(all_nodes) # Get all edges (for relation lookup) all_edges = self.get_all_edges(graph_id) if enrich_with_edges else [] # Build UUID-to-node mapping node_map = {n["uuid"]: n for n in all_nodes} # Filter matching entities filtered_entities = [] entity_types_found = set() for node in all_nodes: labels = node.get("labels", []) # Filter logic: Labels must contain at least one label other than "Entity" and "Node" custom_labels = [l for l in labels if l not in ["Entity", "Node"]] if not custom_labels: # Only default labels; skip continue # If predefined types are specified, check for a match if defined_entity_types: matching_labels = [l for l in custom_labels if l in defined_entity_types] if not matching_labels: continue entity_type = matching_labels[0] else: entity_type = custom_labels[0] entity_types_found.add(entity_type) # Create entity node object entity = EntityNode( uuid=node["uuid"], name=node["name"], labels=labels, summary=node["summary"], attributes=node["attributes"], ) # Fetch related edges and nodes if enrich_with_edges: related_edges = [] related_node_uuids = set() for edge in all_edges: if edge["source_node_uuid"] == node["uuid"]: related_edges.append({ "direction": "outgoing", "edge_name": edge["name"], "fact": edge["fact"], "target_node_uuid": edge["target_node_uuid"], }) related_node_uuids.add(edge["target_node_uuid"]) elif edge["target_node_uuid"] == node["uuid"]: related_edges.append({ "direction": "incoming", "edge_name": edge["name"], "fact": edge["fact"], "source_node_uuid": edge["source_node_uuid"], }) related_node_uuids.add(edge["source_node_uuid"]) entity.related_edges = related_edges # Fetch basic info for related nodes related_nodes = [] for related_uuid in related_node_uuids: if related_uuid in node_map: related_node = node_map[related_uuid] related_nodes.append({ "uuid": related_node["uuid"], "name": related_node["name"], "labels": related_node["labels"], "summary": related_node.get("summary", ""), }) entity.related_nodes = related_nodes filtered_entities.append(entity) logger.info(f"Filtering complete: total nodes {total_count}, matching {len(filtered_entities)}, " f"entity types: {entity_types_found}") return FilteredEntities( entities=filtered_entities, entity_types=entity_types_found, total_count=total_count, filtered_count=len(filtered_entities), ) def get_entity_with_context( self, graph_id: str, entity_uuid: str ) -> Optional[EntityNode]: """ Get a single entity and its full context (edges and related nodes, with retry) Args: graph_id: graph ID entity_uuid: entity UUID Returns: EntityNode or None """ try: node = self._graph.get_node(entity_uuid) if not node: return None # Get the node's edges edges = self.get_node_edges(entity_uuid) # Get all nodes for relation lookup all_nodes = self.get_all_nodes(graph_id) node_map = {n["uuid"]: n for n in all_nodes} # Process related edges and nodes related_edges = [] related_node_uuids = set() for edge in edges: if edge["source_node_uuid"] == entity_uuid: related_edges.append({ "direction": "outgoing", "edge_name": edge["name"], "fact": edge["fact"], "target_node_uuid": edge["target_node_uuid"], }) related_node_uuids.add(edge["target_node_uuid"]) else: related_edges.append({ "direction": "incoming", "edge_name": edge["name"], "fact": edge["fact"], "source_node_uuid": edge["source_node_uuid"], }) related_node_uuids.add(edge["source_node_uuid"]) # Fetch related node info related_nodes = [] for related_uuid in related_node_uuids: if related_uuid in node_map: related_node = node_map[related_uuid] related_nodes.append({ "uuid": related_node["uuid"], "name": related_node["name"], "labels": related_node["labels"], "summary": related_node.get("summary", ""), }) return EntityNode( uuid=node.get("uuid", ""), name=node.get("name", ""), labels=node.get("labels", []), summary=node.get("summary", ""), attributes=node.get("attributes", {}), related_edges=related_edges, related_nodes=related_nodes, ) except Exception as e: logger.error(f"Failed to get entity {entity_uuid}: {str(e)}") return None def get_entities_by_connectivity( self, graph_id: str, max_n: Optional[int] = None, defined_entity_types: Optional[List[str]] = None, ) -> List[EntityNode]: """Return entities sorted by edge degree (descending), optionally capped at max_n.""" filtered = self.filter_defined_entities( graph_id=graph_id, defined_entity_types=defined_entity_types, enrich_with_edges=True, ) entities = sorted( filtered.entities, key=lambda e: len(e.related_edges), reverse=True, ) if max_n is not None and max_n > 0: entities = entities[:max_n] return entities def get_entities_by_type( self, graph_id: str, entity_type: str, enrich_with_edges: bool = True ) -> List[EntityNode]: """ Get all entities of a specified type Args: graph_id: graph ID entity_type: entity type (e.g. "Student", "PublicFigure") enrich_with_edges: whether to fetch related edge info Returns: Entity list """ result = self.filter_defined_entities( graph_id=graph_id, defined_entity_types=[entity_type], enrich_with_edges=enrich_with_edges ) return result.entities