diff --git a/backend/app/graph/graphiti_backend.py b/backend/app/graph/graphiti_backend.py new file mode 100644 index 00000000..5a2436d4 --- /dev/null +++ b/backend/app/graph/graphiti_backend.py @@ -0,0 +1,218 @@ +"""Graphiti + Neo4j implementation of GraphBackend.""" +import asyncio +import threading +import uuid as uuid_mod +from typing import Any, Dict, List, Optional + +from .base import GraphBackend +from ..config import Config +from ..utils.logger import get_logger + +logger = get_logger('mirofish.graph.graphiti') + + +def _run_async(coro): + """Run an async coroutine from a sync context using a dedicated thread loop.""" + loop = _get_event_loop() + future = asyncio.run_coroutine_threadsafe(coro, loop) + return future.result(timeout=120) + + +_loop: Optional[asyncio.AbstractEventLoop] = None +_loop_thread: Optional[threading.Thread] = None +_loop_lock = threading.Lock() + + +def _get_event_loop() -> asyncio.AbstractEventLoop: + global _loop, _loop_thread + with _loop_lock: + if _loop is None or not _loop.is_running(): + _loop = asyncio.new_event_loop() + _loop_thread = threading.Thread(target=_loop.run_forever, daemon=True) + _loop_thread.start() + return _loop + + +class GraphitiBackend(GraphBackend): + def __init__( + self, + uri: Optional[str] = None, + user: Optional[str] = None, + password: Optional[str] = None, + ): + self._uri = uri or Config.NEO4J_URI + self._user = user or Config.NEO4J_USER + self._password = password or Config.NEO4J_PASSWORD + if not self._password: + raise ValueError("NEO4J_PASSWORD is not configured") + self._client = self._build_client() + + def _build_client(self): + from graphiti_core import Graphiti + from graphiti_core.llm_client.openai_client import OpenAIClient + from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig + from neo4j import AsyncGraphDatabase + + llm_client = OpenAIClient( + api_key=Config.LLM_API_KEY, + model=Config.LLM_MODEL_NAME, + base_url=Config.LLM_BASE_URL, + ) + embedder = OpenAIEmbedder( + OpenAIEmbedderConfig( + api_key=Config.LLM_API_KEY, + base_url=Config.LLM_BASE_URL, + ) + ) + driver = AsyncGraphDatabase.driver( + self._uri, auth=(self._user, self._password) + ) + return Graphiti(driver=driver, llm_client=llm_client, embedder=embedder) + + def create_graph(self, graph_id: str, name: str, description: str = "") -> None: + logger.info(f"Graphiti graph namespace ready: {graph_id}") + + def set_ontology(self, graph_ids: List[str], entities: Dict[str, Any], edges: Dict[str, Any]) -> None: + logger.info("Graphiti uses LLM-driven ontology extraction; set_ontology is a no-op.") + + def add_batch(self, graph_id: str, episodes: List[Any]) -> List[str]: + from graphiti_core.nodes import EpisodeType + ids = [] + for ep in episodes: + data = ep["data"] if isinstance(ep, dict) else ep.data + ep_id = str(uuid_mod.uuid4()) + _run_async( + self._client.add_episode( + name=ep_id, + episode_body=data, + source=EpisodeType.text, + group_id=graph_id, + ) + ) + ids.append(ep_id) + return ids + + def get_episode(self, uuid_: str) -> Any: + class _FakeEpisode: + processed = True + return _FakeEpisode() + + def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]: + results = _run_async( + self._client.driver.execute_query( + "MATCH (n {group_id: $gid}) RETURN n", + {"gid": graph_id}, + ) + ) + nodes = [] + for record in results.records: + n = record["n"] + nodes.append({ + "uuid": n.get("uuid", str(n.id)), + "name": n.get("name", ""), + "labels": list(n.labels), + "summary": n.get("summary", ""), + "attributes": dict(n), + "created_at": str(n.get("created_at", "")), + }) + return nodes + + def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]: + results = _run_async( + self._client.driver.execute_query( + "MATCH (s)-[r]->(t) WHERE r.group_id = $gid RETURN s, r, t", + {"gid": graph_id}, + ) + ) + edges = [] + for record in results.records: + r = record["r"] + edges.append({ + "uuid": r.get("uuid", str(r.id)), + "name": r.get("name", type(r).__name__), + "fact": r.get("fact", ""), + "source_node_uuid": record["s"].get("uuid", ""), + "target_node_uuid": record["t"].get("uuid", ""), + "fact_type": r.get("fact_type", ""), + "attributes": dict(r), + "created_at": str(r.get("created_at", "")), + "valid_at": str(r.get("valid_at", "")), + "invalid_at": str(r.get("invalid_at", "")), + "expired_at": str(r.get("expired_at", "")), + "episodes": [], + }) + return edges + + def get_node(self, uuid_: str) -> Dict[str, Any]: + results = _run_async( + self._client.driver.execute_query( + "MATCH (n {uuid: $uuid}) RETURN n LIMIT 1", + {"uuid": uuid_}, + ) + ) + if not results.records: + return {} + n = results.records[0]["n"] + return { + "uuid": n.get("uuid", ""), + "name": n.get("name", ""), + "labels": list(n.labels), + "summary": n.get("summary", ""), + "attributes": dict(n), + } + + def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]: + results = _run_async( + self._client.driver.execute_query( + "MATCH (n {uuid: $uuid})-[r]->(t) RETURN r, t " + "UNION MATCH (s)-[r]->(n {uuid: $uuid}) RETURN r, s as t", + {"uuid": node_uuid}, + ) + ) + edges = [] + for record in results.records: + r = record["r"] + edges.append({ + "uuid": r.get("uuid", str(r.id)), + "name": r.get("name", ""), + "fact": r.get("fact", ""), + "source_node_uuid": r.get("source_node_uuid", node_uuid), + "target_node_uuid": r.get("target_node_uuid", ""), + }) + return edges + + def search(self, graph_id: str, query: str, limit: int = 10, scope: str = "edges") -> Dict[str, Any]: + results = _run_async( + self._client.search(query=query, group_ids=[graph_id], num_results=limit) + ) + edges = [ + { + "uuid": getattr(r, "uuid", ""), + "name": getattr(r, "name", ""), + "fact": getattr(r, "fact", ""), + "source_node_uuid": getattr(r, "source_node_uuid", ""), + "target_node_uuid": getattr(r, "target_node_uuid", ""), + } + for r in (results or []) + ] + return {"edges": edges, "nodes": []} + + def add_text(self, graph_id: str, data: str) -> None: + ep_id = str(uuid_mod.uuid4()) + from graphiti_core.nodes import EpisodeType + _run_async( + self._client.add_episode( + name=ep_id, + episode_body=data, + source=EpisodeType.text, + group_id=graph_id, + ) + ) + + def delete_graph(self, graph_id: str) -> None: + _run_async( + self._client.driver.execute_query( + "MATCH (n {group_id: $gid}) DETACH DELETE n", + {"gid": graph_id}, + ) + ) diff --git a/backend/pyproject.toml b/backend/pyproject.toml index ccdd04f9..28a12e1f 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -37,6 +37,10 @@ dependencies = [ ] [project.optional-dependencies] +graphiti = [ + "graphiti-core>=0.3.0", + "neo4j>=5.23.0", +] dev = [ "pytest>=8.0.0", "pytest-asyncio>=0.23.0", diff --git a/backend/tests/test_graph_factory.py b/backend/tests/test_graph_factory.py index f1e9e60b..a9ebcc59 100644 --- a/backend/tests/test_graph_factory.py +++ b/backend/tests/test_graph_factory.py @@ -86,6 +86,30 @@ def test_factory_raises_on_unknown_backend(): fmod._backend_instance = None +def test_graphiti_backend_importable(): + try: + from backend.app.graph.graphiti_backend import GraphitiBackend + from backend.app.graph.base import GraphBackend + assert issubclass(GraphitiBackend, GraphBackend) + except ImportError as e: + pytest.skip(f"graphiti-core not installed: {e}") + + +def test_graphiti_backend_raises_without_password(): + try: + from backend.app.graph.graphiti_backend import GraphitiBackend + except ImportError: + pytest.skip("graphiti-core not installed") + import backend.app.config as cfg_mod + orig = cfg_mod.Config.NEO4J_PASSWORD + try: + cfg_mod.Config.NEO4J_PASSWORD = None + with pytest.raises(ValueError, match="NEO4J_PASSWORD"): + GraphitiBackend() + finally: + cfg_mod.Config.NEO4J_PASSWORD = orig + + def test_config_graphiti_errors_when_missing(): import backend.app.config as cfg_mod orig_backend = cfg_mod.Config.GRAPH_BACKEND