""" Graphiti Adapter — Drop-in replacement for the Zep Cloud client. Exposes the same namespace as the Zep client so all consuming code (graph_builder, zep_tools, zep_entity_reader, etc.) needs only a one-line import swap: from .graphiti_adapter import GraphitiAdapter self.client = GraphitiAdapter() Then all self.client.graph.* calls work unchanged. """ import asyncio import threading import uuid as _uuid_mod from dataclasses import dataclass, field from datetime import datetime, timezone from typing import Any, Dict, List, Optional from graphiti_core import Graphiti from graphiti_core.nodes import EpisodeType, EntityNode from graphiti_core.edges import EntityEdge from graphiti_core.search.search_config import SearchConfig, SearchResults from graphiti_core.search.search_config_recipes import ( NODE_HYBRID_SEARCH_RRF, EDGE_HYBRID_SEARCH_RRF, ) from graphiti_core.llm_client.config import LLMConfig from graphiti_core.cross_encoder.client import CrossEncoderClient from ..config import Config from ..utils.logger import get_logger from .ollama_reranker import OllamaReranker logger = get_logger('mirofish.graphiti_adapter') class _PassthroughReranker(CrossEncoderClient): """Provider-agnostic no-op reranker. Returns passages in the order Graphiti supplied them with synthetic descending scores. Injected explicitly so Graphiti does not fall back to its default ``OpenAIRerankerClient`` (which uses a hard-coded ``gpt-4.1-nano`` model with logprobs and would 401 against Qwen / Dashscope keys). Selected when ``Config.RERANKER_PROVIDER == "none"`` — useful for CI / slim containers that cannot pull the reranker model. For real reranking, set ``RERANKER_PROVIDER=ollama`` (the default). """ async def rank(self, query: str, passages: list[str]) -> list[tuple[str, float]]: if not passages: return [] return [(p, 1.0 - i * 0.01) for i, p in enumerate(passages)] # --------------------------------------------------------------------------- # Persistent event loop in a dedicated background thread. # All async calls are submitted here so the Neo4j driver (which is bound # to one event loop) never crosses loop boundaries. # --------------------------------------------------------------------------- _loop: Optional[asyncio.AbstractEventLoop] = None _loop_thread: Optional[threading.Thread] = None _loop_lock = threading.Lock() def _get_loop() -> asyncio.AbstractEventLoop: global _loop, _loop_thread if _loop is None: with _loop_lock: if _loop is None: _loop = asyncio.new_event_loop() _loop_thread = threading.Thread( target=_loop.run_forever, daemon=True, name="graphiti-event-loop" ) _loop_thread.start() return _loop def _run(coro): """Submit coroutine to the persistent event loop thread and wait for result.""" future = asyncio.run_coroutine_threadsafe(coro, _get_loop()) return future.result(timeout=300) # --------------------------------------------------------------------------- # Singleton Graphiti instance (one Neo4j driver for the whole process) # --------------------------------------------------------------------------- _graphiti_instance: Optional[Graphiti] = None _graphiti_lock = threading.Lock() _ALLOWED_GRAPHITI_PROVIDERS = ("openai", "gemini") _ALLOWED_RERANKER_PROVIDERS = ("ollama", "none") def _build_reranker(provider: str) -> CrossEncoderClient: """Build the cross-encoder reranker for the configured provider. Defers to ``_PassthroughReranker`` when ``provider`` is ``"none"`` (the legacy no-op behaviour, useful for CI / slim containers that cannot pull the reranker model). For ``"ollama"`` it constructs the real Ollama-backed reranker; the construction is side-effect-free, so Graphiti initialisation does not depend on the Ollama daemon being reachable at startup. """ if provider == "none": return _PassthroughReranker() if provider == "ollama": return OllamaReranker( model=Config.RERANKER_MODEL, base_url=Config.RERANKER_BASE_URL, api_key=Config.RERANKER_API_KEY, ) raise ValueError( f"Unknown RERANKER_PROVIDER={provider!r}; " f"allowed: {_ALLOWED_RERANKER_PROVIDERS}" ) def _build_llm_and_embedder(provider: str): """Build (llm_client, embedder) for the requested Graphiti provider. Lazy-imports provider-specific Graphiti classes so a missing optional dependency for one provider does not break the other at import time. """ if provider == "openai": from graphiti_core.llm_client.openai_client import OpenAIClient from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig llm_client = OpenAIClient( config=LLMConfig( api_key=Config.LLM_API_KEY, base_url=Config.LLM_BASE_URL, model=Config.LLM_MODEL_NAME, ) ) embedder = OpenAIEmbedder( config=OpenAIEmbedderConfig( api_key=Config.EMBEDDING_API_KEY or Config.LLM_API_KEY, base_url=Config.EMBEDDING_BASE_URL or Config.LLM_BASE_URL, embedding_model=Config.EMBEDDING_MODEL, ) ) return llm_client, embedder if provider == "gemini": from graphiti_core.llm_client.gemini_client import GeminiClient from graphiti_core.embedder.gemini import GeminiEmbedder, GeminiEmbedderConfig llm_client = GeminiClient( config=LLMConfig( api_key=Config.LLM_API_KEY, model=Config.LLM_MODEL_NAME, ) ) embedder = GeminiEmbedder( config=GeminiEmbedderConfig( api_key=Config.LLM_API_KEY, embedding_model=Config.EMBEDDING_MODEL, ) ) return llm_client, embedder raise ValueError( f"Unknown GRAPHITI_LLM_PROVIDER={provider!r}; " f"allowed: {_ALLOWED_GRAPHITI_PROVIDERS}" ) def _get_graphiti() -> Graphiti: global _graphiti_instance if _graphiti_instance is None: with _graphiti_lock: if _graphiti_instance is None: provider = (Config.GRAPHITI_LLM_PROVIDER or "openai").lower() logger.info(f"Initializing Graphiti client (provider={provider})...") reranker_provider = (Config.RERANKER_PROVIDER or "ollama").lower() logger.info( f"Initializing Graphiti reranker (provider={reranker_provider})..." ) llm_client, embedder = _build_llm_and_embedder(provider) cross_encoder = _build_reranker(reranker_provider) g = Graphiti( Config.NEO4J_URI, Config.NEO4J_USER, Config.NEO4J_PASSWORD, llm_client=llm_client, embedder=embedder, cross_encoder=cross_encoder, ) # Use the persistent loop so the driver is bound to it from the start _run(g.build_indices_and_constraints()) _graphiti_instance = g logger.info("Graphiti client ready.") return _graphiti_instance # --------------------------------------------------------------------------- # Compatibility data classes (mimic Zep response objects) # --------------------------------------------------------------------------- @dataclass class _NodeResult: """Zep-compatible node object.""" uuid_: str name: str labels: List[str] summary: str attributes: Dict[str, Any] created_at: Optional[str] = None @property def uuid(self): return self.uuid_ @dataclass class _EdgeResult: """Zep-compatible edge object.""" uuid_: str name: str fact: str source_node_uuid: str target_node_uuid: str attributes: Dict[str, Any] created_at: Optional[str] = None valid_at: Optional[str] = None invalid_at: Optional[str] = None expired_at: Optional[str] = None @property def uuid(self): return self.uuid_ @dataclass class _EpisodeResult: """Zep-compatible episode object — always processed (Graphiti is sync).""" uuid_: str processed: bool = True @property def uuid(self): return self.uuid_ @dataclass class _SearchResults: """Zep-compatible search result object.""" edges: List[_EdgeResult] = field(default_factory=list) nodes: List[_NodeResult] = field(default_factory=list) # --------------------------------------------------------------------------- # Helpers: convert Graphiti objects → Zep-compatible objects # --------------------------------------------------------------------------- def _to_ts(dt: Optional[datetime]) -> Optional[str]: if dt is None: return None return dt.isoformat() def _entity_node_to_result(n: EntityNode) -> _NodeResult: return _NodeResult( uuid_=n.uuid, name=n.name, labels=list(n.labels) if n.labels else ["Entity"], summary=n.summary or "", attributes=n.attributes or {}, created_at=_to_ts(n.created_at), ) def _entity_edge_to_result(e: EntityEdge) -> _EdgeResult: return _EdgeResult( uuid_=e.uuid, name=e.name or "", fact=e.fact or "", source_node_uuid=e.source_node_uuid, target_node_uuid=e.target_node_uuid, attributes={}, created_at=_to_ts(e.created_at), valid_at=_to_ts(e.valid_at), invalid_at=_to_ts(e.invalid_at), expired_at=_to_ts(e.expired_at), ) def _neo4j_record_to_node(record: Dict) -> _NodeResult: labels = record.get("labels", ["Entity"]) if isinstance(labels, (list, tuple)): labels = [str(l) for l in labels] return _NodeResult( uuid_=record.get("uuid", ""), name=record.get("name", ""), labels=labels, summary=record.get("summary", ""), attributes=record.get("attributes") or {}, created_at=str(record.get("created_at", "")) or None, ) def _neo4j_record_to_edge(record: Dict) -> _EdgeResult: def ts(v): return str(v) if v else None return _EdgeResult( uuid_=record.get("uuid", ""), name=record.get("name", ""), fact=record.get("fact", ""), source_node_uuid=record.get("source_node_uuid", ""), target_node_uuid=record.get("target_node_uuid", ""), attributes=record.get("attributes") or {}, created_at=ts(record.get("created_at")), valid_at=ts(record.get("valid_at")), invalid_at=ts(record.get("invalid_at")), expired_at=ts(record.get("expired_at")), ) # --------------------------------------------------------------------------- # Neo4j direct query helpers # --------------------------------------------------------------------------- async def _neo4j_query(graphiti: Graphiti, cypher: str, params: Dict) -> List[Dict]: """Execute a read Cypher query and return list of record dicts.""" records, _, _ = await graphiti.driver.execute_query(cypher, params) return [dict(r) for r in records] async def _neo4j_write(graphiti: Graphiti, cypher: str, params: Dict) -> None: """Execute a write Cypher query.""" await graphiti.driver.execute_query(cypher, params) # Cypher queries _NODES_BY_GROUP = """ MATCH (n:Entity {group_id: $group_id}) RETURN n.uuid AS uuid, n.name AS name, n.summary AS summary, labels(n) AS labels, n.created_at AS created_at, n.attributes AS attributes ORDER BY n.created_at ASC SKIP $skip LIMIT $limit """ _EDGES_BY_GROUP = """ MATCH (s:Entity {group_id: $group_id})-[r:RELATES_TO]->(t:Entity {group_id: $group_id}) RETURN r.uuid AS uuid, r.name AS name, r.fact AS fact, s.uuid AS source_node_uuid, t.uuid AS target_node_uuid, r.created_at AS created_at, r.valid_at AS valid_at, r.invalid_at AS invalid_at, r.expired_at AS expired_at, r.attributes AS attributes ORDER BY r.created_at ASC SKIP $skip LIMIT $limit """ _NODE_BY_UUID = """ MATCH (n:Entity {uuid: $uuid}) RETURN n.uuid AS uuid, n.name AS name, n.summary AS summary, labels(n) AS labels, n.created_at AS created_at, n.group_id AS group_id, n.attributes AS attributes LIMIT 1 """ _EDGES_BY_NODE_UUID = """ MATCH (s:Entity {uuid: $node_uuid})-[r:RELATES_TO]->(t:Entity) RETURN r.uuid AS uuid, r.name AS name, r.fact AS fact, s.uuid AS source_node_uuid, t.uuid AS target_node_uuid, r.created_at AS created_at, r.valid_at AS valid_at, r.invalid_at AS invalid_at, r.expired_at AS expired_at UNION MATCH (s:Entity)-[r:RELATES_TO]->(t:Entity {uuid: $node_uuid}) RETURN r.uuid AS uuid, r.name AS name, r.fact AS fact, s.uuid AS source_node_uuid, t.uuid AS target_node_uuid, r.created_at AS created_at, r.valid_at AS valid_at, r.invalid_at AS invalid_at, r.expired_at AS expired_at """ _DELETE_GROUP = """ MATCH (n:Entity {group_id: $group_id}) DETACH DELETE n """ # --------------------------------------------------------------------------- # Sub-namespaces # --------------------------------------------------------------------------- class _EpisodeNamespace: def get(self, uuid_: str) -> _EpisodeResult: """Always returns processed=True — Graphiti is synchronous.""" return _EpisodeResult(uuid_=uuid_, processed=True) class _NodeNamespace: def __init__(self, graphiti: Graphiti): self._g = graphiti def get_by_graph_id( self, graph_id: str, limit: int = 100, uuid_cursor: Optional[str] = None, ) -> List[_NodeResult]: """Return nodes for a group. First call returns all; cursor call returns empty.""" if uuid_cursor is not None: # Already fetched all on first call — signal end of pagination return [] records = _run(_neo4j_query( self._g, _NODES_BY_GROUP, {"group_id": graph_id, "skip": 0, "limit": 10000} )) return [_neo4j_record_to_node(r) for r in records] def get(self, uuid_: str) -> _NodeResult: records = _run(_neo4j_query(self._g, _NODE_BY_UUID, {"uuid": uuid_})) if not records: return _NodeResult(uuid_=uuid_, name="", labels=[], summary="", attributes={}) return _neo4j_record_to_node(records[0]) def get_entity_edges(self, node_uuid: str) -> List[_EdgeResult]: records = _run(_neo4j_query( self._g, _EDGES_BY_NODE_UUID, {"node_uuid": node_uuid} )) return [_neo4j_record_to_edge(r) for r in records] class _EdgeNamespace: def __init__(self, graphiti: Graphiti): self._g = graphiti def get_by_graph_id( self, graph_id: str, limit: int = 100, uuid_cursor: Optional[str] = None, ) -> List[_EdgeResult]: """Return edges for a group. First call returns all; cursor call returns empty.""" if uuid_cursor is not None: return [] records = _run(_neo4j_query( self._g, _EDGES_BY_GROUP, {"group_id": graph_id, "skip": 0, "limit": 50000} )) return [_neo4j_record_to_edge(r) for r in records] class _GraphNamespace: def __init__(self, graphiti: Graphiti): self._g = graphiti self.node = _NodeNamespace(graphiti) self.edge = _EdgeNamespace(graphiti) self.episode = _EpisodeNamespace() self._ontologies: Dict[str, Dict] = {} # graph_id -> ontology dict def create(self, graph_id: str, name: str, description: str = "") -> None: """No-op — Graphiti uses group_id implicitly, no explicit creation needed.""" logger.info(f"Graph '{graph_id}' registered (group_id in Graphiti)") def set_ontology( self, graph_ids: List[str], entities: Any = None, edges: Any = None, ) -> None: """Store ontology hints for use during episode ingestion. Graphiti extracts entities dynamically.""" for gid in graph_ids: self._ontologies[gid] = {"entities": entities, "edges": edges} logger.info(f"Ontology hints stored for graphs: {graph_ids}") def add(self, graph_id: str, type: str = "text", data: str = "") -> _EpisodeResult: """Add a single text episode to the graph.""" result = _run(self._g.add_episode( name=f"activity_{_uuid_mod.uuid4().hex[:8]}", episode_body=data, source_description="MiroFish simulation activity", reference_time=datetime.now(timezone.utc), source=EpisodeType.text, group_id=graph_id, update_communities=False, )) ep_uuid_out = result.episode.uuid if result and result.episode else str(_uuid_mod.uuid4()) return _EpisodeResult(uuid_=ep_uuid_out) def add_batch(self, graph_id: str, episodes: List[Any]) -> List[_EpisodeResult]: """Add a batch of episodes. Returns one _EpisodeResult per episode in input order. On the first ingestion failure the underlying exception is logged at ERROR level (with traceback) and re-raised; episodes successfully ingested before the failure remain committed in Neo4j. The caller (the graph-build worker) translates the propagated exception into Task.status = FAILED with the underlying error message — never substitute a placeholder UUID, since that would produce a Task that looks completed while the graph is empty. """ results = [] for index, ep in enumerate(episodes): text = getattr(ep, 'data', '') or str(ep) try: result = _run(self._g.add_episode( name=f"chunk_{_uuid_mod.uuid4().hex[:8]}", episode_body=text, source_description="MiroFish document chunk", reference_time=datetime.now(timezone.utc), source=EpisodeType.text, group_id=graph_id, update_communities=False, )) except Exception: logger.exception( "Episode add failed (group_id=%s, episode_index=%d)", graph_id, index, ) raise ep_uuid_out = result.episode.uuid if result and result.episode else str(_uuid_mod.uuid4()) results.append(_EpisodeResult(uuid_=ep_uuid_out)) return results def search( self, graph_id: str, query: str, limit: int = 10, scope: str = "edges", ) -> _SearchResults: """Semantic search over the graph. scope='edges'|'nodes'|'both'.""" try: if scope == "nodes": results = _run(self._g.search_( query=query, config=SearchConfig( node_config=NODE_HYBRID_SEARCH_RRF.node_config, limit=limit, ), group_ids=[graph_id], )) nodes = [_entity_node_to_result(n) for n in (results.nodes or [])] return _SearchResults(nodes=nodes) else: edges = _run(self._g.search( query=query, group_ids=[graph_id], num_results=limit, )) return _SearchResults(edges=[_entity_edge_to_result(e) for e in (edges or [])]) except Exception as e: logger.warning(f"Graph search failed: {str(e)[:150]}") return _SearchResults() def delete(self, graph_id: str) -> None: """Delete all nodes and edges for a group_id.""" _run(_neo4j_write(self._g, _DELETE_GROUP, {"group_id": graph_id})) logger.info(f"Graph '{graph_id}' deleted from Neo4j") # --------------------------------------------------------------------------- # Main adapter class — drop-in for Zep(api_key=...) # --------------------------------------------------------------------------- class GraphitiAdapter: """ Drop-in replacement for `from zep_cloud.client import Zep`. Usage: self.client = GraphitiAdapter() self.client.graph.create(graph_id, name) self.client.graph.search(graph_id, query, limit, scope) self.client.graph.node.get(uuid_) ... """ def __init__(self, api_key: Optional[str] = None): # api_key ignored — kept for signature compatibility graphiti = _get_graphiti() self.graph = _GraphNamespace(graphiti)