feat(graph): add pluggable graph backend with Graphiti support

This commit is contained in:
MiroFish Bot 2026-03-21 17:25:34 +09:00
parent 1536a79334
commit 25d43f8a4b
20 changed files with 2834 additions and 408 deletions

View File

@ -1,16 +1,78 @@
# LLM API配置支持 OpenAI SDK 格式的任意 LLM API
# 推荐使用阿里百炼平台qwen-plus模型https://bailian.console.aliyun.com/
# 注意消耗较大可先进行小于40轮的模拟尝试
# LLM API 配置(支持 OpenAI SDK 格式的任意 LLM API
# 可直接填写你自己的 LLM 接口,例如:
# LLM_BASE_URL=http://127.0.0.1:18081/v1
# LLM_MODEL_NAME=gpt-5.4
LLM_API_KEY=your_api_key_here
LLM_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1
LLM_MODEL_NAME=qwen-plus
# ===== ZEP记忆图谱配置 =====
# 每月免费额度即可支撑简单使用https://app.getzep.com/
ZEP_API_KEY=your_zep_api_key_here
# Docker 容器内访问宿主机 LLM 时使用的地址
# Linux + Docker Compose 下可保持默认 host.docker.internal
DOCKER_LLM_BASE_URL=http://host.docker.internal:18081/v1
# ===== Graphiti + Neo4j默认推荐=====
GRAPH_BACKEND=graphiti
GRAPHITI_URI=bolt://localhost:7687
GRAPHITI_USER=neo4j
GRAPHITI_PASSWORD=password123
GRAPHITI_DATABASE=neo4j
GRAPHITI_LLM_CLIENT_MODE=openai
GRAPHITI_EMBEDDER_API_KEY=ollama
GRAPHITI_EMBEDDER_BASE_URL=http://127.0.0.1:11434/v1
GRAPHITI_EMBEDDER_MODEL=qwen3-embedding:8b
GRAPHITI_EMBEDDER_DIM=1024
GRAPH_SEARCH_RERANKER=rrf
GRAPH_SEARCH_APP_RERANKER=embedding_rrf
GRAPH_SEARCH_APP_SEMANTIC_WEIGHT=2.0
GRAPH_SEARCH_EXPAND_EDGES_FROM_NODES=true
OLLAMA_PORT=11434
OLLAMA_EMBEDDER_MODEL=qwen3-embedding:8b
# 可选:如需独立于 Graphiti / OpenZep 当前 embedder可单独覆写
# GRAPH_SEARCH_APP_EMBEDDER_API_KEY=ollama
# GRAPH_SEARCH_APP_EMBEDDER_BASE_URL=http://127.0.0.1:11434/v1
# GRAPH_SEARCH_APP_EMBEDDER_MODEL=qwen3-embedding:8b
# 可选:如果你另起了免费 cross-encoder / rerank 服务TEI / Infinity / vLLM
# GRAPH_SEARCH_APP_RERANKER=api_rrf
# GRAPH_SEARCH_APP_RERANKER_PROVIDER=tei
# GRAPH_SEARCH_APP_RERANKER_BASE_URL=http://127.0.0.1:18090
# GRAPH_SEARCH_APP_RERANKER_MODEL=your_reranker_model
# GRAPH_SEARCH_APP_RERANKER_TIMEOUT=20
# 免费召回增强:从高相关节点补抓相邻边,默认开启
# GRAPH_SEARCH_NODE_EDGE_EXPANSION_LIMIT=2
# GRAPH_SEARCH_NODE_EDGE_PER_NODE_LIMIT=8
# GRAPHITI_ENABLE_CROSS_ENCODER=false
# Docker 中的 MiroFish 容器访问宿主机 LLM / 容器内 Ollama 时使用:
DOCKER_GRAPHITI_LLM_BASE_URL=http://host.docker.internal:18081/v1
DOCKER_GRAPHITI_EMBEDDER_BASE_URL=http://ollama:11434/v1
# ===== OpenZep / Zep可选非默认=====
# ZEP_API_KEY=your_zep_api_key_here
# ZEP_MODE=openzep
# 本地源码运行 MiroFish 时使用 localhost
# ZEP_BASE_URL=http://localhost:8000/api/v2
# Docker 中的 MiroFish 容器会自动改用 openzep 服务名
# DOCKER_ZEP_BASE_URL=http://openzep:8000/api/v2
# 留空表示不启用 OpenZep API 鉴权
# OPENZEP_API_KEY=
# OPENZEP_LLM_API_KEY=your_api_key_here
# OPENZEP_LLM_BASE_URL=http://127.0.0.1:18081/v1
# OPENZEP_DOCKER_LLM_BASE_URL=http://host.docker.internal:18081/v1
# OPENZEP_LLM_MODEL=gpt-5.4
# OPENZEP_EMBEDDER_API_KEY=ollama
# OPENZEP_EMBEDDER_BASE_URL=http://127.0.0.1:11434/v1
# OPENZEP_DOCKER_EMBEDDER_BASE_URL=http://ollama:11434/v1
# OPENZEP_EMBEDDER_MODEL=qwen3-embedding:8b
# ===== Neo4j / OpenZep 端口 =====
NEO4J_PASSWORD=password123
NEO4J_HTTP_PORT=7474
NEO4J_BOLT_PORT=7687
OPENZEP_PORT=8000
# ===== 加速 LLM 配置(可选)=====
# 注意如果不使用加速配置env文件中就不要出现下面的配置项
# 注意如果不使用加速配置env 文件中就不要出现下面的配置项
LLM_BOOST_API_KEY=your_api_key_here
LLM_BOOST_BASE_URL=your_base_url_here
LLM_BOOST_MODEL_NAME=your_model_name_here
LLM_BOOST_MODEL_NAME=your_model_name_here

View File

@ -10,6 +10,8 @@ COPY --from=ghcr.io/astral-sh/uv:0.9.26 /uv /uvx /bin/
WORKDIR /app
ARG INSTALL_GRAPHITI=true
# 先复制依赖描述文件以利用缓存
COPY package.json package-lock.json ./
COPY frontend/package.json frontend/package-lock.json ./frontend/
@ -18,7 +20,10 @@ COPY backend/pyproject.toml backend/uv.lock ./backend/
# 安装依赖Node + Python
RUN npm ci \
&& npm ci --prefix frontend \
&& cd backend && uv sync --frozen
&& cd backend && uv sync --frozen \
&& if [ "$INSTALL_GRAPHITI" = "true" ]; then \
uv pip install --python /app/backend/.venv/bin/python --no-cache-dir graphiti-core==0.28.2 "neo4j>=5.26.0"; \
fi
# 复制项目源码
COPY . .

View File

@ -283,9 +283,7 @@ def build_graph():
logger.info("=== 开始构建图谱 ===")
# 检查配置
errors = []
if not Config.ZEP_API_KEY:
errors.append("ZEP_API_KEY未配置")
errors = Config.get_graph_backend_config_errors()
if errors:
logger.error(f"配置错误: {errors}")
return jsonify({
@ -432,10 +430,12 @@ def build_graph():
progress=15
)
# OpenZep 本地链路在批量抽取时更容易卡在长时间的联合推理里。
# 改为单块发送可以显著降低单次处理负载,牺牲吞吐换稳定性。
episode_uuids = builder.add_text_batches(
graph_id,
graph_id,
chunks,
batch_size=3,
batch_size=1 if Config.use_openzep() else 3,
progress_callback=add_progress_callback
)
@ -454,7 +454,7 @@ def build_graph():
progress=progress
)
builder._wait_for_episodes(episode_uuids, wait_progress_callback)
builder._wait_for_episodes(graph_id, episode_uuids, wait_progress_callback)
# 获取图谱数据
task_manager.update_task(
@ -464,12 +464,20 @@ def build_graph():
)
graph_data = builder.get_graph_data(graph_id)
node_count = graph_data.get("node_count", 0)
edge_count = graph_data.get("edge_count", 0)
# 如果图谱仍然是空的,说明 OpenZep 没有成功完成抽取。
# 不能把这种情况标记为成功,否则前端会误以为构图完成。
if node_count == 0 and edge_count == 0:
raise RuntimeError(
"图谱构建未产出任何节点或边OpenZep 处理可能超时或未完成"
)
# 更新项目状态
project.status = ProjectStatus.GRAPH_COMPLETED
ProjectManager.save_project(project)
node_count = graph_data.get("node_count", 0)
edge_count = graph_data.get("edge_count", 0)
build_logger.info(f"[{task_id}] 图谱构建完成: graph_id={graph_id}, 节点={node_count}, 边={edge_count}")
# 完成
@ -567,10 +575,11 @@ def get_graph_data(graph_id: str):
获取图谱数据节点和边
"""
try:
if not Config.ZEP_API_KEY:
errors = Config.get_graph_backend_config_errors()
if errors:
return jsonify({
"success": False,
"error": "ZEP_API_KEY未配置"
"error": "; ".join(errors)
}), 500
builder = GraphBuilderService(api_key=Config.ZEP_API_KEY)
@ -595,10 +604,11 @@ def delete_graph(graph_id: str):
删除Zep图谱
"""
try:
if not Config.ZEP_API_KEY:
errors = Config.get_graph_backend_config_errors()
if errors:
return jsonify({
"success": False,
"error": "ZEP_API_KEY未配置"
"error": "; ".join(errors)
}), 500
builder = GraphBuilderService(api_key=Config.ZEP_API_KEY)

View File

