refactor: replace Zep Cloud memory with local JSON file storage
Removes the zep-cloud dependency entirely and replaces it with a
local file-based graph store (LocalGraphStore) that persists nodes,
edges, and episodes as JSON files under uploads/graphs/{graph_id}/.
- Add backend/app/utils/local_graph_store.py: thread-safe JSON store
with keyword search, node upsert, and episode append
- Rewrite graph_builder.py: LLM-based entity/relationship extraction
from text batches, stored locally instead of sent to Zep Cloud
- Rewrite zep_graph_memory_updater.py: agent activities written as
episodes + searchable fact edges in local JSON
- Rewrite zep_entity_reader.py: reads nodes/edges from local JSON
- Rewrite zep_tools.py: keyword search on local JSON replaces
Zep semantic search; _local_search is now the primary path
- Update oasis_profile_generator.py: local store replaces Zep client
for entity context enrichment
- Update ontology_generator.py: generated code template uses
pydantic BaseModel instead of Zep EntityModel/EdgeModel
- Convert zep_paging.py to a no-op stub (pagination not needed)
- Remove ZEP_API_KEY from config.py, add GRAPH_STORAGE_DIR
- Remove ZEP_API_KEY guards from api/graph.py and api/simulation.py
- Remove zep-cloud==3.13.0 from requirements.txt and pyproject.toml
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
fa0f6519b1
commit
2bec63be1b
|
|
@ -283,17 +283,6 @@ def build_graph():
|
||||||
try:
|
try:
|
||||||
logger.info("=== 开始构建图谱 ===")
|
logger.info("=== 开始构建图谱 ===")
|
||||||
|
|
||||||
# 检查配置
|
|
||||||
errors = []
|
|
||||||
if not Config.ZEP_API_KEY:
|
|
||||||
errors.append(t('api.zepApiKeyMissing'))
|
|
||||||
if errors:
|
|
||||||
logger.error(f"配置错误: {errors}")
|
|
||||||
return jsonify({
|
|
||||||
"success": False,
|
|
||||||
"error": t('api.configError', details="; ".join(errors))
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
# 解析请求
|
# 解析请求
|
||||||
data = request.get_json() or {}
|
data = request.get_json() or {}
|
||||||
project_id = data.get('project_id')
|
project_id = data.get('project_id')
|
||||||
|
|
@ -387,7 +376,7 @@ def build_graph():
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建图谱构建服务
|
# 创建图谱构建服务
|
||||||
builder = GraphBuilderService(api_key=Config.ZEP_API_KEY)
|
builder = GraphBuilderService()
|
||||||
|
|
||||||
# 分块
|
# 分块
|
||||||
task_manager.update_task(
|
task_manager.update_task(
|
||||||
|
|
@ -572,13 +561,7 @@ def get_graph_data(graph_id: str):
|
||||||
获取图谱数据(节点和边)
|
获取图谱数据(节点和边)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if not Config.ZEP_API_KEY:
|
builder = GraphBuilderService()
|
||||||
return jsonify({
|
|
||||||
"success": False,
|
|
||||||
"error": t('api.zepApiKeyMissing')
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
builder = GraphBuilderService(api_key=Config.ZEP_API_KEY)
|
|
||||||
graph_data = builder.get_graph_data(graph_id)
|
graph_data = builder.get_graph_data(graph_id)
|
||||||
|
|
||||||
return jsonify({
|
return jsonify({
|
||||||
|
|
@ -597,16 +580,10 @@ def get_graph_data(graph_id: str):
|
||||||
@graph_bp.route('/delete/<graph_id>', methods=['DELETE'])
|
@graph_bp.route('/delete/<graph_id>', methods=['DELETE'])
|
||||||
def delete_graph(graph_id: str):
|
def delete_graph(graph_id: str):
|
||||||
"""
|
"""
|
||||||
删除Zep图谱
|
删除本地图谱
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if not Config.ZEP_API_KEY:
|
builder = GraphBuilderService()
|
||||||
return jsonify({
|
|
||||||
"success": False,
|
|
||||||
"error": t('api.zepApiKeyMissing')
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
builder = GraphBuilderService(api_key=Config.ZEP_API_KEY)
|
|
||||||
builder.delete_graph(graph_id)
|
builder.delete_graph(graph_id)
|
||||||
|
|
||||||
return jsonify({
|
return jsonify({
|
||||||
|
|
|
||||||
|
|
@ -57,12 +57,6 @@ def get_graph_entities(graph_id: str):
|
||||||
enrich: 是否获取相关边信息(默认true)
|
enrich: 是否获取相关边信息(默认true)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if not Config.ZEP_API_KEY:
|
|
||||||
return jsonify({
|
|
||||||
"success": False,
|
|
||||||
"error": t('api.zepApiKeyMissing')
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
entity_types_str = request.args.get('entity_types', '')
|
entity_types_str = request.args.get('entity_types', '')
|
||||||
entity_types = [t.strip() for t in entity_types_str.split(',') if t.strip()] if entity_types_str else None
|
entity_types = [t.strip() for t in entity_types_str.split(',') if t.strip()] if entity_types_str else None
|
||||||
enrich = request.args.get('enrich', 'true').lower() == 'true'
|
enrich = request.args.get('enrich', 'true').lower() == 'true'
|
||||||
|
|
@ -94,12 +88,6 @@ def get_graph_entities(graph_id: str):
|
||||||
def get_entity_detail(graph_id: str, entity_uuid: str):
|
def get_entity_detail(graph_id: str, entity_uuid: str):
|
||||||
"""获取单个实体的详细信息"""
|
"""获取单个实体的详细信息"""
|
||||||
try:
|
try:
|
||||||
if not Config.ZEP_API_KEY:
|
|
||||||
return jsonify({
|
|
||||||
"success": False,
|
|
||||||
"error": t('api.zepApiKeyMissing')
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
reader = ZepEntityReader()
|
reader = ZepEntityReader()
|
||||||
entity = reader.get_entity_with_context(graph_id, entity_uuid)
|
entity = reader.get_entity_with_context(graph_id, entity_uuid)
|
||||||
|
|
||||||
|
|
@ -127,12 +115,6 @@ def get_entity_detail(graph_id: str, entity_uuid: str):
|
||||||
def get_entities_by_type(graph_id: str, entity_type: str):
|
def get_entities_by_type(graph_id: str, entity_type: str):
|
||||||
"""获取指定类型的所有实体"""
|
"""获取指定类型的所有实体"""
|
||||||
try:
|
try:
|
||||||
if not Config.ZEP_API_KEY:
|
|
||||||
return jsonify({
|
|
||||||
"success": False,
|
|
||||||
"error": t('api.zepApiKeyMissing')
|
|
||||||
}), 500
|
|
||||||
|
|
||||||
enrich = request.args.get('enrich', 'true').lower() == 'true'
|
enrich = request.args.get('enrich', 'true').lower() == 'true'
|
||||||
|
|
||||||
reader = ZepEntityReader()
|
reader = ZepEntityReader()
|
||||||
|
|
|
||||||
|
|
@ -32,8 +32,8 @@ class Config:
|
||||||
LLM_BASE_URL = os.environ.get('LLM_BASE_URL', 'https://api.openai.com/v1')
|
LLM_BASE_URL = os.environ.get('LLM_BASE_URL', 'https://api.openai.com/v1')
|
||||||
LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME', 'gpt-4o-mini')
|
LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME', 'gpt-4o-mini')
|
||||||
|
|
||||||
# Zep配置
|
# 本地图谱存储目录
|
||||||
ZEP_API_KEY = os.environ.get('ZEP_API_KEY')
|
GRAPH_STORAGE_DIR = os.path.join(os.path.dirname(__file__), '../uploads/graphs')
|
||||||
|
|
||||||
# 文件上传配置
|
# 文件上传配置
|
||||||
MAX_CONTENT_LENGTH = 50 * 1024 * 1024 # 50MB
|
MAX_CONTENT_LENGTH = 50 * 1024 * 1024 # 50MB
|
||||||
|
|
@ -69,7 +69,7 @@ class Config:
|
||||||
errors = []
|
errors = []
|
||||||
if not cls.LLM_API_KEY:
|
if not cls.LLM_API_KEY:
|
||||||
errors.append("LLM_API_KEY 未配置")
|
errors.append("LLM_API_KEY 未配置")
|
||||||
if not cls.ZEP_API_KEY:
|
# 确保图谱存储目录存在
|
||||||
errors.append("ZEP_API_KEY 未配置")
|
os.makedirs(cls.GRAPH_STORAGE_DIR, exist_ok=True)
|
||||||
return errors
|
return errors
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"""
|
"""
|
||||||
图谱构建服务
|
图谱构建服务
|
||||||
接口2:使用Zep API构建Standalone Graph
|
使用本地JSON文件存储替代Zep Cloud
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
@ -10,14 +10,15 @@ import threading
|
||||||
from typing import Dict, Any, List, Optional, Callable
|
from typing import Dict, Any, List, Optional, Callable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from zep_cloud.client import Zep
|
|
||||||
from zep_cloud import EpisodeData, EntityEdgeSourceTarget
|
|
||||||
|
|
||||||
from ..config import Config
|
from ..config import Config
|
||||||
from ..models.task import TaskManager, TaskStatus
|
from ..models.task import TaskManager, TaskStatus
|
||||||
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
|
from ..utils.local_graph_store import LocalGraphStore
|
||||||
|
from ..utils.llm_client import LLMClient
|
||||||
from .text_processor import TextProcessor
|
from .text_processor import TextProcessor
|
||||||
from ..utils.locale import t, get_locale, set_locale
|
from ..utils.locale import t, get_locale, set_locale
|
||||||
|
from ..utils.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger('mirofish.graph_builder')
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -40,16 +41,22 @@ class GraphInfo:
|
||||||
class GraphBuilderService:
|
class GraphBuilderService:
|
||||||
"""
|
"""
|
||||||
图谱构建服务
|
图谱构建服务
|
||||||
负责调用Zep API构建知识图谱
|
使用本地JSON文件存储构建知识图谱
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, api_key: Optional[str] = None):
|
def __init__(self, storage_dir: Optional[str] = None, api_key: Optional[str] = None):
|
||||||
self.api_key = api_key or Config.ZEP_API_KEY
|
# api_key参数保留以兼容旧调用方式,但不再使用
|
||||||
if not self.api_key:
|
self.storage_dir = storage_dir or Config.GRAPH_STORAGE_DIR
|
||||||
raise ValueError("ZEP_API_KEY 未配置")
|
self.store = LocalGraphStore(self.storage_dir)
|
||||||
|
|
||||||
self.client = Zep(api_key=self.api_key)
|
|
||||||
self.task_manager = TaskManager()
|
self.task_manager = TaskManager()
|
||||||
|
self._llm: Optional[LLMClient] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def llm(self) -> LLMClient:
|
||||||
|
"""延迟初始化LLM客户端"""
|
||||||
|
if self._llm is None:
|
||||||
|
self._llm = LLMClient()
|
||||||
|
return self._llm
|
||||||
|
|
||||||
def build_graph_async(
|
def build_graph_async(
|
||||||
self,
|
self,
|
||||||
|
|
@ -63,18 +70,9 @@ class GraphBuilderService:
|
||||||
"""
|
"""
|
||||||
异步构建图谱
|
异步构建图谱
|
||||||
|
|
||||||
Args:
|
|
||||||
text: 输入文本
|
|
||||||
ontology: 本体定义(来自接口1的输出)
|
|
||||||
graph_name: 图谱名称
|
|
||||||
chunk_size: 文本块大小
|
|
||||||
chunk_overlap: 块重叠大小
|
|
||||||
batch_size: 每批发送的块数量
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
任务ID
|
任务ID
|
||||||
"""
|
"""
|
||||||
# 创建任务
|
|
||||||
task_id = self.task_manager.create_task(
|
task_id = self.task_manager.create_task(
|
||||||
task_type="graph_build",
|
task_type="graph_build",
|
||||||
metadata={
|
metadata={
|
||||||
|
|
@ -84,10 +82,8 @@ class GraphBuilderService:
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Capture locale before spawning background thread
|
|
||||||
current_locale = get_locale()
|
current_locale = get_locale()
|
||||||
|
|
||||||
# 在后台线程中执行构建
|
|
||||||
thread = threading.Thread(
|
thread = threading.Thread(
|
||||||
target=self._build_graph_worker,
|
target=self._build_graph_worker,
|
||||||
args=(task_id, text, ontology, graph_name, chunk_size, chunk_overlap, batch_size, current_locale)
|
args=(task_id, text, ontology, graph_name, chunk_size, chunk_overlap, batch_size, current_locale)
|
||||||
|
|
@ -126,7 +122,7 @@ class GraphBuilderService:
|
||||||
message=t('progress.graphCreated', graphId=graph_id)
|
message=t('progress.graphCreated', graphId=graph_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. 设置本体
|
# 2. 保存本体
|
||||||
self.set_ontology(graph_id, ontology)
|
self.set_ontology(graph_id, ontology)
|
||||||
self.task_manager.update_task(
|
self.task_manager.update_task(
|
||||||
task_id,
|
task_id,
|
||||||
|
|
@ -143,33 +139,17 @@ class GraphBuilderService:
|
||||||
message=t('progress.textSplit', count=total_chunks)
|
message=t('progress.textSplit', count=total_chunks)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. 分批发送数据
|
# 4. 分批处理:提取实体并存储
|
||||||
episode_uuids = self.add_text_batches(
|
self.add_text_batches(
|
||||||
graph_id, chunks, batch_size,
|
graph_id, chunks, batch_size,
|
||||||
lambda msg, prog: self.task_manager.update_task(
|
lambda msg, prog: self.task_manager.update_task(
|
||||||
task_id,
|
task_id,
|
||||||
progress=20 + int(prog * 0.4), # 20-60%
|
progress=20 + int(prog * 0.7), # 20-90%
|
||||||
message=msg
|
message=msg
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 5. 等待Zep处理完成
|
# 5. 获取图谱信息
|
||||||
self.task_manager.update_task(
|
|
||||||
task_id,
|
|
||||||
progress=60,
|
|
||||||
message=t('progress.waitingZepProcess')
|
|
||||||
)
|
|
||||||
|
|
||||||
self._wait_for_episodes(
|
|
||||||
episode_uuids,
|
|
||||||
lambda msg, prog: self.task_manager.update_task(
|
|
||||||
task_id,
|
|
||||||
progress=60 + int(prog * 0.3), # 60-90%
|
|
||||||
message=msg
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 6. 获取图谱信息
|
|
||||||
self.task_manager.update_task(
|
self.task_manager.update_task(
|
||||||
task_id,
|
task_id,
|
||||||
progress=90,
|
progress=90,
|
||||||
|
|
@ -178,7 +158,6 @@ class GraphBuilderService:
|
||||||
|
|
||||||
graph_info = self._get_graph_info(graph_id)
|
graph_info = self._get_graph_info(graph_id)
|
||||||
|
|
||||||
# 完成
|
|
||||||
self.task_manager.complete_task(task_id, {
|
self.task_manager.complete_task(task_id, {
|
||||||
"graph_id": graph_id,
|
"graph_id": graph_id,
|
||||||
"graph_info": graph_info.to_dict(),
|
"graph_info": graph_info.to_dict(),
|
||||||
|
|
@ -191,105 +170,14 @@ class GraphBuilderService:
|
||||||
self.task_manager.fail_task(task_id, error_msg)
|
self.task_manager.fail_task(task_id, error_msg)
|
||||||
|
|
||||||
def create_graph(self, name: str) -> str:
|
def create_graph(self, name: str) -> str:
|
||||||
"""创建Zep图谱(公开方法)"""
|
"""创建本地图谱"""
|
||||||
graph_id = f"mirofish_{uuid.uuid4().hex[:16]}"
|
graph_id = f"mirofish_{uuid.uuid4().hex[:16]}"
|
||||||
|
self.store.create_graph(graph_id, name, "MiroFish Social Simulation Graph")
|
||||||
self.client.graph.create(
|
|
||||||
graph_id=graph_id,
|
|
||||||
name=name,
|
|
||||||
description="MiroFish Social Simulation Graph"
|
|
||||||
)
|
|
||||||
|
|
||||||
return graph_id
|
return graph_id
|
||||||
|
|
||||||
def set_ontology(self, graph_id: str, ontology: Dict[str, Any]):
|
def set_ontology(self, graph_id: str, ontology: Dict[str, Any]):
|
||||||
"""设置图谱本体(公开方法)"""
|
"""保存本体定义"""
|
||||||
import warnings
|
self.store.set_ontology(graph_id, ontology)
|
||||||
from typing import Optional
|
|
||||||
from pydantic import Field
|
|
||||||
from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel
|
|
||||||
|
|
||||||
# 抑制 Pydantic v2 关于 Field(default=None) 的警告
|
|
||||||
# 这是 Zep SDK 要求的用法,警告来自动态类创建,可以安全忽略
|
|
||||||
warnings.filterwarnings('ignore', category=UserWarning, module='pydantic')
|
|
||||||
|
|
||||||
# Zep 保留名称,不能作为属性名
|
|
||||||
RESERVED_NAMES = {'uuid', 'name', 'group_id', 'name_embedding', 'summary', 'created_at'}
|
|
||||||
|
|
||||||
def safe_attr_name(attr_name: str) -> str:
|
|
||||||
"""将保留名称转换为安全名称"""
|
|
||||||
if attr_name.lower() in RESERVED_NAMES:
|
|
||||||
return f"entity_{attr_name}"
|
|
||||||
return attr_name
|
|
||||||
|
|
||||||
# 动态创建实体类型
|
|
||||||
entity_types = {}
|
|
||||||
for entity_def in ontology.get("entity_types", []):
|
|
||||||
name = entity_def["name"]
|
|
||||||
description = entity_def.get("description", f"A {name} entity.")
|
|
||||||
|
|
||||||
# 创建属性字典和类型注解(Pydantic v2 需要)
|
|
||||||
attrs = {"__doc__": description}
|
|
||||||
annotations = {}
|
|
||||||
|
|
||||||
for attr_def in entity_def.get("attributes", []):
|
|
||||||
attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称
|
|
||||||
attr_desc = attr_def.get("description", attr_name)
|
|
||||||
# Zep API 需要 Field 的 description,这是必需的
|
|
||||||
attrs[attr_name] = Field(description=attr_desc, default=None)
|
|
||||||
annotations[attr_name] = Optional[EntityText] # 类型注解
|
|
||||||
|
|
||||||
attrs["__annotations__"] = annotations
|
|
||||||
|
|
||||||
# 动态创建类
|
|
||||||
entity_class = type(name, (EntityModel,), attrs)
|
|
||||||
entity_class.__doc__ = description
|
|
||||||
entity_types[name] = entity_class
|
|
||||||
|
|
||||||
# 动态创建边类型
|
|
||||||
edge_definitions = {}
|
|
||||||
for edge_def in ontology.get("edge_types", []):
|
|
||||||
name = edge_def["name"]
|
|
||||||
description = edge_def.get("description", f"A {name} relationship.")
|
|
||||||
|
|
||||||
# 创建属性字典和类型注解
|
|
||||||
attrs = {"__doc__": description}
|
|
||||||
annotations = {}
|
|
||||||
|
|
||||||
for attr_def in edge_def.get("attributes", []):
|
|
||||||
attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称
|
|
||||||
attr_desc = attr_def.get("description", attr_name)
|
|
||||||
# Zep API 需要 Field 的 description,这是必需的
|
|
||||||
attrs[attr_name] = Field(description=attr_desc, default=None)
|
|
||||||
annotations[attr_name] = Optional[str] # 边属性用str类型
|
|
||||||
|
|
||||||
attrs["__annotations__"] = annotations
|
|
||||||
|
|
||||||
# 动态创建类
|
|
||||||
class_name = ''.join(word.capitalize() for word in name.split('_'))
|
|
||||||
edge_class = type(class_name, (EdgeModel,), attrs)
|
|
||||||
edge_class.__doc__ = description
|
|
||||||
|
|
||||||
# 构建source_targets
|
|
||||||
source_targets = []
|
|
||||||
for st in edge_def.get("source_targets", []):
|
|
||||||
source_targets.append(
|
|
||||||
EntityEdgeSourceTarget(
|
|
||||||
source=st.get("source", "Entity"),
|
|
||||||
target=st.get("target", "Entity")
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if source_targets:
|
|
||||||
edge_definitions[name] = (edge_class, source_targets)
|
|
||||||
|
|
||||||
# 调用Zep API设置本体
|
|
||||||
if entity_types or edge_definitions:
|
|
||||||
self.client.graph.set_ontology(
|
|
||||||
graph_ids=[graph_id],
|
|
||||||
entities=entity_types if entity_types else None,
|
|
||||||
edges=edge_definitions if edge_definitions else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
def add_text_batches(
|
def add_text_batches(
|
||||||
self,
|
self,
|
||||||
|
|
@ -298,209 +186,206 @@ class GraphBuilderService:
|
||||||
batch_size: int = 3,
|
batch_size: int = 3,
|
||||||
progress_callback: Optional[Callable] = None
|
progress_callback: Optional[Callable] = None
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""分批添加文本到图谱,返回所有 episode 的 uuid 列表"""
|
"""分批处理文本:提取实体/关系并存储,返回情节uuid列表"""
|
||||||
episode_uuids = []
|
episode_uuids = []
|
||||||
|
ontology = self.store.get_ontology(graph_id) or {}
|
||||||
total_chunks = len(chunks)
|
total_chunks = len(chunks)
|
||||||
|
|
||||||
for i in range(0, total_chunks, batch_size):
|
for i in range(0, total_chunks, batch_size):
|
||||||
batch_chunks = chunks[i:i + batch_size]
|
batch = chunks[i:i + batch_size]
|
||||||
batch_num = i // batch_size + 1
|
batch_num = i // batch_size + 1
|
||||||
total_batches = (total_chunks + batch_size - 1) // batch_size
|
total_batches = (total_chunks + batch_size - 1) // batch_size
|
||||||
|
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
progress = (i + len(batch_chunks)) / total_chunks
|
progress = (i + len(batch)) / total_chunks
|
||||||
progress_callback(
|
progress_callback(
|
||||||
t('progress.sendingBatch', current=batch_num, total=total_batches, chunks=len(batch_chunks)),
|
t('progress.sendingBatch', current=batch_num, total=total_batches, chunks=len(batch)),
|
||||||
progress
|
progress
|
||||||
)
|
)
|
||||||
|
|
||||||
# 构建episode数据
|
# 存储情节文本
|
||||||
episodes = [
|
for text in batch:
|
||||||
EpisodeData(data=chunk, type="text")
|
ep_uuid = self.store.add_episode(graph_id, text)
|
||||||
for chunk in batch_chunks
|
episode_uuids.append(ep_uuid)
|
||||||
]
|
|
||||||
|
|
||||||
# 发送到Zep
|
# 使用LLM从批次文本中提取实体和关系
|
||||||
try:
|
if ontology.get("entity_types") or ontology.get("edge_types"):
|
||||||
batch_result = self.client.graph.add_batch(
|
try:
|
||||||
graph_id=graph_id,
|
extracted = self._extract_entities_from_batch(batch, ontology)
|
||||||
episodes=episodes
|
self._store_extracted(graph_id, extracted)
|
||||||
)
|
except Exception as e:
|
||||||
|
logger.warning(f"批次 {batch_num} 实体提取失败: {e}")
|
||||||
|
|
||||||
# 收集返回的 episode uuid
|
# 轻微延迟,避免LLM请求过快
|
||||||
if batch_result and isinstance(batch_result, list):
|
time.sleep(0.3)
|
||||||
for ep in batch_result:
|
|
||||||
ep_uuid = getattr(ep, 'uuid_', None) or getattr(ep, 'uuid', None)
|
|
||||||
if ep_uuid:
|
|
||||||
episode_uuids.append(ep_uuid)
|
|
||||||
|
|
||||||
# 避免请求过快
|
|
||||||
time.sleep(1)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
if progress_callback:
|
|
||||||
progress_callback(t('progress.batchFailed', batch=batch_num, error=str(e)), 0)
|
|
||||||
raise
|
|
||||||
|
|
||||||
return episode_uuids
|
return episode_uuids
|
||||||
|
|
||||||
|
def _extract_entities_from_batch(self, texts: List[str], ontology: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""使用LLM从文本批次中提取实体和关系"""
|
||||||
|
combined_text = "\n\n".join(texts)
|
||||||
|
|
||||||
|
entity_types_desc = "\n".join(
|
||||||
|
f"- {et['name']}: {et.get('description', '')}"
|
||||||
|
for et in ontology.get("entity_types", [])
|
||||||
|
) or "- Entity (通用实体)"
|
||||||
|
|
||||||
|
edge_types_desc = "\n".join(
|
||||||
|
f"- {rt['name']}: {rt.get('description', '')}"
|
||||||
|
for rt in ontology.get("edge_types", [])
|
||||||
|
) or "- RELATED_TO"
|
||||||
|
|
||||||
|
user_prompt = f"""从以下文本中提取实体和关系,仅使用给定的本体类型。
|
||||||
|
|
||||||
|
实体类型(只能使用这些):
|
||||||
|
{entity_types_desc}
|
||||||
|
|
||||||
|
关系类型(只能使用这些):
|
||||||
|
{edge_types_desc}
|
||||||
|
|
||||||
|
文本:
|
||||||
|
{combined_text[:4000]}
|
||||||
|
|
||||||
|
返回JSON格式:
|
||||||
|
{{
|
||||||
|
"entities": [
|
||||||
|
{{"name": "实体名称", "type": "实体类型", "summary": "一句话描述", "attributes": {{}}}}
|
||||||
|
],
|
||||||
|
"relationships": [
|
||||||
|
{{"source": "源实体名称", "target": "目标实体名称", "type": "关系类型", "fact": "事实描述"}}
|
||||||
|
]
|
||||||
|
}}
|
||||||
|
|
||||||
|
规则:
|
||||||
|
- 仅使用本体中定义的实体类型和关系类型
|
||||||
|
- 实体名称应具体(人名、地名、组织名等)
|
||||||
|
- fact字段应是简洁的事实陈述
|
||||||
|
- 若找不到匹配项,返回空列表"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = self.llm.chat_json(
|
||||||
|
messages=[{"role": "user", "content": user_prompt}],
|
||||||
|
temperature=0.1
|
||||||
|
)
|
||||||
|
return result if isinstance(result, dict) else {"entities": [], "relationships": []}
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"实体提取LLM调用失败: {e}")
|
||||||
|
return {"entities": [], "relationships": []}
|
||||||
|
|
||||||
|
def _store_extracted(self, graph_id: str, extracted: Dict[str, Any]):
|
||||||
|
"""将LLM提取的实体和关系存储到本地图谱"""
|
||||||
|
entities = extracted.get("entities", []) or []
|
||||||
|
relationships = extracted.get("relationships", []) or []
|
||||||
|
|
||||||
|
name_to_uuid: Dict[str, str] = {}
|
||||||
|
|
||||||
|
for entity in entities:
|
||||||
|
name = (entity.get("name") or "").strip()
|
||||||
|
if not name:
|
||||||
|
continue
|
||||||
|
entity_type = entity.get("type") or "Entity"
|
||||||
|
summary = entity.get("summary") or ""
|
||||||
|
attributes = entity.get("attributes") or {}
|
||||||
|
|
||||||
|
labels = [entity_type, "Entity"] if entity_type != "Entity" else ["Entity"]
|
||||||
|
node_uuid = self.store.upsert_node(
|
||||||
|
graph_id=graph_id,
|
||||||
|
name=name,
|
||||||
|
labels=labels,
|
||||||
|
summary=summary,
|
||||||
|
attributes=attributes,
|
||||||
|
)
|
||||||
|
name_to_uuid[name.lower()] = node_uuid
|
||||||
|
|
||||||
|
for rel in relationships:
|
||||||
|
source_name = (rel.get("source") or "").strip()
|
||||||
|
target_name = (rel.get("target") or "").strip()
|
||||||
|
rel_type = rel.get("type") or "RELATED_TO"
|
||||||
|
fact = rel.get("fact") or ""
|
||||||
|
|
||||||
|
if not source_name or not target_name or not fact:
|
||||||
|
continue
|
||||||
|
|
||||||
|
source_uuid = name_to_uuid.get(source_name.lower()) or \
|
||||||
|
self.store.upsert_node(graph_id, source_name, ["Entity"])
|
||||||
|
name_to_uuid[source_name.lower()] = source_uuid
|
||||||
|
|
||||||
|
target_uuid = name_to_uuid.get(target_name.lower()) or \
|
||||||
|
self.store.upsert_node(graph_id, target_name, ["Entity"])
|
||||||
|
name_to_uuid[target_name.lower()] = target_uuid
|
||||||
|
|
||||||
|
self.store.add_fact_edge(
|
||||||
|
graph_id=graph_id,
|
||||||
|
source_uuid=source_uuid,
|
||||||
|
target_uuid=target_uuid,
|
||||||
|
name=rel_type,
|
||||||
|
fact=fact,
|
||||||
|
)
|
||||||
|
|
||||||
def _wait_for_episodes(
|
def _wait_for_episodes(
|
||||||
self,
|
self,
|
||||||
episode_uuids: List[str],
|
episode_uuids: List[str],
|
||||||
progress_callback: Optional[Callable] = None,
|
progress_callback: Optional[Callable] = None,
|
||||||
timeout: int = 600
|
timeout: int = 600
|
||||||
):
|
):
|
||||||
"""等待所有 episode 处理完成(通过查询每个 episode 的 processed 状态)"""
|
"""本地存储中情节立即处理完成,无需等待"""
|
||||||
if not episode_uuids:
|
|
||||||
if progress_callback:
|
|
||||||
progress_callback(t('progress.noEpisodesWait'), 1.0)
|
|
||||||
return
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
pending_episodes = set(episode_uuids)
|
|
||||||
completed_count = 0
|
|
||||||
total_episodes = len(episode_uuids)
|
|
||||||
|
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
progress_callback(t('progress.waitingEpisodes', count=total_episodes), 0)
|
progress_callback(t('progress.processingComplete',
|
||||||
|
completed=len(episode_uuids),
|
||||||
while pending_episodes:
|
total=len(episode_uuids)), 1.0)
|
||||||
if time.time() - start_time > timeout:
|
|
||||||
if progress_callback:
|
|
||||||
progress_callback(
|
|
||||||
t('progress.episodesTimeout', completed=completed_count, total=total_episodes),
|
|
||||||
completed_count / total_episodes
|
|
||||||
)
|
|
||||||
break
|
|
||||||
|
|
||||||
# 检查每个 episode 的处理状态
|
|
||||||
for ep_uuid in list(pending_episodes):
|
|
||||||
try:
|
|
||||||
episode = self.client.graph.episode.get(uuid_=ep_uuid)
|
|
||||||
is_processed = getattr(episode, 'processed', False)
|
|
||||||
|
|
||||||
if is_processed:
|
|
||||||
pending_episodes.remove(ep_uuid)
|
|
||||||
completed_count += 1
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
# 忽略单个查询错误,继续
|
|
||||||
pass
|
|
||||||
|
|
||||||
elapsed = int(time.time() - start_time)
|
|
||||||
if progress_callback:
|
|
||||||
progress_callback(
|
|
||||||
t('progress.zepProcessing', completed=completed_count, total=total_episodes, pending=len(pending_episodes), elapsed=elapsed),
|
|
||||||
completed_count / total_episodes if total_episodes > 0 else 0
|
|
||||||
)
|
|
||||||
|
|
||||||
if pending_episodes:
|
|
||||||
time.sleep(3) # 每3秒检查一次
|
|
||||||
|
|
||||||
if progress_callback:
|
|
||||||
progress_callback(t('progress.processingComplete', completed=completed_count, total=total_episodes), 1.0)
|
|
||||||
|
|
||||||
def _get_graph_info(self, graph_id: str) -> GraphInfo:
|
def _get_graph_info(self, graph_id: str) -> GraphInfo:
|
||||||
"""获取图谱信息"""
|
"""获取图谱统计信息"""
|
||||||
# 获取节点(分页)
|
nodes = self.store.get_nodes(graph_id)
|
||||||
nodes = fetch_all_nodes(self.client, graph_id)
|
edges = self.store.get_edges(graph_id)
|
||||||
|
|
||||||
# 获取边(分页)
|
|
||||||
edges = fetch_all_edges(self.client, graph_id)
|
|
||||||
|
|
||||||
# 统计实体类型
|
|
||||||
entity_types = set()
|
entity_types = set()
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
if node.labels:
|
for label in (node.get("labels") or []):
|
||||||
for label in node.labels:
|
if label not in ("Entity", "Node"):
|
||||||
if label not in ["Entity", "Node"]:
|
entity_types.add(label)
|
||||||
entity_types.add(label)
|
|
||||||
|
|
||||||
return GraphInfo(
|
return GraphInfo(
|
||||||
graph_id=graph_id,
|
graph_id=graph_id,
|
||||||
node_count=len(nodes),
|
node_count=len(nodes),
|
||||||
edge_count=len(edges),
|
edge_count=len(edges),
|
||||||
entity_types=list(entity_types)
|
entity_types=list(entity_types),
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_graph_data(self, graph_id: str) -> Dict[str, Any]:
|
def get_graph_data(self, graph_id: str) -> Dict[str, Any]:
|
||||||
"""
|
"""获取完整图谱数据(含节点和边详情)"""
|
||||||
获取完整图谱数据(包含详细信息)
|
nodes = self.store.get_nodes(graph_id)
|
||||||
|
edges = self.store.get_edges(graph_id)
|
||||||
|
|
||||||
Args:
|
node_map = {n["uuid"]: n.get("name", "") for n in nodes}
|
||||||
graph_id: 图谱ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
包含nodes和edges的字典,包括时间信息、属性等详细数据
|
|
||||||
"""
|
|
||||||
nodes = fetch_all_nodes(self.client, graph_id)
|
|
||||||
edges = fetch_all_edges(self.client, graph_id)
|
|
||||||
|
|
||||||
# 创建节点映射用于获取节点名称
|
|
||||||
node_map = {}
|
|
||||||
for node in nodes:
|
|
||||||
node_map[node.uuid_] = node.name or ""
|
|
||||||
|
|
||||||
nodes_data = []
|
|
||||||
for node in nodes:
|
|
||||||
# 获取创建时间
|
|
||||||
created_at = getattr(node, 'created_at', None)
|
|
||||||
if created_at:
|
|
||||||
created_at = str(created_at)
|
|
||||||
|
|
||||||
nodes_data.append({
|
|
||||||
"uuid": node.uuid_,
|
|
||||||
"name": node.name,
|
|
||||||
"labels": node.labels or [],
|
|
||||||
"summary": node.summary or "",
|
|
||||||
"attributes": node.attributes or {},
|
|
||||||
"created_at": created_at,
|
|
||||||
})
|
|
||||||
|
|
||||||
edges_data = []
|
edges_data = []
|
||||||
for edge in edges:
|
for edge in edges:
|
||||||
# 获取时间信息
|
|
||||||
created_at = getattr(edge, 'created_at', None)
|
|
||||||
valid_at = getattr(edge, 'valid_at', None)
|
|
||||||
invalid_at = getattr(edge, 'invalid_at', None)
|
|
||||||
expired_at = getattr(edge, 'expired_at', None)
|
|
||||||
|
|
||||||
# 获取 episodes
|
|
||||||
episodes = getattr(edge, 'episodes', None) or getattr(edge, 'episode_ids', None)
|
|
||||||
if episodes and not isinstance(episodes, list):
|
|
||||||
episodes = [str(episodes)]
|
|
||||||
elif episodes:
|
|
||||||
episodes = [str(e) for e in episodes]
|
|
||||||
|
|
||||||
# 获取 fact_type
|
|
||||||
fact_type = getattr(edge, 'fact_type', None) or edge.name or ""
|
|
||||||
|
|
||||||
edges_data.append({
|
edges_data.append({
|
||||||
"uuid": edge.uuid_,
|
"uuid": edge.get("uuid", ""),
|
||||||
"name": edge.name or "",
|
"name": edge.get("name", ""),
|
||||||
"fact": edge.fact or "",
|
"fact": edge.get("fact", ""),
|
||||||
"fact_type": fact_type,
|
"fact_type": edge.get("name", ""),
|
||||||
"source_node_uuid": edge.source_node_uuid,
|
"source_node_uuid": edge.get("source_node_uuid", ""),
|
||||||
"target_node_uuid": edge.target_node_uuid,
|
"target_node_uuid": edge.get("target_node_uuid", ""),
|
||||||
"source_node_name": node_map.get(edge.source_node_uuid, ""),
|
"source_node_name": node_map.get(edge.get("source_node_uuid", ""), ""),
|
||||||
"target_node_name": node_map.get(edge.target_node_uuid, ""),
|
"target_node_name": node_map.get(edge.get("target_node_uuid", ""), ""),
|
||||||
"attributes": edge.attributes or {},
|
"attributes": edge.get("attributes", {}),
|
||||||
"created_at": str(created_at) if created_at else None,
|
"created_at": edge.get("created_at"),
|
||||||
"valid_at": str(valid_at) if valid_at else None,
|
"valid_at": edge.get("valid_at"),
|
||||||
"invalid_at": str(invalid_at) if invalid_at else None,
|
"invalid_at": edge.get("invalid_at"),
|
||||||
"expired_at": str(expired_at) if expired_at else None,
|
"expired_at": edge.get("expired_at"),
|
||||||
"episodes": episodes or [],
|
"episodes": [],
|
||||||
})
|
})
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"graph_id": graph_id,
|
"graph_id": graph_id,
|
||||||
"nodes": nodes_data,
|
"nodes": nodes,
|
||||||
"edges": edges_data,
|
"edges": edges_data,
|
||||||
"node_count": len(nodes_data),
|
"node_count": len(nodes),
|
||||||
"edge_count": len(edges_data),
|
"edge_count": len(edges),
|
||||||
}
|
}
|
||||||
|
|
||||||
def delete_graph(self, graph_id: str):
|
def delete_graph(self, graph_id: str):
|
||||||
"""删除图谱"""
|
"""删除图谱"""
|
||||||
self.client.graph.delete(graph_id=graph_id)
|
self.store.delete_graph(graph_id)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,9 +16,9 @@ from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from zep_cloud.client import Zep
|
|
||||||
|
|
||||||
from ..config import Config
|
from ..config import Config
|
||||||
|
from ..utils.local_graph_store import LocalGraphStore
|
||||||
from ..utils.logger import get_logger
|
from ..utils.logger import get_logger
|
||||||
from ..utils.locale import get_language_instruction, get_locale, set_locale, t
|
from ..utils.locale import get_language_instruction, get_locale, set_locale, t
|
||||||
from .zep_entity_reader import EntityNode, ZepEntityReader
|
from .zep_entity_reader import EntityNode, ZepEntityReader
|
||||||
|
|
@ -183,8 +183,9 @@ class OasisProfileGenerator:
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
base_url: Optional[str] = None,
|
base_url: Optional[str] = None,
|
||||||
model_name: Optional[str] = None,
|
model_name: Optional[str] = None,
|
||||||
zep_api_key: Optional[str] = None,
|
zep_api_key: Optional[str] = None, # 已废弃,保留以兼容旧调用
|
||||||
graph_id: Optional[str] = None
|
graph_id: Optional[str] = None,
|
||||||
|
storage_dir: Optional[str] = None,
|
||||||
):
|
):
|
||||||
self.api_key = api_key or Config.LLM_API_KEY
|
self.api_key = api_key or Config.LLM_API_KEY
|
||||||
self.base_url = base_url or Config.LLM_BASE_URL
|
self.base_url = base_url or Config.LLM_BASE_URL
|
||||||
|
|
@ -198,17 +199,11 @@ class OasisProfileGenerator:
|
||||||
base_url=self.base_url
|
base_url=self.base_url
|
||||||
)
|
)
|
||||||
|
|
||||||
# Zep客户端用于检索丰富上下文
|
# 本地图谱存储
|
||||||
self.zep_api_key = zep_api_key or Config.ZEP_API_KEY
|
storage_dir = storage_dir or Config.GRAPH_STORAGE_DIR
|
||||||
self.zep_client = None
|
self.store = LocalGraphStore(storage_dir)
|
||||||
self.graph_id = graph_id
|
self.graph_id = graph_id
|
||||||
|
|
||||||
if self.zep_api_key:
|
|
||||||
try:
|
|
||||||
self.zep_client = Zep(api_key=self.zep_api_key)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Zep客户端初始化失败: {e}")
|
|
||||||
|
|
||||||
def generate_profile_from_entity(
|
def generate_profile_from_entity(
|
||||||
self,
|
self,
|
||||||
entity: EntityNode,
|
entity: EntityNode,
|
||||||
|
|
@ -285,10 +280,7 @@ class OasisProfileGenerator:
|
||||||
|
|
||||||
def _search_zep_for_entity(self, entity: EntityNode) -> Dict[str, Any]:
|
def _search_zep_for_entity(self, entity: EntityNode) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
使用Zep图谱混合搜索功能获取实体相关的丰富信息
|
使用本地图谱关键词搜索获取实体相关的丰富信息
|
||||||
|
|
||||||
Zep没有内置混合搜索接口,需要分别搜索edges和nodes然后合并结果。
|
|
||||||
使用并行请求同时搜索,提高效率。
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
entity: 实体节点对象
|
entity: 实体节点对象
|
||||||
|
|
@ -296,105 +288,32 @@ class OasisProfileGenerator:
|
||||||
Returns:
|
Returns:
|
||||||
包含facts, node_summaries, context的字典
|
包含facts, node_summaries, context的字典
|
||||||
"""
|
"""
|
||||||
import concurrent.futures
|
results: Dict[str, Any] = {"facts": [], "node_summaries": [], "context": ""}
|
||||||
|
|
||||||
if not self.zep_client:
|
|
||||||
return {"facts": [], "node_summaries": [], "context": ""}
|
|
||||||
|
|
||||||
entity_name = entity.name
|
|
||||||
|
|
||||||
results = {
|
|
||||||
"facts": [],
|
|
||||||
"node_summaries": [],
|
|
||||||
"context": ""
|
|
||||||
}
|
|
||||||
|
|
||||||
# 必须有graph_id才能进行搜索
|
|
||||||
if not self.graph_id:
|
if not self.graph_id:
|
||||||
logger.debug(f"跳过Zep检索:未设置graph_id")
|
logger.debug("跳过本地检索:未设置graph_id")
|
||||||
return results
|
return results
|
||||||
|
|
||||||
comprehensive_query = t('progress.zepSearchQuery', name=entity_name)
|
entity_name = entity.name
|
||||||
|
query = t('progress.zepSearchQuery', name=entity_name)
|
||||||
def search_edges():
|
|
||||||
"""搜索边(事实/关系)- 带重试机制"""
|
|
||||||
max_retries = 3
|
|
||||||
last_exception = None
|
|
||||||
delay = 2.0
|
|
||||||
|
|
||||||
for attempt in range(max_retries):
|
|
||||||
try:
|
|
||||||
return self.zep_client.graph.search(
|
|
||||||
query=comprehensive_query,
|
|
||||||
graph_id=self.graph_id,
|
|
||||||
limit=30,
|
|
||||||
scope="edges",
|
|
||||||
reranker="rrf"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
last_exception = e
|
|
||||||
if attempt < max_retries - 1:
|
|
||||||
logger.debug(f"Zep边搜索第 {attempt + 1} 次失败: {str(e)[:80]}, 重试中...")
|
|
||||||
time.sleep(delay)
|
|
||||||
delay *= 2
|
|
||||||
else:
|
|
||||||
logger.debug(f"Zep边搜索在 {max_retries} 次尝试后仍失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def search_nodes():
|
|
||||||
"""搜索节点(实体摘要)- 带重试机制"""
|
|
||||||
max_retries = 3
|
|
||||||
last_exception = None
|
|
||||||
delay = 2.0
|
|
||||||
|
|
||||||
for attempt in range(max_retries):
|
|
||||||
try:
|
|
||||||
return self.zep_client.graph.search(
|
|
||||||
query=comprehensive_query,
|
|
||||||
graph_id=self.graph_id,
|
|
||||||
limit=20,
|
|
||||||
scope="nodes",
|
|
||||||
reranker="rrf"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
last_exception = e
|
|
||||||
if attempt < max_retries - 1:
|
|
||||||
logger.debug(f"Zep节点搜索第 {attempt + 1} 次失败: {str(e)[:80]}, 重试中...")
|
|
||||||
time.sleep(delay)
|
|
||||||
delay *= 2
|
|
||||||
else:
|
|
||||||
logger.debug(f"Zep节点搜索在 {max_retries} 次尝试后仍失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 并行执行edges和nodes搜索
|
# 搜索边(事实)
|
||||||
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
edge_raw = self.store.search(self.graph_id, query, limit=30, scope="edges")
|
||||||
edge_future = executor.submit(search_edges)
|
facts = list({e.get("fact", "") for e in edge_raw.get("edges", []) if e.get("fact")})
|
||||||
node_future = executor.submit(search_nodes)
|
results["facts"] = facts
|
||||||
|
|
||||||
# 获取结果
|
# 搜索节点(摘要)
|
||||||
edge_result = edge_future.result(timeout=30)
|
node_raw = self.store.search(self.graph_id, query, limit=20, scope="nodes")
|
||||||
node_result = node_future.result(timeout=30)
|
summaries = set()
|
||||||
|
for n in node_raw.get("nodes", []):
|
||||||
|
if n.get("summary"):
|
||||||
|
summaries.add(n["summary"])
|
||||||
|
if n.get("name") and n["name"] != entity_name:
|
||||||
|
summaries.add(f"相关实体: {n['name']}")
|
||||||
|
results["node_summaries"] = list(summaries)
|
||||||
|
|
||||||
# 处理边搜索结果
|
# 构建上下文
|
||||||
all_facts = set()
|
|
||||||
if edge_result and hasattr(edge_result, 'edges') and edge_result.edges:
|
|
||||||
for edge in edge_result.edges:
|
|
||||||
if hasattr(edge, 'fact') and edge.fact:
|
|
||||||
all_facts.add(edge.fact)
|
|
||||||
results["facts"] = list(all_facts)
|
|
||||||
|
|
||||||
# 处理节点搜索结果
|
|
||||||
all_summaries = set()
|
|
||||||
if node_result and hasattr(node_result, 'nodes') and node_result.nodes:
|
|
||||||
for node in node_result.nodes:
|
|
||||||
if hasattr(node, 'summary') and node.summary:
|
|
||||||
all_summaries.add(node.summary)
|
|
||||||
if hasattr(node, 'name') and node.name and node.name != entity_name:
|
|
||||||
all_summaries.add(f"相关实体: {node.name}")
|
|
||||||
results["node_summaries"] = list(all_summaries)
|
|
||||||
|
|
||||||
# 构建综合上下文
|
|
||||||
context_parts = []
|
context_parts = []
|
||||||
if results["facts"]:
|
if results["facts"]:
|
||||||
context_parts.append("事实信息:\n" + "\n".join(f"- {f}" for f in results["facts"][:20]))
|
context_parts.append("事实信息:\n" + "\n".join(f"- {f}" for f in results["facts"][:20]))
|
||||||
|
|
@ -402,12 +321,11 @@ class OasisProfileGenerator:
|
||||||
context_parts.append("相关实体:\n" + "\n".join(f"- {s}" for s in results["node_summaries"][:10]))
|
context_parts.append("相关实体:\n" + "\n".join(f"- {s}" for s in results["node_summaries"][:10]))
|
||||||
results["context"] = "\n\n".join(context_parts)
|
results["context"] = "\n\n".join(context_parts)
|
||||||
|
|
||||||
logger.info(f"Zep混合检索完成: {entity_name}, 获取 {len(results['facts'])} 条事实, {len(results['node_summaries'])} 个相关节点")
|
logger.info(f"本地检索完成: {entity_name}, 获取 {len(results['facts'])} 条事实, "
|
||||||
|
f"{len(results['node_summaries'])} 个相关节点")
|
||||||
|
|
||||||
except concurrent.futures.TimeoutError:
|
|
||||||
logger.warning(f"Zep检索超时 ({entity_name})")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Zep检索失败 ({entity_name}): {e}")
|
logger.warning(f"本地检索失败 ({entity_name}): {e}")
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -413,8 +413,8 @@ class OntologyGenerator:
|
||||||
'由MiroFish自动生成,用于社会舆论模拟',
|
'由MiroFish自动生成,用于社会舆论模拟',
|
||||||
'"""',
|
'"""',
|
||||||
'',
|
'',
|
||||||
'from pydantic import Field',
|
'from typing import Optional',
|
||||||
'from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel',
|
'from pydantic import BaseModel, Field',
|
||||||
'',
|
'',
|
||||||
'',
|
'',
|
||||||
'# ============== 实体类型定义 ==============',
|
'# ============== 实体类型定义 ==============',
|
||||||
|
|
@ -426,7 +426,7 @@ class OntologyGenerator:
|
||||||
name = entity["name"]
|
name = entity["name"]
|
||||||
desc = entity.get("description", f"A {name} entity.")
|
desc = entity.get("description", f"A {name} entity.")
|
||||||
|
|
||||||
code_lines.append(f'class {name}(EntityModel):')
|
code_lines.append(f'class {name}(BaseModel):')
|
||||||
code_lines.append(f' """{desc}"""')
|
code_lines.append(f' """{desc}"""')
|
||||||
|
|
||||||
attrs = entity.get("attributes", [])
|
attrs = entity.get("attributes", [])
|
||||||
|
|
@ -434,7 +434,7 @@ class OntologyGenerator:
|
||||||
for attr in attrs:
|
for attr in attrs:
|
||||||
attr_name = attr["name"]
|
attr_name = attr["name"]
|
||||||
attr_desc = attr.get("description", attr_name)
|
attr_desc = attr.get("description", attr_name)
|
||||||
code_lines.append(f' {attr_name}: EntityText = Field(')
|
code_lines.append(f' {attr_name}: Optional[str] = Field(')
|
||||||
code_lines.append(f' description="{attr_desc}",')
|
code_lines.append(f' description="{attr_desc}",')
|
||||||
code_lines.append(f' default=None')
|
code_lines.append(f' default=None')
|
||||||
code_lines.append(f' )')
|
code_lines.append(f' )')
|
||||||
|
|
@ -454,7 +454,7 @@ class OntologyGenerator:
|
||||||
class_name = ''.join(word.capitalize() for word in name.split('_'))
|
class_name = ''.join(word.capitalize() for word in name.split('_'))
|
||||||
desc = edge.get("description", f"A {name} relationship.")
|
desc = edge.get("description", f"A {name} relationship.")
|
||||||
|
|
||||||
code_lines.append(f'class {class_name}(EdgeModel):')
|
code_lines.append(f'class {class_name}(BaseModel):')
|
||||||
code_lines.append(f' """{desc}"""')
|
code_lines.append(f' """{desc}"""')
|
||||||
|
|
||||||
attrs = edge.get("attributes", [])
|
attrs = edge.get("attributes", [])
|
||||||
|
|
@ -462,7 +462,7 @@ class OntologyGenerator:
|
||||||
for attr in attrs:
|
for attr in attrs:
|
||||||
attr_name = attr["name"]
|
attr_name = attr["name"]
|
||||||
attr_desc = attr.get("description", attr_name)
|
attr_desc = attr.get("description", attr_name)
|
||||||
code_lines.append(f' {attr_name}: EntityText = Field(')
|
code_lines.append(f' {attr_name}: Optional[str] = Field(')
|
||||||
code_lines.append(f' description="{attr_desc}",')
|
code_lines.append(f' description="{attr_desc}",')
|
||||||
code_lines.append(f' default=None')
|
code_lines.append(f' default=None')
|
||||||
code_lines.append(f' )')
|
code_lines.append(f' )')
|
||||||
|
|
|
||||||
|
|
@ -1,23 +1,17 @@
|
||||||
"""
|
"""
|
||||||
Zep实体读取与过滤服务
|
实体读取与过滤服务
|
||||||
从Zep图谱中读取节点,筛选出符合预定义实体类型的节点
|
从本地JSON图谱中读取节点,筛选出符合预定义实体类型的节点
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import time
|
from typing import Dict, Any, List, Optional, Set
|
||||||
from typing import Dict, Any, List, Optional, Set, Callable, TypeVar
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from zep_cloud.client import Zep
|
|
||||||
|
|
||||||
from ..config import Config
|
from ..config import Config
|
||||||
|
from ..utils.local_graph_store import LocalGraphStore
|
||||||
from ..utils.logger import get_logger
|
from ..utils.logger import get_logger
|
||||||
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
|
|
||||||
|
|
||||||
logger = get_logger('mirofish.zep_entity_reader')
|
logger = get_logger('mirofish.zep_entity_reader')
|
||||||
|
|
||||||
# 用于泛型返回类型
|
|
||||||
T = TypeVar('T')
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EntityNode:
|
class EntityNode:
|
||||||
|
|
@ -27,9 +21,7 @@ class EntityNode:
|
||||||
labels: List[str]
|
labels: List[str]
|
||||||
summary: str
|
summary: str
|
||||||
attributes: Dict[str, Any]
|
attributes: Dict[str, Any]
|
||||||
# 相关的边信息
|
|
||||||
related_edges: List[Dict[str, Any]] = field(default_factory=list)
|
related_edges: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
# 相关的其他节点信息
|
|
||||||
related_nodes: List[Dict[str, Any]] = field(default_factory=list)
|
related_nodes: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
|
@ -44,9 +36,9 @@ class EntityNode:
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_entity_type(self) -> Optional[str]:
|
def get_entity_type(self) -> Optional[str]:
|
||||||
"""获取实体类型(排除默认的Entity标签)"""
|
"""获取实体类型(排除默认的Entity/Node标签)"""
|
||||||
for label in self.labels:
|
for label in self.labels:
|
||||||
if label not in ["Entity", "Node"]:
|
if label not in ("Entity", "Node"):
|
||||||
return label
|
return label
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -70,146 +62,39 @@ class FilteredEntities:
|
||||||
|
|
||||||
class ZepEntityReader:
|
class ZepEntityReader:
|
||||||
"""
|
"""
|
||||||
Zep实体读取与过滤服务
|
实体读取与过滤服务
|
||||||
|
|
||||||
主要功能:
|
主要功能:
|
||||||
1. 从Zep图谱读取所有节点
|
1. 从本地图谱读取所有节点
|
||||||
2. 筛选出符合预定义实体类型的节点(Labels不只是Entity的节点)
|
2. 筛选出符合预定义实体类型的节点(Labels不只是Entity的节点)
|
||||||
3. 获取每个实体的相关边和关联节点信息
|
3. 获取每个实体的相关边和关联节点信息
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, api_key: Optional[str] = None):
|
def __init__(self, storage_dir: Optional[str] = None, api_key: Optional[str] = None):
|
||||||
self.api_key = api_key or Config.ZEP_API_KEY
|
# api_key参数保留以兼容旧调用方式,但不再使用
|
||||||
if not self.api_key:
|
storage_dir = storage_dir or Config.GRAPH_STORAGE_DIR
|
||||||
raise ValueError("ZEP_API_KEY 未配置")
|
self.store = LocalGraphStore(storage_dir)
|
||||||
|
|
||||||
self.client = Zep(api_key=self.api_key)
|
|
||||||
|
|
||||||
def _call_with_retry(
|
|
||||||
self,
|
|
||||||
func: Callable[[], T],
|
|
||||||
operation_name: str,
|
|
||||||
max_retries: int = 3,
|
|
||||||
initial_delay: float = 2.0
|
|
||||||
) -> T:
|
|
||||||
"""
|
|
||||||
带重试机制的Zep API调用
|
|
||||||
|
|
||||||
Args:
|
|
||||||
func: 要执行的函数(无参数的lambda或callable)
|
|
||||||
operation_name: 操作名称,用于日志
|
|
||||||
max_retries: 最大重试次数(默认3次,即最多尝试3次)
|
|
||||||
initial_delay: 初始延迟秒数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
API调用结果
|
|
||||||
"""
|
|
||||||
last_exception = None
|
|
||||||
delay = initial_delay
|
|
||||||
|
|
||||||
for attempt in range(max_retries):
|
|
||||||
try:
|
|
||||||
return func()
|
|
||||||
except Exception as e:
|
|
||||||
last_exception = e
|
|
||||||
if attempt < max_retries - 1:
|
|
||||||
logger.warning(
|
|
||||||
f"Zep {operation_name} 第 {attempt + 1} 次尝试失败: {str(e)[:100]}, "
|
|
||||||
f"{delay:.1f}秒后重试..."
|
|
||||||
)
|
|
||||||
time.sleep(delay)
|
|
||||||
delay *= 2 # 指数退避
|
|
||||||
else:
|
|
||||||
logger.error(f"Zep {operation_name} 在 {max_retries} 次尝试后仍失败: {str(e)}")
|
|
||||||
|
|
||||||
raise last_exception
|
|
||||||
|
|
||||||
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
|
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""获取图谱的所有节点"""
|
||||||
获取图谱的所有节点(分页获取)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
graph_id: 图谱ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
节点列表
|
|
||||||
"""
|
|
||||||
logger.info(f"获取图谱 {graph_id} 的所有节点...")
|
logger.info(f"获取图谱 {graph_id} 的所有节点...")
|
||||||
|
nodes = self.store.get_nodes(graph_id)
|
||||||
nodes = fetch_all_nodes(self.client, graph_id)
|
logger.info(f"共获取 {len(nodes)} 个节点")
|
||||||
|
return nodes
|
||||||
nodes_data = []
|
|
||||||
for node in nodes:
|
|
||||||
nodes_data.append({
|
|
||||||
"uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
|
|
||||||
"name": node.name or "",
|
|
||||||
"labels": node.labels or [],
|
|
||||||
"summary": node.summary or "",
|
|
||||||
"attributes": node.attributes or {},
|
|
||||||
})
|
|
||||||
|
|
||||||
logger.info(f"共获取 {len(nodes_data)} 个节点")
|
|
||||||
return nodes_data
|
|
||||||
|
|
||||||
def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]:
|
def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""获取图谱的所有边"""
|
||||||
获取图谱的所有边(分页获取)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
graph_id: 图谱ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
边列表
|
|
||||||
"""
|
|
||||||
logger.info(f"获取图谱 {graph_id} 的所有边...")
|
logger.info(f"获取图谱 {graph_id} 的所有边...")
|
||||||
|
edges = self.store.get_edges(graph_id)
|
||||||
|
logger.info(f"共获取 {len(edges)} 条边")
|
||||||
|
return edges
|
||||||
|
|
||||||
edges = fetch_all_edges(self.client, graph_id)
|
def get_node_edges(self, graph_id: str, node_uuid: str) -> List[Dict[str, Any]]:
|
||||||
|
"""获取指定节点的所有相关边"""
|
||||||
edges_data = []
|
|
||||||
for edge in edges:
|
|
||||||
edges_data.append({
|
|
||||||
"uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
|
|
||||||
"name": edge.name or "",
|
|
||||||
"fact": edge.fact or "",
|
|
||||||
"source_node_uuid": edge.source_node_uuid,
|
|
||||||
"target_node_uuid": edge.target_node_uuid,
|
|
||||||
"attributes": edge.attributes or {},
|
|
||||||
})
|
|
||||||
|
|
||||||
logger.info(f"共获取 {len(edges_data)} 条边")
|
|
||||||
return edges_data
|
|
||||||
|
|
||||||
def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]:
|
|
||||||
"""
|
|
||||||
获取指定节点的所有相关边(带重试机制)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
node_uuid: 节点UUID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
边列表
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# 使用重试机制调用Zep API
|
return self.store.get_node_edges(graph_id, node_uuid)
|
||||||
edges = self._call_with_retry(
|
|
||||||
func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid),
|
|
||||||
operation_name=f"获取节点边(node={node_uuid[:8]}...)"
|
|
||||||
)
|
|
||||||
|
|
||||||
edges_data = []
|
|
||||||
for edge in edges:
|
|
||||||
edges_data.append({
|
|
||||||
"uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
|
|
||||||
"name": edge.name or "",
|
|
||||||
"fact": edge.fact or "",
|
|
||||||
"source_node_uuid": edge.source_node_uuid,
|
|
||||||
"target_node_uuid": edge.target_node_uuid,
|
|
||||||
"attributes": edge.attributes or {},
|
|
||||||
})
|
|
||||||
|
|
||||||
return edges_data
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"获取节点 {node_uuid} 的边失败: {str(e)}")
|
logger.warning(f"获取节点 {node_uuid} 的边失败: {e}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def filter_defined_entities(
|
def filter_defined_entities(
|
||||||
|
|
@ -222,106 +107,91 @@ class ZepEntityReader:
|
||||||
筛选出符合预定义实体类型的节点
|
筛选出符合预定义实体类型的节点
|
||||||
|
|
||||||
筛选逻辑:
|
筛选逻辑:
|
||||||
- 如果节点的Labels只有一个"Entity",说明这个实体不符合我们预定义的类型,跳过
|
- 节点的Labels包含除"Entity"和"Node"之外的标签 → 符合预定义类型,保留
|
||||||
- 如果节点的Labels包含除"Entity"和"Node"之外的标签,说明符合预定义类型,保留
|
- 节点的Labels只有"Entity"/"Node" → 不符合,跳过
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph_id: 图谱ID
|
graph_id: 图谱ID
|
||||||
defined_entity_types: 预定义的实体类型列表(可选,如果提供则只保留这些类型)
|
defined_entity_types: 预定义实体类型列表(可选,若提供则只保留这些类型)
|
||||||
enrich_with_edges: 是否获取每个实体的相关边信息
|
enrich_with_edges: 是否获取每个实体的相关边信息
|
||||||
|
|
||||||
Returns:
|
|
||||||
FilteredEntities: 过滤后的实体集合
|
|
||||||
"""
|
"""
|
||||||
logger.info(f"开始筛选图谱 {graph_id} 的实体...")
|
logger.info(f"开始筛选图谱 {graph_id} 的实体...")
|
||||||
|
|
||||||
# 获取所有节点
|
|
||||||
all_nodes = self.get_all_nodes(graph_id)
|
all_nodes = self.get_all_nodes(graph_id)
|
||||||
total_count = len(all_nodes)
|
total_count = len(all_nodes)
|
||||||
|
|
||||||
# 获取所有边(用于后续关联查找)
|
|
||||||
all_edges = self.get_all_edges(graph_id) if enrich_with_edges else []
|
all_edges = self.get_all_edges(graph_id) if enrich_with_edges else []
|
||||||
|
|
||||||
# 构建节点UUID到节点数据的映射
|
|
||||||
node_map = {n["uuid"]: n for n in all_nodes}
|
node_map = {n["uuid"]: n for n in all_nodes}
|
||||||
|
|
||||||
# 筛选符合条件的实体
|
|
||||||
filtered_entities = []
|
filtered_entities = []
|
||||||
entity_types_found = set()
|
entity_types_found: Set[str] = set()
|
||||||
|
|
||||||
for node in all_nodes:
|
for node in all_nodes:
|
||||||
labels = node.get("labels", [])
|
labels = node.get("labels") or []
|
||||||
|
custom_labels = [l for l in labels if l not in ("Entity", "Node")]
|
||||||
# 筛选逻辑:Labels必须包含除"Entity"和"Node"之外的标签
|
|
||||||
custom_labels = [l for l in labels if l not in ["Entity", "Node"]]
|
|
||||||
|
|
||||||
if not custom_labels:
|
if not custom_labels:
|
||||||
# 只有默认标签,跳过
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 如果指定了预定义类型,检查是否匹配
|
|
||||||
if defined_entity_types:
|
if defined_entity_types:
|
||||||
matching_labels = [l for l in custom_labels if l in defined_entity_types]
|
matching = [l for l in custom_labels if l in defined_entity_types]
|
||||||
if not matching_labels:
|
if not matching:
|
||||||
continue
|
continue
|
||||||
entity_type = matching_labels[0]
|
entity_type = matching[0]
|
||||||
else:
|
else:
|
||||||
entity_type = custom_labels[0]
|
entity_type = custom_labels[0]
|
||||||
|
|
||||||
entity_types_found.add(entity_type)
|
entity_types_found.add(entity_type)
|
||||||
|
|
||||||
# 创建实体节点对象
|
|
||||||
entity = EntityNode(
|
entity = EntityNode(
|
||||||
uuid=node["uuid"],
|
uuid=node["uuid"],
|
||||||
name=node["name"],
|
name=node.get("name", ""),
|
||||||
labels=labels,
|
labels=labels,
|
||||||
summary=node["summary"],
|
summary=node.get("summary", ""),
|
||||||
attributes=node["attributes"],
|
attributes=node.get("attributes", {}),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 获取相关边和节点
|
|
||||||
if enrich_with_edges:
|
if enrich_with_edges:
|
||||||
related_edges = []
|
related_edges = []
|
||||||
related_node_uuids = set()
|
related_node_uuids: Set[str] = set()
|
||||||
|
|
||||||
for edge in all_edges:
|
for edge in all_edges:
|
||||||
if edge["source_node_uuid"] == node["uuid"]:
|
if edge.get("source_node_uuid") == node["uuid"]:
|
||||||
related_edges.append({
|
related_edges.append({
|
||||||
"direction": "outgoing",
|
"direction": "outgoing",
|
||||||
"edge_name": edge["name"],
|
"edge_name": edge.get("name", ""),
|
||||||
"fact": edge["fact"],
|
"fact": edge.get("fact", ""),
|
||||||
"target_node_uuid": edge["target_node_uuid"],
|
"target_node_uuid": edge.get("target_node_uuid", ""),
|
||||||
})
|
})
|
||||||
related_node_uuids.add(edge["target_node_uuid"])
|
related_node_uuids.add(edge.get("target_node_uuid", ""))
|
||||||
elif edge["target_node_uuid"] == node["uuid"]:
|
elif edge.get("target_node_uuid") == node["uuid"]:
|
||||||
related_edges.append({
|
related_edges.append({
|
||||||
"direction": "incoming",
|
"direction": "incoming",
|
||||||
"edge_name": edge["name"],
|
"edge_name": edge.get("name", ""),
|
||||||
"fact": edge["fact"],
|
"fact": edge.get("fact", ""),
|
||||||
"source_node_uuid": edge["source_node_uuid"],
|
"source_node_uuid": edge.get("source_node_uuid", ""),
|
||||||
})
|
})
|
||||||
related_node_uuids.add(edge["source_node_uuid"])
|
related_node_uuids.add(edge.get("source_node_uuid", ""))
|
||||||
|
|
||||||
entity.related_edges = related_edges
|
entity.related_edges = related_edges
|
||||||
|
|
||||||
# 获取关联节点的基本信息
|
|
||||||
related_nodes = []
|
related_nodes = []
|
||||||
for related_uuid in related_node_uuids:
|
for related_uuid in related_node_uuids:
|
||||||
if related_uuid in node_map:
|
if related_uuid and related_uuid in node_map:
|
||||||
related_node = node_map[related_uuid]
|
rn = node_map[related_uuid]
|
||||||
related_nodes.append({
|
related_nodes.append({
|
||||||
"uuid": related_node["uuid"],
|
"uuid": rn["uuid"],
|
||||||
"name": related_node["name"],
|
"name": rn.get("name", ""),
|
||||||
"labels": related_node["labels"],
|
"labels": rn.get("labels", []),
|
||||||
"summary": related_node.get("summary", ""),
|
"summary": rn.get("summary", ""),
|
||||||
})
|
})
|
||||||
|
|
||||||
entity.related_nodes = related_nodes
|
entity.related_nodes = related_nodes
|
||||||
|
|
||||||
filtered_entities.append(entity)
|
filtered_entities.append(entity)
|
||||||
|
|
||||||
logger.info(f"筛选完成: 总节点 {total_count}, 符合条件 {len(filtered_entities)}, "
|
logger.info(f"筛选完成: 总节点 {total_count}, 符合条件 {len(filtered_entities)}, "
|
||||||
f"实体类型: {entity_types_found}")
|
f"实体类型: {entity_types_found}")
|
||||||
|
|
||||||
return FilteredEntities(
|
return FilteredEntities(
|
||||||
entities=filtered_entities,
|
entities=filtered_entities,
|
||||||
|
|
@ -335,79 +205,60 @@ class ZepEntityReader:
|
||||||
graph_id: str,
|
graph_id: str,
|
||||||
entity_uuid: str
|
entity_uuid: str
|
||||||
) -> Optional[EntityNode]:
|
) -> Optional[EntityNode]:
|
||||||
"""
|
"""获取单个实体及其完整上下文(边和关联节点)"""
|
||||||
获取单个实体及其完整上下文(边和关联节点,带重试机制)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
graph_id: 图谱ID
|
|
||||||
entity_uuid: 实体UUID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
EntityNode或None
|
|
||||||
"""
|
|
||||||
try:
|
try:
|
||||||
# 使用重试机制获取节点
|
node = self.store.get_node(graph_id, entity_uuid)
|
||||||
node = self._call_with_retry(
|
|
||||||
func=lambda: self.client.graph.node.get(uuid_=entity_uuid),
|
|
||||||
operation_name=f"获取节点详情(uuid={entity_uuid[:8]}...)"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not node:
|
if not node:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 获取节点的边
|
edges = self.get_node_edges(graph_id, entity_uuid)
|
||||||
edges = self.get_node_edges(entity_uuid)
|
|
||||||
|
|
||||||
# 获取所有节点用于关联查找
|
|
||||||
all_nodes = self.get_all_nodes(graph_id)
|
all_nodes = self.get_all_nodes(graph_id)
|
||||||
node_map = {n["uuid"]: n for n in all_nodes}
|
node_map = {n["uuid"]: n for n in all_nodes}
|
||||||
|
|
||||||
# 处理相关边和节点
|
|
||||||
related_edges = []
|
related_edges = []
|
||||||
related_node_uuids = set()
|
related_node_uuids: Set[str] = set()
|
||||||
|
|
||||||
for edge in edges:
|
for edge in edges:
|
||||||
if edge["source_node_uuid"] == entity_uuid:
|
if edge.get("source_node_uuid") == entity_uuid:
|
||||||
related_edges.append({
|
related_edges.append({
|
||||||
"direction": "outgoing",
|
"direction": "outgoing",
|
||||||
"edge_name": edge["name"],
|
"edge_name": edge.get("name", ""),
|
||||||
"fact": edge["fact"],
|
"fact": edge.get("fact", ""),
|
||||||
"target_node_uuid": edge["target_node_uuid"],
|
"target_node_uuid": edge.get("target_node_uuid", ""),
|
||||||
})
|
})
|
||||||
related_node_uuids.add(edge["target_node_uuid"])
|
related_node_uuids.add(edge.get("target_node_uuid", ""))
|
||||||
else:
|
else:
|
||||||
related_edges.append({
|
related_edges.append({
|
||||||
"direction": "incoming",
|
"direction": "incoming",
|
||||||
"edge_name": edge["name"],
|
"edge_name": edge.get("name", ""),
|
||||||
"fact": edge["fact"],
|
"fact": edge.get("fact", ""),
|
||||||
"source_node_uuid": edge["source_node_uuid"],
|
"source_node_uuid": edge.get("source_node_uuid", ""),
|
||||||
})
|
})
|
||||||
related_node_uuids.add(edge["source_node_uuid"])
|
related_node_uuids.add(edge.get("source_node_uuid", ""))
|
||||||
|
|
||||||
# 获取关联节点信息
|
|
||||||
related_nodes = []
|
related_nodes = []
|
||||||
for related_uuid in related_node_uuids:
|
for related_uuid in related_node_uuids:
|
||||||
if related_uuid in node_map:
|
if related_uuid and related_uuid in node_map:
|
||||||
related_node = node_map[related_uuid]
|
rn = node_map[related_uuid]
|
||||||
related_nodes.append({
|
related_nodes.append({
|
||||||
"uuid": related_node["uuid"],
|
"uuid": rn["uuid"],
|
||||||
"name": related_node["name"],
|
"name": rn.get("name", ""),
|
||||||
"labels": related_node["labels"],
|
"labels": rn.get("labels", []),
|
||||||
"summary": related_node.get("summary", ""),
|
"summary": rn.get("summary", ""),
|
||||||
})
|
})
|
||||||
|
|
||||||
return EntityNode(
|
return EntityNode(
|
||||||
uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
|
uuid=node["uuid"],
|
||||||
name=node.name or "",
|
name=node.get("name", ""),
|
||||||
labels=node.labels or [],
|
labels=node.get("labels", []),
|
||||||
summary=node.summary or "",
|
summary=node.get("summary", ""),
|
||||||
attributes=node.attributes or {},
|
attributes=node.get("attributes", {}),
|
||||||
related_edges=related_edges,
|
related_edges=related_edges,
|
||||||
related_nodes=related_nodes,
|
related_nodes=related_nodes,
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取实体 {entity_uuid} 失败: {str(e)}")
|
logger.error(f"获取实体 {entity_uuid} 失败: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_entities_by_type(
|
def get_entities_by_type(
|
||||||
|
|
@ -416,22 +267,10 @@ class ZepEntityReader:
|
||||||
entity_type: str,
|
entity_type: str,
|
||||||
enrich_with_edges: bool = True
|
enrich_with_edges: bool = True
|
||||||
) -> List[EntityNode]:
|
) -> List[EntityNode]:
|
||||||
"""
|
"""获取指定类型的所有实体"""
|
||||||
获取指定类型的所有实体
|
|
||||||
|
|
||||||
Args:
|
|
||||||
graph_id: 图谱ID
|
|
||||||
entity_type: 实体类型(如 "Student", "PublicFigure" 等)
|
|
||||||
enrich_with_edges: 是否获取相关边信息
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
实体列表
|
|
||||||
"""
|
|
||||||
result = self.filter_defined_entities(
|
result = self.filter_defined_entities(
|
||||||
graph_id=graph_id,
|
graph_id=graph_id,
|
||||||
defined_entity_types=[entity_type],
|
defined_entity_types=[entity_type],
|
||||||
enrich_with_edges=enrich_with_edges
|
enrich_with_edges=enrich_with_edges
|
||||||
)
|
)
|
||||||
return result.entities
|
return result.entities
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"""
|
"""
|
||||||
Zep图谱记忆更新服务
|
图谱记忆更新服务
|
||||||
将模拟中的Agent活动动态更新到Zep图谱中
|
将模拟中的Agent活动动态写入本地JSON图谱文件
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
@ -12,9 +12,8 @@ from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from queue import Queue, Empty
|
from queue import Queue, Empty
|
||||||
|
|
||||||
from zep_cloud.client import Zep
|
|
||||||
|
|
||||||
from ..config import Config
|
from ..config import Config
|
||||||
|
from ..utils.local_graph_store import LocalGraphStore
|
||||||
from ..utils.logger import get_logger
|
from ..utils.logger import get_logger
|
||||||
from ..utils.locale import get_locale, set_locale
|
from ..utils.locale import get_locale, set_locale
|
||||||
|
|
||||||
|
|
@ -34,12 +33,11 @@ class AgentActivity:
|
||||||
|
|
||||||
def to_episode_text(self) -> str:
|
def to_episode_text(self) -> str:
|
||||||
"""
|
"""
|
||||||
将活动转换为可以发送给Zep的文本描述
|
将活动转换为自然语言描述文本
|
||||||
|
|
||||||
采用自然语言描述格式,让Zep能够从中提取实体和关系
|
采用自然语言描述格式,让图谱能够从中提取实体和关系
|
||||||
不添加模拟相关的前缀,避免误导图谱更新
|
不添加模拟相关的前缀,避免误导图谱更新
|
||||||
"""
|
"""
|
||||||
# 根据不同的动作类型生成不同的描述
|
|
||||||
action_descriptions = {
|
action_descriptions = {
|
||||||
"CREATE_POST": self._describe_create_post,
|
"CREATE_POST": self._describe_create_post,
|
||||||
"LIKE_POST": self._describe_like_post,
|
"LIKE_POST": self._describe_like_post,
|
||||||
|
|
@ -58,7 +56,6 @@ class AgentActivity:
|
||||||
describe_func = action_descriptions.get(self.action_type, self._describe_generic)
|
describe_func = action_descriptions.get(self.action_type, self._describe_generic)
|
||||||
description = describe_func()
|
description = describe_func()
|
||||||
|
|
||||||
# 直接返回 "agent名称: 活动描述" 格式,不添加模拟前缀
|
|
||||||
return f"{self.agent_name}: {description}"
|
return f"{self.agent_name}: {description}"
|
||||||
|
|
||||||
def _describe_create_post(self) -> str:
|
def _describe_create_post(self) -> str:
|
||||||
|
|
@ -68,7 +65,6 @@ class AgentActivity:
|
||||||
return "发布了一条帖子"
|
return "发布了一条帖子"
|
||||||
|
|
||||||
def _describe_like_post(self) -> str:
|
def _describe_like_post(self) -> str:
|
||||||
"""点赞帖子 - 包含帖子原文和作者信息"""
|
|
||||||
post_content = self.action_args.get("post_content", "")
|
post_content = self.action_args.get("post_content", "")
|
||||||
post_author = self.action_args.get("post_author_name", "")
|
post_author = self.action_args.get("post_author_name", "")
|
||||||
|
|
||||||
|
|
@ -81,7 +77,6 @@ class AgentActivity:
|
||||||
return "点赞了一条帖子"
|
return "点赞了一条帖子"
|
||||||
|
|
||||||
def _describe_dislike_post(self) -> str:
|
def _describe_dislike_post(self) -> str:
|
||||||
"""踩帖子 - 包含帖子原文和作者信息"""
|
|
||||||
post_content = self.action_args.get("post_content", "")
|
post_content = self.action_args.get("post_content", "")
|
||||||
post_author = self.action_args.get("post_author_name", "")
|
post_author = self.action_args.get("post_author_name", "")
|
||||||
|
|
||||||
|
|
@ -94,7 +89,6 @@ class AgentActivity:
|
||||||
return "踩了一条帖子"
|
return "踩了一条帖子"
|
||||||
|
|
||||||
def _describe_repost(self) -> str:
|
def _describe_repost(self) -> str:
|
||||||
"""转发帖子 - 包含原帖内容和作者信息"""
|
|
||||||
original_content = self.action_args.get("original_content", "")
|
original_content = self.action_args.get("original_content", "")
|
||||||
original_author = self.action_args.get("original_author_name", "")
|
original_author = self.action_args.get("original_author_name", "")
|
||||||
|
|
||||||
|
|
@ -107,7 +101,6 @@ class AgentActivity:
|
||||||
return "转发了一条帖子"
|
return "转发了一条帖子"
|
||||||
|
|
||||||
def _describe_quote_post(self) -> str:
|
def _describe_quote_post(self) -> str:
|
||||||
"""引用帖子 - 包含原帖内容、作者信息和引用评论"""
|
|
||||||
original_content = self.action_args.get("original_content", "")
|
original_content = self.action_args.get("original_content", "")
|
||||||
original_author = self.action_args.get("original_author_name", "")
|
original_author = self.action_args.get("original_author_name", "")
|
||||||
quote_content = self.action_args.get("quote_content", "") or self.action_args.get("content", "")
|
quote_content = self.action_args.get("quote_content", "") or self.action_args.get("content", "")
|
||||||
|
|
@ -127,15 +120,12 @@ class AgentActivity:
|
||||||
return base
|
return base
|
||||||
|
|
||||||
def _describe_follow(self) -> str:
|
def _describe_follow(self) -> str:
|
||||||
"""关注用户 - 包含被关注用户的名称"""
|
|
||||||
target_user_name = self.action_args.get("target_user_name", "")
|
target_user_name = self.action_args.get("target_user_name", "")
|
||||||
|
|
||||||
if target_user_name:
|
if target_user_name:
|
||||||
return f"关注了用户「{target_user_name}」"
|
return f"关注了用户「{target_user_name}」"
|
||||||
return "关注了一个用户"
|
return "关注了一个用户"
|
||||||
|
|
||||||
def _describe_create_comment(self) -> str:
|
def _describe_create_comment(self) -> str:
|
||||||
"""发表评论 - 包含评论内容和所评论的帖子信息"""
|
|
||||||
content = self.action_args.get("content", "")
|
content = self.action_args.get("content", "")
|
||||||
post_content = self.action_args.get("post_content", "")
|
post_content = self.action_args.get("post_content", "")
|
||||||
post_author = self.action_args.get("post_author_name", "")
|
post_author = self.action_args.get("post_author_name", "")
|
||||||
|
|
@ -151,7 +141,6 @@ class AgentActivity:
|
||||||
return "发表了评论"
|
return "发表了评论"
|
||||||
|
|
||||||
def _describe_like_comment(self) -> str:
|
def _describe_like_comment(self) -> str:
|
||||||
"""点赞评论 - 包含评论内容和作者信息"""
|
|
||||||
comment_content = self.action_args.get("comment_content", "")
|
comment_content = self.action_args.get("comment_content", "")
|
||||||
comment_author = self.action_args.get("comment_author_name", "")
|
comment_author = self.action_args.get("comment_author_name", "")
|
||||||
|
|
||||||
|
|
@ -164,7 +153,6 @@ class AgentActivity:
|
||||||
return "点赞了一条评论"
|
return "点赞了一条评论"
|
||||||
|
|
||||||
def _describe_dislike_comment(self) -> str:
|
def _describe_dislike_comment(self) -> str:
|
||||||
"""踩评论 - 包含评论内容和作者信息"""
|
|
||||||
comment_content = self.action_args.get("comment_content", "")
|
comment_content = self.action_args.get("comment_content", "")
|
||||||
comment_author = self.action_args.get("comment_author_name", "")
|
comment_author = self.action_args.get("comment_author_name", "")
|
||||||
|
|
||||||
|
|
@ -177,99 +165,74 @@ class AgentActivity:
|
||||||
return "踩了一条评论"
|
return "踩了一条评论"
|
||||||
|
|
||||||
def _describe_search(self) -> str:
|
def _describe_search(self) -> str:
|
||||||
"""搜索帖子 - 包含搜索关键词"""
|
|
||||||
query = self.action_args.get("query", "") or self.action_args.get("keyword", "")
|
query = self.action_args.get("query", "") or self.action_args.get("keyword", "")
|
||||||
return f"搜索了「{query}」" if query else "进行了搜索"
|
return f"搜索了「{query}」" if query else "进行了搜索"
|
||||||
|
|
||||||
def _describe_search_user(self) -> str:
|
def _describe_search_user(self) -> str:
|
||||||
"""搜索用户 - 包含搜索关键词"""
|
|
||||||
query = self.action_args.get("query", "") or self.action_args.get("username", "")
|
query = self.action_args.get("query", "") or self.action_args.get("username", "")
|
||||||
return f"搜索了用户「{query}」" if query else "搜索了用户"
|
return f"搜索了用户「{query}」" if query else "搜索了用户"
|
||||||
|
|
||||||
def _describe_mute(self) -> str:
|
def _describe_mute(self) -> str:
|
||||||
"""屏蔽用户 - 包含被屏蔽用户的名称"""
|
|
||||||
target_user_name = self.action_args.get("target_user_name", "")
|
target_user_name = self.action_args.get("target_user_name", "")
|
||||||
|
|
||||||
if target_user_name:
|
if target_user_name:
|
||||||
return f"屏蔽了用户「{target_user_name}」"
|
return f"屏蔽了用户「{target_user_name}」"
|
||||||
return "屏蔽了一个用户"
|
return "屏蔽了一个用户"
|
||||||
|
|
||||||
def _describe_generic(self) -> str:
|
def _describe_generic(self) -> str:
|
||||||
# 对于未知的动作类型,生成通用描述
|
|
||||||
return f"执行了{self.action_type}操作"
|
return f"执行了{self.action_type}操作"
|
||||||
|
|
||||||
|
|
||||||
class ZepGraphMemoryUpdater:
|
class GraphMemoryUpdater:
|
||||||
"""
|
"""
|
||||||
Zep图谱记忆更新器
|
图谱记忆更新器
|
||||||
|
|
||||||
监控模拟的actions日志文件,将新的agent活动实时更新到Zep图谱中。
|
监控模拟的actions日志文件,将新的agent活动实时写入本地图谱。
|
||||||
按平台分组,每累积BATCH_SIZE条活动后批量发送到Zep。
|
按平台分组,每累积BATCH_SIZE条活动后批量写入。
|
||||||
|
|
||||||
所有有意义的行为都会被更新到Zep,action_args中会包含完整的上下文信息:
|
|
||||||
- 点赞/踩的帖子原文
|
|
||||||
- 转发/引用的帖子原文
|
|
||||||
- 关注/屏蔽的用户名
|
|
||||||
- 点赞/踩的评论原文
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 批量发送大小(每个平台累积多少条后发送)
|
|
||||||
BATCH_SIZE = 5
|
BATCH_SIZE = 5
|
||||||
|
|
||||||
# 平台名称映射(用于控制台显示)
|
|
||||||
PLATFORM_DISPLAY_NAMES = {
|
PLATFORM_DISPLAY_NAMES = {
|
||||||
'twitter': '世界1',
|
'twitter': '世界1',
|
||||||
'reddit': '世界2',
|
'reddit': '世界2',
|
||||||
}
|
}
|
||||||
|
|
||||||
# 发送间隔(秒),避免请求过快
|
SEND_INTERVAL = 0.1 # 本地写入更快,间隔可以更短
|
||||||
SEND_INTERVAL = 0.5
|
|
||||||
|
|
||||||
# 重试配置
|
|
||||||
MAX_RETRIES = 3
|
MAX_RETRIES = 3
|
||||||
RETRY_DELAY = 2 # 秒
|
RETRY_DELAY = 1
|
||||||
|
|
||||||
def __init__(self, graph_id: str, api_key: Optional[str] = None):
|
def __init__(self, graph_id: str, storage_dir: Optional[str] = None, api_key: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
初始化更新器
|
初始化更新器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph_id: Zep图谱ID
|
graph_id: 本地图谱ID
|
||||||
api_key: Zep API Key(可选,默认从配置读取)
|
storage_dir: 图谱存储目录(可选,默认从配置读取)
|
||||||
|
api_key: 已废弃,保留以兼容旧调用代码
|
||||||
"""
|
"""
|
||||||
self.graph_id = graph_id
|
self.graph_id = graph_id
|
||||||
self.api_key = api_key or Config.ZEP_API_KEY
|
storage_dir = storage_dir or Config.GRAPH_STORAGE_DIR
|
||||||
|
self.store = LocalGraphStore(storage_dir)
|
||||||
|
|
||||||
if not self.api_key:
|
|
||||||
raise ValueError("ZEP_API_KEY未配置")
|
|
||||||
|
|
||||||
self.client = Zep(api_key=self.api_key)
|
|
||||||
|
|
||||||
# 活动队列
|
|
||||||
self._activity_queue: Queue = Queue()
|
self._activity_queue: Queue = Queue()
|
||||||
|
|
||||||
# 按平台分组的活动缓冲区(每个平台各自累积到BATCH_SIZE后批量发送)
|
|
||||||
self._platform_buffers: Dict[str, List[AgentActivity]] = {
|
self._platform_buffers: Dict[str, List[AgentActivity]] = {
|
||||||
'twitter': [],
|
'twitter': [],
|
||||||
'reddit': [],
|
'reddit': [],
|
||||||
}
|
}
|
||||||
self._buffer_lock = threading.Lock()
|
self._buffer_lock = threading.Lock()
|
||||||
|
|
||||||
# 控制标志
|
|
||||||
self._running = False
|
self._running = False
|
||||||
self._worker_thread: Optional[threading.Thread] = None
|
self._worker_thread: Optional[threading.Thread] = None
|
||||||
|
|
||||||
# 统计
|
self._total_activities = 0
|
||||||
self._total_activities = 0 # 实际添加到队列的活动数
|
self._total_sent = 0
|
||||||
self._total_sent = 0 # 成功发送到Zep的批次数
|
self._total_items_sent = 0
|
||||||
self._total_items_sent = 0 # 成功发送到Zep的活动条数
|
self._failed_count = 0
|
||||||
self._failed_count = 0 # 发送失败的批次数
|
self._skipped_count = 0
|
||||||
self._skipped_count = 0 # 被过滤跳过的活动数(DO_NOTHING)
|
|
||||||
|
|
||||||
logger.info(f"ZepGraphMemoryUpdater 初始化完成: graph_id={graph_id}, batch_size={self.BATCH_SIZE}")
|
logger.info(f"GraphMemoryUpdater 初始化完成: graph_id={graph_id}, batch_size={self.BATCH_SIZE}")
|
||||||
|
|
||||||
def _get_platform_display_name(self, platform: str) -> str:
|
def _get_platform_display_name(self, platform: str) -> str:
|
||||||
"""获取平台的显示名称"""
|
|
||||||
return self.PLATFORM_DISPLAY_NAMES.get(platform.lower(), platform)
|
return self.PLATFORM_DISPLAY_NAMES.get(platform.lower(), platform)
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
|
|
@ -277,7 +240,6 @@ class ZepGraphMemoryUpdater:
|
||||||
if self._running:
|
if self._running:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Capture locale before spawning background thread
|
|
||||||
current_locale = get_locale()
|
current_locale = get_locale()
|
||||||
|
|
||||||
self._running = True
|
self._running = True
|
||||||
|
|
@ -285,67 +247,39 @@ class ZepGraphMemoryUpdater:
|
||||||
target=self._worker_loop,
|
target=self._worker_loop,
|
||||||
args=(current_locale,),
|
args=(current_locale,),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
name=f"ZepMemoryUpdater-{self.graph_id[:8]}"
|
name=f"GraphMemoryUpdater-{self.graph_id[:8]}"
|
||||||
)
|
)
|
||||||
self._worker_thread.start()
|
self._worker_thread.start()
|
||||||
logger.info(f"ZepGraphMemoryUpdater 已启动: graph_id={self.graph_id}")
|
logger.info(f"GraphMemoryUpdater 已启动: graph_id={self.graph_id}")
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
"""停止后台工作线程"""
|
"""停止后台工作线程"""
|
||||||
self._running = False
|
self._running = False
|
||||||
|
|
||||||
# 发送剩余的活动
|
|
||||||
self._flush_remaining()
|
self._flush_remaining()
|
||||||
|
|
||||||
if self._worker_thread and self._worker_thread.is_alive():
|
if self._worker_thread and self._worker_thread.is_alive():
|
||||||
self._worker_thread.join(timeout=10)
|
self._worker_thread.join(timeout=10)
|
||||||
|
|
||||||
logger.info(f"ZepGraphMemoryUpdater 已停止: graph_id={self.graph_id}, "
|
logger.info(f"GraphMemoryUpdater 已停止: graph_id={self.graph_id}, "
|
||||||
f"total_activities={self._total_activities}, "
|
f"total_activities={self._total_activities}, "
|
||||||
f"batches_sent={self._total_sent}, "
|
f"batches_sent={self._total_sent}, "
|
||||||
f"items_sent={self._total_items_sent}, "
|
f"items_sent={self._total_items_sent}, "
|
||||||
f"failed={self._failed_count}, "
|
f"failed={self._failed_count}, "
|
||||||
f"skipped={self._skipped_count}")
|
f"skipped={self._skipped_count}")
|
||||||
|
|
||||||
def add_activity(self, activity: AgentActivity):
|
def add_activity(self, activity: AgentActivity):
|
||||||
"""
|
"""添加一个agent活动到队列"""
|
||||||
添加一个agent活动到队列
|
|
||||||
|
|
||||||
所有有意义的行为都会被添加到队列,包括:
|
|
||||||
- CREATE_POST(发帖)
|
|
||||||
- CREATE_COMMENT(评论)
|
|
||||||
- QUOTE_POST(引用帖子)
|
|
||||||
- SEARCH_POSTS(搜索帖子)
|
|
||||||
- SEARCH_USER(搜索用户)
|
|
||||||
- LIKE_POST/DISLIKE_POST(点赞/踩帖子)
|
|
||||||
- REPOST(转发)
|
|
||||||
- FOLLOW(关注)
|
|
||||||
- MUTE(屏蔽)
|
|
||||||
- LIKE_COMMENT/DISLIKE_COMMENT(点赞/踩评论)
|
|
||||||
|
|
||||||
action_args中会包含完整的上下文信息(如帖子原文、用户名等)。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
activity: Agent活动记录
|
|
||||||
"""
|
|
||||||
# 跳过DO_NOTHING类型的活动
|
|
||||||
if activity.action_type == "DO_NOTHING":
|
if activity.action_type == "DO_NOTHING":
|
||||||
self._skipped_count += 1
|
self._skipped_count += 1
|
||||||
return
|
return
|
||||||
|
|
||||||
self._activity_queue.put(activity)
|
self._activity_queue.put(activity)
|
||||||
self._total_activities += 1
|
self._total_activities += 1
|
||||||
logger.debug(f"添加活动到Zep队列: {activity.agent_name} - {activity.action_type}")
|
logger.debug(f"添加活动到队列: {activity.agent_name} - {activity.action_type}")
|
||||||
|
|
||||||
def add_activity_from_dict(self, data: Dict[str, Any], platform: str):
|
def add_activity_from_dict(self, data: Dict[str, Any], platform: str):
|
||||||
"""
|
"""从字典数据添加活动"""
|
||||||
从字典数据添加活动
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: 从actions.jsonl解析的字典数据
|
|
||||||
platform: 平台名称 (twitter/reddit)
|
|
||||||
"""
|
|
||||||
# 跳过事件类型的条目
|
|
||||||
if "event_type" in data:
|
if "event_type" in data:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -362,28 +296,23 @@ class ZepGraphMemoryUpdater:
|
||||||
self.add_activity(activity)
|
self.add_activity(activity)
|
||||||
|
|
||||||
def _worker_loop(self, locale: str = 'zh'):
|
def _worker_loop(self, locale: str = 'zh'):
|
||||||
"""后台工作循环 - 按平台批量发送活动到Zep"""
|
"""后台工作循环 - 按平台批量写入活动"""
|
||||||
set_locale(locale)
|
set_locale(locale)
|
||||||
while self._running or not self._activity_queue.empty():
|
while self._running or not self._activity_queue.empty():
|
||||||
try:
|
try:
|
||||||
# 尝试从队列获取活动(超时1秒)
|
|
||||||
try:
|
try:
|
||||||
activity = self._activity_queue.get(timeout=1)
|
activity = self._activity_queue.get(timeout=1)
|
||||||
|
|
||||||
# 将活动添加到对应平台的缓冲区
|
|
||||||
platform = activity.platform.lower()
|
platform = activity.platform.lower()
|
||||||
with self._buffer_lock:
|
with self._buffer_lock:
|
||||||
if platform not in self._platform_buffers:
|
if platform not in self._platform_buffers:
|
||||||
self._platform_buffers[platform] = []
|
self._platform_buffers[platform] = []
|
||||||
self._platform_buffers[platform].append(activity)
|
self._platform_buffers[platform].append(activity)
|
||||||
|
|
||||||
# 检查该平台是否达到批量大小
|
|
||||||
if len(self._platform_buffers[platform]) >= self.BATCH_SIZE:
|
if len(self._platform_buffers[platform]) >= self.BATCH_SIZE:
|
||||||
batch = self._platform_buffers[platform][:self.BATCH_SIZE]
|
batch = self._platform_buffers[platform][:self.BATCH_SIZE]
|
||||||
self._platform_buffers[platform] = self._platform_buffers[platform][self.BATCH_SIZE:]
|
self._platform_buffers[platform] = self._platform_buffers[platform][self.BATCH_SIZE:]
|
||||||
# 释放锁后再发送
|
self._write_batch_activities(batch, platform)
|
||||||
self._send_batch_activities(batch, platform)
|
|
||||||
# 发送间隔,避免请求过快
|
|
||||||
time.sleep(self.SEND_INTERVAL)
|
time.sleep(self.SEND_INTERVAL)
|
||||||
|
|
||||||
except Empty:
|
except Empty:
|
||||||
|
|
@ -393,48 +322,64 @@ class ZepGraphMemoryUpdater:
|
||||||
logger.error(f"工作循环异常: {e}")
|
logger.error(f"工作循环异常: {e}")
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
def _send_batch_activities(self, activities: List[AgentActivity], platform: str):
|
def _write_batch_activities(self, activities: List[AgentActivity], platform: str):
|
||||||
"""
|
"""批量将活动写入本地图谱"""
|
||||||
批量发送活动到Zep图谱(合并为一条文本)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
activities: Agent活动列表
|
|
||||||
platform: 平台名称
|
|
||||||
"""
|
|
||||||
if not activities:
|
if not activities:
|
||||||
return
|
return
|
||||||
|
|
||||||
# 将多条活动合并为一条文本,用换行分隔
|
|
||||||
episode_texts = [activity.to_episode_text() for activity in activities]
|
episode_texts = [activity.to_episode_text() for activity in activities]
|
||||||
combined_text = "\n".join(episode_texts)
|
combined_text = "\n".join(episode_texts)
|
||||||
|
|
||||||
# 带重试的发送
|
|
||||||
for attempt in range(self.MAX_RETRIES):
|
for attempt in range(self.MAX_RETRIES):
|
||||||
try:
|
try:
|
||||||
self.client.graph.add(
|
# 写入情节文本
|
||||||
graph_id=self.graph_id,
|
self.store.add_episode(self.graph_id, combined_text)
|
||||||
type="text",
|
|
||||||
data=combined_text
|
# 为每条活动创建可搜索的事实边
|
||||||
)
|
for activity in activities:
|
||||||
|
self._create_activity_edge(activity)
|
||||||
|
|
||||||
self._total_sent += 1
|
self._total_sent += 1
|
||||||
self._total_items_sent += len(activities)
|
self._total_items_sent += len(activities)
|
||||||
display_name = self._get_platform_display_name(platform)
|
display_name = self._get_platform_display_name(platform)
|
||||||
logger.info(f"成功批量发送 {len(activities)} 条{display_name}活动到图谱 {self.graph_id}")
|
logger.info(f"成功写入 {len(activities)} 条{display_name}活动到图谱 {self.graph_id}")
|
||||||
logger.debug(f"批量内容预览: {combined_text[:200]}...")
|
|
||||||
return
|
return
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if attempt < self.MAX_RETRIES - 1:
|
if attempt < self.MAX_RETRIES - 1:
|
||||||
logger.warning(f"批量发送到Zep失败 (尝试 {attempt + 1}/{self.MAX_RETRIES}): {e}")
|
logger.warning(f"写入活动失败 (尝试 {attempt + 1}/{self.MAX_RETRIES}): {e}")
|
||||||
time.sleep(self.RETRY_DELAY * (attempt + 1))
|
time.sleep(self.RETRY_DELAY * (attempt + 1))
|
||||||
else:
|
else:
|
||||||
logger.error(f"批量发送到Zep失败,已重试{self.MAX_RETRIES}次: {e}")
|
logger.error(f"写入活动失败,已重试{self.MAX_RETRIES}次: {e}")
|
||||||
self._failed_count += 1
|
self._failed_count += 1
|
||||||
|
|
||||||
|
def _create_activity_edge(self, activity: AgentActivity):
|
||||||
|
"""为单条活动在图谱中创建可搜索的事实边"""
|
||||||
|
fact = activity.to_episode_text()
|
||||||
|
|
||||||
|
# 创建或获取Agent节点
|
||||||
|
agent_uuid = self.store.upsert_node(
|
||||||
|
graph_id=self.graph_id,
|
||||||
|
name=activity.agent_name,
|
||||||
|
labels=["Agent", "Entity"],
|
||||||
|
summary=f"Agent {activity.agent_name}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 为活动创建自环边(以便关键词搜索可以找到它)
|
||||||
|
self.store.add_edge(self.graph_id, {
|
||||||
|
"name": activity.action_type,
|
||||||
|
"fact": fact,
|
||||||
|
"source_node_uuid": agent_uuid,
|
||||||
|
"target_node_uuid": agent_uuid,
|
||||||
|
"attributes": {
|
||||||
|
"platform": activity.platform,
|
||||||
|
"round": activity.round_num,
|
||||||
|
"timestamp": activity.timestamp,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
def _flush_remaining(self):
|
def _flush_remaining(self):
|
||||||
"""发送队列和缓冲区中剩余的活动"""
|
"""发送队列和缓冲区中剩余的活动"""
|
||||||
# 首先处理队列中剩余的活动,添加到缓冲区
|
|
||||||
while not self._activity_queue.empty():
|
while not self._activity_queue.empty():
|
||||||
try:
|
try:
|
||||||
activity = self._activity_queue.get_nowait()
|
activity = self._activity_queue.get_nowait()
|
||||||
|
|
@ -446,14 +391,12 @@ class ZepGraphMemoryUpdater:
|
||||||
except Empty:
|
except Empty:
|
||||||
break
|
break
|
||||||
|
|
||||||
# 然后发送各平台缓冲区中剩余的活动(即使不足BATCH_SIZE条)
|
|
||||||
with self._buffer_lock:
|
with self._buffer_lock:
|
||||||
for platform, buffer in self._platform_buffers.items():
|
for platform, buffer in self._platform_buffers.items():
|
||||||
if buffer:
|
if buffer:
|
||||||
display_name = self._get_platform_display_name(platform)
|
display_name = self._get_platform_display_name(platform)
|
||||||
logger.info(f"发送{display_name}平台剩余的 {len(buffer)} 条活动")
|
logger.info(f"发送{display_name}平台剩余的 {len(buffer)} 条活动")
|
||||||
self._send_batch_activities(buffer, platform)
|
self._write_batch_activities(buffer, platform)
|
||||||
# 清空所有缓冲区
|
|
||||||
for platform in self._platform_buffers:
|
for platform in self._platform_buffers:
|
||||||
self._platform_buffers[platform] = []
|
self._platform_buffers[platform] = []
|
||||||
|
|
||||||
|
|
@ -465,45 +408,39 @@ class ZepGraphMemoryUpdater:
|
||||||
return {
|
return {
|
||||||
"graph_id": self.graph_id,
|
"graph_id": self.graph_id,
|
||||||
"batch_size": self.BATCH_SIZE,
|
"batch_size": self.BATCH_SIZE,
|
||||||
"total_activities": self._total_activities, # 添加到队列的活动总数
|
"total_activities": self._total_activities,
|
||||||
"batches_sent": self._total_sent, # 成功发送的批次数
|
"batches_sent": self._total_sent,
|
||||||
"items_sent": self._total_items_sent, # 成功发送的活动条数
|
"items_sent": self._total_items_sent,
|
||||||
"failed_count": self._failed_count, # 发送失败的批次数
|
"failed_count": self._failed_count,
|
||||||
"skipped_count": self._skipped_count, # 被过滤跳过的活动数(DO_NOTHING)
|
"skipped_count": self._skipped_count,
|
||||||
"queue_size": self._activity_queue.qsize(),
|
"queue_size": self._activity_queue.qsize(),
|
||||||
"buffer_sizes": buffer_sizes, # 各平台缓冲区大小
|
"buffer_sizes": buffer_sizes,
|
||||||
"running": self._running,
|
"running": self._running,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 向后兼容别名
|
||||||
|
ZepGraphMemoryUpdater = GraphMemoryUpdater
|
||||||
|
|
||||||
|
|
||||||
class ZepGraphMemoryManager:
|
class ZepGraphMemoryManager:
|
||||||
"""
|
"""
|
||||||
管理多个模拟的Zep图谱记忆更新器
|
管理多个模拟的图谱记忆更新器
|
||||||
|
|
||||||
每个模拟可以有自己的更新器实例
|
每个模拟可以有自己的更新器实例
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_updaters: Dict[str, ZepGraphMemoryUpdater] = {}
|
_updaters: Dict[str, GraphMemoryUpdater] = {}
|
||||||
_lock = threading.Lock()
|
_lock = threading.Lock()
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_updater(cls, simulation_id: str, graph_id: str) -> ZepGraphMemoryUpdater:
|
def create_updater(cls, simulation_id: str, graph_id: str) -> GraphMemoryUpdater:
|
||||||
"""
|
"""为模拟创建图谱记忆更新器"""
|
||||||
为模拟创建图谱记忆更新器
|
|
||||||
|
|
||||||
Args:
|
|
||||||
simulation_id: 模拟ID
|
|
||||||
graph_id: Zep图谱ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
ZepGraphMemoryUpdater实例
|
|
||||||
"""
|
|
||||||
with cls._lock:
|
with cls._lock:
|
||||||
# 如果已存在,先停止旧的
|
|
||||||
if simulation_id in cls._updaters:
|
if simulation_id in cls._updaters:
|
||||||
cls._updaters[simulation_id].stop()
|
cls._updaters[simulation_id].stop()
|
||||||
|
|
||||||
updater = ZepGraphMemoryUpdater(graph_id)
|
updater = GraphMemoryUpdater(graph_id)
|
||||||
updater.start()
|
updater.start()
|
||||||
cls._updaters[simulation_id] = updater
|
cls._updaters[simulation_id] = updater
|
||||||
|
|
||||||
|
|
@ -511,26 +448,21 @@ class ZepGraphMemoryManager:
|
||||||
return updater
|
return updater
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_updater(cls, simulation_id: str) -> Optional[ZepGraphMemoryUpdater]:
|
def get_updater(cls, simulation_id: str) -> Optional[GraphMemoryUpdater]:
|
||||||
"""获取模拟的更新器"""
|
|
||||||
return cls._updaters.get(simulation_id)
|
return cls._updaters.get(simulation_id)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def stop_updater(cls, simulation_id: str):
|
def stop_updater(cls, simulation_id: str):
|
||||||
"""停止并移除模拟的更新器"""
|
|
||||||
with cls._lock:
|
with cls._lock:
|
||||||
if simulation_id in cls._updaters:
|
if simulation_id in cls._updaters:
|
||||||
cls._updaters[simulation_id].stop()
|
cls._updaters[simulation_id].stop()
|
||||||
del cls._updaters[simulation_id]
|
del cls._updaters[simulation_id]
|
||||||
logger.info(f"已停止图谱记忆更新器: simulation_id={simulation_id}")
|
logger.info(f"已停止图谱记忆更新器: simulation_id={simulation_id}")
|
||||||
|
|
||||||
# 防止 stop_all 重复调用的标志
|
|
||||||
_stop_all_done = False
|
_stop_all_done = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def stop_all(cls):
|
def stop_all(cls):
|
||||||
"""停止所有更新器"""
|
|
||||||
# 防止重复调用
|
|
||||||
if cls._stop_all_done:
|
if cls._stop_all_done:
|
||||||
return
|
return
|
||||||
cls._stop_all_done = True
|
cls._stop_all_done = True
|
||||||
|
|
@ -547,7 +479,6 @@ class ZepGraphMemoryManager:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_all_stats(cls) -> Dict[str, Dict[str, Any]]:
|
def get_all_stats(cls) -> Dict[str, Dict[str, Any]]:
|
||||||
"""获取所有更新器的统计信息"""
|
|
||||||
return {
|
return {
|
||||||
sim_id: updater.get_stats()
|
sim_id: updater.get_stats()
|
||||||
for sim_id, updater in cls._updaters.items()
|
for sim_id, updater in cls._updaters.items()
|
||||||
|
|
|
||||||
|
|
@ -13,13 +13,11 @@ import json
|
||||||
from typing import Dict, Any, List, Optional
|
from typing import Dict, Any, List, Optional
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
from zep_cloud.client import Zep
|
|
||||||
|
|
||||||
from ..config import Config
|
from ..config import Config
|
||||||
|
from ..utils.local_graph_store import LocalGraphStore
|
||||||
from ..utils.logger import get_logger
|
from ..utils.logger import get_logger
|
||||||
from ..utils.llm_client import LLMClient
|
from ..utils.llm_client import LLMClient
|
||||||
from ..utils.locale import get_locale, t
|
from ..utils.locale import get_locale, t
|
||||||
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
|
|
||||||
|
|
||||||
logger = get_logger('mirofish.zep_tools')
|
logger = get_logger('mirofish.zep_tools')
|
||||||
|
|
||||||
|
|
@ -418,17 +416,11 @@ class ZepToolsService:
|
||||||
- get_entity_summary - 获取实体的关系摘要
|
- get_entity_summary - 获取实体的关系摘要
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 重试配置
|
def __init__(self, storage_dir: Optional[str] = None, api_key: Optional[str] = None,
|
||||||
MAX_RETRIES = 3
|
llm_client: Optional[LLMClient] = None):
|
||||||
RETRY_DELAY = 2.0
|
# api_key参数保留以兼容旧调用方式,但不再使用
|
||||||
|
storage_dir = storage_dir or Config.GRAPH_STORAGE_DIR
|
||||||
def __init__(self, api_key: Optional[str] = None, llm_client: Optional[LLMClient] = None):
|
self.store = LocalGraphStore(storage_dir)
|
||||||
self.api_key = api_key or Config.ZEP_API_KEY
|
|
||||||
if not self.api_key:
|
|
||||||
raise ValueError("ZEP_API_KEY 未配置")
|
|
||||||
|
|
||||||
self.client = Zep(api_key=self.api_key)
|
|
||||||
# LLM客户端用于InsightForge生成子问题
|
|
||||||
self._llm_client = llm_client
|
self._llm_client = llm_client
|
||||||
logger.info(t("console.zepToolsInitialized"))
|
logger.info(t("console.zepToolsInitialized"))
|
||||||
|
|
||||||
|
|
@ -439,28 +431,6 @@ class ZepToolsService:
|
||||||
self._llm_client = LLMClient()
|
self._llm_client = LLMClient()
|
||||||
return self._llm_client
|
return self._llm_client
|
||||||
|
|
||||||
def _call_with_retry(self, func, operation_name: str, max_retries: int = None):
|
|
||||||
"""带重试机制的API调用"""
|
|
||||||
max_retries = max_retries or self.MAX_RETRIES
|
|
||||||
last_exception = None
|
|
||||||
delay = self.RETRY_DELAY
|
|
||||||
|
|
||||||
for attempt in range(max_retries):
|
|
||||||
try:
|
|
||||||
return func()
|
|
||||||
except Exception as e:
|
|
||||||
last_exception = e
|
|
||||||
if attempt < max_retries - 1:
|
|
||||||
logger.warning(
|
|
||||||
t("console.zepRetryAttempt", operation=operation_name, attempt=attempt + 1, error=str(e)[:100], delay=f"{delay:.1f}")
|
|
||||||
)
|
|
||||||
time.sleep(delay)
|
|
||||||
delay *= 2
|
|
||||||
else:
|
|
||||||
logger.error(t("console.zepAllRetriesFailed", operation=operation_name, retries=max_retries, error=str(e)))
|
|
||||||
|
|
||||||
raise last_exception
|
|
||||||
|
|
||||||
def search_graph(
|
def search_graph(
|
||||||
self,
|
self,
|
||||||
graph_id: str,
|
graph_id: str,
|
||||||
|
|
@ -469,13 +439,10 @@ class ZepToolsService:
|
||||||
scope: str = "edges"
|
scope: str = "edges"
|
||||||
) -> SearchResult:
|
) -> SearchResult:
|
||||||
"""
|
"""
|
||||||
图谱语义搜索
|
图谱关键词搜索
|
||||||
|
|
||||||
使用混合搜索(语义+BM25)在图谱中搜索相关信息。
|
|
||||||
如果Zep Cloud的search API不可用,则降级为本地关键词匹配。
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph_id: 图谱ID (Standalone Graph)
|
graph_id: 图谱ID
|
||||||
query: 搜索查询
|
query: 搜索查询
|
||||||
limit: 返回结果数量
|
limit: 返回结果数量
|
||||||
scope: 搜索范围,"edges" 或 "nodes"
|
scope: 搜索范围,"edges" 或 "nodes"
|
||||||
|
|
@ -484,64 +451,7 @@ class ZepToolsService:
|
||||||
SearchResult: 搜索结果
|
SearchResult: 搜索结果
|
||||||
"""
|
"""
|
||||||
logger.info(t("console.graphSearch", graphId=graph_id, query=query[:50]))
|
logger.info(t("console.graphSearch", graphId=graph_id, query=query[:50]))
|
||||||
|
return self._local_search(graph_id, query, limit, scope)
|
||||||
# 尝试使用Zep Cloud Search API
|
|
||||||
try:
|
|
||||||
search_results = self._call_with_retry(
|
|
||||||
func=lambda: self.client.graph.search(
|
|
||||||
graph_id=graph_id,
|
|
||||||
query=query,
|
|
||||||
limit=limit,
|
|
||||||
scope=scope,
|
|
||||||
reranker="cross_encoder"
|
|
||||||
),
|
|
||||||
operation_name=t("console.graphSearchOp", graphId=graph_id)
|
|
||||||
)
|
|
||||||
|
|
||||||
facts = []
|
|
||||||
edges = []
|
|
||||||
nodes = []
|
|
||||||
|
|
||||||
# 解析边搜索结果
|
|
||||||
if hasattr(search_results, 'edges') and search_results.edges:
|
|
||||||
for edge in search_results.edges:
|
|
||||||
if hasattr(edge, 'fact') and edge.fact:
|
|
||||||
facts.append(edge.fact)
|
|
||||||
edges.append({
|
|
||||||
"uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
|
|
||||||
"name": getattr(edge, 'name', ''),
|
|
||||||
"fact": getattr(edge, 'fact', ''),
|
|
||||||
"source_node_uuid": getattr(edge, 'source_node_uuid', ''),
|
|
||||||
"target_node_uuid": getattr(edge, 'target_node_uuid', ''),
|
|
||||||
})
|
|
||||||
|
|
||||||
# 解析节点搜索结果
|
|
||||||
if hasattr(search_results, 'nodes') and search_results.nodes:
|
|
||||||
for node in search_results.nodes:
|
|
||||||
nodes.append({
|
|
||||||
"uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
|
|
||||||
"name": getattr(node, 'name', ''),
|
|
||||||
"labels": getattr(node, 'labels', []),
|
|
||||||
"summary": getattr(node, 'summary', ''),
|
|
||||||
})
|
|
||||||
# 节点摘要也算作事实
|
|
||||||
if hasattr(node, 'summary') and node.summary:
|
|
||||||
facts.append(f"[{node.name}]: {node.summary}")
|
|
||||||
|
|
||||||
logger.info(t("console.searchComplete", count=len(facts)))
|
|
||||||
|
|
||||||
return SearchResult(
|
|
||||||
facts=facts,
|
|
||||||
edges=edges,
|
|
||||||
nodes=nodes,
|
|
||||||
query=query,
|
|
||||||
total_count=len(facts)
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(t("console.zepSearchApiFallback", error=str(e)))
|
|
||||||
# 降级:使用本地关键词匹配搜索
|
|
||||||
return self._local_search(graph_id, query, limit, scope)
|
|
||||||
|
|
||||||
def _local_search(
|
def _local_search(
|
||||||
self,
|
self,
|
||||||
|
|
@ -550,94 +460,20 @@ class ZepToolsService:
|
||||||
limit: int = 10,
|
limit: int = 10,
|
||||||
scope: str = "edges"
|
scope: str = "edges"
|
||||||
) -> SearchResult:
|
) -> SearchResult:
|
||||||
"""
|
"""本地关键词匹配搜索"""
|
||||||
本地关键词匹配搜索(作为Zep Search API的降级方案)
|
|
||||||
|
|
||||||
获取所有边/节点,然后在本地进行关键词匹配
|
|
||||||
|
|
||||||
Args:
|
|
||||||
graph_id: 图谱ID
|
|
||||||
query: 搜索查询
|
|
||||||
limit: 返回结果数量
|
|
||||||
scope: 搜索范围
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
SearchResult: 搜索结果
|
|
||||||
"""
|
|
||||||
logger.info(t("console.usingLocalSearch", query=query[:30]))
|
logger.info(t("console.usingLocalSearch", query=query[:30]))
|
||||||
|
|
||||||
facts = []
|
|
||||||
edges_result = []
|
|
||||||
nodes_result = []
|
|
||||||
|
|
||||||
# 提取查询关键词(简单分词)
|
|
||||||
query_lower = query.lower()
|
|
||||||
keywords = [w.strip() for w in query_lower.replace(',', ' ').replace(',', ' ').split() if len(w.strip()) > 1]
|
|
||||||
|
|
||||||
def match_score(text: str) -> int:
|
|
||||||
"""计算文本与查询的匹配分数"""
|
|
||||||
if not text:
|
|
||||||
return 0
|
|
||||||
text_lower = text.lower()
|
|
||||||
# 完全匹配查询
|
|
||||||
if query_lower in text_lower:
|
|
||||||
return 100
|
|
||||||
# 关键词匹配
|
|
||||||
score = 0
|
|
||||||
for keyword in keywords:
|
|
||||||
if keyword in text_lower:
|
|
||||||
score += 10
|
|
||||||
return score
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if scope in ["edges", "both"]:
|
raw = self.store.search(graph_id, query, limit=limit, scope=scope)
|
||||||
# 获取所有边并匹配
|
|
||||||
all_edges = self.get_all_edges(graph_id)
|
|
||||||
scored_edges = []
|
|
||||||
for edge in all_edges:
|
|
||||||
score = match_score(edge.fact) + match_score(edge.name)
|
|
||||||
if score > 0:
|
|
||||||
scored_edges.append((score, edge))
|
|
||||||
|
|
||||||
# 按分数排序
|
facts = raw.get("facts", [])
|
||||||
scored_edges.sort(key=lambda x: x[0], reverse=True)
|
edges_result = raw.get("edges", [])
|
||||||
|
nodes_result = raw.get("nodes", [])
|
||||||
for score, edge in scored_edges[:limit]:
|
|
||||||
if edge.fact:
|
|
||||||
facts.append(edge.fact)
|
|
||||||
edges_result.append({
|
|
||||||
"uuid": edge.uuid,
|
|
||||||
"name": edge.name,
|
|
||||||
"fact": edge.fact,
|
|
||||||
"source_node_uuid": edge.source_node_uuid,
|
|
||||||
"target_node_uuid": edge.target_node_uuid,
|
|
||||||
})
|
|
||||||
|
|
||||||
if scope in ["nodes", "both"]:
|
|
||||||
# 获取所有节点并匹配
|
|
||||||
all_nodes = self.get_all_nodes(graph_id)
|
|
||||||
scored_nodes = []
|
|
||||||
for node in all_nodes:
|
|
||||||
score = match_score(node.name) + match_score(node.summary)
|
|
||||||
if score > 0:
|
|
||||||
scored_nodes.append((score, node))
|
|
||||||
|
|
||||||
scored_nodes.sort(key=lambda x: x[0], reverse=True)
|
|
||||||
|
|
||||||
for score, node in scored_nodes[:limit]:
|
|
||||||
nodes_result.append({
|
|
||||||
"uuid": node.uuid,
|
|
||||||
"name": node.name,
|
|
||||||
"labels": node.labels,
|
|
||||||
"summary": node.summary,
|
|
||||||
})
|
|
||||||
if node.summary:
|
|
||||||
facts.append(f"[{node.name}]: {node.summary}")
|
|
||||||
|
|
||||||
logger.info(t("console.localSearchComplete", count=len(facts)))
|
logger.info(t("console.localSearchComplete", count=len(facts)))
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(t("console.localSearchFailed", error=str(e)))
|
logger.error(t("console.localSearchFailed", error=str(e)))
|
||||||
|
facts, edges_result, nodes_result = [], [], []
|
||||||
|
|
||||||
return SearchResult(
|
return SearchResult(
|
||||||
facts=facts,
|
facts=facts,
|
||||||
|
|
@ -648,77 +484,55 @@ class ZepToolsService:
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_all_nodes(self, graph_id: str) -> List[NodeInfo]:
|
def get_all_nodes(self, graph_id: str) -> List[NodeInfo]:
|
||||||
"""
|
"""获取图谱的所有节点"""
|
||||||
获取图谱的所有节点(分页获取)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
graph_id: 图谱ID
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
节点列表
|
|
||||||
"""
|
|
||||||
logger.info(t("console.fetchingAllNodes", graphId=graph_id))
|
logger.info(t("console.fetchingAllNodes", graphId=graph_id))
|
||||||
|
|
||||||
nodes = fetch_all_nodes(self.client, graph_id)
|
nodes = self.store.get_nodes(graph_id)
|
||||||
|
result = [
|
||||||
result = []
|
NodeInfo(
|
||||||
for node in nodes:
|
uuid=n.get("uuid", ""),
|
||||||
node_uuid = getattr(node, 'uuid_', None) or getattr(node, 'uuid', None) or ""
|
name=n.get("name", ""),
|
||||||
result.append(NodeInfo(
|
labels=n.get("labels") or [],
|
||||||
uuid=str(node_uuid) if node_uuid else "",
|
summary=n.get("summary", ""),
|
||||||
name=node.name or "",
|
attributes=n.get("attributes") or {},
|
||||||
labels=node.labels or [],
|
)
|
||||||
summary=node.summary or "",
|
for n in nodes
|
||||||
attributes=node.attributes or {}
|
]
|
||||||
))
|
|
||||||
|
|
||||||
logger.info(t("console.fetchedNodes", count=len(result)))
|
logger.info(t("console.fetchedNodes", count=len(result)))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_all_edges(self, graph_id: str, include_temporal: bool = True) -> List[EdgeInfo]:
|
def get_all_edges(self, graph_id: str, include_temporal: bool = True) -> List[EdgeInfo]:
|
||||||
"""
|
"""获取图谱的所有边(含时间信息)"""
|
||||||
获取图谱的所有边(分页获取,包含时间信息)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
graph_id: 图谱ID
|
|
||||||
include_temporal: 是否包含时间信息(默认True)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
边列表(包含created_at, valid_at, invalid_at, expired_at)
|
|
||||||
"""
|
|
||||||
logger.info(t("console.fetchingAllEdges", graphId=graph_id))
|
logger.info(t("console.fetchingAllEdges", graphId=graph_id))
|
||||||
|
|
||||||
edges = fetch_all_edges(self.client, graph_id)
|
edges = self.store.get_edges(graph_id)
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
for edge in edges:
|
for e in edges:
|
||||||
edge_uuid = getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', None) or ""
|
|
||||||
edge_info = EdgeInfo(
|
edge_info = EdgeInfo(
|
||||||
uuid=str(edge_uuid) if edge_uuid else "",
|
uuid=e.get("uuid", ""),
|
||||||
name=edge.name or "",
|
name=e.get("name", ""),
|
||||||
fact=edge.fact or "",
|
fact=e.get("fact", ""),
|
||||||
source_node_uuid=edge.source_node_uuid or "",
|
source_node_uuid=e.get("source_node_uuid", ""),
|
||||||
target_node_uuid=edge.target_node_uuid or ""
|
target_node_uuid=e.get("target_node_uuid", ""),
|
||||||
)
|
)
|
||||||
|
|
||||||
# 添加时间信息
|
|
||||||
if include_temporal:
|
if include_temporal:
|
||||||
edge_info.created_at = getattr(edge, 'created_at', None)
|
edge_info.created_at = e.get("created_at")
|
||||||
edge_info.valid_at = getattr(edge, 'valid_at', None)
|
edge_info.valid_at = e.get("valid_at")
|
||||||
edge_info.invalid_at = getattr(edge, 'invalid_at', None)
|
edge_info.invalid_at = e.get("invalid_at")
|
||||||
edge_info.expired_at = getattr(edge, 'expired_at', None)
|
edge_info.expired_at = e.get("expired_at")
|
||||||
|
|
||||||
result.append(edge_info)
|
result.append(edge_info)
|
||||||
|
|
||||||
logger.info(t("console.fetchedEdges", count=len(result)))
|
logger.info(t("console.fetchedEdges", count=len(result)))
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def get_node_detail(self, node_uuid: str) -> Optional[NodeInfo]:
|
def get_node_detail(self, node_uuid: str, graph_id: str = "") -> Optional[NodeInfo]:
|
||||||
"""
|
"""
|
||||||
获取单个节点的详细信息
|
获取单个节点的详细信息
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_uuid: 节点UUID
|
node_uuid: 节点UUID
|
||||||
|
graph_id: 图谱ID(从本地存储检索时需要)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
节点信息或None
|
节点信息或None
|
||||||
|
|
@ -726,21 +540,18 @@ class ZepToolsService:
|
||||||
logger.info(t("console.fetchingNodeDetail", uuid=node_uuid[:8]))
|
logger.info(t("console.fetchingNodeDetail", uuid=node_uuid[:8]))
|
||||||
|
|
||||||
try:
|
try:
|
||||||
node = self._call_with_retry(
|
# 若提供了graph_id,直接从该图谱查找
|
||||||
func=lambda: self.client.graph.node.get(uuid_=node_uuid),
|
if graph_id:
|
||||||
operation_name=t("console.fetchNodeDetailOp", uuid=node_uuid[:8])
|
n = self.store.get_node(graph_id, node_uuid)
|
||||||
)
|
if n:
|
||||||
|
return NodeInfo(
|
||||||
if not node:
|
uuid=n.get("uuid", ""),
|
||||||
return None
|
name=n.get("name", ""),
|
||||||
|
labels=n.get("labels") or [],
|
||||||
return NodeInfo(
|
summary=n.get("summary", ""),
|
||||||
uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
|
attributes=n.get("attributes") or {},
|
||||||
name=node.name or "",
|
)
|
||||||
labels=node.labels or [],
|
return None
|
||||||
summary=node.summary or "",
|
|
||||||
attributes=node.attributes or {}
|
|
||||||
)
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(t("console.fetchNodeDetailFailed", error=str(e)))
|
logger.error(t("console.fetchNodeDetailFailed", error=str(e)))
|
||||||
return None
|
return None
|
||||||
|
|
@ -1043,7 +854,7 @@ class ZepToolsService:
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
# 单独获取每个相关节点的信息
|
# 单独获取每个相关节点的信息
|
||||||
node = self.get_node_detail(uuid)
|
node = self.get_node_detail(uuid, graph_id=graph_id)
|
||||||
if node:
|
if node:
|
||||||
node_map[uuid] = node
|
node_map[uuid] = node
|
||||||
entity_type = next((l for l in node.labels if l not in ["Entity", "Node"]), "实体")
|
entity_type = next((l for l in node.labels if l not in ["Entity", "Node"]), "实体")
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,290 @@
|
||||||
|
"""
|
||||||
|
本地JSON文件图谱存储
|
||||||
|
替代Zep Cloud,将图谱数据(节点、边、情节)存储在本地JSON文件中
|
||||||
|
|
||||||
|
存储目录结构:
|
||||||
|
{storage_dir}/
|
||||||
|
{graph_id}/
|
||||||
|
metadata.json - 图谱元数据和本体定义
|
||||||
|
nodes.json - 节点列表
|
||||||
|
edges.json - 边列表
|
||||||
|
episodes.jsonl - 情节文本日志(追加写入)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
import threading
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
from .logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger('mirofish.local_graph_store')
|
||||||
|
|
||||||
|
# 每个图谱一把锁,保证并发写入安全
|
||||||
|
_global_lock = threading.Lock()
|
||||||
|
_graph_locks: Dict[str, threading.Lock] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _lock_for(graph_id: str) -> threading.Lock:
|
||||||
|
with _global_lock:
|
||||||
|
if graph_id not in _graph_locks:
|
||||||
|
_graph_locks[graph_id] = threading.Lock()
|
||||||
|
return _graph_locks[graph_id]
|
||||||
|
|
||||||
|
|
||||||
|
class LocalGraphStore:
|
||||||
|
"""本地JSON文件图谱存储"""
|
||||||
|
|
||||||
|
def __init__(self, storage_dir: str):
|
||||||
|
self.storage_dir = storage_dir
|
||||||
|
os.makedirs(storage_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# ── 图谱生命周期 ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def create_graph(self, graph_id: str, name: str, description: str = "") -> None:
|
||||||
|
graph_dir = self._graph_dir(graph_id)
|
||||||
|
os.makedirs(graph_dir, exist_ok=True)
|
||||||
|
self._write_json(self._meta_path(graph_id), {
|
||||||
|
"graph_id": graph_id,
|
||||||
|
"name": name,
|
||||||
|
"description": description,
|
||||||
|
"created_at": datetime.now().isoformat(),
|
||||||
|
"ontology": None,
|
||||||
|
})
|
||||||
|
if not os.path.exists(self._nodes_path(graph_id)):
|
||||||
|
self._write_json(self._nodes_path(graph_id), [])
|
||||||
|
if not os.path.exists(self._edges_path(graph_id)):
|
||||||
|
self._write_json(self._edges_path(graph_id), [])
|
||||||
|
logger.info(f"本地图谱已创建: {graph_id}")
|
||||||
|
|
||||||
|
def delete_graph(self, graph_id: str) -> None:
|
||||||
|
graph_dir = self._graph_dir(graph_id)
|
||||||
|
if os.path.exists(graph_dir):
|
||||||
|
shutil.rmtree(graph_dir)
|
||||||
|
logger.info(f"本地图谱已删除: {graph_id}")
|
||||||
|
|
||||||
|
def graph_exists(self, graph_id: str) -> bool:
|
||||||
|
return os.path.exists(self._meta_path(graph_id))
|
||||||
|
|
||||||
|
# ── 本体 ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def set_ontology(self, graph_id: str, ontology: Dict[str, Any]) -> None:
|
||||||
|
meta = self._read_json(self._meta_path(graph_id)) or {}
|
||||||
|
meta["ontology"] = ontology
|
||||||
|
self._write_json(self._meta_path(graph_id), meta)
|
||||||
|
|
||||||
|
def get_ontology(self, graph_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
meta = self._read_json(self._meta_path(graph_id)) or {}
|
||||||
|
return meta.get("ontology")
|
||||||
|
|
||||||
|
def get_metadata(self, graph_id: str) -> Optional[Dict[str, Any]]:
|
||||||
|
return self._read_json(self._meta_path(graph_id))
|
||||||
|
|
||||||
|
# ── 情节(Episode)────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def add_episode(self, graph_id: str, text: str) -> str:
|
||||||
|
"""追加一条情节文本,返回情节uuid(本地存储立即处理完成)"""
|
||||||
|
episode_id = uuid.uuid4().hex
|
||||||
|
record = {
|
||||||
|
"uuid": episode_id,
|
||||||
|
"text": text,
|
||||||
|
"created_at": datetime.now().isoformat(),
|
||||||
|
"processed": True,
|
||||||
|
}
|
||||||
|
ep_path = self._episodes_path(graph_id)
|
||||||
|
with _lock_for(graph_id):
|
||||||
|
with open(ep_path, 'a', encoding='utf-8') as f:
|
||||||
|
f.write(json.dumps(record, ensure_ascii=False) + '\n')
|
||||||
|
return episode_id
|
||||||
|
|
||||||
|
def add_episodes_batch(self, graph_id: str, texts: List[str]) -> List[str]:
|
||||||
|
return [self.add_episode(graph_id, t) for t in texts]
|
||||||
|
|
||||||
|
def episode_is_processed(self, graph_id: str, episode_uuid: str) -> bool:
|
||||||
|
"""本地存储中的情节总是立即处理完成"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
# ── 节点 ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def get_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
|
||||||
|
return self._read_json(self._nodes_path(graph_id)) or []
|
||||||
|
|
||||||
|
def get_node(self, graph_id: str, node_uuid: str) -> Optional[Dict[str, Any]]:
|
||||||
|
for node in self.get_nodes(graph_id):
|
||||||
|
if node.get("uuid") == node_uuid:
|
||||||
|
return node
|
||||||
|
return None
|
||||||
|
|
||||||
|
def upsert_node(
|
||||||
|
self,
|
||||||
|
graph_id: str,
|
||||||
|
name: str,
|
||||||
|
labels: Optional[List[str]] = None,
|
||||||
|
summary: str = "",
|
||||||
|
attributes: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> str:
|
||||||
|
"""按名称查找节点,存在则更新,不存在则创建。返回uuid。"""
|
||||||
|
labels = labels or ["Entity"]
|
||||||
|
attributes = attributes or {}
|
||||||
|
|
||||||
|
with _lock_for(graph_id):
|
||||||
|
nodes = self._read_json(self._nodes_path(graph_id)) or []
|
||||||
|
# 按名称(不区分大小写)查找
|
||||||
|
for node in nodes:
|
||||||
|
if node.get("name", "").lower() == name.lower():
|
||||||
|
# 合并标签
|
||||||
|
existing = set(node.get("labels", []))
|
||||||
|
existing.update(labels)
|
||||||
|
node["labels"] = list(existing)
|
||||||
|
# 若原摘要为空则填充
|
||||||
|
if summary and not node.get("summary"):
|
||||||
|
node["summary"] = summary
|
||||||
|
# 合并属性
|
||||||
|
if attributes:
|
||||||
|
node.setdefault("attributes", {}).update(attributes)
|
||||||
|
self._write_json(self._nodes_path(graph_id), nodes)
|
||||||
|
return node["uuid"]
|
||||||
|
# 创建新节点
|
||||||
|
node_uuid = uuid.uuid4().hex
|
||||||
|
nodes.append({
|
||||||
|
"uuid": node_uuid,
|
||||||
|
"name": name,
|
||||||
|
"labels": labels,
|
||||||
|
"summary": summary,
|
||||||
|
"attributes": attributes,
|
||||||
|
"created_at": datetime.now().isoformat(),
|
||||||
|
})
|
||||||
|
self._write_json(self._nodes_path(graph_id), nodes)
|
||||||
|
return node_uuid
|
||||||
|
|
||||||
|
# ── 边 ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def get_edges(self, graph_id: str) -> List[Dict[str, Any]]:
|
||||||
|
return self._read_json(self._edges_path(graph_id)) or []
|
||||||
|
|
||||||
|
def get_node_edges(self, graph_id: str, node_uuid: str) -> List[Dict[str, Any]]:
|
||||||
|
"""获取与指定节点相关的所有边(作为源或目标)"""
|
||||||
|
return [
|
||||||
|
e for e in self.get_edges(graph_id)
|
||||||
|
if e.get("source_node_uuid") == node_uuid or e.get("target_node_uuid") == node_uuid
|
||||||
|
]
|
||||||
|
|
||||||
|
def add_edge(self, graph_id: str, edge: Dict[str, Any]) -> str:
|
||||||
|
"""添加一条边,返回其uuid。"""
|
||||||
|
edge_uuid = edge.get("uuid") or uuid.uuid4().hex
|
||||||
|
edge = dict(edge)
|
||||||
|
edge["uuid"] = edge_uuid
|
||||||
|
edge.setdefault("created_at", datetime.now().isoformat())
|
||||||
|
edge.setdefault("valid_at", None)
|
||||||
|
edge.setdefault("invalid_at", None)
|
||||||
|
edge.setdefault("expired_at", None)
|
||||||
|
edge.setdefault("attributes", {})
|
||||||
|
|
||||||
|
with _lock_for(graph_id):
|
||||||
|
edges = self._read_json(self._edges_path(graph_id)) or []
|
||||||
|
edges.append(edge)
|
||||||
|
self._write_json(self._edges_path(graph_id), edges)
|
||||||
|
return edge_uuid
|
||||||
|
|
||||||
|
def add_fact_edge(
|
||||||
|
self,
|
||||||
|
graph_id: str,
|
||||||
|
source_uuid: str,
|
||||||
|
target_uuid: str,
|
||||||
|
name: str,
|
||||||
|
fact: str,
|
||||||
|
) -> str:
|
||||||
|
"""便利方法:在两个节点之间添加一条命名事实边。"""
|
||||||
|
return self.add_edge(graph_id, {
|
||||||
|
"name": name,
|
||||||
|
"fact": fact,
|
||||||
|
"source_node_uuid": source_uuid,
|
||||||
|
"target_node_uuid": target_uuid,
|
||||||
|
})
|
||||||
|
|
||||||
|
# ── 搜索 ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def search(
|
||||||
|
self,
|
||||||
|
graph_id: str,
|
||||||
|
query: str,
|
||||||
|
limit: int = 10,
|
||||||
|
scope: str = "edges",
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""基于关键词的本地搜索"""
|
||||||
|
query_lower = query.lower()
|
||||||
|
keywords = [
|
||||||
|
w.strip()
|
||||||
|
for w in query_lower.replace(',', ' ').replace(',', ' ').split()
|
||||||
|
if len(w.strip()) > 1
|
||||||
|
]
|
||||||
|
|
||||||
|
def score(text: str) -> int:
|
||||||
|
if not text:
|
||||||
|
return 0
|
||||||
|
tl = text.lower()
|
||||||
|
if query_lower in tl:
|
||||||
|
return 100
|
||||||
|
return sum(10 for kw in keywords if kw in tl)
|
||||||
|
|
||||||
|
result_edges: List[Dict] = []
|
||||||
|
result_nodes: List[Dict] = []
|
||||||
|
facts: List[str] = []
|
||||||
|
|
||||||
|
if scope in ("edges", "both"):
|
||||||
|
scored = sorted(
|
||||||
|
[(score(e.get("fact", "")) + score(e.get("name", "")), e)
|
||||||
|
for e in self.get_edges(graph_id)
|
||||||
|
if score(e.get("fact", "")) + score(e.get("name", "")) > 0],
|
||||||
|
key=lambda x: x[0], reverse=True,
|
||||||
|
)
|
||||||
|
for _, edge in scored[:limit]:
|
||||||
|
result_edges.append(edge)
|
||||||
|
if edge.get("fact"):
|
||||||
|
facts.append(edge["fact"])
|
||||||
|
|
||||||
|
if scope in ("nodes", "both"):
|
||||||
|
scored = sorted(
|
||||||
|
[(score(n.get("name", "")) + score(n.get("summary", "")), n)
|
||||||
|
for n in self.get_nodes(graph_id)
|
||||||
|
if score(n.get("name", "")) + score(n.get("summary", "")) > 0],
|
||||||
|
key=lambda x: x[0], reverse=True,
|
||||||
|
)
|
||||||
|
for _, node in scored[:limit]:
|
||||||
|
result_nodes.append(node)
|
||||||
|
if node.get("summary"):
|
||||||
|
facts.append(f"[{node['name']}]: {node['summary']}")
|
||||||
|
|
||||||
|
return {"facts": facts, "edges": result_edges, "nodes": result_nodes}
|
||||||
|
|
||||||
|
# ── 内部路径辅助 ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _graph_dir(self, graph_id: str) -> str:
|
||||||
|
return os.path.join(self.storage_dir, graph_id)
|
||||||
|
|
||||||
|
def _meta_path(self, graph_id: str) -> str:
|
||||||
|
return os.path.join(self._graph_dir(graph_id), "metadata.json")
|
||||||
|
|
||||||
|
def _nodes_path(self, graph_id: str) -> str:
|
||||||
|
return os.path.join(self._graph_dir(graph_id), "nodes.json")
|
||||||
|
|
||||||
|
def _edges_path(self, graph_id: str) -> str:
|
||||||
|
return os.path.join(self._graph_dir(graph_id), "edges.json")
|
||||||
|
|
||||||
|
def _episodes_path(self, graph_id: str) -> str:
|
||||||
|
return os.path.join(self._graph_dir(graph_id), "episodes.jsonl")
|
||||||
|
|
||||||
|
def _read_json(self, path: str) -> Any:
|
||||||
|
if not os.path.exists(path):
|
||||||
|
return None
|
||||||
|
with open(path, 'r', encoding='utf-8') as f:
|
||||||
|
return json.load(f)
|
||||||
|
|
||||||
|
def _write_json(self, path: str, data: Any) -> None:
|
||||||
|
with open(path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||||
|
|
@ -1,143 +1,25 @@
|
||||||
"""Zep Graph 分页读取工具。
|
"""
|
||||||
|
图谱分页读取工具(存根模块)
|
||||||
|
|
||||||
Zep 的 node/edge 列表接口使用 UUID cursor 分页,
|
原来封装 Zep Cloud 的分页逻辑。
|
||||||
本模块封装自动翻页逻辑(含单页重试),对调用方透明地返回完整列表。
|
现在图谱数据存储在本地 JSON 文件中,不再需要分页。
|
||||||
|
本模块保留以避免破坏未更新的旧导入。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import time
|
from ..utils.logger import get_logger
|
||||||
from collections.abc import Callable
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from zep_cloud import InternalServerError
|
|
||||||
from zep_cloud.client import Zep
|
|
||||||
|
|
||||||
from .logger import get_logger
|
|
||||||
|
|
||||||
logger = get_logger('mirofish.zep_paging')
|
logger = get_logger('mirofish.zep_paging')
|
||||||
|
|
||||||
_DEFAULT_PAGE_SIZE = 100
|
|
||||||
_MAX_NODES = 2000
|
def fetch_all_nodes(client, graph_id: str, **kwargs) -> list:
|
||||||
_DEFAULT_MAX_RETRIES = 3
|
"""已废弃:请直接使用 LocalGraphStore.get_nodes()"""
|
||||||
_DEFAULT_RETRY_DELAY = 2.0 # seconds, doubles each retry
|
logger.warning("fetch_all_nodes 已废弃,请使用 LocalGraphStore.get_nodes()")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
def _fetch_page_with_retry(
|
def fetch_all_edges(client, graph_id: str, **kwargs) -> list:
|
||||||
api_call: Callable[..., list[Any]],
|
"""已废弃:请直接使用 LocalGraphStore.get_edges()"""
|
||||||
*args: Any,
|
logger.warning("fetch_all_edges 已废弃,请使用 LocalGraphStore.get_edges()")
|
||||||
max_retries: int = _DEFAULT_MAX_RETRIES,
|
return []
|
||||||
retry_delay: float = _DEFAULT_RETRY_DELAY,
|
|
||||||
page_description: str = "page",
|
|
||||||
**kwargs: Any,
|
|
||||||
) -> list[Any]:
|
|
||||||
"""单页请求,失败时指数退避重试。仅重试网络/IO类瞬态错误。"""
|
|
||||||
if max_retries < 1:
|
|
||||||
raise ValueError("max_retries must be >= 1")
|
|
||||||
|
|
||||||
last_exception: Exception | None = None
|
|
||||||
delay = retry_delay
|
|
||||||
|
|
||||||
for attempt in range(max_retries):
|
|
||||||
try:
|
|
||||||
return api_call(*args, **kwargs)
|
|
||||||
except (ConnectionError, TimeoutError, OSError, InternalServerError) as e:
|
|
||||||
last_exception = e
|
|
||||||
if attempt < max_retries - 1:
|
|
||||||
logger.warning(
|
|
||||||
f"Zep {page_description} attempt {attempt + 1} failed: {str(e)[:100]}, retrying in {delay:.1f}s..."
|
|
||||||
)
|
|
||||||
time.sleep(delay)
|
|
||||||
delay *= 2
|
|
||||||
else:
|
|
||||||
logger.error(f"Zep {page_description} failed after {max_retries} attempts: {str(e)}")
|
|
||||||
|
|
||||||
assert last_exception is not None
|
|
||||||
raise last_exception
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_all_nodes(
|
|
||||||
client: Zep,
|
|
||||||
graph_id: str,
|
|
||||||
page_size: int = _DEFAULT_PAGE_SIZE,
|
|
||||||
max_items: int = _MAX_NODES,
|
|
||||||
max_retries: int = _DEFAULT_MAX_RETRIES,
|
|
||||||
retry_delay: float = _DEFAULT_RETRY_DELAY,
|
|
||||||
) -> list[Any]:
|
|
||||||
"""分页获取图谱节点,最多返回 max_items 条(默认 2000)。每页请求自带重试。"""
|
|
||||||
all_nodes: list[Any] = []
|
|
||||||
cursor: str | None = None
|
|
||||||
page_num = 0
|
|
||||||
|
|
||||||
while True:
|
|
||||||
kwargs: dict[str, Any] = {"limit": page_size}
|
|
||||||
if cursor is not None:
|
|
||||||
kwargs["uuid_cursor"] = cursor
|
|
||||||
|
|
||||||
page_num += 1
|
|
||||||
batch = _fetch_page_with_retry(
|
|
||||||
client.graph.node.get_by_graph_id,
|
|
||||||
graph_id,
|
|
||||||
max_retries=max_retries,
|
|
||||||
retry_delay=retry_delay,
|
|
||||||
page_description=f"fetch nodes page {page_num} (graph={graph_id})",
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
if not batch:
|
|
||||||
break
|
|
||||||
|
|
||||||
all_nodes.extend(batch)
|
|
||||||
if len(all_nodes) >= max_items:
|
|
||||||
all_nodes = all_nodes[:max_items]
|
|
||||||
logger.warning(f"Node count reached limit ({max_items}), stopping pagination for graph {graph_id}")
|
|
||||||
break
|
|
||||||
if len(batch) < page_size:
|
|
||||||
break
|
|
||||||
|
|
||||||
cursor = getattr(batch[-1], "uuid_", None) or getattr(batch[-1], "uuid", None)
|
|
||||||
if cursor is None:
|
|
||||||
logger.warning(f"Node missing uuid field, stopping pagination at {len(all_nodes)} nodes")
|
|
||||||
break
|
|
||||||
|
|
||||||
return all_nodes
|
|
||||||
|
|
||||||
|
|
||||||
def fetch_all_edges(
|
|
||||||
client: Zep,
|
|
||||||
graph_id: str,
|
|
||||||
page_size: int = _DEFAULT_PAGE_SIZE,
|
|
||||||
max_retries: int = _DEFAULT_MAX_RETRIES,
|
|
||||||
retry_delay: float = _DEFAULT_RETRY_DELAY,
|
|
||||||
) -> list[Any]:
|
|
||||||
"""分页获取图谱所有边,返回完整列表。每页请求自带重试。"""
|
|
||||||
all_edges: list[Any] = []
|
|
||||||
cursor: str | None = None
|
|
||||||
page_num = 0
|
|
||||||
|
|
||||||
while True:
|
|
||||||
kwargs: dict[str, Any] = {"limit": page_size}
|
|
||||||
if cursor is not None:
|
|
||||||
kwargs["uuid_cursor"] = cursor
|
|
||||||
|
|
||||||
page_num += 1
|
|
||||||
batch = _fetch_page_with_retry(
|
|
||||||
client.graph.edge.get_by_graph_id,
|
|
||||||
graph_id,
|
|
||||||
max_retries=max_retries,
|
|
||||||
retry_delay=retry_delay,
|
|
||||||
page_description=f"fetch edges page {page_num} (graph={graph_id})",
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
if not batch:
|
|
||||||
break
|
|
||||||
|
|
||||||
all_edges.extend(batch)
|
|
||||||
if len(batch) < page_size:
|
|
||||||
break
|
|
||||||
|
|
||||||
cursor = getattr(batch[-1], "uuid_", None) or getattr(batch[-1], "uuid", None)
|
|
||||||
if cursor is None:
|
|
||||||
logger.warning(f"Edge missing uuid field, stopping pagination at {len(all_edges)} edges")
|
|
||||||
break
|
|
||||||
|
|
||||||
return all_edges
|
|
||||||
|
|
|
||||||
|
|
@ -16,9 +16,6 @@ dependencies = [
|
||||||
# LLM 相关
|
# LLM 相关
|
||||||
"openai>=1.0.0",
|
"openai>=1.0.0",
|
||||||
|
|
||||||
# Zep Cloud
|
|
||||||
"zep-cloud==3.13.0",
|
|
||||||
|
|
||||||
# OASIS 社交媒体模拟
|
# OASIS 社交媒体模拟
|
||||||
"camel-oasis==0.2.5",
|
"camel-oasis==0.2.5",
|
||||||
"camel-ai==0.2.78",
|
"camel-ai==0.2.78",
|
||||||
|
|
|
||||||
|
|
@ -13,9 +13,6 @@ flask-cors>=6.0.0
|
||||||
# OpenAI SDK(统一使用 OpenAI 格式调用 LLM)
|
# OpenAI SDK(统一使用 OpenAI 格式调用 LLM)
|
||||||
openai>=1.0.0
|
openai>=1.0.0
|
||||||
|
|
||||||
# ============= Zep Cloud =============
|
|
||||||
zep-cloud==3.13.0
|
|
||||||
|
|
||||||
# ============= OASIS 社交媒体模拟 =============
|
# ============= OASIS 社交媒体模拟 =============
|
||||||
# OASIS 社交模拟框架
|
# OASIS 社交模拟框架
|
||||||
camel-oasis==0.2.5
|
camel-oasis==0.2.5
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue