chore(i18n): replace all hardcoded Chinese strings with English in backend

Translate all Chinese comments, docstrings, log messages, error messages,
and LLM prompt text to English across the entire backend codebase.
Locale translation files (locales/*.json) are unchanged.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Ubuntu 2026-04-24 23:40:58 +00:00
parent e3943c7d7c
commit 7d172b9eec
34 changed files with 4882 additions and 4843 deletions

View File

@ -20,15 +20,15 @@ set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
# ── Carregar configuració ───────────────────────────────────────────────────── # ── Carregar configuració ─────────────────────────────────────────────────────
CONFIG_FILE="${SCRIPT_DIR}/config.sh" #CONFIG_FILE="${SCRIPT_DIR}/config.sh"
if [[ ! -f "$CONFIG_FILE" ]]; then #if [[ ! -f "$CONFIG_FILE" ]]; then
echo "ERROR: No s'ha trobat azure/config.sh" # echo "ERROR: No s'ha trobat azure/config.sh"
echo " Còpia l'exemple: cp azure/config.sh.example azure/config.sh" # echo " Còpia l'exemple: cp azure/config.sh.example azure/config.sh"
echo " Després omple els valors i torna a executar." # echo " Després omple els valors i torna a executar."
exit 1 # exit 1
fi #fi
# shellcheck source=config.sh.example # shellcheck source=config.sh.example
source "$CONFIG_FILE" #source "$CONFIG_FILE"
# ── Validar variables obligatòries ─────────────────────────────────────────── # ── Validar variables obligatòries ───────────────────────────────────────────
REQUIRED_VARS=( REQUIRED_VARS=(

View File

@ -17,14 +17,14 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)" REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
# ── Carregar configuració ───────────────────────────────────────────────────── # ── Carregar configuració ─────────────────────────────────────────────────────
CONFIG_FILE="${SCRIPT_DIR}/config.sh" #CONFIG_FILE="${SCRIPT_DIR}/config.sh"
if [[ ! -f "$CONFIG_FILE" ]]; then #if [[ ! -f "$CONFIG_FILE" ]]; then
echo "ERROR: No s'ha trobat azure/config.sh" # echo "ERROR: No s'ha trobat azure/config.sh"
echo " Còpia l'exemple: cp azure/config.sh.example azure/config.sh" # echo " Còpia l'exemple: cp azure/config.sh.example azure/config.sh"
exit 1 # exit 1
fi #fi
# shellcheck source=config.sh.example # shellcheck source=config.sh.example
source "$CONFIG_FILE" #source "$CONFIG_FILE"
# ── Validar variables obligatòries ─────────────────────────────────────────── # ── Validar variables obligatòries ───────────────────────────────────────────
REQUIRED_VARS=( REQUIRED_VARS=(

View File

@ -1,12 +1,12 @@
""" """
MiroFish Backend - Flask应用工厂 MiroFish Backend - Flask application factory
""" """
import os import os
import warnings import warnings
# 抑制 multiprocessing resource_tracker 的警告(来自第三方库如 transformers # Suppress multiprocessing resource_tracker warnings (from third-party libraries like transformers)
# 需要在所有其他导入之前设置 # Must be set before all other imports
warnings.filterwarnings("ignore", message=".*resource_tracker.*") warnings.filterwarnings("ignore", message=".*resource_tracker.*")
import jwt import jwt
@ -21,36 +21,36 @@ _PUBLIC_PATHS = {'/health', '/api/auth/login'}
def create_app(config_class=Config): def create_app(config_class=Config):
"""Flask应用工厂函数""" """Flask application factory"""
app = Flask(__name__) app = Flask(__name__)
app.config.from_object(config_class) app.config.from_object(config_class)
# 设置JSON编码确保中文直接显示而不是 \uXXXX 格式) # Configure JSON encoding: ensure non-ASCII characters are output directly (not as \uXXXX)
# Flask >= 2.3 使用 app.json.ensure_ascii旧版本使用 JSON_AS_ASCII 配置 # Flask >= 2.3 uses app.json.ensure_ascii; older versions use JSON_AS_ASCII config
if hasattr(app, 'json') and hasattr(app.json, 'ensure_ascii'): if hasattr(app, 'json') and hasattr(app.json, 'ensure_ascii'):
app.json.ensure_ascii = False app.json.ensure_ascii = False
# 设置日志 # Set up logging
logger = setup_logger('mirofish') logger = setup_logger('mirofish')
# 只在 reloader 子进程中打印启动信息(避免 debug 模式下打印两次) # Only log startup info in the reloader subprocess (avoids double-printing in debug mode)
is_reloader_process = os.environ.get('WERKZEUG_RUN_MAIN') == 'true' is_reloader_process = os.environ.get('WERKZEUG_RUN_MAIN') == 'true'
debug_mode = app.config.get('DEBUG', False) debug_mode = app.config.get('DEBUG', False)
should_log_startup = not debug_mode or is_reloader_process should_log_startup = not debug_mode or is_reloader_process
if should_log_startup: if should_log_startup:
logger.info("=" * 50) logger.info("=" * 50)
logger.info("MiroFish Backend 启动中...") logger.info("MiroFish Backend starting...")
logger.info("=" * 50) logger.info("=" * 50)
# 启用CORS # Enable CORS
CORS(app, resources={r"/api/*": {"origins": "*"}}) CORS(app, resources={r"/api/*": {"origins": "*"}})
# 注册模拟进程清理函数(确保服务器关闭时终止所有模拟进程) # Register simulation process cleanup (ensures all simulation processes are terminated on server shutdown)
from .services.simulation_runner import SimulationRunner from .services.simulation_runner import SimulationRunner
SimulationRunner.register_cleanup() SimulationRunner.register_cleanup()
if should_log_startup: if should_log_startup:
logger.info("已注册模拟进程清理函数") logger.info("Simulation process cleanup handler registered")
# Middleware d'autenticació JWT — s'executa ABANS del log_request (ordre FIFO) # Middleware d'autenticació JWT — s'executa ABANS del log_request (ordre FIFO)
@app.before_request @app.before_request
@ -70,28 +70,28 @@ def create_app(config_class=Config):
except jwt.InvalidTokenError: except jwt.InvalidTokenError:
return jsonify({'success': False, 'error': 'Invalid token'}), 401 return jsonify({'success': False, 'error': 'Invalid token'}), 401
# 请求日志中间件 # Request logging middleware
@app.before_request @app.before_request
def log_request(): def log_request():
logger = get_logger('mirofish.request') logger = get_logger('mirofish.request')
logger.debug(f"请求: {request.method} {request.path}") logger.debug(f"Request: {request.method} {request.path}")
if request.content_type and 'json' in request.content_type: if request.content_type and 'json' in request.content_type:
logger.debug(f"请求体: {request.get_json(silent=True)}") logger.debug(f"Request body: {request.get_json(silent=True)}")
@app.after_request @app.after_request
def log_response(response): def log_response(response):
logger = get_logger('mirofish.request') logger = get_logger('mirofish.request')
logger.debug(f"响应: {response.status_code}") logger.debug(f"Response: {response.status_code}")
return response return response
# 注册蓝图 (auth primer, luego els existents) # Register blueprints (auth first, then the rest)
from .api import graph_bp, simulation_bp, report_bp, auth_bp from .api import graph_bp, simulation_bp, report_bp, auth_bp
app.register_blueprint(auth_bp, url_prefix='/api/auth') app.register_blueprint(auth_bp, url_prefix='/api/auth')
app.register_blueprint(graph_bp, url_prefix='/api/graph') app.register_blueprint(graph_bp, url_prefix='/api/graph')
app.register_blueprint(simulation_bp, url_prefix='/api/simulation') app.register_blueprint(simulation_bp, url_prefix='/api/simulation')
app.register_blueprint(report_bp, url_prefix='/api/report') app.register_blueprint(report_bp, url_prefix='/api/report')
# 健康检查 # Health check
@app.route('/health') @app.route('/health')
def health(): def health():
return {'status': 'ok', 'service': 'MiroFish Backend'} return {'status': 'ok', 'service': 'MiroFish Backend'}
@ -111,6 +111,6 @@ def create_app(config_class=Config):
return _send_file(_os.path.join(_dist, 'index.html')) return _send_file(_os.path.join(_dist, 'index.html'))
if should_log_startup: if should_log_startup:
logger.info("MiroFish Backend 启动完成") logger.info("MiroFish Backend startup complete")
return app return app

View File

@ -1,5 +1,5 @@
""" """
API路由模块 API routes module
""" """
from flask import Blueprint from flask import Blueprint

View File

@ -1,6 +1,6 @@
""" """
图谱相关API路由 Graph-related API routes
采用项目上下文机制服务端持久化状态 Uses project context mechanism with server-side persistent state
""" """
import os import os
@ -19,24 +19,24 @@ from ..utils.locale import t, get_locale, set_locale
from ..models.task import TaskManager, TaskStatus from ..models.task import TaskManager, TaskStatus
from ..models.project import ProjectManager, ProjectStatus from ..models.project import ProjectManager, ProjectStatus
# 获取日志器 # Get logger
logger = get_logger('mirofish.api') logger = get_logger('mirofish.api')
def allowed_file(filename: str) -> bool: def allowed_file(filename: str) -> bool:
"""检查文件扩展名是否允许""" """Check if the file extension is allowed"""
if not filename or '.' not in filename: if not filename or '.' not in filename:
return False return False
ext = os.path.splitext(filename)[1].lower().lstrip('.') ext = os.path.splitext(filename)[1].lower().lstrip('.')
return ext in Config.ALLOWED_EXTENSIONS return ext in Config.ALLOWED_EXTENSIONS
# ============== 项目管理接口 ============== # ============== Project management endpoints ==============
@graph_bp.route('/project/<project_id>', methods=['GET']) @graph_bp.route('/project/<project_id>', methods=['GET'])
def get_project(project_id: str): def get_project(project_id: str):
""" """
获取项目详情 Get project details
""" """
project = ProjectManager.get_project(project_id) project = ProjectManager.get_project(project_id)
@ -55,7 +55,7 @@ def get_project(project_id: str):
@graph_bp.route('/project/list', methods=['GET']) @graph_bp.route('/project/list', methods=['GET'])
def list_projects(): def list_projects():
""" """
列出所有项目 List all projects
""" """
limit = request.args.get('limit', 50, type=int) limit = request.args.get('limit', 50, type=int)
projects = ProjectManager.list_projects(limit=limit) projects = ProjectManager.list_projects(limit=limit)
@ -70,7 +70,7 @@ def list_projects():
@graph_bp.route('/project/<project_id>', methods=['DELETE']) @graph_bp.route('/project/<project_id>', methods=['DELETE'])
def delete_project(project_id: str): def delete_project(project_id: str):
""" """
删除项目 Delete a project
""" """
success = ProjectManager.delete_project(project_id) success = ProjectManager.delete_project(project_id)
@ -89,7 +89,7 @@ def delete_project(project_id: str):
@graph_bp.route('/project/<project_id>/reset', methods=['POST']) @graph_bp.route('/project/<project_id>/reset', methods=['POST'])
def reset_project(project_id: str): def reset_project(project_id: str):
""" """
重置项目状态用于重新构建图谱 Reset project status (used to rebuild the graph)
""" """
project = ProjectManager.get_project(project_id) project = ProjectManager.get_project(project_id)
@ -99,7 +99,7 @@ def reset_project(project_id: str):
"error": t('api.projectNotFound', id=project_id) "error": t('api.projectNotFound', id=project_id)
}), 404 }), 404
# 重置到本体已生成状态 # Reset to ontology-generated status
if project.ontology: if project.ontology:
project.status = ProjectStatus.ONTOLOGY_GENERATED project.status = ProjectStatus.ONTOLOGY_GENERATED
else: else:
@ -117,22 +117,22 @@ def reset_project(project_id: str):
}) })
# ============== 接口1上传文件并生成本体 ============== # ============== Endpoint 1: Upload files and generate ontology ==============
@graph_bp.route('/ontology/generate', methods=['POST']) @graph_bp.route('/ontology/generate', methods=['POST'])
def generate_ontology(): def generate_ontology():
""" """
接口1上传文件分析生成本体定义 Endpoint 1: Upload files and generate ontology definition
请求方式multipart/form-data Request method: multipart/form-data
参数 Parameters:
files: 上传的文件PDF/MD/TXT可多个 files: Uploaded files (PDF/MD/TXT), multiple allowed
simulation_requirement: 模拟需求描述必填 simulation_requirement: Simulation requirement description (required)
project_name: 项目名称可选 project_name: Project name (optional)
additional_context: 额外说明可选 additional_context: Additional context (optional)
返回 Returns:
{ {
"success": true, "success": true,
"data": { "data": {
@ -148,15 +148,15 @@ def generate_ontology():
} }
""" """
try: try:
logger.info("=== 开始生成本体定义 ===") logger.info("=== Starting ontology generation ===")
# 获取参数 # Get parameters
simulation_requirement = request.form.get('simulation_requirement', '') simulation_requirement = request.form.get('simulation_requirement', '')
project_name = request.form.get('project_name', 'Unnamed Project') project_name = request.form.get('project_name', 'Unnamed Project')
additional_context = request.form.get('additional_context', '') additional_context = request.form.get('additional_context', '')
logger.debug(f"项目名称: {project_name}") logger.debug(f"Project name: {project_name}")
logger.debug(f"模拟需求: {simulation_requirement[:100]}...") logger.debug(f"Simulation requirement: {simulation_requirement[:100]}...")
if not simulation_requirement: if not simulation_requirement:
return jsonify({ return jsonify({
@ -164,7 +164,7 @@ def generate_ontology():
"error": t('api.requireSimulationRequirement') "error": t('api.requireSimulationRequirement')
}), 400 }), 400
# 获取上传的文件 # Get uploaded files
uploaded_files = request.files.getlist('files') uploaded_files = request.files.getlist('files')
if not uploaded_files or all(not f.filename for f in uploaded_files): if not uploaded_files or all(not f.filename for f in uploaded_files):
return jsonify({ return jsonify({
@ -172,18 +172,18 @@ def generate_ontology():
"error": t('api.requireFileUpload') "error": t('api.requireFileUpload')
}), 400 }), 400
# 创建项目 # Create project
project = ProjectManager.create_project(name=project_name) project = ProjectManager.create_project(name=project_name)
project.simulation_requirement = simulation_requirement project.simulation_requirement = simulation_requirement
logger.info(f"创建项目: {project.project_id}") logger.info(f"Project created: {project.project_id}")
# 保存文件并提取文本 # Save files and extract text
document_texts = [] document_texts = []
all_text = "" all_text = ""
for file in uploaded_files: for file in uploaded_files:
if file and file.filename and allowed_file(file.filename): if file and file.filename and allowed_file(file.filename):
# 保存文件到项目目录 # Save file to project directory
file_info = ProjectManager.save_file_to_project( file_info = ProjectManager.save_file_to_project(
project.project_id, project.project_id,
file, file,
@ -194,7 +194,7 @@ def generate_ontology():
"size": file_info["size"] "size": file_info["size"]
}) })
# 提取文本 # Extract text
text = FileParser.extract_text(file_info["path"]) text = FileParser.extract_text(file_info["path"])
text = TextProcessor.preprocess_text(text) text = TextProcessor.preprocess_text(text)
document_texts.append(text) document_texts.append(text)
@ -207,13 +207,13 @@ def generate_ontology():
"error": t('api.noDocProcessed') "error": t('api.noDocProcessed')
}), 400 }), 400
# 保存提取的文本 # Save extracted text
project.total_text_length = len(all_text) project.total_text_length = len(all_text)
ProjectManager.save_extracted_text(project.project_id, all_text) ProjectManager.save_extracted_text(project.project_id, all_text)
logger.info(f"文本提取完成,共 {len(all_text)} 字符") logger.info(f"Text extraction complete, total {len(all_text)} characters")
# 生成本体 # Generate ontology
logger.info("调用 LLM 生成本体定义...") logger.info("Calling LLM to generate ontology definition...")
generator = OntologyGenerator() generator = OntologyGenerator()
ontology = generator.generate( ontology = generator.generate(
document_texts=document_texts, document_texts=document_texts,
@ -221,10 +221,10 @@ def generate_ontology():
additional_context=additional_context if additional_context else None additional_context=additional_context if additional_context else None
) )
# 保存本体到项目 # Save ontology to project
entity_count = len(ontology.get("entity_types", [])) entity_count = len(ontology.get("entity_types", []))
edge_count = len(ontology.get("edge_types", [])) edge_count = len(ontology.get("edge_types", []))
logger.info(f"本体生成完成: {entity_count} 个实体类型, {edge_count} 个关系类型") logger.info(f"Ontology generation complete: {entity_count} entity types, {edge_count} relationship types")
project.ontology = { project.ontology = {
"entity_types": ontology.get("entity_types", []), "entity_types": ontology.get("entity_types", []),
@ -233,7 +233,7 @@ def generate_ontology():
project.analysis_summary = ontology.get("analysis_summary", "") project.analysis_summary = ontology.get("analysis_summary", "")
project.status = ProjectStatus.ONTOLOGY_GENERATED project.status = ProjectStatus.ONTOLOGY_GENERATED
ProjectManager.save_project(project) ProjectManager.save_project(project)
logger.info(f"=== 本体生成完成 === 项目ID: {project.project_id}") logger.info(f"=== Ontology generation complete === Project ID: {project.project_id}")
return jsonify({ return jsonify({
"success": True, "success": True,
@ -255,49 +255,49 @@ def generate_ontology():
}), 500 }), 500
# ============== 接口2构建图谱 ============== # ============== Endpoint 2: Build graph ==============
@graph_bp.route('/build', methods=['POST']) @graph_bp.route('/build', methods=['POST'])
def build_graph(): def build_graph():
""" """
接口2根据project_id构建图谱 Endpoint 2: Build graph from project_id
请求JSON Request (JSON):
{ {
"project_id": "proj_xxxx", // 必填来自接口1 "project_id": "proj_xxxx", // required, from endpoint 1
"graph_name": "图谱名称", // 可选 "graph_name": "Graph name", // optional
"chunk_size": 500, // 可选默认500 "chunk_size": 500, // optional, default 500
"chunk_overlap": 50 // 可选默认50 "chunk_overlap": 50 // optional, default 50
} }
返回 Returns:
{ {
"success": true, "success": true,
"data": { "data": {
"project_id": "proj_xxxx", "project_id": "proj_xxxx",
"task_id": "task_xxxx", "task_id": "task_xxxx",
"message": "图谱构建任务已启动" "message": "Graph build task started"
} }
} }
""" """
try: try:
logger.info("=== 开始构建图谱 ===") logger.info("=== Starting graph build ===")
# 检查配置 # Check configuration
errors = [] errors = []
if not Config.ZEP_API_KEY: if not Config.ZEP_API_KEY:
errors.append(t('api.zepApiKeyMissing')) errors.append(t('api.zepApiKeyMissing'))
if errors: if errors:
logger.error(f"配置错误: {errors}") logger.error(f"Configuration error: {errors}")
return jsonify({ return jsonify({
"success": False, "success": False,
"error": t('api.configError', details="; ".join(errors)) "error": t('api.configError', details="; ".join(errors))
}), 500 }), 500
# 解析请求 # Parse request
data = request.get_json() or {} data = request.get_json() or {}
project_id = data.get('project_id') project_id = data.get('project_id')
logger.debug(f"请求参数: project_id={project_id}") logger.debug(f"Request parameters: project_id={project_id}")
if not project_id: if not project_id:
return jsonify({ return jsonify({
@ -305,7 +305,7 @@ def build_graph():
"error": t('api.requireProjectId') "error": t('api.requireProjectId')
}), 400 }), 400
# 获取项目 # Get project
project = ProjectManager.get_project(project_id) project = ProjectManager.get_project(project_id)
if not project: if not project:
return jsonify({ return jsonify({
@ -313,8 +313,8 @@ def build_graph():
"error": t('api.projectNotFound', id=project_id) "error": t('api.projectNotFound', id=project_id)
}), 404 }), 404
# 检查项目状态 # Check project status
force = data.get('force', False) # 强制重新构建 force = data.get('force', False) # Force rebuild
if project.status == ProjectStatus.CREATED: if project.status == ProjectStatus.CREATED:
return jsonify({ return jsonify({
@ -329,23 +329,23 @@ def build_graph():
"task_id": project.graph_build_task_id "task_id": project.graph_build_task_id
}), 400 }), 400
# 如果强制重建,重置状态 # If force rebuild, reset status
if force and project.status in [ProjectStatus.GRAPH_BUILDING, ProjectStatus.FAILED, ProjectStatus.GRAPH_COMPLETED]: if force and project.status in [ProjectStatus.GRAPH_BUILDING, ProjectStatus.FAILED, ProjectStatus.GRAPH_COMPLETED]:
project.status = ProjectStatus.ONTOLOGY_GENERATED project.status = ProjectStatus.ONTOLOGY_GENERATED
project.graph_id = None project.graph_id = None
project.graph_build_task_id = None project.graph_build_task_id = None
project.error = None project.error = None
# 获取配置 # Get configuration
graph_name = data.get('graph_name', project.name or 'MiroFish Graph') graph_name = data.get('graph_name', project.name or 'MiroFish Graph')
chunk_size = data.get('chunk_size', project.chunk_size or Config.DEFAULT_CHUNK_SIZE) chunk_size = data.get('chunk_size', project.chunk_size or Config.DEFAULT_CHUNK_SIZE)
chunk_overlap = data.get('chunk_overlap', project.chunk_overlap or Config.DEFAULT_CHUNK_OVERLAP) chunk_overlap = data.get('chunk_overlap', project.chunk_overlap or Config.DEFAULT_CHUNK_OVERLAP)
# 更新项目配置 # Update project configuration
project.chunk_size = chunk_size project.chunk_size = chunk_size
project.chunk_overlap = chunk_overlap project.chunk_overlap = chunk_overlap
# 获取提取的文本 # Get extracted text
text = ProjectManager.get_extracted_text(project_id) text = ProjectManager.get_extracted_text(project_id)
if not text: if not text:
return jsonify({ return jsonify({
@ -353,7 +353,7 @@ def build_graph():
"error": t('api.textNotFound') "error": t('api.textNotFound')
}), 400 }), 400
# 获取本体 # Get ontology
ontology = project.ontology ontology = project.ontology
if not ontology: if not ontology:
return jsonify({ return jsonify({
@ -361,12 +361,12 @@ def build_graph():
"error": t('api.ontologyNotFound') "error": t('api.ontologyNotFound')
}), 400 }), 400
# 创建异步任务 # Create async task
task_manager = TaskManager() task_manager = TaskManager()
task_id = task_manager.create_task(f"构建图谱: {graph_name}") task_id = task_manager.create_task(f"Build graph: {graph_name}")
logger.info(f"创建图谱构建任务: task_id={task_id}, project_id={project_id}") logger.info(f"Graph build task created: task_id={task_id}, project_id={project_id}")
# 更新项目状态 # Update project status
project.status = ProjectStatus.GRAPH_BUILDING project.status = ProjectStatus.GRAPH_BUILDING
project.graph_build_task_id = task_id project.graph_build_task_id = task_id
ProjectManager.save_project(project) ProjectManager.save_project(project)
@ -374,22 +374,22 @@ def build_graph():
# Capture locale before spawning background thread # Capture locale before spawning background thread
current_locale = get_locale() current_locale = get_locale()
# 启动后台任务 # Start background task
def build_task(): def build_task():
set_locale(current_locale) set_locale(current_locale)
build_logger = get_logger('mirofish.build') build_logger = get_logger('mirofish.build')
try: try:
build_logger.info(f"[{task_id}] 开始构建图谱...") build_logger.info(f"[{task_id}] Starting graph build...")
task_manager.update_task( task_manager.update_task(
task_id, task_id,
status=TaskStatus.PROCESSING, status=TaskStatus.PROCESSING,
message=t('progress.initGraphService') message=t('progress.initGraphService')
) )
# 创建图谱构建服务 # Create graph builder service
builder = GraphBuilderService(api_key=Config.ZEP_API_KEY) builder = GraphBuilderService(api_key=Config.ZEP_API_KEY)
# 分块 # Split into chunks
task_manager.update_task( task_manager.update_task(
task_id, task_id,
message=t('progress.textChunking'), message=t('progress.textChunking'),
@ -402,7 +402,7 @@ def build_graph():
) )
total_chunks = len(chunks) total_chunks = len(chunks)
# 创建图谱 # Create graph
task_manager.update_task( task_manager.update_task(
task_id, task_id,
message=t('progress.creatingZepGraph'), message=t('progress.creatingZepGraph'),
@ -410,11 +410,11 @@ def build_graph():
) )
graph_id = builder.create_graph(name=graph_name) graph_id = builder.create_graph(name=graph_name)
# 更新项目的graph_id # Update project graph_id
project.graph_id = graph_id project.graph_id = graph_id
ProjectManager.save_project(project) ProjectManager.save_project(project)
# 设置本体 # Set ontology
task_manager.update_task( task_manager.update_task(
task_id, task_id,
message=t('progress.settingOntology'), message=t('progress.settingOntology'),
@ -422,7 +422,7 @@ def build_graph():
) )
builder.set_ontology(graph_id, ontology) builder.set_ontology(graph_id, ontology)
# 添加文本progress_callback 签名是 (msg, progress_ratio) # Add text (progress_callback signature: (msg, progress_ratio))
def add_progress_callback(msg, progress_ratio): def add_progress_callback(msg, progress_ratio):
progress = 15 + int(progress_ratio * 40) # 15% - 55% progress = 15 + int(progress_ratio * 40) # 15% - 55%
task_manager.update_task( task_manager.update_task(
@ -444,7 +444,7 @@ def build_graph():
progress_callback=add_progress_callback progress_callback=add_progress_callback
) )
# 等待Zep处理完成查询每个episode的processed状态 # Wait for Zep processing to complete (poll each episode's processed status)
task_manager.update_task( task_manager.update_task(
task_id, task_id,
message=t('progress.waitingZepProcess'), message=t('progress.waitingZepProcess'),
@ -461,7 +461,7 @@ def build_graph():
builder._wait_for_episodes(episode_uuids, wait_progress_callback) builder._wait_for_episodes(episode_uuids, wait_progress_callback)
# 获取图谱数据 # Fetch graph data
task_manager.update_task( task_manager.update_task(
task_id, task_id,
message=t('progress.fetchingGraphData'), message=t('progress.fetchingGraphData'),
@ -469,15 +469,15 @@ def build_graph():
) )
graph_data = builder.get_graph_data(graph_id) graph_data = builder.get_graph_data(graph_id)
# 更新项目状态 # Update project status
project.status = ProjectStatus.GRAPH_COMPLETED project.status = ProjectStatus.GRAPH_COMPLETED
ProjectManager.save_project(project) ProjectManager.save_project(project)
node_count = graph_data.get("node_count", 0) node_count = graph_data.get("node_count", 0)
edge_count = graph_data.get("edge_count", 0) edge_count = graph_data.get("edge_count", 0)
build_logger.info(f"[{task_id}] 图谱构建完成: graph_id={graph_id}, 节点={node_count}, 边={edge_count}") build_logger.info(f"[{task_id}] Graph build complete: graph_id={graph_id}, nodes={node_count}, edges={edge_count}")
# 完成 # Complete
task_manager.update_task( task_manager.update_task(
task_id, task_id,
status=TaskStatus.COMPLETED, status=TaskStatus.COMPLETED,
@ -493,8 +493,8 @@ def build_graph():
) )
except Exception as e: except Exception as e:
# 更新项目状态为失败 # Update project status to failed
build_logger.error(f"[{task_id}] 图谱构建失败: {str(e)}") build_logger.error(f"[{task_id}] Graph build failed: {str(e)}")
build_logger.debug(traceback.format_exc()) build_logger.debug(traceback.format_exc())
project.status = ProjectStatus.FAILED project.status = ProjectStatus.FAILED
@ -508,7 +508,7 @@ def build_graph():
error=traceback.format_exc() error=traceback.format_exc()
) )
# 启动后台线程 # Start background thread
thread = threading.Thread(target=build_task, daemon=True) thread = threading.Thread(target=build_task, daemon=True)
thread.start() thread.start()
@ -529,12 +529,12 @@ def build_graph():
}), 500 }), 500
# ============== 任务查询接口 ============== # ============== Task query endpoints ==============
@graph_bp.route('/task/<task_id>', methods=['GET']) @graph_bp.route('/task/<task_id>', methods=['GET'])
def get_task(task_id: str): def get_task(task_id: str):
""" """
查询任务状态 Query task status
""" """
task = TaskManager().get_task(task_id) task = TaskManager().get_task(task_id)
@ -553,7 +553,7 @@ def get_task(task_id: str):
@graph_bp.route('/tasks', methods=['GET']) @graph_bp.route('/tasks', methods=['GET'])
def list_tasks(): def list_tasks():
""" """
列出所有任务 List all tasks
""" """
tasks = TaskManager().list_tasks() tasks = TaskManager().list_tasks()
@ -564,12 +564,12 @@ def list_tasks():
}) })
# ============== 图谱数据接口 ============== # ============== Graph data endpoints ==============
@graph_bp.route('/data/<graph_id>', methods=['GET']) @graph_bp.route('/data/<graph_id>', methods=['GET'])
def get_graph_data(graph_id: str): def get_graph_data(graph_id: str):
""" """
获取图谱数据节点和边 Get graph data (nodes and edges)
""" """
try: try:
if not Config.ZEP_API_KEY: if not Config.ZEP_API_KEY:
@ -597,7 +597,7 @@ def get_graph_data(graph_id: str):
@graph_bp.route('/delete/<graph_id>', methods=['DELETE']) @graph_bp.route('/delete/<graph_id>', methods=['DELETE'])
def delete_graph(graph_id: str): def delete_graph(graph_id: str):
""" """
删除Zep图谱 Delete a Zep graph
""" """
try: try:
if not Config.ZEP_API_KEY: if not Config.ZEP_API_KEY:

View File

@ -1,6 +1,6 @@
""" """
Report API路由 Report API routes
提供模拟报告生成获取对话等接口 Provides simulation report generation, retrieval, and chat endpoints
""" """
import os import os
@ -20,30 +20,30 @@ from ..utils.locale import t, get_locale, set_locale
logger = get_logger('mirofish.api.report') logger = get_logger('mirofish.api.report')
# ============== 报告生成接口 ============== # ============== Report generation endpoints ==============
@report_bp.route('/generate', methods=['POST']) @report_bp.route('/generate', methods=['POST'])
def generate_report(): def generate_report():
""" """
生成模拟分析报告异步任务 Generate a simulation analysis report (async task)
这是一个耗时操作接口会立即返回task_id This is a long-running operation; the endpoint returns task_id immediately.
使用 GET /api/report/generate/status 查询进度 Use GET /api/report/generate/status to poll progress.
请求JSON Request (JSON):
{ {
"simulation_id": "sim_xxxx", // 必填模拟ID "simulation_id": "sim_xxxx", // required, simulation ID
"force_regenerate": false // 可选强制重新生成 "force_regenerate": false // optional, force regeneration
} }
返回 Returns:
{ {
"success": true, "success": true,
"data": { "data": {
"simulation_id": "sim_xxxx", "simulation_id": "sim_xxxx",
"task_id": "task_xxxx", "task_id": "task_xxxx",
"status": "generating", "status": "generating",
"message": "报告生成任务已启动" "message": "Report generation task started"
} }
} }
""" """
@ -59,7 +59,7 @@ def generate_report():
force_regenerate = data.get('force_regenerate', False) force_regenerate = data.get('force_regenerate', False)
# 获取模拟信息 # Get simulation info
manager = SimulationManager() manager = SimulationManager()
state = manager.get_simulation(simulation_id) state = manager.get_simulation(simulation_id)
@ -69,7 +69,7 @@ def generate_report():
"error": t('api.simulationNotFound', id=simulation_id) "error": t('api.simulationNotFound', id=simulation_id)
}), 404 }), 404
# 检查是否已有报告 # Check if a report already exists
if not force_regenerate: if not force_regenerate:
existing_report = ReportManager.get_report_by_simulation(simulation_id) existing_report = ReportManager.get_report_by_simulation(simulation_id)
if existing_report and existing_report.status == ReportStatus.COMPLETED: if existing_report and existing_report.status == ReportStatus.COMPLETED:
@ -84,7 +84,7 @@ def generate_report():
} }
}) })
# 获取项目信息 # Get project info
project = ProjectManager.get_project(state.project_id) project = ProjectManager.get_project(state.project_id)
if not project: if not project:
return jsonify({ return jsonify({
@ -106,11 +106,11 @@ def generate_report():
"error": t('api.missingSimRequirement') "error": t('api.missingSimRequirement')
}), 400 }), 400
# 提前生成 report_id以便立即返回给前端 # Pre-generate report_id so it can be returned immediately
import uuid import uuid
report_id = f"report_{uuid.uuid4().hex[:12]}" report_id = f"report_{uuid.uuid4().hex[:12]}"
# 创建异步任务 # Create async task
task_manager = TaskManager() task_manager = TaskManager()
task_id = task_manager.create_task( task_id = task_manager.create_task(
task_type="report_generate", task_type="report_generate",
@ -124,7 +124,7 @@ def generate_report():
# Capture locale before spawning background thread # Capture locale before spawning background thread
current_locale = get_locale() current_locale = get_locale()
# 定义后台任务 # Define background task
def run_generate(): def run_generate():
set_locale(current_locale) set_locale(current_locale)
try: try:
@ -135,14 +135,14 @@ def generate_report():
message=t('api.initReportAgent') message=t('api.initReportAgent')
) )
# 创建Report Agent # Create Report Agent
agent = ReportAgent( agent = ReportAgent(
graph_id=graph_id, graph_id=graph_id,
simulation_id=simulation_id, simulation_id=simulation_id,
simulation_requirement=simulation_requirement simulation_requirement=simulation_requirement
) )
# 进度回调 # Progress callback
def progress_callback(stage, progress, message): def progress_callback(stage, progress, message):
task_manager.update_task( task_manager.update_task(
task_id, task_id,
@ -150,13 +150,13 @@ def generate_report():
message=f"[{stage}] {message}" message=f"[{stage}] {message}"
) )
# 生成报告(传入预先生成的 report_id # Generate report (pass pre-generated report_id)
report = agent.generate_report( report = agent.generate_report(
progress_callback=progress_callback, progress_callback=progress_callback,
report_id=report_id report_id=report_id
) )
# 保存报告 # Save report
ReportManager.save_report(report) ReportManager.save_report(report)
if report.status == ReportStatus.COMPLETED: if report.status == ReportStatus.COMPLETED:
@ -172,10 +172,10 @@ def generate_report():
task_manager.fail_task(task_id, report.error or t('api.reportGenerateFailed')) task_manager.fail_task(task_id, report.error or t('api.reportGenerateFailed'))
except Exception as e: except Exception as e:
logger.error(f"报告生成失败: {str(e)}") logger.error(f"Report generation failed: {str(e)}")
task_manager.fail_task(task_id, str(e)) task_manager.fail_task(task_id, str(e))
# 启动后台线程 # Start background thread
thread = threading.Thread(target=run_generate, daemon=True) thread = threading.Thread(target=run_generate, daemon=True)
thread.start() thread.start()
@ -192,7 +192,7 @@ def generate_report():
}) })
except Exception as e: except Exception as e:
logger.error(f"启动报告生成任务失败: {str(e)}") logger.error(f"Failed to start report generation task: {str(e)}")
return jsonify({ return jsonify({
"success": False, "success": False,
"error": str(e), "error": str(e),
@ -203,15 +203,15 @@ def generate_report():
@report_bp.route('/generate/status', methods=['POST']) @report_bp.route('/generate/status', methods=['POST'])
def get_generate_status(): def get_generate_status():
""" """
查询报告生成任务进度 Query report generation task progress
请求JSON Request (JSON):
{ {
"task_id": "task_xxxx", // 可选generate返回的task_id "task_id": "task_xxxx", // optional, task_id from generate
"simulation_id": "sim_xxxx" // 可选模拟ID "simulation_id": "sim_xxxx" // optional, simulation ID
} }
返回 Returns:
{ {
"success": true, "success": true,
"data": { "data": {
@ -228,7 +228,7 @@ def get_generate_status():
task_id = data.get('task_id') task_id = data.get('task_id')
simulation_id = data.get('simulation_id') simulation_id = data.get('simulation_id')
# 如果提供了simulation_id先检查是否已有完成的报告 # If simulation_id is provided, check whether a completed report exists
if simulation_id: if simulation_id:
existing_report = ReportManager.get_report_by_simulation(simulation_id) existing_report = ReportManager.get_report_by_simulation(simulation_id)
if existing_report and existing_report.status == ReportStatus.COMPLETED: if existing_report and existing_report.status == ReportStatus.COMPLETED:
@ -265,21 +265,21 @@ def get_generate_status():
}) })
except Exception as e: except Exception as e:
logger.error(f"查询任务状态失败: {str(e)}") logger.error(f"Failed to query task status: {str(e)}")
return jsonify({ return jsonify({
"success": False, "success": False,
"error": str(e) "error": str(e)
}), 500 }), 500
# ============== 报告获取接口 ============== # ============== Report retrieval endpoints ==============
@report_bp.route('/<report_id>', methods=['GET']) @report_bp.route('/<report_id>', methods=['GET'])
def get_report(report_id: str): def get_report(report_id: str):
""" """
获取报告详情 Get report details
返回 Returns:
{ {
"success": true, "success": true,
"data": { "data": {
@ -308,7 +308,7 @@ def get_report(report_id: str):
}) })
except Exception as e: except Exception as e:
logger.error(f"获取报告失败: {str(e)}") logger.error(f"Failed to get report: {str(e)}")
return jsonify({ return jsonify({
"success": False, "success": False,
"error": str(e), "error": str(e),
@ -319,9 +319,9 @@ def get_report(report_id: str):
@report_bp.route('/by-simulation/<simulation_id>', methods=['GET']) @report_bp.route('/by-simulation/<simulation_id>', methods=['GET'])
def get_report_by_simulation(simulation_id: str): def get_report_by_simulation(simulation_id: str):
""" """
根据模拟ID获取报告 Get report by simulation ID
返回 Returns:
{ {
"success": true, "success": true,
"data": { "data": {
@ -347,7 +347,7 @@ def get_report_by_simulation(simulation_id: str):
}) })
except Exception as e: except Exception as e:
logger.error(f"获取报告失败: {str(e)}") logger.error(f"Failed to get report: {str(e)}")
return jsonify({ return jsonify({
"success": False, "success": False,
"error": str(e), "error": str(e),
@ -358,13 +358,13 @@ def get_report_by_simulation(simulation_id: str):
@report_bp.route('/list', methods=['GET']) @report_bp.route('/list', methods=['GET'])
def list_reports(): def list_reports():
""" """
列出所有报告 List all reports
Query参数 Query parameters:
simulation_id: 按模拟ID过滤可选 simulation_id: filter by simulation ID (optional)
limit: 返回数量限制默认50 limit: result count limit (default 50)
返回 Returns:
{ {
"success": true, "success": true,
"data": [...], "data": [...],
@ -387,7 +387,7 @@ def list_reports():
}) })
except Exception as e: except Exception as e:
logger.error(f"列出报告失败: {str(e)}") logger.error(f"Failed to list reports: {str(e)}")
return jsonify({ return jsonify({
"success": False, "success": False,
"error": str(e), "error": str(e),
@ -398,9 +398,9 @@ def list_reports():
@report_bp.route('/<report_id>/download', methods=['GET']) @report_bp.route('/<report_id>/download', methods=['GET'])
def download_report(report_id: str): def download_report(report_id: str):
""" """
下载报告Markdown格式 Download report (Markdown format)
返回Markdown文件 Returns a Markdown file
""" """
try: try:
report = ReportManager.get_report(report_id) report = ReportManager.get_report(report_id)
@ -414,7 +414,7 @@ def download_report(report_id: str):
md_path = ReportManager._get_report_markdown_path(report_id) md_path = ReportManager._get_report_markdown_path(report_id)
if not os.path.exists(md_path): if not os.path.exists(md_path):
# 如果MD文件不存在生成一个临时文件 # If MD file doesn't exist, create a temporary file
import tempfile import tempfile
with tempfile.NamedTemporaryFile(mode='w', suffix='.md', delete=False) as f: with tempfile.NamedTemporaryFile(mode='w', suffix='.md', delete=False) as f:
f.write(report.markdown_content) f.write(report.markdown_content)
@ -433,7 +433,7 @@ def download_report(report_id: str):
) )
except Exception as e: except Exception as e:
logger.error(f"下载报告失败: {str(e)}") logger.error(f"Failed to download report: {str(e)}")
return jsonify({ return jsonify({
"success": False, "success": False,
"error": str(e), "error": str(e),
@ -443,7 +443,7 @@ def download_report(report_id: str):
@report_bp.route('/<report_id>', methods=['DELETE']) @report_bp.route('/<report_id>', methods=['DELETE'])
def delete_report(report_id: str): def delete_report(report_id: str):
"""删除报告""" """Delete a report"""
try: try:
success = ReportManager.delete_report(report_id) success = ReportManager.delete_report(report_id)
@ -459,7 +459,7 @@ def delete_report(report_id: str):
}) })
except Exception as e: except Exception as e:
logger.error(f"删除报告失败: {str(e)}") logger.error(f"Failed to delete report: {str(e)}")
return jsonify({ return jsonify({
"success": False, "success": False,
"error": str(e), "error": str(e),
@ -467,32 +467,32 @@ def delete_report(report_id: str):
}), 500 }), 500
# ============== Report Agent对话接口 ============== # ============== Report Agent chat endpoint ==============
@report_bp.route('/chat', methods=['POST']) @report_bp.route('/chat', methods=['POST'])
def chat_with_report_agent(): def chat_with_report_agent():
""" """
与Report Agent对话 Chat with the Report Agent
Report Agent可以在对话中自主调用检索工具来回答问题 The Report Agent can autonomously call retrieval tools to answer questions.
请求JSON Request (JSON):
{ {
"simulation_id": "sim_xxxx", // 必填模拟ID "simulation_id": "sim_xxxx", // required, simulation ID
"message": "请解释一下舆情走向", // 必填用户消息 "message": "Explain the trend...", // required, user message
"chat_history": [ // 可选对话历史 "chat_history": [ // optional, conversation history
{"role": "user", "content": "..."}, {"role": "user", "content": "..."},
{"role": "assistant", "content": "..."} {"role": "assistant", "content": "..."}
] ]
} }
返回 Returns:
{ {
"success": true, "success": true,
"data": { "data": {
"response": "Agent回复...", "response": "Agent reply...",
"tool_calls": [调用的工具列表], "tool_calls": [list of tools called],
"sources": [信息来源] "sources": [information sources]
} }
} }
""" """
@ -515,7 +515,7 @@ def chat_with_report_agent():
"error": t('api.requireMessage') "error": t('api.requireMessage')
}), 400 }), 400
# 获取模拟和项目信息 # Get simulation and project info
manager = SimulationManager() manager = SimulationManager()
state = manager.get_simulation(simulation_id) state = manager.get_simulation(simulation_id)
@ -541,7 +541,7 @@ def chat_with_report_agent():
simulation_requirement = project.simulation_requirement or "" simulation_requirement = project.simulation_requirement or ""
# 创建Agent并进行对话 # Create agent and start chat
agent = ReportAgent( agent = ReportAgent(
graph_id=graph_id, graph_id=graph_id,
simulation_id=simulation_id, simulation_id=simulation_id,
@ -556,7 +556,7 @@ def chat_with_report_agent():
}) })
except Exception as e: except Exception as e:
logger.error(f"对话失败: {str(e)}") logger.error(f"Chat failed: {str(e)}")
return jsonify({ return jsonify({
"success": False, "success": False,
"error": str(e), "error": str(e),
@ -564,22 +564,22 @@ def chat_with_report_agent():
}), 500 }), 500
# ============== 报告进度与分章节接口 ============== # ============== Report progress and section endpoints ==============
@report_bp.route('/<report_id>/progress', methods=['GET']) @report_bp.route('/<report_id>/progress', methods=['GET'])
def get_report_progress(report_id: str): def get_report_progress(report_id: str):
""" """
获取报告生成进度实时 Get report generation progress (real-time)
返回 Returns:
{ {
"success": true, "success": true,
"data": { "data": {
"status": "generating", "status": "generating",
"progress": 45, "progress": 45,
"message": "正在生成章节: 关键发现", "message": "Generating section: Key Findings",
"current_section": "关键发现", "current_section": "Key Findings",
"completed_sections": ["执行摘要", "模拟背景"], "completed_sections": ["Executive Summary", "Simulation Background"],
"updated_at": "2025-12-09T..." "updated_at": "2025-12-09T..."
} }
} }
@ -599,7 +599,7 @@ def get_report_progress(report_id: str):
}) })
except Exception as e: except Exception as e:
logger.error(f"获取报告进度失败: {str(e)}") logger.error(f"Failed to get report progress: {str(e)}")
return jsonify({ return jsonify({
"success": False, "success": False,
"error": str(e), "error": str(e),
@ -610,11 +610,12 @@ def get_report_progress(report_id: str):
@report_bp.route('/<report_id>/sections', methods=['GET']) @report_bp.route('/<report_id>/sections', methods=['GET'])
def get_report_sections(report_id: str): def get_report_sections(report_id: str):
""" """
获取已生成的章节列表分章节输出 Get list of already-generated sections (section-by-section output)
前端可以轮询此接口获取已生成的章节内容无需等待整个报告完成 The frontend can poll this endpoint to get section content as it is generated,
without waiting for the full report to complete.
返回 Returns:
{ {
"success": true, "success": true,
"data": { "data": {
@ -623,7 +624,7 @@ def get_report_sections(report_id: str):
{ {
"filename": "section_01.md", "filename": "section_01.md",
"section_index": 1, "section_index": 1,
"content": "## 执行摘要\\n\\n..." "content": "## Executive Summary\\n\\n..."
}, },
... ...
], ],
@ -635,7 +636,7 @@ def get_report_sections(report_id: str):
try: try:
sections = ReportManager.get_generated_sections(report_id) sections = ReportManager.get_generated_sections(report_id)
# 获取报告状态 # Get report status
report = ReportManager.get_report(report_id) report = ReportManager.get_report(report_id)
is_complete = report is not None and report.status == ReportStatus.COMPLETED is_complete = report is not None and report.status == ReportStatus.COMPLETED
@ -650,7 +651,7 @@ def get_report_sections(report_id: str):
}) })
except Exception as e: except Exception as e:
logger.error(f"获取章节列表失败: {str(e)}") logger.error(f"Failed to get section list: {str(e)}")
return jsonify({ return jsonify({
"success": False, "success": False,
"error": str(e), "error": str(e),
@ -661,14 +662,14 @@ def get_report_sections(report_id: str):
@report_bp.route('/<report_id>/section/<int:section_index>', methods=['GET']) @report_bp.route('/<report_id>/section/<int:section_index>', methods=['GET'])
def get_single_section(report_id: str, section_index: int): def get_single_section(report_id: str, section_index: int):
""" """
获取单个章节内容 Get single section content
返回 Returns:
{ {
"success": true, "success": true,
"data": { "data": {
"filename": "section_01.md", "filename": "section_01.md",
"content": "## 执行摘要\\n\\n..." "content": "## Executive Summary\\n\\n..."
} }
} }
""" """
@ -694,7 +695,7 @@ def get_single_section(report_id: str, section_index: int):
}) })
except Exception as e: except Exception as e:
logger.error(f"获取章节内容失败: {str(e)}") logger.error(f"Failed to get section content: {str(e)}")
return jsonify({ return jsonify({
"success": False, "success": False,
"error": str(e), "error": str(e),
@ -702,16 +703,16 @@ def get_single_section(report_id: str, section_index: int):
}), 500 }), 500
# ============== 报告状态检查接口 ============== # ============== Report status check endpoint ==============
@report_bp.route('/check/<simulation_id>', methods=['GET']) @report_bp.route('/check/<simulation_id>', methods=['GET'])
def check_report_status(simulation_id: str): def check_report_status(simulation_id: str):
""" """
检查模拟是否有报告以及报告状态 Check whether a simulation has a report and its status
用于前端判断是否解锁Interview功能 Used by the frontend to determine whether to unlock the Interview feature.
返回 Returns:
{ {
"success": true, "success": true,
"data": { "data": {
@ -730,7 +731,7 @@ def check_report_status(simulation_id: str):
report_status = report.status.value if report else None report_status = report.status.value if report else None
report_id = report.report_id if report else None report_id = report.report_id if report else None
# 只有报告完成后才解锁interview # Interview is unlocked only after the report is complete
interview_unlocked = has_report and report.status == ReportStatus.COMPLETED interview_unlocked = has_report and report.status == ReportStatus.COMPLETED
return jsonify({ return jsonify({
@ -745,7 +746,7 @@ def check_report_status(simulation_id: str):
}) })
except Exception as e: except Exception as e:
logger.error(f"检查报告状态失败: {str(e)}") logger.error(f"Failed to check report status: {str(e)}")
return jsonify({ return jsonify({
"success": False, "success": False,
"error": str(e), "error": str(e),
@ -753,22 +754,22 @@ def check_report_status(simulation_id: str):
}), 500 }), 500
# ============== Agent 日志接口 ============== # ============== Agent log endpoints ==============
@report_bp.route('/<report_id>/agent-log', methods=['GET']) @report_bp.route('/<report_id>/agent-log', methods=['GET'])
def get_agent_log(report_id: str): def get_agent_log(report_id: str):
""" """
获取 Report Agent 的详细执行日志 Get detailed execution log of the Report Agent
实时获取报告生成过程中的每一步动作包括 Retrieves step-by-step actions during report generation, including:
- 报告开始规划开始/完成 - Report start, planning start/complete
- 每个章节的开始工具调用LLM响应完成 - Each section's start, tool calls, LLM response, completion
- 报告完成或失败 - Report completion or failure
Query参数 Query parameters:
from_line: 从第几行开始读取可选默认0用于增量获取 from_line: start reading from this line (optional, default 0, for incremental fetch)
返回 Returns:
{ {
"success": true, "success": true,
"data": { "data": {
@ -779,7 +780,7 @@ def get_agent_log(report_id: str):
"report_id": "report_xxxx", "report_id": "report_xxxx",
"action": "tool_call", "action": "tool_call",
"stage": "generating", "stage": "generating",
"section_title": "执行摘要", "section_title": "Executive Summary",
"section_index": 1, "section_index": 1,
"details": { "details": {
"tool_name": "insight_forge", "tool_name": "insight_forge",
@ -806,7 +807,7 @@ def get_agent_log(report_id: str):
}) })
except Exception as e: except Exception as e:
logger.error(f"获取Agent日志失败: {str(e)}") logger.error(f"Failed to get Agent log: {str(e)}")
return jsonify({ return jsonify({
"success": False, "success": False,
"error": str(e), "error": str(e),
@ -817,9 +818,9 @@ def get_agent_log(report_id: str):
@report_bp.route('/<report_id>/agent-log/stream', methods=['GET']) @report_bp.route('/<report_id>/agent-log/stream', methods=['GET'])
def stream_agent_log(report_id: str): def stream_agent_log(report_id: str):
""" """
获取完整的 Agent 日志一次性获取全部 Get the full Agent log (fetch all at once)
返回 Returns:
{ {
"success": true, "success": true,
"data": { "data": {
@ -840,7 +841,7 @@ def stream_agent_log(report_id: str):
}) })
except Exception as e: except Exception as e:
logger.error(f"获取Agent日志失败: {str(e)}") logger.error(f"Failed to get Agent log: {str(e)}")
return jsonify({ return jsonify({
"success": False, "success": False,
"error": str(e), "error": str(e),
@ -848,27 +849,27 @@ def stream_agent_log(report_id: str):
}), 500 }), 500
# ============== 控制台日志接口 ============== # ============== Console log endpoints ==============
@report_bp.route('/<report_id>/console-log', methods=['GET']) @report_bp.route('/<report_id>/console-log', methods=['GET'])
def get_console_log(report_id: str): def get_console_log(report_id: str):
""" """
获取 Report Agent 的控制台输出日志 Get the console output log of the Report Agent
实时获取报告生成过程中的控制台输出INFOWARNING等 Returns real-time console output (INFO, WARNING, etc.) during report generation.
这与 agent-log 接口返回的结构化 JSON 日志不同 Unlike the agent-log endpoint which returns structured JSON logs,
是纯文本格式的控制台风格日志 this returns plain-text console-style logs.
Query参数 Query parameters:
from_line: 从第几行开始读取可选默认0用于增量获取 from_line: start reading from this line (optional, default 0, for incremental fetch)
返回 Returns:
{ {
"success": true, "success": true,
"data": { "data": {
"logs": [ "logs": [
"[19:46:14] INFO: 搜索完成: 找到 15 条相关事实", "[19:46:14] INFO: Search complete: found 15 relevant facts",
"[19:46:14] INFO: 图谱搜索: graph_id=xxx, query=...", "[19:46:14] INFO: Graph search: graph_id=xxx, query=...",
... ...
], ],
"total_lines": 100, "total_lines": 100,
@ -888,7 +889,7 @@ def get_console_log(report_id: str):
}) })
except Exception as e: except Exception as e:
logger.error(f"获取控制台日志失败: {str(e)}") logger.error(f"Failed to get console log: {str(e)}")
return jsonify({ return jsonify({
"success": False, "success": False,
"error": str(e), "error": str(e),
@ -899,9 +900,9 @@ def get_console_log(report_id: str):
@report_bp.route('/<report_id>/console-log/stream', methods=['GET']) @report_bp.route('/<report_id>/console-log/stream', methods=['GET'])
def stream_console_log(report_id: str): def stream_console_log(report_id: str):
""" """
获取完整的控制台日志一次性获取全部 Get the full console log (fetch all at once)
返回 Returns:
{ {
"success": true, "success": true,
"data": { "data": {
@ -922,7 +923,7 @@ def stream_console_log(report_id: str):
}) })
except Exception as e: except Exception as e:
logger.error(f"获取控制台日志失败: {str(e)}") logger.error(f"Failed to get console log: {str(e)}")
return jsonify({ return jsonify({
"success": False, "success": False,
"error": str(e), "error": str(e),
@ -930,17 +931,17 @@ def stream_console_log(report_id: str):
}), 500 }), 500
# ============== 工具调用接口(供调试使用)============== # ============== Tool call endpoints (for debugging) ==============
@report_bp.route('/tools/search', methods=['POST']) @report_bp.route('/tools/search', methods=['POST'])
def search_graph_tool(): def search_graph_tool():
""" """
图谱搜索工具接口供调试使用 Graph search tool endpoint (for debugging)
请求JSON Request (JSON):
{ {
"graph_id": "mirofish_xxxx", "graph_id": "mirofish_xxxx",
"query": "搜索查询", "query": "search query",
"limit": 10 "limit": 10
} }
""" """
@ -972,7 +973,7 @@ def search_graph_tool():
}) })
except Exception as e: except Exception as e:
logger.error(f"图谱搜索失败: {str(e)}") logger.error(f"Graph search failed: {str(e)}")
return jsonify({ return jsonify({
"success": False, "success": False,
"error": str(e), "error": str(e),
@ -983,9 +984,9 @@ def search_graph_tool():
@report_bp.route('/tools/statistics', methods=['POST']) @report_bp.route('/tools/statistics', methods=['POST'])
def get_graph_statistics_tool(): def get_graph_statistics_tool():
""" """
图谱统计工具接口供调试使用 Graph statistics tool endpoint (for debugging)
请求JSON Request (JSON):
{ {
"graph_id": "mirofish_xxxx" "graph_id": "mirofish_xxxx"
} }
@ -1012,7 +1013,7 @@ def get_graph_statistics_tool():
}) })
except Exception as e: except Exception as e:
logger.error(f"获取图谱统计失败: {str(e)}") logger.error(f"Failed to get graph statistics: {str(e)}")
return jsonify({ return jsonify({
"success": False, "success": False,
"error": str(e), "error": str(e),

File diff suppressed because it is too large Load Diff

View File

@ -1,55 +1,55 @@
""" """
配置管理 Configuration management
统一从项目根目录的 .env 文件加载配置 Loads config uniformly from the .env file at the project root
""" """
import os import os
from dotenv import load_dotenv from dotenv import load_dotenv
# 加载项目根目录的 .env 文件 # Load the .env file from the project root
# 路径: MiroFish/.env (相对于 backend/app/config.py) # Path: MiroFish/.env (relative to backend/app/config.py)
project_root_env = os.path.join(os.path.dirname(__file__), '../../.env') project_root_env = os.path.join(os.path.dirname(__file__), '../../.env')
if os.path.exists(project_root_env): if os.path.exists(project_root_env):
load_dotenv(project_root_env, override=True) load_dotenv(project_root_env, override=True)
else: else:
# 如果根目录没有 .env尝试加载环境变量用于生产环境 # If no root-level .env file found, load from environment variables (production)
load_dotenv(override=True) load_dotenv(override=True)
class Config: class Config:
"""Flask配置类""" """Flask configuration class"""
# Flask配置 # Flask settings
SECRET_KEY = os.environ.get('SECRET_KEY', 'mirofish-secret-key') SECRET_KEY = os.environ.get('SECRET_KEY', 'mirofish-secret-key')
DEMO_PASSWORD = os.environ.get('DEMO_PASSWORD', '') DEMO_PASSWORD = os.environ.get('DEMO_PASSWORD', '')
DEBUG = os.environ.get('FLASK_DEBUG', 'True').lower() == 'true' DEBUG = os.environ.get('FLASK_DEBUG', 'True').lower() == 'true'
# JSON配置 - 禁用ASCII转义让中文直接显示而不是 \uXXXX 格式) # JSON settings - disable ASCII escaping so non-ASCII chars are output directly (not as \uXXXX)
JSON_AS_ASCII = False JSON_AS_ASCII = False
# LLM配置统一使用OpenAI格式 # LLM settings (unified OpenAI-compatible format)
LLM_API_KEY = os.environ.get('LLM_API_KEY') LLM_API_KEY = os.environ.get('LLM_API_KEY')
LLM_BASE_URL = os.environ.get('LLM_BASE_URL', 'https://api.openai.com/v1') LLM_BASE_URL = os.environ.get('LLM_BASE_URL', 'https://api.openai.com/v1')
LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME', 'gpt-4o-mini') LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME', 'gpt-4o-mini')
# Zep配置 # Zep settings
ZEP_API_KEY = os.environ.get('ZEP_API_KEY') ZEP_API_KEY = os.environ.get('ZEP_API_KEY')
# 文件上传配置 # File upload settings
MAX_CONTENT_LENGTH = 50 * 1024 * 1024 # 50MB MAX_CONTENT_LENGTH = 50 * 1024 * 1024 # 50MB
UPLOAD_FOLDER = os.path.join(os.path.dirname(__file__), '../uploads') UPLOAD_FOLDER = os.path.join(os.path.dirname(__file__), '../uploads')
ALLOWED_EXTENSIONS = {'pdf', 'md', 'txt', 'markdown'} ALLOWED_EXTENSIONS = {'pdf', 'md', 'txt', 'markdown'}
# 文本处理配置 # Text processing settings
DEFAULT_CHUNK_SIZE = 500 # 默认切块大小 DEFAULT_CHUNK_SIZE = 500 # default chunk size
DEFAULT_CHUNK_OVERLAP = 50 # 默认重叠大小 DEFAULT_CHUNK_OVERLAP = 50 # default overlap size
# OASIS模拟配置 # OASIS simulation settings
OASIS_DEFAULT_MAX_ROUNDS = int(os.environ.get('OASIS_DEFAULT_MAX_ROUNDS', '10')) OASIS_DEFAULT_MAX_ROUNDS = int(os.environ.get('OASIS_DEFAULT_MAX_ROUNDS', '10'))
OASIS_SIMULATION_DATA_DIR = os.path.join(os.path.dirname(__file__), '../uploads/simulations') OASIS_SIMULATION_DATA_DIR = os.path.join(os.path.dirname(__file__), '../uploads/simulations')
# OASIS平台可用动作配置 # OASIS platform available actions
OASIS_TWITTER_ACTIONS = [ OASIS_TWITTER_ACTIONS = [
'CREATE_POST', 'LIKE_POST', 'REPOST', 'FOLLOW', 'DO_NOTHING', 'QUOTE_POST' 'CREATE_POST', 'LIKE_POST', 'REPOST', 'FOLLOW', 'DO_NOTHING', 'QUOTE_POST'
] ]
@ -59,18 +59,18 @@ class Config:
'TREND', 'REFRESH', 'DO_NOTHING', 'FOLLOW', 'MUTE' 'TREND', 'REFRESH', 'DO_NOTHING', 'FOLLOW', 'MUTE'
] ]
# Report Agent配置 # Report Agent settings
REPORT_AGENT_MAX_TOOL_CALLS = int(os.environ.get('REPORT_AGENT_MAX_TOOL_CALLS', '5')) REPORT_AGENT_MAX_TOOL_CALLS = int(os.environ.get('REPORT_AGENT_MAX_TOOL_CALLS', '5'))
REPORT_AGENT_MAX_REFLECTION_ROUNDS = int(os.environ.get('REPORT_AGENT_MAX_REFLECTION_ROUNDS', '2')) REPORT_AGENT_MAX_REFLECTION_ROUNDS = int(os.environ.get('REPORT_AGENT_MAX_REFLECTION_ROUNDS', '2'))
REPORT_AGENT_TEMPERATURE = float(os.environ.get('REPORT_AGENT_TEMPERATURE', '0.5')) REPORT_AGENT_TEMPERATURE = float(os.environ.get('REPORT_AGENT_TEMPERATURE', '0.5'))
@classmethod @classmethod
def validate(cls): def validate(cls):
"""验证必要配置""" """Validate required configuration"""
errors = [] errors = []
if not cls.LLM_API_KEY: if not cls.LLM_API_KEY:
errors.append("LLM_API_KEY 未配置") errors.append("LLM_API_KEY is not configured")
if not cls.ZEP_API_KEY: if not cls.ZEP_API_KEY:
errors.append("ZEP_API_KEY 未配置") errors.append("ZEP_API_KEY is not configured")
return errors return errors

View File

@ -1,5 +1,5 @@
""" """
数据模型模块 Data models module
""" """
from .task import TaskManager, TaskStatus from .task import TaskManager, TaskStatus

View File

@ -1,6 +1,6 @@
""" """
项目上下文管理 Project context management
用于在服务端持久化项目状态避免前端在接口间传递大量数据 Persists project state server-side so the frontend does not need to pass large amounts of data between endpoints.
""" """
import os import os
@ -15,45 +15,45 @@ from ..config import Config
class ProjectStatus(str, Enum): class ProjectStatus(str, Enum):
"""项目状态""" """Project status"""
CREATED = "created" # 刚创建,文件已上传 CREATED = "created" # Just created; files uploaded
ONTOLOGY_GENERATED = "ontology_generated" # 本体已生成 ONTOLOGY_GENERATED = "ontology_generated" # Ontology generated
GRAPH_BUILDING = "graph_building" # 图谱构建中 GRAPH_BUILDING = "graph_building" # Graph building in progress
GRAPH_COMPLETED = "graph_completed" # 图谱构建完成 GRAPH_COMPLETED = "graph_completed" # Graph build complete
FAILED = "failed" # 失败 FAILED = "failed" # Failed
@dataclass @dataclass
class Project: class Project:
"""项目数据模型""" """Project data model"""
project_id: str project_id: str
name: str name: str
status: ProjectStatus status: ProjectStatus
created_at: str created_at: str
updated_at: str updated_at: str
# 文件信息 # File info
files: List[Dict[str, str]] = field(default_factory=list) # [{filename, path, size}] files: List[Dict[str, str]] = field(default_factory=list) # [{filename, path, size}]
total_text_length: int = 0 total_text_length: int = 0
# 本体信息接口1生成后填充 # Ontology info (populated after endpoint 1)
ontology: Optional[Dict[str, Any]] = None ontology: Optional[Dict[str, Any]] = None
analysis_summary: Optional[str] = None analysis_summary: Optional[str] = None
# 图谱信息接口2完成后填充 # Graph info (populated after endpoint 2 completes)
graph_id: Optional[str] = None graph_id: Optional[str] = None
graph_build_task_id: Optional[str] = None graph_build_task_id: Optional[str] = None
# 配置 # Configuration
simulation_requirement: Optional[str] = None simulation_requirement: Optional[str] = None
chunk_size: int = 500 chunk_size: int = 500
chunk_overlap: int = 50 chunk_overlap: int = 50
# 错误信息 # Error info
error: Optional[str] = None error: Optional[str] = None
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
"""转换为字典""" """Convert to dictionary"""
return { return {
"project_id": self.project_id, "project_id": self.project_id,
"name": self.name, "name": self.name,
@ -74,7 +74,7 @@ class Project:
@classmethod @classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'Project': def from_dict(cls, data: Dict[str, Any]) -> 'Project':
"""从字典创建""" """Create from dictionary"""
status = data.get('status', 'created') status = data.get('status', 'created')
if isinstance(status, str): if isinstance(status, str):
status = ProjectStatus(status) status = ProjectStatus(status)
@ -99,46 +99,46 @@ class Project:
class ProjectManager: class ProjectManager:
"""项目管理器 - 负责项目的持久化存储和检索""" """Project manager - handles persistent storage and retrieval of projects"""
# 项目存储根目录 # Root directory for project storage
PROJECTS_DIR = os.path.join(Config.UPLOAD_FOLDER, 'projects') PROJECTS_DIR = os.path.join(Config.UPLOAD_FOLDER, 'projects')
@classmethod @classmethod
def _ensure_projects_dir(cls): def _ensure_projects_dir(cls):
"""确保项目目录存在""" """Ensure the projects directory exists"""
os.makedirs(cls.PROJECTS_DIR, exist_ok=True) os.makedirs(cls.PROJECTS_DIR, exist_ok=True)
@classmethod @classmethod
def _get_project_dir(cls, project_id: str) -> str: def _get_project_dir(cls, project_id: str) -> str:
"""获取项目目录路径""" """Get project directory path"""
return os.path.join(cls.PROJECTS_DIR, project_id) return os.path.join(cls.PROJECTS_DIR, project_id)
@classmethod @classmethod
def _get_project_meta_path(cls, project_id: str) -> str: def _get_project_meta_path(cls, project_id: str) -> str:
"""获取项目元数据文件路径""" """Get project metadata file path"""
return os.path.join(cls._get_project_dir(project_id), 'project.json') return os.path.join(cls._get_project_dir(project_id), 'project.json')
@classmethod @classmethod
def _get_project_files_dir(cls, project_id: str) -> str: def _get_project_files_dir(cls, project_id: str) -> str:
"""获取项目文件存储目录""" """Get project files storage directory"""
return os.path.join(cls._get_project_dir(project_id), 'files') return os.path.join(cls._get_project_dir(project_id), 'files')
@classmethod @classmethod
def _get_project_text_path(cls, project_id: str) -> str: def _get_project_text_path(cls, project_id: str) -> str:
"""获取项目提取文本存储路径""" """Get path for storing the extracted project text"""
return os.path.join(cls._get_project_dir(project_id), 'extracted_text.txt') return os.path.join(cls._get_project_dir(project_id), 'extracted_text.txt')
@classmethod @classmethod
def create_project(cls, name: str = "Unnamed Project") -> Project: def create_project(cls, name: str = "Unnamed Project") -> Project:
""" """
创建新项目 Create a new project.
Args: Args:
name: 项目名称 name: project name
Returns: Returns:
新创建的Project对象 newly created Project object
""" """
cls._ensure_projects_dir() cls._ensure_projects_dir()
@ -153,20 +153,20 @@ class ProjectManager:
updated_at=now updated_at=now
) )
# 创建项目目录结构 # Create project directory structure
project_dir = cls._get_project_dir(project_id) project_dir = cls._get_project_dir(project_id)
files_dir = cls._get_project_files_dir(project_id) files_dir = cls._get_project_files_dir(project_id)
os.makedirs(project_dir, exist_ok=True) os.makedirs(project_dir, exist_ok=True)
os.makedirs(files_dir, exist_ok=True) os.makedirs(files_dir, exist_ok=True)
# 保存项目元数据 # Save project metadata
cls.save_project(project) cls.save_project(project)
return project return project
@classmethod @classmethod
def save_project(cls, project: Project) -> None: def save_project(cls, project: Project) -> None:
"""保存项目元数据""" """Save project metadata"""
project.updated_at = datetime.now().isoformat() project.updated_at = datetime.now().isoformat()
meta_path = cls._get_project_meta_path(project.project_id) meta_path = cls._get_project_meta_path(project.project_id)
@ -176,13 +176,13 @@ class ProjectManager:
@classmethod @classmethod
def get_project(cls, project_id: str) -> Optional[Project]: def get_project(cls, project_id: str) -> Optional[Project]:
""" """
获取项目 Get a project.
Args: Args:
project_id: 项目ID project_id: project ID
Returns: Returns:
Project对象如果不存在返回None Project object, or None if not found
""" """
meta_path = cls._get_project_meta_path(project_id) meta_path = cls._get_project_meta_path(project_id)
@ -197,13 +197,13 @@ class ProjectManager:
@classmethod @classmethod
def list_projects(cls, limit: int = 50) -> List[Project]: def list_projects(cls, limit: int = 50) -> List[Project]:
""" """
列出所有项目 List all projects.
Args: Args:
limit: 返回数量限制 limit: result count limit
Returns: Returns:
项目列表按创建时间倒序 list of projects sorted by creation time, descending
""" """
cls._ensure_projects_dir() cls._ensure_projects_dir()
@ -213,7 +213,7 @@ class ProjectManager:
if project: if project:
projects.append(project) projects.append(project)
# 按创建时间倒序排序 # Sort by creation time, descending
projects.sort(key=lambda p: p.created_at, reverse=True) projects.sort(key=lambda p: p.created_at, reverse=True)
return projects[:limit] return projects[:limit]
@ -221,13 +221,13 @@ class ProjectManager:
@classmethod @classmethod
def delete_project(cls, project_id: str) -> bool: def delete_project(cls, project_id: str) -> bool:
""" """
删除项目及其所有文件 Delete a project and all its files.
Args: Args:
project_id: 项目ID project_id: project ID
Returns: Returns:
是否删除成功 True if successfully deleted
""" """
project_dir = cls._get_project_dir(project_id) project_dir = cls._get_project_dir(project_id)
@ -240,28 +240,28 @@ class ProjectManager:
@classmethod @classmethod
def save_file_to_project(cls, project_id: str, file_storage, original_filename: str) -> Dict[str, str]: def save_file_to_project(cls, project_id: str, file_storage, original_filename: str) -> Dict[str, str]:
""" """
保存上传的文件到项目目录 Save an uploaded file to the project directory.
Args: Args:
project_id: 项目ID project_id: project ID
file_storage: Flask的FileStorage对象 file_storage: Flask FileStorage object
original_filename: 原始文件名 original_filename: original filename
Returns: Returns:
文件信息字典 {filename, path, size} file info dict {filename, path, size}
""" """
files_dir = cls._get_project_files_dir(project_id) files_dir = cls._get_project_files_dir(project_id)
os.makedirs(files_dir, exist_ok=True) os.makedirs(files_dir, exist_ok=True)
# 生成安全的文件名 # Generate a safe filename
ext = os.path.splitext(original_filename)[1].lower() ext = os.path.splitext(original_filename)[1].lower()
safe_filename = f"{uuid.uuid4().hex[:8]}{ext}" safe_filename = f"{uuid.uuid4().hex[:8]}{ext}"
file_path = os.path.join(files_dir, safe_filename) file_path = os.path.join(files_dir, safe_filename)
# 保存文件 # Save file
file_storage.save(file_path) file_storage.save(file_path)
# 获取文件大小 # Get file size
file_size = os.path.getsize(file_path) file_size = os.path.getsize(file_path)
return { return {
@ -273,14 +273,14 @@ class ProjectManager:
@classmethod @classmethod
def save_extracted_text(cls, project_id: str, text: str) -> None: def save_extracted_text(cls, project_id: str, text: str) -> None:
"""保存提取的文本""" """Save extracted text"""
text_path = cls._get_project_text_path(project_id) text_path = cls._get_project_text_path(project_id)
with open(text_path, 'w', encoding='utf-8') as f: with open(text_path, 'w', encoding='utf-8') as f:
f.write(text) f.write(text)
@classmethod @classmethod
def get_extracted_text(cls, project_id: str) -> Optional[str]: def get_extracted_text(cls, project_id: str) -> Optional[str]:
"""获取提取的文本""" """Get extracted text"""
text_path = cls._get_project_text_path(project_id) text_path = cls._get_project_text_path(project_id)
if not os.path.exists(text_path): if not os.path.exists(text_path):
@ -291,7 +291,7 @@ class ProjectManager:
@classmethod @classmethod
def get_project_files(cls, project_id: str) -> List[str]: def get_project_files(cls, project_id: str) -> List[str]:
"""获取项目的所有文件路径""" """Get all file paths for a project"""
files_dir = cls._get_project_files_dir(project_id) files_dir = cls._get_project_files_dir(project_id)
if not os.path.exists(files_dir): if not os.path.exists(files_dir):

View File

@ -1,6 +1,6 @@
""" """
任务状态管理 Task state management
用于跟踪长时间运行的任务如图谱构建 Used to track long-running tasks (e.g. graph building).
""" """
import uuid import uuid
@ -14,30 +14,30 @@ from ..utils.locale import t
class TaskStatus(str, Enum): class TaskStatus(str, Enum):
"""任务状态枚举""" """Task status enum"""
PENDING = "pending" # 等待中 PENDING = "pending" # Waiting
PROCESSING = "processing" # 处理中 PROCESSING = "processing" # In progress
COMPLETED = "completed" # 已完成 COMPLETED = "completed" # Completed
FAILED = "failed" # 失败 FAILED = "failed" # Failed
@dataclass @dataclass
class Task: class Task:
"""任务数据类""" """Task data class"""
task_id: str task_id: str
task_type: str task_type: str
status: TaskStatus status: TaskStatus
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
progress: int = 0 # 总进度百分比 0-100 progress: int = 0 # Total progress percentage 0-100
message: str = "" # 状态消息 message: str = "" # Status message
result: Optional[Dict] = None # 任务结果 result: Optional[Dict] = None # Task result
error: Optional[str] = None # 错误信息 error: Optional[str] = None # Error info
metadata: Dict = field(default_factory=dict) # 额外元数据 metadata: Dict = field(default_factory=dict) # Extra metadata
progress_detail: Dict = field(default_factory=dict) # 详细进度信息 progress_detail: Dict = field(default_factory=dict) # Detailed progress info
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
"""转换为字典""" """Convert to dictionary"""
return { return {
"task_id": self.task_id, "task_id": self.task_id,
"task_type": self.task_type, "task_type": self.task_type,
@ -55,15 +55,15 @@ class Task:
class TaskManager: class TaskManager:
""" """
任务管理器 Task manager
线程安全的任务状态管理 Thread-safe task state management
""" """
_instance = None _instance = None
_lock = threading.Lock() _lock = threading.Lock()
def __new__(cls): def __new__(cls):
"""单例模式""" """Singleton pattern"""
if cls._instance is None: if cls._instance is None:
with cls._lock: with cls._lock:
if cls._instance is None: if cls._instance is None:
@ -74,14 +74,14 @@ class TaskManager:
def create_task(self, task_type: str, metadata: Optional[Dict] = None) -> str: def create_task(self, task_type: str, metadata: Optional[Dict] = None) -> str:
""" """
创建新任务 Create a new task.
Args: Args:
task_type: 任务类型 task_type: task type
metadata: 额外元数据 metadata: extra metadata
Returns: Returns:
任务ID task ID
""" """
task_id = str(uuid.uuid4()) task_id = str(uuid.uuid4())
now = datetime.now() now = datetime.now()
@ -101,7 +101,7 @@ class TaskManager:
return task_id return task_id
def get_task(self, task_id: str) -> Optional[Task]: def get_task(self, task_id: str) -> Optional[Task]:
"""获取任务""" """Get a task"""
with self._task_lock: with self._task_lock:
return self._tasks.get(task_id) return self._tasks.get(task_id)
@ -116,16 +116,16 @@ class TaskManager:
progress_detail: Optional[Dict] = None progress_detail: Optional[Dict] = None
): ):
""" """
更新任务状态 Update task status.
Args: Args:
task_id: 任务ID task_id: task ID
status: 新状态 status: new status
progress: 进度 progress: progress
message: 消息 message: message
result: 结果 result: result
error: 错误信息 error: error info
progress_detail: 详细进度信息 progress_detail: detailed progress info
""" """
with self._task_lock: with self._task_lock:
task = self._tasks.get(task_id) task = self._tasks.get(task_id)
@ -145,7 +145,7 @@ class TaskManager:
task.progress_detail = progress_detail task.progress_detail = progress_detail
def complete_task(self, task_id: str, result: Dict): def complete_task(self, task_id: str, result: Dict):
"""标记任务完成""" """Mark task as complete"""
self.update_task( self.update_task(
task_id, task_id,
status=TaskStatus.COMPLETED, status=TaskStatus.COMPLETED,
@ -155,7 +155,7 @@ class TaskManager:
) )
def fail_task(self, task_id: str, error: str): def fail_task(self, task_id: str, error: str):
"""标记任务失败""" """Mark task as failed"""
self.update_task( self.update_task(
task_id, task_id,
status=TaskStatus.FAILED, status=TaskStatus.FAILED,
@ -164,7 +164,7 @@ class TaskManager:
) )
def list_tasks(self, task_type: Optional[str] = None) -> list: def list_tasks(self, task_type: Optional[str] = None) -> list:
"""列出任务""" """List tasks"""
with self._task_lock: with self._task_lock:
tasks = list(self._tasks.values()) tasks = list(self._tasks.values())
if task_type: if task_type:
@ -172,7 +172,7 @@ class TaskManager:
return [t.to_dict() for t in sorted(tasks, key=lambda x: x.created_at, reverse=True)] return [t.to_dict() for t in sorted(tasks, key=lambda x: x.created_at, reverse=True)]
def cleanup_old_tasks(self, max_age_hours: int = 24): def cleanup_old_tasks(self, max_age_hours: int = 24):
"""清理旧任务""" """Clean up old tasks"""
from datetime import timedelta from datetime import timedelta
cutoff = datetime.now() - timedelta(hours=max_age_hours) cutoff = datetime.now() - timedelta(hours=max_age_hours)

View File

@ -1,5 +1,5 @@
""" """
业务服务模块 Business services module
""" """
from .ontology_generator import OntologyGenerator from .ontology_generator import OntologyGenerator

View File

@ -1,6 +1,6 @@
""" """
图谱构建服务 Graph building service
接口2使用Zep API构建Standalone Graph Endpoint 2: Build a Standalone Graph using the Zep API
""" """
import os import os
@ -22,7 +22,7 @@ from ..utils.locale import t, get_locale, set_locale
@dataclass @dataclass
class GraphInfo: class GraphInfo:
"""图谱信息""" """Graph info"""
graph_id: str graph_id: str
node_count: int node_count: int
edge_count: int edge_count: int
@ -39,14 +39,14 @@ class GraphInfo:
class GraphBuilderService: class GraphBuilderService:
""" """
图谱构建服务 Graph building service
负责调用Zep API构建知识图谱 Responsible for calling the Zep API to build the knowledge graph.
""" """
def __init__(self, api_key: Optional[str] = None): def __init__(self, api_key: Optional[str] = None):
self.api_key = api_key or Config.ZEP_API_KEY self.api_key = api_key or Config.ZEP_API_KEY
if not self.api_key: if not self.api_key:
raise ValueError("ZEP_API_KEY 未配置") raise ValueError("ZEP_API_KEY is not configured")
self.client = Zep(api_key=self.api_key) self.client = Zep(api_key=self.api_key)
self.task_manager = TaskManager() self.task_manager = TaskManager()
@ -61,20 +61,20 @@ class GraphBuilderService:
batch_size: int = 3 batch_size: int = 3
) -> str: ) -> str:
""" """
异步构建图谱 Build the graph asynchronously.
Args: Args:
text: 输入文本 text: input text
ontology: 本体定义来自接口1的输出 ontology: ontology definition (output from endpoint 1)
graph_name: 图谱名称 graph_name: graph name
chunk_size: 文本块大小 chunk_size: text chunk size
chunk_overlap: 块重叠大小 chunk_overlap: chunk overlap size
batch_size: 每批发送的块数量 batch_size: number of chunks per batch
Returns: Returns:
任务ID task ID
""" """
# 创建任务 # Create task
task_id = self.task_manager.create_task( task_id = self.task_manager.create_task(
task_type="graph_build", task_type="graph_build",
metadata={ metadata={
@ -87,7 +87,7 @@ class GraphBuilderService:
# Capture locale before spawning background thread # Capture locale before spawning background thread
current_locale = get_locale() current_locale = get_locale()
# 在后台线程中执行构建 # Run build in background thread
thread = threading.Thread( thread = threading.Thread(
target=self._build_graph_worker, target=self._build_graph_worker,
args=(task_id, text, ontology, graph_name, chunk_size, chunk_overlap, batch_size, current_locale) args=(task_id, text, ontology, graph_name, chunk_size, chunk_overlap, batch_size, current_locale)
@ -108,7 +108,7 @@ class GraphBuilderService:
batch_size: int, batch_size: int,
locale: str = 'zh' locale: str = 'zh'
): ):
"""图谱构建工作线程""" """Graph build worker thread"""
set_locale(locale) set_locale(locale)
try: try:
self.task_manager.update_task( self.task_manager.update_task(
@ -118,7 +118,7 @@ class GraphBuilderService:
message=t('progress.startBuildingGraph') message=t('progress.startBuildingGraph')
) )
# 1. 创建图谱 # 1. Create graph
graph_id = self.create_graph(graph_name) graph_id = self.create_graph(graph_name)
self.task_manager.update_task( self.task_manager.update_task(
task_id, task_id,
@ -126,7 +126,7 @@ class GraphBuilderService:
message=t('progress.graphCreated', graphId=graph_id) message=t('progress.graphCreated', graphId=graph_id)
) )
# 2. 设置本体 # 2. Set ontology
self.set_ontology(graph_id, ontology) self.set_ontology(graph_id, ontology)
self.task_manager.update_task( self.task_manager.update_task(
task_id, task_id,
@ -134,7 +134,7 @@ class GraphBuilderService:
message=t('progress.ontologySet') message=t('progress.ontologySet')
) )
# 3. 文本分块 # 3. Split text into chunks
chunks = TextProcessor.split_text(text, chunk_size, chunk_overlap) chunks = TextProcessor.split_text(text, chunk_size, chunk_overlap)
total_chunks = len(chunks) total_chunks = len(chunks)
self.task_manager.update_task( self.task_manager.update_task(
@ -143,7 +143,7 @@ class GraphBuilderService:
message=t('progress.textSplit', count=total_chunks) message=t('progress.textSplit', count=total_chunks)
) )
# 4. 分批发送数据 # 4. Send data in batches
episode_uuids = self.add_text_batches( episode_uuids = self.add_text_batches(
graph_id, chunks, batch_size, graph_id, chunks, batch_size,
lambda msg, prog: self.task_manager.update_task( lambda msg, prog: self.task_manager.update_task(
@ -153,7 +153,7 @@ class GraphBuilderService:
) )
) )
# 5. 等待Zep处理完成 # 5. Wait for Zep processing to complete
self.task_manager.update_task( self.task_manager.update_task(
task_id, task_id,
progress=60, progress=60,
@ -169,7 +169,7 @@ class GraphBuilderService:
) )
) )
# 6. 获取图谱信息 # 6. Fetch graph info
self.task_manager.update_task( self.task_manager.update_task(
task_id, task_id,
progress=90, progress=90,
@ -178,7 +178,7 @@ class GraphBuilderService:
graph_info = self._get_graph_info(graph_id) graph_info = self._get_graph_info(graph_id)
# 完成 # Complete
self.task_manager.complete_task(task_id, { self.task_manager.complete_task(task_id, {
"graph_id": graph_id, "graph_id": graph_id,
"graph_info": graph_info.to_dict(), "graph_info": graph_info.to_dict(),
@ -191,7 +191,7 @@ class GraphBuilderService:
self.task_manager.fail_task(task_id, error_msg) self.task_manager.fail_task(task_id, error_msg)
def create_graph(self, name: str) -> str: def create_graph(self, name: str) -> str:
"""创建Zep图谱公开方法""" """Create a Zep graph (public method)"""
graph_id = f"mirofish_{uuid.uuid4().hex[:16]}" graph_id = f"mirofish_{uuid.uuid4().hex[:16]}"
self.client.graph.create( self.client.graph.create(
@ -203,74 +203,74 @@ class GraphBuilderService:
return graph_id return graph_id
def set_ontology(self, graph_id: str, ontology: Dict[str, Any]): def set_ontology(self, graph_id: str, ontology: Dict[str, Any]):
"""设置图谱本体(公开方法)""" """Set graph ontology (public method)"""
import warnings import warnings
from typing import Optional from typing import Optional
from pydantic import Field from pydantic import Field
from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel
# 抑制 Pydantic v2 关于 Field(default=None) 的警告 # Suppress Pydantic v2 warnings about Field(default=None)
# 这是 Zep SDK 要求的用法,警告来自动态类创建,可以安全忽略 # This is the usage required by the Zep SDK; warnings come from dynamic class creation and can be safely ignored
warnings.filterwarnings('ignore', category=UserWarning, module='pydantic') warnings.filterwarnings('ignore', category=UserWarning, module='pydantic')
# Zep 保留名称,不能作为属性名 # Zep reserved names that cannot be used as attribute names
RESERVED_NAMES = {'uuid', 'name', 'group_id', 'name_embedding', 'summary', 'created_at'} RESERVED_NAMES = {'uuid', 'name', 'group_id', 'name_embedding', 'summary', 'created_at'}
def safe_attr_name(attr_name: str) -> str: def safe_attr_name(attr_name: str) -> str:
"""将保留名称转换为安全名称""" """Convert reserved names to safe attribute names"""
if attr_name.lower() in RESERVED_NAMES: if attr_name.lower() in RESERVED_NAMES:
return f"entity_{attr_name}" return f"entity_{attr_name}"
return attr_name return attr_name
# 动态创建实体类型 # Dynamically create entity types
entity_types = {} entity_types = {}
for entity_def in ontology.get("entity_types", []): for entity_def in ontology.get("entity_types", []):
name = entity_def["name"] name = entity_def["name"]
description = entity_def.get("description", f"A {name} entity.") description = entity_def.get("description", f"A {name} entity.")
# 创建属性字典和类型注解Pydantic v2 需要) # Build attribute dict and type annotations (required by Pydantic v2)
attrs = {"__doc__": description} attrs = {"__doc__": description}
annotations = {} annotations = {}
for attr_def in entity_def.get("attributes", []): for attr_def in entity_def.get("attributes", []):
attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称 attr_name = safe_attr_name(attr_def["name"]) # Use safe name
attr_desc = attr_def.get("description", attr_name) attr_desc = attr_def.get("description", attr_name)
# Zep API 需要 Field 的 description这是必需的 # Zep API requires Field description — this is mandatory
attrs[attr_name] = Field(description=attr_desc, default=None) attrs[attr_name] = Field(description=attr_desc, default=None)
annotations[attr_name] = Optional[EntityText] # 类型注解 annotations[attr_name] = Optional[EntityText] # Type annotation
attrs["__annotations__"] = annotations attrs["__annotations__"] = annotations
# 动态创建类 # Dynamically create class
entity_class = type(name, (EntityModel,), attrs) entity_class = type(name, (EntityModel,), attrs)
entity_class.__doc__ = description entity_class.__doc__ = description
entity_types[name] = entity_class entity_types[name] = entity_class
# 动态创建边类型 # Dynamically create edge types
edge_definitions = {} edge_definitions = {}
for edge_def in ontology.get("edge_types", []): for edge_def in ontology.get("edge_types", []):
name = edge_def["name"] name = edge_def["name"]
description = edge_def.get("description", f"A {name} relationship.") description = edge_def.get("description", f"A {name} relationship.")
# 创建属性字典和类型注解 # Build attribute dict and type annotations
attrs = {"__doc__": description} attrs = {"__doc__": description}
annotations = {} annotations = {}
for attr_def in edge_def.get("attributes", []): for attr_def in edge_def.get("attributes", []):
attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称 attr_name = safe_attr_name(attr_def["name"]) # Use safe name
attr_desc = attr_def.get("description", attr_name) attr_desc = attr_def.get("description", attr_name)
# Zep API 需要 Field 的 description这是必需的 # Zep API requires Field description — this is mandatory
attrs[attr_name] = Field(description=attr_desc, default=None) attrs[attr_name] = Field(description=attr_desc, default=None)
annotations[attr_name] = Optional[str] # 边属性用str类型 annotations[attr_name] = Optional[str] # Edge attributes use str type
attrs["__annotations__"] = annotations attrs["__annotations__"] = annotations
# 动态创建类 # Dynamically create class
class_name = ''.join(word.capitalize() for word in name.split('_')) class_name = ''.join(word.capitalize() for word in name.split('_'))
edge_class = type(class_name, (EdgeModel,), attrs) edge_class = type(class_name, (EdgeModel,), attrs)
edge_class.__doc__ = description edge_class.__doc__ = description
# 构建source_targets # Build source_targets
source_targets = [] source_targets = []
for st in edge_def.get("source_targets", []): for st in edge_def.get("source_targets", []):
source_targets.append( source_targets.append(
@ -283,7 +283,7 @@ class GraphBuilderService:
if source_targets: if source_targets:
edge_definitions[name] = (edge_class, source_targets) edge_definitions[name] = (edge_class, source_targets)
# 调用Zep API设置本体 # Call Zep API to set ontology
if entity_types or edge_definitions: if entity_types or edge_definitions:
self.client.graph.set_ontology( self.client.graph.set_ontology(
graph_ids=[graph_id], graph_ids=[graph_id],
@ -298,7 +298,7 @@ class GraphBuilderService:
batch_size: int = 3, batch_size: int = 3,
progress_callback: Optional[Callable] = None progress_callback: Optional[Callable] = None
) -> List[str]: ) -> List[str]:
"""分批添加文本到图谱,返回所有 episode 的 uuid 列表""" """Add text to the graph in batches; returns a list of all episode UUIDs"""
episode_uuids = [] episode_uuids = []
total_chunks = len(chunks) total_chunks = len(chunks)
@ -314,27 +314,27 @@ class GraphBuilderService:
progress progress
) )
# 构建episode数据 # Build episode data
episodes = [ episodes = [
EpisodeData(data=chunk, type="text") EpisodeData(data=chunk, type="text")
for chunk in batch_chunks for chunk in batch_chunks
] ]
# 发送到Zep # Send to Zep
try: try:
batch_result = self.client.graph.add_batch( batch_result = self.client.graph.add_batch(
graph_id=graph_id, graph_id=graph_id,
episodes=episodes episodes=episodes
) )
# 收集返回的 episode uuid # Collect returned episode UUIDs
if batch_result and isinstance(batch_result, list): if batch_result and isinstance(batch_result, list):
for ep in batch_result: for ep in batch_result:
ep_uuid = getattr(ep, 'uuid_', None) or getattr(ep, 'uuid', None) ep_uuid = getattr(ep, 'uuid_', None) or getattr(ep, 'uuid', None)
if ep_uuid: if ep_uuid:
episode_uuids.append(ep_uuid) episode_uuids.append(ep_uuid)
# 避免请求过快 # Avoid sending requests too quickly
time.sleep(1) time.sleep(1)
except Exception as e: except Exception as e:
@ -350,7 +350,7 @@ class GraphBuilderService:
progress_callback: Optional[Callable] = None, progress_callback: Optional[Callable] = None,
timeout: int = 600 timeout: int = 600
): ):
"""等待所有 episode 处理完成(通过查询每个 episode 的 processed 状态)""" """Wait for all episodes to finish processing (by polling each episode's processed status)"""
if not episode_uuids: if not episode_uuids:
if progress_callback: if progress_callback:
progress_callback(t('progress.noEpisodesWait'), 1.0) progress_callback(t('progress.noEpisodesWait'), 1.0)
@ -373,7 +373,7 @@ class GraphBuilderService:
) )
break break
# 检查每个 episode 的处理状态 # Check processing status of each episode
for ep_uuid in list(pending_episodes): for ep_uuid in list(pending_episodes):
try: try:
episode = self.client.graph.episode.get(uuid_=ep_uuid) episode = self.client.graph.episode.get(uuid_=ep_uuid)
@ -384,7 +384,7 @@ class GraphBuilderService:
completed_count += 1 completed_count += 1
except Exception as e: except Exception as e:
# 忽略单个查询错误,继续 # Ignore individual query errors and continue
pass pass
elapsed = int(time.time() - start_time) elapsed = int(time.time() - start_time)
@ -395,20 +395,20 @@ class GraphBuilderService:
) )
if pending_episodes: if pending_episodes:
time.sleep(3) # 每3秒检查一次 time.sleep(3) # Check every 3 seconds
if progress_callback: if progress_callback:
progress_callback(t('progress.processingComplete', completed=completed_count, total=total_episodes), 1.0) progress_callback(t('progress.processingComplete', completed=completed_count, total=total_episodes), 1.0)
def _get_graph_info(self, graph_id: str) -> GraphInfo: def _get_graph_info(self, graph_id: str) -> GraphInfo:
"""获取图谱信息""" """Retrieve graph info"""
# 获取节点(分页) # Fetch nodes (paginated)
nodes = fetch_all_nodes(self.client, graph_id) nodes = fetch_all_nodes(self.client, graph_id)
# 获取边(分页) # Fetch edges (paginated)
edges = fetch_all_edges(self.client, graph_id) edges = fetch_all_edges(self.client, graph_id)
# 统计实体类型 # Count entity types
entity_types = set() entity_types = set()
for node in nodes: for node in nodes:
if node.labels: if node.labels:
@ -425,25 +425,25 @@ class GraphBuilderService:
def get_graph_data(self, graph_id: str) -> Dict[str, Any]: def get_graph_data(self, graph_id: str) -> Dict[str, Any]:
""" """
获取完整图谱数据包含详细信息 Retrieve full graph data (with detailed information).
Args: Args:
graph_id: 图谱ID graph_id: graph ID
Returns: Returns:
包含nodes和edges的字典包括时间信息属性等详细数据 Dictionary containing nodes and edges with timestamps, attributes, and other details
""" """
nodes = fetch_all_nodes(self.client, graph_id) nodes = fetch_all_nodes(self.client, graph_id)
edges = fetch_all_edges(self.client, graph_id) edges = fetch_all_edges(self.client, graph_id)
# 创建节点映射用于获取节点名称 # Build node map for looking up node names
node_map = {} node_map = {}
for node in nodes: for node in nodes:
node_map[node.uuid_] = node.name or "" node_map[node.uuid_] = node.name or ""
nodes_data = [] nodes_data = []
for node in nodes: for node in nodes:
# 获取创建时间 # Get creation timestamp
created_at = getattr(node, 'created_at', None) created_at = getattr(node, 'created_at', None)
if created_at: if created_at:
created_at = str(created_at) created_at = str(created_at)
@ -459,20 +459,20 @@ class GraphBuilderService:
edges_data = [] edges_data = []
for edge in edges: for edge in edges:
# 获取时间信息 # Get timestamps
created_at = getattr(edge, 'created_at', None) created_at = getattr(edge, 'created_at', None)
valid_at = getattr(edge, 'valid_at', None) valid_at = getattr(edge, 'valid_at', None)
invalid_at = getattr(edge, 'invalid_at', None) invalid_at = getattr(edge, 'invalid_at', None)
expired_at = getattr(edge, 'expired_at', None) expired_at = getattr(edge, 'expired_at', None)
# 获取 episodes # Get episodes
episodes = getattr(edge, 'episodes', None) or getattr(edge, 'episode_ids', None) episodes = getattr(edge, 'episodes', None) or getattr(edge, 'episode_ids', None)
if episodes and not isinstance(episodes, list): if episodes and not isinstance(episodes, list):
episodes = [str(episodes)] episodes = [str(episodes)]
elif episodes: elif episodes:
episodes = [str(e) for e in episodes] episodes = [str(e) for e in episodes]
# 获取 fact_type # Get fact_type
fact_type = getattr(edge, 'fact_type', None) or edge.name or "" fact_type = getattr(edge, 'fact_type', None) or edge.name or ""
edges_data.append({ edges_data.append({
@ -501,6 +501,6 @@ class GraphBuilderService:
} }
def delete_graph(self, graph_id: str): def delete_graph(self, graph_id: str):
"""删除图谱""" """Delete graph"""
self.client.graph.delete(graph_id=graph_id) self.client.graph.delete(graph_id=graph_id)

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,6 @@
""" """
本体生成服务 Ontology generation service
接口1分析文本内容生成适合社会模拟的实体和关系类型定义 Endpoint 1: Analyze text content and generate entity and relationship type definitions suitable for social simulation.
""" """
import json import json
@ -14,169 +14,169 @@ logger = logging.getLogger(__name__)
def _to_pascal_case(name: str) -> str: def _to_pascal_case(name: str) -> str:
"""将任意格式的名称转换为 PascalCase 'works_for' -> 'WorksFor', 'person' -> 'Person'""" """Convert a name in any format to PascalCase (e.g. 'works_for' -> 'WorksFor', 'person' -> 'Person')"""
# 按非字母数字字符分割 # Split on non-alphanumeric characters
parts = re.split(r'[^a-zA-Z0-9]+', name) parts = re.split(r'[^a-zA-Z0-9]+', name)
# 再按 camelCase 边界分割(如 'camelCase' -> ['camel', 'Case'] # Also split on camelCase boundaries (e.g. 'camelCase' -> ['camel', 'Case'])
words = [] words = []
for part in parts: for part in parts:
words.extend(re.sub(r'([a-z])([A-Z])', r'\1_\2', part).split('_')) words.extend(re.sub(r'([a-z])([A-Z])', r'\1_\2', part).split('_'))
# 每个词首字母大写,过滤空串 # Capitalize each word and filter empty strings
result = ''.join(word.capitalize() for word in words if word) result = ''.join(word.capitalize() for word in words if word)
return result if result else 'Unknown' return result if result else 'Unknown'
# 本体生成的系统提示词 # System prompt for ontology generation
ONTOLOGY_SYSTEM_PROMPT = """你是一个专业的知识图谱本体设计专家。你的任务是分析给定的文本内容和模拟需求,设计适合**社交媒体舆论模拟**的实体类型和关系类型。 ONTOLOGY_SYSTEM_PROMPT = """You are a professional knowledge graph ontology design expert. Your task is to analyze the given text content and simulation requirements, and design entity types and relationship types suitable for **social media opinion simulation**.
**重要你必须输出有效的JSON格式数据不要输出任何其他内容** **Important: You must output valid JSON format data, and nothing else.**
## 核心任务背景 ## Core Task Background
我们正在构建一个**社交媒体舆论模拟系统**在这个系统中 We are building a **social media opinion simulation system**. In this system:
- 每个实体都是一个可以在社交媒体上发声互动传播信息的"账号""主体" - Every entity is an "account" or "subject" that can speak out, interact, and spread information on social media
- 实体之间会相互影响转发评论回应 - Entities influence each other, repost, comment, and respond
- 我们需要模拟舆论事件中各方的反应和信息传播路径 - We need to simulate each party's reaction and the information propagation path during opinion events
因此**实体必须是现实中真实存在的可以在社媒上发声和互动的主体** Therefore, **entities must be real-world subjects that exist and can speak out and interact on social media**:
**可以是** **Can be**:
- 具体的个人公众人物当事人意见领袖专家学者普通人 - Specific individuals (public figures, persons involved, opinion leaders, experts and scholars, ordinary people)
- 公司企业包括其官方账号 - Companies and enterprises (including their official accounts)
- 组织机构大学协会NGO工会等 - Organizations (universities, associations, NGOs, unions, etc.)
- 政府部门监管机构 - Government departments and regulatory agencies
- 媒体机构报纸电视台自媒体网站 - Media organizations (newspapers, TV stations, self-media, websites)
- 社交媒体平台本身 - Social media platforms themselves
- 特定群体代表如校友会粉丝团维权群体等 - Representatives of specific groups (e.g. alumni associations, fan clubs, rights-protection groups, etc.)
**不可以是** **Cannot be**:
- 抽象概念"舆论""情绪""趋势" - Abstract concepts (e.g. "public opinion", "emotion", "trend")
- 主题/话题"学术诚信""教育改革" - Topics/themes (e.g. "academic integrity", "education reform")
- 观点/态度"支持方""反对方" - Viewpoints/attitudes (e.g. "supporters", "opponents")
## 输出格式 ## Output Format
请输出JSON格式包含以下结构 Please output JSON format with the following structure:
```json ```json
{ {
"entity_types": [ "entity_types": [
{ {
"name": "实体类型名称英文PascalCase", "name": "Entity type name (English, PascalCase)",
"description": "简短描述英文不超过100字符", "description": "Brief description (English, max 100 characters)",
"attributes": [ "attributes": [
{ {
"name": "属性名英文snake_case", "name": "Attribute name (English, snake_case)",
"type": "text", "type": "text",
"description": "属性描述" "description": "Attribute description"
} }
], ],
"examples": ["示例实体1", "示例实体2"] "examples": ["Example entity 1", "Example entity 2"]
} }
], ],
"edge_types": [ "edge_types": [
{ {
"name": "关系类型名称英文UPPER_SNAKE_CASE", "name": "Relationship type name (English, UPPER_SNAKE_CASE)",
"description": "简短描述英文不超过100字符", "description": "Brief description (English, max 100 characters)",
"source_targets": [ "source_targets": [
{"source": "源实体类型", "target": "目标实体类型"} {"source": "Source entity type", "target": "Target entity type"}
], ],
"attributes": [] "attributes": []
} }
], ],
"analysis_summary": "对文本内容的简要分析说明" "analysis_summary": "Brief analysis summary of the text content"
} }
``` ```
## 设计指南(极其重要!) ## Design Guidelines (Extremely Important!)
### 1. 实体类型设计 - 必须严格遵守 ### 1. Entity Type Design — Must Be Strictly Followed
**数量要求必须正好10个实体类型** **Quantity requirement: exactly 10 entity types**
**层次结构要求必须同时包含具体类型和兜底类型** **Hierarchy requirement (must include both specific types and fallback types)**:
你的10个实体类型必须包含以下层次 Your 10 entity types must include the following levels:
A. **兜底类型必须包含放在列表最后2个** A. **Fallback types (required, placed as the last 2 in the list)**:
- `Person`: 任何自然人个体的兜底类型当一个人不属于其他更具体的人物类型时归入此类 - `Person`: Fallback type for any individual person. Use this when a person does not fit any other more specific person type.
- `Organization`: 任何组织机构的兜底类型当一个组织不属于其他更具体的组织类型时归入此类 - `Organization`: Fallback type for any organization. Use this when an organization does not fit any other more specific organization type.
B. **具体类型8根据文本内容设计** B. **Specific types (8 types, designed based on text content)**:
- 针对文本中出现的主要角色设计更具体的类型 - Design more specific types for the main roles that appear in the text
- 例如如果文本涉及学术事件可以有 `Student`, `Professor`, `University` - Example: if the text involves an academic event, you might have `Student`, `Professor`, `University`
- 例如如果文本涉及商业事件可以有 `Company`, `CEO`, `Employee` - Example: if the text involves a business event, you might have `Company`, `CEO`, `Employee`
**为什么需要兜底类型** **Why fallback types are needed**:
- 文本中会出现各种人物"中小学教师""路人甲""某位网友" - Various people appear in text, such as "primary and secondary school teachers", "passersby", "some netizen"
- 如果没有专门的类型匹配他们应该被归入 `Person` - If there is no dedicated type to match them, they should fall into `Person`
- 同理小型组织临时团体等应该归入 `Organization` - Similarly, small organizations, ad hoc groups, etc. should fall into `Organization`
**具体类型的设计原则** **Principles for designing specific types**:
- 从文本中识别出高频出现或关键的角色类型 - Identify high-frequency or key role types from the text
- 每个具体类型应该有明确的边界避免重叠 - Each specific type should have clear boundaries and avoid overlap
- description 必须清晰说明这个类型和兜底类型的区别 - The description must clearly explain the difference between this type and the fallback types
### 2. 关系类型设计 ### 2. Relationship Type Design
- 数量6-10 - Quantity: 6-10
- 关系应该反映社媒互动中的真实联系 - Relationships should reflect real connections in social media interactions
- 确保关系的 source_targets 涵盖你定义的实体类型 - Ensure the source_targets in relationships cover the entity types you have defined
### 3. 属性设计 ### 3. Attribute Design
- 每个实体类型1-3个关键属性 - 1-3 key attributes per entity type
- **注意**属性名不能使用 `name``uuid``group_id``created_at``summary`这些是系统保留字 - **Note**: Attribute names must not use `name`, `uuid`, `group_id`, `created_at`, `summary` (these are system reserved words)
- 推荐使用`full_name`, `title`, `role`, `position`, `location`, `description` - Recommended: `full_name`, `title`, `role`, `position`, `location`, `description`, etc.
## 实体类型参考 ## Entity Type Reference
**个人类具体** **Individual types (specific)**:
- Student: 学生 - Student: student
- Professor: 教授/学者 - Professor: professor/scholar
- Journalist: 记者 - Journalist: journalist
- Celebrity: 明星/网红 - Celebrity: celebrity/influencer
- Executive: 高管 - Executive: corporate executive
- Official: 政府官员 - Official: government official
- Lawyer: 律师 - Lawyer: lawyer
- Doctor: 医生 - Doctor: doctor
**个人类兜底** **Individual types (fallback)**:
- Person: 任何自然人不属于上述具体类型时使用 - Person: any individual (use when not fitting the specific types above)
**组织类具体** **Organization types (specific)**:
- University: 高校 - University: university/college
- Company: 公司企业 - Company: company/enterprise
- GovernmentAgency: 政府机构 - GovernmentAgency: government agency
- MediaOutlet: 媒体机构 - MediaOutlet: media organization
- Hospital: 医院 - Hospital: hospital
- School: 中小学 - School: primary/secondary school
- NGO: 非政府组织 - NGO: non-governmental organization
**组织类兜底** **Organization types (fallback)**:
- Organization: 任何组织机构不属于上述具体类型时使用 - Organization: any organization (use when not fitting the specific types above)
## 关系类型参考 ## Relationship Type Reference
- WORKS_FOR: 工作于 - WORKS_FOR: works for
- STUDIES_AT: 就读于 - STUDIES_AT: studies at
- AFFILIATED_WITH: 隶属于 - AFFILIATED_WITH: affiliated with
- REPRESENTS: 代表 - REPRESENTS: represents
- REGULATES: 监管 - REGULATES: regulates
- REPORTS_ON: 报道 - REPORTS_ON: reports on
- COMMENTS_ON: 评论 - COMMENTS_ON: comments on
- RESPONDS_TO: 回应 - RESPONDS_TO: responds to
- SUPPORTS: 支持 - SUPPORTS: supports
- OPPOSES: 反对 - OPPOSES: opposes
- COLLABORATES_WITH: 合作 - COLLABORATES_WITH: collaborates with
- COMPETES_WITH: 竞争 - COMPETES_WITH: competes with
""" """
class OntologyGenerator: class OntologyGenerator:
""" """
本体生成器 Ontology generator
分析文本内容生成实体和关系类型定义 Analyzes text content and generates entity and relationship type definitions.
""" """
def __init__(self, llm_client: Optional[LLMClient] = None): def __init__(self, llm_client: Optional[LLMClient] = None):
@ -189,95 +189,100 @@ class OntologyGenerator:
additional_context: Optional[str] = None additional_context: Optional[str] = None
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
生成本体定义 Generate ontology definition.
Args: Args:
document_texts: 文档文本列表 document_texts: list of document texts
simulation_requirement: 模拟需求描述 simulation_requirement: simulation requirement description
additional_context: 额外上下文 additional_context: additional context
Returns: Returns:
本体定义entity_types, edge_types等 Ontology definition (entity_types, edge_types, etc.)
""" """
# 构建用户消息 lang_instruction = get_language_instruction()
# Build user message
user_message = self._build_user_message( user_message = self._build_user_message(
document_texts, document_texts,
simulation_requirement, simulation_requirement,
additional_context additional_context,
lang_instruction
) )
lang_instruction = get_language_instruction() system_prompt = f"LANGUAGE INSTRUCTION (HIGHEST PRIORITY — MUST BE FOLLOWED): {lang_instruction} All description fields, analysis_summary, and examples MUST be written in this language.\n\n{ONTOLOGY_SYSTEM_PROMPT}\n\n{lang_instruction}\nIMPORTANT: Entity type names MUST be in English PascalCase (e.g., 'PersonEntity', 'MediaOrganization'). Relationship type names MUST be in English UPPER_SNAKE_CASE (e.g., 'WORKS_FOR'). Attribute names MUST be in English snake_case. Only description fields and analysis_summary should use the specified language above."
system_prompt = f"{ONTOLOGY_SYSTEM_PROMPT}\n\n{lang_instruction}\nIMPORTANT: Entity type names MUST be in English PascalCase (e.g., 'PersonEntity', 'MediaOrganization'). Relationship type names MUST be in English UPPER_SNAKE_CASE (e.g., 'WORKS_FOR'). Attribute names MUST be in English snake_case. Only description fields and analysis_summary should use the specified language above."
messages = [ messages = [
{"role": "system", "content": system_prompt}, {"role": "system", "content": system_prompt},
{"role": "user", "content": user_message} {"role": "user", "content": user_message}
] ]
# 调用LLM # Call LLM
result = self.llm_client.chat_json( result = self.llm_client.chat_json(
messages=messages, messages=messages,
temperature=0.3, temperature=0.3,
max_tokens=4096 max_tokens=4096
) )
# 验证和后处理 # Validate and post-process
result = self._validate_and_process(result) result = self._validate_and_process(result)
return result return result
# 传给 LLM 的文本最大长度5万字 # Maximum text length passed to LLM (50,000 characters)
MAX_TEXT_LENGTH_FOR_LLM = 50000 MAX_TEXT_LENGTH_FOR_LLM = 50000
def _build_user_message( def _build_user_message(
self, self,
document_texts: List[str], document_texts: List[str],
simulation_requirement: str, simulation_requirement: str,
additional_context: Optional[str] additional_context: Optional[str],
lang_instruction: str = ""
) -> str: ) -> str:
"""构建用户消息""" """Build user message"""
# 合并文本 # Merge texts
combined_text = "\n\n---\n\n".join(document_texts) combined_text = "\n\n---\n\n".join(document_texts)
original_length = len(combined_text) original_length = len(combined_text)
# 如果文本超过5万字截断仅影响传给LLM的内容不影响图谱构建 # If text exceeds 50,000 characters, truncate (only affects what is passed to LLM, not graph building)
if len(combined_text) > self.MAX_TEXT_LENGTH_FOR_LLM: if len(combined_text) > self.MAX_TEXT_LENGTH_FOR_LLM:
combined_text = combined_text[:self.MAX_TEXT_LENGTH_FOR_LLM] combined_text = combined_text[:self.MAX_TEXT_LENGTH_FOR_LLM]
combined_text += f"\n\n...(原文共{original_length}字,已截取前{self.MAX_TEXT_LENGTH_FOR_LLM}字用于本体分析)..." combined_text += f"\n\n...(text truncated at {self.MAX_TEXT_LENGTH_FOR_LLM} chars out of {original_length} total)..."
message = f"""## 模拟需求 message = f"""## Simulation requirement
{simulation_requirement} {simulation_requirement}
## 文档内容 ## Document content
{combined_text} {combined_text}
""" """
if additional_context: if additional_context:
message += f""" message += f"""
## 额外说明 ## Additional context
{additional_context} {additional_context}
""" """
message += """ message += f"""
请根据以上内容设计适合社会舆论模拟的实体类型和关系类型 Based on the content above, design entity types and relationship types suitable for social opinion simulation.
**必须遵守的规则** **Mandatory rules**:
1. 必须正好输出10个实体类型 1. Output exactly 10 entity types
2. 最后2个必须是兜底类型Person个人兜底 Organization组织兜底 2. The last 2 must be fallback types: Person (individual fallback) and Organization (organization fallback)
3. 前8个是根据文本内容设计的具体类型 3. The first 8 are specific types designed from the document content
4. 所有实体类型必须是现实中可以发声的主体不能是抽象概念 4. All entity types must be real-world subjects capable of speaking out, not abstract concepts
5. 属性名不能使用 nameuuidgroup_id 等保留字 full_nameorg_name 等替代 5. Attribute names must not use reserved words: name, uuid, group_id use full_name, org_name, etc. instead
{lang_instruction}
""" """
return message return message
def _validate_and_process(self, result: Dict[str, Any]) -> Dict[str, Any]: def _validate_and_process(self, result: Dict[str, Any]) -> Dict[str, Any]:
"""验证和后处理结果""" """Validate and post-process the result"""
# 确保必要字段存在 # Ensure required fields exist
if "entity_types" not in result: if "entity_types" not in result:
result["entity_types"] = [] result["entity_types"] = []
if "edge_types" not in result: if "edge_types" not in result:
@ -285,11 +290,11 @@ class OntologyGenerator:
if "analysis_summary" not in result: if "analysis_summary" not in result:
result["analysis_summary"] = "" result["analysis_summary"] = ""
# 验证实体类型 # Validate entity types
# 记录原始名称到 PascalCase 的映射,用于后续修正 edge 的 source_targets 引用 # Record mapping from original name to PascalCase for fixing edge source_targets references later
entity_name_map = {} entity_name_map = {}
for entity in result["entity_types"]: for entity in result["entity_types"]:
# 强制将 entity name 转为 PascalCaseZep API 要求) # Force entity name to PascalCase (required by Zep API)
if "name" in entity: if "name" in entity:
original_name = entity["name"] original_name = entity["name"]
entity["name"] = _to_pascal_case(original_name) entity["name"] = _to_pascal_case(original_name)
@ -300,19 +305,19 @@ class OntologyGenerator:
entity["attributes"] = [] entity["attributes"] = []
if "examples" not in entity: if "examples" not in entity:
entity["examples"] = [] entity["examples"] = []
# 确保description不超过100字符 # Ensure description does not exceed 100 characters
if len(entity.get("description", "")) > 100: if len(entity.get("description", "")) > 100:
entity["description"] = entity["description"][:97] + "..." entity["description"] = entity["description"][:97] + "..."
# 验证关系类型 # Validate relationship types
for edge in result["edge_types"]: for edge in result["edge_types"]:
# 强制将 edge name 转为 SCREAMING_SNAKE_CASEZep API 要求) # Force edge name to SCREAMING_SNAKE_CASE (required by Zep API)
if "name" in edge: if "name" in edge:
original_name = edge["name"] original_name = edge["name"]
edge["name"] = original_name.upper() edge["name"] = original_name.upper()
if edge["name"] != original_name: if edge["name"] != original_name:
logger.warning(f"Edge type name '{original_name}' auto-converted to '{edge['name']}'") logger.warning(f"Edge type name '{original_name}' auto-converted to '{edge['name']}'")
# 修正 source_targets 中的实体名称引用,与转换后的 PascalCase 保持一致 # Fix entity name references in source_targets to match converted PascalCase names
for st in edge.get("source_targets", []): for st in edge.get("source_targets", []):
if st.get("source") in entity_name_map: if st.get("source") in entity_name_map:
st["source"] = entity_name_map[st["source"]] st["source"] = entity_name_map[st["source"]]
@ -325,11 +330,11 @@ class OntologyGenerator:
if len(edge.get("description", "")) > 100: if len(edge.get("description", "")) > 100:
edge["description"] = edge["description"][:97] + "..." edge["description"] = edge["description"][:97] + "..."
# Zep API 限制:最多 10 个自定义实体类型,最多 10 个自定义边类型 # Zep API limit: maximum 10 custom entity types and 10 custom edge types
MAX_ENTITY_TYPES = 10 MAX_ENTITY_TYPES = 10
MAX_EDGE_TYPES = 10 MAX_EDGE_TYPES = 10
# 去重:按 name 去重,保留首次出现的 # Deduplicate: keep first occurrence by name
seen_names = set() seen_names = set()
deduped = [] deduped = []
for entity in result["entity_types"]: for entity in result["entity_types"]:
@ -341,7 +346,7 @@ class OntologyGenerator:
logger.warning(f"Duplicate entity type '{name}' removed during validation") logger.warning(f"Duplicate entity type '{name}' removed during validation")
result["entity_types"] = deduped result["entity_types"] = deduped
# 兜底类型定义 # Fallback type definitions
person_fallback = { person_fallback = {
"name": "Person", "name": "Person",
"description": "Any individual person not fitting other specific person types.", "description": "Any individual person not fitting other specific person types.",
@ -362,12 +367,12 @@ class OntologyGenerator:
"examples": ["small business", "community group"] "examples": ["small business", "community group"]
} }
# 检查是否已有兜底类型 # Check whether fallback types already exist
entity_names = {e["name"] for e in result["entity_types"]} entity_names = {e["name"] for e in result["entity_types"]}
has_person = "Person" in entity_names has_person = "Person" in entity_names
has_organization = "Organization" in entity_names has_organization = "Organization" in entity_names
# 需要添加的兜底类型 # Collect fallback types to add
fallbacks_to_add = [] fallbacks_to_add = []
if not has_person: if not has_person:
fallbacks_to_add.append(person_fallback) fallbacks_to_add.append(person_fallback)
@ -378,17 +383,17 @@ class OntologyGenerator:
current_count = len(result["entity_types"]) current_count = len(result["entity_types"])
needed_slots = len(fallbacks_to_add) needed_slots = len(fallbacks_to_add)
# 如果添加后会超过 10 个,需要移除一些现有类型 # If adding them would exceed 10, remove some existing types
if current_count + needed_slots > MAX_ENTITY_TYPES: if current_count + needed_slots > MAX_ENTITY_TYPES:
# 计算需要移除多少个 # Calculate how many to remove
to_remove = current_count + needed_slots - MAX_ENTITY_TYPES to_remove = current_count + needed_slots - MAX_ENTITY_TYPES
# 从末尾移除(保留前面更重要的具体类型) # Remove from the end (preserve the more important specific types at the front)
result["entity_types"] = result["entity_types"][:-to_remove] result["entity_types"] = result["entity_types"][:-to_remove]
# 添加兜底类型 # Add fallback types
result["entity_types"].extend(fallbacks_to_add) result["entity_types"].extend(fallbacks_to_add)
# 最终确保不超过限制(防御性编程) # Final guard: ensure limits are not exceeded (defensive programming)
if len(result["entity_types"]) > MAX_ENTITY_TYPES: if len(result["entity_types"]) > MAX_ENTITY_TYPES:
result["entity_types"] = result["entity_types"][:MAX_ENTITY_TYPES] result["entity_types"] = result["entity_types"][:MAX_ENTITY_TYPES]
@ -399,29 +404,29 @@ class OntologyGenerator:
def generate_python_code(self, ontology: Dict[str, Any]) -> str: def generate_python_code(self, ontology: Dict[str, Any]) -> str:
""" """
将本体定义转换为Python代码类似ontology.py Convert the ontology definition to Python code (similar to ontology.py).
Args: Args:
ontology: 本体定义 ontology: ontology definition
Returns: Returns:
Python代码字符串 Python code string
""" """
code_lines = [ code_lines = [
'"""', '"""',
'自定义实体类型定义', 'Custom entity type definitions',
'由MiroFish自动生成用于社会舆论模拟', 'Auto-generated by MiroFish for social opinion simulation',
'"""', '"""',
'', '',
'from pydantic import Field', 'from pydantic import Field',
'from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel', 'from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel',
'', '',
'', '',
'# ============== 实体类型定义 ==============', '# ============== Entity type definitions ==============',
'', '',
] ]
# 生成实体类型 # Generate entity types
for entity in ontology.get("entity_types", []): for entity in ontology.get("entity_types", []):
name = entity["name"] name = entity["name"]
desc = entity.get("description", f"A {name} entity.") desc = entity.get("description", f"A {name} entity.")
@ -444,13 +449,13 @@ class OntologyGenerator:
code_lines.append('') code_lines.append('')
code_lines.append('') code_lines.append('')
code_lines.append('# ============== 关系类型定义 ==============') code_lines.append('# ============== Relationship type definitions ==============')
code_lines.append('') code_lines.append('')
# 生成关系类型 # Generate relationship types
for edge in ontology.get("edge_types", []): for edge in ontology.get("edge_types", []):
name = edge["name"] name = edge["name"]
# 转换为PascalCase类名 # Convert to PascalCase class name
class_name = ''.join(word.capitalize() for word in name.split('_')) class_name = ''.join(word.capitalize() for word in name.split('_'))
desc = edge.get("description", f"A {name} relationship.") desc = edge.get("description", f"A {name} relationship.")
@ -472,8 +477,8 @@ class OntologyGenerator:
code_lines.append('') code_lines.append('')
code_lines.append('') code_lines.append('')
# 生成类型字典 # Generate type dictionaries
code_lines.append('# ============== 类型配置 ==============') code_lines.append('# ============== Type configuration ==============')
code_lines.append('') code_lines.append('')
code_lines.append('ENTITY_TYPES = {') code_lines.append('ENTITY_TYPES = {')
for entity in ontology.get("entity_types", []): for entity in ontology.get("entity_types", []):
@ -489,7 +494,7 @@ class OntologyGenerator:
code_lines.append('}') code_lines.append('}')
code_lines.append('') code_lines.append('')
# 生成边的source_targets映射 # Generate edge source_targets mapping
code_lines.append('EDGE_SOURCE_TARGETS = {') code_lines.append('EDGE_SOURCE_TARGETS = {')
for edge in ontology.get("edge_types", []): for edge in ontology.get("edge_types", []):
name = edge["name"] name = edge["name"]
@ -503,4 +508,3 @@ class OntologyGenerator:
code_lines.append('}') code_lines.append('}')
return '\n'.join(code_lines) return '\n'.join(code_lines)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -1,11 +1,11 @@
""" """
模拟IPC通信模块 Simulation IPC communication module
用于Flask后端和模拟脚本之间的进程间通信 Used for inter-process communication between the Flask backend and simulation scripts.
通过文件系统实现简单的命令/响应模式 Implements a simple command/response pattern via the file system:
1. Flask写入命令到 commands/ 目录 1. Flask writes commands to the commands/ directory
2. 模拟脚本轮询命令目录执行命令并写入响应到 responses/ 目录 2. Simulation scripts poll the command directory, execute commands, and write responses to the responses/ directory
3. Flask轮询响应目录获取结果 3. Flask polls the response directory to get results
""" """
import os import os
@ -23,14 +23,14 @@ logger = get_logger('mirofish.simulation_ipc')
class CommandType(str, Enum): class CommandType(str, Enum):
"""命令类型""" """Command type"""
INTERVIEW = "interview" # 单个Agent采访 INTERVIEW = "interview" # Single agent interview
BATCH_INTERVIEW = "batch_interview" # 批量采访 BATCH_INTERVIEW = "batch_interview" # Batch interview
CLOSE_ENV = "close_env" # 关闭环境 CLOSE_ENV = "close_env" # Close environment
class CommandStatus(str, Enum): class CommandStatus(str, Enum):
"""命令状态""" """Command status"""
PENDING = "pending" PENDING = "pending"
PROCESSING = "processing" PROCESSING = "processing"
COMPLETED = "completed" COMPLETED = "completed"
@ -39,7 +39,7 @@ class CommandStatus(str, Enum):
@dataclass @dataclass
class IPCCommand: class IPCCommand:
"""IPC命令""" """IPC command"""
command_id: str command_id: str
command_type: CommandType command_type: CommandType
args: Dict[str, Any] args: Dict[str, Any]
@ -65,7 +65,7 @@ class IPCCommand:
@dataclass @dataclass
class IPCResponse: class IPCResponse:
"""IPC响应""" """IPC response"""
command_id: str command_id: str
status: CommandStatus status: CommandStatus
result: Optional[Dict[str, Any]] = None result: Optional[Dict[str, Any]] = None
@ -94,23 +94,23 @@ class IPCResponse:
class SimulationIPCClient: class SimulationIPCClient:
""" """
模拟IPC客户端Flask端使用 Simulation IPC client (used by the Flask side)
用于向模拟进程发送命令并等待响应 Used to send commands to the simulation process and wait for responses
""" """
def __init__(self, simulation_dir: str): def __init__(self, simulation_dir: str):
""" """
初始化IPC客户端 Initialize the IPC client
Args: Args:
simulation_dir: 模拟数据目录 simulation_dir: simulation data directory
""" """
self.simulation_dir = simulation_dir self.simulation_dir = simulation_dir
self.commands_dir = os.path.join(simulation_dir, "ipc_commands") self.commands_dir = os.path.join(simulation_dir, "ipc_commands")
self.responses_dir = os.path.join(simulation_dir, "ipc_responses") self.responses_dir = os.path.join(simulation_dir, "ipc_responses")
# 确保目录存在 # Ensure directories exist
os.makedirs(self.commands_dir, exist_ok=True) os.makedirs(self.commands_dir, exist_ok=True)
os.makedirs(self.responses_dir, exist_ok=True) os.makedirs(self.responses_dir, exist_ok=True)
@ -122,19 +122,19 @@ class SimulationIPCClient:
poll_interval: float = 0.5 poll_interval: float = 0.5
) -> IPCResponse: ) -> IPCResponse:
""" """
发送命令并等待响应 Send a command and wait for a response
Args: Args:
command_type: 命令类型 command_type: command type
args: 命令参数 args: command arguments
timeout: 超时时间 timeout: timeout in seconds
poll_interval: 轮询间隔 poll_interval: polling interval in seconds
Returns: Returns:
IPCResponse IPCResponse
Raises: Raises:
TimeoutError: 等待响应超时 TimeoutError: timed out waiting for a response
""" """
command_id = str(uuid.uuid4()) command_id = str(uuid.uuid4())
command = IPCCommand( command = IPCCommand(
@ -143,14 +143,14 @@ class SimulationIPCClient:
args=args args=args
) )
# 写入命令文件 # Write command file
command_file = os.path.join(self.commands_dir, f"{command_id}.json") command_file = os.path.join(self.commands_dir, f"{command_id}.json")
with open(command_file, 'w', encoding='utf-8') as f: with open(command_file, 'w', encoding='utf-8') as f:
json.dump(command.to_dict(), f, ensure_ascii=False, indent=2) json.dump(command.to_dict(), f, ensure_ascii=False, indent=2)
logger.info(f"发送IPC命令: {command_type.value}, command_id={command_id}") logger.info(f"Sending IPC command: {command_type.value}, command_id={command_id}")
# 等待响应 # Wait for response
response_file = os.path.join(self.responses_dir, f"{command_id}.json") response_file = os.path.join(self.responses_dir, f"{command_id}.json")
start_time = time.time() start_time = time.time()
@ -161,30 +161,30 @@ class SimulationIPCClient:
response_data = json.load(f) response_data = json.load(f)
response = IPCResponse.from_dict(response_data) response = IPCResponse.from_dict(response_data)
# 清理命令和响应文件 # Clean up command and response files
try: try:
os.remove(command_file) os.remove(command_file)
os.remove(response_file) os.remove(response_file)
except OSError: except OSError:
pass pass
logger.info(f"收到IPC响应: command_id={command_id}, status={response.status.value}") logger.info(f"Received IPC response: command_id={command_id}, status={response.status.value}")
return response return response
except (json.JSONDecodeError, KeyError) as e: except (json.JSONDecodeError, KeyError) as e:
logger.warning(f"解析响应失败: {e}") logger.warning(f"Failed to parse response: {e}")
time.sleep(poll_interval) time.sleep(poll_interval)
# 超时 # Timeout
logger.error(f"等待IPC响应超时: command_id={command_id}") logger.error(f"Timed out waiting for IPC response: command_id={command_id}")
# 清理命令文件 # Clean up command file
try: try:
os.remove(command_file) os.remove(command_file)
except OSError: except OSError:
pass pass
raise TimeoutError(f"等待命令响应超时 ({timeout})") raise TimeoutError(f"Timed out waiting for command response ({timeout}s)")
def send_interview( def send_interview(
self, self,
@ -194,19 +194,19 @@ class SimulationIPCClient:
timeout: float = 60.0 timeout: float = 60.0
) -> IPCResponse: ) -> IPCResponse:
""" """
发送单个Agent采访命令 Send a single agent interview command
Args: Args:
agent_id: Agent ID agent_id: Agent ID
prompt: 采访问题 prompt: interview question
platform: 指定平台可选 platform: target platform (optional)
- "twitter": 只采访Twitter平台 - "twitter": interview only the Twitter platform
- "reddit": 只采访Reddit平台 - "reddit": interview only the Reddit platform
- None: 双平台模拟时同时采访两个平台单平台模拟时采访该平台 - None: in dual-platform mode, interview both; in single-platform mode, interview that platform
timeout: 超时时间 timeout: timeout in seconds
Returns: Returns:
IPCResponseresult字段包含采访结果 IPCResponse with interview result in the result field
""" """
args = { args = {
"agent_id": agent_id, "agent_id": agent_id,
@ -228,18 +228,18 @@ class SimulationIPCClient:
timeout: float = 120.0 timeout: float = 120.0
) -> IPCResponse: ) -> IPCResponse:
""" """
发送批量采访命令 Send a batch interview command
Args: Args:
interviews: 采访列表每个元素包含 {"agent_id": int, "prompt": str, "platform": str(可选)} interviews: list of interviews, each containing {"agent_id": int, "prompt": str, "platform": str (optional)}
platform: 默认平台可选会被每个采访项的platform覆盖 platform: default platform (optional; overridden per-item by each interview's platform)
- "twitter": 默认只采访Twitter平台 - "twitter": default to Twitter platform only
- "reddit": 默认只采访Reddit平台 - "reddit": default to Reddit platform only
- None: 双平台模拟时每个Agent同时采访两个平台 - None: in dual-platform mode, interview each agent on both platforms
timeout: 超时时间 timeout: timeout in seconds
Returns: Returns:
IPCResponseresult字段包含所有采访结果 IPCResponse with all interview results in the result field
""" """
args = {"interviews": interviews} args = {"interviews": interviews}
if platform: if platform:
@ -253,10 +253,10 @@ class SimulationIPCClient:
def send_close_env(self, timeout: float = 30.0) -> IPCResponse: def send_close_env(self, timeout: float = 30.0) -> IPCResponse:
""" """
发送关闭环境命令 Send a close-environment command
Args: Args:
timeout: 超时时间 timeout: timeout in seconds
Returns: Returns:
IPCResponse IPCResponse
@ -269,9 +269,9 @@ class SimulationIPCClient:
def check_env_alive(self) -> bool: def check_env_alive(self) -> bool:
""" """
检查模拟环境是否存活 Check whether the simulation environment is alive
通过检查 env_status.json 文件来判断 Determined by checking the env_status.json file
""" """
status_file = os.path.join(self.simulation_dir, "env_status.json") status_file = os.path.join(self.simulation_dir, "env_status.json")
if not os.path.exists(status_file): if not os.path.exists(status_file):
@ -287,41 +287,41 @@ class SimulationIPCClient:
class SimulationIPCServer: class SimulationIPCServer:
""" """
模拟IPC服务器模拟脚本端使用 Simulation IPC server (used by the simulation script side)
轮询命令目录执行命令并返回响应 Polls the command directory, executes commands, and returns responses
""" """
def __init__(self, simulation_dir: str): def __init__(self, simulation_dir: str):
""" """
初始化IPC服务器 Initialize the IPC server
Args: Args:
simulation_dir: 模拟数据目录 simulation_dir: simulation data directory
""" """
self.simulation_dir = simulation_dir self.simulation_dir = simulation_dir
self.commands_dir = os.path.join(simulation_dir, "ipc_commands") self.commands_dir = os.path.join(simulation_dir, "ipc_commands")
self.responses_dir = os.path.join(simulation_dir, "ipc_responses") self.responses_dir = os.path.join(simulation_dir, "ipc_responses")
# 确保目录存在 # Ensure directories exist
os.makedirs(self.commands_dir, exist_ok=True) os.makedirs(self.commands_dir, exist_ok=True)
os.makedirs(self.responses_dir, exist_ok=True) os.makedirs(self.responses_dir, exist_ok=True)
# 环境状态 # Environment status
self._running = False self._running = False
def start(self): def start(self):
"""标记服务器为运行状态""" """Mark the server as running"""
self._running = True self._running = True
self._update_env_status("alive") self._update_env_status("alive")
def stop(self): def stop(self):
"""标记服务器为停止状态""" """Mark the server as stopped"""
self._running = False self._running = False
self._update_env_status("stopped") self._update_env_status("stopped")
def _update_env_status(self, status: str): def _update_env_status(self, status: str):
"""更新环境状态文件""" """Update the environment status file"""
status_file = os.path.join(self.simulation_dir, "env_status.json") status_file = os.path.join(self.simulation_dir, "env_status.json")
with open(status_file, 'w', encoding='utf-8') as f: with open(status_file, 'w', encoding='utf-8') as f:
json.dump({ json.dump({
@ -331,15 +331,15 @@ class SimulationIPCServer:
def poll_commands(self) -> Optional[IPCCommand]: def poll_commands(self) -> Optional[IPCCommand]:
""" """
轮询命令目录返回第一个待处理的命令 Poll the command directory and return the first pending command
Returns: Returns:
IPCCommand None IPCCommand or None
""" """
if not os.path.exists(self.commands_dir): if not os.path.exists(self.commands_dir):
return None return None
# 按时间排序获取命令文件 # Get command files sorted by modification time
command_files = [] command_files = []
for filename in os.listdir(self.commands_dir): for filename in os.listdir(self.commands_dir):
if filename.endswith('.json'): if filename.endswith('.json'):
@ -354,23 +354,23 @@ class SimulationIPCServer:
data = json.load(f) data = json.load(f)
return IPCCommand.from_dict(data) return IPCCommand.from_dict(data)
except (json.JSONDecodeError, KeyError, OSError) as e: except (json.JSONDecodeError, KeyError, OSError) as e:
logger.warning(f"读取命令文件失败: {filepath}, {e}") logger.warning(f"Failed to read command file: {filepath}, {e}")
continue continue
return None return None
def send_response(self, response: IPCResponse): def send_response(self, response: IPCResponse):
""" """
发送响应 Send a response
Args: Args:
response: IPC响应 response: IPC response
""" """
response_file = os.path.join(self.responses_dir, f"{response.command_id}.json") response_file = os.path.join(self.responses_dir, f"{response.command_id}.json")
with open(response_file, 'w', encoding='utf-8') as f: with open(response_file, 'w', encoding='utf-8') as f:
json.dump(response.to_dict(), f, ensure_ascii=False, indent=2) json.dump(response.to_dict(), f, ensure_ascii=False, indent=2)
# 删除命令文件 # Delete the command file
command_file = os.path.join(self.commands_dir, f"{response.command_id}.json") command_file = os.path.join(self.commands_dir, f"{response.command_id}.json")
try: try:
os.remove(command_file) os.remove(command_file)
@ -378,7 +378,7 @@ class SimulationIPCServer:
pass pass
def send_success(self, command_id: str, result: Dict[str, Any]): def send_success(self, command_id: str, result: Dict[str, Any]):
"""发送成功响应""" """Send a success response"""
self.send_response(IPCResponse( self.send_response(IPCResponse(
command_id=command_id, command_id=command_id,
status=CommandStatus.COMPLETED, status=CommandStatus.COMPLETED,
@ -386,7 +386,7 @@ class SimulationIPCServer:
)) ))
def send_error(self, command_id: str, error: str): def send_error(self, command_id: str, error: str):
"""发送错误响应""" """Send an error response"""
self.send_response(IPCResponse( self.send_response(IPCResponse(
command_id=command_id, command_id=command_id,
status=CommandStatus.FAILED, status=CommandStatus.FAILED,

View File

@ -1,7 +1,7 @@
""" """
OASIS模拟管理器 OASIS simulation manager
管理Twitter和Reddit双平台并行模拟 Manages parallel simulation on both Twitter and Reddit platforms.
使用预设脚本 + LLM智能生成配置参数 Uses preset scripts with LLM-generated configuration parameters.
""" """
import os import os
@ -23,60 +23,60 @@ logger = get_logger('mirofish.simulation')
class SimulationStatus(str, Enum): class SimulationStatus(str, Enum):
"""模拟状态""" """Simulation status"""
CREATED = "created" CREATED = "created"
PREPARING = "preparing" PREPARING = "preparing"
READY = "ready" READY = "ready"
RUNNING = "running" RUNNING = "running"
PAUSED = "paused" PAUSED = "paused"
STOPPED = "stopped" # 模拟被手动停止 STOPPED = "stopped" # Simulation manually stopped
COMPLETED = "completed" # 模拟自然完成 COMPLETED = "completed" # Simulation naturally completed
FAILED = "failed" FAILED = "failed"
class PlatformType(str, Enum): class PlatformType(str, Enum):
"""平台类型""" """Platform type"""
TWITTER = "twitter" TWITTER = "twitter"
REDDIT = "reddit" REDDIT = "reddit"
@dataclass @dataclass
class SimulationState: class SimulationState:
"""模拟状态""" """Simulation state"""
simulation_id: str simulation_id: str
project_id: str project_id: str
graph_id: str graph_id: str
# 平台启用状态 # Platform enable flags
enable_twitter: bool = True enable_twitter: bool = True
enable_reddit: bool = True enable_reddit: bool = True
# 状态 # Status
status: SimulationStatus = SimulationStatus.CREATED status: SimulationStatus = SimulationStatus.CREATED
# 准备阶段数据 # Preparation phase data
entities_count: int = 0 entities_count: int = 0
profiles_count: int = 0 profiles_count: int = 0
entity_types: List[str] = field(default_factory=list) entity_types: List[str] = field(default_factory=list)
# 配置生成信息 # Config generation info
config_generated: bool = False config_generated: bool = False
config_reasoning: str = "" config_reasoning: str = ""
# 运行时数据 # Runtime data
current_round: int = 0 current_round: int = 0
twitter_status: str = "not_started" twitter_status: str = "not_started"
reddit_status: str = "not_started" reddit_status: str = "not_started"
# 时间戳 # Timestamps
created_at: str = field(default_factory=lambda: datetime.now().isoformat()) created_at: str = field(default_factory=lambda: datetime.now().isoformat())
updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
# 错误信息 # Error message
error: Optional[str] = None error: Optional[str] = None
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
"""完整状态字典(内部使用)""" """Full state dictionary (internal use)"""
return { return {
"simulation_id": self.simulation_id, "simulation_id": self.simulation_id,
"project_id": self.project_id, "project_id": self.project_id,
@ -98,7 +98,7 @@ class SimulationState:
} }
def to_simple_dict(self) -> Dict[str, Any]: def to_simple_dict(self) -> Dict[str, Any]:
"""简化状态字典API返回使用""" """Simplified state dictionary (used for API responses)"""
return { return {
"simulation_id": self.simulation_id, "simulation_id": self.simulation_id,
"project_id": self.project_id, "project_id": self.project_id,
@ -114,36 +114,36 @@ class SimulationState:
class SimulationManager: class SimulationManager:
""" """
模拟管理器 Simulation manager
核心功能 Core functions:
1. 从Zep图谱读取实体并过滤 1. Read and filter entities from the Zep graph
2. 生成OASIS Agent Profile 2. Generate OASIS Agent Profiles
3. 使用LLM智能生成模拟配置参数 3. Use LLM to intelligently generate simulation configuration parameters
4. 准备预设脚本所需的所有文件 4. Prepare all files required by the preset scripts
""" """
# 模拟数据存储目录 # Simulation data storage directory
SIMULATION_DATA_DIR = os.path.join( SIMULATION_DATA_DIR = os.path.join(
os.path.dirname(__file__), os.path.dirname(__file__),
'../../uploads/simulations' '../../uploads/simulations'
) )
def __init__(self): def __init__(self):
# 确保目录存在 # Ensure directory exists
os.makedirs(self.SIMULATION_DATA_DIR, exist_ok=True) os.makedirs(self.SIMULATION_DATA_DIR, exist_ok=True)
# 内存中的模拟状态缓存 # In-memory simulation state cache
self._simulations: Dict[str, SimulationState] = {} self._simulations: Dict[str, SimulationState] = {}
def _get_simulation_dir(self, simulation_id: str) -> str: def _get_simulation_dir(self, simulation_id: str) -> str:
"""获取模拟数据目录""" """Get the simulation data directory"""
sim_dir = os.path.join(self.SIMULATION_DATA_DIR, simulation_id) sim_dir = os.path.join(self.SIMULATION_DATA_DIR, simulation_id)
os.makedirs(sim_dir, exist_ok=True) os.makedirs(sim_dir, exist_ok=True)
return sim_dir return sim_dir
def _save_simulation_state(self, state: SimulationState): def _save_simulation_state(self, state: SimulationState):
"""保存模拟状态到文件""" """Save simulation state to file"""
sim_dir = self._get_simulation_dir(state.simulation_id) sim_dir = self._get_simulation_dir(state.simulation_id)
state_file = os.path.join(sim_dir, "state.json") state_file = os.path.join(sim_dir, "state.json")
@ -155,7 +155,7 @@ class SimulationManager:
self._simulations[state.simulation_id] = state self._simulations[state.simulation_id] = state
def _load_simulation_state(self, simulation_id: str) -> Optional[SimulationState]: def _load_simulation_state(self, simulation_id: str) -> Optional[SimulationState]:
"""从文件加载模拟状态""" """Load simulation state from file"""
if simulation_id in self._simulations: if simulation_id in self._simulations:
return self._simulations[simulation_id] return self._simulations[simulation_id]
@ -199,13 +199,13 @@ class SimulationManager:
enable_reddit: bool = True, enable_reddit: bool = True,
) -> SimulationState: ) -> SimulationState:
""" """
创建新的模拟 Create a new simulation.
Args: Args:
project_id: 项目ID project_id: project ID
graph_id: Zep图谱ID graph_id: Zep graph ID
enable_twitter: 是否启用Twitter模拟 enable_twitter: whether to enable Twitter simulation
enable_reddit: 是否启用Reddit模拟 enable_reddit: whether to enable Reddit simulation
Returns: Returns:
SimulationState SimulationState
@ -223,7 +223,7 @@ class SimulationManager:
) )
self._save_simulation_state(state) self._save_simulation_state(state)
logger.info(f"创建模拟: {simulation_id}, project={project_id}, graph={graph_id}") logger.info(f"Simulation created: {simulation_id}, project={project_id}, graph={graph_id}")
return state return state
@ -238,30 +238,30 @@ class SimulationManager:
parallel_profile_count: int = 3 parallel_profile_count: int = 3
) -> SimulationState: ) -> SimulationState:
""" """
准备模拟环境全程自动化 Prepare the simulation environment (fully automated).
步骤 Steps:
1. 从Zep图谱读取并过滤实体 1. Read and filter entities from the Zep graph
2. 为每个实体生成OASIS Agent Profile可选LLM增强支持并行 2. Generate an OASIS Agent Profile for each entity (optional LLM enhancement, supports parallelism)
3. 使用LLM智能生成模拟配置参数时间活跃度发言频率等 3. Use LLM to intelligently generate simulation configuration parameters (time, activity level, posting frequency, etc.)
4. 保存配置文件和Profile文件 4. Save configuration files and profile files
5. 复制预设脚本到模拟目录 5. Copy preset scripts to the simulation directory
Args: Args:
simulation_id: 模拟ID simulation_id: simulation ID
simulation_requirement: 模拟需求描述用于LLM生成配置 simulation_requirement: simulation requirement description (used for LLM config generation)
document_text: 原始文档内容用于LLM理解背景 document_text: original document content (used for LLM background understanding)
defined_entity_types: 预定义的实体类型可选 defined_entity_types: predefined entity types (optional)
use_llm_for_profiles: 是否使用LLM生成详细人设 use_llm_for_profiles: whether to use LLM to generate detailed personas
progress_callback: 进度回调函数 (stage, progress, message) progress_callback: progress callback function (stage, progress, message)
parallel_profile_count: 并行生成人设的数量默认3 parallel_profile_count: number of profiles to generate in parallel, default 3
Returns: Returns:
SimulationState SimulationState
""" """
state = self._load_simulation_state(simulation_id) state = self._load_simulation_state(simulation_id)
if not state: if not state:
raise ValueError(f"模拟不存在: {simulation_id}") raise ValueError(f"Simulation not found: {simulation_id}")
try: try:
state.status = SimulationStatus.PREPARING state.status = SimulationStatus.PREPARING
@ -269,7 +269,7 @@ class SimulationManager:
sim_dir = self._get_simulation_dir(simulation_id) sim_dir = self._get_simulation_dir(simulation_id)
# ========== 阶段1: 读取并过滤实体 ========== # ========== Stage 1: Read and filter entities ==========
if progress_callback: if progress_callback:
progress_callback("reading", 0, t('progress.connectingZepGraph')) progress_callback("reading", 0, t('progress.connectingZepGraph'))
@ -297,11 +297,11 @@ class SimulationManager:
if filtered.filtered_count == 0: if filtered.filtered_count == 0:
state.status = SimulationStatus.FAILED state.status = SimulationStatus.FAILED
state.error = "没有找到符合条件的实体,请检查图谱是否正确构建" state.error = "No qualifying entities found. Please check that the graph was built correctly."
self._save_simulation_state(state) self._save_simulation_state(state)
return state return state
# ========== 阶段2: 生成Agent Profile ========== # ========== Stage 2: Generate Agent Profiles ==========
total_entities = len(filtered.entities) total_entities = len(filtered.entities)
if progress_callback: if progress_callback:
@ -312,7 +312,7 @@ class SimulationManager:
total=total_entities total=total_entities
) )
# 传入graph_id以启用Zep检索功能获取更丰富的上下文 # Pass graph_id to enable Zep retrieval for richer context
generator = OasisProfileGenerator(graph_id=state.graph_id) generator = OasisProfileGenerator(graph_id=state.graph_id)
def profile_progress(current, total, msg): def profile_progress(current, total, msg):
@ -326,7 +326,7 @@ class SimulationManager:
item_name=msg item_name=msg
) )
# 设置实时保存的文件路径(优先使用 Reddit JSON 格式) # Set real-time save path (prefer Reddit JSON format)
realtime_output_path = None realtime_output_path = None
realtime_platform = "reddit" realtime_platform = "reddit"
if state.enable_reddit: if state.enable_reddit:
@ -340,16 +340,16 @@ class SimulationManager:
entities=filtered.entities, entities=filtered.entities,
use_llm=use_llm_for_profiles, use_llm=use_llm_for_profiles,
progress_callback=profile_progress, progress_callback=profile_progress,
graph_id=state.graph_id, # 传入graph_id用于Zep检索 graph_id=state.graph_id, # Pass graph_id for Zep retrieval
parallel_count=parallel_profile_count, # 并行生成数量 parallel_count=parallel_profile_count, # Parallel generation count
realtime_output_path=realtime_output_path, # 实时保存路径 realtime_output_path=realtime_output_path, # Real-time save path
output_platform=realtime_platform # 输出格式 output_platform=realtime_platform # Output format
) )
state.profiles_count = len(profiles) state.profiles_count = len(profiles)
# 保存Profile文件注意Twitter使用CSV格式Reddit使用JSON格式 # Save profile files (note: Twitter uses CSV format, Reddit uses JSON format)
# Reddit 已经在生成过程中实时保存了,这里再保存一次确保完整性 # Reddit has already been saved incrementally during generation; save once more to ensure completeness
if progress_callback: if progress_callback:
progress_callback( progress_callback(
"generating_profiles", 95, "generating_profiles", 95,
@ -366,7 +366,7 @@ class SimulationManager:
) )
if state.enable_twitter: if state.enable_twitter:
# Twitter使用CSV格式这是OASIS的要求 # Twitter uses CSV format — this is a requirement of OASIS
generator.save_profiles( generator.save_profiles(
profiles=profiles, profiles=profiles,
file_path=os.path.join(sim_dir, "twitter_profiles.csv"), file_path=os.path.join(sim_dir, "twitter_profiles.csv"),
@ -381,7 +381,7 @@ class SimulationManager:
total=len(profiles) total=len(profiles)
) )
# ========== 阶段3: LLM智能生成模拟配置 ========== # ========== Stage 3: LLM intelligent simulation configuration generation ==========
if progress_callback: if progress_callback:
progress_callback( progress_callback(
"generating_config", 0, "generating_config", 0,
@ -419,7 +419,7 @@ class SimulationManager:
total=3 total=3
) )
# 保存配置文件 # Save configuration file
config_path = os.path.join(sim_dir, "simulation_config.json") config_path = os.path.join(sim_dir, "simulation_config.json")
with open(config_path, 'w', encoding='utf-8') as f: with open(config_path, 'w', encoding='utf-8') as f:
f.write(sim_params.to_json()) f.write(sim_params.to_json())
@ -435,20 +435,20 @@ class SimulationManager:
total=3 total=3
) )
# 注意:运行脚本保留在 backend/scripts/ 目录,不再复制到模拟目录 # Note: run scripts remain in backend/scripts/; they are not copied to the simulation directory.
# 启动模拟时simulation_runner 会从 scripts/ 目录运行脚本 # When starting a simulation, simulation_runner runs scripts from the scripts/ directory.
# 更新状态 # Update status
state.status = SimulationStatus.READY state.status = SimulationStatus.READY
self._save_simulation_state(state) self._save_simulation_state(state)
logger.info(f"模拟准备完成: {simulation_id}, " logger.info(f"Simulation preparation complete: {simulation_id}, "
f"entities={state.entities_count}, profiles={state.profiles_count}") f"entities={state.entities_count}, profiles={state.profiles_count}")
return state return state
except Exception as e: except Exception as e:
logger.error(f"模拟准备失败: {simulation_id}, error={str(e)}") logger.error(f"Simulation preparation failed: {simulation_id}, error={str(e)}")
import traceback import traceback
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
state.status = SimulationStatus.FAILED state.status = SimulationStatus.FAILED
@ -457,16 +457,16 @@ class SimulationManager:
raise raise
def get_simulation(self, simulation_id: str) -> Optional[SimulationState]: def get_simulation(self, simulation_id: str) -> Optional[SimulationState]:
"""获取模拟状态""" """Get simulation state"""
return self._load_simulation_state(simulation_id) return self._load_simulation_state(simulation_id)
def list_simulations(self, project_id: Optional[str] = None) -> List[SimulationState]: def list_simulations(self, project_id: Optional[str] = None) -> List[SimulationState]:
"""列出所有模拟""" """List all simulations"""
simulations = [] simulations = []
if os.path.exists(self.SIMULATION_DATA_DIR): if os.path.exists(self.SIMULATION_DATA_DIR):
for sim_id in os.listdir(self.SIMULATION_DATA_DIR): for sim_id in os.listdir(self.SIMULATION_DATA_DIR):
# 跳过隐藏文件(如 .DS_Store和非目录文件 # Skip hidden files (e.g. .DS_Store) and non-directory entries
sim_path = os.path.join(self.SIMULATION_DATA_DIR, sim_id) sim_path = os.path.join(self.SIMULATION_DATA_DIR, sim_id)
if sim_id.startswith('.') or not os.path.isdir(sim_path): if sim_id.startswith('.') or not os.path.isdir(sim_path):
continue continue
@ -479,10 +479,10 @@ class SimulationManager:
return simulations return simulations
def get_profiles(self, simulation_id: str, platform: str = "reddit") -> List[Dict[str, Any]]: def get_profiles(self, simulation_id: str, platform: str = "reddit") -> List[Dict[str, Any]]:
"""获取模拟的Agent Profile""" """Get agent profiles for a simulation"""
state = self._load_simulation_state(simulation_id) state = self._load_simulation_state(simulation_id)
if not state: if not state:
raise ValueError(f"模拟不存在: {simulation_id}") raise ValueError(f"Simulation not found: {simulation_id}")
sim_dir = self._get_simulation_dir(simulation_id) sim_dir = self._get_simulation_dir(simulation_id)
profile_path = os.path.join(sim_dir, f"{platform}_profiles.json") profile_path = os.path.join(sim_dir, f"{platform}_profiles.json")
@ -494,7 +494,7 @@ class SimulationManager:
return json.load(f) return json.load(f)
def get_simulation_config(self, simulation_id: str) -> Optional[Dict[str, Any]]: def get_simulation_config(self, simulation_id: str) -> Optional[Dict[str, Any]]:
"""获取模拟配置""" """Get simulation configuration"""
sim_dir = self._get_simulation_dir(simulation_id) sim_dir = self._get_simulation_dir(simulation_id)
config_path = os.path.join(sim_dir, "simulation_config.json") config_path = os.path.join(sim_dir, "simulation_config.json")
@ -505,7 +505,7 @@ class SimulationManager:
return json.load(f) return json.load(f)
def get_run_instructions(self, simulation_id: str) -> Dict[str, str]: def get_run_instructions(self, simulation_id: str) -> Dict[str, str]:
"""获取运行说明""" """Get run instructions"""
sim_dir = self._get_simulation_dir(simulation_id) sim_dir = self._get_simulation_dir(simulation_id)
config_path = os.path.join(sim_dir, "simulation_config.json") config_path = os.path.join(sim_dir, "simulation_config.json")
scripts_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../scripts')) scripts_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../scripts'))
@ -520,10 +520,10 @@ class SimulationManager:
"parallel": f"python {scripts_dir}/run_parallel_simulation.py --config {config_path}", "parallel": f"python {scripts_dir}/run_parallel_simulation.py --config {config_path}",
}, },
"instructions": ( "instructions": (
f"1. 激活conda环境: conda activate MiroFish\n" f"1. Activate conda environment: conda activate MiroFish\n"
f"2. 运行模拟 (脚本位于 {scripts_dir}):\n" f"2. Run simulation (scripts located at {scripts_dir}):\n"
f" - 单独运行Twitter: python {scripts_dir}/run_twitter_simulation.py --config {config_path}\n" f" - Twitter only: python {scripts_dir}/run_twitter_simulation.py --config {config_path}\n"
f" - 单独运行Reddit: python {scripts_dir}/run_reddit_simulation.py --config {config_path}\n" f" - Reddit only: python {scripts_dir}/run_reddit_simulation.py --config {config_path}\n"
f" - 并行运行双平台: python {scripts_dir}/run_parallel_simulation.py --config {config_path}" f" - Both platforms in parallel: python {scripts_dir}/run_parallel_simulation.py --config {config_path}"
) )
} }

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,5 @@
""" """
文本处理服务 Text processing service
""" """
from typing import List, Optional from typing import List, Optional
@ -7,11 +7,11 @@ from ..utils.file_parser import FileParser, split_text_into_chunks
class TextProcessor: class TextProcessor:
"""文本处理器""" """Text processor"""
@staticmethod @staticmethod
def extract_from_files(file_paths: List[str]) -> str: def extract_from_files(file_paths: List[str]) -> str:
"""从多个文件提取文本""" """Extract text from multiple files"""
return FileParser.extract_from_multiple(file_paths) return FileParser.extract_from_multiple(file_paths)
@staticmethod @staticmethod
@ -21,40 +21,40 @@ class TextProcessor:
overlap: int = 50 overlap: int = 50
) -> List[str]: ) -> List[str]:
""" """
分割文本 Split text into chunks.
Args: Args:
text: 原始文本 text: raw text
chunk_size: 块大小 chunk_size: chunk size
overlap: 重叠大小 overlap: overlap size
Returns: Returns:
文本块列表 list of text chunks
""" """
return split_text_into_chunks(text, chunk_size, overlap) return split_text_into_chunks(text, chunk_size, overlap)
@staticmethod @staticmethod
def preprocess_text(text: str) -> str: def preprocess_text(text: str) -> str:
""" """
预处理文本 Preprocess text:
- 移除多余空白 - Remove excess whitespace
- 标准化换行 - Normalize line endings
Args: Args:
text: 原始文本 text: raw text
Returns: Returns:
处理后的文本 processed text
""" """
import re import re
# 标准化换行 # Normalize line endings
text = text.replace('\r\n', '\n').replace('\r', '\n') text = text.replace('\r\n', '\n').replace('\r', '\n')
# 移除连续空行(保留最多两个换行) # Remove consecutive blank lines (keep at most two newlines)
text = re.sub(r'\n{3,}', '\n\n', text) text = re.sub(r'\n{3,}', '\n\n', text)
# 移除行首行尾空白 # Strip leading/trailing whitespace from each line
lines = [line.strip() for line in text.split('\n')] lines = [line.strip() for line in text.split('\n')]
text = '\n'.join(lines) text = '\n'.join(lines)
@ -62,7 +62,7 @@ class TextProcessor:
@staticmethod @staticmethod
def get_text_stats(text: str) -> dict: def get_text_stats(text: str) -> dict:
"""获取文本统计信息""" """Get text statistics"""
return { return {
"total_chars": len(text), "total_chars": len(text),
"total_lines": text.count('\n') + 1, "total_lines": text.count('\n') + 1,

View File

@ -1,6 +1,6 @@
""" """
Zep实体读取与过滤服务 Zep entity read and filter service
从Zep图谱中读取节点筛选出符合预定义实体类型的节点 Reads nodes from the Zep graph and filters out nodes that match predefined entity types
""" """
import time import time
@ -15,21 +15,21 @@ from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
logger = get_logger('mirofish.zep_entity_reader') logger = get_logger('mirofish.zep_entity_reader')
# 用于泛型返回类型 # Generic return type
T = TypeVar('T') T = TypeVar('T')
@dataclass @dataclass
class EntityNode: class EntityNode:
"""实体节点数据结构""" """Entity node data structure"""
uuid: str uuid: str
name: str name: str
labels: List[str] labels: List[str]
summary: str summary: str
attributes: Dict[str, Any] attributes: Dict[str, Any]
# 相关的边信息 # Related edge info
related_edges: List[Dict[str, Any]] = field(default_factory=list) related_edges: List[Dict[str, Any]] = field(default_factory=list)
# 相关的其他节点信息 # Related node info
related_nodes: List[Dict[str, Any]] = field(default_factory=list) related_nodes: List[Dict[str, Any]] = field(default_factory=list)
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
@ -44,7 +44,7 @@ class EntityNode:
} }
def get_entity_type(self) -> Optional[str]: def get_entity_type(self) -> Optional[str]:
"""获取实体类型排除默认的Entity标签""" """Get entity type (excluding the default Entity label)"""
for label in self.labels: for label in self.labels:
if label not in ["Entity", "Node"]: if label not in ["Entity", "Node"]:
return label return label
@ -53,7 +53,7 @@ class EntityNode:
@dataclass @dataclass
class FilteredEntities: class FilteredEntities:
"""过滤后的实体集合""" """Filtered entity collection"""
entities: List[EntityNode] entities: List[EntityNode]
entity_types: Set[str] entity_types: Set[str]
total_count: int total_count: int
@ -70,18 +70,18 @@ class FilteredEntities:
class ZepEntityReader: class ZepEntityReader:
""" """
Zep实体读取与过滤服务 Zep entity read and filter service
主要功能 Main features:
1. 从Zep图谱读取所有节点 1. Read all nodes from the Zep graph
2. 筛选出符合预定义实体类型的节点Labels不只是Entity的节点 2. Filter out nodes matching predefined entity types (nodes with labels beyond just "Entity")
3. 获取每个实体的相关边和关联节点信息 3. Fetch related edges and associated node info for each entity
""" """
def __init__(self, api_key: Optional[str] = None): def __init__(self, api_key: Optional[str] = None):
self.api_key = api_key or Config.ZEP_API_KEY self.api_key = api_key or Config.ZEP_API_KEY
if not self.api_key: if not self.api_key:
raise ValueError("ZEP_API_KEY 未配置") raise ValueError("ZEP_API_KEY is not configured")
self.client = Zep(api_key=self.api_key) self.client = Zep(api_key=self.api_key)
@ -93,16 +93,16 @@ class ZepEntityReader:
initial_delay: float = 2.0 initial_delay: float = 2.0
) -> T: ) -> T:
""" """
带重试机制的Zep API调用 Zep API call with retry logic
Args: Args:
func: 要执行的函数无参数的lambda或callable func: function to execute (a lambda or callable with no arguments)
operation_name: 操作名称用于日志 operation_name: operation name for logging
max_retries: 最大重试次数默认3次即最多尝试3次 max_retries: maximum number of retries (default 3, meaning up to 3 attempts total)
initial_delay: 初始延迟秒数 initial_delay: initial delay in seconds
Returns: Returns:
API调用结果 API call result
""" """
last_exception = None last_exception = None
delay = initial_delay delay = initial_delay
@ -114,27 +114,27 @@ class ZepEntityReader:
last_exception = e last_exception = e
if attempt < max_retries - 1: if attempt < max_retries - 1:
logger.warning( logger.warning(
f"Zep {operation_name} {attempt + 1} 次尝试失败: {str(e)[:100]}, " f"Zep {operation_name} attempt {attempt + 1} failed: {str(e)[:100]}, "
f"{delay:.1f}秒后重试..." f"retrying in {delay:.1f}s..."
) )
time.sleep(delay) time.sleep(delay)
delay *= 2 # 指数退避 delay *= 2 # Exponential backoff
else: else:
logger.error(f"Zep {operation_name} {max_retries} 次尝试后仍失败: {str(e)}") logger.error(f"Zep {operation_name} still failing after {max_retries} attempts: {str(e)}")
raise last_exception raise last_exception
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]: def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
""" """
获取图谱的所有节点分页获取 Get all nodes in the graph (paginated)
Args: Args:
graph_id: 图谱ID graph_id: graph ID
Returns: Returns:
节点列表 Node list
""" """
logger.info(f"获取图谱 {graph_id} 的所有节点...") logger.info(f"Fetching all nodes for graph {graph_id}...")
nodes = fetch_all_nodes(self.client, graph_id) nodes = fetch_all_nodes(self.client, graph_id)
@ -148,20 +148,20 @@ class ZepEntityReader:
"attributes": node.attributes or {}, "attributes": node.attributes or {},
}) })
logger.info(f"共获取 {len(nodes_data)} 个节点") logger.info(f"Fetched {len(nodes_data)} nodes")
return nodes_data return nodes_data
def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]: def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]:
""" """
获取图谱的所有边分页获取 Get all edges in the graph (paginated)
Args: Args:
graph_id: 图谱ID graph_id: graph ID
Returns: Returns:
边列表 Edge list
""" """
logger.info(f"获取图谱 {graph_id} 的所有边...") logger.info(f"Fetching all edges for graph {graph_id}...")
edges = fetch_all_edges(self.client, graph_id) edges = fetch_all_edges(self.client, graph_id)
@ -176,24 +176,24 @@ class ZepEntityReader:
"attributes": edge.attributes or {}, "attributes": edge.attributes or {},
}) })
logger.info(f"共获取 {len(edges_data)} 条边") logger.info(f"Fetched {len(edges_data)} edges")
return edges_data return edges_data
def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]: def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]:
""" """
获取指定节点的所有相关边带重试机制 Get all edges related to the specified node (with retry logic)
Args: Args:
node_uuid: 节点UUID node_uuid: node UUID
Returns: Returns:
边列表 Edge list
""" """
try: try:
# 使用重试机制调用Zep API # Call Zep API with retry
edges = self._call_with_retry( edges = self._call_with_retry(
func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid), func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid),
operation_name=f"获取节点边(node={node_uuid[:8]}...)" operation_name=f"get node edges (node={node_uuid[:8]}...)"
) )
edges_data = [] edges_data = []
@ -209,7 +209,7 @@ class ZepEntityReader:
return edges_data return edges_data
except Exception as e: except Exception as e:
logger.warning(f"获取节点 {node_uuid} 的边失败: {str(e)}") logger.warning(f"Failed to get edges for node {node_uuid}: {str(e)}")
return [] return []
def filter_defined_entities( def filter_defined_entities(
@ -219,47 +219,47 @@ class ZepEntityReader:
enrich_with_edges: bool = True enrich_with_edges: bool = True
) -> FilteredEntities: ) -> FilteredEntities:
""" """
筛选出符合预定义实体类型的节点 Filter out nodes that match predefined entity types
筛选逻辑 Filter logic:
- 如果节点的Labels只有一个"Entity"说明这个实体不符合我们预定义的类型跳过 - If a node's Labels contain only "Entity", it does not match our predefined types; skip it
- 如果节点的Labels包含除"Entity""Node"之外的标签说明符合预定义类型保留 - If a node's Labels contain labels other than "Entity" and "Node", it matches a predefined type; keep it
Args: Args:
graph_id: 图谱ID graph_id: graph ID
defined_entity_types: 预定义的实体类型列表可选如果提供则只保留这些类型 defined_entity_types: list of predefined entity types (optional; if provided, only these types are kept)
enrich_with_edges: 是否获取每个实体的相关边信息 enrich_with_edges: whether to fetch related edge info for each entity
Returns: Returns:
FilteredEntities: 过滤后的实体集合 FilteredEntities: filtered entity collection
""" """
logger.info(f"开始筛选图谱 {graph_id} 的实体...") logger.info(f"Starting entity filtering for graph {graph_id}...")
# 获取所有节点 # Get all nodes
all_nodes = self.get_all_nodes(graph_id) all_nodes = self.get_all_nodes(graph_id)
total_count = len(all_nodes) total_count = len(all_nodes)
# 获取所有边(用于后续关联查找) # Get all edges (for relation lookup)
all_edges = self.get_all_edges(graph_id) if enrich_with_edges else [] all_edges = self.get_all_edges(graph_id) if enrich_with_edges else []
# 构建节点UUID到节点数据的映射 # Build UUID-to-node mapping
node_map = {n["uuid"]: n for n in all_nodes} node_map = {n["uuid"]: n for n in all_nodes}
# 筛选符合条件的实体 # Filter matching entities
filtered_entities = [] filtered_entities = []
entity_types_found = set() entity_types_found = set()
for node in all_nodes: for node in all_nodes:
labels = node.get("labels", []) labels = node.get("labels", [])
# 筛选逻辑Labels必须包含除"Entity"和"Node"之外的标签 # Filter logic: Labels must contain at least one label other than "Entity" and "Node"
custom_labels = [l for l in labels if l not in ["Entity", "Node"]] custom_labels = [l for l in labels if l not in ["Entity", "Node"]]
if not custom_labels: if not custom_labels:
# 只有默认标签,跳过 # Only default labels; skip
continue continue
# 如果指定了预定义类型,检查是否匹配 # If predefined types are specified, check for a match
if defined_entity_types: if defined_entity_types:
matching_labels = [l for l in custom_labels if l in defined_entity_types] matching_labels = [l for l in custom_labels if l in defined_entity_types]
if not matching_labels: if not matching_labels:
@ -270,7 +270,7 @@ class ZepEntityReader:
entity_types_found.add(entity_type) entity_types_found.add(entity_type)
# 创建实体节点对象 # Create entity node object
entity = EntityNode( entity = EntityNode(
uuid=node["uuid"], uuid=node["uuid"],
name=node["name"], name=node["name"],
@ -279,7 +279,7 @@ class ZepEntityReader:
attributes=node["attributes"], attributes=node["attributes"],
) )
# 获取相关边和节点 # Fetch related edges and nodes
if enrich_with_edges: if enrich_with_edges:
related_edges = [] related_edges = []
related_node_uuids = set() related_node_uuids = set()
@ -304,7 +304,7 @@ class ZepEntityReader:
entity.related_edges = related_edges entity.related_edges = related_edges
# 获取关联节点的基本信息 # Fetch basic info for related nodes
related_nodes = [] related_nodes = []
for related_uuid in related_node_uuids: for related_uuid in related_node_uuids:
if related_uuid in node_map: if related_uuid in node_map:
@ -320,8 +320,8 @@ class ZepEntityReader:
filtered_entities.append(entity) filtered_entities.append(entity)
logger.info(f"筛选完成: 总节点 {total_count}, 符合条件 {len(filtered_entities)}, " logger.info(f"Filtering complete: total nodes {total_count}, matching {len(filtered_entities)}, "
f"实体类型: {entity_types_found}") f"entity types: {entity_types_found}")
return FilteredEntities( return FilteredEntities(
entities=filtered_entities, entities=filtered_entities,
@ -336,33 +336,33 @@ class ZepEntityReader:
entity_uuid: str entity_uuid: str
) -> Optional[EntityNode]: ) -> Optional[EntityNode]:
""" """
获取单个实体及其完整上下文边和关联节点带重试机制 Get a single entity and its full context (edges and related nodes, with retry)
Args: Args:
graph_id: 图谱ID graph_id: graph ID
entity_uuid: 实体UUID entity_uuid: entity UUID
Returns: Returns:
EntityNodeNone EntityNode or None
""" """
try: try:
# 使用重试机制获取节点 # Get the node with retry
node = self._call_with_retry( node = self._call_with_retry(
func=lambda: self.client.graph.node.get(uuid_=entity_uuid), func=lambda: self.client.graph.node.get(uuid_=entity_uuid),
operation_name=f"获取节点详情(uuid={entity_uuid[:8]}...)" operation_name=f"get node detail (uuid={entity_uuid[:8]}...)"
) )
if not node: if not node:
return None return None
# 获取节点的边 # Get the node's edges
edges = self.get_node_edges(entity_uuid) edges = self.get_node_edges(entity_uuid)
# 获取所有节点用于关联查找 # Get all nodes for relation lookup
all_nodes = self.get_all_nodes(graph_id) all_nodes = self.get_all_nodes(graph_id)
node_map = {n["uuid"]: n for n in all_nodes} node_map = {n["uuid"]: n for n in all_nodes}
# 处理相关边和节点 # Process related edges and nodes
related_edges = [] related_edges = []
related_node_uuids = set() related_node_uuids = set()
@ -384,7 +384,7 @@ class ZepEntityReader:
}) })
related_node_uuids.add(edge["source_node_uuid"]) related_node_uuids.add(edge["source_node_uuid"])
# 获取关联节点信息 # Fetch related node info
related_nodes = [] related_nodes = []
for related_uuid in related_node_uuids: for related_uuid in related_node_uuids:
if related_uuid in node_map: if related_uuid in node_map:
@ -407,7 +407,7 @@ class ZepEntityReader:
) )
except Exception as e: except Exception as e:
logger.error(f"获取实体 {entity_uuid} 失败: {str(e)}") logger.error(f"Failed to get entity {entity_uuid}: {str(e)}")
return None return None
def get_entities_by_type( def get_entities_by_type(
@ -417,15 +417,15 @@ class ZepEntityReader:
enrich_with_edges: bool = True enrich_with_edges: bool = True
) -> List[EntityNode]: ) -> List[EntityNode]:
""" """
获取指定类型的所有实体 Get all entities of a specified type
Args: Args:
graph_id: 图谱ID graph_id: graph ID
entity_type: 实体类型 "Student", "PublicFigure" entity_type: entity type (e.g. "Student", "PublicFigure")
enrich_with_edges: 是否获取相关边信息 enrich_with_edges: whether to fetch related edge info
Returns: Returns:
实体列表 Entity list
""" """
result = self.filter_defined_entities( result = self.filter_defined_entities(
graph_id=graph_id, graph_id=graph_id,
@ -433,5 +433,3 @@ class ZepEntityReader:
enrich_with_edges=enrich_with_edges enrich_with_edges=enrich_with_edges
) )
return result.entities return result.entities

View File

@ -1,6 +1,6 @@
""" """
Zep图谱记忆更新服务 Zep graph memory update service
将模拟中的Agent活动动态更新到Zep图谱中 Dynamically updates agent activities from the simulation to the Zep graph
""" """
import os import os
@ -23,7 +23,7 @@ logger = get_logger('mirofish.zep_graph_memory_updater')
@dataclass @dataclass
class AgentActivity: class AgentActivity:
"""Agent活动记录""" """Agent activity record"""
platform: str # twitter / reddit platform: str # twitter / reddit
agent_id: int agent_id: int
agent_name: str agent_name: str
@ -34,12 +34,13 @@ class AgentActivity:
def to_episode_text(self) -> str: def to_episode_text(self) -> str:
""" """
将活动转换为可以发送给Zep的文本描述 Convert the activity to a text description suitable for sending to Zep
采用自然语言描述格式让Zep能够从中提取实体和关系 Uses a natural-language description format so Zep can extract entities and
不添加模拟相关的前缀避免误导图谱更新 relationships. No simulation-specific prefix is added to avoid misleading
graph updates.
""" """
# 根据不同的动作类型生成不同的描述 # Generate a description based on the action type
action_descriptions = { action_descriptions = {
"CREATE_POST": self._describe_create_post, "CREATE_POST": self._describe_create_post,
"LIKE_POST": self._describe_like_post, "LIKE_POST": self._describe_like_post,
@ -58,222 +59,223 @@ class AgentActivity:
describe_func = action_descriptions.get(self.action_type, self._describe_generic) describe_func = action_descriptions.get(self.action_type, self._describe_generic)
description = describe_func() description = describe_func()
# 直接返回 "agent名称: 活动描述" 格式,不添加模拟前缀 # Return "agent_name: activity description" format without a simulation prefix
return f"{self.agent_name}: {description}" return f"{self.agent_name}: {description}"
def _describe_create_post(self) -> str: def _describe_create_post(self) -> str:
content = self.action_args.get("content", "") content = self.action_args.get("content", "")
if content: if content:
return f"发布了一条帖子:「{content}" return f'posted: "{content}"'
return "发布了一条帖子" return "created a post"
def _describe_like_post(self) -> str: def _describe_like_post(self) -> str:
"""点赞帖子 - 包含帖子原文和作者信息""" """Like a post — includes post content and author info"""
post_content = self.action_args.get("post_content", "") post_content = self.action_args.get("post_content", "")
post_author = self.action_args.get("post_author_name", "") post_author = self.action_args.get("post_author_name", "")
if post_content and post_author: if post_content and post_author:
return f"点赞了{post_author}的帖子:「{post_content}" return f'liked {post_author}\'s post: "{post_content}"'
elif post_content: elif post_content:
return f"点赞了一条帖子:「{post_content}" return f'liked a post: "{post_content}"'
elif post_author: elif post_author:
return f"点赞了{post_author}的一条帖子" return f"liked a post by {post_author}"
return "点赞了一条帖子" return "liked a post"
def _describe_dislike_post(self) -> str: def _describe_dislike_post(self) -> str:
"""踩帖子 - 包含帖子原文和作者信息""" """Dislike a post — includes post content and author info"""
post_content = self.action_args.get("post_content", "") post_content = self.action_args.get("post_content", "")
post_author = self.action_args.get("post_author_name", "") post_author = self.action_args.get("post_author_name", "")
if post_content and post_author: if post_content and post_author:
return f"踩了{post_author}的帖子:「{post_content}" return f'disliked {post_author}\'s post: "{post_content}"'
elif post_content: elif post_content:
return f"踩了一条帖子:「{post_content}" return f'disliked a post: "{post_content}"'
elif post_author: elif post_author:
return f"踩了{post_author}的一条帖子" return f"disliked a post by {post_author}"
return "踩了一条帖子" return "disliked a post"
def _describe_repost(self) -> str: def _describe_repost(self) -> str:
"""转发帖子 - 包含原帖内容和作者信息""" """Repost — includes original post content and author info"""
original_content = self.action_args.get("original_content", "") original_content = self.action_args.get("original_content", "")
original_author = self.action_args.get("original_author_name", "") original_author = self.action_args.get("original_author_name", "")
if original_content and original_author: if original_content and original_author:
return f"转发了{original_author}的帖子:「{original_content}" return f'reposted {original_author}\'s post: "{original_content}"'
elif original_content: elif original_content:
return f"转发了一条帖子:「{original_content}" return f'reposted: "{original_content}"'
elif original_author: elif original_author:
return f"转发了{original_author}的一条帖子" return f"reposted a post by {original_author}"
return "转发了一条帖子" return "reposted a post"
def _describe_quote_post(self) -> str: def _describe_quote_post(self) -> str:
"""引用帖子 - 包含原帖内容、作者信息和引用评论""" """Quote post — includes original post content, author info, and quote comment"""
original_content = self.action_args.get("original_content", "") original_content = self.action_args.get("original_content", "")
original_author = self.action_args.get("original_author_name", "") original_author = self.action_args.get("original_author_name", "")
quote_content = self.action_args.get("quote_content", "") or self.action_args.get("content", "") quote_content = self.action_args.get("quote_content", "") or self.action_args.get("content", "")
base = "" base = ""
if original_content and original_author: if original_content and original_author:
base = f"引用了{original_author}的帖子「{original_content}" base = f'quoted {original_author}\'s post "{original_content}"'
elif original_content: elif original_content:
base = f"引用了一条帖子「{original_content}" base = f'quoted a post: "{original_content}"'
elif original_author: elif original_author:
base = f"引用了{original_author}的一条帖子" base = f"quoted a post by {original_author}"
else: else:
base = "引用了一条帖子" base = "quoted a post"
if quote_content: if quote_content:
base += f",并评论道:「{quote_content}" base += f' with comment: "{quote_content}"'
return base return base
def _describe_follow(self) -> str: def _describe_follow(self) -> str:
"""关注用户 - 包含被关注用户的名称""" """Follow a user — includes the followed user's name"""
target_user_name = self.action_args.get("target_user_name", "") target_user_name = self.action_args.get("target_user_name", "")
if target_user_name: if target_user_name:
return f"关注了用户「{target_user_name}" return f'followed user "{target_user_name}"'
return "关注了一个用户" return "followed a user"
def _describe_create_comment(self) -> str: def _describe_create_comment(self) -> str:
"""发表评论 - 包含评论内容和所评论的帖子信息""" """Create a comment — includes comment content and the post being commented on"""
content = self.action_args.get("content", "") content = self.action_args.get("content", "")
post_content = self.action_args.get("post_content", "") post_content = self.action_args.get("post_content", "")
post_author = self.action_args.get("post_author_name", "") post_author = self.action_args.get("post_author_name", "")
if content: if content:
if post_content and post_author: if post_content and post_author:
return f"{post_author}的帖子「{post_content}」下评论道:「{content}" return f'commented on {post_author}\'s post "{post_content}": "{content}"'
elif post_content: elif post_content:
return f"在帖子「{post_content}」下评论道:「{content}" return f'commented on post "{post_content}": "{content}"'
elif post_author: elif post_author:
return f"{post_author}的帖子下评论道:「{content}" return f'commented on {post_author}\'s post: "{content}"'
return f"评论道:「{content}" return f'commented: "{content}"'
return "发表了评论" return "posted a comment"
def _describe_like_comment(self) -> str: def _describe_like_comment(self) -> str:
"""点赞评论 - 包含评论内容和作者信息""" """Like a comment — includes comment content and author info"""
comment_content = self.action_args.get("comment_content", "") comment_content = self.action_args.get("comment_content", "")
comment_author = self.action_args.get("comment_author_name", "") comment_author = self.action_args.get("comment_author_name", "")
if comment_content and comment_author: if comment_content and comment_author:
return f"点赞了{comment_author}的评论:「{comment_content}" return f'liked {comment_author}\'s comment: "{comment_content}"'
elif comment_content: elif comment_content:
return f"点赞了一条评论:「{comment_content}" return f'liked a comment: "{comment_content}"'
elif comment_author: elif comment_author:
return f"点赞了{comment_author}的一条评论" return f"liked a comment by {comment_author}"
return "点赞了一条评论" return "liked a comment"
def _describe_dislike_comment(self) -> str: def _describe_dislike_comment(self) -> str:
"""踩评论 - 包含评论内容和作者信息""" """Dislike a comment — includes comment content and author info"""
comment_content = self.action_args.get("comment_content", "") comment_content = self.action_args.get("comment_content", "")
comment_author = self.action_args.get("comment_author_name", "") comment_author = self.action_args.get("comment_author_name", "")
if comment_content and comment_author: if comment_content and comment_author:
return f"踩了{comment_author}的评论:「{comment_content}" return f'disliked {comment_author}\'s comment: "{comment_content}"'
elif comment_content: elif comment_content:
return f"踩了一条评论:「{comment_content}" return f'disliked a comment: "{comment_content}"'
elif comment_author: elif comment_author:
return f"踩了{comment_author}的一条评论" return f"disliked a comment by {comment_author}"
return "踩了一条评论" return "disliked a comment"
def _describe_search(self) -> str: def _describe_search(self) -> str:
"""搜索帖子 - 包含搜索关键词""" """Search posts — includes search keyword"""
query = self.action_args.get("query", "") or self.action_args.get("keyword", "") query = self.action_args.get("query", "") or self.action_args.get("keyword", "")
return f"搜索了「{query}" if query else "进行了搜索" return f'searched for "{query}"' if query else "performed a search"
def _describe_search_user(self) -> str: def _describe_search_user(self) -> str:
"""搜索用户 - 包含搜索关键词""" """Search users — includes search keyword"""
query = self.action_args.get("query", "") or self.action_args.get("username", "") query = self.action_args.get("query", "") or self.action_args.get("username", "")
return f"搜索了用户「{query}" if query else "搜索了用户" return f'searched for user "{query}"' if query else "searched for a user"
def _describe_mute(self) -> str: def _describe_mute(self) -> str:
"""屏蔽用户 - 包含被屏蔽用户的名称""" """Mute a user — includes the muted user's name"""
target_user_name = self.action_args.get("target_user_name", "") target_user_name = self.action_args.get("target_user_name", "")
if target_user_name: if target_user_name:
return f"屏蔽了用户「{target_user_name}" return f'muted user "{target_user_name}"'
return "屏蔽了一个用户" return "muted a user"
def _describe_generic(self) -> str: def _describe_generic(self) -> str:
# 对于未知的动作类型,生成通用描述 # Generic description for unknown action types
return f"执行了{self.action_type}操作" return f"performed action: {self.action_type}"
class ZepGraphMemoryUpdater: class ZepGraphMemoryUpdater:
""" """
Zep图谱记忆更新器 Zep graph memory updater
监控模拟的actions日志文件将新的agent活动实时更新到Zep图谱中 Monitors the simulation's actions log file and updates new agent activities
按平台分组每累积BATCH_SIZE条活动后批量发送到Zep to the Zep graph in real time. Activities are grouped by platform; each platform
batches up to BATCH_SIZE activities before sending them to Zep.
所有有意义的行为都会被更新到Zepaction_args中会包含完整的上下文信息 All meaningful actions are updated to Zep. action_args contains full context:
- 点赞/踩的帖子原文 - Original post content for likes/dislikes
- 转发/引用的帖子原文 - Original post content for reposts/quotes
- 关注/屏蔽的用户名 - Usernames for follows/mutes
- 点赞/踩的评论原文 - Original comment content for comment likes/dislikes
""" """
# 批量发送大小(每个平台累积多少条后发送) # Batch send size (activities per platform before sending)
BATCH_SIZE = 5 BATCH_SIZE = 5
# 平台名称映射(用于控制台显示) # Platform display names
PLATFORM_DISPLAY_NAMES = { PLATFORM_DISPLAY_NAMES = {
'twitter': '世界1', 'twitter': 'World 1',
'reddit': '世界2', 'reddit': 'World 2',
} }
# 发送间隔(秒),避免请求过快 # Send interval (seconds) to avoid sending too fast
SEND_INTERVAL = 0.5 SEND_INTERVAL = 0.5
# 重试配置 # Retry config
MAX_RETRIES = 3 MAX_RETRIES = 3
RETRY_DELAY = 2 # RETRY_DELAY = 2 # seconds
def __init__(self, graph_id: str, api_key: Optional[str] = None): def __init__(self, graph_id: str, api_key: Optional[str] = None):
""" """
初始化更新器 Initialize the updater
Args: Args:
graph_id: Zep图谱ID graph_id: Zep graph ID
api_key: Zep API Key可选默认从配置读取 api_key: Zep API key (optional; defaults to config value)
""" """
self.graph_id = graph_id self.graph_id = graph_id
self.api_key = api_key or Config.ZEP_API_KEY self.api_key = api_key or Config.ZEP_API_KEY
if not self.api_key: if not self.api_key:
raise ValueError("ZEP_API_KEY未配置") raise ValueError("ZEP_API_KEY is not configured")
self.client = Zep(api_key=self.api_key) self.client = Zep(api_key=self.api_key)
# 活动队列 # Activity queue
self._activity_queue: Queue = Queue() self._activity_queue: Queue = Queue()
# 按平台分组的活动缓冲区每个平台各自累积到BATCH_SIZE后批量发送 # Per-platform activity buffers (each platform accumulates to BATCH_SIZE before batch sending)
self._platform_buffers: Dict[str, List[AgentActivity]] = { self._platform_buffers: Dict[str, List[AgentActivity]] = {
'twitter': [], 'twitter': [],
'reddit': [], 'reddit': [],
} }
self._buffer_lock = threading.Lock() self._buffer_lock = threading.Lock()
# 控制标志 # Control flags
self._running = False self._running = False
self._worker_thread: Optional[threading.Thread] = None self._worker_thread: Optional[threading.Thread] = None
# 统计 # Statistics
self._total_activities = 0 # 实际添加到队列的活动数 self._total_activities = 0 # Activities added to queue
self._total_sent = 0 # 成功发送到Zep的批次数 self._total_sent = 0 # Batches successfully sent to Zep
self._total_items_sent = 0 # 成功发送到Zep的活动条数 self._total_items_sent = 0 # Individual activities successfully sent to Zep
self._failed_count = 0 # 发送失败的批次数 self._failed_count = 0 # Batches that failed to send
self._skipped_count = 0 # 被过滤跳过的活动数DO_NOTHING self._skipped_count = 0 # Activities filtered out (DO_NOTHING)
logger.info(f"ZepGraphMemoryUpdater 初始化完成: graph_id={graph_id}, batch_size={self.BATCH_SIZE}") logger.info(f"ZepGraphMemoryUpdater initialized: graph_id={graph_id}, batch_size={self.BATCH_SIZE}")
def _get_platform_display_name(self, platform: str) -> str: def _get_platform_display_name(self, platform: str) -> str:
"""获取平台的显示名称""" """Get the display name for a platform"""
return self.PLATFORM_DISPLAY_NAMES.get(platform.lower(), platform) return self.PLATFORM_DISPLAY_NAMES.get(platform.lower(), platform)
def start(self): def start(self):
"""启动后台工作线程""" """Start the background worker thread"""
if self._running: if self._running:
return return
@ -288,19 +290,19 @@ class ZepGraphMemoryUpdater:
name=f"ZepMemoryUpdater-{self.graph_id[:8]}" name=f"ZepMemoryUpdater-{self.graph_id[:8]}"
) )
self._worker_thread.start() self._worker_thread.start()
logger.info(f"ZepGraphMemoryUpdater 已启动: graph_id={self.graph_id}") logger.info(f"ZepGraphMemoryUpdater started: graph_id={self.graph_id}")
def stop(self): def stop(self):
"""停止后台工作线程""" """Stop the background worker thread"""
self._running = False self._running = False
# 发送剩余的活动 # Send remaining activities
self._flush_remaining() self._flush_remaining()
if self._worker_thread and self._worker_thread.is_alive(): if self._worker_thread and self._worker_thread.is_alive():
self._worker_thread.join(timeout=10) self._worker_thread.join(timeout=10)
logger.info(f"ZepGraphMemoryUpdater 已停止: graph_id={self.graph_id}, " logger.info(f"ZepGraphMemoryUpdater stopped: graph_id={self.graph_id}, "
f"total_activities={self._total_activities}, " f"total_activities={self._total_activities}, "
f"batches_sent={self._total_sent}, " f"batches_sent={self._total_sent}, "
f"items_sent={self._total_items_sent}, " f"items_sent={self._total_items_sent}, "
@ -309,43 +311,43 @@ class ZepGraphMemoryUpdater:
def add_activity(self, activity: AgentActivity): def add_activity(self, activity: AgentActivity):
""" """
添加一个agent活动到队列 Add an agent activity to the queue
所有有意义的行为都会被添加到队列包括 All meaningful actions are added to the queue, including:
- CREATE_POST发帖 - CREATE_POST
- CREATE_COMMENT评论 - CREATE_COMMENT
- QUOTE_POST引用帖子 - QUOTE_POST
- SEARCH_POSTS搜索帖子 - SEARCH_POSTS
- SEARCH_USER搜索用户 - SEARCH_USER
- LIKE_POST/DISLIKE_POST点赞/踩帖子 - LIKE_POST/DISLIKE_POST
- REPOST转发 - REPOST
- FOLLOW关注 - FOLLOW
- MUTE屏蔽 - MUTE
- LIKE_COMMENT/DISLIKE_COMMENT点赞/踩评论 - LIKE_COMMENT/DISLIKE_COMMENT
action_args中会包含完整的上下文信息如帖子原文用户名等 action_args contains full context (e.g. post content, usernames, etc.).
Args: Args:
activity: Agent活动记录 activity: agent activity record
""" """
# 跳过DO_NOTHING类型的活动 # Skip DO_NOTHING activities
if activity.action_type == "DO_NOTHING": if activity.action_type == "DO_NOTHING":
self._skipped_count += 1 self._skipped_count += 1
return return
self._activity_queue.put(activity) self._activity_queue.put(activity)
self._total_activities += 1 self._total_activities += 1
logger.debug(f"添加活动到Zep队列: {activity.agent_name} - {activity.action_type}") logger.debug(f"Added activity to Zep queue: {activity.agent_name} - {activity.action_type}")
def add_activity_from_dict(self, data: Dict[str, Any], platform: str): def add_activity_from_dict(self, data: Dict[str, Any], platform: str):
""" """
从字典数据添加活动 Add an activity from a dictionary
Args: Args:
data: 从actions.jsonl解析的字典数据 data: dict parsed from actions.jsonl
platform: 平台名称 (twitter/reddit) platform: platform name (twitter/reddit)
""" """
# 跳过事件类型的条目 # Skip event-type entries
if "event_type" in data: if "event_type" in data:
return return
@ -362,53 +364,53 @@ class ZepGraphMemoryUpdater:
self.add_activity(activity) self.add_activity(activity)
def _worker_loop(self, locale: str = 'zh'): def _worker_loop(self, locale: str = 'zh'):
"""后台工作循环 - 按平台批量发送活动到Zep""" """Background worker loop — batch-sends activities to Zep per platform"""
set_locale(locale) set_locale(locale)
while self._running or not self._activity_queue.empty(): while self._running or not self._activity_queue.empty():
try: try:
# 尝试从队列获取活动超时1秒 # Try to get an activity from the queue (1 second timeout)
try: try:
activity = self._activity_queue.get(timeout=1) activity = self._activity_queue.get(timeout=1)
# 将活动添加到对应平台的缓冲区 # Add activity to the corresponding platform buffer
platform = activity.platform.lower() platform = activity.platform.lower()
with self._buffer_lock: with self._buffer_lock:
if platform not in self._platform_buffers: if platform not in self._platform_buffers:
self._platform_buffers[platform] = [] self._platform_buffers[platform] = []
self._platform_buffers[platform].append(activity) self._platform_buffers[platform].append(activity)
# 检查该平台是否达到批量大小 # Check if this platform has reached the batch size
if len(self._platform_buffers[platform]) >= self.BATCH_SIZE: if len(self._platform_buffers[platform]) >= self.BATCH_SIZE:
batch = self._platform_buffers[platform][:self.BATCH_SIZE] batch = self._platform_buffers[platform][:self.BATCH_SIZE]
self._platform_buffers[platform] = self._platform_buffers[platform][self.BATCH_SIZE:] self._platform_buffers[platform] = self._platform_buffers[platform][self.BATCH_SIZE:]
# 释放锁后再发送 # Release lock before sending
self._send_batch_activities(batch, platform) self._send_batch_activities(batch, platform)
# 发送间隔,避免请求过快 # Throttle to avoid sending too fast
time.sleep(self.SEND_INTERVAL) time.sleep(self.SEND_INTERVAL)
except Empty: except Empty:
pass pass
except Exception as e: except Exception as e:
logger.error(f"工作循环异常: {e}") logger.error(f"Worker loop exception: {e}")
time.sleep(1) time.sleep(1)
def _send_batch_activities(self, activities: List[AgentActivity], platform: str): def _send_batch_activities(self, activities: List[AgentActivity], platform: str):
""" """
批量发送活动到Zep图谱合并为一条文本 Batch-send activities to the Zep graph (merged into a single text block)
Args: Args:
activities: Agent活动列表 activities: list of agent activities
platform: 平台名称 platform: platform name
""" """
if not activities: if not activities:
return return
# 将多条活动合并为一条文本,用换行分隔 # Merge multiple activities into a single text, separated by newlines
episode_texts = [activity.to_episode_text() for activity in activities] episode_texts = [activity.to_episode_text() for activity in activities]
combined_text = "\n".join(episode_texts) combined_text = "\n".join(episode_texts)
# 带重试的发送 # Send with retry
for attempt in range(self.MAX_RETRIES): for attempt in range(self.MAX_RETRIES):
try: try:
self.client.graph.add( self.client.graph.add(
@ -420,21 +422,21 @@ class ZepGraphMemoryUpdater:
self._total_sent += 1 self._total_sent += 1
self._total_items_sent += len(activities) self._total_items_sent += len(activities)
display_name = self._get_platform_display_name(platform) display_name = self._get_platform_display_name(platform)
logger.info(f"成功批量发送 {len(activities)}{display_name}活动到图谱 {self.graph_id}") logger.info(f"Successfully sent batch of {len(activities)} {display_name} activities to graph {self.graph_id}")
logger.debug(f"批量内容预览: {combined_text[:200]}...") logger.debug(f"Batch content preview: {combined_text[:200]}...")
return return
except Exception as e: except Exception as e:
if attempt < self.MAX_RETRIES - 1: if attempt < self.MAX_RETRIES - 1:
logger.warning(f"批量发送到Zep失败 (尝试 {attempt + 1}/{self.MAX_RETRIES}): {e}") logger.warning(f"Batch send to Zep failed (attempt {attempt + 1}/{self.MAX_RETRIES}): {e}")
time.sleep(self.RETRY_DELAY * (attempt + 1)) time.sleep(self.RETRY_DELAY * (attempt + 1))
else: else:
logger.error(f"批量发送到Zep失败已重试{self.MAX_RETRIES}: {e}") logger.error(f"Batch send to Zep failed after {self.MAX_RETRIES} attempts: {e}")
self._failed_count += 1 self._failed_count += 1
def _flush_remaining(self): def _flush_remaining(self):
"""发送队列和缓冲区中剩余的活动""" """Send remaining activities from the queue and buffers"""
# 首先处理队列中剩余的活动,添加到缓冲区 # First, drain the queue into the buffers
while not self._activity_queue.empty(): while not self._activity_queue.empty():
try: try:
activity = self._activity_queue.get_nowait() activity = self._activity_queue.get_nowait()
@ -446,41 +448,41 @@ class ZepGraphMemoryUpdater:
except Empty: except Empty:
break break
# 然后发送各平台缓冲区中剩余的活动即使不足BATCH_SIZE条 # Then send remaining activities in each platform buffer (even if below BATCH_SIZE)
with self._buffer_lock: with self._buffer_lock:
for platform, buffer in self._platform_buffers.items(): for platform, buffer in self._platform_buffers.items():
if buffer: if buffer:
display_name = self._get_platform_display_name(platform) display_name = self._get_platform_display_name(platform)
logger.info(f"发送{display_name}平台剩余的 {len(buffer)} 条活动") logger.info(f"Sending {len(buffer)} remaining {display_name} platform activities")
self._send_batch_activities(buffer, platform) self._send_batch_activities(buffer, platform)
# 清空所有缓冲区 # Clear all buffers
for platform in self._platform_buffers: for platform in self._platform_buffers:
self._platform_buffers[platform] = [] self._platform_buffers[platform] = []
def get_stats(self) -> Dict[str, Any]: def get_stats(self) -> Dict[str, Any]:
"""获取统计信息""" """Get statistics"""
with self._buffer_lock: with self._buffer_lock:
buffer_sizes = {p: len(b) for p, b in self._platform_buffers.items()} buffer_sizes = {p: len(b) for p, b in self._platform_buffers.items()}
return { return {
"graph_id": self.graph_id, "graph_id": self.graph_id,
"batch_size": self.BATCH_SIZE, "batch_size": self.BATCH_SIZE,
"total_activities": self._total_activities, # 添加到队列的活动总数 "total_activities": self._total_activities, # Total activities added to queue
"batches_sent": self._total_sent, # 成功发送的批次数 "batches_sent": self._total_sent, # Batches successfully sent
"items_sent": self._total_items_sent, # 成功发送的活动条数 "items_sent": self._total_items_sent, # Individual activities successfully sent
"failed_count": self._failed_count, # 发送失败的批次数 "failed_count": self._failed_count, # Batches that failed to send
"skipped_count": self._skipped_count, # 被过滤跳过的活动数DO_NOTHING "skipped_count": self._skipped_count, # Activities filtered out (DO_NOTHING)
"queue_size": self._activity_queue.qsize(), "queue_size": self._activity_queue.qsize(),
"buffer_sizes": buffer_sizes, # 各平台缓冲区大小 "buffer_sizes": buffer_sizes, # Per-platform buffer sizes
"running": self._running, "running": self._running,
} }
class ZepGraphMemoryManager: class ZepGraphMemoryManager:
""" """
管理多个模拟的Zep图谱记忆更新器 Manages Zep graph memory updaters for multiple simulations
每个模拟可以有自己的更新器实例 Each simulation can have its own updater instance
""" """
_updaters: Dict[str, ZepGraphMemoryUpdater] = {} _updaters: Dict[str, ZepGraphMemoryUpdater] = {}
@ -489,17 +491,17 @@ class ZepGraphMemoryManager:
@classmethod @classmethod
def create_updater(cls, simulation_id: str, graph_id: str) -> ZepGraphMemoryUpdater: def create_updater(cls, simulation_id: str, graph_id: str) -> ZepGraphMemoryUpdater:
""" """
为模拟创建图谱记忆更新器 Create a graph memory updater for a simulation
Args: Args:
simulation_id: 模拟ID simulation_id: simulation ID
graph_id: Zep图谱ID graph_id: Zep graph ID
Returns: Returns:
ZepGraphMemoryUpdater实例 ZepGraphMemoryUpdater instance
""" """
with cls._lock: with cls._lock:
# 如果已存在,先停止旧的 # If one already exists, stop it first
if simulation_id in cls._updaters: if simulation_id in cls._updaters:
cls._updaters[simulation_id].stop() cls._updaters[simulation_id].stop()
@ -507,30 +509,30 @@ class ZepGraphMemoryManager:
updater.start() updater.start()
cls._updaters[simulation_id] = updater cls._updaters[simulation_id] = updater
logger.info(f"创建图谱记忆更新器: simulation_id={simulation_id}, graph_id={graph_id}") logger.info(f"Created graph memory updater: simulation_id={simulation_id}, graph_id={graph_id}")
return updater return updater
@classmethod @classmethod
def get_updater(cls, simulation_id: str) -> Optional[ZepGraphMemoryUpdater]: def get_updater(cls, simulation_id: str) -> Optional[ZepGraphMemoryUpdater]:
"""获取模拟的更新器""" """Get the updater for a simulation"""
return cls._updaters.get(simulation_id) return cls._updaters.get(simulation_id)
@classmethod @classmethod
def stop_updater(cls, simulation_id: str): def stop_updater(cls, simulation_id: str):
"""停止并移除模拟的更新器""" """Stop and remove the updater for a simulation"""
with cls._lock: with cls._lock:
if simulation_id in cls._updaters: if simulation_id in cls._updaters:
cls._updaters[simulation_id].stop() cls._updaters[simulation_id].stop()
del cls._updaters[simulation_id] del cls._updaters[simulation_id]
logger.info(f"已停止图谱记忆更新器: simulation_id={simulation_id}") logger.info(f"Stopped graph memory updater: simulation_id={simulation_id}")
# 防止 stop_all 重复调用的标志 # Flag to prevent stop_all from being called more than once
_stop_all_done = False _stop_all_done = False
@classmethod @classmethod
def stop_all(cls): def stop_all(cls):
"""停止所有更新器""" """Stop all updaters"""
# 防止重复调用 # Prevent duplicate calls
if cls._stop_all_done: if cls._stop_all_done:
return return
cls._stop_all_done = True cls._stop_all_done = True
@ -541,13 +543,13 @@ class ZepGraphMemoryManager:
try: try:
updater.stop() updater.stop()
except Exception as e: except Exception as e:
logger.error(f"停止更新器失败: simulation_id={simulation_id}, error={e}") logger.error(f"Failed to stop updater: simulation_id={simulation_id}, error={e}")
cls._updaters.clear() cls._updaters.clear()
logger.info("已停止所有图谱记忆更新器") logger.info("All graph memory updaters stopped")
@classmethod @classmethod
def get_all_stats(cls) -> Dict[str, Dict[str, Any]]: def get_all_stats(cls) -> Dict[str, Dict[str, Any]]:
"""获取所有更新器的统计信息""" """Get statistics for all updaters"""
return { return {
sim_id: updater.get_stats() sim_id: updater.get_stats()
for sim_id, updater in cls._updaters.items() for sim_id, updater in cls._updaters.items()

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,5 @@
""" """
工具模块 Utilities module
""" """
from .file_parser import FileParser from .file_parser import FileParser

View File

@ -1,6 +1,6 @@
""" """
文件解析工具 File parsing utilities
支持PDFMarkdownTXT文件的文本提取 Supports text extraction from PDF, Markdown, and TXT files
""" """
import os import os
@ -10,29 +10,29 @@ from typing import List, Optional
def _read_text_with_fallback(file_path: str) -> str: def _read_text_with_fallback(file_path: str) -> str:
""" """
读取文本文件UTF-8失败时自动探测编码 Read a text file, automatically detecting encoding if UTF-8 fails.
采用多级回退策略 Uses a multi-level fallback strategy:
1. 首先尝试 UTF-8 解码 1. First attempts UTF-8 decoding
2. 使用 charset_normalizer 检测编码 2. Uses charset_normalizer to detect encoding
3. 回退到 chardet 检测编码 3. Falls back to chardet for encoding detection
4. 最终使用 UTF-8 + errors='replace' 兜底 4. Final fallback: UTF-8 with errors='replace'
Args: Args:
file_path: 文件路径 file_path: Path to the file
Returns: Returns:
解码后的文本内容 Decoded text content
""" """
data = Path(file_path).read_bytes() data = Path(file_path).read_bytes()
# 首先尝试 UTF-8 # First attempt: UTF-8
try: try:
return data.decode('utf-8') return data.decode('utf-8')
except UnicodeDecodeError: except UnicodeDecodeError:
pass pass
# 尝试使用 charset_normalizer 检测编码 # Attempt encoding detection with charset_normalizer
encoding = None encoding = None
try: try:
from charset_normalizer import from_bytes from charset_normalizer import from_bytes
@ -42,7 +42,7 @@ def _read_text_with_fallback(file_path: str) -> str:
except Exception: except Exception:
pass pass
# 回退到 chardet # Fall back to chardet
if not encoding: if not encoding:
try: try:
import chardet import chardet
@ -51,7 +51,7 @@ def _read_text_with_fallback(file_path: str) -> str:
except Exception: except Exception:
pass pass
# 最终兜底:使用 UTF-8 + replace # Final fallback: UTF-8 with replace
if not encoding: if not encoding:
encoding = 'utf-8' encoding = 'utf-8'
@ -59,30 +59,30 @@ def _read_text_with_fallback(file_path: str) -> str:
class FileParser: class FileParser:
"""文件解析器""" """File parser"""
SUPPORTED_EXTENSIONS = {'.pdf', '.md', '.markdown', '.txt'} SUPPORTED_EXTENSIONS = {'.pdf', '.md', '.markdown', '.txt'}
@classmethod @classmethod
def extract_text(cls, file_path: str) -> str: def extract_text(cls, file_path: str) -> str:
""" """
从文件中提取文本 Extract text from a file
Args: Args:
file_path: 文件路径 file_path: Path to the file
Returns: Returns:
提取的文本内容 Extracted text content
""" """
path = Path(file_path) path = Path(file_path)
if not path.exists(): if not path.exists():
raise FileNotFoundError(f"文件不存在: {file_path}") raise FileNotFoundError(f"File not found: {file_path}")
suffix = path.suffix.lower() suffix = path.suffix.lower()
if suffix not in cls.SUPPORTED_EXTENSIONS: if suffix not in cls.SUPPORTED_EXTENSIONS:
raise ValueError(f"不支持的文件格式: {suffix}") raise ValueError(f"Unsupported file format: {suffix}")
if suffix == '.pdf': if suffix == '.pdf':
return cls._extract_from_pdf(file_path) return cls._extract_from_pdf(file_path)
@ -91,15 +91,15 @@ class FileParser:
elif suffix == '.txt': elif suffix == '.txt':
return cls._extract_from_txt(file_path) return cls._extract_from_txt(file_path)
raise ValueError(f"无法处理的文件格式: {suffix}") raise ValueError(f"Cannot process file format: {suffix}")
@staticmethod @staticmethod
def _extract_from_pdf(file_path: str) -> str: def _extract_from_pdf(file_path: str) -> str:
"""从PDF提取文本""" """Extract text from a PDF file"""
try: try:
import fitz # PyMuPDF import fitz # PyMuPDF
except ImportError: except ImportError:
raise ImportError("需要安装PyMuPDF: pip install PyMuPDF") raise ImportError("PyMuPDF is required: pip install PyMuPDF")
text_parts = [] text_parts = []
with fitz.open(file_path) as doc: with fitz.open(file_path) as doc:
@ -112,24 +112,24 @@ class FileParser:
@staticmethod @staticmethod
def _extract_from_md(file_path: str) -> str: def _extract_from_md(file_path: str) -> str:
"""从Markdown提取文本支持自动编码检测""" """Extract text from a Markdown file with automatic encoding detection"""
return _read_text_with_fallback(file_path) return _read_text_with_fallback(file_path)
@staticmethod @staticmethod
def _extract_from_txt(file_path: str) -> str: def _extract_from_txt(file_path: str) -> str:
"""从TXT提取文本支持自动编码检测""" """Extract text from a TXT file with automatic encoding detection"""
return _read_text_with_fallback(file_path) return _read_text_with_fallback(file_path)
@classmethod @classmethod
def extract_from_multiple(cls, file_paths: List[str]) -> str: def extract_from_multiple(cls, file_paths: List[str]) -> str:
""" """
从多个文件提取文本并合并 Extract text from multiple files and merge the results
Args: Args:
file_paths: 文件路径列表 file_paths: List of file paths
Returns: Returns:
合并后的文本 Merged text content
""" """
all_texts = [] all_texts = []
@ -137,9 +137,9 @@ class FileParser:
try: try:
text = cls.extract_text(file_path) text = cls.extract_text(file_path)
filename = Path(file_path).name filename = Path(file_path).name
all_texts.append(f"=== 文档 {i}: {filename} ===\n{text}") all_texts.append(f"=== Document {i}: {filename} ===\n{text}")
except Exception as e: except Exception as e:
all_texts.append(f"=== 文档 {i}: {file_path} (提取失败: {str(e)}) ===") all_texts.append(f"=== Document {i}: {file_path} (extraction failed: {str(e)}) ===")
return "\n\n".join(all_texts) return "\n\n".join(all_texts)
@ -150,15 +150,15 @@ def split_text_into_chunks(
overlap: int = 50 overlap: int = 50
) -> List[str]: ) -> List[str]:
""" """
将文本分割成小块 Split text into smaller chunks
Args: Args:
text: 原始文本 text: Source text
chunk_size: 每块的字符数 chunk_size: Number of characters per chunk
overlap: 重叠字符数 overlap: Number of overlapping characters between chunks
Returns: Returns:
文本块列表 List of text chunks
""" """
if len(text) <= chunk_size: if len(text) <= chunk_size:
return [text] if text.strip() else [] return [text] if text.strip() else []
@ -169,9 +169,9 @@ def split_text_into_chunks(
while start < len(text): while start < len(text):
end = start + chunk_size end = start + chunk_size
# 尝试在句子边界处分割 # Try to split at sentence boundaries
if end < len(text): if end < len(text):
# 查找最近的句子结束符 # Find the nearest sentence-ending separator
for sep in ['', '', '', '.\n', '!\n', '?\n', '\n\n', '. ', '! ', '? ']: for sep in ['', '', '', '.\n', '!\n', '?\n', '\n\n', '. ', '! ', '? ']:
last_sep = text[start:end].rfind(sep) last_sep = text[start:end].rfind(sep)
if last_sep != -1 and last_sep > chunk_size * 0.3: if last_sep != -1 and last_sep > chunk_size * 0.3:
@ -182,8 +182,7 @@ def split_text_into_chunks(
if chunk: if chunk:
chunks.append(chunk) chunks.append(chunk)
# 下一个块从重叠位置开始 # Next chunk starts at the overlap position
start = end - overlap if end < len(text) else len(text) start = end - overlap if end < len(text) else len(text)
return chunks return chunks

View File

@ -1,18 +1,19 @@
""" """
LLM客户端封装 LLM client wrapper
统一使用OpenAI格式调用 Unified interface using the OpenAI-compatible API format
""" """
import json import json
import re import re
from typing import Optional, Dict, Any, List from typing import Optional, Dict, Any, List
from urllib.parse import urlparse, parse_qs, urlunparse
from openai import OpenAI from openai import OpenAI
from ..config import Config from ..config import Config
class LLMClient: class LLMClient:
"""LLM客户端""" """LLM client"""
def __init__( def __init__(
self, self,
@ -21,15 +22,30 @@ class LLMClient:
model: Optional[str] = None model: Optional[str] = None
): ):
self.api_key = api_key or Config.LLM_API_KEY self.api_key = api_key or Config.LLM_API_KEY
self.base_url = base_url or Config.LLM_BASE_URL raw_url = base_url or Config.LLM_BASE_URL
self.model = model or Config.LLM_MODEL_NAME self.model = model or Config.LLM_MODEL_NAME
if not self.api_key: if not self.api_key:
raise ValueError("LLM_API_KEY 未配置") raise ValueError("LLM_API_KEY is not configured")
# Azure Portal provides full endpoint URLs like:
# https://<resource>.cognitiveservices.azure.com/openai/deployments/<model>/chat/completions?api-version=...
# The OpenAI SDK expects a base_url and appends /chat/completions itself,
# so we strip that suffix and extract api-version as a default query param.
default_query: Dict[str, str] = {}
if raw_url and '/chat/completions' in raw_url:
parsed = urlparse(raw_url)
qs = parse_qs(parsed.query)
if 'api-version' in qs:
default_query['api-version'] = qs['api-version'][0]
clean_path = parsed.path.replace('/chat/completions', '').rstrip('/')
raw_url = urlunparse(parsed._replace(path=clean_path, query=''))
self.base_url = raw_url
self.client = OpenAI( self.client = OpenAI(
api_key=self.api_key, api_key=self.api_key,
base_url=self.base_url base_url=self.base_url,
default_query=default_query if default_query else None
) )
def chat( def chat(
@ -40,22 +56,22 @@ class LLMClient:
response_format: Optional[Dict] = None response_format: Optional[Dict] = None
) -> str: ) -> str:
""" """
发送聊天请求 Send a chat request
Args: Args:
messages: 消息列表 messages: List of messages
temperature: 温度参数 temperature: Temperature parameter
max_tokens: 最大token数 max_tokens: Maximum number of tokens
response_format: 响应格式如JSON模式 response_format: Response format (e.g. JSON mode)
Returns: Returns:
模型响应文本 Model response text
""" """
kwargs = { kwargs = {
"model": self.model, "model": self.model,
"messages": messages, "messages": messages,
"temperature": temperature, "temperature": temperature,
"max_tokens": max_tokens, "max_completion_tokens": max_tokens,
} }
if response_format: if response_format:
@ -63,7 +79,7 @@ class LLMClient:
response = self.client.chat.completions.create(**kwargs) response = self.client.chat.completions.create(**kwargs)
content = response.choices[0].message.content content = response.choices[0].message.content
# 部分模型如MiniMax M2.5会在content中包含<think>思考内容,需要移除 # Some models (e.g. MiniMax M2.5) include <think> reasoning content in the response; strip it out
content = re.sub(r'<think>[\s\S]*?</think>', '', content).strip() content = re.sub(r'<think>[\s\S]*?</think>', '', content).strip()
return content return content
@ -74,15 +90,15 @@ class LLMClient:
max_tokens: int = 4096 max_tokens: int = 4096
) -> Dict[str, Any]: ) -> Dict[str, Any]:
""" """
发送聊天请求并返回JSON Send a chat request and return parsed JSON
Args: Args:
messages: 消息列表 messages: List of messages
temperature: 温度参数 temperature: Temperature parameter
max_tokens: 最大token数 max_tokens: Maximum number of tokens
Returns: Returns:
解析后的JSON对象 Parsed JSON object
""" """
response = self.chat( response = self.chat(
messages=messages, messages=messages,
@ -90,7 +106,7 @@ class LLMClient:
max_tokens=max_tokens, max_tokens=max_tokens,
response_format={"type": "json_object"} response_format={"type": "json_object"}
) )
# 清理markdown代码块标记 # Strip markdown code-block markers if present
cleaned_response = response.strip() cleaned_response = response.strip()
cleaned_response = re.sub(r'^```(?:json)?\s*\n?', '', cleaned_response, flags=re.IGNORECASE) cleaned_response = re.sub(r'^```(?:json)?\s*\n?', '', cleaned_response, flags=re.IGNORECASE)
cleaned_response = re.sub(r'\n?```\s*$', '', cleaned_response) cleaned_response = re.sub(r'\n?```\s*$', '', cleaned_response)
@ -99,5 +115,4 @@ class LLMClient:
try: try:
return json.loads(cleaned_response) return json.loads(cleaned_response)
except json.JSONDecodeError: except json.JSONDecodeError:
raise ValueError(f"LLM返回的JSON格式无效: {cleaned_response}") raise ValueError(f"Invalid JSON returned by LLM: {cleaned_response}")

View File

@ -66,4 +66,4 @@ def t(key: str, **kwargs) -> str:
def get_language_instruction() -> str: def get_language_instruction() -> str:
locale = get_locale() locale = get_locale()
lang_config = _languages.get(locale, _languages.get('zh', {})) lang_config = _languages.get(locale, _languages.get('zh', {}))
return lang_config.get('llmInstruction', '请使用中文回答。') return lang_config.get('llmInstruction', 'Please respond in Chinese.')

View File

@ -1,6 +1,6 @@
""" """
日志配置模块 Logging configuration module
提供统一的日志管理同时输出到控制台和文件 Provides unified log management, writing to both console and file
""" """
import os import os
@ -12,47 +12,47 @@ from logging.handlers import RotatingFileHandler
def _ensure_utf8_stdout(): def _ensure_utf8_stdout():
""" """
确保 stdout/stderr 使用 UTF-8 编码 Ensure stdout/stderr use UTF-8 encoding.
解决 Windows 控制台中文乱码问题 Fixes garbled output in Windows consoles.
""" """
if sys.platform == 'win32': if sys.platform == 'win32':
# Windows 下重新配置标准输出为 UTF-8 # Reconfigure standard streams to UTF-8 on Windows
if hasattr(sys.stdout, 'reconfigure'): if hasattr(sys.stdout, 'reconfigure'):
sys.stdout.reconfigure(encoding='utf-8', errors='replace') sys.stdout.reconfigure(encoding='utf-8', errors='replace')
if hasattr(sys.stderr, 'reconfigure'): if hasattr(sys.stderr, 'reconfigure'):
sys.stderr.reconfigure(encoding='utf-8', errors='replace') sys.stderr.reconfigure(encoding='utf-8', errors='replace')
# 日志目录 # Log directory
LOG_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'logs') LOG_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'logs')
def setup_logger(name: str = 'mirofish', level: int = logging.DEBUG) -> logging.Logger: def setup_logger(name: str = 'mirofish', level: int = logging.DEBUG) -> logging.Logger:
""" """
设置日志器 Set up a logger
Args: Args:
name: 日志器名称 name: Logger name
level: 日志级别 level: Log level
Returns: Returns:
配置好的日志器 Configured logger instance
""" """
# 确保日志目录存在 # Ensure the log directory exists
os.makedirs(LOG_DIR, exist_ok=True) os.makedirs(LOG_DIR, exist_ok=True)
# 创建日志器 # Create logger
logger = logging.getLogger(name) logger = logging.getLogger(name)
logger.setLevel(level) logger.setLevel(level)
# 阻止日志向上传播到根 logger避免重复输出 # Prevent log records from propagating to the root logger to avoid duplicate output
logger.propagate = False logger.propagate = False
# 如果已经有处理器,不重复添加 # Skip adding handlers if they already exist
if logger.handlers: if logger.handlers:
return logger return logger
# 日志格式 # Log formatters
detailed_formatter = logging.Formatter( detailed_formatter = logging.Formatter(
'[%(asctime)s] %(levelname)s [%(name)s.%(funcName)s:%(lineno)d] %(message)s', '[%(asctime)s] %(levelname)s [%(name)s.%(funcName)s:%(lineno)d] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S' datefmt='%Y-%m-%d %H:%M:%S'
@ -63,7 +63,7 @@ def setup_logger(name: str = 'mirofish', level: int = logging.DEBUG) -> logging.
datefmt='%H:%M:%S' datefmt='%H:%M:%S'
) )
# 1. 文件处理器 - 详细日志(按日期命名,带轮转) # 1. File handler — detailed logs (date-stamped filename with rotation)
log_filename = datetime.now().strftime('%Y-%m-%d') + '.log' log_filename = datetime.now().strftime('%Y-%m-%d') + '.log'
file_handler = RotatingFileHandler( file_handler = RotatingFileHandler(
os.path.join(LOG_DIR, log_filename), os.path.join(LOG_DIR, log_filename),
@ -74,14 +74,14 @@ def setup_logger(name: str = 'mirofish', level: int = logging.DEBUG) -> logging.
file_handler.setLevel(logging.DEBUG) file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(detailed_formatter) file_handler.setFormatter(detailed_formatter)
# 2. 控制台处理器 - 简洁日志INFO及以上 # 2. Console handler — concise logs (INFO and above)
# 确保 Windows 下使用 UTF-8 编码,避免中文乱码 # Ensure UTF-8 encoding on Windows to avoid garbled output
_ensure_utf8_stdout() _ensure_utf8_stdout()
console_handler = logging.StreamHandler(sys.stdout) console_handler = logging.StreamHandler(sys.stdout)
console_handler.setLevel(logging.INFO) console_handler.setLevel(logging.INFO)
console_handler.setFormatter(simple_formatter) console_handler.setFormatter(simple_formatter)
# 添加处理器 # Register handlers
logger.addHandler(file_handler) logger.addHandler(file_handler)
logger.addHandler(console_handler) logger.addHandler(console_handler)
@ -90,13 +90,13 @@ def setup_logger(name: str = 'mirofish', level: int = logging.DEBUG) -> logging.
def get_logger(name: str = 'mirofish') -> logging.Logger: def get_logger(name: str = 'mirofish') -> logging.Logger:
""" """
获取日志器如果不存在则创建 Get a logger, creating it if it does not exist
Args: Args:
name: 日志器名称 name: Logger name
Returns: Returns:
日志器实例 Logger instance
""" """
logger = logging.getLogger(name) logger = logging.getLogger(name)
if not logger.handlers: if not logger.handlers:
@ -104,11 +104,11 @@ def get_logger(name: str = 'mirofish') -> logging.Logger:
return logger return logger
# 创建默认日志器 # Create default logger
logger = setup_logger() logger = setup_logger()
# 便捷方法 # Convenience functions
def debug(msg, *args, **kwargs): def debug(msg, *args, **kwargs):
logger.debug(msg, *args, **kwargs) logger.debug(msg, *args, **kwargs)
@ -123,4 +123,3 @@ def error(msg, *args, **kwargs):
def critical(msg, *args, **kwargs): def critical(msg, *args, **kwargs):
logger.critical(msg, *args, **kwargs) logger.critical(msg, *args, **kwargs)

View File

@ -1,6 +1,6 @@
""" """
API调用重试机制 API call retry mechanism
用于处理LLM等外部API调用的重试逻辑 Handles retry logic for external API calls such as LLM services
""" """
import time import time
@ -22,16 +22,16 @@ def retry_with_backoff(
on_retry: Optional[Callable[[Exception, int], None]] = None on_retry: Optional[Callable[[Exception, int], None]] = None
): ):
""" """
带指数退避的重试装饰器 Retry decorator with exponential backoff
Args: Args:
max_retries: 最大重试次数 max_retries: Maximum number of retries
initial_delay: 初始延迟 initial_delay: Initial delay in seconds
max_delay: 最大延迟 max_delay: Maximum delay in seconds
backoff_factor: 退避因子 backoff_factor: Backoff multiplier
jitter: 是否添加随机抖动 jitter: Whether to add random jitter
exceptions: 需要重试的异常类型 exceptions: Exception types that should trigger a retry
on_retry: 重试时的回调函数 (exception, retry_count) on_retry: Callback invoked on each retry (exception, retry_count)
Usage: Usage:
@retry_with_backoff(max_retries=3) @retry_with_backoff(max_retries=3)
@ -52,17 +52,17 @@ def retry_with_backoff(
last_exception = e last_exception = e
if attempt == max_retries: if attempt == max_retries:
logger.error(f"函数 {func.__name__}{max_retries} 次重试后仍失败: {str(e)}") logger.error(f"Function {func.__name__} failed after {max_retries} retries: {str(e)}")
raise raise
# 计算延迟 # Calculate delay
current_delay = min(delay, max_delay) current_delay = min(delay, max_delay)
if jitter: if jitter:
current_delay = current_delay * (0.5 + random.random()) current_delay = current_delay * (0.5 + random.random())
logger.warning( logger.warning(
f"函数 {func.__name__}{attempt + 1} 次尝试失败: {str(e)}, " f"Function {func.__name__} attempt {attempt + 1} failed: {str(e)}, "
f"{current_delay:.1f}秒后重试..." f"retrying in {current_delay:.1f}s..."
) )
if on_retry: if on_retry:
@ -87,7 +87,7 @@ def retry_with_backoff_async(
on_retry: Optional[Callable[[Exception, int], None]] = None on_retry: Optional[Callable[[Exception, int], None]] = None
): ):
""" """
异步版本的重试装饰器 Async version of the retry decorator
""" """
import asyncio import asyncio
@ -105,7 +105,7 @@ def retry_with_backoff_async(
last_exception = e last_exception = e
if attempt == max_retries: if attempt == max_retries:
logger.error(f"异步函数 {func.__name__}{max_retries} 次重试后仍失败: {str(e)}") logger.error(f"Async function {func.__name__} failed after {max_retries} retries: {str(e)}")
raise raise
current_delay = min(delay, max_delay) current_delay = min(delay, max_delay)
@ -113,8 +113,8 @@ def retry_with_backoff_async(
current_delay = current_delay * (0.5 + random.random()) current_delay = current_delay * (0.5 + random.random())
logger.warning( logger.warning(
f"异步函数 {func.__name__}{attempt + 1} 次尝试失败: {str(e)}, " f"Async function {func.__name__} attempt {attempt + 1} failed: {str(e)}, "
f"{current_delay:.1f}秒后重试..." f"retrying in {current_delay:.1f}s..."
) )
if on_retry: if on_retry:
@ -131,7 +131,7 @@ def retry_with_backoff_async(
class RetryableAPIClient: class RetryableAPIClient:
""" """
可重试的API客户端封装 Retryable API client wrapper
""" """
def __init__( def __init__(
@ -154,16 +154,16 @@ class RetryableAPIClient:
**kwargs **kwargs
) -> Any: ) -> Any:
""" """
执行函数调用并在失败时重试 Execute a function call and retry on failure
Args: Args:
func: 要调用的函数 func: Function to call
*args: 函数参数 *args: Positional arguments for the function
exceptions: 需要重试的异常类型 exceptions: Exception types that should trigger a retry
**kwargs: 函数关键字参数 **kwargs: Keyword arguments for the function
Returns: Returns:
函数返回值 Return value of the function
""" """
last_exception = None last_exception = None
delay = self.initial_delay delay = self.initial_delay
@ -176,15 +176,15 @@ class RetryableAPIClient:
last_exception = e last_exception = e
if attempt == self.max_retries: if attempt == self.max_retries:
logger.error(f"API调用在 {self.max_retries} 次重试后仍失败: {str(e)}") logger.error(f"API call failed after {self.max_retries} retries: {str(e)}")
raise raise
current_delay = min(delay, self.max_delay) current_delay = min(delay, self.max_delay)
current_delay = current_delay * (0.5 + random.random()) current_delay = current_delay * (0.5 + random.random())
logger.warning( logger.warning(
f"API调用第 {attempt + 1} 次尝试失败: {str(e)}, " f"API call attempt {attempt + 1} failed: {str(e)}, "
f"{current_delay:.1f}秒后重试..." f"retrying in {current_delay:.1f}s..."
) )
time.sleep(current_delay) time.sleep(current_delay)
@ -200,16 +200,16 @@ class RetryableAPIClient:
continue_on_failure: bool = True continue_on_failure: bool = True
) -> Tuple[list, list]: ) -> Tuple[list, list]:
""" """
批量调用并对每个失败项单独重试 Process a batch of items, retrying individually on failure
Args: Args:
items: 要处理的项目列表 items: List of items to process
process_func: 处理函数接收单个item作为参数 process_func: Processing function that accepts a single item
exceptions: 需要重试的异常类型 exceptions: Exception types that should trigger a retry
continue_on_failure: 单项失败后是否继续处理其他项 continue_on_failure: Whether to continue processing remaining items after a failure
Returns: Returns:
(成功结果列表, 失败项列表) (list of successful results, list of failed items)
""" """
results = [] results = []
failures = [] failures = []
@ -224,7 +224,7 @@ class RetryableAPIClient:
results.append(result) results.append(result)
except Exception as e: except Exception as e:
logger.error(f"处理第 {idx + 1} 项失败: {str(e)}") logger.error(f"Failed to process item {idx + 1}: {str(e)}")
failures.append({ failures.append({
"index": idx, "index": idx,
"item": item, "item": item,
@ -235,4 +235,3 @@ class RetryableAPIClient:
raise raise
return results, failures return results, failures

View File

@ -1,7 +1,8 @@
"""Zep Graph 分页读取工具。 """Zep Graph paginated fetch utilities.
Zep node/edge 列表接口使用 UUID cursor 分页 Zep's node/edge list endpoints use UUID-cursor-based pagination.
本模块封装自动翻页逻辑含单页重试对调用方透明地返回完整列表 This module wraps the auto-pagination logic (with per-page retries) and
returns the full result list transparently to the caller.
""" """
from __future__ import annotations from __future__ import annotations
@ -31,7 +32,7 @@ def _fetch_page_with_retry(
page_description: str = "page", page_description: str = "page",
**kwargs: Any, **kwargs: Any,
) -> list[Any]: ) -> list[Any]:
"""单页请求,失败时指数退避重试。仅重试网络/IO类瞬态错误。""" """Fetch a single page with exponential-backoff retry on transient network/IO errors."""
if max_retries < 1: if max_retries < 1:
raise ValueError("max_retries must be >= 1") raise ValueError("max_retries must be >= 1")
@ -64,7 +65,7 @@ def fetch_all_nodes(
max_retries: int = _DEFAULT_MAX_RETRIES, max_retries: int = _DEFAULT_MAX_RETRIES,
retry_delay: float = _DEFAULT_RETRY_DELAY, retry_delay: float = _DEFAULT_RETRY_DELAY,
) -> list[Any]: ) -> list[Any]:
"""分页获取图谱节点,最多返回 max_items 条(默认 2000。每页请求自带重试。""" """Fetch all graph nodes with pagination, returning at most max_items (default 2000). Each page request includes retries."""
all_nodes: list[Any] = [] all_nodes: list[Any] = []
cursor: str | None = None cursor: str | None = None
page_num = 0 page_num = 0
@ -109,7 +110,7 @@ def fetch_all_edges(
max_retries: int = _DEFAULT_MAX_RETRIES, max_retries: int = _DEFAULT_MAX_RETRIES,
retry_delay: float = _DEFAULT_RETRY_DELAY, retry_delay: float = _DEFAULT_RETRY_DELAY,
) -> list[Any]: ) -> list[Any]:
"""分页获取图谱所有边,返回完整列表。每页请求自带重试。""" """Fetch all graph edges with pagination, returning the complete list. Each page request includes retries."""
all_edges: list[Any] = [] all_edges: list[Any] = []
cursor: str | None = None cursor: str | None = None
page_num = 0 page_num = 0

View File

@ -1435,7 +1435,6 @@
"resolved": "https://registry.npmjs.org/d3-selection/-/d3-selection-3.0.0.tgz", "resolved": "https://registry.npmjs.org/d3-selection/-/d3-selection-3.0.0.tgz",
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==", "integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
"license": "ISC", "license": "ISC",
"peer": true,
"engines": { "engines": {
"node": ">=12" "node": ">=12"
} }
@ -1913,7 +1912,6 @@
"integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==", "integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==",
"dev": true, "dev": true,
"license": "MIT", "license": "MIT",
"peer": true,
"engines": { "engines": {
"node": ">=12" "node": ">=12"
}, },
@ -2053,7 +2051,6 @@
"integrity": "sha512-ITcnkFeR3+fI8P1wMgItjGrR10170d8auB4EpMLPqmx6uxElH3a/hHGQabSHKdqd4FXWO1nFIp9rRn7JQ34ACQ==", "integrity": "sha512-ITcnkFeR3+fI8P1wMgItjGrR10170d8auB4EpMLPqmx6uxElH3a/hHGQabSHKdqd4FXWO1nFIp9rRn7JQ34ACQ==",
"dev": true, "dev": true,
"license": "MIT", "license": "MIT",
"peer": true,
"dependencies": { "dependencies": {
"esbuild": "^0.25.0", "esbuild": "^0.25.0",
"fdir": "^6.5.0", "fdir": "^6.5.0",
@ -2128,7 +2125,6 @@
"resolved": "https://registry.npmjs.org/vue/-/vue-3.5.25.tgz", "resolved": "https://registry.npmjs.org/vue/-/vue-3.5.25.tgz",
"integrity": "sha512-YLVdgv2K13WJ6n+kD5owehKtEXwdwXuj2TTyJMsO7pSeKw2bfRNZGjhB7YzrpbMYj5b5QsUebHpOqR3R3ziy/g==", "integrity": "sha512-YLVdgv2K13WJ6n+kD5owehKtEXwdwXuj2TTyJMsO7pSeKw2bfRNZGjhB7YzrpbMYj5b5QsUebHpOqR3R3ziy/g==",
"license": "MIT", "license": "MIT",
"peer": true,
"dependencies": { "dependencies": {
"@vue/compiler-dom": "3.5.25", "@vue/compiler-dom": "3.5.25",
"@vue/compiler-sfc": "3.5.25", "@vue/compiler-sfc": "3.5.25",

View File

@ -4,7 +4,7 @@ import authState, { clearToken } from '../store/auth'
// 创建axios实例 // 创建axios实例
const service = axios.create({ const service = axios.create({
baseURL: import.meta.env.VITE_API_BASE_URL || 'http://localhost:5001', baseURL: import.meta.env.VITE_API_BASE_URL || '',
timeout: 300000, // 5分钟超时本体生成可能需要较长时间 timeout: 300000, // 5分钟超时本体生成可能需要较长时间
headers: { headers: {
'Content-Type': 'application/json' 'Content-Type': 'application/json'

View File

@ -14,7 +14,7 @@ for (const path in localeFiles) {
} }
} }
const savedLocale = localStorage.getItem('locale') || 'zh' const savedLocale = localStorage.getItem('locale') || 'ca'
const i18n = createI18n({ const i18n = createI18n({
legacy: false, legacy: false,