refactor(services): replace direct Zep SDK calls with GraphBackend interface

This commit is contained in:
Ubuntu 2026-04-25 13:09:59 +00:00
parent e073ef8716
commit b2fd7e1b87
4 changed files with 103 additions and 272 deletions

View File

@ -10,12 +10,11 @@ import threading
from typing import Dict, Any, List, Optional, Callable from typing import Dict, Any, List, Optional, Callable
from dataclasses import dataclass from dataclasses import dataclass
from zep_cloud.client import Zep from zep_cloud import EntityEdgeSourceTarget
from zep_cloud import EpisodeData, EntityEdgeSourceTarget
from ..config import Config from ..config import Config
from ..graph import get_graph_backend
from ..models.task import TaskManager, TaskStatus from ..models.task import TaskManager, TaskStatus
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
from .text_processor import TextProcessor from .text_processor import TextProcessor
from ..utils.locale import t, get_locale, set_locale from ..utils.locale import t, get_locale, set_locale
@ -43,12 +42,8 @@ class GraphBuilderService:
Responsible for calling the Zep API to build the knowledge graph. Responsible for calling the Zep API to build the knowledge graph.
""" """
def __init__(self, api_key: Optional[str] = None): def __init__(self):
self.api_key = api_key or Config.ZEP_API_KEY self._graph = get_graph_backend()
if not self.api_key:
raise ValueError("ZEP_API_KEY is not configured")
self.client = Zep(api_key=self.api_key)
self.task_manager = TaskManager() self.task_manager = TaskManager()
def build_graph_async( def build_graph_async(
@ -191,15 +186,13 @@ class GraphBuilderService:
self.task_manager.fail_task(task_id, error_msg) self.task_manager.fail_task(task_id, error_msg)
def create_graph(self, name: str) -> str: def create_graph(self, name: str) -> str:
"""Create a Zep graph (public method)""" """Create a graph (public method)"""
graph_id = f"mirofish_{uuid.uuid4().hex[:16]}" graph_id = f"mirofish_{uuid.uuid4().hex[:16]}"
self._graph.create_graph(
self.client.graph.create(
graph_id=graph_id, graph_id=graph_id,
name=name, name=name,
description="MiroFish Social Simulation Graph" description="MiroFish Social Simulation Graph"
) )
return graph_id return graph_id
def set_ontology(self, graph_id: str, ontology: Dict[str, Any]): def set_ontology(self, graph_id: str, ontology: Dict[str, Any]):
@ -283,9 +276,8 @@ class GraphBuilderService:
if source_targets: if source_targets:
edge_definitions[name] = (edge_class, source_targets) edge_definitions[name] = (edge_class, source_targets)
# Call Zep API to set ontology
if entity_types or edge_definitions: if entity_types or edge_definitions:
self.client.graph.set_ontology( self._graph.set_ontology(
graph_ids=[graph_id], graph_ids=[graph_id],
entities=entity_types if entity_types else None, entities=entity_types if entity_types else None,
edges=edge_definitions if edge_definitions else None, edges=edge_definitions if edge_definitions else None,
@ -314,25 +306,14 @@ class GraphBuilderService:
progress progress
) )
# Build episode data
episodes = [ episodes = [
EpisodeData(data=chunk, type="text") {"data": chunk, "type": "text"}
for chunk in batch_chunks for chunk in batch_chunks
] ]
# Send to Zep
try: try:
batch_result = self.client.graph.add_batch( returned_uuids = self._graph.add_batch(graph_id=graph_id, episodes=episodes)
graph_id=graph_id, episode_uuids.extend(returned_uuids)
episodes=episodes
)
# Collect returned episode UUIDs
if batch_result and isinstance(batch_result, list):
for ep in batch_result:
ep_uuid = getattr(ep, 'uuid_', None) or getattr(ep, 'uuid', None)
if ep_uuid:
episode_uuids.append(ep_uuid)
# Avoid sending requests too quickly # Avoid sending requests too quickly
time.sleep(1) time.sleep(1)
@ -376,7 +357,7 @@ class GraphBuilderService:
# Check processing status of each episode # Check processing status of each episode
for ep_uuid in list(pending_episodes): for ep_uuid in list(pending_episodes):
try: try:
episode = self.client.graph.episode.get(uuid_=ep_uuid) episode = self._graph.get_episode(ep_uuid)
is_processed = getattr(episode, 'processed', False) is_processed = getattr(episode, 'processed', False)
if is_processed: if is_processed:
@ -402,19 +383,14 @@ class GraphBuilderService:
def _get_graph_info(self, graph_id: str) -> GraphInfo: def _get_graph_info(self, graph_id: str) -> GraphInfo:
"""Retrieve graph info""" """Retrieve graph info"""
# Fetch nodes (paginated) nodes = self._graph.get_all_nodes(graph_id)
nodes = fetch_all_nodes(self.client, graph_id) edges = self._graph.get_all_edges(graph_id)
# Fetch edges (paginated)
edges = fetch_all_edges(self.client, graph_id)
# Count entity types
entity_types = set() entity_types = set()
for node in nodes: for node in nodes:
if node.labels: for label in node.get("labels", []):
for label in node.labels: if label not in ["Entity", "Node"]:
if label not in ["Entity", "Node"]: entity_types.add(label)
entity_types.add(label)
return GraphInfo( return GraphInfo(
graph_id=graph_id, graph_id=graph_id,
@ -424,83 +400,25 @@ class GraphBuilderService:
) )
def get_graph_data(self, graph_id: str) -> Dict[str, Any]: def get_graph_data(self, graph_id: str) -> Dict[str, Any]:
""" """Retrieve full graph data (nodes + edges with timestamps and attributes)."""
Retrieve full graph data (with detailed information). nodes = self._graph.get_all_nodes(graph_id)
edges = self._graph.get_all_edges(graph_id)
Args: node_map = {n["uuid"]: n.get("name", "") for n in nodes}
graph_id: graph ID
Returns:
Dictionary containing nodes and edges with timestamps, attributes, and other details
"""
nodes = fetch_all_nodes(self.client, graph_id)
edges = fetch_all_edges(self.client, graph_id)
# Build node map for looking up node names
node_map = {}
for node in nodes:
node_map[node.uuid_] = node.name or ""
nodes_data = []
for node in nodes:
# Get creation timestamp
created_at = getattr(node, 'created_at', None)
if created_at:
created_at = str(created_at)
nodes_data.append({
"uuid": node.uuid_,
"name": node.name,
"labels": node.labels or [],
"summary": node.summary or "",
"attributes": node.attributes or {},
"created_at": created_at,
})
edges_data = []
for edge in edges:
# Get timestamps
created_at = getattr(edge, 'created_at', None)
valid_at = getattr(edge, 'valid_at', None)
invalid_at = getattr(edge, 'invalid_at', None)
expired_at = getattr(edge, 'expired_at', None)
# Get episodes
episodes = getattr(edge, 'episodes', None) or getattr(edge, 'episode_ids', None)
if episodes and not isinstance(episodes, list):
episodes = [str(episodes)]
elif episodes:
episodes = [str(e) for e in episodes]
# Get fact_type
fact_type = getattr(edge, 'fact_type', None) or edge.name or ""
edges_data.append({
"uuid": edge.uuid_,
"name": edge.name or "",
"fact": edge.fact or "",
"fact_type": fact_type,
"source_node_uuid": edge.source_node_uuid,
"target_node_uuid": edge.target_node_uuid,
"source_node_name": node_map.get(edge.source_node_uuid, ""),
"target_node_name": node_map.get(edge.target_node_uuid, ""),
"attributes": edge.attributes or {},
"created_at": str(created_at) if created_at else None,
"valid_at": str(valid_at) if valid_at else None,
"invalid_at": str(invalid_at) if invalid_at else None,
"expired_at": str(expired_at) if expired_at else None,
"episodes": episodes or [],
})
return { return {
"graph_id": graph_id, "graph_id": graph_id,
"nodes": nodes_data, "nodes": nodes,
"edges": edges_data, "edges": [
"node_count": len(nodes_data), {**e, "source_node_name": node_map.get(e.get("source_node_uuid", ""), ""),
"edge_count": len(edges_data), "target_node_name": node_map.get(e.get("target_node_uuid", ""), "")}
for e in edges
],
"node_count": len(nodes),
"edge_count": len(edges),
} }
def delete_graph(self, graph_id: str): def delete_graph(self, graph_id: str):
"""Delete graph""" """Delete graph"""
self.client.graph.delete(graph_id=graph_id) self._graph.delete_graph(graph_id)

View File

@ -7,11 +7,9 @@ import time
from typing import Dict, Any, List, Optional, Set, Callable, TypeVar from typing import Dict, Any, List, Optional, Set, Callable, TypeVar
from dataclasses import dataclass, field from dataclasses import dataclass, field
from zep_cloud.client import Zep
from ..config import Config from ..config import Config
from ..graph import get_graph_backend
from ..utils.logger import get_logger from ..utils.logger import get_logger
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
logger = get_logger('mirofish.zep_entity_reader') logger = get_logger('mirofish.zep_entity_reader')
@ -79,11 +77,7 @@ class ZepEntityReader:
""" """
def __init__(self, api_key: Optional[str] = None): def __init__(self, api_key: Optional[str] = None):
self.api_key = api_key or Config.ZEP_API_KEY self._graph = get_graph_backend()
if not self.api_key:
raise ValueError("ZEP_API_KEY is not configured")
self.client = Zep(api_key=self.api_key)
def _call_with_retry( def _call_with_retry(
self, self,
@ -136,18 +130,7 @@ class ZepEntityReader:
""" """
logger.info(f"Fetching all nodes for graph {graph_id}...") logger.info(f"Fetching all nodes for graph {graph_id}...")
nodes = fetch_all_nodes(self.client, graph_id) nodes_data = self._graph.get_all_nodes(graph_id)
nodes_data = []
for node in nodes:
nodes_data.append({
"uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
"name": node.name or "",
"labels": node.labels or [],
"summary": node.summary or "",
"attributes": node.attributes or {},
})
logger.info(f"Fetched {len(nodes_data)} nodes") logger.info(f"Fetched {len(nodes_data)} nodes")
return nodes_data return nodes_data
@ -163,19 +146,7 @@ class ZepEntityReader:
""" """
logger.info(f"Fetching all edges for graph {graph_id}...") logger.info(f"Fetching all edges for graph {graph_id}...")
edges = fetch_all_edges(self.client, graph_id) edges_data = self._graph.get_all_edges(graph_id)
edges_data = []
for edge in edges:
edges_data.append({
"uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
"name": edge.name or "",
"fact": edge.fact or "",
"source_node_uuid": edge.source_node_uuid,
"target_node_uuid": edge.target_node_uuid,
"attributes": edge.attributes or {},
})
logger.info(f"Fetched {len(edges_data)} edges") logger.info(f"Fetched {len(edges_data)} edges")
return edges_data return edges_data
@ -190,24 +161,7 @@ class ZepEntityReader:
Edge list Edge list
""" """
try: try:
# Call Zep API with retry return self._graph.get_node_edges(node_uuid)
edges = self._call_with_retry(
func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid),
operation_name=f"get node edges (node={node_uuid[:8]}...)"
)
edges_data = []
for edge in edges:
edges_data.append({
"uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
"name": edge.name or "",
"fact": edge.fact or "",
"source_node_uuid": edge.source_node_uuid,
"target_node_uuid": edge.target_node_uuid,
"attributes": edge.attributes or {},
})
return edges_data
except Exception as e: except Exception as e:
logger.warning(f"Failed to get edges for node {node_uuid}: {str(e)}") logger.warning(f"Failed to get edges for node {node_uuid}: {str(e)}")
return [] return []
@ -346,11 +300,7 @@ class ZepEntityReader:
EntityNode or None EntityNode or None
""" """
try: try:
# Get the node with retry node = self._graph.get_node(entity_uuid)
node = self._call_with_retry(
func=lambda: self.client.graph.node.get(uuid_=entity_uuid),
operation_name=f"get node detail (uuid={entity_uuid[:8]}...)"
)
if not node: if not node:
return None return None
@ -397,11 +347,11 @@ class ZepEntityReader:
}) })
return EntityNode( return EntityNode(
uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), uuid=node.get("uuid", ""),
name=node.name or "", name=node.get("name", ""),
labels=node.labels or [], labels=node.get("labels", []),
summary=node.summary or "", summary=node.get("summary", ""),
attributes=node.attributes or {}, attributes=node.get("attributes", {}),
related_edges=related_edges, related_edges=related_edges,
related_nodes=related_nodes, related_nodes=related_nodes,
) )