@ -56,10 +56,11 @@ def get_graph_entities(graph_id: str):
enrich: 是否获取相关边信息默认true
"""
try:
if not Config.ZEP_API_KEY:
errors = Config.get_graph_backend_config_errors()
if errors:
return jsonify({
"success": False,
"error": "ZEP_API_KEY未配置"
"error": "; ".join(errors)
}), 500
entity_types_str = request.args.get('entity_types', '')
@ -93,10 +94,11 @@ def get_graph_entities(graph_id: str):
def get_entity_detail(graph_id: str, entity_uuid: str):
"""获取单个实体的详细信息"""
try:
if not Config.ZEP_API_KEY:
errors = Config.get_graph_backend_config_errors()
if errors:
return jsonify({
"success": False,
"error": "ZEP_API_KEY未配置"
"error": "; ".join(errors)
}), 500
reader = ZepEntityReader()
@ -126,10 +128,11 @@ def get_entity_detail(graph_id: str, entity_uuid: str):
def get_entities_by_type(graph_id: str, entity_type: str):
"""获取指定类型的所有实体"""
try:
if not Config.ZEP_API_KEY:
errors = Config.get_graph_backend_config_errors()
if errors:
return jsonify({
"success": False,
"error": "ZEP_API_KEY未配置"
"error": "; ".join(errors)
}), 500
enrich = request.args.get('enrich', 'true').lower() == 'true'

View File

@ -33,7 +33,54 @@ class Config:
LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME', 'gpt-4o-mini')
# Zep配置
ZEP_MODE = os.environ.get('ZEP_MODE', 'cloud').lower()
ZEP_API_KEY = os.environ.get('ZEP_API_KEY')
ZEP_BASE_URL = os.environ.get('ZEP_BASE_URL')
OPENZEP_EMBEDDER_API_KEY = os.environ.get('OPENZEP_EMBEDDER_API_KEY') or None
OPENZEP_EMBEDDER_BASE_URL = os.environ.get('OPENZEP_EMBEDDER_BASE_URL') or None
OPENZEP_EMBEDDER_MODEL = os.environ.get('OPENZEP_EMBEDDER_MODEL') or None
# 图后端配置
GRAPH_BACKEND = os.environ.get('GRAPH_BACKEND', 'graphiti').lower()
GRAPH_SEARCH_RERANKER = os.environ.get('GRAPH_SEARCH_RERANKER', 'rrf').strip() or None
GRAPH_SEARCH_APP_RERANKER = (os.environ.get('GRAPH_SEARCH_APP_RERANKER', 'embedding_rrf').strip().lower() or 'embedding_rrf')
GRAPH_SEARCH_APP_RERANK_FUSION_K = max(1, int(os.environ.get('GRAPH_SEARCH_APP_RERANK_FUSION_K', '60')))
GRAPH_SEARCH_APP_SEMANTIC_WEIGHT = max(0.0, float(os.environ.get('GRAPH_SEARCH_APP_SEMANTIC_WEIGHT', '2.0')))
GRAPH_SEARCH_APP_EMBEDDER_API_KEY = os.environ.get('GRAPH_SEARCH_APP_EMBEDDER_API_KEY') or None
GRAPH_SEARCH_APP_EMBEDDER_BASE_URL = os.environ.get('GRAPH_SEARCH_APP_EMBEDDER_BASE_URL') or None
GRAPH_SEARCH_APP_EMBEDDER_MODEL = os.environ.get('GRAPH_SEARCH_APP_EMBEDDER_MODEL') or None
GRAPH_SEARCH_APP_EMBED_BATCH_SIZE = max(1, int(os.environ.get('GRAPH_SEARCH_APP_EMBED_BATCH_SIZE', '32')))
GRAPH_SEARCH_APP_RERANKER_API_KEY = os.environ.get('GRAPH_SEARCH_APP_RERANKER_API_KEY') or None
GRAPH_SEARCH_APP_RERANKER_BASE_URL = os.environ.get('GRAPH_SEARCH_APP_RERANKER_BASE_URL') or None
GRAPH_SEARCH_APP_RERANKER_MODEL = os.environ.get('GRAPH_SEARCH_APP_RERANKER_MODEL') or None
GRAPH_SEARCH_APP_RERANKER_PROVIDER = (os.environ.get('GRAPH_SEARCH_APP_RERANKER_PROVIDER', 'auto').strip().lower() or 'auto')
GRAPH_SEARCH_APP_RERANKER_TIMEOUT = max(1.0, float(os.environ.get('GRAPH_SEARCH_APP_RERANKER_TIMEOUT', '20')))
GRAPH_SEARCH_INCLUDE_NODES = os.environ.get('GRAPH_SEARCH_INCLUDE_NODES', 'true').lower() == 'true'
GRAPH_SEARCH_EDGE_LIMIT_MULTIPLIER = max(1, int(os.environ.get('GRAPH_SEARCH_EDGE_LIMIT_MULTIPLIER', '2')))
GRAPH_SEARCH_NODE_LIMIT_MULTIPLIER = max(1, int(os.environ.get('GRAPH_SEARCH_NODE_LIMIT_MULTIPLIER', '1')))
GRAPH_SEARCH_NODE_SUMMARY_LIMIT = max(1, int(os.environ.get('GRAPH_SEARCH_NODE_SUMMARY_LIMIT', '5')))
GRAPH_SEARCH_EXPAND_EDGES_FROM_NODES = os.environ.get('GRAPH_SEARCH_EXPAND_EDGES_FROM_NODES', 'true').lower() == 'true'
GRAPH_SEARCH_NODE_EDGE_EXPANSION_LIMIT = max(0, int(os.environ.get('GRAPH_SEARCH_NODE_EDGE_EXPANSION_LIMIT', '2')))
GRAPH_SEARCH_NODE_EDGE_PER_NODE_LIMIT = max(1, int(os.environ.get('GRAPH_SEARCH_NODE_EDGE_PER_NODE_LIMIT', '8')))
GRAPHITI_URI = os.environ.get('GRAPHITI_URI')
GRAPHITI_USER = os.environ.get('GRAPHITI_USER', 'neo4j')
GRAPHITI_PASSWORD = os.environ.get('GRAPHITI_PASSWORD')
GRAPHITI_DATABASE = os.environ.get('GRAPHITI_DATABASE', 'neo4j')
GRAPHITI_LLM_API_KEY = os.environ.get('GRAPHITI_LLM_API_KEY') or LLM_API_KEY
GRAPHITI_LLM_BASE_URL = os.environ.get('GRAPHITI_LLM_BASE_URL') or LLM_BASE_URL
GRAPHITI_LLM_MODEL = os.environ.get('GRAPHITI_LLM_MODEL') or LLM_MODEL_NAME
GRAPHITI_LLM_SMALL_MODEL = os.environ.get('GRAPHITI_LLM_SMALL_MODEL') or GRAPHITI_LLM_MODEL
GRAPHITI_LLM_CLIENT_MODE = os.environ.get('GRAPHITI_LLM_CLIENT_MODE', 'openai').lower()
GRAPHITI_LLM_MAX_TOKENS = max(1024, int(os.environ.get('GRAPHITI_LLM_MAX_TOKENS', '16384')))
GRAPHITI_EMBEDDER_API_KEY = os.environ.get('GRAPHITI_EMBEDDER_API_KEY') or GRAPHITI_LLM_API_KEY
GRAPHITI_EMBEDDER_BASE_URL = os.environ.get('GRAPHITI_EMBEDDER_BASE_URL') or GRAPHITI_LLM_BASE_URL
GRAPHITI_EMBEDDER_MODEL = os.environ.get('GRAPHITI_EMBEDDER_MODEL', 'qwen3-embedding:8b')
GRAPHITI_EMBEDDER_DIM = max(128, int(os.environ.get('GRAPHITI_EMBEDDER_DIM', '1024')))
GRAPHITI_RERANKER_API_KEY = os.environ.get('GRAPHITI_RERANKER_API_KEY') or GRAPHITI_LLM_API_KEY
GRAPHITI_RERANKER_BASE_URL = os.environ.get('GRAPHITI_RERANKER_BASE_URL') or GRAPHITI_LLM_BASE_URL
GRAPHITI_RERANKER_MODEL = os.environ.get('GRAPHITI_RERANKER_MODEL') or GRAPHITI_LLM_MODEL
GRAPHITI_ENABLE_CROSS_ENCODER = os.environ.get('GRAPHITI_ENABLE_CROSS_ENCODER', 'false').lower() == 'true'
GRAPHITI_MAX_COROUTINES = max(1, int(os.environ.get('GRAPHITI_MAX_COROUTINES', '20')))
# 文件上传配置
MAX_CONTENT_LENGTH = 50 * 1024 * 1024 # 50MB
@ -62,6 +109,123 @@ class Config:
REPORT_AGENT_MAX_TOOL_CALLS = int(os.environ.get('REPORT_AGENT_MAX_TOOL_CALLS', '5'))
REPORT_AGENT_MAX_REFLECTION_ROUNDS = int(os.environ.get('REPORT_AGENT_MAX_REFLECTION_ROUNDS', '2'))
REPORT_AGENT_TEMPERATURE = float(os.environ.get('REPORT_AGENT_TEMPERATURE', '0.5'))
@classmethod
def use_openzep(cls):
"""是否启用 OpenZep / 自定义 Zep endpoint。"""
return cls.ZEP_MODE == 'openzep' or bool(cls.ZEP_BASE_URL)
@classmethod
def is_zep_configured(cls, api_key=None):
"""检查 Zep/OpenZep 是否已完成最小配置。"""
resolved_api_key = cls.ZEP_API_KEY if api_key is None else api_key
if cls.use_openzep():
return bool(cls.ZEP_BASE_URL)
return bool(resolved_api_key)
@classmethod
def get_zep_client_kwargs(cls, api_key=None):
"""生成 Zep 客户端初始化参数。"""
resolved_api_key = cls.ZEP_API_KEY if api_key is None else api_key
kwargs = {}
# OpenZep 场景允许关闭鉴权;此时不要传空字符串 api_key
# 否则 zep sdk 会构造出非法的 `Api-Key:` 请求头httpx 会直接拒绝。
if resolved_api_key:
kwargs['api_key'] = resolved_api_key
if cls.ZEP_BASE_URL:
kwargs['base_url'] = cls.ZEP_BASE_URL
return kwargs
@classmethod
def get_zep_config_errors(cls, api_key=None):
"""返回 Zep/OpenZep 的配置错误。"""
if cls.is_zep_configured(api_key=api_key):
return []
if cls.use_openzep():
return ["ZEP_BASE_URL 未配置"]
return ["ZEP_API_KEY 未配置"]
@classmethod
def get_graph_search_embedder_config(cls):
"""返回 app-side 图搜索语义重排使用的 embedding 配置。"""
api_key = cls.GRAPH_SEARCH_APP_EMBEDDER_API_KEY
base_url = cls.GRAPH_SEARCH_APP_EMBEDDER_BASE_URL
model = cls.GRAPH_SEARCH_APP_EMBEDDER_MODEL
backend = (cls.GRAPH_BACKEND or 'graphiti').lower()
if backend == 'graphiti':
api_key = api_key or cls.GRAPHITI_EMBEDDER_API_KEY
base_url = base_url or cls.GRAPHITI_EMBEDDER_BASE_URL
model = model or cls.GRAPHITI_EMBEDDER_MODEL
elif cls.use_openzep():
api_key = api_key or cls.OPENZEP_EMBEDDER_API_KEY
base_url = base_url or cls.OPENZEP_EMBEDDER_BASE_URL
model = model or cls.OPENZEP_EMBEDDER_MODEL
if model and not api_key:
api_key = 'ollama'
return {
'api_key': api_key,
'base_url': base_url,
'model': model,
}
@classmethod
def get_graph_search_reranker_config(cls):
"""返回 app-side 图搜索交叉编码重排使用的 reranker 配置。"""
api_key = cls.GRAPH_SEARCH_APP_RERANKER_API_KEY
base_url = cls.GRAPH_SEARCH_APP_RERANKER_BASE_URL
model = cls.GRAPH_SEARCH_APP_RERANKER_MODEL
provider = cls.GRAPH_SEARCH_APP_RERANKER_PROVIDER
graphiti_reranker_base_url = os.environ.get('GRAPHITI_RERANKER_BASE_URL') or None
if graphiti_reranker_base_url:
api_key = api_key or (os.environ.get('GRAPHITI_RERANKER_API_KEY') or None)
base_url = base_url or graphiti_reranker_base_url
model = model or (os.environ.get('GRAPHITI_RERANKER_MODEL') or None)
return {
'api_key': api_key,
'base_url': base_url,
'model': model,
'provider': provider,
'timeout': cls.GRAPH_SEARCH_APP_RERANKER_TIMEOUT,
}
@classmethod
def get_graphiti_config_errors(cls):
"""返回 Graphiti + Neo4j 的配置错误。"""
errors = []
if not cls.GRAPHITI_URI:
errors.append("GRAPHITI_URI 未配置")
if not cls.GRAPHITI_DATABASE:
errors.append("GRAPHITI_DATABASE 未配置")
if not cls.GRAPHITI_LLM_MODEL:
errors.append("GRAPHITI_LLM_MODEL 未配置")
if not cls.GRAPHITI_EMBEDDER_MODEL:
errors.append("GRAPHITI_EMBEDDER_MODEL 未配置")
return errors
@classmethod
def get_graph_backend_config_errors(cls, api_key=None):
"""根据当前 GRAPH_BACKEND 返回对应的配置错误。"""
backend = (cls.GRAPH_BACKEND or 'graphiti').lower()
if backend in {'zep', 'openzep'}:
return cls.get_zep_config_errors(api_key=api_key)
if backend == 'graphiti':
return cls.get_graphiti_config_errors()
return [f"不支持的 GRAPH_BACKEND: {backend}"]
@classmethod
def is_graph_backend_configured(cls, api_key=None):
return len(cls.get_graph_backend_config_errors(api_key=api_key)) == 0
@classmethod
def validate(cls):
@ -69,7 +233,5 @@ class Config:
errors = []
if not cls.LLM_API_KEY:
errors.append("LLM_API_KEY 未配置")
if not cls.ZEP_API_KEY:
errors.append("ZEP_API_KEY 未配置")
errors.extend(cls.get_graph_backend_config_errors())
return errors

View File

@ -0,0 +1,8 @@
"""
Graph backend abstractions.
"""
from .base import GraphBackend
from .factory import get_graph_backend
__all__ = ["GraphBackend", "get_graph_backend"]

81
backend/app/graph/base.py Normal file
View File

@ -0,0 +1,81 @@
"""
Common graph backend interface.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
class GraphBackend(ABC):
"""Minimal graph backend interface used by the application services."""
@property
def raw_client(self) -> Any:
"""Return the underlying SDK client when a service still needs it."""
return None
@abstractmethod
def create_graph(self, graph_id: str, name: str, description: str) -> None:
"""Create a graph."""
@abstractmethod
def set_ontology(
self,
graph_id: str,
entities: Any = None,
edges: Any = None,
) -> None:
"""Set graph ontology."""
@abstractmethod
def add_batch(self, graph_id: str, episodes: List[Any]) -> List[Any]:
"""Add a batch of episodes."""
@abstractmethod
def add_text(self, graph_id: str, data: str) -> Any:
"""Add a single text episode."""
@abstractmethod
def get_episode(self, episode_uuid: str) -> Any:
"""Fetch a single episode."""
@abstractmethod
def search(
self,
graph_id: str,
query: str,
limit: int = 10,
scope: str = "edges",
reranker: Optional[str] = None,
) -> Any:
"""Search the graph."""
@abstractmethod
def get_all_nodes(self, graph_id: str) -> List[Any]:
"""Fetch all nodes."""
@abstractmethod
def get_all_edges(self, graph_id: str) -> List[Any]:
"""Fetch all edges."""
@abstractmethod
def get_node(self, node_uuid: str) -> Any:
"""Fetch a node by UUID."""
@abstractmethod
def get_node_edges(self, node_uuid: str) -> List[Any]:
"""Fetch edges for a node."""
@abstractmethod
def delete_graph(self, graph_id: str) -> None:
"""Delete a graph."""
def get_ontology_spec(self, graph_id: str) -> Optional[Dict[str, Any]]:
"""Fetch backend ontology metadata when available."""
return None
def get_live_graph_statistics(self, graph_id: str) -> Optional[Dict[str, int]]:
"""Fetch backend-specific live statistics when available."""
return None

