feat(graph): add GraphitiBackend for Neo4j+graphiti-core local deployment
This commit is contained in:
parent
b2fd7e1b87
commit
788f9c29c9
|
|
@ -0,0 +1,218 @@
|
|||
"""Graphiti + Neo4j implementation of GraphBackend."""
|
||||
import asyncio
|
||||
import threading
|
||||
import uuid as uuid_mod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .base import GraphBackend
|
||||
from ..config import Config
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger('mirofish.graph.graphiti')
|
||||
|
||||
|
||||
def _run_async(coro):
|
||||
"""Run an async coroutine from a sync context using a dedicated thread loop."""
|
||||
loop = _get_event_loop()
|
||||
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
||||
return future.result(timeout=120)
|
||||
|
||||
|
||||
_loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
_loop_thread: Optional[threading.Thread] = None
|
||||
_loop_lock = threading.Lock()
|
||||
|
||||
|
||||
def _get_event_loop() -> asyncio.AbstractEventLoop:
|
||||
global _loop, _loop_thread
|
||||
with _loop_lock:
|
||||
if _loop is None or not _loop.is_running():
|
||||
_loop = asyncio.new_event_loop()
|
||||
_loop_thread = threading.Thread(target=_loop.run_forever, daemon=True)
|
||||
_loop_thread.start()
|
||||
return _loop
|
||||
|
||||
|
||||
class GraphitiBackend(GraphBackend):
|
||||
def __init__(
|
||||
self,
|
||||
uri: Optional[str] = None,
|
||||
user: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
):
|
||||
self._uri = uri or Config.NEO4J_URI
|
||||
self._user = user or Config.NEO4J_USER
|
||||
self._password = password or Config.NEO4J_PASSWORD
|
||||
if not self._password:
|
||||
raise ValueError("NEO4J_PASSWORD is not configured")
|
||||
self._client = self._build_client()
|
||||
|
||||
def _build_client(self):
|
||||
from graphiti_core import Graphiti
|
||||
from graphiti_core.llm_client.openai_client import OpenAIClient
|
||||
from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig
|
||||
from neo4j import AsyncGraphDatabase
|
||||
|
||||
llm_client = OpenAIClient(
|
||||
api_key=Config.LLM_API_KEY,
|
||||
model=Config.LLM_MODEL_NAME,
|
||||
base_url=Config.LLM_BASE_URL,
|
||||
)
|
||||
embedder = OpenAIEmbedder(
|
||||
OpenAIEmbedderConfig(
|
||||
api_key=Config.LLM_API_KEY,
|
||||
base_url=Config.LLM_BASE_URL,
|
||||
)
|
||||
)
|
||||
driver = AsyncGraphDatabase.driver(
|
||||
self._uri, auth=(self._user, self._password)
|
||||
)
|
||||
return Graphiti(driver=driver, llm_client=llm_client, embedder=embedder)
|
||||
|
||||
def create_graph(self, graph_id: str, name: str, description: str = "") -> None:
|
||||
logger.info(f"Graphiti graph namespace ready: {graph_id}")
|
||||
|
||||
def set_ontology(self, graph_ids: List[str], entities: Dict[str, Any], edges: Dict[str, Any]) -> None:
|
||||
logger.info("Graphiti uses LLM-driven ontology extraction; set_ontology is a no-op.")
|
||||
|
||||
def add_batch(self, graph_id: str, episodes: List[Any]) -> List[str]:
|
||||
from graphiti_core.nodes import EpisodeType
|
||||
ids = []
|
||||
for ep in episodes:
|
||||
data = ep["data"] if isinstance(ep, dict) else ep.data
|
||||
ep_id = str(uuid_mod.uuid4())
|
||||
_run_async(
|
||||
self._client.add_episode(
|
||||
name=ep_id,
|
||||
episode_body=data,
|
||||
source=EpisodeType.text,
|
||||
group_id=graph_id,
|
||||
)
|
||||
)
|
||||
ids.append(ep_id)
|
||||
return ids
|
||||
|
||||
def get_episode(self, uuid_: str) -> Any:
|
||||
class _FakeEpisode:
|
||||
processed = True
|
||||
return _FakeEpisode()
|
||||
|
||||
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
|
||||
results = _run_async(
|
||||
self._client.driver.execute_query(
|
||||
"MATCH (n {group_id: $gid}) RETURN n",
|
||||
{"gid": graph_id},
|
||||
)
|
||||
)
|
||||
nodes = []
|
||||
for record in results.records:
|
||||
n = record["n"]
|
||||
nodes.append({
|
||||
"uuid": n.get("uuid", str(n.id)),
|
||||
"name": n.get("name", ""),
|
||||
"labels": list(n.labels),
|
||||
"summary": n.get("summary", ""),
|
||||
"attributes": dict(n),
|
||||
"created_at": str(n.get("created_at", "")),
|
||||
})
|
||||
return nodes
|
||||
|
||||
def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]:
|
||||
results = _run_async(
|
||||
self._client.driver.execute_query(
|
||||
"MATCH (s)-[r]->(t) WHERE r.group_id = $gid RETURN s, r, t",
|
||||
{"gid": graph_id},
|
||||
)
|
||||
)
|
||||
edges = []
|
||||
for record in results.records:
|
||||
r = record["r"]
|
||||
edges.append({
|
||||
"uuid": r.get("uuid", str(r.id)),
|
||||
"name": r.get("name", type(r).__name__),
|
||||
"fact": r.get("fact", ""),
|
||||
"source_node_uuid": record["s"].get("uuid", ""),
|
||||
"target_node_uuid": record["t"].get("uuid", ""),
|
||||
"fact_type": r.get("fact_type", ""),
|
||||
"attributes": dict(r),
|
||||
"created_at": str(r.get("created_at", "")),
|
||||
"valid_at": str(r.get("valid_at", "")),
|
||||
"invalid_at": str(r.get("invalid_at", "")),
|
||||
"expired_at": str(r.get("expired_at", "")),
|
||||
"episodes": [],
|
||||
})
|
||||
return edges
|
||||
|
||||
def get_node(self, uuid_: str) -> Dict[str, Any]:
|
||||
results = _run_async(
|
||||
self._client.driver.execute_query(
|
||||
"MATCH (n {uuid: $uuid}) RETURN n LIMIT 1",
|
||||
{"uuid": uuid_},
|
||||
)
|
||||
)
|
||||
if not results.records:
|
||||
return {}
|
||||
n = results.records[0]["n"]
|
||||
return {
|
||||
"uuid": n.get("uuid", ""),
|
||||
"name": n.get("name", ""),
|
||||
"labels": list(n.labels),
|
||||
"summary": n.get("summary", ""),
|
||||
"attributes": dict(n),
|
||||
}
|
||||
|
||||
def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]:
|
||||
results = _run_async(
|
||||
self._client.driver.execute_query(
|
||||
"MATCH (n {uuid: $uuid})-[r]->(t) RETURN r, t "
|
||||
"UNION MATCH (s)-[r]->(n {uuid: $uuid}) RETURN r, s as t",
|
||||
{"uuid": node_uuid},
|
||||
)
|
||||
)
|
||||
edges = []
|
||||
for record in results.records:
|
||||
r = record["r"]
|
||||
edges.append({
|
||||
"uuid": r.get("uuid", str(r.id)),
|
||||
"name": r.get("name", ""),
|
||||
"fact": r.get("fact", ""),
|
||||
"source_node_uuid": r.get("source_node_uuid", node_uuid),
|
||||
"target_node_uuid": r.get("target_node_uuid", ""),
|
||||
})
|
||||
return edges
|
||||
|
||||
def search(self, graph_id: str, query: str, limit: int = 10, scope: str = "edges") -> Dict[str, Any]:
|
||||
results = _run_async(
|
||||
self._client.search(query=query, group_ids=[graph_id], num_results=limit)
|
||||
)
|
||||
edges = [
|
||||
{
|
||||
"uuid": getattr(r, "uuid", ""),
|
||||
"name": getattr(r, "name", ""),
|
||||
"fact": getattr(r, "fact", ""),
|
||||
"source_node_uuid": getattr(r, "source_node_uuid", ""),
|
||||
"target_node_uuid": getattr(r, "target_node_uuid", ""),
|
||||
}
|
||||
for r in (results or [])
|
||||
]
|
||||
return {"edges": edges, "nodes": []}
|
||||
|
||||
def add_text(self, graph_id: str, data: str) -> None:
|
||||
ep_id = str(uuid_mod.uuid4())
|
||||
from graphiti_core.nodes import EpisodeType
|
||||
_run_async(
|
||||
self._client.add_episode(
|
||||
name=ep_id,
|
||||
episode_body=data,
|
||||
source=EpisodeType.text,
|
||||
group_id=graph_id,
|
||||
)
|
||||
)
|
||||
|
||||
def delete_graph(self, graph_id: str) -> None:
|
||||
_run_async(
|
||||
self._client.driver.execute_query(
|
||||
"MATCH (n {group_id: $gid}) DETACH DELETE n",
|
||||
{"gid": graph_id},
|
||||
)
|
||||
)
|
||||
|
|
@ -37,6 +37,10 @@ dependencies = [
|
|||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
graphiti = [
|
||||
"graphiti-core>=0.3.0",
|
||||
"neo4j>=5.23.0",
|
||||
]
|
||||
dev = [
|
||||
"pytest>=8.0.0",
|
||||
"pytest-asyncio>=0.23.0",
|
||||
|
|
|
|||
|
|
@ -86,6 +86,30 @@ def test_factory_raises_on_unknown_backend():
|
|||
fmod._backend_instance = None
|
||||
|
||||
|
||||
def test_graphiti_backend_importable():
|
||||
try:
|
||||
from backend.app.graph.graphiti_backend import GraphitiBackend
|
||||
from backend.app.graph.base import GraphBackend
|
||||
assert issubclass(GraphitiBackend, GraphBackend)
|
||||
except ImportError as e:
|
||||
pytest.skip(f"graphiti-core not installed: {e}")
|
||||
|
||||
|
||||
def test_graphiti_backend_raises_without_password():
|
||||
try:
|
||||
from backend.app.graph.graphiti_backend import GraphitiBackend
|
||||
except ImportError:
|
||||
pytest.skip("graphiti-core not installed")
|
||||
import backend.app.config as cfg_mod
|
||||
orig = cfg_mod.Config.NEO4J_PASSWORD
|
||||
try:
|
||||
cfg_mod.Config.NEO4J_PASSWORD = None
|
||||
with pytest.raises(ValueError, match="NEO4J_PASSWORD"):
|
||||
GraphitiBackend()
|
||||
finally:
|
||||
cfg_mod.Config.NEO4J_PASSWORD = orig
|
||||
|
||||
|
||||
def test_config_graphiti_errors_when_missing():
|
||||
import backend.app.config as cfg_mod
|
||||
orig_backend = cfg_mod.Config.GRAPH_BACKEND
|
||||
|
|
|
|||
Loading…
Reference in New Issue