View File

@ -12,9 +12,8 @@ from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from queue import Queue, Empty from queue import Queue, Empty
from zep_cloud.client import Zep
from ..config import Config from ..config import Config
from ..graph import get_graph_backend
from ..utils.logger import get_logger from ..utils.logger import get_logger
from ..utils.locale import get_locale, set_locale from ..utils.locale import get_locale, set_locale
@ -240,12 +239,7 @@ class ZepGraphMemoryUpdater:
api_key: Zep API key (optional; defaults to config value) api_key: Zep API key (optional; defaults to config value)
""" """
self.graph_id = graph_id self.graph_id = graph_id
self.api_key = api_key or Config.ZEP_API_KEY self._graph = get_graph_backend()
if not self.api_key:
raise ValueError("ZEP_API_KEY is not configured")
self.client = Zep(api_key=self.api_key)
# Activity queue # Activity queue
self._activity_queue: Queue = Queue() self._activity_queue: Queue = Queue()
@ -413,11 +407,7 @@ class ZepGraphMemoryUpdater:
# Send with retry # Send with retry
for attempt in range(self.MAX_RETRIES): for attempt in range(self.MAX_RETRIES):
try: try:
self.client.graph.add( self._graph.add_text(self.graph_id, combined_text)
graph_id=self.graph_id,
type="text",
data=combined_text
)
self._total_sent += 1 self._total_sent += 1
self._total_items_sent += len(activities) self._total_items_sent += len(activities)

