From 25d43f8a4bca10e5608ce7b1a2ad250d43bafea5 Mon Sep 17 00:00:00 2001 From: MiroFish Bot Date: Sat, 21 Mar 2026 17:25:34 +0900 Subject: [PATCH 1/2] feat(graph): add pluggable graph backend with Graphiti support --- .env.example | 78 +- Dockerfile | 7 +- backend/app/api/graph.py | 36 +- backend/app/api/simulation.py | 15 +- backend/app/config.py | 168 ++- backend/app/graph/__init__.py | 8 + backend/app/graph/base.py | 81 ++ backend/app/graph/factory.py | 27 + backend/app/graph/graphiti_backend.py | 875 +++++++++++++ backend/app/graph/zep_backend.py | 114 ++ backend/app/services/graph_builder.py | 168 ++- .../app/services/oasis_profile_generator.py | 19 +- backend/app/services/zep_entity_reader.py | 22 +- .../app/services/zep_graph_memory_updater.py | 15 +- backend/app/services/zep_tools.py | 1148 ++++++++++++----- backend/app/utils/embedding_client.py | 47 + backend/app/utils/reranker_client.py | 175 +++ backend/pyproject.toml | 8 +- backend/requirements.txt | 8 +- backend/scripts/graphiti_smoke_test.py | 223 ++++ 20 files changed, 2834 insertions(+), 408 deletions(-) create mode 100644 backend/app/graph/__init__.py create mode 100644 backend/app/graph/base.py create mode 100644 backend/app/graph/factory.py create mode 100644 backend/app/graph/graphiti_backend.py create mode 100644 backend/app/graph/zep_backend.py create mode 100644 backend/app/utils/embedding_client.py create mode 100644 backend/app/utils/reranker_client.py create mode 100644 backend/scripts/graphiti_smoke_test.py diff --git a/.env.example b/.env.example index 78a3b72c..17632aba 100644 --- a/.env.example +++ b/.env.example @@ -1,16 +1,78 @@ -# LLM API配置(支持 OpenAI SDK 格式的任意 LLM API) -# 推荐使用阿里百炼平台qwen-plus模型:https://bailian.console.aliyun.com/ -# 注意消耗较大,可先进行小于40轮的模拟尝试 +# LLM API 配置(支持 OpenAI SDK 格式的任意 LLM API) +# 可直接填写你自己的 LLM 接口,例如: +# LLM_BASE_URL=http://127.0.0.1:18081/v1 +# LLM_MODEL_NAME=gpt-5.4 LLM_API_KEY=your_api_key_here LLM_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1 LLM_MODEL_NAME=qwen-plus -# ===== ZEP记忆图谱配置 ===== -# 每月免费额度即可支撑简单使用:https://app.getzep.com/ -ZEP_API_KEY=your_zep_api_key_here +# Docker 容器内访问宿主机 LLM 时使用的地址 +# Linux + Docker Compose 下可保持默认 host.docker.internal +DOCKER_LLM_BASE_URL=http://host.docker.internal:18081/v1 + +# ===== Graphiti + Neo4j(默认推荐)===== +GRAPH_BACKEND=graphiti +GRAPHITI_URI=bolt://localhost:7687 +GRAPHITI_USER=neo4j +GRAPHITI_PASSWORD=password123 +GRAPHITI_DATABASE=neo4j +GRAPHITI_LLM_CLIENT_MODE=openai +GRAPHITI_EMBEDDER_API_KEY=ollama +GRAPHITI_EMBEDDER_BASE_URL=http://127.0.0.1:11434/v1 +GRAPHITI_EMBEDDER_MODEL=qwen3-embedding:8b +GRAPHITI_EMBEDDER_DIM=1024 +GRAPH_SEARCH_RERANKER=rrf +GRAPH_SEARCH_APP_RERANKER=embedding_rrf +GRAPH_SEARCH_APP_SEMANTIC_WEIGHT=2.0 +GRAPH_SEARCH_EXPAND_EDGES_FROM_NODES=true +OLLAMA_PORT=11434 +OLLAMA_EMBEDDER_MODEL=qwen3-embedding:8b + +# 可选:如需独立于 Graphiti / OpenZep 当前 embedder,可单独覆写: +# GRAPH_SEARCH_APP_EMBEDDER_API_KEY=ollama +# GRAPH_SEARCH_APP_EMBEDDER_BASE_URL=http://127.0.0.1:11434/v1 +# GRAPH_SEARCH_APP_EMBEDDER_MODEL=qwen3-embedding:8b +# 可选:如果你另起了免费 cross-encoder / rerank 服务(TEI / Infinity / vLLM) +# GRAPH_SEARCH_APP_RERANKER=api_rrf +# GRAPH_SEARCH_APP_RERANKER_PROVIDER=tei +# GRAPH_SEARCH_APP_RERANKER_BASE_URL=http://127.0.0.1:18090 +# GRAPH_SEARCH_APP_RERANKER_MODEL=your_reranker_model +# GRAPH_SEARCH_APP_RERANKER_TIMEOUT=20 +# 免费召回增强:从高相关节点补抓相邻边,默认开启 +# GRAPH_SEARCH_NODE_EDGE_EXPANSION_LIMIT=2 +# GRAPH_SEARCH_NODE_EDGE_PER_NODE_LIMIT=8 +# GRAPHITI_ENABLE_CROSS_ENCODER=false + +# Docker 中的 MiroFish 容器访问宿主机 LLM / 容器内 Ollama 时使用: +DOCKER_GRAPHITI_LLM_BASE_URL=http://host.docker.internal:18081/v1 +DOCKER_GRAPHITI_EMBEDDER_BASE_URL=http://ollama:11434/v1 + +# ===== OpenZep / Zep(可选,非默认)===== +# ZEP_API_KEY=your_zep_api_key_here +# ZEP_MODE=openzep +# 本地源码运行 MiroFish 时使用 localhost +# ZEP_BASE_URL=http://localhost:8000/api/v2 +# Docker 中的 MiroFish 容器会自动改用 openzep 服务名 +# DOCKER_ZEP_BASE_URL=http://openzep:8000/api/v2 +# 留空表示不启用 OpenZep API 鉴权 +# OPENZEP_API_KEY= +# OPENZEP_LLM_API_KEY=your_api_key_here +# OPENZEP_LLM_BASE_URL=http://127.0.0.1:18081/v1 +# OPENZEP_DOCKER_LLM_BASE_URL=http://host.docker.internal:18081/v1 +# OPENZEP_LLM_MODEL=gpt-5.4 +# OPENZEP_EMBEDDER_API_KEY=ollama +# OPENZEP_EMBEDDER_BASE_URL=http://127.0.0.1:11434/v1 +# OPENZEP_DOCKER_EMBEDDER_BASE_URL=http://ollama:11434/v1 +# OPENZEP_EMBEDDER_MODEL=qwen3-embedding:8b + +# ===== Neo4j / OpenZep 端口 ===== +NEO4J_PASSWORD=password123 +NEO4J_HTTP_PORT=7474 +NEO4J_BOLT_PORT=7687 +OPENZEP_PORT=8000 # ===== 加速 LLM 配置(可选)===== -# 注意如果不使用加速配置,env文件中就不要出现下面的配置项 +# 注意如果不使用加速配置,env 文件中就不要出现下面的配置项 LLM_BOOST_API_KEY=your_api_key_here LLM_BOOST_BASE_URL=your_base_url_here -LLM_BOOST_MODEL_NAME=your_model_name_here \ No newline at end of file +LLM_BOOST_MODEL_NAME=your_model_name_here diff --git a/Dockerfile b/Dockerfile index e6564686..ce2dd30d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -10,6 +10,8 @@ COPY --from=ghcr.io/astral-sh/uv:0.9.26 /uv /uvx /bin/ WORKDIR /app +ARG INSTALL_GRAPHITI=true + # 先复制依赖描述文件以利用缓存 COPY package.json package-lock.json ./ COPY frontend/package.json frontend/package-lock.json ./frontend/ @@ -18,7 +20,10 @@ COPY backend/pyproject.toml backend/uv.lock ./backend/ # 安装依赖(Node + Python) RUN npm ci \ && npm ci --prefix frontend \ - && cd backend && uv sync --frozen + && cd backend && uv sync --frozen \ + && if [ "$INSTALL_GRAPHITI" = "true" ]; then \ + uv pip install --python /app/backend/.venv/bin/python --no-cache-dir graphiti-core==0.28.2 "neo4j>=5.26.0"; \ + fi # 复制项目源码 COPY . . diff --git a/backend/app/api/graph.py b/backend/app/api/graph.py index 12ff1ba2..d26e7d1a 100644 --- a/backend/app/api/graph.py +++ b/backend/app/api/graph.py @@ -283,9 +283,7 @@ def build_graph(): logger.info("=== 开始构建图谱 ===") # 检查配置 - errors = [] - if not Config.ZEP_API_KEY: - errors.append("ZEP_API_KEY未配置") + errors = Config.get_graph_backend_config_errors() if errors: logger.error(f"配置错误: {errors}") return jsonify({ @@ -432,10 +430,12 @@ def build_graph(): progress=15 ) + # OpenZep 本地链路在批量抽取时更容易卡在长时间的联合推理里。 + # 改为单块发送可以显著降低单次处理负载,牺牲吞吐换稳定性。 episode_uuids = builder.add_text_batches( - graph_id, + graph_id, chunks, - batch_size=3, + batch_size=1 if Config.use_openzep() else 3, progress_callback=add_progress_callback ) @@ -454,7 +454,7 @@ def build_graph(): progress=progress ) - builder._wait_for_episodes(episode_uuids, wait_progress_callback) + builder._wait_for_episodes(graph_id, episode_uuids, wait_progress_callback) # 获取图谱数据 task_manager.update_task( @@ -464,12 +464,20 @@ def build_graph(): ) graph_data = builder.get_graph_data(graph_id) + node_count = graph_data.get("node_count", 0) + edge_count = graph_data.get("edge_count", 0) + + # 如果图谱仍然是空的,说明 OpenZep 没有成功完成抽取。 + # 不能把这种情况标记为成功,否则前端会误以为构图完成。 + if node_count == 0 and edge_count == 0: + raise RuntimeError( + "图谱构建未产出任何节点或边;OpenZep 处理可能超时或未完成" + ) + # 更新项目状态 project.status = ProjectStatus.GRAPH_COMPLETED ProjectManager.save_project(project) - - node_count = graph_data.get("node_count", 0) - edge_count = graph_data.get("edge_count", 0) + build_logger.info(f"[{task_id}] 图谱构建完成: graph_id={graph_id}, 节点={node_count}, 边={edge_count}") # 完成 @@ -567,10 +575,11 @@ def get_graph_data(graph_id: str): 获取图谱数据(节点和边) """ try: - if not Config.ZEP_API_KEY: + errors = Config.get_graph_backend_config_errors() + if errors: return jsonify({ "success": False, - "error": "ZEP_API_KEY未配置" + "error": "; ".join(errors) }), 500 builder = GraphBuilderService(api_key=Config.ZEP_API_KEY) @@ -595,10 +604,11 @@ def delete_graph(graph_id: str): 删除Zep图谱 """ try: - if not Config.ZEP_API_KEY: + errors = Config.get_graph_backend_config_errors() + if errors: return jsonify({ "success": False, - "error": "ZEP_API_KEY未配置" + "error": "; ".join(errors) }), 500 builder = GraphBuilderService(api_key=Config.ZEP_API_KEY) diff --git a/backend/app/api/simulation.py b/backend/app/api/simulation.py index 3a0f6816..53b2bb7b 100644 --- a/backend/app/api/simulation.py +++ b/backend/app/api/simulation.py @@ -56,10 +56,11 @@ def get_graph_entities(graph_id: str): enrich: 是否获取相关边信息(默认true) """ try: - if not Config.ZEP_API_KEY: + errors = Config.get_graph_backend_config_errors() + if errors: return jsonify({ "success": False, - "error": "ZEP_API_KEY未配置" + "error": "; ".join(errors) }), 500 entity_types_str = request.args.get('entity_types', '') @@ -93,10 +94,11 @@ def get_graph_entities(graph_id: str): def get_entity_detail(graph_id: str, entity_uuid: str): """获取单个实体的详细信息""" try: - if not Config.ZEP_API_KEY: + errors = Config.get_graph_backend_config_errors() + if errors: return jsonify({ "success": False, - "error": "ZEP_API_KEY未配置" + "error": "; ".join(errors) }), 500 reader = ZepEntityReader() @@ -126,10 +128,11 @@ def get_entity_detail(graph_id: str, entity_uuid: str): def get_entities_by_type(graph_id: str, entity_type: str): """获取指定类型的所有实体""" try: - if not Config.ZEP_API_KEY: + errors = Config.get_graph_backend_config_errors() + if errors: return jsonify({ "success": False, - "error": "ZEP_API_KEY未配置" + "error": "; ".join(errors) }), 500 enrich = request.args.get('enrich', 'true').lower() == 'true' diff --git a/backend/app/config.py b/backend/app/config.py index 953dfa50..967c862e 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -33,7 +33,54 @@ class Config: LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME', 'gpt-4o-mini') # Zep配置 + ZEP_MODE = os.environ.get('ZEP_MODE', 'cloud').lower() ZEP_API_KEY = os.environ.get('ZEP_API_KEY') + ZEP_BASE_URL = os.environ.get('ZEP_BASE_URL') + OPENZEP_EMBEDDER_API_KEY = os.environ.get('OPENZEP_EMBEDDER_API_KEY') or None + OPENZEP_EMBEDDER_BASE_URL = os.environ.get('OPENZEP_EMBEDDER_BASE_URL') or None + OPENZEP_EMBEDDER_MODEL = os.environ.get('OPENZEP_EMBEDDER_MODEL') or None + + # 图后端配置 + GRAPH_BACKEND = os.environ.get('GRAPH_BACKEND', 'graphiti').lower() + GRAPH_SEARCH_RERANKER = os.environ.get('GRAPH_SEARCH_RERANKER', 'rrf').strip() or None + GRAPH_SEARCH_APP_RERANKER = (os.environ.get('GRAPH_SEARCH_APP_RERANKER', 'embedding_rrf').strip().lower() or 'embedding_rrf') + GRAPH_SEARCH_APP_RERANK_FUSION_K = max(1, int(os.environ.get('GRAPH_SEARCH_APP_RERANK_FUSION_K', '60'))) + GRAPH_SEARCH_APP_SEMANTIC_WEIGHT = max(0.0, float(os.environ.get('GRAPH_SEARCH_APP_SEMANTIC_WEIGHT', '2.0'))) + GRAPH_SEARCH_APP_EMBEDDER_API_KEY = os.environ.get('GRAPH_SEARCH_APP_EMBEDDER_API_KEY') or None + GRAPH_SEARCH_APP_EMBEDDER_BASE_URL = os.environ.get('GRAPH_SEARCH_APP_EMBEDDER_BASE_URL') or None + GRAPH_SEARCH_APP_EMBEDDER_MODEL = os.environ.get('GRAPH_SEARCH_APP_EMBEDDER_MODEL') or None + GRAPH_SEARCH_APP_EMBED_BATCH_SIZE = max(1, int(os.environ.get('GRAPH_SEARCH_APP_EMBED_BATCH_SIZE', '32'))) + GRAPH_SEARCH_APP_RERANKER_API_KEY = os.environ.get('GRAPH_SEARCH_APP_RERANKER_API_KEY') or None + GRAPH_SEARCH_APP_RERANKER_BASE_URL = os.environ.get('GRAPH_SEARCH_APP_RERANKER_BASE_URL') or None + GRAPH_SEARCH_APP_RERANKER_MODEL = os.environ.get('GRAPH_SEARCH_APP_RERANKER_MODEL') or None + GRAPH_SEARCH_APP_RERANKER_PROVIDER = (os.environ.get('GRAPH_SEARCH_APP_RERANKER_PROVIDER', 'auto').strip().lower() or 'auto') + GRAPH_SEARCH_APP_RERANKER_TIMEOUT = max(1.0, float(os.environ.get('GRAPH_SEARCH_APP_RERANKER_TIMEOUT', '20'))) + GRAPH_SEARCH_INCLUDE_NODES = os.environ.get('GRAPH_SEARCH_INCLUDE_NODES', 'true').lower() == 'true' + GRAPH_SEARCH_EDGE_LIMIT_MULTIPLIER = max(1, int(os.environ.get('GRAPH_SEARCH_EDGE_LIMIT_MULTIPLIER', '2'))) + GRAPH_SEARCH_NODE_LIMIT_MULTIPLIER = max(1, int(os.environ.get('GRAPH_SEARCH_NODE_LIMIT_MULTIPLIER', '1'))) + GRAPH_SEARCH_NODE_SUMMARY_LIMIT = max(1, int(os.environ.get('GRAPH_SEARCH_NODE_SUMMARY_LIMIT', '5'))) + GRAPH_SEARCH_EXPAND_EDGES_FROM_NODES = os.environ.get('GRAPH_SEARCH_EXPAND_EDGES_FROM_NODES', 'true').lower() == 'true' + GRAPH_SEARCH_NODE_EDGE_EXPANSION_LIMIT = max(0, int(os.environ.get('GRAPH_SEARCH_NODE_EDGE_EXPANSION_LIMIT', '2'))) + GRAPH_SEARCH_NODE_EDGE_PER_NODE_LIMIT = max(1, int(os.environ.get('GRAPH_SEARCH_NODE_EDGE_PER_NODE_LIMIT', '8'))) + GRAPHITI_URI = os.environ.get('GRAPHITI_URI') + GRAPHITI_USER = os.environ.get('GRAPHITI_USER', 'neo4j') + GRAPHITI_PASSWORD = os.environ.get('GRAPHITI_PASSWORD') + GRAPHITI_DATABASE = os.environ.get('GRAPHITI_DATABASE', 'neo4j') + GRAPHITI_LLM_API_KEY = os.environ.get('GRAPHITI_LLM_API_KEY') or LLM_API_KEY + GRAPHITI_LLM_BASE_URL = os.environ.get('GRAPHITI_LLM_BASE_URL') or LLM_BASE_URL + GRAPHITI_LLM_MODEL = os.environ.get('GRAPHITI_LLM_MODEL') or LLM_MODEL_NAME + GRAPHITI_LLM_SMALL_MODEL = os.environ.get('GRAPHITI_LLM_SMALL_MODEL') or GRAPHITI_LLM_MODEL + GRAPHITI_LLM_CLIENT_MODE = os.environ.get('GRAPHITI_LLM_CLIENT_MODE', 'openai').lower() + GRAPHITI_LLM_MAX_TOKENS = max(1024, int(os.environ.get('GRAPHITI_LLM_MAX_TOKENS', '16384'))) + GRAPHITI_EMBEDDER_API_KEY = os.environ.get('GRAPHITI_EMBEDDER_API_KEY') or GRAPHITI_LLM_API_KEY + GRAPHITI_EMBEDDER_BASE_URL = os.environ.get('GRAPHITI_EMBEDDER_BASE_URL') or GRAPHITI_LLM_BASE_URL + GRAPHITI_EMBEDDER_MODEL = os.environ.get('GRAPHITI_EMBEDDER_MODEL', 'qwen3-embedding:8b') + GRAPHITI_EMBEDDER_DIM = max(128, int(os.environ.get('GRAPHITI_EMBEDDER_DIM', '1024'))) + GRAPHITI_RERANKER_API_KEY = os.environ.get('GRAPHITI_RERANKER_API_KEY') or GRAPHITI_LLM_API_KEY + GRAPHITI_RERANKER_BASE_URL = os.environ.get('GRAPHITI_RERANKER_BASE_URL') or GRAPHITI_LLM_BASE_URL + GRAPHITI_RERANKER_MODEL = os.environ.get('GRAPHITI_RERANKER_MODEL') or GRAPHITI_LLM_MODEL + GRAPHITI_ENABLE_CROSS_ENCODER = os.environ.get('GRAPHITI_ENABLE_CROSS_ENCODER', 'false').lower() == 'true' + GRAPHITI_MAX_COROUTINES = max(1, int(os.environ.get('GRAPHITI_MAX_COROUTINES', '20'))) # 文件上传配置 MAX_CONTENT_LENGTH = 50 * 1024 * 1024 # 50MB @@ -62,6 +109,123 @@ class Config: REPORT_AGENT_MAX_TOOL_CALLS = int(os.environ.get('REPORT_AGENT_MAX_TOOL_CALLS', '5')) REPORT_AGENT_MAX_REFLECTION_ROUNDS = int(os.environ.get('REPORT_AGENT_MAX_REFLECTION_ROUNDS', '2')) REPORT_AGENT_TEMPERATURE = float(os.environ.get('REPORT_AGENT_TEMPERATURE', '0.5')) + + @classmethod + def use_openzep(cls): + """是否启用 OpenZep / 自定义 Zep endpoint。""" + return cls.ZEP_MODE == 'openzep' or bool(cls.ZEP_BASE_URL) + + @classmethod + def is_zep_configured(cls, api_key=None): + """检查 Zep/OpenZep 是否已完成最小配置。""" + resolved_api_key = cls.ZEP_API_KEY if api_key is None else api_key + + if cls.use_openzep(): + return bool(cls.ZEP_BASE_URL) + + return bool(resolved_api_key) + + @classmethod + def get_zep_client_kwargs(cls, api_key=None): + """生成 Zep 客户端初始化参数。""" + resolved_api_key = cls.ZEP_API_KEY if api_key is None else api_key + kwargs = {} + + # OpenZep 场景允许关闭鉴权;此时不要传空字符串 api_key, + # 否则 zep sdk 会构造出非法的 `Api-Key:` 请求头,httpx 会直接拒绝。 + if resolved_api_key: + kwargs['api_key'] = resolved_api_key + if cls.ZEP_BASE_URL: + kwargs['base_url'] = cls.ZEP_BASE_URL + + return kwargs + + @classmethod + def get_zep_config_errors(cls, api_key=None): + """返回 Zep/OpenZep 的配置错误。""" + if cls.is_zep_configured(api_key=api_key): + return [] + + if cls.use_openzep(): + return ["ZEP_BASE_URL 未配置"] + + return ["ZEP_API_KEY 未配置"] + + @classmethod + def get_graph_search_embedder_config(cls): + """返回 app-side 图搜索语义重排使用的 embedding 配置。""" + api_key = cls.GRAPH_SEARCH_APP_EMBEDDER_API_KEY + base_url = cls.GRAPH_SEARCH_APP_EMBEDDER_BASE_URL + model = cls.GRAPH_SEARCH_APP_EMBEDDER_MODEL + + backend = (cls.GRAPH_BACKEND or 'graphiti').lower() + if backend == 'graphiti': + api_key = api_key or cls.GRAPHITI_EMBEDDER_API_KEY + base_url = base_url or cls.GRAPHITI_EMBEDDER_BASE_URL + model = model or cls.GRAPHITI_EMBEDDER_MODEL + elif cls.use_openzep(): + api_key = api_key or cls.OPENZEP_EMBEDDER_API_KEY + base_url = base_url or cls.OPENZEP_EMBEDDER_BASE_URL + model = model or cls.OPENZEP_EMBEDDER_MODEL + + if model and not api_key: + api_key = 'ollama' + + return { + 'api_key': api_key, + 'base_url': base_url, + 'model': model, + } + + @classmethod + def get_graph_search_reranker_config(cls): + """返回 app-side 图搜索交叉编码重排使用的 reranker 配置。""" + api_key = cls.GRAPH_SEARCH_APP_RERANKER_API_KEY + base_url = cls.GRAPH_SEARCH_APP_RERANKER_BASE_URL + model = cls.GRAPH_SEARCH_APP_RERANKER_MODEL + provider = cls.GRAPH_SEARCH_APP_RERANKER_PROVIDER + + graphiti_reranker_base_url = os.environ.get('GRAPHITI_RERANKER_BASE_URL') or None + if graphiti_reranker_base_url: + api_key = api_key or (os.environ.get('GRAPHITI_RERANKER_API_KEY') or None) + base_url = base_url or graphiti_reranker_base_url + model = model or (os.environ.get('GRAPHITI_RERANKER_MODEL') or None) + + return { + 'api_key': api_key, + 'base_url': base_url, + 'model': model, + 'provider': provider, + 'timeout': cls.GRAPH_SEARCH_APP_RERANKER_TIMEOUT, + } + + @classmethod + def get_graphiti_config_errors(cls): + """返回 Graphiti + Neo4j 的配置错误。""" + errors = [] + if not cls.GRAPHITI_URI: + errors.append("GRAPHITI_URI 未配置") + if not cls.GRAPHITI_DATABASE: + errors.append("GRAPHITI_DATABASE 未配置") + if not cls.GRAPHITI_LLM_MODEL: + errors.append("GRAPHITI_LLM_MODEL 未配置") + if not cls.GRAPHITI_EMBEDDER_MODEL: + errors.append("GRAPHITI_EMBEDDER_MODEL 未配置") + return errors + + @classmethod + def get_graph_backend_config_errors(cls, api_key=None): + """根据当前 GRAPH_BACKEND 返回对应的配置错误。""" + backend = (cls.GRAPH_BACKEND or 'graphiti').lower() + if backend in {'zep', 'openzep'}: + return cls.get_zep_config_errors(api_key=api_key) + if backend == 'graphiti': + return cls.get_graphiti_config_errors() + return [f"不支持的 GRAPH_BACKEND: {backend}"] + + @classmethod + def is_graph_backend_configured(cls, api_key=None): + return len(cls.get_graph_backend_config_errors(api_key=api_key)) == 0 @classmethod def validate(cls): @@ -69,7 +233,5 @@ class Config: errors = [] if not cls.LLM_API_KEY: errors.append("LLM_API_KEY 未配置") - if not cls.ZEP_API_KEY: - errors.append("ZEP_API_KEY 未配置") + errors.extend(cls.get_graph_backend_config_errors()) return errors - diff --git a/backend/app/graph/__init__.py b/backend/app/graph/__init__.py new file mode 100644 index 00000000..3e9b2d7b --- /dev/null +++ b/backend/app/graph/__init__.py @@ -0,0 +1,8 @@ +""" +Graph backend abstractions. +""" + +from .base import GraphBackend +from .factory import get_graph_backend + +__all__ = ["GraphBackend", "get_graph_backend"] diff --git a/backend/app/graph/base.py b/backend/app/graph/base.py new file mode 100644 index 00000000..143849e7 --- /dev/null +++ b/backend/app/graph/base.py @@ -0,0 +1,81 @@ +""" +Common graph backend interface. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + + +class GraphBackend(ABC): + """Minimal graph backend interface used by the application services.""" + + @property + def raw_client(self) -> Any: + """Return the underlying SDK client when a service still needs it.""" + return None + + @abstractmethod + def create_graph(self, graph_id: str, name: str, description: str) -> None: + """Create a graph.""" + + @abstractmethod + def set_ontology( + self, + graph_id: str, + entities: Any = None, + edges: Any = None, + ) -> None: + """Set graph ontology.""" + + @abstractmethod + def add_batch(self, graph_id: str, episodes: List[Any]) -> List[Any]: + """Add a batch of episodes.""" + + @abstractmethod + def add_text(self, graph_id: str, data: str) -> Any: + """Add a single text episode.""" + + @abstractmethod + def get_episode(self, episode_uuid: str) -> Any: + """Fetch a single episode.""" + + @abstractmethod + def search( + self, + graph_id: str, + query: str, + limit: int = 10, + scope: str = "edges", + reranker: Optional[str] = None, + ) -> Any: + """Search the graph.""" + + @abstractmethod + def get_all_nodes(self, graph_id: str) -> List[Any]: + """Fetch all nodes.""" + + @abstractmethod + def get_all_edges(self, graph_id: str) -> List[Any]: + """Fetch all edges.""" + + @abstractmethod + def get_node(self, node_uuid: str) -> Any: + """Fetch a node by UUID.""" + + @abstractmethod + def get_node_edges(self, node_uuid: str) -> List[Any]: + """Fetch edges for a node.""" + + @abstractmethod + def delete_graph(self, graph_id: str) -> None: + """Delete a graph.""" + + def get_ontology_spec(self, graph_id: str) -> Optional[Dict[str, Any]]: + """Fetch backend ontology metadata when available.""" + return None + + def get_live_graph_statistics(self, graph_id: str) -> Optional[Dict[str, int]]: + """Fetch backend-specific live statistics when available.""" + return None diff --git a/backend/app/graph/factory.py b/backend/app/graph/factory.py new file mode 100644 index 00000000..f2e00b0f --- /dev/null +++ b/backend/app/graph/factory.py @@ -0,0 +1,27 @@ +""" +Graph backend factory. +""" + +from __future__ import annotations + +from typing import Optional + +from ..config import Config +from .base import GraphBackend + + +def get_graph_backend(api_key: Optional[str] = None) -> GraphBackend: + """Create the configured graph backend.""" + backend = (Config.GRAPH_BACKEND or "graphiti").lower() + + if backend in {"zep", "openzep"}: + from .zep_backend import ZepGraphBackend + + return ZepGraphBackend(api_key=api_key) + + if backend == "graphiti": + from .graphiti_backend import GraphitiBackend + + return GraphitiBackend(api_key=api_key) + + raise ValueError(f"不支持的 GRAPH_BACKEND: {backend}") diff --git a/backend/app/graph/graphiti_backend.py b/backend/app/graph/graphiti_backend.py new file mode 100644 index 00000000..83960592 --- /dev/null +++ b/backend/app/graph/graphiti_backend.py @@ -0,0 +1,875 @@ +""" +Graphiti + Neo4j graph backend implementation. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import threading +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field, create_model + +from ..config import Config +from .base import GraphBackend + +logger = logging.getLogger(__name__) + + +@dataclass +class _CompatEpisode: + uuid: str + processed: bool = True + name: str = "" + content: str = "" + valid_at: Optional[datetime] = None + created_at: Optional[datetime] = None + + @property + def uuid_(self) -> str: + return self.uuid + + +@dataclass +class _CompatNode: + uuid: str + name: str = "" + labels: List[str] = field(default_factory=list) + summary: str = "" + attributes: Dict[str, Any] = field(default_factory=dict) + created_at: Optional[datetime] = None + + @property + def uuid_(self) -> str: + return self.uuid + + +@dataclass +class _CompatEdge: + uuid: str + name: str + fact: str + source_node_uuid: str + target_node_uuid: str + source_node_name: str = "" + target_node_name: str = "" + attributes: Dict[str, Any] = field(default_factory=dict) + episodes: List[str] = field(default_factory=list) + created_at: Optional[datetime] = None + valid_at: Optional[datetime] = None + invalid_at: Optional[datetime] = None + expired_at: Optional[datetime] = None + + @property + def uuid_(self) -> str: + return self.uuid + + +@dataclass +class _CompatSearchResults: + edges: List[_CompatEdge] = field(default_factory=list) + nodes: List[_CompatNode] = field(default_factory=list) + + +@dataclass +class _OntologyBundle: + entity_types: Dict[str, type[BaseModel]] = field(default_factory=dict) + edge_types: Dict[str, type[BaseModel]] = field(default_factory=dict) + edge_type_map: Dict[tuple[str, str], List[str]] = field(default_factory=dict) + spec: Dict[str, Any] = field(default_factory=dict) + + +class _AsyncBridge: + """Run async Graphiti calls from the app's synchronous service layer.""" + + def __init__(self): + self._ready = threading.Event() + self._loop: Optional[asyncio.AbstractEventLoop] = None + self._thread = threading.Thread(target=self._run_loop, daemon=True) + self._thread.start() + self._ready.wait() + + def _run_loop(self): + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + self._loop = loop + self._ready.set() + loop.run_forever() + + def run(self, coro): + if self._loop is None: + raise RuntimeError("Graphiti async loop 未初始化") + future = asyncio.run_coroutine_threadsafe(coro, self._loop) + return future.result() + + +class GraphitiBackend(GraphBackend): + """Graph backend backed by Graphiti OSS + Neo4j.""" + + _bridge: Optional[_AsyncBridge] = None + _bridge_lock = threading.Lock() + _indices_ready = False + _indices_lock = threading.Lock() + _ontology_registry: Dict[str, _OntologyBundle] = {} + _ontology_lock = threading.Lock() + _cross_encoder_warning_emitted = False + PAGE_SIZE = 200 + + def __init__(self, api_key: Optional[str] = None): + del api_key + + errors = Config.get_graphiti_config_errors() + if errors: + raise ValueError("; ".join(errors)) + + try: + from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient + from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig + from graphiti_core.graphiti import Graphiti + from graphiti_core.llm_client import OpenAIClient + from graphiti_core.llm_client.config import LLMConfig + from graphiti_core.llm_client.openai_generic_client import OpenAIGenericClient + except ImportError as exc: + raise ImportError( + "Graphiti 依赖未安装,请先在 backend 环境中安装 graphiti-core 与 neo4j" + ) from exc + + llm_config = LLMConfig( + api_key=Config.GRAPHITI_LLM_API_KEY, + base_url=Config.GRAPHITI_LLM_BASE_URL, + model=Config.GRAPHITI_LLM_MODEL, + small_model=Config.GRAPHITI_LLM_SMALL_MODEL, + temperature=0, + ) + reranker_config = LLMConfig( + api_key=Config.GRAPHITI_RERANKER_API_KEY, + base_url=Config.GRAPHITI_RERANKER_BASE_URL, + model=Config.GRAPHITI_RERANKER_MODEL, + temperature=0, + ) + embedder_config = OpenAIEmbedderConfig( + api_key=Config.GRAPHITI_EMBEDDER_API_KEY, + base_url=Config.GRAPHITI_EMBEDDER_BASE_URL, + embedding_model=Config.GRAPHITI_EMBEDDER_MODEL, + embedding_dim=Config.GRAPHITI_EMBEDDER_DIM, + ) + + llm_client_mode = (Config.GRAPHITI_LLM_CLIENT_MODE or "openai").lower() + if llm_client_mode == "generic": + llm_client = OpenAIGenericClient( + config=llm_config, + max_tokens=Config.GRAPHITI_LLM_MAX_TOKENS, + ) + else: + llm_client = OpenAIClient( + config=llm_config, + max_tokens=Config.GRAPHITI_LLM_MAX_TOKENS, + ) + + self._graphiti = Graphiti( + uri=Config.GRAPHITI_URI, + user=Config.GRAPHITI_USER, + password=Config.GRAPHITI_PASSWORD, + llm_client=llm_client, + embedder=OpenAIEmbedder(config=embedder_config), + cross_encoder=OpenAIRerankerClient(config=reranker_config), + max_coroutines=Config.GRAPHITI_MAX_COROUTINES, + ) + self._driver = self._graphiti.driver.with_database(Config.GRAPHITI_DATABASE) + self._graphiti.driver = self._driver + self._graphiti.clients.driver = self._driver + self._bridge = self._get_bridge() + + self._ensure_indices() + + @classmethod + def _get_bridge(cls) -> _AsyncBridge: + with cls._bridge_lock: + if cls._bridge is None: + cls._bridge = _AsyncBridge() + return cls._bridge + + @property + def raw_client(self) -> Any: + return self._graphiti + + def _run(self, coro): + return self._bridge.run(coro) + + def _ensure_indices(self) -> None: + if self.__class__._indices_ready: + return + + with self.__class__._indices_lock: + if self.__class__._indices_ready: + return + self._run(self._graphiti.build_indices_and_constraints()) + self.__class__._indices_ready = True + + def _validate_graph_id(self, graph_id: str) -> None: + from graphiti_core.helpers import validate_group_id + + validate_group_id(graph_id) + + def _normalize_model_spec(self, model: type[BaseModel]) -> Dict[str, Any]: + fields = [] + for field_name, model_field in model.model_fields.items(): + fields.append( + { + "name": field_name, + "description": model_field.description or field_name, + } + ) + + return { + "description": (getattr(model, "__doc__", "") or "").strip(), + "fields": fields, + } + + def _build_dynamic_model(self, name: str, spec: Dict[str, Any]) -> type[BaseModel]: + field_definitions = {} + for field_spec in spec.get("fields", []): + field_name = field_spec.get("name", "").strip() + if not field_name: + continue + field_definitions[field_name] = ( + Optional[str], + Field( + default=None, + description=field_spec.get("description") or field_name, + ), + ) + + model = create_model(name, __base__=BaseModel, **field_definitions) + model.__doc__ = spec.get("description") or name + return model + + def _serialize_ontology_spec( + self, + entity_specs: Dict[str, Dict[str, Any]], + edge_specs: Dict[str, Dict[str, Any]], + edge_type_map: Dict[tuple[str, str], List[str]], + ) -> Dict[str, Any]: + return { + "entity_types": entity_specs, + "edge_types": edge_specs, + "edge_type_map": [ + { + "source": source, + "target": target, + "edges": edge_names, + } + for (source, target), edge_names in sorted(edge_type_map.items()) + ], + } + + def _bundle_from_spec(self, spec: Dict[str, Any]) -> _OntologyBundle: + entity_types = { + name: self._build_dynamic_model(name, model_spec) + for name, model_spec in (spec.get("entity_types") or {}).items() + } + edge_types = { + name: self._build_dynamic_model(name, model_spec) + for name, model_spec in (spec.get("edge_types") or {}).items() + } + edge_type_map: Dict[tuple[str, str], List[str]] = {} + for entry in spec.get("edge_type_map") or []: + source = entry.get("source", "Entity") + target = entry.get("target", "Entity") + edge_type_map[(source, target)] = list(entry.get("edges") or []) + + return _OntologyBundle( + entity_types=entity_types, + edge_types=edge_types, + edge_type_map=edge_type_map, + spec=spec, + ) + + def _build_ontology_bundle(self, entities: Any = None, edges: Any = None) -> _OntologyBundle: + entity_specs = {} + entity_types = {} + for entity_name, entity_model in (entities or {}).items(): + entity_spec = self._normalize_model_spec(entity_model) + entity_specs[entity_name] = entity_spec + entity_types[entity_name] = self._build_dynamic_model(entity_name, entity_spec) + + edge_specs = {} + edge_types = {} + edge_type_map: Dict[tuple[str, str], List[str]] = {} + for edge_name, edge_value in (edges or {}).items(): + if not isinstance(edge_value, tuple) or len(edge_value) != 2: + continue + edge_model, source_targets = edge_value + edge_spec = self._normalize_model_spec(edge_model) + edge_specs[edge_name] = edge_spec + edge_types[edge_name] = self._build_dynamic_model(edge_name, edge_spec) + + for source_target in source_targets or []: + source = getattr(source_target, "source", "Entity") or "Entity" + target = getattr(source_target, "target", "Entity") or "Entity" + edge_type_map.setdefault((source, target), []).append(edge_name) + + if not edge_type_map: + edge_type_map = {("Entity", "Entity"): list(edge_types.keys())} + + spec = self._serialize_ontology_spec(entity_specs, edge_specs, edge_type_map) + return _OntologyBundle( + entity_types=entity_types, + edge_types=edge_types, + edge_type_map=edge_type_map, + spec=spec, + ) + + async def _upsert_graph_metadata_async( + self, + graph_id: str, + name: Optional[str] = None, + description: Optional[str] = None, + ontology_spec: Optional[Dict[str, Any]] = None, + ) -> None: + payload = json.dumps(ontology_spec, ensure_ascii=False) if ontology_spec is not None else None + + await self._driver.execute_query( + """ + MERGE (m:GraphMetadata {graph_id: $graph_id}) + ON CREATE SET + m.group_id = $graph_id, + m.created_at = datetime() + SET + m.name = CASE WHEN $name IS NULL OR $name = '' THEN coalesce(m.name, '') ELSE $name END, + m.description = CASE + WHEN $description IS NULL OR $description = '' + THEN coalesce(m.description, '') + ELSE $description + END, + m.ontology_json = CASE + WHEN $ontology_json IS NULL OR $ontology_json = '' + THEN m.ontology_json + ELSE $ontology_json + END, + m.updated_at = datetime() + """, + graph_id=graph_id, + name=name, + description=description, + ontology_json=payload, + ) + + async def _load_ontology_bundle_async(self, graph_id: str) -> _OntologyBundle: + records, _, _ = await self._driver.execute_query( + """ + MATCH (m:GraphMetadata {graph_id: $graph_id}) + RETURN m.ontology_json AS ontology_json + """, + graph_id=graph_id, + routing_="r", + ) + + if not records: + return _OntologyBundle() + + ontology_json = records[0].get("ontology_json") + if not ontology_json: + return _OntologyBundle() + + try: + spec = json.loads(ontology_json) + except (TypeError, ValueError, json.JSONDecodeError): + logger.warning("Graphiti ontology metadata 解析失败,graph_id=%s", graph_id) + return _OntologyBundle() + + return self._bundle_from_spec(spec) + + async def _get_ontology_bundle_async(self, graph_id: str) -> _OntologyBundle: + with self.__class__._ontology_lock: + cached = self.__class__._ontology_registry.get(graph_id) + if cached is not None: + return cached + + bundle = await self._load_ontology_bundle_async(graph_id) + with self.__class__._ontology_lock: + self.__class__._ontology_registry[graph_id] = bundle + return bundle + + def _get_ontology_bundle(self, graph_id: str) -> _OntologyBundle: + return self._run(self._get_ontology_bundle_async(graph_id)) + + def _set_ontology_bundle(self, graph_id: str, bundle: _OntologyBundle) -> None: + with self.__class__._ontology_lock: + self.__class__._ontology_registry[graph_id] = bundle + + def get_ontology_spec(self, graph_id: str) -> Optional[Dict[str, Any]]: + self._validate_graph_id(graph_id) + bundle = self._get_ontology_bundle(graph_id) + return dict(bundle.spec) if bundle.spec else None + + def create_graph(self, graph_id: str, name: str, description: str) -> None: + self._validate_graph_id(graph_id) + self._run( + self._upsert_graph_metadata_async( + graph_id=graph_id, + name=name, + description=description, + ) + ) + + def set_ontology( + self, + graph_id: str, + entities: Any = None, + edges: Any = None, + ) -> None: + self._validate_graph_id(graph_id) + bundle = self._build_ontology_bundle(entities=entities, edges=edges) + self._set_ontology_bundle(graph_id, bundle) + self._run( + self._upsert_graph_metadata_async( + graph_id=graph_id, + ontology_spec=bundle.spec, + ) + ) + + async def _add_text_async(self, graph_id: str, data: str) -> _CompatEpisode: + from graphiti_core.helpers import validate_excluded_entity_types + from graphiti_core.nodes import EpisodeType, EpisodicNode + from graphiti_core.search.search_utils import RELEVANT_SCHEMA_LIMIT + from graphiti_core.utils.datetime_utils import utc_now + from graphiti_core.utils.maintenance.node_operations import ( + extract_attributes_from_nodes, + extract_nodes, + resolve_extracted_nodes, + ) + from graphiti_core.utils.ontology_utils.entity_types_utils import validate_entity_types + + bundle = await self._get_ontology_bundle_async(graph_id) + entity_types = bundle.entity_types or None + edge_types = bundle.edge_types or None + edge_type_map = bundle.edge_type_map or {("Entity", "Entity"): []} + + validate_entity_types(entity_types) + validate_excluded_entity_types(None, entity_types) + + now = utc_now() + previous_episodes = await self._graphiti.retrieve_episodes( + reference_time=now, + last_n=RELEVANT_SCHEMA_LIMIT, + group_ids=[graph_id], + source=EpisodeType.text, + driver=self._driver, + ) + + episode = EpisodicNode( + name=f"episode_{now.strftime('%Y%m%d%H%M%S%f')}", + group_id=graph_id, + labels=[], + source=EpisodeType.text, + content=data, + source_description="text", + created_at=now, + valid_at=now, + ) + + extracted_nodes = await extract_nodes( + self._graphiti.clients, + episode, + previous_episodes, + entity_types, + None, + None, + ) + nodes, uuid_map, _ = await resolve_extracted_nodes( + self._graphiti.clients, + extracted_nodes, + episode, + previous_episodes, + entity_types, + ) + resolved_edges, invalidated_edges, new_edges = await self._graphiti._extract_and_resolve_edges( + episode, + extracted_nodes, + previous_episodes, + edge_type_map, + graph_id, + edge_types, + nodes, + uuid_map, + None, + ) + entity_edges = resolved_edges + invalidated_edges + hydrated_nodes = await extract_attributes_from_nodes( + self._graphiti.clients, + nodes, + episode, + previous_episodes, + entity_types, + edges=new_edges, + ) + _, saved_episode = await self._graphiti._process_episode_data( + episode, + hydrated_nodes, + entity_edges, + now, + graph_id, + None, + None, + ) + + return _CompatEpisode( + uuid=saved_episode.uuid, + processed=True, + name=saved_episode.name, + content=saved_episode.content, + valid_at=saved_episode.valid_at, + created_at=saved_episode.created_at, + ) + + def add_batch(self, graph_id: str, episodes: List[Any]) -> List[Any]: + results = [] + for episode in episodes: + data = getattr(episode, "data", None) + if data is None and isinstance(episode, dict): + data = episode.get("data", "") + results.append(self.add_text(graph_id=graph_id, data=str(data or ""))) + return results + + def add_text(self, graph_id: str, data: str) -> Any: + self._validate_graph_id(graph_id) + return self._run(self._add_text_async(graph_id=graph_id, data=data)) + + async def _get_episode_async(self, episode_uuid: str) -> _CompatEpisode: + from graphiti_core.nodes import EpisodicNode + + episode = await EpisodicNode.get_by_uuid(self._driver, episode_uuid) + return _CompatEpisode( + uuid=episode.uuid, + processed=True, + name=episode.name, + content=episode.content, + valid_at=episode.valid_at, + created_at=episode.created_at, + ) + + def get_episode(self, episode_uuid: str) -> Any: + return self._run(self._get_episode_async(episode_uuid)) + + def _warn_cross_encoder_fallback(self) -> None: + if self.__class__._cross_encoder_warning_emitted: + return + logger.info( + "Graphiti cross_encoder 默认已降级为 rrf;如需启用请设置 GRAPHITI_ENABLE_CROSS_ENCODER=true" + ) + self.__class__._cross_encoder_warning_emitted = True + + def _build_search_config(self, scope: str, limit: int, reranker: Optional[str]): + from graphiti_core.search.search_config import ( + EdgeReranker, + EdgeSearchConfig, + EdgeSearchMethod, + NodeReranker, + NodeSearchConfig, + NodeSearchMethod, + SearchConfig, + ) + + reranker_name = (reranker or "rrf").strip().lower() + edge_reranker_map = { + "rrf": EdgeReranker.rrf, + "reciprocal_rank_fusion": EdgeReranker.rrf, + "cross_encoder": EdgeReranker.cross_encoder, + "node_distance": EdgeReranker.node_distance, + "episode_mentions": EdgeReranker.episode_mentions, + "mmr": EdgeReranker.mmr, + } + node_reranker_map = { + "rrf": NodeReranker.rrf, + "reciprocal_rank_fusion": NodeReranker.rrf, + "cross_encoder": NodeReranker.cross_encoder, + "node_distance": NodeReranker.node_distance, + "episode_mentions": NodeReranker.episode_mentions, + "mmr": NodeReranker.mmr, + } + + edge_reranker = edge_reranker_map.get(reranker_name, EdgeReranker.rrf) + node_reranker = node_reranker_map.get(reranker_name, NodeReranker.rrf) + edge_methods = [EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity] + node_methods = [NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity] + + if reranker_name == "cross_encoder": + if Config.GRAPHITI_ENABLE_CROSS_ENCODER: + edge_methods.append(EdgeSearchMethod.bfs) + node_methods.append(NodeSearchMethod.bfs) + else: + self._warn_cross_encoder_fallback() + edge_reranker = EdgeReranker.rrf + node_reranker = NodeReranker.rrf + + edge_config = None + node_config = None + if scope in {"edges", "both"}: + edge_config = EdgeSearchConfig( + search_methods=edge_methods, + reranker=edge_reranker, + ) + if scope in {"nodes", "both"}: + node_config = NodeSearchConfig( + search_methods=node_methods, + reranker=node_reranker, + ) + + return SearchConfig( + edge_config=edge_config, + node_config=node_config, + limit=max(1, limit), + ) + + def _wrap_node(self, node: Any) -> _CompatNode: + return _CompatNode( + uuid=getattr(node, "uuid", ""), + name=getattr(node, "name", "") or "", + labels=list(getattr(node, "labels", []) or []), + summary=getattr(node, "summary", "") or "", + attributes=dict(getattr(node, "attributes", {}) or {}), + created_at=getattr(node, "created_at", None), + ) + + def _wrap_edge( + self, + edge: Any, + source_node_name: str = "", + target_node_name: str = "", + ) -> _CompatEdge: + return _CompatEdge( + uuid=getattr(edge, "uuid", ""), + name=getattr(edge, "name", "") or "", + fact=getattr(edge, "fact", "") or "", + source_node_uuid=getattr(edge, "source_node_uuid", "") or "", + target_node_uuid=getattr(edge, "target_node_uuid", "") or "", + source_node_name=source_node_name, + target_node_name=target_node_name, + attributes=dict(getattr(edge, "attributes", {}) or {}), + episodes=list(getattr(edge, "episodes", []) or []), + created_at=getattr(edge, "created_at", None), + valid_at=getattr(edge, "valid_at", None), + invalid_at=getattr(edge, "invalid_at", None), + expired_at=getattr(edge, "expired_at", None), + ) + + async def _search_async( + self, + graph_id: str, + query: str, + limit: int, + scope: str, + reranker: Optional[str], + ) -> _CompatSearchResults: + from graphiti_core.nodes import EntityNode + + search_config = self._build_search_config(scope=scope, limit=limit, reranker=reranker) + results = await self._graphiti.search_( + query=query, + config=search_config, + group_ids=[graph_id], + driver=self._driver, + ) + + nodes = [self._wrap_node(node) for node in results.nodes] + node_name_map = {node.uuid: node.name for node in nodes if node.uuid} + + missing_node_ids = { + node_uuid + for edge in results.edges + for node_uuid in (edge.source_node_uuid, edge.target_node_uuid) + if node_uuid and node_uuid not in node_name_map + } + if missing_node_ids: + for node in await EntityNode.get_by_uuids(self._driver, list(missing_node_ids)): + node_name_map[node.uuid] = node.name or "" + + edges = [ + self._wrap_edge( + edge, + source_node_name=node_name_map.get(edge.source_node_uuid, ""), + target_node_name=node_name_map.get(edge.target_node_uuid, ""), + ) + for edge in results.edges + ] + + return _CompatSearchResults(edges=edges, nodes=nodes) + + def search( + self, + graph_id: str, + query: str, + limit: int = 10, + scope: str = "edges", + reranker: Optional[str] = None, + ) -> Any: + self._validate_graph_id(graph_id) + return self._run( + self._search_async( + graph_id=graph_id, + query=query, + limit=limit, + scope=scope, + reranker=reranker, + ) + ) + + + async def _get_all_nodes_async(self, graph_id: str) -> List[_CompatNode]: + from graphiti_core.nodes import EntityNode + + result = [] + cursor = None + while True: + batch = await EntityNode.get_by_group_ids( + self._driver, + [graph_id], + limit=self.PAGE_SIZE, + uuid_cursor=cursor, + ) + if not batch: + break + result.extend(self._wrap_node(node) for node in batch) + if len(batch) < self.PAGE_SIZE: + break + cursor = batch[-1].uuid + + return result + + def get_all_nodes(self, graph_id: str) -> List[Any]: + self._validate_graph_id(graph_id) + return self._run(self._get_all_nodes_async(graph_id)) + + async def _get_all_edges_async(self, graph_id: str) -> List[_CompatEdge]: + from graphiti_core.edges import EntityEdge, GroupsEdgesNotFoundError + + result = [] + cursor = None + while True: + try: + batch = await EntityEdge.get_by_group_ids( + self._driver, + [graph_id], + limit=self.PAGE_SIZE, + uuid_cursor=cursor, + ) + except GroupsEdgesNotFoundError: + break + + if not batch: + break + + result.extend(self._wrap_edge(edge) for edge in batch) + if len(batch) < self.PAGE_SIZE: + break + cursor = batch[-1].uuid + + return result + + def get_all_edges(self, graph_id: str) -> List[Any]: + self._validate_graph_id(graph_id) + return self._run(self._get_all_edges_async(graph_id)) + + async def _get_node_async(self, node_uuid: str) -> _CompatNode: + from graphiti_core.nodes import EntityNode + + return self._wrap_node(await EntityNode.get_by_uuid(self._driver, node_uuid)) + + def get_node(self, node_uuid: str) -> Any: + return self._run(self._get_node_async(node_uuid)) + + async def _get_node_edges_async(self, node_uuid: str) -> List[_CompatEdge]: + from graphiti_core.edges import EntityEdge + from graphiti_core.nodes import EntityNode + + edges = await EntityEdge.get_by_node_uuid(self._driver, node_uuid) + related_node_ids = { + related_uuid + for edge in edges + for related_uuid in (edge.source_node_uuid, edge.target_node_uuid) + if related_uuid + } + node_name_map = {} + if related_node_ids: + for node in await EntityNode.get_by_uuids(self._driver, list(related_node_ids)): + node_name_map[node.uuid] = node.name or "" + + return [ + self._wrap_edge( + edge, + source_node_name=node_name_map.get(edge.source_node_uuid, ""), + target_node_name=node_name_map.get(edge.target_node_uuid, ""), + ) + for edge in edges + ] + + def get_node_edges(self, node_uuid: str) -> List[Any]: + return self._run(self._get_node_edges_async(node_uuid)) + + async def _delete_graph_async(self, graph_id: str) -> None: + from graphiti_core.nodes import Node + + await self._driver.execute_query( + """ + MATCH (s:Saga {group_id: $graph_id}) + DETACH DELETE s + """, + graph_id=graph_id, + ) + await Node.delete_by_group_id(self._driver, graph_id) + await self._driver.execute_query( + """ + MATCH (m:GraphMetadata {graph_id: $graph_id}) + DETACH DELETE m + """, + graph_id=graph_id, + ) + + def delete_graph(self, graph_id: str) -> None: + self._validate_graph_id(graph_id) + self._run(self._delete_graph_async(graph_id)) + with self.__class__._ontology_lock: + self.__class__._ontology_registry.pop(graph_id, None) + + async def _get_live_graph_statistics_async(self, graph_id: str) -> Dict[str, int]: + node_records, _, _ = await self._driver.execute_query( + """ + MATCH (n:Entity {group_id: $graph_id}) + RETURN count(n) AS node_count + """, + graph_id=graph_id, + routing_="r", + ) + edge_records, _, _ = await self._driver.execute_query( + """ + MATCH ()-[e:RELATES_TO {group_id: $graph_id}]->() + RETURN count(e) AS edge_count + """, + graph_id=graph_id, + routing_="r", + ) + episode_records, _, _ = await self._driver.execute_query( + """ + MATCH (n:Episodic {group_id: $graph_id}) + RETURN count(n) AS episode_count + """, + graph_id=graph_id, + routing_="r", + ) + + return { + "node_count": int((node_records[0] if node_records else {}).get("node_count", 0) or 0), + "edge_count": int((edge_records[0] if edge_records else {}).get("edge_count", 0) or 0), + "episode_count": int( + (episode_records[0] if episode_records else {}).get("episode_count", 0) or 0 + ), + } + + def get_live_graph_statistics(self, graph_id: str) -> Optional[Dict[str, int]]: + self._validate_graph_id(graph_id) + return self._run(self._get_live_graph_statistics_async(graph_id)) diff --git a/backend/app/graph/zep_backend.py b/backend/app/graph/zep_backend.py new file mode 100644 index 00000000..8759eec2 --- /dev/null +++ b/backend/app/graph/zep_backend.py @@ -0,0 +1,114 @@ +""" +Zep / OpenZep graph backend implementation. +""" + +from __future__ import annotations + +import json +from typing import Any, Dict, List, Optional +from urllib.error import HTTPError, URLError +from urllib.request import Request, urlopen + +from zep_cloud.client import Zep + +from ..config import Config +from ..utils.zep_paging import fetch_all_edges, fetch_all_nodes +from .base import GraphBackend + + +class ZepGraphBackend(GraphBackend): + """Graph backend backed by Zep Cloud or OpenZep.""" + + def __init__(self, api_key: Optional[str] = None): + self.api_key = Config.ZEP_API_KEY if api_key is None else api_key + errors = Config.get_zep_config_errors(api_key=self.api_key) + if errors: + raise ValueError("; ".join(errors)) + + self.client = Zep(**Config.get_zep_client_kwargs(api_key=self.api_key)) + + @property + def raw_client(self) -> Zep: + return self.client + + def create_graph(self, graph_id: str, name: str, description: str) -> None: + self.client.graph.create( + graph_id=graph_id, + name=name, + description=description, + ) + + def set_ontology( + self, + graph_id: str, + entities: Any = None, + edges: Any = None, + ) -> None: + self.client.graph.set_ontology( + graph_ids=[graph_id], + entities=entities, + edges=edges, + ) + + def add_batch(self, graph_id: str, episodes: List[Any]) -> List[Any]: + return self.client.graph.add_batch(graph_id=graph_id, episodes=episodes) + + def add_text(self, graph_id: str, data: str) -> Any: + return self.client.graph.add(graph_id=graph_id, type="text", data=data) + + def get_episode(self, episode_uuid: str) -> Any: + return self.client.graph.episode.get(uuid_=episode_uuid) + + def search( + self, + graph_id: str, + query: str, + limit: int = 10, + scope: str = "edges", + reranker: Optional[str] = None, + ) -> Any: + kwargs = { + "graph_id": graph_id, + "query": query, + "limit": limit, + "scope": scope, + } + if reranker: + kwargs["reranker"] = reranker + return self.client.graph.search(**kwargs) + + def get_all_nodes(self, graph_id: str) -> List[Any]: + return fetch_all_nodes(self.client, graph_id) + + def get_all_edges(self, graph_id: str) -> List[Any]: + return fetch_all_edges(self.client, graph_id) + + def get_node(self, node_uuid: str) -> Any: + return self.client.graph.node.get(uuid_=node_uuid) + + def get_node_edges(self, node_uuid: str) -> List[Any]: + return self.client.graph.node.get_entity_edges(node_uuid=node_uuid) + + def delete_graph(self, graph_id: str) -> None: + self.client.graph.delete(graph_id=graph_id) + + def get_live_graph_statistics(self, graph_id: str) -> Optional[Dict[str, int]]: + if not Config.ZEP_BASE_URL: + return None + + base_url = Config.ZEP_BASE_URL.rstrip("/") + request = Request(f"{base_url}/graph/{graph_id}/statistics") + if self.api_key: + request.add_header("Authorization", f"Bearer {self.api_key}") + + try: + with urlopen(request, timeout=10) as response: + payload = json.loads(response.read().decode("utf-8")) + except (HTTPError, URLError, TimeoutError, OSError, json.JSONDecodeError): + return None + + return { + "node_count": max(0, int(payload.get("node_count", 0) or 0)), + "edge_count": max(0, int(payload.get("edge_count", 0) or 0)), + "episode_count": max(0, int(payload.get("episode_count", 0) or 0)), + } diff --git a/backend/app/services/graph_builder.py b/backend/app/services/graph_builder.py index 0e0444bf..c53c8e1f 100644 --- a/backend/app/services/graph_builder.py +++ b/backend/app/services/graph_builder.py @@ -7,15 +7,14 @@ import os import uuid import time import threading +import json from typing import Dict, Any, List, Optional, Callable from dataclasses import dataclass - -from zep_cloud.client import Zep from zep_cloud import EpisodeData, EntityEdgeSourceTarget from ..config import Config +from ..graph import get_graph_backend from ..models.task import TaskManager, TaskStatus -from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges from .text_processor import TextProcessor @@ -43,11 +42,12 @@ class GraphBuilderService: """ def __init__(self, api_key: Optional[str] = None): - self.api_key = api_key or Config.ZEP_API_KEY - if not self.api_key: - raise ValueError("ZEP_API_KEY 未配置") + self.api_key = Config.ZEP_API_KEY if api_key is None else api_key + errors = Config.get_graph_backend_config_errors(api_key=self.api_key) + if errors: + raise ValueError("; ".join(errors)) - self.client = Zep(api_key=self.api_key) + self.backend = get_graph_backend(api_key=self.api_key) self.task_manager = TaskManager() def build_graph_async( @@ -57,7 +57,7 @@ class GraphBuilderService: graph_name: str = "MiroFish Graph", chunk_size: int = 500, chunk_overlap: int = 50, - batch_size: int = 3 + batch_size: int = 1 ) -> str: """ 异步构建图谱 @@ -155,6 +155,7 @@ class GraphBuilderService: ) self._wait_for_episodes( + graph_id, episode_uuids, lambda msg, prog: self.task_manager.update_task( task_id, @@ -188,10 +189,10 @@ class GraphBuilderService: """创建Zep图谱(公开方法)""" graph_id = f"mirofish_{uuid.uuid4().hex[:16]}" - self.client.graph.create( + self.backend.create_graph( graph_id=graph_id, name=name, - description="MiroFish Social Simulation Graph" + description="MiroFish Social Simulation Graph", ) return graph_id @@ -279,8 +280,8 @@ class GraphBuilderService: # 调用Zep API设置本体 if entity_types or edge_definitions: - self.client.graph.set_ontology( - graph_ids=[graph_id], + self.backend.set_ontology( + graph_id=graph_id, entities=entity_types if entity_types else None, edges=edge_definitions if edge_definitions else None, ) @@ -289,7 +290,7 @@ class GraphBuilderService: self, graph_id: str, chunks: List[str], - batch_size: int = 3, + batch_size: int = 1, progress_callback: Optional[Callable] = None ) -> List[str]: """分批添加文本到图谱,返回所有 episode 的 uuid 列表""" @@ -316,7 +317,7 @@ class GraphBuilderService: # 发送到Zep try: - batch_result = self.client.graph.add_batch( + batch_result = self.backend.add_batch( graph_id=graph_id, episodes=episodes ) @@ -337,70 +338,152 @@ class GraphBuilderService: raise return episode_uuids - + + + + def _get_live_graph_statistics(self, graph_id: str) -> Optional[Dict[str, int]]: + """直接读取后端的实时图谱统计。""" + return self.backend.get_live_graph_statistics(graph_id) + def _wait_for_episodes( self, + graph_id: str, episode_uuids: List[str], progress_callback: Optional[Callable] = None, timeout: int = 600 ): - """等待所有 episode 处理完成(通过查询每个 episode 的 processed 状态)""" + """等待 OpenZep 处理完成,优先参考真实图谱状态。""" if not episode_uuids: if progress_callback: progress_callback("无需等待(没有 episode)", 1.0) return - + start_time = time.time() pending_episodes = set(episode_uuids) completed_count = 0 total_episodes = len(episode_uuids) - + last_graph_signature: Optional[tuple[int, int, int]] = None + stable_graph_polls = 0 + stable_graph_required = 2 + last_live_stats: Optional[Dict[str, int]] = None + if progress_callback: progress_callback(f"开始等待 {total_episodes} 个文本块处理...", 0) - + while pending_episodes: - if time.time() - start_time > timeout: + elapsed_seconds = time.time() - start_time + if elapsed_seconds > timeout: + if last_live_stats is not None: + graph_episode_count = min(last_live_stats["episode_count"], total_episodes) + graph_node_count = last_live_stats["node_count"] + graph_edge_count = last_live_stats["edge_count"] + graph_entity_like_nodes = max(0, graph_node_count - last_live_stats["episode_count"]) + if graph_episode_count >= total_episodes and (graph_entity_like_nodes > 0 or graph_edge_count > 0): + if progress_callback: + progress_callback( + ( + f"OpenZep 接口进度未返回完成标记,但真实图谱已写入 " + f"episodes={graph_episode_count}/{total_episodes}, " + f"nodes={graph_node_count}, edges={graph_edge_count}" + ), + 1.0, + ) + return if progress_callback: progress_callback( f"部分文本块超时,已完成 {completed_count}/{total_episodes}", completed_count / total_episodes ) break - - # 检查每个 episode 的处理状态 + for ep_uuid in list(pending_episodes): try: - episode = self.client.graph.episode.get(uuid_=ep_uuid) + episode = self.backend.get_episode(ep_uuid) is_processed = getattr(episode, 'processed', False) - + if is_processed: pending_episodes.remove(ep_uuid) completed_count += 1 - - except Exception as e: - # 忽略单个查询错误,继续 + + except Exception: pass - - elapsed = int(time.time() - start_time) - if progress_callback: - progress_callback( - f"Zep处理中... {completed_count}/{total_episodes} 完成, {len(pending_episodes)} 待处理 ({elapsed}秒)", - completed_count / total_episodes if total_episodes > 0 else 0 + + live_stats = self._get_live_graph_statistics(graph_id) + graph_episode_count = 0 + graph_node_count = 0 + graph_edge_count = 0 + graph_entity_like_nodes = 0 + graph_progress = 0.0 + + if live_stats is not None: + last_live_stats = live_stats + graph_episode_count = min(live_stats["episode_count"], total_episodes) + graph_node_count = live_stats["node_count"] + graph_edge_count = live_stats["edge_count"] + graph_entity_like_nodes = max(0, graph_node_count - live_stats["episode_count"]) + graph_progress = graph_episode_count / total_episodes if total_episodes > 0 else 1.0 + + graph_signature = ( + graph_episode_count, + graph_entity_like_nodes, + graph_edge_count, ) - + if graph_signature == last_graph_signature: + stable_graph_polls += 1 + else: + last_graph_signature = graph_signature + stable_graph_polls = 0 + + graph_ready = ( + graph_episode_count >= total_episodes + and (graph_entity_like_nodes > 0 or graph_edge_count > 0) + and stable_graph_polls >= stable_graph_required + ) + if graph_ready: + if progress_callback: + progress_callback( + ( + f"OpenZep 图谱已稳定: episodes={graph_episode_count}/{total_episodes}, " + f"nodes={graph_node_count}, edges={graph_edge_count}" + ), + 1.0, + ) + return + + elapsed = int(elapsed_seconds) + effective_progress = max( + completed_count / total_episodes if total_episodes > 0 else 1.0, + graph_progress, + ) + if progress_callback: + if live_stats is not None: + progress_callback( + ( + f"OpenZep处理中... 接口完成 {completed_count}/{total_episodes}, " + f"图中已写入 episodes={graph_episode_count}/{total_episodes}, " + f"nodes={graph_node_count}, edges={graph_edge_count} ({elapsed}秒)" + ), + effective_progress, + ) + else: + progress_callback( + f"Zep处理中... {completed_count}/{total_episodes} 完成, {len(pending_episodes)} 待处理 ({elapsed}秒)", + completed_count / total_episodes if total_episodes > 0 else 0 + ) + if pending_episodes: - time.sleep(3) # 每3秒检查一次 - + time.sleep(3) + if progress_callback: progress_callback(f"处理完成: {completed_count}/{total_episodes}", 1.0) - + def _get_graph_info(self, graph_id: str) -> GraphInfo: """获取图谱信息""" # 获取节点(分页) - nodes = fetch_all_nodes(self.client, graph_id) + nodes = self.backend.get_all_nodes(graph_id) # 获取边(分页) - edges = fetch_all_edges(self.client, graph_id) + edges = self.backend.get_all_edges(graph_id) # 统计实体类型 entity_types = set() @@ -427,8 +510,8 @@ class GraphBuilderService: Returns: 包含nodes和edges的字典,包括时间信息、属性等详细数据 """ - nodes = fetch_all_nodes(self.client, graph_id) - edges = fetch_all_edges(self.client, graph_id) + nodes = self.backend.get_all_nodes(graph_id) + edges = self.backend.get_all_edges(graph_id) # 创建节点映射用于获取节点名称 node_map = {} @@ -496,5 +579,4 @@ class GraphBuilderService: def delete_graph(self, graph_id: str): """删除图谱""" - self.client.graph.delete(graph_id=graph_id) - + self.backend.delete_graph(graph_id) diff --git a/backend/app/services/oasis_profile_generator.py b/backend/app/services/oasis_profile_generator.py index 57836c53..2aef63fb 100644 --- a/backend/app/services/oasis_profile_generator.py +++ b/backend/app/services/oasis_profile_generator.py @@ -16,9 +16,9 @@ from dataclasses import dataclass, field from datetime import datetime from openai import OpenAI -from zep_cloud.client import Zep from ..config import Config +from ..graph import get_graph_backend from ..utils.logger import get_logger from .zep_entity_reader import EntityNode, ZepEntityReader @@ -198,15 +198,15 @@ class OasisProfileGenerator: ) # Zep客户端用于检索丰富上下文 - self.zep_api_key = zep_api_key or Config.ZEP_API_KEY - self.zep_client = None + self.zep_api_key = Config.ZEP_API_KEY if zep_api_key is None else zep_api_key + self.zep_backend = None self.graph_id = graph_id - if self.zep_api_key: + if Config.is_graph_backend_configured(api_key=self.zep_api_key): try: - self.zep_client = Zep(api_key=self.zep_api_key) + self.zep_backend = get_graph_backend(api_key=self.zep_api_key) except Exception as e: - logger.warning(f"Zep客户端初始化失败: {e}") + logger.warning(f"图谱客户端初始化失败: {e}") def generate_profile_from_entity( self, @@ -297,7 +297,7 @@ class OasisProfileGenerator: """ import concurrent.futures - if not self.zep_client: + if not self.zep_backend: return {"facts": [], "node_summaries": [], "context": ""} entity_name = entity.name @@ -323,7 +323,7 @@ class OasisProfileGenerator: for attempt in range(max_retries): try: - return self.zep_client.graph.search( + return self.zep_backend.search( query=comprehensive_query, graph_id=self.graph_id, limit=30, @@ -348,7 +348,7 @@ class OasisProfileGenerator: for attempt in range(max_retries): try: - return self.zep_client.graph.search( + return self.zep_backend.search( query=comprehensive_query, graph_id=self.graph_id, limit=20, @@ -1197,4 +1197,3 @@ class OasisProfileGenerator: """[已废弃] 请使用 save_profiles() 方法""" logger.warning("save_profiles_to_json已废弃,请使用save_profiles方法") self.save_profiles(profiles, file_path, platform) - diff --git a/backend/app/services/zep_entity_reader.py b/backend/app/services/zep_entity_reader.py index 71661be4..0fa0193c 100644 --- a/backend/app/services/zep_entity_reader.py +++ b/backend/app/services/zep_entity_reader.py @@ -7,11 +7,9 @@ import time from typing import Dict, Any, List, Optional, Set, Callable, TypeVar from dataclasses import dataclass, field -from zep_cloud.client import Zep - from ..config import Config +from ..graph import get_graph_backend from ..utils.logger import get_logger -from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges logger = get_logger('mirofish.zep_entity_reader') @@ -79,11 +77,12 @@ class ZepEntityReader: """ def __init__(self, api_key: Optional[str] = None): - self.api_key = api_key or Config.ZEP_API_KEY - if not self.api_key: - raise ValueError("ZEP_API_KEY 未配置") + self.api_key = Config.ZEP_API_KEY if api_key is None else api_key + errors = Config.get_graph_backend_config_errors(api_key=self.api_key) + if errors: + raise ValueError("; ".join(errors)) - self.client = Zep(api_key=self.api_key) + self.backend = get_graph_backend(api_key=self.api_key) def _call_with_retry( self, @@ -136,7 +135,7 @@ class ZepEntityReader: """ logger.info(f"获取图谱 {graph_id} 的所有节点...") - nodes = fetch_all_nodes(self.client, graph_id) + nodes = self.backend.get_all_nodes(graph_id) nodes_data = [] for node in nodes: @@ -163,7 +162,7 @@ class ZepEntityReader: """ logger.info(f"获取图谱 {graph_id} 的所有边...") - edges = fetch_all_edges(self.client, graph_id) + edges = self.backend.get_all_edges(graph_id) edges_data = [] for edge in edges: @@ -192,7 +191,7 @@ class ZepEntityReader: try: # 使用重试机制调用Zep API edges = self._call_with_retry( - func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid), + func=lambda: self.backend.get_node_edges(node_uuid), operation_name=f"获取节点边(node={node_uuid[:8]}...)" ) @@ -348,7 +347,7 @@ class ZepEntityReader: try: # 使用重试机制获取节点 node = self._call_with_retry( - func=lambda: self.client.graph.node.get(uuid_=entity_uuid), + func=lambda: self.backend.get_node(entity_uuid), operation_name=f"获取节点详情(uuid={entity_uuid[:8]}...)" ) @@ -434,4 +433,3 @@ class ZepEntityReader: ) return result.entities - diff --git a/backend/app/services/zep_graph_memory_updater.py b/backend/app/services/zep_graph_memory_updater.py index a8f3cecd..508d1afc 100644 --- a/backend/app/services/zep_graph_memory_updater.py +++ b/backend/app/services/zep_graph_memory_updater.py @@ -12,9 +12,8 @@ from dataclasses import dataclass from datetime import datetime from queue import Queue, Empty -from zep_cloud.client import Zep - from ..config import Config +from ..graph import get_graph_backend from ..utils.logger import get_logger logger = get_logger('mirofish.zep_graph_memory_updater') @@ -237,12 +236,13 @@ class ZepGraphMemoryUpdater: api_key: Zep API Key(可选,默认从配置读取) """ self.graph_id = graph_id - self.api_key = api_key or Config.ZEP_API_KEY + self.api_key = Config.ZEP_API_KEY if api_key is None else api_key - if not self.api_key: - raise ValueError("ZEP_API_KEY未配置") + errors = Config.get_graph_backend_config_errors(api_key=self.api_key) + if errors: + raise ValueError("; ".join(errors)) - self.client = Zep(api_key=self.api_key) + self.backend = get_graph_backend(api_key=self.api_key) # 活动队列 self._activity_queue: Queue = Queue() @@ -405,9 +405,8 @@ class ZepGraphMemoryUpdater: # 带重试的发送 for attempt in range(self.MAX_RETRIES): try: - self.client.graph.add( + self.backend.add_text( graph_id=self.graph_id, - type="text", data=combined_text ) diff --git a/backend/app/services/zep_tools.py b/backend/app/services/zep_tools.py index 384cf540..320fb6a4 100644 --- a/backend/app/services/zep_tools.py +++ b/backend/app/services/zep_tools.py @@ -10,15 +10,18 @@ Zep检索工具服务 import time import json -from typing import Dict, Any, List, Optional +import math +import re +from collections import defaultdict +from typing import Any, Callable, Dict, List, Optional from dataclasses import dataclass, field -from zep_cloud.client import Zep - from ..config import Config +from ..graph import get_graph_backend +from ..utils.embedding_client import EmbeddingClient +from ..utils.reranker_client import RerankerClient from ..utils.logger import get_logger from ..utils.llm_client import LLMClient -from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges logger = get_logger('mirofish.zep_tools') @@ -422,13 +425,16 @@ class ZepToolsService: RETRY_DELAY = 2.0 def __init__(self, api_key: Optional[str] = None, llm_client: Optional[LLMClient] = None): - self.api_key = api_key or Config.ZEP_API_KEY - if not self.api_key: - raise ValueError("ZEP_API_KEY 未配置") + self.api_key = Config.ZEP_API_KEY if api_key is None else api_key + errors = Config.get_graph_backend_config_errors(api_key=self.api_key) + if errors: + raise ValueError("; ".join(errors)) - self.client = Zep(api_key=self.api_key) + self.backend = get_graph_backend(api_key=self.api_key) # LLM客户端用于InsightForge生成子问题 self._llm_client = llm_client + self._search_embedder_client = None + self._search_reranker_client = None logger.info("ZepToolsService 初始化完成") @property @@ -461,148 +467,696 @@ class ZepToolsService: raise last_exception + + + def _normalize_text(self, text: Optional[str]) -> str: + """标准化文本,便于后续打分和去重。""" + return " ".join(str(text or "").split()) + + def _query_tokens(self, query: str) -> List[str]: + """提取查询词,兼顾中英文。""" + normalized = self._normalize_text(query).lower() + tokens = set(re.findall(r"[a-z0-9_]+", normalized)) + + for run in re.findall(r"[一-鿿]+", normalized): + if len(run) <= 4: + tokens.add(run) + continue + + tokens.add(run) + for size in (2, 3, 4): + for idx in range(len(run) - size + 1): + tokens.add(run[idx:idx + size]) + + return [token for token in tokens if len(token) > 1] + + def _score_texts(self, query_lower: str, query_tokens: List[str], *parts: str) -> int: + """轻量级文本相关性打分,用于本地合并与退化检索。""" + combined = self._normalize_text(" ".join(part for part in parts if part)).lower() + if not combined: + return 0 + + score = 0 + if query_lower and query_lower in combined: + score += 120 + + for token in query_tokens: + if token in combined: + score += 12 if len(token) >= 3 else 5 + + return score + + def _graph_search_app_reranker(self) -> str: + """返回 app-side 检索重排模式。""" + return (Config.GRAPH_SEARCH_APP_RERANKER or "lexical").strip().lower() or "lexical" + + def _get_search_embedder(self) -> Optional[EmbeddingClient]: + """懒加载图搜索 embedding client。""" + if self._search_embedder_client is False: + return None + if self._search_embedder_client is not None: + return self._search_embedder_client + + embedder_config = Config.get_graph_search_embedder_config() + base_url = embedder_config.get("base_url") + model = embedder_config.get("model") + if not base_url or not model: + self._search_embedder_client = False + return None + + try: + self._search_embedder_client = EmbeddingClient( + api_key=embedder_config.get("api_key") or "ollama", + base_url=base_url, + model=model, + batch_size=Config.GRAPH_SEARCH_APP_EMBED_BATCH_SIZE, + ) + logger.info( + "图搜索语义重排已启用: mode=%s, model=%s", + self._graph_search_app_reranker(), + model, + ) + except Exception as exc: + logger.warning(f"图搜索 embedding reranker 初始化失败: {exc}") + self._search_embedder_client = False + return None + + return self._search_embedder_client + + def _edge_search_text(self, edge: Dict[str, Any]) -> str: + """构建边候选的语义检索文本。""" + return self._normalize_text( + " ".join( + part + for part in [ + edge.get("fact", ""), + edge.get("name", ""), + edge.get("source_node_name", ""), + edge.get("target_node_name", ""), + ] + if part + ) + ) + + def _node_search_text(self, node: Dict[str, Any]) -> str: + """构建节点候选的语义检索文本。""" + return self._normalize_text( + " ".join( + part + for part in [ + node.get("name", ""), + node.get("summary", ""), + " ".join(node.get("labels", [])), + ] + if part + ) + ) + + def _cosine_similarity(self, left: List[float], right: List[float]) -> float: + """计算两个 embedding 向量的余弦相似度。""" + if not left or not right or len(left) != len(right): + return 0.0 + + numerator = sum(a * b for a, b in zip(left, right)) + left_norm = math.sqrt(sum(a * a for a in left)) + right_norm = math.sqrt(sum(b * b for b in right)) + if left_norm == 0 or right_norm == 0: + return 0.0 + + return numerator / (left_norm * right_norm) + + def _rrf_score(self, rank: int) -> float: + """Reciprocal Rank Fusion score.""" + rank = max(1, int(rank)) + return 1.0 / (Config.GRAPH_SEARCH_APP_RERANK_FUSION_K + rank) + + def _sort_candidates(self, candidates: List[Dict[str, Any]], score_key: str) -> List[Dict[str, Any]]: + """按分数倒序、原始召回顺序升序排序。""" + return sorted( + candidates, + key=lambda item: ( + -float(item.get(score_key, 0.0) or 0.0), + int(item.get("_backend_rank", 10**9)), + ), + ) + + def _strip_candidate_meta(self, candidate: Dict[str, Any]) -> Dict[str, Any]: + """移除仅用于本地重排的内部字段。""" + return {key: value for key, value in candidate.items() if not key.startswith("_")} + + def _edge_candidate_key(self, edge: Dict[str, Any]) -> str: + """生成边候选的稳定 key。""" + return edge.get("uuid") or "|".join([ + edge.get("name", ""), + edge.get("fact", ""), + edge.get("source_node_uuid", ""), + edge.get("target_node_uuid", ""), + ]) + + def _edge_info_to_candidate(self, edge: EdgeInfo) -> Dict[str, Any]: + """将 EdgeInfo 转换为本地重排使用的候选字典。""" + edge_candidate = { + "uuid": edge.uuid, + "name": edge.name, + "fact": edge.fact, + "source_node_uuid": edge.source_node_uuid, + "target_node_uuid": edge.target_node_uuid, + "source_node_name": edge.source_node_name or "", + "target_node_name": edge.target_node_name or "", + } + edge_candidate["_candidate_key"] = self._edge_candidate_key(edge_candidate) + return edge_candidate + + def _expand_edge_candidates_from_nodes( + self, + graph_id: str, + ranked_nodes: List[Dict[str, Any]], + candidate_edges: Dict[str, Dict[str, Any]], + query_lower: str, + query_tokens: List[str], + ) -> None: + """从高相关节点补抓相邻边,提升边召回率。""" + if not Config.GRAPH_SEARCH_EXPAND_EDGES_FROM_NODES: + return + + node_limit = Config.GRAPH_SEARCH_NODE_EDGE_EXPANSION_LIMIT + per_node_limit = Config.GRAPH_SEARCH_NODE_EDGE_PER_NODE_LIMIT + if node_limit <= 0 or per_node_limit <= 0 or not ranked_nodes: + return + + expanded_node_count = 0 + added_edge_count = 0 + + for node_rank, node in enumerate(ranked_nodes[:node_limit], start=1): + node_uuid = node.get("uuid") + if not node_uuid: + continue + + related_edges = self.get_node_edges(graph_id, node_uuid) + if not related_edges: + continue + + expanded_node_count += 1 + scored_edges = [] + for edge in related_edges: + edge_candidate = self._edge_info_to_candidate(edge) + edge_key = edge_candidate.get("_candidate_key") + if not edge_key or edge_key in candidate_edges: + continue + + lexical_score = self._score_texts( + query_lower, + query_tokens, + self._edge_search_text(edge_candidate), + ) + scored_edges.append((lexical_score, node_rank, edge_candidate)) + + scored_edges.sort( + key=lambda item: ( + -int(item[0]), + int(item[1]), + item[2].get("name", ""), + item[2].get("fact", ""), + ) + ) + + for _, _, edge_candidate in scored_edges[:per_node_limit]: + edge_key = edge_candidate.get("_candidate_key") + if edge_key and edge_key not in candidate_edges: + candidate_edges[edge_key] = edge_candidate + added_edge_count += 1 + + if added_edge_count > 0: + logger.info( + "节点召回补边完成: expanded_nodes=%s, added_edges=%s", + expanded_node_count, + added_edge_count, + ) + + def _compute_semantic_scores(self, query: str, candidates: List[Dict[str, Any]]) -> Optional[Dict[str, float]]: + """使用 embedding 计算 query 与候选文本的相似度。""" + if len(candidates) < 2: + return {} + + embedder = self._get_search_embedder() + if embedder is None: + return None + + keyed_texts = [ + (candidate.get("_candidate_key", ""), candidate.get("_search_text", "")) + for candidate in candidates + if candidate.get("_candidate_key") and candidate.get("_search_text") + ] + if not keyed_texts: + return {} + + try: + embeddings = embedder.embed_texts([query] + [candidate_text for _, candidate_text in keyed_texts]) + except Exception as exc: + logger.warning(f"图搜索 embedding reranker 调用失败,回退到词面排序: {exc}") + return None + + if len(embeddings) != len(keyed_texts) + 1: + logger.warning("图搜索 embedding reranker 返回向量数量异常,回退到词面排序") + return None + + query_vector = embeddings[0] + scores: Dict[str, float] = {} + for (candidate_key, _), vector in zip(keyed_texts, embeddings[1:]): + scores[candidate_key] = self._cosine_similarity(query_vector, vector) + + return scores + + def _get_search_reranker(self) -> Optional[RerankerClient]: + """懒加载图搜索 API reranker client。""" + if self._search_reranker_client is False: + return None + if self._search_reranker_client is not None: + return self._search_reranker_client + + reranker_config = Config.get_graph_search_reranker_config() + base_url = reranker_config.get("base_url") + if not base_url: + self._search_reranker_client = False + return None + + try: + self._search_reranker_client = RerankerClient( + api_key=reranker_config.get("api_key"), + base_url=base_url, + model=reranker_config.get("model"), + provider=reranker_config.get("provider") or "auto", + timeout=reranker_config.get("timeout") or 20.0, + ) + logger.info( + "图搜索 API reranker 已启用: mode=%s, provider=%s, model=%s", + self._graph_search_app_reranker(), + reranker_config.get("provider") or "auto", + reranker_config.get("model"), + ) + except Exception as exc: + logger.warning(f"图搜索 API reranker 初始化失败: {exc}") + self._search_reranker_client = False + return None + + return self._search_reranker_client + + def _compute_api_rerank_scores(self, query: str, candidates: List[Dict[str, Any]]) -> Optional[Dict[str, float]]: + """使用独立 reranker endpoint 计算 query 与候选文本的相关性。""" + if len(candidates) < 2: + return {} + + reranker = self._get_search_reranker() + if reranker is None: + return None + + keyed_texts = [ + (candidate.get("_candidate_key", ""), candidate.get("_search_text", "")) + for candidate in candidates + if candidate.get("_candidate_key") and candidate.get("_search_text") + ] + if not keyed_texts: + return {} + + try: + score_by_index = reranker.rerank( + query=query, + documents=[candidate_text for _, candidate_text in keyed_texts], + ) + except Exception as exc: + logger.warning(f"图搜索 API reranker 调用失败,回退到词面排序: {exc}") + return None + + scores: Dict[str, float] = {} + for index, (candidate_key, _) in enumerate(keyed_texts): + scores[candidate_key] = float(score_by_index.get(index, 0.0)) + + return scores + + def _apply_app_rerank( + self, + candidates: List[Dict[str, Any]], + query_normalized: str, + query_lower: str, + query_tokens: List[str], + text_builder: Callable[[Dict[str, Any]], str], + ) -> List[Dict[str, Any]]: + """对候选结果执行 app-side 重排。""" + if not candidates: + return [] + + prepared: List[Dict[str, Any]] = [] + for backend_rank, original in enumerate(candidates, start=1): + candidate = dict(original) + candidate.setdefault( + "_candidate_key", + candidate.get("uuid") or candidate.get("name") or f"candidate_{backend_rank}", + ) + candidate["_backend_rank"] = backend_rank + candidate["_search_text"] = self._normalize_text(text_builder(candidate)) + candidate["_lexical_score"] = self._score_texts( + query_lower, + query_tokens, + candidate.get("_search_text", ""), + ) + prepared.append(candidate) + + mode = self._graph_search_app_reranker() + lexical_ranked = self._sort_candidates(prepared, "_lexical_score") + + if mode in {"none", "off"}: + return [self._strip_candidate_meta(candidate) for candidate in prepared] + + if mode in {"lexical", "keyword"}: + return [self._strip_candidate_meta(candidate) for candidate in lexical_ranked] + + semantic_modes = {"embedding_rrf", "semantic_rrf", "hybrid", "embedding_similarity", "semantic", "semantic_similarity"} + api_score_modes = {"api_rerank", "rerank_api", "cross_encoder", "cross_encoder_api"} + api_rrf_modes = {"api_rrf", "rerank_rrf", "cross_encoder_rrf"} + supported_modes = semantic_modes | api_score_modes | api_rrf_modes + + if mode not in supported_modes: + logger.warning(f"未知 GRAPH_SEARCH_APP_RERANKER={mode},回退到 lexical") + return [self._strip_candidate_meta(candidate) for candidate in lexical_ranked] + + if mode in semantic_modes: + semantic_scores = self._compute_semantic_scores(query_normalized, prepared) + if semantic_scores is None: + return [self._strip_candidate_meta(candidate) for candidate in lexical_ranked] + + for candidate in prepared: + candidate["_semantic_score"] = float(semantic_scores.get(candidate["_candidate_key"], 0.0)) + + semantic_ranked = self._sort_candidates(prepared, "_semantic_score") + + if mode in {"embedding_similarity", "semantic", "semantic_similarity"}: + ranked = sorted( + prepared, + key=lambda item: ( + -float(item.get("_semantic_score", 0.0) or 0.0), + -float(item.get("_lexical_score", 0.0) or 0.0), + int(item.get("_backend_rank", 10**9)), + ), + ) + return [self._strip_candidate_meta(candidate) for candidate in ranked] + + backend_ranks = {candidate["_candidate_key"]: idx for idx, candidate in enumerate(prepared, start=1)} + lexical_ranks = {candidate["_candidate_key"]: idx for idx, candidate in enumerate(lexical_ranked, start=1)} + semantic_ranks = {candidate["_candidate_key"]: idx for idx, candidate in enumerate(semantic_ranked, start=1)} + + for candidate in prepared: + candidate_key = candidate["_candidate_key"] + candidate["_fusion_score"] = ( + self._rrf_score(backend_ranks[candidate_key]) + + self._rrf_score(lexical_ranks[candidate_key]) + + (Config.GRAPH_SEARCH_APP_SEMANTIC_WEIGHT * self._rrf_score(semantic_ranks[candidate_key])) + ) + + ranked = sorted( + prepared, + key=lambda item: ( + -float(item.get("_fusion_score", 0.0) or 0.0), + -float(item.get("_semantic_score", 0.0) or 0.0), + -float(item.get("_lexical_score", 0.0) or 0.0), + int(item.get("_backend_rank", 10**9)), + ), + ) + return [self._strip_candidate_meta(candidate) for candidate in ranked] + + api_scores = self._compute_api_rerank_scores(query_normalized, prepared) + if api_scores is None: + return [self._strip_candidate_meta(candidate) for candidate in lexical_ranked] + + for candidate in prepared: + candidate["_api_rerank_score"] = float(api_scores.get(candidate["_candidate_key"], 0.0)) + + api_ranked = self._sort_candidates(prepared, "_api_rerank_score") + + if mode in api_score_modes: + ranked = sorted( + prepared, + key=lambda item: ( + -float(item.get("_api_rerank_score", 0.0) or 0.0), + -float(item.get("_lexical_score", 0.0) or 0.0), + int(item.get("_backend_rank", 10**9)), + ), + ) + return [self._strip_candidate_meta(candidate) for candidate in ranked] + + backend_ranks = {candidate["_candidate_key"]: idx for idx, candidate in enumerate(prepared, start=1)} + lexical_ranks = {candidate["_candidate_key"]: idx for idx, candidate in enumerate(lexical_ranked, start=1)} + api_ranks = {candidate["_candidate_key"]: idx for idx, candidate in enumerate(api_ranked, start=1)} + + for candidate in prepared: + candidate_key = candidate["_candidate_key"] + candidate["_fusion_score"] = ( + self._rrf_score(backend_ranks[candidate_key]) + + self._rrf_score(lexical_ranks[candidate_key]) + + (Config.GRAPH_SEARCH_APP_SEMANTIC_WEIGHT * self._rrf_score(api_ranks[candidate_key])) + ) + + ranked = sorted( + prepared, + key=lambda item: ( + -float(item.get("_fusion_score", 0.0) or 0.0), + -float(item.get("_api_rerank_score", 0.0) or 0.0), + -float(item.get("_lexical_score", 0.0) or 0.0), + int(item.get("_backend_rank", 10**9)), + ), + ) + return [self._strip_candidate_meta(candidate) for candidate in ranked] + + def _search_scope(self, graph_id: str, query: str, scope: str, limit: int) -> Any: + """执行单个 scope 的后端检索。""" + reranker = Config.GRAPH_SEARCH_RERANKER + return self._call_with_retry( + func=lambda: self.backend.search( + graph_id=graph_id, + query=query, + limit=limit, + scope=scope, + reranker=reranker, + ), + operation_name=f"图谱搜索(graph={graph_id}, scope={scope})", + ) + + def _parse_search_edges(self, search_results: Any) -> List[Dict[str, Any]]: + edges = [] + if hasattr(search_results, 'edges') and search_results.edges: + for edge in search_results.edges: + edges.append({ + "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), + "name": getattr(edge, 'name', ''), + "fact": getattr(edge, 'fact', ''), + "source_node_uuid": getattr(edge, 'source_node_uuid', ''), + "target_node_uuid": getattr(edge, 'target_node_uuid', ''), + "source_node_name": getattr(edge, 'source_node_name', ''), + "target_node_name": getattr(edge, 'target_node_name', ''), + }) + return edges + + def _parse_search_nodes(self, search_results: Any) -> List[Dict[str, Any]]: + nodes = [] + if hasattr(search_results, 'nodes') and search_results.nodes: + for node in search_results.nodes: + nodes.append({ + "uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), + "name": getattr(node, 'name', ''), + "labels": getattr(node, 'labels', []), + "summary": getattr(node, 'summary', ''), + }) + return nodes + def search_graph( - self, - graph_id: str, - query: str, + self, + graph_id: str, + query: str, limit: int = 10, scope: str = "edges" ) -> SearchResult: """ - 图谱语义搜索 - - 使用混合搜索(语义+BM25)在图谱中搜索相关信息。 - 如果Zep Cloud的search API不可用,则降级为本地关键词匹配。 - - Args: - graph_id: 图谱ID (Standalone Graph) - query: 搜索查询 - limit: 返回结果数量 - scope: 搜索范围,"edges" 或 "nodes" - - Returns: - SearchResult: 搜索结果 + 图谱语义搜索。 + + 当前实现会优先召回边,再按配置补充节点摘要,避免只拿到零散 fact + 或在 OpenZep 上完全退化到本地关键词搜索。 """ logger.info(f"图谱搜索: graph_id={graph_id}, query={query[:50]}...") - - # 尝试使用Zep Cloud Search API - try: - search_results = self._call_with_retry( - func=lambda: self.client.graph.search( + + query_normalized = self._normalize_text(query) + query_lower = query_normalized.lower() + query_tokens = self._query_tokens(query_normalized) + + edge_limit = max(limit, limit * Config.GRAPH_SEARCH_EDGE_LIMIT_MULTIPLIER) + node_limit = max(limit, limit * Config.GRAPH_SEARCH_NODE_LIMIT_MULTIPLIER) + + candidate_edges: Dict[str, Dict[str, Any]] = {} + candidate_nodes: Dict[str, Dict[str, Any]] = {} + search_errors: List[str] = [] + + scopes_to_search: List[tuple[str, int]] = [] + if scope in {"edges", "both"}: + scopes_to_search.append(("edges", edge_limit)) + if scope in {"nodes", "both"} or Config.GRAPH_SEARCH_INCLUDE_NODES: + scopes_to_search.append(("nodes", node_limit)) + + for search_scope, scoped_limit in scopes_to_search: + try: + search_results = self._search_scope( graph_id=graph_id, - query=query, - limit=limit, - scope=scope, - reranker="cross_encoder" - ), - operation_name=f"图谱搜索(graph={graph_id})" + query=query_normalized, + scope=search_scope, + limit=scoped_limit, + ) + except Exception as e: + logger.warning(f"{search_scope} 检索失败: {str(e)}") + search_errors.append(f"{search_scope}:{str(e)}") + continue + + if search_scope == "edges": + for edge in self._parse_search_edges(search_results): + edge_key = self._edge_candidate_key(edge) + if edge_key and edge_key not in candidate_edges: + edge["_candidate_key"] = edge_key + candidate_edges[edge_key] = edge + else: + for node in self._parse_search_nodes(search_results): + node_key = node["uuid"] or node.get("name", "") + if node_key and node_key not in candidate_nodes: + node["_candidate_key"] = node_key + candidate_nodes[node_key] = node + + if not candidate_edges and not candidate_nodes: + if search_errors: + logger.warning("后端图搜索不可用,降级为本地搜索") + return self._local_search(graph_id, query_normalized, limit, scope) + + ranked_nodes = self._apply_app_rerank( + list(candidate_nodes.values()), + query_normalized=query_normalized, + query_lower=query_lower, + query_tokens=query_tokens, + text_builder=self._node_search_text, + ) + + if scope in {"edges", "both"} and ranked_nodes: + self._expand_edge_candidates_from_nodes( + graph_id=graph_id, + ranked_nodes=ranked_nodes, + candidate_edges=candidate_edges, + query_lower=query_lower, + query_tokens=query_tokens, ) - - facts = [] - edges = [] - nodes = [] - - # 解析边搜索结果 - if hasattr(search_results, 'edges') and search_results.edges: - for edge in search_results.edges: - if hasattr(edge, 'fact') and edge.fact: - facts.append(edge.fact) - edges.append({ - "uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''), - "name": getattr(edge, 'name', ''), - "fact": getattr(edge, 'fact', ''), - "source_node_uuid": getattr(edge, 'source_node_uuid', ''), - "target_node_uuid": getattr(edge, 'target_node_uuid', ''), - }) - - # 解析节点搜索结果 - if hasattr(search_results, 'nodes') and search_results.nodes: - for node in search_results.nodes: - nodes.append({ - "uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), - "name": getattr(node, 'name', ''), - "labels": getattr(node, 'labels', []), - "summary": getattr(node, 'summary', ''), - }) - # 节点摘要也算作事实 - if hasattr(node, 'summary') and node.summary: - facts.append(f"[{node.name}]: {node.summary}") - - logger.info(f"搜索完成: 找到 {len(facts)} 条相关事实") - - return SearchResult( - facts=facts, - edges=edges, - nodes=nodes, - query=query, - total_count=len(facts) + + ranked_edges = self._apply_app_rerank( + list(candidate_edges.values()), + query_normalized=query_normalized, + query_lower=query_lower, + query_tokens=query_tokens, + text_builder=self._edge_search_text, + ) + + selected_edges = ranked_edges[:limit] if scope in {"edges", "both"} else [] + + if scope == "nodes": + selected_nodes = ranked_nodes[:limit] + else: + node_summary_limit = min(Config.GRAPH_SEARCH_NODE_SUMMARY_LIMIT, max(1, limit)) + related_node_uuids = { + edge.get("source_node_uuid", "") + for edge in selected_edges + if edge.get("source_node_uuid") + } + related_node_uuids.update( + edge.get("target_node_uuid", "") + for edge in selected_edges + if edge.get("target_node_uuid") ) - - except Exception as e: - logger.warning(f"Zep Search API失败,降级为本地搜索: {str(e)}") - # 降级:使用本地关键词匹配搜索 - return self._local_search(graph_id, query, limit, scope) - + selected_nodes = [] + for node in ranked_nodes: + if scope == "edges" and selected_edges and node.get("uuid") not in related_node_uuids: + continue + selected_nodes.append(node) + if len(selected_nodes) >= node_summary_limit: + break + + facts: List[str] = [] + seen_facts = set() + + for edge in selected_edges: + fact = self._normalize_text(edge.get("fact", "")) + if fact and fact not in seen_facts: + facts.append(fact) + seen_facts.add(fact) + + for node in selected_nodes: + summary = self._normalize_text(node.get("summary", "")) + if not summary: + continue + fact = f"[{node.get('name', '未知实体')}]: {summary}" + if fact not in seen_facts: + facts.append(fact) + seen_facts.add(fact) + + logger.info( + "搜索完成: edges=%s, nodes=%s, facts=%s, backend_reranker=%s, app_reranker=%s", + len(selected_edges), + len(selected_nodes), + len(facts), + Config.GRAPH_SEARCH_RERANKER, + self._graph_search_app_reranker(), + ) + + return SearchResult( + facts=facts, + edges=selected_edges, + nodes=selected_nodes, + query=query_normalized, + total_count=len(facts), + ) + def _local_search( - self, - graph_id: str, - query: str, + self, + graph_id: str, + query: str, limit: int = 10, scope: str = "edges" ) -> SearchResult: """ - 本地关键词匹配搜索(作为Zep Search API的降级方案) - - 获取所有边/节点,然后在本地进行关键词匹配 - - Args: - graph_id: 图谱ID - query: 搜索查询 - limit: 返回结果数量 - scope: 搜索范围 - - Returns: - SearchResult: 搜索结果 + 本地关键词匹配搜索(作为后端 search 不可用时的降级方案)。 """ logger.info(f"使用本地搜索: query={query[:30]}...") - - facts = [] - edges_result = [] - nodes_result = [] - - # 提取查询关键词(简单分词) - query_lower = query.lower() - keywords = [w.strip() for w in query_lower.replace(',', ' ').replace(',', ' ').split() if len(w.strip()) > 1] - - def match_score(text: str) -> int: - """计算文本与查询的匹配分数""" - if not text: - return 0 - text_lower = text.lower() - # 完全匹配查询 - if query_lower in text_lower: - return 100 - # 关键词匹配 - score = 0 - for keyword in keywords: - if keyword in text_lower: - score += 10 - return score - + + facts: List[str] = [] + edges_result: List[Dict[str, Any]] = [] + nodes_result: List[Dict[str, Any]] = [] + + query_normalized = self._normalize_text(query) + query_lower = query_normalized.lower() + query_tokens = self._query_tokens(query_normalized) + try: + node_map = {node.uuid: node for node in self.get_all_nodes(graph_id)} + if scope in ["edges", "both"]: - # 获取所有边并匹配 all_edges = self.get_all_edges(graph_id) scored_edges = [] for edge in all_edges: - score = match_score(edge.fact) + match_score(edge.name) + source_name = node_map.get(edge.source_node_uuid, NodeInfo('', '', [], '', {})).name + target_name = node_map.get(edge.target_node_uuid, NodeInfo('', '', [], '', {})).name + score = self._score_texts( + query_lower, + query_tokens, + edge.fact, + edge.name, + source_name, + target_name, + ) if score > 0: - scored_edges.append((score, edge)) - - # 按分数排序 - scored_edges.sort(key=lambda x: x[0], reverse=True) - - for score, edge in scored_edges[:limit]: + scored_edges.append((score, edge, source_name, target_name)) + + scored_edges.sort(key=lambda item: item[0], reverse=True) + + for score, edge, source_name, target_name in scored_edges[:limit]: if edge.fact: facts.append(edge.fact) edges_result.append({ @@ -611,20 +1165,28 @@ class ZepToolsService: "fact": edge.fact, "source_node_uuid": edge.source_node_uuid, "target_node_uuid": edge.target_node_uuid, + "source_node_name": source_name, + "target_node_name": target_name, }) - - if scope in ["nodes", "both"]: - # 获取所有节点并匹配 - all_nodes = self.get_all_nodes(graph_id) + + if scope in ["nodes", "both"] or Config.GRAPH_SEARCH_INCLUDE_NODES: + all_nodes = list(node_map.values()) scored_nodes = [] for node in all_nodes: - score = match_score(node.name) + match_score(node.summary) + score = self._score_texts( + query_lower, + query_tokens, + node.name, + node.summary, + " ".join(node.labels), + ) if score > 0: scored_nodes.append((score, node)) - - scored_nodes.sort(key=lambda x: x[0], reverse=True) - - for score, node in scored_nodes[:limit]: + + scored_nodes.sort(key=lambda item: item[0], reverse=True) + + node_limit = limit if scope == "nodes" else min(limit, Config.GRAPH_SEARCH_NODE_SUMMARY_LIMIT) + for score, node in scored_nodes[:node_limit]: nodes_result.append({ "uuid": node.uuid, "name": node.name, @@ -633,33 +1195,28 @@ class ZepToolsService: }) if node.summary: facts.append(f"[{node.name}]: {node.summary}") - + + facts = list(dict.fromkeys(facts)) logger.info(f"本地搜索完成: 找到 {len(facts)} 条相关事实") - + except Exception as e: logger.error(f"本地搜索失败: {str(e)}") - + return SearchResult( facts=facts, edges=edges_result, nodes=nodes_result, - query=query, + query=query_normalized, total_count=len(facts) ) - + def get_all_nodes(self, graph_id: str) -> List[NodeInfo]: """ 获取图谱的所有节点(分页获取) - - Args: - graph_id: 图谱ID - - Returns: - 节点列表 """ logger.info(f"获取图谱 {graph_id} 的所有节点...") - nodes = fetch_all_nodes(self.client, graph_id) + nodes = self.backend.get_all_nodes(graph_id) result = [] for node in nodes: @@ -678,17 +1235,10 @@ class ZepToolsService: def get_all_edges(self, graph_id: str, include_temporal: bool = True) -> List[EdgeInfo]: """ 获取图谱的所有边(分页获取,包含时间信息) - - Args: - graph_id: 图谱ID - include_temporal: 是否包含时间信息(默认True) - - Returns: - 边列表(包含created_at, valid_at, invalid_at, expired_at) """ logger.info(f"获取图谱 {graph_id} 的所有边...") - edges = fetch_all_edges(self.client, graph_id) + edges = self.backend.get_all_edges(graph_id) result = [] for edge in edges: @@ -701,7 +1251,6 @@ class ZepToolsService: target_node_uuid=edge.target_node_uuid or "" ) - # 添加时间信息 if include_temporal: edge_info.created_at = getattr(edge, 'created_at', None) edge_info.valid_at = getattr(edge, 'valid_at', None) @@ -712,28 +1261,20 @@ class ZepToolsService: logger.info(f"获取到 {len(result)} 条边") return result - + def get_node_detail(self, node_uuid: str) -> Optional[NodeInfo]: - """ - 获取单个节点的详细信息 - - Args: - node_uuid: 节点UUID - - Returns: - 节点信息或None - """ + """获取单个节点的详细信息。""" logger.info(f"获取节点详情: {node_uuid[:8]}...") - + try: node = self._call_with_retry( - func=lambda: self.client.graph.node.get(uuid_=node_uuid), + func=lambda: self.backend.get_node(node_uuid), operation_name=f"获取节点详情(uuid={node_uuid[:8]}...)" ) - + if not node: return None - + return NodeInfo( uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), name=node.name or "", @@ -744,39 +1285,41 @@ class ZepToolsService: except Exception as e: logger.error(f"获取节点详情失败: {str(e)}") return None - + def get_node_edges(self, graph_id: str, node_uuid: str) -> List[EdgeInfo]: - """ - 获取节点相关的所有边 - - 通过获取图谱所有边,然后过滤出与指定节点相关的边 - - Args: - graph_id: 图谱ID - node_uuid: 节点UUID - - Returns: - 边列表 - """ + """获取节点相关的所有边。""" logger.info(f"获取节点 {node_uuid[:8]}... 的相关边") - + try: - # 获取图谱所有边,然后过滤 - all_edges = self.get_all_edges(graph_id) - + edges = self._call_with_retry( + func=lambda: self.backend.get_node_edges(node_uuid), + operation_name=f"获取节点边(uuid={node_uuid[:8]}...)" + ) + result = [] - for edge in all_edges: - # 检查边是否与指定节点相关(作为源或目标) - if edge.source_node_uuid == node_uuid or edge.target_node_uuid == node_uuid: - result.append(edge) - + for edge in edges: + edge_uuid = getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', None) or "" + result.append(EdgeInfo( + uuid=str(edge_uuid) if edge_uuid else "", + name=edge.name or "", + fact=edge.fact or "", + source_node_uuid=edge.source_node_uuid or "", + target_node_uuid=edge.target_node_uuid or "", + source_node_name=getattr(edge, 'source_node_name', None), + target_node_name=getattr(edge, 'target_node_name', None), + created_at=getattr(edge, 'created_at', None), + valid_at=getattr(edge, 'valid_at', None), + invalid_at=getattr(edge, 'invalid_at', None), + expired_at=getattr(edge, 'expired_at', None), + )) + logger.info(f"找到 {len(result)} 条与节点相关的边") return result - + except Exception as e: logger.warning(f"获取节点边失败: {str(e)}") return [] - + def get_entities_by_type( self, graph_id: str, @@ -942,6 +1485,8 @@ class ZepToolsService: # ========== 核心检索工具(优化后) ========== + + def insight_forge( self, graph_id: str, @@ -952,33 +1497,15 @@ class ZepToolsService: ) -> InsightForgeResult: """ 【InsightForge - 深度洞察检索】 - - 最强大的混合检索函数,自动分解问题并多维度检索: - 1. 使用LLM将问题分解为多个子问题 - 2. 对每个子问题进行语义搜索 - 3. 提取相关实体并获取其详细信息 - 4. 追踪关系链 - 5. 整合所有结果,生成深度洞察 - - Args: - graph_id: 图谱ID - query: 用户问题 - simulation_requirement: 模拟需求描述 - report_context: 报告上下文(可选,用于更精准的子问题生成) - max_sub_queries: 最大子问题数量 - - Returns: - InsightForgeResult: 深度洞察检索结果 """ logger.info(f"InsightForge 深度洞察检索: {query[:50]}...") - + result = InsightForgeResult( query=query, simulation_requirement=simulation_requirement, sub_queries=[] ) - - # Step 1: 使用LLM生成子问题 + sub_queries = self._generate_sub_queries( query=query, simulation_requirement=simulation_requirement, @@ -987,108 +1514,129 @@ class ZepToolsService: ) result.sub_queries = sub_queries logger.info(f"生成 {len(sub_queries)} 个子问题") - - # Step 2: 对每个子问题进行语义搜索 - all_facts = [] - all_edges = [] + + all_facts: List[str] = [] seen_facts = set() - - for sub_query in sub_queries: - search_result = self.search_graph( - graph_id=graph_id, - query=sub_query, - limit=15, - scope="edges" - ) - + all_edges: Dict[str, Dict[str, Any]] = {} + all_nodes: Dict[str, Dict[str, Any]] = {} + entity_fact_map: Dict[str, List[str]] = defaultdict(list) + + def merge_search(search_result: SearchResult) -> None: for fact in search_result.facts: if fact not in seen_facts: all_facts.append(fact) seen_facts.add(fact) - - all_edges.extend(search_result.edges) - - # 对原始问题也进行搜索 - main_search = self.search_graph( - graph_id=graph_id, - query=query, - limit=20, - scope="edges" + + for edge in search_result.edges: + edge_key = edge.get('uuid') or "|".join([ + edge.get('name', ''), + edge.get('fact', ''), + edge.get('source_node_uuid', ''), + edge.get('target_node_uuid', ''), + ]) + if edge_key and edge_key not in all_edges: + all_edges[edge_key] = edge + + fact = edge.get('fact', '') + for node_uuid in (edge.get('source_node_uuid', ''), edge.get('target_node_uuid', '')): + if node_uuid and fact and fact not in entity_fact_map[node_uuid]: + entity_fact_map[node_uuid].append(fact) + + for node in search_result.nodes: + node_uuid = node.get('uuid') or node.get('name') + if node_uuid and node_uuid not in all_nodes: + all_nodes[node_uuid] = node + + for sub_query in sub_queries: + merge_search( + self.search_graph( + graph_id=graph_id, + query=sub_query, + limit=15, + scope="edges" + ) + ) + + merge_search( + self.search_graph( + graph_id=graph_id, + query=query, + limit=20, + scope="edges" + ) ) - for fact in main_search.facts: - if fact not in seen_facts: - all_facts.append(fact) - seen_facts.add(fact) - + result.semantic_facts = all_facts result.total_facts = len(all_facts) - - # Step 3: 从边中提取相关实体UUID,只获取这些实体的信息(不获取全部节点) - entity_uuids = set() - for edge_data in all_edges: - if isinstance(edge_data, dict): - source_uuid = edge_data.get('source_node_uuid', '') - target_uuid = edge_data.get('target_node_uuid', '') - if source_uuid: - entity_uuids.add(source_uuid) - if target_uuid: - entity_uuids.add(target_uuid) - - # 获取所有相关实体的详情(不限制数量,完整输出) + + entity_uuids = set(all_nodes.keys()) + for edge in all_edges.values(): + source_uuid = edge.get('source_node_uuid', '') + target_uuid = edge.get('target_node_uuid', '') + if source_uuid: + entity_uuids.add(source_uuid) + if target_uuid: + entity_uuids.add(target_uuid) + entity_insights = [] - node_map = {} # 用于后续关系链构建 - - for uuid in list(entity_uuids): # 处理所有实体,不截断 - if not uuid: + node_map: Dict[str, NodeInfo] = {} + + for node_uuid in list(entity_uuids): + if not node_uuid: continue - try: - # 单独获取每个相关节点的信息 - node = self.get_node_detail(uuid) - if node: - node_map[uuid] = node - entity_type = next((l for l in node.labels if l not in ["Entity", "Node"]), "实体") - - # 获取该实体相关的所有事实(不截断) - related_facts = [ - f for f in all_facts - if node.name.lower() in f.lower() - ] - - entity_insights.append({ - "uuid": node.uuid, - "name": node.name, - "type": entity_type, - "summary": node.summary, - "related_facts": related_facts # 完整输出,不截断 - }) - except Exception as e: - logger.debug(f"获取节点 {uuid} 失败: {e}") + + search_node = all_nodes.get(node_uuid, {}) + node = self.get_node_detail(node_uuid) + if node is None and search_node: + node = NodeInfo( + uuid=search_node.get('uuid', node_uuid), + name=search_node.get('name', ''), + labels=search_node.get('labels', []), + summary=search_node.get('summary', ''), + attributes={}, + ) + + if not node: continue - + + node_map[node_uuid] = node + entity_type = next((label for label in node.labels if label not in ["Entity", "Node"]), "实体") + related_facts = list(dict.fromkeys(entity_fact_map.get(node_uuid, []))) + if not related_facts and node.name: + related_facts = [fact for fact in all_facts if node.name.lower() in fact.lower()] + + entity_insights.append({ + "uuid": node.uuid, + "name": node.name, + "type": entity_type, + "summary": node.summary, + "related_facts": related_facts, + }) + result.entity_insights = entity_insights result.total_entities = len(entity_insights) - - # Step 4: 构建所有关系链(不限制数量) + relationship_chains = [] - for edge_data in all_edges: # 处理所有边,不截断 - if isinstance(edge_data, dict): - source_uuid = edge_data.get('source_node_uuid', '') - target_uuid = edge_data.get('target_node_uuid', '') - relation_name = edge_data.get('name', '') - - source_name = node_map.get(source_uuid, NodeInfo('', '', [], '', {})).name or source_uuid[:8] - target_name = node_map.get(target_uuid, NodeInfo('', '', [], '', {})).name or target_uuid[:8] - - chain = f"{source_name} --[{relation_name}]--> {target_name}" - if chain not in relationship_chains: - relationship_chains.append(chain) - + for edge in all_edges.values(): + source_uuid = edge.get('source_node_uuid', '') + target_uuid = edge.get('target_node_uuid', '') + relation_name = edge.get('name', '') + + source_name = node_map.get(source_uuid, NodeInfo('', '', [], '', {})).name or edge.get('source_node_name', '') or source_uuid[:8] + target_name = node_map.get(target_uuid, NodeInfo('', '', [], '', {})).name or edge.get('target_node_name', '') or target_uuid[:8] + + chain = f"{source_name} --[{relation_name}]--> {target_name}" + if chain not in relationship_chains: + relationship_chains.append(chain) + result.relationship_chains = relationship_chains result.total_relationships = len(relationship_chains) - - logger.info(f"InsightForge完成: {result.total_facts}条事实, {result.total_entities}个实体, {result.total_relationships}条关系") + + logger.info( + f"InsightForge完成: {result.total_facts}条事实, {result.total_entities}个实体, {result.total_relationships}条关系" + ) return result - + def _generate_sub_queries( self, query: str, diff --git a/backend/app/utils/embedding_client.py b/backend/app/utils/embedding_client.py new file mode 100644 index 00000000..871bbdab --- /dev/null +++ b/backend/app/utils/embedding_client.py @@ -0,0 +1,47 @@ +""" +Embedding client wrapper for OpenAI-compatible embedding APIs. +""" + +from __future__ import annotations + +from typing import List, Optional + +from openai import OpenAI + + +class EmbeddingClient: + """Thin wrapper around OpenAI-compatible embedding endpoints.""" + + def __init__( + self, + api_key: Optional[str], + base_url: str, + model: str, + batch_size: int = 32, + ): + if not base_url: + raise ValueError("Embedding base_url 未配置") + if not model: + raise ValueError("Embedding model 未配置") + + self.api_key = api_key or 'ollama' + self.base_url = base_url + self.model = model + self.batch_size = max(1, int(batch_size)) + self.client = OpenAI(api_key=self.api_key, base_url=self.base_url) + + def embed_texts(self, texts: List[str]) -> List[List[float]]: + """Embed texts in batches while preserving input order.""" + if not texts: + return [] + + embeddings: List[List[float]] = [] + normalized_inputs = [str(text or ' ').strip() or ' ' for text in texts] + + for start in range(0, len(normalized_inputs), self.batch_size): + batch = normalized_inputs[start:start + self.batch_size] + response = self.client.embeddings.create(model=self.model, input=batch) + data = sorted(response.data, key=lambda item: item.index) + embeddings.extend(item.embedding for item in data) + + return embeddings diff --git a/backend/app/utils/reranker_client.py b/backend/app/utils/reranker_client.py new file mode 100644 index 00000000..6465863e --- /dev/null +++ b/backend/app/utils/reranker_client.py @@ -0,0 +1,175 @@ +""" +Reranker client wrappers for common HTTP rerank APIs. +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from typing import Dict, List, Optional +from urllib import error, request + + +@dataclass +class RerankerRequestSpec: + provider: str + path: str + body: dict + + +class RerankerClient: + """Thin wrapper around HTTP rerank endpoints.""" + + def __init__( + self, + base_url: str, + model: Optional[str] = None, + api_key: Optional[str] = None, + provider: str = "auto", + timeout: float = 20.0, + ): + if not base_url: + raise ValueError("Reranker base_url 未配置") + + self.base_url = base_url.rstrip("/") + self.model = model + self.api_key = api_key or None + self.provider = (provider or "auto").strip().lower() or "auto" + self.timeout = max(1.0, float(timeout)) + + def rerank(self, query: str, documents: List[str]) -> Dict[int, float]: + """Return index -> score for the supplied candidate documents.""" + if not documents: + return {} + + last_error: Optional[Exception] = None + for spec in self._build_request_specs(query, documents): + try: + return self._execute_request(spec) + except Exception as exc: + last_error = exc + + if last_error is not None: + raise last_error + raise RuntimeError("没有可用的 reranker provider") + + def _build_request_specs(self, query: str, documents: List[str]) -> List[RerankerRequestSpec]: + providers = [self.provider] + if self.provider == "auto": + providers = ["tei", "jina", "cohere", "vllm", "infinity"] + + specs: List[RerankerRequestSpec] = [] + for provider in providers: + if provider == "tei": + specs.append( + RerankerRequestSpec( + provider=provider, + path="/rerank", + body={ + "query": query, + "texts": documents, + "truncate": True, + "raw_scores": False, + }, + ) + ) + elif provider in {"jina", "vllm", "infinity"}: + body = { + "query": query, + "documents": documents, + "top_n": len(documents), + "return_documents": False, + } + if self.model: + body["model"] = self.model + specs.append( + RerankerRequestSpec( + provider=provider, + path="/v1/rerank", + body=body, + ) + ) + elif provider == "cohere": + body = { + "query": query, + "documents": documents, + "top_n": len(documents), + "return_documents": False, + } + if self.model: + body["model"] = self.model + specs.append( + RerankerRequestSpec( + provider=provider, + path="/v2/rerank", + body=body, + ) + ) + + return specs + + def _headers(self) -> dict: + headers = {"Content-Type": "application/json"} + if self.api_key: + headers["Authorization"] = f"Bearer {self.api_key}" + return headers + + def _execute_request(self, spec: RerankerRequestSpec) -> Dict[int, float]: + payload = json.dumps(spec.body).encode("utf-8") + req = request.Request( + f"{self.base_url}{spec.path}", + data=payload, + headers=self._headers(), + method="POST", + ) + + try: + with request.urlopen(req, timeout=self.timeout) as resp: + body = resp.read().decode("utf-8") + except error.HTTPError as exc: + detail = exc.read().decode("utf-8", errors="replace") + raise RuntimeError(f"{spec.provider} rerank 请求失败: HTTP {exc.code}: {detail[:300]}") from exc + except error.URLError as exc: + raise RuntimeError(f"{spec.provider} rerank 请求失败: {exc.reason}") from exc + + try: + data = json.loads(body) + except json.JSONDecodeError as exc: + raise RuntimeError(f"{spec.provider} rerank 返回非 JSON 响应") from exc + + scores = self._parse_scores(spec.provider, data) + if not scores: + raise RuntimeError(f"{spec.provider} rerank 未返回有效分数") + return scores + + def _parse_scores(self, provider: str, payload: object) -> Dict[int, float]: + if provider == "tei": + if not isinstance(payload, list): + raise RuntimeError("TEI rerank 响应格式异常") + + scores: Dict[int, float] = {} + for item in payload: + if not isinstance(item, dict): + continue + index = item.get("index") + score = item.get("score") + if isinstance(index, int) and score is not None: + scores[index] = float(score) + return scores + + if not isinstance(payload, dict): + raise RuntimeError("rerank 响应格式异常") + + results = payload.get("results") or payload.get("data") or [] + scores: Dict[int, float] = {} + if isinstance(results, list): + for item in results: + if not isinstance(item, dict): + continue + index = item.get("index") + score = item.get("relevance_score") + if score is None: + score = item.get("score") + if isinstance(index, int) and score is not None: + scores[index] = float(score) + return scores diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 4f5361d5..161949a7 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -14,10 +14,14 @@ dependencies = [ "flask-cors>=6.0.0", # LLM 相关 - "openai>=1.0.0", + "openai>=1.91.0", # Zep Cloud "zep-cloud==3.13.0", + + # Graphiti / Neo4j + "graphiti-core==0.28.2", + "neo4j>=5.26.0", # OASIS 社交媒体模拟 "camel-oasis==0.2.5", @@ -31,7 +35,7 @@ dependencies = [ # 工具库 "python-dotenv>=1.0.0", - "pydantic>=2.0.0", + "pydantic>=2.11.5", ] [project.optional-dependencies] diff --git a/backend/requirements.txt b/backend/requirements.txt index 4f146296..628bea54 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -11,11 +11,15 @@ flask-cors>=6.0.0 # ============= LLM 相关 ============= # OpenAI SDK(统一使用 OpenAI 格式调用 LLM) -openai>=1.0.0 +openai>=1.91.0 # ============= Zep Cloud ============= zep-cloud==3.13.0 +# ============= Graphiti / Neo4j ============= +graphiti-core==0.28.2 +neo4j>=5.26.0 + # ============= OASIS 社交媒体模拟 ============= # OASIS 社交模拟框架 camel-oasis==0.2.5 @@ -32,4 +36,4 @@ chardet>=5.0.0 python-dotenv>=1.0.0 # 数据验证 -pydantic>=2.0.0 +pydantic>=2.11.5 diff --git a/backend/scripts/graphiti_smoke_test.py b/backend/scripts/graphiti_smoke_test.py new file mode 100644 index 00000000..ac23cbf2 --- /dev/null +++ b/backend/scripts/graphiti_smoke_test.py @@ -0,0 +1,223 @@ +#!/usr/bin/env python3 +""" +Service-level Graphiti smoke test for MiroFish. + +The script creates a temporary graph, applies a small ontology, ingests a few +text chunks through GraphBuilderService, then queries through ZepToolsService. +It prints JSON output and deletes the graph by default. +""" + +from __future__ import annotations + +import argparse +import json +import os +import sys +import traceback +import uuid +from pathlib import Path +from typing import Iterable + + +REPO_ROOT = Path(__file__).resolve().parents[2] +BACKEND_ROOT = REPO_ROOT / "backend" + +DEFAULT_CHUNKS = [ + "Alice works at Acme Robotics. Bob manages Alice.", + "Acme Robotics is located in Tokyo. Alice relocated to Tokyo for work.", +] + + +def load_env_file(path: Path) -> None: + """Load a simple .env file without overriding existing variables.""" + if not path.exists(): + return + + for raw_line in path.read_text(encoding="utf-8").splitlines(): + line = raw_line.strip() + if not line or line.startswith("#") or "=" not in line: + continue + + key, value = line.split("=", 1) + key = key.strip() + value = value.strip() + + if value and value[0] == value[-1] and value[0] in {"'", '"'}: + value = value[1:-1] + + os.environ.setdefault(key, value) + + +def prepare_environment() -> None: + """Apply repo-local defaults that make Graphiti smoke testing predictable.""" + load_env_file(REPO_ROOT / ".env") + load_env_file(BACKEND_ROOT / ".env") + + os.environ.setdefault("GRAPH_BACKEND", "graphiti") + os.environ.setdefault("GRAPHITI_URI", "bolt://127.0.0.1:7687") + os.environ.setdefault("GRAPHITI_USER", "neo4j") + os.environ.setdefault("GRAPHITI_PASSWORD", "password123") + os.environ.setdefault("GRAPHITI_DATABASE", "neo4j") + os.environ.setdefault("GRAPHITI_LLM_BASE_URL", os.environ.get("LLM_BASE_URL", "http://127.0.0.1:18081/v1")) + os.environ.setdefault("GRAPHITI_LLM_MODEL", os.environ.get("LLM_MODEL_NAME", "gpt-5.4")) + os.environ.setdefault("GRAPHITI_LLM_CLIENT_MODE", "openai") + os.environ.setdefault("GRAPH_SEARCH_RERANKER", "rrf") + os.environ.setdefault("GRAPHITI_EMBEDDER_API_KEY", "ollama") + os.environ.setdefault("GRAPHITI_EMBEDDER_BASE_URL", "http://127.0.0.1:11434/v1") + os.environ.setdefault("GRAPHITI_EMBEDDER_MODEL", "qwen3-embedding:8b") + os.environ.setdefault("GRAPHITI_EMBEDDER_DIM", "1024") + + +def parse_args(argv: Iterable[str]) -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Run a Graphiti service-level smoke test.") + parser.add_argument( + "--query", + default="Where does Alice work?", + help="Search query executed through ZepToolsService.", + ) + parser.add_argument( + "--chunk", + action="append", + dest="chunks", + help="Text chunk to ingest. Repeat to add multiple chunks.", + ) + parser.add_argument( + "--graph-name", + default="graphiti smoke test", + help="Temporary graph name.", + ) + parser.add_argument( + "--keep-graph", + action="store_true", + help="Keep the temporary graph instead of deleting it on exit.", + ) + return parser.parse_args(list(argv)) + + +def build_sample_ontology() -> dict: + return { + "entity_types": [ + {"name": "Person", "description": "A human individual.", "attributes": []}, + {"name": "Organization", "description": "A company or institution.", "attributes": []}, + {"name": "Location", "description": "A place or city.", "attributes": []}, + ], + "edge_types": [ + { + "name": "WORKS_AT", + "description": "A person works at an organization.", + "attributes": [], + "source_targets": [{"source": "Person", "target": "Organization"}], + }, + { + "name": "MANAGES", + "description": "A person manages another person.", + "attributes": [], + "source_targets": [{"source": "Person", "target": "Person"}], + }, + { + "name": "LOCATED_IN", + "description": "An organization is located in a place.", + "attributes": [], + "source_targets": [{"source": "Organization", "target": "Location"}], + }, + ], + } + + +def main(argv: Iterable[str]) -> int: + prepare_environment() + sys.path.insert(0, str(BACKEND_ROOT)) + + from app.config import Config + + search_embedder = Config.get_graph_search_embedder_config() + search_reranker = Config.get_graph_search_reranker_config() + errors = Config.get_graph_backend_config_errors(api_key=Config.ZEP_API_KEY) + if errors: + print( + json.dumps( + { + "success": False, + "error": "Graph backend config is incomplete.", + "backend": Config.GRAPH_BACKEND, + "errors": errors, + }, + ensure_ascii=False, + indent=2, + ), + file=sys.stderr, + ) + return 2 + + from app.services.graph_builder import GraphBuilderService + from app.services.zep_tools import ZepToolsService + + args = parse_args(argv) + chunks = args.chunks or list(DEFAULT_CHUNKS) + + builder = GraphBuilderService() + tools = ZepToolsService() + graph_id = builder.create_graph(f"{args.graph_name}-{uuid.uuid4().hex[:8]}") + + try: + builder.set_ontology(graph_id, build_sample_ontology()) + builder.add_text_batches(graph_id, chunks, batch_size=1) + + result = tools.search_graph( + graph_id=graph_id, + query=args.query, + limit=5, + scope="edges", + ) + stats = tools.get_graph_statistics(graph_id) + + output = { + "success": True, + "graph_id": graph_id, + "kept_graph": args.keep_graph, + "backend": Config.GRAPH_BACKEND, + "llm_base_url": Config.GRAPHITI_LLM_BASE_URL or Config.LLM_BASE_URL, + "llm_model": Config.GRAPHITI_LLM_MODEL or Config.LLM_MODEL_NAME, + "embedder_base_url": Config.GRAPHITI_EMBEDDER_BASE_URL, + "embedder_model": Config.GRAPHITI_EMBEDDER_MODEL, + "app_reranker": Config.GRAPH_SEARCH_APP_RERANKER, + "app_embedder_base_url": search_embedder.get("base_url"), + "app_embedder_model": search_embedder.get("model"), + "app_reranker_base_url": search_reranker.get("base_url"), + "app_reranker_model": search_reranker.get("model"), + "app_reranker_provider": search_reranker.get("provider"), + "query": args.query, + "chunks": chunks, + "stats": stats, + "facts": result.facts, + "edges": result.edges, + "nodes": result.nodes, + "total_count": result.total_count, + } + print(json.dumps(output, ensure_ascii=False, indent=2)) + return 0 + except Exception as exc: + print( + json.dumps( + { + "success": False, + "graph_id": graph_id, + "error": str(exc), + "traceback": traceback.format_exc(), + }, + ensure_ascii=False, + indent=2, + ), + file=sys.stderr, + ) + return 1 + finally: + if not args.keep_graph: + try: + builder.backend.delete_graph(graph_id) + except Exception: + traceback.print_exc() + + +if __name__ == "__main__": + raise SystemExit(main(sys.argv[1:])) From a026178d67d0c2f617efe829df23cad3b7bfb921 Mon Sep 17 00:00:00 2001 From: MiroFish Bot Date: Sat, 21 Mar 2026 17:24:55 +0900 Subject: [PATCH 2/2] fix(graph): harden ontology normalization for Zep limits --- backend/app/services/graph_builder.py | 91 ++++++++++++------ backend/app/services/ontology_generator.py | 104 +++++++++++++++------ 2 files changed, 139 insertions(+), 56 deletions(-) diff --git a/backend/app/services/graph_builder.py b/backend/app/services/graph_builder.py index c53c8e1f..2731c61a 100644 --- a/backend/app/services/graph_builder.py +++ b/backend/app/services/graph_builder.py @@ -210,74 +210,111 @@ class GraphBuilderService: # Zep 保留名称,不能作为属性名 RESERVED_NAMES = {'uuid', 'name', 'group_id', 'name_embedding', 'summary', 'created_at'} - + def safe_attr_name(attr_name: str) -> str: """将保留名称转换为安全名称""" if attr_name.lower() in RESERVED_NAMES: return f"entity_{attr_name}" return attr_name - + + def normalize_attributes(raw_attributes: Any) -> List[Dict[str, str]]: + normalized: List[Dict[str, str]] = [] + for attr_def in raw_attributes or []: + if isinstance(attr_def, str): + attr_def = {"name": attr_def, "description": attr_def} + if not isinstance(attr_def, dict): + continue + + attr_name = str(attr_def.get("name", "")).strip() + if not attr_name: + continue + + normalized.append({ + "name": attr_name, + "description": str(attr_def.get("description") or attr_name), + }) + return normalized + + def normalize_source_targets(raw_source_targets: Any) -> List[EntityEdgeSourceTarget]: + normalized: List[EntityEdgeSourceTarget] = [] + for source_target in raw_source_targets or []: + if not isinstance(source_target, dict): + continue + + normalized.append( + EntityEdgeSourceTarget( + source=str(source_target.get("source", "Entity")) or "Entity", + target=str(source_target.get("target", "Entity")) or "Entity", + ) + ) + + # Zep API allows max 10 source_targets per edge type. + return normalized[:10] + # 动态创建实体类型 entity_types = {} for entity_def in ontology.get("entity_types", []): - name = entity_def["name"] + if not isinstance(entity_def, dict): + continue + + name = str(entity_def.get("name", "")).strip() + if not name: + continue + description = entity_def.get("description", f"A {name} entity.") - + # 创建属性字典和类型注解(Pydantic v2 需要) attrs = {"__doc__": description} annotations = {} - - for attr_def in entity_def.get("attributes", []): + + for attr_def in normalize_attributes(entity_def.get("attributes", [])): attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称 attr_desc = attr_def.get("description", attr_name) # Zep API 需要 Field 的 description,这是必需的 attrs[attr_name] = Field(description=attr_desc, default=None) annotations[attr_name] = Optional[EntityText] # 类型注解 - + attrs["__annotations__"] = annotations - + # 动态创建类 entity_class = type(name, (EntityModel,), attrs) entity_class.__doc__ = description entity_types[name] = entity_class - + # 动态创建边类型 edge_definitions = {} for edge_def in ontology.get("edge_types", []): - name = edge_def["name"] + if not isinstance(edge_def, dict): + continue + + name = str(edge_def.get("name", "")).strip() + if not name: + continue + description = edge_def.get("description", f"A {name} relationship.") - + # 创建属性字典和类型注解 attrs = {"__doc__": description} annotations = {} - - for attr_def in edge_def.get("attributes", []): + + for attr_def in normalize_attributes(edge_def.get("attributes", [])): attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称 attr_desc = attr_def.get("description", attr_name) # Zep API 需要 Field 的 description,这是必需的 attrs[attr_name] = Field(description=attr_desc, default=None) annotations[attr_name] = Optional[str] # 边属性用str类型 - + attrs["__annotations__"] = annotations - + # 动态创建类 class_name = ''.join(word.capitalize() for word in name.split('_')) edge_class = type(class_name, (EdgeModel,), attrs) edge_class.__doc__ = description - - # 构建source_targets - source_targets = [] - for st in edge_def.get("source_targets", []): - source_targets.append( - EntityEdgeSourceTarget( - source=st.get("source", "Entity"), - target=st.get("target", "Entity") - ) - ) - + + source_targets = normalize_source_targets(edge_def.get("source_targets", [])) if source_targets: edge_definitions[name] = (edge_class, source_targets) - + # 调用Zep API设置本体 if entity_types or edge_definitions: self.backend.set_ontology( diff --git a/backend/app/services/ontology_generator.py b/backend/app/services/ontology_generator.py index 2d3e39bd..330ef831 100644 --- a/backend/app/services/ontology_generator.py +++ b/backend/app/services/ontology_generator.py @@ -256,38 +256,84 @@ class OntologyGenerator: def _validate_and_process(self, result: Dict[str, Any]) -> Dict[str, Any]: """验证和后处理结果""" - + + if not isinstance(result, dict): + result = {} + # 确保必要字段存在 - if "entity_types" not in result: + if not isinstance(result.get("entity_types"), list): result["entity_types"] = [] - if "edge_types" not in result: + if not isinstance(result.get("edge_types"), list): result["edge_types"] = [] if "analysis_summary" not in result: result["analysis_summary"] = "" - + # 验证实体类型 + validated_entities = [] for entity in result["entity_types"]: - if "attributes" not in entity: - entity["attributes"] = [] - if "examples" not in entity: - entity["examples"] = [] - # 确保description不超过100字符 - if len(entity.get("description", "")) > 100: - entity["description"] = entity["description"][:97] + "..." - + if isinstance(entity, str): + entity = {"name": entity, "description": f"Entity type: {entity}"} + if not isinstance(entity, dict): + continue + + name = str(entity.get("name", "")).strip() + if not name: + continue + + attributes = entity.get("attributes") + if not isinstance(attributes, list): + attributes = [] + + examples = entity.get("examples") + if not isinstance(examples, list): + examples = [] + + normalized = dict(entity) + normalized["name"] = name + normalized["attributes"] = attributes + normalized["examples"] = examples + if len(normalized.get("description", "")) > 100: + normalized["description"] = normalized["description"][:97] + "..." + + validated_entities.append(normalized) + + result["entity_types"] = validated_entities + # 验证关系类型 + validated_edges = [] for edge in result["edge_types"]: - if "source_targets" not in edge: - edge["source_targets"] = [] - if "attributes" not in edge: - edge["attributes"] = [] - if len(edge.get("description", "")) > 100: - edge["description"] = edge["description"][:97] + "..." - + if isinstance(edge, str): + edge = {"name": edge, "description": f"Relationship type: {edge}"} + if not isinstance(edge, dict): + continue + + name = str(edge.get("name", "")).strip() + if not name: + continue + + source_targets = edge.get("source_targets") + if not isinstance(source_targets, list): + source_targets = [] + + attributes = edge.get("attributes") + if not isinstance(attributes, list): + attributes = [] + + normalized = dict(edge) + normalized["name"] = name + normalized["source_targets"] = source_targets + normalized["attributes"] = attributes + if len(normalized.get("description", "")) > 100: + normalized["description"] = normalized["description"][:97] + "..." + + validated_edges.append(normalized) + + result["edge_types"] = validated_edges + # Zep API 限制:最多 10 个自定义实体类型,最多 10 个自定义边类型 MAX_ENTITY_TYPES = 10 MAX_EDGE_TYPES = 10 - + # 兜底类型定义 person_fallback = { "name": "Person", @@ -298,7 +344,7 @@ class OntologyGenerator: ], "examples": ["ordinary citizen", "anonymous netizen"] } - + organization_fallback = { "name": "Organization", "description": "Any organization not fitting other specific organization types.", @@ -308,40 +354,40 @@ class OntologyGenerator: ], "examples": ["small business", "community group"] } - + # 检查是否已有兜底类型 entity_names = {e["name"] for e in result["entity_types"]} has_person = "Person" in entity_names has_organization = "Organization" in entity_names - + # 需要添加的兜底类型 fallbacks_to_add = [] if not has_person: fallbacks_to_add.append(person_fallback) if not has_organization: fallbacks_to_add.append(organization_fallback) - + if fallbacks_to_add: current_count = len(result["entity_types"]) needed_slots = len(fallbacks_to_add) - + # 如果添加后会超过 10 个,需要移除一些现有类型 if current_count + needed_slots > MAX_ENTITY_TYPES: # 计算需要移除多少个 to_remove = current_count + needed_slots - MAX_ENTITY_TYPES # 从末尾移除(保留前面更重要的具体类型) result["entity_types"] = result["entity_types"][:-to_remove] - + # 添加兜底类型 result["entity_types"].extend(fallbacks_to_add) - + # 最终确保不超过限制(防御性编程) if len(result["entity_types"]) > MAX_ENTITY_TYPES: result["entity_types"] = result["entity_types"][:MAX_ENTITY_TYPES] - + if len(result["edge_types"]) > MAX_EDGE_TYPES: result["edge_types"] = result["edge_types"][:MAX_EDGE_TYPES] - + return result def generate_python_code(self, ontology: Dict[str, Any]) -> str: