From 247ecc86ae751f3f88eec1e2a6af15e9bb609f21 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Sat, 25 Apr 2026 11:29:59 +0000 Subject: [PATCH] feat(graph): add ZepBackend adapter implementing GraphBackend --- backend/app/graph/zep_backend.py | 151 ++++++++++++++++++++++++++++ backend/tests/test_graph_factory.py | 18 ++++ 2 files changed, 169 insertions(+) create mode 100644 backend/app/graph/zep_backend.py diff --git a/backend/app/graph/zep_backend.py b/backend/app/graph/zep_backend.py new file mode 100644 index 00000000..478a214d --- /dev/null +++ b/backend/app/graph/zep_backend.py @@ -0,0 +1,151 @@ +"""Zep Cloud implementation of GraphBackend.""" +import time +from typing import Any, Dict, List, Optional + +from zep_cloud.client import Zep +from zep_cloud import InternalServerError + +from .base import GraphBackend +from ..config import Config +from ..utils.logger import get_logger + +logger = get_logger('mirofish.graph.zep') + +_PAGE_SIZE = 100 +_MAX_ITEMS = 2000 +_MAX_RETRIES = 3 +_RETRY_DELAY = 2.0 + + +def _fetch_page_with_retry(api_call, *args, max_retries=_MAX_RETRIES, retry_delay=_RETRY_DELAY, **kwargs): + for attempt in range(max_retries): + try: + return api_call(*args, **kwargs) or [] + except (ConnectionError, TimeoutError, OSError, InternalServerError): + if attempt == max_retries - 1: + raise + time.sleep(retry_delay * (2 ** attempt)) + return [] + + +def _fetch_all(list_fn, graph_id: str, cursor_key: str = "uuid_cursor") -> List[Any]: + results, cursor = [], None + while True: + kwargs = {"limit": _PAGE_SIZE} + if cursor: + kwargs[cursor_key] = cursor + batch = _fetch_page_with_retry(list_fn, graph_id, **kwargs) + results.extend(batch) + if not batch or len(batch) < _PAGE_SIZE or len(results) >= _MAX_ITEMS: + break + cursor = batch[-1].uuid_ + return results + + +class ZepBackend(GraphBackend): + def __init__(self, api_key: Optional[str] = None): + self.api_key = api_key or Config.ZEP_API_KEY + if not self.api_key: + raise ValueError("ZEP_API_KEY is not configured") + self._client = Zep(api_key=self.api_key) + + def create_graph(self, graph_id: str, name: str, description: str = "") -> None: + self._client.graph.create(graph_id=graph_id, name=name, description=description) + + def set_ontology(self, graph_ids: List[str], entities: Dict[str, Any], edges: Dict[str, Any]) -> None: + self._client.graph.set_ontology(graph_ids=graph_ids, entities=entities, edges=edges) + + def add_batch(self, graph_id: str, episodes: List[Any]) -> List[str]: + from zep_cloud import EpisodeData + ep_objects = [ + EpisodeData(data=ep["data"], type=ep.get("type", "text")) + if isinstance(ep, dict) else ep + for ep in episodes + ] + result = self._client.graph.add_batch(graph_id=graph_id, episodes=ep_objects) + return [ep.uuid_ for ep in (result or [])] + + def get_episode(self, uuid_: str) -> Any: + return self._client.graph.episode.get(uuid_=uuid_) + + def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]: + nodes = _fetch_all(self._client.graph.node.get_by_graph_id, graph_id) + return [ + { + "uuid": getattr(n, "uuid_", None) or getattr(n, "uuid", None), + "name": getattr(n, "name", ""), + "labels": list(getattr(n, "labels", []) or []), + "summary": getattr(n, "summary", ""), + "attributes": dict(getattr(n, "attributes", {}) or {}), + "created_at": str(getattr(n, "created_at", "")), + } + for n in nodes + ] + + def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]: + edges = _fetch_all(self._client.graph.edge.get_by_graph_id, graph_id) + return [ + { + "uuid": getattr(e, "uuid_", None) or getattr(e, "uuid", None), + "name": getattr(e, "name", ""), + "fact": getattr(e, "fact", ""), + "source_node_uuid": getattr(e, "source_node_uuid", None), + "target_node_uuid": getattr(e, "target_node_uuid", None), + "fact_type": getattr(e, "fact_type", None), + "attributes": dict(getattr(e, "attributes", {}) or {}), + "created_at": str(getattr(e, "created_at", "")), + "valid_at": str(getattr(e, "valid_at", "")), + "invalid_at": str(getattr(e, "invalid_at", "")), + "expired_at": str(getattr(e, "expired_at", "")), + "episodes": list(getattr(e, "episodes", []) or []), + } + for e in edges + ] + + def get_node(self, uuid_: str) -> Dict[str, Any]: + n = self._client.graph.node.get(uuid_=uuid_) + return { + "uuid": getattr(n, "uuid_", None) or getattr(n, "uuid", None), + "name": getattr(n, "name", ""), + "labels": list(getattr(n, "labels", []) or []), + "summary": getattr(n, "summary", ""), + "attributes": dict(getattr(n, "attributes", {}) or {}), + } + + def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]: + for attempt in range(_MAX_RETRIES): + try: + edges = self._client.graph.node.get_entity_edges(node_uuid=node_uuid) or [] + return [ + { + "uuid": getattr(e, "uuid_", None) or getattr(e, "uuid", None), + "name": getattr(e, "name", ""), + "fact": getattr(e, "fact", ""), + "source_node_uuid": getattr(e, "source_node_uuid", None), + "target_node_uuid": getattr(e, "target_node_uuid", None), + } + for e in edges + ] + except (ConnectionError, TimeoutError, OSError, InternalServerError): + if attempt == _MAX_RETRIES - 1: + raise + time.sleep(_RETRY_DELAY * (2 ** attempt)) + return [] + + def search(self, graph_id: str, query: str, limit: int = 10, scope: str = "edges") -> Dict[str, Any]: + result = self._client.graph.search( + graph_id=graph_id, + query=query, + limit=limit, + scope=scope, + reranker="cross_encoder", + ) + edges = getattr(result, "edges", []) or [] + nodes = getattr(result, "nodes", []) or [] + return {"edges": edges, "nodes": nodes} + + def add_text(self, graph_id: str, data: str) -> None: + self._client.graph.add(graph_id=graph_id, type="text", data=data) + + def delete_graph(self, graph_id: str) -> None: + self._client.graph.delete(graph_id=graph_id) diff --git a/backend/tests/test_graph_factory.py b/backend/tests/test_graph_factory.py index eb4f0976..48327310 100644 --- a/backend/tests/test_graph_factory.py +++ b/backend/tests/test_graph_factory.py @@ -36,6 +36,24 @@ def test_config_zep_errors_when_key_missing(): cfg_mod.Config.ZEP_API_KEY = orig_key +def test_zep_backend_implements_interface(): + from backend.app.graph.base import GraphBackend + from backend.app.graph.zep_backend import ZepBackend + assert issubclass(ZepBackend, GraphBackend) + + +def test_zep_backend_raises_without_key(): + import backend.app.config as cfg_mod + orig = cfg_mod.Config.ZEP_API_KEY + try: + cfg_mod.Config.ZEP_API_KEY = None + from backend.app.graph.zep_backend import ZepBackend + with pytest.raises(ValueError, match="ZEP_API_KEY"): + ZepBackend() + finally: + cfg_mod.Config.ZEP_API_KEY = orig + + def test_config_graphiti_errors_when_missing(): import backend.app.config as cfg_mod orig_backend = cfg_mod.Config.GRAPH_BACKEND