876 lines
30 KiB
Python
876 lines
30 KiB
Python
"""
|
||
Graphiti + Neo4j graph backend implementation.
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import threading
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
from pydantic import BaseModel, Field, create_model
|
||
|
||
from ..config import Config
|
||
from .base import GraphBackend
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class _CompatEpisode:
|
||
uuid: str
|
||
processed: bool = True
|
||
name: str = ""
|
||
content: str = ""
|
||
valid_at: Optional[datetime] = None
|
||
created_at: Optional[datetime] = None
|
||
|
||
@property
|
||
def uuid_(self) -> str:
|
||
return self.uuid
|
||
|
||
|
||
@dataclass
|
||
class _CompatNode:
|
||
uuid: str
|
||
name: str = ""
|
||
labels: List[str] = field(default_factory=list)
|
||
summary: str = ""
|
||
attributes: Dict[str, Any] = field(default_factory=dict)
|
||
created_at: Optional[datetime] = None
|
||
|
||
@property
|
||
def uuid_(self) -> str:
|
||
return self.uuid
|
||
|
||
|
||
@dataclass
|
||
class _CompatEdge:
|
||
uuid: str
|
||
name: str
|
||
fact: str
|
||
source_node_uuid: str
|
||
target_node_uuid: str
|
||
source_node_name: str = ""
|
||
target_node_name: str = ""
|
||
attributes: Dict[str, Any] = field(default_factory=dict)
|
||
episodes: List[str] = field(default_factory=list)
|
||
created_at: Optional[datetime] = None
|
||
valid_at: Optional[datetime] = None
|
||
invalid_at: Optional[datetime] = None
|
||
expired_at: Optional[datetime] = None
|
||
|
||
@property
|
||
def uuid_(self) -> str:
|
||
return self.uuid
|
||
|
||
|
||
@dataclass
|
||
class _CompatSearchResults:
|
||
edges: List[_CompatEdge] = field(default_factory=list)
|
||
nodes: List[_CompatNode] = field(default_factory=list)
|
||
|
||
|
||
@dataclass
|
||
class _OntologyBundle:
|
||
entity_types: Dict[str, type[BaseModel]] = field(default_factory=dict)
|
||
edge_types: Dict[str, type[BaseModel]] = field(default_factory=dict)
|
||
edge_type_map: Dict[tuple[str, str], List[str]] = field(default_factory=dict)
|
||
spec: Dict[str, Any] = field(default_factory=dict)
|
||
|
||
|
||
class _AsyncBridge:
|
||
"""Run async Graphiti calls from the app's synchronous service layer."""
|
||
|
||
def __init__(self):
|
||
self._ready = threading.Event()
|
||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
||
self._thread.start()
|
||
self._ready.wait()
|
||
|
||
def _run_loop(self):
|
||
loop = asyncio.new_event_loop()
|
||
asyncio.set_event_loop(loop)
|
||
self._loop = loop
|
||
self._ready.set()
|
||
loop.run_forever()
|
||
|
||
def run(self, coro):
|
||
if self._loop is None:
|
||
raise RuntimeError("Graphiti async loop 未初始化")
|
||
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
||
return future.result()
|
||
|
||
|
||
class GraphitiBackend(GraphBackend):
|
||
"""Graph backend backed by Graphiti OSS + Neo4j."""
|
||
|
||
_bridge: Optional[_AsyncBridge] = None
|
||
_bridge_lock = threading.Lock()
|
||
_indices_ready = False
|
||
_indices_lock = threading.Lock()
|
||
_ontology_registry: Dict[str, _OntologyBundle] = {}
|
||
_ontology_lock = threading.Lock()
|
||
_cross_encoder_warning_emitted = False
|
||
PAGE_SIZE = 200
|
||
|
||
def __init__(self, api_key: Optional[str] = None):
|
||
del api_key
|
||
|
||
errors = Config.get_graphiti_config_errors()
|
||
if errors:
|
||
raise ValueError("; ".join(errors))
|
||
|
||
try:
|
||
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
||
from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig
|
||
from graphiti_core.graphiti import Graphiti
|
||
from graphiti_core.llm_client import OpenAIClient
|
||
from graphiti_core.llm_client.config import LLMConfig
|
||
from graphiti_core.llm_client.openai_generic_client import OpenAIGenericClient
|
||
except ImportError as exc:
|
||
raise ImportError(
|
||
"Graphiti 依赖未安装,请先在 backend 环境中安装 graphiti-core 与 neo4j"
|
||
) from exc
|
||
|
||
llm_config = LLMConfig(
|
||
api_key=Config.GRAPHITI_LLM_API_KEY,
|
||
base_url=Config.GRAPHITI_LLM_BASE_URL,
|
||
model=Config.GRAPHITI_LLM_MODEL,
|
||
small_model=Config.GRAPHITI_LLM_SMALL_MODEL,
|
||
temperature=0,
|
||
)
|
||
reranker_config = LLMConfig(
|
||
api_key=Config.GRAPHITI_RERANKER_API_KEY,
|
||
base_url=Config.GRAPHITI_RERANKER_BASE_URL,
|
||
model=Config.GRAPHITI_RERANKER_MODEL,
|
||
temperature=0,
|
||
)
|
||
embedder_config = OpenAIEmbedderConfig(
|
||
api_key=Config.GRAPHITI_EMBEDDER_API_KEY,
|
||
base_url=Config.GRAPHITI_EMBEDDER_BASE_URL,
|
||
embedding_model=Config.GRAPHITI_EMBEDDER_MODEL,
|
||
embedding_dim=Config.GRAPHITI_EMBEDDER_DIM,
|
||
)
|
||
|
||
llm_client_mode = (Config.GRAPHITI_LLM_CLIENT_MODE or "openai").lower()
|
||
if llm_client_mode == "generic":
|
||
llm_client = OpenAIGenericClient(
|
||
config=llm_config,
|
||
max_tokens=Config.GRAPHITI_LLM_MAX_TOKENS,
|
||
)
|
||
else:
|
||
llm_client = OpenAIClient(
|
||
config=llm_config,
|
||
max_tokens=Config.GRAPHITI_LLM_MAX_TOKENS,
|
||
)
|
||
|
||
self._graphiti = Graphiti(
|
||
uri=Config.GRAPHITI_URI,
|
||
user=Config.GRAPHITI_USER,
|
||
password=Config.GRAPHITI_PASSWORD,
|
||
llm_client=llm_client,
|
||
embedder=OpenAIEmbedder(config=embedder_config),
|
||
cross_encoder=OpenAIRerankerClient(config=reranker_config),
|
||
max_coroutines=Config.GRAPHITI_MAX_COROUTINES,
|
||
)
|
||
self._driver = self._graphiti.driver.with_database(Config.GRAPHITI_DATABASE)
|
||
self._graphiti.driver = self._driver
|
||
self._graphiti.clients.driver = self._driver
|
||
self._bridge = self._get_bridge()
|
||
|
||
self._ensure_indices()
|
||
|
||
@classmethod
|
||
def _get_bridge(cls) -> _AsyncBridge:
|
||
with cls._bridge_lock:
|
||
if cls._bridge is None:
|
||
cls._bridge = _AsyncBridge()
|
||
return cls._bridge
|
||
|
||
@property
|
||
def raw_client(self) -> Any:
|
||
return self._graphiti
|
||
|
||
def _run(self, coro):
|
||
return self._bridge.run(coro)
|
||
|
||
def _ensure_indices(self) -> None:
|
||
if self.__class__._indices_ready:
|
||
return
|
||
|
||
with self.__class__._indices_lock:
|
||
if self.__class__._indices_ready:
|
||
return
|
||
self._run(self._graphiti.build_indices_and_constraints())
|
||
self.__class__._indices_ready = True
|
||
|
||
def _validate_graph_id(self, graph_id: str) -> None:
|
||
from graphiti_core.helpers import validate_group_id
|
||
|
||
validate_group_id(graph_id)
|
||
|
||
def _normalize_model_spec(self, model: type[BaseModel]) -> Dict[str, Any]:
|
||
fields = []
|
||
for field_name, model_field in model.model_fields.items():
|
||
fields.append(
|
||
{
|
||
"name": field_name,
|
||
"description": model_field.description or field_name,
|
||
}
|
||
)
|
||
|
||
return {
|
||
"description": (getattr(model, "__doc__", "") or "").strip(),
|
||
"fields": fields,
|
||
}
|
||
|
||
def _build_dynamic_model(self, name: str, spec: Dict[str, Any]) -> type[BaseModel]:
|
||
field_definitions = {}
|
||
for field_spec in spec.get("fields", []):
|
||
field_name = field_spec.get("name", "").strip()
|
||
if not field_name:
|
||
continue
|
||
field_definitions[field_name] = (
|
||
Optional[str],
|
||
Field(
|
||
default=None,
|
||
description=field_spec.get("description") or field_name,
|
||
),
|
||
)
|
||
|
||
model = create_model(name, __base__=BaseModel, **field_definitions)
|
||
model.__doc__ = spec.get("description") or name
|
||
return model
|
||
|
||
def _serialize_ontology_spec(
|
||
self,
|
||
entity_specs: Dict[str, Dict[str, Any]],
|
||
edge_specs: Dict[str, Dict[str, Any]],
|
||
edge_type_map: Dict[tuple[str, str], List[str]],
|
||
) -> Dict[str, Any]:
|
||
return {
|
||
"entity_types": entity_specs,
|
||
"edge_types": edge_specs,
|
||
"edge_type_map": [
|
||
{
|
||
"source": source,
|
||
"target": target,
|
||
"edges": edge_names,
|
||
}
|
||
for (source, target), edge_names in sorted(edge_type_map.items())
|
||
],
|
||
}
|
||
|
||
def _bundle_from_spec(self, spec: Dict[str, Any]) -> _OntologyBundle:
|
||
entity_types = {
|
||
name: self._build_dynamic_model(name, model_spec)
|
||
for name, model_spec in (spec.get("entity_types") or {}).items()
|
||
}
|
||
edge_types = {
|
||
name: self._build_dynamic_model(name, model_spec)
|
||
for name, model_spec in (spec.get("edge_types") or {}).items()
|
||
}
|
||
edge_type_map: Dict[tuple[str, str], List[str]] = {}
|
||
for entry in spec.get("edge_type_map") or []:
|
||
source = entry.get("source", "Entity")
|
||
target = entry.get("target", "Entity")
|
||
edge_type_map[(source, target)] = list(entry.get("edges") or [])
|
||
|
||
return _OntologyBundle(
|
||
entity_types=entity_types,
|
||
edge_types=edge_types,
|
||
edge_type_map=edge_type_map,
|
||
spec=spec,
|
||
)
|
||
|
||
def _build_ontology_bundle(self, entities: Any = None, edges: Any = None) -> _OntologyBundle:
|
||
entity_specs = {}
|
||
entity_types = {}
|
||
for entity_name, entity_model in (entities or {}).items():
|
||
entity_spec = self._normalize_model_spec(entity_model)
|
||
entity_specs[entity_name] = entity_spec
|
||
entity_types[entity_name] = self._build_dynamic_model(entity_name, entity_spec)
|
||
|
||
edge_specs = {}
|
||
edge_types = {}
|
||
edge_type_map: Dict[tuple[str, str], List[str]] = {}
|
||
for edge_name, edge_value in (edges or {}).items():
|
||
if not isinstance(edge_value, tuple) or len(edge_value) != 2:
|
||
continue
|
||
edge_model, source_targets = edge_value
|
||
edge_spec = self._normalize_model_spec(edge_model)
|
||
edge_specs[edge_name] = edge_spec
|
||
edge_types[edge_name] = self._build_dynamic_model(edge_name, edge_spec)
|
||
|
||
for source_target in source_targets or []:
|
||
source = getattr(source_target, "source", "Entity") or "Entity"
|
||
target = getattr(source_target, "target", "Entity") or "Entity"
|
||
edge_type_map.setdefault((source, target), []).append(edge_name)
|
||
|
||
if not edge_type_map:
|
||
edge_type_map = {("Entity", "Entity"): list(edge_types.keys())}
|
||
|
||
spec = self._serialize_ontology_spec(entity_specs, edge_specs, edge_type_map)
|
||
return _OntologyBundle(
|
||
entity_types=entity_types,
|
||
edge_types=edge_types,
|
||
edge_type_map=edge_type_map,
|
||
spec=spec,
|
||
)
|
||
|
||
async def _upsert_graph_metadata_async(
|
||
self,
|
||
graph_id: str,
|
||
name: Optional[str] = None,
|
||
description: Optional[str] = None,
|
||
ontology_spec: Optional[Dict[str, Any]] = None,
|
||
) -> None:
|
||
payload = json.dumps(ontology_spec, ensure_ascii=False) if ontology_spec is not None else None
|
||
|
||
await self._driver.execute_query(
|
||
"""
|
||
MERGE (m:GraphMetadata {graph_id: $graph_id})
|
||
ON CREATE SET
|
||
m.group_id = $graph_id,
|
||
m.created_at = datetime()
|
||
SET
|
||
m.name = CASE WHEN $name IS NULL OR $name = '' THEN coalesce(m.name, '') ELSE $name END,
|
||
m.description = CASE
|
||
WHEN $description IS NULL OR $description = ''
|
||
THEN coalesce(m.description, '')
|
||
ELSE $description
|
||
END,
|
||
m.ontology_json = CASE
|
||
WHEN $ontology_json IS NULL OR $ontology_json = ''
|
||
THEN m.ontology_json
|
||
ELSE $ontology_json
|
||
END,
|
||
m.updated_at = datetime()
|
||
""",
|
||
graph_id=graph_id,
|
||
name=name,
|
||
description=description,
|
||
ontology_json=payload,
|
||
)
|
||
|
||
async def _load_ontology_bundle_async(self, graph_id: str) -> _OntologyBundle:
|
||
records, _, _ = await self._driver.execute_query(
|
||
"""
|
||
MATCH (m:GraphMetadata {graph_id: $graph_id})
|
||
RETURN m.ontology_json AS ontology_json
|
||
""",
|
||
graph_id=graph_id,
|
||
routing_="r",
|
||
)
|
||
|
||
if not records:
|
||
return _OntologyBundle()
|
||
|
||
ontology_json = records[0].get("ontology_json")
|
||
if not ontology_json:
|
||
return _OntologyBundle()
|
||
|
||
try:
|
||
spec = json.loads(ontology_json)
|
||
except (TypeError, ValueError, json.JSONDecodeError):
|
||
logger.warning("Graphiti ontology metadata 解析失败,graph_id=%s", graph_id)
|
||
return _OntologyBundle()
|
||
|
||
return self._bundle_from_spec(spec)
|
||
|
||
async def _get_ontology_bundle_async(self, graph_id: str) -> _OntologyBundle:
|
||
with self.__class__._ontology_lock:
|
||
cached = self.__class__._ontology_registry.get(graph_id)
|
||
if cached is not None:
|
||
return cached
|
||
|
||
bundle = await self._load_ontology_bundle_async(graph_id)
|
||
with self.__class__._ontology_lock:
|
||
self.__class__._ontology_registry[graph_id] = bundle
|
||
return bundle
|
||
|
||
def _get_ontology_bundle(self, graph_id: str) -> _OntologyBundle:
|
||
return self._run(self._get_ontology_bundle_async(graph_id))
|
||
|
||
def _set_ontology_bundle(self, graph_id: str, bundle: _OntologyBundle) -> None:
|
||
with self.__class__._ontology_lock:
|
||
self.__class__._ontology_registry[graph_id] = bundle
|
||
|
||
def get_ontology_spec(self, graph_id: str) -> Optional[Dict[str, Any]]:
|
||
self._validate_graph_id(graph_id)
|
||
bundle = self._get_ontology_bundle(graph_id)
|
||
return dict(bundle.spec) if bundle.spec else None
|
||
|
||
def create_graph(self, graph_id: str, name: str, description: str) -> None:
|
||
self._validate_graph_id(graph_id)
|
||
self._run(
|
||
self._upsert_graph_metadata_async(
|
||
graph_id=graph_id,
|
||
name=name,
|
||
description=description,
|
||
)
|
||
)
|
||
|
||
def set_ontology(
|
||
self,
|
||
graph_id: str,
|
||
entities: Any = None,
|
||
edges: Any = None,
|
||
) -> None:
|
||
self._validate_graph_id(graph_id)
|
||
bundle = self._build_ontology_bundle(entities=entities, edges=edges)
|
||
self._set_ontology_bundle(graph_id, bundle)
|
||
self._run(
|
||
self._upsert_graph_metadata_async(
|
||
graph_id=graph_id,
|
||
ontology_spec=bundle.spec,
|
||
)
|
||
)
|
||
|
||
async def _add_text_async(self, graph_id: str, data: str) -> _CompatEpisode:
|
||
from graphiti_core.helpers import validate_excluded_entity_types
|
||
from graphiti_core.nodes import EpisodeType, EpisodicNode
|
||
from graphiti_core.search.search_utils import RELEVANT_SCHEMA_LIMIT
|
||
from graphiti_core.utils.datetime_utils import utc_now
|
||
from graphiti_core.utils.maintenance.node_operations import (
|
||
extract_attributes_from_nodes,
|
||
extract_nodes,
|
||
resolve_extracted_nodes,
|
||
)
|
||
from graphiti_core.utils.ontology_utils.entity_types_utils import validate_entity_types
|
||
|
||
bundle = await self._get_ontology_bundle_async(graph_id)
|
||
entity_types = bundle.entity_types or None
|
||
edge_types = bundle.edge_types or None
|
||
edge_type_map = bundle.edge_type_map or {("Entity", "Entity"): []}
|
||
|
||
validate_entity_types(entity_types)
|
||
validate_excluded_entity_types(None, entity_types)
|
||
|
||
now = utc_now()
|
||
previous_episodes = await self._graphiti.retrieve_episodes(
|
||
reference_time=now,
|
||
last_n=RELEVANT_SCHEMA_LIMIT,
|
||
group_ids=[graph_id],
|
||
source=EpisodeType.text,
|
||
driver=self._driver,
|
||
)
|
||
|
||
episode = EpisodicNode(
|
||
name=f"episode_{now.strftime('%Y%m%d%H%M%S%f')}",
|
||
group_id=graph_id,
|
||
labels=[],
|
||
source=EpisodeType.text,
|
||
content=data,
|
||
source_description="text",
|
||
created_at=now,
|
||
valid_at=now,
|
||
)
|
||
|
||
extracted_nodes = await extract_nodes(
|
||
self._graphiti.clients,
|
||
episode,
|
||
previous_episodes,
|
||
entity_types,
|
||
None,
|
||
None,
|
||
)
|
||
nodes, uuid_map, _ = await resolve_extracted_nodes(
|
||
self._graphiti.clients,
|
||
extracted_nodes,
|
||
episode,
|
||
previous_episodes,
|
||
entity_types,
|
||
)
|
||
resolved_edges, invalidated_edges, new_edges = await self._graphiti._extract_and_resolve_edges(
|
||
episode,
|
||
extracted_nodes,
|
||
previous_episodes,
|
||
edge_type_map,
|
||
graph_id,
|
||
edge_types,
|
||
nodes,
|
||
uuid_map,
|
||
None,
|
||
)
|
||
entity_edges = resolved_edges + invalidated_edges
|
||
hydrated_nodes = await extract_attributes_from_nodes(
|
||
self._graphiti.clients,
|
||
nodes,
|
||
episode,
|
||
previous_episodes,
|
||
entity_types,
|
||
edges=new_edges,
|
||
)
|
||
_, saved_episode = await self._graphiti._process_episode_data(
|
||
episode,
|
||
hydrated_nodes,
|
||
entity_edges,
|
||
now,
|
||
graph_id,
|
||
None,
|
||
None,
|
||
)
|
||
|
||
return _CompatEpisode(
|
||
uuid=saved_episode.uuid,
|
||
processed=True,
|
||
name=saved_episode.name,
|
||
content=saved_episode.content,
|
||
valid_at=saved_episode.valid_at,
|
||
created_at=saved_episode.created_at,
|
||
)
|
||
|
||
def add_batch(self, graph_id: str, episodes: List[Any]) -> List[Any]:
|
||
results = []
|
||
for episode in episodes:
|
||
data = getattr(episode, "data", None)
|
||
if data is None and isinstance(episode, dict):
|
||
data = episode.get("data", "")
|
||
results.append(self.add_text(graph_id=graph_id, data=str(data or "")))
|
||
return results
|
||
|
||
def add_text(self, graph_id: str, data: str) -> Any:
|
||
self._validate_graph_id(graph_id)
|
||
return self._run(self._add_text_async(graph_id=graph_id, data=data))
|
||
|
||
async def _get_episode_async(self, episode_uuid: str) -> _CompatEpisode:
|
||
from graphiti_core.nodes import EpisodicNode
|
||
|
||
episode = await EpisodicNode.get_by_uuid(self._driver, episode_uuid)
|
||
return _CompatEpisode(
|
||
uuid=episode.uuid,
|
||
processed=True,
|
||
name=episode.name,
|
||
content=episode.content,
|
||
valid_at=episode.valid_at,
|
||
created_at=episode.created_at,
|
||
)
|
||
|
||
def get_episode(self, episode_uuid: str) -> Any:
|
||
return self._run(self._get_episode_async(episode_uuid))
|
||
|
||
def _warn_cross_encoder_fallback(self) -> None:
|
||
if self.__class__._cross_encoder_warning_emitted:
|
||
return
|
||
logger.info(
|
||
"Graphiti cross_encoder 默认已降级为 rrf;如需启用请设置 GRAPHITI_ENABLE_CROSS_ENCODER=true"
|
||
)
|
||
self.__class__._cross_encoder_warning_emitted = True
|
||
|
||
def _build_search_config(self, scope: str, limit: int, reranker: Optional[str]):
|
||
from graphiti_core.search.search_config import (
|
||
EdgeReranker,
|
||
EdgeSearchConfig,
|
||
EdgeSearchMethod,
|
||
NodeReranker,
|
||
NodeSearchConfig,
|
||
NodeSearchMethod,
|
||
SearchConfig,
|
||
)
|
||
|
||
reranker_name = (reranker or "rrf").strip().lower()
|
||
edge_reranker_map = {
|
||
"rrf": EdgeReranker.rrf,
|
||
"reciprocal_rank_fusion": EdgeReranker.rrf,
|
||
"cross_encoder": EdgeReranker.cross_encoder,
|
||
"node_distance": EdgeReranker.node_distance,
|
||
"episode_mentions": EdgeReranker.episode_mentions,
|
||
"mmr": EdgeReranker.mmr,
|
||
}
|
||
node_reranker_map = {
|
||
"rrf": NodeReranker.rrf,
|
||
"reciprocal_rank_fusion": NodeReranker.rrf,
|
||
"cross_encoder": NodeReranker.cross_encoder,
|
||
"node_distance": NodeReranker.node_distance,
|
||
"episode_mentions": NodeReranker.episode_mentions,
|
||
"mmr": NodeReranker.mmr,
|
||
}
|
||
|
||
edge_reranker = edge_reranker_map.get(reranker_name, EdgeReranker.rrf)
|
||
node_reranker = node_reranker_map.get(reranker_name, NodeReranker.rrf)
|
||
edge_methods = [EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity]
|
||
node_methods = [NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity]
|
||
|
||
if reranker_name == "cross_encoder":
|
||
if Config.GRAPHITI_ENABLE_CROSS_ENCODER:
|
||
edge_methods.append(EdgeSearchMethod.bfs)
|
||
node_methods.append(NodeSearchMethod.bfs)
|
||
else:
|
||
self._warn_cross_encoder_fallback()
|
||
edge_reranker = EdgeReranker.rrf
|
||
node_reranker = NodeReranker.rrf
|
||
|
||
edge_config = None
|
||
node_config = None
|
||
if scope in {"edges", "both"}:
|
||
edge_config = EdgeSearchConfig(
|
||
search_methods=edge_methods,
|
||
reranker=edge_reranker,
|
||
)
|
||
if scope in {"nodes", "both"}:
|
||
node_config = NodeSearchConfig(
|
||
search_methods=node_methods,
|
||
reranker=node_reranker,
|
||
)
|
||
|
||
return SearchConfig(
|
||
edge_config=edge_config,
|
||
node_config=node_config,
|
||
limit=max(1, limit),
|
||
)
|
||
|
||
def _wrap_node(self, node: Any) -> _CompatNode:
|
||
return _CompatNode(
|
||
uuid=getattr(node, "uuid", ""),
|
||
name=getattr(node, "name", "") or "",
|
||
labels=list(getattr(node, "labels", []) or []),
|
||
summary=getattr(node, "summary", "") or "",
|
||
attributes=dict(getattr(node, "attributes", {}) or {}),
|
||
created_at=getattr(node, "created_at", None),
|
||
)
|
||
|
||
def _wrap_edge(
|
||
self,
|
||
edge: Any,
|
||
source_node_name: str = "",
|
||
target_node_name: str = "",
|
||
) -> _CompatEdge:
|
||
return _CompatEdge(
|
||
uuid=getattr(edge, "uuid", ""),
|
||
name=getattr(edge, "name", "") or "",
|
||
fact=getattr(edge, "fact", "") or "",
|
||
source_node_uuid=getattr(edge, "source_node_uuid", "") or "",
|
||
target_node_uuid=getattr(edge, "target_node_uuid", "") or "",
|
||
source_node_name=source_node_name,
|
||
target_node_name=target_node_name,
|
||
attributes=dict(getattr(edge, "attributes", {}) or {}),
|
||
episodes=list(getattr(edge, "episodes", []) or []),
|
||
created_at=getattr(edge, "created_at", None),
|
||
valid_at=getattr(edge, "valid_at", None),
|
||
invalid_at=getattr(edge, "invalid_at", None),
|
||
expired_at=getattr(edge, "expired_at", None),
|
||
)
|
||
|
||
async def _search_async(
|
||
self,
|
||
graph_id: str,
|
||
query: str,
|
||
limit: int,
|
||
scope: str,
|
||
reranker: Optional[str],
|
||
) -> _CompatSearchResults:
|
||
from graphiti_core.nodes import EntityNode
|
||
|
||
search_config = self._build_search_config(scope=scope, limit=limit, reranker=reranker)
|
||
results = await self._graphiti.search_(
|
||
query=query,
|
||
config=search_config,
|
||
group_ids=[graph_id],
|
||
driver=self._driver,
|
||
)
|
||
|
||
nodes = [self._wrap_node(node) for node in results.nodes]
|
||
node_name_map = {node.uuid: node.name for node in nodes if node.uuid}
|
||
|
||
missing_node_ids = {
|
||
node_uuid
|
||
for edge in results.edges
|
||
for node_uuid in (edge.source_node_uuid, edge.target_node_uuid)
|
||
if node_uuid and node_uuid not in node_name_map
|
||
}
|
||
if missing_node_ids:
|
||
for node in await EntityNode.get_by_uuids(self._driver, list(missing_node_ids)):
|
||
node_name_map[node.uuid] = node.name or ""
|
||
|
||
edges = [
|
||
self._wrap_edge(
|
||
edge,
|
||
source_node_name=node_name_map.get(edge.source_node_uuid, ""),
|
||
target_node_name=node_name_map.get(edge.target_node_uuid, ""),
|
||
)
|
||
for edge in results.edges
|
||
]
|
||
|
||
return _CompatSearchResults(edges=edges, nodes=nodes)
|
||
|
||
def search(
|
||
self,
|
||
graph_id: str,
|
||
query: str,
|
||
limit: int = 10,
|
||
scope: str = "edges",
|
||
reranker: Optional[str] = None,
|
||
) -> Any:
|
||
self._validate_graph_id(graph_id)
|
||
return self._run(
|
||
self._search_async(
|
||
graph_id=graph_id,
|
||
query=query,
|
||
limit=limit,
|
||
scope=scope,
|
||
reranker=reranker,
|
||
)
|
||
)
|
||
|
||
|
||
async def _get_all_nodes_async(self, graph_id: str) -> List[_CompatNode]:
|
||
from graphiti_core.nodes import EntityNode
|
||
|
||
result = []
|
||
cursor = None
|
||
while True:
|
||
batch = await EntityNode.get_by_group_ids(
|
||
self._driver,
|
||
[graph_id],
|
||
limit=self.PAGE_SIZE,
|
||
uuid_cursor=cursor,
|
||
)
|
||
if not batch:
|
||
break
|
||
result.extend(self._wrap_node(node) for node in batch)
|
||
if len(batch) < self.PAGE_SIZE:
|
||
break
|
||
cursor = batch[-1].uuid
|
||
|
||
return result
|
||
|
||
def get_all_nodes(self, graph_id: str) -> List[Any]:
|
||
self._validate_graph_id(graph_id)
|
||
return self._run(self._get_all_nodes_async(graph_id))
|
||
|
||
async def _get_all_edges_async(self, graph_id: str) -> List[_CompatEdge]:
|
||
from graphiti_core.edges import EntityEdge, GroupsEdgesNotFoundError
|
||
|
||
result = []
|
||
cursor = None
|
||
while True:
|
||
try:
|
||
batch = await EntityEdge.get_by_group_ids(
|
||
self._driver,
|
||
[graph_id],
|
||
limit=self.PAGE_SIZE,
|
||
uuid_cursor=cursor,
|
||
)
|
||
except GroupsEdgesNotFoundError:
|
||
break
|
||
|
||
if not batch:
|
||
break
|
||
|
||
result.extend(self._wrap_edge(edge) for edge in batch)
|
||
if len(batch) < self.PAGE_SIZE:
|
||
break
|
||
cursor = batch[-1].uuid
|
||
|
||
return result
|
||
|
||
def get_all_edges(self, graph_id: str) -> List[Any]:
|
||
self._validate_graph_id(graph_id)
|
||
return self._run(self._get_all_edges_async(graph_id))
|
||
|
||
async def _get_node_async(self, node_uuid: str) -> _CompatNode:
|
||
from graphiti_core.nodes import EntityNode
|
||
|
||
return self._wrap_node(await EntityNode.get_by_uuid(self._driver, node_uuid))
|
||
|
||
def get_node(self, node_uuid: str) -> Any:
|
||
return self._run(self._get_node_async(node_uuid))
|
||
|
||
async def _get_node_edges_async(self, node_uuid: str) -> List[_CompatEdge]:
|
||
from graphiti_core.edges import EntityEdge
|
||
from graphiti_core.nodes import EntityNode
|
||
|
||
edges = await EntityEdge.get_by_node_uuid(self._driver, node_uuid)
|
||
related_node_ids = {
|
||
related_uuid
|
||
for edge in edges
|
||
for related_uuid in (edge.source_node_uuid, edge.target_node_uuid)
|
||
if related_uuid
|
||
}
|
||
node_name_map = {}
|
||
if related_node_ids:
|
||
for node in await EntityNode.get_by_uuids(self._driver, list(related_node_ids)):
|
||
node_name_map[node.uuid] = node.name or ""
|
||
|
||
return [
|
||
self._wrap_edge(
|
||
edge,
|
||
source_node_name=node_name_map.get(edge.source_node_uuid, ""),
|
||
target_node_name=node_name_map.get(edge.target_node_uuid, ""),
|
||
)
|
||
for edge in edges
|
||
]
|
||
|
||
def get_node_edges(self, node_uuid: str) -> List[Any]:
|
||
return self._run(self._get_node_edges_async(node_uuid))
|
||
|
||
async def _delete_graph_async(self, graph_id: str) -> None:
|
||
from graphiti_core.nodes import Node
|
||
|
||
await self._driver.execute_query(
|
||
"""
|
||
MATCH (s:Saga {group_id: $graph_id})
|
||
DETACH DELETE s
|
||
""",
|
||
graph_id=graph_id,
|
||
)
|
||
await Node.delete_by_group_id(self._driver, graph_id)
|
||
await self._driver.execute_query(
|
||
"""
|
||
MATCH (m:GraphMetadata {graph_id: $graph_id})
|
||
DETACH DELETE m
|
||
""",
|
||
graph_id=graph_id,
|
||
)
|
||
|
||
def delete_graph(self, graph_id: str) -> None:
|
||
self._validate_graph_id(graph_id)
|
||
self._run(self._delete_graph_async(graph_id))
|
||
with self.__class__._ontology_lock:
|
||
self.__class__._ontology_registry.pop(graph_id, None)
|
||
|
||
async def _get_live_graph_statistics_async(self, graph_id: str) -> Dict[str, int]:
|
||
node_records, _, _ = await self._driver.execute_query(
|
||
"""
|
||
MATCH (n:Entity {group_id: $graph_id})
|
||
RETURN count(n) AS node_count
|
||
""",
|
||
graph_id=graph_id,
|
||
routing_="r",
|
||
)
|
||
edge_records, _, _ = await self._driver.execute_query(
|
||
"""
|
||
MATCH ()-[e:RELATES_TO {group_id: $graph_id}]->()
|
||
RETURN count(e) AS edge_count
|
||
""",
|
||
graph_id=graph_id,
|
||
routing_="r",
|
||
)
|
||
episode_records, _, _ = await self._driver.execute_query(
|
||
"""
|
||
MATCH (n:Episodic {group_id: $graph_id})
|
||
RETURN count(n) AS episode_count
|
||
""",
|
||
graph_id=graph_id,
|
||
routing_="r",
|
||
)
|
||
|
||
return {
|
||
"node_count": int((node_records[0] if node_records else {}).get("node_count", 0) or 0),
|
||
"edge_count": int((edge_records[0] if edge_records else {}).get("edge_count", 0) or 0),
|
||
"episode_count": int(
|
||
(episode_records[0] if episode_records else {}).get("episode_count", 0) or 0
|
||
),
|
||
}
|
||
|
||
def get_live_graph_statistics(self, graph_id: str) -> Optional[Dict[str, int]]:
|
||
self._validate_graph_id(graph_id)
|
||
return self._run(self._get_live_graph_statistics_async(graph_id))
|