From 480621611089c2898e52bb21e995a545b887c76b Mon Sep 17 00:00:00 2001 From: Nader Elkhouri Date: Mon, 25 May 2026 22:40:08 -0300 Subject: [PATCH] feat: add local SQLite graph storage fallback - Add local graph store for running without Zep Cloud - Allow graph, simulation and memory flows to work in sqlite mode - Preserve Zep integration when API credentials are configured --- backend/app/api/graph.py | 12 +- backend/app/api/simulation.py | 6 +- backend/app/config.py | 16 +- backend/app/services/graph_builder.py | 108 ++- backend/app/services/local_graph_store.py | 869 ++++++++++++++++++ backend/app/services/zep_entity_reader.py | 44 +- .../app/services/zep_graph_memory_updater.py | 22 +- backend/app/services/zep_tools.py | 565 +++++++----- 8 files changed, 1365 insertions(+), 277 deletions(-) create mode 100644 backend/app/services/local_graph_store.py diff --git a/backend/app/api/graph.py b/backend/app/api/graph.py index 759ff48b..e699533b 100644 --- a/backend/app/api/graph.py +++ b/backend/app/api/graph.py @@ -285,7 +285,7 @@ def build_graph(): # 检查配置 errors = [] - if not Config.ZEP_API_KEY: + if Config.GRAPH_STORAGE_BACKEND != 'sqlite' and not Config.ZEP_API_KEY: errors.append(t('api.zepApiKeyMissing')) if errors: logger.error(f"配置错误: {errors}") @@ -387,7 +387,7 @@ def build_graph(): ) # 创建图谱构建服务 - builder = GraphBuilderService(api_key=Config.ZEP_API_KEY) + builder = GraphBuilderService() # 分块 task_manager.update_task( @@ -572,13 +572,13 @@ def get_graph_data(graph_id: str): 获取图谱数据(节点和边) """ try: - if not Config.ZEP_API_KEY: + if Config.GRAPH_STORAGE_BACKEND != 'sqlite' and not Config.ZEP_API_KEY: return jsonify({ "success": False, "error": t('api.zepApiKeyMissing') }), 500 - builder = GraphBuilderService(api_key=Config.ZEP_API_KEY) + builder = GraphBuilderService() graph_data = builder.get_graph_data(graph_id) return jsonify({ @@ -600,13 +600,13 @@ def delete_graph(graph_id: str): 删除Zep图谱 """ try: - if not Config.ZEP_API_KEY: + if Config.GRAPH_STORAGE_BACKEND != 'sqlite' and not Config.ZEP_API_KEY: return jsonify({ "success": False, "error": t('api.zepApiKeyMissing') }), 500 - builder = GraphBuilderService(api_key=Config.ZEP_API_KEY) + builder = GraphBuilderService() builder.delete_graph(graph_id) return jsonify({ diff --git a/backend/app/api/simulation.py b/backend/app/api/simulation.py index 3a8e1e3f..f22b431c 100644 --- a/backend/app/api/simulation.py +++ b/backend/app/api/simulation.py @@ -57,7 +57,7 @@ def get_graph_entities(graph_id: str): enrich: 是否获取相关边信息(默认true) """ try: - if not Config.ZEP_API_KEY: + if Config.GRAPH_STORAGE_BACKEND != 'sqlite' and not Config.ZEP_API_KEY: return jsonify({ "success": False, "error": t('api.zepApiKeyMissing') @@ -94,7 +94,7 @@ def get_graph_entities(graph_id: str): def get_entity_detail(graph_id: str, entity_uuid: str): """获取单个实体的详细信息""" try: - if not Config.ZEP_API_KEY: + if Config.GRAPH_STORAGE_BACKEND != 'sqlite' and not Config.ZEP_API_KEY: return jsonify({ "success": False, "error": t('api.zepApiKeyMissing') @@ -127,7 +127,7 @@ def get_entity_detail(graph_id: str, entity_uuid: str): def get_entities_by_type(graph_id: str, entity_type: str): """获取指定类型的所有实体""" try: - if not Config.ZEP_API_KEY: + if Config.GRAPH_STORAGE_BACKEND != 'sqlite' and not Config.ZEP_API_KEY: return jsonify({ "success": False, "error": t('api.zepApiKeyMissing') diff --git a/backend/app/config.py b/backend/app/config.py index de63e2b4..55ddbf8f 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -32,9 +32,18 @@ class Config: LLM_BASE_URL = os.environ.get('LLM_BASE_URL', 'https://api.openai.com/v1') LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME', 'gpt-4o-mini') - # Zep配置 + # 图谱存储配置 + # 默认优先使用 SQLite 本地存储;如果显式提供 ZEP_API_KEY,也可以继续走远端 Zep。 ZEP_API_KEY = os.environ.get('ZEP_API_KEY') - + GRAPH_STORAGE_BACKEND = os.environ.get( + 'MIROFISH_GRAPH_STORAGE', + 'sqlite' if not os.environ.get('ZEP_API_KEY') else 'zep' + ).strip().lower() + LOCAL_GRAPH_DB_PATH = os.environ.get( + 'MIROFISH_GRAPH_DB_PATH', + os.path.join(os.path.dirname(__file__), '../uploads/local_graphs.sqlite3') + ) + # 文件上传配置 MAX_CONTENT_LENGTH = 50 * 1024 * 1024 # 50MB UPLOAD_FOLDER = os.path.join(os.path.dirname(__file__), '../uploads') @@ -69,7 +78,6 @@ class Config: errors: list[str] = [] if not cls.LLM_API_KEY: errors.append("LLM_API_KEY 未配置") - if not cls.ZEP_API_KEY: - errors.append("ZEP_API_KEY 未配置") + # 图谱存储现在默认可以使用 SQLite,本地模式下不再强制要求 ZEP_API_KEY。 return errors diff --git a/backend/app/services/graph_builder.py b/backend/app/services/graph_builder.py index 37c9969c..e90ae511 100644 --- a/backend/app/services/graph_builder.py +++ b/backend/app/services/graph_builder.py @@ -15,7 +15,9 @@ from zep_cloud import EpisodeData, EntityEdgeSourceTarget from ..config import Config from ..models.task import TaskManager, TaskStatus +from ..utils.llm_client import LLMClient from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges +from .local_graph_store import LocalGraphStore from .text_processor import TextProcessor from ..utils.locale import t, get_locale, set_locale @@ -45,11 +47,23 @@ class GraphBuilderService: def __init__(self, api_key: Optional[str] = None): self.api_key = api_key or Config.ZEP_API_KEY - if not self.api_key: - raise ValueError("ZEP_API_KEY 未配置") - - self.client = Zep(api_key=self.api_key) + self.use_local_storage = Config.GRAPH_STORAGE_BACKEND == 'sqlite' or not self.api_key self.task_manager = TaskManager() + self.local_store = LocalGraphStore() if self.use_local_storage else None + self.llm_client = None + + if self.use_local_storage: + # Local SQLite mode: no Zep client needed. + if Config.LLM_API_KEY: + try: + self.llm_client = LLMClient() + except Exception: + self.llm_client = None + self.client = None + else: + if not self.api_key: + raise ValueError("ZEP_API_KEY 未配置") + self.client = Zep(api_key=self.api_key) def build_graph_async( self, @@ -191,19 +205,31 @@ class GraphBuilderService: self.task_manager.fail_task(task_id, error_msg) def create_graph(self, name: str) -> str: - """创建Zep图谱(公开方法)""" + """创建图谱(公开方法)""" graph_id = f"mirofish_{uuid.uuid4().hex[:16]}" - + + if self.use_local_storage: + self.local_store.create_graph( + graph_id=graph_id, + name=name, + description="MiroFish Social Simulation Graph", + ) + return graph_id + self.client.graph.create( graph_id=graph_id, name=name, description="MiroFish Social Simulation Graph" ) - + return graph_id def set_ontology(self, graph_id: str, ontology: Dict[str, Any]): """设置图谱本体(公开方法)""" + if self.use_local_storage: + self.local_store.set_ontology(graph_id, ontology) + return + import warnings from typing import Optional from pydantic import Field @@ -290,7 +316,6 @@ class GraphBuilderService: entities=entity_types if entity_types else None, edges=edge_definitions if edge_definitions else None, ) - def add_text_batches( self, graph_id: str, @@ -299,6 +324,25 @@ class GraphBuilderService: progress_callback: Optional[Callable] = None ) -> List[str]: """分批添加文本到图谱,返回所有 episode 的 uuid 列表""" + if self.use_local_storage: + total = len(chunks) + if total == 0: + return [] + if progress_callback: + progress_callback(t('progress.sendingBatch', current=1, total=1, chunks=total), 0.0) + + episode_uuids = self.local_store.extract_and_store_chunks( + graph_id=graph_id, + chunks=chunks, + ontology=(self.local_store.get_graph(graph_id) or {}).get('ontology', {}), + llm_client=self.llm_client, + progress_callback=None, + batch_size=batch_size, + ) + if progress_callback: + progress_callback(t('progress.processingComplete', completed=len(episode_uuids), total=len(episode_uuids)), 1.0) + return episode_uuids + episode_uuids = [] total_chunks = len(chunks) @@ -351,6 +395,11 @@ class GraphBuilderService: timeout: int = 600 ): """等待所有 episode 处理完成(通过查询每个 episode 的 processed 状态)""" + if self.use_local_storage: + if progress_callback: + progress_callback(t('progress.processingComplete', completed=len(episode_uuids), total=len(episode_uuids)), 1.0) + return + if not episode_uuids: if progress_callback: progress_callback(t('progress.noEpisodesWait'), 1.0) @@ -402,11 +451,29 @@ class GraphBuilderService: def _get_graph_info(self, graph_id: str) -> GraphInfo: """获取图谱信息""" + if self.use_local_storage: + graph = self.local_store.get_graph(graph_id) or {} + nodes = graph.get("nodes", []) + edges = graph.get("edges", []) + entity_types = set() + for node in nodes: + for label in node.get("labels", []): + if label not in ["Entity", "Node"]: + entity_types.add(label) + return GraphInfo( + graph_id=graph_id, + node_count=len(nodes), + edge_count=len(edges), + entity_types=list(entity_types), + ) + # 获取节点(分页) - nodes = fetch_all_nodes(self.client, graph_id) + client = self.client + assert client is not None + nodes = fetch_all_nodes(client, graph_id) # 获取边(分页) - edges = fetch_all_edges(self.client, graph_id) + edges = fetch_all_edges(client, graph_id) # 统计实体类型 entity_types = set() @@ -433,8 +500,22 @@ class GraphBuilderService: Returns: 包含nodes和edges的字典,包括时间信息、属性等详细数据 """ - nodes = fetch_all_nodes(self.client, graph_id) - edges = fetch_all_edges(self.client, graph_id) + if self.use_local_storage: + graph = self.local_store.get_graph(graph_id) or {} + nodes_data = graph.get("nodes", []) + edges_data = graph.get("edges", []) + return { + "graph_id": graph_id, + "nodes": nodes_data, + "edges": edges_data, + "node_count": len(nodes_data), + "edge_count": len(edges_data), + } + + client = self.client + assert client is not None + nodes = fetch_all_nodes(client, graph_id) + edges = fetch_all_edges(client, graph_id) # 创建节点映射用于获取节点名称 node_map = {} @@ -502,5 +583,8 @@ class GraphBuilderService: def delete_graph(self, graph_id: str): """删除图谱""" + if self.use_local_storage: + self.local_store.delete_graph(graph_id) + return self.client.graph.delete(graph_id=graph_id) diff --git a/backend/app/services/local_graph_store.py b/backend/app/services/local_graph_store.py new file mode 100644 index 00000000..28d12afa --- /dev/null +++ b/backend/app/services/local_graph_store.py @@ -0,0 +1,869 @@ +"""SQLite-backed local graph store used when Zep is disabled. + +This is a pragmatic replacement layer that keeps the app bootable and lets the +main graph/profile/report flows use local persistence instead of Zep Cloud. +It intentionally stores raw chunks, extracted nodes, extracted edges, and +lightweight episode metadata in SQLite. +""" + +from __future__ import annotations + +import json +import logging +import re +import sqlite3 +import threading +import uuid +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple + +from ..config import Config + +logger = logging.getLogger(__name__) + + +@dataclass +class ExtractedChunk: + episode_uuid: str + nodes: List[Dict[str, Any]] + edges: List[Dict[str, Any]] + + +def _utcnow() -> str: + return datetime.now(timezone.utc).isoformat() + + +def _json_loads(value: Optional[str], default: Any) -> Any: + if not value: + return default + try: + return json.loads(value) + except Exception: + return default + + +def _json_dumps(value: Any) -> str: + return json.dumps(value, ensure_ascii=False, sort_keys=True) + + +def _stable_uuid(*parts: str) -> str: + payload = "::".join(part.strip().lower() for part in parts if part is not None) + return str(uuid.uuid5(uuid.NAMESPACE_URL, payload)) + + +class LocalGraphStore: + def __init__(self, db_path: Optional[str] = None): + self.db_path = Path(db_path or getattr(Config, "LOCAL_GRAPH_DB_PATH", "") or self._default_db_path()) + self.db_path.parent.mkdir(parents=True, exist_ok=True) + self._lock = threading.Lock() + self._ensure_schema() + + @staticmethod + def _default_db_path() -> str: + return str(Path(__file__).resolve().parents[2] / "uploads" / "local_graphs.sqlite3") + + def _connect(self) -> sqlite3.Connection: + conn = sqlite3.connect(self.db_path) + conn.row_factory = sqlite3.Row + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA foreign_keys=ON") + return conn + + def _ensure_schema(self) -> None: + with self._lock: + conn = self._connect() + try: + conn.executescript( + """ + CREATE TABLE IF NOT EXISTS graphs ( + graph_id TEXT PRIMARY KEY, + name TEXT NOT NULL, + description TEXT DEFAULT '', + ontology_json TEXT NOT NULL DEFAULT '{}', + source_text TEXT NOT NULL DEFAULT '', + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + mode TEXT NOT NULL DEFAULT 'sqlite' + ); + + CREATE TABLE IF NOT EXISTS graph_chunks ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + graph_id TEXT NOT NULL, + chunk_index INTEGER NOT NULL, + content TEXT NOT NULL, + created_at TEXT NOT NULL, + FOREIGN KEY(graph_id) REFERENCES graphs(graph_id) ON DELETE CASCADE + ); + + CREATE TABLE IF NOT EXISTS graph_episodes ( + uuid TEXT PRIMARY KEY, + graph_id TEXT NOT NULL, + kind TEXT NOT NULL DEFAULT 'chunk', + data TEXT NOT NULL, + processed INTEGER NOT NULL DEFAULT 1, + created_at TEXT NOT NULL, + FOREIGN KEY(graph_id) REFERENCES graphs(graph_id) ON DELETE CASCADE + ); + + CREATE TABLE IF NOT EXISTS graph_nodes ( + uuid TEXT PRIMARY KEY, + graph_id TEXT NOT NULL, + name TEXT NOT NULL, + labels_json TEXT NOT NULL DEFAULT '[]', + summary TEXT NOT NULL DEFAULT '', + attributes_json TEXT NOT NULL DEFAULT '{}', + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL, + FOREIGN KEY(graph_id) REFERENCES graphs(graph_id) ON DELETE CASCADE + ); + + CREATE TABLE IF NOT EXISTS graph_edges ( + uuid TEXT PRIMARY KEY, + graph_id TEXT NOT NULL, + name TEXT NOT NULL DEFAULT '', + fact TEXT NOT NULL DEFAULT '', + fact_type TEXT NOT NULL DEFAULT '', + source_node_uuid TEXT NOT NULL, + target_node_uuid TEXT NOT NULL, + source_node_name TEXT NOT NULL DEFAULT '', + target_node_name TEXT NOT NULL DEFAULT '', + attributes_json TEXT NOT NULL DEFAULT '{}', + episodes_json TEXT NOT NULL DEFAULT '[]', + created_at TEXT NOT NULL, + valid_at TEXT, + invalid_at TEXT, + expired_at TEXT, + updated_at TEXT NOT NULL, + FOREIGN KEY(graph_id) REFERENCES graphs(graph_id) ON DELETE CASCADE + ); + + CREATE INDEX IF NOT EXISTS idx_graph_chunks_graph_id ON graph_chunks(graph_id); + CREATE INDEX IF NOT EXISTS idx_graph_episodes_graph_id ON graph_episodes(graph_id); + CREATE INDEX IF NOT EXISTS idx_graph_nodes_graph_id ON graph_nodes(graph_id); + CREATE INDEX IF NOT EXISTS idx_graph_edges_graph_id ON graph_edges(graph_id); + CREATE INDEX IF NOT EXISTS idx_graph_edges_source ON graph_edges(source_node_uuid); + CREATE INDEX IF NOT EXISTS idx_graph_edges_target ON graph_edges(target_node_uuid); + """ + ) + conn.commit() + finally: + conn.close() + + def create_graph(self, graph_id: str, name: str, description: str = "") -> str: + now = _utcnow() + with self._lock: + conn = self._connect() + try: + conn.execute( + """ + INSERT INTO graphs (graph_id, name, description, ontology_json, source_text, created_at, updated_at, mode) + VALUES (?, ?, ?, ?, ?, ?, ?, 'sqlite') + ON CONFLICT(graph_id) DO UPDATE SET + name=excluded.name, + description=excluded.description, + updated_at=excluded.updated_at + """, + (graph_id, name, description, "{}", "", now, now), + ) + conn.commit() + return graph_id + finally: + conn.close() + + def set_ontology(self, graph_id: str, ontology: Dict[str, Any]) -> None: + now = _utcnow() + with self._lock: + conn = self._connect() + try: + conn.execute( + "UPDATE graphs SET ontology_json = ?, updated_at = ? WHERE graph_id = ?", + (_json_dumps(ontology or {}), now, graph_id), + ) + conn.commit() + finally: + conn.close() + + def append_source_text(self, graph_id: str, source_text: str) -> None: + now = _utcnow() + with self._lock: + conn = self._connect() + try: + conn.execute( + "UPDATE graphs SET source_text = COALESCE(source_text, '') || ?, updated_at = ? WHERE graph_id = ?", + (source_text or "", now, graph_id), + ) + conn.commit() + finally: + conn.close() + + def store_episode(self, graph_id: str, chunk_index: int, content: str, kind: str = "chunk") -> str: + episode_uuid = str(uuid.uuid4()) + now = _utcnow() + with self._lock: + conn = self._connect() + try: + conn.execute( + """ + INSERT INTO graph_episodes (uuid, graph_id, kind, data, processed, created_at) + VALUES (?, ?, ?, ?, 1, ?) + """, + (episode_uuid, graph_id, kind, content, now), + ) + conn.execute( + """ + INSERT INTO graph_chunks (graph_id, chunk_index, content, created_at) + VALUES (?, ?, ?, ?) + """, + (graph_id, chunk_index, content, now), + ) + conn.commit() + return episode_uuid + finally: + conn.close() + + def upsert_node( + self, + graph_id: str, + name: str, + labels: Sequence[str], + summary: str = "", + attributes: Optional[Dict[str, Any]] = None, + node_uuid: Optional[str] = None, + ) -> str: + node_uuid = node_uuid or _stable_uuid(graph_id, "node", name, ",".join(labels or [])) + now = _utcnow() + payload = ( + node_uuid, + graph_id, + name.strip(), + _json_dumps(list(labels or [])), + summary or "", + _json_dumps(attributes or {}), + now, + now, + ) + with self._lock: + conn = self._connect() + try: + conn.execute( + """ + INSERT INTO graph_nodes (uuid, graph_id, name, labels_json, summary, attributes_json, created_at, updated_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(uuid) DO UPDATE SET + name=excluded.name, + labels_json=excluded.labels_json, + summary=excluded.summary, + attributes_json=excluded.attributes_json, + updated_at=excluded.updated_at + """, + payload, + ) + conn.commit() + return node_uuid + finally: + conn.close() + + def upsert_edge( + self, + graph_id: str, + name: str, + fact: str, + source_node_uuid: str, + target_node_uuid: str, + source_node_name: str, + target_node_name: str, + attributes: Optional[Dict[str, Any]] = None, + episodes: Optional[Sequence[str]] = None, + fact_type: Optional[str] = None, + edge_uuid: Optional[str] = None, + created_at: Optional[str] = None, + valid_at: Optional[str] = None, + invalid_at: Optional[str] = None, + expired_at: Optional[str] = None, + ) -> str: + edge_uuid = edge_uuid or _stable_uuid( + graph_id, + "edge", + name, + source_node_uuid, + target_node_uuid, + fact[:160], + ) + now = _utcnow() + with self._lock: + conn = self._connect() + try: + existing = conn.execute( + "SELECT episodes_json, created_at FROM graph_edges WHERE uuid = ?", + (edge_uuid,), + ).fetchone() + merged_episodes = list(episodes or []) + if existing: + prev_episodes = _json_loads(existing["episodes_json"], []) + if isinstance(prev_episodes, list): + merged_episodes = list(dict.fromkeys([*prev_episodes, *merged_episodes])) + if not created_at: + created_at = existing["created_at"] + conn.execute( + """ + INSERT INTO graph_edges ( + uuid, graph_id, name, fact, fact_type, source_node_uuid, target_node_uuid, + source_node_name, target_node_name, attributes_json, episodes_json, + created_at, valid_at, invalid_at, expired_at, updated_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + ON CONFLICT(uuid) DO UPDATE SET + name=excluded.name, + fact=excluded.fact, + fact_type=excluded.fact_type, + source_node_uuid=excluded.source_node_uuid, + target_node_uuid=excluded.target_node_uuid, + source_node_name=excluded.source_node_name, + target_node_name=excluded.target_node_name, + attributes_json=excluded.attributes_json, + episodes_json=excluded.episodes_json, + valid_at=excluded.valid_at, + invalid_at=excluded.invalid_at, + expired_at=excluded.expired_at, + updated_at=excluded.updated_at + """, + ( + edge_uuid, + graph_id, + name or "", + fact or "", + fact_type or name or "", + source_node_uuid, + target_node_uuid, + source_node_name or "", + target_node_name or "", + _json_dumps(attributes or {}), + _json_dumps(merged_episodes), + created_at or now, + valid_at, + invalid_at, + expired_at, + now, + ), + ) + conn.commit() + return edge_uuid + finally: + conn.close() + + def _load_graph(self, conn: sqlite3.Connection, graph_id: str) -> Optional[sqlite3.Row]: + return conn.execute("SELECT * FROM graphs WHERE graph_id = ?", (graph_id,)).fetchone() + + def get_graph(self, graph_id: str) -> Optional[Dict[str, Any]]: + with self._lock: + conn = self._connect() + try: + graph = self._load_graph(conn, graph_id) + if not graph: + return None + nodes = self.list_nodes(graph_id, conn=conn) + edges = self.list_edges(graph_id, conn=conn) + chunks = conn.execute( + "SELECT chunk_index, content, created_at FROM graph_chunks WHERE graph_id = ? ORDER BY chunk_index, id", + (graph_id,), + ).fetchall() + return { + "graph_id": graph_id, + "name": graph["name"], + "description": graph["description"], + "ontology": _json_loads(graph["ontology_json"], {}), + "source_text": graph["source_text"], + "created_at": graph["created_at"], + "updated_at": graph["updated_at"], + "nodes": nodes, + "edges": edges, + "chunks": [dict(row) for row in chunks], + } + finally: + conn.close() + + def list_graphs(self) -> List[Dict[str, Any]]: + with self._lock: + conn = self._connect() + try: + rows = conn.execute( + "SELECT graph_id, name, description, ontology_json, created_at, updated_at FROM graphs ORDER BY updated_at DESC" + ).fetchall() + return [ + { + "graph_id": row["graph_id"], + "name": row["name"], + "description": row["description"], + "ontology": _json_loads(row["ontology_json"], {}), + "created_at": row["created_at"], + "updated_at": row["updated_at"], + } + for row in rows + ] + finally: + conn.close() + + def list_nodes(self, graph_id: str, conn: Optional[sqlite3.Connection] = None) -> List[Dict[str, Any]]: + own_conn = None + if conn is None: + own_conn = self._connect() + conn = own_conn + try: + rows = conn.execute( + "SELECT * FROM graph_nodes WHERE graph_id = ? ORDER BY updated_at DESC, created_at DESC, name", + (graph_id,), + ).fetchall() + return [self._row_to_node(row) for row in rows] + finally: + if own_conn is not None: + own_conn.close() + + def list_edges(self, graph_id: str, conn: Optional[sqlite3.Connection] = None) -> List[Dict[str, Any]]: + own_conn = None + if conn is None: + own_conn = self._connect() + conn = own_conn + try: + rows = conn.execute( + "SELECT * FROM graph_edges WHERE graph_id = ? ORDER BY updated_at DESC, created_at DESC", + (graph_id,), + ).fetchall() + return [self._row_to_edge(row) for row in rows] + finally: + if own_conn is not None: + own_conn.close() + + def get_node(self, graph_id: str, node_uuid: str) -> Optional[Dict[str, Any]]: + with self._lock: + conn = self._connect() + try: + row = conn.execute( + "SELECT * FROM graph_nodes WHERE graph_id = ? AND uuid = ?", + (graph_id, node_uuid), + ).fetchone() + return self._row_to_node(row) if row else None + finally: + conn.close() + + def get_edges_for_node(self, graph_id: str, node_uuid: str) -> List[Dict[str, Any]]: + with self._lock: + conn = self._connect() + try: + rows = conn.execute( + """ + SELECT * FROM graph_edges + WHERE graph_id = ? AND (source_node_uuid = ? OR target_node_uuid = ?) + ORDER BY updated_at DESC, created_at DESC + """, + (graph_id, node_uuid, node_uuid), + ).fetchall() + return [self._row_to_edge(row) for row in rows] + finally: + conn.close() + + def get_edges_for_node_any(self, node_uuid: str) -> List[Dict[str, Any]]: + with self._lock: + conn = self._connect() + try: + rows = conn.execute( + """ + SELECT * FROM graph_edges + WHERE source_node_uuid = ? OR target_node_uuid = ? + ORDER BY updated_at DESC, created_at DESC + """, + (node_uuid, node_uuid), + ).fetchall() + return [self._row_to_edge(row) for row in rows] + finally: + conn.close() + + def get_graph_statistics(self, graph_id: str) -> Dict[str, Any]: + nodes = self.list_nodes(graph_id) + edges = self.list_edges(graph_id) + entity_types: Dict[str, int] = {} + relation_types: Dict[str, int] = {} + for node in nodes: + for label in node.get("labels", []): + if label not in ["Entity", "Node"]: + entity_types[label] = entity_types.get(label, 0) + 1 + for edge in edges: + relation_types[edge.get("name", "")] = relation_types.get(edge.get("name", ""), 0) + 1 + return { + "graph_id": graph_id, + "total_nodes": len(nodes), + "total_edges": len(edges), + "entity_types": entity_types, + "relation_types": relation_types, + } + + def delete_graph(self, graph_id: str) -> None: + with self._lock: + conn = self._connect() + try: + conn.execute("DELETE FROM graphs WHERE graph_id = ?", (graph_id,)) + conn.commit() + finally: + conn.close() + + def append_activity(self, graph_id: str, payload: Dict[str, Any]) -> str: + episode_uuid = str(uuid.uuid4()) + now = _utcnow() + with self._lock: + conn = self._connect() + try: + conn.execute( + """ + INSERT INTO graph_episodes (uuid, graph_id, kind, data, processed, created_at) + VALUES (?, ?, ?, ?, 1, ?) + """, + (episode_uuid, graph_id, "activity", _json_dumps(payload), now), + ) + conn.commit() + return episode_uuid + finally: + conn.close() + + def search(self, graph_id: str, query: str, scope: str = "edges", limit: int = 10) -> Dict[str, Any]: + query_lower = (query or "").lower() + keywords = [w.strip() for w in re.split(r"[\s,,]+", query_lower) if len(w.strip()) > 1] + + def score_text(text: str) -> int: + if not text: + return 0 + text_lower = text.lower() + score = 0 + if query_lower and query_lower in text_lower: + score += 100 + for kw in keywords: + if kw in text_lower: + score += 10 + return score + + nodes = self.list_nodes(graph_id) if scope in ["nodes", "both"] else [] + edges = self.list_edges(graph_id) if scope in ["edges", "both"] else [] + facts: List[str] = [] + scored_nodes: List[Tuple[int, Dict[str, Any]]] = [] + scored_edges: List[Tuple[int, Dict[str, Any]]] = [] + + for node in nodes: + text = " ".join([ + node.get("name", ""), + node.get("summary", ""), + _json_dumps(node.get("attributes", {})), + ]) + score = score_text(text) + if score > 0: + scored_nodes.append((score, node)) + + for edge in edges: + text = " ".join([ + edge.get("name", ""), + edge.get("fact", ""), + edge.get("source_node_name", ""), + edge.get("target_node_name", ""), + _json_dumps(edge.get("attributes", {})), + ]) + score = score_text(text) + if score > 0: + scored_edges.append((score, edge)) + + scored_nodes.sort(key=lambda item: item[0], reverse=True) + scored_edges.sort(key=lambda item: item[0], reverse=True) + + node_results = [node for _, node in scored_nodes[:limit]] + edge_results = [edge for _, edge in scored_edges[:limit]] + + for node in node_results: + if node.get("summary"): + facts.append(f"[{node['name']}]: {node['summary']}") + for edge in edge_results: + if edge.get("fact"): + facts.append(edge["fact"]) + + return { + "facts": facts[:limit], + "nodes": node_results, + "edges": edge_results, + "total_count": len(facts), + } + + def extract_and_store_chunks( + self, + graph_id: str, + chunks: Sequence[str], + ontology: Optional[Dict[str, Any]] = None, + llm_client: Optional[Any] = None, + progress_callback: Optional[Any] = None, + batch_size: int = 1, + ) -> List[str]: + episode_uuids: List[str] = [] + ontology = ontology or {} + + existing_nodes = {node["name"].strip().lower(): node for node in self.list_nodes(graph_id)} + total = len(chunks) + for idx, chunk in enumerate(chunks): + if progress_callback: + progress_callback(idx + 1, total) + episode_uuid = self.store_episode(graph_id, idx, chunk, kind="text") + episode_uuids.append(episode_uuid) + extracted = self._extract_chunk(chunk, ontology, llm_client) + node_map = self._store_extracted_chunk(graph_id, extracted, episode_uuid, existing_nodes) + existing_nodes.update(node_map) + return episode_uuids + + def _extract_chunk(self, chunk: str, ontology: Dict[str, Any], llm_client: Optional[Any]) -> ExtractedChunk: + if llm_client is not None: + try: + payload = self._extract_with_llm(chunk, ontology, llm_client) + return ExtractedChunk( + episode_uuid="", + nodes=payload.get("nodes", []), + edges=payload.get("edges", []), + ) + except Exception as exc: + logger.warning("Local graph extraction via LLM failed, falling back to heuristics: %s", exc) + nodes = self._heuristic_extract_nodes(chunk, ontology) + edges = self._heuristic_extract_edges(chunk, nodes, ontology) + return ExtractedChunk(episode_uuid="", nodes=nodes, edges=edges) + + def _extract_with_llm(self, chunk: str, ontology: Dict[str, Any], llm_client: Any) -> Dict[str, Any]: + entity_types = [e.get("name", "Entity") for e in ontology.get("entity_types", []) if isinstance(e, dict)] + edge_types = [e.get("name", "RELATED_TO") for e in ontology.get("edge_types", []) if isinstance(e, dict)] + system = ( + "You extract graph nodes and edges from source text for a social-simulation knowledge graph. " + "Return only JSON with keys nodes and edges. " + "Each node must have name, labels, summary, attributes. " + "Each edge must have name, fact, source_node_name, target_node_name, attributes." + ) + if entity_types: + system += " Allowed entity types: " + ", ".join(entity_types) + if edge_types: + system += " Allowed relation types: " + ", ".join(edge_types) + user = ( + f"Ontology:\n{_json_dumps(ontology)}\n\n" + f"Text chunk:\n{chunk}\n\n" + "Rules:\n" + "- Prefer real-world entities that can act on social media\n" + "- If the text mentions organizations or people, use those as nodes\n" + "- Use labels like ['Entity', '']\n" + "- Facts should be concise natural language statements\n" + "- If you cannot infer a relationship, return an empty edges list\n" + ) + result = llm_client.chat_json( + messages=[ + {"role": "system", "content": system}, + {"role": "user", "content": user}, + ], + temperature=0.1, + max_tokens=2500, + ) + if not isinstance(result, dict): + raise ValueError("Invalid extraction payload") + return result + + def _heuristic_extract_nodes(self, chunk: str, ontology: Dict[str, Any]) -> List[Dict[str, Any]]: + entity_types = [e.get("name", "Entity") for e in ontology.get("entity_types", []) if isinstance(e, dict)] + fallback_type = entity_types[0] if entity_types else "Entity" + candidates = self._guess_entities(chunk) + nodes: List[Dict[str, Any]] = [] + for candidate in candidates[:8]: + nodes.append( + { + "name": candidate, + "labels": ["Entity", fallback_type], + "summary": candidate, + "attributes": {}, + } + ) + return nodes + + def _heuristic_extract_edges( + self, + chunk: str, + nodes: Sequence[Dict[str, Any]], + ontology: Dict[str, Any], + ) -> List[Dict[str, Any]]: + relation_types = [e.get("name", "RELATED_TO") for e in ontology.get("edge_types", []) if isinstance(e, dict)] + relation_name = relation_types[0] if relation_types else "RELATED_TO" + node_names = [str(node.get("name", "")).strip() for node in nodes if str(node.get("name", "")).strip()] + if len(node_names) < 2: + return [] + + # Split on sentence boundaries first; fall back to whole chunk if needed. + sentences = [part.strip() for part in re.split(r"[\.!?。!?\n]+", chunk) if part.strip()] + if not sentences: + sentences = [chunk.strip()] + + edges: List[Dict[str, Any]] = [] + seen_pairs = set() + + for sentence in sentences: + matched: List[str] = [] + lower_sentence = sentence.lower() + for name in node_names: + needle = name.lower() + if needle and needle in lower_sentence and name not in matched: + matched.append(name) + if len(matched) < 2: + continue + for source_name, target_name in zip(matched, matched[1:]): + pair_key = (source_name.lower(), target_name.lower(), sentence.lower()) + if pair_key in seen_pairs: + continue + seen_pairs.add(pair_key) + edges.append( + { + "name": relation_name, + "fact": sentence.strip() or f"{source_name} {relation_name} {target_name}", + "source_node_name": source_name, + "target_node_name": target_name, + "attributes": {}, + "fact_type": relation_name, + } + ) + + if not edges: + for source_name, target_name in zip(node_names, node_names[1:]): + pair_key = (source_name.lower(), target_name.lower(), "fallback") + if pair_key in seen_pairs: + continue + seen_pairs.add(pair_key) + edges.append( + { + "name": relation_name, + "fact": f"{source_name} {relation_name} {target_name}", + "source_node_name": source_name, + "target_node_name": target_name, + "attributes": {}, + "fact_type": relation_name, + } + ) + + return edges + + def _guess_entities(self, text: str) -> List[str]: + # Simple fallback: quoted names, title-case words, and long nouns. + candidates: List[str] = [] + for match in re.findall(r"[A-Z][A-Za-z0-9&_-]*(?:\s+[A-Z][A-Za-z0-9&_-]*){0,3}", text): + match = match.strip().strip("'\"“”‘’,,。.!?;;::") + if len(match) > 2: + candidates.append(match) + for match in re.findall(r"[\u4e00-\u9fff]{2,10}", text): + match = match.strip() + if len(match) > 2: + candidates.append(match) + seen = set() + result = [] + for item in candidates: + key = item.lower() + if key not in seen: + seen.add(key) + result.append(item) + return result + + def _store_extracted_chunk( + self, + graph_id: str, + extracted: ExtractedChunk, + episode_uuid: str, + existing_nodes: Dict[str, Dict[str, Any]], + ) -> Dict[str, Dict[str, Any]]: + ontology = self.get_graph(graph_id) or {} + new_nodes: Dict[str, Dict[str, Any]] = {} + node_uuid_by_name: Dict[str, str] = {} + + # First, upsert nodes. + for node in extracted.nodes: + if not isinstance(node, dict): + continue + name = str(node.get("name", "")).strip() + if not name: + continue + labels = node.get("labels") or ["Entity", "Entity"] + if isinstance(labels, str): + labels = [labels] + summary = str(node.get("summary", "")) + attributes = node.get("attributes") or {} + entity_type = next((label for label in labels if label not in ["Entity", "Node"]), "Entity") + node_uuid = node.get("uuid") or _stable_uuid(graph_id, "node", entity_type, name) + self.upsert_node(graph_id, name, labels, summary=summary, attributes=attributes, node_uuid=node_uuid) + node_uuid_by_name[name.lower()] = node_uuid + node_uuid_by_name.setdefault(name.strip().lower(), node_uuid) + new_nodes[name.lower()] = { + "uuid": node_uuid, + "name": name, + "labels": list(labels), + "summary": summary, + "attributes": attributes, + } + + # Resolve existing names too. + for name, node in existing_nodes.items(): + node_uuid_by_name.setdefault(name, node.get("uuid", "")) + + # Now upsert edges. + for edge in extracted.edges: + if not isinstance(edge, dict): + continue + name = str(edge.get("name", "RELATED_TO")).strip() or "RELATED_TO" + fact = str(edge.get("fact", "")).strip() + source_name = str(edge.get("source_node_name", "")).strip() + target_name = str(edge.get("target_node_name", "")).strip() + if not source_name or not target_name: + continue + source_uuid = node_uuid_by_name.get(source_name.lower()) or existing_nodes.get(source_name.lower(), {}).get("uuid") + target_uuid = node_uuid_by_name.get(target_name.lower()) or existing_nodes.get(target_name.lower(), {}).get("uuid") + if not source_uuid: + source_uuid = self.upsert_node(graph_id, source_name, ["Entity", "Entity"], summary=source_name) + node_uuid_by_name[source_name.lower()] = source_uuid + if not target_uuid: + target_uuid = self.upsert_node(graph_id, target_name, ["Entity", "Entity"], summary=target_name) + node_uuid_by_name[target_name.lower()] = target_uuid + self.upsert_edge( + graph_id=graph_id, + name=name, + fact=fact or f"{source_name} {name} {target_name}", + source_node_uuid=source_uuid, + target_node_uuid=target_uuid, + source_node_name=source_name, + target_node_name=target_name, + attributes=edge.get("attributes") or {}, + episodes=[episode_uuid], + fact_type=str(edge.get("fact_type", name) or name), + edge_uuid=edge.get("uuid"), + ) + + return new_nodes + + def _row_to_node(self, row: Optional[sqlite3.Row]) -> Dict[str, Any]: + if row is None: + return {} + return { + "uuid": row["uuid"], + "name": row["name"], + "labels": _json_loads(row["labels_json"], []), + "summary": row["summary"], + "attributes": _json_loads(row["attributes_json"], {}), + "created_at": row["created_at"], + } + + def _row_to_edge(self, row: Optional[sqlite3.Row]) -> Dict[str, Any]: + if row is None: + return {} + return { + "uuid": row["uuid"], + "name": row["name"], + "fact": row["fact"], + "fact_type": row["fact_type"], + "source_node_uuid": row["source_node_uuid"], + "target_node_uuid": row["target_node_uuid"], + "source_node_name": row["source_node_name"], + "target_node_name": row["target_node_name"], + "attributes": _json_loads(row["attributes_json"], {}), + "created_at": row["created_at"], + "valid_at": row["valid_at"], + "invalid_at": row["invalid_at"], + "expired_at": row["expired_at"], + "episodes": _json_loads(row["episodes_json"], []), + } diff --git a/backend/app/services/zep_entity_reader.py b/backend/app/services/zep_entity_reader.py index 71661be4..cd781fa9 100644 --- a/backend/app/services/zep_entity_reader.py +++ b/backend/app/services/zep_entity_reader.py @@ -12,6 +12,7 @@ from zep_cloud.client import Zep from ..config import Config from ..utils.logger import get_logger from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges +from .local_graph_store import LocalGraphStore logger = get_logger('mirofish.zep_entity_reader') @@ -80,6 +81,13 @@ class ZepEntityReader: def __init__(self, api_key: Optional[str] = None): self.api_key = api_key or Config.ZEP_API_KEY + self.use_local_storage = Config.GRAPH_STORAGE_BACKEND == 'sqlite' or not self.api_key + self.local_store = LocalGraphStore() if self.use_local_storage else None + + if self.use_local_storage: + self.client = None + return + if not self.api_key: raise ValueError("ZEP_API_KEY 未配置") @@ -136,6 +144,11 @@ class ZepEntityReader: """ logger.info(f"获取图谱 {graph_id} 的所有节点...") + if self.use_local_storage: + nodes_data = self.local_store.list_nodes(graph_id) + logger.info(f"共获取 {len(nodes_data)} 个节点") + return nodes_data + nodes = fetch_all_nodes(self.client, graph_id) nodes_data = [] @@ -163,6 +176,11 @@ class ZepEntityReader: """ logger.info(f"获取图谱 {graph_id} 的所有边...") + if self.use_local_storage: + edges_data = self.local_store.list_edges(graph_id) + logger.info(f"共获取 {len(edges_data)} 条边") + return edges_data + edges = fetch_all_edges(self.client, graph_id) edges_data = [] @@ -190,6 +208,9 @@ class ZepEntityReader: 边列表 """ try: + if self.use_local_storage: + return self.local_store.get_edges_for_node_any(node_uuid) + # 使用重试机制调用Zep API edges = self._call_with_retry( func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid), @@ -346,11 +367,13 @@ class ZepEntityReader: EntityNode或None """ try: - # 使用重试机制获取节点 - node = self._call_with_retry( - func=lambda: self.client.graph.node.get(uuid_=entity_uuid), - operation_name=f"获取节点详情(uuid={entity_uuid[:8]}...)" - ) + if self.use_local_storage: + node = self.local_store.get_node(graph_id, entity_uuid) + else: + node = self._call_with_retry( + func=lambda: self.client.graph.node.get(uuid_=entity_uuid), + operation_name=f"获取节点详情(uuid={entity_uuid[:8]}...)" + ) if not node: return None @@ -396,6 +419,17 @@ class ZepEntityReader: "summary": related_node.get("summary", ""), }) + if self.use_local_storage: + return EntityNode( + uuid=node["uuid"], + name=node["name"], + labels=node["labels"], + summary=node["summary"], + attributes=node["attributes"], + related_edges=related_edges, + related_nodes=related_nodes, + ) + return EntityNode( uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), name=node.name or "", diff --git a/backend/app/services/zep_graph_memory_updater.py b/backend/app/services/zep_graph_memory_updater.py index e034fee2..d16d2984 100644 --- a/backend/app/services/zep_graph_memory_updater.py +++ b/backend/app/services/zep_graph_memory_updater.py @@ -17,6 +17,7 @@ from zep_cloud.client import Zep from ..config import Config from ..utils.logger import get_logger from ..utils.locale import get_locale, set_locale +from .local_graph_store import LocalGraphStore logger = get_logger('mirofish.zep_graph_memory_updater') @@ -239,11 +240,13 @@ class ZepGraphMemoryUpdater: """ self.graph_id = graph_id self.api_key = api_key or Config.ZEP_API_KEY + self.use_local_storage = Config.GRAPH_STORAGE_BACKEND == 'sqlite' or not self.api_key + self.local_store = LocalGraphStore() if self.use_local_storage else None - if not self.api_key: + if not self.use_local_storage and not self.api_key: raise ValueError("ZEP_API_KEY未配置") - self.client = Zep(api_key=self.api_key) + self.client = None if self.use_local_storage else Zep(api_key=self.api_key) # 活动队列 self._activity_queue: Queue = Queue() @@ -403,6 +406,21 @@ class ZepGraphMemoryUpdater: """ if not activities: return + + if self.use_local_storage: + store = self.local_store + assert store is not None + combined_text = "\n".join(activity.to_episode_text() for activity in activities) + store.append_activity(self.graph_id, { + "platform": platform, + "activities": [activity.__dict__ for activity in activities], + "combined_text": combined_text, + }) + self._total_sent += 1 + self._total_items_sent += len(activities) + display_name = self._get_platform_display_name(platform) + logger.info(f"本地模式已记录 {len(activities)} 条{display_name}活动到 SQLite 图谱 {self.graph_id}") + return # 将多条活动合并为一条文本,用换行分隔 episode_texts = [activity.to_episode_text() for activity in activities] diff --git a/backend/app/services/zep_tools.py b/backend/app/services/zep_tools.py index 3bc8a57a..83f60522 100644 --- a/backend/app/services/zep_tools.py +++ b/backend/app/services/zep_tools.py @@ -20,6 +20,7 @@ from ..utils.logger import get_logger from ..utils.llm_client import LLMClient from ..utils.locale import get_locale, t from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges +from .local_graph_store import LocalGraphStore logger = get_logger('mirofish.zep_tools') @@ -32,7 +33,7 @@ class SearchResult: nodes: List[Dict[str, Any]] query: str total_count: int - + def to_dict(self) -> Dict[str, Any]: return { "facts": self.facts, @@ -41,16 +42,16 @@ class SearchResult: "query": self.query, "total_count": self.total_count } - + def to_text(self) -> str: """转换为文本格式,供LLM理解""" text_parts = [f"搜索查询: {self.query}", f"找到 {self.total_count} 条相关信息"] - + if self.facts: text_parts.append("\n### 相关事实:") for i, fact in enumerate(self.facts, 1): text_parts.append(f"{i}. {fact}") - + return "\n".join(text_parts) @@ -62,7 +63,7 @@ class NodeInfo: labels: List[str] summary: str attributes: Dict[str, Any] - + def to_dict(self) -> Dict[str, Any]: return { "uuid": self.uuid, @@ -71,7 +72,7 @@ class NodeInfo: "summary": self.summary, "attributes": self.attributes } - + def to_text(self) -> str: """转换为文本格式""" entity_type = next((l for l in self.labels if l not in ["Entity", "Node"]), "未知类型") @@ -93,7 +94,7 @@ class EdgeInfo: valid_at: Optional[str] = None invalid_at: Optional[str] = None expired_at: Optional[str] = None - + def to_dict(self) -> Dict[str, Any]: return { "uuid": self.uuid, @@ -108,27 +109,27 @@ class EdgeInfo: "invalid_at": self.invalid_at, "expired_at": self.expired_at } - + def to_text(self, include_temporal: bool = False) -> str: """转换为文本格式""" source = self.source_node_name or self.source_node_uuid[:8] target = self.target_node_name or self.target_node_uuid[:8] base_text = f"关系: {source} --[{self.name}]--> {target}\n事实: {self.fact}" - + if include_temporal: valid_at = self.valid_at or "未知" invalid_at = self.invalid_at or "至今" base_text += f"\n时效: {valid_at} - {invalid_at}" if self.expired_at: base_text += f" (已过期: {self.expired_at})" - + return base_text - + @property def is_expired(self) -> bool: """是否已过期""" return self.expired_at is not None - + @property def is_invalid(self) -> bool: """是否已失效""" @@ -144,17 +145,17 @@ class InsightForgeResult: query: str simulation_requirement: str sub_queries: List[str] - + # 各维度检索结果 semantic_facts: List[str] = field(default_factory=list) # 语义搜索结果 entity_insights: List[Dict[str, Any]] = field(default_factory=list) # 实体洞察 relationship_chains: List[str] = field(default_factory=list) # 关系链 - + # 统计信息 total_facts: int = 0 total_entities: int = 0 total_relationships: int = 0 - + def to_dict(self) -> Dict[str, Any]: return { "query": self.query, @@ -167,7 +168,7 @@ class InsightForgeResult: "total_entities": self.total_entities, "total_relationships": self.total_relationships } - + def to_text(self) -> str: """转换为详细的文本格式,供LLM理解""" text_parts = [ @@ -179,19 +180,19 @@ class InsightForgeResult: f"- 涉及实体: {self.total_entities}个", f"- 关系链: {self.total_relationships}条" ] - + # 子问题 if self.sub_queries: text_parts.append(f"\n### 分析的子问题") for i, sq in enumerate(self.sub_queries, 1): text_parts.append(f"{i}. {sq}") - + # 语义搜索结果 if self.semantic_facts: text_parts.append(f"\n### 【关键事实】(请在报告中引用这些原文)") for i, fact in enumerate(self.semantic_facts, 1): text_parts.append(f"{i}. \"{fact}\"") - + # 实体洞察 if self.entity_insights: text_parts.append(f"\n### 【核心实体】") @@ -201,13 +202,13 @@ class InsightForgeResult: text_parts.append(f" 摘要: \"{entity.get('summary')}\"") if entity.get('related_facts'): text_parts.append(f" 相关事实: {len(entity.get('related_facts', []))}条") - + # 关系链 if self.relationship_chains: text_parts.append(f"\n### 【关系链】") for chain in self.relationship_chains: text_parts.append(f"- {chain}") - + return "\n".join(text_parts) @@ -218,7 +219,7 @@ class PanoramaResult: 包含所有相关信息,包括过期内容 """ query: str - + # 全部节点 all_nodes: List[NodeInfo] = field(default_factory=list) # 全部边(包括过期的) @@ -227,13 +228,13 @@ class PanoramaResult: active_facts: List[str] = field(default_factory=list) # 已过期/失效的事实(历史记录) historical_facts: List[str] = field(default_factory=list) - + # 统计 total_nodes: int = 0 total_edges: int = 0 active_count: int = 0 historical_count: int = 0 - + def to_dict(self) -> Dict[str, Any]: return { "query": self.query, @@ -246,7 +247,7 @@ class PanoramaResult: "active_count": self.active_count, "historical_count": self.historical_count } - + def to_text(self) -> str: """转换为文本格式(完整版本,不截断)""" text_parts = [ @@ -258,26 +259,26 @@ class PanoramaResult: f"- 当前有效事实: {self.active_count}条", f"- 历史/过期事实: {self.historical_count}条" ] - + # 当前有效的事实(完整输出,不截断) if self.active_facts: text_parts.append(f"\n### 【当前有效事实】(模拟结果原文)") for i, fact in enumerate(self.active_facts, 1): text_parts.append(f"{i}. \"{fact}\"") - + # 历史/过期事实(完整输出,不截断) if self.historical_facts: text_parts.append(f"\n### 【历史/过期事实】(演变过程记录)") for i, fact in enumerate(self.historical_facts, 1): text_parts.append(f"{i}. \"{fact}\"") - + # 关键实体(完整输出,不截断) if self.all_nodes: text_parts.append(f"\n### 【涉及实体】") for node in self.all_nodes: entity_type = next((l for l in node.labels if l not in ["Entity", "Node"]), "实体") text_parts.append(f"- **{node.name}** ({entity_type})") - + return "\n".join(text_parts) @@ -290,7 +291,7 @@ class AgentInterview: question: str # 采访问题 response: str # 采访回答 key_quotes: List[str] = field(default_factory=list) # 关键引言 - + def to_dict(self) -> Dict[str, Any]: return { "agent_name": self.agent_name, @@ -300,7 +301,7 @@ class AgentInterview: "response": self.response, "key_quotes": self.key_quotes } - + def to_text(self) -> str: text = f"**{self.agent_name}** ({self.agent_role})\n" # 显示完整的agent_bio,不截断 @@ -345,21 +346,21 @@ class InterviewResult: """ interview_topic: str # 采访主题 interview_questions: List[str] # 采访问题列表 - + # 采访选择的Agent selected_agents: List[Dict[str, Any]] = field(default_factory=list) # 各Agent的采访回答 interviews: List[AgentInterview] = field(default_factory=list) - + # 选择Agent的理由 selection_reasoning: str = "" # 整合后的采访摘要 summary: str = "" - + # 统计 total_agents: int = 0 interviewed_count: int = 0 - + def to_dict(self) -> Dict[str, Any]: return { "interview_topic": self.interview_topic, @@ -371,7 +372,7 @@ class InterviewResult: "total_agents": self.total_agents, "interviewed_count": self.interviewed_count } - + def to_text(self) -> str: """转换为详细的文本格式,供LLM理解和报告引用""" text_parts = [ @@ -401,13 +402,13 @@ class InterviewResult: class ZepToolsService: """ Zep检索工具服务 - + 【核心检索工具 - 优化后】 1. insight_forge - 深度洞察检索(最强大,自动生成子问题,多维度检索) 2. panorama_search - 广度搜索(获取全貌,包括过期内容) 3. quick_search - 简单搜索(快速检索) 4. interview_agents - 深度采访(采访模拟Agent,获取多视角观点) - + 【基础工具】 - search_graph - 图谱语义搜索 - get_all_nodes - 获取图谱所有节点 @@ -417,34 +418,41 @@ class ZepToolsService: - get_entities_by_type - 按类型获取实体 - get_entity_summary - 获取实体的关系摘要 """ - + # 重试配置 MAX_RETRIES = 3 RETRY_DELAY = 2.0 - + def __init__(self, api_key: Optional[str] = None, llm_client: Optional[LLMClient] = None): self.api_key = api_key or Config.ZEP_API_KEY - if not self.api_key: - raise ValueError("ZEP_API_KEY 未配置") - - self.client = Zep(api_key=self.api_key) + self.use_local_storage = Config.GRAPH_STORAGE_BACKEND == 'sqlite' or not self.api_key + self.local_store = LocalGraphStore() if self.use_local_storage else None + + if self.use_local_storage: + self.client = None + else: + if not self.api_key: + raise ValueError("ZEP_API_KEY 未配置") + + self.client = Zep(api_key=self.api_key) + # LLM客户端用于InsightForge生成子问题 self._llm_client = llm_client logger.info(t("console.zepToolsInitialized")) - + @property def llm(self) -> LLMClient: """延迟初始化LLM客户端""" if self._llm_client is None: self._llm_client = LLMClient() return self._llm_client - + def _call_with_retry(self, func, operation_name: str, max_retries: int = None): """带重试机制的API调用""" max_retries = max_retries or self.MAX_RETRIES last_exception = None delay = self.RETRY_DELAY - + for attempt in range(max_retries): try: return func() @@ -458,33 +466,36 @@ class ZepToolsService: delay *= 2 else: logger.error(t("console.zepAllRetriesFailed", operation=operation_name, retries=max_retries, error=str(e))) - + raise last_exception - + def search_graph( - self, - graph_id: str, - query: str, + self, + graph_id: str, + query: str, limit: int = 10, scope: str = "edges" ) -> SearchResult: """ 图谱语义搜索 - + 使用混合搜索(语义+BM25)在图谱中搜索相关信息。 如果Zep Cloud的search API不可用,则降级为本地关键词匹配。 - + Args: graph_id: 图谱ID (Standalone Graph) query: 搜索查询 limit: 返回结果数量 scope: 搜索范围,"edges" 或 "nodes" - + Returns: SearchResult: 搜索结果 """ logger.info(t("console.graphSearch", graphId=graph_id, query=query[:50])) - + + if self.use_local_storage: + return self._local_search(graph_id, query, limit, scope) + # 尝试使用Zep Cloud Search API try: search_results = self._call_with_retry( @@ -497,11 +508,11 @@ class ZepToolsService: ), operation_name=t("console.graphSearchOp", graphId=graph_id) ) - + facts = [] edges = [] nodes = [] - + # 解析边搜索结果 if hasattr(search_results, 'edges') and search_results.edges: for edge in search_results.edges: @@ -514,7 +525,7 @@ class ZepToolsService: "source_node_uuid": getattr(edge, 'source_node_uuid', ''), "target_node_uuid": getattr(edge, 'target_node_uuid', ''), }) - + # 解析节点搜索结果 if hasattr(search_results, 'nodes') and search_results.nodes: for node in search_results.nodes: @@ -527,9 +538,9 @@ class ZepToolsService: # 节点摘要也算作事实 if hasattr(node, 'summary') and node.summary: facts.append(f"[{node.name}]: {node.summary}") - + logger.info(t("console.searchComplete", count=len(facts))) - + return SearchResult( facts=facts, edges=edges, @@ -537,43 +548,43 @@ class ZepToolsService: query=query, total_count=len(facts) ) - + except Exception as e: logger.warning(t("console.zepSearchApiFallback", error=str(e))) # 降级:使用本地关键词匹配搜索 return self._local_search(graph_id, query, limit, scope) - + def _local_search( - self, - graph_id: str, - query: str, + self, + graph_id: str, + query: str, limit: int = 10, scope: str = "edges" ) -> SearchResult: """ 本地关键词匹配搜索(作为Zep Search API的降级方案) - + 获取所有边/节点,然后在本地进行关键词匹配 - + Args: graph_id: 图谱ID query: 搜索查询 limit: 返回结果数量 scope: 搜索范围 - + Returns: SearchResult: 搜索结果 """ logger.info(t("console.usingLocalSearch", query=query[:30])) - + facts = [] edges_result = [] nodes_result = [] - + # 提取查询关键词(简单分词) query_lower = query.lower() keywords = [w.strip() for w in query_lower.replace(',', ' ').replace(',', ' ').split() if len(w.strip()) > 1] - + def match_score(text: str) -> int: """计算文本与查询的匹配分数""" if not text: @@ -588,7 +599,7 @@ class ZepToolsService: if keyword in text_lower: score += 10 return score - + try: if scope in ["edges", "both"]: # 获取所有边并匹配 @@ -598,10 +609,10 @@ class ZepToolsService: score = match_score(edge.fact) + match_score(edge.name) if score > 0: scored_edges.append((score, edge)) - + # 按分数排序 scored_edges.sort(key=lambda x: x[0], reverse=True) - + for score, edge in scored_edges[:limit]: if edge.fact: facts.append(edge.fact) @@ -612,7 +623,7 @@ class ZepToolsService: "source_node_uuid": edge.source_node_uuid, "target_node_uuid": edge.target_node_uuid, }) - + if scope in ["nodes", "both"]: # 获取所有节点并匹配 all_nodes = self.get_all_nodes(graph_id) @@ -621,9 +632,9 @@ class ZepToolsService: score = match_score(node.name) + match_score(node.summary) if score > 0: scored_nodes.append((score, node)) - + scored_nodes.sort(key=lambda x: x[0], reverse=True) - + for score, node in scored_nodes[:limit]: nodes_result.append({ "uuid": node.uuid, @@ -633,12 +644,12 @@ class ZepToolsService: }) if node.summary: facts.append(f"[{node.name}]: {node.summary}") - + logger.info(t("console.localSearchComplete", count=len(facts))) - + except Exception as e: logger.error(t("console.localSearchFailed", error=str(e))) - + return SearchResult( facts=facts, edges=edges_result, @@ -646,7 +657,7 @@ class ZepToolsService: query=query, total_count=len(facts) ) - + def get_all_nodes(self, graph_id: str) -> List[NodeInfo]: """ 获取图谱的所有节点(分页获取) @@ -659,7 +670,26 @@ class ZepToolsService: """ logger.info(t("console.fetchingAllNodes", graphId=graph_id)) - nodes = fetch_all_nodes(self.client, graph_id) + if self.use_local_storage: + store = self.local_store + assert store is not None + nodes = store.list_nodes(graph_id) + result = [ + NodeInfo( + uuid=node["uuid"], + name=node["name"], + labels=node["labels"], + summary=node["summary"], + attributes=node["attributes"], + ) + for node in nodes + ] + logger.info(t("console.fetchedNodes", count=len(result))) + return result + + client = self.client + assert client is not None + nodes = fetch_all_nodes(client, graph_id) result = [] for node in nodes: @@ -688,7 +718,33 @@ class ZepToolsService: """ logger.info(t("console.fetchingAllEdges", graphId=graph_id)) - edges = fetch_all_edges(self.client, graph_id) + if self.use_local_storage: + store = self.local_store + assert store is not None + edges = store.list_edges(graph_id) + result = [] + for edge in edges: + edge_info = EdgeInfo( + uuid=edge["uuid"], + name=edge["name"], + fact=edge["fact"], + source_node_uuid=edge["source_node_uuid"], + target_node_uuid=edge["target_node_uuid"], + source_node_name=edge.get("source_node_name"), + target_node_name=edge.get("target_node_name"), + ) + if include_temporal: + edge_info.created_at = edge.get("created_at") + edge_info.valid_at = edge.get("valid_at") + edge_info.invalid_at = edge.get("invalid_at") + edge_info.expired_at = edge.get("expired_at") + result.append(edge_info) + logger.info(t("console.fetchedEdges", count=len(result))) + return result + + client = self.client + assert client is not None + edges = fetch_all_edges(client, graph_id) result = [] for edge in edges: @@ -712,28 +768,47 @@ class ZepToolsService: logger.info(t("console.fetchedEdges", count=len(result))) return result - + def get_node_detail(self, node_uuid: str) -> Optional[NodeInfo]: """ 获取单个节点的详细信息 - + Args: node_uuid: 节点UUID - + Returns: 节点信息或None """ logger.info(t("console.fetchingNodeDetail", uuid=node_uuid[:8])) - + try: - node = self._call_with_retry( - func=lambda: self.client.graph.node.get(uuid_=node_uuid), - operation_name=t("console.fetchNodeDetailOp", uuid=node_uuid[:8]) - ) - + if self.use_local_storage: + store = self.local_store + assert store is not None + # 在本地模式下,遍历所有图谱查找节点 + node = None + for graph in store.list_graphs(): + node = store.get_node(graph["graph_id"], node_uuid) + if node: + break + else: + node = self._call_with_retry( + func=lambda: self.client.graph.node.get(uuid_=node_uuid), + operation_name=t("console.fetchNodeDetailOp", uuid=node_uuid[:8]) + ) + if not node: return None - + + if self.use_local_storage: + return NodeInfo( + uuid=node["uuid"], + name=node["name"], + labels=node["labels"], + summary=node["summary"], + attributes=node["attributes"], + ) + return NodeInfo( uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''), name=node.name or "", @@ -744,93 +819,93 @@ class ZepToolsService: except Exception as e: logger.error(t("console.fetchNodeDetailFailed", error=str(e))) return None - + def get_node_edges(self, graph_id: str, node_uuid: str) -> List[EdgeInfo]: """ 获取节点相关的所有边 - + 通过获取图谱所有边,然后过滤出与指定节点相关的边 - + Args: graph_id: 图谱ID node_uuid: 节点UUID - + Returns: 边列表 """ logger.info(t("console.fetchingNodeEdges", uuid=node_uuid[:8])) - + try: # 获取图谱所有边,然后过滤 all_edges = self.get_all_edges(graph_id) - + result = [] for edge in all_edges: # 检查边是否与指定节点相关(作为源或目标) if edge.source_node_uuid == node_uuid or edge.target_node_uuid == node_uuid: result.append(edge) - + logger.info(t("console.foundNodeEdges", count=len(result))) return result - + except Exception as e: logger.warning(t("console.fetchNodeEdgesFailed", error=str(e))) return [] - + def get_entities_by_type( - self, - graph_id: str, + self, + graph_id: str, entity_type: str ) -> List[NodeInfo]: """ 按类型获取实体 - + Args: graph_id: 图谱ID entity_type: 实体类型(如 Student, PublicFigure 等) - + Returns: 符合类型的实体列表 """ logger.info(t("console.fetchingEntitiesByType", type=entity_type)) - + all_nodes = self.get_all_nodes(graph_id) - + filtered = [] for node in all_nodes: # 检查labels是否包含指定类型 if entity_type in node.labels: filtered.append(node) - + logger.info(t("console.foundEntitiesByType", count=len(filtered), type=entity_type)) return filtered - + def get_entity_summary( - self, - graph_id: str, + self, + graph_id: str, entity_name: str ) -> Dict[str, Any]: """ 获取指定实体的关系摘要 - + 搜索与该实体相关的所有信息,并生成摘要 - + Args: graph_id: 图谱ID entity_name: 实体名称 - + Returns: 实体摘要信息 """ logger.info(t("console.fetchingEntitySummary", name=entity_name)) - + # 先搜索该实体相关的信息 search_result = self.search_graph( graph_id=graph_id, query=entity_name, limit=20 ) - + # 尝试在所有节点中找到该实体 all_nodes = self.get_all_nodes(graph_id) entity_node = None @@ -838,12 +913,12 @@ class ZepToolsService: if node.name.lower() == entity_name.lower(): entity_node = node break - + related_edges = [] if entity_node: # 传入graph_id参数 related_edges = self.get_node_edges(graph_id, entity_node.uuid) - + return { "entity_name": entity_name, "entity_info": entity_node.to_dict() if entity_node else None, @@ -851,34 +926,34 @@ class ZepToolsService: "related_edges": [e.to_dict() for e in related_edges], "total_relations": len(related_edges) } - + def get_graph_statistics(self, graph_id: str) -> Dict[str, Any]: """ 获取图谱的统计信息 - + Args: graph_id: 图谱ID - + Returns: 统计信息 """ logger.info(t("console.fetchingGraphStats", graphId=graph_id)) - + nodes = self.get_all_nodes(graph_id) edges = self.get_all_edges(graph_id) - + # 统计实体类型分布 entity_types = {} for node in nodes: for label in node.labels: if label not in ["Entity", "Node"]: entity_types[label] = entity_types.get(label, 0) + 1 - + # 统计关系类型分布 relation_types = {} for edge in edges: relation_types[edge.name] = relation_types.get(edge.name, 0) + 1 - + return { "graph_id": graph_id, "total_nodes": len(nodes), @@ -886,41 +961,41 @@ class ZepToolsService: "entity_types": entity_types, "relation_types": relation_types } - + def get_simulation_context( - self, + self, graph_id: str, simulation_requirement: str, limit: int = 30 ) -> Dict[str, Any]: """ 获取模拟相关的上下文信息 - + 综合搜索与模拟需求相关的所有信息 - + Args: graph_id: 图谱ID simulation_requirement: 模拟需求描述 limit: 每类信息的数量限制 - + Returns: 模拟上下文信息 """ logger.info(t("console.fetchingSimContext", requirement=simulation_requirement[:50])) - + # 搜索与模拟需求相关的信息 search_result = self.search_graph( graph_id=graph_id, query=simulation_requirement, limit=limit ) - + # 获取图谱统计 stats = self.get_graph_statistics(graph_id) - + # 获取所有实体节点 all_nodes = self.get_all_nodes(graph_id) - + # 筛选有实际类型的实体(非纯Entity节点) entities = [] for node in all_nodes: @@ -931,7 +1006,7 @@ class ZepToolsService: "type": custom_labels[0], "summary": node.summary }) - + return { "simulation_requirement": simulation_requirement, "related_facts": search_result.facts, @@ -939,9 +1014,9 @@ class ZepToolsService: "entities": entities[:limit], # 限制数量 "total_entities": len(entities) } - + # ========== 核心检索工具(优化后) ========== - + def insight_forge( self, graph_id: str, @@ -952,32 +1027,32 @@ class ZepToolsService: ) -> InsightForgeResult: """ 【InsightForge - 深度洞察检索】 - + 最强大的混合检索函数,自动分解问题并多维度检索: 1. 使用LLM将问题分解为多个子问题 2. 对每个子问题进行语义搜索 3. 提取相关实体并获取其详细信息 4. 追踪关系链 5. 整合所有结果,生成深度洞察 - + Args: graph_id: 图谱ID query: 用户问题 simulation_requirement: 模拟需求描述 report_context: 报告上下文(可选,用于更精准的子问题生成) max_sub_queries: 最大子问题数量 - + Returns: InsightForgeResult: 深度洞察检索结果 """ logger.info(t("console.insightForgeStart", query=query[:50])) - + result = InsightForgeResult( query=query, simulation_requirement=simulation_requirement, sub_queries=[] ) - + # Step 1: 使用LLM生成子问题 sub_queries = self._generate_sub_queries( query=query, @@ -987,12 +1062,12 @@ class ZepToolsService: ) result.sub_queries = sub_queries logger.info(t("console.generatedSubQueries", count=len(sub_queries))) - + # Step 2: 对每个子问题进行语义搜索 all_facts = [] all_edges = [] seen_facts = set() - + for sub_query in sub_queries: search_result = self.search_graph( graph_id=graph_id, @@ -1000,14 +1075,14 @@ class ZepToolsService: limit=15, scope="edges" ) - + for fact in search_result.facts: if fact not in seen_facts: all_facts.append(fact) seen_facts.add(fact) - + all_edges.extend(search_result.edges) - + # 对原始问题也进行搜索 main_search = self.search_graph( graph_id=graph_id, @@ -1019,10 +1094,10 @@ class ZepToolsService: if fact not in seen_facts: all_facts.append(fact) seen_facts.add(fact) - + result.semantic_facts = all_facts result.total_facts = len(all_facts) - + # Step 3: 从边中提取相关实体UUID,只获取这些实体的信息(不获取全部节点) entity_uuids = set() for edge_data in all_edges: @@ -1033,11 +1108,11 @@ class ZepToolsService: entity_uuids.add(source_uuid) if target_uuid: entity_uuids.add(target_uuid) - + # 获取所有相关实体的详情(不限制数量,完整输出) entity_insights = [] node_map = {} # 用于后续关系链构建 - + for uuid in list(entity_uuids): # 处理所有实体,不截断 if not uuid: continue @@ -1047,13 +1122,13 @@ class ZepToolsService: if node: node_map[uuid] = node entity_type = next((l for l in node.labels if l not in ["Entity", "Node"]), "实体") - + # 获取该实体相关的所有事实(不截断) related_facts = [ - f for f in all_facts + f for f in all_facts if node.name.lower() in f.lower() ] - + entity_insights.append({ "uuid": node.uuid, "name": node.name, @@ -1064,10 +1139,10 @@ class ZepToolsService: except Exception as e: logger.debug(f"获取节点 {uuid} 失败: {e}") continue - + result.entity_insights = entity_insights result.total_entities = len(entity_insights) - + # Step 4: 构建所有关系链(不限制数量) relationship_chains = [] for edge_data in all_edges: # 处理所有边,不截断 @@ -1075,20 +1150,20 @@ class ZepToolsService: source_uuid = edge_data.get('source_node_uuid', '') target_uuid = edge_data.get('target_node_uuid', '') relation_name = edge_data.get('name', '') - + source_name = node_map.get(source_uuid, NodeInfo('', '', [], '', {})).name or source_uuid[:8] target_name = node_map.get(target_uuid, NodeInfo('', '', [], '', {})).name or target_uuid[:8] - + chain = f"{source_name} --[{relation_name}]--> {target_name}" if chain not in relationship_chains: relationship_chains.append(chain) - + result.relationship_chains = relationship_chains result.total_relationships = len(relationship_chains) - + logger.info(t("console.insightForgeComplete", facts=result.total_facts, entities=result.total_entities, relationships=result.total_relationships)) return result - + def _generate_sub_queries( self, query: str, @@ -1098,7 +1173,7 @@ class ZepToolsService: ) -> List[str]: """ 使用LLM生成子问题 - + 将复杂问题分解为多个可以独立检索的子问题 """ system_prompt = """你是一个专业的问题分析专家。你的任务是将一个复杂问题分解为多个可以在模拟世界中独立观察的子问题。 @@ -1127,11 +1202,11 @@ class ZepToolsService: ], temperature=0.3 ) - + sub_queries = response.get("sub_queries", []) # 确保是字符串列表 return [str(sq) for sq in sub_queries[:max_queries]] - + except Exception as e: logger.warning(t("console.generateSubQueriesFailed", error=str(e))) # 降级:返回基于原问题的变体 @@ -1141,7 +1216,7 @@ class ZepToolsService: f"{query} 的原因和影响", f"{query} 的发展过程" ][:max_queries] - + def panorama_search( self, graph_id: str, @@ -1151,53 +1226,53 @@ class ZepToolsService: ) -> PanoramaResult: """ 【PanoramaSearch - 广度搜索】 - + 获取全貌视图,包括所有相关内容和历史/过期信息: 1. 获取所有相关节点 2. 获取所有边(包括已过期/失效的) 3. 分类整理当前有效和历史信息 - + 这个工具适用于需要了解事件全貌、追踪演变过程的场景。 - + Args: graph_id: 图谱ID query: 搜索查询(用于相关性排序) include_expired: 是否包含过期内容(默认True) limit: 返回结果数量限制 - + Returns: PanoramaResult: 广度搜索结果 """ logger.info(t("console.panoramaSearchStart", query=query[:50])) - + result = PanoramaResult(query=query) - + # 获取所有节点 all_nodes = self.get_all_nodes(graph_id) node_map = {n.uuid: n for n in all_nodes} result.all_nodes = all_nodes result.total_nodes = len(all_nodes) - + # 获取所有边(包含时间信息) all_edges = self.get_all_edges(graph_id, include_temporal=True) result.all_edges = all_edges result.total_edges = len(all_edges) - + # 分类事实 active_facts = [] historical_facts = [] - + for edge in all_edges: if not edge.fact: continue - + # 为事实添加实体名称 source_name = node_map.get(edge.source_node_uuid, NodeInfo('', '', [], '', {})).name or edge.source_node_uuid[:8] target_name = node_map.get(edge.target_node_uuid, NodeInfo('', '', [], '', {})).name or edge.target_node_uuid[:8] - + # 判断是否过期/失效 is_historical = edge.is_expired or edge.is_invalid - + if is_historical: # 历史/过期事实,添加时间标记 valid_at = edge.valid_at or "未知" @@ -1207,11 +1282,11 @@ class ZepToolsService: else: # 当前有效事实 active_facts.append(edge.fact) - + # 基于查询进行相关性排序 query_lower = query.lower() keywords = [w.strip() for w in query_lower.replace(',', ' ').replace(',', ' ').split() if len(w.strip()) > 1] - + def relevance_score(fact: str) -> int: fact_lower = fact.lower() score = 0 @@ -1221,19 +1296,19 @@ class ZepToolsService: if kw in fact_lower: score += 10 return score - + # 排序并限制数量 active_facts.sort(key=relevance_score, reverse=True) historical_facts.sort(key=relevance_score, reverse=True) - + result.active_facts = active_facts[:limit] result.historical_facts = historical_facts[:limit] if include_expired else [] result.active_count = len(active_facts) result.historical_count = len(historical_facts) - + logger.info(t("console.panoramaSearchComplete", active=result.active_count, historical=result.historical_count)) return result - + def quick_search( self, graph_id: str, @@ -1242,22 +1317,22 @@ class ZepToolsService: ) -> SearchResult: """ 【QuickSearch - 简单搜索】 - + 快速、轻量级的检索工具: 1. 直接调用Zep语义搜索 2. 返回最相关的结果 3. 适用于简单、直接的检索需求 - + Args: graph_id: 图谱ID query: 搜索查询 limit: 返回结果数量 - + Returns: SearchResult: 搜索结果 """ logger.info(t("console.quickSearchStart", query=query[:50])) - + # 直接调用现有的search_graph方法 result = self.search_graph( graph_id=graph_id, @@ -1265,10 +1340,10 @@ class ZepToolsService: limit=limit, scope="edges" ) - + logger.info(t("console.quickSearchComplete", count=result.total_count)) return result - + def interview_agents( self, simulation_id: str, @@ -1279,51 +1354,51 @@ class ZepToolsService: ) -> InterviewResult: """ 【InterviewAgents - 深度采访】 - + 调用真实的OASIS采访API,采访模拟中正在运行的Agent: 1. 自动读取人设文件,了解所有模拟Agent 2. 使用LLM分析采访需求,智能选择最相关的Agent 3. 使用LLM生成采访问题 4. 调用 /api/simulation/interview/batch 接口进行真实采访(双平台同时采访) 5. 整合所有采访结果,生成采访报告 - + 【重要】此功能需要模拟环境处于运行状态(OASIS环境未关闭) - + 【使用场景】 - 需要从不同角色视角了解事件看法 - 需要收集多方意见和观点 - 需要获取模拟Agent的真实回答(非LLM模拟) - + Args: simulation_id: 模拟ID(用于定位人设文件和调用采访API) interview_requirement: 采访需求描述(非结构化,如"了解学生对事件的看法") simulation_requirement: 模拟需求背景(可选) max_agents: 最多采访的Agent数量 custom_questions: 自定义采访问题(可选,若不提供则自动生成) - + Returns: InterviewResult: 采访结果 """ from .simulation_runner import SimulationRunner - + logger.info(t("console.interviewAgentsStart", requirement=interview_requirement[:50])) - + result = InterviewResult( interview_topic=interview_requirement, interview_questions=custom_questions or [] ) - + # Step 1: 读取人设文件 profiles = self._load_agent_profiles(simulation_id) - + if not profiles: logger.warning(t("console.profilesNotFound", simId=simulation_id)) result.summary = "未找到可采访的Agent人设文件" return result - + result.total_agents = len(profiles) logger.info(t("console.loadedProfiles", count=len(profiles))) - + # Step 2: 使用LLM选择要采访的Agent(返回agent_id列表) selected_agents, selected_indices, selection_reasoning = self._select_agents_for_interview( profiles=profiles, @@ -1331,11 +1406,11 @@ class ZepToolsService: simulation_requirement=simulation_requirement, max_agents=max_agents ) - + result.selected_agents = selected_agents result.selection_reasoning = selection_reasoning logger.info(t("console.selectedAgentsForInterview", count=len(selected_agents), indices=selected_indices)) - + # Step 3: 生成采访问题(如果没有提供) if not result.interview_questions: result.interview_questions = self._generate_interview_questions( @@ -1344,10 +1419,10 @@ class ZepToolsService: selected_agents=selected_agents ) logger.info(t("console.generatedInterviewQuestions", count=len(result.interview_questions))) - + # 将问题合并为一个采访prompt combined_prompt = "\n".join([f"{i+1}. {q}" for i, q in enumerate(result.interview_questions)]) - + # 添加优化前缀,约束Agent回复格式 INTERVIEW_PROMPT_PREFIX = ( "你正在接受一次采访。请结合你的人设、所有的过往记忆与行动," @@ -1361,7 +1436,7 @@ class ZepToolsService: "6. 回答要有实质内容,每个问题至少回答2-3句话\n\n" ) optimized_prompt = f"{INTERVIEW_PROMPT_PREFIX}{combined_prompt}" - + # Step 4: 调用真实的采访API(不指定platform,默认双平台同时采访) try: # 构建批量采访列表(不指定platform,双平台采访) @@ -1372,9 +1447,9 @@ class ZepToolsService: "prompt": optimized_prompt # 使用优化后的prompt # 不指定platform,API会在twitter和reddit两个平台都采访 }) - + logger.info(t("console.callingBatchInterviewApi", count=len(interviews_request))) - + # 调用 SimulationRunner 的批量采访方法(不传platform,双平台采访) api_result = SimulationRunner.interview_agents_batch( simulation_id=simulation_id, @@ -1382,31 +1457,31 @@ class ZepToolsService: platform=None, # 不指定platform,双平台采访 timeout=180.0 # 双平台需要更长超时 ) - + logger.info(t("console.interviewApiReturned", count=api_result.get('interviews_count', 0), success=api_result.get('success'))) - + # 检查API调用是否成功 if not api_result.get("success", False): error_msg = api_result.get("error", "未知错误") logger.warning(t("console.interviewApiReturnedFailure", error=error_msg)) result.summary = f"采访API调用失败:{error_msg}。请检查OASIS模拟环境状态。" return result - + # Step 5: 解析API返回结果,构建AgentInterview对象 # 双平台模式返回格式: {"twitter_0": {...}, "reddit_0": {...}, "twitter_1": {...}, ...} api_data = api_result.get("result", {}) results_dict = api_data.get("results", {}) if isinstance(api_data, dict) else {} - + for i, agent_idx in enumerate(selected_indices): agent = selected_agents[i] agent_name = agent.get("realname", agent.get("username", f"Agent_{agent_idx}")) agent_role = agent.get("profession", "未知") agent_bio = agent.get("bio", "") - + # 获取该Agent在两个平台的采访结果 twitter_result = results_dict.get(f"twitter_{agent_idx}", {}) reddit_result = results_dict.get(f"reddit_{agent_idx}", {}) - + twitter_response = twitter_result.get("response", "") reddit_response = reddit_result.get("response", "") @@ -1446,7 +1521,7 @@ class ZepToolsService: paired = re.findall(r'\u201c([^\u201c\u201d]{15,100})\u201d', clean_text) paired += re.findall(r'\u300c([^\u300c\u300d]{15,100})\u300d', clean_text) key_quotes = [q for q in paired if not re.match(r'^[,,;;::、]', q)][:3] - + interview = AgentInterview( agent_name=agent_name, agent_role=agent_role, @@ -1456,9 +1531,9 @@ class ZepToolsService: key_quotes=key_quotes[:5] ) result.interviews.append(interview) - + result.interviewed_count = len(result.interviews) - + except ValueError as e: # 模拟环境未运行 logger.warning(t("console.interviewApiCallFailed", error=e)) @@ -1470,17 +1545,17 @@ class ZepToolsService: logger.error(traceback.format_exc()) result.summary = f"采访过程发生错误:{str(e)}" return result - + # Step 6: 生成采访摘要 if result.interviews: result.summary = self._generate_interview_summary( interviews=result.interviews, interview_requirement=interview_requirement ) - + logger.info(t("console.interviewAgentsComplete", count=result.interviewed_count)) return result - + @staticmethod def _clean_tool_call_response(response: str) -> str: """清理 Agent 回复中的 JSON 工具调用包裹,提取实际内容""" @@ -1506,15 +1581,15 @@ class ZepToolsService: """加载模拟的Agent人设文件""" import os import csv - + # 构建人设文件路径 sim_dir = os.path.join( - os.path.dirname(__file__), + os.path.dirname(__file__), f'../../uploads/simulations/{simulation_id}' ) - + profiles = [] - + # 优先尝试读取Reddit JSON格式 reddit_profile_path = os.path.join(sim_dir, "reddit_profiles.json") if os.path.exists(reddit_profile_path): @@ -1525,7 +1600,7 @@ class ZepToolsService: return profiles except Exception as e: logger.warning(t("console.readRedditProfilesFailed", error=e)) - + # 尝试读取Twitter CSV格式 twitter_profile_path = os.path.join(sim_dir, "twitter_profiles.csv") if os.path.exists(twitter_profile_path): @@ -1545,9 +1620,9 @@ class ZepToolsService: return profiles except Exception as e: logger.warning(t("console.readTwitterProfilesFailed", error=e)) - + return profiles - + def _select_agents_for_interview( self, profiles: List[Dict[str, Any]], @@ -1557,14 +1632,14 @@ class ZepToolsService: ) -> tuple: """ 使用LLM选择要采访的Agent - + Returns: tuple: (selected_agents, selected_indices, reasoning) - selected_agents: 选中Agent的完整信息列表 - selected_indices: 选中Agent的索引列表(用于API调用) - reasoning: 选择理由 """ - + # 构建Agent摘要列表 agent_summaries = [] for i, profile in enumerate(profiles): @@ -1576,7 +1651,7 @@ class ZepToolsService: "interested_topics": profile.get("interested_topics", []) } agent_summaries.append(summary) - + system_prompt = """你是一个专业的采访策划专家。你的任务是根据采访需求,从模拟Agent列表中选择最适合采访的对象。 选择标准: @@ -1610,10 +1685,10 @@ class ZepToolsService: ], temperature=0.3 ) - + selected_indices = response.get("selected_indices", [])[:max_agents] reasoning = response.get("reasoning", "基于相关性自动选择") - + # 获取选中的Agent完整信息 selected_agents = [] valid_indices = [] @@ -1621,16 +1696,16 @@ class ZepToolsService: if 0 <= idx < len(profiles): selected_agents.append(profiles[idx]) valid_indices.append(idx) - + return selected_agents, valid_indices, reasoning - + except Exception as e: logger.warning(t("console.llmSelectAgentFailed", error=e)) # 降级:选择前N个 selected = profiles[:max_agents] indices = list(range(min(max_agents, len(profiles)))) return selected, indices, "使用默认选择策略" - + def _generate_interview_questions( self, interview_requirement: str, @@ -1638,9 +1713,9 @@ class ZepToolsService: selected_agents: List[Dict[str, Any]] ) -> List[str]: """使用LLM生成采访问题""" - + agent_roles = [a.get("profession", "未知") for a in selected_agents] - + system_prompt = """你是一个专业的记者/采访者。根据采访需求,生成3-5个深度采访问题。 问题要求: @@ -1669,9 +1744,9 @@ class ZepToolsService: ], temperature=0.5 ) - + return response.get("questions", [f"关于{interview_requirement},您有什么看法?"]) - + except Exception as e: logger.warning(t("console.generateInterviewQuestionsFailed", error=e)) return [ @@ -1679,22 +1754,22 @@ class ZepToolsService: "这件事对您或您所代表的群体有什么影响?", "您认为应该如何解决或改进这个问题?" ] - + def _generate_interview_summary( self, interviews: List[AgentInterview], interview_requirement: str ) -> str: """生成采访摘要""" - + if not interviews: return "未完成任何采访" - + # 收集所有采访内容 interview_texts = [] for interview in interviews: interview_texts.append(f"【{interview.agent_name}({interview.agent_role})】\n{interview.response[:500]}") - + quote_instruction = "引用受访者原话时使用中文引号「」" if get_locale() == 'zh' else 'Use quotation marks "" when quoting interviewees' system_prompt = f"""你是一个专业的新闻编辑。请根据多位受访者的回答,生成一份采访摘要。 @@ -1729,7 +1804,7 @@ class ZepToolsService: max_tokens=800 ) return summary - + except Exception as e: logger.warning(t("console.generateInterviewSummaryFailed", error=e)) # 降级:简单拼接