feat(graph): add pluggable graph backend with Graphiti support
This commit is contained in:
parent
1536a79334
commit
25d43f8a4b
78
.env.example
78
.env.example
|
|
@ -1,16 +1,78 @@
|
|||
# LLM API配置(支持 OpenAI SDK 格式的任意 LLM API)
|
||||
# 推荐使用阿里百炼平台qwen-plus模型:https://bailian.console.aliyun.com/
|
||||
# 注意消耗较大,可先进行小于40轮的模拟尝试
|
||||
# LLM API 配置(支持 OpenAI SDK 格式的任意 LLM API)
|
||||
# 可直接填写你自己的 LLM 接口,例如:
|
||||
# LLM_BASE_URL=http://127.0.0.1:18081/v1
|
||||
# LLM_MODEL_NAME=gpt-5.4
|
||||
LLM_API_KEY=your_api_key_here
|
||||
LLM_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1
|
||||
LLM_MODEL_NAME=qwen-plus
|
||||
|
||||
# ===== ZEP记忆图谱配置 =====
|
||||
# 每月免费额度即可支撑简单使用:https://app.getzep.com/
|
||||
ZEP_API_KEY=your_zep_api_key_here
|
||||
# Docker 容器内访问宿主机 LLM 时使用的地址
|
||||
# Linux + Docker Compose 下可保持默认 host.docker.internal
|
||||
DOCKER_LLM_BASE_URL=http://host.docker.internal:18081/v1
|
||||
|
||||
# ===== Graphiti + Neo4j(默认推荐)=====
|
||||
GRAPH_BACKEND=graphiti
|
||||
GRAPHITI_URI=bolt://localhost:7687
|
||||
GRAPHITI_USER=neo4j
|
||||
GRAPHITI_PASSWORD=password123
|
||||
GRAPHITI_DATABASE=neo4j
|
||||
GRAPHITI_LLM_CLIENT_MODE=openai
|
||||
GRAPHITI_EMBEDDER_API_KEY=ollama
|
||||
GRAPHITI_EMBEDDER_BASE_URL=http://127.0.0.1:11434/v1
|
||||
GRAPHITI_EMBEDDER_MODEL=qwen3-embedding:8b
|
||||
GRAPHITI_EMBEDDER_DIM=1024
|
||||
GRAPH_SEARCH_RERANKER=rrf
|
||||
GRAPH_SEARCH_APP_RERANKER=embedding_rrf
|
||||
GRAPH_SEARCH_APP_SEMANTIC_WEIGHT=2.0
|
||||
GRAPH_SEARCH_EXPAND_EDGES_FROM_NODES=true
|
||||
OLLAMA_PORT=11434
|
||||
OLLAMA_EMBEDDER_MODEL=qwen3-embedding:8b
|
||||
|
||||
# 可选:如需独立于 Graphiti / OpenZep 当前 embedder,可单独覆写:
|
||||
# GRAPH_SEARCH_APP_EMBEDDER_API_KEY=ollama
|
||||
# GRAPH_SEARCH_APP_EMBEDDER_BASE_URL=http://127.0.0.1:11434/v1
|
||||
# GRAPH_SEARCH_APP_EMBEDDER_MODEL=qwen3-embedding:8b
|
||||
# 可选:如果你另起了免费 cross-encoder / rerank 服务(TEI / Infinity / vLLM)
|
||||
# GRAPH_SEARCH_APP_RERANKER=api_rrf
|
||||
# GRAPH_SEARCH_APP_RERANKER_PROVIDER=tei
|
||||
# GRAPH_SEARCH_APP_RERANKER_BASE_URL=http://127.0.0.1:18090
|
||||
# GRAPH_SEARCH_APP_RERANKER_MODEL=your_reranker_model
|
||||
# GRAPH_SEARCH_APP_RERANKER_TIMEOUT=20
|
||||
# 免费召回增强:从高相关节点补抓相邻边,默认开启
|
||||
# GRAPH_SEARCH_NODE_EDGE_EXPANSION_LIMIT=2
|
||||
# GRAPH_SEARCH_NODE_EDGE_PER_NODE_LIMIT=8
|
||||
# GRAPHITI_ENABLE_CROSS_ENCODER=false
|
||||
|
||||
# Docker 中的 MiroFish 容器访问宿主机 LLM / 容器内 Ollama 时使用:
|
||||
DOCKER_GRAPHITI_LLM_BASE_URL=http://host.docker.internal:18081/v1
|
||||
DOCKER_GRAPHITI_EMBEDDER_BASE_URL=http://ollama:11434/v1
|
||||
|
||||
# ===== OpenZep / Zep(可选,非默认)=====
|
||||
# ZEP_API_KEY=your_zep_api_key_here
|
||||
# ZEP_MODE=openzep
|
||||
# 本地源码运行 MiroFish 时使用 localhost
|
||||
# ZEP_BASE_URL=http://localhost:8000/api/v2
|
||||
# Docker 中的 MiroFish 容器会自动改用 openzep 服务名
|
||||
# DOCKER_ZEP_BASE_URL=http://openzep:8000/api/v2
|
||||
# 留空表示不启用 OpenZep API 鉴权
|
||||
# OPENZEP_API_KEY=
|
||||
# OPENZEP_LLM_API_KEY=your_api_key_here
|
||||
# OPENZEP_LLM_BASE_URL=http://127.0.0.1:18081/v1
|
||||
# OPENZEP_DOCKER_LLM_BASE_URL=http://host.docker.internal:18081/v1
|
||||
# OPENZEP_LLM_MODEL=gpt-5.4
|
||||
# OPENZEP_EMBEDDER_API_KEY=ollama
|
||||
# OPENZEP_EMBEDDER_BASE_URL=http://127.0.0.1:11434/v1
|
||||
# OPENZEP_DOCKER_EMBEDDER_BASE_URL=http://ollama:11434/v1
|
||||
# OPENZEP_EMBEDDER_MODEL=qwen3-embedding:8b
|
||||
|
||||
# ===== Neo4j / OpenZep 端口 =====
|
||||
NEO4J_PASSWORD=password123
|
||||
NEO4J_HTTP_PORT=7474
|
||||
NEO4J_BOLT_PORT=7687
|
||||
OPENZEP_PORT=8000
|
||||
|
||||
# ===== 加速 LLM 配置(可选)=====
|
||||
# 注意如果不使用加速配置,env文件中就不要出现下面的配置项
|
||||
# 注意如果不使用加速配置,env 文件中就不要出现下面的配置项
|
||||
LLM_BOOST_API_KEY=your_api_key_here
|
||||
LLM_BOOST_BASE_URL=your_base_url_here
|
||||
LLM_BOOST_MODEL_NAME=your_model_name_here
|
||||
LLM_BOOST_MODEL_NAME=your_model_name_here
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ COPY --from=ghcr.io/astral-sh/uv:0.9.26 /uv /uvx /bin/
|
|||
|
||||
WORKDIR /app
|
||||
|
||||
ARG INSTALL_GRAPHITI=true
|
||||
|
||||
# 先复制依赖描述文件以利用缓存
|
||||
COPY package.json package-lock.json ./
|
||||
COPY frontend/package.json frontend/package-lock.json ./frontend/
|
||||
|
|
@ -18,7 +20,10 @@ COPY backend/pyproject.toml backend/uv.lock ./backend/
|
|||
# 安装依赖(Node + Python)
|
||||
RUN npm ci \
|
||||
&& npm ci --prefix frontend \
|
||||
&& cd backend && uv sync --frozen
|
||||
&& cd backend && uv sync --frozen \
|
||||
&& if [ "$INSTALL_GRAPHITI" = "true" ]; then \
|
||||
uv pip install --python /app/backend/.venv/bin/python --no-cache-dir graphiti-core==0.28.2 "neo4j>=5.26.0"; \
|
||||
fi
|
||||
|
||||
# 复制项目源码
|
||||
COPY . .
|
||||
|
|
|
|||
|
|
@ -283,9 +283,7 @@ def build_graph():
|
|||
logger.info("=== 开始构建图谱 ===")
|
||||
|
||||
# 检查配置
|
||||
errors = []
|
||||
if not Config.ZEP_API_KEY:
|
||||
errors.append("ZEP_API_KEY未配置")
|
||||
errors = Config.get_graph_backend_config_errors()
|
||||
if errors:
|
||||
logger.error(f"配置错误: {errors}")
|
||||
return jsonify({
|
||||
|
|
@ -432,10 +430,12 @@ def build_graph():
|
|||
progress=15
|
||||
)
|
||||
|
||||
# OpenZep 本地链路在批量抽取时更容易卡在长时间的联合推理里。
|
||||
# 改为单块发送可以显著降低单次处理负载,牺牲吞吐换稳定性。
|
||||
episode_uuids = builder.add_text_batches(
|
||||
graph_id,
|
||||
graph_id,
|
||||
chunks,
|
||||
batch_size=3,
|
||||
batch_size=1 if Config.use_openzep() else 3,
|
||||
progress_callback=add_progress_callback
|
||||
)
|
||||
|
||||
|
|
@ -454,7 +454,7 @@ def build_graph():
|
|||
progress=progress
|
||||
)
|
||||
|
||||
builder._wait_for_episodes(episode_uuids, wait_progress_callback)
|
||||
builder._wait_for_episodes(graph_id, episode_uuids, wait_progress_callback)
|
||||
|
||||
# 获取图谱数据
|
||||
task_manager.update_task(
|
||||
|
|
@ -464,12 +464,20 @@ def build_graph():
|
|||
)
|
||||
graph_data = builder.get_graph_data(graph_id)
|
||||
|
||||
node_count = graph_data.get("node_count", 0)
|
||||
edge_count = graph_data.get("edge_count", 0)
|
||||
|
||||
# 如果图谱仍然是空的,说明 OpenZep 没有成功完成抽取。
|
||||
# 不能把这种情况标记为成功,否则前端会误以为构图完成。
|
||||
if node_count == 0 and edge_count == 0:
|
||||
raise RuntimeError(
|
||||
"图谱构建未产出任何节点或边;OpenZep 处理可能超时或未完成"
|
||||
)
|
||||
|
||||
# 更新项目状态
|
||||
project.status = ProjectStatus.GRAPH_COMPLETED
|
||||
ProjectManager.save_project(project)
|
||||
|
||||
node_count = graph_data.get("node_count", 0)
|
||||
edge_count = graph_data.get("edge_count", 0)
|
||||
|
||||
build_logger.info(f"[{task_id}] 图谱构建完成: graph_id={graph_id}, 节点={node_count}, 边={edge_count}")
|
||||
|
||||
# 完成
|
||||
|
|
@ -567,10 +575,11 @@ def get_graph_data(graph_id: str):
|
|||
获取图谱数据(节点和边)
|
||||
"""
|
||||
try:
|
||||
if not Config.ZEP_API_KEY:
|
||||
errors = Config.get_graph_backend_config_errors()
|
||||
if errors:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": "ZEP_API_KEY未配置"
|
||||
"error": "; ".join(errors)
|
||||
}), 500
|
||||
|
||||
builder = GraphBuilderService(api_key=Config.ZEP_API_KEY)
|
||||
|
|
@ -595,10 +604,11 @@ def delete_graph(graph_id: str):
|
|||
删除Zep图谱
|
||||
"""
|
||||
try:
|
||||
if not Config.ZEP_API_KEY:
|
||||
errors = Config.get_graph_backend_config_errors()
|
||||
if errors:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": "ZEP_API_KEY未配置"
|
||||
"error": "; ".join(errors)
|
||||
}), 500
|
||||
|
||||
builder = GraphBuilderService(api_key=Config.ZEP_API_KEY)
|
||||
|
|
|
|||
|
|
@ -56,10 +56,11 @@ def get_graph_entities(graph_id: str):
|
|||
enrich: 是否获取相关边信息(默认true)
|
||||
"""
|
||||
try:
|
||||
if not Config.ZEP_API_KEY:
|
||||
errors = Config.get_graph_backend_config_errors()
|
||||
if errors:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": "ZEP_API_KEY未配置"
|
||||
"error": "; ".join(errors)
|
||||
}), 500
|
||||
|
||||
entity_types_str = request.args.get('entity_types', '')
|
||||
|
|
@ -93,10 +94,11 @@ def get_graph_entities(graph_id: str):
|
|||
def get_entity_detail(graph_id: str, entity_uuid: str):
|
||||
"""获取单个实体的详细信息"""
|
||||
try:
|
||||
if not Config.ZEP_API_KEY:
|
||||
errors = Config.get_graph_backend_config_errors()
|
||||
if errors:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": "ZEP_API_KEY未配置"
|
||||
"error": "; ".join(errors)
|
||||
}), 500
|
||||
|
||||
reader = ZepEntityReader()
|
||||
|
|
@ -126,10 +128,11 @@ def get_entity_detail(graph_id: str, entity_uuid: str):
|
|||
def get_entities_by_type(graph_id: str, entity_type: str):
|
||||
"""获取指定类型的所有实体"""
|
||||
try:
|
||||
if not Config.ZEP_API_KEY:
|
||||
errors = Config.get_graph_backend_config_errors()
|
||||
if errors:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": "ZEP_API_KEY未配置"
|
||||
"error": "; ".join(errors)
|
||||
}), 500
|
||||
|
||||
enrich = request.args.get('enrich', 'true').lower() == 'true'
|
||||
|
|
|
|||
|
|
@ -33,7 +33,54 @@ class Config:
|
|||
LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME', 'gpt-4o-mini')
|
||||
|
||||
# Zep配置
|
||||
ZEP_MODE = os.environ.get('ZEP_MODE', 'cloud').lower()
|
||||
ZEP_API_KEY = os.environ.get('ZEP_API_KEY')
|
||||
ZEP_BASE_URL = os.environ.get('ZEP_BASE_URL')
|
||||
OPENZEP_EMBEDDER_API_KEY = os.environ.get('OPENZEP_EMBEDDER_API_KEY') or None
|
||||
OPENZEP_EMBEDDER_BASE_URL = os.environ.get('OPENZEP_EMBEDDER_BASE_URL') or None
|
||||
OPENZEP_EMBEDDER_MODEL = os.environ.get('OPENZEP_EMBEDDER_MODEL') or None
|
||||
|
||||
# 图后端配置
|
||||
GRAPH_BACKEND = os.environ.get('GRAPH_BACKEND', 'graphiti').lower()
|
||||
GRAPH_SEARCH_RERANKER = os.environ.get('GRAPH_SEARCH_RERANKER', 'rrf').strip() or None
|
||||
GRAPH_SEARCH_APP_RERANKER = (os.environ.get('GRAPH_SEARCH_APP_RERANKER', 'embedding_rrf').strip().lower() or 'embedding_rrf')
|
||||
GRAPH_SEARCH_APP_RERANK_FUSION_K = max(1, int(os.environ.get('GRAPH_SEARCH_APP_RERANK_FUSION_K', '60')))
|
||||
GRAPH_SEARCH_APP_SEMANTIC_WEIGHT = max(0.0, float(os.environ.get('GRAPH_SEARCH_APP_SEMANTIC_WEIGHT', '2.0')))
|
||||
GRAPH_SEARCH_APP_EMBEDDER_API_KEY = os.environ.get('GRAPH_SEARCH_APP_EMBEDDER_API_KEY') or None
|
||||
GRAPH_SEARCH_APP_EMBEDDER_BASE_URL = os.environ.get('GRAPH_SEARCH_APP_EMBEDDER_BASE_URL') or None
|
||||
GRAPH_SEARCH_APP_EMBEDDER_MODEL = os.environ.get('GRAPH_SEARCH_APP_EMBEDDER_MODEL') or None
|
||||
GRAPH_SEARCH_APP_EMBED_BATCH_SIZE = max(1, int(os.environ.get('GRAPH_SEARCH_APP_EMBED_BATCH_SIZE', '32')))
|
||||
GRAPH_SEARCH_APP_RERANKER_API_KEY = os.environ.get('GRAPH_SEARCH_APP_RERANKER_API_KEY') or None
|
||||
GRAPH_SEARCH_APP_RERANKER_BASE_URL = os.environ.get('GRAPH_SEARCH_APP_RERANKER_BASE_URL') or None
|
||||
GRAPH_SEARCH_APP_RERANKER_MODEL = os.environ.get('GRAPH_SEARCH_APP_RERANKER_MODEL') or None
|
||||
GRAPH_SEARCH_APP_RERANKER_PROVIDER = (os.environ.get('GRAPH_SEARCH_APP_RERANKER_PROVIDER', 'auto').strip().lower() or 'auto')
|
||||
GRAPH_SEARCH_APP_RERANKER_TIMEOUT = max(1.0, float(os.environ.get('GRAPH_SEARCH_APP_RERANKER_TIMEOUT', '20')))
|
||||
GRAPH_SEARCH_INCLUDE_NODES = os.environ.get('GRAPH_SEARCH_INCLUDE_NODES', 'true').lower() == 'true'
|
||||
GRAPH_SEARCH_EDGE_LIMIT_MULTIPLIER = max(1, int(os.environ.get('GRAPH_SEARCH_EDGE_LIMIT_MULTIPLIER', '2')))
|
||||
GRAPH_SEARCH_NODE_LIMIT_MULTIPLIER = max(1, int(os.environ.get('GRAPH_SEARCH_NODE_LIMIT_MULTIPLIER', '1')))
|
||||
GRAPH_SEARCH_NODE_SUMMARY_LIMIT = max(1, int(os.environ.get('GRAPH_SEARCH_NODE_SUMMARY_LIMIT', '5')))
|
||||
GRAPH_SEARCH_EXPAND_EDGES_FROM_NODES = os.environ.get('GRAPH_SEARCH_EXPAND_EDGES_FROM_NODES', 'true').lower() == 'true'
|
||||
GRAPH_SEARCH_NODE_EDGE_EXPANSION_LIMIT = max(0, int(os.environ.get('GRAPH_SEARCH_NODE_EDGE_EXPANSION_LIMIT', '2')))
|
||||
GRAPH_SEARCH_NODE_EDGE_PER_NODE_LIMIT = max(1, int(os.environ.get('GRAPH_SEARCH_NODE_EDGE_PER_NODE_LIMIT', '8')))
|
||||
GRAPHITI_URI = os.environ.get('GRAPHITI_URI')
|
||||
GRAPHITI_USER = os.environ.get('GRAPHITI_USER', 'neo4j')
|
||||
GRAPHITI_PASSWORD = os.environ.get('GRAPHITI_PASSWORD')
|
||||
GRAPHITI_DATABASE = os.environ.get('GRAPHITI_DATABASE', 'neo4j')
|
||||
GRAPHITI_LLM_API_KEY = os.environ.get('GRAPHITI_LLM_API_KEY') or LLM_API_KEY
|
||||
GRAPHITI_LLM_BASE_URL = os.environ.get('GRAPHITI_LLM_BASE_URL') or LLM_BASE_URL
|
||||
GRAPHITI_LLM_MODEL = os.environ.get('GRAPHITI_LLM_MODEL') or LLM_MODEL_NAME
|
||||
GRAPHITI_LLM_SMALL_MODEL = os.environ.get('GRAPHITI_LLM_SMALL_MODEL') or GRAPHITI_LLM_MODEL
|
||||
GRAPHITI_LLM_CLIENT_MODE = os.environ.get('GRAPHITI_LLM_CLIENT_MODE', 'openai').lower()
|
||||
GRAPHITI_LLM_MAX_TOKENS = max(1024, int(os.environ.get('GRAPHITI_LLM_MAX_TOKENS', '16384')))
|
||||
GRAPHITI_EMBEDDER_API_KEY = os.environ.get('GRAPHITI_EMBEDDER_API_KEY') or GRAPHITI_LLM_API_KEY
|
||||
GRAPHITI_EMBEDDER_BASE_URL = os.environ.get('GRAPHITI_EMBEDDER_BASE_URL') or GRAPHITI_LLM_BASE_URL
|
||||
GRAPHITI_EMBEDDER_MODEL = os.environ.get('GRAPHITI_EMBEDDER_MODEL', 'qwen3-embedding:8b')
|
||||
GRAPHITI_EMBEDDER_DIM = max(128, int(os.environ.get('GRAPHITI_EMBEDDER_DIM', '1024')))
|
||||
GRAPHITI_RERANKER_API_KEY = os.environ.get('GRAPHITI_RERANKER_API_KEY') or GRAPHITI_LLM_API_KEY
|
||||
GRAPHITI_RERANKER_BASE_URL = os.environ.get('GRAPHITI_RERANKER_BASE_URL') or GRAPHITI_LLM_BASE_URL
|
||||
GRAPHITI_RERANKER_MODEL = os.environ.get('GRAPHITI_RERANKER_MODEL') or GRAPHITI_LLM_MODEL
|
||||
GRAPHITI_ENABLE_CROSS_ENCODER = os.environ.get('GRAPHITI_ENABLE_CROSS_ENCODER', 'false').lower() == 'true'
|
||||
GRAPHITI_MAX_COROUTINES = max(1, int(os.environ.get('GRAPHITI_MAX_COROUTINES', '20')))
|
||||
|
||||
# 文件上传配置
|
||||
MAX_CONTENT_LENGTH = 50 * 1024 * 1024 # 50MB
|
||||
|
|
@ -62,6 +109,123 @@ class Config:
|
|||
REPORT_AGENT_MAX_TOOL_CALLS = int(os.environ.get('REPORT_AGENT_MAX_TOOL_CALLS', '5'))
|
||||
REPORT_AGENT_MAX_REFLECTION_ROUNDS = int(os.environ.get('REPORT_AGENT_MAX_REFLECTION_ROUNDS', '2'))
|
||||
REPORT_AGENT_TEMPERATURE = float(os.environ.get('REPORT_AGENT_TEMPERATURE', '0.5'))
|
||||
|
||||
@classmethod
|
||||
def use_openzep(cls):
|
||||
"""是否启用 OpenZep / 自定义 Zep endpoint。"""
|
||||
return cls.ZEP_MODE == 'openzep' or bool(cls.ZEP_BASE_URL)
|
||||
|
||||
@classmethod
|
||||
def is_zep_configured(cls, api_key=None):
|
||||
"""检查 Zep/OpenZep 是否已完成最小配置。"""
|
||||
resolved_api_key = cls.ZEP_API_KEY if api_key is None else api_key
|
||||
|
||||
if cls.use_openzep():
|
||||
return bool(cls.ZEP_BASE_URL)
|
||||
|
||||
return bool(resolved_api_key)
|
||||
|
||||
@classmethod
|
||||
def get_zep_client_kwargs(cls, api_key=None):
|
||||
"""生成 Zep 客户端初始化参数。"""
|
||||
resolved_api_key = cls.ZEP_API_KEY if api_key is None else api_key
|
||||
kwargs = {}
|
||||
|
||||
# OpenZep 场景允许关闭鉴权;此时不要传空字符串 api_key,
|
||||
# 否则 zep sdk 会构造出非法的 `Api-Key:` 请求头,httpx 会直接拒绝。
|
||||
if resolved_api_key:
|
||||
kwargs['api_key'] = resolved_api_key
|
||||
if cls.ZEP_BASE_URL:
|
||||
kwargs['base_url'] = cls.ZEP_BASE_URL
|
||||
|
||||
return kwargs
|
||||
|
||||
@classmethod
|
||||
def get_zep_config_errors(cls, api_key=None):
|
||||
"""返回 Zep/OpenZep 的配置错误。"""
|
||||
if cls.is_zep_configured(api_key=api_key):
|
||||
return []
|
||||
|
||||
if cls.use_openzep():
|
||||
return ["ZEP_BASE_URL 未配置"]
|
||||
|
||||
return ["ZEP_API_KEY 未配置"]
|
||||
|
||||
@classmethod
|
||||
def get_graph_search_embedder_config(cls):
|
||||
"""返回 app-side 图搜索语义重排使用的 embedding 配置。"""
|
||||
api_key = cls.GRAPH_SEARCH_APP_EMBEDDER_API_KEY
|
||||
base_url = cls.GRAPH_SEARCH_APP_EMBEDDER_BASE_URL
|
||||
model = cls.GRAPH_SEARCH_APP_EMBEDDER_MODEL
|
||||
|
||||
backend = (cls.GRAPH_BACKEND or 'graphiti').lower()
|
||||
if backend == 'graphiti':
|
||||
api_key = api_key or cls.GRAPHITI_EMBEDDER_API_KEY
|
||||
base_url = base_url or cls.GRAPHITI_EMBEDDER_BASE_URL
|
||||
model = model or cls.GRAPHITI_EMBEDDER_MODEL
|
||||
elif cls.use_openzep():
|
||||
api_key = api_key or cls.OPENZEP_EMBEDDER_API_KEY
|
||||
base_url = base_url or cls.OPENZEP_EMBEDDER_BASE_URL
|
||||
model = model or cls.OPENZEP_EMBEDDER_MODEL
|
||||
|
||||
if model and not api_key:
|
||||
api_key = 'ollama'
|
||||
|
||||
return {
|
||||
'api_key': api_key,
|
||||
'base_url': base_url,
|
||||
'model': model,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_graph_search_reranker_config(cls):
|
||||
"""返回 app-side 图搜索交叉编码重排使用的 reranker 配置。"""
|
||||
api_key = cls.GRAPH_SEARCH_APP_RERANKER_API_KEY
|
||||
base_url = cls.GRAPH_SEARCH_APP_RERANKER_BASE_URL
|
||||
model = cls.GRAPH_SEARCH_APP_RERANKER_MODEL
|
||||
provider = cls.GRAPH_SEARCH_APP_RERANKER_PROVIDER
|
||||
|
||||
graphiti_reranker_base_url = os.environ.get('GRAPHITI_RERANKER_BASE_URL') or None
|
||||
if graphiti_reranker_base_url:
|
||||
api_key = api_key or (os.environ.get('GRAPHITI_RERANKER_API_KEY') or None)
|
||||
base_url = base_url or graphiti_reranker_base_url
|
||||
model = model or (os.environ.get('GRAPHITI_RERANKER_MODEL') or None)
|
||||
|
||||
return {
|
||||
'api_key': api_key,
|
||||
'base_url': base_url,
|
||||
'model': model,
|
||||
'provider': provider,
|
||||
'timeout': cls.GRAPH_SEARCH_APP_RERANKER_TIMEOUT,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_graphiti_config_errors(cls):
|
||||
"""返回 Graphiti + Neo4j 的配置错误。"""
|
||||
errors = []
|
||||
if not cls.GRAPHITI_URI:
|
||||
errors.append("GRAPHITI_URI 未配置")
|
||||
if not cls.GRAPHITI_DATABASE:
|
||||
errors.append("GRAPHITI_DATABASE 未配置")
|
||||
if not cls.GRAPHITI_LLM_MODEL:
|
||||
errors.append("GRAPHITI_LLM_MODEL 未配置")
|
||||
if not cls.GRAPHITI_EMBEDDER_MODEL:
|
||||
errors.append("GRAPHITI_EMBEDDER_MODEL 未配置")
|
||||
return errors
|
||||
|
||||
@classmethod
|
||||
def get_graph_backend_config_errors(cls, api_key=None):
|
||||
"""根据当前 GRAPH_BACKEND 返回对应的配置错误。"""
|
||||
backend = (cls.GRAPH_BACKEND or 'graphiti').lower()
|
||||
if backend in {'zep', 'openzep'}:
|
||||
return cls.get_zep_config_errors(api_key=api_key)
|
||||
if backend == 'graphiti':
|
||||
return cls.get_graphiti_config_errors()
|
||||
return [f"不支持的 GRAPH_BACKEND: {backend}"]
|
||||
|
||||
@classmethod
|
||||
def is_graph_backend_configured(cls, api_key=None):
|
||||
return len(cls.get_graph_backend_config_errors(api_key=api_key)) == 0
|
||||
|
||||
@classmethod
|
||||
def validate(cls):
|
||||
|
|
@ -69,7 +233,5 @@ class Config:
|
|||
errors = []
|
||||
if not cls.LLM_API_KEY:
|
||||
errors.append("LLM_API_KEY 未配置")
|
||||
if not cls.ZEP_API_KEY:
|
||||
errors.append("ZEP_API_KEY 未配置")
|
||||
errors.extend(cls.get_graph_backend_config_errors())
|
||||
return errors
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,8 @@
|
|||
"""
|
||||
Graph backend abstractions.
|
||||
"""
|
||||
|
||||
from .base import GraphBackend
|
||||
from .factory import get_graph_backend
|
||||
|
||||
__all__ = ["GraphBackend", "get_graph_backend"]
|
||||
|
|
@ -0,0 +1,81 @@
|
|||
"""
|
||||
Common graph backend interface.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class GraphBackend(ABC):
|
||||
"""Minimal graph backend interface used by the application services."""
|
||||
|
||||
@property
|
||||
def raw_client(self) -> Any:
|
||||
"""Return the underlying SDK client when a service still needs it."""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def create_graph(self, graph_id: str, name: str, description: str) -> None:
|
||||
"""Create a graph."""
|
||||
|
||||
@abstractmethod
|
||||
def set_ontology(
|
||||
self,
|
||||
graph_id: str,
|
||||
entities: Any = None,
|
||||
edges: Any = None,
|
||||
) -> None:
|
||||
"""Set graph ontology."""
|
||||
|
||||
@abstractmethod
|
||||
def add_batch(self, graph_id: str, episodes: List[Any]) -> List[Any]:
|
||||
"""Add a batch of episodes."""
|
||||
|
||||
@abstractmethod
|
||||
def add_text(self, graph_id: str, data: str) -> Any:
|
||||
"""Add a single text episode."""
|
||||
|
||||
@abstractmethod
|
||||
def get_episode(self, episode_uuid: str) -> Any:
|
||||
"""Fetch a single episode."""
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self,
|
||||
graph_id: str,
|
||||
query: str,
|
||||
limit: int = 10,
|
||||
scope: str = "edges",
|
||||
reranker: Optional[str] = None,
|
||||
) -> Any:
|
||||
"""Search the graph."""
|
||||
|
||||
@abstractmethod
|
||||
def get_all_nodes(self, graph_id: str) -> List[Any]:
|
||||
"""Fetch all nodes."""
|
||||
|
||||
@abstractmethod
|
||||
def get_all_edges(self, graph_id: str) -> List[Any]:
|
||||
"""Fetch all edges."""
|
||||
|
||||
@abstractmethod
|
||||
def get_node(self, node_uuid: str) -> Any:
|
||||
"""Fetch a node by UUID."""
|
||||
|
||||
@abstractmethod
|
||||
def get_node_edges(self, node_uuid: str) -> List[Any]:
|
||||
"""Fetch edges for a node."""
|
||||
|
||||
@abstractmethod
|
||||
def delete_graph(self, graph_id: str) -> None:
|
||||
"""Delete a graph."""
|
||||
|
||||
def get_ontology_spec(self, graph_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Fetch backend ontology metadata when available."""
|
||||
return None
|
||||
|
||||
def get_live_graph_statistics(self, graph_id: str) -> Optional[Dict[str, int]]:
|
||||
"""Fetch backend-specific live statistics when available."""
|
||||
return None
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
"""
|
||||
Graph backend factory.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from ..config import Config
|
||||
from .base import GraphBackend
|
||||
|
||||
|
||||
def get_graph_backend(api_key: Optional[str] = None) -> GraphBackend:
|
||||
"""Create the configured graph backend."""
|
||||
backend = (Config.GRAPH_BACKEND or "graphiti").lower()
|
||||
|
||||
if backend in {"zep", "openzep"}:
|
||||
from .zep_backend import ZepGraphBackend
|
||||
|
||||
return ZepGraphBackend(api_key=api_key)
|
||||
|
||||
if backend == "graphiti":
|
||||
from .graphiti_backend import GraphitiBackend
|
||||
|
||||
return GraphitiBackend(api_key=api_key)
|
||||
|
||||
raise ValueError(f"不支持的 GRAPH_BACKEND: {backend}")
|
||||
|
|
@ -0,0 +1,875 @@
|
|||
"""
|
||||
Graphiti + Neo4j graph backend implementation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
from ..config import Config
|
||||
from .base import GraphBackend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _CompatEpisode:
|
||||
uuid: str
|
||||
processed: bool = True
|
||||
name: str = ""
|
||||
content: str = ""
|
||||
valid_at: Optional[datetime] = None
|
||||
created_at: Optional[datetime] = None
|
||||
|
||||
@property
|
||||
def uuid_(self) -> str:
|
||||
return self.uuid
|
||||
|
||||
|
||||
@dataclass
|
||||
class _CompatNode:
|
||||
uuid: str
|
||||
name: str = ""
|
||||
labels: List[str] = field(default_factory=list)
|
||||
summary: str = ""
|
||||
attributes: Dict[str, Any] = field(default_factory=dict)
|
||||
created_at: Optional[datetime] = None
|
||||
|
||||
@property
|
||||
def uuid_(self) -> str:
|
||||
return self.uuid
|
||||
|
||||
|
||||
@dataclass
|
||||
class _CompatEdge:
|
||||
uuid: str
|
||||
name: str
|
||||
fact: str
|
||||
source_node_uuid: str
|
||||
target_node_uuid: str
|
||||
source_node_name: str = ""
|
||||
target_node_name: str = ""
|
||||
attributes: Dict[str, Any] = field(default_factory=dict)
|
||||
episodes: List[str] = field(default_factory=list)
|
||||
created_at: Optional[datetime] = None
|
||||
valid_at: Optional[datetime] = None
|
||||
invalid_at: Optional[datetime] = None
|
||||
expired_at: Optional[datetime] = None
|
||||
|
||||
@property
|
||||
def uuid_(self) -> str:
|
||||
return self.uuid
|
||||
|
||||
|
||||
@dataclass
|
||||
class _CompatSearchResults:
|
||||
edges: List[_CompatEdge] = field(default_factory=list)
|
||||
nodes: List[_CompatNode] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _OntologyBundle:
|
||||
entity_types: Dict[str, type[BaseModel]] = field(default_factory=dict)
|
||||
edge_types: Dict[str, type[BaseModel]] = field(default_factory=dict)
|
||||
edge_type_map: Dict[tuple[str, str], List[str]] = field(default_factory=dict)
|
||||
spec: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class _AsyncBridge:
|
||||
"""Run async Graphiti calls from the app's synchronous service layer."""
|
||||
|
||||
def __init__(self):
|
||||
self._ready = threading.Event()
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self._thread.start()
|
||||
self._ready.wait()
|
||||
|
||||
def _run_loop(self):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
self._loop = loop
|
||||
self._ready.set()
|
||||
loop.run_forever()
|
||||
|
||||
def run(self, coro):
|
||||
if self._loop is None:
|
||||
raise RuntimeError("Graphiti async loop 未初始化")
|
||||
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
||||
return future.result()
|
||||
|
||||
|
||||
class GraphitiBackend(GraphBackend):
|
||||
"""Graph backend backed by Graphiti OSS + Neo4j."""
|
||||
|
||||
_bridge: Optional[_AsyncBridge] = None
|
||||
_bridge_lock = threading.Lock()
|
||||
_indices_ready = False
|
||||
_indices_lock = threading.Lock()
|
||||
_ontology_registry: Dict[str, _OntologyBundle] = {}
|
||||
_ontology_lock = threading.Lock()
|
||||
_cross_encoder_warning_emitted = False
|
||||
PAGE_SIZE = 200
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
del api_key
|
||||
|
||||
errors = Config.get_graphiti_config_errors()
|
||||
if errors:
|
||||
raise ValueError("; ".join(errors))
|
||||
|
||||
try:
|
||||
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
||||
from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig
|
||||
from graphiti_core.graphiti import Graphiti
|
||||
from graphiti_core.llm_client import OpenAIClient
|
||||
from graphiti_core.llm_client.config import LLMConfig
|
||||
from graphiti_core.llm_client.openai_generic_client import OpenAIGenericClient
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Graphiti 依赖未安装,请先在 backend 环境中安装 graphiti-core 与 neo4j"
|
||||
) from exc
|
||||
|
||||
llm_config = LLMConfig(
|
||||
api_key=Config.GRAPHITI_LLM_API_KEY,
|
||||
base_url=Config.GRAPHITI_LLM_BASE_URL,
|
||||
model=Config.GRAPHITI_LLM_MODEL,
|
||||
small_model=Config.GRAPHITI_LLM_SMALL_MODEL,
|
||||
temperature=0,
|
||||
)
|
||||
reranker_config = LLMConfig(
|
||||
api_key=Config.GRAPHITI_RERANKER_API_KEY,
|
||||
base_url=Config.GRAPHITI_RERANKER_BASE_URL,
|
||||
model=Config.GRAPHITI_RERANKER_MODEL,
|
||||
temperature=0,
|
||||
)
|
||||
embedder_config = OpenAIEmbedderConfig(
|
||||
api_key=Config.GRAPHITI_EMBEDDER_API_KEY,
|
||||
base_url=Config.GRAPHITI_EMBEDDER_BASE_URL,
|
||||
embedding_model=Config.GRAPHITI_EMBEDDER_MODEL,
|
||||
embedding_dim=Config.GRAPHITI_EMBEDDER_DIM,
|
||||
)
|
||||
|
||||
llm_client_mode = (Config.GRAPHITI_LLM_CLIENT_MODE or "openai").lower()
|
||||
if llm_client_mode == "generic":
|
||||
llm_client = OpenAIGenericClient(
|
||||
config=llm_config,
|
||||
max_tokens=Config.GRAPHITI_LLM_MAX_TOKENS,
|
||||
)
|
||||
else:
|
||||
llm_client = OpenAIClient(
|
||||
config=llm_config,
|
||||
max_tokens=Config.GRAPHITI_LLM_MAX_TOKENS,
|
||||
)
|
||||
|
||||
self._graphiti = Graphiti(
|
||||
uri=Config.GRAPHITI_URI,
|
||||
user=Config.GRAPHITI_USER,
|
||||
password=Config.GRAPHITI_PASSWORD,
|
||||
llm_client=llm_client,
|
||||
embedder=OpenAIEmbedder(config=embedder_config),
|
||||
cross_encoder=OpenAIRerankerClient(config=reranker_config),
|
||||
max_coroutines=Config.GRAPHITI_MAX_COROUTINES,
|
||||
)
|
||||
self._driver = self._graphiti.driver.with_database(Config.GRAPHITI_DATABASE)
|
||||
self._graphiti.driver = self._driver
|
||||
self._graphiti.clients.driver = self._driver
|
||||
self._bridge = self._get_bridge()
|
||||
|
||||
self._ensure_indices()
|
||||
|
||||
@classmethod
|
||||
def _get_bridge(cls) -> _AsyncBridge:
|
||||
with cls._bridge_lock:
|
||||
if cls._bridge is None:
|
||||
cls._bridge = _AsyncBridge()
|
||||
return cls._bridge
|
||||
|
||||
@property
|
||||
def raw_client(self) -> Any:
|
||||
return self._graphiti
|
||||
|
||||
def _run(self, coro):
|
||||
return self._bridge.run(coro)
|
||||
|
||||
def _ensure_indices(self) -> None:
|
||||
if self.__class__._indices_ready:
|
||||
return
|
||||
|
||||
with self.__class__._indices_lock:
|
||||
if self.__class__._indices_ready:
|
||||
return
|
||||
self._run(self._graphiti.build_indices_and_constraints())
|
||||
self.__class__._indices_ready = True
|
||||
|
||||
def _validate_graph_id(self, graph_id: str) -> None:
|
||||
from graphiti_core.helpers import validate_group_id
|
||||
|
||||
validate_group_id(graph_id)
|
||||
|
||||
def _normalize_model_spec(self, model: type[BaseModel]) -> Dict[str, Any]:
|
||||
fields = []
|
||||
for field_name, model_field in model.model_fields.items():
|
||||
fields.append(
|
||||
{
|
||||
"name": field_name,
|
||||
"description": model_field.description or field_name,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"description": (getattr(model, "__doc__", "") or "").strip(),
|
||||
"fields": fields,
|
||||
}
|
||||
|
||||
def _build_dynamic_model(self, name: str, spec: Dict[str, Any]) -> type[BaseModel]:
|
||||
field_definitions = {}
|
||||
for field_spec in spec.get("fields", []):
|
||||
field_name = field_spec.get("name", "").strip()
|
||||
if not field_name:
|
||||
continue
|
||||
field_definitions[field_name] = (
|
||||
Optional[str],
|
||||
Field(
|
||||
default=None,
|
||||
description=field_spec.get("description") or field_name,
|
||||
),
|
||||
)
|
||||
|
||||
model = create_model(name, __base__=BaseModel, **field_definitions)
|
||||
model.__doc__ = spec.get("description") or name
|
||||
return model
|
||||
|
||||
def _serialize_ontology_spec(
|
||||
self,
|
||||
entity_specs: Dict[str, Dict[str, Any]],
|
||||
edge_specs: Dict[str, Dict[str, Any]],
|
||||
edge_type_map: Dict[tuple[str, str], List[str]],
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"entity_types": entity_specs,
|
||||
"edge_types": edge_specs,
|
||||
"edge_type_map": [
|
||||
{
|
||||
"source": source,
|
||||
"target": target,
|
||||
"edges": edge_names,
|
||||
}
|
||||
for (source, target), edge_names in sorted(edge_type_map.items())
|
||||
],
|
||||
}
|
||||
|
||||
def _bundle_from_spec(self, spec: Dict[str, Any]) -> _OntologyBundle:
|
||||
entity_types = {
|
||||
name: self._build_dynamic_model(name, model_spec)
|
||||
for name, model_spec in (spec.get("entity_types") or {}).items()
|
||||
}
|
||||
edge_types = {
|
||||
name: self._build_dynamic_model(name, model_spec)
|
||||
for name, model_spec in (spec.get("edge_types") or {}).items()
|
||||
}
|
||||
edge_type_map: Dict[tuple[str, str], List[str]] = {}
|
||||
for entry in spec.get("edge_type_map") or []:
|
||||
source = entry.get("source", "Entity")
|
||||
target = entry.get("target", "Entity")
|
||||
edge_type_map[(source, target)] = list(entry.get("edges") or [])
|
||||
|
||||
return _OntologyBundle(
|
||||
entity_types=entity_types,
|
||||
edge_types=edge_types,
|
||||
edge_type_map=edge_type_map,
|
||||
spec=spec,
|
||||
)
|
||||
|
||||
def _build_ontology_bundle(self, entities: Any = None, edges: Any = None) -> _OntologyBundle:
|
||||
entity_specs = {}
|
||||
entity_types = {}
|
||||
for entity_name, entity_model in (entities or {}).items():
|
||||
entity_spec = self._normalize_model_spec(entity_model)
|
||||
entity_specs[entity_name] = entity_spec
|
||||
entity_types[entity_name] = self._build_dynamic_model(entity_name, entity_spec)
|
||||
|
||||
edge_specs = {}
|
||||
edge_types = {}
|
||||
edge_type_map: Dict[tuple[str, str], List[str]] = {}
|
||||
for edge_name, edge_value in (edges or {}).items():
|
||||
if not isinstance(edge_value, tuple) or len(edge_value) != 2:
|
||||
continue
|
||||
edge_model, source_targets = edge_value
|
||||
edge_spec = self._normalize_model_spec(edge_model)
|
||||
edge_specs[edge_name] = edge_spec
|
||||
edge_types[edge_name] = self._build_dynamic_model(edge_name, edge_spec)
|
||||
|
||||
for source_target in source_targets or []:
|
||||
source = getattr(source_target, "source", "Entity") or "Entity"
|
||||
target = getattr(source_target, "target", "Entity") or "Entity"
|
||||
edge_type_map.setdefault((source, target), []).append(edge_name)
|
||||
|
||||
if not edge_type_map:
|
||||
edge_type_map = {("Entity", "Entity"): list(edge_types.keys())}
|
||||
|
||||
spec = self._serialize_ontology_spec(entity_specs, edge_specs, edge_type_map)
|
||||
return _OntologyBundle(
|
||||
entity_types=entity_types,
|
||||
edge_types=edge_types,
|
||||
edge_type_map=edge_type_map,
|
||||
spec=spec,
|
||||
)
|
||||
|
||||
async def _upsert_graph_metadata_async(
|
||||
self,
|
||||
graph_id: str,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
ontology_spec: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
payload = json.dumps(ontology_spec, ensure_ascii=False) if ontology_spec is not None else None
|
||||
|
||||
await self._driver.execute_query(
|
||||
"""
|
||||
MERGE (m:GraphMetadata {graph_id: $graph_id})
|
||||
ON CREATE SET
|
||||
m.group_id = $graph_id,
|
||||
m.created_at = datetime()
|
||||
SET
|
||||
m.name = CASE WHEN $name IS NULL OR $name = '' THEN coalesce(m.name, '') ELSE $name END,
|
||||
m.description = CASE
|
||||
WHEN $description IS NULL OR $description = ''
|
||||
THEN coalesce(m.description, '')
|
||||
ELSE $description
|
||||
END,
|
||||
m.ontology_json = CASE
|
||||
WHEN $ontology_json IS NULL OR $ontology_json = ''
|
||||
THEN m.ontology_json
|
||||
ELSE $ontology_json
|
||||
END,
|
||||
m.updated_at = datetime()
|
||||
""",
|
||||
graph_id=graph_id,
|
||||
name=name,
|
||||
description=description,
|
||||
ontology_json=payload,
|
||||
)
|
||||
|
||||
async def _load_ontology_bundle_async(self, graph_id: str) -> _OntologyBundle:
|
||||
records, _, _ = await self._driver.execute_query(
|
||||
"""
|
||||
MATCH (m:GraphMetadata {graph_id: $graph_id})
|
||||
RETURN m.ontology_json AS ontology_json
|
||||
""",
|
||||
graph_id=graph_id,
|
||||
routing_="r",
|
||||
)
|
||||
|
||||
if not records:
|
||||
return _OntologyBundle()
|
||||
|
||||
ontology_json = records[0].get("ontology_json")
|
||||
if not ontology_json:
|
||||
return _OntologyBundle()
|
||||
|
||||
try:
|
||||
spec = json.loads(ontology_json)
|
||||
except (TypeError, ValueError, json.JSONDecodeError):
|
||||
logger.warning("Graphiti ontology metadata 解析失败,graph_id=%s", graph_id)
|
||||
return _OntologyBundle()
|
||||
|
||||
return self._bundle_from_spec(spec)
|
||||
|
||||
async def _get_ontology_bundle_async(self, graph_id: str) -> _OntologyBundle:
|
||||
with self.__class__._ontology_lock:
|
||||
cached = self.__class__._ontology_registry.get(graph_id)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
bundle = await self._load_ontology_bundle_async(graph_id)
|
||||
with self.__class__._ontology_lock:
|
||||
self.__class__._ontology_registry[graph_id] = bundle
|
||||
return bundle
|
||||
|
||||
def _get_ontology_bundle(self, graph_id: str) -> _OntologyBundle:
|
||||
return self._run(self._get_ontology_bundle_async(graph_id))
|
||||
|
||||
def _set_ontology_bundle(self, graph_id: str, bundle: _OntologyBundle) -> None:
|
||||
with self.__class__._ontology_lock:
|
||||
self.__class__._ontology_registry[graph_id] = bundle
|
||||
|
||||
def get_ontology_spec(self, graph_id: str) -> Optional[Dict[str, Any]]:
|
||||
self._validate_graph_id(graph_id)
|
||||
bundle = self._get_ontology_bundle(graph_id)
|
||||
return dict(bundle.spec) if bundle.spec else None
|
||||
|
||||
def create_graph(self, graph_id: str, name: str, description: str) -> None:
|
||||
self._validate_graph_id(graph_id)
|
||||
self._run(
|
||||
self._upsert_graph_metadata_async(
|
||||
graph_id=graph_id,
|
||||
name=name,
|
||||
description=description,
|
||||
)
|
||||
)
|
||||
|
||||
def set_ontology(
|
||||
self,
|
||||
graph_id: str,
|
||||
entities: Any = None,
|
||||
edges: Any = None,
|
||||
) -> None:
|
||||
self._validate_graph_id(graph_id)
|
||||
bundle = self._build_ontology_bundle(entities=entities, edges=edges)
|
||||
self._set_ontology_bundle(graph_id, bundle)
|
||||
self._run(
|
||||
self._upsert_graph_metadata_async(
|
||||
graph_id=graph_id,
|
||||
ontology_spec=bundle.spec,
|
||||
)
|
||||
)
|
||||
|
||||
async def _add_text_async(self, graph_id: str, data: str) -> _CompatEpisode:
|
||||
from graphiti_core.helpers import validate_excluded_entity_types
|
||||
from graphiti_core.nodes import EpisodeType, EpisodicNode
|
||||
from graphiti_core.search.search_utils import RELEVANT_SCHEMA_LIMIT
|
||||
from graphiti_core.utils.datetime_utils import utc_now
|
||||
from graphiti_core.utils.maintenance.node_operations import (
|
||||
extract_attributes_from_nodes,
|
||||
extract_nodes,
|
||||
resolve_extracted_nodes,
|
||||
)
|
||||
from graphiti_core.utils.ontology_utils.entity_types_utils import validate_entity_types
|
||||
|
||||
bundle = await self._get_ontology_bundle_async(graph_id)
|
||||
entity_types = bundle.entity_types or None
|
||||
edge_types = bundle.edge_types or None
|
||||
edge_type_map = bundle.edge_type_map or {("Entity", "Entity"): []}
|
||||
|
||||
validate_entity_types(entity_types)
|
||||
validate_excluded_entity_types(None, entity_types)
|
||||
|
||||
now = utc_now()
|
||||
previous_episodes = await self._graphiti.retrieve_episodes(
|
||||
reference_time=now,
|
||||
last_n=RELEVANT_SCHEMA_LIMIT,
|
||||
group_ids=[graph_id],
|
||||
source=EpisodeType.text,
|
||||
driver=self._driver,
|
||||
)
|
||||
|
||||
episode = EpisodicNode(
|
||||
name=f"episode_{now.strftime('%Y%m%d%H%M%S%f')}",
|
||||
group_id=graph_id,
|
||||
labels=[],
|
||||
source=EpisodeType.text,
|
||||
content=data,
|
||||
source_description="text",
|
||||
created_at=now,
|
||||
valid_at=now,
|
||||
)
|
||||
|
||||
extracted_nodes = await extract_nodes(
|
||||
self._graphiti.clients,
|
||||
episode,
|
||||
previous_episodes,
|
||||
entity_types,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
nodes, uuid_map, _ = await resolve_extracted_nodes(
|
||||
self._graphiti.clients,
|
||||
extracted_nodes,
|
||||
episode,
|
||||
previous_episodes,
|
||||
entity_types,
|
||||
)
|
||||
resolved_edges, invalidated_edges, new_edges = await self._graphiti._extract_and_resolve_edges(
|
||||
episode,
|
||||
extracted_nodes,
|
||||
previous_episodes,
|
||||
edge_type_map,
|
||||
graph_id,
|
||||
edge_types,
|
||||
nodes,
|
||||
uuid_map,
|
||||
None,
|
||||
)
|
||||
entity_edges = resolved_edges + invalidated_edges
|
||||
hydrated_nodes = await extract_attributes_from_nodes(
|
||||
self._graphiti.clients,
|
||||
nodes,
|
||||
episode,
|
||||
previous_episodes,
|
||||
entity_types,
|
||||
edges=new_edges,
|
||||
)
|
||||
_, saved_episode = await self._graphiti._process_episode_data(
|
||||
episode,
|
||||
hydrated_nodes,
|
||||
entity_edges,
|
||||
now,
|
||||
graph_id,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
return _CompatEpisode(
|
||||
uuid=saved_episode.uuid,
|
||||
processed=True,
|
||||
name=saved_episode.name,
|
||||
content=saved_episode.content,
|
||||
valid_at=saved_episode.valid_at,
|
||||
created_at=saved_episode.created_at,
|
||||
)
|
||||
|
||||
def add_batch(self, graph_id: str, episodes: List[Any]) -> List[Any]:
|
||||
results = []
|
||||
for episode in episodes:
|
||||
data = getattr(episode, "data", None)
|
||||
if data is None and isinstance(episode, dict):
|
||||
data = episode.get("data", "")
|
||||
results.append(self.add_text(graph_id=graph_id, data=str(data or "")))
|
||||
return results
|
||||
|
||||
def add_text(self, graph_id: str, data: str) -> Any:
|
||||
self._validate_graph_id(graph_id)
|
||||
return self._run(self._add_text_async(graph_id=graph_id, data=data))
|
||||
|
||||
async def _get_episode_async(self, episode_uuid: str) -> _CompatEpisode:
|
||||
from graphiti_core.nodes import EpisodicNode
|
||||
|
||||
episode = await EpisodicNode.get_by_uuid(self._driver, episode_uuid)
|
||||
return _CompatEpisode(
|
||||
uuid=episode.uuid,
|
||||
processed=True,
|
||||
name=episode.name,
|
||||
content=episode.content,
|
||||
valid_at=episode.valid_at,
|
||||
created_at=episode.created_at,
|
||||
)
|
||||
|
||||
def get_episode(self, episode_uuid: str) -> Any:
|
||||
return self._run(self._get_episode_async(episode_uuid))
|
||||
|
||||
def _warn_cross_encoder_fallback(self) -> None:
|
||||
if self.__class__._cross_encoder_warning_emitted:
|
||||
return
|
||||
logger.info(
|
||||
"Graphiti cross_encoder 默认已降级为 rrf;如需启用请设置 GRAPHITI_ENABLE_CROSS_ENCODER=true"
|
||||
)
|
||||
self.__class__._cross_encoder_warning_emitted = True
|
||||
|
||||
def _build_search_config(self, scope: str, limit: int, reranker: Optional[str]):
|
||||
from graphiti_core.search.search_config import (
|
||||
EdgeReranker,
|
||||
EdgeSearchConfig,
|
||||
EdgeSearchMethod,
|
||||
NodeReranker,
|
||||
NodeSearchConfig,
|
||||
NodeSearchMethod,
|
||||
SearchConfig,
|
||||
)
|
||||
|
||||
reranker_name = (reranker or "rrf").strip().lower()
|
||||
edge_reranker_map = {
|
||||
"rrf": EdgeReranker.rrf,
|
||||
"reciprocal_rank_fusion": EdgeReranker.rrf,
|
||||
"cross_encoder": EdgeReranker.cross_encoder,
|
||||
"node_distance": EdgeReranker.node_distance,
|
||||
"episode_mentions": EdgeReranker.episode_mentions,
|
||||
"mmr": EdgeReranker.mmr,
|
||||
}
|
||||
node_reranker_map = {
|
||||
"rrf": NodeReranker.rrf,
|
||||
"reciprocal_rank_fusion": NodeReranker.rrf,
|
||||
"cross_encoder": NodeReranker.cross_encoder,
|
||||
"node_distance": NodeReranker.node_distance,
|
||||
"episode_mentions": NodeReranker.episode_mentions,
|
||||
"mmr": NodeReranker.mmr,
|
||||
}
|
||||
|
||||
edge_reranker = edge_reranker_map.get(reranker_name, EdgeReranker.rrf)
|
||||
node_reranker = node_reranker_map.get(reranker_name, NodeReranker.rrf)
|
||||
edge_methods = [EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity]
|
||||
node_methods = [NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity]
|
||||
|
||||
if reranker_name == "cross_encoder":
|
||||
if Config.GRAPHITI_ENABLE_CROSS_ENCODER:
|
||||
edge_methods.append(EdgeSearchMethod.bfs)
|
||||
node_methods.append(NodeSearchMethod.bfs)
|
||||
else:
|
||||
self._warn_cross_encoder_fallback()
|
||||
edge_reranker = EdgeReranker.rrf
|
||||
node_reranker = NodeReranker.rrf
|
||||
|
||||
edge_config = None
|
||||
node_config = None
|
||||
if scope in {"edges", "both"}:
|
||||
edge_config = EdgeSearchConfig(
|
||||
search_methods=edge_methods,
|
||||
reranker=edge_reranker,
|
||||
)
|
||||
if scope in {"nodes", "both"}:
|
||||
node_config = NodeSearchConfig(
|
||||
search_methods=node_methods,
|
||||
reranker=node_reranker,
|
||||
)
|
||||
|
||||
return SearchConfig(
|
||||
edge_config=edge_config,
|
||||
node_config=node_config,
|
||||
limit=max(1, limit),
|
||||
)
|
||||
|
||||
def _wrap_node(self, node: Any) -> _CompatNode:
|
||||
return _CompatNode(
|
||||
uuid=getattr(node, "uuid", ""),
|
||||
name=getattr(node, "name", "") or "",
|
||||
labels=list(getattr(node, "labels", []) or []),
|
||||
summary=getattr(node, "summary", "") or "",
|
||||
attributes=dict(getattr(node, "attributes", {}) or {}),
|
||||
created_at=getattr(node, "created_at", None),
|
||||
)
|
||||
|
||||
def _wrap_edge(
|
||||
self,
|
||||
edge: Any,
|
||||
source_node_name: str = "",
|
||||
target_node_name: str = "",
|
||||
) -> _CompatEdge:
|
||||
return _CompatEdge(
|
||||
uuid=getattr(edge, "uuid", ""),
|
||||
name=getattr(edge, "name", "") or "",
|
||||
fact=getattr(edge, "fact", "") or "",
|
||||
source_node_uuid=getattr(edge, "source_node_uuid", "") or "",
|
||||
target_node_uuid=getattr(edge, "target_node_uuid", "") or "",
|
||||
source_node_name=source_node_name,
|
||||
target_node_name=target_node_name,
|
||||
attributes=dict(getattr(edge, "attributes", {}) or {}),
|
||||
episodes=list(getattr(edge, "episodes", []) or []),
|
||||
created_at=getattr(edge, "created_at", None),
|
||||
valid_at=getattr(edge, "valid_at", None),
|
||||
invalid_at=getattr(edge, "invalid_at", None),
|
||||
expired_at=getattr(edge, "expired_at", None),
|
||||
)
|
||||
|
||||
async def _search_async(
|
||||
self,
|
||||
graph_id: str,
|
||||
query: str,
|
||||
limit: int,
|
||||
scope: str,
|
||||
reranker: Optional[str],
|
||||
) -> _CompatSearchResults:
|
||||
from graphiti_core.nodes import EntityNode
|
||||
|
||||
search_config = self._build_search_config(scope=scope, limit=limit, reranker=reranker)
|
||||
results = await self._graphiti.search_(
|
||||
query=query,
|
||||
config=search_config,
|
||||
group_ids=[graph_id],
|
||||
driver=self._driver,
|
||||
)
|
||||
|
||||
nodes = [self._wrap_node(node) for node in results.nodes]
|
||||
node_name_map = {node.uuid: node.name for node in nodes if node.uuid}
|
||||
|
||||
missing_node_ids = {
|
||||
node_uuid
|
||||
for edge in results.edges
|
||||
for node_uuid in (edge.source_node_uuid, edge.target_node_uuid)
|
||||
if node_uuid and node_uuid not in node_name_map
|
||||
}
|
||||
if missing_node_ids:
|
||||
for node in await EntityNode.get_by_uuids(self._driver, list(missing_node_ids)):
|
||||
node_name_map[node.uuid] = node.name or ""
|
||||
|
||||
edges = [
|
||||
self._wrap_edge(
|
||||
edge,
|
||||
source_node_name=node_name_map.get(edge.source_node_uuid, ""),
|
||||
target_node_name=node_name_map.get(edge.target_node_uuid, ""),
|
||||
)
|
||||
for edge in results.edges
|
||||
]
|
||||
|
||||
return _CompatSearchResults(edges=edges, nodes=nodes)
|
||||
|
||||
def search(
|
||||
self,
|
||||
graph_id: str,
|
||||
query: str,
|
||||
limit: int = 10,
|
||||
scope: str = "edges",
|
||||
reranker: Optional[str] = None,
|
||||
) -> Any:
|
||||
self._validate_graph_id(graph_id)
|
||||
return self._run(
|
||||
self._search_async(
|
||||
graph_id=graph_id,
|
||||
query=query,
|
||||
limit=limit,
|
||||
scope=scope,
|
||||
reranker=reranker,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def _get_all_nodes_async(self, graph_id: str) -> List[_CompatNode]:
|
||||
from graphiti_core.nodes import EntityNode
|
||||
|
||||
result = []
|
||||
cursor = None
|
||||
while True:
|
||||
batch = await EntityNode.get_by_group_ids(
|
||||
self._driver,
|
||||
[graph_id],
|
||||
limit=self.PAGE_SIZE,
|
||||
uuid_cursor=cursor,
|
||||
)
|
||||
if not batch:
|
||||
break
|
||||
result.extend(self._wrap_node(node) for node in batch)
|
||||
if len(batch) < self.PAGE_SIZE:
|
||||
break
|
||||
cursor = batch[-1].uuid
|
||||
|
||||
return result
|
||||
|
||||
def get_all_nodes(self, graph_id: str) -> List[Any]:
|
||||
self._validate_graph_id(graph_id)
|
||||
return self._run(self._get_all_nodes_async(graph_id))
|
||||
|
||||
async def _get_all_edges_async(self, graph_id: str) -> List[_CompatEdge]:
|
||||
from graphiti_core.edges import EntityEdge, GroupsEdgesNotFoundError
|
||||
|
||||
result = []
|
||||
cursor = None
|
||||
while True:
|
||||
try:
|
||||
batch = await EntityEdge.get_by_group_ids(
|
||||
self._driver,
|
||||
[graph_id],
|
||||
limit=self.PAGE_SIZE,
|
||||
uuid_cursor=cursor,
|
||||
)
|
||||
except GroupsEdgesNotFoundError:
|
||||
break
|
||||
|
||||
if not batch:
|
||||
break
|
||||
|
||||
result.extend(self._wrap_edge(edge) for edge in batch)
|
||||
if len(batch) < self.PAGE_SIZE:
|
||||
break
|
||||
cursor = batch[-1].uuid
|
||||
|
||||
return result
|
||||
|
||||
def get_all_edges(self, graph_id: str) -> List[Any]:
|
||||
self._validate_graph_id(graph_id)
|
||||
return self._run(self._get_all_edges_async(graph_id))
|
||||
|
||||
async def _get_node_async(self, node_uuid: str) -> _CompatNode:
|
||||
from graphiti_core.nodes import EntityNode
|
||||
|
||||
return self._wrap_node(await EntityNode.get_by_uuid(self._driver, node_uuid))
|
||||
|
||||
def get_node(self, node_uuid: str) -> Any:
|
||||
return self._run(self._get_node_async(node_uuid))
|
||||
|
||||
async def _get_node_edges_async(self, node_uuid: str) -> List[_CompatEdge]:
|
||||
from graphiti_core.edges import EntityEdge
|
||||
from graphiti_core.nodes import EntityNode
|
||||
|
||||
edges = await EntityEdge.get_by_node_uuid(self._driver, node_uuid)
|
||||
related_node_ids = {
|
||||
related_uuid
|
||||
for edge in edges
|
||||
for related_uuid in (edge.source_node_uuid, edge.target_node_uuid)
|
||||
if related_uuid
|
||||
}
|
||||
node_name_map = {}
|
||||
if related_node_ids:
|
||||
for node in await EntityNode.get_by_uuids(self._driver, list(related_node_ids)):
|
||||
node_name_map[node.uuid] = node.name or ""
|
||||
|
||||
return [
|
||||
self._wrap_edge(
|
||||
edge,
|
||||
source_node_name=node_name_map.get(edge.source_node_uuid, ""),
|
||||
target_node_name=node_name_map.get(edge.target_node_uuid, ""),
|
||||
)
|
||||
for edge in edges
|
||||
]
|
||||
|
||||
def get_node_edges(self, node_uuid: str) -> List[Any]:
|
||||
return self._run(self._get_node_edges_async(node_uuid))
|
||||
|
||||
async def _delete_graph_async(self, graph_id: str) -> None:
|
||||
from graphiti_core.nodes import Node
|
||||
|
||||
await self._driver.execute_query(
|
||||
"""
|
||||
MATCH (s:Saga {group_id: $graph_id})
|
||||
DETACH DELETE s
|
||||
""",
|
||||
graph_id=graph_id,
|
||||
)
|
||||
await Node.delete_by_group_id(self._driver, graph_id)
|
||||
await self._driver.execute_query(
|
||||
"""
|
||||
MATCH (m:GraphMetadata {graph_id: $graph_id})
|
||||
DETACH DELETE m
|
||||
""",
|
||||
graph_id=graph_id,
|
||||
)
|
||||
|
||||
def delete_graph(self, graph_id: str) -> None:
|
||||
self._validate_graph_id(graph_id)
|
||||
self._run(self._delete_graph_async(graph_id))
|
||||
with self.__class__._ontology_lock:
|
||||
self.__class__._ontology_registry.pop(graph_id, None)
|
||||
|
||||
async def _get_live_graph_statistics_async(self, graph_id: str) -> Dict[str, int]:
|
||||
node_records, _, _ = await self._driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity {group_id: $graph_id})
|
||||
RETURN count(n) AS node_count
|
||||
""",
|
||||
graph_id=graph_id,
|
||||
routing_="r",
|
||||
)
|
||||
edge_records, _, _ = await self._driver.execute_query(
|
||||
"""
|
||||
MATCH ()-[e:RELATES_TO {group_id: $graph_id}]->()
|
||||
RETURN count(e) AS edge_count
|
||||
""",
|
||||
graph_id=graph_id,
|
||||
routing_="r",
|
||||
)
|
||||
episode_records, _, _ = await self._driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Episodic {group_id: $graph_id})
|
||||
RETURN count(n) AS episode_count
|
||||
""",
|
||||
graph_id=graph_id,
|
||||
routing_="r",
|
||||
)
|
||||
|
||||
return {
|
||||
"node_count": int((node_records[0] if node_records else {}).get("node_count", 0) or 0),
|
||||
"edge_count": int((edge_records[0] if edge_records else {}).get("edge_count", 0) or 0),
|
||||
"episode_count": int(
|
||||
(episode_records[0] if episode_records else {}).get("episode_count", 0) or 0
|
||||
),
|
||||
}
|
||||
|
||||
def get_live_graph_statistics(self, graph_id: str) -> Optional[Dict[str, int]]:
|
||||
self._validate_graph_id(graph_id)
|
||||
return self._run(self._get_live_graph_statistics_async(graph_id))
|
||||
|
|
@ -0,0 +1,114 @@
|
|||
"""
|
||||
Zep / OpenZep graph backend implementation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.error import HTTPError, URLError
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
from zep_cloud.client import Zep
|
||||
|
||||
from ..config import Config
|
||||
from ..utils.zep_paging import fetch_all_edges, fetch_all_nodes
|
||||
from .base import GraphBackend
|
||||
|
||||
|
||||
class ZepGraphBackend(GraphBackend):
|
||||
"""Graph backend backed by Zep Cloud or OpenZep."""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
self.api_key = Config.ZEP_API_KEY if api_key is None else api_key
|
||||
errors = Config.get_zep_config_errors(api_key=self.api_key)
|
||||
if errors:
|
||||
raise ValueError("; ".join(errors))
|
||||
|
||||
self.client = Zep(**Config.get_zep_client_kwargs(api_key=self.api_key))
|
||||
|
||||
@property
|
||||
def raw_client(self) -> Zep:
|
||||
return self.client
|
||||
|
||||
def create_graph(self, graph_id: str, name: str, description: str) -> None:
|
||||
self.client.graph.create(
|
||||
graph_id=graph_id,
|
||||
name=name,
|
||||
description=description,
|
||||
)
|
||||
|
||||
def set_ontology(
|
||||
self,
|
||||
graph_id: str,
|
||||
entities: Any = None,
|
||||
edges: Any = None,
|
||||
) -> None:
|
||||
self.client.graph.set_ontology(
|
||||
graph_ids=[graph_id],
|
||||
entities=entities,
|
||||
edges=edges,
|
||||
)
|
||||
|
||||
def add_batch(self, graph_id: str, episodes: List[Any]) -> List[Any]:
|
||||
return self.client.graph.add_batch(graph_id=graph_id, episodes=episodes)
|
||||
|
||||
def add_text(self, graph_id: str, data: str) -> Any:
|
||||
return self.client.graph.add(graph_id=graph_id, type="text", data=data)
|
||||
|
||||
def get_episode(self, episode_uuid: str) -> Any:
|
||||
return self.client.graph.episode.get(uuid_=episode_uuid)
|
||||
|
||||
def search(
|
||||
self,
|
||||
graph_id: str,
|
||||
query: str,
|
||||
limit: int = 10,
|
||||
scope: str = "edges",
|
||||
reranker: Optional[str] = None,
|
||||
) -> Any:
|
||||
kwargs = {
|
||||
"graph_id": graph_id,
|
||||
"query": query,
|
||||
"limit": limit,
|
||||
"scope": scope,
|
||||
}
|
||||
if reranker:
|
||||
kwargs["reranker"] = reranker
|
||||
return self.client.graph.search(**kwargs)
|
||||
|
||||
def get_all_nodes(self, graph_id: str) -> List[Any]:
|
||||
return fetch_all_nodes(self.client, graph_id)
|
||||
|
||||
def get_all_edges(self, graph_id: str) -> List[Any]:
|
||||
return fetch_all_edges(self.client, graph_id)
|
||||
|
||||
def get_node(self, node_uuid: str) -> Any:
|
||||
return self.client.graph.node.get(uuid_=node_uuid)
|
||||
|
||||
def get_node_edges(self, node_uuid: str) -> List[Any]:
|
||||
return self.client.graph.node.get_entity_edges(node_uuid=node_uuid)
|
||||
|
||||
def delete_graph(self, graph_id: str) -> None:
|
||||
self.client.graph.delete(graph_id=graph_id)
|
||||
|
||||
def get_live_graph_statistics(self, graph_id: str) -> Optional[Dict[str, int]]:
|
||||
if not Config.ZEP_BASE_URL:
|
||||
return None
|
||||
|
||||
base_url = Config.ZEP_BASE_URL.rstrip("/")
|
||||
request = Request(f"{base_url}/graph/{graph_id}/statistics")
|
||||
if self.api_key:
|
||||
request.add_header("Authorization", f"Bearer {self.api_key}")
|
||||
|
||||
try:
|
||||
with urlopen(request, timeout=10) as response:
|
||||
payload = json.loads(response.read().decode("utf-8"))
|
||||
except (HTTPError, URLError, TimeoutError, OSError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
return {
|
||||
"node_count": max(0, int(payload.get("node_count", 0) or 0)),
|
||||
"edge_count": max(0, int(payload.get("edge_count", 0) or 0)),
|
||||
"episode_count": max(0, int(payload.get("episode_count", 0) or 0)),
|
||||
}
|
||||
|
|
@ -7,15 +7,14 @@ import os
|
|||
import uuid
|
||||
import time
|
||||
import threading
|
||||
import json
|
||||
from typing import Dict, Any, List, Optional, Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from zep_cloud.client import Zep
|
||||
from zep_cloud import EpisodeData, EntityEdgeSourceTarget
|
||||
|
||||
from ..config import Config
|
||||
from ..graph import get_graph_backend
|
||||
from ..models.task import TaskManager, TaskStatus
|
||||
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
|
||||
from .text_processor import TextProcessor
|
||||
|
||||
|
||||
|
|
@ -43,11 +42,12 @@ class GraphBuilderService:
|
|||
"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
self.api_key = api_key or Config.ZEP_API_KEY
|
||||
if not self.api_key:
|
||||
raise ValueError("ZEP_API_KEY 未配置")
|
||||
self.api_key = Config.ZEP_API_KEY if api_key is None else api_key
|
||||
errors = Config.get_graph_backend_config_errors(api_key=self.api_key)
|
||||
if errors:
|
||||
raise ValueError("; ".join(errors))
|
||||
|
||||
self.client = Zep(api_key=self.api_key)
|
||||
self.backend = get_graph_backend(api_key=self.api_key)
|
||||
self.task_manager = TaskManager()
|
||||
|
||||
def build_graph_async(
|
||||
|
|
@ -57,7 +57,7 @@ class GraphBuilderService:
|
|||
graph_name: str = "MiroFish Graph",
|
||||
chunk_size: int = 500,
|
||||
chunk_overlap: int = 50,
|
||||
batch_size: int = 3
|
||||
batch_size: int = 1
|
||||
) -> str:
|
||||
"""
|
||||
异步构建图谱
|
||||
|
|
@ -155,6 +155,7 @@ class GraphBuilderService:
|
|||
)
|
||||
|
||||
self._wait_for_episodes(
|
||||
graph_id,
|
||||
episode_uuids,
|
||||
lambda msg, prog: self.task_manager.update_task(
|
||||
task_id,
|
||||
|
|
@ -188,10 +189,10 @@ class GraphBuilderService:
|
|||
"""创建Zep图谱(公开方法)"""
|
||||
graph_id = f"mirofish_{uuid.uuid4().hex[:16]}"
|
||||
|
||||
self.client.graph.create(
|
||||
self.backend.create_graph(
|
||||
graph_id=graph_id,
|
||||
name=name,
|
||||
description="MiroFish Social Simulation Graph"
|
||||
description="MiroFish Social Simulation Graph",
|
||||
)
|
||||
|
||||
return graph_id
|
||||
|
|
@ -279,8 +280,8 @@ class GraphBuilderService:
|
|||
|
||||
# 调用Zep API设置本体
|
||||
if entity_types or edge_definitions:
|
||||
self.client.graph.set_ontology(
|
||||
graph_ids=[graph_id],
|
||||
self.backend.set_ontology(
|
||||
graph_id=graph_id,
|
||||
entities=entity_types if entity_types else None,
|
||||
edges=edge_definitions if edge_definitions else None,
|
||||
)
|
||||
|
|
@ -289,7 +290,7 @@ class GraphBuilderService:
|
|||
self,
|
||||
graph_id: str,
|
||||
chunks: List[str],
|
||||
batch_size: int = 3,
|
||||
batch_size: int = 1,
|
||||
progress_callback: Optional[Callable] = None
|
||||
) -> List[str]:
|
||||
"""分批添加文本到图谱,返回所有 episode 的 uuid 列表"""
|
||||
|
|
@ -316,7 +317,7 @@ class GraphBuilderService:
|
|||
|
||||
# 发送到Zep
|
||||
try:
|
||||
batch_result = self.client.graph.add_batch(
|
||||
batch_result = self.backend.add_batch(
|
||||
graph_id=graph_id,
|
||||
episodes=episodes
|
||||
)
|
||||
|
|
@ -337,70 +338,152 @@ class GraphBuilderService:
|
|||
raise
|
||||
|
||||
return episode_uuids
|
||||
|
||||
|
||||
|
||||
|
||||
def _get_live_graph_statistics(self, graph_id: str) -> Optional[Dict[str, int]]:
|
||||
"""直接读取后端的实时图谱统计。"""
|
||||
return self.backend.get_live_graph_statistics(graph_id)
|
||||
|
||||
def _wait_for_episodes(
|
||||
self,
|
||||
graph_id: str,
|
||||
episode_uuids: List[str],
|
||||
progress_callback: Optional[Callable] = None,
|
||||
timeout: int = 600
|
||||
):
|
||||
"""等待所有 episode 处理完成(通过查询每个 episode 的 processed 状态)"""
|
||||
"""等待 OpenZep 处理完成,优先参考真实图谱状态。"""
|
||||
if not episode_uuids:
|
||||
if progress_callback:
|
||||
progress_callback("无需等待(没有 episode)", 1.0)
|
||||
return
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
pending_episodes = set(episode_uuids)
|
||||
completed_count = 0
|
||||
total_episodes = len(episode_uuids)
|
||||
|
||||
last_graph_signature: Optional[tuple[int, int, int]] = None
|
||||
stable_graph_polls = 0
|
||||
stable_graph_required = 2
|
||||
last_live_stats: Optional[Dict[str, int]] = None
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(f"开始等待 {total_episodes} 个文本块处理...", 0)
|
||||
|
||||
|
||||
while pending_episodes:
|
||||
if time.time() - start_time > timeout:
|
||||
elapsed_seconds = time.time() - start_time
|
||||
if elapsed_seconds > timeout:
|
||||
if last_live_stats is not None:
|
||||
graph_episode_count = min(last_live_stats["episode_count"], total_episodes)
|
||||
graph_node_count = last_live_stats["node_count"]
|
||||
graph_edge_count = last_live_stats["edge_count"]
|
||||
graph_entity_like_nodes = max(0, graph_node_count - last_live_stats["episode_count"])
|
||||
if graph_episode_count >= total_episodes and (graph_entity_like_nodes > 0 or graph_edge_count > 0):
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
(
|
||||
f"OpenZep 接口进度未返回完成标记,但真实图谱已写入 "
|
||||
f"episodes={graph_episode_count}/{total_episodes}, "
|
||||
f"nodes={graph_node_count}, edges={graph_edge_count}"
|
||||
),
|
||||
1.0,
|
||||
)
|
||||
return
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
f"部分文本块超时,已完成 {completed_count}/{total_episodes}",
|
||||
completed_count / total_episodes
|
||||
)
|
||||
break
|
||||
|
||||
# 检查每个 episode 的处理状态
|
||||
|
||||
for ep_uuid in list(pending_episodes):
|
||||
try:
|
||||
episode = self.client.graph.episode.get(uuid_=ep_uuid)
|
||||
episode = self.backend.get_episode(ep_uuid)
|
||||
is_processed = getattr(episode, 'processed', False)
|
||||
|
||||
|
||||
if is_processed:
|
||||
pending_episodes.remove(ep_uuid)
|
||||
completed_count += 1
|
||||
|
||||
except Exception as e:
|
||||
# 忽略单个查询错误,继续
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
elapsed = int(time.time() - start_time)
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
f"Zep处理中... {completed_count}/{total_episodes} 完成, {len(pending_episodes)} 待处理 ({elapsed}秒)",
|
||||
completed_count / total_episodes if total_episodes > 0 else 0
|
||||
|
||||
live_stats = self._get_live_graph_statistics(graph_id)
|
||||
graph_episode_count = 0
|
||||
graph_node_count = 0
|
||||
graph_edge_count = 0
|
||||
graph_entity_like_nodes = 0
|
||||
graph_progress = 0.0
|
||||
|
||||
if live_stats is not None:
|
||||
last_live_stats = live_stats
|
||||
graph_episode_count = min(live_stats["episode_count"], total_episodes)
|
||||
graph_node_count = live_stats["node_count"]
|
||||
graph_edge_count = live_stats["edge_count"]
|
||||
graph_entity_like_nodes = max(0, graph_node_count - live_stats["episode_count"])
|
||||
graph_progress = graph_episode_count / total_episodes if total_episodes > 0 else 1.0
|
||||
|
||||
graph_signature = (
|
||||
graph_episode_count,
|
||||
graph_entity_like_nodes,
|
||||
graph_edge_count,
|
||||
)
|
||||
|
||||
if graph_signature == last_graph_signature:
|
||||
stable_graph_polls += 1
|
||||
else:
|
||||
last_graph_signature = graph_signature
|
||||
stable_graph_polls = 0
|
||||
|
||||
graph_ready = (
|
||||
graph_episode_count >= total_episodes
|
||||
and (graph_entity_like_nodes > 0 or graph_edge_count > 0)
|
||||
and stable_graph_polls >= stable_graph_required
|
||||
)
|
||||
if graph_ready:
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
(
|
||||
f"OpenZep 图谱已稳定: episodes={graph_episode_count}/{total_episodes}, "
|
||||
f"nodes={graph_node_count}, edges={graph_edge_count}"
|
||||
),
|
||||
1.0,
|
||||
)
|
||||
return
|
||||
|
||||
elapsed = int(elapsed_seconds)
|
||||
effective_progress = max(
|
||||
completed_count / total_episodes if total_episodes > 0 else 1.0,
|
||||
graph_progress,
|
||||
)
|
||||
if progress_callback:
|
||||
if live_stats is not None:
|
||||
progress_callback(
|
||||
(
|
||||
f"OpenZep处理中... 接口完成 {completed_count}/{total_episodes}, "
|
||||
f"图中已写入 episodes={graph_episode_count}/{total_episodes}, "
|
||||
f"nodes={graph_node_count}, edges={graph_edge_count} ({elapsed}秒)"
|
||||
),
|
||||
effective_progress,
|
||||
)
|
||||
else:
|
||||
progress_callback(
|
||||
f"Zep处理中... {completed_count}/{total_episodes} 完成, {len(pending_episodes)} 待处理 ({elapsed}秒)",
|
||||
completed_count / total_episodes if total_episodes > 0 else 0
|
||||
)
|
||||
|
||||
if pending_episodes:
|
||||
time.sleep(3) # 每3秒检查一次
|
||||
|
||||
time.sleep(3)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(f"处理完成: {completed_count}/{total_episodes}", 1.0)
|
||||
|
||||
|
||||
def _get_graph_info(self, graph_id: str) -> GraphInfo:
|
||||
"""获取图谱信息"""
|
||||
# 获取节点(分页)
|
||||
nodes = fetch_all_nodes(self.client, graph_id)
|
||||
nodes = self.backend.get_all_nodes(graph_id)
|
||||
|
||||
# 获取边(分页)
|
||||
edges = fetch_all_edges(self.client, graph_id)
|
||||
edges = self.backend.get_all_edges(graph_id)
|
||||
|
||||
# 统计实体类型
|
||||
entity_types = set()
|
||||
|
|
@ -427,8 +510,8 @@ class GraphBuilderService:
|
|||
Returns:
|
||||
包含nodes和edges的字典,包括时间信息、属性等详细数据
|
||||
"""
|
||||
nodes = fetch_all_nodes(self.client, graph_id)
|
||||
edges = fetch_all_edges(self.client, graph_id)
|
||||
nodes = self.backend.get_all_nodes(graph_id)
|
||||
edges = self.backend.get_all_edges(graph_id)
|
||||
|
||||
# 创建节点映射用于获取节点名称
|
||||
node_map = {}
|
||||
|
|
@ -496,5 +579,4 @@ class GraphBuilderService:
|
|||
|
||||
def delete_graph(self, graph_id: str):
|
||||
"""删除图谱"""
|
||||
self.client.graph.delete(graph_id=graph_id)
|
||||
|
||||
self.backend.delete_graph(graph_id)
|
||||
|
|
|
|||
|
|
@ -16,9 +16,9 @@ from dataclasses import dataclass, field
|
|||
from datetime import datetime
|
||||
|
||||
from openai import OpenAI
|
||||
from zep_cloud.client import Zep
|
||||
|
||||
from ..config import Config
|
||||
from ..graph import get_graph_backend
|
||||
from ..utils.logger import get_logger
|
||||
from .zep_entity_reader import EntityNode, ZepEntityReader
|
||||
|
||||
|
|
@ -198,15 +198,15 @@ class OasisProfileGenerator:
|
|||
)
|
||||
|
||||
# Zep客户端用于检索丰富上下文
|
||||
self.zep_api_key = zep_api_key or Config.ZEP_API_KEY
|
||||
self.zep_client = None
|
||||
self.zep_api_key = Config.ZEP_API_KEY if zep_api_key is None else zep_api_key
|
||||
self.zep_backend = None
|
||||
self.graph_id = graph_id
|
||||
|
||||
if self.zep_api_key:
|
||||
if Config.is_graph_backend_configured(api_key=self.zep_api_key):
|
||||
try:
|
||||
self.zep_client = Zep(api_key=self.zep_api_key)
|
||||
self.zep_backend = get_graph_backend(api_key=self.zep_api_key)
|
||||
except Exception as e:
|
||||
logger.warning(f"Zep客户端初始化失败: {e}")
|
||||
logger.warning(f"图谱客户端初始化失败: {e}")
|
||||
|
||||
def generate_profile_from_entity(
|
||||
self,
|
||||
|
|
@ -297,7 +297,7 @@ class OasisProfileGenerator:
|
|||
"""
|
||||
import concurrent.futures
|
||||
|
||||
if not self.zep_client:
|
||||
if not self.zep_backend:
|
||||
return {"facts": [], "node_summaries": [], "context": ""}
|
||||
|
||||
entity_name = entity.name
|
||||
|
|
@ -323,7 +323,7 @@ class OasisProfileGenerator:
|
|||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return self.zep_client.graph.search(
|
||||
return self.zep_backend.search(
|
||||
query=comprehensive_query,
|
||||
graph_id=self.graph_id,
|
||||
limit=30,
|
||||
|
|
@ -348,7 +348,7 @@ class OasisProfileGenerator:
|
|||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return self.zep_client.graph.search(
|
||||
return self.zep_backend.search(
|
||||
query=comprehensive_query,
|
||||
graph_id=self.graph_id,
|
||||
limit=20,
|
||||
|
|
@ -1197,4 +1197,3 @@ class OasisProfileGenerator:
|
|||
"""[已废弃] 请使用 save_profiles() 方法"""
|
||||
logger.warning("save_profiles_to_json已废弃,请使用save_profiles方法")
|
||||
self.save_profiles(profiles, file_path, platform)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,11 +7,9 @@ import time
|
|||
from typing import Dict, Any, List, Optional, Set, Callable, TypeVar
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from zep_cloud.client import Zep
|
||||
|
||||
from ..config import Config
|
||||
from ..graph import get_graph_backend
|
||||
from ..utils.logger import get_logger
|
||||
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
|
||||
|
||||
logger = get_logger('mirofish.zep_entity_reader')
|
||||
|
||||
|
|
@ -79,11 +77,12 @@ class ZepEntityReader:
|
|||
"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
self.api_key = api_key or Config.ZEP_API_KEY
|
||||
if not self.api_key:
|
||||
raise ValueError("ZEP_API_KEY 未配置")
|
||||
self.api_key = Config.ZEP_API_KEY if api_key is None else api_key
|
||||
errors = Config.get_graph_backend_config_errors(api_key=self.api_key)
|
||||
if errors:
|
||||
raise ValueError("; ".join(errors))
|
||||
|
||||
self.client = Zep(api_key=self.api_key)
|
||||
self.backend = get_graph_backend(api_key=self.api_key)
|
||||
|
||||
def _call_with_retry(
|
||||
self,
|
||||
|
|
@ -136,7 +135,7 @@ class ZepEntityReader:
|
|||
"""
|
||||
logger.info(f"获取图谱 {graph_id} 的所有节点...")
|
||||
|
||||
nodes = fetch_all_nodes(self.client, graph_id)
|
||||
nodes = self.backend.get_all_nodes(graph_id)
|
||||
|
||||
nodes_data = []
|
||||
for node in nodes:
|
||||
|
|
@ -163,7 +162,7 @@ class ZepEntityReader:
|
|||
"""
|
||||
logger.info(f"获取图谱 {graph_id} 的所有边...")
|
||||
|
||||
edges = fetch_all_edges(self.client, graph_id)
|
||||
edges = self.backend.get_all_edges(graph_id)
|
||||
|
||||
edges_data = []
|
||||
for edge in edges:
|
||||
|
|
@ -192,7 +191,7 @@ class ZepEntityReader:
|
|||
try:
|
||||
# 使用重试机制调用Zep API
|
||||
edges = self._call_with_retry(
|
||||
func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid),
|
||||
func=lambda: self.backend.get_node_edges(node_uuid),
|
||||
operation_name=f"获取节点边(node={node_uuid[:8]}...)"
|
||||
)
|
||||
|
||||
|
|
@ -348,7 +347,7 @@ class ZepEntityReader:
|
|||
try:
|
||||
# 使用重试机制获取节点
|
||||
node = self._call_with_retry(
|
||||
func=lambda: self.client.graph.node.get(uuid_=entity_uuid),
|
||||
func=lambda: self.backend.get_node(entity_uuid),
|
||||
operation_name=f"获取节点详情(uuid={entity_uuid[:8]}...)"
|
||||
)
|
||||
|
||||
|
|
@ -434,4 +433,3 @@ class ZepEntityReader:
|
|||
)
|
||||
return result.entities
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,9 +12,8 @@ from dataclasses import dataclass
|
|||
from datetime import datetime
|
||||
from queue import Queue, Empty
|
||||
|
||||
from zep_cloud.client import Zep
|
||||
|
||||
from ..config import Config
|
||||
from ..graph import get_graph_backend
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger('mirofish.zep_graph_memory_updater')
|
||||
|
|
@ -237,12 +236,13 @@ class ZepGraphMemoryUpdater:
|
|||
api_key: Zep API Key(可选,默认从配置读取)
|
||||
"""
|
||||
self.graph_id = graph_id
|
||||
self.api_key = api_key or Config.ZEP_API_KEY
|
||||
self.api_key = Config.ZEP_API_KEY if api_key is None else api_key
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("ZEP_API_KEY未配置")
|
||||
errors = Config.get_graph_backend_config_errors(api_key=self.api_key)
|
||||
if errors:
|
||||
raise ValueError("; ".join(errors))
|
||||
|
||||
self.client = Zep(api_key=self.api_key)
|
||||
self.backend = get_graph_backend(api_key=self.api_key)
|
||||
|
||||
# 活动队列
|
||||
self._activity_queue: Queue = Queue()
|
||||
|
|
@ -405,9 +405,8 @@ class ZepGraphMemoryUpdater:
|
|||
# 带重试的发送
|
||||
for attempt in range(self.MAX_RETRIES):
|
||||
try:
|
||||
self.client.graph.add(
|
||||
self.backend.add_text(
|
||||
graph_id=self.graph_id,
|
||||
type="text",
|
||||
data=combined_text
|
||||
)
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,47 @@
|
|||
"""
|
||||
Embedding client wrapper for OpenAI-compatible embedding APIs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
class EmbeddingClient:
|
||||
"""Thin wrapper around OpenAI-compatible embedding endpoints."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
base_url: str,
|
||||
model: str,
|
||||
batch_size: int = 32,
|
||||
):
|
||||
if not base_url:
|
||||
raise ValueError("Embedding base_url 未配置")
|
||||
if not model:
|
||||
raise ValueError("Embedding model 未配置")
|
||||
|
||||
self.api_key = api_key or 'ollama'
|
||||
self.base_url = base_url
|
||||
self.model = model
|
||||
self.batch_size = max(1, int(batch_size))
|
||||
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed texts in batches while preserving input order."""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
embeddings: List[List[float]] = []
|
||||
normalized_inputs = [str(text or ' ').strip() or ' ' for text in texts]
|
||||
|
||||
for start in range(0, len(normalized_inputs), self.batch_size):
|
||||
batch = normalized_inputs[start:start + self.batch_size]
|
||||
response = self.client.embeddings.create(model=self.model, input=batch)
|
||||
data = sorted(response.data, key=lambda item: item.index)
|
||||
embeddings.extend(item.embedding for item in data)
|
||||
|
||||
return embeddings
|
||||
|
|
@ -0,0 +1,175 @@
|
|||
"""
|
||||
Reranker client wrappers for common HTTP rerank APIs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
from urllib import error, request
|
||||
|
||||
|
||||
@dataclass
|
||||
class RerankerRequestSpec:
|
||||
provider: str
|
||||
path: str
|
||||
body: dict
|
||||
|
||||
|
||||
class RerankerClient:
|
||||
"""Thin wrapper around HTTP rerank endpoints."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
provider: str = "auto",
|
||||
timeout: float = 20.0,
|
||||
):
|
||||
if not base_url:
|
||||
raise ValueError("Reranker base_url 未配置")
|
||||
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.model = model
|
||||
self.api_key = api_key or None
|
||||
self.provider = (provider or "auto").strip().lower() or "auto"
|
||||
self.timeout = max(1.0, float(timeout))
|
||||
|
||||
def rerank(self, query: str, documents: List[str]) -> Dict[int, float]:
|
||||
"""Return index -> score for the supplied candidate documents."""
|
||||
if not documents:
|
||||
return {}
|
||||
|
||||
last_error: Optional[Exception] = None
|
||||
for spec in self._build_request_specs(query, documents):
|
||||
try:
|
||||
return self._execute_request(spec)
|
||||
except Exception as exc:
|
||||
last_error = exc
|
||||
|
||||
if last_error is not None:
|
||||
raise last_error
|
||||
raise RuntimeError("没有可用的 reranker provider")
|
||||
|
||||
def _build_request_specs(self, query: str, documents: List[str]) -> List[RerankerRequestSpec]:
|
||||
providers = [self.provider]
|
||||
if self.provider == "auto":
|
||||
providers = ["tei", "jina", "cohere", "vllm", "infinity"]
|
||||
|
||||
specs: List[RerankerRequestSpec] = []
|
||||
for provider in providers:
|
||||
if provider == "tei":
|
||||
specs.append(
|
||||
RerankerRequestSpec(
|
||||
provider=provider,
|
||||
path="/rerank",
|
||||
body={
|
||||
"query": query,
|
||||
"texts": documents,
|
||||
"truncate": True,
|
||||
"raw_scores": False,
|
||||
},
|
||||
)
|
||||
)
|
||||
elif provider in {"jina", "vllm", "infinity"}:
|
||||
body = {
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
"top_n": len(documents),
|
||||
"return_documents": False,
|
||||
}
|
||||
if self.model:
|
||||
body["model"] = self.model
|
||||
specs.append(
|
||||
RerankerRequestSpec(
|
||||
provider=provider,
|
||||
path="/v1/rerank",
|
||||
body=body,
|
||||
)
|
||||
)
|
||||
elif provider == "cohere":
|
||||
body = {
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
"top_n": len(documents),
|
||||
"return_documents": False,
|
||||
}
|
||||
if self.model:
|
||||
body["model"] = self.model
|
||||
specs.append(
|
||||
RerankerRequestSpec(
|
||||
provider=provider,
|
||||
path="/v2/rerank",
|
||||
body=body,
|
||||
)
|
||||
)
|
||||
|
||||
return specs
|
||||
|
||||
def _headers(self) -> dict:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
return headers
|
||||
|
||||
def _execute_request(self, spec: RerankerRequestSpec) -> Dict[int, float]:
|
||||
payload = json.dumps(spec.body).encode("utf-8")
|
||||
req = request.Request(
|
||||
f"{self.base_url}{spec.path}",
|
||||
data=payload,
|
||||
headers=self._headers(),
|
||||
method="POST",
|
||||
)
|
||||
|
||||
try:
|
||||
with request.urlopen(req, timeout=self.timeout) as resp:
|
||||
body = resp.read().decode("utf-8")
|
||||
except error.HTTPError as exc:
|
||||
detail = exc.read().decode("utf-8", errors="replace")
|
||||
raise RuntimeError(f"{spec.provider} rerank 请求失败: HTTP {exc.code}: {detail[:300]}") from exc
|
||||
except error.URLError as exc:
|
||||
raise RuntimeError(f"{spec.provider} rerank 请求失败: {exc.reason}") from exc
|
||||
|
||||
try:
|
||||
data = json.loads(body)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise RuntimeError(f"{spec.provider} rerank 返回非 JSON 响应") from exc
|
||||
|
||||
scores = self._parse_scores(spec.provider, data)
|
||||
if not scores:
|
||||
raise RuntimeError(f"{spec.provider} rerank 未返回有效分数")
|
||||
return scores
|
||||
|
||||
def _parse_scores(self, provider: str, payload: object) -> Dict[int, float]:
|
||||
if provider == "tei":
|
||||
if not isinstance(payload, list):
|
||||
raise RuntimeError("TEI rerank 响应格式异常")
|
||||
|
||||
scores: Dict[int, float] = {}
|
||||
for item in payload:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
index = item.get("index")
|
||||
score = item.get("score")
|
||||
if isinstance(index, int) and score is not None:
|
||||
scores[index] = float(score)
|
||||
return scores
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
raise RuntimeError("rerank 响应格式异常")
|
||||
|
||||
results = payload.get("results") or payload.get("data") or []
|
||||
scores: Dict[int, float] = {}
|
||||
if isinstance(results, list):
|
||||
for item in results:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
index = item.get("index")
|
||||
score = item.get("relevance_score")
|
||||
if score is None:
|
||||
score = item.get("score")
|
||||
if isinstance(index, int) and score is not None:
|
||||
scores[index] = float(score)
|
||||
return scores
|
||||
|
|
@ -14,10 +14,14 @@ dependencies = [
|
|||
"flask-cors>=6.0.0",
|
||||
|
||||
# LLM 相关
|
||||
"openai>=1.0.0",
|
||||
"openai>=1.91.0",
|
||||
|
||||
# Zep Cloud
|
||||
"zep-cloud==3.13.0",
|
||||
|
||||
# Graphiti / Neo4j
|
||||
"graphiti-core==0.28.2",
|
||||
"neo4j>=5.26.0",
|
||||
|
||||
# OASIS 社交媒体模拟
|
||||
"camel-oasis==0.2.5",
|
||||
|
|
@ -31,7 +35,7 @@ dependencies = [
|
|||
|
||||
# 工具库
|
||||
"python-dotenv>=1.0.0",
|
||||
"pydantic>=2.0.0",
|
||||
"pydantic>=2.11.5",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
|
|
|||
|
|
@ -11,11 +11,15 @@ flask-cors>=6.0.0
|
|||
|
||||
# ============= LLM 相关 =============
|
||||
# OpenAI SDK(统一使用 OpenAI 格式调用 LLM)
|
||||
openai>=1.0.0
|
||||
openai>=1.91.0
|
||||
|
||||
# ============= Zep Cloud =============
|
||||
zep-cloud==3.13.0
|
||||
|
||||
# ============= Graphiti / Neo4j =============
|
||||
graphiti-core==0.28.2
|
||||
neo4j>=5.26.0
|
||||
|
||||
# ============= OASIS 社交媒体模拟 =============
|
||||
# OASIS 社交模拟框架
|
||||
camel-oasis==0.2.5
|
||||
|
|
@ -32,4 +36,4 @@ chardet>=5.0.0
|
|||
python-dotenv>=1.0.0
|
||||
|
||||
# 数据验证
|
||||
pydantic>=2.0.0
|
||||
pydantic>=2.11.5
|
||||
|
|
|
|||
|
|
@ -0,0 +1,223 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Service-level Graphiti smoke test for MiroFish.
|
||||
|
||||
The script creates a temporary graph, applies a small ontology, ingests a few
|
||||
text chunks through GraphBuilderService, then queries through ZepToolsService.
|
||||
It prints JSON output and deletes the graph by default.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
BACKEND_ROOT = REPO_ROOT / "backend"
|
||||
|
||||
DEFAULT_CHUNKS = [
|
||||
"Alice works at Acme Robotics. Bob manages Alice.",
|
||||
"Acme Robotics is located in Tokyo. Alice relocated to Tokyo for work.",
|
||||
]
|
||||
|
||||
|
||||
def load_env_file(path: Path) -> None:
|
||||
"""Load a simple .env file without overriding existing variables."""
|
||||
if not path.exists():
|
||||
return
|
||||
|
||||
for raw_line in path.read_text(encoding="utf-8").splitlines():
|
||||
line = raw_line.strip()
|
||||
if not line or line.startswith("#") or "=" not in line:
|
||||
continue
|
||||
|
||||
key, value = line.split("=", 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
|
||||
if value and value[0] == value[-1] and value[0] in {"'", '"'}:
|
||||
value = value[1:-1]
|
||||
|
||||
os.environ.setdefault(key, value)
|
||||
|
||||
|
||||
def prepare_environment() -> None:
|
||||
"""Apply repo-local defaults that make Graphiti smoke testing predictable."""
|
||||
load_env_file(REPO_ROOT / ".env")
|
||||
load_env_file(BACKEND_ROOT / ".env")
|
||||
|
||||
os.environ.setdefault("GRAPH_BACKEND", "graphiti")
|
||||
os.environ.setdefault("GRAPHITI_URI", "bolt://127.0.0.1:7687")
|
||||
os.environ.setdefault("GRAPHITI_USER", "neo4j")
|
||||
os.environ.setdefault("GRAPHITI_PASSWORD", "password123")
|
||||
os.environ.setdefault("GRAPHITI_DATABASE", "neo4j")
|
||||
os.environ.setdefault("GRAPHITI_LLM_BASE_URL", os.environ.get("LLM_BASE_URL", "http://127.0.0.1:18081/v1"))
|
||||
os.environ.setdefault("GRAPHITI_LLM_MODEL", os.environ.get("LLM_MODEL_NAME", "gpt-5.4"))
|
||||
os.environ.setdefault("GRAPHITI_LLM_CLIENT_MODE", "openai")
|
||||
os.environ.setdefault("GRAPH_SEARCH_RERANKER", "rrf")
|
||||
os.environ.setdefault("GRAPHITI_EMBEDDER_API_KEY", "ollama")
|
||||
os.environ.setdefault("GRAPHITI_EMBEDDER_BASE_URL", "http://127.0.0.1:11434/v1")
|
||||
os.environ.setdefault("GRAPHITI_EMBEDDER_MODEL", "qwen3-embedding:8b")
|
||||
os.environ.setdefault("GRAPHITI_EMBEDDER_DIM", "1024")
|
||||
|
||||
|
||||
def parse_args(argv: Iterable[str]) -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Run a Graphiti service-level smoke test.")
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
default="Where does Alice work?",
|
||||
help="Search query executed through ZepToolsService.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chunk",
|
||||
action="append",
|
||||
dest="chunks",
|
||||
help="Text chunk to ingest. Repeat to add multiple chunks.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--graph-name",
|
||||
default="graphiti smoke test",
|
||||
help="Temporary graph name.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keep-graph",
|
||||
action="store_true",
|
||||
help="Keep the temporary graph instead of deleting it on exit.",
|
||||
)
|
||||
return parser.parse_args(list(argv))
|
||||
|
||||
|
||||
def build_sample_ontology() -> dict:
|
||||
return {
|
||||
"entity_types": [
|
||||
{"name": "Person", "description": "A human individual.", "attributes": []},
|
||||
{"name": "Organization", "description": "A company or institution.", "attributes": []},
|
||||
{"name": "Location", "description": "A place or city.", "attributes": []},
|
||||
],
|
||||
"edge_types": [
|
||||
{
|
||||
"name": "WORKS_AT",
|
||||
"description": "A person works at an organization.",
|
||||
"attributes": [],
|
||||
"source_targets": [{"source": "Person", "target": "Organization"}],
|
||||
},
|
||||
{
|
||||
"name": "MANAGES",
|
||||
"description": "A person manages another person.",
|
||||
"attributes": [],
|
||||
"source_targets": [{"source": "Person", "target": "Person"}],
|
||||
},
|
||||
{
|
||||
"name": "LOCATED_IN",
|
||||
"description": "An organization is located in a place.",
|
||||
"attributes": [],
|
||||
"source_targets": [{"source": "Organization", "target": "Location"}],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def main(argv: Iterable[str]) -> int:
|
||||
prepare_environment()
|
||||
sys.path.insert(0, str(BACKEND_ROOT))
|
||||
|
||||
from app.config import Config
|
||||
|
||||
search_embedder = Config.get_graph_search_embedder_config()
|
||||
search_reranker = Config.get_graph_search_reranker_config()
|
||||
errors = Config.get_graph_backend_config_errors(api_key=Config.ZEP_API_KEY)
|
||||
if errors:
|
||||
print(
|
||||
json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Graph backend config is incomplete.",
|
||||
"backend": Config.GRAPH_BACKEND,
|
||||
"errors": errors,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
),
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 2
|
||||
|
||||
from app.services.graph_builder import GraphBuilderService
|
||||
from app.services.zep_tools import ZepToolsService
|
||||
|
||||
args = parse_args(argv)
|
||||
chunks = args.chunks or list(DEFAULT_CHUNKS)
|
||||
|
||||
builder = GraphBuilderService()
|
||||
tools = ZepToolsService()
|
||||
graph_id = builder.create_graph(f"{args.graph_name}-{uuid.uuid4().hex[:8]}")
|
||||
|
||||
try:
|
||||
builder.set_ontology(graph_id, build_sample_ontology())
|
||||
builder.add_text_batches(graph_id, chunks, batch_size=1)
|
||||
|
||||
result = tools.search_graph(
|
||||
graph_id=graph_id,
|
||||
query=args.query,
|
||||
limit=5,
|
||||
scope="edges",
|
||||
)
|
||||
stats = tools.get_graph_statistics(graph_id)
|
||||
|
||||
output = {
|
||||
"success": True,
|
||||
"graph_id": graph_id,
|
||||
"kept_graph": args.keep_graph,
|
||||
"backend": Config.GRAPH_BACKEND,
|
||||
"llm_base_url": Config.GRAPHITI_LLM_BASE_URL or Config.LLM_BASE_URL,
|
||||
"llm_model": Config.GRAPHITI_LLM_MODEL or Config.LLM_MODEL_NAME,
|
||||
"embedder_base_url": Config.GRAPHITI_EMBEDDER_BASE_URL,
|
||||
"embedder_model": Config.GRAPHITI_EMBEDDER_MODEL,
|
||||
"app_reranker": Config.GRAPH_SEARCH_APP_RERANKER,
|
||||
"app_embedder_base_url": search_embedder.get("base_url"),
|
||||
"app_embedder_model": search_embedder.get("model"),
|
||||
"app_reranker_base_url": search_reranker.get("base_url"),
|
||||
"app_reranker_model": search_reranker.get("model"),
|
||||
"app_reranker_provider": search_reranker.get("provider"),
|
||||
"query": args.query,
|
||||
"chunks": chunks,
|
||||
"stats": stats,
|
||||
"facts": result.facts,
|
||||
"edges": result.edges,
|
||||
"nodes": result.nodes,
|
||||
"total_count": result.total_count,
|
||||
}
|
||||
print(json.dumps(output, ensure_ascii=False, indent=2))
|
||||
return 0
|
||||
except Exception as exc:
|
||||
print(
|
||||
json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"graph_id": graph_id,
|
||||
"error": str(exc),
|
||||
"traceback": traceback.format_exc(),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
),
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
finally:
|
||||
if not args.keep_graph:
|
||||
try:
|
||||
builder.backend.delete_graph(graph_id)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main(sys.argv[1:]))
|
||||
Loading…
Reference in New Issue