View File

@ -13,13 +13,11 @@ import json
from typing import Dict, Any, List, Optional from typing import Dict, Any, List, Optional
from dataclasses import dataclass, field from dataclasses import dataclass, field
from zep_cloud.client import Zep
from ..config import Config from ..config import Config
from ..graph import get_graph_backend
from ..utils.logger import get_logger from ..utils.logger import get_logger
from ..utils.llm_client import LLMClient from ..utils.llm_client import LLMClient
from ..utils.locale import get_locale, t from ..utils.locale import get_locale, t
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
logger = get_logger('mirofish.zep_tools') logger = get_logger('mirofish.zep_tools')
@ -423,12 +421,7 @@ class ZepToolsService:
RETRY_DELAY = 2.0 RETRY_DELAY = 2.0
def __init__(self, api_key: Optional[str] = None, llm_client: Optional[LLMClient] = None): def __init__(self, api_key: Optional[str] = None, llm_client: Optional[LLMClient] = None):
self.api_key = api_key or Config.ZEP_API_KEY self._graph = get_graph_backend()
if not self.api_key:
raise ValueError("ZEP_API_KEY is not configured")
self.client = Zep(api_key=self.api_key)
# LLM client used by InsightForge to generate sub-queries
self._llm_client = llm_client self._llm_client = llm_client
logger.info(t("console.zepToolsInitialized")) logger.info(t("console.zepToolsInitialized"))
@ -485,51 +478,38 @@ class ZepToolsService:
""" """
logger.info(t("console.graphSearch", graphId=graph_id, query=query[:50])) logger.info(t("console.graphSearch", graphId=graph_id, query=query[:50]))
# Try using the Zep Cloud Search API
try: try:
search_results = self._call_with_retry( raw = self._graph.search(graph_id=graph_id, query=query, limit=limit, scope=scope)
func=lambda: self.client.graph.search(
graph_id=graph_id,
query=query,
limit=limit,
scope=scope,
reranker="cross_encoder"
),
operation_name=t("console.graphSearchOp", graphId=graph_id)
)
facts = [] facts = []
edges = [] edges = []
nodes = [] nodes = []
# Parse edge search results for edge in raw.get("edges", []) or []:
if hasattr(search_results, 'edges') and search_results.edges: fact = edge.get("fact", "") if isinstance(edge, dict) else getattr(edge, "fact", "")
for edge in search_results.edges: if fact:
if hasattr(edge, 'fact') and edge.fact: facts.append(fact)
facts.append(edge.fact) edges.append(edge if isinstance(edge, dict) else {
edges.append({ "uuid": getattr(edge, "uuid_", None) or getattr(edge, "uuid", ""),
"uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), "name": getattr(edge, "name", ""),
"name": getattr(edge, 'name', ''), "fact": getattr(edge, "fact", ""),
"fact": getattr(edge, 'fact', ''), "source_node_uuid": getattr(edge, "source_node_uuid", ""),
"source_node_uuid": getattr(edge, 'source_node_uuid', ''), "target_node_uuid": getattr(edge, "target_node_uuid", ""),
"target_node_uuid": getattr(edge, 'target_node_uuid', ''), })
})
for node in raw.get("nodes", []) or []:
# Parse node search results node_dict = node if isinstance(node, dict) else {
if hasattr(search_results, 'nodes') and search_results.nodes: "uuid": getattr(node, "uuid_", None) or getattr(node, "uuid", ""),
for node in search_results.nodes: "name": getattr(node, "name", ""),
nodes.append({ "labels": getattr(node, "labels", []),
"uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), "summary": getattr(node, "summary", ""),
"name": getattr(node, 'name', ''), }
"labels": getattr(node, 'labels', []), nodes.append(node_dict)
"summary": getattr(node, 'summary', ''), if node_dict.get("summary"):
}) facts.append(f"[{node_dict['name']}]: {node_dict['summary']}")
# Node summaries count as facts too
if hasattr(node, 'summary') and node.summary:
facts.append(f"[{node.name}]: {node.summary}")
logger.info(t("console.searchComplete", count=len(facts))) logger.info(t("console.searchComplete", count=len(facts)))
return SearchResult( return SearchResult(
facts=facts, facts=facts,
edges=edges, edges=edges,
@ -659,18 +639,18 @@ class ZepToolsService:
""" """
logger.info(t("console.fetchingAllNodes", graphId=graph_id)) logger.info(t("console.fetchingAllNodes", graphId=graph_id))
nodes = fetch_all_nodes(self.client, graph_id) nodes = self._graph.get_all_nodes(graph_id)
result = [] result = [
for node in nodes: NodeInfo(
node_uuid = getattr(node, 'uuid_', None) or getattr(node, 'uuid', None) or "" uuid=n.get("uuid", ""),
result.append(NodeInfo( name=n.get("name", ""),
uuid=str(node_uuid) if node_uuid else "", labels=n.get("labels", []),
name=node.name or "", summary=n.get("summary", ""),
labels=node.labels or [], attributes=n.get("attributes", {})
summary=node.summary or "", )
attributes=node.attributes or {} for n in nodes
)) ]
logger.info(t("console.fetchedNodes", count=len(result))) logger.info(t("console.fetchedNodes", count=len(result)))
return result return result
@ -688,26 +668,22 @@ class ZepToolsService:
""" """
logger.info(t("console.fetchingAllEdges", graphId=graph_id)) logger.info(t("console.fetchingAllEdges", graphId=graph_id))
edges = fetch_all_edges(self.client, graph_id) edges = self._graph.get_all_edges(graph_id)
result = [] result = []
for edge in edges: for e in edges:
edge_uuid = getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', None) or ""
edge_info = EdgeInfo( edge_info = EdgeInfo(
uuid=str(edge_uuid) if edge_uuid else "", uuid=e.get("uuid", ""),
name=edge.name or "", name=e.get("name", ""),
fact=edge.fact or "", fact=e.get("fact", ""),
source_node_uuid=edge.source_node_uuid or "", source_node_uuid=e.get("source_node_uuid", ""),
target_node_uuid=edge.target_node_uuid or "" target_node_uuid=e.get("target_node_uuid", ""),
) )
# Add temporal info
if include_temporal: if include_temporal:
edge_info.created_at = getattr(edge, 'created_at', None) edge_info.created_at = e.get("created_at")
edge_info.valid_at = getattr(edge, 'valid_at', None) edge_info.valid_at = e.get("valid_at")
edge_info.invalid_at = getattr(edge, 'invalid_at', None) edge_info.invalid_at = e.get("invalid_at")
edge_info.expired_at = getattr(edge, 'expired_at', None) edge_info.expired_at = e.get("expired_at")
result.append(edge_info) result.append(edge_info)
logger.info(t("console.fetchedEdges", count=len(result))) logger.info(t("console.fetchedEdges", count=len(result)))
@ -726,20 +702,17 @@ class ZepToolsService:
logger.info(t("console.fetchingNodeDetail", uuid=node_uuid[:8])) logger.info(t("console.fetchingNodeDetail", uuid=node_uuid[:8]))
try: try:
node = self._call_with_retry( node = self._graph.get_node(node_uuid)
func=lambda: self.client.graph.node.get(uuid_=node_uuid),
operation_name=t("console.fetchNodeDetailOp", uuid=node_uuid[:8])
)
if not node: if not node:
return None return None
return NodeInfo( return NodeInfo(
uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), uuid=node.get("uuid", ""),
name=node.name or "", name=node.get("name", ""),
labels=node.labels or [], labels=node.get("labels", []),
summary=node.summary or "", summary=node.get("summary", ""),
attributes=node.attributes or {} attributes=node.get("attributes", {})
) )
except Exception as e: except Exception as e:
logger.error(t("console.fetchNodeDetailFailed", error=str(e))) logger.error(t("console.fetchNodeDetailFailed", error=str(e)))