diff --git a/backend/app/services/graph_builder.py b/backend/app/services/graph_builder.py index 566c4321..d6aa33c3 100644 --- a/backend/app/services/graph_builder.py +++ b/backend/app/services/graph_builder.py @@ -10,12 +10,11 @@ import threading from typing import Dict, Any, List, Optional, Callable from dataclasses import dataclass -from zep_cloud.client import Zep -from zep_cloud import EpisodeData, EntityEdgeSourceTarget +from zep_cloud import EntityEdgeSourceTarget from ..config import Config +from ..graph import get_graph_backend from ..models.task import TaskManager, TaskStatus -from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges from .text_processor import TextProcessor from ..utils.locale import t, get_locale, set_locale @@ -43,12 +42,8 @@ class GraphBuilderService: Responsible for calling the Zep API to build the knowledge graph. """ - def __init__(self, api_key: Optional[str] = None): - self.api_key = api_key or Config.ZEP_API_KEY - if not self.api_key: - raise ValueError("ZEP_API_KEY is not configured") - - self.client = Zep(api_key=self.api_key) + def __init__(self): + self._graph = get_graph_backend() self.task_manager = TaskManager() def build_graph_async( @@ -191,15 +186,13 @@ class GraphBuilderService: self.task_manager.fail_task(task_id, error_msg) def create_graph(self, name: str) -> str: - """Create a Zep graph (public method)""" + """Create a graph (public method)""" graph_id = f"mirofish_{uuid.uuid4().hex[:16]}" - - self.client.graph.create( + self._graph.create_graph( graph_id=graph_id, name=name, description="MiroFish Social Simulation Graph" ) - return graph_id def set_ontology(self, graph_id: str, ontology: Dict[str, Any]): @@ -283,9 +276,8 @@ class GraphBuilderService: if source_targets: edge_definitions[name] = (edge_class, source_targets) - # Call Zep API to set ontology if entity_types or edge_definitions: - self.client.graph.set_ontology( + self._graph.set_ontology( graph_ids=[graph_id], entities=entity_types if entity_types else None, edges=edge_definitions if edge_definitions else None, @@ -314,25 +306,14 @@ class GraphBuilderService: progress ) - # Build episode data episodes = [ - EpisodeData(data=chunk, type="text") + {"data": chunk, "type": "text"} for chunk in batch_chunks ] - - # Send to Zep + try: - batch_result = self.client.graph.add_batch( - graph_id=graph_id, - episodes=episodes - ) - - # Collect returned episode UUIDs - if batch_result and isinstance(batch_result, list): - for ep in batch_result: - ep_uuid = getattr(ep, 'uuid_', None) or getattr(ep, 'uuid', None) - if ep_uuid: - episode_uuids.append(ep_uuid) + returned_uuids = self._graph.add_batch(graph_id=graph_id, episodes=episodes) + episode_uuids.extend(returned_uuids) # Avoid sending requests too quickly time.sleep(1) @@ -376,7 +357,7 @@ class GraphBuilderService: # Check processing status of each episode for ep_uuid in list(pending_episodes): try: - episode = self.client.graph.episode.get(uuid_=ep_uuid) + episode = self._graph.get_episode(ep_uuid) is_processed = getattr(episode, 'processed', False) if is_processed: @@ -402,19 +383,14 @@ class GraphBuilderService: def _get_graph_info(self, graph_id: str) -> GraphInfo: """Retrieve graph info""" - # Fetch nodes (paginated) - nodes = fetch_all_nodes(self.client, graph_id) + nodes = self._graph.get_all_nodes(graph_id) + edges = self._graph.get_all_edges(graph_id) - # Fetch edges (paginated) - edges = fetch_all_edges(self.client, graph_id) - - # Count entity types entity_types = set() for node in nodes: - if node.labels: - for label in node.labels: - if label not in ["Entity", "Node"]: - entity_types.add(label) + for label in node.get("labels", []): + if label not in ["Entity", "Node"]: + entity_types.add(label) return GraphInfo( graph_id=graph_id, @@ -424,83 +400,25 @@ class GraphBuilderService: ) def get_graph_data(self, graph_id: str) -> Dict[str, Any]: - """ - Retrieve full graph data (with detailed information). + """Retrieve full graph data (nodes + edges with timestamps and attributes).""" + nodes = self._graph.get_all_nodes(graph_id) + edges = self._graph.get_all_edges(graph_id) - Args: - graph_id: graph ID + node_map = {n["uuid"]: n.get("name", "") for n in nodes} - Returns: - Dictionary containing nodes and edges with timestamps, attributes, and other details - """ - nodes = fetch_all_nodes(self.client, graph_id) - edges = fetch_all_edges(self.client, graph_id) - - # Build node map for looking up node names - node_map = {} - for node in nodes: - node_map[node.uuid_] = node.name or "" - - nodes_data = [] - for node in nodes: - # Get creation timestamp - created_at = getattr(node, 'created_at', None) - if created_at: - created_at = str(created_at) - - nodes_data.append({ - "uuid": node.uuid_, - "name": node.name, - "labels": node.labels or [], - "summary": node.summary or "", - "attributes": node.attributes or {}, - "created_at": created_at, - }) - - edges_data = [] - for edge in edges: - # Get timestamps - created_at = getattr(edge, 'created_at', None) - valid_at = getattr(edge, 'valid_at', None) - invalid_at = getattr(edge, 'invalid_at', None) - expired_at = getattr(edge, 'expired_at', None) - - # Get episodes - episodes = getattr(edge, 'episodes', None) or getattr(edge, 'episode_ids', None) - if episodes and not isinstance(episodes, list): - episodes = [str(episodes)] - elif episodes: - episodes = [str(e) for e in episodes] - - # Get fact_type - fact_type = getattr(edge, 'fact_type', None) or edge.name or "" - - edges_data.append({ - "uuid": edge.uuid_, - "name": edge.name or "", - "fact": edge.fact or "", - "fact_type": fact_type, - "source_node_uuid": edge.source_node_uuid, - "target_node_uuid": edge.target_node_uuid, - "source_node_name": node_map.get(edge.source_node_uuid, ""), - "target_node_name": node_map.get(edge.target_node_uuid, ""), - "attributes": edge.attributes or {}, - "created_at": str(created_at) if created_at else None, - "valid_at": str(valid_at) if valid_at else None, - "invalid_at": str(invalid_at) if invalid_at else None, - "expired_at": str(expired_at) if expired_at else None, - "episodes": episodes or [], - }) - return { "graph_id": graph_id, - "nodes": nodes_data, - "edges": edges_data, - "node_count": len(nodes_data), - "edge_count": len(edges_data), + "nodes": nodes, + "edges": [ + {**e, "source_node_name": node_map.get(e.get("source_node_uuid", ""), ""), + "target_node_name": node_map.get(e.get("target_node_uuid", ""), "")} + for e in edges + ], + "node_count": len(nodes), + "edge_count": len(edges), } def delete_graph(self, graph_id: str): """Delete graph""" - self.client.graph.delete(graph_id=graph_id) + self._graph.delete_graph(graph_id) diff --git a/backend/app/services/zep_entity_reader.py b/backend/app/services/zep_entity_reader.py index dc14961e..30a5fef2 100644 --- a/backend/app/services/zep_entity_reader.py +++ b/backend/app/services/zep_entity_reader.py @@ -7,11 +7,9 @@ import time from typing import Dict, Any, List, Optional, Set, Callable, TypeVar from dataclasses import dataclass, field -from zep_cloud.client import Zep - from ..config import Config +from ..graph import get_graph_backend from ..utils.logger import get_logger -from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges logger = get_logger('mirofish.zep_entity_reader') @@ -79,11 +77,7 @@ class ZepEntityReader: """ def __init__(self, api_key: Optional[str] = None): - self.api_key = api_key or Config.ZEP_API_KEY - if not self.api_key: - raise ValueError("ZEP_API_KEY is not configured") - - self.client = Zep(api_key=self.api_key) + self._graph = get_graph_backend() def _call_with_retry( self, @@ -136,18 +130,7 @@ class ZepEntityReader: """ logger.info(f"Fetching all nodes for graph {graph_id}...") - nodes = fetch_all_nodes(self.client, graph_id) - - nodes_data = [] - for node in nodes: - nodes_data.append({ - "uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), - "name": node.name or "", - "labels": node.labels or [], - "summary": node.summary or "", - "attributes": node.attributes or {}, - }) - + nodes_data = self._graph.get_all_nodes(graph_id) logger.info(f"Fetched {len(nodes_data)} nodes") return nodes_data @@ -163,19 +146,7 @@ class ZepEntityReader: """ logger.info(f"Fetching all edges for graph {graph_id}...") - edges = fetch_all_edges(self.client, graph_id) - - edges_data = [] - for edge in edges: - edges_data.append({ - "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), - "name": edge.name or "", - "fact": edge.fact or "", - "source_node_uuid": edge.source_node_uuid, - "target_node_uuid": edge.target_node_uuid, - "attributes": edge.attributes or {}, - }) - + edges_data = self._graph.get_all_edges(graph_id) logger.info(f"Fetched {len(edges_data)} edges") return edges_data @@ -190,24 +161,7 @@ class ZepEntityReader: Edge list """ try: - # Call Zep API with retry - edges = self._call_with_retry( - func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid), - operation_name=f"get node edges (node={node_uuid[:8]}...)" - ) - - edges_data = [] - for edge in edges: - edges_data.append({ - "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), - "name": edge.name or "", - "fact": edge.fact or "", - "source_node_uuid": edge.source_node_uuid, - "target_node_uuid": edge.target_node_uuid, - "attributes": edge.attributes or {}, - }) - - return edges_data + return self._graph.get_node_edges(node_uuid) except Exception as e: logger.warning(f"Failed to get edges for node {node_uuid}: {str(e)}") return [] @@ -346,11 +300,7 @@ class ZepEntityReader: EntityNode or None """ try: - # Get the node with retry - node = self._call_with_retry( - func=lambda: self.client.graph.node.get(uuid_=entity_uuid), - operation_name=f"get node detail (uuid={entity_uuid[:8]}...)" - ) + node = self._graph.get_node(entity_uuid) if not node: return None @@ -397,11 +347,11 @@ class ZepEntityReader: }) return EntityNode( - uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), - name=node.name or "", - labels=node.labels or [], - summary=node.summary or "", - attributes=node.attributes or {}, + uuid=node.get("uuid", ""), + name=node.get("name", ""), + labels=node.get("labels", []), + summary=node.get("summary", ""), + attributes=node.get("attributes", {}), related_edges=related_edges, related_nodes=related_nodes, ) diff --git a/backend/app/services/zep_graph_memory_updater.py b/backend/app/services/zep_graph_memory_updater.py index eab77fdd..683f1634 100644 --- a/backend/app/services/zep_graph_memory_updater.py +++ b/backend/app/services/zep_graph_memory_updater.py @@ -12,9 +12,8 @@ from dataclasses import dataclass from datetime import datetime from queue import Queue, Empty -from zep_cloud.client import Zep - from ..config import Config +from ..graph import get_graph_backend from ..utils.logger import get_logger from ..utils.locale import get_locale, set_locale @@ -240,12 +239,7 @@ class ZepGraphMemoryUpdater: api_key: Zep API key (optional; defaults to config value) """ self.graph_id = graph_id - self.api_key = api_key or Config.ZEP_API_KEY - - if not self.api_key: - raise ValueError("ZEP_API_KEY is not configured") - - self.client = Zep(api_key=self.api_key) + self._graph = get_graph_backend() # Activity queue self._activity_queue: Queue = Queue() @@ -413,11 +407,7 @@ class ZepGraphMemoryUpdater: # Send with retry for attempt in range(self.MAX_RETRIES): try: - self.client.graph.add( - graph_id=self.graph_id, - type="text", - data=combined_text - ) + self._graph.add_text(self.graph_id, combined_text) self._total_sent += 1 self._total_items_sent += len(activities) diff --git a/backend/app/services/zep_tools.py b/backend/app/services/zep_tools.py index 1cadcbd5..ee9e981c 100644 --- a/backend/app/services/zep_tools.py +++ b/backend/app/services/zep_tools.py @@ -13,13 +13,11 @@ import json from typing import Dict, Any, List, Optional from dataclasses import dataclass, field -from zep_cloud.client import Zep - from ..config import Config +from ..graph import get_graph_backend from ..utils.logger import get_logger from ..utils.llm_client import LLMClient from ..utils.locale import get_locale, t -from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges logger = get_logger('mirofish.zep_tools') @@ -423,12 +421,7 @@ class ZepToolsService: RETRY_DELAY = 2.0 def __init__(self, api_key: Optional[str] = None, llm_client: Optional[LLMClient] = None): - self.api_key = api_key or Config.ZEP_API_KEY - if not self.api_key: - raise ValueError("ZEP_API_KEY is not configured") - - self.client = Zep(api_key=self.api_key) - # LLM client used by InsightForge to generate sub-queries + self._graph = get_graph_backend() self._llm_client = llm_client logger.info(t("console.zepToolsInitialized")) @@ -485,51 +478,38 @@ class ZepToolsService: """ logger.info(t("console.graphSearch", graphId=graph_id, query=query[:50])) - # Try using the Zep Cloud Search API try: - search_results = self._call_with_retry( - func=lambda: self.client.graph.search( - graph_id=graph_id, - query=query, - limit=limit, - scope=scope, - reranker="cross_encoder" - ), - operation_name=t("console.graphSearchOp", graphId=graph_id) - ) - + raw = self._graph.search(graph_id=graph_id, query=query, limit=limit, scope=scope) + facts = [] edges = [] nodes = [] - - # Parse edge search results - if hasattr(search_results, 'edges') and search_results.edges: - for edge in search_results.edges: - if hasattr(edge, 'fact') and edge.fact: - facts.append(edge.fact) - edges.append({ - "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), - "name": getattr(edge, 'name', ''), - "fact": getattr(edge, 'fact', ''), - "source_node_uuid": getattr(edge, 'source_node_uuid', ''), - "target_node_uuid": getattr(edge, 'target_node_uuid', ''), - }) - - # Parse node search results - if hasattr(search_results, 'nodes') and search_results.nodes: - for node in search_results.nodes: - nodes.append({ - "uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), - "name": getattr(node, 'name', ''), - "labels": getattr(node, 'labels', []), - "summary": getattr(node, 'summary', ''), - }) - # Node summaries count as facts too - if hasattr(node, 'summary') and node.summary: - facts.append(f"[{node.name}]: {node.summary}") - + + for edge in raw.get("edges", []) or []: + fact = edge.get("fact", "") if isinstance(edge, dict) else getattr(edge, "fact", "") + if fact: + facts.append(fact) + edges.append(edge if isinstance(edge, dict) else { + "uuid": getattr(edge, "uuid_", None) or getattr(edge, "uuid", ""), + "name": getattr(edge, "name", ""), + "fact": getattr(edge, "fact", ""), + "source_node_uuid": getattr(edge, "source_node_uuid", ""), + "target_node_uuid": getattr(edge, "target_node_uuid", ""), + }) + + for node in raw.get("nodes", []) or []: + node_dict = node if isinstance(node, dict) else { + "uuid": getattr(node, "uuid_", None) or getattr(node, "uuid", ""), + "name": getattr(node, "name", ""), + "labels": getattr(node, "labels", []), + "summary": getattr(node, "summary", ""), + } + nodes.append(node_dict) + if node_dict.get("summary"): + facts.append(f"[{node_dict['name']}]: {node_dict['summary']}") + logger.info(t("console.searchComplete", count=len(facts))) - + return SearchResult( facts=facts, edges=edges, @@ -659,18 +639,18 @@ class ZepToolsService: """ logger.info(t("console.fetchingAllNodes", graphId=graph_id)) - nodes = fetch_all_nodes(self.client, graph_id) + nodes = self._graph.get_all_nodes(graph_id) - result = [] - for node in nodes: - node_uuid = getattr(node, 'uuid_', None) or getattr(node, 'uuid', None) or "" - result.append(NodeInfo( - uuid=str(node_uuid) if node_uuid else "", - name=node.name or "", - labels=node.labels or [], - summary=node.summary or "", - attributes=node.attributes or {} - )) + result = [ + NodeInfo( + uuid=n.get("uuid", ""), + name=n.get("name", ""), + labels=n.get("labels", []), + summary=n.get("summary", ""), + attributes=n.get("attributes", {}) + ) + for n in nodes + ] logger.info(t("console.fetchedNodes", count=len(result))) return result @@ -688,26 +668,22 @@ class ZepToolsService: """ logger.info(t("console.fetchingAllEdges", graphId=graph_id)) - edges = fetch_all_edges(self.client, graph_id) + edges = self._graph.get_all_edges(graph_id) result = [] - for edge in edges: - edge_uuid = getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', None) or "" + for e in edges: edge_info = EdgeInfo( - uuid=str(edge_uuid) if edge_uuid else "", - name=edge.name or "", - fact=edge.fact or "", - source_node_uuid=edge.source_node_uuid or "", - target_node_uuid=edge.target_node_uuid or "" + uuid=e.get("uuid", ""), + name=e.get("name", ""), + fact=e.get("fact", ""), + source_node_uuid=e.get("source_node_uuid", ""), + target_node_uuid=e.get("target_node_uuid", ""), ) - - # Add temporal info if include_temporal: - edge_info.created_at = getattr(edge, 'created_at', None) - edge_info.valid_at = getattr(edge, 'valid_at', None) - edge_info.invalid_at = getattr(edge, 'invalid_at', None) - edge_info.expired_at = getattr(edge, 'expired_at', None) - + edge_info.created_at = e.get("created_at") + edge_info.valid_at = e.get("valid_at") + edge_info.invalid_at = e.get("invalid_at") + edge_info.expired_at = e.get("expired_at") result.append(edge_info) logger.info(t("console.fetchedEdges", count=len(result))) @@ -726,20 +702,17 @@ class ZepToolsService: logger.info(t("console.fetchingNodeDetail", uuid=node_uuid[:8])) try: - node = self._call_with_retry( - func=lambda: self.client.graph.node.get(uuid_=node_uuid), - operation_name=t("console.fetchNodeDetailOp", uuid=node_uuid[:8]) - ) - + node = self._graph.get_node(node_uuid) + if not node: return None - + return NodeInfo( - uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), - name=node.name or "", - labels=node.labels or [], - summary=node.summary or "", - attributes=node.attributes or {} + uuid=node.get("uuid", ""), + name=node.get("name", ""), + labels=node.get("labels", []), + summary=node.get("summary", ""), + attributes=node.get("attributes", {}) ) except Exception as e: logger.error(t("console.fetchNodeDetailFailed", error=str(e)))