MicroFish/backend/app/services/zep_graph_memory_updater.py

486 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
图谱记忆更新服务
将模拟中的Agent活动动态写入本地JSON图谱文件
"""
import os
import time
import threading
import json
from typing import Dict, Any, List, Optional, Callable
from dataclasses import dataclass
from datetime import datetime
from queue import Queue, Empty
from ..config import Config
from ..utils.local_graph_store import LocalGraphStore
from ..utils.logger import get_logger
from ..utils.locale import get_locale, set_locale
logger = get_logger('mirofish.zep_graph_memory_updater')
@dataclass
class AgentActivity:
"""Agent活动记录"""
platform: str # twitter / reddit
agent_id: int
agent_name: str
action_type: str # CREATE_POST, LIKE_POST, etc.
action_args: Dict[str, Any]
round_num: int
timestamp: str
def to_episode_text(self) -> str:
"""
将活动转换为自然语言描述文本
采用自然语言描述格式,让图谱能够从中提取实体和关系
不添加模拟相关的前缀,避免误导图谱更新
"""
action_descriptions = {
"CREATE_POST": self._describe_create_post,
"LIKE_POST": self._describe_like_post,
"DISLIKE_POST": self._describe_dislike_post,
"REPOST": self._describe_repost,
"QUOTE_POST": self._describe_quote_post,
"FOLLOW": self._describe_follow,
"CREATE_COMMENT": self._describe_create_comment,
"LIKE_COMMENT": self._describe_like_comment,
"DISLIKE_COMMENT": self._describe_dislike_comment,
"SEARCH_POSTS": self._describe_search,
"SEARCH_USER": self._describe_search_user,
"MUTE": self._describe_mute,
}
describe_func = action_descriptions.get(self.action_type, self._describe_generic)
description = describe_func()
return f"{self.agent_name}: {description}"
def _describe_create_post(self) -> str:
content = self.action_args.get("content", "")
if content:
return f"发布了一条帖子:「{content}"
return "发布了一条帖子"
def _describe_like_post(self) -> str:
post_content = self.action_args.get("post_content", "")
post_author = self.action_args.get("post_author_name", "")
if post_content and post_author:
return f"点赞了{post_author}的帖子:「{post_content}"
elif post_content:
return f"点赞了一条帖子:「{post_content}"
elif post_author:
return f"点赞了{post_author}的一条帖子"
return "点赞了一条帖子"
def _describe_dislike_post(self) -> str:
post_content = self.action_args.get("post_content", "")
post_author = self.action_args.get("post_author_name", "")
if post_content and post_author:
return f"踩了{post_author}的帖子:「{post_content}"
elif post_content:
return f"踩了一条帖子:「{post_content}"
elif post_author:
return f"踩了{post_author}的一条帖子"
return "踩了一条帖子"
def _describe_repost(self) -> str:
original_content = self.action_args.get("original_content", "")
original_author = self.action_args.get("original_author_name", "")
if original_content and original_author:
return f"转发了{original_author}的帖子:「{original_content}"
elif original_content:
return f"转发了一条帖子:「{original_content}"
elif original_author:
return f"转发了{original_author}的一条帖子"
return "转发了一条帖子"
def _describe_quote_post(self) -> str:
original_content = self.action_args.get("original_content", "")
original_author = self.action_args.get("original_author_name", "")
quote_content = self.action_args.get("quote_content", "") or self.action_args.get("content", "")
base = ""
if original_content and original_author:
base = f"引用了{original_author}的帖子「{original_content}"
elif original_content:
base = f"引用了一条帖子「{original_content}"
elif original_author:
base = f"引用了{original_author}的一条帖子"
else:
base = "引用了一条帖子"
if quote_content:
base += f",并评论道:「{quote_content}"
return base
def _describe_follow(self) -> str:
target_user_name = self.action_args.get("target_user_name", "")
if target_user_name:
return f"关注了用户「{target_user_name}"
return "关注了一个用户"
def _describe_create_comment(self) -> str:
content = self.action_args.get("content", "")
post_content = self.action_args.get("post_content", "")
post_author = self.action_args.get("post_author_name", "")
if content:
if post_content and post_author:
return f"{post_author}的帖子「{post_content}」下评论道:「{content}"
elif post_content:
return f"在帖子「{post_content}」下评论道:「{content}"
elif post_author:
return f"{post_author}的帖子下评论道:「{content}"
return f"评论道:「{content}"
return "发表了评论"
def _describe_like_comment(self) -> str:
comment_content = self.action_args.get("comment_content", "")
comment_author = self.action_args.get("comment_author_name", "")
if comment_content and comment_author:
return f"点赞了{comment_author}的评论:「{comment_content}"
elif comment_content:
return f"点赞了一条评论:「{comment_content}"
elif comment_author:
return f"点赞了{comment_author}的一条评论"
return "点赞了一条评论"
def _describe_dislike_comment(self) -> str:
comment_content = self.action_args.get("comment_content", "")
comment_author = self.action_args.get("comment_author_name", "")
if comment_content and comment_author:
return f"踩了{comment_author}的评论:「{comment_content}"
elif comment_content:
return f"踩了一条评论:「{comment_content}"
elif comment_author:
return f"踩了{comment_author}的一条评论"
return "踩了一条评论"
def _describe_search(self) -> str:
query = self.action_args.get("query", "") or self.action_args.get("keyword", "")
return f"搜索了「{query}" if query else "进行了搜索"
def _describe_search_user(self) -> str:
query = self.action_args.get("query", "") or self.action_args.get("username", "")
return f"搜索了用户「{query}" if query else "搜索了用户"
def _describe_mute(self) -> str:
target_user_name = self.action_args.get("target_user_name", "")
if target_user_name:
return f"屏蔽了用户「{target_user_name}"
return "屏蔽了一个用户"
def _describe_generic(self) -> str:
return f"执行了{self.action_type}操作"
class GraphMemoryUpdater:
"""
图谱记忆更新器
监控模拟的actions日志文件将新的agent活动实时写入本地图谱。
按平台分组每累积BATCH_SIZE条活动后批量写入。
"""
BATCH_SIZE = 5
PLATFORM_DISPLAY_NAMES = {
'twitter': '世界1',
'reddit': '世界2',
}
SEND_INTERVAL = 0.1 # 本地写入更快,间隔可以更短
MAX_RETRIES = 3
RETRY_DELAY = 1
def __init__(self, graph_id: str, storage_dir: Optional[str] = None, api_key: Optional[str] = None):
"""
初始化更新器
Args:
graph_id: 本地图谱ID
storage_dir: 图谱存储目录(可选,默认从配置读取)
api_key: 已废弃,保留以兼容旧调用代码
"""
self.graph_id = graph_id
storage_dir = storage_dir or Config.GRAPH_STORAGE_DIR
self.store = LocalGraphStore(storage_dir)
self._activity_queue: Queue = Queue()
self._platform_buffers: Dict[str, List[AgentActivity]] = {
'twitter': [],
'reddit': [],
}
self._buffer_lock = threading.Lock()
self._running = False
self._worker_thread: Optional[threading.Thread] = None
self._total_activities = 0
self._total_sent = 0
self._total_items_sent = 0
self._failed_count = 0
self._skipped_count = 0
logger.info(f"GraphMemoryUpdater 初始化完成: graph_id={graph_id}, batch_size={self.BATCH_SIZE}")
def _get_platform_display_name(self, platform: str) -> str:
return self.PLATFORM_DISPLAY_NAMES.get(platform.lower(), platform)
def start(self):
"""启动后台工作线程"""
if self._running:
return
current_locale = get_locale()
self._running = True
self._worker_thread = threading.Thread(
target=self._worker_loop,
args=(current_locale,),
daemon=True,
name=f"GraphMemoryUpdater-{self.graph_id[:8]}"
)
self._worker_thread.start()
logger.info(f"GraphMemoryUpdater 已启动: graph_id={self.graph_id}")
def stop(self):
"""停止后台工作线程"""
self._running = False
self._flush_remaining()
if self._worker_thread and self._worker_thread.is_alive():
self._worker_thread.join(timeout=10)
logger.info(f"GraphMemoryUpdater 已停止: graph_id={self.graph_id}, "
f"total_activities={self._total_activities}, "
f"batches_sent={self._total_sent}, "
f"items_sent={self._total_items_sent}, "
f"failed={self._failed_count}, "
f"skipped={self._skipped_count}")
def add_activity(self, activity: AgentActivity):
"""添加一个agent活动到队列"""
if activity.action_type == "DO_NOTHING":
self._skipped_count += 1
return
self._activity_queue.put(activity)
self._total_activities += 1
logger.debug(f"添加活动到队列: {activity.agent_name} - {activity.action_type}")
def add_activity_from_dict(self, data: Dict[str, Any], platform: str):
"""从字典数据添加活动"""
if "event_type" in data:
return
activity = AgentActivity(
platform=platform,
agent_id=data.get("agent_id", 0),
agent_name=data.get("agent_name", ""),
action_type=data.get("action_type", ""),
action_args=data.get("action_args", {}),
round_num=data.get("round", 0),
timestamp=data.get("timestamp", datetime.now().isoformat()),
)
self.add_activity(activity)
def _worker_loop(self, locale: str = 'zh'):
"""后台工作循环 - 按平台批量写入活动"""
set_locale(locale)
while self._running or not self._activity_queue.empty():
try:
try:
activity = self._activity_queue.get(timeout=1)
platform = activity.platform.lower()
with self._buffer_lock:
if platform not in self._platform_buffers:
self._platform_buffers[platform] = []
self._platform_buffers[platform].append(activity)
if len(self._platform_buffers[platform]) >= self.BATCH_SIZE:
batch = self._platform_buffers[platform][:self.BATCH_SIZE]
self._platform_buffers[platform] = self._platform_buffers[platform][self.BATCH_SIZE:]
self._write_batch_activities(batch, platform)
time.sleep(self.SEND_INTERVAL)
except Empty:
pass
except Exception as e:
logger.error(f"工作循环异常: {e}")
time.sleep(1)
def _write_batch_activities(self, activities: List[AgentActivity], platform: str):
"""批量将活动写入本地图谱"""
if not activities:
return
episode_texts = [activity.to_episode_text() for activity in activities]
combined_text = "\n".join(episode_texts)
for attempt in range(self.MAX_RETRIES):
try:
# 写入情节文本
self.store.add_episode(self.graph_id, combined_text)
# 为每条活动创建可搜索的事实边
for activity in activities:
self._create_activity_edge(activity)
self._total_sent += 1
self._total_items_sent += len(activities)
display_name = self._get_platform_display_name(platform)
logger.info(f"成功写入 {len(activities)}{display_name}活动到图谱 {self.graph_id}")
return
except Exception as e:
if attempt < self.MAX_RETRIES - 1:
logger.warning(f"写入活动失败 (尝试 {attempt + 1}/{self.MAX_RETRIES}): {e}")
time.sleep(self.RETRY_DELAY * (attempt + 1))
else:
logger.error(f"写入活动失败,已重试{self.MAX_RETRIES}次: {e}")
self._failed_count += 1
def _create_activity_edge(self, activity: AgentActivity):
"""为单条活动在图谱中创建可搜索的事实边"""
fact = activity.to_episode_text()
# 创建或获取Agent节点
agent_uuid = self.store.upsert_node(
graph_id=self.graph_id,
name=activity.agent_name,
labels=["Agent", "Entity"],
summary=f"Agent {activity.agent_name}",
)
# 为活动创建自环边(以便关键词搜索可以找到它)
self.store.add_edge(self.graph_id, {
"name": activity.action_type,
"fact": fact,
"source_node_uuid": agent_uuid,
"target_node_uuid": agent_uuid,
"attributes": {
"platform": activity.platform,
"round": activity.round_num,
"timestamp": activity.timestamp,
},
})
def _flush_remaining(self):
"""发送队列和缓冲区中剩余的活动"""
while not self._activity_queue.empty():
try:
activity = self._activity_queue.get_nowait()
platform = activity.platform.lower()
with self._buffer_lock:
if platform not in self._platform_buffers:
self._platform_buffers[platform] = []
self._platform_buffers[platform].append(activity)
except Empty:
break
with self._buffer_lock:
for platform, buffer in self._platform_buffers.items():
if buffer:
display_name = self._get_platform_display_name(platform)
logger.info(f"发送{display_name}平台剩余的 {len(buffer)} 条活动")
self._write_batch_activities(buffer, platform)
for platform in self._platform_buffers:
self._platform_buffers[platform] = []
def get_stats(self) -> Dict[str, Any]:
"""获取统计信息"""
with self._buffer_lock:
buffer_sizes = {p: len(b) for p, b in self._platform_buffers.items()}
return {
"graph_id": self.graph_id,
"batch_size": self.BATCH_SIZE,
"total_activities": self._total_activities,
"batches_sent": self._total_sent,
"items_sent": self._total_items_sent,
"failed_count": self._failed_count,
"skipped_count": self._skipped_count,
"queue_size": self._activity_queue.qsize(),
"buffer_sizes": buffer_sizes,
"running": self._running,
}
# 向后兼容别名
ZepGraphMemoryUpdater = GraphMemoryUpdater
class ZepGraphMemoryManager:
"""
管理多个模拟的图谱记忆更新器
每个模拟可以有自己的更新器实例
"""
_updaters: Dict[str, GraphMemoryUpdater] = {}
_lock = threading.Lock()
@classmethod
def create_updater(cls, simulation_id: str, graph_id: str) -> GraphMemoryUpdater:
"""为模拟创建图谱记忆更新器"""
with cls._lock:
if simulation_id in cls._updaters:
cls._updaters[simulation_id].stop()
updater = GraphMemoryUpdater(graph_id)
updater.start()
cls._updaters[simulation_id] = updater
logger.info(f"创建图谱记忆更新器: simulation_id={simulation_id}, graph_id={graph_id}")
return updater
@classmethod
def get_updater(cls, simulation_id: str) -> Optional[GraphMemoryUpdater]:
return cls._updaters.get(simulation_id)
@classmethod
def stop_updater(cls, simulation_id: str):
with cls._lock:
if simulation_id in cls._updaters:
cls._updaters[simulation_id].stop()
del cls._updaters[simulation_id]
logger.info(f"已停止图谱记忆更新器: simulation_id={simulation_id}")
_stop_all_done = False
@classmethod
def stop_all(cls):
if cls._stop_all_done:
return
cls._stop_all_done = True
with cls._lock:
if cls._updaters:
for simulation_id, updater in list(cls._updaters.items()):
try:
updater.stop()
except Exception as e:
logger.error(f"停止更新器失败: simulation_id={simulation_id}, error={e}")
cls._updaters.clear()
logger.info("已停止所有图谱记忆更新器")
@classmethod
def get_all_stats(cls) -> Dict[str, Dict[str, Any]]:
return {
sim_id: updater.get_stats()
for sim_id, updater in cls._updaters.items()
}