MicroFish/backend/app/graph/graphiti_backend.py

554 lines
23 KiB
Python

"""Graphiti + Neo4j implementation of GraphBackend."""
import asyncio
import json
import threading
import typing
import uuid as uuid_mod
from typing import Any, Dict, List, Optional
from .base import GraphBackend
from ..config import Config
from ..utils.logger import get_logger
from ..utils.llm_client import parse_azure_url
def _neo4j_val(v: Any) -> Any:
"""Convert Neo4j native types to JSON-serializable Python types."""
if v is None:
return None
t = type(v).__name__
if t in ('DateTime', 'Date', 'Time', 'LocalDateTime', 'LocalTime', 'Duration'):
return str(v)
if isinstance(v, (list, tuple)):
return [_neo4j_val(i) for i in v]
if isinstance(v, dict):
return {k: _neo4j_val(vv) for k, vv in v.items()}
return v
def _flatten_attributes(attrs: dict) -> dict:
"""Flatten entity attribute dicts so every value is a Neo4j-safe primitive.
Graphiti extracts entity attributes via a Pydantic model, but the raw LLM
response sometimes wraps each value in a nested dict (e.g. {"value": "CTTI"}).
Neo4j only accepts primitive types or arrays thereof, so we coerce any
dict value to its string representation. Lists of primitives are kept as-is
because Neo4j supports array properties.
"""
result = {}
for k, v in attrs.items():
if v is None:
continue
if isinstance(v, dict):
# Unwrap {"value": "..."} pattern emitted by some LLMs; fall back to str()
result[k] = v.get("value") or v.get("text") or str(v)
else:
result[k] = v
return result
def _neo4j_props(node_or_rel: Any) -> Dict[str, Any]:
"""Return a JSON-safe dict of a Neo4j node or relationship's properties."""
return {k: _neo4j_val(v) for k, v in dict(node_or_rel).items()}
logger = get_logger('mirofish.graph.graphiti')
def _make_azure_generic_client(config, client):
"""Return an OpenAIGenericClient subclass that uses max_completion_tokens
instead of max_tokens — required by gpt-5 / o-series models on Azure."""
from graphiti_core.llm_client.openai_generic_client import OpenAIGenericClient
import openai as _openai
from graphiti_core.llm_client.errors import RateLimitError as _RateLimitError
from pydantic import BaseModel as _BaseModel
class _AzureGenericClient(OpenAIGenericClient):
async def _generate_response(self, messages, response_model=None, max_tokens=None, model_size=None):
from openai.types.chat import ChatCompletionMessageParam
if max_tokens is None:
max_tokens = self.max_tokens
openai_messages: list[ChatCompletionMessageParam] = []
for m in messages:
if m.role == 'user':
openai_messages.append({'role': 'user', 'content': m.content})
elif m.role == 'system':
openai_messages.append({'role': 'system', 'content': m.content})
response_format: dict[str, Any] = {'type': 'json_object'}
if response_model is not None:
schema_name = getattr(response_model, '__name__', 'structured_response')
response_format = {
'type': 'json_schema',
'json_schema': {
'name': schema_name,
'schema': response_model.model_json_schema(),
},
}
try:
response = await self.client.chat.completions.create(
model=self.model,
messages=openai_messages,
temperature=self.temperature,
max_completion_tokens=max_tokens,
response_format=response_format,
)
return json.loads(response.choices[0].message.content or '{}')
except _openai.RateLimitError as e:
raise _RateLimitError from e
return _AzureGenericClient(config=config, client=client)
def _run_async(coro, timeout=300):
"""Run an async coroutine from a sync context using a dedicated thread loop."""
loop = _get_event_loop()
future = asyncio.run_coroutine_threadsafe(coro, loop)
return future.result(timeout=timeout)
_loop: Optional[asyncio.AbstractEventLoop] = None
_loop_thread: Optional[threading.Thread] = None
_loop_lock = threading.Lock()
def _get_event_loop() -> asyncio.AbstractEventLoop:
global _loop, _loop_thread
with _loop_lock:
if _loop is None or not _loop.is_running():
_loop = asyncio.new_event_loop()
_loop_thread = threading.Thread(target=_loop.run_forever, daemon=True)
_loop_thread.start()
return _loop
class GraphitiBackend(GraphBackend):
def __init__(
self,
uri: Optional[str] = None,
user: Optional[str] = None,
password: Optional[str] = None,
):
self._uri = uri or Config.NEO4J_URI
self._user = user or Config.NEO4J_USER
self._password = password or Config.NEO4J_PASSWORD
if not self._password:
raise ValueError("NEO4J_PASSWORD is not configured")
self._entity_types: Dict[str, Any] = {}
self._edge_types: Dict[str, Any] = {}
self._entity_defs: Dict[str, Any] = {}
self._edge_defs: Dict[str, Any] = {}
self._client = self._build_client()
def _build_client(self):
from graphiti_core import Graphiti
from graphiti_core.llm_client.openai_generic_client import OpenAIGenericClient
from graphiti_core.llm_client.config import LLMConfig
from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
from openai import AsyncOpenAI
llm_base_url, llm_query = parse_azure_url(Config.LLM_BASE_URL)
small_base_url, small_query = parse_azure_url(Config.LLM_SMALL_BASE_URL)
embed_base_url, embed_query = parse_azure_url(Config.LLM_EMBED_BASE_URL)
# Pre-built async clients so api-version is passed as default_query (Azure requirement)
async_llm_client = AsyncOpenAI(
api_key=Config.LLM_API_KEY,
base_url=llm_base_url,
default_query=llm_query or None,
)
async_small_client = AsyncOpenAI(
api_key=Config.LLM_SMALL_API_KEY,
base_url=small_base_url,
default_query=small_query or None,
)
async_embed_client = AsyncOpenAI(
api_key=Config.LLM_EMBED_API_KEY,
base_url=embed_base_url,
default_query=embed_query or None,
)
llm_config = LLMConfig(
api_key=Config.LLM_API_KEY,
model=Config.LLM_MODEL_NAME,
small_model=Config.LLM_SMALL_MODEL_NAME,
base_url=llm_base_url,
)
llm_client = _make_azure_generic_client(config=llm_config, client=async_llm_client)
embedder = OpenAIEmbedder(
config=OpenAIEmbedderConfig(
api_key=Config.LLM_EMBED_API_KEY,
base_url=embed_base_url,
embedding_model=Config.LLM_EMBED_MODEL_NAME,
),
client=async_embed_client,
)
cross_encoder = OpenAIRerankerClient(config=llm_config, client=async_small_client)
client = Graphiti(
uri=self._uri,
user=self._user,
password=self._password,
llm_client=llm_client,
embedder=embedder,
cross_encoder=cross_encoder,
)
self._patch_extract_entity_attributes()
return client
@staticmethod
def _patch_extract_entity_attributes() -> None:
"""Monkey-patch graphiti internals to fix two LLM quirks:
1. _extract_entity_attributes: some LLMs wrap attribute values in nested
dicts ({"value": "CTTI"}). Neo4j rejects these — flatten them.
2. _extract_nodes_single: some LLMs omit entity_type_id from extracted
entities, causing a Pydantic ValidationError. Default missing IDs to 0
(the generic "Entity" type) before validation runs.
"""
import graphiti_core.utils.maintenance.node_operations as _node_ops
# --- patch 1: attribute flattening ---
original_attrs = _node_ops._extract_entity_attributes
async def _patched_attrs(llm_client, node, episode, previous_episodes, entity_type):
result = await original_attrs(llm_client, node, episode, previous_episodes, entity_type)
return _flatten_attributes(result) if result else result
_node_ops._extract_entity_attributes = _patched_attrs
# --- patch 2: entity_type_id defaulting ---
original_nodes = _node_ops._extract_nodes_single
async def _patched_nodes(llm_client, episode, context):
from graphiti_core.utils.maintenance.node_operations import ExtractedEntities
# Call the LLM the normal way but catch the Pydantic validation error
# that arises when the LLM forgets entity_type_id.
try:
return await original_nodes(llm_client, episode, context)
except Exception as exc:
# Only intercept Pydantic validation errors about entity_type_id
if "entity_type_id" not in str(exc):
raise
logger.warning(f"LLM omitted entity_type_id — defaulting to 0 and retrying validation: {exc}")
# Re-run the LLM call via the internal helper to get the raw dict
from graphiti_core.utils.maintenance.node_operations import _call_extraction_llm
llm_response = await _call_extraction_llm(llm_client, episode, context)
# Inject entity_type_id=0 for any entity that is missing it
entities = llm_response.get("extracted_entities", [])
for ent in entities:
if isinstance(ent, dict) and "entity_type_id" not in ent:
ent["entity_type_id"] = 0
response_object = ExtractedEntities(**llm_response)
return response_object.extracted_entities
_node_ops._extract_nodes_single = _patched_nodes
def create_graph(self, graph_id: str, name: str, description: str = "") -> None:
logger.info(f"Graphiti graph namespace ready: {graph_id}")
def set_ontology(self, graph_ids: List[str], entities: Dict[str, Any], edges: Dict[str, Any]) -> None:
from pydantic import BaseModel as _BaseModel, Field as _Field
def _make_model(name: str, type_def: Any) -> Any:
if isinstance(type_def, dict):
doc = type_def.get("description", "")
attrs_defs = type_def.get("attributes", [])
else:
doc = getattr(type_def, "__doc__", "") or ""
attrs_defs = []
annotations: Dict[str, Any] = {}
fields: Dict[str, Any] = {"__doc__": doc, "__annotations__": annotations}
for attr in attrs_defs:
attr_name = attr.get("name", "")
attr_desc = attr.get("description", attr_name)
if not attr_name:
continue
annotations[attr_name] = Optional[str]
fields[attr_name] = _Field(default=None, description=attr_desc)
return type(name, (_BaseModel,), fields)
self._entity_types: Dict[str, Any] = {
name: _make_model(name, td) for name, td in (entities or {}).items()
}
self._edge_types: Dict[str, Any] = {
name: _make_model(name, td) for name, td in (edges or {}).items()
}
# Keep a separate plain dict for use in extraction instructions
self._entity_defs: Dict[str, Any] = dict(entities or {})
self._edge_defs: Dict[str, Any] = dict(edges or {})
if self._entity_types:
logger.info(f"Graphiti entity types: {list(self._entity_types.keys())}")
if self._edge_types:
logger.info(f"Graphiti edge types: {list(self._edge_types.keys())}")
def _build_extraction_instructions(self) -> Optional[str]:
"""Return custom instructions that constrain extraction to ontology types and attributes."""
entity_defs = self._entity_defs or {}
edge_defs = self._edge_defs or {}
if not entity_defs and not edge_defs:
return None
parts = []
if entity_defs:
entity_lines = []
for name, td in entity_defs.items():
desc = td.get("description", "") if isinstance(td, dict) else ""
attrs = td.get("attributes", []) if isinstance(td, dict) else []
if attrs:
attr_str = ", ".join(
f"{a['name']} ({a.get('description', a['name'])})"
for a in attrs if a.get("name")
)
entity_lines.append(f" - {name}: {desc} [attributes: {attr_str}]")
else:
entity_lines.append(f" - {name}: {desc}")
parts.append(
"Only classify entities using these types (use 'Entity' only if none fits):\n"
+ "\n".join(entity_lines)
+ "\nFor each entity, extract values for the listed attributes when present in the text."
)
if edge_defs:
edge_names = list(edge_defs.keys())
parts.append(
f"Only use these relationship types: {', '.join(edge_names)}. "
"Do not invent new relationship type names."
)
return "\n\n".join(parts)
def add_batch(self, graph_id: str, episodes: List[Any]) -> List[str]:
from graphiti_core.nodes import EpisodeType
from datetime import datetime, timezone
import time as _time
entity_types = self._entity_types or None
edge_types = self._edge_types or None
instructions = self._build_extraction_instructions()
ids = []
for ep in episodes:
data = ep["data"] if isinstance(ep, dict) else ep.data
ep_id = str(uuid_mod.uuid4())
ids.append(ep_id)
last_exc = None
for attempt in range(3):
try:
_run_async(
self._client.add_episode(
name=ep_id,
episode_body=data,
source_description="MiroFish document chunk",
reference_time=datetime.now(timezone.utc),
source=EpisodeType.text,
group_id=graph_id,
entity_types=entity_types,
edge_types=edge_types,
custom_extraction_instructions=instructions,
),
timeout=300,
)
last_exc = None
break
except Exception as exc:
last_exc = exc
# "node not found" race condition — wait and retry
if "not found" in str(exc).lower() and attempt < 2:
logger.warning(f"Episode {ep_id} attempt {attempt + 1} failed ({exc}), retrying...")
_time.sleep(2 * (attempt + 1))
else:
raise
if last_exc:
raise last_exc
return ids
def get_episode(self, uuid_: str) -> Any:
class _FakeEpisode:
processed = True
return _FakeEpisode()
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
results = _run_async(
self._client.driver.execute_query(
"MATCH (n:Entity {group_id: $gid}) RETURN n",
params={"gid": graph_id},
)
)
nodes = []
for record in results.records:
n = record["n"]
nodes.append({
"uuid": n.get("uuid", n.element_id),
"name": n.get("name", ""),
"labels": list(n.labels),
"summary": n.get("summary", ""),
"attributes": _neo4j_props(n),
"created_at": str(n.get("created_at", "")),
})
return nodes
def get_all_edges(self, graph_id: str, max_items: int = 5000) -> List[Dict[str, Any]]:
results = _run_async(
self._client.driver.execute_query(
"MATCH (s)-[r]->(t) WHERE r.group_id = $gid RETURN s, r, t LIMIT $limit",
params={"gid": graph_id, "limit": max_items},
)
)
if len(results.records) >= max_items:
logger.warning(
f"get_all_edges: result truncated at {max_items} edges for graph {graph_id}"
)
edges = []
for record in results.records:
r = record["r"]
edges.append({
"uuid": r.get("uuid", r.element_id),
"name": r.get("name", type(r).__name__),
"fact": r.get("fact", ""),
"source_node_uuid": record["s"].get("uuid", ""),
"target_node_uuid": record["t"].get("uuid", ""),
"fact_type": r.get("fact_type", ""),
"attributes": _neo4j_props(r),
"created_at": str(r.get("created_at", "")),
"valid_at": str(r.get("valid_at", "")),
"invalid_at": str(r.get("invalid_at", "")),
"expired_at": str(r.get("expired_at", "")),
"episodes": [],
})
return edges
def get_node(self, uuid_: str) -> Dict[str, Any]:
results = _run_async(
self._client.driver.execute_query(
"MATCH (n {uuid: $uuid}) RETURN n LIMIT 1",
params={"uuid": uuid_},
)
)
if not results.records:
return {}
n = results.records[0]["n"]
return {
"uuid": n.get("uuid", ""),
"name": n.get("name", ""),
"labels": list(n.labels),
"summary": n.get("summary", ""),
"attributes": _neo4j_props(n),
}
def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]:
results = _run_async(
self._client.driver.execute_query(
"MATCH (n {uuid: $uuid})-[r]->(t) RETURN r, t "
"UNION MATCH (s)-[r]->(n {uuid: $uuid}) RETURN r, s as t",
params={"uuid": node_uuid},
)
)
edges = []
for record in results.records:
r = record["r"]
edges.append({
"uuid": r.get("uuid", r.element_id),
"name": r.get("name", ""),
"fact": r.get("fact", ""),
"source_node_uuid": r.get("source_node_uuid", node_uuid),
"target_node_uuid": r.get("target_node_uuid", ""),
})
return edges
def search(self, graph_id: str, query: str, limit: int = 10, scope: str = "edges") -> Dict[str, Any]:
max_retries = 3
delay = 2.0
last_exc = None
for attempt in range(max_retries):
try:
results = _run_async(
self._client.search(query=query, group_ids=[graph_id], num_results=limit)
)
edges = [
{
"uuid": getattr(r, "uuid", ""),
"name": getattr(r, "name", ""),
"fact": getattr(r, "fact", ""),
"source_node_uuid": getattr(r, "source_node_uuid", ""),
"target_node_uuid": getattr(r, "target_node_uuid", ""),
}
for r in (results or [])
]
return {"edges": edges, "nodes": []}
except Exception as e:
last_exc = e
logger.debug(
f"Graphiti search attempt {attempt + 1}/{max_retries} failed: "
f"{type(e).__name__}: {e}"
)
if attempt < max_retries - 1:
import time as _time
_time.sleep(delay)
delay *= 2
# Reconnect in case the Neo4j TCP connection dropped
try:
self._client = self._build_client()
except Exception as rebuild_exc:
logger.warning(f"Graphiti client rebuild failed: {rebuild_exc}")
import traceback as _tb
logger.error(
f"Graphiti search failed after {max_retries} attempts: "
f"{type(last_exc).__name__}: {last_exc}\n{_tb.format_exc()}"
)
raise last_exc
def add_text(self, graph_id: str, data: str) -> None:
from graphiti_core.nodes import EpisodeType
from datetime import datetime, timezone
ep_id = str(uuid_mod.uuid4())
_run_async(
self._client.add_episode(
name=ep_id,
episode_body=data,
source_description="MiroFish document chunk",
reference_time=datetime.now(timezone.utc),
source=EpisodeType.text,
group_id=graph_id,
entity_types=self._entity_types or None,
edge_types=self._edge_types or None,
custom_extraction_instructions=self._build_extraction_instructions(),
)
)
async def _execute_neo4j_query(self, query: str, parameters: dict = None):
"""Execute a raw Cypher query against the async Neo4j driver."""
kwargs = {"params": parameters} if parameters else {}
return await self._client.driver.execute_query(query, **kwargs)
async def clone_graph(self, src_group_id: str, dst_group_id: str) -> None:
"""Clone all nodes and relationships from src_group_id to dst_group_id."""
clone_nodes_query = """
MATCH (n) WHERE n.group_id = $src
WITH n, properties(n) AS props
CREATE (m) SET m = props SET m.group_id = $dst
"""
await self._execute_neo4j_query(clone_nodes_query, {"src": src_group_id, "dst": dst_group_id})
clone_rels_query = """
MATCH (n)-[r]->(m)
WHERE n.group_id = $src AND m.group_id = $src
MATCH (n2 {uuid: n.uuid, group_id: $dst})
MATCH (m2 {uuid: m.uuid, group_id: $dst})
CALL apoc.create.relationship(n2, type(r), properties(r), m2) YIELD rel
SET rel.group_id = $dst
RETURN rel
"""
await self._execute_neo4j_query(clone_rels_query, {"src": src_group_id, "dst": dst_group_id})
def delete_graph(self, graph_id: str) -> None:
"""Delete all nodes and relationships for a given group_id."""
_run_async(self._execute_neo4j_query(
"MATCH (n) WHERE n.group_id = $gid DETACH DELETE n",
{"gid": graph_id},
))