Merge PR #276: pluggable graph backend with Graphiti support
This commit is contained in:
commit
1466ea2c69
87
.env.example
87
.env.example
|
|
@ -1,8 +1,7 @@
|
|||
# ===== LLM API Configuration =====
|
||||
# Default: any OpenAI-compatible API
|
||||
# With Prompture installed (pip install prompture): 12+ providers supported
|
||||
#
|
||||
# ── OpenAI-compatible (default, no Prompture needed) ──
|
||||
# LLM API 配置(支持 OpenAI SDK 格式的任意 LLM API)
|
||||
# 可直接填写你自己的 LLM 接口,例如:
|
||||
# LLM_BASE_URL=http://127.0.0.1:18081/v1
|
||||
# LLM_MODEL_NAME=gpt-5.4
|
||||
LLM_API_KEY=your_api_key_here
|
||||
LLM_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1
|
||||
LLM_MODEL_NAME=qwen-plus
|
||||
|
|
@ -30,15 +29,73 @@ LLM_MODEL_NAME=qwen-plus
|
|||
#
|
||||
# See all providers: https://github.com/jhd3197/prompture#providers
|
||||
|
||||
# ===== ZEP Memory Graph =====
|
||||
# Free monthly quota: https://app.getzep.com/
|
||||
ZEP_API_KEY=your_zep_api_key_here
|
||||
# Docker 容器内访问宿主机 LLM 时使用的地址
|
||||
# Linux + Docker Compose 下可保持默认 host.docker.internal
|
||||
DOCKER_LLM_BASE_URL=http://host.docker.internal:18081/v1
|
||||
|
||||
# ===== Boost LLM (optional) =====
|
||||
# LLM_BOOST_API_KEY=your_api_key_here
|
||||
# LLM_BOOST_BASE_URL=your_base_url_here
|
||||
# LLM_BOOST_MODEL_NAME=your_model_name_here
|
||||
# ===== Graphiti + Neo4j(默认推荐)=====
|
||||
GRAPH_BACKEND=graphiti
|
||||
GRAPHITI_URI=bolt://localhost:7687
|
||||
GRAPHITI_USER=neo4j
|
||||
GRAPHITI_PASSWORD=password123
|
||||
GRAPHITI_DATABASE=neo4j
|
||||
GRAPHITI_LLM_CLIENT_MODE=openai
|
||||
GRAPHITI_EMBEDDER_API_KEY=ollama
|
||||
GRAPHITI_EMBEDDER_BASE_URL=http://127.0.0.1:11434/v1
|
||||
GRAPHITI_EMBEDDER_MODEL=qwen3-embedding:8b
|
||||
GRAPHITI_EMBEDDER_DIM=1024
|
||||
GRAPH_SEARCH_RERANKER=rrf
|
||||
GRAPH_SEARCH_APP_RERANKER=embedding_rrf
|
||||
GRAPH_SEARCH_APP_SEMANTIC_WEIGHT=2.0
|
||||
GRAPH_SEARCH_EXPAND_EDGES_FROM_NODES=true
|
||||
OLLAMA_PORT=11434
|
||||
OLLAMA_EMBEDDER_MODEL=qwen3-embedding:8b
|
||||
|
||||
# ===== Frontend API timeout (optional) =====
|
||||
# Increase this value for slow local LLMs (milliseconds)
|
||||
# VITE_API_TIMEOUT=600000 # 10 minutes
|
||||
# 可选:如需独立于 Graphiti / OpenZep 当前 embedder,可单独覆写:
|
||||
# GRAPH_SEARCH_APP_EMBEDDER_API_KEY=ollama
|
||||
# GRAPH_SEARCH_APP_EMBEDDER_BASE_URL=http://127.0.0.1:11434/v1
|
||||
# GRAPH_SEARCH_APP_EMBEDDER_MODEL=qwen3-embedding:8b
|
||||
# 可选:如果你另起了免费 cross-encoder / rerank 服务(TEI / Infinity / vLLM)
|
||||
# GRAPH_SEARCH_APP_RERANKER=api_rrf
|
||||
# GRAPH_SEARCH_APP_RERANKER_PROVIDER=tei
|
||||
# GRAPH_SEARCH_APP_RERANKER_BASE_URL=http://127.0.0.1:18090
|
||||
# GRAPH_SEARCH_APP_RERANKER_MODEL=your_reranker_model
|
||||
# GRAPH_SEARCH_APP_RERANKER_TIMEOUT=20
|
||||
# 免费召回增强:从高相关节点补抓相邻边,默认开启
|
||||
# GRAPH_SEARCH_NODE_EDGE_EXPANSION_LIMIT=2
|
||||
# GRAPH_SEARCH_NODE_EDGE_PER_NODE_LIMIT=8
|
||||
# GRAPHITI_ENABLE_CROSS_ENCODER=false
|
||||
|
||||
# Docker 中的 MiroFish 容器访问宿主机 LLM / 容器内 Ollama 时使用:
|
||||
DOCKER_GRAPHITI_LLM_BASE_URL=http://host.docker.internal:18081/v1
|
||||
DOCKER_GRAPHITI_EMBEDDER_BASE_URL=http://ollama:11434/v1
|
||||
|
||||
# ===== OpenZep / Zep(可选,非默认)=====
|
||||
# ZEP_API_KEY=your_zep_api_key_here
|
||||
# ZEP_MODE=openzep
|
||||
# 本地源码运行 MiroFish 时使用 localhost
|
||||
# ZEP_BASE_URL=http://localhost:8000/api/v2
|
||||
# Docker 中的 MiroFish 容器会自动改用 openzep 服务名
|
||||
# DOCKER_ZEP_BASE_URL=http://openzep:8000/api/v2
|
||||
# 留空表示不启用 OpenZep API 鉴权
|
||||
# OPENZEP_API_KEY=
|
||||
# OPENZEP_LLM_API_KEY=your_api_key_here
|
||||
# OPENZEP_LLM_BASE_URL=http://127.0.0.1:18081/v1
|
||||
# OPENZEP_DOCKER_LLM_BASE_URL=http://host.docker.internal:18081/v1
|
||||
# OPENZEP_LLM_MODEL=gpt-5.4
|
||||
# OPENZEP_EMBEDDER_API_KEY=ollama
|
||||
# OPENZEP_EMBEDDER_BASE_URL=http://127.0.0.1:11434/v1
|
||||
# OPENZEP_DOCKER_EMBEDDER_BASE_URL=http://ollama:11434/v1
|
||||
# OPENZEP_EMBEDDER_MODEL=qwen3-embedding:8b
|
||||
|
||||
# ===== Neo4j / OpenZep 端口 =====
|
||||
NEO4J_PASSWORD=password123
|
||||
NEO4J_HTTP_PORT=7474
|
||||
NEO4J_BOLT_PORT=7687
|
||||
OPENZEP_PORT=8000
|
||||
|
||||
# ===== 加速 LLM 配置(可选)=====
|
||||
# 注意如果不使用加速配置,env 文件中就不要出现下面的配置项
|
||||
LLM_BOOST_API_KEY=your_api_key_here
|
||||
LLM_BOOST_BASE_URL=your_base_url_here
|
||||
LLM_BOOST_MODEL_NAME=your_model_name_here
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ COPY --from=ghcr.io/astral-sh/uv:0.9.26 /uv /uvx /bin/
|
|||
|
||||
WORKDIR /app
|
||||
|
||||
ARG INSTALL_GRAPHITI=true
|
||||
|
||||
# 先复制依赖描述文件以利用缓存
|
||||
COPY package.json package-lock.json ./
|
||||
COPY frontend/package.json frontend/package-lock.json ./frontend/
|
||||
|
|
@ -18,7 +20,10 @@ COPY backend/pyproject.toml backend/uv.lock ./backend/
|
|||
# 安装依赖(Node + Python)
|
||||
RUN npm ci \
|
||||
&& npm ci --prefix frontend \
|
||||
&& cd backend && uv sync --frozen
|
||||
&& cd backend && uv sync --frozen \
|
||||
&& if [ "$INSTALL_GRAPHITI" = "true" ]; then \
|
||||
uv pip install --python /app/backend/.venv/bin/python --no-cache-dir graphiti-core==0.28.2 "neo4j>=5.26.0"; \
|
||||
fi
|
||||
|
||||
# 复制项目源码
|
||||
COPY . .
|
||||
|
|
|
|||
|
|
@ -284,9 +284,7 @@ def build_graph():
|
|||
logger.info("=== 开始构建图谱 ===")
|
||||
|
||||
# 检查配置
|
||||
errors = []
|
||||
if not Config.ZEP_API_KEY:
|
||||
errors.append(t('api.zepApiKeyMissing'))
|
||||
errors = Config.get_graph_backend_config_errors()
|
||||
if errors:
|
||||
logger.error(f"配置错误: {errors}")
|
||||
return jsonify({
|
||||
|
|
@ -437,10 +435,12 @@ def build_graph():
|
|||
progress=15
|
||||
)
|
||||
|
||||
# OpenZep 本地链路在批量抽取时更容易卡在长时间的联合推理里。
|
||||
# 改为单块发送可以显著降低单次处理负载,牺牲吞吐换稳定性。
|
||||
episode_uuids = builder.add_text_batches(
|
||||
graph_id,
|
||||
graph_id,
|
||||
chunks,
|
||||
batch_size=3,
|
||||
batch_size=1 if Config.use_openzep() else 3,
|
||||
progress_callback=add_progress_callback
|
||||
)
|
||||
|
||||
|
|
@ -459,7 +459,7 @@ def build_graph():
|
|||
progress=progress
|
||||
)
|
||||
|
||||
builder._wait_for_episodes(episode_uuids, wait_progress_callback)
|
||||
builder._wait_for_episodes(graph_id, episode_uuids, wait_progress_callback)
|
||||
|
||||
# 获取图谱数据
|
||||
task_manager.update_task(
|
||||
|
|
@ -469,12 +469,20 @@ def build_graph():
|
|||
)
|
||||
graph_data = builder.get_graph_data(graph_id)
|
||||
|
||||
node_count = graph_data.get("node_count", 0)
|
||||
edge_count = graph_data.get("edge_count", 0)
|
||||
|
||||
# 如果图谱仍然是空的,说明 OpenZep 没有成功完成抽取。
|
||||
# 不能把这种情况标记为成功,否则前端会误以为构图完成。
|
||||
if node_count == 0 and edge_count == 0:
|
||||
raise RuntimeError(
|
||||
"图谱构建未产出任何节点或边;OpenZep 处理可能超时或未完成"
|
||||
)
|
||||
|
||||
# 更新项目状态
|
||||
project.status = ProjectStatus.GRAPH_COMPLETED
|
||||
ProjectManager.save_project(project)
|
||||
|
||||
node_count = graph_data.get("node_count", 0)
|
||||
edge_count = graph_data.get("edge_count", 0)
|
||||
|
||||
build_logger.info(f"[{task_id}] 图谱构建完成: graph_id={graph_id}, 节点={node_count}, 边={edge_count}")
|
||||
|
||||
# 完成
|
||||
|
|
@ -572,10 +580,11 @@ def get_graph_data(graph_id: str):
|
|||
获取图谱数据(节点和边)
|
||||
"""
|
||||
try:
|
||||
if not Config.ZEP_API_KEY:
|
||||
errors = Config.get_graph_backend_config_errors()
|
||||
if errors:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": t('api.zepApiKeyMissing')
|
||||
"error": "; ".join(errors)
|
||||
}), 500
|
||||
|
||||
builder = GraphBuilderService(api_key=Config.ZEP_API_KEY)
|
||||
|
|
@ -600,10 +609,11 @@ def delete_graph(graph_id: str):
|
|||
删除Zep图谱
|
||||
"""
|
||||
try:
|
||||
if not Config.ZEP_API_KEY:
|
||||
errors = Config.get_graph_backend_config_errors()
|
||||
if errors:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": t('api.zepApiKeyMissing')
|
||||
"error": "; ".join(errors)
|
||||
}), 500
|
||||
|
||||
builder = GraphBuilderService(api_key=Config.ZEP_API_KEY)
|
||||
|
|
|
|||
|
|
@ -57,10 +57,11 @@ def get_graph_entities(graph_id: str):
|
|||
enrich: 是否获取相关边信息(默认true)
|
||||
"""
|
||||
try:
|
||||
if not Config.ZEP_API_KEY:
|
||||
errors = Config.get_graph_backend_config_errors()
|
||||
if errors:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": t('api.zepApiKeyMissing')
|
||||
"error": "; ".join(errors)
|
||||
}), 500
|
||||
|
||||
entity_types_str = request.args.get('entity_types', '')
|
||||
|
|
@ -94,10 +95,11 @@ def get_graph_entities(graph_id: str):
|
|||
def get_entity_detail(graph_id: str, entity_uuid: str):
|
||||
"""获取单个实体的详细信息"""
|
||||
try:
|
||||
if not Config.ZEP_API_KEY:
|
||||
errors = Config.get_graph_backend_config_errors()
|
||||
if errors:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": t('api.zepApiKeyMissing')
|
||||
"error": "; ".join(errors)
|
||||
}), 500
|
||||
|
||||
reader = ZepEntityReader()
|
||||
|
|
@ -127,10 +129,11 @@ def get_entity_detail(graph_id: str, entity_uuid: str):
|
|||
def get_entities_by_type(graph_id: str, entity_type: str):
|
||||
"""获取指定类型的所有实体"""
|
||||
try:
|
||||
if not Config.ZEP_API_KEY:
|
||||
errors = Config.get_graph_backend_config_errors()
|
||||
if errors:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": t('api.zepApiKeyMissing')
|
||||
"error": "; ".join(errors)
|
||||
}), 500
|
||||
|
||||
enrich = request.args.get('enrich', 'true').lower() == 'true'
|
||||
|
|
|
|||
|
|
@ -33,7 +33,54 @@ class Config:
|
|||
LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME', 'gpt-4o-mini')
|
||||
|
||||
# Zep配置
|
||||
ZEP_MODE = os.environ.get('ZEP_MODE', 'cloud').lower()
|
||||
ZEP_API_KEY = os.environ.get('ZEP_API_KEY')
|
||||
ZEP_BASE_URL = os.environ.get('ZEP_BASE_URL')
|
||||
OPENZEP_EMBEDDER_API_KEY = os.environ.get('OPENZEP_EMBEDDER_API_KEY') or None
|
||||
OPENZEP_EMBEDDER_BASE_URL = os.environ.get('OPENZEP_EMBEDDER_BASE_URL') or None
|
||||
OPENZEP_EMBEDDER_MODEL = os.environ.get('OPENZEP_EMBEDDER_MODEL') or None
|
||||
|
||||
# 图后端配置
|
||||
GRAPH_BACKEND = os.environ.get('GRAPH_BACKEND', 'graphiti').lower()
|
||||
GRAPH_SEARCH_RERANKER = os.environ.get('GRAPH_SEARCH_RERANKER', 'rrf').strip() or None
|
||||
GRAPH_SEARCH_APP_RERANKER = (os.environ.get('GRAPH_SEARCH_APP_RERANKER', 'embedding_rrf').strip().lower() or 'embedding_rrf')
|
||||
GRAPH_SEARCH_APP_RERANK_FUSION_K = max(1, int(os.environ.get('GRAPH_SEARCH_APP_RERANK_FUSION_K', '60')))
|
||||
GRAPH_SEARCH_APP_SEMANTIC_WEIGHT = max(0.0, float(os.environ.get('GRAPH_SEARCH_APP_SEMANTIC_WEIGHT', '2.0')))
|
||||
GRAPH_SEARCH_APP_EMBEDDER_API_KEY = os.environ.get('GRAPH_SEARCH_APP_EMBEDDER_API_KEY') or None
|
||||
GRAPH_SEARCH_APP_EMBEDDER_BASE_URL = os.environ.get('GRAPH_SEARCH_APP_EMBEDDER_BASE_URL') or None
|
||||
GRAPH_SEARCH_APP_EMBEDDER_MODEL = os.environ.get('GRAPH_SEARCH_APP_EMBEDDER_MODEL') or None
|
||||
GRAPH_SEARCH_APP_EMBED_BATCH_SIZE = max(1, int(os.environ.get('GRAPH_SEARCH_APP_EMBED_BATCH_SIZE', '32')))
|
||||
GRAPH_SEARCH_APP_RERANKER_API_KEY = os.environ.get('GRAPH_SEARCH_APP_RERANKER_API_KEY') or None
|
||||
GRAPH_SEARCH_APP_RERANKER_BASE_URL = os.environ.get('GRAPH_SEARCH_APP_RERANKER_BASE_URL') or None
|
||||
GRAPH_SEARCH_APP_RERANKER_MODEL = os.environ.get('GRAPH_SEARCH_APP_RERANKER_MODEL') or None
|
||||
GRAPH_SEARCH_APP_RERANKER_PROVIDER = (os.environ.get('GRAPH_SEARCH_APP_RERANKER_PROVIDER', 'auto').strip().lower() or 'auto')
|
||||
GRAPH_SEARCH_APP_RERANKER_TIMEOUT = max(1.0, float(os.environ.get('GRAPH_SEARCH_APP_RERANKER_TIMEOUT', '20')))
|
||||
GRAPH_SEARCH_INCLUDE_NODES = os.environ.get('GRAPH_SEARCH_INCLUDE_NODES', 'true').lower() == 'true'
|
||||
GRAPH_SEARCH_EDGE_LIMIT_MULTIPLIER = max(1, int(os.environ.get('GRAPH_SEARCH_EDGE_LIMIT_MULTIPLIER', '2')))
|
||||
GRAPH_SEARCH_NODE_LIMIT_MULTIPLIER = max(1, int(os.environ.get('GRAPH_SEARCH_NODE_LIMIT_MULTIPLIER', '1')))
|
||||
GRAPH_SEARCH_NODE_SUMMARY_LIMIT = max(1, int(os.environ.get('GRAPH_SEARCH_NODE_SUMMARY_LIMIT', '5')))
|
||||
GRAPH_SEARCH_EXPAND_EDGES_FROM_NODES = os.environ.get('GRAPH_SEARCH_EXPAND_EDGES_FROM_NODES', 'true').lower() == 'true'
|
||||
GRAPH_SEARCH_NODE_EDGE_EXPANSION_LIMIT = max(0, int(os.environ.get('GRAPH_SEARCH_NODE_EDGE_EXPANSION_LIMIT', '2')))
|
||||
GRAPH_SEARCH_NODE_EDGE_PER_NODE_LIMIT = max(1, int(os.environ.get('GRAPH_SEARCH_NODE_EDGE_PER_NODE_LIMIT', '8')))
|
||||
GRAPHITI_URI = os.environ.get('GRAPHITI_URI')
|
||||
GRAPHITI_USER = os.environ.get('GRAPHITI_USER', 'neo4j')
|
||||
GRAPHITI_PASSWORD = os.environ.get('GRAPHITI_PASSWORD')
|
||||
GRAPHITI_DATABASE = os.environ.get('GRAPHITI_DATABASE', 'neo4j')
|
||||
GRAPHITI_LLM_API_KEY = os.environ.get('GRAPHITI_LLM_API_KEY') or LLM_API_KEY
|
||||
GRAPHITI_LLM_BASE_URL = os.environ.get('GRAPHITI_LLM_BASE_URL') or LLM_BASE_URL
|
||||
GRAPHITI_LLM_MODEL = os.environ.get('GRAPHITI_LLM_MODEL') or LLM_MODEL_NAME
|
||||
GRAPHITI_LLM_SMALL_MODEL = os.environ.get('GRAPHITI_LLM_SMALL_MODEL') or GRAPHITI_LLM_MODEL
|
||||
GRAPHITI_LLM_CLIENT_MODE = os.environ.get('GRAPHITI_LLM_CLIENT_MODE', 'openai').lower()
|
||||
GRAPHITI_LLM_MAX_TOKENS = max(1024, int(os.environ.get('GRAPHITI_LLM_MAX_TOKENS', '16384')))
|
||||
GRAPHITI_EMBEDDER_API_KEY = os.environ.get('GRAPHITI_EMBEDDER_API_KEY') or GRAPHITI_LLM_API_KEY
|
||||
GRAPHITI_EMBEDDER_BASE_URL = os.environ.get('GRAPHITI_EMBEDDER_BASE_URL') or GRAPHITI_LLM_BASE_URL
|
||||
GRAPHITI_EMBEDDER_MODEL = os.environ.get('GRAPHITI_EMBEDDER_MODEL', 'qwen3-embedding:8b')
|
||||
GRAPHITI_EMBEDDER_DIM = max(128, int(os.environ.get('GRAPHITI_EMBEDDER_DIM', '1024')))
|
||||
GRAPHITI_RERANKER_API_KEY = os.environ.get('GRAPHITI_RERANKER_API_KEY') or GRAPHITI_LLM_API_KEY
|
||||
GRAPHITI_RERANKER_BASE_URL = os.environ.get('GRAPHITI_RERANKER_BASE_URL') or GRAPHITI_LLM_BASE_URL
|
||||
GRAPHITI_RERANKER_MODEL = os.environ.get('GRAPHITI_RERANKER_MODEL') or GRAPHITI_LLM_MODEL
|
||||
GRAPHITI_ENABLE_CROSS_ENCODER = os.environ.get('GRAPHITI_ENABLE_CROSS_ENCODER', 'false').lower() == 'true'
|
||||
GRAPHITI_MAX_COROUTINES = max(1, int(os.environ.get('GRAPHITI_MAX_COROUTINES', '20')))
|
||||
|
||||
# 文件上传配置
|
||||
MAX_CONTENT_LENGTH = 50 * 1024 * 1024 # 50MB
|
||||
|
|
@ -62,6 +109,123 @@ class Config:
|
|||
REPORT_AGENT_MAX_TOOL_CALLS = int(os.environ.get('REPORT_AGENT_MAX_TOOL_CALLS', '5'))
|
||||
REPORT_AGENT_MAX_REFLECTION_ROUNDS = int(os.environ.get('REPORT_AGENT_MAX_REFLECTION_ROUNDS', '2'))
|
||||
REPORT_AGENT_TEMPERATURE = float(os.environ.get('REPORT_AGENT_TEMPERATURE', '0.5'))
|
||||
|
||||
@classmethod
|
||||
def use_openzep(cls):
|
||||
"""是否启用 OpenZep / 自定义 Zep endpoint。"""
|
||||
return cls.ZEP_MODE == 'openzep' or bool(cls.ZEP_BASE_URL)
|
||||
|
||||
@classmethod
|
||||
def is_zep_configured(cls, api_key=None):
|
||||
"""检查 Zep/OpenZep 是否已完成最小配置。"""
|
||||
resolved_api_key = cls.ZEP_API_KEY if api_key is None else api_key
|
||||
|
||||
if cls.use_openzep():
|
||||
return bool(cls.ZEP_BASE_URL)
|
||||
|
||||
return bool(resolved_api_key)
|
||||
|
||||
@classmethod
|
||||
def get_zep_client_kwargs(cls, api_key=None):
|
||||
"""生成 Zep 客户端初始化参数。"""
|
||||
resolved_api_key = cls.ZEP_API_KEY if api_key is None else api_key
|
||||
kwargs = {}
|
||||
|
||||
# OpenZep 场景允许关闭鉴权;此时不要传空字符串 api_key,
|
||||
# 否则 zep sdk 会构造出非法的 `Api-Key:` 请求头,httpx 会直接拒绝。
|
||||
if resolved_api_key:
|
||||
kwargs['api_key'] = resolved_api_key
|
||||
if cls.ZEP_BASE_URL:
|
||||
kwargs['base_url'] = cls.ZEP_BASE_URL
|
||||
|
||||
return kwargs
|
||||
|
||||
@classmethod
|
||||
def get_zep_config_errors(cls, api_key=None):
|
||||
"""返回 Zep/OpenZep 的配置错误。"""
|
||||
if cls.is_zep_configured(api_key=api_key):
|
||||
return []
|
||||
|
||||
if cls.use_openzep():
|
||||
return ["ZEP_BASE_URL 未配置"]
|
||||
|
||||
return ["ZEP_API_KEY 未配置"]
|
||||
|
||||
@classmethod
|
||||
def get_graph_search_embedder_config(cls):
|
||||
"""返回 app-side 图搜索语义重排使用的 embedding 配置。"""
|
||||
api_key = cls.GRAPH_SEARCH_APP_EMBEDDER_API_KEY
|
||||
base_url = cls.GRAPH_SEARCH_APP_EMBEDDER_BASE_URL
|
||||
model = cls.GRAPH_SEARCH_APP_EMBEDDER_MODEL
|
||||
|
||||
backend = (cls.GRAPH_BACKEND or 'graphiti').lower()
|
||||
if backend == 'graphiti':
|
||||
api_key = api_key or cls.GRAPHITI_EMBEDDER_API_KEY
|
||||
base_url = base_url or cls.GRAPHITI_EMBEDDER_BASE_URL
|
||||
model = model or cls.GRAPHITI_EMBEDDER_MODEL
|
||||
elif cls.use_openzep():
|
||||
api_key = api_key or cls.OPENZEP_EMBEDDER_API_KEY
|
||||
base_url = base_url or cls.OPENZEP_EMBEDDER_BASE_URL
|
||||
model = model or cls.OPENZEP_EMBEDDER_MODEL
|
||||
|
||||
if model and not api_key:
|
||||
api_key = 'ollama'
|
||||
|
||||
return {
|
||||
'api_key': api_key,
|
||||
'base_url': base_url,
|
||||
'model': model,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_graph_search_reranker_config(cls):
|
||||
"""返回 app-side 图搜索交叉编码重排使用的 reranker 配置。"""
|
||||
api_key = cls.GRAPH_SEARCH_APP_RERANKER_API_KEY
|
||||
base_url = cls.GRAPH_SEARCH_APP_RERANKER_BASE_URL
|
||||
model = cls.GRAPH_SEARCH_APP_RERANKER_MODEL
|
||||
provider = cls.GRAPH_SEARCH_APP_RERANKER_PROVIDER
|
||||
|
||||
graphiti_reranker_base_url = os.environ.get('GRAPHITI_RERANKER_BASE_URL') or None
|
||||
if graphiti_reranker_base_url:
|
||||
api_key = api_key or (os.environ.get('GRAPHITI_RERANKER_API_KEY') or None)
|
||||
base_url = base_url or graphiti_reranker_base_url
|
||||
model = model or (os.environ.get('GRAPHITI_RERANKER_MODEL') or None)
|
||||
|
||||
return {
|
||||
'api_key': api_key,
|
||||
'base_url': base_url,
|
||||
'model': model,
|
||||
'provider': provider,
|
||||
'timeout': cls.GRAPH_SEARCH_APP_RERANKER_TIMEOUT,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_graphiti_config_errors(cls):
|
||||
"""返回 Graphiti + Neo4j 的配置错误。"""
|
||||
errors = []
|
||||
if not cls.GRAPHITI_URI:
|
||||
errors.append("GRAPHITI_URI 未配置")
|
||||
if not cls.GRAPHITI_DATABASE:
|
||||
errors.append("GRAPHITI_DATABASE 未配置")
|
||||
if not cls.GRAPHITI_LLM_MODEL:
|
||||
errors.append("GRAPHITI_LLM_MODEL 未配置")
|
||||
if not cls.GRAPHITI_EMBEDDER_MODEL:
|
||||
errors.append("GRAPHITI_EMBEDDER_MODEL 未配置")
|
||||
return errors
|
||||
|
||||
@classmethod
|
||||
def get_graph_backend_config_errors(cls, api_key=None):
|
||||
"""根据当前 GRAPH_BACKEND 返回对应的配置错误。"""
|
||||
backend = (cls.GRAPH_BACKEND or 'graphiti').lower()
|
||||
if backend in {'zep', 'openzep'}:
|
||||
return cls.get_zep_config_errors(api_key=api_key)
|
||||
if backend == 'graphiti':
|
||||
return cls.get_graphiti_config_errors()
|
||||
return [f"不支持的 GRAPH_BACKEND: {backend}"]
|
||||
|
||||
@classmethod
|
||||
def is_graph_backend_configured(cls, api_key=None):
|
||||
return len(cls.get_graph_backend_config_errors(api_key=api_key)) == 0
|
||||
|
||||
@classmethod
|
||||
def validate(cls):
|
||||
|
|
@ -69,7 +233,5 @@ class Config:
|
|||
errors = []
|
||||
if not cls.LLM_API_KEY:
|
||||
errors.append("LLM_API_KEY 未配置")
|
||||
if not cls.ZEP_API_KEY:
|
||||
errors.append("ZEP_API_KEY 未配置")
|
||||
errors.extend(cls.get_graph_backend_config_errors())
|
||||
return errors
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,8 @@
|
|||
"""
|
||||
Graph backend abstractions.
|
||||
"""
|
||||
|
||||
from .base import GraphBackend
|
||||
from .factory import get_graph_backend
|
||||
|
||||
__all__ = ["GraphBackend", "get_graph_backend"]
|
||||
|
|
@ -0,0 +1,81 @@
|
|||
"""
|
||||
Common graph backend interface.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class GraphBackend(ABC):
|
||||
"""Minimal graph backend interface used by the application services."""
|
||||
|
||||
@property
|
||||
def raw_client(self) -> Any:
|
||||
"""Return the underlying SDK client when a service still needs it."""
|
||||
return None
|
||||
|
||||
@abstractmethod
|
||||
def create_graph(self, graph_id: str, name: str, description: str) -> None:
|
||||
"""Create a graph."""
|
||||
|
||||
@abstractmethod
|
||||
def set_ontology(
|
||||
self,
|
||||
graph_id: str,
|
||||
entities: Any = None,
|
||||
edges: Any = None,
|
||||
) -> None:
|
||||
"""Set graph ontology."""
|
||||
|
||||
@abstractmethod
|
||||
def add_batch(self, graph_id: str, episodes: List[Any]) -> List[Any]:
|
||||
"""Add a batch of episodes."""
|
||||
|
||||
@abstractmethod
|
||||
def add_text(self, graph_id: str, data: str) -> Any:
|
||||
"""Add a single text episode."""
|
||||
|
||||
@abstractmethod
|
||||
def get_episode(self, episode_uuid: str) -> Any:
|
||||
"""Fetch a single episode."""
|
||||
|
||||
@abstractmethod
|
||||
def search(
|
||||
self,
|
||||
graph_id: str,
|
||||
query: str,
|
||||
limit: int = 10,
|
||||
scope: str = "edges",
|
||||
reranker: Optional[str] = None,
|
||||
) -> Any:
|
||||
"""Search the graph."""
|
||||
|
||||
@abstractmethod
|
||||
def get_all_nodes(self, graph_id: str) -> List[Any]:
|
||||
"""Fetch all nodes."""
|
||||
|
||||
@abstractmethod
|
||||
def get_all_edges(self, graph_id: str) -> List[Any]:
|
||||
"""Fetch all edges."""
|
||||
|
||||
@abstractmethod
|
||||
def get_node(self, node_uuid: str) -> Any:
|
||||
"""Fetch a node by UUID."""
|
||||
|
||||
@abstractmethod
|
||||
def get_node_edges(self, node_uuid: str) -> List[Any]:
|
||||
"""Fetch edges for a node."""
|
||||
|
||||
@abstractmethod
|
||||
def delete_graph(self, graph_id: str) -> None:
|
||||
"""Delete a graph."""
|
||||
|
||||
def get_ontology_spec(self, graph_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""Fetch backend ontology metadata when available."""
|
||||
return None
|
||||
|
||||
def get_live_graph_statistics(self, graph_id: str) -> Optional[Dict[str, int]]:
|
||||
"""Fetch backend-specific live statistics when available."""
|
||||
return None
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
"""
|
||||
Graph backend factory.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from ..config import Config
|
||||
from .base import GraphBackend
|
||||
|
||||
|
||||
def get_graph_backend(api_key: Optional[str] = None) -> GraphBackend:
|
||||
"""Create the configured graph backend."""
|
||||
backend = (Config.GRAPH_BACKEND or "graphiti").lower()
|
||||
|
||||
if backend in {"zep", "openzep"}:
|
||||
from .zep_backend import ZepGraphBackend
|
||||
|
||||
return ZepGraphBackend(api_key=api_key)
|
||||
|
||||
if backend == "graphiti":
|
||||
from .graphiti_backend import GraphitiBackend
|
||||
|
||||
return GraphitiBackend(api_key=api_key)
|
||||
|
||||
raise ValueError(f"不支持的 GRAPH_BACKEND: {backend}")
|
||||
|
|
@ -0,0 +1,875 @@
|
|||
"""
|
||||
Graphiti + Neo4j graph backend implementation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
from ..config import Config
|
||||
from .base import GraphBackend
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _CompatEpisode:
|
||||
uuid: str
|
||||
processed: bool = True
|
||||
name: str = ""
|
||||
content: str = ""
|
||||
valid_at: Optional[datetime] = None
|
||||
created_at: Optional[datetime] = None
|
||||
|
||||
@property
|
||||
def uuid_(self) -> str:
|
||||
return self.uuid
|
||||
|
||||
|
||||
@dataclass
|
||||
class _CompatNode:
|
||||
uuid: str
|
||||
name: str = ""
|
||||
labels: List[str] = field(default_factory=list)
|
||||
summary: str = ""
|
||||
attributes: Dict[str, Any] = field(default_factory=dict)
|
||||
created_at: Optional[datetime] = None
|
||||
|
||||
@property
|
||||
def uuid_(self) -> str:
|
||||
return self.uuid
|
||||
|
||||
|
||||
@dataclass
|
||||
class _CompatEdge:
|
||||
uuid: str
|
||||
name: str
|
||||
fact: str
|
||||
source_node_uuid: str
|
||||
target_node_uuid: str
|
||||
source_node_name: str = ""
|
||||
target_node_name: str = ""
|
||||
attributes: Dict[str, Any] = field(default_factory=dict)
|
||||
episodes: List[str] = field(default_factory=list)
|
||||
created_at: Optional[datetime] = None
|
||||
valid_at: Optional[datetime] = None
|
||||
invalid_at: Optional[datetime] = None
|
||||
expired_at: Optional[datetime] = None
|
||||
|
||||
@property
|
||||
def uuid_(self) -> str:
|
||||
return self.uuid
|
||||
|
||||
|
||||
@dataclass
|
||||
class _CompatSearchResults:
|
||||
edges: List[_CompatEdge] = field(default_factory=list)
|
||||
nodes: List[_CompatNode] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class _OntologyBundle:
|
||||
entity_types: Dict[str, type[BaseModel]] = field(default_factory=dict)
|
||||
edge_types: Dict[str, type[BaseModel]] = field(default_factory=dict)
|
||||
edge_type_map: Dict[tuple[str, str], List[str]] = field(default_factory=dict)
|
||||
spec: Dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class _AsyncBridge:
|
||||
"""Run async Graphiti calls from the app's synchronous service layer."""
|
||||
|
||||
def __init__(self):
|
||||
self._ready = threading.Event()
|
||||
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
||||
self._thread.start()
|
||||
self._ready.wait()
|
||||
|
||||
def _run_loop(self):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
self._loop = loop
|
||||
self._ready.set()
|
||||
loop.run_forever()
|
||||
|
||||
def run(self, coro):
|
||||
if self._loop is None:
|
||||
raise RuntimeError("Graphiti async loop 未初始化")
|
||||
future = asyncio.run_coroutine_threadsafe(coro, self._loop)
|
||||
return future.result()
|
||||
|
||||
|
||||
class GraphitiBackend(GraphBackend):
|
||||
"""Graph backend backed by Graphiti OSS + Neo4j."""
|
||||
|
||||
_bridge: Optional[_AsyncBridge] = None
|
||||
_bridge_lock = threading.Lock()
|
||||
_indices_ready = False
|
||||
_indices_lock = threading.Lock()
|
||||
_ontology_registry: Dict[str, _OntologyBundle] = {}
|
||||
_ontology_lock = threading.Lock()
|
||||
_cross_encoder_warning_emitted = False
|
||||
PAGE_SIZE = 200
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
del api_key
|
||||
|
||||
errors = Config.get_graphiti_config_errors()
|
||||
if errors:
|
||||
raise ValueError("; ".join(errors))
|
||||
|
||||
try:
|
||||
from graphiti_core.cross_encoder.openai_reranker_client import OpenAIRerankerClient
|
||||
from graphiti_core.embedder.openai import OpenAIEmbedder, OpenAIEmbedderConfig
|
||||
from graphiti_core.graphiti import Graphiti
|
||||
from graphiti_core.llm_client import OpenAIClient
|
||||
from graphiti_core.llm_client.config import LLMConfig
|
||||
from graphiti_core.llm_client.openai_generic_client import OpenAIGenericClient
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"Graphiti 依赖未安装,请先在 backend 环境中安装 graphiti-core 与 neo4j"
|
||||
) from exc
|
||||
|
||||
llm_config = LLMConfig(
|
||||
api_key=Config.GRAPHITI_LLM_API_KEY,
|
||||
base_url=Config.GRAPHITI_LLM_BASE_URL,
|
||||
model=Config.GRAPHITI_LLM_MODEL,
|
||||
small_model=Config.GRAPHITI_LLM_SMALL_MODEL,
|
||||
temperature=0,
|
||||
)
|
||||
reranker_config = LLMConfig(
|
||||
api_key=Config.GRAPHITI_RERANKER_API_KEY,
|
||||
base_url=Config.GRAPHITI_RERANKER_BASE_URL,
|
||||
model=Config.GRAPHITI_RERANKER_MODEL,
|
||||
temperature=0,
|
||||
)
|
||||
embedder_config = OpenAIEmbedderConfig(
|
||||
api_key=Config.GRAPHITI_EMBEDDER_API_KEY,
|
||||
base_url=Config.GRAPHITI_EMBEDDER_BASE_URL,
|
||||
embedding_model=Config.GRAPHITI_EMBEDDER_MODEL,
|
||||
embedding_dim=Config.GRAPHITI_EMBEDDER_DIM,
|
||||
)
|
||||
|
||||
llm_client_mode = (Config.GRAPHITI_LLM_CLIENT_MODE or "openai").lower()
|
||||
if llm_client_mode == "generic":
|
||||
llm_client = OpenAIGenericClient(
|
||||
config=llm_config,
|
||||
max_tokens=Config.GRAPHITI_LLM_MAX_TOKENS,
|
||||
)
|
||||
else:
|
||||
llm_client = OpenAIClient(
|
||||
config=llm_config,
|
||||
max_tokens=Config.GRAPHITI_LLM_MAX_TOKENS,
|
||||
)
|
||||
|
||||
self._graphiti = Graphiti(
|
||||
uri=Config.GRAPHITI_URI,
|
||||
user=Config.GRAPHITI_USER,
|
||||
password=Config.GRAPHITI_PASSWORD,
|
||||
llm_client=llm_client,
|
||||
embedder=OpenAIEmbedder(config=embedder_config),
|
||||
cross_encoder=OpenAIRerankerClient(config=reranker_config),
|
||||
max_coroutines=Config.GRAPHITI_MAX_COROUTINES,
|
||||
)
|
||||
self._driver = self._graphiti.driver.with_database(Config.GRAPHITI_DATABASE)
|
||||
self._graphiti.driver = self._driver
|
||||
self._graphiti.clients.driver = self._driver
|
||||
self._bridge = self._get_bridge()
|
||||
|
||||
self._ensure_indices()
|
||||
|
||||
@classmethod
|
||||
def _get_bridge(cls) -> _AsyncBridge:
|
||||
with cls._bridge_lock:
|
||||
if cls._bridge is None:
|
||||
cls._bridge = _AsyncBridge()
|
||||
return cls._bridge
|
||||
|
||||
@property
|
||||
def raw_client(self) -> Any:
|
||||
return self._graphiti
|
||||
|
||||
def _run(self, coro):
|
||||
return self._bridge.run(coro)
|
||||
|
||||
def _ensure_indices(self) -> None:
|
||||
if self.__class__._indices_ready:
|
||||
return
|
||||
|
||||
with self.__class__._indices_lock:
|
||||
if self.__class__._indices_ready:
|
||||
return
|
||||
self._run(self._graphiti.build_indices_and_constraints())
|
||||
self.__class__._indices_ready = True
|
||||
|
||||
def _validate_graph_id(self, graph_id: str) -> None:
|
||||
from graphiti_core.helpers import validate_group_id
|
||||
|
||||
validate_group_id(graph_id)
|
||||
|
||||
def _normalize_model_spec(self, model: type[BaseModel]) -> Dict[str, Any]:
|
||||
fields = []
|
||||
for field_name, model_field in model.model_fields.items():
|
||||
fields.append(
|
||||
{
|
||||
"name": field_name,
|
||||
"description": model_field.description or field_name,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"description": (getattr(model, "__doc__", "") or "").strip(),
|
||||
"fields": fields,
|
||||
}
|
||||
|
||||
def _build_dynamic_model(self, name: str, spec: Dict[str, Any]) -> type[BaseModel]:
|
||||
field_definitions = {}
|
||||
for field_spec in spec.get("fields", []):
|
||||
field_name = field_spec.get("name", "").strip()
|
||||
if not field_name:
|
||||
continue
|
||||
field_definitions[field_name] = (
|
||||
Optional[str],
|
||||
Field(
|
||||
default=None,
|
||||
description=field_spec.get("description") or field_name,
|
||||
),
|
||||
)
|
||||
|
||||
model = create_model(name, __base__=BaseModel, **field_definitions)
|
||||
model.__doc__ = spec.get("description") or name
|
||||
return model
|
||||
|
||||
def _serialize_ontology_spec(
|
||||
self,
|
||||
entity_specs: Dict[str, Dict[str, Any]],
|
||||
edge_specs: Dict[str, Dict[str, Any]],
|
||||
edge_type_map: Dict[tuple[str, str], List[str]],
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"entity_types": entity_specs,
|
||||
"edge_types": edge_specs,
|
||||
"edge_type_map": [
|
||||
{
|
||||
"source": source,
|
||||
"target": target,
|
||||
"edges": edge_names,
|
||||
}
|
||||
for (source, target), edge_names in sorted(edge_type_map.items())
|
||||
],
|
||||
}
|
||||
|
||||
def _bundle_from_spec(self, spec: Dict[str, Any]) -> _OntologyBundle:
|
||||
entity_types = {
|
||||
name: self._build_dynamic_model(name, model_spec)
|
||||
for name, model_spec in (spec.get("entity_types") or {}).items()
|
||||
}
|
||||
edge_types = {
|
||||
name: self._build_dynamic_model(name, model_spec)
|
||||
for name, model_spec in (spec.get("edge_types") or {}).items()
|
||||
}
|
||||
edge_type_map: Dict[tuple[str, str], List[str]] = {}
|
||||
for entry in spec.get("edge_type_map") or []:
|
||||
source = entry.get("source", "Entity")
|
||||
target = entry.get("target", "Entity")
|
||||
edge_type_map[(source, target)] = list(entry.get("edges") or [])
|
||||
|
||||
return _OntologyBundle(
|
||||
entity_types=entity_types,
|
||||
edge_types=edge_types,
|
||||
edge_type_map=edge_type_map,
|
||||
spec=spec,
|
||||
)
|
||||
|
||||
def _build_ontology_bundle(self, entities: Any = None, edges: Any = None) -> _OntologyBundle:
|
||||
entity_specs = {}
|
||||
entity_types = {}
|
||||
for entity_name, entity_model in (entities or {}).items():
|
||||
entity_spec = self._normalize_model_spec(entity_model)
|
||||
entity_specs[entity_name] = entity_spec
|
||||
entity_types[entity_name] = self._build_dynamic_model(entity_name, entity_spec)
|
||||
|
||||
edge_specs = {}
|
||||
edge_types = {}
|
||||
edge_type_map: Dict[tuple[str, str], List[str]] = {}
|
||||
for edge_name, edge_value in (edges or {}).items():
|
||||
if not isinstance(edge_value, tuple) or len(edge_value) != 2:
|
||||
continue
|
||||
edge_model, source_targets = edge_value
|
||||
edge_spec = self._normalize_model_spec(edge_model)
|
||||
edge_specs[edge_name] = edge_spec
|
||||
edge_types[edge_name] = self._build_dynamic_model(edge_name, edge_spec)
|
||||
|
||||
for source_target in source_targets or []:
|
||||
source = getattr(source_target, "source", "Entity") or "Entity"
|
||||
target = getattr(source_target, "target", "Entity") or "Entity"
|
||||
edge_type_map.setdefault((source, target), []).append(edge_name)
|
||||
|
||||
if not edge_type_map:
|
||||
edge_type_map = {("Entity", "Entity"): list(edge_types.keys())}
|
||||
|
||||
spec = self._serialize_ontology_spec(entity_specs, edge_specs, edge_type_map)
|
||||
return _OntologyBundle(
|
||||
entity_types=entity_types,
|
||||
edge_types=edge_types,
|
||||
edge_type_map=edge_type_map,
|
||||
spec=spec,
|
||||
)
|
||||
|
||||
async def _upsert_graph_metadata_async(
|
||||
self,
|
||||
graph_id: str,
|
||||
name: Optional[str] = None,
|
||||
description: Optional[str] = None,
|
||||
ontology_spec: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
payload = json.dumps(ontology_spec, ensure_ascii=False) if ontology_spec is not None else None
|
||||
|
||||
await self._driver.execute_query(
|
||||
"""
|
||||
MERGE (m:GraphMetadata {graph_id: $graph_id})
|
||||
ON CREATE SET
|
||||
m.group_id = $graph_id,
|
||||
m.created_at = datetime()
|
||||
SET
|
||||
m.name = CASE WHEN $name IS NULL OR $name = '' THEN coalesce(m.name, '') ELSE $name END,
|
||||
m.description = CASE
|
||||
WHEN $description IS NULL OR $description = ''
|
||||
THEN coalesce(m.description, '')
|
||||
ELSE $description
|
||||
END,
|
||||
m.ontology_json = CASE
|
||||
WHEN $ontology_json IS NULL OR $ontology_json = ''
|
||||
THEN m.ontology_json
|
||||
ELSE $ontology_json
|
||||
END,
|
||||
m.updated_at = datetime()
|
||||
""",
|
||||
graph_id=graph_id,
|
||||
name=name,
|
||||
description=description,
|
||||
ontology_json=payload,
|
||||
)
|
||||
|
||||
async def _load_ontology_bundle_async(self, graph_id: str) -> _OntologyBundle:
|
||||
records, _, _ = await self._driver.execute_query(
|
||||
"""
|
||||
MATCH (m:GraphMetadata {graph_id: $graph_id})
|
||||
RETURN m.ontology_json AS ontology_json
|
||||
""",
|
||||
graph_id=graph_id,
|
||||
routing_="r",
|
||||
)
|
||||
|
||||
if not records:
|
||||
return _OntologyBundle()
|
||||
|
||||
ontology_json = records[0].get("ontology_json")
|
||||
if not ontology_json:
|
||||
return _OntologyBundle()
|
||||
|
||||
try:
|
||||
spec = json.loads(ontology_json)
|
||||
except (TypeError, ValueError, json.JSONDecodeError):
|
||||
logger.warning("Graphiti ontology metadata 解析失败,graph_id=%s", graph_id)
|
||||
return _OntologyBundle()
|
||||
|
||||
return self._bundle_from_spec(spec)
|
||||
|
||||
async def _get_ontology_bundle_async(self, graph_id: str) -> _OntologyBundle:
|
||||
with self.__class__._ontology_lock:
|
||||
cached = self.__class__._ontology_registry.get(graph_id)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
bundle = await self._load_ontology_bundle_async(graph_id)
|
||||
with self.__class__._ontology_lock:
|
||||
self.__class__._ontology_registry[graph_id] = bundle
|
||||
return bundle
|
||||
|
||||
def _get_ontology_bundle(self, graph_id: str) -> _OntologyBundle:
|
||||
return self._run(self._get_ontology_bundle_async(graph_id))
|
||||
|
||||
def _set_ontology_bundle(self, graph_id: str, bundle: _OntologyBundle) -> None:
|
||||
with self.__class__._ontology_lock:
|
||||
self.__class__._ontology_registry[graph_id] = bundle
|
||||
|
||||
def get_ontology_spec(self, graph_id: str) -> Optional[Dict[str, Any]]:
|
||||
self._validate_graph_id(graph_id)
|
||||
bundle = self._get_ontology_bundle(graph_id)
|
||||
return dict(bundle.spec) if bundle.spec else None
|
||||
|
||||
def create_graph(self, graph_id: str, name: str, description: str) -> None:
|
||||
self._validate_graph_id(graph_id)
|
||||
self._run(
|
||||
self._upsert_graph_metadata_async(
|
||||
graph_id=graph_id,
|
||||
name=name,
|
||||
description=description,
|
||||
)
|
||||
)
|
||||
|
||||
def set_ontology(
|
||||
self,
|
||||
graph_id: str,
|
||||
entities: Any = None,
|
||||
edges: Any = None,
|
||||
) -> None:
|
||||
self._validate_graph_id(graph_id)
|
||||
bundle = self._build_ontology_bundle(entities=entities, edges=edges)
|
||||
self._set_ontology_bundle(graph_id, bundle)
|
||||
self._run(
|
||||
self._upsert_graph_metadata_async(
|
||||
graph_id=graph_id,
|
||||
ontology_spec=bundle.spec,
|
||||
)
|
||||
)
|
||||
|
||||
async def _add_text_async(self, graph_id: str, data: str) -> _CompatEpisode:
|
||||
from graphiti_core.helpers import validate_excluded_entity_types
|
||||
from graphiti_core.nodes import EpisodeType, EpisodicNode
|
||||
from graphiti_core.search.search_utils import RELEVANT_SCHEMA_LIMIT
|
||||
from graphiti_core.utils.datetime_utils import utc_now
|
||||
from graphiti_core.utils.maintenance.node_operations import (
|
||||
extract_attributes_from_nodes,
|
||||
extract_nodes,
|
||||
resolve_extracted_nodes,
|
||||
)
|
||||
from graphiti_core.utils.ontology_utils.entity_types_utils import validate_entity_types
|
||||
|
||||
bundle = await self._get_ontology_bundle_async(graph_id)
|
||||
entity_types = bundle.entity_types or None
|
||||
edge_types = bundle.edge_types or None
|
||||
edge_type_map = bundle.edge_type_map or {("Entity", "Entity"): []}
|
||||
|
||||
validate_entity_types(entity_types)
|
||||
validate_excluded_entity_types(None, entity_types)
|
||||
|
||||
now = utc_now()
|
||||
previous_episodes = await self._graphiti.retrieve_episodes(
|
||||
reference_time=now,
|
||||
last_n=RELEVANT_SCHEMA_LIMIT,
|
||||
group_ids=[graph_id],
|
||||
source=EpisodeType.text,
|
||||
driver=self._driver,
|
||||
)
|
||||
|
||||
episode = EpisodicNode(
|
||||
name=f"episode_{now.strftime('%Y%m%d%H%M%S%f')}",
|
||||
group_id=graph_id,
|
||||
labels=[],
|
||||
source=EpisodeType.text,
|
||||
content=data,
|
||||
source_description="text",
|
||||
created_at=now,
|
||||
valid_at=now,
|
||||
)
|
||||
|
||||
extracted_nodes = await extract_nodes(
|
||||
self._graphiti.clients,
|
||||
episode,
|
||||
previous_episodes,
|
||||
entity_types,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
nodes, uuid_map, _ = await resolve_extracted_nodes(
|
||||
self._graphiti.clients,
|
||||
extracted_nodes,
|
||||
episode,
|
||||
previous_episodes,
|
||||
entity_types,
|
||||
)
|
||||
resolved_edges, invalidated_edges, new_edges = await self._graphiti._extract_and_resolve_edges(
|
||||
episode,
|
||||
extracted_nodes,
|
||||
previous_episodes,
|
||||
edge_type_map,
|
||||
graph_id,
|
||||
edge_types,
|
||||
nodes,
|
||||
uuid_map,
|
||||
None,
|
||||
)
|
||||
entity_edges = resolved_edges + invalidated_edges
|
||||
hydrated_nodes = await extract_attributes_from_nodes(
|
||||
self._graphiti.clients,
|
||||
nodes,
|
||||
episode,
|
||||
previous_episodes,
|
||||
entity_types,
|
||||
edges=new_edges,
|
||||
)
|
||||
_, saved_episode = await self._graphiti._process_episode_data(
|
||||
episode,
|
||||
hydrated_nodes,
|
||||
entity_edges,
|
||||
now,
|
||||
graph_id,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
|
||||
return _CompatEpisode(
|
||||
uuid=saved_episode.uuid,
|
||||
processed=True,
|
||||
name=saved_episode.name,
|
||||
content=saved_episode.content,
|
||||
valid_at=saved_episode.valid_at,
|
||||
created_at=saved_episode.created_at,
|
||||
)
|
||||
|
||||
def add_batch(self, graph_id: str, episodes: List[Any]) -> List[Any]:
|
||||
results = []
|
||||
for episode in episodes:
|
||||
data = getattr(episode, "data", None)
|
||||
if data is None and isinstance(episode, dict):
|
||||
data = episode.get("data", "")
|
||||
results.append(self.add_text(graph_id=graph_id, data=str(data or "")))
|
||||
return results
|
||||
|
||||
def add_text(self, graph_id: str, data: str) -> Any:
|
||||
self._validate_graph_id(graph_id)
|
||||
return self._run(self._add_text_async(graph_id=graph_id, data=data))
|
||||
|
||||
async def _get_episode_async(self, episode_uuid: str) -> _CompatEpisode:
|
||||
from graphiti_core.nodes import EpisodicNode
|
||||
|
||||
episode = await EpisodicNode.get_by_uuid(self._driver, episode_uuid)
|
||||
return _CompatEpisode(
|
||||
uuid=episode.uuid,
|
||||
processed=True,
|
||||
name=episode.name,
|
||||
content=episode.content,
|
||||
valid_at=episode.valid_at,
|
||||
created_at=episode.created_at,
|
||||
)
|
||||
|
||||
def get_episode(self, episode_uuid: str) -> Any:
|
||||
return self._run(self._get_episode_async(episode_uuid))
|
||||
|
||||
def _warn_cross_encoder_fallback(self) -> None:
|
||||
if self.__class__._cross_encoder_warning_emitted:
|
||||
return
|
||||
logger.info(
|
||||
"Graphiti cross_encoder 默认已降级为 rrf;如需启用请设置 GRAPHITI_ENABLE_CROSS_ENCODER=true"
|
||||
)
|
||||
self.__class__._cross_encoder_warning_emitted = True
|
||||
|
||||
def _build_search_config(self, scope: str, limit: int, reranker: Optional[str]):
|
||||
from graphiti_core.search.search_config import (
|
||||
EdgeReranker,
|
||||
EdgeSearchConfig,
|
||||
EdgeSearchMethod,
|
||||
NodeReranker,
|
||||
NodeSearchConfig,
|
||||
NodeSearchMethod,
|
||||
SearchConfig,
|
||||
)
|
||||
|
||||
reranker_name = (reranker or "rrf").strip().lower()
|
||||
edge_reranker_map = {
|
||||
"rrf": EdgeReranker.rrf,
|
||||
"reciprocal_rank_fusion": EdgeReranker.rrf,
|
||||
"cross_encoder": EdgeReranker.cross_encoder,
|
||||
"node_distance": EdgeReranker.node_distance,
|
||||
"episode_mentions": EdgeReranker.episode_mentions,
|
||||
"mmr": EdgeReranker.mmr,
|
||||
}
|
||||
node_reranker_map = {
|
||||
"rrf": NodeReranker.rrf,
|
||||
"reciprocal_rank_fusion": NodeReranker.rrf,
|
||||
"cross_encoder": NodeReranker.cross_encoder,
|
||||
"node_distance": NodeReranker.node_distance,
|
||||
"episode_mentions": NodeReranker.episode_mentions,
|
||||
"mmr": NodeReranker.mmr,
|
||||
}
|
||||
|
||||
edge_reranker = edge_reranker_map.get(reranker_name, EdgeReranker.rrf)
|
||||
node_reranker = node_reranker_map.get(reranker_name, NodeReranker.rrf)
|
||||
edge_methods = [EdgeSearchMethod.bm25, EdgeSearchMethod.cosine_similarity]
|
||||
node_methods = [NodeSearchMethod.bm25, NodeSearchMethod.cosine_similarity]
|
||||
|
||||
if reranker_name == "cross_encoder":
|
||||
if Config.GRAPHITI_ENABLE_CROSS_ENCODER:
|
||||
edge_methods.append(EdgeSearchMethod.bfs)
|
||||
node_methods.append(NodeSearchMethod.bfs)
|
||||
else:
|
||||
self._warn_cross_encoder_fallback()
|
||||
edge_reranker = EdgeReranker.rrf
|
||||
node_reranker = NodeReranker.rrf
|
||||
|
||||
edge_config = None
|
||||
node_config = None
|
||||
if scope in {"edges", "both"}:
|
||||
edge_config = EdgeSearchConfig(
|
||||
search_methods=edge_methods,
|
||||
reranker=edge_reranker,
|
||||
)
|
||||
if scope in {"nodes", "both"}:
|
||||
node_config = NodeSearchConfig(
|
||||
search_methods=node_methods,
|
||||
reranker=node_reranker,
|
||||
)
|
||||
|
||||
return SearchConfig(
|
||||
edge_config=edge_config,
|
||||
node_config=node_config,
|
||||
limit=max(1, limit),
|
||||
)
|
||||
|
||||
def _wrap_node(self, node: Any) -> _CompatNode:
|
||||
return _CompatNode(
|
||||
uuid=getattr(node, "uuid", ""),
|
||||
name=getattr(node, "name", "") or "",
|
||||
labels=list(getattr(node, "labels", []) or []),
|
||||
summary=getattr(node, "summary", "") or "",
|
||||
attributes=dict(getattr(node, "attributes", {}) or {}),
|
||||
created_at=getattr(node, "created_at", None),
|
||||
)
|
||||
|
||||
def _wrap_edge(
|
||||
self,
|
||||
edge: Any,
|
||||
source_node_name: str = "",
|
||||
target_node_name: str = "",
|
||||
) -> _CompatEdge:
|
||||
return _CompatEdge(
|
||||
uuid=getattr(edge, "uuid", ""),
|
||||
name=getattr(edge, "name", "") or "",
|
||||
fact=getattr(edge, "fact", "") or "",
|
||||
source_node_uuid=getattr(edge, "source_node_uuid", "") or "",
|
||||
target_node_uuid=getattr(edge, "target_node_uuid", "") or "",
|
||||
source_node_name=source_node_name,
|
||||
target_node_name=target_node_name,
|
||||
attributes=dict(getattr(edge, "attributes", {}) or {}),
|
||||
episodes=list(getattr(edge, "episodes", []) or []),
|
||||
created_at=getattr(edge, "created_at", None),
|
||||
valid_at=getattr(edge, "valid_at", None),
|
||||
invalid_at=getattr(edge, "invalid_at", None),
|
||||
expired_at=getattr(edge, "expired_at", None),
|
||||
)
|
||||
|
||||
async def _search_async(
|
||||
self,
|
||||
graph_id: str,
|
||||
query: str,
|
||||
limit: int,
|
||||
scope: str,
|
||||
reranker: Optional[str],
|
||||
) -> _CompatSearchResults:
|
||||
from graphiti_core.nodes import EntityNode
|
||||
|
||||
search_config = self._build_search_config(scope=scope, limit=limit, reranker=reranker)
|
||||
results = await self._graphiti.search_(
|
||||
query=query,
|
||||
config=search_config,
|
||||
group_ids=[graph_id],
|
||||
driver=self._driver,
|
||||
)
|
||||
|
||||
nodes = [self._wrap_node(node) for node in results.nodes]
|
||||
node_name_map = {node.uuid: node.name for node in nodes if node.uuid}
|
||||
|
||||
missing_node_ids = {
|
||||
node_uuid
|
||||
for edge in results.edges
|
||||
for node_uuid in (edge.source_node_uuid, edge.target_node_uuid)
|
||||
if node_uuid and node_uuid not in node_name_map
|
||||
}
|
||||
if missing_node_ids:
|
||||
for node in await EntityNode.get_by_uuids(self._driver, list(missing_node_ids)):
|
||||
node_name_map[node.uuid] = node.name or ""
|
||||
|
||||
edges = [
|
||||
self._wrap_edge(
|
||||
edge,
|
||||
source_node_name=node_name_map.get(edge.source_node_uuid, ""),
|
||||
target_node_name=node_name_map.get(edge.target_node_uuid, ""),
|
||||
)
|
||||
for edge in results.edges
|
||||
]
|
||||
|
||||
return _CompatSearchResults(edges=edges, nodes=nodes)
|
||||
|
||||
def search(
|
||||
self,
|
||||
graph_id: str,
|
||||
query: str,
|
||||
limit: int = 10,
|
||||
scope: str = "edges",
|
||||
reranker: Optional[str] = None,
|
||||
) -> Any:
|
||||
self._validate_graph_id(graph_id)
|
||||
return self._run(
|
||||
self._search_async(
|
||||
graph_id=graph_id,
|
||||
query=query,
|
||||
limit=limit,
|
||||
scope=scope,
|
||||
reranker=reranker,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def _get_all_nodes_async(self, graph_id: str) -> List[_CompatNode]:
|
||||
from graphiti_core.nodes import EntityNode
|
||||
|
||||
result = []
|
||||
cursor = None
|
||||
while True:
|
||||
batch = await EntityNode.get_by_group_ids(
|
||||
self._driver,
|
||||
[graph_id],
|
||||
limit=self.PAGE_SIZE,
|
||||
uuid_cursor=cursor,
|
||||
)
|
||||
if not batch:
|
||||
break
|
||||
result.extend(self._wrap_node(node) for node in batch)
|
||||
if len(batch) < self.PAGE_SIZE:
|
||||
break
|
||||
cursor = batch[-1].uuid
|
||||
|
||||
return result
|
||||
|
||||
def get_all_nodes(self, graph_id: str) -> List[Any]:
|
||||
self._validate_graph_id(graph_id)
|
||||
return self._run(self._get_all_nodes_async(graph_id))
|
||||
|
||||
async def _get_all_edges_async(self, graph_id: str) -> List[_CompatEdge]:
|
||||
from graphiti_core.edges import EntityEdge, GroupsEdgesNotFoundError
|
||||
|
||||
result = []
|
||||
cursor = None
|
||||
while True:
|
||||
try:
|
||||
batch = await EntityEdge.get_by_group_ids(
|
||||
self._driver,
|
||||
[graph_id],
|
||||
limit=self.PAGE_SIZE,
|
||||
uuid_cursor=cursor,
|
||||
)
|
||||
except GroupsEdgesNotFoundError:
|
||||
break
|
||||
|
||||
if not batch:
|
||||
break
|
||||
|
||||
result.extend(self._wrap_edge(edge) for edge in batch)
|
||||
if len(batch) < self.PAGE_SIZE:
|
||||
break
|
||||
cursor = batch[-1].uuid
|
||||
|
||||
return result
|
||||
|
||||
def get_all_edges(self, graph_id: str) -> List[Any]:
|
||||
self._validate_graph_id(graph_id)
|
||||
return self._run(self._get_all_edges_async(graph_id))
|
||||
|
||||
async def _get_node_async(self, node_uuid: str) -> _CompatNode:
|
||||
from graphiti_core.nodes import EntityNode
|
||||
|
||||
return self._wrap_node(await EntityNode.get_by_uuid(self._driver, node_uuid))
|
||||
|
||||
def get_node(self, node_uuid: str) -> Any:
|
||||
return self._run(self._get_node_async(node_uuid))
|
||||
|
||||
async def _get_node_edges_async(self, node_uuid: str) -> List[_CompatEdge]:
|
||||
from graphiti_core.edges import EntityEdge
|
||||
from graphiti_core.nodes import EntityNode
|
||||
|
||||
edges = await EntityEdge.get_by_node_uuid(self._driver, node_uuid)
|
||||
related_node_ids = {
|
||||
related_uuid
|
||||
for edge in edges
|
||||
for related_uuid in (edge.source_node_uuid, edge.target_node_uuid)
|
||||
if related_uuid
|
||||
}
|
||||
node_name_map = {}
|
||||
if related_node_ids:
|
||||
for node in await EntityNode.get_by_uuids(self._driver, list(related_node_ids)):
|
||||
node_name_map[node.uuid] = node.name or ""
|
||||
|
||||
return [
|
||||
self._wrap_edge(
|
||||
edge,
|
||||
source_node_name=node_name_map.get(edge.source_node_uuid, ""),
|
||||
target_node_name=node_name_map.get(edge.target_node_uuid, ""),
|
||||
)
|
||||
for edge in edges
|
||||
]
|
||||
|
||||
def get_node_edges(self, node_uuid: str) -> List[Any]:
|
||||
return self._run(self._get_node_edges_async(node_uuid))
|
||||
|
||||
async def _delete_graph_async(self, graph_id: str) -> None:
|
||||
from graphiti_core.nodes import Node
|
||||
|
||||
await self._driver.execute_query(
|
||||
"""
|
||||
MATCH (s:Saga {group_id: $graph_id})
|
||||
DETACH DELETE s
|
||||
""",
|
||||
graph_id=graph_id,
|
||||
)
|
||||
await Node.delete_by_group_id(self._driver, graph_id)
|
||||
await self._driver.execute_query(
|
||||
"""
|
||||
MATCH (m:GraphMetadata {graph_id: $graph_id})
|
||||
DETACH DELETE m
|
||||
""",
|
||||
graph_id=graph_id,
|
||||
)
|
||||
|
||||
def delete_graph(self, graph_id: str) -> None:
|
||||
self._validate_graph_id(graph_id)
|
||||
self._run(self._delete_graph_async(graph_id))
|
||||
with self.__class__._ontology_lock:
|
||||
self.__class__._ontology_registry.pop(graph_id, None)
|
||||
|
||||
async def _get_live_graph_statistics_async(self, graph_id: str) -> Dict[str, int]:
|
||||
node_records, _, _ = await self._driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Entity {group_id: $graph_id})
|
||||
RETURN count(n) AS node_count
|
||||
""",
|
||||
graph_id=graph_id,
|
||||
routing_="r",
|
||||
)
|
||||
edge_records, _, _ = await self._driver.execute_query(
|
||||
"""
|
||||
MATCH ()-[e:RELATES_TO {group_id: $graph_id}]->()
|
||||
RETURN count(e) AS edge_count
|
||||
""",
|
||||
graph_id=graph_id,
|
||||
routing_="r",
|
||||
)
|
||||
episode_records, _, _ = await self._driver.execute_query(
|
||||
"""
|
||||
MATCH (n:Episodic {group_id: $graph_id})
|
||||
RETURN count(n) AS episode_count
|
||||
""",
|
||||
graph_id=graph_id,
|
||||
routing_="r",
|
||||
)
|
||||
|
||||
return {
|
||||
"node_count": int((node_records[0] if node_records else {}).get("node_count", 0) or 0),
|
||||
"edge_count": int((edge_records[0] if edge_records else {}).get("edge_count", 0) or 0),
|
||||
"episode_count": int(
|
||||
(episode_records[0] if episode_records else {}).get("episode_count", 0) or 0
|
||||
),
|
||||
}
|
||||
|
||||
def get_live_graph_statistics(self, graph_id: str) -> Optional[Dict[str, int]]:
|
||||
self._validate_graph_id(graph_id)
|
||||
return self._run(self._get_live_graph_statistics_async(graph_id))
|
||||
|
|
@ -0,0 +1,114 @@
|
|||
"""
|
||||
Zep / OpenZep graph backend implementation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
from urllib.error import HTTPError, URLError
|
||||
from urllib.request import Request, urlopen
|
||||
|
||||
from zep_cloud.client import Zep
|
||||
|
||||
from ..config import Config
|
||||
from ..utils.zep_paging import fetch_all_edges, fetch_all_nodes
|
||||
from .base import GraphBackend
|
||||
|
||||
|
||||
class ZepGraphBackend(GraphBackend):
|
||||
"""Graph backend backed by Zep Cloud or OpenZep."""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
self.api_key = Config.ZEP_API_KEY if api_key is None else api_key
|
||||
errors = Config.get_zep_config_errors(api_key=self.api_key)
|
||||
if errors:
|
||||
raise ValueError("; ".join(errors))
|
||||
|
||||
self.client = Zep(**Config.get_zep_client_kwargs(api_key=self.api_key))
|
||||
|
||||
@property
|
||||
def raw_client(self) -> Zep:
|
||||
return self.client
|
||||
|
||||
def create_graph(self, graph_id: str, name: str, description: str) -> None:
|
||||
self.client.graph.create(
|
||||
graph_id=graph_id,
|
||||
name=name,
|
||||
description=description,
|
||||
)
|
||||
|
||||
def set_ontology(
|
||||
self,
|
||||
graph_id: str,
|
||||
entities: Any = None,
|
||||
edges: Any = None,
|
||||
) -> None:
|
||||
self.client.graph.set_ontology(
|
||||
graph_ids=[graph_id],
|
||||
entities=entities,
|
||||
edges=edges,
|
||||
)
|
||||
|
||||
def add_batch(self, graph_id: str, episodes: List[Any]) -> List[Any]:
|
||||
return self.client.graph.add_batch(graph_id=graph_id, episodes=episodes)
|
||||
|
||||
def add_text(self, graph_id: str, data: str) -> Any:
|
||||
return self.client.graph.add(graph_id=graph_id, type="text", data=data)
|
||||
|
||||
def get_episode(self, episode_uuid: str) -> Any:
|
||||
return self.client.graph.episode.get(uuid_=episode_uuid)
|
||||
|
||||
def search(
|
||||
self,
|
||||
graph_id: str,
|
||||
query: str,
|
||||
limit: int = 10,
|
||||
scope: str = "edges",
|
||||
reranker: Optional[str] = None,
|
||||
) -> Any:
|
||||
kwargs = {
|
||||
"graph_id": graph_id,
|
||||
"query": query,
|
||||
"limit": limit,
|
||||
"scope": scope,
|
||||
}
|
||||
if reranker:
|
||||
kwargs["reranker"] = reranker
|
||||
return self.client.graph.search(**kwargs)
|
||||
|
||||
def get_all_nodes(self, graph_id: str) -> List[Any]:
|
||||
return fetch_all_nodes(self.client, graph_id)
|
||||
|
||||
def get_all_edges(self, graph_id: str) -> List[Any]:
|
||||
return fetch_all_edges(self.client, graph_id)
|
||||
|
||||
def get_node(self, node_uuid: str) -> Any:
|
||||
return self.client.graph.node.get(uuid_=node_uuid)
|
||||
|
||||
def get_node_edges(self, node_uuid: str) -> List[Any]:
|
||||
return self.client.graph.node.get_entity_edges(node_uuid=node_uuid)
|
||||
|
||||
def delete_graph(self, graph_id: str) -> None:
|
||||
self.client.graph.delete(graph_id=graph_id)
|
||||
|
||||
def get_live_graph_statistics(self, graph_id: str) -> Optional[Dict[str, int]]:
|
||||
if not Config.ZEP_BASE_URL:
|
||||
return None
|
||||
|
||||
base_url = Config.ZEP_BASE_URL.rstrip("/")
|
||||
request = Request(f"{base_url}/graph/{graph_id}/statistics")
|
||||
if self.api_key:
|
||||
request.add_header("Authorization", f"Bearer {self.api_key}")
|
||||
|
||||
try:
|
||||
with urlopen(request, timeout=10) as response:
|
||||
payload = json.loads(response.read().decode("utf-8"))
|
||||
except (HTTPError, URLError, TimeoutError, OSError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
return {
|
||||
"node_count": max(0, int(payload.get("node_count", 0) or 0)),
|
||||
"edge_count": max(0, int(payload.get("edge_count", 0) or 0)),
|
||||
"episode_count": max(0, int(payload.get("episode_count", 0) or 0)),
|
||||
}
|
||||
|
|
@ -7,15 +7,14 @@ import os
|
|||
import uuid
|
||||
import time
|
||||
import threading
|
||||
import json
|
||||
from typing import Dict, Any, List, Optional, Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from zep_cloud.client import Zep
|
||||
from zep_cloud import EpisodeData, EntityEdgeSourceTarget
|
||||
|
||||
from ..config import Config
|
||||
from ..graph import get_graph_backend
|
||||
from ..models.task import TaskManager, TaskStatus
|
||||
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
|
||||
from .text_processor import TextProcessor
|
||||
from ..utils.locale import t, get_locale, set_locale
|
||||
|
||||
|
|
@ -44,11 +43,12 @@ class GraphBuilderService:
|
|||
"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
self.api_key = api_key or Config.ZEP_API_KEY
|
||||
if not self.api_key:
|
||||
raise ValueError("ZEP_API_KEY 未配置")
|
||||
self.api_key = Config.ZEP_API_KEY if api_key is None else api_key
|
||||
errors = Config.get_graph_backend_config_errors(api_key=self.api_key)
|
||||
if errors:
|
||||
raise ValueError("; ".join(errors))
|
||||
|
||||
self.client = Zep(api_key=self.api_key)
|
||||
self.backend = get_graph_backend(api_key=self.api_key)
|
||||
self.task_manager = TaskManager()
|
||||
|
||||
def build_graph_async(
|
||||
|
|
@ -58,7 +58,7 @@ class GraphBuilderService:
|
|||
graph_name: str = "MiroFish Graph",
|
||||
chunk_size: int = 500,
|
||||
chunk_overlap: int = 50,
|
||||
batch_size: int = 3
|
||||
batch_size: int = 1
|
||||
) -> str:
|
||||
"""
|
||||
异步构建图谱
|
||||
|
|
@ -161,6 +161,7 @@ class GraphBuilderService:
|
|||
)
|
||||
|
||||
self._wait_for_episodes(
|
||||
graph_id,
|
||||
episode_uuids,
|
||||
lambda msg, prog: self.task_manager.update_task(
|
||||
task_id,
|
||||
|
|
@ -194,10 +195,10 @@ class GraphBuilderService:
|
|||
"""创建Zep图谱(公开方法)"""
|
||||
graph_id = f"mirofish_{uuid.uuid4().hex[:16]}"
|
||||
|
||||
self.client.graph.create(
|
||||
self.backend.create_graph(
|
||||
graph_id=graph_id,
|
||||
name=name,
|
||||
description="MiroFish Social Simulation Graph"
|
||||
description="MiroFish Social Simulation Graph",
|
||||
)
|
||||
|
||||
return graph_id
|
||||
|
|
@ -215,78 +216,115 @@ class GraphBuilderService:
|
|||
|
||||
# Zep 保留名称,不能作为属性名
|
||||
RESERVED_NAMES = {'uuid', 'name', 'group_id', 'name_embedding', 'summary', 'created_at'}
|
||||
|
||||
|
||||
def safe_attr_name(attr_name: str) -> str:
|
||||
"""将保留名称转换为安全名称"""
|
||||
if attr_name.lower() in RESERVED_NAMES:
|
||||
return f"entity_{attr_name}"
|
||||
return attr_name
|
||||
|
||||
|
||||
def normalize_attributes(raw_attributes: Any) -> List[Dict[str, str]]:
|
||||
normalized: List[Dict[str, str]] = []
|
||||
for attr_def in raw_attributes or []:
|
||||
if isinstance(attr_def, str):
|
||||
attr_def = {"name": attr_def, "description": attr_def}
|
||||
if not isinstance(attr_def, dict):
|
||||
continue
|
||||
|
||||
attr_name = str(attr_def.get("name", "")).strip()
|
||||
if not attr_name:
|
||||
continue
|
||||
|
||||
normalized.append({
|
||||
"name": attr_name,
|
||||
"description": str(attr_def.get("description") or attr_name),
|
||||
})
|
||||
return normalized
|
||||
|
||||
def normalize_source_targets(raw_source_targets: Any) -> List[EntityEdgeSourceTarget]:
|
||||
normalized: List[EntityEdgeSourceTarget] = []
|
||||
for source_target in raw_source_targets or []:
|
||||
if not isinstance(source_target, dict):
|
||||
continue
|
||||
|
||||
normalized.append(
|
||||
EntityEdgeSourceTarget(
|
||||
source=str(source_target.get("source", "Entity")) or "Entity",
|
||||
target=str(source_target.get("target", "Entity")) or "Entity",
|
||||
)
|
||||
)
|
||||
|
||||
# Zep API allows max 10 source_targets per edge type.
|
||||
return normalized[:10]
|
||||
|
||||
# 动态创建实体类型
|
||||
entity_types = {}
|
||||
for entity_def in ontology.get("entity_types", []):
|
||||
name = entity_def["name"]
|
||||
if not isinstance(entity_def, dict):
|
||||
continue
|
||||
|
||||
name = str(entity_def.get("name", "")).strip()
|
||||
if not name:
|
||||
continue
|
||||
|
||||
description = entity_def.get("description", f"A {name} entity.")
|
||||
|
||||
|
||||
# 创建属性字典和类型注解(Pydantic v2 需要)
|
||||
attrs = {"__doc__": description}
|
||||
annotations = {}
|
||||
|
||||
for attr_def in entity_def.get("attributes", []):
|
||||
|
||||
for attr_def in normalize_attributes(entity_def.get("attributes", [])):
|
||||
attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称
|
||||
attr_desc = attr_def.get("description", attr_name)
|
||||
# Zep API 需要 Field 的 description,这是必需的
|
||||
attrs[attr_name] = Field(description=attr_desc, default=None)
|
||||
annotations[attr_name] = Optional[EntityText] # 类型注解
|
||||
|
||||
|
||||
attrs["__annotations__"] = annotations
|
||||
|
||||
|
||||
# 动态创建类
|
||||
entity_class = type(name, (EntityModel,), attrs)
|
||||
entity_class.__doc__ = description
|
||||
entity_types[name] = entity_class
|
||||
|
||||
|
||||
# 动态创建边类型
|
||||
edge_definitions = {}
|
||||
for edge_def in ontology.get("edge_types", []):
|
||||
name = edge_def["name"]
|
||||
if not isinstance(edge_def, dict):
|
||||
continue
|
||||
|
||||
name = str(edge_def.get("name", "")).strip()
|
||||
if not name:
|
||||
continue
|
||||
|
||||
description = edge_def.get("description", f"A {name} relationship.")
|
||||
|
||||
|
||||
# 创建属性字典和类型注解
|
||||
attrs = {"__doc__": description}
|
||||
annotations = {}
|
||||
|
||||
for attr_def in edge_def.get("attributes", []):
|
||||
|
||||
for attr_def in normalize_attributes(edge_def.get("attributes", [])):
|
||||
attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称
|
||||
attr_desc = attr_def.get("description", attr_name)
|
||||
# Zep API 需要 Field 的 description,这是必需的
|
||||
attrs[attr_name] = Field(description=attr_desc, default=None)
|
||||
annotations[attr_name] = Optional[str] # 边属性用str类型
|
||||
|
||||
|
||||
attrs["__annotations__"] = annotations
|
||||
|
||||
|
||||
# 动态创建类
|
||||
class_name = ''.join(word.capitalize() for word in name.split('_'))
|
||||
edge_class = type(class_name, (EdgeModel,), attrs)
|
||||
edge_class.__doc__ = description
|
||||
|
||||
# 构建source_targets
|
||||
source_targets = []
|
||||
for st in edge_def.get("source_targets", []):
|
||||
source_targets.append(
|
||||
EntityEdgeSourceTarget(
|
||||
source=st.get("source", "Entity"),
|
||||
target=st.get("target", "Entity")
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
source_targets = normalize_source_targets(edge_def.get("source_targets", []))
|
||||
if source_targets:
|
||||
edge_definitions[name] = (edge_class, source_targets)
|
||||
|
||||
|
||||
# 调用Zep API设置本体
|
||||
if entity_types or edge_definitions:
|
||||
self.client.graph.set_ontology(
|
||||
graph_ids=[graph_id],
|
||||
self.backend.set_ontology(
|
||||
graph_id=graph_id,
|
||||
entities=entity_types if entity_types else None,
|
||||
edges=edge_definitions if edge_definitions else None,
|
||||
)
|
||||
|
|
@ -295,7 +333,7 @@ class GraphBuilderService:
|
|||
self,
|
||||
graph_id: str,
|
||||
chunks: List[str],
|
||||
batch_size: int = 3,
|
||||
batch_size: int = 1,
|
||||
progress_callback: Optional[Callable] = None
|
||||
) -> List[str]:
|
||||
"""分批添加文本到图谱,返回所有 episode 的 uuid 列表"""
|
||||
|
|
@ -322,7 +360,7 @@ class GraphBuilderService:
|
|||
|
||||
# 发送到Zep
|
||||
try:
|
||||
batch_result = self.client.graph.add_batch(
|
||||
batch_result = self.backend.add_batch(
|
||||
graph_id=graph_id,
|
||||
episodes=episodes
|
||||
)
|
||||
|
|
@ -343,70 +381,152 @@ class GraphBuilderService:
|
|||
raise
|
||||
|
||||
return episode_uuids
|
||||
|
||||
|
||||
|
||||
|
||||
def _get_live_graph_statistics(self, graph_id: str) -> Optional[Dict[str, int]]:
|
||||
"""直接读取后端的实时图谱统计。"""
|
||||
return self.backend.get_live_graph_statistics(graph_id)
|
||||
|
||||
def _wait_for_episodes(
|
||||
self,
|
||||
graph_id: str,
|
||||
episode_uuids: List[str],
|
||||
progress_callback: Optional[Callable] = None,
|
||||
timeout: int = 600
|
||||
):
|
||||
"""等待所有 episode 处理完成(通过查询每个 episode 的 processed 状态)"""
|
||||
"""等待 OpenZep 处理完成,优先参考真实图谱状态。"""
|
||||
if not episode_uuids:
|
||||
if progress_callback:
|
||||
progress_callback(t('progress.noEpisodesWait'), 1.0)
|
||||
return
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
pending_episodes = set(episode_uuids)
|
||||
completed_count = 0
|
||||
total_episodes = len(episode_uuids)
|
||||
|
||||
last_graph_signature: Optional[tuple[int, int, int]] = None
|
||||
stable_graph_polls = 0
|
||||
stable_graph_required = 2
|
||||
last_live_stats: Optional[Dict[str, int]] = None
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(t('progress.waitingEpisodes', count=total_episodes), 0)
|
||||
|
||||
progress_callback(f"开始等待 {total_episodes} 个文本块处理...", 0)
|
||||
|
||||
while pending_episodes:
|
||||
if time.time() - start_time > timeout:
|
||||
elapsed_seconds = time.time() - start_time
|
||||
if elapsed_seconds > timeout:
|
||||
if last_live_stats is not None:
|
||||
graph_episode_count = min(last_live_stats["episode_count"], total_episodes)
|
||||
graph_node_count = last_live_stats["node_count"]
|
||||
graph_edge_count = last_live_stats["edge_count"]
|
||||
graph_entity_like_nodes = max(0, graph_node_count - last_live_stats["episode_count"])
|
||||
if graph_episode_count >= total_episodes and (graph_entity_like_nodes > 0 or graph_edge_count > 0):
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
(
|
||||
f"OpenZep 接口进度未返回完成标记,但真实图谱已写入 "
|
||||
f"episodes={graph_episode_count}/{total_episodes}, "
|
||||
f"nodes={graph_node_count}, edges={graph_edge_count}"
|
||||
),
|
||||
1.0,
|
||||
)
|
||||
return
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
t('progress.episodesTimeout', completed=completed_count, total=total_episodes),
|
||||
completed_count / total_episodes
|
||||
)
|
||||
break
|
||||
|
||||
# 检查每个 episode 的处理状态
|
||||
|
||||
for ep_uuid in list(pending_episodes):
|
||||
try:
|
||||
episode = self.client.graph.episode.get(uuid_=ep_uuid)
|
||||
episode = self.backend.get_episode(ep_uuid)
|
||||
is_processed = getattr(episode, 'processed', False)
|
||||
|
||||
|
||||
if is_processed:
|
||||
pending_episodes.remove(ep_uuid)
|
||||
completed_count += 1
|
||||
|
||||
except Exception as e:
|
||||
# 忽略单个查询错误,继续
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
elapsed = int(time.time() - start_time)
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
t('progress.zepProcessing', completed=completed_count, total=total_episodes, pending=len(pending_episodes), elapsed=elapsed),
|
||||
completed_count / total_episodes if total_episodes > 0 else 0
|
||||
|
||||
live_stats = self._get_live_graph_statistics(graph_id)
|
||||
graph_episode_count = 0
|
||||
graph_node_count = 0
|
||||
graph_edge_count = 0
|
||||
graph_entity_like_nodes = 0
|
||||
graph_progress = 0.0
|
||||
|
||||
if live_stats is not None:
|
||||
last_live_stats = live_stats
|
||||
graph_episode_count = min(live_stats["episode_count"], total_episodes)
|
||||
graph_node_count = live_stats["node_count"]
|
||||
graph_edge_count = live_stats["edge_count"]
|
||||
graph_entity_like_nodes = max(0, graph_node_count - live_stats["episode_count"])
|
||||
graph_progress = graph_episode_count / total_episodes if total_episodes > 0 else 1.0
|
||||
|
||||
graph_signature = (
|
||||
graph_episode_count,
|
||||
graph_entity_like_nodes,
|
||||
graph_edge_count,
|
||||
)
|
||||
|
||||
if graph_signature == last_graph_signature:
|
||||
stable_graph_polls += 1
|
||||
else:
|
||||
last_graph_signature = graph_signature
|
||||
stable_graph_polls = 0
|
||||
|
||||
graph_ready = (
|
||||
graph_episode_count >= total_episodes
|
||||
and (graph_entity_like_nodes > 0 or graph_edge_count > 0)
|
||||
and stable_graph_polls >= stable_graph_required
|
||||
)
|
||||
if graph_ready:
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
(
|
||||
f"OpenZep 图谱已稳定: episodes={graph_episode_count}/{total_episodes}, "
|
||||
f"nodes={graph_node_count}, edges={graph_edge_count}"
|
||||
),
|
||||
1.0,
|
||||
)
|
||||
return
|
||||
|
||||
elapsed = int(elapsed_seconds)
|
||||
effective_progress = max(
|
||||
completed_count / total_episodes if total_episodes > 0 else 1.0,
|
||||
graph_progress,
|
||||
)
|
||||
if progress_callback:
|
||||
if live_stats is not None:
|
||||
progress_callback(
|
||||
(
|
||||
f"OpenZep处理中... 接口完成 {completed_count}/{total_episodes}, "
|
||||
f"图中已写入 episodes={graph_episode_count}/{total_episodes}, "
|
||||
f"nodes={graph_node_count}, edges={graph_edge_count} ({elapsed}秒)"
|
||||
),
|
||||
effective_progress,
|
||||
)
|
||||
else:
|
||||
progress_callback(
|
||||
f"Zep处理中... {completed_count}/{total_episodes} 完成, {len(pending_episodes)} 待处理 ({elapsed}秒)",
|
||||
completed_count / total_episodes if total_episodes > 0 else 0
|
||||
)
|
||||
|
||||
if pending_episodes:
|
||||
time.sleep(3) # 每3秒检查一次
|
||||
|
||||
time.sleep(3)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(t('progress.processingComplete', completed=completed_count, total=total_episodes), 1.0)
|
||||
|
||||
progress_callback(f"处理完成: {completed_count}/{total_episodes}", 1.0)
|
||||
|
||||
def _get_graph_info(self, graph_id: str) -> GraphInfo:
|
||||
"""获取图谱信息"""
|
||||
# 获取节点(分页)
|
||||
nodes = fetch_all_nodes(self.client, graph_id)
|
||||
nodes = self.backend.get_all_nodes(graph_id)
|
||||
|
||||
# 获取边(分页)
|
||||
edges = fetch_all_edges(self.client, graph_id)
|
||||
edges = self.backend.get_all_edges(graph_id)
|
||||
|
||||
# 统计实体类型
|
||||
entity_types = set()
|
||||
|
|
@ -433,8 +553,8 @@ class GraphBuilderService:
|
|||
Returns:
|
||||
包含nodes和edges的字典,包括时间信息、属性等详细数据
|
||||
"""
|
||||
nodes = fetch_all_nodes(self.client, graph_id)
|
||||
edges = fetch_all_edges(self.client, graph_id)
|
||||
nodes = self.backend.get_all_nodes(graph_id)
|
||||
edges = self.backend.get_all_edges(graph_id)
|
||||
|
||||
# 创建节点映射用于获取节点名称
|
||||
node_map = {}
|
||||
|
|
@ -502,5 +622,4 @@ class GraphBuilderService:
|
|||
|
||||
def delete_graph(self, graph_id: str):
|
||||
"""删除图谱"""
|
||||
self.client.graph.delete(graph_id=graph_id)
|
||||
|
||||
self.backend.delete_graph(graph_id)
|
||||
|
|
|
|||
|
|
@ -16,9 +16,9 @@ from dataclasses import dataclass, field
|
|||
from datetime import datetime
|
||||
|
||||
from openai import OpenAI
|
||||
from zep_cloud.client import Zep
|
||||
|
||||
from ..config import Config
|
||||
from ..graph import get_graph_backend
|
||||
from ..utils.logger import get_logger
|
||||
from ..utils.locale import get_language_instruction, get_locale, set_locale, t
|
||||
from .zep_entity_reader import EntityNode, ZepEntityReader
|
||||
|
|
@ -199,15 +199,15 @@ class OasisProfileGenerator:
|
|||
)
|
||||
|
||||
# Zep客户端用于检索丰富上下文
|
||||
self.zep_api_key = zep_api_key or Config.ZEP_API_KEY
|
||||
self.zep_client = None
|
||||
self.zep_api_key = Config.ZEP_API_KEY if zep_api_key is None else zep_api_key
|
||||
self.zep_backend = None
|
||||
self.graph_id = graph_id
|
||||
|
||||
if self.zep_api_key:
|
||||
if Config.is_graph_backend_configured(api_key=self.zep_api_key):
|
||||
try:
|
||||
self.zep_client = Zep(api_key=self.zep_api_key)
|
||||
self.zep_backend = get_graph_backend(api_key=self.zep_api_key)
|
||||
except Exception as e:
|
||||
logger.warning(f"Zep客户端初始化失败: {e}")
|
||||
logger.warning(f"图谱客户端初始化失败: {e}")
|
||||
|
||||
def generate_profile_from_entity(
|
||||
self,
|
||||
|
|
@ -298,7 +298,7 @@ class OasisProfileGenerator:
|
|||
"""
|
||||
import concurrent.futures
|
||||
|
||||
if not self.zep_client:
|
||||
if not self.zep_backend:
|
||||
return {"facts": [], "node_summaries": [], "context": ""}
|
||||
|
||||
entity_name = entity.name
|
||||
|
|
@ -324,7 +324,7 @@ class OasisProfileGenerator:
|
|||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return self.zep_client.graph.search(
|
||||
return self.zep_backend.search(
|
||||
query=comprehensive_query,
|
||||
graph_id=self.graph_id,
|
||||
limit=30,
|
||||
|
|
@ -349,7 +349,7 @@ class OasisProfileGenerator:
|
|||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return self.zep_client.graph.search(
|
||||
return self.zep_backend.search(
|
||||
query=comprehensive_query,
|
||||
graph_id=self.graph_id,
|
||||
limit=20,
|
||||
|
|
@ -1202,4 +1202,3 @@ class OasisProfileGenerator:
|
|||
"""[已废弃] 请使用 save_profiles() 方法"""
|
||||
logger.warning("save_profiles_to_json已废弃,请使用save_profiles方法")
|
||||
self.save_profiles(profiles, file_path, platform)
|
||||
|
||||
|
|
|
|||
|
|
@ -276,71 +276,84 @@ class OntologyGenerator:
|
|||
|
||||
def _validate_and_process(self, result: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""验证和后处理结果"""
|
||||
|
||||
|
||||
if not isinstance(result, dict):
|
||||
result = {}
|
||||
|
||||
# 确保必要字段存在
|
||||
if "entity_types" not in result:
|
||||
if not isinstance(result.get("entity_types"), list):
|
||||
result["entity_types"] = []
|
||||
if "edge_types" not in result:
|
||||
if not isinstance(result.get("edge_types"), list):
|
||||
result["edge_types"] = []
|
||||
if "analysis_summary" not in result:
|
||||
result["analysis_summary"] = ""
|
||||
|
||||
|
||||
# 验证实体类型
|
||||
# 记录原始名称到 PascalCase 的映射,用于后续修正 edge 的 source_targets 引用
|
||||
entity_name_map = {}
|
||||
validated_entities = []
|
||||
for entity in result["entity_types"]:
|
||||
# 强制将 entity name 转为 PascalCase(Zep API 要求)
|
||||
if "name" in entity:
|
||||
original_name = entity["name"]
|
||||
entity["name"] = _to_pascal_case(original_name)
|
||||
if entity["name"] != original_name:
|
||||
logger.warning(f"Entity type name '{original_name}' auto-converted to '{entity['name']}'")
|
||||
entity_name_map[original_name] = entity["name"]
|
||||
if "attributes" not in entity:
|
||||
entity["attributes"] = []
|
||||
if "examples" not in entity:
|
||||
entity["examples"] = []
|
||||
# 确保description不超过100字符
|
||||
if len(entity.get("description", "")) > 100:
|
||||
entity["description"] = entity["description"][:97] + "..."
|
||||
|
||||
if isinstance(entity, str):
|
||||
entity = {"name": entity, "description": f"Entity type: {entity}"}
|
||||
if not isinstance(entity, dict):
|
||||
continue
|
||||
|
||||
name = str(entity.get("name", "")).strip()
|
||||
if not name:
|
||||
continue
|
||||
|
||||
attributes = entity.get("attributes")
|
||||
if not isinstance(attributes, list):
|
||||
attributes = []
|
||||
|
||||
examples = entity.get("examples")
|
||||
if not isinstance(examples, list):
|
||||
examples = []
|
||||
|
||||
normalized = dict(entity)
|
||||
normalized["name"] = name
|
||||
normalized["attributes"] = attributes
|
||||
normalized["examples"] = examples
|
||||
if len(normalized.get("description", "")) > 100:
|
||||
normalized["description"] = normalized["description"][:97] + "..."
|
||||
|
||||
validated_entities.append(normalized)
|
||||
|
||||
result["entity_types"] = validated_entities
|
||||
|
||||
# 验证关系类型
|
||||
validated_edges = []
|
||||
for edge in result["edge_types"]:
|
||||
# 强制将 edge name 转为 SCREAMING_SNAKE_CASE(Zep API 要求)
|
||||
if "name" in edge:
|
||||
original_name = edge["name"]
|
||||
edge["name"] = original_name.upper()
|
||||
if edge["name"] != original_name:
|
||||
logger.warning(f"Edge type name '{original_name}' auto-converted to '{edge['name']}'")
|
||||
# 修正 source_targets 中的实体名称引用,与转换后的 PascalCase 保持一致
|
||||
for st in edge.get("source_targets", []):
|
||||
if st.get("source") in entity_name_map:
|
||||
st["source"] = entity_name_map[st["source"]]
|
||||
if st.get("target") in entity_name_map:
|
||||
st["target"] = entity_name_map[st["target"]]
|
||||
if "source_targets" not in edge:
|
||||
edge["source_targets"] = []
|
||||
if "attributes" not in edge:
|
||||
edge["attributes"] = []
|
||||
if len(edge.get("description", "")) > 100:
|
||||
edge["description"] = edge["description"][:97] + "..."
|
||||
|
||||
if isinstance(edge, str):
|
||||
edge = {"name": edge, "description": f"Relationship type: {edge}"}
|
||||
if not isinstance(edge, dict):
|
||||
continue
|
||||
|
||||
name = str(edge.get("name", "")).strip()
|
||||
if not name:
|
||||
continue
|
||||
|
||||
source_targets = edge.get("source_targets")
|
||||
if not isinstance(source_targets, list):
|
||||
source_targets = []
|
||||
|
||||
attributes = edge.get("attributes")
|
||||
if not isinstance(attributes, list):
|
||||
attributes = []
|
||||
|
||||
normalized = dict(edge)
|
||||
normalized["name"] = name
|
||||
normalized["source_targets"] = source_targets
|
||||
normalized["attributes"] = attributes
|
||||
if len(normalized.get("description", "")) > 100:
|
||||
normalized["description"] = normalized["description"][:97] + "..."
|
||||
|
||||
validated_edges.append(normalized)
|
||||
|
||||
result["edge_types"] = validated_edges
|
||||
|
||||
# Zep API 限制:最多 10 个自定义实体类型,最多 10 个自定义边类型
|
||||
MAX_ENTITY_TYPES = 10
|
||||
MAX_EDGE_TYPES = 10
|
||||
|
||||
# 去重:按 name 去重,保留首次出现的
|
||||
seen_names = set()
|
||||
deduped = []
|
||||
for entity in result["entity_types"]:
|
||||
name = entity.get("name", "")
|
||||
if name and name not in seen_names:
|
||||
seen_names.add(name)
|
||||
deduped.append(entity)
|
||||
elif name in seen_names:
|
||||
logger.warning(f"Duplicate entity type '{name}' removed during validation")
|
||||
result["entity_types"] = deduped
|
||||
|
||||
# 兜底类型定义
|
||||
person_fallback = {
|
||||
"name": "Person",
|
||||
|
|
@ -351,7 +364,7 @@ class OntologyGenerator:
|
|||
],
|
||||
"examples": ["ordinary citizen", "anonymous netizen"]
|
||||
}
|
||||
|
||||
|
||||
organization_fallback = {
|
||||
"name": "Organization",
|
||||
"description": "Any organization not fitting other specific organization types.",
|
||||
|
|
@ -361,40 +374,40 @@ class OntologyGenerator:
|
|||
],
|
||||
"examples": ["small business", "community group"]
|
||||
}
|
||||
|
||||
|
||||
# 检查是否已有兜底类型
|
||||
entity_names = {e["name"] for e in result["entity_types"]}
|
||||
has_person = "Person" in entity_names
|
||||
has_organization = "Organization" in entity_names
|
||||
|
||||
|
||||
# 需要添加的兜底类型
|
||||
fallbacks_to_add = []
|
||||
if not has_person:
|
||||
fallbacks_to_add.append(person_fallback)
|
||||
if not has_organization:
|
||||
fallbacks_to_add.append(organization_fallback)
|
||||
|
||||
|
||||
if fallbacks_to_add:
|
||||
current_count = len(result["entity_types"])
|
||||
needed_slots = len(fallbacks_to_add)
|
||||
|
||||
|
||||
# 如果添加后会超过 10 个,需要移除一些现有类型
|
||||
if current_count + needed_slots > MAX_ENTITY_TYPES:
|
||||
# 计算需要移除多少个
|
||||
to_remove = current_count + needed_slots - MAX_ENTITY_TYPES
|
||||
# 从末尾移除(保留前面更重要的具体类型)
|
||||
result["entity_types"] = result["entity_types"][:-to_remove]
|
||||
|
||||
|
||||
# 添加兜底类型
|
||||
result["entity_types"].extend(fallbacks_to_add)
|
||||
|
||||
|
||||
# 最终确保不超过限制(防御性编程)
|
||||
if len(result["entity_types"]) > MAX_ENTITY_TYPES:
|
||||
result["entity_types"] = result["entity_types"][:MAX_ENTITY_TYPES]
|
||||
|
||||
|
||||
if len(result["edge_types"]) > MAX_EDGE_TYPES:
|
||||
result["edge_types"] = result["edge_types"][:MAX_EDGE_TYPES]
|
||||
|
||||
|
||||
return result
|
||||
|
||||
def generate_python_code(self, ontology: Dict[str, Any]) -> str:
|
||||
|
|
|
|||
|
|
@ -7,11 +7,9 @@ import time
|
|||
from typing import Dict, Any, List, Optional, Set, Callable, TypeVar
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from zep_cloud.client import Zep
|
||||
|
||||
from ..config import Config
|
||||
from ..graph import get_graph_backend
|
||||
from ..utils.logger import get_logger
|
||||
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
|
||||
|
||||
logger = get_logger('mirofish.zep_entity_reader')
|
||||
|
||||
|
|
@ -79,11 +77,12 @@ class ZepEntityReader:
|
|||
"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
self.api_key = api_key or Config.ZEP_API_KEY
|
||||
if not self.api_key:
|
||||
raise ValueError("ZEP_API_KEY 未配置")
|
||||
self.api_key = Config.ZEP_API_KEY if api_key is None else api_key
|
||||
errors = Config.get_graph_backend_config_errors(api_key=self.api_key)
|
||||
if errors:
|
||||
raise ValueError("; ".join(errors))
|
||||
|
||||
self.client = Zep(api_key=self.api_key)
|
||||
self.backend = get_graph_backend(api_key=self.api_key)
|
||||
|
||||
def _call_with_retry(
|
||||
self,
|
||||
|
|
@ -136,7 +135,7 @@ class ZepEntityReader:
|
|||
"""
|
||||
logger.info(f"获取图谱 {graph_id} 的所有节点...")
|
||||
|
||||
nodes = fetch_all_nodes(self.client, graph_id)
|
||||
nodes = self.backend.get_all_nodes(graph_id)
|
||||
|
||||
nodes_data = []
|
||||
for node in nodes:
|
||||
|
|
@ -163,7 +162,7 @@ class ZepEntityReader:
|
|||
"""
|
||||
logger.info(f"获取图谱 {graph_id} 的所有边...")
|
||||
|
||||
edges = fetch_all_edges(self.client, graph_id)
|
||||
edges = self.backend.get_all_edges(graph_id)
|
||||
|
||||
edges_data = []
|
||||
for edge in edges:
|
||||
|
|
@ -192,7 +191,7 @@ class ZepEntityReader:
|
|||
try:
|
||||
# 使用重试机制调用Zep API
|
||||
edges = self._call_with_retry(
|
||||
func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid),
|
||||
func=lambda: self.backend.get_node_edges(node_uuid),
|
||||
operation_name=f"获取节点边(node={node_uuid[:8]}...)"
|
||||
)
|
||||
|
||||
|
|
@ -348,7 +347,7 @@ class ZepEntityReader:
|
|||
try:
|
||||
# 使用重试机制获取节点
|
||||
node = self._call_with_retry(
|
||||
func=lambda: self.client.graph.node.get(uuid_=entity_uuid),
|
||||
func=lambda: self.backend.get_node(entity_uuid),
|
||||
operation_name=f"获取节点详情(uuid={entity_uuid[:8]}...)"
|
||||
)
|
||||
|
||||
|
|
@ -434,4 +433,3 @@ class ZepEntityReader:
|
|||
)
|
||||
return result.entities
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,9 +12,8 @@ from dataclasses import dataclass
|
|||
from datetime import datetime
|
||||
from queue import Queue, Empty
|
||||
|
||||
from zep_cloud.client import Zep
|
||||
|
||||
from ..config import Config
|
||||
from ..graph import get_graph_backend
|
||||
from ..utils.logger import get_logger
|
||||
from ..utils.locale import get_locale, set_locale
|
||||
|
||||
|
|
@ -238,12 +237,13 @@ class ZepGraphMemoryUpdater:
|
|||
api_key: Zep API Key(可选,默认从配置读取)
|
||||
"""
|
||||
self.graph_id = graph_id
|
||||
self.api_key = api_key or Config.ZEP_API_KEY
|
||||
self.api_key = Config.ZEP_API_KEY if api_key is None else api_key
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("ZEP_API_KEY未配置")
|
||||
errors = Config.get_graph_backend_config_errors(api_key=self.api_key)
|
||||
if errors:
|
||||
raise ValueError("; ".join(errors))
|
||||
|
||||
self.client = Zep(api_key=self.api_key)
|
||||
self.backend = get_graph_backend(api_key=self.api_key)
|
||||
|
||||
# 活动队列
|
||||
self._activity_queue: Queue = Queue()
|
||||
|
|
@ -411,9 +411,8 @@ class ZepGraphMemoryUpdater:
|
|||
# 带重试的发送
|
||||
for attempt in range(self.MAX_RETRIES):
|
||||
try:
|
||||
self.client.graph.add(
|
||||
self.backend.add_text(
|
||||
graph_id=self.graph_id,
|
||||
type="text",
|
||||
data=combined_text
|
||||
)
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,47 @@
|
|||
"""
|
||||
Embedding client wrapper for OpenAI-compatible embedding APIs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
class EmbeddingClient:
|
||||
"""Thin wrapper around OpenAI-compatible embedding endpoints."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str],
|
||||
base_url: str,
|
||||
model: str,
|
||||
batch_size: int = 32,
|
||||
):
|
||||
if not base_url:
|
||||
raise ValueError("Embedding base_url 未配置")
|
||||
if not model:
|
||||
raise ValueError("Embedding model 未配置")
|
||||
|
||||
self.api_key = api_key or 'ollama'
|
||||
self.base_url = base_url
|
||||
self.model = model
|
||||
self.batch_size = max(1, int(batch_size))
|
||||
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed texts in batches while preserving input order."""
|
||||
if not texts:
|
||||
return []
|
||||
|
||||
embeddings: List[List[float]] = []
|
||||
normalized_inputs = [str(text or ' ').strip() or ' ' for text in texts]
|
||||
|
||||
for start in range(0, len(normalized_inputs), self.batch_size):
|
||||
batch = normalized_inputs[start:start + self.batch_size]
|
||||
response = self.client.embeddings.create(model=self.model, input=batch)
|
||||
data = sorted(response.data, key=lambda item: item.index)
|
||||
embeddings.extend(item.embedding for item in data)
|
||||
|
||||
return embeddings
|
||||
|
|
@ -0,0 +1,175 @@
|
|||
"""
|
||||
Reranker client wrappers for common HTTP rerank APIs.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
from urllib import error, request
|
||||
|
||||
|
||||
@dataclass
|
||||
class RerankerRequestSpec:
|
||||
provider: str
|
||||
path: str
|
||||
body: dict
|
||||
|
||||
|
||||
class RerankerClient:
|
||||
"""Thin wrapper around HTTP rerank endpoints."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
model: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
provider: str = "auto",
|
||||
timeout: float = 20.0,
|
||||
):
|
||||
if not base_url:
|
||||
raise ValueError("Reranker base_url 未配置")
|
||||
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.model = model
|
||||
self.api_key = api_key or None
|
||||
self.provider = (provider or "auto").strip().lower() or "auto"
|
||||
self.timeout = max(1.0, float(timeout))
|
||||
|
||||
def rerank(self, query: str, documents: List[str]) -> Dict[int, float]:
|
||||
"""Return index -> score for the supplied candidate documents."""
|
||||
if not documents:
|
||||
return {}
|
||||
|
||||
last_error: Optional[Exception] = None
|
||||
for spec in self._build_request_specs(query, documents):
|
||||
try:
|
||||
return self._execute_request(spec)
|
||||
except Exception as exc:
|
||||
last_error = exc
|
||||
|
||||
if last_error is not None:
|
||||
raise last_error
|
||||
raise RuntimeError("没有可用的 reranker provider")
|
||||
|
||||
def _build_request_specs(self, query: str, documents: List[str]) -> List[RerankerRequestSpec]:
|
||||
providers = [self.provider]
|
||||
if self.provider == "auto":
|
||||
providers = ["tei", "jina", "cohere", "vllm", "infinity"]
|
||||
|
||||
specs: List[RerankerRequestSpec] = []
|
||||
for provider in providers:
|
||||
if provider == "tei":
|
||||
specs.append(
|
||||
RerankerRequestSpec(
|
||||
provider=provider,
|
||||
path="/rerank",
|
||||
body={
|
||||
"query": query,
|
||||
"texts": documents,
|
||||
"truncate": True,
|
||||
"raw_scores": False,
|
||||
},
|
||||
)
|
||||
)
|
||||
elif provider in {"jina", "vllm", "infinity"}:
|
||||
body = {
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
"top_n": len(documents),
|
||||
"return_documents": False,
|
||||
}
|
||||
if self.model:
|
||||
body["model"] = self.model
|
||||
specs.append(
|
||||
RerankerRequestSpec(
|
||||
provider=provider,
|
||||
path="/v1/rerank",
|
||||
body=body,
|
||||
)
|
||||
)
|
||||
elif provider == "cohere":
|
||||
body = {
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
"top_n": len(documents),
|
||||
"return_documents": False,
|
||||
}
|
||||
if self.model:
|
||||
body["model"] = self.model
|
||||
specs.append(
|
||||
RerankerRequestSpec(
|
||||
provider=provider,
|
||||
path="/v2/rerank",
|
||||
body=body,
|
||||
)
|
||||
)
|
||||
|
||||
return specs
|
||||
|
||||
def _headers(self) -> dict:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if self.api_key:
|
||||
headers["Authorization"] = f"Bearer {self.api_key}"
|
||||
return headers
|
||||
|
||||
def _execute_request(self, spec: RerankerRequestSpec) -> Dict[int, float]:
|
||||
payload = json.dumps(spec.body).encode("utf-8")
|
||||
req = request.Request(
|
||||
f"{self.base_url}{spec.path}",
|
||||
data=payload,
|
||||
headers=self._headers(),
|
||||
method="POST",
|
||||
)
|
||||
|
||||
try:
|
||||
with request.urlopen(req, timeout=self.timeout) as resp:
|
||||
body = resp.read().decode("utf-8")
|
||||
except error.HTTPError as exc:
|
||||
detail = exc.read().decode("utf-8", errors="replace")
|
||||
raise RuntimeError(f"{spec.provider} rerank 请求失败: HTTP {exc.code}: {detail[:300]}") from exc
|
||||
except error.URLError as exc:
|
||||
raise RuntimeError(f"{spec.provider} rerank 请求失败: {exc.reason}") from exc
|
||||
|
||||
try:
|
||||
data = json.loads(body)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise RuntimeError(f"{spec.provider} rerank 返回非 JSON 响应") from exc
|
||||
|
||||
scores = self._parse_scores(spec.provider, data)
|
||||
if not scores:
|
||||
raise RuntimeError(f"{spec.provider} rerank 未返回有效分数")
|
||||
return scores
|
||||
|
||||
def _parse_scores(self, provider: str, payload: object) -> Dict[int, float]:
|
||||
if provider == "tei":
|
||||
if not isinstance(payload, list):
|
||||
raise RuntimeError("TEI rerank 响应格式异常")
|
||||
|
||||
scores: Dict[int, float] = {}
|
||||
for item in payload:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
index = item.get("index")
|
||||
score = item.get("score")
|
||||
if isinstance(index, int) and score is not None:
|
||||
scores[index] = float(score)
|
||||
return scores
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
raise RuntimeError("rerank 响应格式异常")
|
||||
|
||||
results = payload.get("results") or payload.get("data") or []
|
||||
scores: Dict[int, float] = {}
|
||||
if isinstance(results, list):
|
||||
for item in results:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
index = item.get("index")
|
||||
score = item.get("relevance_score")
|
||||
if score is None:
|
||||
score = item.get("score")
|
||||
if isinstance(index, int) and score is not None:
|
||||
scores[index] = float(score)
|
||||
return scores
|
||||
|
|
@ -14,10 +14,14 @@ dependencies = [
|
|||
"flask-cors>=6.0.0",
|
||||
|
||||
# LLM 相关
|
||||
"openai>=1.0.0",
|
||||
"openai>=1.91.0",
|
||||
|
||||
# Zep Cloud
|
||||
"zep-cloud==3.13.0",
|
||||
|
||||
# Graphiti / Neo4j
|
||||
"graphiti-core==0.28.2",
|
||||
"neo4j>=5.26.0",
|
||||
|
||||
# OASIS 社交媒体模拟
|
||||
"camel-oasis==0.2.5",
|
||||
|
|
@ -31,7 +35,7 @@ dependencies = [
|
|||
|
||||
# 工具库
|
||||
"python-dotenv>=1.0.0",
|
||||
"pydantic>=2.0.0",
|
||||
"pydantic>=2.11.5",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
|
|
|||
|
|
@ -10,8 +10,8 @@ flask>=3.0.0
|
|||
flask-cors>=6.0.0
|
||||
|
||||
# ============= LLM 相关 =============
|
||||
# OpenAI SDK(默认 LLM 后端)
|
||||
openai>=1.0.0
|
||||
# OpenAI SDK(统一使用 OpenAI 格式调用 LLM)
|
||||
openai>=1.91.0
|
||||
|
||||
# Prompture(可选)— 多供应商 LLM 支持:LM Studio, Ollama, Claude, Groq, Kimi 等
|
||||
# Install for multi-provider support: pip install prompture
|
||||
|
|
@ -21,6 +21,10 @@ openai>=1.0.0
|
|||
# ============= Zep Cloud =============
|
||||
zep-cloud==3.13.0
|
||||
|
||||
# ============= Graphiti / Neo4j =============
|
||||
graphiti-core==0.28.2
|
||||
neo4j>=5.26.0
|
||||
|
||||
# ============= OASIS 社交媒体模拟 =============
|
||||
# OASIS 社交模拟框架
|
||||
camel-oasis==0.2.5
|
||||
|
|
@ -37,4 +41,4 @@ chardet>=5.0.0
|
|||
python-dotenv>=1.0.0
|
||||
|
||||
# 数据验证
|
||||
pydantic>=2.0.0
|
||||
pydantic>=2.11.5
|
||||
|
|
|
|||
|
|
@ -0,0 +1,223 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
Service-level Graphiti smoke test for MiroFish.
|
||||
|
||||
The script creates a temporary graph, applies a small ontology, ingests a few
|
||||
text chunks through GraphBuilderService, then queries through ZepToolsService.
|
||||
It prints JSON output and deletes the graph by default.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Iterable
|
||||
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[2]
|
||||
BACKEND_ROOT = REPO_ROOT / "backend"
|
||||
|
||||
DEFAULT_CHUNKS = [
|
||||
"Alice works at Acme Robotics. Bob manages Alice.",
|
||||
"Acme Robotics is located in Tokyo. Alice relocated to Tokyo for work.",
|
||||
]
|
||||
|
||||
|
||||
def load_env_file(path: Path) -> None:
|
||||
"""Load a simple .env file without overriding existing variables."""
|
||||
if not path.exists():
|
||||
return
|
||||
|
||||
for raw_line in path.read_text(encoding="utf-8").splitlines():
|
||||
line = raw_line.strip()
|
||||
if not line or line.startswith("#") or "=" not in line:
|
||||
continue
|
||||
|
||||
key, value = line.split("=", 1)
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
|
||||
if value and value[0] == value[-1] and value[0] in {"'", '"'}:
|
||||
value = value[1:-1]
|
||||
|
||||
os.environ.setdefault(key, value)
|
||||
|
||||
|
||||
def prepare_environment() -> None:
|
||||
"""Apply repo-local defaults that make Graphiti smoke testing predictable."""
|
||||
load_env_file(REPO_ROOT / ".env")
|
||||
load_env_file(BACKEND_ROOT / ".env")
|
||||
|
||||
os.environ.setdefault("GRAPH_BACKEND", "graphiti")
|
||||
os.environ.setdefault("GRAPHITI_URI", "bolt://127.0.0.1:7687")
|
||||
os.environ.setdefault("GRAPHITI_USER", "neo4j")
|
||||
os.environ.setdefault("GRAPHITI_PASSWORD", "password123")
|
||||
os.environ.setdefault("GRAPHITI_DATABASE", "neo4j")
|
||||
os.environ.setdefault("GRAPHITI_LLM_BASE_URL", os.environ.get("LLM_BASE_URL", "http://127.0.0.1:18081/v1"))
|
||||
os.environ.setdefault("GRAPHITI_LLM_MODEL", os.environ.get("LLM_MODEL_NAME", "gpt-5.4"))
|
||||
os.environ.setdefault("GRAPHITI_LLM_CLIENT_MODE", "openai")
|
||||
os.environ.setdefault("GRAPH_SEARCH_RERANKER", "rrf")
|
||||
os.environ.setdefault("GRAPHITI_EMBEDDER_API_KEY", "ollama")
|
||||
os.environ.setdefault("GRAPHITI_EMBEDDER_BASE_URL", "http://127.0.0.1:11434/v1")
|
||||
os.environ.setdefault("GRAPHITI_EMBEDDER_MODEL", "qwen3-embedding:8b")
|
||||
os.environ.setdefault("GRAPHITI_EMBEDDER_DIM", "1024")
|
||||
|
||||
|
||||
def parse_args(argv: Iterable[str]) -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(description="Run a Graphiti service-level smoke test.")
|
||||
parser.add_argument(
|
||||
"--query",
|
||||
default="Where does Alice work?",
|
||||
help="Search query executed through ZepToolsService.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--chunk",
|
||||
action="append",
|
||||
dest="chunks",
|
||||
help="Text chunk to ingest. Repeat to add multiple chunks.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--graph-name",
|
||||
default="graphiti smoke test",
|
||||
help="Temporary graph name.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--keep-graph",
|
||||
action="store_true",
|
||||
help="Keep the temporary graph instead of deleting it on exit.",
|
||||
)
|
||||
return parser.parse_args(list(argv))
|
||||
|
||||
|
||||
def build_sample_ontology() -> dict:
|
||||
return {
|
||||
"entity_types": [
|
||||
{"name": "Person", "description": "A human individual.", "attributes": []},
|
||||
{"name": "Organization", "description": "A company or institution.", "attributes": []},
|
||||
{"name": "Location", "description": "A place or city.", "attributes": []},
|
||||
],
|
||||
"edge_types": [
|
||||
{
|
||||
"name": "WORKS_AT",
|
||||
"description": "A person works at an organization.",
|
||||
"attributes": [],
|
||||
"source_targets": [{"source": "Person", "target": "Organization"}],
|
||||
},
|
||||
{
|
||||
"name": "MANAGES",
|
||||
"description": "A person manages another person.",
|
||||
"attributes": [],
|
||||
"source_targets": [{"source": "Person", "target": "Person"}],
|
||||
},
|
||||
{
|
||||
"name": "LOCATED_IN",
|
||||
"description": "An organization is located in a place.",
|
||||
"attributes": [],
|
||||
"source_targets": [{"source": "Organization", "target": "Location"}],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def main(argv: Iterable[str]) -> int:
|
||||
prepare_environment()
|
||||
sys.path.insert(0, str(BACKEND_ROOT))
|
||||
|
||||
from app.config import Config
|
||||
|
||||
search_embedder = Config.get_graph_search_embedder_config()
|
||||
search_reranker = Config.get_graph_search_reranker_config()
|
||||
errors = Config.get_graph_backend_config_errors(api_key=Config.ZEP_API_KEY)
|
||||
if errors:
|
||||
print(
|
||||
json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Graph backend config is incomplete.",
|
||||
"backend": Config.GRAPH_BACKEND,
|
||||
"errors": errors,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
),
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 2
|
||||
|
||||
from app.services.graph_builder import GraphBuilderService
|
||||
from app.services.zep_tools import ZepToolsService
|
||||
|
||||
args = parse_args(argv)
|
||||
chunks = args.chunks or list(DEFAULT_CHUNKS)
|
||||
|
||||
builder = GraphBuilderService()
|
||||
tools = ZepToolsService()
|
||||
graph_id = builder.create_graph(f"{args.graph_name}-{uuid.uuid4().hex[:8]}")
|
||||
|
||||
try:
|
||||
builder.set_ontology(graph_id, build_sample_ontology())
|
||||
builder.add_text_batches(graph_id, chunks, batch_size=1)
|
||||
|
||||
result = tools.search_graph(
|
||||
graph_id=graph_id,
|
||||
query=args.query,
|
||||
limit=5,
|
||||
scope="edges",
|
||||
)
|
||||
stats = tools.get_graph_statistics(graph_id)
|
||||
|
||||
output = {
|
||||
"success": True,
|
||||
"graph_id": graph_id,
|
||||
"kept_graph": args.keep_graph,
|
||||
"backend": Config.GRAPH_BACKEND,
|
||||
"llm_base_url": Config.GRAPHITI_LLM_BASE_URL or Config.LLM_BASE_URL,
|
||||
"llm_model": Config.GRAPHITI_LLM_MODEL or Config.LLM_MODEL_NAME,
|
||||
"embedder_base_url": Config.GRAPHITI_EMBEDDER_BASE_URL,
|
||||
"embedder_model": Config.GRAPHITI_EMBEDDER_MODEL,
|
||||
"app_reranker": Config.GRAPH_SEARCH_APP_RERANKER,
|
||||
"app_embedder_base_url": search_embedder.get("base_url"),
|
||||
"app_embedder_model": search_embedder.get("model"),
|
||||
"app_reranker_base_url": search_reranker.get("base_url"),
|
||||
"app_reranker_model": search_reranker.get("model"),
|
||||
"app_reranker_provider": search_reranker.get("provider"),
|
||||
"query": args.query,
|
||||
"chunks": chunks,
|
||||
"stats": stats,
|
||||
"facts": result.facts,
|
||||
"edges": result.edges,
|
||||
"nodes": result.nodes,
|
||||
"total_count": result.total_count,
|
||||
}
|
||||
print(json.dumps(output, ensure_ascii=False, indent=2))
|
||||
return 0
|
||||
except Exception as exc:
|
||||
print(
|
||||
json.dumps(
|
||||
{
|
||||
"success": False,
|
||||
"graph_id": graph_id,
|
||||
"error": str(exc),
|
||||
"traceback": traceback.format_exc(),
|
||||
},
|
||||
ensure_ascii=False,
|
||||
indent=2,
|
||||
),
|
||||
file=sys.stderr,
|
||||
)
|
||||
return 1
|
||||
finally:
|
||||
if not args.keep_graph:
|
||||
try:
|
||||
builder.backend.delete_graph(graph_id)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main(sys.argv[1:]))
|
||||
Loading…
Reference in New Issue