Merge 2bec63be1b into 3f4d56116c
This commit is contained in:
commit
f619ebeb97
|
|
@ -283,17 +283,6 @@ def build_graph():
|
|||
try:
|
||||
logger.info("=== 开始构建图谱 ===")
|
||||
|
||||
# 检查配置
|
||||
errors = []
|
||||
if not Config.ZEP_API_KEY:
|
||||
errors.append(t('api.zepApiKeyMissing'))
|
||||
if errors:
|
||||
logger.error(f"配置错误: {errors}")
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": t('api.configError', details="; ".join(errors))
|
||||
}), 500
|
||||
|
||||
# 解析请求
|
||||
data = request.get_json() or {}
|
||||
project_id = data.get('project_id')
|
||||
|
|
@ -387,8 +376,8 @@ def build_graph():
|
|||
)
|
||||
|
||||
# 创建图谱构建服务
|
||||
builder = GraphBuilderService(api_key=Config.ZEP_API_KEY)
|
||||
|
||||
builder = GraphBuilderService()
|
||||
|
||||
# 分块
|
||||
task_manager.update_task(
|
||||
task_id,
|
||||
|
|
@ -572,20 +561,14 @@ def get_graph_data(graph_id: str):
|
|||
获取图谱数据(节点和边)
|
||||
"""
|
||||
try:
|
||||
if not Config.ZEP_API_KEY:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": t('api.zepApiKeyMissing')
|
||||
}), 500
|
||||
|
||||
builder = GraphBuilderService(api_key=Config.ZEP_API_KEY)
|
||||
builder = GraphBuilderService()
|
||||
graph_data = builder.get_graph_data(graph_id)
|
||||
|
||||
|
||||
return jsonify({
|
||||
"success": True,
|
||||
"data": graph_data
|
||||
})
|
||||
|
||||
|
||||
except Exception as e:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
|
|
@ -597,16 +580,10 @@ def get_graph_data(graph_id: str):
|
|||
@graph_bp.route('/delete/<graph_id>', methods=['DELETE'])
|
||||
def delete_graph(graph_id: str):
|
||||
"""
|
||||
删除Zep图谱
|
||||
删除本地图谱
|
||||
"""
|
||||
try:
|
||||
if not Config.ZEP_API_KEY:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": t('api.zepApiKeyMissing')
|
||||
}), 500
|
||||
|
||||
builder = GraphBuilderService(api_key=Config.ZEP_API_KEY)
|
||||
builder = GraphBuilderService()
|
||||
builder.delete_graph(graph_id)
|
||||
|
||||
return jsonify({
|
||||
|
|
|
|||
|
|
@ -57,18 +57,12 @@ def get_graph_entities(graph_id: str):
|
|||
enrich: 是否获取相关边信息(默认true)
|
||||
"""
|
||||
try:
|
||||
if not Config.ZEP_API_KEY:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": t('api.zepApiKeyMissing')
|
||||
}), 500
|
||||
|
||||
entity_types_str = request.args.get('entity_types', '')
|
||||
entity_types = [t.strip() for t in entity_types_str.split(',') if t.strip()] if entity_types_str else None
|
||||
enrich = request.args.get('enrich', 'true').lower() == 'true'
|
||||
|
||||
|
||||
logger.info(f"获取图谱实体: graph_id={graph_id}, entity_types={entity_types}, enrich={enrich}")
|
||||
|
||||
|
||||
reader = ZepEntityReader()
|
||||
result = reader.filter_defined_entities(
|
||||
graph_id=graph_id,
|
||||
|
|
@ -94,12 +88,6 @@ def get_graph_entities(graph_id: str):
|
|||
def get_entity_detail(graph_id: str, entity_uuid: str):
|
||||
"""获取单个实体的详细信息"""
|
||||
try:
|
||||
if not Config.ZEP_API_KEY:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": t('api.zepApiKeyMissing')
|
||||
}), 500
|
||||
|
||||
reader = ZepEntityReader()
|
||||
entity = reader.get_entity_with_context(graph_id, entity_uuid)
|
||||
|
||||
|
|
@ -127,14 +115,8 @@ def get_entity_detail(graph_id: str, entity_uuid: str):
|
|||
def get_entities_by_type(graph_id: str, entity_type: str):
|
||||
"""获取指定类型的所有实体"""
|
||||
try:
|
||||
if not Config.ZEP_API_KEY:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": t('api.zepApiKeyMissing')
|
||||
}), 500
|
||||
|
||||
enrich = request.args.get('enrich', 'true').lower() == 'true'
|
||||
|
||||
|
||||
reader = ZepEntityReader()
|
||||
entities = reader.get_entities_by_type(
|
||||
graph_id=graph_id,
|
||||
|
|
|
|||
|
|
@ -32,8 +32,8 @@ class Config:
|
|||
LLM_BASE_URL = os.environ.get('LLM_BASE_URL', 'https://api.openai.com/v1')
|
||||
LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME', 'gpt-4o-mini')
|
||||
|
||||
# Zep配置
|
||||
ZEP_API_KEY = os.environ.get('ZEP_API_KEY')
|
||||
# 本地图谱存储目录
|
||||
GRAPH_STORAGE_DIR = os.path.join(os.path.dirname(__file__), '../uploads/graphs')
|
||||
|
||||
# 文件上传配置
|
||||
MAX_CONTENT_LENGTH = 50 * 1024 * 1024 # 50MB
|
||||
|
|
@ -69,7 +69,7 @@ class Config:
|
|||
errors = []
|
||||
if not cls.LLM_API_KEY:
|
||||
errors.append("LLM_API_KEY 未配置")
|
||||
if not cls.ZEP_API_KEY:
|
||||
errors.append("ZEP_API_KEY 未配置")
|
||||
# 确保图谱存储目录存在
|
||||
os.makedirs(cls.GRAPH_STORAGE_DIR, exist_ok=True)
|
||||
return errors
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
图谱构建服务
|
||||
接口2:使用Zep API构建Standalone Graph
|
||||
使用本地JSON文件存储替代Zep Cloud
|
||||
"""
|
||||
|
||||
import os
|
||||
|
|
@ -10,14 +10,15 @@ import threading
|
|||
from typing import Dict, Any, List, Optional, Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from zep_cloud.client import Zep
|
||||
from zep_cloud import EpisodeData, EntityEdgeSourceTarget
|
||||
|
||||
from ..config import Config
|
||||
from ..models.task import TaskManager, TaskStatus
|
||||
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
|
||||
from ..utils.local_graph_store import LocalGraphStore
|
||||
from ..utils.llm_client import LLMClient
|
||||
from .text_processor import TextProcessor
|
||||
from ..utils.locale import t, get_locale, set_locale
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger('mirofish.graph_builder')
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -27,7 +28,7 @@ class GraphInfo:
|
|||
node_count: int
|
||||
edge_count: int
|
||||
entity_types: List[str]
|
||||
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"graph_id": self.graph_id,
|
||||
|
|
@ -40,17 +41,23 @@ class GraphInfo:
|
|||
class GraphBuilderService:
|
||||
"""
|
||||
图谱构建服务
|
||||
负责调用Zep API构建知识图谱
|
||||
使用本地JSON文件存储构建知识图谱
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
self.api_key = api_key or Config.ZEP_API_KEY
|
||||
if not self.api_key:
|
||||
raise ValueError("ZEP_API_KEY 未配置")
|
||||
|
||||
self.client = Zep(api_key=self.api_key)
|
||||
|
||||
def __init__(self, storage_dir: Optional[str] = None, api_key: Optional[str] = None):
|
||||
# api_key参数保留以兼容旧调用方式,但不再使用
|
||||
self.storage_dir = storage_dir or Config.GRAPH_STORAGE_DIR
|
||||
self.store = LocalGraphStore(self.storage_dir)
|
||||
self.task_manager = TaskManager()
|
||||
|
||||
self._llm: Optional[LLMClient] = None
|
||||
|
||||
@property
|
||||
def llm(self) -> LLMClient:
|
||||
"""延迟初始化LLM客户端"""
|
||||
if self._llm is None:
|
||||
self._llm = LLMClient()
|
||||
return self._llm
|
||||
|
||||
def build_graph_async(
|
||||
self,
|
||||
text: str,
|
||||
|
|
@ -62,19 +69,10 @@ class GraphBuilderService:
|
|||
) -> str:
|
||||
"""
|
||||
异步构建图谱
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
ontology: 本体定义(来自接口1的输出)
|
||||
graph_name: 图谱名称
|
||||
chunk_size: 文本块大小
|
||||
chunk_overlap: 块重叠大小
|
||||
batch_size: 每批发送的块数量
|
||||
|
||||
|
||||
Returns:
|
||||
任务ID
|
||||
"""
|
||||
# 创建任务
|
||||
task_id = self.task_manager.create_task(
|
||||
task_type="graph_build",
|
||||
metadata={
|
||||
|
|
@ -83,20 +81,18 @@ class GraphBuilderService:
|
|||
"text_length": len(text),
|
||||
}
|
||||
)
|
||||
|
||||
# Capture locale before spawning background thread
|
||||
|
||||
current_locale = get_locale()
|
||||
|
||||
# 在后台线程中执行构建
|
||||
thread = threading.Thread(
|
||||
target=self._build_graph_worker,
|
||||
args=(task_id, text, ontology, graph_name, chunk_size, chunk_overlap, batch_size, current_locale)
|
||||
)
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
|
||||
|
||||
return task_id
|
||||
|
||||
|
||||
def _build_graph_worker(
|
||||
self,
|
||||
task_id: str,
|
||||
|
|
@ -117,7 +113,7 @@ class GraphBuilderService:
|
|||
progress=5,
|
||||
message=t('progress.startBuildingGraph')
|
||||
)
|
||||
|
||||
|
||||
# 1. 创建图谱
|
||||
graph_id = self.create_graph(graph_name)
|
||||
self.task_manager.update_task(
|
||||
|
|
@ -125,15 +121,15 @@ class GraphBuilderService:
|
|||
progress=10,
|
||||
message=t('progress.graphCreated', graphId=graph_id)
|
||||
)
|
||||
|
||||
# 2. 设置本体
|
||||
|
||||
# 2. 保存本体
|
||||
self.set_ontology(graph_id, ontology)
|
||||
self.task_manager.update_task(
|
||||
task_id,
|
||||
progress=15,
|
||||
message=t('progress.ontologySet')
|
||||
)
|
||||
|
||||
|
||||
# 3. 文本分块
|
||||
chunks = TextProcessor.split_text(text, chunk_size, chunk_overlap)
|
||||
total_chunks = len(chunks)
|
||||
|
|
@ -142,155 +138,47 @@ class GraphBuilderService:
|
|||
progress=20,
|
||||
message=t('progress.textSplit', count=total_chunks)
|
||||
)
|
||||
|
||||
# 4. 分批发送数据
|
||||
episode_uuids = self.add_text_batches(
|
||||
|
||||
# 4. 分批处理:提取实体并存储
|
||||
self.add_text_batches(
|
||||
graph_id, chunks, batch_size,
|
||||
lambda msg, prog: self.task_manager.update_task(
|
||||
task_id,
|
||||
progress=20 + int(prog * 0.4), # 20-60%
|
||||
progress=20 + int(prog * 0.7), # 20-90%
|
||||
message=msg
|
||||
)
|
||||
)
|
||||
|
||||
# 5. 等待Zep处理完成
|
||||
self.task_manager.update_task(
|
||||
task_id,
|
||||
progress=60,
|
||||
message=t('progress.waitingZepProcess')
|
||||
)
|
||||
|
||||
self._wait_for_episodes(
|
||||
episode_uuids,
|
||||
lambda msg, prog: self.task_manager.update_task(
|
||||
task_id,
|
||||
progress=60 + int(prog * 0.3), # 60-90%
|
||||
message=msg
|
||||
)
|
||||
)
|
||||
|
||||
# 6. 获取图谱信息
|
||||
|
||||
# 5. 获取图谱信息
|
||||
self.task_manager.update_task(
|
||||
task_id,
|
||||
progress=90,
|
||||
message=t('progress.fetchingGraphInfo')
|
||||
)
|
||||
|
||||
|
||||
graph_info = self._get_graph_info(graph_id)
|
||||
|
||||
# 完成
|
||||
|
||||
self.task_manager.complete_task(task_id, {
|
||||
"graph_id": graph_id,
|
||||
"graph_info": graph_info.to_dict(),
|
||||
"chunks_processed": total_chunks,
|
||||
})
|
||||
|
||||
|
||||
except Exception as e:
|
||||
import traceback
|
||||
error_msg = f"{str(e)}\n{traceback.format_exc()}"
|
||||
self.task_manager.fail_task(task_id, error_msg)
|
||||
|
||||
|
||||
def create_graph(self, name: str) -> str:
|
||||
"""创建Zep图谱(公开方法)"""
|
||||
"""创建本地图谱"""
|
||||
graph_id = f"mirofish_{uuid.uuid4().hex[:16]}"
|
||||
|
||||
self.client.graph.create(
|
||||
graph_id=graph_id,
|
||||
name=name,
|
||||
description="MiroFish Social Simulation Graph"
|
||||
)
|
||||
|
||||
self.store.create_graph(graph_id, name, "MiroFish Social Simulation Graph")
|
||||
return graph_id
|
||||
|
||||
|
||||
def set_ontology(self, graph_id: str, ontology: Dict[str, Any]):
|
||||
"""设置图谱本体(公开方法)"""
|
||||
import warnings
|
||||
from typing import Optional
|
||||
from pydantic import Field
|
||||
from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel
|
||||
|
||||
# 抑制 Pydantic v2 关于 Field(default=None) 的警告
|
||||
# 这是 Zep SDK 要求的用法,警告来自动态类创建,可以安全忽略
|
||||
warnings.filterwarnings('ignore', category=UserWarning, module='pydantic')
|
||||
|
||||
# Zep 保留名称,不能作为属性名
|
||||
RESERVED_NAMES = {'uuid', 'name', 'group_id', 'name_embedding', 'summary', 'created_at'}
|
||||
|
||||
def safe_attr_name(attr_name: str) -> str:
|
||||
"""将保留名称转换为安全名称"""
|
||||
if attr_name.lower() in RESERVED_NAMES:
|
||||
return f"entity_{attr_name}"
|
||||
return attr_name
|
||||
|
||||
# 动态创建实体类型
|
||||
entity_types = {}
|
||||
for entity_def in ontology.get("entity_types", []):
|
||||
name = entity_def["name"]
|
||||
description = entity_def.get("description", f"A {name} entity.")
|
||||
|
||||
# 创建属性字典和类型注解(Pydantic v2 需要)
|
||||
attrs = {"__doc__": description}
|
||||
annotations = {}
|
||||
|
||||
for attr_def in entity_def.get("attributes", []):
|
||||
attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称
|
||||
attr_desc = attr_def.get("description", attr_name)
|
||||
# Zep API 需要 Field 的 description,这是必需的
|
||||
attrs[attr_name] = Field(description=attr_desc, default=None)
|
||||
annotations[attr_name] = Optional[EntityText] # 类型注解
|
||||
|
||||
attrs["__annotations__"] = annotations
|
||||
|
||||
# 动态创建类
|
||||
entity_class = type(name, (EntityModel,), attrs)
|
||||
entity_class.__doc__ = description
|
||||
entity_types[name] = entity_class
|
||||
|
||||
# 动态创建边类型
|
||||
edge_definitions = {}
|
||||
for edge_def in ontology.get("edge_types", []):
|
||||
name = edge_def["name"]
|
||||
description = edge_def.get("description", f"A {name} relationship.")
|
||||
|
||||
# 创建属性字典和类型注解
|
||||
attrs = {"__doc__": description}
|
||||
annotations = {}
|
||||
|
||||
for attr_def in edge_def.get("attributes", []):
|
||||
attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称
|
||||
attr_desc = attr_def.get("description", attr_name)
|
||||
# Zep API 需要 Field 的 description,这是必需的
|
||||
attrs[attr_name] = Field(description=attr_desc, default=None)
|
||||
annotations[attr_name] = Optional[str] # 边属性用str类型
|
||||
|
||||
attrs["__annotations__"] = annotations
|
||||
|
||||
# 动态创建类
|
||||
class_name = ''.join(word.capitalize() for word in name.split('_'))
|
||||
edge_class = type(class_name, (EdgeModel,), attrs)
|
||||
edge_class.__doc__ = description
|
||||
|
||||
# 构建source_targets
|
||||
source_targets = []
|
||||
for st in edge_def.get("source_targets", []):
|
||||
source_targets.append(
|
||||
EntityEdgeSourceTarget(
|
||||
source=st.get("source", "Entity"),
|
||||
target=st.get("target", "Entity")
|
||||
)
|
||||
)
|
||||
|
||||
if source_targets:
|
||||
edge_definitions[name] = (edge_class, source_targets)
|
||||
|
||||
# 调用Zep API设置本体
|
||||
if entity_types or edge_definitions:
|
||||
self.client.graph.set_ontology(
|
||||
graph_ids=[graph_id],
|
||||
entities=entity_types if entity_types else None,
|
||||
edges=edge_definitions if edge_definitions else None,
|
||||
)
|
||||
|
||||
"""保存本体定义"""
|
||||
self.store.set_ontology(graph_id, ontology)
|
||||
|
||||
def add_text_batches(
|
||||
self,
|
||||
graph_id: str,
|
||||
|
|
@ -298,209 +186,206 @@ class GraphBuilderService:
|
|||
batch_size: int = 3,
|
||||
progress_callback: Optional[Callable] = None
|
||||
) -> List[str]:
|
||||
"""分批添加文本到图谱,返回所有 episode 的 uuid 列表"""
|
||||
"""分批处理文本:提取实体/关系并存储,返回情节uuid列表"""
|
||||
episode_uuids = []
|
||||
ontology = self.store.get_ontology(graph_id) or {}
|
||||
total_chunks = len(chunks)
|
||||
|
||||
|
||||
for i in range(0, total_chunks, batch_size):
|
||||
batch_chunks = chunks[i:i + batch_size]
|
||||
batch = chunks[i:i + batch_size]
|
||||
batch_num = i // batch_size + 1
|
||||
total_batches = (total_chunks + batch_size - 1) // batch_size
|
||||
|
||||
|
||||
if progress_callback:
|
||||
progress = (i + len(batch_chunks)) / total_chunks
|
||||
progress = (i + len(batch)) / total_chunks
|
||||
progress_callback(
|
||||
t('progress.sendingBatch', current=batch_num, total=total_batches, chunks=len(batch_chunks)),
|
||||
t('progress.sendingBatch', current=batch_num, total=total_batches, chunks=len(batch)),
|
||||
progress
|
||||
)
|
||||
|
||||
# 构建episode数据
|
||||
episodes = [
|
||||
EpisodeData(data=chunk, type="text")
|
||||
for chunk in batch_chunks
|
||||
]
|
||||
|
||||
# 发送到Zep
|
||||
try:
|
||||
batch_result = self.client.graph.add_batch(
|
||||
graph_id=graph_id,
|
||||
episodes=episodes
|
||||
)
|
||||
|
||||
# 收集返回的 episode uuid
|
||||
if batch_result and isinstance(batch_result, list):
|
||||
for ep in batch_result:
|
||||
ep_uuid = getattr(ep, 'uuid_', None) or getattr(ep, 'uuid', None)
|
||||
if ep_uuid:
|
||||
episode_uuids.append(ep_uuid)
|
||||
|
||||
# 避免请求过快
|
||||
time.sleep(1)
|
||||
|
||||
except Exception as e:
|
||||
if progress_callback:
|
||||
progress_callback(t('progress.batchFailed', batch=batch_num, error=str(e)), 0)
|
||||
raise
|
||||
|
||||
|
||||
# 存储情节文本
|
||||
for text in batch:
|
||||
ep_uuid = self.store.add_episode(graph_id, text)
|
||||
episode_uuids.append(ep_uuid)
|
||||
|
||||
# 使用LLM从批次文本中提取实体和关系
|
||||
if ontology.get("entity_types") or ontology.get("edge_types"):
|
||||
try:
|
||||
extracted = self._extract_entities_from_batch(batch, ontology)
|
||||
self._store_extracted(graph_id, extracted)
|
||||
except Exception as e:
|
||||
logger.warning(f"批次 {batch_num} 实体提取失败: {e}")
|
||||
|
||||
# 轻微延迟,避免LLM请求过快
|
||||
time.sleep(0.3)
|
||||
|
||||
return episode_uuids
|
||||
|
||||
|
||||
def _extract_entities_from_batch(self, texts: List[str], ontology: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""使用LLM从文本批次中提取实体和关系"""
|
||||
combined_text = "\n\n".join(texts)
|
||||
|
||||
entity_types_desc = "\n".join(
|
||||
f"- {et['name']}: {et.get('description', '')}"
|
||||
for et in ontology.get("entity_types", [])
|
||||
) or "- Entity (通用实体)"
|
||||
|
||||
edge_types_desc = "\n".join(
|
||||
f"- {rt['name']}: {rt.get('description', '')}"
|
||||
for rt in ontology.get("edge_types", [])
|
||||
) or "- RELATED_TO"
|
||||
|
||||
user_prompt = f"""从以下文本中提取实体和关系,仅使用给定的本体类型。
|
||||
|
||||
实体类型(只能使用这些):
|
||||
{entity_types_desc}
|
||||
|
||||
关系类型(只能使用这些):
|
||||
{edge_types_desc}
|
||||
|
||||
文本:
|
||||
{combined_text[:4000]}
|
||||
|
||||
返回JSON格式:
|
||||
{{
|
||||
"entities": [
|
||||
{{"name": "实体名称", "type": "实体类型", "summary": "一句话描述", "attributes": {{}}}}
|
||||
],
|
||||
"relationships": [
|
||||
{{"source": "源实体名称", "target": "目标实体名称", "type": "关系类型", "fact": "事实描述"}}
|
||||
]
|
||||
}}
|
||||
|
||||
规则:
|
||||
- 仅使用本体中定义的实体类型和关系类型
|
||||
- 实体名称应具体(人名、地名、组织名等)
|
||||
- fact字段应是简洁的事实陈述
|
||||
- 若找不到匹配项,返回空列表"""
|
||||
|
||||
try:
|
||||
result = self.llm.chat_json(
|
||||
messages=[{"role": "user", "content": user_prompt}],
|
||||
temperature=0.1
|
||||
)
|
||||
return result if isinstance(result, dict) else {"entities": [], "relationships": []}
|
||||
except Exception as e:
|
||||
logger.warning(f"实体提取LLM调用失败: {e}")
|
||||
return {"entities": [], "relationships": []}
|
||||
|
||||
def _store_extracted(self, graph_id: str, extracted: Dict[str, Any]):
|
||||
"""将LLM提取的实体和关系存储到本地图谱"""
|
||||
entities = extracted.get("entities", []) or []
|
||||
relationships = extracted.get("relationships", []) or []
|
||||
|
||||
name_to_uuid: Dict[str, str] = {}
|
||||
|
||||
for entity in entities:
|
||||
name = (entity.get("name") or "").strip()
|
||||
if not name:
|
||||
continue
|
||||
entity_type = entity.get("type") or "Entity"
|
||||
summary = entity.get("summary") or ""
|
||||
attributes = entity.get("attributes") or {}
|
||||
|
||||
labels = [entity_type, "Entity"] if entity_type != "Entity" else ["Entity"]
|
||||
node_uuid = self.store.upsert_node(
|
||||
graph_id=graph_id,
|
||||
name=name,
|
||||
labels=labels,
|
||||
summary=summary,
|
||||
attributes=attributes,
|
||||
)
|
||||
name_to_uuid[name.lower()] = node_uuid
|
||||
|
||||
for rel in relationships:
|
||||
source_name = (rel.get("source") or "").strip()
|
||||
target_name = (rel.get("target") or "").strip()
|
||||
rel_type = rel.get("type") or "RELATED_TO"
|
||||
fact = rel.get("fact") or ""
|
||||
|
||||
if not source_name or not target_name or not fact:
|
||||
continue
|
||||
|
||||
source_uuid = name_to_uuid.get(source_name.lower()) or \
|
||||
self.store.upsert_node(graph_id, source_name, ["Entity"])
|
||||
name_to_uuid[source_name.lower()] = source_uuid
|
||||
|
||||
target_uuid = name_to_uuid.get(target_name.lower()) or \
|
||||
self.store.upsert_node(graph_id, target_name, ["Entity"])
|
||||
name_to_uuid[target_name.lower()] = target_uuid
|
||||
|
||||
self.store.add_fact_edge(
|
||||
graph_id=graph_id,
|
||||
source_uuid=source_uuid,
|
||||
target_uuid=target_uuid,
|
||||
name=rel_type,
|
||||
fact=fact,
|
||||
)
|
||||
|
||||
def _wait_for_episodes(
|
||||
self,
|
||||
episode_uuids: List[str],
|
||||
progress_callback: Optional[Callable] = None,
|
||||
timeout: int = 600
|
||||
):
|
||||
"""等待所有 episode 处理完成(通过查询每个 episode 的 processed 状态)"""
|
||||
if not episode_uuids:
|
||||
if progress_callback:
|
||||
progress_callback(t('progress.noEpisodesWait'), 1.0)
|
||||
return
|
||||
|
||||
start_time = time.time()
|
||||
pending_episodes = set(episode_uuids)
|
||||
completed_count = 0
|
||||
total_episodes = len(episode_uuids)
|
||||
|
||||
"""本地存储中情节立即处理完成,无需等待"""
|
||||
if progress_callback:
|
||||
progress_callback(t('progress.waitingEpisodes', count=total_episodes), 0)
|
||||
|
||||
while pending_episodes:
|
||||
if time.time() - start_time > timeout:
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
t('progress.episodesTimeout', completed=completed_count, total=total_episodes),
|
||||
completed_count / total_episodes
|
||||
)
|
||||
break
|
||||
|
||||
# 检查每个 episode 的处理状态
|
||||
for ep_uuid in list(pending_episodes):
|
||||
try:
|
||||
episode = self.client.graph.episode.get(uuid_=ep_uuid)
|
||||
is_processed = getattr(episode, 'processed', False)
|
||||
|
||||
if is_processed:
|
||||
pending_episodes.remove(ep_uuid)
|
||||
completed_count += 1
|
||||
|
||||
except Exception as e:
|
||||
# 忽略单个查询错误,继续
|
||||
pass
|
||||
|
||||
elapsed = int(time.time() - start_time)
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
t('progress.zepProcessing', completed=completed_count, total=total_episodes, pending=len(pending_episodes), elapsed=elapsed),
|
||||
completed_count / total_episodes if total_episodes > 0 else 0
|
||||
)
|
||||
|
||||
if pending_episodes:
|
||||
time.sleep(3) # 每3秒检查一次
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(t('progress.processingComplete', completed=completed_count, total=total_episodes), 1.0)
|
||||
|
||||
progress_callback(t('progress.processingComplete',
|
||||
completed=len(episode_uuids),
|
||||
total=len(episode_uuids)), 1.0)
|
||||
|
||||
def _get_graph_info(self, graph_id: str) -> GraphInfo:
|
||||
"""获取图谱信息"""
|
||||
# 获取节点(分页)
|
||||
nodes = fetch_all_nodes(self.client, graph_id)
|
||||
"""获取图谱统计信息"""
|
||||
nodes = self.store.get_nodes(graph_id)
|
||||
edges = self.store.get_edges(graph_id)
|
||||
|
||||
# 获取边(分页)
|
||||
edges = fetch_all_edges(self.client, graph_id)
|
||||
|
||||
# 统计实体类型
|
||||
entity_types = set()
|
||||
for node in nodes:
|
||||
if node.labels:
|
||||
for label in node.labels:
|
||||
if label not in ["Entity", "Node"]:
|
||||
entity_types.add(label)
|
||||
for label in (node.get("labels") or []):
|
||||
if label not in ("Entity", "Node"):
|
||||
entity_types.add(label)
|
||||
|
||||
return GraphInfo(
|
||||
graph_id=graph_id,
|
||||
node_count=len(nodes),
|
||||
edge_count=len(edges),
|
||||
entity_types=list(entity_types)
|
||||
entity_types=list(entity_types),
|
||||
)
|
||||
|
||||
def get_graph_data(self, graph_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取完整图谱数据(包含详细信息)
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
|
||||
Returns:
|
||||
包含nodes和edges的字典,包括时间信息、属性等详细数据
|
||||
"""
|
||||
nodes = fetch_all_nodes(self.client, graph_id)
|
||||
edges = fetch_all_edges(self.client, graph_id)
|
||||
|
||||
# 创建节点映射用于获取节点名称
|
||||
node_map = {}
|
||||
for node in nodes:
|
||||
node_map[node.uuid_] = node.name or ""
|
||||
|
||||
nodes_data = []
|
||||
for node in nodes:
|
||||
# 获取创建时间
|
||||
created_at = getattr(node, 'created_at', None)
|
||||
if created_at:
|
||||
created_at = str(created_at)
|
||||
|
||||
nodes_data.append({
|
||||
"uuid": node.uuid_,
|
||||
"name": node.name,
|
||||
"labels": node.labels or [],
|
||||
"summary": node.summary or "",
|
||||
"attributes": node.attributes or {},
|
||||
"created_at": created_at,
|
||||
})
|
||||
|
||||
def get_graph_data(self, graph_id: str) -> Dict[str, Any]:
|
||||
"""获取完整图谱数据(含节点和边详情)"""
|
||||
nodes = self.store.get_nodes(graph_id)
|
||||
edges = self.store.get_edges(graph_id)
|
||||
|
||||
node_map = {n["uuid"]: n.get("name", "") for n in nodes}
|
||||
|
||||
edges_data = []
|
||||
for edge in edges:
|
||||
# 获取时间信息
|
||||
created_at = getattr(edge, 'created_at', None)
|
||||
valid_at = getattr(edge, 'valid_at', None)
|
||||
invalid_at = getattr(edge, 'invalid_at', None)
|
||||
expired_at = getattr(edge, 'expired_at', None)
|
||||
|
||||
# 获取 episodes
|
||||
episodes = getattr(edge, 'episodes', None) or getattr(edge, 'episode_ids', None)
|
||||
if episodes and not isinstance(episodes, list):
|
||||
episodes = [str(episodes)]
|
||||
elif episodes:
|
||||
episodes = [str(e) for e in episodes]
|
||||
|
||||
# 获取 fact_type
|
||||
fact_type = getattr(edge, 'fact_type', None) or edge.name or ""
|
||||
|
||||
edges_data.append({
|
||||
"uuid": edge.uuid_,
|
||||
"name": edge.name or "",
|
||||
"fact": edge.fact or "",
|
||||
"fact_type": fact_type,
|
||||
"source_node_uuid": edge.source_node_uuid,
|
||||
"target_node_uuid": edge.target_node_uuid,
|
||||
"source_node_name": node_map.get(edge.source_node_uuid, ""),
|
||||
"target_node_name": node_map.get(edge.target_node_uuid, ""),
|
||||
"attributes": edge.attributes or {},
|
||||
"created_at": str(created_at) if created_at else None,
|
||||
"valid_at": str(valid_at) if valid_at else None,
|
||||
"invalid_at": str(invalid_at) if invalid_at else None,
|
||||
"expired_at": str(expired_at) if expired_at else None,
|
||||
"episodes": episodes or [],
|
||||
"uuid": edge.get("uuid", ""),
|
||||
"name": edge.get("name", ""),
|
||||
"fact": edge.get("fact", ""),
|
||||
"fact_type": edge.get("name", ""),
|
||||
"source_node_uuid": edge.get("source_node_uuid", ""),
|
||||
"target_node_uuid": edge.get("target_node_uuid", ""),
|
||||
"source_node_name": node_map.get(edge.get("source_node_uuid", ""), ""),
|
||||
"target_node_name": node_map.get(edge.get("target_node_uuid", ""), ""),
|
||||
"attributes": edge.get("attributes", {}),
|
||||
"created_at": edge.get("created_at"),
|
||||
"valid_at": edge.get("valid_at"),
|
||||
"invalid_at": edge.get("invalid_at"),
|
||||
"expired_at": edge.get("expired_at"),
|
||||
"episodes": [],
|
||||
})
|
||||
|
||||
|
||||
return {
|
||||
"graph_id": graph_id,
|
||||
"nodes": nodes_data,
|
||||
"nodes": nodes,
|
||||
"edges": edges_data,
|
||||
"node_count": len(nodes_data),
|
||||
"edge_count": len(edges_data),
|
||||
"node_count": len(nodes),
|
||||
"edge_count": len(edges),
|
||||
}
|
||||
|
||||
|
||||
def delete_graph(self, graph_id: str):
|
||||
"""删除图谱"""
|
||||
self.client.graph.delete(graph_id=graph_id)
|
||||
|
||||
self.store.delete_graph(graph_id)
|
||||
|
|
|
|||
|
|
@ -16,9 +16,9 @@ from dataclasses import dataclass, field
|
|||
from datetime import datetime
|
||||
|
||||
from openai import OpenAI
|
||||
from zep_cloud.client import Zep
|
||||
|
||||
from ..config import Config
|
||||
from ..utils.local_graph_store import LocalGraphStore
|
||||
from ..utils.logger import get_logger
|
||||
from ..utils.locale import get_language_instruction, get_locale, set_locale, t
|
||||
from .zep_entity_reader import EntityNode, ZepEntityReader
|
||||
|
|
@ -179,35 +179,30 @@ class OasisProfileGenerator:
|
|||
]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
model_name: Optional[str] = None,
|
||||
zep_api_key: Optional[str] = None,
|
||||
graph_id: Optional[str] = None
|
||||
zep_api_key: Optional[str] = None, # 已废弃,保留以兼容旧调用
|
||||
graph_id: Optional[str] = None,
|
||||
storage_dir: Optional[str] = None,
|
||||
):
|
||||
self.api_key = api_key or Config.LLM_API_KEY
|
||||
self.base_url = base_url or Config.LLM_BASE_URL
|
||||
self.model_name = model_name or Config.LLM_MODEL_NAME
|
||||
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("LLM_API_KEY 未配置")
|
||||
|
||||
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
)
|
||||
|
||||
# Zep客户端用于检索丰富上下文
|
||||
self.zep_api_key = zep_api_key or Config.ZEP_API_KEY
|
||||
self.zep_client = None
|
||||
|
||||
# 本地图谱存储
|
||||
storage_dir = storage_dir or Config.GRAPH_STORAGE_DIR
|
||||
self.store = LocalGraphStore(storage_dir)
|
||||
self.graph_id = graph_id
|
||||
|
||||
if self.zep_api_key:
|
||||
try:
|
||||
self.zep_client = Zep(api_key=self.zep_api_key)
|
||||
except Exception as e:
|
||||
logger.warning(f"Zep客户端初始化失败: {e}")
|
||||
|
||||
def generate_profile_from_entity(
|
||||
self,
|
||||
|
|
@ -285,130 +280,53 @@ class OasisProfileGenerator:
|
|||
|
||||
def _search_zep_for_entity(self, entity: EntityNode) -> Dict[str, Any]:
|
||||
"""
|
||||
使用Zep图谱混合搜索功能获取实体相关的丰富信息
|
||||
|
||||
Zep没有内置混合搜索接口,需要分别搜索edges和nodes然后合并结果。
|
||||
使用并行请求同时搜索,提高效率。
|
||||
|
||||
使用本地图谱关键词搜索获取实体相关的丰富信息
|
||||
|
||||
Args:
|
||||
entity: 实体节点对象
|
||||
|
||||
|
||||
Returns:
|
||||
包含facts, node_summaries, context的字典
|
||||
"""
|
||||
import concurrent.futures
|
||||
|
||||
if not self.zep_client:
|
||||
return {"facts": [], "node_summaries": [], "context": ""}
|
||||
|
||||
entity_name = entity.name
|
||||
|
||||
results = {
|
||||
"facts": [],
|
||||
"node_summaries": [],
|
||||
"context": ""
|
||||
}
|
||||
|
||||
# 必须有graph_id才能进行搜索
|
||||
results: Dict[str, Any] = {"facts": [], "node_summaries": [], "context": ""}
|
||||
|
||||
if not self.graph_id:
|
||||
logger.debug(f"跳过Zep检索:未设置graph_id")
|
||||
logger.debug("跳过本地检索:未设置graph_id")
|
||||
return results
|
||||
|
||||
comprehensive_query = t('progress.zepSearchQuery', name=entity_name)
|
||||
|
||||
def search_edges():
|
||||
"""搜索边(事实/关系)- 带重试机制"""
|
||||
max_retries = 3
|
||||
last_exception = None
|
||||
delay = 2.0
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return self.zep_client.graph.search(
|
||||
query=comprehensive_query,
|
||||
graph_id=self.graph_id,
|
||||
limit=30,
|
||||
scope="edges",
|
||||
reranker="rrf"
|
||||
)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
if attempt < max_retries - 1:
|
||||
logger.debug(f"Zep边搜索第 {attempt + 1} 次失败: {str(e)[:80]}, 重试中...")
|
||||
time.sleep(delay)
|
||||
delay *= 2
|
||||
else:
|
||||
logger.debug(f"Zep边搜索在 {max_retries} 次尝试后仍失败: {e}")
|
||||
return None
|
||||
|
||||
def search_nodes():
|
||||
"""搜索节点(实体摘要)- 带重试机制"""
|
||||
max_retries = 3
|
||||
last_exception = None
|
||||
delay = 2.0
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return self.zep_client.graph.search(
|
||||
query=comprehensive_query,
|
||||
graph_id=self.graph_id,
|
||||
limit=20,
|
||||
scope="nodes",
|
||||
reranker="rrf"
|
||||
)
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
if attempt < max_retries - 1:
|
||||
logger.debug(f"Zep节点搜索第 {attempt + 1} 次失败: {str(e)[:80]}, 重试中...")
|
||||
time.sleep(delay)
|
||||
delay *= 2
|
||||
else:
|
||||
logger.debug(f"Zep节点搜索在 {max_retries} 次尝试后仍失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
entity_name = entity.name
|
||||
query = t('progress.zepSearchQuery', name=entity_name)
|
||||
|
||||
try:
|
||||
# 并行执行edges和nodes搜索
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
|
||||
edge_future = executor.submit(search_edges)
|
||||
node_future = executor.submit(search_nodes)
|
||||
|
||||
# 获取结果
|
||||
edge_result = edge_future.result(timeout=30)
|
||||
node_result = node_future.result(timeout=30)
|
||||
|
||||
# 处理边搜索结果
|
||||
all_facts = set()
|
||||
if edge_result and hasattr(edge_result, 'edges') and edge_result.edges:
|
||||
for edge in edge_result.edges:
|
||||
if hasattr(edge, 'fact') and edge.fact:
|
||||
all_facts.add(edge.fact)
|
||||
results["facts"] = list(all_facts)
|
||||
|
||||
# 处理节点搜索结果
|
||||
all_summaries = set()
|
||||
if node_result and hasattr(node_result, 'nodes') and node_result.nodes:
|
||||
for node in node_result.nodes:
|
||||
if hasattr(node, 'summary') and node.summary:
|
||||
all_summaries.add(node.summary)
|
||||
if hasattr(node, 'name') and node.name and node.name != entity_name:
|
||||
all_summaries.add(f"相关实体: {node.name}")
|
||||
results["node_summaries"] = list(all_summaries)
|
||||
|
||||
# 构建综合上下文
|
||||
# 搜索边(事实)
|
||||
edge_raw = self.store.search(self.graph_id, query, limit=30, scope="edges")
|
||||
facts = list({e.get("fact", "") for e in edge_raw.get("edges", []) if e.get("fact")})
|
||||
results["facts"] = facts
|
||||
|
||||
# 搜索节点(摘要)
|
||||
node_raw = self.store.search(self.graph_id, query, limit=20, scope="nodes")
|
||||
summaries = set()
|
||||
for n in node_raw.get("nodes", []):
|
||||
if n.get("summary"):
|
||||
summaries.add(n["summary"])
|
||||
if n.get("name") and n["name"] != entity_name:
|
||||
summaries.add(f"相关实体: {n['name']}")
|
||||
results["node_summaries"] = list(summaries)
|
||||
|
||||
# 构建上下文
|
||||
context_parts = []
|
||||
if results["facts"]:
|
||||
context_parts.append("事实信息:\n" + "\n".join(f"- {f}" for f in results["facts"][:20]))
|
||||
if results["node_summaries"]:
|
||||
context_parts.append("相关实体:\n" + "\n".join(f"- {s}" for s in results["node_summaries"][:10]))
|
||||
results["context"] = "\n\n".join(context_parts)
|
||||
|
||||
logger.info(f"Zep混合检索完成: {entity_name}, 获取 {len(results['facts'])} 条事实, {len(results['node_summaries'])} 个相关节点")
|
||||
|
||||
except concurrent.futures.TimeoutError:
|
||||
logger.warning(f"Zep检索超时 ({entity_name})")
|
||||
|
||||
logger.info(f"本地检索完成: {entity_name}, 获取 {len(results['facts'])} 条事实, "
|
||||
f"{len(results['node_summaries'])} 个相关节点")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"Zep检索失败 ({entity_name}): {e}")
|
||||
|
||||
logger.warning(f"本地检索失败 ({entity_name}): {e}")
|
||||
|
||||
return results
|
||||
|
||||
def _build_entity_context(self, entity: EntityNode) -> str:
|
||||
|
|
|
|||
|
|
@ -413,8 +413,8 @@ class OntologyGenerator:
|
|||
'由MiroFish自动生成,用于社会舆论模拟',
|
||||
'"""',
|
||||
'',
|
||||
'from pydantic import Field',
|
||||
'from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel',
|
||||
'from typing import Optional',
|
||||
'from pydantic import BaseModel, Field',
|
||||
'',
|
||||
'',
|
||||
'# ============== 实体类型定义 ==============',
|
||||
|
|
@ -426,15 +426,15 @@ class OntologyGenerator:
|
|||
name = entity["name"]
|
||||
desc = entity.get("description", f"A {name} entity.")
|
||||
|
||||
code_lines.append(f'class {name}(EntityModel):')
|
||||
code_lines.append(f'class {name}(BaseModel):')
|
||||
code_lines.append(f' """{desc}"""')
|
||||
|
||||
|
||||
attrs = entity.get("attributes", [])
|
||||
if attrs:
|
||||
for attr in attrs:
|
||||
attr_name = attr["name"]
|
||||
attr_desc = attr.get("description", attr_name)
|
||||
code_lines.append(f' {attr_name}: EntityText = Field(')
|
||||
code_lines.append(f' {attr_name}: Optional[str] = Field(')
|
||||
code_lines.append(f' description="{attr_desc}",')
|
||||
code_lines.append(f' default=None')
|
||||
code_lines.append(f' )')
|
||||
|
|
@ -454,15 +454,15 @@ class OntologyGenerator:
|
|||
class_name = ''.join(word.capitalize() for word in name.split('_'))
|
||||
desc = edge.get("description", f"A {name} relationship.")
|
||||
|
||||
code_lines.append(f'class {class_name}(EdgeModel):')
|
||||
code_lines.append(f'class {class_name}(BaseModel):')
|
||||
code_lines.append(f' """{desc}"""')
|
||||
|
||||
|
||||
attrs = edge.get("attributes", [])
|
||||
if attrs:
|
||||
for attr in attrs:
|
||||
attr_name = attr["name"]
|
||||
attr_desc = attr.get("description", attr_name)
|
||||
code_lines.append(f' {attr_name}: EntityText = Field(')
|
||||
code_lines.append(f' {attr_name}: Optional[str] = Field(')
|
||||
code_lines.append(f' description="{attr_desc}",')
|
||||
code_lines.append(f' default=None')
|
||||
code_lines.append(f' )')
|
||||
|
|
|
|||
|
|
@ -1,23 +1,17 @@
|
|||
"""
|
||||
Zep实体读取与过滤服务
|
||||
从Zep图谱中读取节点,筛选出符合预定义实体类型的节点
|
||||
实体读取与过滤服务
|
||||
从本地JSON图谱中读取节点,筛选出符合预定义实体类型的节点
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Dict, Any, List, Optional, Set, Callable, TypeVar
|
||||
from typing import Dict, Any, List, Optional, Set
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from zep_cloud.client import Zep
|
||||
|
||||
from ..config import Config
|
||||
from ..utils.local_graph_store import LocalGraphStore
|
||||
from ..utils.logger import get_logger
|
||||
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
|
||||
|
||||
logger = get_logger('mirofish.zep_entity_reader')
|
||||
|
||||
# 用于泛型返回类型
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
@dataclass
|
||||
class EntityNode:
|
||||
|
|
@ -27,11 +21,9 @@ class EntityNode:
|
|||
labels: List[str]
|
||||
summary: str
|
||||
attributes: Dict[str, Any]
|
||||
# 相关的边信息
|
||||
related_edges: List[Dict[str, Any]] = field(default_factory=list)
|
||||
# 相关的其他节点信息
|
||||
related_nodes: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"uuid": self.uuid,
|
||||
|
|
@ -42,11 +34,11 @@ class EntityNode:
|
|||
"related_edges": self.related_edges,
|
||||
"related_nodes": self.related_nodes,
|
||||
}
|
||||
|
||||
|
||||
def get_entity_type(self) -> Optional[str]:
|
||||
"""获取实体类型(排除默认的Entity标签)"""
|
||||
"""获取实体类型(排除默认的Entity/Node标签)"""
|
||||
for label in self.labels:
|
||||
if label not in ["Entity", "Node"]:
|
||||
if label not in ("Entity", "Node"):
|
||||
return label
|
||||
return None
|
||||
|
||||
|
|
@ -58,7 +50,7 @@ class FilteredEntities:
|
|||
entity_types: Set[str]
|
||||
total_count: int
|
||||
filtered_count: int
|
||||
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"entities": [e.to_dict() for e in self.entities],
|
||||
|
|
@ -70,368 +62,215 @@ class FilteredEntities:
|
|||
|
||||
class ZepEntityReader:
|
||||
"""
|
||||
Zep实体读取与过滤服务
|
||||
|
||||
实体读取与过滤服务
|
||||
|
||||
主要功能:
|
||||
1. 从Zep图谱读取所有节点
|
||||
1. 从本地图谱读取所有节点
|
||||
2. 筛选出符合预定义实体类型的节点(Labels不只是Entity的节点)
|
||||
3. 获取每个实体的相关边和关联节点信息
|
||||
"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
self.api_key = api_key or Config.ZEP_API_KEY
|
||||
if not self.api_key:
|
||||
raise ValueError("ZEP_API_KEY 未配置")
|
||||
|
||||
self.client = Zep(api_key=self.api_key)
|
||||
|
||||
def _call_with_retry(
|
||||
self,
|
||||
func: Callable[[], T],
|
||||
operation_name: str,
|
||||
max_retries: int = 3,
|
||||
initial_delay: float = 2.0
|
||||
) -> T:
|
||||
"""
|
||||
带重试机制的Zep API调用
|
||||
|
||||
Args:
|
||||
func: 要执行的函数(无参数的lambda或callable)
|
||||
operation_name: 操作名称,用于日志
|
||||
max_retries: 最大重试次数(默认3次,即最多尝试3次)
|
||||
initial_delay: 初始延迟秒数
|
||||
|
||||
Returns:
|
||||
API调用结果
|
||||
"""
|
||||
last_exception = None
|
||||
delay = initial_delay
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return func()
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
if attempt < max_retries - 1:
|
||||
logger.warning(
|
||||
f"Zep {operation_name} 第 {attempt + 1} 次尝试失败: {str(e)[:100]}, "
|
||||
f"{delay:.1f}秒后重试..."
|
||||
)
|
||||
time.sleep(delay)
|
||||
delay *= 2 # 指数退避
|
||||
else:
|
||||
logger.error(f"Zep {operation_name} 在 {max_retries} 次尝试后仍失败: {str(e)}")
|
||||
|
||||
raise last_exception
|
||||
|
||||
|
||||
def __init__(self, storage_dir: Optional[str] = None, api_key: Optional[str] = None):
|
||||
# api_key参数保留以兼容旧调用方式,但不再使用
|
||||
storage_dir = storage_dir or Config.GRAPH_STORAGE_DIR
|
||||
self.store = LocalGraphStore(storage_dir)
|
||||
|
||||
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取图谱的所有节点(分页获取)
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
|
||||
Returns:
|
||||
节点列表
|
||||
"""
|
||||
"""获取图谱的所有节点"""
|
||||
logger.info(f"获取图谱 {graph_id} 的所有节点...")
|
||||
|
||||
nodes = fetch_all_nodes(self.client, graph_id)
|
||||
|
||||
nodes_data = []
|
||||
for node in nodes:
|
||||
nodes_data.append({
|
||||
"uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
|
||||
"name": node.name or "",
|
||||
"labels": node.labels or [],
|
||||
"summary": node.summary or "",
|
||||
"attributes": node.attributes or {},
|
||||
})
|
||||
|
||||
logger.info(f"共获取 {len(nodes_data)} 个节点")
|
||||
return nodes_data
|
||||
nodes = self.store.get_nodes(graph_id)
|
||||
logger.info(f"共获取 {len(nodes)} 个节点")
|
||||
return nodes
|
||||
|
||||
def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取图谱的所有边(分页获取)
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
|
||||
Returns:
|
||||
边列表
|
||||
"""
|
||||
"""获取图谱的所有边"""
|
||||
logger.info(f"获取图谱 {graph_id} 的所有边...")
|
||||
edges = self.store.get_edges(graph_id)
|
||||
logger.info(f"共获取 {len(edges)} 条边")
|
||||
return edges
|
||||
|
||||
edges = fetch_all_edges(self.client, graph_id)
|
||||
|
||||
edges_data = []
|
||||
for edge in edges:
|
||||
edges_data.append({
|
||||
"uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
|
||||
"name": edge.name or "",
|
||||
"fact": edge.fact or "",
|
||||
"source_node_uuid": edge.source_node_uuid,
|
||||
"target_node_uuid": edge.target_node_uuid,
|
||||
"attributes": edge.attributes or {},
|
||||
})
|
||||
|
||||
logger.info(f"共获取 {len(edges_data)} 条边")
|
||||
return edges_data
|
||||
|
||||
def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定节点的所有相关边(带重试机制)
|
||||
|
||||
Args:
|
||||
node_uuid: 节点UUID
|
||||
|
||||
Returns:
|
||||
边列表
|
||||
"""
|
||||
def get_node_edges(self, graph_id: str, node_uuid: str) -> List[Dict[str, Any]]:
|
||||
"""获取指定节点的所有相关边"""
|
||||
try:
|
||||
# 使用重试机制调用Zep API
|
||||
edges = self._call_with_retry(
|
||||
func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid),
|
||||
operation_name=f"获取节点边(node={node_uuid[:8]}...)"
|
||||
)
|
||||
|
||||
edges_data = []
|
||||
for edge in edges:
|
||||
edges_data.append({
|
||||
"uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
|
||||
"name": edge.name or "",
|
||||
"fact": edge.fact or "",
|
||||
"source_node_uuid": edge.source_node_uuid,
|
||||
"target_node_uuid": edge.target_node_uuid,
|
||||
"attributes": edge.attributes or {},
|
||||
})
|
||||
|
||||
return edges_data
|
||||
return self.store.get_node_edges(graph_id, node_uuid)
|
||||
except Exception as e:
|
||||
logger.warning(f"获取节点 {node_uuid} 的边失败: {str(e)}")
|
||||
logger.warning(f"获取节点 {node_uuid} 的边失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def filter_defined_entities(
|
||||
self,
|
||||
self,
|
||||
graph_id: str,
|
||||
defined_entity_types: Optional[List[str]] = None,
|
||||
enrich_with_edges: bool = True
|
||||
) -> FilteredEntities:
|
||||
"""
|
||||
筛选出符合预定义实体类型的节点
|
||||
|
||||
|
||||
筛选逻辑:
|
||||
- 如果节点的Labels只有一个"Entity",说明这个实体不符合我们预定义的类型,跳过
|
||||
- 如果节点的Labels包含除"Entity"和"Node"之外的标签,说明符合预定义类型,保留
|
||||
|
||||
- 节点的Labels包含除"Entity"和"Node"之外的标签 → 符合预定义类型,保留
|
||||
- 节点的Labels只有"Entity"/"Node" → 不符合,跳过
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
defined_entity_types: 预定义的实体类型列表(可选,如果提供则只保留这些类型)
|
||||
defined_entity_types: 预定义实体类型列表(可选,若提供则只保留这些类型)
|
||||
enrich_with_edges: 是否获取每个实体的相关边信息
|
||||
|
||||
Returns:
|
||||
FilteredEntities: 过滤后的实体集合
|
||||
"""
|
||||
logger.info(f"开始筛选图谱 {graph_id} 的实体...")
|
||||
|
||||
# 获取所有节点
|
||||
|
||||
all_nodes = self.get_all_nodes(graph_id)
|
||||
total_count = len(all_nodes)
|
||||
|
||||
# 获取所有边(用于后续关联查找)
|
||||
|
||||
all_edges = self.get_all_edges(graph_id) if enrich_with_edges else []
|
||||
|
||||
# 构建节点UUID到节点数据的映射
|
||||
|
||||
node_map = {n["uuid"]: n for n in all_nodes}
|
||||
|
||||
# 筛选符合条件的实体
|
||||
|
||||
filtered_entities = []
|
||||
entity_types_found = set()
|
||||
|
||||
entity_types_found: Set[str] = set()
|
||||
|
||||
for node in all_nodes:
|
||||
labels = node.get("labels", [])
|
||||
|
||||
# 筛选逻辑:Labels必须包含除"Entity"和"Node"之外的标签
|
||||
custom_labels = [l for l in labels if l not in ["Entity", "Node"]]
|
||||
|
||||
labels = node.get("labels") or []
|
||||
custom_labels = [l for l in labels if l not in ("Entity", "Node")]
|
||||
|
||||
if not custom_labels:
|
||||
# 只有默认标签,跳过
|
||||
continue
|
||||
|
||||
# 如果指定了预定义类型,检查是否匹配
|
||||
|
||||
if defined_entity_types:
|
||||
matching_labels = [l for l in custom_labels if l in defined_entity_types]
|
||||
if not matching_labels:
|
||||
matching = [l for l in custom_labels if l in defined_entity_types]
|
||||
if not matching:
|
||||
continue
|
||||
entity_type = matching_labels[0]
|
||||
entity_type = matching[0]
|
||||
else:
|
||||
entity_type = custom_labels[0]
|
||||
|
||||
|
||||
entity_types_found.add(entity_type)
|
||||
|
||||
# 创建实体节点对象
|
||||
|
||||
entity = EntityNode(
|
||||
uuid=node["uuid"],
|
||||
name=node["name"],
|
||||
name=node.get("name", ""),
|
||||
labels=labels,
|
||||
summary=node["summary"],
|
||||
attributes=node["attributes"],
|
||||
summary=node.get("summary", ""),
|
||||
attributes=node.get("attributes", {}),
|
||||
)
|
||||
|
||||
# 获取相关边和节点
|
||||
|
||||
if enrich_with_edges:
|
||||
related_edges = []
|
||||
related_node_uuids = set()
|
||||
|
||||
related_node_uuids: Set[str] = set()
|
||||
|
||||
for edge in all_edges:
|
||||
if edge["source_node_uuid"] == node["uuid"]:
|
||||
if edge.get("source_node_uuid") == node["uuid"]:
|
||||
related_edges.append({
|
||||
"direction": "outgoing",
|
||||
"edge_name": edge["name"],
|
||||
"fact": edge["fact"],
|
||||
"target_node_uuid": edge["target_node_uuid"],
|
||||
"edge_name": edge.get("name", ""),
|
||||
"fact": edge.get("fact", ""),
|
||||
"target_node_uuid": edge.get("target_node_uuid", ""),
|
||||
})
|
||||
related_node_uuids.add(edge["target_node_uuid"])
|
||||
elif edge["target_node_uuid"] == node["uuid"]:
|
||||
related_node_uuids.add(edge.get("target_node_uuid", ""))
|
||||
elif edge.get("target_node_uuid") == node["uuid"]:
|
||||
related_edges.append({
|
||||
"direction": "incoming",
|
||||
"edge_name": edge["name"],
|
||||
"fact": edge["fact"],
|
||||
"source_node_uuid": edge["source_node_uuid"],
|
||||
"edge_name": edge.get("name", ""),
|
||||
"fact": edge.get("fact", ""),
|
||||
"source_node_uuid": edge.get("source_node_uuid", ""),
|
||||
})
|
||||
related_node_uuids.add(edge["source_node_uuid"])
|
||||
|
||||
related_node_uuids.add(edge.get("source_node_uuid", ""))
|
||||
|
||||
entity.related_edges = related_edges
|
||||
|
||||
# 获取关联节点的基本信息
|
||||
|
||||
related_nodes = []
|
||||
for related_uuid in related_node_uuids:
|
||||
if related_uuid in node_map:
|
||||
related_node = node_map[related_uuid]
|
||||
if related_uuid and related_uuid in node_map:
|
||||
rn = node_map[related_uuid]
|
||||
related_nodes.append({
|
||||
"uuid": related_node["uuid"],
|
||||
"name": related_node["name"],
|
||||
"labels": related_node["labels"],
|
||||
"summary": related_node.get("summary", ""),
|
||||
"uuid": rn["uuid"],
|
||||
"name": rn.get("name", ""),
|
||||
"labels": rn.get("labels", []),
|
||||
"summary": rn.get("summary", ""),
|
||||
})
|
||||
|
||||
entity.related_nodes = related_nodes
|
||||
|
||||
|
||||
filtered_entities.append(entity)
|
||||
|
||||
|
||||
logger.info(f"筛选完成: 总节点 {total_count}, 符合条件 {len(filtered_entities)}, "
|
||||
f"实体类型: {entity_types_found}")
|
||||
|
||||
f"实体类型: {entity_types_found}")
|
||||
|
||||
return FilteredEntities(
|
||||
entities=filtered_entities,
|
||||
entity_types=entity_types_found,
|
||||
total_count=total_count,
|
||||
filtered_count=len(filtered_entities),
|
||||
)
|
||||
|
||||
|
||||
def get_entity_with_context(
|
||||
self,
|
||||
graph_id: str,
|
||||
self,
|
||||
graph_id: str,
|
||||
entity_uuid: str
|
||||
) -> Optional[EntityNode]:
|
||||
"""
|
||||
获取单个实体及其完整上下文(边和关联节点,带重试机制)
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
entity_uuid: 实体UUID
|
||||
|
||||
Returns:
|
||||
EntityNode或None
|
||||
"""
|
||||
"""获取单个实体及其完整上下文(边和关联节点)"""
|
||||
try:
|
||||
# 使用重试机制获取节点
|
||||
node = self._call_with_retry(
|
||||
func=lambda: self.client.graph.node.get(uuid_=entity_uuid),
|
||||
operation_name=f"获取节点详情(uuid={entity_uuid[:8]}...)"
|
||||
)
|
||||
|
||||
node = self.store.get_node(graph_id, entity_uuid)
|
||||
if not node:
|
||||
return None
|
||||
|
||||
# 获取节点的边
|
||||
edges = self.get_node_edges(entity_uuid)
|
||||
|
||||
# 获取所有节点用于关联查找
|
||||
|
||||
edges = self.get_node_edges(graph_id, entity_uuid)
|
||||
all_nodes = self.get_all_nodes(graph_id)
|
||||
node_map = {n["uuid"]: n for n in all_nodes}
|
||||
|
||||
# 处理相关边和节点
|
||||
|
||||
related_edges = []
|
||||
related_node_uuids = set()
|
||||
|
||||
related_node_uuids: Set[str] = set()
|
||||
|
||||
for edge in edges:
|
||||
if edge["source_node_uuid"] == entity_uuid:
|
||||
if edge.get("source_node_uuid") == entity_uuid:
|
||||
related_edges.append({
|
||||
"direction": "outgoing",
|
||||
"edge_name": edge["name"],
|
||||
"fact": edge["fact"],
|
||||
"target_node_uuid": edge["target_node_uuid"],
|
||||
"edge_name": edge.get("name", ""),
|
||||
"fact": edge.get("fact", ""),
|
||||
"target_node_uuid": edge.get("target_node_uuid", ""),
|
||||
})
|
||||
related_node_uuids.add(edge["target_node_uuid"])
|
||||
related_node_uuids.add(edge.get("target_node_uuid", ""))
|
||||
else:
|
||||
related_edges.append({
|
||||
"direction": "incoming",
|
||||
"edge_name": edge["name"],
|
||||
"fact": edge["fact"],
|
||||
"source_node_uuid": edge["source_node_uuid"],
|
||||
"edge_name": edge.get("name", ""),
|
||||
"fact": edge.get("fact", ""),
|
||||
"source_node_uuid": edge.get("source_node_uuid", ""),
|
||||
})
|
||||
related_node_uuids.add(edge["source_node_uuid"])
|
||||
|
||||
# 获取关联节点信息
|
||||
related_node_uuids.add(edge.get("source_node_uuid", ""))
|
||||
|
||||
related_nodes = []
|
||||
for related_uuid in related_node_uuids:
|
||||
if related_uuid in node_map:
|
||||
related_node = node_map[related_uuid]
|
||||
if related_uuid and related_uuid in node_map:
|
||||
rn = node_map[related_uuid]
|
||||
related_nodes.append({
|
||||
"uuid": related_node["uuid"],
|
||||
"name": related_node["name"],
|
||||
"labels": related_node["labels"],
|
||||
"summary": related_node.get("summary", ""),
|
||||
"uuid": rn["uuid"],
|
||||
"name": rn.get("name", ""),
|
||||
"labels": rn.get("labels", []),
|
||||
"summary": rn.get("summary", ""),
|
||||
})
|
||||
|
||||
|
||||
return EntityNode(
|
||||
uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
|
||||
name=node.name or "",
|
||||
labels=node.labels or [],
|
||||
summary=node.summary or "",
|
||||
attributes=node.attributes or {},
|
||||
uuid=node["uuid"],
|
||||
name=node.get("name", ""),
|
||||
labels=node.get("labels", []),
|
||||
summary=node.get("summary", ""),
|
||||
attributes=node.get("attributes", {}),
|
||||
related_edges=related_edges,
|
||||
related_nodes=related_nodes,
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取实体 {entity_uuid} 失败: {str(e)}")
|
||||
logger.error(f"获取实体 {entity_uuid} 失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_entities_by_type(
|
||||
self,
|
||||
graph_id: str,
|
||||
self,
|
||||
graph_id: str,
|
||||
entity_type: str,
|
||||
enrich_with_edges: bool = True
|
||||
) -> List[EntityNode]:
|
||||
"""
|
||||
获取指定类型的所有实体
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
entity_type: 实体类型(如 "Student", "PublicFigure" 等)
|
||||
enrich_with_edges: 是否获取相关边信息
|
||||
|
||||
Returns:
|
||||
实体列表
|
||||
"""
|
||||
"""获取指定类型的所有实体"""
|
||||
result = self.filter_defined_entities(
|
||||
graph_id=graph_id,
|
||||
defined_entity_types=[entity_type],
|
||||
enrich_with_edges=enrich_with_edges
|
||||
)
|
||||
return result.entities
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
Zep图谱记忆更新服务
|
||||
将模拟中的Agent活动动态更新到Zep图谱中
|
||||
图谱记忆更新服务
|
||||
将模拟中的Agent活动动态写入本地JSON图谱文件
|
||||
"""
|
||||
|
||||
import os
|
||||
|
|
@ -12,9 +12,8 @@ from dataclasses import dataclass
|
|||
from datetime import datetime
|
||||
from queue import Queue, Empty
|
||||
|
||||
from zep_cloud.client import Zep
|
||||
|
||||
from ..config import Config
|
||||
from ..utils.local_graph_store import LocalGraphStore
|
||||
from ..utils.logger import get_logger
|
||||
from ..utils.locale import get_locale, set_locale
|
||||
|
||||
|
|
@ -31,15 +30,14 @@ class AgentActivity:
|
|||
action_args: Dict[str, Any]
|
||||
round_num: int
|
||||
timestamp: str
|
||||
|
||||
|
||||
def to_episode_text(self) -> str:
|
||||
"""
|
||||
将活动转换为可以发送给Zep的文本描述
|
||||
|
||||
采用自然语言描述格式,让Zep能够从中提取实体和关系
|
||||
将活动转换为自然语言描述文本
|
||||
|
||||
采用自然语言描述格式,让图谱能够从中提取实体和关系
|
||||
不添加模拟相关的前缀,避免误导图谱更新
|
||||
"""
|
||||
# 根据不同的动作类型生成不同的描述
|
||||
action_descriptions = {
|
||||
"CREATE_POST": self._describe_create_post,
|
||||
"LIKE_POST": self._describe_like_post,
|
||||
|
|
@ -54,24 +52,22 @@ class AgentActivity:
|
|||
"SEARCH_USER": self._describe_search_user,
|
||||
"MUTE": self._describe_mute,
|
||||
}
|
||||
|
||||
|
||||
describe_func = action_descriptions.get(self.action_type, self._describe_generic)
|
||||
description = describe_func()
|
||||
|
||||
# 直接返回 "agent名称: 活动描述" 格式,不添加模拟前缀
|
||||
|
||||
return f"{self.agent_name}: {description}"
|
||||
|
||||
|
||||
def _describe_create_post(self) -> str:
|
||||
content = self.action_args.get("content", "")
|
||||
if content:
|
||||
return f"发布了一条帖子:「{content}」"
|
||||
return "发布了一条帖子"
|
||||
|
||||
|
||||
def _describe_like_post(self) -> str:
|
||||
"""点赞帖子 - 包含帖子原文和作者信息"""
|
||||
post_content = self.action_args.get("post_content", "")
|
||||
post_author = self.action_args.get("post_author_name", "")
|
||||
|
||||
|
||||
if post_content and post_author:
|
||||
return f"点赞了{post_author}的帖子:「{post_content}」"
|
||||
elif post_content:
|
||||
|
|
@ -79,12 +75,11 @@ class AgentActivity:
|
|||
elif post_author:
|
||||
return f"点赞了{post_author}的一条帖子"
|
||||
return "点赞了一条帖子"
|
||||
|
||||
|
||||
def _describe_dislike_post(self) -> str:
|
||||
"""踩帖子 - 包含帖子原文和作者信息"""
|
||||
post_content = self.action_args.get("post_content", "")
|
||||
post_author = self.action_args.get("post_author_name", "")
|
||||
|
||||
|
||||
if post_content and post_author:
|
||||
return f"踩了{post_author}的帖子:「{post_content}」"
|
||||
elif post_content:
|
||||
|
|
@ -92,12 +87,11 @@ class AgentActivity:
|
|||
elif post_author:
|
||||
return f"踩了{post_author}的一条帖子"
|
||||
return "踩了一条帖子"
|
||||
|
||||
|
||||
def _describe_repost(self) -> str:
|
||||
"""转发帖子 - 包含原帖内容和作者信息"""
|
||||
original_content = self.action_args.get("original_content", "")
|
||||
original_author = self.action_args.get("original_author_name", "")
|
||||
|
||||
|
||||
if original_content and original_author:
|
||||
return f"转发了{original_author}的帖子:「{original_content}」"
|
||||
elif original_content:
|
||||
|
|
@ -105,13 +99,12 @@ class AgentActivity:
|
|||
elif original_author:
|
||||
return f"转发了{original_author}的一条帖子"
|
||||
return "转发了一条帖子"
|
||||
|
||||
|
||||
def _describe_quote_post(self) -> str:
|
||||
"""引用帖子 - 包含原帖内容、作者信息和引用评论"""
|
||||
original_content = self.action_args.get("original_content", "")
|
||||
original_author = self.action_args.get("original_author_name", "")
|
||||
quote_content = self.action_args.get("quote_content", "") or self.action_args.get("content", "")
|
||||
|
||||
|
||||
base = ""
|
||||
if original_content and original_author:
|
||||
base = f"引用了{original_author}的帖子「{original_content}」"
|
||||
|
|
@ -121,25 +114,22 @@ class AgentActivity:
|
|||
base = f"引用了{original_author}的一条帖子"
|
||||
else:
|
||||
base = "引用了一条帖子"
|
||||
|
||||
|
||||
if quote_content:
|
||||
base += f",并评论道:「{quote_content}」"
|
||||
return base
|
||||
|
||||
|
||||
def _describe_follow(self) -> str:
|
||||
"""关注用户 - 包含被关注用户的名称"""
|
||||
target_user_name = self.action_args.get("target_user_name", "")
|
||||
|
||||
if target_user_name:
|
||||
return f"关注了用户「{target_user_name}」"
|
||||
return "关注了一个用户"
|
||||
|
||||
|
||||
def _describe_create_comment(self) -> str:
|
||||
"""发表评论 - 包含评论内容和所评论的帖子信息"""
|
||||
content = self.action_args.get("content", "")
|
||||
post_content = self.action_args.get("post_content", "")
|
||||
post_author = self.action_args.get("post_author_name", "")
|
||||
|
||||
|
||||
if content:
|
||||
if post_content and post_author:
|
||||
return f"在{post_author}的帖子「{post_content}」下评论道:「{content}」"
|
||||
|
|
@ -149,12 +139,11 @@ class AgentActivity:
|
|||
return f"在{post_author}的帖子下评论道:「{content}」"
|
||||
return f"评论道:「{content}」"
|
||||
return "发表了评论"
|
||||
|
||||
|
||||
def _describe_like_comment(self) -> str:
|
||||
"""点赞评论 - 包含评论内容和作者信息"""
|
||||
comment_content = self.action_args.get("comment_content", "")
|
||||
comment_author = self.action_args.get("comment_author_name", "")
|
||||
|
||||
|
||||
if comment_content and comment_author:
|
||||
return f"点赞了{comment_author}的评论:「{comment_content}」"
|
||||
elif comment_content:
|
||||
|
|
@ -162,12 +151,11 @@ class AgentActivity:
|
|||
elif comment_author:
|
||||
return f"点赞了{comment_author}的一条评论"
|
||||
return "点赞了一条评论"
|
||||
|
||||
|
||||
def _describe_dislike_comment(self) -> str:
|
||||
"""踩评论 - 包含评论内容和作者信息"""
|
||||
comment_content = self.action_args.get("comment_content", "")
|
||||
comment_author = self.action_args.get("comment_author_name", "")
|
||||
|
||||
|
||||
if comment_content and comment_author:
|
||||
return f"踩了{comment_author}的评论:「{comment_content}」"
|
||||
elif comment_content:
|
||||
|
|
@ -175,109 +163,83 @@ class AgentActivity:
|
|||
elif comment_author:
|
||||
return f"踩了{comment_author}的一条评论"
|
||||
return "踩了一条评论"
|
||||
|
||||
|
||||
def _describe_search(self) -> str:
|
||||
"""搜索帖子 - 包含搜索关键词"""
|
||||
query = self.action_args.get("query", "") or self.action_args.get("keyword", "")
|
||||
return f"搜索了「{query}」" if query else "进行了搜索"
|
||||
|
||||
|
||||
def _describe_search_user(self) -> str:
|
||||
"""搜索用户 - 包含搜索关键词"""
|
||||
query = self.action_args.get("query", "") or self.action_args.get("username", "")
|
||||
return f"搜索了用户「{query}」" if query else "搜索了用户"
|
||||
|
||||
|
||||
def _describe_mute(self) -> str:
|
||||
"""屏蔽用户 - 包含被屏蔽用户的名称"""
|
||||
target_user_name = self.action_args.get("target_user_name", "")
|
||||
|
||||
if target_user_name:
|
||||
return f"屏蔽了用户「{target_user_name}」"
|
||||
return "屏蔽了一个用户"
|
||||
|
||||
|
||||
def _describe_generic(self) -> str:
|
||||
# 对于未知的动作类型,生成通用描述
|
||||
return f"执行了{self.action_type}操作"
|
||||
|
||||
|
||||
class ZepGraphMemoryUpdater:
|
||||
class GraphMemoryUpdater:
|
||||
"""
|
||||
Zep图谱记忆更新器
|
||||
|
||||
监控模拟的actions日志文件,将新的agent活动实时更新到Zep图谱中。
|
||||
按平台分组,每累积BATCH_SIZE条活动后批量发送到Zep。
|
||||
|
||||
所有有意义的行为都会被更新到Zep,action_args中会包含完整的上下文信息:
|
||||
- 点赞/踩的帖子原文
|
||||
- 转发/引用的帖子原文
|
||||
- 关注/屏蔽的用户名
|
||||
- 点赞/踩的评论原文
|
||||
图谱记忆更新器
|
||||
|
||||
监控模拟的actions日志文件,将新的agent活动实时写入本地图谱。
|
||||
按平台分组,每累积BATCH_SIZE条活动后批量写入。
|
||||
"""
|
||||
|
||||
# 批量发送大小(每个平台累积多少条后发送)
|
||||
|
||||
BATCH_SIZE = 5
|
||||
|
||||
# 平台名称映射(用于控制台显示)
|
||||
|
||||
PLATFORM_DISPLAY_NAMES = {
|
||||
'twitter': '世界1',
|
||||
'reddit': '世界2',
|
||||
}
|
||||
|
||||
# 发送间隔(秒),避免请求过快
|
||||
SEND_INTERVAL = 0.5
|
||||
|
||||
# 重试配置
|
||||
|
||||
SEND_INTERVAL = 0.1 # 本地写入更快,间隔可以更短
|
||||
MAX_RETRIES = 3
|
||||
RETRY_DELAY = 2 # 秒
|
||||
|
||||
def __init__(self, graph_id: str, api_key: Optional[str] = None):
|
||||
RETRY_DELAY = 1
|
||||
|
||||
def __init__(self, graph_id: str, storage_dir: Optional[str] = None, api_key: Optional[str] = None):
|
||||
"""
|
||||
初始化更新器
|
||||
|
||||
|
||||
Args:
|
||||
graph_id: Zep图谱ID
|
||||
api_key: Zep API Key(可选,默认从配置读取)
|
||||
graph_id: 本地图谱ID
|
||||
storage_dir: 图谱存储目录(可选,默认从配置读取)
|
||||
api_key: 已废弃,保留以兼容旧调用代码
|
||||
"""
|
||||
self.graph_id = graph_id
|
||||
self.api_key = api_key or Config.ZEP_API_KEY
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("ZEP_API_KEY未配置")
|
||||
|
||||
self.client = Zep(api_key=self.api_key)
|
||||
|
||||
# 活动队列
|
||||
storage_dir = storage_dir or Config.GRAPH_STORAGE_DIR
|
||||
self.store = LocalGraphStore(storage_dir)
|
||||
|
||||
self._activity_queue: Queue = Queue()
|
||||
|
||||
# 按平台分组的活动缓冲区(每个平台各自累积到BATCH_SIZE后批量发送)
|
||||
self._platform_buffers: Dict[str, List[AgentActivity]] = {
|
||||
'twitter': [],
|
||||
'reddit': [],
|
||||
}
|
||||
self._buffer_lock = threading.Lock()
|
||||
|
||||
# 控制标志
|
||||
|
||||
self._running = False
|
||||
self._worker_thread: Optional[threading.Thread] = None
|
||||
|
||||
# 统计
|
||||
self._total_activities = 0 # 实际添加到队列的活动数
|
||||
self._total_sent = 0 # 成功发送到Zep的批次数
|
||||
self._total_items_sent = 0 # 成功发送到Zep的活动条数
|
||||
self._failed_count = 0 # 发送失败的批次数
|
||||
self._skipped_count = 0 # 被过滤跳过的活动数(DO_NOTHING)
|
||||
|
||||
logger.info(f"ZepGraphMemoryUpdater 初始化完成: graph_id={graph_id}, batch_size={self.BATCH_SIZE}")
|
||||
|
||||
|
||||
self._total_activities = 0
|
||||
self._total_sent = 0
|
||||
self._total_items_sent = 0
|
||||
self._failed_count = 0
|
||||
self._skipped_count = 0
|
||||
|
||||
logger.info(f"GraphMemoryUpdater 初始化完成: graph_id={graph_id}, batch_size={self.BATCH_SIZE}")
|
||||
|
||||
def _get_platform_display_name(self, platform: str) -> str:
|
||||
"""获取平台的显示名称"""
|
||||
return self.PLATFORM_DISPLAY_NAMES.get(platform.lower(), platform)
|
||||
|
||||
|
||||
def start(self):
|
||||
"""启动后台工作线程"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
# Capture locale before spawning background thread
|
||||
current_locale = get_locale()
|
||||
|
||||
self._running = True
|
||||
|
|
@ -285,70 +247,42 @@ class ZepGraphMemoryUpdater:
|
|||
target=self._worker_loop,
|
||||
args=(current_locale,),
|
||||
daemon=True,
|
||||
name=f"ZepMemoryUpdater-{self.graph_id[:8]}"
|
||||
name=f"GraphMemoryUpdater-{self.graph_id[:8]}"
|
||||
)
|
||||
self._worker_thread.start()
|
||||
logger.info(f"ZepGraphMemoryUpdater 已启动: graph_id={self.graph_id}")
|
||||
|
||||
logger.info(f"GraphMemoryUpdater 已启动: graph_id={self.graph_id}")
|
||||
|
||||
def stop(self):
|
||||
"""停止后台工作线程"""
|
||||
self._running = False
|
||||
|
||||
# 发送剩余的活动
|
||||
|
||||
self._flush_remaining()
|
||||
|
||||
|
||||
if self._worker_thread and self._worker_thread.is_alive():
|
||||
self._worker_thread.join(timeout=10)
|
||||
|
||||
logger.info(f"ZepGraphMemoryUpdater 已停止: graph_id={self.graph_id}, "
|
||||
f"total_activities={self._total_activities}, "
|
||||
f"batches_sent={self._total_sent}, "
|
||||
f"items_sent={self._total_items_sent}, "
|
||||
f"failed={self._failed_count}, "
|
||||
f"skipped={self._skipped_count}")
|
||||
|
||||
|
||||
logger.info(f"GraphMemoryUpdater 已停止: graph_id={self.graph_id}, "
|
||||
f"total_activities={self._total_activities}, "
|
||||
f"batches_sent={self._total_sent}, "
|
||||
f"items_sent={self._total_items_sent}, "
|
||||
f"failed={self._failed_count}, "
|
||||
f"skipped={self._skipped_count}")
|
||||
|
||||
def add_activity(self, activity: AgentActivity):
|
||||
"""
|
||||
添加一个agent活动到队列
|
||||
|
||||
所有有意义的行为都会被添加到队列,包括:
|
||||
- CREATE_POST(发帖)
|
||||
- CREATE_COMMENT(评论)
|
||||
- QUOTE_POST(引用帖子)
|
||||
- SEARCH_POSTS(搜索帖子)
|
||||
- SEARCH_USER(搜索用户)
|
||||
- LIKE_POST/DISLIKE_POST(点赞/踩帖子)
|
||||
- REPOST(转发)
|
||||
- FOLLOW(关注)
|
||||
- MUTE(屏蔽)
|
||||
- LIKE_COMMENT/DISLIKE_COMMENT(点赞/踩评论)
|
||||
|
||||
action_args中会包含完整的上下文信息(如帖子原文、用户名等)。
|
||||
|
||||
Args:
|
||||
activity: Agent活动记录
|
||||
"""
|
||||
# 跳过DO_NOTHING类型的活动
|
||||
"""添加一个agent活动到队列"""
|
||||
if activity.action_type == "DO_NOTHING":
|
||||
self._skipped_count += 1
|
||||
return
|
||||
|
||||
|
||||
self._activity_queue.put(activity)
|
||||
self._total_activities += 1
|
||||
logger.debug(f"添加活动到Zep队列: {activity.agent_name} - {activity.action_type}")
|
||||
|
||||
logger.debug(f"添加活动到队列: {activity.agent_name} - {activity.action_type}")
|
||||
|
||||
def add_activity_from_dict(self, data: Dict[str, Any], platform: str):
|
||||
"""
|
||||
从字典数据添加活动
|
||||
|
||||
Args:
|
||||
data: 从actions.jsonl解析的字典数据
|
||||
platform: 平台名称 (twitter/reddit)
|
||||
"""
|
||||
# 跳过事件类型的条目
|
||||
"""从字典数据添加活动"""
|
||||
if "event_type" in data:
|
||||
return
|
||||
|
||||
|
||||
activity = AgentActivity(
|
||||
platform=platform,
|
||||
agent_id=data.get("agent_id", 0),
|
||||
|
|
@ -358,83 +292,94 @@ class ZepGraphMemoryUpdater:
|
|||
round_num=data.get("round", 0),
|
||||
timestamp=data.get("timestamp", datetime.now().isoformat()),
|
||||
)
|
||||
|
||||
|
||||
self.add_activity(activity)
|
||||
|
||||
|
||||
def _worker_loop(self, locale: str = 'zh'):
|
||||
"""后台工作循环 - 按平台批量发送活动到Zep"""
|
||||
"""后台工作循环 - 按平台批量写入活动"""
|
||||
set_locale(locale)
|
||||
while self._running or not self._activity_queue.empty():
|
||||
try:
|
||||
# 尝试从队列获取活动(超时1秒)
|
||||
try:
|
||||
activity = self._activity_queue.get(timeout=1)
|
||||
|
||||
# 将活动添加到对应平台的缓冲区
|
||||
|
||||
platform = activity.platform.lower()
|
||||
with self._buffer_lock:
|
||||
if platform not in self._platform_buffers:
|
||||
self._platform_buffers[platform] = []
|
||||
self._platform_buffers[platform].append(activity)
|
||||
|
||||
# 检查该平台是否达到批量大小
|
||||
|
||||
if len(self._platform_buffers[platform]) >= self.BATCH_SIZE:
|
||||
batch = self._platform_buffers[platform][:self.BATCH_SIZE]
|
||||
self._platform_buffers[platform] = self._platform_buffers[platform][self.BATCH_SIZE:]
|
||||
# 释放锁后再发送
|
||||
self._send_batch_activities(batch, platform)
|
||||
# 发送间隔,避免请求过快
|
||||
self._write_batch_activities(batch, platform)
|
||||
time.sleep(self.SEND_INTERVAL)
|
||||
|
||||
|
||||
except Empty:
|
||||
pass
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工作循环异常: {e}")
|
||||
time.sleep(1)
|
||||
|
||||
def _send_batch_activities(self, activities: List[AgentActivity], platform: str):
|
||||
"""
|
||||
批量发送活动到Zep图谱(合并为一条文本)
|
||||
|
||||
Args:
|
||||
activities: Agent活动列表
|
||||
platform: 平台名称
|
||||
"""
|
||||
|
||||
def _write_batch_activities(self, activities: List[AgentActivity], platform: str):
|
||||
"""批量将活动写入本地图谱"""
|
||||
if not activities:
|
||||
return
|
||||
|
||||
# 将多条活动合并为一条文本,用换行分隔
|
||||
|
||||
episode_texts = [activity.to_episode_text() for activity in activities]
|
||||
combined_text = "\n".join(episode_texts)
|
||||
|
||||
# 带重试的发送
|
||||
|
||||
for attempt in range(self.MAX_RETRIES):
|
||||
try:
|
||||
self.client.graph.add(
|
||||
graph_id=self.graph_id,
|
||||
type="text",
|
||||
data=combined_text
|
||||
)
|
||||
|
||||
# 写入情节文本
|
||||
self.store.add_episode(self.graph_id, combined_text)
|
||||
|
||||
# 为每条活动创建可搜索的事实边
|
||||
for activity in activities:
|
||||
self._create_activity_edge(activity)
|
||||
|
||||
self._total_sent += 1
|
||||
self._total_items_sent += len(activities)
|
||||
display_name = self._get_platform_display_name(platform)
|
||||
logger.info(f"成功批量发送 {len(activities)} 条{display_name}活动到图谱 {self.graph_id}")
|
||||
logger.debug(f"批量内容预览: {combined_text[:200]}...")
|
||||
logger.info(f"成功写入 {len(activities)} 条{display_name}活动到图谱 {self.graph_id}")
|
||||
return
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if attempt < self.MAX_RETRIES - 1:
|
||||
logger.warning(f"批量发送到Zep失败 (尝试 {attempt + 1}/{self.MAX_RETRIES}): {e}")
|
||||
logger.warning(f"写入活动失败 (尝试 {attempt + 1}/{self.MAX_RETRIES}): {e}")
|
||||
time.sleep(self.RETRY_DELAY * (attempt + 1))
|
||||
else:
|
||||
logger.error(f"批量发送到Zep失败,已重试{self.MAX_RETRIES}次: {e}")
|
||||
logger.error(f"写入活动失败,已重试{self.MAX_RETRIES}次: {e}")
|
||||
self._failed_count += 1
|
||||
|
||||
|
||||
def _create_activity_edge(self, activity: AgentActivity):
|
||||
"""为单条活动在图谱中创建可搜索的事实边"""
|
||||
fact = activity.to_episode_text()
|
||||
|
||||
# 创建或获取Agent节点
|
||||
agent_uuid = self.store.upsert_node(
|
||||
graph_id=self.graph_id,
|
||||
name=activity.agent_name,
|
||||
labels=["Agent", "Entity"],
|
||||
summary=f"Agent {activity.agent_name}",
|
||||
)
|
||||
|
||||
# 为活动创建自环边(以便关键词搜索可以找到它)
|
||||
self.store.add_edge(self.graph_id, {
|
||||
"name": activity.action_type,
|
||||
"fact": fact,
|
||||
"source_node_uuid": agent_uuid,
|
||||
"target_node_uuid": agent_uuid,
|
||||
"attributes": {
|
||||
"platform": activity.platform,
|
||||
"round": activity.round_num,
|
||||
"timestamp": activity.timestamp,
|
||||
},
|
||||
})
|
||||
|
||||
def _flush_remaining(self):
|
||||
"""发送队列和缓冲区中剩余的活动"""
|
||||
# 首先处理队列中剩余的活动,添加到缓冲区
|
||||
while not self._activity_queue.empty():
|
||||
try:
|
||||
activity = self._activity_queue.get_nowait()
|
||||
|
|
@ -445,96 +390,83 @@ class ZepGraphMemoryUpdater:
|
|||
self._platform_buffers[platform].append(activity)
|
||||
except Empty:
|
||||
break
|
||||
|
||||
# 然后发送各平台缓冲区中剩余的活动(即使不足BATCH_SIZE条)
|
||||
|
||||
with self._buffer_lock:
|
||||
for platform, buffer in self._platform_buffers.items():
|
||||
if buffer:
|
||||
display_name = self._get_platform_display_name(platform)
|
||||
logger.info(f"发送{display_name}平台剩余的 {len(buffer)} 条活动")
|
||||
self._send_batch_activities(buffer, platform)
|
||||
# 清空所有缓冲区
|
||||
self._write_batch_activities(buffer, platform)
|
||||
for platform in self._platform_buffers:
|
||||
self._platform_buffers[platform] = []
|
||||
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
with self._buffer_lock:
|
||||
buffer_sizes = {p: len(b) for p, b in self._platform_buffers.items()}
|
||||
|
||||
|
||||
return {
|
||||
"graph_id": self.graph_id,
|
||||
"batch_size": self.BATCH_SIZE,
|
||||
"total_activities": self._total_activities, # 添加到队列的活动总数
|
||||
"batches_sent": self._total_sent, # 成功发送的批次数
|
||||
"items_sent": self._total_items_sent, # 成功发送的活动条数
|
||||
"failed_count": self._failed_count, # 发送失败的批次数
|
||||
"skipped_count": self._skipped_count, # 被过滤跳过的活动数(DO_NOTHING)
|
||||
"total_activities": self._total_activities,
|
||||
"batches_sent": self._total_sent,
|
||||
"items_sent": self._total_items_sent,
|
||||
"failed_count": self._failed_count,
|
||||
"skipped_count": self._skipped_count,
|
||||
"queue_size": self._activity_queue.qsize(),
|
||||
"buffer_sizes": buffer_sizes, # 各平台缓冲区大小
|
||||
"buffer_sizes": buffer_sizes,
|
||||
"running": self._running,
|
||||
}
|
||||
|
||||
|
||||
# 向后兼容别名
|
||||
ZepGraphMemoryUpdater = GraphMemoryUpdater
|
||||
|
||||
|
||||
class ZepGraphMemoryManager:
|
||||
"""
|
||||
管理多个模拟的Zep图谱记忆更新器
|
||||
|
||||
管理多个模拟的图谱记忆更新器
|
||||
|
||||
每个模拟可以有自己的更新器实例
|
||||
"""
|
||||
|
||||
_updaters: Dict[str, ZepGraphMemoryUpdater] = {}
|
||||
|
||||
_updaters: Dict[str, GraphMemoryUpdater] = {}
|
||||
_lock = threading.Lock()
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_updater(cls, simulation_id: str, graph_id: str) -> ZepGraphMemoryUpdater:
|
||||
"""
|
||||
为模拟创建图谱记忆更新器
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
graph_id: Zep图谱ID
|
||||
|
||||
Returns:
|
||||
ZepGraphMemoryUpdater实例
|
||||
"""
|
||||
def create_updater(cls, simulation_id: str, graph_id: str) -> GraphMemoryUpdater:
|
||||
"""为模拟创建图谱记忆更新器"""
|
||||
with cls._lock:
|
||||
# 如果已存在,先停止旧的
|
||||
if simulation_id in cls._updaters:
|
||||
cls._updaters[simulation_id].stop()
|
||||
|
||||
updater = ZepGraphMemoryUpdater(graph_id)
|
||||
|
||||
updater = GraphMemoryUpdater(graph_id)
|
||||
updater.start()
|
||||
cls._updaters[simulation_id] = updater
|
||||
|
||||
|
||||
logger.info(f"创建图谱记忆更新器: simulation_id={simulation_id}, graph_id={graph_id}")
|
||||
return updater
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_updater(cls, simulation_id: str) -> Optional[ZepGraphMemoryUpdater]:
|
||||
"""获取模拟的更新器"""
|
||||
def get_updater(cls, simulation_id: str) -> Optional[GraphMemoryUpdater]:
|
||||
return cls._updaters.get(simulation_id)
|
||||
|
||||
|
||||
@classmethod
|
||||
def stop_updater(cls, simulation_id: str):
|
||||
"""停止并移除模拟的更新器"""
|
||||
with cls._lock:
|
||||
if simulation_id in cls._updaters:
|
||||
cls._updaters[simulation_id].stop()
|
||||
del cls._updaters[simulation_id]
|
||||
logger.info(f"已停止图谱记忆更新器: simulation_id={simulation_id}")
|
||||
|
||||
# 防止 stop_all 重复调用的标志
|
||||
|
||||
_stop_all_done = False
|
||||
|
||||
|
||||
@classmethod
|
||||
def stop_all(cls):
|
||||
"""停止所有更新器"""
|
||||
# 防止重复调用
|
||||
if cls._stop_all_done:
|
||||
return
|
||||
cls._stop_all_done = True
|
||||
|
||||
|
||||
with cls._lock:
|
||||
if cls._updaters:
|
||||
for simulation_id, updater in list(cls._updaters.items()):
|
||||
|
|
@ -544,11 +476,10 @@ class ZepGraphMemoryManager:
|
|||
logger.error(f"停止更新器失败: simulation_id={simulation_id}, error={e}")
|
||||
cls._updaters.clear()
|
||||
logger.info("已停止所有图谱记忆更新器")
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_all_stats(cls) -> Dict[str, Dict[str, Any]]:
|
||||
"""获取所有更新器的统计信息"""
|
||||
return {
|
||||
sim_id: updater.get_stats()
|
||||
sim_id: updater.get_stats()
|
||||
for sim_id, updater in cls._updaters.items()
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,13 +13,11 @@ import json
|
|||
from typing import Dict, Any, List, Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from zep_cloud.client import Zep
|
||||
|
||||
from ..config import Config
|
||||
from ..utils.local_graph_store import LocalGraphStore
|
||||
from ..utils.logger import get_logger
|
||||
from ..utils.llm_client import LLMClient
|
||||
from ..utils.locale import get_locale, t
|
||||
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
|
||||
|
||||
logger = get_logger('mirofish.zep_tools')
|
||||
|
||||
|
|
@ -418,20 +416,14 @@ class ZepToolsService:
|
|||
- get_entity_summary - 获取实体的关系摘要
|
||||
"""
|
||||
|
||||
# 重试配置
|
||||
MAX_RETRIES = 3
|
||||
RETRY_DELAY = 2.0
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, llm_client: Optional[LLMClient] = None):
|
||||
self.api_key = api_key or Config.ZEP_API_KEY
|
||||
if not self.api_key:
|
||||
raise ValueError("ZEP_API_KEY 未配置")
|
||||
|
||||
self.client = Zep(api_key=self.api_key)
|
||||
# LLM客户端用于InsightForge生成子问题
|
||||
def __init__(self, storage_dir: Optional[str] = None, api_key: Optional[str] = None,
|
||||
llm_client: Optional[LLMClient] = None):
|
||||
# api_key参数保留以兼容旧调用方式,但不再使用
|
||||
storage_dir = storage_dir or Config.GRAPH_STORAGE_DIR
|
||||
self.store = LocalGraphStore(storage_dir)
|
||||
self._llm_client = llm_client
|
||||
logger.info(t("console.zepToolsInitialized"))
|
||||
|
||||
|
||||
@property
|
||||
def llm(self) -> LLMClient:
|
||||
"""延迟初始化LLM客户端"""
|
||||
|
|
@ -439,206 +431,50 @@ class ZepToolsService:
|
|||
self._llm_client = LLMClient()
|
||||
return self._llm_client
|
||||
|
||||
def _call_with_retry(self, func, operation_name: str, max_retries: int = None):
|
||||
"""带重试机制的API调用"""
|
||||
max_retries = max_retries or self.MAX_RETRIES
|
||||
last_exception = None
|
||||
delay = self.RETRY_DELAY
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return func()
|
||||
except Exception as e:
|
||||
last_exception = e
|
||||
if attempt < max_retries - 1:
|
||||
logger.warning(
|
||||
t("console.zepRetryAttempt", operation=operation_name, attempt=attempt + 1, error=str(e)[:100], delay=f"{delay:.1f}")
|
||||
)
|
||||
time.sleep(delay)
|
||||
delay *= 2
|
||||
else:
|
||||
logger.error(t("console.zepAllRetriesFailed", operation=operation_name, retries=max_retries, error=str(e)))
|
||||
|
||||
raise last_exception
|
||||
|
||||
def search_graph(
|
||||
self,
|
||||
graph_id: str,
|
||||
query: str,
|
||||
self,
|
||||
graph_id: str,
|
||||
query: str,
|
||||
limit: int = 10,
|
||||
scope: str = "edges"
|
||||
) -> SearchResult:
|
||||
"""
|
||||
图谱语义搜索
|
||||
|
||||
使用混合搜索(语义+BM25)在图谱中搜索相关信息。
|
||||
如果Zep Cloud的search API不可用,则降级为本地关键词匹配。
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID (Standalone Graph)
|
||||
query: 搜索查询
|
||||
limit: 返回结果数量
|
||||
scope: 搜索范围,"edges" 或 "nodes"
|
||||
|
||||
Returns:
|
||||
SearchResult: 搜索结果
|
||||
"""
|
||||
logger.info(t("console.graphSearch", graphId=graph_id, query=query[:50]))
|
||||
|
||||
# 尝试使用Zep Cloud Search API
|
||||
try:
|
||||
search_results = self._call_with_retry(
|
||||
func=lambda: self.client.graph.search(
|
||||
graph_id=graph_id,
|
||||
query=query,
|
||||
limit=limit,
|
||||
scope=scope,
|
||||
reranker="cross_encoder"
|
||||
),
|
||||
operation_name=t("console.graphSearchOp", graphId=graph_id)
|
||||
)
|
||||
|
||||
facts = []
|
||||
edges = []
|
||||
nodes = []
|
||||
|
||||
# 解析边搜索结果
|
||||
if hasattr(search_results, 'edges') and search_results.edges:
|
||||
for edge in search_results.edges:
|
||||
if hasattr(edge, 'fact') and edge.fact:
|
||||
facts.append(edge.fact)
|
||||
edges.append({
|
||||
"uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
|
||||
"name": getattr(edge, 'name', ''),
|
||||
"fact": getattr(edge, 'fact', ''),
|
||||
"source_node_uuid": getattr(edge, 'source_node_uuid', ''),
|
||||
"target_node_uuid": getattr(edge, 'target_node_uuid', ''),
|
||||
})
|
||||
|
||||
# 解析节点搜索结果
|
||||
if hasattr(search_results, 'nodes') and search_results.nodes:
|
||||
for node in search_results.nodes:
|
||||
nodes.append({
|
||||
"uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
|
||||
"name": getattr(node, 'name', ''),
|
||||
"labels": getattr(node, 'labels', []),
|
||||
"summary": getattr(node, 'summary', ''),
|
||||
})
|
||||
# 节点摘要也算作事实
|
||||
if hasattr(node, 'summary') and node.summary:
|
||||
facts.append(f"[{node.name}]: {node.summary}")
|
||||
|
||||
logger.info(t("console.searchComplete", count=len(facts)))
|
||||
|
||||
return SearchResult(
|
||||
facts=facts,
|
||||
edges=edges,
|
||||
nodes=nodes,
|
||||
query=query,
|
||||
total_count=len(facts)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(t("console.zepSearchApiFallback", error=str(e)))
|
||||
# 降级:使用本地关键词匹配搜索
|
||||
return self._local_search(graph_id, query, limit, scope)
|
||||
|
||||
def _local_search(
|
||||
self,
|
||||
graph_id: str,
|
||||
query: str,
|
||||
limit: int = 10,
|
||||
scope: str = "edges"
|
||||
) -> SearchResult:
|
||||
"""
|
||||
本地关键词匹配搜索(作为Zep Search API的降级方案)
|
||||
|
||||
获取所有边/节点,然后在本地进行关键词匹配
|
||||
|
||||
图谱关键词搜索
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
query: 搜索查询
|
||||
limit: 返回结果数量
|
||||
scope: 搜索范围
|
||||
|
||||
scope: 搜索范围,"edges" 或 "nodes"
|
||||
|
||||
Returns:
|
||||
SearchResult: 搜索结果
|
||||
"""
|
||||
logger.info(t("console.graphSearch", graphId=graph_id, query=query[:50]))
|
||||
return self._local_search(graph_id, query, limit, scope)
|
||||
|
||||
def _local_search(
|
||||
self,
|
||||
graph_id: str,
|
||||
query: str,
|
||||
limit: int = 10,
|
||||
scope: str = "edges"
|
||||
) -> SearchResult:
|
||||
"""本地关键词匹配搜索"""
|
||||
logger.info(t("console.usingLocalSearch", query=query[:30]))
|
||||
|
||||
facts = []
|
||||
edges_result = []
|
||||
nodes_result = []
|
||||
|
||||
# 提取查询关键词(简单分词)
|
||||
query_lower = query.lower()
|
||||
keywords = [w.strip() for w in query_lower.replace(',', ' ').replace(',', ' ').split() if len(w.strip()) > 1]
|
||||
|
||||
def match_score(text: str) -> int:
|
||||
"""计算文本与查询的匹配分数"""
|
||||
if not text:
|
||||
return 0
|
||||
text_lower = text.lower()
|
||||
# 完全匹配查询
|
||||
if query_lower in text_lower:
|
||||
return 100
|
||||
# 关键词匹配
|
||||
score = 0
|
||||
for keyword in keywords:
|
||||
if keyword in text_lower:
|
||||
score += 10
|
||||
return score
|
||||
|
||||
|
||||
try:
|
||||
if scope in ["edges", "both"]:
|
||||
# 获取所有边并匹配
|
||||
all_edges = self.get_all_edges(graph_id)
|
||||
scored_edges = []
|
||||
for edge in all_edges:
|
||||
score = match_score(edge.fact) + match_score(edge.name)
|
||||
if score > 0:
|
||||
scored_edges.append((score, edge))
|
||||
|
||||
# 按分数排序
|
||||
scored_edges.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
for score, edge in scored_edges[:limit]:
|
||||
if edge.fact:
|
||||
facts.append(edge.fact)
|
||||
edges_result.append({
|
||||
"uuid": edge.uuid,
|
||||
"name": edge.name,
|
||||
"fact": edge.fact,
|
||||
"source_node_uuid": edge.source_node_uuid,
|
||||
"target_node_uuid": edge.target_node_uuid,
|
||||
})
|
||||
|
||||
if scope in ["nodes", "both"]:
|
||||
# 获取所有节点并匹配
|
||||
all_nodes = self.get_all_nodes(graph_id)
|
||||
scored_nodes = []
|
||||
for node in all_nodes:
|
||||
score = match_score(node.name) + match_score(node.summary)
|
||||
if score > 0:
|
||||
scored_nodes.append((score, node))
|
||||
|
||||
scored_nodes.sort(key=lambda x: x[0], reverse=True)
|
||||
|
||||
for score, node in scored_nodes[:limit]:
|
||||
nodes_result.append({
|
||||
"uuid": node.uuid,
|
||||
"name": node.name,
|
||||
"labels": node.labels,
|
||||
"summary": node.summary,
|
||||
})
|
||||
if node.summary:
|
||||
facts.append(f"[{node.name}]: {node.summary}")
|
||||
|
||||
raw = self.store.search(graph_id, query, limit=limit, scope=scope)
|
||||
|
||||
facts = raw.get("facts", [])
|
||||
edges_result = raw.get("edges", [])
|
||||
nodes_result = raw.get("nodes", [])
|
||||
|
||||
logger.info(t("console.localSearchComplete", count=len(facts)))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(t("console.localSearchFailed", error=str(e)))
|
||||
|
||||
facts, edges_result, nodes_result = [], [], []
|
||||
|
||||
return SearchResult(
|
||||
facts=facts,
|
||||
edges=edges_result,
|
||||
|
|
@ -648,99 +484,74 @@ class ZepToolsService:
|
|||
)
|
||||
|
||||
def get_all_nodes(self, graph_id: str) -> List[NodeInfo]:
|
||||
"""
|
||||
获取图谱的所有节点(分页获取)
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
|
||||
Returns:
|
||||
节点列表
|
||||
"""
|
||||
"""获取图谱的所有节点"""
|
||||
logger.info(t("console.fetchingAllNodes", graphId=graph_id))
|
||||
|
||||
nodes = fetch_all_nodes(self.client, graph_id)
|
||||
|
||||
result = []
|
||||
for node in nodes:
|
||||
node_uuid = getattr(node, 'uuid_', None) or getattr(node, 'uuid', None) or ""
|
||||
result.append(NodeInfo(
|
||||
uuid=str(node_uuid) if node_uuid else "",
|
||||
name=node.name or "",
|
||||
labels=node.labels or [],
|
||||
summary=node.summary or "",
|
||||
attributes=node.attributes or {}
|
||||
))
|
||||
nodes = self.store.get_nodes(graph_id)
|
||||
result = [
|
||||
NodeInfo(
|
||||
uuid=n.get("uuid", ""),
|
||||
name=n.get("name", ""),
|
||||
labels=n.get("labels") or [],
|
||||
summary=n.get("summary", ""),
|
||||
attributes=n.get("attributes") or {},
|
||||
)
|
||||
for n in nodes
|
||||
]
|
||||
|
||||
logger.info(t("console.fetchedNodes", count=len(result)))
|
||||
return result
|
||||
|
||||
def get_all_edges(self, graph_id: str, include_temporal: bool = True) -> List[EdgeInfo]:
|
||||
"""
|
||||
获取图谱的所有边(分页获取,包含时间信息)
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
include_temporal: 是否包含时间信息(默认True)
|
||||
|
||||
Returns:
|
||||
边列表(包含created_at, valid_at, invalid_at, expired_at)
|
||||
"""
|
||||
"""获取图谱的所有边(含时间信息)"""
|
||||
logger.info(t("console.fetchingAllEdges", graphId=graph_id))
|
||||
|
||||
edges = fetch_all_edges(self.client, graph_id)
|
||||
|
||||
edges = self.store.get_edges(graph_id)
|
||||
result = []
|
||||
for edge in edges:
|
||||
edge_uuid = getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', None) or ""
|
||||
for e in edges:
|
||||
edge_info = EdgeInfo(
|
||||
uuid=str(edge_uuid) if edge_uuid else "",
|
||||
name=edge.name or "",
|
||||
fact=edge.fact or "",
|
||||
source_node_uuid=edge.source_node_uuid or "",
|
||||
target_node_uuid=edge.target_node_uuid or ""
|
||||
uuid=e.get("uuid", ""),
|
||||
name=e.get("name", ""),
|
||||
fact=e.get("fact", ""),
|
||||
source_node_uuid=e.get("source_node_uuid", ""),
|
||||
target_node_uuid=e.get("target_node_uuid", ""),
|
||||
)
|
||||
|
||||
# 添加时间信息
|
||||
if include_temporal:
|
||||
edge_info.created_at = getattr(edge, 'created_at', None)
|
||||
edge_info.valid_at = getattr(edge, 'valid_at', None)
|
||||
edge_info.invalid_at = getattr(edge, 'invalid_at', None)
|
||||
edge_info.expired_at = getattr(edge, 'expired_at', None)
|
||||
|
||||
edge_info.created_at = e.get("created_at")
|
||||
edge_info.valid_at = e.get("valid_at")
|
||||
edge_info.invalid_at = e.get("invalid_at")
|
||||
edge_info.expired_at = e.get("expired_at")
|
||||
result.append(edge_info)
|
||||
|
||||
logger.info(t("console.fetchedEdges", count=len(result)))
|
||||
return result
|
||||
|
||||
def get_node_detail(self, node_uuid: str) -> Optional[NodeInfo]:
|
||||
def get_node_detail(self, node_uuid: str, graph_id: str = "") -> Optional[NodeInfo]:
|
||||
"""
|
||||
获取单个节点的详细信息
|
||||
|
||||
|
||||
Args:
|
||||
node_uuid: 节点UUID
|
||||
|
||||
graph_id: 图谱ID(从本地存储检索时需要)
|
||||
|
||||
Returns:
|
||||
节点信息或None
|
||||
"""
|
||||
logger.info(t("console.fetchingNodeDetail", uuid=node_uuid[:8]))
|
||||
|
||||
|
||||
try:
|
||||
node = self._call_with_retry(
|
||||
func=lambda: self.client.graph.node.get(uuid_=node_uuid),
|
||||
operation_name=t("console.fetchNodeDetailOp", uuid=node_uuid[:8])
|
||||
)
|
||||
|
||||
if not node:
|
||||
return None
|
||||
|
||||
return NodeInfo(
|
||||
uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
|
||||
name=node.name or "",
|
||||
labels=node.labels or [],
|
||||
summary=node.summary or "",
|
||||
attributes=node.attributes or {}
|
||||
)
|
||||
# 若提供了graph_id,直接从该图谱查找
|
||||
if graph_id:
|
||||
n = self.store.get_node(graph_id, node_uuid)
|
||||
if n:
|
||||
return NodeInfo(
|
||||
uuid=n.get("uuid", ""),
|
||||
name=n.get("name", ""),
|
||||
labels=n.get("labels") or [],
|
||||
summary=n.get("summary", ""),
|
||||
attributes=n.get("attributes") or {},
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(t("console.fetchNodeDetailFailed", error=str(e)))
|
||||
return None
|
||||
|
|
@ -1043,7 +854,7 @@ class ZepToolsService:
|
|||
continue
|
||||
try:
|
||||
# 单独获取每个相关节点的信息
|
||||
node = self.get_node_detail(uuid)
|
||||
node = self.get_node_detail(uuid, graph_id=graph_id)
|
||||
if node:
|
||||
node_map[uuid] = node
|
||||
entity_type = next((l for l in node.labels if l not in ["Entity", "Node"]), "实体")
|
||||
|
|
|
|||
|
|
@ -0,0 +1,290 @@
|
|||
"""
|
||||
本地JSON文件图谱存储
|
||||
替代Zep Cloud,将图谱数据(节点、边、情节)存储在本地JSON文件中
|
||||
|
||||
存储目录结构:
|
||||
{storage_dir}/
|
||||
{graph_id}/
|
||||
metadata.json - 图谱元数据和本体定义
|
||||
nodes.json - 节点列表
|
||||
edges.json - 边列表
|
||||
episodes.jsonl - 情节文本日志(追加写入)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
import threading
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from .logger import get_logger
|
||||
|
||||
logger = get_logger('mirofish.local_graph_store')
|
||||
|
||||
# 每个图谱一把锁,保证并发写入安全
|
||||
_global_lock = threading.Lock()
|
||||
_graph_locks: Dict[str, threading.Lock] = {}
|
||||
|
||||
|
||||
def _lock_for(graph_id: str) -> threading.Lock:
|
||||
with _global_lock:
|
||||
if graph_id not in _graph_locks:
|
||||
_graph_locks[graph_id] = threading.Lock()
|
||||
return _graph_locks[graph_id]
|
||||
|
||||
|
||||
class LocalGraphStore:
|
||||
"""本地JSON文件图谱存储"""
|
||||
|
||||
def __init__(self, storage_dir: str):
|
||||
self.storage_dir = storage_dir
|
||||
os.makedirs(storage_dir, exist_ok=True)
|
||||
|
||||
# ── 图谱生命周期 ──────────────────────────────────────────────────────────
|
||||
|
||||
def create_graph(self, graph_id: str, name: str, description: str = "") -> None:
|
||||
graph_dir = self._graph_dir(graph_id)
|
||||
os.makedirs(graph_dir, exist_ok=True)
|
||||
self._write_json(self._meta_path(graph_id), {
|
||||
"graph_id": graph_id,
|
||||
"name": name,
|
||||
"description": description,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"ontology": None,
|
||||
})
|
||||
if not os.path.exists(self._nodes_path(graph_id)):
|
||||
self._write_json(self._nodes_path(graph_id), [])
|
||||
if not os.path.exists(self._edges_path(graph_id)):
|
||||
self._write_json(self._edges_path(graph_id), [])
|
||||
logger.info(f"本地图谱已创建: {graph_id}")
|
||||
|
||||
def delete_graph(self, graph_id: str) -> None:
|
||||
graph_dir = self._graph_dir(graph_id)
|
||||
if os.path.exists(graph_dir):
|
||||
shutil.rmtree(graph_dir)
|
||||
logger.info(f"本地图谱已删除: {graph_id}")
|
||||
|
||||
def graph_exists(self, graph_id: str) -> bool:
|
||||
return os.path.exists(self._meta_path(graph_id))
|
||||
|
||||
# ── 本体 ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def set_ontology(self, graph_id: str, ontology: Dict[str, Any]) -> None:
|
||||
meta = self._read_json(self._meta_path(graph_id)) or {}
|
||||
meta["ontology"] = ontology
|
||||
self._write_json(self._meta_path(graph_id), meta)
|
||||
|
||||
def get_ontology(self, graph_id: str) -> Optional[Dict[str, Any]]:
|
||||
meta = self._read_json(self._meta_path(graph_id)) or {}
|
||||
return meta.get("ontology")
|
||||
|
||||
def get_metadata(self, graph_id: str) -> Optional[Dict[str, Any]]:
|
||||
return self._read_json(self._meta_path(graph_id))
|
||||
|
||||
# ── 情节(Episode)────────────────────────────────────────────────────────
|
||||
|
||||
def add_episode(self, graph_id: str, text: str) -> str:
|
||||
"""追加一条情节文本,返回情节uuid(本地存储立即处理完成)"""
|
||||
episode_id = uuid.uuid4().hex
|
||||
record = {
|
||||
"uuid": episode_id,
|
||||
"text": text,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"processed": True,
|
||||
}
|
||||
ep_path = self._episodes_path(graph_id)
|
||||
with _lock_for(graph_id):
|
||||
with open(ep_path, 'a', encoding='utf-8') as f:
|
||||
f.write(json.dumps(record, ensure_ascii=False) + '\n')
|
||||
return episode_id
|
||||
|
||||
def add_episodes_batch(self, graph_id: str, texts: List[str]) -> List[str]:
|
||||
return [self.add_episode(graph_id, t) for t in texts]
|
||||
|
||||
def episode_is_processed(self, graph_id: str, episode_uuid: str) -> bool:
|
||||
"""本地存储中的情节总是立即处理完成"""
|
||||
return True
|
||||
|
||||
# ── 节点 ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def get_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
|
||||
return self._read_json(self._nodes_path(graph_id)) or []
|
||||
|
||||
def get_node(self, graph_id: str, node_uuid: str) -> Optional[Dict[str, Any]]:
|
||||
for node in self.get_nodes(graph_id):
|
||||
if node.get("uuid") == node_uuid:
|
||||
return node
|
||||
return None
|
||||
|
||||
def upsert_node(
|
||||
self,
|
||||
graph_id: str,
|
||||
name: str,
|
||||
labels: Optional[List[str]] = None,
|
||||
summary: str = "",
|
||||
attributes: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""按名称查找节点,存在则更新,不存在则创建。返回uuid。"""
|
||||
labels = labels or ["Entity"]
|
||||
attributes = attributes or {}
|
||||
|
||||
with _lock_for(graph_id):
|
||||
nodes = self._read_json(self._nodes_path(graph_id)) or []
|
||||
# 按名称(不区分大小写)查找
|
||||
for node in nodes:
|
||||
if node.get("name", "").lower() == name.lower():
|
||||
# 合并标签
|
||||
existing = set(node.get("labels", []))
|
||||
existing.update(labels)
|
||||
node["labels"] = list(existing)
|
||||
# 若原摘要为空则填充
|
||||
if summary and not node.get("summary"):
|
||||
node["summary"] = summary
|
||||
# 合并属性
|
||||
if attributes:
|
||||
node.setdefault("attributes", {}).update(attributes)
|
||||
self._write_json(self._nodes_path(graph_id), nodes)
|
||||
return node["uuid"]
|
||||
# 创建新节点
|
||||
node_uuid = uuid.uuid4().hex
|
||||
nodes.append({
|
||||
"uuid": node_uuid,
|
||||
"name": name,
|
||||
"labels": labels,
|
||||
"summary": summary,
|
||||
"attributes": attributes,
|
||||
"created_at": datetime.now().isoformat(),
|
||||
})
|
||||
self._write_json(self._nodes_path(graph_id), nodes)
|
||||
return node_uuid
|
||||
|
||||
# ── 边 ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def get_edges(self, graph_id: str) -> List[Dict[str, Any]]:
|
||||
return self._read_json(self._edges_path(graph_id)) or []
|
||||
|
||||
def get_node_edges(self, graph_id: str, node_uuid: str) -> List[Dict[str, Any]]:
|
||||
"""获取与指定节点相关的所有边(作为源或目标)"""
|
||||
return [
|
||||
e for e in self.get_edges(graph_id)
|
||||
if e.get("source_node_uuid") == node_uuid or e.get("target_node_uuid") == node_uuid
|
||||
]
|
||||
|
||||
def add_edge(self, graph_id: str, edge: Dict[str, Any]) -> str:
|
||||
"""添加一条边,返回其uuid。"""
|
||||
edge_uuid = edge.get("uuid") or uuid.uuid4().hex
|
||||
edge = dict(edge)
|
||||
edge["uuid"] = edge_uuid
|
||||
edge.setdefault("created_at", datetime.now().isoformat())
|
||||
edge.setdefault("valid_at", None)
|
||||
edge.setdefault("invalid_at", None)
|
||||
edge.setdefault("expired_at", None)
|
||||
edge.setdefault("attributes", {})
|
||||
|
||||
with _lock_for(graph_id):
|
||||
edges = self._read_json(self._edges_path(graph_id)) or []
|
||||
edges.append(edge)
|
||||
self._write_json(self._edges_path(graph_id), edges)
|
||||
return edge_uuid
|
||||
|
||||
def add_fact_edge(
|
||||
self,
|
||||
graph_id: str,
|
||||
source_uuid: str,
|
||||
target_uuid: str,
|
||||
name: str,
|
||||
fact: str,
|
||||
) -> str:
|
||||
"""便利方法:在两个节点之间添加一条命名事实边。"""
|
||||
return self.add_edge(graph_id, {
|
||||
"name": name,
|
||||
"fact": fact,
|
||||
"source_node_uuid": source_uuid,
|
||||
"target_node_uuid": target_uuid,
|
||||
})
|
||||
|
||||
# ── 搜索 ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def search(
|
||||
self,
|
||||
graph_id: str,
|
||||
query: str,
|
||||
limit: int = 10,
|
||||
scope: str = "edges",
|
||||
) -> Dict[str, Any]:
|
||||
"""基于关键词的本地搜索"""
|
||||
query_lower = query.lower()
|
||||
keywords = [
|
||||
w.strip()
|
||||
for w in query_lower.replace(',', ' ').replace(',', ' ').split()
|
||||
if len(w.strip()) > 1
|
||||
]
|
||||
|
||||
def score(text: str) -> int:
|
||||
if not text:
|
||||
return 0
|
||||
tl = text.lower()
|
||||
if query_lower in tl:
|
||||
return 100
|
||||
return sum(10 for kw in keywords if kw in tl)
|
||||
|
||||
result_edges: List[Dict] = []
|
||||
result_nodes: List[Dict] = []
|
||||
facts: List[str] = []
|
||||
|
||||
if scope in ("edges", "both"):
|
||||
scored = sorted(
|
||||
[(score(e.get("fact", "")) + score(e.get("name", "")), e)
|
||||
for e in self.get_edges(graph_id)
|
||||
if score(e.get("fact", "")) + score(e.get("name", "")) > 0],
|
||||
key=lambda x: x[0], reverse=True,
|
||||
)
|
||||
for _, edge in scored[:limit]:
|
||||
result_edges.append(edge)
|
||||
if edge.get("fact"):
|
||||
facts.append(edge["fact"])
|
||||
|
||||
if scope in ("nodes", "both"):
|
||||
scored = sorted(
|
||||
[(score(n.get("name", "")) + score(n.get("summary", "")), n)
|
||||
for n in self.get_nodes(graph_id)
|
||||
if score(n.get("name", "")) + score(n.get("summary", "")) > 0],
|
||||
key=lambda x: x[0], reverse=True,
|
||||
)
|
||||
for _, node in scored[:limit]:
|
||||
result_nodes.append(node)
|
||||
if node.get("summary"):
|
||||
facts.append(f"[{node['name']}]: {node['summary']}")
|
||||
|
||||
return {"facts": facts, "edges": result_edges, "nodes": result_nodes}
|
||||
|
||||
# ── 内部路径辅助 ──────────────────────────────────────────────────────────
|
||||
|
||||
def _graph_dir(self, graph_id: str) -> str:
|
||||
return os.path.join(self.storage_dir, graph_id)
|
||||
|
||||
def _meta_path(self, graph_id: str) -> str:
|
||||
return os.path.join(self._graph_dir(graph_id), "metadata.json")
|
||||
|
||||
def _nodes_path(self, graph_id: str) -> str:
|
||||
return os.path.join(self._graph_dir(graph_id), "nodes.json")
|
||||
|
||||
def _edges_path(self, graph_id: str) -> str:
|
||||
return os.path.join(self._graph_dir(graph_id), "edges.json")
|
||||
|
||||
def _episodes_path(self, graph_id: str) -> str:
|
||||
return os.path.join(self._graph_dir(graph_id), "episodes.jsonl")
|
||||
|
||||
def _read_json(self, path: str) -> Any:
|
||||
if not os.path.exists(path):
|
||||
return None
|
||||
with open(path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
|
||||
def _write_json(self, path: str, data: Any) -> None:
|
||||
with open(path, 'w', encoding='utf-8') as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
|
@ -1,143 +1,25 @@
|
|||
"""Zep Graph 分页读取工具。
|
||||
"""
|
||||
图谱分页读取工具(存根模块)
|
||||
|
||||
Zep 的 node/edge 列表接口使用 UUID cursor 分页,
|
||||
本模块封装自动翻页逻辑(含单页重试),对调用方透明地返回完整列表。
|
||||
原来封装 Zep Cloud 的分页逻辑。
|
||||
现在图谱数据存储在本地 JSON 文件中,不再需要分页。
|
||||
本模块保留以避免破坏未更新的旧导入。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from zep_cloud import InternalServerError
|
||||
from zep_cloud.client import Zep
|
||||
|
||||
from .logger import get_logger
|
||||
from ..utils.logger import get_logger
|
||||
|
||||
logger = get_logger('mirofish.zep_paging')
|
||||
|
||||
_DEFAULT_PAGE_SIZE = 100
|
||||
_MAX_NODES = 2000
|
||||
_DEFAULT_MAX_RETRIES = 3
|
||||
_DEFAULT_RETRY_DELAY = 2.0 # seconds, doubles each retry
|
||||
|
||||
def fetch_all_nodes(client, graph_id: str, **kwargs) -> list:
|
||||
"""已废弃:请直接使用 LocalGraphStore.get_nodes()"""
|
||||
logger.warning("fetch_all_nodes 已废弃,请使用 LocalGraphStore.get_nodes()")
|
||||
return []
|
||||
|
||||
|
||||
def _fetch_page_with_retry(
|
||||
api_call: Callable[..., list[Any]],
|
||||
*args: Any,
|
||||
max_retries: int = _DEFAULT_MAX_RETRIES,
|
||||
retry_delay: float = _DEFAULT_RETRY_DELAY,
|
||||
page_description: str = "page",
|
||||
**kwargs: Any,
|
||||
) -> list[Any]:
|
||||
"""单页请求,失败时指数退避重试。仅重试网络/IO类瞬态错误。"""
|
||||
if max_retries < 1:
|
||||
raise ValueError("max_retries must be >= 1")
|
||||
|
||||
last_exception: Exception | None = None
|
||||
delay = retry_delay
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return api_call(*args, **kwargs)
|
||||
except (ConnectionError, TimeoutError, OSError, InternalServerError) as e:
|
||||
last_exception = e
|
||||
if attempt < max_retries - 1:
|
||||
logger.warning(
|
||||
f"Zep {page_description} attempt {attempt + 1} failed: {str(e)[:100]}, retrying in {delay:.1f}s..."
|
||||
)
|
||||
time.sleep(delay)
|
||||
delay *= 2
|
||||
else:
|
||||
logger.error(f"Zep {page_description} failed after {max_retries} attempts: {str(e)}")
|
||||
|
||||
assert last_exception is not None
|
||||
raise last_exception
|
||||
|
||||
|
||||
def fetch_all_nodes(
|
||||
client: Zep,
|
||||
graph_id: str,
|
||||
page_size: int = _DEFAULT_PAGE_SIZE,
|
||||
max_items: int = _MAX_NODES,
|
||||
max_retries: int = _DEFAULT_MAX_RETRIES,
|
||||
retry_delay: float = _DEFAULT_RETRY_DELAY,
|
||||
) -> list[Any]:
|
||||
"""分页获取图谱节点,最多返回 max_items 条(默认 2000)。每页请求自带重试。"""
|
||||
all_nodes: list[Any] = []
|
||||
cursor: str | None = None
|
||||
page_num = 0
|
||||
|
||||
while True:
|
||||
kwargs: dict[str, Any] = {"limit": page_size}
|
||||
if cursor is not None:
|
||||
kwargs["uuid_cursor"] = cursor
|
||||
|
||||
page_num += 1
|
||||
batch = _fetch_page_with_retry(
|
||||
client.graph.node.get_by_graph_id,
|
||||
graph_id,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
page_description=f"fetch nodes page {page_num} (graph={graph_id})",
|
||||
**kwargs,
|
||||
)
|
||||
if not batch:
|
||||
break
|
||||
|
||||
all_nodes.extend(batch)
|
||||
if len(all_nodes) >= max_items:
|
||||
all_nodes = all_nodes[:max_items]
|
||||
logger.warning(f"Node count reached limit ({max_items}), stopping pagination for graph {graph_id}")
|
||||
break
|
||||
if len(batch) < page_size:
|
||||
break
|
||||
|
||||
cursor = getattr(batch[-1], "uuid_", None) or getattr(batch[-1], "uuid", None)
|
||||
if cursor is None:
|
||||
logger.warning(f"Node missing uuid field, stopping pagination at {len(all_nodes)} nodes")
|
||||
break
|
||||
|
||||
return all_nodes
|
||||
|
||||
|
||||
def fetch_all_edges(
|
||||
client: Zep,
|
||||
graph_id: str,
|
||||
page_size: int = _DEFAULT_PAGE_SIZE,
|
||||
max_retries: int = _DEFAULT_MAX_RETRIES,
|
||||
retry_delay: float = _DEFAULT_RETRY_DELAY,
|
||||
) -> list[Any]:
|
||||
"""分页获取图谱所有边,返回完整列表。每页请求自带重试。"""
|
||||
all_edges: list[Any] = []
|
||||
cursor: str | None = None
|
||||
page_num = 0
|
||||
|
||||
while True:
|
||||
kwargs: dict[str, Any] = {"limit": page_size}
|
||||
if cursor is not None:
|
||||
kwargs["uuid_cursor"] = cursor
|
||||
|
||||
page_num += 1
|
||||
batch = _fetch_page_with_retry(
|
||||
client.graph.edge.get_by_graph_id,
|
||||
graph_id,
|
||||
max_retries=max_retries,
|
||||
retry_delay=retry_delay,
|
||||
page_description=f"fetch edges page {page_num} (graph={graph_id})",
|
||||
**kwargs,
|
||||
)
|
||||
if not batch:
|
||||
break
|
||||
|
||||
all_edges.extend(batch)
|
||||
if len(batch) < page_size:
|
||||
break
|
||||
|
||||
cursor = getattr(batch[-1], "uuid_", None) or getattr(batch[-1], "uuid", None)
|
||||
if cursor is None:
|
||||
logger.warning(f"Edge missing uuid field, stopping pagination at {len(all_edges)} edges")
|
||||
break
|
||||
|
||||
return all_edges
|
||||
def fetch_all_edges(client, graph_id: str, **kwargs) -> list:
|
||||
"""已废弃:请直接使用 LocalGraphStore.get_edges()"""
|
||||
logger.warning("fetch_all_edges 已废弃,请使用 LocalGraphStore.get_edges()")
|
||||
return []
|
||||
|
|
|
|||
|
|
@ -16,9 +16,6 @@ dependencies = [
|
|||
# LLM 相关
|
||||
"openai>=1.0.0",
|
||||
|
||||
# Zep Cloud
|
||||
"zep-cloud==3.13.0",
|
||||
|
||||
# OASIS 社交媒体模拟
|
||||
"camel-oasis==0.2.5",
|
||||
"camel-ai==0.2.78",
|
||||
|
|
|
|||
|
|
@ -13,9 +13,6 @@ flask-cors>=6.0.0
|
|||
# OpenAI SDK(统一使用 OpenAI 格式调用 LLM)
|
||||
openai>=1.0.0
|
||||
|
||||
# ============= Zep Cloud =============
|
||||
zep-cloud==3.13.0
|
||||
|
||||
# ============= OASIS 社交媒体模拟 =============
|
||||
# OASIS 社交模拟框架
|
||||
camel-oasis==0.2.5
|
||||
|
|
|
|||
Loading…
Reference in New Issue