554 lines
23 KiB
Python
554 lines
23 KiB
Python
"""Graphiti + Neo4j implementation of GraphBackend."""
|
|
import asyncio
|
|
import json
|
|
import threading
|
|
import typing
|
|
import uuid as uuid_mod
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from .base import GraphBackend
|
|
from ..config import Config
|
|
from ..utils.logger import get_logger
|
|
from ..utils.llm_client import parse_azure_url
|
|
|
|
|
|
def _neo4j_val(v: Any) -> Any:
|
|
"""Convert Neo4j native types to JSON-serializable Python types."""
|
|
if v is None:
|
|
return None
|
|
t = type(v).__name__
|
|
if t in ('DateTime', 'Date', 'Time', 'LocalDateTime', 'LocalTime', 'Duration'):
|
|
return str(v)
|
|
if isinstance(v, (list, tuple)):
|
|
return [_neo4j_val(i) for i in v]
|
|
if isinstance(v, dict):
|
|
return {k: _neo4j_val(vv) for k, vv in v.items()}
|
|
return v
|
|
|
|
|
|
def _flatten_attributes(attrs: dict) -> dict:
|
|
"""Flatten entity attribute dicts so every value is a Neo4j-safe primitive.
|
|
|
|
Graphiti extracts entity attributes via a Pydantic model, but the raw LLM
|
|
response sometimes wraps each value in a nested dict (e.g. {"value": "CTTI"}).
|
|
Neo4j only accepts primitive types or arrays thereof, so we coerce any
|
|
dict value to its string representation. Lists of primitives are kept as-is
|
|
because Neo4j supports array properties.
|
|
"""
|
|
result = {}
|
|
for k, v in attrs.items():
|
|
if v is None:
|
|
continue
|
|
if isinstance(v, dict):
|
|
# Unwrap {"value": "..."} pattern emitted by some LLMs; fall back to str()
|
|
result[k] = v.get("value") or v.get("text") or str(v)
|
|
else:
|
|
result[k] = v
|
|
return result
|
|
|
|
|
|
def _neo4j_props(node_or_rel: Any) -> Dict[str, Any]:
|
|
"""Return a JSON-safe dict of a Neo4j node or relationship's properties."""
|
|
return {k: _neo4j_val(v) for k, v in dict(node_or_rel).items()}
|
|
|
|
logger = get_logger('mirofish.graph.graphiti')
|
|
|
|
|
|
def _make_azure_generic_client(config, client):
|
|
"""Return an OpenAIGenericClient subclass that uses max_completion_tokens
|
|
instead of max_tokens — required by gpt-5 / o-series models on Azure."""
|
|
from graphiti_core.llm_client.openai_generic_client import OpenAIGenericClient
|
|
import openai as _openai
|
|
from graphiti_core.llm_client.errors import RateLimitError as _RateLimitError
|
|
from pydantic import BaseModel as _BaseModel
|
|
|
|
class _AzureGenericClient(OpenAIGenericClient):
|
|
async def _generate_response(self, messages, response_model=None, max_tokens=None, model_size=None):
|
|
from openai.types.chat import ChatCompletionMessageParam
|
|
if max_tokens is None:
|
|
max_tokens = self.max_tokens
|
|
openai_messages: list[ChatCompletionMessageParam] = []
|
|
for m in messages:
|
|
if m.role == 'user':
|
|
openai_messages.append({'role': 'user', 'content': m.content})
|
|
elif m.role == 'system':
|
|
openai_messages.append({'role': 'system', 'content': m.content})
|
|
response_format: dict[str, Any] = {'type': 'json_object'}
|
|
if response_model is not None:
|
|
schema_name = getattr(response_model, '__name__', 'structured_response')
|
|
response_format = {
|
|
'type': 'json_schema',
|
|
'json_schema': {
|
|
'name': schema_name,
|
|
'schema': response_model.model_json_schema(),
|
|
},
|
|
}
|
|
try:
|
|
response = await self.client.chat.completions.create(
|
|
model=self.model,
|
|
messages=openai_messages,
|
|
temperature=self.temperature,
|
|
max_completion_tokens=max_tokens,
|
|
response_format=response_format,
|
|
)
|
|
return json.loads(response.choices[0].message.content or '{}')
|
|
except _openai.RateLimitError as e:
|
|
raise _RateLimitError from e
|
|
|
|
return _AzureGenericClient(config=config, client=client)
|
|
|
|
|
|
def _run_async(coro, timeout=300):
|
|
"""Run an async coroutine from a sync context using a dedicated thread loop."""
|
|
loop = _get_event_loop()
|
|
future = asyncio.run_coroutine_threadsafe(coro, loop)
|
|
return future.result(timeout=timeout)
|
|
|
|
|
|
_loop: Optional[asyncio.AbstractEventLoop] = None
|
|
_loop_thread: Optional[threading.Thread] = None
|
|
_loop_lock = threading.Lock()
|
|
|
|
|
|
def _get_event_loop() -> asyncio.AbstractEventLoop:
|
|
global _loop, _loop_thread
|
|
with _loop_lock:
|
|
if _loop is None or not _loop.is_running():
|
|
_loop = asyncio.new_event_loop()
|
|
_loop_thread = threading.Thread(target=_loop.run_forever, daemon=True)
|
|
_loop_thread.start()
|
|
return _loop
|
|
|
|
|
|
class GraphitiBackend(GraphBackend):
|
|
def __init__(
|
|
self,
|
|
uri: Optional[str] = None,
|
|
user: Optional[str] = None,
|
|
password: Optional[str] = None,
|
|
):
|
|
self._uri = uri or Config.NEO4J_URI
|
|
self._user = user or Config.NEO4J_USER
|
|
self._password = password or Config.NEO4J_PASSWORD
|
|
if not self._password:
|
|
raise ValueError("NEO4J_PASSWORD is not configured")
|
|
self._entity_types: Dict[str, Any] = {}
|
|
self._edge_types: Dict[str, Any] = {}
|
|
self._entity_defs: Dict[str, Any] = {}
|
|
self._edge_defs: Dict[str, Any] = {}
|
|
self._client = self._build_client()
|
|
|
|
def _build_client(self):
|
|
from graphiti_core import Graphiti
|
|
from graphiti_core.llm_client.openai_generic_client import OpenAIGenericClient
|
|
from graphiti_core.llm_client.config import LLMConfig
|
|
from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig
|
|
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
|
from openai import AsyncOpenAI
|
|
|
|
llm_base_url, llm_query = parse_azure_url(Config.LLM_BASE_URL)
|
|
small_base_url, small_query = parse_azure_url(Config.LLM_SMALL_BASE_URL)
|
|
embed_base_url, embed_query = parse_azure_url(Config.LLM_EMBED_BASE_URL)
|
|
|
|
# Pre-built async clients so api-version is passed as default_query (Azure requirement)
|
|
async_llm_client = AsyncOpenAI(
|
|
api_key=Config.LLM_API_KEY,
|
|
base_url=llm_base_url,
|
|
default_query=llm_query or None,
|
|
)
|
|
async_small_client = AsyncOpenAI(
|
|
api_key=Config.LLM_SMALL_API_KEY,
|
|
base_url=small_base_url,
|
|
default_query=small_query or None,
|
|
)
|
|
async_embed_client = AsyncOpenAI(
|
|
api_key=Config.LLM_EMBED_API_KEY,
|
|
base_url=embed_base_url,
|
|
default_query=embed_query or None,
|
|
)
|
|
|
|
llm_config = LLMConfig(
|
|
api_key=Config.LLM_API_KEY,
|
|
model=Config.LLM_MODEL_NAME,
|
|
small_model=Config.LLM_SMALL_MODEL_NAME,
|
|
base_url=llm_base_url,
|
|
)
|
|
llm_client = _make_azure_generic_client(config=llm_config, client=async_llm_client)
|
|
embedder = OpenAIEmbedder(
|
|
config=OpenAIEmbedderConfig(
|
|
api_key=Config.LLM_EMBED_API_KEY,
|
|
base_url=embed_base_url,
|
|
embedding_model=Config.LLM_EMBED_MODEL_NAME,
|
|
),
|
|
client=async_embed_client,
|
|
)
|
|
cross_encoder = OpenAIRerankerClient(config=llm_config, client=async_small_client)
|
|
client = Graphiti(
|
|
uri=self._uri,
|
|
user=self._user,
|
|
password=self._password,
|
|
llm_client=llm_client,
|
|
embedder=embedder,
|
|
cross_encoder=cross_encoder,
|
|
)
|
|
self._patch_extract_entity_attributes()
|
|
return client
|
|
|
|
@staticmethod
|
|
def _patch_extract_entity_attributes() -> None:
|
|
"""Monkey-patch graphiti internals to fix two LLM quirks:
|
|
|
|
1. _extract_entity_attributes: some LLMs wrap attribute values in nested
|
|
dicts ({"value": "CTTI"}). Neo4j rejects these — flatten them.
|
|
2. _extract_nodes_single: some LLMs omit entity_type_id from extracted
|
|
entities, causing a Pydantic ValidationError. Default missing IDs to 0
|
|
(the generic "Entity" type) before validation runs.
|
|
"""
|
|
import graphiti_core.utils.maintenance.node_operations as _node_ops
|
|
|
|
# --- patch 1: attribute flattening ---
|
|
original_attrs = _node_ops._extract_entity_attributes
|
|
|
|
async def _patched_attrs(llm_client, node, episode, previous_episodes, entity_type):
|
|
result = await original_attrs(llm_client, node, episode, previous_episodes, entity_type)
|
|
return _flatten_attributes(result) if result else result
|
|
|
|
_node_ops._extract_entity_attributes = _patched_attrs
|
|
|
|
# --- patch 2: entity_type_id defaulting ---
|
|
original_nodes = _node_ops._extract_nodes_single
|
|
|
|
async def _patched_nodes(llm_client, episode, context):
|
|
from graphiti_core.utils.maintenance.node_operations import ExtractedEntities
|
|
# Call the LLM the normal way but catch the Pydantic validation error
|
|
# that arises when the LLM forgets entity_type_id.
|
|
try:
|
|
return await original_nodes(llm_client, episode, context)
|
|
except Exception as exc:
|
|
# Only intercept Pydantic validation errors about entity_type_id
|
|
if "entity_type_id" not in str(exc):
|
|
raise
|
|
logger.warning(f"LLM omitted entity_type_id — defaulting to 0 and retrying validation: {exc}")
|
|
# Re-run the LLM call via the internal helper to get the raw dict
|
|
from graphiti_core.utils.maintenance.node_operations import _call_extraction_llm
|
|
llm_response = await _call_extraction_llm(llm_client, episode, context)
|
|
# Inject entity_type_id=0 for any entity that is missing it
|
|
entities = llm_response.get("extracted_entities", [])
|
|
for ent in entities:
|
|
if isinstance(ent, dict) and "entity_type_id" not in ent:
|
|
ent["entity_type_id"] = 0
|
|
response_object = ExtractedEntities(**llm_response)
|
|
return response_object.extracted_entities
|
|
|
|
_node_ops._extract_nodes_single = _patched_nodes
|
|
|
|
def create_graph(self, graph_id: str, name: str, description: str = "") -> None:
|
|
logger.info(f"Graphiti graph namespace ready: {graph_id}")
|
|
|
|
def set_ontology(self, graph_ids: List[str], entities: Dict[str, Any], edges: Dict[str, Any]) -> None:
|
|
from pydantic import BaseModel as _BaseModel, Field as _Field
|
|
|
|
def _make_model(name: str, type_def: Any) -> Any:
|
|
if isinstance(type_def, dict):
|
|
doc = type_def.get("description", "")
|
|
attrs_defs = type_def.get("attributes", [])
|
|
else:
|
|
doc = getattr(type_def, "__doc__", "") or ""
|
|
attrs_defs = []
|
|
|
|
annotations: Dict[str, Any] = {}
|
|
fields: Dict[str, Any] = {"__doc__": doc, "__annotations__": annotations}
|
|
for attr in attrs_defs:
|
|
attr_name = attr.get("name", "")
|
|
attr_desc = attr.get("description", attr_name)
|
|
if not attr_name:
|
|
continue
|
|
annotations[attr_name] = Optional[str]
|
|
fields[attr_name] = _Field(default=None, description=attr_desc)
|
|
|
|
return type(name, (_BaseModel,), fields)
|
|
|
|
self._entity_types: Dict[str, Any] = {
|
|
name: _make_model(name, td) for name, td in (entities or {}).items()
|
|
}
|
|
self._edge_types: Dict[str, Any] = {
|
|
name: _make_model(name, td) for name, td in (edges or {}).items()
|
|
}
|
|
# Keep a separate plain dict for use in extraction instructions
|
|
self._entity_defs: Dict[str, Any] = dict(entities or {})
|
|
self._edge_defs: Dict[str, Any] = dict(edges or {})
|
|
if self._entity_types:
|
|
logger.info(f"Graphiti entity types: {list(self._entity_types.keys())}")
|
|
if self._edge_types:
|
|
logger.info(f"Graphiti edge types: {list(self._edge_types.keys())}")
|
|
|
|
def _build_extraction_instructions(self) -> Optional[str]:
|
|
"""Return custom instructions that constrain extraction to ontology types and attributes."""
|
|
entity_defs = self._entity_defs or {}
|
|
edge_defs = self._edge_defs or {}
|
|
if not entity_defs and not edge_defs:
|
|
return None
|
|
|
|
parts = []
|
|
|
|
if entity_defs:
|
|
entity_lines = []
|
|
for name, td in entity_defs.items():
|
|
desc = td.get("description", "") if isinstance(td, dict) else ""
|
|
attrs = td.get("attributes", []) if isinstance(td, dict) else []
|
|
if attrs:
|
|
attr_str = ", ".join(
|
|
f"{a['name']} ({a.get('description', a['name'])})"
|
|
for a in attrs if a.get("name")
|
|
)
|
|
entity_lines.append(f" - {name}: {desc} [attributes: {attr_str}]")
|
|
else:
|
|
entity_lines.append(f" - {name}: {desc}")
|
|
parts.append(
|
|
"Only classify entities using these types (use 'Entity' only if none fits):\n"
|
|
+ "\n".join(entity_lines)
|
|
+ "\nFor each entity, extract values for the listed attributes when present in the text."
|
|
)
|
|
|
|
if edge_defs:
|
|
edge_names = list(edge_defs.keys())
|
|
parts.append(
|
|
f"Only use these relationship types: {', '.join(edge_names)}. "
|
|
"Do not invent new relationship type names."
|
|
)
|
|
|
|
return "\n\n".join(parts)
|
|
|
|
def add_batch(self, graph_id: str, episodes: List[Any]) -> List[str]:
|
|
from graphiti_core.nodes import EpisodeType
|
|
from datetime import datetime, timezone
|
|
import time as _time
|
|
|
|
entity_types = self._entity_types or None
|
|
edge_types = self._edge_types or None
|
|
instructions = self._build_extraction_instructions()
|
|
ids = []
|
|
|
|
for ep in episodes:
|
|
data = ep["data"] if isinstance(ep, dict) else ep.data
|
|
ep_id = str(uuid_mod.uuid4())
|
|
ids.append(ep_id)
|
|
|
|
last_exc = None
|
|
for attempt in range(3):
|
|
try:
|
|
_run_async(
|
|
self._client.add_episode(
|
|
name=ep_id,
|
|
episode_body=data,
|
|
source_description="MiroFish document chunk",
|
|
reference_time=datetime.now(timezone.utc),
|
|
source=EpisodeType.text,
|
|
group_id=graph_id,
|
|
entity_types=entity_types,
|
|
edge_types=edge_types,
|
|
custom_extraction_instructions=instructions,
|
|
),
|
|
timeout=300,
|
|
)
|
|
last_exc = None
|
|
break
|
|
except Exception as exc:
|
|
last_exc = exc
|
|
# "node not found" race condition — wait and retry
|
|
if "not found" in str(exc).lower() and attempt < 2:
|
|
logger.warning(f"Episode {ep_id} attempt {attempt + 1} failed ({exc}), retrying...")
|
|
_time.sleep(2 * (attempt + 1))
|
|
else:
|
|
raise
|
|
|
|
if last_exc:
|
|
raise last_exc
|
|
|
|
return ids
|
|
|
|
def get_episode(self, uuid_: str) -> Any:
|
|
class _FakeEpisode:
|
|
processed = True
|
|
return _FakeEpisode()
|
|
|
|
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
|
|
results = _run_async(
|
|
self._client.driver.execute_query(
|
|
"MATCH (n:Entity {group_id: $gid}) RETURN n",
|
|
params={"gid": graph_id},
|
|
)
|
|
)
|
|
nodes = []
|
|
for record in results.records:
|
|
n = record["n"]
|
|
nodes.append({
|
|
"uuid": n.get("uuid", n.element_id),
|
|
"name": n.get("name", ""),
|
|
"labels": list(n.labels),
|
|
"summary": n.get("summary", ""),
|
|
"attributes": _neo4j_props(n),
|
|
"created_at": str(n.get("created_at", "")),
|
|
})
|
|
return nodes
|
|
|
|
def get_all_edges(self, graph_id: str, max_items: int = 5000) -> List[Dict[str, Any]]:
|
|
results = _run_async(
|
|
self._client.driver.execute_query(
|
|
"MATCH (s)-[r]->(t) WHERE r.group_id = $gid RETURN s, r, t LIMIT $limit",
|
|
params={"gid": graph_id, "limit": max_items},
|
|
)
|
|
)
|
|
if len(results.records) >= max_items:
|
|
logger.warning(
|
|
f"get_all_edges: result truncated at {max_items} edges for graph {graph_id}"
|
|
)
|
|
edges = []
|
|
for record in results.records:
|
|
r = record["r"]
|
|
edges.append({
|
|
"uuid": r.get("uuid", r.element_id),
|
|
"name": r.get("name", type(r).__name__),
|
|
"fact": r.get("fact", ""),
|
|
"source_node_uuid": record["s"].get("uuid", ""),
|
|
"target_node_uuid": record["t"].get("uuid", ""),
|
|
"fact_type": r.get("fact_type", ""),
|
|
"attributes": _neo4j_props(r),
|
|
"created_at": str(r.get("created_at", "")),
|
|
"valid_at": str(r.get("valid_at", "")),
|
|
"invalid_at": str(r.get("invalid_at", "")),
|
|
"expired_at": str(r.get("expired_at", "")),
|
|
"episodes": [],
|
|
})
|
|
return edges
|
|
|
|
def get_node(self, uuid_: str) -> Dict[str, Any]:
|
|
results = _run_async(
|
|
self._client.driver.execute_query(
|
|
"MATCH (n {uuid: $uuid}) RETURN n LIMIT 1",
|
|
params={"uuid": uuid_},
|
|
)
|
|
)
|
|
if not results.records:
|
|
return {}
|
|
n = results.records[0]["n"]
|
|
return {
|
|
"uuid": n.get("uuid", ""),
|
|
"name": n.get("name", ""),
|
|
"labels": list(n.labels),
|
|
"summary": n.get("summary", ""),
|
|
"attributes": _neo4j_props(n),
|
|
}
|
|
|
|
def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]:
|
|
results = _run_async(
|
|
self._client.driver.execute_query(
|
|
"MATCH (n {uuid: $uuid})-[r]->(t) RETURN r, t "
|
|
"UNION MATCH (s)-[r]->(n {uuid: $uuid}) RETURN r, s as t",
|
|
params={"uuid": node_uuid},
|
|
)
|
|
)
|
|
edges = []
|
|
for record in results.records:
|
|
r = record["r"]
|
|
edges.append({
|
|
"uuid": r.get("uuid", r.element_id),
|
|
"name": r.get("name", ""),
|
|
"fact": r.get("fact", ""),
|
|
"source_node_uuid": r.get("source_node_uuid", node_uuid),
|
|
"target_node_uuid": r.get("target_node_uuid", ""),
|
|
})
|
|
return edges
|
|
|
|
def search(self, graph_id: str, query: str, limit: int = 10, scope: str = "edges") -> Dict[str, Any]:
|
|
max_retries = 3
|
|
delay = 2.0
|
|
last_exc = None
|
|
for attempt in range(max_retries):
|
|
try:
|
|
results = _run_async(
|
|
self._client.search(query=query, group_ids=[graph_id], num_results=limit)
|
|
)
|
|
edges = [
|
|
{
|
|
"uuid": getattr(r, "uuid", ""),
|
|
"name": getattr(r, "name", ""),
|
|
"fact": getattr(r, "fact", ""),
|
|
"source_node_uuid": getattr(r, "source_node_uuid", ""),
|
|
"target_node_uuid": getattr(r, "target_node_uuid", ""),
|
|
}
|
|
for r in (results or [])
|
|
]
|
|
return {"edges": edges, "nodes": []}
|
|
except Exception as e:
|
|
last_exc = e
|
|
logger.debug(
|
|
f"Graphiti search attempt {attempt + 1}/{max_retries} failed: "
|
|
f"{type(e).__name__}: {e}"
|
|
)
|
|
if attempt < max_retries - 1:
|
|
import time as _time
|
|
_time.sleep(delay)
|
|
delay *= 2
|
|
# Reconnect in case the Neo4j TCP connection dropped
|
|
try:
|
|
self._client = self._build_client()
|
|
except Exception as rebuild_exc:
|
|
logger.warning(f"Graphiti client rebuild failed: {rebuild_exc}")
|
|
import traceback as _tb
|
|
logger.error(
|
|
f"Graphiti search failed after {max_retries} attempts: "
|
|
f"{type(last_exc).__name__}: {last_exc}\n{_tb.format_exc()}"
|
|
)
|
|
raise last_exc
|
|
|
|
def add_text(self, graph_id: str, data: str) -> None:
|
|
from graphiti_core.nodes import EpisodeType
|
|
from datetime import datetime, timezone
|
|
ep_id = str(uuid_mod.uuid4())
|
|
_run_async(
|
|
self._client.add_episode(
|
|
name=ep_id,
|
|
episode_body=data,
|
|
source_description="MiroFish document chunk",
|
|
reference_time=datetime.now(timezone.utc),
|
|
source=EpisodeType.text,
|
|
group_id=graph_id,
|
|
entity_types=self._entity_types or None,
|
|
edge_types=self._edge_types or None,
|
|
custom_extraction_instructions=self._build_extraction_instructions(),
|
|
)
|
|
)
|
|
|
|
async def _execute_neo4j_query(self, query: str, parameters: dict = None):
|
|
"""Execute a raw Cypher query against the async Neo4j driver."""
|
|
kwargs = {"params": parameters} if parameters else {}
|
|
return await self._client.driver.execute_query(query, **kwargs)
|
|
|
|
async def clone_graph(self, src_group_id: str, dst_group_id: str) -> None:
|
|
"""Clone all nodes and relationships from src_group_id to dst_group_id."""
|
|
clone_nodes_query = """
|
|
MATCH (n) WHERE n.group_id = $src
|
|
WITH n, properties(n) AS props
|
|
CREATE (m) SET m = props SET m.group_id = $dst
|
|
"""
|
|
await self._execute_neo4j_query(clone_nodes_query, {"src": src_group_id, "dst": dst_group_id})
|
|
|
|
clone_rels_query = """
|
|
MATCH (n)-[r]->(m)
|
|
WHERE n.group_id = $src AND m.group_id = $src
|
|
MATCH (n2 {uuid: n.uuid, group_id: $dst})
|
|
MATCH (m2 {uuid: m.uuid, group_id: $dst})
|
|
CALL apoc.create.relationship(n2, type(r), properties(r), m2) YIELD rel
|
|
SET rel.group_id = $dst
|
|
RETURN rel
|
|
"""
|
|
await self._execute_neo4j_query(clone_rels_query, {"src": src_group_id, "dst": dst_group_id})
|
|
|
|
def delete_graph(self, graph_id: str) -> None:
|
|
"""Delete all nodes and relationships for a given group_id."""
|
|
_run_async(self._execute_neo4j_query(
|
|
"MATCH (n) WHERE n.group_id = $gid DETACH DELETE n",
|
|
{"gid": graph_id},
|
|
))
|