refactor(services): replace direct Zep SDK calls with GraphBackend interface

This commit is contained in:
Ubuntu 2026-04-25 13:09:59 +00:00
parent e073ef8716
commit b2fd7e1b87
4 changed files with 103 additions and 272 deletions

View File

@ -10,12 +10,11 @@ import threading
from typing import Dict, Any, List, Optional, Callable
from dataclasses import dataclass
from zep_cloud.client import Zep
from zep_cloud import EpisodeData, EntityEdgeSourceTarget
from zep_cloud import EntityEdgeSourceTarget
from ..config import Config
from ..graph import get_graph_backend
from ..models.task import TaskManager, TaskStatus
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
from .text_processor import TextProcessor
from ..utils.locale import t, get_locale, set_locale
@ -43,12 +42,8 @@ class GraphBuilderService:
Responsible for calling the Zep API to build the knowledge graph.
"""
def __init__(self, api_key: Optional[str] = None):
self.api_key = api_key or Config.ZEP_API_KEY
if not self.api_key:
raise ValueError("ZEP_API_KEY is not configured")
self.client = Zep(api_key=self.api_key)
def __init__(self):
self._graph = get_graph_backend()
self.task_manager = TaskManager()
def build_graph_async(
@ -191,15 +186,13 @@ class GraphBuilderService:
self.task_manager.fail_task(task_id, error_msg)
def create_graph(self, name: str) -> str:
"""Create a Zep graph (public method)"""
"""Create a graph (public method)"""
graph_id = f"mirofish_{uuid.uuid4().hex[:16]}"
self.client.graph.create(
self._graph.create_graph(
graph_id=graph_id,
name=name,
description="MiroFish Social Simulation Graph"
)
return graph_id
def set_ontology(self, graph_id: str, ontology: Dict[str, Any]):
@ -283,9 +276,8 @@ class GraphBuilderService:
if source_targets:
edge_definitions[name] = (edge_class, source_targets)
# Call Zep API to set ontology
if entity_types or edge_definitions:
self.client.graph.set_ontology(
self._graph.set_ontology(
graph_ids=[graph_id],
entities=entity_types if entity_types else None,
edges=edge_definitions if edge_definitions else None,
@ -314,25 +306,14 @@ class GraphBuilderService:
progress
)
# Build episode data
episodes = [
EpisodeData(data=chunk, type="text")
{"data": chunk, "type": "text"}
for chunk in batch_chunks
]
# Send to Zep
try:
batch_result = self.client.graph.add_batch(
graph_id=graph_id,
episodes=episodes
)
# Collect returned episode UUIDs
if batch_result and isinstance(batch_result, list):
for ep in batch_result:
ep_uuid = getattr(ep, 'uuid_', None) or getattr(ep, 'uuid', None)
if ep_uuid:
episode_uuids.append(ep_uuid)
returned_uuids = self._graph.add_batch(graph_id=graph_id, episodes=episodes)
episode_uuids.extend(returned_uuids)
# Avoid sending requests too quickly
time.sleep(1)
@ -376,7 +357,7 @@ class GraphBuilderService:
# Check processing status of each episode
for ep_uuid in list(pending_episodes):
try:
episode = self.client.graph.episode.get(uuid_=ep_uuid)
episode = self._graph.get_episode(ep_uuid)
is_processed = getattr(episode, 'processed', False)
if is_processed:
@ -402,19 +383,14 @@ class GraphBuilderService:
def _get_graph_info(self, graph_id: str) -> GraphInfo:
"""Retrieve graph info"""
# Fetch nodes (paginated)
nodes = fetch_all_nodes(self.client, graph_id)
nodes = self._graph.get_all_nodes(graph_id)
edges = self._graph.get_all_edges(graph_id)
# Fetch edges (paginated)
edges = fetch_all_edges(self.client, graph_id)
# Count entity types
entity_types = set()
for node in nodes:
if node.labels:
for label in node.labels:
if label not in ["Entity", "Node"]:
entity_types.add(label)
for label in node.get("labels", []):
if label not in ["Entity", "Node"]:
entity_types.add(label)
return GraphInfo(
graph_id=graph_id,
@ -424,83 +400,25 @@ class GraphBuilderService:
)
def get_graph_data(self, graph_id: str) -> Dict[str, Any]:
"""
Retrieve full graph data (with detailed information).
"""Retrieve full graph data (nodes + edges with timestamps and attributes)."""
nodes = self._graph.get_all_nodes(graph_id)
edges = self._graph.get_all_edges(graph_id)
Args:
graph_id: graph ID
node_map = {n["uuid"]: n.get("name", "") for n in nodes}
Returns:
Dictionary containing nodes and edges with timestamps, attributes, and other details
"""
nodes = fetch_all_nodes(self.client, graph_id)
edges = fetch_all_edges(self.client, graph_id)
# Build node map for looking up node names
node_map = {}
for node in nodes:
node_map[node.uuid_] = node.name or ""
nodes_data = []
for node in nodes:
# Get creation timestamp
created_at = getattr(node, 'created_at', None)
if created_at:
created_at = str(created_at)
nodes_data.append({
"uuid": node.uuid_,
"name": node.name,
"labels": node.labels or [],
"summary": node.summary or "",
"attributes": node.attributes or {},
"created_at": created_at,
})
edges_data = []
for edge in edges:
# Get timestamps
created_at = getattr(edge, 'created_at', None)
valid_at = getattr(edge, 'valid_at', None)
invalid_at = getattr(edge, 'invalid_at', None)
expired_at = getattr(edge, 'expired_at', None)
# Get episodes
episodes = getattr(edge, 'episodes', None) or getattr(edge, 'episode_ids', None)
if episodes and not isinstance(episodes, list):
episodes = [str(episodes)]
elif episodes:
episodes = [str(e) for e in episodes]
# Get fact_type
fact_type = getattr(edge, 'fact_type', None) or edge.name or ""
edges_data.append({
"uuid": edge.uuid_,
"name": edge.name or "",
"fact": edge.fact or "",
"fact_type": fact_type,
"source_node_uuid": edge.source_node_uuid,
"target_node_uuid": edge.target_node_uuid,
"source_node_name": node_map.get(edge.source_node_uuid, ""),
"target_node_name": node_map.get(edge.target_node_uuid, ""),
"attributes": edge.attributes or {},
"created_at": str(created_at) if created_at else None,
"valid_at": str(valid_at) if valid_at else None,
"invalid_at": str(invalid_at) if invalid_at else None,
"expired_at": str(expired_at) if expired_at else None,
"episodes": episodes or [],
})
return {
"graph_id": graph_id,
"nodes": nodes_data,
"edges": edges_data,
"node_count": len(nodes_data),
"edge_count": len(edges_data),
"nodes": nodes,
"edges": [
{**e, "source_node_name": node_map.get(e.get("source_node_uuid", ""), ""),
"target_node_name": node_map.get(e.get("target_node_uuid", ""), "")}
for e in edges
],
"node_count": len(nodes),
"edge_count": len(edges),
}
def delete_graph(self, graph_id: str):
"""Delete graph"""
self.client.graph.delete(graph_id=graph_id)
self._graph.delete_graph(graph_id)

View File

@ -7,11 +7,9 @@ import time
from typing import Dict, Any, List, Optional, Set, Callable, TypeVar
from dataclasses import dataclass, field
from zep_cloud.client import Zep
from ..config import Config
from ..graph import get_graph_backend
from ..utils.logger import get_logger
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
logger = get_logger('mirofish.zep_entity_reader')
@ -79,11 +77,7 @@ class ZepEntityReader:
"""
def __init__(self, api_key: Optional[str] = None):
self.api_key = api_key or Config.ZEP_API_KEY
if not self.api_key:
raise ValueError("ZEP_API_KEY is not configured")
self.client = Zep(api_key=self.api_key)
self._graph = get_graph_backend()
def _call_with_retry(
self,
@ -136,18 +130,7 @@ class ZepEntityReader:
"""
logger.info(f"Fetching all nodes for graph {graph_id}...")
nodes = fetch_all_nodes(self.client, graph_id)
nodes_data = []
for node in nodes:
nodes_data.append({
"uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
"name": node.name or "",
"labels": node.labels or [],
"summary": node.summary or "",
"attributes": node.attributes or {},
})
nodes_data = self._graph.get_all_nodes(graph_id)
logger.info(f"Fetched {len(nodes_data)} nodes")
return nodes_data
@ -163,19 +146,7 @@ class ZepEntityReader:
"""
logger.info(f"Fetching all edges for graph {graph_id}...")
edges = fetch_all_edges(self.client, graph_id)
edges_data = []
for edge in edges:
edges_data.append({
"uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
"name": edge.name or "",
"fact": edge.fact or "",
"source_node_uuid": edge.source_node_uuid,
"target_node_uuid": edge.target_node_uuid,
"attributes": edge.attributes or {},
})
edges_data = self._graph.get_all_edges(graph_id)
logger.info(f"Fetched {len(edges_data)} edges")
return edges_data
@ -190,24 +161,7 @@ class ZepEntityReader:
Edge list
"""
try:
# Call Zep API with retry
edges = self._call_with_retry(
func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid),
operation_name=f"get node edges (node={node_uuid[:8]}...)"
)
edges_data = []
for edge in edges:
edges_data.append({
"uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
"name": edge.name or "",
"fact": edge.fact or "",
"source_node_uuid": edge.source_node_uuid,
"target_node_uuid": edge.target_node_uuid,
"attributes": edge.attributes or {},
})
return edges_data
return self._graph.get_node_edges(node_uuid)
except Exception as e:
logger.warning(f"Failed to get edges for node {node_uuid}: {str(e)}")
return []
@ -346,11 +300,7 @@ class ZepEntityReader:
EntityNode or None
"""
try:
# Get the node with retry
node = self._call_with_retry(
func=lambda: self.client.graph.node.get(uuid_=entity_uuid),
operation_name=f"get node detail (uuid={entity_uuid[:8]}...)"
)
node = self._graph.get_node(entity_uuid)
if not node:
return None
@ -397,11 +347,11 @@ class ZepEntityReader:
})
return EntityNode(
uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
name=node.name or "",
labels=node.labels or [],
summary=node.summary or "",
attributes=node.attributes or {},
uuid=node.get("uuid", ""),
name=node.get("name", ""),
labels=node.get("labels", []),
summary=node.get("summary", ""),
attributes=node.get("attributes", {}),
related_edges=related_edges,
related_nodes=related_nodes,
)

