MicroFish/backend/app/graph/graphiti_backend.py

876 lines
30 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Graphiti + Neo4j graph backend implementation.
"""
from __future__ import annotations
import asyncio
import json
import logging
import threading
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, create_model
from ..config import Config
from .base import GraphBackend
logger = logging.getLogger(__name__)
@dataclass
class _CompatEpisode:
uuid: str
processed: bool = True
name: str = ""
content: str = ""
valid_at: Optional[datetime] = None
created_at: Optional[datetime] = None
@property
def uuid_(self) -> str:
return self.uuid
@dataclass
class _CompatNode:
uuid: str
name: str = ""
labels: List[str] = field(default_factory=list)
summary: str = ""
attributes: Dict[str, Any] = field(default_factory=dict)
created_at: Optional[datetime] = None
@property
def uuid_(self) -> str:
return self.uuid
@dataclass
class _CompatEdge:
uuid: str
name: str
fact: str
source_node_uuid: str
target_node_uuid: str
source_node_name: str = ""
target_node_name: str = ""
attributes: Dict[str, Any] = field(default_factory=dict)
episodes: List[str] = field(default_factory=list)
created_at: Optional[datetime] = None
valid_at: Optional[datetime] = None
invalid_at: Optional[datetime] = None
expired_at: Optional[datetime] = None
@property
def uuid_(self) -> str:
return self.uuid
@dataclass
class _CompatSearchResults:
edges: List[_CompatEdge] = field(default_factory=list)
nodes: List[_CompatNode] = field(default_factory=list)
@dataclass
class _OntologyBundle:
entity_types: Dict[str, type[BaseModel]] = field(default_factory=dict)
edge_types: Dict[str, type[BaseModel]] = field(default_factory=dict)
edge_type_map: Dict[tuple[str, str], List[str]] = field(default_factory=dict)
spec: Dict[str, Any] = field(default_factory=dict)
class _AsyncBridge:
"""Run async Graphiti calls from the app's synchronous service layer."""
def __init__(self):
self._ready = threading.Event()
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
self._ready.wait()
def _run_loop(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
self._loop = loop
self._ready.set()
loop.run_forever()
def run(self, coro):
if self._loop is None:
raise RuntimeError("Graphiti async loop 未初始化")
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
return future.result()
class GraphitiBackend(GraphBackend):
"""Graph backend backed by Graphiti OSS + Neo4j."""
_bridge: Optional[_AsyncBridge] = None
_bridge_lock = threading.Lock()
_indices_ready = False
_indices_lock = threading.Lock()
_ontology_registry: Dict[str, _OntologyBundle] = {}
_ontology_lock = threading.Lock()
_cross_encoder_warning_emitted = False
PAGE_SIZE = 200
def __init__(self, api_key: Optional[str] = None):
del api_key
errors = Config.get_graphiti_config_errors()
if errors:
raise ValueError("; ".join(errors))
try:
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig
from graphiti_core.graphiti import Graphiti
from graphiti_core.llm_client import OpenAIClient
from graphiti_core.llm_client.config import LLMConfig
from graphiti_core.llm_client.openai_generic_client import OpenAIGenericClient
except ImportError as exc:
raise ImportError(
"Graphiti 依赖未安装,请先在 backend 环境中安装 graphiti-core 与 neo4j"
) from exc
llm_config = LLMConfig(
api_key=Config.GRAPHITI_LLM_API_KEY,
base_url=Config.GRAPHITI_LLM_BASE_URL,
model=Config.GRAPHITI_LLM_MODEL,
small_model=Config.GRAPHITI_LLM_SMALL_MODEL,
temperature=0,
)
reranker_config = LLMConfig(
api_key=Config.GRAPHITI_RERANKER_API_KEY,
base_url=Config.GRAPHITI_RERANKER_BASE_URL,
model=Config.GRAPHITI_RERANKER_MODEL,
temperature=0,
)
embedder_config = OpenAIEmbedderConfig(
api_key=Config.GRAPHITI_EMBEDDER_API_KEY,
base_url=Config.GRAPHITI_EMBEDDER_BASE_URL,
embedding_model=Config.GRAPHITI_EMBEDDER_MODEL,
embedding_dim=Config.GRAPHITI_EMBEDDER_DIM,
)
llm_client_mode = (Config.GRAPHITI_LLM_CLIENT_MODE or "openai").lower()
if llm_client_mode == "generic":
llm_client = OpenAIGenericClient(
config=llm_config,
max_tokens=Config.GRAPHITI_LLM_MAX_TOKENS,
)
else:
llm_client = OpenAIClient(
config=llm_config,
max_tokens=Config.GRAPHITI_LLM_MAX_TOKENS,
)
self._graphiti = Graphiti(
uri=Config.GRAPHITI_URI,
user=Config.GRAPHITI_USER,
password=Config.GRAPHITI_PASSWORD,
llm_client=llm_client,
embedder=OpenAIEmbedder(config=embedder_config),
cross_encoder=OpenAIRerankerClient(config=reranker_config),
max_coroutines=Config.GRAPHITI_MAX_COROUTINES,
)
self._driver = self._graphiti.driver.with_database(Config.GRAPHITI_DATABASE)
self._graphiti.driver = self._driver
self._graphiti.clients.driver = self._driver
self._bridge = self._get_bridge()
self._ensure_indices()
@classmethod
def _get_bridge(cls) -> _AsyncBridge:
with cls._bridge_lock:
if cls._bridge is None:
cls._bridge = _AsyncBridge()
return cls._bridge
@property
def raw_client(self) -> Any:
return self._graphiti
def _run(self, coro):
return self._bridge.run(coro)
def _ensure_indices(self) -> None:
if self.__class__._indices_ready:
return
with self.__class__._indices_lock:
if self.__class__._indices_ready:
return
self._run(self._graphiti.build_indices_and_constraints())
self.__class__._indices_ready = True
def _validate_graph_id(self, graph_id: str) -> None:
from graphiti_core.helpers import validate_group_id
validate_group_id(graph_id)
def _normalize_model_spec(self, model: type[BaseModel]) -> Dict[str, Any]:
fields = []
for field_name, model_field in model.model_fields.items():
fields.append(
{
"name": field_name,
"description": model_field.description or field_name,
}
)
return {
"description": (getattr(model, "__doc__", "") or "").strip(),
"fields": fields,
}
def _build_dynamic_model(self, name: str, spec: Dict[str, Any]) -> type[BaseModel]:
field_definitions = {}
for field_spec in spec.get("fields", []):
field_name = field_spec.get("name", "").strip()
if not field_name:
continue
field_definitions[field_name] = (
Optional[str],
Field(
default=None,
description=field_spec.get("description") or field_name,
),
)
model = create_model(name, __base__=BaseModel, **field_definitions)
model.__doc__ = spec.get("description") or name
return model
def _serialize_ontology_spec(
self,
entity_specs: Dict[str, Dict[str, Any]],
edge_specs: Dict[str, Dict[str, Any]],
edge_type_map: Dict[tuple[str, str], List[str]],
) -> Dict[str, Any]:
return {
"entity_types": entity_specs,
"edge_types": edge_specs,
"edge_type_map": [
{
"source": source,
"target": target,
"edges": edge_names,
}
for (source, target), edge_names in sorted(edge_type_map.items())
],
}
def _bundle_from_spec(self, spec: Dict[str, Any]) -> _OntologyBundle:
entity_types = {
name: self._build_dynamic_model(name, model_spec)
for name, model_spec in (spec.get("entity_types") or {}).items()
}
edge_types = {
name: self._build_dynamic_model(name, model_spec)
for name, model_spec in (spec.get("edge_types") or {}).items()
}
edge_type_map: Dict[tuple[str, str], List[str]] = {}
for entry in spec.get("edge_type_map") or []:
source = entry.get("source", "Entity")
target = entry.get("target", "Entity")
edge_type_map[(source, target)] = list(entry.get("edges") or [])
return _OntologyBundle(
entity_types=entity_types,
edge_types=edge_types,
edge_type_map=edge_type_map,
spec=spec,
)
def _build_ontology_bundle(self, entities: Any = None, edges: Any = None) -> _OntologyBundle:
entity_specs = {}
entity_types = {}
for entity_name, entity_model in (entities or {}).items():
entity_spec = self._normalize_model_spec(entity_model)
entity_specs[entity_name] = entity_spec
entity_types[entity_name] = self._build_dynamic_model(entity_name, entity_spec)
edge_specs = {}
edge_types = {}
edge_type_map: Dict[tuple[str, str], List[str]] = {}
for edge_name, edge_value in (edges or {}).items():
if not isinstance(edge_value, tuple) or len(edge_value) != 2:
continue
edge_model, source_targets = edge_value
edge_spec = self._normalize_model_spec(edge_model)
edge_specs[edge_name] = edge_spec
edge_types[edge_name] = self._build_dynamic_model(edge_name, edge_spec)
for source_target in source_targets or []:
source = getattr(source_target, "source", "Entity") or "Entity"
target = getattr(source_target, "target", "Entity") or "Entity"
edge_type_map.setdefault((source, target), []).append(edge_name)
if not edge_type_map:
edge_type_map = {("Entity", "Entity"): list(edge_types.keys())}
spec = self._serialize_ontology_spec(entity_specs, edge_specs, edge_type_map)
return _OntologyBundle(
entity_types=entity_types,
edge_types=edge_types,
edge_type_map=edge_type_map,
spec=spec,
)
async def _upsert_graph_metadata_async(
self,
graph_id: str,
name: Optional[str] = None,
description: Optional[str] = None,
ontology_spec: Optional[Dict[str, Any]] = None,
) -> None:
payload = json.dumps(ontology_spec, ensure_ascii=False) if ontology_spec is not None else None
await self._driver.execute_query(
"""
MERGE (m:GraphMetadata {graph_id: $graph_id})
ON CREATE SET
m.group_id = $graph_id,
m.created_at = datetime()
SET
m.name = CASE WHEN $name IS NULL OR $name = '' THEN coalesce(m.name, '') ELSE $name END,
m.description = CASE
WHEN $description IS NULL OR $description = ''
THEN coalesce(m.description, '')
ELSE $description
END,
m.ontology_json = CASE
WHEN $ontology_json IS NULL OR $ontology_json = ''
THEN m.ontology_json
ELSE $ontology_json
END,
m.updated_at = datetime()
""",
graph_id=graph_id,
name=name,
description=description,
ontology_json=payload,
)
async def _load_ontology_bundle_async(self, graph_id: str) -> _OntologyBundle:
records, _, _ = await self._driver.execute_query(
"""
MATCH (m:GraphMetadata {graph_id: $graph_id})
RETURN m.ontology_json AS ontology_json
""",
graph_id=graph_id,
routing_="r",
)
if not records:
return _OntologyBundle()
ontology_json = records[0].get("ontology_json")
if not ontology_json:
return _OntologyBundle()
try:
spec = json.loads(ontology_json)
except (TypeError, ValueError, json.JSONDecodeError):
logger.warning("Graphiti ontology metadata 解析失败graph_id=%s", graph_id)
return _OntologyBundle()
return self._bundle_from_spec(spec)
async def _get_ontology_bundle_async(self, graph_id: str) -> _OntologyBundle:
with self.__class__._ontology_lock:
cached = self.__class__._ontology_registry.get(graph_id)
if cached is not None:
return cached
bundle = await self._load_ontology_bundle_async(graph_id)
with self.__class__._ontology_lock:
self.__class__._ontology_registry[graph_id] = bundle
return bundle
def _get_ontology_bundle(self, graph_id: str) -> _OntologyBundle:
return self._run(self._get_ontology_bundle_async(graph_id))
def _set_ontology_bundle(self, graph_id: str, bundle: _OntologyBundle) -> None:
with self.__class__._ontology_lock:
self.__class__._ontology_registry[graph_id] = bundle
def get_ontology_spec(self, graph_id: str) -> Optional[Dict[str, Any]]:
self._validate_graph_id(graph_id)
bundle = self._get_ontology_bundle(graph_id)
return dict(bundle.spec) if bundle.spec else None
def create_graph(self, graph_id: str, name: str, description: str) -> None:
self._validate_graph_id(graph_id)
self._run(
self._upsert_graph_metadata_async(
graph_id=graph_id,
name=name,
description=description,
)
)
def set_ontology(
self,
graph_id: str,
entities: Any = None,
edges: Any = None,
) -> None:
self._validate_graph_id(graph_id)
bundle = self._build_ontology_bundle(entities=entities, edges=edges)
self._set_ontology_bundle(graph_id, bundle)
self._run(
self._upsert_graph_metadata_async(
graph_id=graph_id,
ontology_spec=bundle.spec,
)
)
async def _add_text_async(self, graph_id: str, data: str) -> _CompatEpisode:
from graphiti_core.helpers import validate_excluded_entity_types
from graphiti_core.nodes import EpisodeType, EpisodicNode
from graphiti_core.search.search_utils import RELEVANT_SCHEMA_LIMIT
from graphiti_core.utils.datetime_utils import utc_now
from graphiti_core.utils.maintenance.node_operations import (
extract_attributes_from_nodes,
extract_nodes,
resolve_extracted_nodes,
)
from graphiti_core.utils.ontology_utils.entity_types_utils import validate_entity_types
bundle = await self._get_ontology_bundle_async(graph_id)
entity_types = bundle.entity_types or None
edge_types = bundle.edge_types or None
edge_type_map = bundle.edge_type_map or {("Entity", "Entity"): []}
validate_entity_types(entity_types)
validate_excluded_entity_types(None, entity_types)
now = utc_now()
previous_episodes = await self._graphiti.retrieve_episodes(
reference_time=now,
last_n=RELEVANT_SCHEMA_LIMIT,
group_ids=[graph_id],
source=EpisodeType.text,
driver=self._driver,
)
episode = EpisodicNode(
name=f"episode_{now.strftime('%Y%m%d%H%M%S%f')}",
group_id=graph_id,
labels=[],
source=EpisodeType.text,
content=data,
source_description="text",
created_at=now,
valid_at=now,
)
extracted_nodes = await extract_nodes(
self._graphiti.clients,
episode,
previous_episodes,
entity_types,
None,
None,
)
nodes, uuid_map, _ = await resolve_extracted_nodes(
self._graphiti.clients,
extracted_nodes,
episode,
previous_episodes,
entity_types,
)
resolved_edges, invalidated_edges, new_edges = await self._graphiti._extract_and_resolve_edges(
episode,
extracted_nodes,
previous_episodes,
edge_type_map,
graph_id,
edge_types,
nodes,
uuid_map,
None,
)
entity_edges = resolved_edges + invalidated_edges
hydrated_nodes = await extract_attributes_from_nodes(
self._graphiti.clients,
nodes,
episode,
previous_episodes,
entity_types,
edges=new_edges,
)
_, saved_episode = await self._graphiti._process_episode_data(
episode,
hydrated_nodes,
entity_edges,
now,
graph_id,
None,
None,
)
return _CompatEpisode(
uuid=saved_episode.uuid,
processed=True,
name=saved_episode.name,
content=saved_episode.content,
valid_at=saved_episode.valid_at,
created_at=saved_episode.created_at,
)
def add_batch(self, graph_id: str, episodes: List[Any]) -> List[Any]:
results = []
for episode in episodes:
data = getattr(episode, "data", None)
if data is None and isinstance(episode, dict):
data = episode.get("data", "")
results.append(self.add_text(graph_id=graph_id, data=str(data or "")))
return results
def add_text(self, graph_id: str, data: str) -> Any:
self._validate_graph_id(graph_id)
return self._run(self._add_text_async(graph_id=graph_id, data=data))
async def _get_episode_async(self, episode_uuid: str) -> _CompatEpisode:
from graphiti_core.nodes import EpisodicNode
episode = await EpisodicNode.get_by_uuid(self._driver, episode_uuid)
return _CompatEpisode(
uuid=episode.uuid,
processed=True,
name=episode.name,
content=episode.content,
valid_at=episode.valid_at,
created_at=episode.created_at,
)
def get_episode(self, episode_uuid: str) -> Any:
return self._run(self._get_episode_async(episode_uuid))
def _warn_cross_encoder_fallback(self) -> None:
if self.__class__._cross_encoder_warning_emitted:
return
logger.info(
"Graphiti cross_encoder 默认已降级为 rrf如需启用请设置 GRAPHITI_ENABLE_CROSS_ENCODER=true"
)
self.__class__._cross_encoder_warning_emitted = True
def _build_search_config(self, scope: str, limit: int, reranker: Optional[str]):
from graphiti_core.search.search_config import (
EdgeReranker,
EdgeSearchConfig,
EdgeSearchMethod,
NodeReranker,
NodeSearchConfig,
NodeSearchMethod,
SearchConfig,
)
reranker_name = (reranker or "rrf").strip().lower()
edge_reranker_map = {
"rrf": EdgeReranker.rrf,
"reciprocal_rank_fusion": EdgeReranker.rrf,
"cross_encoder": EdgeReranker.cross_encoder,
"node_distance": EdgeReranker.node_distance,
"episode_mentions": EdgeReranker.episode_mentions,
"mmr": EdgeReranker.mmr,
}
node_reranker_map = {
"rrf": NodeReranker.rrf,
"reciprocal_rank_fusion": NodeReranker.rrf,
"cross_encoder": NodeReranker.cross_encoder,
"node_distance": NodeReranker.node_distance,
"episode_mentions": NodeReranker.episode_mentions,
"mmr": NodeReranker.mmr,
}
edge_reranker = edge_reranker_map.get(reranker_name, EdgeReranker.rrf)
node_reranker = node_reranker_map.get(reranker_name, NodeReranker.rrf)
edge_methods = [EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity]
node_methods = [NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity]
if reranker_name == "cross_encoder":
if Config.GRAPHITI_ENABLE_CROSS_ENCODER:
edge_methods.append(EdgeSearchMethod.bfs)
node_methods.append(NodeSearchMethod.bfs)
else:
self._warn_cross_encoder_fallback()
edge_reranker = EdgeReranker.rrf
node_reranker = NodeReranker.rrf
edge_config = None
node_config = None
if scope in {"edges", "both"}:
edge_config = EdgeSearchConfig(
search_methods=edge_methods,
reranker=edge_reranker,
)
if scope in {"nodes", "both"}:
node_config = NodeSearchConfig(
search_methods=node_methods,
reranker=node_reranker,
)
return SearchConfig(
edge_config=edge_config,
node_config=node_config,
limit=max(1, limit),
)
def _wrap_node(self, node: Any) -> _CompatNode:
return _CompatNode(
uuid=getattr(node, "uuid", ""),
name=getattr(node, "name", "") or "",
labels=list(getattr(node, "labels", []) or []),
summary=getattr(node, "summary", "") or "",
attributes=dict(getattr(node, "attributes", {}) or {}),
created_at=getattr(node, "created_at", None),
)
def _wrap_edge(
self,
edge: Any,
source_node_name: str = "",
target_node_name: str = "",
) -> _CompatEdge:
return _CompatEdge(
uuid=getattr(edge, "uuid", ""),
name=getattr(edge, "name", "") or "",
fact=getattr(edge, "fact", "") or "",
source_node_uuid=getattr(edge, "source_node_uuid", "") or "",
target_node_uuid=getattr(edge, "target_node_uuid", "") or "",
source_node_name=source_node_name,
target_node_name=target_node_name,
attributes=dict(getattr(edge, "attributes", {}) or {}),
episodes=list(getattr(edge, "episodes", []) or []),
created_at=getattr(edge, "created_at", None),
valid_at=getattr(edge, "valid_at", None),
invalid_at=getattr(edge, "invalid_at", None),
expired_at=getattr(edge, "expired_at", None),
)
async def _search_async(
self,
graph_id: str,
query: str,
limit: int,
scope: str,
reranker: Optional[str],
) -> _CompatSearchResults:
from graphiti_core.nodes import EntityNode
search_config = self._build_search_config(scope=scope, limit=limit, reranker=reranker)
results = await self._graphiti.search_(
query=query,
config=search_config,
group_ids=[graph_id],
driver=self._driver,
)
nodes = [self._wrap_node(node) for node in results.nodes]
node_name_map = {node.uuid: node.name for node in nodes if node.uuid}
missing_node_ids = {
node_uuid
for edge in results.edges
for node_uuid in (edge.source_node_uuid, edge.target_node_uuid)
if node_uuid and node_uuid not in node_name_map
}
if missing_node_ids:
for node in await EntityNode.get_by_uuids(self._driver, list(missing_node_ids)):
node_name_map[node.uuid] = node.name or ""
edges = [
self._wrap_edge(
edge,
source_node_name=node_name_map.get(edge.source_node_uuid, ""),
target_node_name=node_name_map.get(edge.target_node_uuid, ""),
)
for edge in results.edges
]
return _CompatSearchResults(edges=edges, nodes=nodes)
def search(
self,
graph_id: str,
query: str,
limit: int = 10,
scope: str = "edges",
reranker: Optional[str] = None,
) -> Any:
self._validate_graph_id(graph_id)
return self._run(
self._search_async(
graph_id=graph_id,
query=query,
limit=limit,
scope=scope,
reranker=reranker,
)
)
async def _get_all_nodes_async(self, graph_id: str) -> List[_CompatNode]:
from graphiti_core.nodes import EntityNode
result = []
cursor = None
while True:
batch = await EntityNode.get_by_group_ids(
self._driver,
[graph_id],
limit=self.PAGE_SIZE,
uuid_cursor=cursor,
)
if not batch:
break
result.extend(self._wrap_node(node) for node in batch)
if len(batch) < self.PAGE_SIZE:
break
cursor = batch[-1].uuid
return result
def get_all_nodes(self, graph_id: str) -> List[Any]:
self._validate_graph_id(graph_id)
return self._run(self._get_all_nodes_async(graph_id))
async def _get_all_edges_async(self, graph_id: str) -> List[_CompatEdge]:
from graphiti_core.edges import EntityEdge, GroupsEdgesNotFoundError
result = []
cursor = None
while True:
try:
batch = await EntityEdge.get_by_group_ids(
self._driver,
[graph_id],
limit=self.PAGE_SIZE,
uuid_cursor=cursor,
)
except GroupsEdgesNotFoundError:
break
if not batch:
break
result.extend(self._wrap_edge(edge) for edge in batch)
if len(batch) < self.PAGE_SIZE:
break
cursor = batch[-1].uuid
return result
def get_all_edges(self, graph_id: str) -> List[Any]:
self._validate_graph_id(graph_id)
return self._run(self._get_all_edges_async(graph_id))
async def _get_node_async(self, node_uuid: str) -> _CompatNode:
from graphiti_core.nodes import EntityNode
return self._wrap_node(await EntityNode.get_by_uuid(self._driver, node_uuid))
def get_node(self, node_uuid: str) -> Any:
return self._run(self._get_node_async(node_uuid))
async def _get_node_edges_async(self, node_uuid: str) -> List[_CompatEdge]:
from graphiti_core.edges import EntityEdge
from graphiti_core.nodes import EntityNode
edges = await EntityEdge.get_by_node_uuid(self._driver, node_uuid)
related_node_ids = {
related_uuid
for edge in edges
for related_uuid in (edge.source_node_uuid, edge.target_node_uuid)
if related_uuid
}
node_name_map = {}
if related_node_ids:
for node in await EntityNode.get_by_uuids(self._driver, list(related_node_ids)):
node_name_map[node.uuid] = node.name or ""
return [
self._wrap_edge(
edge,
source_node_name=node_name_map.get(edge.source_node_uuid, ""),
target_node_name=node_name_map.get(edge.target_node_uuid, ""),
)
for edge in edges
]
def get_node_edges(self, node_uuid: str) -> List[Any]:
return self._run(self._get_node_edges_async(node_uuid))
async def _delete_graph_async(self, graph_id: str) -> None:
from graphiti_core.nodes import Node
await self._driver.execute_query(
"""
MATCH (s:Saga {group_id: $graph_id})
DETACH DELETE s
""",
graph_id=graph_id,
)
await Node.delete_by_group_id(self._driver, graph_id)
await self._driver.execute_query(
"""
MATCH (m:GraphMetadata {graph_id: $graph_id})
DETACH DELETE m
""",
graph_id=graph_id,
)
def delete_graph(self, graph_id: str) -> None:
self._validate_graph_id(graph_id)
self._run(self._delete_graph_async(graph_id))
with self.__class__._ontology_lock:
self.__class__._ontology_registry.pop(graph_id, None)
async def _get_live_graph_statistics_async(self, graph_id: str) -> Dict[str, int]:
node_records, _, _ = await self._driver.execute_query(
"""
MATCH (n:Entity {group_id: $graph_id})
RETURN count(n) AS node_count
""",
graph_id=graph_id,
routing_="r",
)
edge_records, _, _ = await self._driver.execute_query(
"""
MATCH ()-[e:RELATES_TO {group_id: $graph_id}]->()
RETURN count(e) AS edge_count
""",
graph_id=graph_id,
routing_="r",
)
episode_records, _, _ = await self._driver.execute_query(
"""
MATCH (n:Episodic {group_id: $graph_id})
RETURN count(n) AS episode_count
""",
graph_id=graph_id,
routing_="r",
)
return {
"node_count": int((node_records[0] if node_records else {}).get("node_count", 0) or 0),
"edge_count": int((edge_records[0] if edge_records else {}).get("edge_count", 0) or 0),
"episode_count": int(
(episode_records[0] if episode_records else {}).get("episode_count", 0) or 0
),
}
def get_live_graph_statistics(self, graph_id: str) -> Optional[Dict[str, int]]:
self._validate_graph_id(graph_id)
return self._run(self._get_live_graph_statistics_async(graph_id))