refactor: replace Zep Cloud memory with local JSON file storage

Removes the zep-cloud dependency entirely and replaces it with a
local file-based graph store (LocalGraphStore) that persists nodes,
edges, and episodes as JSON files under uploads/graphs/{graph_id}/.

- Add backend/app/utils/local_graph_store.py: thread-safe JSON store
  with keyword search, node upsert, and episode append
- Rewrite graph_builder.py: LLM-based entity/relationship extraction
  from text batches, stored locally instead of sent to Zep Cloud
- Rewrite zep_graph_memory_updater.py: agent activities written as
  episodes + searchable fact edges in local JSON
- Rewrite zep_entity_reader.py: reads nodes/edges from local JSON
- Rewrite zep_tools.py: keyword search on local JSON replaces
  Zep semantic search; _local_search is now the primary path
- Update oasis_profile_generator.py: local store replaces Zep client
  for entity context enrichment
- Update ontology_generator.py: generated code template uses
  pydantic BaseModel instead of Zep EntityModel/EdgeModel
- Convert zep_paging.py to a no-op stub (pagination not needed)
- Remove ZEP_API_KEY from config.py, add GRAPH_STORAGE_DIR
- Remove ZEP_API_KEY guards from api/graph.py and api/simulation.py
- Remove zep-cloud==3.13.0 from requirements.txt and pyproject.toml

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
PMA 2026-05-21 00:25:21 -03:00
parent fa0f6519b1
commit 2bec63be1b
13 changed files with 936 additions and 1427 deletions

View File