View File

@ -12,9 +12,8 @@ from dataclasses import dataclass
from datetime import datetime
from queue import Queue, Empty
from zep_cloud.client import Zep
from ..config import Config
from ..graph import get_graph_backend
from ..utils.logger import get_logger
from ..utils.locale import get_locale, set_locale
@ -240,12 +239,7 @@ class ZepGraphMemoryUpdater:
api_key: Zep API key (optional; defaults to config value)
"""
self.graph_id = graph_id
self.api_key = api_key or Config.ZEP_API_KEY
if not self.api_key:
raise ValueError("ZEP_API_KEY is not configured")
self.client = Zep(api_key=self.api_key)
self._graph = get_graph_backend()
# Activity queue
self._activity_queue: Queue = Queue()
@ -413,11 +407,7 @@ class ZepGraphMemoryUpdater:
# Send with retry
for attempt in range(self.MAX_RETRIES):
try:
self.client.graph.add(
graph_id=self.graph_id,
type="text",
data=combined_text
)
self._graph.add_text(self.graph_id, combined_text)
self._total_sent += 1
self._total_items_sent += len(activities)

View File

@ -13,13 +13,11 @@ import json
from typing import Dict, Any, List, Optional
from dataclasses import dataclass, field
from zep_cloud.client import Zep
from ..config import Config
from ..graph import get_graph_backend
from ..utils.logger import get_logger
from ..utils.llm_client import LLMClient
from ..utils.locale import get_locale, t
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
logger = get_logger('mirofish.zep_tools')
@ -423,12 +421,7 @@ class ZepToolsService:
RETRY_DELAY = 2.0
def __init__(self, api_key: Optional[str] = None, llm_client: Optional[LLMClient] = None):
self.api_key = api_key or Config.ZEP_API_KEY
if not self.api_key:
raise ValueError("ZEP_API_KEY is not configured")
self.client = Zep(api_key=self.api_key)
# LLM client used by InsightForge to generate sub-queries
self._graph = get_graph_backend()
self._llm_client = llm_client
logger.info(t("console.zepToolsInitialized"))
@ -485,51 +478,38 @@ class ZepToolsService:
"""
logger.info(t("console.graphSearch", graphId=graph_id, query=query[:50]))
# Try using the Zep Cloud Search API
try:
search_results = self._call_with_retry(
func=lambda: self.client.graph.search(
graph_id=graph_id,
query=query,
limit=limit,
scope=scope,
reranker="cross_encoder"
),
operation_name=t("console.graphSearchOp", graphId=graph_id)
)
raw = self._graph.search(graph_id=graph_id, query=query, limit=limit, scope=scope)
facts = []
edges = []
nodes = []
# Parse edge search results
if hasattr(search_results, 'edges') and search_results.edges:
for edge in search_results.edges:
if hasattr(edge, 'fact') and edge.fact:
facts.append(edge.fact)
edges.append({
"uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
"name": getattr(edge, 'name', ''),
"fact": getattr(edge, 'fact', ''),
"source_node_uuid": getattr(edge, 'source_node_uuid', ''),
"target_node_uuid": getattr(edge, 'target_node_uuid', ''),
})
# Parse node search results
if hasattr(search_results, 'nodes') and search_results.nodes:
for node in search_results.nodes:
nodes.append({
"uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
"name": getattr(node, 'name', ''),
"labels": getattr(node, 'labels', []),
"summary": getattr(node, 'summary', ''),
})
# Node summaries count as facts too
if hasattr(node, 'summary') and node.summary:
facts.append(f"[{node.name}]: {node.summary}")
for edge in raw.get("edges", []) or []:
fact = edge.get("fact", "") if isinstance(edge, dict) else getattr(edge, "fact", "")
if fact:
facts.append(fact)
edges.append(edge if isinstance(edge, dict) else {
"uuid": getattr(edge, "uuid_", None) or getattr(edge, "uuid", ""),
"name": getattr(edge, "name", ""),
"fact": getattr(edge, "fact", ""),
"source_node_uuid": getattr(edge, "source_node_uuid", ""),
"target_node_uuid": getattr(edge, "target_node_uuid", ""),
})
for node in raw.get("nodes", []) or []:
node_dict = node if isinstance(node, dict) else {
"uuid": getattr(node, "uuid_", None) or getattr(node, "uuid", ""),
"name": getattr(node, "name", ""),
"labels": getattr(node, "labels", []),
"summary": getattr(node, "summary", ""),
}
nodes.append(node_dict)
if node_dict.get("summary"):
facts.append(f"[{node_dict['name']}]: {node_dict['summary']}")
logger.info(t("console.searchComplete", count=len(facts)))
return SearchResult(
facts=facts,
edges=edges,
@ -659,18 +639,18 @@ class ZepToolsService:
"""
logger.info(t("console.fetchingAllNodes", graphId=graph_id))
nodes = fetch_all_nodes(self.client, graph_id)
nodes = self._graph.get_all_nodes(graph_id)
result = []
for node in nodes:
node_uuid = getattr(node, 'uuid_', None) or getattr(node, 'uuid', None) or ""
result.append(NodeInfo(
uuid=str(node_uuid) if node_uuid else "",
name=node.name or "",
labels=node.labels or [],
summary=node.summary or "",
attributes=node.attributes or {}
))
result = [
NodeInfo(
uuid=n.get("uuid", ""),
name=n.get("name", ""),
labels=n.get("labels", []),
summary=n.get("summary", ""),
attributes=n.get("attributes", {})
)
for n in nodes
]
logger.info(t("console.fetchedNodes", count=len(result)))
return result
@ -688,26 +668,22 @@ class ZepToolsService:
"""
logger.info(t("console.fetchingAllEdges", graphId=graph_id))
edges = fetch_all_edges(self.client, graph_id)
edges = self._graph.get_all_edges(graph_id)
result = []
for edge in edges:
edge_uuid = getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', None) or ""
for e in edges:
edge_info = EdgeInfo(
uuid=str(edge_uuid) if edge_uuid else "",
name=edge.name or "",
fact=edge.fact or "",
source_node_uuid=edge.source_node_uuid or "",
target_node_uuid=edge.target_node_uuid or ""
uuid=e.get("uuid", ""),
name=e.get("name", ""),
fact=e.get("fact", ""),
source_node_uuid=e.get("source_node_uuid", ""),
target_node_uuid=e.get("target_node_uuid", ""),
)
# Add temporal info
if include_temporal:
edge_info.created_at = getattr(edge, 'created_at', None)
edge_info.valid_at = getattr(edge, 'valid_at', None)
edge_info.invalid_at = getattr(edge, 'invalid_at', None)
edge_info.expired_at = getattr(edge, 'expired_at', None)
edge_info.created_at = e.get("created_at")
edge_info.valid_at = e.get("valid_at")
edge_info.invalid_at = e.get("invalid_at")
edge_info.expired_at = e.get("expired_at")
result.append(edge_info)
logger.info(t("console.fetchedEdges", count=len(result)))
@ -726,20 +702,17 @@ class ZepToolsService:
logger.info(t("console.fetchingNodeDetail", uuid=node_uuid[:8]))
try:
node = self._call_with_retry(
func=lambda: self.client.graph.node.get(uuid_=node_uuid),
operation_name=t("console.fetchNodeDetailOp", uuid=node_uuid[:8])
)
node = self._graph.get_node(node_uuid)
if not node:
return None
return NodeInfo(
uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
name=node.name or "",
labels=node.labels or [],
summary=node.summary or "",
attributes=node.attributes or {}
uuid=node.get("uuid", ""),
name=node.get("name", ""),
labels=node.get("labels", []),
summary=node.get("summary", ""),
attributes=node.get("attributes", {})
)
except Exception as e:
logger.error(t("console.fetchNodeDetailFailed", error=str(e)))