115 lines
3.6 KiB
Python
115 lines
3.6 KiB
Python
"""
|
|
Zep / OpenZep graph backend implementation.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
from typing import Any, Dict, List, Optional
|
|
from urllib.error import HTTPError, URLError
|
|
from urllib.request import Request, urlopen
|
|
|
|
from zep_cloud.client import Zep
|
|
|
|
from ..config import Config
|
|
from ..utils.zep_paging import fetch_all_edges, fetch_all_nodes
|
|
from .base import GraphBackend
|
|
|
|
|
|
class ZepGraphBackend(GraphBackend):
|
|
"""Graph backend backed by Zep Cloud or OpenZep."""
|
|
|
|
def __init__(self, api_key: Optional[str] = None):
|
|
self.api_key = Config.ZEP_API_KEY if api_key is None else api_key
|
|
errors = Config.get_zep_config_errors(api_key=self.api_key)
|
|
if errors:
|
|
raise ValueError("; ".join(errors))
|
|
|
|
self.client = Zep(**Config.get_zep_client_kwargs(api_key=self.api_key))
|
|
|
|
@property
|
|
def raw_client(self) -> Zep:
|
|
return self.client
|
|
|
|
def create_graph(self, graph_id: str, name: str, description: str) -> None:
|
|
self.client.graph.create(
|
|
graph_id=graph_id,
|
|
name=name,
|
|
description=description,
|
|
)
|
|
|
|
def set_ontology(
|
|
self,
|
|
graph_id: str,
|
|
entities: Any = None,
|
|
edges: Any = None,
|
|
) -> None:
|
|
self.client.graph.set_ontology(
|
|
graph_ids=[graph_id],
|
|
entities=entities,
|
|
edges=edges,
|
|
)
|
|
|
|
def add_batch(self, graph_id: str, episodes: List[Any]) -> List[Any]:
|
|
return self.client.graph.add_batch(graph_id=graph_id, episodes=episodes)
|
|
|
|
def add_text(self, graph_id: str, data: str) -> Any:
|
|
return self.client.graph.add(graph_id=graph_id, type="text", data=data)
|
|
|
|
def get_episode(self, episode_uuid: str) -> Any:
|
|
return self.client.graph.episode.get(uuid_=episode_uuid)
|
|
|
|
def search(
|
|
self,
|
|
graph_id: str,
|
|
query: str,
|
|
limit: int = 10,
|
|
scope: str = "edges",
|
|
reranker: Optional[str] = None,
|
|
) -> Any:
|
|
kwargs = {
|
|
"graph_id": graph_id,
|
|
"query": query,
|
|
"limit": limit,
|
|
"scope": scope,
|
|
}
|
|
if reranker:
|
|
kwargs["reranker"] = reranker
|
|
return self.client.graph.search(**kwargs)
|
|
|
|
def get_all_nodes(self, graph_id: str) -> List[Any]:
|
|
return fetch_all_nodes(self.client, graph_id)
|
|
|
|
def get_all_edges(self, graph_id: str) -> List[Any]:
|
|
return fetch_all_edges(self.client, graph_id)
|
|
|
|
def get_node(self, node_uuid: str) -> Any:
|
|
return self.client.graph.node.get(uuid_=node_uuid)
|
|
|
|
def get_node_edges(self, node_uuid: str) -> List[Any]:
|
|
return self.client.graph.node.get_entity_edges(node_uuid=node_uuid)
|
|
|
|
def delete_graph(self, graph_id: str) -> None:
|
|
self.client.graph.delete(graph_id=graph_id)
|
|
|
|
def get_live_graph_statistics(self, graph_id: str) -> Optional[Dict[str, int]]:
|
|
if not Config.ZEP_BASE_URL:
|
|
return None
|
|
|
|
base_url = Config.ZEP_BASE_URL.rstrip("/")
|
|
request = Request(f"{base_url}/graph/{graph_id}/statistics")
|
|
if self.api_key:
|
|
request.add_header("Authorization", f"Bearer {self.api_key}")
|
|
|
|
try:
|
|
with urlopen(request, timeout=10) as response:
|
|
payload = json.loads(response.read().decode("utf-8"))
|
|
except (HTTPError, URLError, TimeoutError, OSError, json.JSONDecodeError):
|
|
return None
|
|
|
|
return {
|
|
"node_count": max(0, int(payload.get("node_count", 0) or 0)),
|
|
"edge_count": max(0, int(payload.get("edge_count", 0) or 0)),
|
|
"episode_count": max(0, int(payload.get("episode_count", 0) or 0)),
|
|
}
|