View File

@ -0,0 +1,27 @@
"""
Graph backend factory.
"""
from __future__ import annotations
from typing import Optional
from ..config import Config
from .base import GraphBackend
def get_graph_backend(api_key: Optional[str] = None) -> GraphBackend:
"""Create the configured graph backend."""
backend = (Config.GRAPH_BACKEND or "graphiti").lower()
if backend in {"zep", "openzep"}:
from .zep_backend import ZepGraphBackend
return ZepGraphBackend(api_key=api_key)
if backend == "graphiti":
from .graphiti_backend import GraphitiBackend
return GraphitiBackend(api_key=api_key)
raise ValueError(f"不支持的 GRAPH_BACKEND: {backend}")

View File

@ -0,0 +1,875 @@
"""
Graphiti + Neo4j graph backend implementation.
"""
from __future__ import annotations
import asyncio
import json
import logging
import threading
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, Dict, List, Optional
from pydantic import BaseModel, Field, create_model
from ..config import Config
from .base import GraphBackend
logger = logging.getLogger(__name__)
@dataclass
class _CompatEpisode:
uuid: str
processed: bool = True
name: str = ""
content: str = ""
valid_at: Optional[datetime] = None
created_at: Optional[datetime] = None
@property
def uuid_(self) -> str:
return self.uuid
@dataclass
class _CompatNode:
uuid: str
name: str = ""
labels: List[str] = field(default_factory=list)
summary: str = ""
attributes: Dict[str, Any] = field(default_factory=dict)
created_at: Optional[datetime] = None
@property
def uuid_(self) -> str:
return self.uuid
@dataclass
class _CompatEdge:
uuid: str
name: str
fact: str
source_node_uuid: str
target_node_uuid: str
source_node_name: str = ""
target_node_name: str = ""
attributes: Dict[str, Any] = field(default_factory=dict)
episodes: List[str] = field(default_factory=list)
created_at: Optional[datetime] = None
valid_at: Optional[datetime] = None
invalid_at: Optional[datetime] = None
expired_at: Optional[datetime] = None
@property
def uuid_(self) -> str:
return self.uuid
@dataclass
class _CompatSearchResults:
edges: List[_CompatEdge] = field(default_factory=list)
nodes: List[_CompatNode] = field(default_factory=list)
@dataclass
class _OntologyBundle:
entity_types: Dict[str, type[BaseModel]] = field(default_factory=dict)
edge_types: Dict[str, type[BaseModel]] = field(default_factory=dict)
edge_type_map: Dict[tuple[str, str], List[str]] = field(default_factory=dict)
spec: Dict[str, Any] = field(default_factory=dict)
class _AsyncBridge:
"""Run async Graphiti calls from the app's synchronous service layer."""
def __init__(self):
self._ready = threading.Event()
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._thread = threading.Thread(target=self._run_loop, daemon=True)
self._thread.start()
self._ready.wait()
def _run_loop(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
self._loop = loop
self._ready.set()
loop.run_forever()
def run(self, coro):
if self._loop is None:
raise RuntimeError("Graphiti async loop 未初始化")
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
return future.result()
class GraphitiBackend(GraphBackend):
"""Graph backend backed by Graphiti OSS + Neo4j."""
_bridge: Optional[_AsyncBridge] = None
_bridge_lock = threading.Lock()
_indices_ready = False
_indices_lock = threading.Lock()
_ontology_registry: Dict[str, _OntologyBundle] = {}
_ontology_lock = threading.Lock()
_cross_encoder_warning_emitted = False
PAGE_SIZE = 200
def __init__(self, api_key: Optional[str] = None):
del api_key
errors = Config.get_graphiti_config_errors()
if errors:
raise ValueError("; ".join(errors))
try:
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig
from graphiti_core.graphiti import Graphiti
from graphiti_core.llm_client import OpenAIClient
from graphiti_core.llm_client.config import LLMConfig
from graphiti_core.llm_client.openai_generic_client import OpenAIGenericClient
except ImportError as exc:
raise ImportError(
"Graphiti 依赖未安装,请先在 backend 环境中安装 graphiti-core 与 neo4j"
) from exc
llm_config = LLMConfig(
api_key=Config.GRAPHITI_LLM_API_KEY,
base_url=Config.GRAPHITI_LLM_BASE_URL,
model=Config.GRAPHITI_LLM_MODEL,
small_model=Config.GRAPHITI_LLM_SMALL_MODEL,
temperature=0,
)
reranker_config = LLMConfig(
api_key=Config.GRAPHITI_RERANKER_API_KEY,
base_url=Config.GRAPHITI_RERANKER_BASE_URL,
model=Config.GRAPHITI_RERANKER_MODEL,
temperature=0,
)
embedder_config = OpenAIEmbedderConfig(
api_key=Config.GRAPHITI_EMBEDDER_API_KEY,
base_url=Config.GRAPHITI_EMBEDDER_BASE_URL,
embedding_model=Config.GRAPHITI_EMBEDDER_MODEL,
embedding_dim=Config.GRAPHITI_EMBEDDER_DIM,
)
llm_client_mode = (Config.GRAPHITI_LLM_CLIENT_MODE or "openai").lower()
if llm_client_mode == "generic":
llm_client = OpenAIGenericClient(
config=llm_config,
max_tokens=Config.GRAPHITI_LLM_MAX_TOKENS,
)
else:
llm_client = OpenAIClient(
config=llm_config,
max_tokens=Config.GRAPHITI_LLM_MAX_TOKENS,
)
self._graphiti = Graphiti(
uri=Config.GRAPHITI_URI,
user=Config.GRAPHITI_USER,
password=Config.GRAPHITI_PASSWORD,
llm_client=llm_client,
embedder=OpenAIEmbedder(config=embedder_config),
cross_encoder=OpenAIRerankerClient(config=reranker_config),
max_coroutines=Config.GRAPHITI_MAX_COROUTINES,
)
self._driver = self._graphiti.driver.with_database(Config.GRAPHITI_DATABASE)
self._graphiti.driver = self._driver
self._graphiti.clients.driver = self._driver
self._bridge = self._get_bridge()
self._ensure_indices()
@classmethod
def _get_bridge(cls) -> _AsyncBridge:
with cls._bridge_lock:
if cls._bridge is None:
cls._bridge = _AsyncBridge()
return cls._bridge
@property
def raw_client(self) -> Any:
return self._graphiti
def _run(self, coro):
return self._bridge.run(coro)
def _ensure_indices(self) -> None:
if self.__class__._indices_ready:
return
with self.__class__._indices_lock:
if self.__class__._indices_ready:
return
self._run(self._graphiti.build_indices_and_constraints())
self.__class__._indices_ready = True
def _validate_graph_id(self, graph_id: str) -> None:
from graphiti_core.helpers import validate_group_id
validate_group_id(graph_id)
def _normalize_model_spec(self, model: type[BaseModel]) -> Dict[str, Any]:
fields = []
for field_name, model_field in model.model_fields.items():
fields.append(
{
"name": field_name,
"description": model_field.description or field_name,
}
)
return {
"description": (getattr(model, "__doc__", "") or "").strip(),
"fields": fields,
}
def _build_dynamic_model(self, name: str, spec: Dict[str, Any]) -> type[BaseModel]:
field_definitions = {}
for field_spec in spec.get("fields", []):
field_name = field_spec.get("name", "").strip()
if not field_name:
continue
field_definitions[field_name] = (
Optional[str],
Field(
default=None,
description=field_spec.get("description") or field_name,
),
)
model = create_model(name, __base__=BaseModel, **field_definitions)
model.__doc__ = spec.get("description") or name
return model
def _serialize_ontology_spec(
self,
entity_specs: Dict[str, Dict[str, Any]],
edge_specs: Dict[str, Dict[str, Any]],
edge_type_map: Dict[tuple[str, str], List[str]],
) -> Dict[str, Any]:
return {
"entity_types": entity_specs,
"edge_types": edge_specs,
"edge_type_map": [
{
"source": source,
"target": target,
"edges": edge_names,
}
for (source, target), edge_names in sorted(edge_type_map.items())
],
}
def _bundle_from_spec(self, spec: Dict[str, Any]) -> _OntologyBundle:
entity_types = {
name: self._build_dynamic_model(name, model_spec)
for name, model_spec in (spec.get("entity_types") or {}).items()
}
edge_types = {
name: self._build_dynamic_model(name, model_spec)
for name, model_spec in (spec.get("edge_types") or {}).items()
}
edge_type_map: Dict[tuple[str, str], List[str]] = {}
for entry in spec.get("edge_type_map") or []:
source = entry.get("source", "Entity")
target = entry.get("target", "Entity")
edge_type_map[(source, target)] = list(entry.get("edges") or [])
return _OntologyBundle(
entity_types=entity_types,
edge_types=edge_types,
edge_type_map=edge_type_map,
spec=spec,
)
def _build_ontology_bundle(self, entities: Any = None, edges: Any = None) -> _OntologyBundle:
entity_specs = {}
entity_types = {}
for entity_name, entity_model in (entities or {}).items():
entity_spec = self._normalize_model_spec(entity_model)
entity_specs[entity_name] = entity_spec
entity_types[entity_name] = self._build_dynamic_model(entity_name, entity_spec)
edge_specs = {}
edge_types = {}
edge_type_map: Dict[tuple[str, str], List[str]] = {}
for edge_name, edge_value in (edges or {}).items():
if not isinstance(edge_value, tuple) or len(edge_value) != 2:
continue
edge_model, source_targets = edge_value
edge_spec = self._normalize_model_spec(edge_model)
edge_specs[edge_name] = edge_spec
edge_types[edge_name] = self._build_dynamic_model(edge_name, edge_spec)
for source_target in source_targets or []:
source = getattr(source_target, "source", "Entity") or "Entity"
target = getattr(source_target, "target", "Entity") or "Entity"
edge_type_map.setdefault((source, target), []).append(edge_name)
if not edge_type_map:
edge_type_map = {("Entity", "Entity"): list(edge_types.keys())}
spec = self._serialize_ontology_spec(entity_specs, edge_specs, edge_type_map)
return _OntologyBundle(
entity_types=entity_types,
edge_types=edge_types,
edge_type_map=edge_type_map,
spec=spec,
)
async def _upsert_graph_metadata_async(
self,
graph_id: str,
name: Optional[str] = None,
description: Optional[str] = None,
ontology_spec: Optional[Dict[str, Any]] = None,
) -> None:
payload = json.dumps(ontology_spec, ensure_ascii=False) if ontology_spec is not None else None
await self._driver.execute_query(
"""
MERGE (m:GraphMetadata {graph_id: $graph_id})
ON CREATE SET
m.group_id = $graph_id,
m.created_at = datetime()
SET
m.name = CASE WHEN $name IS NULL OR $name = '' THEN coalesce(m.name, '') ELSE $name END,
m.description = CASE
WHEN $description IS NULL OR $description = ''
THEN coalesce(m.description, '')
ELSE $description
END,
m.ontology_json = CASE
WHEN $ontology_json IS NULL OR $ontology_json = ''
THEN m.ontology_json
ELSE $ontology_json
END,
m.updated_at = datetime()
""",
graph_id=graph_id,
name=name,
description=description,
ontology_json=payload,
)
async def _load_ontology_bundle_async(self, graph_id: str) -> _OntologyBundle:
records, _, _ = await self._driver.execute_query(
"""
MATCH (m:GraphMetadata {graph_id: $graph_id})
RETURN m.ontology_json AS ontology_json
""",
graph_id=graph_id,
routing_="r",
)
if not records:
return _OntologyBundle()
ontology_json = records[0].get("ontology_json")
if not ontology_json:
return _OntologyBundle()
try:
spec = json.loads(ontology_json)
except (TypeError, ValueError, json.JSONDecodeError):
logger.warning("Graphiti ontology metadata 解析失败graph_id=%s", graph_id)
return _OntologyBundle()
return self._bundle_from_spec(spec)
async def _get_ontology_bundle_async(self, graph_id: str) -> _OntologyBundle:
with self.__class__._ontology_lock:
cached = self.__class__._ontology_registry.get(graph_id)
if cached is not None:
return cached
bundle = await self._load_ontology_bundle_async(graph_id)
with self.__class__._ontology_lock:
self.__class__._ontology_registry[graph_id] = bundle
return bundle
def _get_ontology_bundle(self, graph_id: str) -> _OntologyBundle:
return self._run(self._get_ontology_bundle_async(graph_id))
def _set_ontology_bundle(self, graph_id: str, bundle: _OntologyBundle) -> None:
with self.__class__._ontology_lock:
self.__class__._ontology_registry[graph_id] = bundle
def get_ontology_spec(self, graph_id: str) -> Optional[Dict[str, Any]]:
self._validate_graph_id(graph_id)
bundle = self._get_ontology_bundle(graph_id)
return dict(bundle.spec) if bundle.spec else None
def create_graph(self, graph_id: str, name: str, description: str) -> None:
self._validate_graph_id(graph_id)
self._run(
self._upsert_graph_metadata_async(
graph_id=graph_id,
name=name,
description=description,
)
)
def set_ontology(
self,
graph_id: str,
entities: Any = None,
edges: Any = None,
) -> None:
self._validate_graph_id(graph_id)
bundle = self._build_ontology_bundle(entities=entities, edges=edges)
self._set_ontology_bundle(graph_id, bundle)
self._run(
self._upsert_graph_metadata_async(
graph_id=graph_id,
ontology_spec=bundle.spec,
)
)
async def _add_text_async(self, graph_id: str, data: str) -> _CompatEpisode:
from graphiti_core.helpers import validate_excluded_entity_types
from graphiti_core.nodes import EpisodeType, EpisodicNode
from graphiti_core.search.search_utils import RELEVANT_SCHEMA_LIMIT
from graphiti_core.utils.datetime_utils import utc_now
from graphiti_core.utils.maintenance.node_operations import (
extract_attributes_from_nodes,
extract_nodes,
resolve_extracted_nodes,
)
from graphiti_core.utils.ontology_utils.entity_types_utils import validate_entity_types
bundle = await self._get_ontology_bundle_async(graph_id)
entity_types = bundle.entity_types or None
edge_types = bundle.edge_types or None
edge_type_map = bundle.edge_type_map or {("Entity", "Entity"): []}
validate_entity_types(entity_types)
validate_excluded_entity_types(None, entity_types)
now = utc_now()
previous_episodes = await self._graphiti.retrieve_episodes(
reference_time=now,
last_n=RELEVANT_SCHEMA_LIMIT,
group_ids=[graph_id],
source=EpisodeType.text,
driver=self._driver,
)
episode = EpisodicNode(
name=f"episode_{now.strftime('%Y%m%d%H%M%S%f')}",
group_id=graph_id,
labels=[],
source=EpisodeType.text,
content=data,
source_description="text",
created_at=now,
valid_at=now,
)
extracted_nodes = await extract_nodes(
self._graphiti.clients,
episode,
previous_episodes,
entity_types,
None,
None,
)
nodes, uuid_map, _ = await resolve_extracted_nodes(
self._graphiti.clients,
extracted_nodes,
episode,
previous_episodes,
entity_types,
)
resolved_edges, invalidated_edges, new_edges = await self._graphiti._extract_and_resolve_edges(
episode,
extracted_nodes,
previous_episodes,
edge_type_map,
graph_id,
edge_types,
nodes,
uuid_map,
None,
)
entity_edges = resolved_edges + invalidated_edges
hydrated_nodes = await extract_attributes_from_nodes(
self._graphiti.clients,
nodes,
episode,
previous_episodes,
entity_types,
edges=new_edges,
)
_, saved_episode = await self._graphiti._process_episode_data(
episode,
hydrated_nodes,
entity_edges,
now,
graph_id,
None,
None,
)
return _CompatEpisode(
uuid=saved_episode.uuid,
processed=True,
name=saved_episode.name,
content=saved_episode.content,
valid_at=saved_episode.valid_at,
created_at=saved_episode.created_at,
)
def add_batch(self, graph_id: str, episodes: List[Any]) -> List[Any]:
results = []
for episode in episodes:
data = getattr(episode, "data", None)
if data is None and isinstance(episode, dict):
data = episode.get("data", "")
results.append(self.add_text(graph_id=graph_id, data=str(data or "")))
return results
def add_text(self, graph_id: str, data: str) -> Any:
self._validate_graph_id(graph_id)
return self._run(self._add_text_async(graph_id=graph_id, data=data))
async def _get_episode_async(self, episode_uuid: str) -> _CompatEpisode:
from graphiti_core.nodes import EpisodicNode
episode = await EpisodicNode.get_by_uuid(self._driver, episode_uuid)
return _CompatEpisode(
uuid=episode.uuid,
processed=True,
name=episode.name,
content=episode.content,
valid_at=episode.valid_at,
created_at=episode.created_at,
)
def get_episode(self, episode_uuid: str) -> Any:
return self._run(self._get_episode_async(episode_uuid))
def _warn_cross_encoder_fallback(self) -> None:
if self.__class__._cross_encoder_warning_emitted:
return
logger.info(
"Graphiti cross_encoder 默认已降级为 rrf如需启用请设置 GRAPHITI_ENABLE_CROSS_ENCODER=true"
)
self.__class__._cross_encoder_warning_emitted = True
def _build_search_config(self, scope: str, limit: int, reranker: Optional[str]):
from graphiti_core.search.search_config import (
EdgeReranker,
EdgeSearchConfig,
EdgeSearchMethod,
NodeReranker,
NodeSearchConfig,
NodeSearchMethod,
SearchConfig,
)
reranker_name = (reranker or "rrf").strip().lower()
edge_reranker_map = {
"rrf": EdgeReranker.rrf,
"reciprocal_rank_fusion": EdgeReranker.rrf,
"cross_encoder": EdgeReranker.cross_encoder,
"node_distance": EdgeReranker.node_distance,
"episode_mentions": EdgeReranker.episode_mentions,
"mmr": EdgeReranker.mmr,
}
node_reranker_map = {
"rrf": NodeReranker.rrf,
"reciprocal_rank_fusion": NodeReranker.rrf,
"cross_encoder": NodeReranker.cross_encoder,
"node_distance": NodeReranker.node_distance,
"episode_mentions": NodeReranker.episode_mentions,
"mmr": NodeReranker.mmr,
}
edge_reranker = edge_reranker_map.get(reranker_name, EdgeReranker.rrf)
node_reranker = node_reranker_map.get(reranker_name, NodeReranker.rrf)
edge_methods = [EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity]
node_methods = [NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity]
if reranker_name == "cross_encoder":
if Config.GRAPHITI_ENABLE_CROSS_ENCODER:
edge_methods.append(EdgeSearchMethod.bfs)
node_methods.append(NodeSearchMethod.bfs)
else:
self._warn_cross_encoder_fallback()
edge_reranker = EdgeReranker.rrf
node_reranker = NodeReranker.rrf
edge_config = None
node_config = None
if scope in {"edges", "both"}:
edge_config = EdgeSearchConfig(
search_methods=edge_methods,
reranker=edge_reranker,
)
if scope in {"nodes", "both"}:
node_config = NodeSearchConfig(
search_methods=node_methods,
reranker=node_reranker,
)
return SearchConfig(
edge_config=edge_config,
node_config=node_config,
limit=max(1, limit),
)
def _wrap_node(self, node: Any) -> _CompatNode:
return _CompatNode(
uuid=getattr(node, "uuid", ""),
name=getattr(node, "name", "") or "",
labels=list(getattr(node, "labels", []) or []),
summary=getattr(node, "summary", "") or "",
attributes=dict(getattr(node, "attributes", {}) or {}),
created_at=getattr(node, "created_at", None),
)
def _wrap_edge(
self,
edge: Any,
source_node_name: str = "",
target_node_name: str = "",
) -> _CompatEdge:
return _CompatEdge(
uuid=getattr(edge, "uuid", ""),
name=getattr(edge, "name", "") or "",
fact=getattr(edge, "fact", "") or "",
source_node_uuid=getattr(edge, "source_node_uuid", "") or "",
target_node_uuid=getattr(edge, "target_node_uuid", "") or "",
source_node_name=source_node_name,
target_node_name=target_node_name,
attributes=dict(getattr(edge, "attributes", {}) or {}),
episodes=list(getattr(edge, "episodes", []) or []),
created_at=getattr(edge, "created_at", None),
valid_at=getattr(edge, "valid_at", None),
invalid_at=getattr(edge, "invalid_at", None),
expired_at=getattr(edge, "expired_at", None),
)
async def _search_async(
self,
graph_id: str,
query: str,
limit: int,
scope: str,
reranker: Optional[str],
) -> _CompatSearchResults:
from graphiti_core.nodes import EntityNode
search_config = self._build_search_config(scope=scope, limit=limit, reranker=reranker)
results = await self._graphiti.search_(
query=query,
config=search_config,
group_ids=[graph_id],
driver=self._driver,
)
nodes = [self._wrap_node(node) for node in results.nodes]
node_name_map = {node.uuid: node.name for node in nodes if node.uuid}
missing_node_ids = {
node_uuid
for edge in results.edges
for node_uuid in (edge.source_node_uuid, edge.target_node_uuid)
if node_uuid and node_uuid not in node_name_map
}
if missing_node_ids:
for node in await EntityNode.get_by_uuids(self._driver, list(missing_node_ids)):
node_name_map[node.uuid] = node.name or ""
edges = [
self._wrap_edge(
edge,
source_node_name=node_name_map.get(edge.source_node_uuid, ""),
target_node_name=node_name_map.get(edge.target_node_uuid, ""),
)
for edge in results.edges
]
return _CompatSearchResults(edges=edges, nodes=nodes)
def search(
self,
graph_id: str,
query: str,
limit: int = 10,
scope: str = "edges",
reranker: Optional[str] = None,
) -> Any:
self._validate_graph_id(graph_id)
return self._run(
self._search_async(
graph_id=graph_id,
query=query,
limit=limit,
scope=scope,
reranker=reranker,
)
)
async def _get_all_nodes_async(self, graph_id: str) -> List[_CompatNode]:
from graphiti_core.nodes import EntityNode
result = []
cursor = None
while True:
batch = await EntityNode.get_by_group_ids(
self._driver,
[graph_id],
limit=self.PAGE_SIZE,
uuid_cursor=cursor,
)
if not batch:
break
result.extend(self._wrap_node(node) for node in batch)
if len(batch) < self.PAGE_SIZE:
break
cursor = batch[-1].uuid
return result
def get_all_nodes(self, graph_id: str) -> List[Any]:
self._validate_graph_id(graph_id)
return self._run(self._get_all_nodes_async(graph_id))
async def _get_all_edges_async(self, graph_id: str) -> List[_CompatEdge]:
from graphiti_core.edges import EntityEdge, GroupsEdgesNotFoundError
result = []
cursor = None
while True:
try:
batch = await EntityEdge.get_by_group_ids(
self._driver,
[graph_id],
limit=self.PAGE_SIZE,
uuid_cursor=cursor,
)
except GroupsEdgesNotFoundError:
break
if not batch:
break
result.extend(self._wrap_edge(edge) for edge in batch)
if len(batch) < self.PAGE_SIZE:
break
cursor = batch[-1].uuid
return result
def get_all_edges(self, graph_id: str) -> List[Any]:
self._validate_graph_id(graph_id)
return self._run(self._get_all_edges_async(graph_id))
async def _get_node_async(self, node_uuid: str) -> _CompatNode:
from graphiti_core.nodes import EntityNode
return self._wrap_node(await EntityNode.get_by_uuid(self._driver, node_uuid))
def get_node(self, node_uuid: str) -> Any:
return self._run(self._get_node_async(node_uuid))
async def _get_node_edges_async(self, node_uuid: str) -> List[_CompatEdge]:
from graphiti_core.edges import EntityEdge
from graphiti_core.nodes import EntityNode
edges = await EntityEdge.get_by_node_uuid(self._driver, node_uuid)
related_node_ids = {
related_uuid
for edge in edges
for related_uuid in (edge.source_node_uuid, edge.target_node_uuid)
if related_uuid
}
node_name_map = {}
if related_node_ids:
for node in await EntityNode.get_by_uuids(self._driver, list(related_node_ids)):
node_name_map[node.uuid] = node.name or ""
return [
self._wrap_edge(
edge,
source_node_name=node_name_map.get(edge.source_node_uuid, ""),
target_node_name=node_name_map.get(edge.target_node_uuid, ""),
)
for edge in edges
]
def get_node_edges(self, node_uuid: str) -> List[Any]:
return self._run(self._get_node_edges_async(node_uuid))
async def _delete_graph_async(self, graph_id: str) -> None:
from graphiti_core.nodes import Node
await self._driver.execute_query(
"""
MATCH (s:Saga {group_id: $graph_id})
DETACH DELETE s
""",
graph_id=graph_id,
)
await Node.delete_by_group_id(self._driver, graph_id)
await self._driver.execute_query(
"""
MATCH (m:GraphMetadata {graph_id: $graph_id})
DETACH DELETE m
""",
graph_id=graph_id,
)
def delete_graph(self, graph_id: str) -> None:
self._validate_graph_id(graph_id)
self._run(self._delete_graph_async(graph_id))
with self.__class__._ontology_lock:
self.__class__._ontology_registry.pop(graph_id, None)
async def _get_live_graph_statistics_async(self, graph_id: str) -> Dict[str, int]:
node_records, _, _ = await self._driver.execute_query(
"""
MATCH (n:Entity {group_id: $graph_id})
RETURN count(n) AS node_count
""",
graph_id=graph_id,
routing_="r",
)
edge_records, _, _ = await self._driver.execute_query(
"""
MATCH ()-[e:RELATES_TO {group_id: $graph_id}]->()
RETURN count(e) AS edge_count
""",
graph_id=graph_id,
routing_="r",
)
episode_records, _, _ = await self._driver.execute_query(
"""
MATCH (n:Episodic {group_id: $graph_id})
RETURN count(n) AS episode_count
""",
graph_id=graph_id,
routing_="r",
)
return {
"node_count": int((node_records[0] if node_records else {}).get("node_count", 0) or 0),
"edge_count": int((edge_records[0] if edge_records else {}).get("edge_count", 0) or 0),
"episode_count": int(
(episode_records[0] if episode_records else {}).get("episode_count", 0) or 0
),
}
def get_live_graph_statistics(self, graph_id: str) -> Optional[Dict[str, int]]:
self._validate_graph_id(graph_id)
return self._run(self._get_live_graph_statistics_async(graph_id))

View File

@ -0,0 +1,114 @@
"""
Zep / OpenZep graph backend implementation.
"""
from __future__ import annotations
import json
from typing import Any, Dict, List, Optional
from urllib.error import HTTPError, URLError
from urllib.request import Request, urlopen
from zep_cloud.client import Zep
from ..config import Config
from ..utils.zep_paging import fetch_all_edges, fetch_all_nodes
from .base import GraphBackend
class ZepGraphBackend(GraphBackend):
"""Graph backend backed by Zep Cloud or OpenZep."""
def __init__(self, api_key: Optional[str] = None):
self.api_key = Config.ZEP_API_KEY if api_key is None else api_key
errors = Config.get_zep_config_errors(api_key=self.api_key)
if errors:
raise ValueError("; ".join(errors))
self.client = Zep(**Config.get_zep_client_kwargs(api_key=self.api_key))
@property
def raw_client(self) -> Zep:
return self.client
def create_graph(self, graph_id: str, name: str, description: str) -> None:
self.client.graph.create(
graph_id=graph_id,
name=name,
description=description,
)
def set_ontology(
self,
graph_id: str,
entities: Any = None,
edges: Any = None,
) -> None:
self.client.graph.set_ontology(
graph_ids=[graph_id],
entities=entities,
edges=edges,
)
def add_batch(self, graph_id: str, episodes: List[Any]) -> List[Any]:
return self.client.graph.add_batch(graph_id=graph_id, episodes=episodes)
def add_text(self, graph_id: str, data: str) -> Any:
return self.client.graph.add(graph_id=graph_id, type="text", data=data)
def get_episode(self, episode_uuid: str) -> Any:
return self.client.graph.episode.get(uuid_=episode_uuid)
def search(
self,
graph_id: str,
query: str,
limit: int = 10,
scope: str = "edges",
reranker: Optional[str] = None,
) -> Any:
kwargs = {
"graph_id": graph_id,
"query": query,
"limit": limit,
"scope": scope,
}
if reranker:
kwargs["reranker"] = reranker
return self.client.graph.search(**kwargs)
def get_all_nodes(self, graph_id: str) -> List[Any]:
return fetch_all_nodes(self.client, graph_id)
def get_all_edges(self, graph_id: str) -> List[Any]:
return fetch_all_edges(self.client, graph_id)
def get_node(self, node_uuid: str) -> Any:
return self.client.graph.node.get(uuid_=node_uuid)
def get_node_edges(self, node_uuid: str) -> List[Any]:
return self.client.graph.node.get_entity_edges(node_uuid=node_uuid)
def delete_graph(self, graph_id: str) -> None:
self.client.graph.delete(graph_id=graph_id)
def get_live_graph_statistics(self, graph_id: str) -> Optional[Dict[str, int]]:
if not Config.ZEP_BASE_URL:
return None
base_url = Config.ZEP_BASE_URL.rstrip("/")
request = Request(f"{base_url}/graph/{graph_id}/statistics")
if self.api_key:
request.add_header("Authorization", f"Bearer {self.api_key}")
try:
with urlopen(request, timeout=10) as response:
payload = json.loads(response.read().decode("utf-8"))
except (HTTPError, URLError, TimeoutError, OSError, json.JSONDecodeError):
return None
return {
"node_count": max(0, int(payload.get("node_count", 0) or 0)),
"edge_count": max(0, int(payload.get("edge_count", 0) or 0)),
"episode_count": max(0, int(payload.get("episode_count", 0) or 0)),
}

View File

@ -7,15 +7,14 @@ import os
import uuid
import time
import threading
import json
from typing import Dict, Any, List, Optional, Callable
from dataclasses import dataclass
from zep_cloud.client import Zep
from zep_cloud import EpisodeData, EntityEdgeSourceTarget
from ..config import Config
from ..graph import get_graph_backend
from ..models.task import TaskManager, TaskStatus
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
from .text_processor import TextProcessor
@ -43,11 +42,12 @@ class GraphBuilderService:
"""
def __init__(self, api_key: Optional[str] = None):
self.api_key = api_key or Config.ZEP_API_KEY
if not self.api_key:
raise ValueError("ZEP_API_KEY 未配置")
self.api_key = Config.ZEP_API_KEY if api_key is None else api_key
errors = Config.get_graph_backend_config_errors(api_key=self.api_key)
if errors:
raise ValueError("; ".join(errors))
self.client = Zep(api_key=self.api_key)
self.backend = get_graph_backend(api_key=self.api_key)
self.task_manager = TaskManager()
def build_graph_async(
@ -57,7 +57,7 @@ class GraphBuilderService:
graph_name: str = "MiroFish Graph",
chunk_size: int = 500,
chunk_overlap: int = 50,
batch_size: int = 3
batch_size: int = 1
) -> str:
"""
异步构建图谱
@ -155,6 +155,7 @@ class GraphBuilderService:
)
self._wait_for_episodes(
graph_id,
episode_uuids,
lambda msg, prog: self.task_manager.update_task(
task_id,
@ -188,10 +189,10 @@ class GraphBuilderService:
"""创建Zep图谱公开方法"""
graph_id = f"mirofish_{uuid.uuid4().hex[:16]}"
self.client.graph.create(
self.backend.create_graph(
graph_id=graph_id,
name=name,
description="MiroFish Social Simulation Graph"
description="MiroFish Social Simulation Graph",
)
return graph_id
@ -279,8 +280,8 @@ class GraphBuilderService:
# 调用Zep API设置本体
if entity_types or edge_definitions:
self.client.graph.set_ontology(
graph_ids=[graph_id],
self.backend.set_ontology(
graph_id=graph_id,
entities=entity_types if entity_types else None,
edges=edge_definitions if edge_definitions else None,
)
@ -289,7 +290,7 @@ class GraphBuilderService:
self,
graph_id: str,
chunks: List[str],
batch_size: int = 3,
batch_size: int = 1,
progress_callback: Optional[Callable] = None
) -> List[str]:
"""分批添加文本到图谱,返回所有 episode 的 uuid 列表"""
@ -316,7 +317,7 @@ class GraphBuilderService:
# 发送到Zep
try:
batch_result = self.client.graph.add_batch(
batch_result = self.backend.add_batch(
graph_id=graph_id,
episodes=episodes
)
@ -337,70 +338,152 @@ class GraphBuilderService:
raise
return episode_uuids
def _get_live_graph_statistics(self, graph_id: str) -> Optional[Dict[str, int]]:
"""直接读取后端的实时图谱统计。"""
return self.backend.get_live_graph_statistics(graph_id)
def _wait_for_episodes(
self,
graph_id: str,
episode_uuids: List[str],
progress_callback: Optional[Callable] = None,
timeout: int = 600
):
"""等待所有 episode 处理完成(通过查询每个 episode 的 processed 状态)"""
"""等待 OpenZep 处理完成,优先参考真实图谱状态。"""
if not episode_uuids:
if progress_callback:
progress_callback("无需等待(没有 episode", 1.0)
return
start_time = time.time()
pending_episodes = set(episode_uuids)
completed_count = 0
total_episodes = len(episode_uuids)
last_graph_signature: Optional[tuple[int, int, int]] = None
stable_graph_polls = 0
stable_graph_required = 2
last_live_stats: Optional[Dict[str, int]] = None
if progress_callback:
progress_callback(f"开始等待 {total_episodes} 个文本块处理...", 0)
while pending_episodes:
if time.time() - start_time > timeout:
elapsed_seconds = time.time() - start_time
if elapsed_seconds > timeout:
if last_live_stats is not None:
graph_episode_count = min(last_live_stats["episode_count"], total_episodes)
graph_node_count = last_live_stats["node_count"]
graph_edge_count = last_live_stats["edge_count"]
graph_entity_like_nodes = max(0, graph_node_count - last_live_stats["episode_count"])
if graph_episode_count >= total_episodes and (graph_entity_like_nodes > 0 or graph_edge_count > 0):
if progress_callback:
progress_callback(
(
f"OpenZep 接口进度未返回完成标记,但真实图谱已写入 "
f"episodes={graph_episode_count}/{total_episodes}, "
f"nodes={graph_node_count}, edges={graph_edge_count}"
),
1.0,
)
return
if progress_callback:
progress_callback(
f"部分文本块超时,已完成 {completed_count}/{total_episodes}",
completed_count / total_episodes
)
break
# 检查每个 episode 的处理状态
for ep_uuid in list(pending_episodes):
try:
episode = self.client.graph.episode.get(uuid_=ep_uuid)
episode = self.backend.get_episode(ep_uuid)
is_processed = getattr(episode, 'processed', False)
if is_processed:
pending_episodes.remove(ep_uuid)
completed_count += 1
except Exception as e:
# 忽略单个查询错误,继续
except Exception:
pass
elapsed = int(time.time() - start_time)
if progress_callback:
progress_callback(
f"Zep处理中... {completed_count}/{total_episodes} 完成, {len(pending_episodes)} 待处理 ({elapsed}秒)",
completed_count / total_episodes if total_episodes > 0 else 0
live_stats = self._get_live_graph_statistics(graph_id)
graph_episode_count = 0
graph_node_count = 0
graph_edge_count = 0
graph_entity_like_nodes = 0
graph_progress = 0.0
if live_stats is not None:
last_live_stats = live_stats
graph_episode_count = min(live_stats["episode_count"], total_episodes)
graph_node_count = live_stats["node_count"]
graph_edge_count = live_stats["edge_count"]
graph_entity_like_nodes = max(0, graph_node_count - live_stats["episode_count"])
graph_progress = graph_episode_count / total_episodes if total_episodes > 0 else 1.0
graph_signature = (
graph_episode_count,
graph_entity_like_nodes,
graph_edge_count,
)
if graph_signature == last_graph_signature:
stable_graph_polls += 1
else:
last_graph_signature = graph_signature
stable_graph_polls = 0
graph_ready = (
graph_episode_count >= total_episodes
and (graph_entity_like_nodes > 0 or graph_edge_count > 0)
and stable_graph_polls >= stable_graph_required
)
if graph_ready:
if progress_callback:
progress_callback(
(
f"OpenZep 图谱已稳定: episodes={graph_episode_count}/{total_episodes}, "
f"nodes={graph_node_count}, edges={graph_edge_count}"
),
1.0,
)
return
elapsed = int(elapsed_seconds)
effective_progress = max(
completed_count / total_episodes if total_episodes > 0 else 1.0,
graph_progress,
)
if progress_callback:
if live_stats is not None:
progress_callback(
(
f"OpenZep处理中... 接口完成 {completed_count}/{total_episodes}, "
f"图中已写入 episodes={graph_episode_count}/{total_episodes}, "
f"nodes={graph_node_count}, edges={graph_edge_count} ({elapsed}秒)"
),
effective_progress,
)
else:
progress_callback(
f"Zep处理中... {completed_count}/{total_episodes} 完成, {len(pending_episodes)} 待处理 ({elapsed}秒)",
completed_count / total_episodes if total_episodes > 0 else 0
)
if pending_episodes:
time.sleep(3) # 每3秒检查一次
time.sleep(3)
if progress_callback:
progress_callback(f"处理完成: {completed_count}/{total_episodes}", 1.0)
def _get_graph_info(self, graph_id: str) -> GraphInfo:
"""获取图谱信息"""
# 获取节点(分页)
nodes = fetch_all_nodes(self.client, graph_id)
nodes = self.backend.get_all_nodes(graph_id)
# 获取边(分页)
edges = fetch_all_edges(self.client, graph_id)
edges = self.backend.get_all_edges(graph_id)
# 统计实体类型
entity_types = set()
@ -427,8 +510,8 @@ class GraphBuilderService:
Returns:
包含nodes和edges的字典包括时间信息属性等详细数据
"""
nodes = fetch_all_nodes(self.client, graph_id)
edges = fetch_all_edges(self.client, graph_id)
nodes = self.backend.get_all_nodes(graph_id)
edges = self.backend.get_all_edges(graph_id)
# 创建节点映射用于获取节点名称
node_map = {}
@ -496,5 +579,4 @@ class GraphBuilderService:
def delete_graph(self, graph_id: str):
"""删除图谱"""
self.client.graph.delete(graph_id=graph_id)
self.backend.delete_graph(graph_id)

View File

@ -16,9 +16,9 @@ from dataclasses import dataclass, field
from datetime import datetime
from openai import OpenAI
from zep_cloud.client import Zep
from ..config import Config
from ..graph import get_graph_backend
from ..utils.logger import get_logger
from .zep_entity_reader import EntityNode, ZepEntityReader
@ -198,15 +198,15 @@ class OasisProfileGenerator:
)
# Zep客户端用于检索丰富上下文
self.zep_api_key = zep_api_key or Config.ZEP_API_KEY
self.zep_client = None
self.zep_api_key = Config.ZEP_API_KEY if zep_api_key is None else zep_api_key
self.zep_backend = None
self.graph_id = graph_id
if self.zep_api_key:
if Config.is_graph_backend_configured(api_key=self.zep_api_key):
try:
self.zep_client = Zep(api_key=self.zep_api_key)
self.zep_backend = get_graph_backend(api_key=self.zep_api_key)
except Exception as e:
logger.warning(f"Zep客户端初始化失败: {e}")
logger.warning(f"图谱客户端初始化失败: {e}")
def generate_profile_from_entity(
self,
@ -297,7 +297,7 @@ class OasisProfileGenerator:
"""
import concurrent.futures
if not self.zep_client:
if not self.zep_backend:
return {"facts": [], "node_summaries": [], "context": ""}
entity_name = entity.name
@ -323,7 +323,7 @@ class OasisProfileGenerator:
for attempt in range(max_retries):
try:
return self.zep_client.graph.search(
return self.zep_backend.search(
query=comprehensive_query,
graph_id=self.graph_id,
limit=30,
@ -348,7 +348,7 @@ class OasisProfileGenerator:
for attempt in range(max_retries):
try:
return self.zep_client.graph.search(
return self.zep_backend.search(
query=comprehensive_query,
graph_id=self.graph_id,
limit=20,
@ -1197,4 +1197,3 @@ class OasisProfileGenerator:
"""[已废弃] 请使用 save_profiles() 方法"""
logger.warning("save_profiles_to_json已废弃请使用save_profiles方法")
self.save_profiles(profiles, file_path, platform)

View File

@ -7,11 +7,9 @@ import time
from typing import Dict, Any, List, Optional, Set, Callable, TypeVar
from dataclasses import dataclass, field
from zep_cloud.client import Zep
from ..config import Config
from ..graph import get_graph_backend
from ..utils.logger import get_logger
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
logger = get_logger('mirofish.zep_entity_reader')
@ -79,11 +77,12 @@ class ZepEntityReader:
"""
def __init__(self, api_key: Optional[str] = None):
self.api_key = api_key or Config.ZEP_API_KEY
if not self.api_key:
raise ValueError("ZEP_API_KEY 未配置")
self.api_key = Config.ZEP_API_KEY if api_key is None else api_key
errors = Config.get_graph_backend_config_errors(api_key=self.api_key)
if errors:
raise ValueError("; ".join(errors))
self.client = Zep(api_key=self.api_key)
self.backend = get_graph_backend(api_key=self.api_key)
def _call_with_retry(
self,
@ -136,7 +135,7 @@ class ZepEntityReader:
"""
logger.info(f"获取图谱 {graph_id} 的所有节点...")
nodes = fetch_all_nodes(self.client, graph_id)
nodes = self.backend.get_all_nodes(graph_id)
nodes_data = []
for node in nodes:
@ -163,7 +162,7 @@ class ZepEntityReader:
"""
logger.info(f"获取图谱 {graph_id} 的所有边...")
edges = fetch_all_edges(self.client, graph_id)
edges = self.backend.get_all_edges(graph_id)
edges_data = []
for edge in edges:
@ -192,7 +191,7 @@ class ZepEntityReader:
try:
# 使用重试机制调用Zep API
edges = self._call_with_retry(
func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid),
func=lambda: self.backend.get_node_edges(node_uuid),
operation_name=f"获取节点边(node={node_uuid[:8]}...)"
)
@ -348,7 +347,7 @@ class ZepEntityReader:
try:
# 使用重试机制获取节点
node = self._call_with_retry(
func=lambda: self.client.graph.node.get(uuid_=entity_uuid),
func=lambda: self.backend.get_node(entity_uuid),
operation_name=f"获取节点详情(uuid={entity_uuid[:8]}...)"
)
@ -434,4 +433,3 @@ class ZepEntityReader:
)
return result.entities

View File

@ -12,9 +12,8 @@ from dataclasses import dataclass
from datetime import datetime
from queue import Queue, Empty
from zep_cloud.client import Zep
from ..config import Config
from ..graph import get_graph_backend
from ..utils.logger import get_logger
logger = get_logger('mirofish.zep_graph_memory_updater')
@ -237,12 +236,13 @@ class ZepGraphMemoryUpdater:
api_key: Zep API Key可选默认从配置读取
"""
self.graph_id = graph_id
self.api_key = api_key or Config.ZEP_API_KEY
self.api_key = Config.ZEP_API_KEY if api_key is None else api_key
if not self.api_key:
raise ValueError("ZEP_API_KEY未配置")
errors = Config.get_graph_backend_config_errors(api_key=self.api_key)
if errors:
raise ValueError("; ".join(errors))
self.client = Zep(api_key=self.api_key)
self.backend = get_graph_backend(api_key=self.api_key)
# 活动队列
self._activity_queue: Queue = Queue()
@ -405,9 +405,8 @@ class ZepGraphMemoryUpdater:
# 带重试的发送
for attempt in range(self.MAX_RETRIES):
try:
self.client.graph.add(
self.backend.add_text(
graph_id=self.graph_id,
type="text",
data=combined_text
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,47 @@
"""
Embedding client wrapper for OpenAI-compatible embedding APIs.
"""
from __future__ import annotations
from typing import List, Optional
from openai import OpenAI
class EmbeddingClient:
"""Thin wrapper around OpenAI-compatible embedding endpoints."""
def __init__(
self,
api_key: Optional[str],
base_url: str,
model: str,
batch_size: int = 32,
):
if not base_url:
raise ValueError("Embedding base_url 未配置")
if not model:
raise ValueError("Embedding model 未配置")
self.api_key = api_key or 'ollama'
self.base_url = base_url
self.model = model
self.batch_size = max(1, int(batch_size))
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
def embed_texts(self, texts: List[str]) -> List[List[float]]:
"""Embed texts in batches while preserving input order."""
if not texts:
return []
embeddings: List[List[float]] = []
normalized_inputs = [str(text or ' ').strip() or ' ' for text in texts]
for start in range(0, len(normalized_inputs), self.batch_size):
batch = normalized_inputs[start:start + self.batch_size]
response = self.client.embeddings.create(model=self.model, input=batch)
data = sorted(response.data, key=lambda item: item.index)
embeddings.extend(item.embedding for item in data)
return embeddings

View File

@ -0,0 +1,175 @@
"""
Reranker client wrappers for common HTTP rerank APIs.
"""
from __future__ import annotations
import json
from dataclasses import dataclass
from typing import Dict, List, Optional
from urllib import error, request
@dataclass
class RerankerRequestSpec:
provider: str
path: str
body: dict
class RerankerClient:
"""Thin wrapper around HTTP rerank endpoints."""
def __init__(
self,
base_url: str,
model: Optional[str] = None,
api_key: Optional[str] = None,
provider: str = "auto",
timeout: float = 20.0,
):
if not base_url:
raise ValueError("Reranker base_url 未配置")
self.base_url = base_url.rstrip("/")
self.model = model
self.api_key = api_key or None
self.provider = (provider or "auto").strip().lower() or "auto"
self.timeout = max(1.0, float(timeout))
def rerank(self, query: str, documents: List[str]) -> Dict[int, float]:
"""Return index -> score for the supplied candidate documents."""
if not documents:
return {}
last_error: Optional[Exception] = None
for spec in self._build_request_specs(query, documents):
try:
return self._execute_request(spec)
except Exception as exc:
last_error = exc
if last_error is not None:
raise last_error
raise RuntimeError("没有可用的 reranker provider")
def _build_request_specs(self, query: str, documents: List[str]) -> List[RerankerRequestSpec]:
providers = [self.provider]
if self.provider == "auto":
providers = ["tei", "jina", "cohere", "vllm", "infinity"]
specs: List[RerankerRequestSpec] = []
for provider in providers:
if provider == "tei":
specs.append(
RerankerRequestSpec(
provider=provider,
path="/rerank",
body={
"query": query,
"texts": documents,
"truncate": True,
"raw_scores": False,
},
)
)
elif provider in {"jina", "vllm", "infinity"}:
body = {
"query": query,
"documents": documents,
"top_n": len(documents),
"return_documents": False,
}
if self.model:
body["model"] = self.model
specs.append(
RerankerRequestSpec(
provider=provider,
path="/v1/rerank",
body=body,
)
)
elif provider == "cohere":
body = {
"query": query,
"documents": documents,
"top_n": len(documents),
"return_documents": False,
}
if self.model:
body["model"] = self.model
specs.append(
RerankerRequestSpec(
provider=provider,
path="/v2/rerank",
body=body,
)
)
return specs
def _headers(self) -> dict:
headers = {"Content-Type": "application/json"}
if self.api_key:
headers["Authorization"] = f"Bearer {self.api_key}"
return headers
def _execute_request(self, spec: RerankerRequestSpec) -> Dict[int, float]:
payload = json.dumps(spec.body).encode("utf-8")
req = request.Request(
f"{self.base_url}{spec.path}",
data=payload,
headers=self._headers(),
method="POST",
)
try:
with request.urlopen(req, timeout=self.timeout) as resp:
body = resp.read().decode("utf-8")
except error.HTTPError as exc:
detail = exc.read().decode("utf-8", errors="replace")
raise RuntimeError(f"{spec.provider} rerank 请求失败: HTTP {exc.code}: {detail[:300]}") from exc
except error.URLError as exc:
raise RuntimeError(f"{spec.provider} rerank 请求失败: {exc.reason}") from exc
try:
data = json.loads(body)
except json.JSONDecodeError as exc:
raise RuntimeError(f"{spec.provider} rerank 返回非 JSON 响应") from exc
scores = self._parse_scores(spec.provider, data)
if not scores:
raise RuntimeError(f"{spec.provider} rerank 未返回有效分数")
return scores
def _parse_scores(self, provider: str, payload: object) -> Dict[int, float]:
if provider == "tei":
if not isinstance(payload, list):
raise RuntimeError("TEI rerank 响应格式异常")
scores: Dict[int, float] = {}
for item in payload:
if not isinstance(item, dict):
continue
index = item.get("index")
score = item.get("score")
if isinstance(index, int) and score is not None:
scores[index] = float(score)
return scores
if not isinstance(payload, dict):
raise RuntimeError("rerank 响应格式异常")
results = payload.get("results") or payload.get("data") or []
scores: Dict[int, float] = {}
if isinstance(results, list):
for item in results:
if not isinstance(item, dict):
continue
index = item.get("index")
score = item.get("relevance_score")
if score is None:
score = item.get("score")
if isinstance(index, int) and score is not None:
scores[index] = float(score)
return scores

View File

@ -14,10 +14,14 @@ dependencies = [
"flask-cors>=6.0.0",
# LLM 相关
"openai>=1.0.0",
"openai>=1.91.0",
# Zep Cloud
"zep-cloud==3.13.0",
# Graphiti / Neo4j
"graphiti-core==0.28.2",
"neo4j>=5.26.0",
# OASIS 社交媒体模拟
"camel-oasis==0.2.5",
@ -31,7 +35,7 @@ dependencies = [
# 工具库
"python-dotenv>=1.0.0",
"pydantic>=2.0.0",
"pydantic>=2.11.5",
]
[project.optional-dependencies]

View File

@ -11,11 +11,15 @@ flask-cors>=6.0.0
# ============= LLM 相关 =============
# OpenAI SDK统一使用 OpenAI 格式调用 LLM
openai>=1.0.0
openai>=1.91.0
# ============= Zep Cloud =============
zep-cloud==3.13.0
# ============= Graphiti / Neo4j =============
graphiti-core==0.28.2
neo4j>=5.26.0
# ============= OASIS 社交媒体模拟 =============
# OASIS 社交模拟框架
camel-oasis==0.2.5
@ -32,4 +36,4 @@ chardet>=5.0.0
python-dotenv>=1.0.0
# 数据验证
pydantic>=2.0.0
pydantic>=2.11.5

View File

@ -0,0 +1,223 @@
#!/usr/bin/env python3
"""
Service-level Graphiti smoke test for MiroFish.
The script creates a temporary graph, applies a small ontology, ingests a few
text chunks through GraphBuilderService, then queries through ZepToolsService.
It prints JSON output and deletes the graph by default.
"""
from __future__ import annotations
import argparse
import json
import os
import sys
import traceback
import uuid
from pathlib import Path
from typing import Iterable
REPO_ROOT = Path(__file__).resolve().parents[2]
BACKEND_ROOT = REPO_ROOT / "backend"
DEFAULT_CHUNKS = [
"Alice works at Acme Robotics. Bob manages Alice.",
"Acme Robotics is located in Tokyo. Alice relocated to Tokyo for work.",
]
def load_env_file(path: Path) -> None:
"""Load a simple .env file without overriding existing variables."""
if not path.exists():
return
for raw_line in path.read_text(encoding="utf-8").splitlines():
line = raw_line.strip()
if not line or line.startswith("#") or "=" not in line:
continue
key, value = line.split("=", 1)
key = key.strip()
value = value.strip()
if value and value[0] == value[-1] and value[0] in {"'", '"'}:
value = value[1:-1]
os.environ.setdefault(key, value)
def prepare_environment() -> None:
"""Apply repo-local defaults that make Graphiti smoke testing predictable."""
load_env_file(REPO_ROOT / ".env")
load_env_file(BACKEND_ROOT / ".env")
os.environ.setdefault("GRAPH_BACKEND", "graphiti")
os.environ.setdefault("GRAPHITI_URI", "bolt://127.0.0.1:7687")
os.environ.setdefault("GRAPHITI_USER", "neo4j")
os.environ.setdefault("GRAPHITI_PASSWORD", "password123")
os.environ.setdefault("GRAPHITI_DATABASE", "neo4j")
os.environ.setdefault("GRAPHITI_LLM_BASE_URL", os.environ.get("LLM_BASE_URL", "http://127.0.0.1:18081/v1"))
os.environ.setdefault("GRAPHITI_LLM_MODEL", os.environ.get("LLM_MODEL_NAME", "gpt-5.4"))
os.environ.setdefault("GRAPHITI_LLM_CLIENT_MODE", "openai")
os.environ.setdefault("GRAPH_SEARCH_RERANKER", "rrf")
os.environ.setdefault("GRAPHITI_EMBEDDER_API_KEY", "ollama")
os.environ.setdefault("GRAPHITI_EMBEDDER_BASE_URL", "http://127.0.0.1:11434/v1")
os.environ.setdefault("GRAPHITI_EMBEDDER_MODEL", "qwen3-embedding:8b")
os.environ.setdefault("GRAPHITI_EMBEDDER_DIM", "1024")
def parse_args(argv: Iterable[str]) -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run a Graphiti service-level smoke test.")
parser.add_argument(
"--query",
default="Where does Alice work?",
help="Search query executed through ZepToolsService.",
)
parser.add_argument(
"--chunk",
action="append",
dest="chunks",
help="Text chunk to ingest. Repeat to add multiple chunks.",
)
parser.add_argument(
"--graph-name",
default="graphiti smoke test",
help="Temporary graph name.",
)
parser.add_argument(
"--keep-graph",
action="store_true",
help="Keep the temporary graph instead of deleting it on exit.",
)
return parser.parse_args(list(argv))
def build_sample_ontology() -> dict:
return {
"entity_types": [
{"name": "Person", "description": "A human individual.", "attributes": []},
{"name": "Organization", "description": "A company or institution.", "attributes": []},
{"name": "Location", "description": "A place or city.", "attributes": []},
],
"edge_types": [
{
"name": "WORKS_AT",
"description": "A person works at an organization.",
"attributes": [],
"source_targets": [{"source": "Person", "target": "Organization"}],
},
{
"name": "MANAGES",
"description": "A person manages another person.",
"attributes": [],
"source_targets": [{"source": "Person", "target": "Person"}],
},
{
"name": "LOCATED_IN",
"description": "An organization is located in a place.",
"attributes": [],
"source_targets": [{"source": "Organization", "target": "Location"}],
},
],
}
def main(argv: Iterable[str]) -> int:
prepare_environment()
sys.path.insert(0, str(BACKEND_ROOT))
from app.config import Config
search_embedder = Config.get_graph_search_embedder_config()
search_reranker = Config.get_graph_search_reranker_config()
errors = Config.get_graph_backend_config_errors(api_key=Config.ZEP_API_KEY)
if errors:
print(
json.dumps(
{
"success": False,
"error": "Graph backend config is incomplete.",
"backend": Config.GRAPH_BACKEND,
"errors": errors,
},
ensure_ascii=False,
indent=2,
),
file=sys.stderr,
)
return 2
from app.services.graph_builder import GraphBuilderService
from app.services.zep_tools import ZepToolsService
args = parse_args(argv)
chunks = args.chunks or list(DEFAULT_CHUNKS)
builder = GraphBuilderService()
tools = ZepToolsService()
graph_id = builder.create_graph(f"{args.graph_name}-{uuid.uuid4().hex[:8]}")
try:
builder.set_ontology(graph_id, build_sample_ontology())
builder.add_text_batches(graph_id, chunks, batch_size=1)
result = tools.search_graph(
graph_id=graph_id,
query=args.query,
limit=5,
scope="edges",
)
stats = tools.get_graph_statistics(graph_id)
output = {
"success": True,
"graph_id": graph_id,
"kept_graph": args.keep_graph,
"backend": Config.GRAPH_BACKEND,
"llm_base_url": Config.GRAPHITI_LLM_BASE_URL or Config.LLM_BASE_URL,
"llm_model": Config.GRAPHITI_LLM_MODEL or Config.LLM_MODEL_NAME,
"embedder_base_url": Config.GRAPHITI_EMBEDDER_BASE_URL,
"embedder_model": Config.GRAPHITI_EMBEDDER_MODEL,
"app_reranker": Config.GRAPH_SEARCH_APP_RERANKER,
"app_embedder_base_url": search_embedder.get("base_url"),
"app_embedder_model": search_embedder.get("model"),
"app_reranker_base_url": search_reranker.get("base_url"),
"app_reranker_model": search_reranker.get("model"),
"app_reranker_provider": search_reranker.get("provider"),
"query": args.query,
"chunks": chunks,
"stats": stats,
"facts": result.facts,
"edges": result.edges,
"nodes": result.nodes,
"total_count": result.total_count,
}
print(json.dumps(output, ensure_ascii=False, indent=2))
return 0
except Exception as exc:
print(
json.dumps(
{
"success": False,
"graph_id": graph_id,
"error": str(exc),
"traceback": traceback.format_exc(),
},
ensure_ascii=False,
indent=2,
),
file=sys.stderr,
)
return 1
finally:
if not args.keep_graph:
try:
builder.backend.delete_graph(graph_id)
except Exception:
traceback.print_exc()
if __name__ == "__main__":
raise SystemExit(main(sys.argv[1:]))