@ -283,17 +283,6 @@ def build_graph():
try:
logger.info("=== 开始构建图谱 ===")
# 检查配置
errors = []
if not Config.ZEP_API_KEY:
errors.append(t('api.zepApiKeyMissing'))
if errors:
logger.error(f"配置错误: {errors}")
return jsonify({
"success": False,
"error": t('api.configError', details="; ".join(errors))
}), 500
# 解析请求
data = request.get_json() or {}
project_id = data.get('project_id')
@ -387,8 +376,8 @@ def build_graph():
)
# 创建图谱构建服务
builder = GraphBuilderService(api_key=Config.ZEP_API_KEY)
builder = GraphBuilderService()
# 分块
task_manager.update_task(
task_id,
@ -572,20 +561,14 @@ def get_graph_data(graph_id: str):
获取图谱数据节点和边
"""
try:
if not Config.ZEP_API_KEY:
return jsonify({
"success": False,
"error": t('api.zepApiKeyMissing')
}), 500
builder = GraphBuilderService(api_key=Config.ZEP_API_KEY)
builder = GraphBuilderService()
graph_data = builder.get_graph_data(graph_id)
return jsonify({
"success": True,
"data": graph_data
})
except Exception as e:
return jsonify({
"success": False,
@ -597,16 +580,10 @@ def get_graph_data(graph_id: str):
@graph_bp.route('/delete/<graph_id>', methods=['DELETE'])
def delete_graph(graph_id: str):
"""
删除Zep图谱
删除本地图谱
"""
try:
if not Config.ZEP_API_KEY:
return jsonify({
"success": False,
"error": t('api.zepApiKeyMissing')
}), 500
builder = GraphBuilderService(api_key=Config.ZEP_API_KEY)
builder = GraphBuilderService()
builder.delete_graph(graph_id)
return jsonify({

View File

@ -57,18 +57,12 @@ def get_graph_entities(graph_id: str):
enrich: 是否获取相关边信息默认true
"""
try:
if not Config.ZEP_API_KEY:
return jsonify({
"success": False,
"error": t('api.zepApiKeyMissing')
}), 500
entity_types_str = request.args.get('entity_types', '')
entity_types = [t.strip() for t in entity_types_str.split(',') if t.strip()] if entity_types_str else None
enrich = request.args.get('enrich', 'true').lower() == 'true'
logger.info(f"获取图谱实体: graph_id={graph_id}, entity_types={entity_types}, enrich={enrich}")
reader = ZepEntityReader()
result = reader.filter_defined_entities(
graph_id=graph_id,
@ -94,12 +88,6 @@ def get_graph_entities(graph_id: str):
def get_entity_detail(graph_id: str, entity_uuid: str):
"""获取单个实体的详细信息"""
try:
if not Config.ZEP_API_KEY:
return jsonify({
"success": False,
"error": t('api.zepApiKeyMissing')
}), 500
reader = ZepEntityReader()
entity = reader.get_entity_with_context(graph_id, entity_uuid)
@ -127,14 +115,8 @@ def get_entity_detail(graph_id: str, entity_uuid: str):
def get_entities_by_type(graph_id: str, entity_type: str):
"""获取指定类型的所有实体"""
try:
if not Config.ZEP_API_KEY:
return jsonify({
"success": False,
"error": t('api.zepApiKeyMissing')
}), 500
enrich = request.args.get('enrich', 'true').lower() == 'true'
reader = ZepEntityReader()
entities = reader.get_entities_by_type(
graph_id=graph_id,

View File

@ -32,8 +32,8 @@ class Config:
LLM_BASE_URL = os.environ.get('LLM_BASE_URL', 'https://api.openai.com/v1')
LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME', 'gpt-4o-mini')
# Zep配置
ZEP_API_KEY = os.environ.get('ZEP_API_KEY')
# 本地图谱存储目录
GRAPH_STORAGE_DIR = os.path.join(os.path.dirname(__file__), '../uploads/graphs')
# 文件上传配置
MAX_CONTENT_LENGTH = 50 * 1024 * 1024 # 50MB
@ -69,7 +69,7 @@ class Config:
errors = []
if not cls.LLM_API_KEY:
errors.append("LLM_API_KEY 未配置")
if not cls.ZEP_API_KEY:
errors.append("ZEP_API_KEY 未配置")
# 确保图谱存储目录存在
os.makedirs(cls.GRAPH_STORAGE_DIR, exist_ok=True)
return errors

View File

@ -1,6 +1,6 @@
"""
图谱构建服务
接口2使用Zep API构建Standalone Graph
使用本地JSON文件存储替代Zep Cloud
"""
import os
@ -10,14 +10,15 @@ import threading
from typing import Dict, Any, List, Optional, Callable
from dataclasses import dataclass
from zep_cloud.client import Zep
from zep_cloud import EpisodeData, EntityEdgeSourceTarget
from ..config import Config
from ..models.task import TaskManager, TaskStatus
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
from ..utils.local_graph_store import LocalGraphStore
from ..utils.llm_client import LLMClient
from .text_processor import TextProcessor
from ..utils.locale import t, get_locale, set_locale
from ..utils.logger import get_logger
logger = get_logger('mirofish.graph_builder')
@dataclass
@ -27,7 +28,7 @@ class GraphInfo:
node_count: int
edge_count: int
entity_types: List[str]
def to_dict(self) -> Dict[str, Any]:
return {
"graph_id": self.graph_id,
@ -40,17 +41,23 @@ class GraphInfo:
class GraphBuilderService:
"""
图谱构建服务
负责调用Zep API构建知识图谱
使用本地JSON文件存储构建知识图谱
"""
def __init__(self, api_key: Optional[str] = None):
self.api_key = api_key or Config.ZEP_API_KEY
if not self.api_key:
raise ValueError("ZEP_API_KEY 未配置")
self.client = Zep(api_key=self.api_key)
def __init__(self, storage_dir: Optional[str] = None, api_key: Optional[str] = None):
# api_key参数保留以兼容旧调用方式但不再使用
self.storage_dir = storage_dir or Config.GRAPH_STORAGE_DIR
self.store = LocalGraphStore(self.storage_dir)
self.task_manager = TaskManager()
self._llm: Optional[LLMClient] = None
@property
def llm(self) -> LLMClient:
"""延迟初始化LLM客户端"""
if self._llm is None:
self._llm = LLMClient()
return self._llm
def build_graph_async(
self,
text: str,
@ -62,19 +69,10 @@ class GraphBuilderService:
) -> str:
"""
异步构建图谱
Args:
text: 输入文本
ontology: 本体定义来自接口1的输出
graph_name: 图谱名称
chunk_size: 文本块大小
chunk_overlap: 块重叠大小
batch_size: 每批发送的块数量
Returns:
任务ID
"""
# 创建任务
task_id = self.task_manager.create_task(
task_type="graph_build",
metadata={
@ -83,20 +81,18 @@ class GraphBuilderService:
"text_length": len(text),
}
)
# Capture locale before spawning background thread
current_locale = get_locale()
# 在后台线程中执行构建
thread = threading.Thread(
target=self._build_graph_worker,
args=(task_id, text, ontology, graph_name, chunk_size, chunk_overlap, batch_size, current_locale)
)
thread.daemon = True
thread.start()
return task_id
def _build_graph_worker(
self,
task_id: str,
@ -117,7 +113,7 @@ class GraphBuilderService:
progress=5,
message=t('progress.startBuildingGraph')
)
# 1. 创建图谱
graph_id = self.create_graph(graph_name)
self.task_manager.update_task(
@ -125,15 +121,15 @@ class GraphBuilderService:
progress=10,
message=t('progress.graphCreated', graphId=graph_id)
)
# 2. 设置本体
# 2. 保存本体
self.set_ontology(graph_id, ontology)
self.task_manager.update_task(
task_id,
progress=15,
message=t('progress.ontologySet')
)
# 3. 文本分块
chunks = TextProcessor.split_text(text, chunk_size, chunk_overlap)
total_chunks = len(chunks)
@ -142,155 +138,47 @@ class GraphBuilderService:
progress=20,
message=t('progress.textSplit', count=total_chunks)
)
# 4. 分批发送数据
episode_uuids = self.add_text_batches(
# 4. 分批处理:提取实体并存储
self.add_text_batches(
graph_id, chunks, batch_size,
lambda msg, prog: self.task_manager.update_task(
task_id,
progress=20 + int(prog * 0.4), # 20-60%
progress=20 + int(prog * 0.7), # 20-90%
message=msg
)
)
# 5. 等待Zep处理完成
self.task_manager.update_task(
task_id,
progress=60,
message=t('progress.waitingZepProcess')
)
self._wait_for_episodes(
episode_uuids,
lambda msg, prog: self.task_manager.update_task(
task_id,
progress=60 + int(prog * 0.3), # 60-90%
message=msg
)
)
# 6. 获取图谱信息
# 5. 获取图谱信息
self.task_manager.update_task(
task_id,
progress=90,
message=t('progress.fetchingGraphInfo')
)
graph_info = self._get_graph_info(graph_id)
# 完成
self.task_manager.complete_task(task_id, {
"graph_id": graph_id,
"graph_info": graph_info.to_dict(),
"chunks_processed": total_chunks,
})
except Exception as e:
import traceback
error_msg = f"{str(e)}\n{traceback.format_exc()}"
self.task_manager.fail_task(task_id, error_msg)
def create_graph(self, name: str) -> str:
"""创建Zep图谱公开方法"""
"""创建本地图谱"""
graph_id = f"mirofish_{uuid.uuid4().hex[:16]}"
self.client.graph.create(
graph_id=graph_id,
name=name,
description="MiroFish Social Simulation Graph"
)
self.store.create_graph(graph_id, name, "MiroFish Social Simulation Graph")
return graph_id
def set_ontology(self, graph_id: str, ontology: Dict[str, Any]):
"""设置图谱本体(公开方法)"""
import warnings
from typing import Optional
from pydantic import Field
from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel
# 抑制 Pydantic v2 关于 Field(default=None) 的警告
# 这是 Zep SDK 要求的用法,警告来自动态类创建,可以安全忽略
warnings.filterwarnings('ignore', category=UserWarning, module='pydantic')
# Zep 保留名称,不能作为属性名
RESERVED_NAMES = {'uuid', 'name', 'group_id', 'name_embedding', 'summary', 'created_at'}
def safe_attr_name(attr_name: str) -> str:
"""将保留名称转换为安全名称"""
if attr_name.lower() in RESERVED_NAMES:
return f"entity_{attr_name}"
return attr_name
# 动态创建实体类型
entity_types = {}
for entity_def in ontology.get("entity_types", []):
name = entity_def["name"]
description = entity_def.get("description", f"A {name} entity.")
# 创建属性字典和类型注解Pydantic v2 需要)
attrs = {"__doc__": description}
annotations = {}
for attr_def in entity_def.get("attributes", []):
attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称
attr_desc = attr_def.get("description", attr_name)
# Zep API 需要 Field 的 description这是必需的
attrs[attr_name] = Field(description=attr_desc, default=None)
annotations[attr_name] = Optional[EntityText] # 类型注解
attrs["__annotations__"] = annotations
# 动态创建类
entity_class = type(name, (EntityModel,), attrs)
entity_class.__doc__ = description
entity_types[name] = entity_class
# 动态创建边类型
edge_definitions = {}
for edge_def in ontology.get("edge_types", []):
name = edge_def["name"]
description = edge_def.get("description", f"A {name} relationship.")
# 创建属性字典和类型注解
attrs = {"__doc__": description}
annotations = {}
for attr_def in edge_def.get("attributes", []):
attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称
attr_desc = attr_def.get("description", attr_name)
# Zep API 需要 Field 的 description这是必需的
attrs[attr_name] = Field(description=attr_desc, default=None)
annotations[attr_name] = Optional[str] # 边属性用str类型
attrs["__annotations__"] = annotations
# 动态创建类
class_name = ''.join(word.capitalize() for word in name.split('_'))
edge_class = type(class_name, (EdgeModel,), attrs)
edge_class.__doc__ = description
# 构建source_targets
source_targets = []
for st in edge_def.get("source_targets", []):
source_targets.append(
EntityEdgeSourceTarget(
source=st.get("source", "Entity"),
target=st.get("target", "Entity")
)
)
if source_targets:
edge_definitions[name] = (edge_class, source_targets)
# 调用Zep API设置本体
if entity_types or edge_definitions:
self.client.graph.set_ontology(
graph_ids=[graph_id],
entities=entity_types if entity_types else None,
edges=edge_definitions if edge_definitions else None,
)
"""保存本体定义"""
self.store.set_ontology(graph_id, ontology)
def add_text_batches(
self,
graph_id: str,
@ -298,209 +186,206 @@ class GraphBuilderService:
batch_size: int = 3,
progress_callback: Optional[Callable] = None
) -> List[str]:
"""分批添加文本到图谱,返回所有 episode 的 uuid 列表"""
"""分批处理文本:提取实体/关系并存储返回情节uuid列表"""
episode_uuids = []
ontology = self.store.get_ontology(graph_id) or {}
total_chunks = len(chunks)
for i in range(0, total_chunks, batch_size):
batch_chunks = chunks[i:i + batch_size]
batch = chunks[i:i + batch_size]
batch_num = i // batch_size + 1
total_batches = (total_chunks + batch_size - 1) // batch_size
if progress_callback:
progress = (i + len(batch_chunks)) / total_chunks
progress = (i + len(batch)) / total_chunks
progress_callback(
t('progress.sendingBatch', current=batch_num, total=total_batches, chunks=len(batch_chunks)),
t('progress.sendingBatch', current=batch_num, total=total_batches, chunks=len(batch)),
progress
)
# 构建episode数据
episodes = [
EpisodeData(data=chunk, type="text")
for chunk in batch_chunks
]
# 发送到Zep
try:
batch_result = self.client.graph.add_batch(
graph_id=graph_id,
episodes=episodes
)
# 收集返回的 episode uuid
if batch_result and isinstance(batch_result, list):
for ep in batch_result:
ep_uuid = getattr(ep, 'uuid_', None) or getattr(ep, 'uuid', None)
if ep_uuid:
episode_uuids.append(ep_uuid)
# 避免请求过快
time.sleep(1)
except Exception as e:
if progress_callback:
progress_callback(t('progress.batchFailed', batch=batch_num, error=str(e)), 0)
raise
# 存储情节文本
for text in batch:
ep_uuid = self.store.add_episode(graph_id, text)
episode_uuids.append(ep_uuid)
# 使用LLM从批次文本中提取实体和关系
if ontology.get("entity_types") or ontology.get("edge_types"):
try:
extracted = self._extract_entities_from_batch(batch, ontology)
self._store_extracted(graph_id, extracted)
except Exception as e:
logger.warning(f"批次 {batch_num} 实体提取失败: {e}")
# 轻微延迟避免LLM请求过快
time.sleep(0.3)
return episode_uuids
def _extract_entities_from_batch(self, texts: List[str], ontology: Dict[str, Any]) -> Dict[str, Any]:
"""使用LLM从文本批次中提取实体和关系"""
combined_text = "\n\n".join(texts)
entity_types_desc = "\n".join(
f"- {et['name']}: {et.get('description', '')}"
for et in ontology.get("entity_types", [])
) or "- Entity (通用实体)"
edge_types_desc = "\n".join(
f"- {rt['name']}: {rt.get('description', '')}"
for rt in ontology.get("edge_types", [])
) or "- RELATED_TO"
user_prompt = f"""从以下文本中提取实体和关系,仅使用给定的本体类型。
实体类型只能使用这些
{entity_types_desc}
关系类型只能使用这些
{edge_types_desc}
文本
{combined_text[:4000]}
返回JSON格式
{{
"entities": [
{{"name": "实体名称", "type": "实体类型", "summary": "一句话描述", "attributes": {{}}}}
],
"relationships": [
{{"source": "源实体名称", "target": "目标实体名称", "type": "关系类型", "fact": "事实描述"}}
]
}}
规则
- 仅使用本体中定义的实体类型和关系类型
- 实体名称应具体人名地名组织名等
- fact字段应是简洁的事实陈述
- 若找不到匹配项返回空列表"""
try:
result = self.llm.chat_json(
messages=[{"role": "user", "content": user_prompt}],
temperature=0.1
)
return result if isinstance(result, dict) else {"entities": [], "relationships": []}
except Exception as e:
logger.warning(f"实体提取LLM调用失败: {e}")
return {"entities": [], "relationships": []}
def _store_extracted(self, graph_id: str, extracted: Dict[str, Any]):
"""将LLM提取的实体和关系存储到本地图谱"""
entities = extracted.get("entities", []) or []
relationships = extracted.get("relationships", []) or []
name_to_uuid: Dict[str, str] = {}
for entity in entities:
name = (entity.get("name") or "").strip()
if not name:
continue
entity_type = entity.get("type") or "Entity"
summary = entity.get("summary") or ""
attributes = entity.get("attributes") or {}
labels = [entity_type, "Entity"] if entity_type != "Entity" else ["Entity"]
node_uuid = self.store.upsert_node(
graph_id=graph_id,
name=name,
labels=labels,
summary=summary,
attributes=attributes,
)
name_to_uuid[name.lower()] = node_uuid
for rel in relationships:
source_name = (rel.get("source") or "").strip()
target_name = (rel.get("target") or "").strip()
rel_type = rel.get("type") or "RELATED_TO"
fact = rel.get("fact") or ""
if not source_name or not target_name or not fact:
continue
source_uuid = name_to_uuid.get(source_name.lower()) or \
self.store.upsert_node(graph_id, source_name, ["Entity"])
name_to_uuid[source_name.lower()] = source_uuid
target_uuid = name_to_uuid.get(target_name.lower()) or \
self.store.upsert_node(graph_id, target_name, ["Entity"])
name_to_uuid[target_name.lower()] = target_uuid
self.store.add_fact_edge(
graph_id=graph_id,
source_uuid=source_uuid,
target_uuid=target_uuid,
name=rel_type,
fact=fact,
)
def _wait_for_episodes(
self,
episode_uuids: List[str],
progress_callback: Optional[Callable] = None,
timeout: int = 600
):
"""等待所有 episode 处理完成(通过查询每个 episode 的 processed 状态)"""
if not episode_uuids:
if progress_callback:
progress_callback(t('progress.noEpisodesWait'), 1.0)
return
start_time = time.time()
pending_episodes = set(episode_uuids)
completed_count = 0
total_episodes = len(episode_uuids)
"""本地存储中情节立即处理完成,无需等待"""
if progress_callback:
progress_callback(t('progress.waitingEpisodes', count=total_episodes), 0)
while pending_episodes:
if time.time() - start_time > timeout:
if progress_callback:
progress_callback(
t('progress.episodesTimeout', completed=completed_count, total=total_episodes),
completed_count / total_episodes
)
break
# 检查每个 episode 的处理状态
for ep_uuid in list(pending_episodes):
try:
episode = self.client.graph.episode.get(uuid_=ep_uuid)
is_processed = getattr(episode, 'processed', False)
if is_processed:
pending_episodes.remove(ep_uuid)
completed_count += 1
except Exception as e:
# 忽略单个查询错误,继续
pass
elapsed = int(time.time() - start_time)
if progress_callback:
progress_callback(
t('progress.zepProcessing', completed=completed_count, total=total_episodes, pending=len(pending_episodes), elapsed=elapsed),
completed_count / total_episodes if total_episodes > 0 else 0
)
if pending_episodes:
time.sleep(3) # 每3秒检查一次
if progress_callback:
progress_callback(t('progress.processingComplete', completed=completed_count, total=total_episodes), 1.0)
progress_callback(t('progress.processingComplete',
completed=len(episode_uuids),
total=len(episode_uuids)), 1.0)
def _get_graph_info(self, graph_id: str) -> GraphInfo:
"""获取图谱信息"""
# 获取节点(分页)
nodes = fetch_all_nodes(self.client, graph_id)
"""获取图谱统计信息"""
nodes = self.store.get_nodes(graph_id)
edges = self.store.get_edges(graph_id)
# 获取边(分页)
edges = fetch_all_edges(self.client, graph_id)
# 统计实体类型
entity_types = set()
for node in nodes:
if node.labels:
for label in node.labels:
if label not in ["Entity", "Node"]:
entity_types.add(label)
for label in (node.get("labels") or []):
if label not in ("Entity", "Node"):
entity_types.add(label)
return GraphInfo(
graph_id=graph_id,
node_count=len(nodes),
edge_count=len(edges),
entity_types=list(entity_types)
entity_types=list(entity_types),
)
def get_graph_data(self, graph_id: str) -> Dict[str, Any]:
"""
获取完整图谱数据包含详细信息
Args:
graph_id: 图谱ID
Returns:
包含nodes和edges的字典包括时间信息属性等详细数据
"""
nodes = fetch_all_nodes(self.client, graph_id)
edges = fetch_all_edges(self.client, graph_id)
# 创建节点映射用于获取节点名称
node_map = {}
for node in nodes:
node_map[node.uuid_] = node.name or ""
nodes_data = []
for node in nodes:
# 获取创建时间
created_at = getattr(node, 'created_at', None)
if created_at:
created_at = str(created_at)
nodes_data.append({
"uuid": node.uuid_,
"name": node.name,
"labels": node.labels or [],
"summary": node.summary or "",
"attributes": node.attributes or {},
"created_at": created_at,
})
def get_graph_data(self, graph_id: str) -> Dict[str, Any]:
"""获取完整图谱数据(含节点和边详情)"""
nodes = self.store.get_nodes(graph_id)
edges = self.store.get_edges(graph_id)
node_map = {n["uuid"]: n.get("name", "") for n in nodes}
edges_data = []
for edge in edges:
# 获取时间信息
created_at = getattr(edge, 'created_at', None)
valid_at = getattr(edge, 'valid_at', None)
invalid_at = getattr(edge, 'invalid_at', None)
expired_at = getattr(edge, 'expired_at', None)
# 获取 episodes
episodes = getattr(edge, 'episodes', None) or getattr(edge, 'episode_ids', None)
if episodes and not isinstance(episodes, list):
episodes = [str(episodes)]
elif episodes:
episodes = [str(e) for e in episodes]
# 获取 fact_type
fact_type = getattr(edge, 'fact_type', None) or edge.name or ""
edges_data.append({
"uuid": edge.uuid_,
"name": edge.name or "",
"fact": edge.fact or "",
"fact_type": fact_type,
"source_node_uuid": edge.source_node_uuid,
"target_node_uuid": edge.target_node_uuid,
"source_node_name": node_map.get(edge.source_node_uuid, ""),
"target_node_name": node_map.get(edge.target_node_uuid, ""),
"attributes": edge.attributes or {},
"created_at": str(created_at) if created_at else None,
"valid_at": str(valid_at) if valid_at else None,
"invalid_at": str(invalid_at) if invalid_at else None,
"expired_at": str(expired_at) if expired_at else None,
"episodes": episodes or [],
"uuid": edge.get("uuid", ""),
"name": edge.get("name", ""),
"fact": edge.get("fact", ""),
"fact_type": edge.get("name", ""),
"source_node_uuid": edge.get("source_node_uuid", ""),
"target_node_uuid": edge.get("target_node_uuid", ""),
"source_node_name": node_map.get(edge.get("source_node_uuid", ""), ""),
"target_node_name": node_map.get(edge.get("target_node_uuid", ""), ""),
"attributes": edge.get("attributes", {}),
"created_at": edge.get("created_at"),
"valid_at": edge.get("valid_at"),
"invalid_at": edge.get("invalid_at"),
"expired_at": edge.get("expired_at"),
"episodes": [],
})
return {
"graph_id": graph_id,
"nodes": nodes_data,
"nodes": nodes,
"edges": edges_data,
"node_count": len(nodes_data),
"edge_count": len(edges_data),
"node_count": len(nodes),
"edge_count": len(edges),
}
def delete_graph(self, graph_id: str):
"""删除图谱"""
self.client.graph.delete(graph_id=graph_id)
self.store.delete_graph(graph_id)

View File

@ -16,9 +16,9 @@ from dataclasses import dataclass, field
from datetime import datetime
from openai import OpenAI
from zep_cloud.client import Zep
from ..config import Config
from ..utils.local_graph_store import LocalGraphStore
from ..utils.logger import get_logger
from ..utils.locale import get_language_instruction, get_locale, set_locale, t
from .zep_entity_reader import EntityNode, ZepEntityReader
@ -179,35 +179,30 @@ class OasisProfileGenerator:
]
def __init__(
self,
self,
api_key: Optional[str] = None,
base_url: Optional[str] = None,
model_name: Optional[str] = None,
zep_api_key: Optional[str] = None,
graph_id: Optional[str] = None
zep_api_key: Optional[str] = None, # 已废弃,保留以兼容旧调用
graph_id: Optional[str] = None,
storage_dir: Optional[str] = None,
):
self.api_key = api_key or Config.LLM_API_KEY
self.base_url = base_url or Config.LLM_BASE_URL
self.model_name = model_name or Config.LLM_MODEL_NAME
if not self.api_key:
raise ValueError("LLM_API_KEY 未配置")
self.client = OpenAI(
api_key=self.api_key,
base_url=self.base_url
)
# Zep客户端用于检索丰富上下文
self.zep_api_key = zep_api_key or Config.ZEP_API_KEY
self.zep_client = None
# 本地图谱存储
storage_dir = storage_dir or Config.GRAPH_STORAGE_DIR
self.store = LocalGraphStore(storage_dir)
self.graph_id = graph_id
if self.zep_api_key:
try:
self.zep_client = Zep(api_key=self.zep_api_key)
except Exception as e:
logger.warning(f"Zep客户端初始化失败: {e}")
def generate_profile_from_entity(
self,
@ -285,130 +280,53 @@ class OasisProfileGenerator:
def _search_zep_for_entity(self, entity: EntityNode) -> Dict[str, Any]:
"""
使用Zep图谱混合搜索功能获取实体相关的丰富信息
Zep没有内置混合搜索接口需要分别搜索edges和nodes然后合并结果
使用并行请求同时搜索提高效率
使用本地图谱关键词搜索获取实体相关的丰富信息
Args:
entity: 实体节点对象
Returns:
包含facts, node_summaries, context的字典
"""
import concurrent.futures
if not self.zep_client:
return {"facts": [], "node_summaries": [], "context": ""}
entity_name = entity.name
results = {
"facts": [],
"node_summaries": [],
"context": ""
}
# 必须有graph_id才能进行搜索
results: Dict[str, Any] = {"facts": [], "node_summaries": [], "context": ""}
if not self.graph_id:
logger.debug(f"跳过Zep检索未设置graph_id")
logger.debug("跳过本地检索未设置graph_id")
return results
comprehensive_query = t('progress.zepSearchQuery', name=entity_name)
def search_edges():
"""搜索边(事实/关系)- 带重试机制"""
max_retries = 3
last_exception = None
delay = 2.0
for attempt in range(max_retries):
try:
return self.zep_client.graph.search(
query=comprehensive_query,
graph_id=self.graph_id,
limit=30,
scope="edges",
reranker="rrf"
)
except Exception as e:
last_exception = e
if attempt < max_retries - 1:
logger.debug(f"Zep边搜索第 {attempt + 1} 次失败: {str(e)[:80]}, 重试中...")
time.sleep(delay)
delay *= 2
else:
logger.debug(f"Zep边搜索在 {max_retries} 次尝试后仍失败: {e}")
return None
def search_nodes():
"""搜索节点(实体摘要)- 带重试机制"""
max_retries = 3
last_exception = None
delay = 2.0
for attempt in range(max_retries):
try:
return self.zep_client.graph.search(
query=comprehensive_query,
graph_id=self.graph_id,
limit=20,
scope="nodes",
reranker="rrf"
)
except Exception as e:
last_exception = e
if attempt < max_retries - 1:
logger.debug(f"Zep节点搜索第 {attempt + 1} 次失败: {str(e)[:80]}, 重试中...")
time.sleep(delay)
delay *= 2
else:
logger.debug(f"Zep节点搜索在 {max_retries} 次尝试后仍失败: {e}")
return None
entity_name = entity.name
query = t('progress.zepSearchQuery', name=entity_name)
try:
# 并行执行edges和nodes搜索
with concurrent.futures.ThreadPoolExecutor(max_workers=2) as executor:
edge_future = executor.submit(search_edges)
node_future = executor.submit(search_nodes)
# 获取结果
edge_result = edge_future.result(timeout=30)
node_result = node_future.result(timeout=30)
# 处理边搜索结果
all_facts = set()
if edge_result and hasattr(edge_result, 'edges') and edge_result.edges:
for edge in edge_result.edges:
if hasattr(edge, 'fact') and edge.fact:
all_facts.add(edge.fact)
results["facts"] = list(all_facts)
# 处理节点搜索结果
all_summaries = set()
if node_result and hasattr(node_result, 'nodes') and node_result.nodes:
for node in node_result.nodes:
if hasattr(node, 'summary') and node.summary:
all_summaries.add(node.summary)
if hasattr(node, 'name') and node.name and node.name != entity_name:
all_summaries.add(f"相关实体: {node.name}")
results["node_summaries"] = list(all_summaries)
# 构建综合上下文
# 搜索边(事实)
edge_raw = self.store.search(self.graph_id, query, limit=30, scope="edges")
facts = list({e.get("fact", "") for e in edge_raw.get("edges", []) if e.get("fact")})
results["facts"] = facts
# 搜索节点(摘要)
node_raw = self.store.search(self.graph_id, query, limit=20, scope="nodes")
summaries = set()
for n in node_raw.get("nodes", []):
if n.get("summary"):
summaries.add(n["summary"])
if n.get("name") and n["name"] != entity_name:
summaries.add(f"相关实体: {n['name']}")
results["node_summaries"] = list(summaries)
# 构建上下文
context_parts = []
if results["facts"]:
context_parts.append("事实信息:\n" + "\n".join(f"- {f}" for f in results["facts"][:20]))
if results["node_summaries"]:
context_parts.append("相关实体:\n" + "\n".join(f"- {s}" for s in results["node_summaries"][:10]))
results["context"] = "\n\n".join(context_parts)
logger.info(f"Zep混合检索完成: {entity_name}, 获取 {len(results['facts'])} 条事实, {len(results['node_summaries'])} 个相关节点")
except concurrent.futures.TimeoutError:
logger.warning(f"Zep检索超时 ({entity_name})")
logger.info(f"本地检索完成: {entity_name}, 获取 {len(results['facts'])} 条事实, "
f"{len(results['node_summaries'])} 个相关节点")
except Exception as e:
logger.warning(f"Zep检索失败 ({entity_name}): {e}")
logger.warning(f"本地检索失败 ({entity_name}): {e}")
return results
def _build_entity_context(self, entity: EntityNode) -> str:

View File

@ -413,8 +413,8 @@ class OntologyGenerator:
'由MiroFish自动生成用于社会舆论模拟',
'"""',
'',
'from pydantic import Field',
'from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel',
'from typing import Optional',
'from pydantic import BaseModel, Field',
'',
'',
'# ============== 实体类型定义 ==============',
@ -426,15 +426,15 @@ class OntologyGenerator:
name = entity["name"]
desc = entity.get("description", f"A {name} entity.")
code_lines.append(f'class {name}(EntityModel):')
code_lines.append(f'class {name}(BaseModel):')
code_lines.append(f' """{desc}"""')
attrs = entity.get("attributes", [])
if attrs:
for attr in attrs:
attr_name = attr["name"]
attr_desc = attr.get("description", attr_name)
code_lines.append(f' {attr_name}: EntityText = Field(')
code_lines.append(f' {attr_name}: Optional[str] = Field(')
code_lines.append(f' description="{attr_desc}",')
code_lines.append(f' default=None')
code_lines.append(f' )')
@ -454,15 +454,15 @@ class OntologyGenerator:
class_name = ''.join(word.capitalize() for word in name.split('_'))
desc = edge.get("description", f"A {name} relationship.")
code_lines.append(f'class {class_name}(EdgeModel):')
code_lines.append(f'class {class_name}(BaseModel):')
code_lines.append(f' """{desc}"""')
attrs = edge.get("attributes", [])
if attrs:
for attr in attrs:
attr_name = attr["name"]
attr_desc = attr.get("description", attr_name)
code_lines.append(f' {attr_name}: EntityText = Field(')
code_lines.append(f' {attr_name}: Optional[str] = Field(')
code_lines.append(f' description="{attr_desc}",')
code_lines.append(f' default=None')
code_lines.append(f' )')

View File

@ -1,23 +1,17 @@
"""
Zep实体读取与过滤服务
Zep图谱中读取节点筛选出符合预定义实体类型的节点
实体读取与过滤服务
本地JSON图谱中读取节点筛选出符合预定义实体类型的节点
"""
import time
from typing import Dict, Any, List, Optional, Set, Callable, TypeVar
from typing import Dict, Any, List, Optional, Set
from dataclasses import dataclass, field
from zep_cloud.client import Zep
from ..config import Config
from ..utils.local_graph_store import LocalGraphStore
from ..utils.logger import get_logger
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
logger = get_logger('mirofish.zep_entity_reader')
# 用于泛型返回类型
T = TypeVar('T')
@dataclass
class EntityNode:
@ -27,11 +21,9 @@ class EntityNode:
labels: List[str]
summary: str
attributes: Dict[str, Any]
# 相关的边信息
related_edges: List[Dict[str, Any]] = field(default_factory=list)
# 相关的其他节点信息
related_nodes: List[Dict[str, Any]] = field(default_factory=list)
def to_dict(self) -> Dict[str, Any]:
return {
"uuid": self.uuid,
@ -42,11 +34,11 @@ class EntityNode:
"related_edges": self.related_edges,
"related_nodes": self.related_nodes,
}
def get_entity_type(self) -> Optional[str]:
"""获取实体类型排除默认的Entity标签)"""
"""获取实体类型排除默认的Entity/Node标签)"""
for label in self.labels:
if label not in ["Entity", "Node"]:
if label not in ("Entity", "Node"):
return label
return None
@ -58,7 +50,7 @@ class FilteredEntities:
entity_types: Set[str]
total_count: int
filtered_count: int
def to_dict(self) -> Dict[str, Any]:
return {
"entities": [e.to_dict() for e in self.entities],
@ -70,368 +62,215 @@ class FilteredEntities:
class ZepEntityReader:
"""
Zep实体读取与过滤服务
实体读取与过滤服务
主要功能
1. Zep图谱读取所有节点
1. 本地图谱读取所有节点
2. 筛选出符合预定义实体类型的节点Labels不只是Entity的节点
3. 获取每个实体的相关边和关联节点信息
"""
def __init__(self, api_key: Optional[str] = None):
self.api_key = api_key or Config.ZEP_API_KEY
if not self.api_key:
raise ValueError("ZEP_API_KEY 未配置")
self.client = Zep(api_key=self.api_key)
def _call_with_retry(
self,
func: Callable[[], T],
operation_name: str,
max_retries: int = 3,
initial_delay: float = 2.0
) -> T:
"""
带重试机制的Zep API调用
Args:
func: 要执行的函数无参数的lambda或callable
operation_name: 操作名称用于日志
max_retries: 最大重试次数默认3次即最多尝试3次
initial_delay: 初始延迟秒数
Returns:
API调用结果
"""
last_exception = None
delay = initial_delay
for attempt in range(max_retries):
try:
return func()
except Exception as e:
last_exception = e
if attempt < max_retries - 1:
logger.warning(
f"Zep {operation_name}{attempt + 1} 次尝试失败: {str(e)[:100]}, "
f"{delay:.1f}秒后重试..."
)
time.sleep(delay)
delay *= 2 # 指数退避
else:
logger.error(f"Zep {operation_name}{max_retries} 次尝试后仍失败: {str(e)}")
raise last_exception
def __init__(self, storage_dir: Optional[str] = None, api_key: Optional[str] = None):
# api_key参数保留以兼容旧调用方式但不再使用
storage_dir = storage_dir or Config.GRAPH_STORAGE_DIR
self.store = LocalGraphStore(storage_dir)
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
"""
获取图谱的所有节点分页获取
Args:
graph_id: 图谱ID
Returns:
节点列表
"""
"""获取图谱的所有节点"""
logger.info(f"获取图谱 {graph_id} 的所有节点...")
nodes = fetch_all_nodes(self.client, graph_id)
nodes_data = []
for node in nodes:
nodes_data.append({
"uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
"name": node.name or "",
"labels": node.labels or [],
"summary": node.summary or "",
"attributes": node.attributes or {},
})
logger.info(f"共获取 {len(nodes_data)} 个节点")
return nodes_data
nodes = self.store.get_nodes(graph_id)
logger.info(f"共获取 {len(nodes)} 个节点")
return nodes
def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]:
"""
获取图谱的所有边分页获取
Args:
graph_id: 图谱ID
Returns:
边列表
"""
"""获取图谱的所有边"""
logger.info(f"获取图谱 {graph_id} 的所有边...")
edges = self.store.get_edges(graph_id)
logger.info(f"共获取 {len(edges)} 条边")
return edges
edges = fetch_all_edges(self.client, graph_id)
edges_data = []
for edge in edges:
edges_data.append({
"uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
"name": edge.name or "",
"fact": edge.fact or "",
"source_node_uuid": edge.source_node_uuid,
"target_node_uuid": edge.target_node_uuid,
"attributes": edge.attributes or {},
})
logger.info(f"共获取 {len(edges_data)} 条边")
return edges_data
def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]:
"""
获取指定节点的所有相关边带重试机制
Args:
node_uuid: 节点UUID
Returns:
边列表
"""
def get_node_edges(self, graph_id: str, node_uuid: str) -> List[Dict[str, Any]]:
"""获取指定节点的所有相关边"""
try:
# 使用重试机制调用Zep API
edges = self._call_with_retry(
func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid),
operation_name=f"获取节点边(node={node_uuid[:8]}...)"
)
edges_data = []
for edge in edges:
edges_data.append({
"uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
"name": edge.name or "",
"fact": edge.fact or "",
"source_node_uuid": edge.source_node_uuid,
"target_node_uuid": edge.target_node_uuid,
"attributes": edge.attributes or {},
})
return edges_data
return self.store.get_node_edges(graph_id, node_uuid)
except Exception as e:
logger.warning(f"获取节点 {node_uuid} 的边失败: {str(e)}")
logger.warning(f"获取节点 {node_uuid} 的边失败: {e}")
return []
def filter_defined_entities(
self,
self,
graph_id: str,
defined_entity_types: Optional[List[str]] = None,
enrich_with_edges: bool = True
) -> FilteredEntities:
"""
筛选出符合预定义实体类型的节点
筛选逻辑
- 如果节点的Labels只有一个"Entity"说明这个实体不符合我们预定义的类型跳过
- 如果节点的Labels包含除"Entity""Node"之外的标签说明符合预定义类型保留
- 节点的Labels包含除"Entity""Node"之外的标签 符合预定义类型保留
- 节点的Labels只有"Entity"/"Node" 不符合跳过
Args:
graph_id: 图谱ID
defined_entity_types: 预定义实体类型列表可选如果提供则只保留这些类型
defined_entity_types: 预定义实体类型列表可选若提供则只保留这些类型
enrich_with_edges: 是否获取每个实体的相关边信息
Returns:
FilteredEntities: 过滤后的实体集合
"""
logger.info(f"开始筛选图谱 {graph_id} 的实体...")
# 获取所有节点
all_nodes = self.get_all_nodes(graph_id)
total_count = len(all_nodes)
# 获取所有边(用于后续关联查找)
all_edges = self.get_all_edges(graph_id) if enrich_with_edges else []
# 构建节点UUID到节点数据的映射
node_map = {n["uuid"]: n for n in all_nodes}
# 筛选符合条件的实体
filtered_entities = []
entity_types_found = set()
entity_types_found: Set[str] = set()
for node in all_nodes:
labels = node.get("labels", [])
# 筛选逻辑Labels必须包含除"Entity"和"Node"之外的标签
custom_labels = [l for l in labels if l not in ["Entity", "Node"]]
labels = node.get("labels") or []
custom_labels = [l for l in labels if l not in ("Entity", "Node")]
if not custom_labels:
# 只有默认标签,跳过
continue
# 如果指定了预定义类型,检查是否匹配
if defined_entity_types:
matching_labels = [l for l in custom_labels if l in defined_entity_types]
if not matching_labels:
matching = [l for l in custom_labels if l in defined_entity_types]
if not matching:
continue
entity_type = matching_labels[0]
entity_type = matching[0]
else:
entity_type = custom_labels[0]
entity_types_found.add(entity_type)
# 创建实体节点对象
entity = EntityNode(
uuid=node["uuid"],
name=node["name"],
name=node.get("name", ""),
labels=labels,
summary=node["summary"],
attributes=node["attributes"],
summary=node.get("summary", ""),
attributes=node.get("attributes", {}),
)
# 获取相关边和节点
if enrich_with_edges:
related_edges = []
related_node_uuids = set()
related_node_uuids: Set[str] = set()
for edge in all_edges:
if edge["source_node_uuid"] == node["uuid"]:
if edge.get("source_node_uuid") == node["uuid"]:
related_edges.append({
"direction": "outgoing",
"edge_name": edge["name"],
"fact": edge["fact"],
"target_node_uuid": edge["target_node_uuid"],
"edge_name": edge.get("name", ""),
"fact": edge.get("fact", ""),
"target_node_uuid": edge.get("target_node_uuid", ""),
})
related_node_uuids.add(edge["target_node_uuid"])
elif edge["target_node_uuid"] == node["uuid"]:
related_node_uuids.add(edge.get("target_node_uuid", ""))
elif edge.get("target_node_uuid") == node["uuid"]:
related_edges.append({
"direction": "incoming",
"edge_name": edge["name"],
"fact": edge["fact"],
"source_node_uuid": edge["source_node_uuid"],
"edge_name": edge.get("name", ""),
"fact": edge.get("fact", ""),
"source_node_uuid": edge.get("source_node_uuid", ""),
})
related_node_uuids.add(edge["source_node_uuid"])
related_node_uuids.add(edge.get("source_node_uuid", ""))
entity.related_edges = related_edges
# 获取关联节点的基本信息
related_nodes = []
for related_uuid in related_node_uuids:
if related_uuid in node_map:
related_node = node_map[related_uuid]
if related_uuid and related_uuid in node_map:
rn = node_map[related_uuid]
related_nodes.append({
"uuid": related_node["uuid"],
"name": related_node["name"],
"labels": related_node["labels"],
"summary": related_node.get("summary", ""),
"uuid": rn["uuid"],
"name": rn.get("name", ""),
"labels": rn.get("labels", []),
"summary": rn.get("summary", ""),
})
entity.related_nodes = related_nodes
filtered_entities.append(entity)
logger.info(f"筛选完成: 总节点 {total_count}, 符合条件 {len(filtered_entities)}, "
f"实体类型: {entity_types_found}")
f"实体类型: {entity_types_found}")
return FilteredEntities(
entities=filtered_entities,
entity_types=entity_types_found,
total_count=total_count,
filtered_count=len(filtered_entities),
)
def get_entity_with_context(
self,
graph_id: str,
self,
graph_id: str,
entity_uuid: str
) -> Optional[EntityNode]:
"""
获取单个实体及其完整上下文边和关联节点带重试机制
Args:
graph_id: 图谱ID
entity_uuid: 实体UUID
Returns:
EntityNode或None
"""
"""获取单个实体及其完整上下文(边和关联节点)"""
try:
# 使用重试机制获取节点
node = self._call_with_retry(
func=lambda: self.client.graph.node.get(uuid_=entity_uuid),
operation_name=f"获取节点详情(uuid={entity_uuid[:8]}...)"
)
node = self.store.get_node(graph_id, entity_uuid)
if not node:
return None
# 获取节点的边
edges = self.get_node_edges(entity_uuid)
# 获取所有节点用于关联查找
edges = self.get_node_edges(graph_id, entity_uuid)
all_nodes = self.get_all_nodes(graph_id)
node_map = {n["uuid"]: n for n in all_nodes}
# 处理相关边和节点
related_edges = []
related_node_uuids = set()
related_node_uuids: Set[str] = set()
for edge in edges:
if edge["source_node_uuid"] == entity_uuid:
if edge.get("source_node_uuid") == entity_uuid:
related_edges.append({
"direction": "outgoing",
"edge_name": edge["name"],
"fact": edge["fact"],
"target_node_uuid": edge["target_node_uuid"],
"edge_name": edge.get("name", ""),
"fact": edge.get("fact", ""),
"target_node_uuid": edge.get("target_node_uuid", ""),
})
related_node_uuids.add(edge["target_node_uuid"])
related_node_uuids.add(edge.get("target_node_uuid", ""))
else:
related_edges.append({
"direction": "incoming",
"edge_name": edge["name"],
"fact": edge["fact"],
"source_node_uuid": edge["source_node_uuid"],
"edge_name": edge.get("name", ""),
"fact": edge.get("fact", ""),
"source_node_uuid": edge.get("source_node_uuid", ""),
})
related_node_uuids.add(edge["source_node_uuid"])
# 获取关联节点信息
related_node_uuids.add(edge.get("source_node_uuid", ""))
related_nodes = []
for related_uuid in related_node_uuids:
if related_uuid in node_map:
related_node = node_map[related_uuid]
if related_uuid and related_uuid in node_map:
rn = node_map[related_uuid]
related_nodes.append({
"uuid": related_node["uuid"],
"name": related_node["name"],
"labels": related_node["labels"],
"summary": related_node.get("summary", ""),
"uuid": rn["uuid"],
"name": rn.get("name", ""),
"labels": rn.get("labels", []),
"summary": rn.get("summary", ""),
})
return EntityNode(
uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
name=node.name or "",
labels=node.labels or [],
summary=node.summary or "",
attributes=node.attributes or {},
uuid=node["uuid"],
name=node.get("name", ""),
labels=node.get("labels", []),
summary=node.get("summary", ""),
attributes=node.get("attributes", {}),
related_edges=related_edges,
related_nodes=related_nodes,
)
except Exception as e:
logger.error(f"获取实体 {entity_uuid} 失败: {str(e)}")
logger.error(f"获取实体 {entity_uuid} 失败: {e}")
return None
def get_entities_by_type(
self,
graph_id: str,
self,
graph_id: str,
entity_type: str,
enrich_with_edges: bool = True
) -> List[EntityNode]:
"""
获取指定类型的所有实体
Args:
graph_id: 图谱ID
entity_type: 实体类型 "Student", "PublicFigure"
enrich_with_edges: 是否获取相关边信息
Returns:
实体列表
"""
"""获取指定类型的所有实体"""
result = self.filter_defined_entities(
graph_id=graph_id,
defined_entity_types=[entity_type],
enrich_with_edges=enrich_with_edges
)
return result.entities

View File

@ -1,6 +1,6 @@
"""
Zep图谱记忆更新服务
将模拟中的Agent活动动态更新到Zep图谱中
图谱记忆更新服务
将模拟中的Agent活动动态写入本地JSON图谱文件
"""
import os
@ -12,9 +12,8 @@ from dataclasses import dataclass
from datetime import datetime
from queue import Queue, Empty
from zep_cloud.client import Zep
from ..config import Config
from ..utils.local_graph_store import LocalGraphStore
from ..utils.logger import get_logger
from ..utils.locale import get_locale, set_locale
@ -31,15 +30,14 @@ class AgentActivity:
action_args: Dict[str, Any]
round_num: int
timestamp: str
def to_episode_text(self) -> str:
"""
将活动转换为可以发送给Zep的文本描述
采用自然语言描述格式Zep能够从中提取实体和关系
将活动转换为自然语言描述文本
采用自然语言描述格式图谱能够从中提取实体和关系
不添加模拟相关的前缀避免误导图谱更新
"""
# 根据不同的动作类型生成不同的描述
action_descriptions = {
"CREATE_POST": self._describe_create_post,
"LIKE_POST": self._describe_like_post,
@ -54,24 +52,22 @@ class AgentActivity:
"SEARCH_USER": self._describe_search_user,
"MUTE": self._describe_mute,
}
describe_func = action_descriptions.get(self.action_type, self._describe_generic)
description = describe_func()
# 直接返回 "agent名称: 活动描述" 格式,不添加模拟前缀
return f"{self.agent_name}: {description}"
def _describe_create_post(self) -> str:
content = self.action_args.get("content", "")
if content:
return f"发布了一条帖子:「{content}"
return "发布了一条帖子"
def _describe_like_post(self) -> str:
"""点赞帖子 - 包含帖子原文和作者信息"""
post_content = self.action_args.get("post_content", "")
post_author = self.action_args.get("post_author_name", "")
if post_content and post_author:
return f"点赞了{post_author}的帖子:「{post_content}"
elif post_content:
@ -79,12 +75,11 @@ class AgentActivity:
elif post_author:
return f"点赞了{post_author}的一条帖子"
return "点赞了一条帖子"
def _describe_dislike_post(self) -> str:
"""踩帖子 - 包含帖子原文和作者信息"""
post_content = self.action_args.get("post_content", "")
post_author = self.action_args.get("post_author_name", "")
if post_content and post_author:
return f"踩了{post_author}的帖子:「{post_content}"
elif post_content:
@ -92,12 +87,11 @@ class AgentActivity:
elif post_author:
return f"踩了{post_author}的一条帖子"
return "踩了一条帖子"
def _describe_repost(self) -> str:
"""转发帖子 - 包含原帖内容和作者信息"""
original_content = self.action_args.get("original_content", "")
original_author = self.action_args.get("original_author_name", "")
if original_content and original_author:
return f"转发了{original_author}的帖子:「{original_content}"
elif original_content:
@ -105,13 +99,12 @@ class AgentActivity:
elif original_author:
return f"转发了{original_author}的一条帖子"
return "转发了一条帖子"
def _describe_quote_post(self) -> str:
"""引用帖子 - 包含原帖内容、作者信息和引用评论"""
original_content = self.action_args.get("original_content", "")
original_author = self.action_args.get("original_author_name", "")
quote_content = self.action_args.get("quote_content", "") or self.action_args.get("content", "")
base = ""
if original_content and original_author:
base = f"引用了{original_author}的帖子「{original_content}"
@ -121,25 +114,22 @@ class AgentActivity:
base = f"引用了{original_author}的一条帖子"
else:
base = "引用了一条帖子"
if quote_content:
base += f",并评论道:「{quote_content}"
return base
def _describe_follow(self) -> str:
"""关注用户 - 包含被关注用户的名称"""
target_user_name = self.action_args.get("target_user_name", "")
if target_user_name:
return f"关注了用户「{target_user_name}"
return "关注了一个用户"
def _describe_create_comment(self) -> str:
"""发表评论 - 包含评论内容和所评论的帖子信息"""
content = self.action_args.get("content", "")
post_content = self.action_args.get("post_content", "")
post_author = self.action_args.get("post_author_name", "")
if content:
if post_content and post_author:
return f"{post_author}的帖子「{post_content}」下评论道:「{content}"
@ -149,12 +139,11 @@ class AgentActivity:
return f"{post_author}的帖子下评论道:「{content}"
return f"评论道:「{content}"
return "发表了评论"
def _describe_like_comment(self) -> str:
"""点赞评论 - 包含评论内容和作者信息"""
comment_content = self.action_args.get("comment_content", "")
comment_author = self.action_args.get("comment_author_name", "")
if comment_content and comment_author:
return f"点赞了{comment_author}的评论:「{comment_content}"
elif comment_content:
@ -162,12 +151,11 @@ class AgentActivity:
elif comment_author:
return f"点赞了{comment_author}的一条评论"
return "点赞了一条评论"
def _describe_dislike_comment(self) -> str:
"""踩评论 - 包含评论内容和作者信息"""
comment_content = self.action_args.get("comment_content", "")
comment_author = self.action_args.get("comment_author_name", "")
if comment_content and comment_author:
return f"踩了{comment_author}的评论:「{comment_content}"
elif comment_content:
@ -175,109 +163,83 @@ class AgentActivity:
elif comment_author:
return f"踩了{comment_author}的一条评论"
return "踩了一条评论"
def _describe_search(self) -> str:
"""搜索帖子 - 包含搜索关键词"""
query = self.action_args.get("query", "") or self.action_args.get("keyword", "")
return f"搜索了「{query}" if query else "进行了搜索"
def _describe_search_user(self) -> str:
"""搜索用户 - 包含搜索关键词"""
query = self.action_args.get("query", "") or self.action_args.get("username", "")
return f"搜索了用户「{query}" if query else "搜索了用户"
def _describe_mute(self) -> str:
"""屏蔽用户 - 包含被屏蔽用户的名称"""
target_user_name = self.action_args.get("target_user_name", "")
if target_user_name:
return f"屏蔽了用户「{target_user_name}"
return "屏蔽了一个用户"
def _describe_generic(self) -> str:
# 对于未知的动作类型,生成通用描述
return f"执行了{self.action_type}操作"
class ZepGraphMemoryUpdater:
class GraphMemoryUpdater:
"""
Zep图谱记忆更新器
监控模拟的actions日志文件将新的agent活动实时更新到Zep图谱中
按平台分组每累积BATCH_SIZE条活动后批量发送到Zep
所有有意义的行为都会被更新到Zepaction_args中会包含完整的上下文信息
- 点赞/踩的帖子原文
- 转发/引用的帖子原文
- 关注/屏蔽的用户名
- 点赞/踩的评论原文
图谱记忆更新器
监控模拟的actions日志文件将新的agent活动实时写入本地图谱
按平台分组每累积BATCH_SIZE条活动后批量写入
"""
# 批量发送大小(每个平台累积多少条后发送)
BATCH_SIZE = 5
# 平台名称映射(用于控制台显示)
PLATFORM_DISPLAY_NAMES = {
'twitter': '世界1',
'reddit': '世界2',
}
# 发送间隔(秒),避免请求过快
SEND_INTERVAL = 0.5
# 重试配置
SEND_INTERVAL = 0.1 # 本地写入更快,间隔可以更短
MAX_RETRIES = 3
RETRY_DELAY = 2 # 秒
def __init__(self, graph_id: str, api_key: Optional[str] = None):
RETRY_DELAY = 1
def __init__(self, graph_id: str, storage_dir: Optional[str] = None, api_key: Optional[str] = None):
"""
初始化更新器
Args:
graph_id: Zep图谱ID
api_key: Zep API Key可选默认从配置读取
graph_id: 本地图谱ID
storage_dir: 图谱存储目录可选默认从配置读取
api_key: 已废弃保留以兼容旧调用代码
"""
self.graph_id = graph_id
self.api_key = api_key or Config.ZEP_API_KEY
if not self.api_key:
raise ValueError("ZEP_API_KEY未配置")
self.client = Zep(api_key=self.api_key)
# 活动队列
storage_dir = storage_dir or Config.GRAPH_STORAGE_DIR
self.store = LocalGraphStore(storage_dir)
self._activity_queue: Queue = Queue()
# 按平台分组的活动缓冲区每个平台各自累积到BATCH_SIZE后批量发送
self._platform_buffers: Dict[str, List[AgentActivity]] = {
'twitter': [],
'reddit': [],
}
self._buffer_lock = threading.Lock()
# 控制标志
self._running = False
self._worker_thread: Optional[threading.Thread] = None
# 统计
self._total_activities = 0 # 实际添加到队列的活动数
self._total_sent = 0 # 成功发送到Zep的批次数
self._total_items_sent = 0 # 成功发送到Zep的活动条数
self._failed_count = 0 # 发送失败的批次数
self._skipped_count = 0 # 被过滤跳过的活动数DO_NOTHING
logger.info(f"ZepGraphMemoryUpdater 初始化完成: graph_id={graph_id}, batch_size={self.BATCH_SIZE}")
self._total_activities = 0
self._total_sent = 0
self._total_items_sent = 0
self._failed_count = 0
self._skipped_count = 0
logger.info(f"GraphMemoryUpdater 初始化完成: graph_id={graph_id}, batch_size={self.BATCH_SIZE}")
def _get_platform_display_name(self, platform: str) -> str:
"""获取平台的显示名称"""
return self.PLATFORM_DISPLAY_NAMES.get(platform.lower(), platform)
def start(self):
"""启动后台工作线程"""
if self._running:
return
# Capture locale before spawning background thread
current_locale = get_locale()
self._running = True
@ -285,70 +247,42 @@ class ZepGraphMemoryUpdater:
target=self._worker_loop,
args=(current_locale,),
daemon=True,
name=f"ZepMemoryUpdater-{self.graph_id[:8]}"
name=f"GraphMemoryUpdater-{self.graph_id[:8]}"
)
self._worker_thread.start()
logger.info(f"ZepGraphMemoryUpdater 已启动: graph_id={self.graph_id}")
logger.info(f"GraphMemoryUpdater 已启动: graph_id={self.graph_id}")
def stop(self):
"""停止后台工作线程"""
self._running = False
# 发送剩余的活动
self._flush_remaining()
if self._worker_thread and self._worker_thread.is_alive():
self._worker_thread.join(timeout=10)
logger.info(f"ZepGraphMemoryUpdater 已停止: graph_id={self.graph_id}, "
f"total_activities={self._total_activities}, "
f"batches_sent={self._total_sent}, "
f"items_sent={self._total_items_sent}, "
f"failed={self._failed_count}, "
f"skipped={self._skipped_count}")
logger.info(f"GraphMemoryUpdater 已停止: graph_id={self.graph_id}, "
f"total_activities={self._total_activities}, "
f"batches_sent={self._total_sent}, "
f"items_sent={self._total_items_sent}, "
f"failed={self._failed_count}, "
f"skipped={self._skipped_count}")
def add_activity(self, activity: AgentActivity):
"""
添加一个agent活动到队列
所有有意义的行为都会被添加到队列包括
- CREATE_POST发帖
- CREATE_COMMENT评论
- QUOTE_POST引用帖子
- SEARCH_POSTS搜索帖子
- SEARCH_USER搜索用户
- LIKE_POST/DISLIKE_POST点赞/踩帖子
- REPOST转发
- FOLLOW关注
- MUTE屏蔽
- LIKE_COMMENT/DISLIKE_COMMENT点赞/踩评论
action_args中会包含完整的上下文信息如帖子原文用户名等
Args:
activity: Agent活动记录
"""
# 跳过DO_NOTHING类型的活动
"""添加一个agent活动到队列"""
if activity.action_type == "DO_NOTHING":
self._skipped_count += 1
return
self._activity_queue.put(activity)
self._total_activities += 1
logger.debug(f"添加活动到Zep队列: {activity.agent_name} - {activity.action_type}")
logger.debug(f"添加活动到队列: {activity.agent_name} - {activity.action_type}")
def add_activity_from_dict(self, data: Dict[str, Any], platform: str):
"""
从字典数据添加活动
Args:
data: 从actions.jsonl解析的字典数据
platform: 平台名称 (twitter/reddit)
"""
# 跳过事件类型的条目
"""从字典数据添加活动"""
if "event_type" in data:
return
activity = AgentActivity(
platform=platform,
agent_id=data.get("agent_id", 0),
@ -358,83 +292,94 @@ class ZepGraphMemoryUpdater:
round_num=data.get("round", 0),
timestamp=data.get("timestamp", datetime.now().isoformat()),
)
self.add_activity(activity)
def _worker_loop(self, locale: str = 'zh'):
"""后台工作循环 - 按平台批量发送活动到Zep"""
"""后台工作循环 - 按平台批量写入活动"""
set_locale(locale)
while self._running or not self._activity_queue.empty():
try:
# 尝试从队列获取活动超时1秒
try:
activity = self._activity_queue.get(timeout=1)
# 将活动添加到对应平台的缓冲区
platform = activity.platform.lower()
with self._buffer_lock:
if platform not in self._platform_buffers:
self._platform_buffers[platform] = []
self._platform_buffers[platform].append(activity)
# 检查该平台是否达到批量大小
if len(self._platform_buffers[platform]) >= self.BATCH_SIZE:
batch = self._platform_buffers[platform][:self.BATCH_SIZE]
self._platform_buffers[platform] = self._platform_buffers[platform][self.BATCH_SIZE:]
# 释放锁后再发送
self._send_batch_activities(batch, platform)
# 发送间隔,避免请求过快
self._write_batch_activities(batch, platform)
time.sleep(self.SEND_INTERVAL)
except Empty:
pass
except Exception as e:
logger.error(f"工作循环异常: {e}")
time.sleep(1)
def _send_batch_activities(self, activities: List[AgentActivity], platform: str):
"""
批量发送活动到Zep图谱合并为一条文本
Args:
activities: Agent活动列表
platform: 平台名称
"""
def _write_batch_activities(self, activities: List[AgentActivity], platform: str):
"""批量将活动写入本地图谱"""
if not activities:
return
# 将多条活动合并为一条文本,用换行分隔
episode_texts = [activity.to_episode_text() for activity in activities]
combined_text = "\n".join(episode_texts)
# 带重试的发送
for attempt in range(self.MAX_RETRIES):
try:
self.client.graph.add(
graph_id=self.graph_id,
type="text",
data=combined_text
)
# 写入情节文本
self.store.add_episode(self.graph_id, combined_text)
# 为每条活动创建可搜索的事实边
for activity in activities:
self._create_activity_edge(activity)
self._total_sent += 1
self._total_items_sent += len(activities)
display_name = self._get_platform_display_name(platform)
logger.info(f"成功批量发送 {len(activities)}{display_name}活动到图谱 {self.graph_id}")
logger.debug(f"批量内容预览: {combined_text[:200]}...")
logger.info(f"成功写入 {len(activities)}{display_name}活动到图谱 {self.graph_id}")
return
except Exception as e:
if attempt < self.MAX_RETRIES - 1:
logger.warning(f"批量发送到Zep失败 (尝试 {attempt + 1}/{self.MAX_RETRIES}): {e}")
logger.warning(f"写入活动失败 (尝试 {attempt + 1}/{self.MAX_RETRIES}): {e}")
time.sleep(self.RETRY_DELAY * (attempt + 1))
else:
logger.error(f"批量发送到Zep失败,已重试{self.MAX_RETRIES}次: {e}")
logger.error(f"写入活动失败,已重试{self.MAX_RETRIES}次: {e}")
self._failed_count += 1
def _create_activity_edge(self, activity: AgentActivity):
"""为单条活动在图谱中创建可搜索的事实边"""
fact = activity.to_episode_text()
# 创建或获取Agent节点
agent_uuid = self.store.upsert_node(
graph_id=self.graph_id,
name=activity.agent_name,
labels=["Agent", "Entity"],
summary=f"Agent {activity.agent_name}",
)
# 为活动创建自环边(以便关键词搜索可以找到它)
self.store.add_edge(self.graph_id, {
"name": activity.action_type,
"fact": fact,
"source_node_uuid": agent_uuid,
"target_node_uuid": agent_uuid,
"attributes": {
"platform": activity.platform,
"round": activity.round_num,
"timestamp": activity.timestamp,
},
})
def _flush_remaining(self):
"""发送队列和缓冲区中剩余的活动"""
# 首先处理队列中剩余的活动,添加到缓冲区
while not self._activity_queue.empty():
try:
activity = self._activity_queue.get_nowait()
@ -445,96 +390,83 @@ class ZepGraphMemoryUpdater:
self._platform_buffers[platform].append(activity)
except Empty:
break
# 然后发送各平台缓冲区中剩余的活动即使不足BATCH_SIZE条
with self._buffer_lock:
for platform, buffer in self._platform_buffers.items():
if buffer:
display_name = self._get_platform_display_name(platform)
logger.info(f"发送{display_name}平台剩余的 {len(buffer)} 条活动")
self._send_batch_activities(buffer, platform)
# 清空所有缓冲区
self._write_batch_activities(buffer, platform)
for platform in self._platform_buffers:
self._platform_buffers[platform] = []
def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
with self._buffer_lock:
buffer_sizes = {p: len(b) for p, b in self._platform_buffers.items()}
return {
"graph_id": self.graph_id,
"batch_size": self.BATCH_SIZE,
"total_activities": self._total_activities, # 添加到队列的活动总数
"batches_sent": self._total_sent, # 成功发送的批次数
"items_sent": self._total_items_sent, # 成功发送的活动条数
"failed_count": self._failed_count, # 发送失败的批次数
"skipped_count": self._skipped_count, # 被过滤跳过的活动数DO_NOTHING
"total_activities": self._total_activities,
"batches_sent": self._total_sent,
"items_sent": self._total_items_sent,
"failed_count": self._failed_count,
"skipped_count": self._skipped_count,
"queue_size": self._activity_queue.qsize(),
"buffer_sizes": buffer_sizes, # 各平台缓冲区大小
"buffer_sizes": buffer_sizes,
"running": self._running,
}
# 向后兼容别名
ZepGraphMemoryUpdater = GraphMemoryUpdater
class ZepGraphMemoryManager:
"""
管理多个模拟的Zep图谱记忆更新器
管理多个模拟的图谱记忆更新器
每个模拟可以有自己的更新器实例
"""
_updaters: Dict[str, ZepGraphMemoryUpdater] = {}
_updaters: Dict[str, GraphMemoryUpdater] = {}
_lock = threading.Lock()
@classmethod
def create_updater(cls, simulation_id: str, graph_id: str) -> ZepGraphMemoryUpdater:
"""
为模拟创建图谱记忆更新器
Args:
simulation_id: 模拟ID
graph_id: Zep图谱ID
Returns:
ZepGraphMemoryUpdater实例
"""
def create_updater(cls, simulation_id: str, graph_id: str) -> GraphMemoryUpdater:
"""为模拟创建图谱记忆更新器"""
with cls._lock:
# 如果已存在,先停止旧的
if simulation_id in cls._updaters:
cls._updaters[simulation_id].stop()
updater = ZepGraphMemoryUpdater(graph_id)
updater = GraphMemoryUpdater(graph_id)
updater.start()
cls._updaters[simulation_id] = updater
logger.info(f"创建图谱记忆更新器: simulation_id={simulation_id}, graph_id={graph_id}")
return updater
@classmethod
def get_updater(cls, simulation_id: str) -> Optional[ZepGraphMemoryUpdater]:
"""获取模拟的更新器"""
def get_updater(cls, simulation_id: str) -> Optional[GraphMemoryUpdater]:
return cls._updaters.get(simulation_id)
@classmethod
def stop_updater(cls, simulation_id: str):
"""停止并移除模拟的更新器"""
with cls._lock:
if simulation_id in cls._updaters:
cls._updaters[simulation_id].stop()
del cls._updaters[simulation_id]
logger.info(f"已停止图谱记忆更新器: simulation_id={simulation_id}")
# 防止 stop_all 重复调用的标志
_stop_all_done = False
@classmethod
def stop_all(cls):
"""停止所有更新器"""
# 防止重复调用
if cls._stop_all_done:
return
cls._stop_all_done = True
with cls._lock:
if cls._updaters:
for simulation_id, updater in list(cls._updaters.items()):
@ -544,11 +476,10 @@ class ZepGraphMemoryManager:
logger.error(f"停止更新器失败: simulation_id={simulation_id}, error={e}")
cls._updaters.clear()
logger.info("已停止所有图谱记忆更新器")
@classmethod
def get_all_stats(cls) -> Dict[str, Dict[str, Any]]:
"""获取所有更新器的统计信息"""
return {
sim_id: updater.get_stats()
sim_id: updater.get_stats()
for sim_id, updater in cls._updaters.items()
}

