219 lines
7.6 KiB
Python
219 lines
7.6 KiB
Python
"""Graphiti + Neo4j implementation of GraphBackend."""
|
|
import asyncio
|
|
import threading
|
|
import uuid as uuid_mod
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from .base import GraphBackend
|
|
from ..config import Config
|
|
from ..utils.logger import get_logger
|
|
|
|
logger = get_logger('mirofish.graph.graphiti')
|
|
|
|
|
|
def _run_async(coro):
|
|
"""Run an async coroutine from a sync context using a dedicated thread loop."""
|
|
loop = _get_event_loop()
|
|
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
|
return future.result(timeout=120)
|
|
|
|
|
|
_loop: Optional[asyncio.AbstractEventLoop] = None
|
|
_loop_thread: Optional[threading.Thread] = None
|
|
_loop_lock = threading.Lock()
|
|
|
|
|
|
def _get_event_loop() -> asyncio.AbstractEventLoop:
|
|
global _loop, _loop_thread
|
|
with _loop_lock:
|
|
if _loop is None or not _loop.is_running():
|
|
_loop = asyncio.new_event_loop()
|
|
_loop_thread = threading.Thread(target=_loop.run_forever, daemon=True)
|
|
_loop_thread.start()
|
|
return _loop
|
|
|
|
|
|
class GraphitiBackend(GraphBackend):
|
|
def __init__(
|
|
self,
|
|
uri: Optional[str] = None,
|
|
user: Optional[str] = None,
|
|
password: Optional[str] = None,
|
|
):
|
|
self._uri = uri or Config.NEO4J_URI
|
|
self._user = user or Config.NEO4J_USER
|
|
self._password = password or Config.NEO4J_PASSWORD
|
|
if not self._password:
|
|
raise ValueError("NEO4J_PASSWORD is not configured")
|
|
self._client = self._build_client()
|
|
|
|
def _build_client(self):
|
|
from graphiti_core import Graphiti
|
|
from graphiti_core.llm_client.openai_client import OpenAIClient
|
|
from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig
|
|
from neo4j import AsyncGraphDatabase
|
|
|
|
llm_client = OpenAIClient(
|
|
api_key=Config.LLM_API_KEY,
|
|
model=Config.LLM_MODEL_NAME,
|
|
base_url=Config.LLM_BASE_URL,
|
|
)
|
|
embedder = OpenAIEmbedder(
|
|
OpenAIEmbedderConfig(
|
|
api_key=Config.LLM_API_KEY,
|
|
base_url=Config.LLM_BASE_URL,
|
|
)
|
|
)
|
|
driver = AsyncGraphDatabase.driver(
|
|
self._uri, auth=(self._user, self._password)
|
|
)
|
|
return Graphiti(driver=driver, llm_client=llm_client, embedder=embedder)
|
|
|
|
def create_graph(self, graph_id: str, name: str, description: str = "") -> None:
|
|
logger.info(f"Graphiti graph namespace ready: {graph_id}")
|
|
|
|
def set_ontology(self, graph_ids: List[str], entities: Dict[str, Any], edges: Dict[str, Any]) -> None:
|
|
logger.info("Graphiti uses LLM-driven ontology extraction; set_ontology is a no-op.")
|
|
|
|
def add_batch(self, graph_id: str, episodes: List[Any]) -> List[str]:
|
|
from graphiti_core.nodes import EpisodeType
|
|
ids = []
|
|
for ep in episodes:
|
|
data = ep["data"] if isinstance(ep, dict) else ep.data
|
|
ep_id = str(uuid_mod.uuid4())
|
|
_run_async(
|
|
self._client.add_episode(
|
|
name=ep_id,
|
|
episode_body=data,
|
|
source=EpisodeType.text,
|
|
group_id=graph_id,
|
|
)
|
|
)
|
|
ids.append(ep_id)
|
|
return ids
|
|
|
|
def get_episode(self, uuid_: str) -> Any:
|
|
class _FakeEpisode:
|
|
processed = True
|
|
return _FakeEpisode()
|
|
|
|
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
|
|
results = _run_async(
|
|
self._client.driver.execute_query(
|
|
"MATCH (n {group_id: $gid}) RETURN n",
|
|
{"gid": graph_id},
|
|
)
|
|
)
|
|
nodes = []
|
|
for record in results.records:
|
|
n = record["n"]
|
|
nodes.append({
|
|
"uuid": n.get("uuid", str(n.id)),
|
|
"name": n.get("name", ""),
|
|
"labels": list(n.labels),
|
|
"summary": n.get("summary", ""),
|
|
"attributes": dict(n),
|
|
"created_at": str(n.get("created_at", "")),
|
|
})
|
|
return nodes
|
|
|
|
def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]:
|
|
results = _run_async(
|
|
self._client.driver.execute_query(
|
|
"MATCH (s)-[r]->(t) WHERE r.group_id = $gid RETURN s, r, t",
|
|
{"gid": graph_id},
|
|
)
|
|
)
|
|
edges = []
|
|
for record in results.records:
|
|
r = record["r"]
|
|
edges.append({
|
|
"uuid": r.get("uuid", str(r.id)),
|
|
"name": r.get("name", type(r).__name__),
|
|
"fact": r.get("fact", ""),
|
|
"source_node_uuid": record["s"].get("uuid", ""),
|
|
"target_node_uuid": record["t"].get("uuid", ""),
|
|
"fact_type": r.get("fact_type", ""),
|
|
"attributes": dict(r),
|
|
"created_at": str(r.get("created_at", "")),
|
|
"valid_at": str(r.get("valid_at", "")),
|
|
"invalid_at": str(r.get("invalid_at", "")),
|
|
"expired_at": str(r.get("expired_at", "")),
|
|
"episodes": [],
|
|
})
|
|
return edges
|
|
|
|
def get_node(self, uuid_: str) -> Dict[str, Any]:
|
|
results = _run_async(
|
|
self._client.driver.execute_query(
|
|
"MATCH (n {uuid: $uuid}) RETURN n LIMIT 1",
|
|
{"uuid": uuid_},
|
|
)
|
|
)
|
|
if not results.records:
|
|
return {}
|
|
n = results.records[0]["n"]
|
|
return {
|
|
"uuid": n.get("uuid", ""),
|
|
"name": n.get("name", ""),
|
|
"labels": list(n.labels),
|
|
"summary": n.get("summary", ""),
|
|
"attributes": dict(n),
|
|
}
|
|
|
|
def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]:
|
|
results = _run_async(
|
|
self._client.driver.execute_query(
|
|
"MATCH (n {uuid: $uuid})-[r]->(t) RETURN r, t "
|
|
"UNION MATCH (s)-[r]->(n {uuid: $uuid}) RETURN r, s as t",
|
|
{"uuid": node_uuid},
|
|
)
|
|
)
|
|
edges = []
|
|
for record in results.records:
|
|
r = record["r"]
|
|
edges.append({
|
|
"uuid": r.get("uuid", str(r.id)),
|
|
"name": r.get("name", ""),
|
|
"fact": r.get("fact", ""),
|
|
"source_node_uuid": r.get("source_node_uuid", node_uuid),
|
|
"target_node_uuid": r.get("target_node_uuid", ""),
|
|
})
|
|
return edges
|
|
|
|
def search(self, graph_id: str, query: str, limit: int = 10, scope: str = "edges") -> Dict[str, Any]:
|
|
results = _run_async(
|
|
self._client.search(query=query, group_ids=[graph_id], num_results=limit)
|
|
)
|
|
edges = [
|
|
{
|
|
"uuid": getattr(r, "uuid", ""),
|
|
"name": getattr(r, "name", ""),
|
|
"fact": getattr(r, "fact", ""),
|
|
"source_node_uuid": getattr(r, "source_node_uuid", ""),
|
|
"target_node_uuid": getattr(r, "target_node_uuid", ""),
|
|
}
|
|
for r in (results or [])
|
|
]
|
|
return {"edges": edges, "nodes": []}
|
|
|
|
def add_text(self, graph_id: str, data: str) -> None:
|
|
ep_id = str(uuid_mod.uuid4())
|
|
from graphiti_core.nodes import EpisodeType
|
|
_run_async(
|
|
self._client.add_episode(
|
|
name=ep_id,
|
|
episode_body=data,
|
|
source=EpisodeType.text,
|
|
group_id=graph_id,
|
|
)
|
|
)
|
|
|
|
def delete_graph(self, graph_id: str) -> None:
|
|
_run_async(
|
|
self._client.driver.execute_query(
|
|
"MATCH (n {group_id: $gid}) DETACH DELETE n",
|
|
{"gid": graph_id},
|
|
)
|
|
)
|