View File

@ -13,13 +13,11 @@ import json
from typing import Dict, Any, List, Optional
from dataclasses import dataclass, field
from zep_cloud.client import Zep
from ..config import Config
from ..utils.local_graph_store import LocalGraphStore
from ..utils.logger import get_logger
from ..utils.llm_client import LLMClient
from ..utils.locale import get_locale, t
from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
logger = get_logger('mirofish.zep_tools')
@ -418,20 +416,14 @@ class ZepToolsService:
- get_entity_summary - 获取实体的关系摘要
"""
# 重试配置
MAX_RETRIES = 3
RETRY_DELAY = 2.0
def __init__(self, api_key: Optional[str] = None, llm_client: Optional[LLMClient] = None):
self.api_key = api_key or Config.ZEP_API_KEY
if not self.api_key:
raise ValueError("ZEP_API_KEY 未配置")
self.client = Zep(api_key=self.api_key)
# LLM客户端用于InsightForge生成子问题
def __init__(self, storage_dir: Optional[str] = None, api_key: Optional[str] = None,
llm_client: Optional[LLMClient] = None):
# api_key参数保留以兼容旧调用方式但不再使用
storage_dir = storage_dir or Config.GRAPH_STORAGE_DIR
self.store = LocalGraphStore(storage_dir)
self._llm_client = llm_client
logger.info(t("console.zepToolsInitialized"))
@property
def llm(self) -> LLMClient:
"""延迟初始化LLM客户端"""
@ -439,206 +431,50 @@ class ZepToolsService:
self._llm_client = LLMClient()
return self._llm_client
def _call_with_retry(self, func, operation_name: str, max_retries: int = None):
"""带重试机制的API调用"""
max_retries = max_retries or self.MAX_RETRIES
last_exception = None
delay = self.RETRY_DELAY
for attempt in range(max_retries):
try:
return func()
except Exception as e:
last_exception = e
if attempt < max_retries - 1:
logger.warning(
t("console.zepRetryAttempt", operation=operation_name, attempt=attempt + 1, error=str(e)[:100], delay=f"{delay:.1f}")
)
time.sleep(delay)
delay *= 2
else:
logger.error(t("console.zepAllRetriesFailed", operation=operation_name, retries=max_retries, error=str(e)))
raise last_exception
def search_graph(
self,
graph_id: str,
query: str,
self,
graph_id: str,
query: str,
limit: int = 10,
scope: str = "edges"
) -> SearchResult:
"""
图谱语义搜索
使用混合搜索语义+BM25在图谱中搜索相关信息
如果Zep Cloud的search API不可用则降级为本地关键词匹配
Args:
graph_id: 图谱ID (Standalone Graph)
query: 搜索查询
limit: 返回结果数量
scope: 搜索范围"edges" "nodes"
Returns:
SearchResult: 搜索结果
"""
logger.info(t("console.graphSearch", graphId=graph_id, query=query[:50]))
# 尝试使用Zep Cloud Search API
try:
search_results = self._call_with_retry(
func=lambda: self.client.graph.search(
graph_id=graph_id,
query=query,
limit=limit,
scope=scope,
reranker="cross_encoder"
),
operation_name=t("console.graphSearchOp", graphId=graph_id)
)
facts = []
edges = []
nodes = []
# 解析边搜索结果
if hasattr(search_results, 'edges') and search_results.edges:
for edge in search_results.edges:
if hasattr(edge, 'fact') and edge.fact:
facts.append(edge.fact)
edges.append({
"uuid": getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', ''),
"name": getattr(edge, 'name', ''),
"fact": getattr(edge, 'fact', ''),
"source_node_uuid": getattr(edge, 'source_node_uuid', ''),
"target_node_uuid": getattr(edge, 'target_node_uuid', ''),
})
# 解析节点搜索结果
if hasattr(search_results, 'nodes') and search_results.nodes:
for node in search_results.nodes:
nodes.append({
"uuid": getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
"name": getattr(node, 'name', ''),
"labels": getattr(node, 'labels', []),
"summary": getattr(node, 'summary', ''),
})
# 节点摘要也算作事实
if hasattr(node, 'summary') and node.summary:
facts.append(f"[{node.name}]: {node.summary}")
logger.info(t("console.searchComplete", count=len(facts)))
return SearchResult(
facts=facts,
edges=edges,
nodes=nodes,
query=query,
total_count=len(facts)
)
except Exception as e:
logger.warning(t("console.zepSearchApiFallback", error=str(e)))
# 降级:使用本地关键词匹配搜索
return self._local_search(graph_id, query, limit, scope)
def _local_search(
self,
graph_id: str,
query: str,
limit: int = 10,
scope: str = "edges"
) -> SearchResult:
"""
本地关键词匹配搜索作为Zep Search API的降级方案
获取所有边/节点然后在本地进行关键词匹配
图谱关键词搜索
Args:
graph_id: 图谱ID
query: 搜索查询
limit: 返回结果数量
scope: 搜索范围
scope: 搜索范围"edges" "nodes"
Returns:
SearchResult: 搜索结果
"""
logger.info(t("console.graphSearch", graphId=graph_id, query=query[:50]))
return self._local_search(graph_id, query, limit, scope)
def _local_search(
self,
graph_id: str,
query: str,
limit: int = 10,
scope: str = "edges"
) -> SearchResult:
"""本地关键词匹配搜索"""
logger.info(t("console.usingLocalSearch", query=query[:30]))
facts = []
edges_result = []
nodes_result = []
# 提取查询关键词(简单分词)
query_lower = query.lower()
keywords = [w.strip() for w in query_lower.replace(',', ' ').replace('', ' ').split() if len(w.strip()) > 1]
def match_score(text: str) -> int:
"""计算文本与查询的匹配分数"""
if not text:
return 0
text_lower = text.lower()
# 完全匹配查询
if query_lower in text_lower:
return 100
# 关键词匹配
score = 0
for keyword in keywords:
if keyword in text_lower:
score += 10
return score
try:
if scope in ["edges", "both"]:
# 获取所有边并匹配
all_edges = self.get_all_edges(graph_id)
scored_edges = []
for edge in all_edges:
score = match_score(edge.fact) + match_score(edge.name)
if score > 0:
scored_edges.append((score, edge))
# 按分数排序
scored_edges.sort(key=lambda x: x[0], reverse=True)
for score, edge in scored_edges[:limit]:
if edge.fact:
facts.append(edge.fact)
edges_result.append({
"uuid": edge.uuid,
"name": edge.name,
"fact": edge.fact,
"source_node_uuid": edge.source_node_uuid,
"target_node_uuid": edge.target_node_uuid,
})
if scope in ["nodes", "both"]:
# 获取所有节点并匹配
all_nodes = self.get_all_nodes(graph_id)
scored_nodes = []
for node in all_nodes:
score = match_score(node.name) + match_score(node.summary)
if score > 0:
scored_nodes.append((score, node))
scored_nodes.sort(key=lambda x: x[0], reverse=True)
for score, node in scored_nodes[:limit]:
nodes_result.append({
"uuid": node.uuid,
"name": node.name,
"labels": node.labels,
"summary": node.summary,
})
if node.summary:
facts.append(f"[{node.name}]: {node.summary}")
raw = self.store.search(graph_id, query, limit=limit, scope=scope)
facts = raw.get("facts", [])
edges_result = raw.get("edges", [])
nodes_result = raw.get("nodes", [])
logger.info(t("console.localSearchComplete", count=len(facts)))
except Exception as e:
logger.error(t("console.localSearchFailed", error=str(e)))
facts, edges_result, nodes_result = [], [], []
return SearchResult(
facts=facts,
edges=edges_result,
@ -648,99 +484,74 @@ class ZepToolsService:
)
def get_all_nodes(self, graph_id: str) -> List[NodeInfo]:
"""
获取图谱的所有节点分页获取
Args:
graph_id: 图谱ID
Returns:
节点列表
"""
"""获取图谱的所有节点"""
logger.info(t("console.fetchingAllNodes", graphId=graph_id))
nodes = fetch_all_nodes(self.client, graph_id)
result = []
for node in nodes:
node_uuid = getattr(node, 'uuid_', None) or getattr(node, 'uuid', None) or ""
result.append(NodeInfo(
uuid=str(node_uuid) if node_uuid else "",
name=node.name or "",
labels=node.labels or [],
summary=node.summary or "",
attributes=node.attributes or {}
))
nodes = self.store.get_nodes(graph_id)
result = [
NodeInfo(
uuid=n.get("uuid", ""),
name=n.get("name", ""),
labels=n.get("labels") or [],
summary=n.get("summary", ""),
attributes=n.get("attributes") or {},
)
for n in nodes
]
logger.info(t("console.fetchedNodes", count=len(result)))
return result
def get_all_edges(self, graph_id: str, include_temporal: bool = True) -> List[EdgeInfo]:
"""
获取图谱的所有边分页获取包含时间信息
Args:
graph_id: 图谱ID
include_temporal: 是否包含时间信息默认True
Returns:
边列表包含created_at, valid_at, invalid_at, expired_at
"""
"""获取图谱的所有边(含时间信息)"""
logger.info(t("console.fetchingAllEdges", graphId=graph_id))
edges = fetch_all_edges(self.client, graph_id)
edges = self.store.get_edges(graph_id)
result = []
for edge in edges:
edge_uuid = getattr(edge, 'uuid_', None) or getattr(edge, 'uuid', None) or ""
for e in edges:
edge_info = EdgeInfo(
uuid=str(edge_uuid) if edge_uuid else "",
name=edge.name or "",
fact=edge.fact or "",
source_node_uuid=edge.source_node_uuid or "",
target_node_uuid=edge.target_node_uuid or ""
uuid=e.get("uuid", ""),
name=e.get("name", ""),
fact=e.get("fact", ""),
source_node_uuid=e.get("source_node_uuid", ""),
target_node_uuid=e.get("target_node_uuid", ""),
)
# 添加时间信息
if include_temporal:
edge_info.created_at = getattr(edge, 'created_at', None)
edge_info.valid_at = getattr(edge, 'valid_at', None)
edge_info.invalid_at = getattr(edge, 'invalid_at', None)
edge_info.expired_at = getattr(edge, 'expired_at', None)
edge_info.created_at = e.get("created_at")
edge_info.valid_at = e.get("valid_at")
edge_info.invalid_at = e.get("invalid_at")
edge_info.expired_at = e.get("expired_at")
result.append(edge_info)
logger.info(t("console.fetchedEdges", count=len(result)))
return result
def get_node_detail(self, node_uuid: str) -> Optional[NodeInfo]:
def get_node_detail(self, node_uuid: str, graph_id: str = "") -> Optional[NodeInfo]:
"""
获取单个节点的详细信息
Args:
node_uuid: 节点UUID
graph_id: 图谱ID从本地存储检索时需要
Returns:
节点信息或None
"""
logger.info(t("console.fetchingNodeDetail", uuid=node_uuid[:8]))
try:
node = self._call_with_retry(
func=lambda: self.client.graph.node.get(uuid_=node_uuid),
operation_name=t("console.fetchNodeDetailOp", uuid=node_uuid[:8])
)
if not node:
return None
return NodeInfo(
uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
name=node.name or "",
labels=node.labels or [],
summary=node.summary or "",
attributes=node.attributes or {}
)
# 若提供了graph_id直接从该图谱查找
if graph_id:
n = self.store.get_node(graph_id, node_uuid)
if n:
return NodeInfo(
uuid=n.get("uuid", ""),
name=n.get("name", ""),
labels=n.get("labels") or [],
summary=n.get("summary", ""),
attributes=n.get("attributes") or {},
)
return None
except Exception as e:
logger.error(t("console.fetchNodeDetailFailed", error=str(e)))
return None
@ -1043,7 +854,7 @@ class ZepToolsService:
continue
try:
# 单独获取每个相关节点的信息
node = self.get_node_detail(uuid)
node = self.get_node_detail(uuid, graph_id=graph_id)
if node:
node_map[uuid] = node
entity_type = next((l for l in node.labels if l not in ["Entity", "Node"]), "实体")

View File

@ -0,0 +1,290 @@
"""
本地JSON文件图谱存储
替代Zep Cloud将图谱数据节点情节存储在本地JSON文件中
存储目录结构:
{storage_dir}/
{graph_id}/
metadata.json - 图谱元数据和本体定义
nodes.json - 节点列表
edges.json - 边列表
episodes.jsonl - 情节文本日志追加写入
"""
from __future__ import annotations
import json
import os
import shutil
import threading
import uuid
from datetime import datetime
from typing import Any, Dict, List, Optional
from .logger import get_logger
logger = get_logger('mirofish.local_graph_store')
# 每个图谱一把锁,保证并发写入安全
_global_lock = threading.Lock()
_graph_locks: Dict[str, threading.Lock] = {}
def _lock_for(graph_id: str) -> threading.Lock:
with _global_lock:
if graph_id not in _graph_locks:
_graph_locks[graph_id] = threading.Lock()
return _graph_locks[graph_id]
class LocalGraphStore:
"""本地JSON文件图谱存储"""
def __init__(self, storage_dir: str):
self.storage_dir = storage_dir
os.makedirs(storage_dir, exist_ok=True)
# ── 图谱生命周期 ──────────────────────────────────────────────────────────
def create_graph(self, graph_id: str, name: str, description: str = "") -> None:
graph_dir = self._graph_dir(graph_id)
os.makedirs(graph_dir, exist_ok=True)
self._write_json(self._meta_path(graph_id), {
"graph_id": graph_id,
"name": name,
"description": description,
"created_at": datetime.now().isoformat(),
"ontology": None,
})
if not os.path.exists(self._nodes_path(graph_id)):
self._write_json(self._nodes_path(graph_id), [])
if not os.path.exists(self._edges_path(graph_id)):
self._write_json(self._edges_path(graph_id), [])
logger.info(f"本地图谱已创建: {graph_id}")
def delete_graph(self, graph_id: str) -> None:
graph_dir = self._graph_dir(graph_id)
if os.path.exists(graph_dir):
shutil.rmtree(graph_dir)
logger.info(f"本地图谱已删除: {graph_id}")
def graph_exists(self, graph_id: str) -> bool:
return os.path.exists(self._meta_path(graph_id))
# ── 本体 ──────────────────────────────────────────────────────────────────
def set_ontology(self, graph_id: str, ontology: Dict[str, Any]) -> None:
meta = self._read_json(self._meta_path(graph_id)) or {}
meta["ontology"] = ontology
self._write_json(self._meta_path(graph_id), meta)
def get_ontology(self, graph_id: str) -> Optional[Dict[str, Any]]:
meta = self._read_json(self._meta_path(graph_id)) or {}
return meta.get("ontology")
def get_metadata(self, graph_id: str) -> Optional[Dict[str, Any]]:
return self._read_json(self._meta_path(graph_id))
# ── 情节Episode────────────────────────────────────────────────────────
def add_episode(self, graph_id: str, text: str) -> str:
"""追加一条情节文本返回情节uuid本地存储立即处理完成"""
episode_id = uuid.uuid4().hex
record = {
"uuid": episode_id,
"text": text,
"created_at": datetime.now().isoformat(),
"processed": True,
}
ep_path = self._episodes_path(graph_id)
with _lock_for(graph_id):
with open(ep_path, 'a', encoding='utf-8') as f:
f.write(json.dumps(record, ensure_ascii=False) + '\n')
return episode_id
def add_episodes_batch(self, graph_id: str, texts: List[str]) -> List[str]:
return [self.add_episode(graph_id, t) for t in texts]
def episode_is_processed(self, graph_id: str, episode_uuid: str) -> bool:
"""本地存储中的情节总是立即处理完成"""
return True
# ── 节点 ──────────────────────────────────────────────────────────────────
def get_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
return self._read_json(self._nodes_path(graph_id)) or []
def get_node(self, graph_id: str, node_uuid: str) -> Optional[Dict[str, Any]]:
for node in self.get_nodes(graph_id):
if node.get("uuid") == node_uuid:
return node
return None
def upsert_node(
self,
graph_id: str,
name: str,
labels: Optional[List[str]] = None,
summary: str = "",
attributes: Optional[Dict[str, Any]] = None,
) -> str:
"""按名称查找节点存在则更新不存在则创建。返回uuid。"""
labels = labels or ["Entity"]
attributes = attributes or {}
with _lock_for(graph_id):
nodes = self._read_json(self._nodes_path(graph_id)) or []
# 按名称(不区分大小写)查找
for node in nodes:
if node.get("name", "").lower() == name.lower():
# 合并标签
existing = set(node.get("labels", []))
existing.update(labels)
node["labels"] = list(existing)
# 若原摘要为空则填充
if summary and not node.get("summary"):
node["summary"] = summary
# 合并属性
if attributes:
node.setdefault("attributes", {}).update(attributes)
self._write_json(self._nodes_path(graph_id), nodes)
return node["uuid"]
# 创建新节点
node_uuid = uuid.uuid4().hex
nodes.append({
"uuid": node_uuid,
"name": name,
"labels": labels,
"summary": summary,
"attributes": attributes,
"created_at": datetime.now().isoformat(),
})
self._write_json(self._nodes_path(graph_id), nodes)
return node_uuid
# ── 边 ───────────────────────────────────────────────────────────────────
def get_edges(self, graph_id: str) -> List[Dict[str, Any]]:
return self._read_json(self._edges_path(graph_id)) or []
def get_node_edges(self, graph_id: str, node_uuid: str) -> List[Dict[str, Any]]:
"""获取与指定节点相关的所有边(作为源或目标)"""
return [
e for e in self.get_edges(graph_id)
if e.get("source_node_uuid") == node_uuid or e.get("target_node_uuid") == node_uuid
]
def add_edge(self, graph_id: str, edge: Dict[str, Any]) -> str:
"""添加一条边返回其uuid。"""
edge_uuid = edge.get("uuid") or uuid.uuid4().hex
edge = dict(edge)
edge["uuid"] = edge_uuid
edge.setdefault("created_at", datetime.now().isoformat())
edge.setdefault("valid_at", None)
edge.setdefault("invalid_at", None)
edge.setdefault("expired_at", None)
edge.setdefault("attributes", {})
with _lock_for(graph_id):
edges = self._read_json(self._edges_path(graph_id)) or []
edges.append(edge)
self._write_json(self._edges_path(graph_id), edges)
return edge_uuid
def add_fact_edge(
self,
graph_id: str,
source_uuid: str,
target_uuid: str,
name: str,
fact: str,
) -> str:
"""便利方法:在两个节点之间添加一条命名事实边。"""
return self.add_edge(graph_id, {
"name": name,
"fact": fact,
"source_node_uuid": source_uuid,
"target_node_uuid": target_uuid,
})
# ── 搜索 ─────────────────────────────────────────────────────────────────
def search(
self,
graph_id: str,
query: str,
limit: int = 10,
scope: str = "edges",
) -> Dict[str, Any]:
"""基于关键词的本地搜索"""
query_lower = query.lower()
keywords = [
w.strip()
for w in query_lower.replace(',', ' ').replace('', ' ').split()
if len(w.strip()) > 1
]
def score(text: str) -> int:
if not text:
return 0
tl = text.lower()
if query_lower in tl:
return 100
return sum(10 for kw in keywords if kw in tl)
result_edges: List[Dict] = []
result_nodes: List[Dict] = []
facts: List[str] = []
if scope in ("edges", "both"):
scored = sorted(
[(score(e.get("fact", "")) + score(e.get("name", "")), e)
for e in self.get_edges(graph_id)
if score(e.get("fact", "")) + score(e.get("name", "")) > 0],
key=lambda x: x[0], reverse=True,
)
for _, edge in scored[:limit]:
result_edges.append(edge)
if edge.get("fact"):
facts.append(edge["fact"])
if scope in ("nodes", "both"):
scored = sorted(
[(score(n.get("name", "")) + score(n.get("summary", "")), n)
for n in self.get_nodes(graph_id)
if score(n.get("name", "")) + score(n.get("summary", "")) > 0],
key=lambda x: x[0], reverse=True,
)
for _, node in scored[:limit]:
result_nodes.append(node)
if node.get("summary"):
facts.append(f"[{node['name']}]: {node['summary']}")
return {"facts": facts, "edges": result_edges, "nodes": result_nodes}
# ── 内部路径辅助 ──────────────────────────────────────────────────────────
def _graph_dir(self, graph_id: str) -> str:
return os.path.join(self.storage_dir, graph_id)
def _meta_path(self, graph_id: str) -> str:
return os.path.join(self._graph_dir(graph_id), "metadata.json")
def _nodes_path(self, graph_id: str) -> str:
return os.path.join(self._graph_dir(graph_id), "nodes.json")
def _edges_path(self, graph_id: str) -> str:
return os.path.join(self._graph_dir(graph_id), "edges.json")
def _episodes_path(self, graph_id: str) -> str:
return os.path.join(self._graph_dir(graph_id), "episodes.jsonl")
def _read_json(self, path: str) -> Any:
if not os.path.exists(path):
return None
with open(path, 'r', encoding='utf-8') as f:
return json.load(f)
def _write_json(self, path: str, data: Any) -> None:
with open(path, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)

View File

@ -1,143 +1,25 @@
"""Zep Graph 分页读取工具。
"""
图谱分页读取工具存根模块
Zep node/edge 列表接口使用 UUID cursor 分页
本模块封装自动翻页逻辑含单页重试对调用方透明地返回完整列表
原来封装 Zep Cloud 的分页逻辑
现在图谱数据存储在本地 JSON 文件中不再需要分页
本模块保留以避免破坏未更新的旧导入
"""
from __future__ import annotations
import time
from collections.abc import Callable
from typing import Any
from zep_cloud import InternalServerError
from zep_cloud.client import Zep
from .logger import get_logger
from ..utils.logger import get_logger
logger = get_logger('mirofish.zep_paging')
_DEFAULT_PAGE_SIZE = 100
_MAX_NODES = 2000
_DEFAULT_MAX_RETRIES = 3
_DEFAULT_RETRY_DELAY = 2.0 # seconds, doubles each retry
def fetch_all_nodes(client, graph_id: str, **kwargs) -> list:
"""已废弃:请直接使用 LocalGraphStore.get_nodes()"""
logger.warning("fetch_all_nodes 已废弃,请使用 LocalGraphStore.get_nodes()")
return []
def _fetch_page_with_retry(
api_call: Callable[..., list[Any]],
*args: Any,
max_retries: int = _DEFAULT_MAX_RETRIES,
retry_delay: float = _DEFAULT_RETRY_DELAY,
page_description: str = "page",
**kwargs: Any,
) -> list[Any]:
"""单页请求,失败时指数退避重试。仅重试网络/IO类瞬态错误。"""
if max_retries < 1:
raise ValueError("max_retries must be >= 1")
last_exception: Exception | None = None
delay = retry_delay
for attempt in range(max_retries):
try:
return api_call(*args, **kwargs)
except (ConnectionError, TimeoutError, OSError, InternalServerError) as e:
last_exception = e
if attempt < max_retries - 1:
logger.warning(
f"Zep {page_description} attempt {attempt + 1} failed: {str(e)[:100]}, retrying in {delay:.1f}s..."
)
time.sleep(delay)
delay *= 2
else:
logger.error(f"Zep {page_description} failed after {max_retries} attempts: {str(e)}")
assert last_exception is not None
raise last_exception
def fetch_all_nodes(
client: Zep,
graph_id: str,
page_size: int = _DEFAULT_PAGE_SIZE,
max_items: int = _MAX_NODES,
max_retries: int = _DEFAULT_MAX_RETRIES,
retry_delay: float = _DEFAULT_RETRY_DELAY,
) -> list[Any]:
"""分页获取图谱节点,最多返回 max_items 条(默认 2000。每页请求自带重试。"""
all_nodes: list[Any] = []
cursor: str | None = None
page_num = 0
while True:
kwargs: dict[str, Any] = {"limit": page_size}
if cursor is not None:
kwargs["uuid_cursor"] = cursor
page_num += 1
batch = _fetch_page_with_retry(
client.graph.node.get_by_graph_id,
graph_id,
max_retries=max_retries,
retry_delay=retry_delay,
page_description=f"fetch nodes page {page_num} (graph={graph_id})",
**kwargs,
)
if not batch:
break
all_nodes.extend(batch)
if len(all_nodes) >= max_items:
all_nodes = all_nodes[:max_items]
logger.warning(f"Node count reached limit ({max_items}), stopping pagination for graph {graph_id}")
break
if len(batch) < page_size:
break
cursor = getattr(batch[-1], "uuid_", None) or getattr(batch[-1], "uuid", None)
if cursor is None:
logger.warning(f"Node missing uuid field, stopping pagination at {len(all_nodes)} nodes")
break
return all_nodes
def fetch_all_edges(
client: Zep,
graph_id: str,
page_size: int = _DEFAULT_PAGE_SIZE,
max_retries: int = _DEFAULT_MAX_RETRIES,
retry_delay: float = _DEFAULT_RETRY_DELAY,
) -> list[Any]:
"""分页获取图谱所有边,返回完整列表。每页请求自带重试。"""
all_edges: list[Any] = []
cursor: str | None = None
page_num = 0
while True:
kwargs: dict[str, Any] = {"limit": page_size}
if cursor is not None:
kwargs["uuid_cursor"] = cursor
page_num += 1
batch = _fetch_page_with_retry(
client.graph.edge.get_by_graph_id,
graph_id,
max_retries=max_retries,
retry_delay=retry_delay,
page_description=f"fetch edges page {page_num} (graph={graph_id})",
**kwargs,
)
if not batch:
break
all_edges.extend(batch)
if len(batch) < page_size:
break
cursor = getattr(batch[-1], "uuid_", None) or getattr(batch[-1], "uuid", None)
if cursor is None:
logger.warning(f"Edge missing uuid field, stopping pagination at {len(all_edges)} edges")
break
return all_edges
def fetch_all_edges(client, graph_id: str, **kwargs) -> list:
"""已废弃:请直接使用 LocalGraphStore.get_edges()"""
logger.warning("fetch_all_edges 已废弃,请使用 LocalGraphStore.get_edges()")
return []

View File

@ -16,9 +16,6 @@ dependencies = [
# LLM 相关
"openai>=1.0.0",
# Zep Cloud
"zep-cloud==3.13.0",
# OASIS 社交媒体模拟
"camel-oasis==0.2.5",
"camel-ai==0.2.78",

View File

@ -13,9 +13,6 @@ flask-cors>=6.0.0
# OpenAI SDK统一使用 OpenAI 格式调用 LLM
openai>=1.0.0
# ============= Zep Cloud =============
zep-cloud==3.13.0
# ============= OASIS 社交媒体模拟 =============
# OASIS 社交模拟框架
camel-oasis==0.2.5