chore(i18n): replace all hardcoded Chinese strings with English in backend
Translate all Chinese comments, docstrings, log messages, error messages, and LLM prompt text to English across the entire backend codebase. Locale translation files (locales/*.json) are unchanged. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
e3943c7d7c
commit
7d172b9eec
|
|
@ -20,15 +20,15 @@ set -euo pipefail
|
||||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
|
||||||
# ── Carregar configuració ─────────────────────────────────────────────────────
|
# ── Carregar configuració ─────────────────────────────────────────────────────
|
||||||
CONFIG_FILE="${SCRIPT_DIR}/config.sh"
|
#CONFIG_FILE="${SCRIPT_DIR}/config.sh"
|
||||||
if [[ ! -f "$CONFIG_FILE" ]]; then
|
#if [[ ! -f "$CONFIG_FILE" ]]; then
|
||||||
echo "ERROR: No s'ha trobat azure/config.sh"
|
# echo "ERROR: No s'ha trobat azure/config.sh"
|
||||||
echo " Còpia l'exemple: cp azure/config.sh.example azure/config.sh"
|
# echo " Còpia l'exemple: cp azure/config.sh.example azure/config.sh"
|
||||||
echo " Després omple els valors i torna a executar."
|
# echo " Després omple els valors i torna a executar."
|
||||||
exit 1
|
# exit 1
|
||||||
fi
|
#fi
|
||||||
# shellcheck source=config.sh.example
|
# shellcheck source=config.sh.example
|
||||||
source "$CONFIG_FILE"
|
#source "$CONFIG_FILE"
|
||||||
|
|
||||||
# ── Validar variables obligatòries ───────────────────────────────────────────
|
# ── Validar variables obligatòries ───────────────────────────────────────────
|
||||||
REQUIRED_VARS=(
|
REQUIRED_VARS=(
|
||||||
|
|
|
||||||
|
|
@ -17,14 +17,14 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
||||||
|
|
||||||
# ── Carregar configuració ─────────────────────────────────────────────────────
|
# ── Carregar configuració ─────────────────────────────────────────────────────
|
||||||
CONFIG_FILE="${SCRIPT_DIR}/config.sh"
|
#CONFIG_FILE="${SCRIPT_DIR}/config.sh"
|
||||||
if [[ ! -f "$CONFIG_FILE" ]]; then
|
#if [[ ! -f "$CONFIG_FILE" ]]; then
|
||||||
echo "ERROR: No s'ha trobat azure/config.sh"
|
# echo "ERROR: No s'ha trobat azure/config.sh"
|
||||||
echo " Còpia l'exemple: cp azure/config.sh.example azure/config.sh"
|
# echo " Còpia l'exemple: cp azure/config.sh.example azure/config.sh"
|
||||||
exit 1
|
# exit 1
|
||||||
fi
|
#fi
|
||||||
# shellcheck source=config.sh.example
|
# shellcheck source=config.sh.example
|
||||||
source "$CONFIG_FILE"
|
#source "$CONFIG_FILE"
|
||||||
|
|
||||||
# ── Validar variables obligatòries ───────────────────────────────────────────
|
# ── Validar variables obligatòries ───────────────────────────────────────────
|
||||||
REQUIRED_VARS=(
|
REQUIRED_VARS=(
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,12 @@
|
||||||
"""
|
"""
|
||||||
MiroFish Backend - Flask应用工厂
|
MiroFish Backend - Flask application factory
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
# 抑制 multiprocessing resource_tracker 的警告(来自第三方库如 transformers)
|
# Suppress multiprocessing resource_tracker warnings (from third-party libraries like transformers)
|
||||||
# 需要在所有其他导入之前设置
|
# Must be set before all other imports
|
||||||
warnings.filterwarnings("ignore", message=".*resource_tracker.*")
|
warnings.filterwarnings("ignore", message=".*resource_tracker.*")
|
||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
|
|
@ -21,36 +21,36 @@ _PUBLIC_PATHS = {'/health', '/api/auth/login'}
|
||||||
|
|
||||||
|
|
||||||
def create_app(config_class=Config):
|
def create_app(config_class=Config):
|
||||||
"""Flask应用工厂函数"""
|
"""Flask application factory"""
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
app.config.from_object(config_class)
|
app.config.from_object(config_class)
|
||||||
|
|
||||||
# 设置JSON编码:确保中文直接显示(而不是 \uXXXX 格式)
|
# Configure JSON encoding: ensure non-ASCII characters are output directly (not as \uXXXX)
|
||||||
# Flask >= 2.3 使用 app.json.ensure_ascii,旧版本使用 JSON_AS_ASCII 配置
|
# Flask >= 2.3 uses app.json.ensure_ascii; older versions use JSON_AS_ASCII config
|
||||||
if hasattr(app, 'json') and hasattr(app.json, 'ensure_ascii'):
|
if hasattr(app, 'json') and hasattr(app.json, 'ensure_ascii'):
|
||||||
app.json.ensure_ascii = False
|
app.json.ensure_ascii = False
|
||||||
|
|
||||||
# 设置日志
|
# Set up logging
|
||||||
logger = setup_logger('mirofish')
|
logger = setup_logger('mirofish')
|
||||||
|
|
||||||
# 只在 reloader 子进程中打印启动信息(避免 debug 模式下打印两次)
|
# Only log startup info in the reloader subprocess (avoids double-printing in debug mode)
|
||||||
is_reloader_process = os.environ.get('WERKZEUG_RUN_MAIN') == 'true'
|
is_reloader_process = os.environ.get('WERKZEUG_RUN_MAIN') == 'true'
|
||||||
debug_mode = app.config.get('DEBUG', False)
|
debug_mode = app.config.get('DEBUG', False)
|
||||||
should_log_startup = not debug_mode or is_reloader_process
|
should_log_startup = not debug_mode or is_reloader_process
|
||||||
|
|
||||||
if should_log_startup:
|
if should_log_startup:
|
||||||
logger.info("=" * 50)
|
logger.info("=" * 50)
|
||||||
logger.info("MiroFish Backend 启动中...")
|
logger.info("MiroFish Backend starting...")
|
||||||
logger.info("=" * 50)
|
logger.info("=" * 50)
|
||||||
|
|
||||||
# 启用CORS
|
# Enable CORS
|
||||||
CORS(app, resources={r"/api/*": {"origins": "*"}})
|
CORS(app, resources={r"/api/*": {"origins": "*"}})
|
||||||
|
|
||||||
# 注册模拟进程清理函数(确保服务器关闭时终止所有模拟进程)
|
# Register simulation process cleanup (ensures all simulation processes are terminated on server shutdown)
|
||||||
from .services.simulation_runner import SimulationRunner
|
from .services.simulation_runner import SimulationRunner
|
||||||
SimulationRunner.register_cleanup()
|
SimulationRunner.register_cleanup()
|
||||||
if should_log_startup:
|
if should_log_startup:
|
||||||
logger.info("已注册模拟进程清理函数")
|
logger.info("Simulation process cleanup handler registered")
|
||||||
|
|
||||||
# Middleware d'autenticació JWT — s'executa ABANS del log_request (ordre FIFO)
|
# Middleware d'autenticació JWT — s'executa ABANS del log_request (ordre FIFO)
|
||||||
@app.before_request
|
@app.before_request
|
||||||
|
|
@ -70,28 +70,28 @@ def create_app(config_class=Config):
|
||||||
except jwt.InvalidTokenError:
|
except jwt.InvalidTokenError:
|
||||||
return jsonify({'success': False, 'error': 'Invalid token'}), 401
|
return jsonify({'success': False, 'error': 'Invalid token'}), 401
|
||||||
|
|
||||||
# 请求日志中间件
|
# Request logging middleware
|
||||||
@app.before_request
|
@app.before_request
|
||||||
def log_request():
|
def log_request():
|
||||||
logger = get_logger('mirofish.request')
|
logger = get_logger('mirofish.request')
|
||||||
logger.debug(f"请求: {request.method} {request.path}")
|
logger.debug(f"Request: {request.method} {request.path}")
|
||||||
if request.content_type and 'json' in request.content_type:
|
if request.content_type and 'json' in request.content_type:
|
||||||
logger.debug(f"请求体: {request.get_json(silent=True)}")
|
logger.debug(f"Request body: {request.get_json(silent=True)}")
|
||||||
|
|
||||||
@app.after_request
|
@app.after_request
|
||||||
def log_response(response):
|
def log_response(response):
|
||||||
logger = get_logger('mirofish.request')
|
logger = get_logger('mirofish.request')
|
||||||
logger.debug(f"响应: {response.status_code}")
|
logger.debug(f"Response: {response.status_code}")
|
||||||
return response
|
return response
|
||||||
|
|
||||||
# 注册蓝图 (auth primer, luego els existents)
|
# Register blueprints (auth first, then the rest)
|
||||||
from .api import graph_bp, simulation_bp, report_bp, auth_bp
|
from .api import graph_bp, simulation_bp, report_bp, auth_bp
|
||||||
app.register_blueprint(auth_bp, url_prefix='/api/auth')
|
app.register_blueprint(auth_bp, url_prefix='/api/auth')
|
||||||
app.register_blueprint(graph_bp, url_prefix='/api/graph')
|
app.register_blueprint(graph_bp, url_prefix='/api/graph')
|
||||||
app.register_blueprint(simulation_bp, url_prefix='/api/simulation')
|
app.register_blueprint(simulation_bp, url_prefix='/api/simulation')
|
||||||
app.register_blueprint(report_bp, url_prefix='/api/report')
|
app.register_blueprint(report_bp, url_prefix='/api/report')
|
||||||
|
|
||||||
# 健康检查
|
# Health check
|
||||||
@app.route('/health')
|
@app.route('/health')
|
||||||
def health():
|
def health():
|
||||||
return {'status': 'ok', 'service': 'MiroFish Backend'}
|
return {'status': 'ok', 'service': 'MiroFish Backend'}
|
||||||
|
|
@ -111,6 +111,6 @@ def create_app(config_class=Config):
|
||||||
return _send_file(_os.path.join(_dist, 'index.html'))
|
return _send_file(_os.path.join(_dist, 'index.html'))
|
||||||
|
|
||||||
if should_log_startup:
|
if should_log_startup:
|
||||||
logger.info("MiroFish Backend 启动完成")
|
logger.info("MiroFish Backend startup complete")
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
"""
|
"""
|
||||||
API路由模块
|
API routes module
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from flask import Blueprint
|
from flask import Blueprint
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"""
|
"""
|
||||||
图谱相关API路由
|
Graph-related API routes
|
||||||
采用项目上下文机制,服务端持久化状态
|
Uses project context mechanism with server-side persistent state
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
@ -19,24 +19,24 @@ from ..utils.locale import t, get_locale, set_locale
|
||||||
from ..models.task import TaskManager, TaskStatus
|
from ..models.task import TaskManager, TaskStatus
|
||||||
from ..models.project import ProjectManager, ProjectStatus
|
from ..models.project import ProjectManager, ProjectStatus
|
||||||
|
|
||||||
# 获取日志器
|
# Get logger
|
||||||
logger = get_logger('mirofish.api')
|
logger = get_logger('mirofish.api')
|
||||||
|
|
||||||
|
|
||||||
def allowed_file(filename: str) -> bool:
|
def allowed_file(filename: str) -> bool:
|
||||||
"""检查文件扩展名是否允许"""
|
"""Check if the file extension is allowed"""
|
||||||
if not filename or '.' not in filename:
|
if not filename or '.' not in filename:
|
||||||
return False
|
return False
|
||||||
ext = os.path.splitext(filename)[1].lower().lstrip('.')
|
ext = os.path.splitext(filename)[1].lower().lstrip('.')
|
||||||
return ext in Config.ALLOWED_EXTENSIONS
|
return ext in Config.ALLOWED_EXTENSIONS
|
||||||
|
|
||||||
|
|
||||||
# ============== 项目管理接口 ==============
|
# ============== Project management endpoints ==============
|
||||||
|
|
||||||
@graph_bp.route('/project/<project_id>', methods=['GET'])
|
@graph_bp.route('/project/<project_id>', methods=['GET'])
|
||||||
def get_project(project_id: str):
|
def get_project(project_id: str):
|
||||||
"""
|
"""
|
||||||
获取项目详情
|
Get project details
|
||||||
"""
|
"""
|
||||||
project = ProjectManager.get_project(project_id)
|
project = ProjectManager.get_project(project_id)
|
||||||
|
|
||||||
|
|
@ -55,7 +55,7 @@ def get_project(project_id: str):
|
||||||
@graph_bp.route('/project/list', methods=['GET'])
|
@graph_bp.route('/project/list', methods=['GET'])
|
||||||
def list_projects():
|
def list_projects():
|
||||||
"""
|
"""
|
||||||
列出所有项目
|
List all projects
|
||||||
"""
|
"""
|
||||||
limit = request.args.get('limit', 50, type=int)
|
limit = request.args.get('limit', 50, type=int)
|
||||||
projects = ProjectManager.list_projects(limit=limit)
|
projects = ProjectManager.list_projects(limit=limit)
|
||||||
|
|
@ -70,7 +70,7 @@ def list_projects():
|
||||||
@graph_bp.route('/project/<project_id>', methods=['DELETE'])
|
@graph_bp.route('/project/<project_id>', methods=['DELETE'])
|
||||||
def delete_project(project_id: str):
|
def delete_project(project_id: str):
|
||||||
"""
|
"""
|
||||||
删除项目
|
Delete a project
|
||||||
"""
|
"""
|
||||||
success = ProjectManager.delete_project(project_id)
|
success = ProjectManager.delete_project(project_id)
|
||||||
|
|
||||||
|
|
@ -89,7 +89,7 @@ def delete_project(project_id: str):
|
||||||
@graph_bp.route('/project/<project_id>/reset', methods=['POST'])
|
@graph_bp.route('/project/<project_id>/reset', methods=['POST'])
|
||||||
def reset_project(project_id: str):
|
def reset_project(project_id: str):
|
||||||
"""
|
"""
|
||||||
重置项目状态(用于重新构建图谱)
|
Reset project status (used to rebuild the graph)
|
||||||
"""
|
"""
|
||||||
project = ProjectManager.get_project(project_id)
|
project = ProjectManager.get_project(project_id)
|
||||||
|
|
||||||
|
|
@ -99,7 +99,7 @@ def reset_project(project_id: str):
|
||||||
"error": t('api.projectNotFound', id=project_id)
|
"error": t('api.projectNotFound', id=project_id)
|
||||||
}), 404
|
}), 404
|
||||||
|
|
||||||
# 重置到本体已生成状态
|
# Reset to ontology-generated status
|
||||||
if project.ontology:
|
if project.ontology:
|
||||||
project.status = ProjectStatus.ONTOLOGY_GENERATED
|
project.status = ProjectStatus.ONTOLOGY_GENERATED
|
||||||
else:
|
else:
|
||||||
|
|
@ -117,22 +117,22 @@ def reset_project(project_id: str):
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
# ============== 接口1:上传文件并生成本体 ==============
|
# ============== Endpoint 1: Upload files and generate ontology ==============
|
||||||
|
|
||||||
@graph_bp.route('/ontology/generate', methods=['POST'])
|
@graph_bp.route('/ontology/generate', methods=['POST'])
|
||||||
def generate_ontology():
|
def generate_ontology():
|
||||||
"""
|
"""
|
||||||
接口1:上传文件,分析生成本体定义
|
Endpoint 1: Upload files and generate ontology definition
|
||||||
|
|
||||||
请求方式:multipart/form-data
|
Request method: multipart/form-data
|
||||||
|
|
||||||
参数:
|
Parameters:
|
||||||
files: 上传的文件(PDF/MD/TXT),可多个
|
files: Uploaded files (PDF/MD/TXT), multiple allowed
|
||||||
simulation_requirement: 模拟需求描述(必填)
|
simulation_requirement: Simulation requirement description (required)
|
||||||
project_name: 项目名称(可选)
|
project_name: Project name (optional)
|
||||||
additional_context: 额外说明(可选)
|
additional_context: Additional context (optional)
|
||||||
|
|
||||||
返回:
|
Returns:
|
||||||
{
|
{
|
||||||
"success": true,
|
"success": true,
|
||||||
"data": {
|
"data": {
|
||||||
|
|
@ -148,15 +148,15 @@ def generate_ontology():
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
logger.info("=== 开始生成本体定义 ===")
|
logger.info("=== Starting ontology generation ===")
|
||||||
|
|
||||||
# 获取参数
|
# Get parameters
|
||||||
simulation_requirement = request.form.get('simulation_requirement', '')
|
simulation_requirement = request.form.get('simulation_requirement', '')
|
||||||
project_name = request.form.get('project_name', 'Unnamed Project')
|
project_name = request.form.get('project_name', 'Unnamed Project')
|
||||||
additional_context = request.form.get('additional_context', '')
|
additional_context = request.form.get('additional_context', '')
|
||||||
|
|
||||||
logger.debug(f"项目名称: {project_name}")
|
logger.debug(f"Project name: {project_name}")
|
||||||
logger.debug(f"模拟需求: {simulation_requirement[:100]}...")
|
logger.debug(f"Simulation requirement: {simulation_requirement[:100]}...")
|
||||||
|
|
||||||
if not simulation_requirement:
|
if not simulation_requirement:
|
||||||
return jsonify({
|
return jsonify({
|
||||||
|
|
@ -164,7 +164,7 @@ def generate_ontology():
|
||||||
"error": t('api.requireSimulationRequirement')
|
"error": t('api.requireSimulationRequirement')
|
||||||
}), 400
|
}), 400
|
||||||
|
|
||||||
# 获取上传的文件
|
# Get uploaded files
|
||||||
uploaded_files = request.files.getlist('files')
|
uploaded_files = request.files.getlist('files')
|
||||||
if not uploaded_files or all(not f.filename for f in uploaded_files):
|
if not uploaded_files or all(not f.filename for f in uploaded_files):
|
||||||
return jsonify({
|
return jsonify({
|
||||||
|
|
@ -172,18 +172,18 @@ def generate_ontology():
|
||||||
"error": t('api.requireFileUpload')
|
"error": t('api.requireFileUpload')
|
||||||
}), 400
|
}), 400
|
||||||
|
|
||||||
# 创建项目
|
# Create project
|
||||||
project = ProjectManager.create_project(name=project_name)
|
project = ProjectManager.create_project(name=project_name)
|
||||||
project.simulation_requirement = simulation_requirement
|
project.simulation_requirement = simulation_requirement
|
||||||
logger.info(f"创建项目: {project.project_id}")
|
logger.info(f"Project created: {project.project_id}")
|
||||||
|
|
||||||
# 保存文件并提取文本
|
# Save files and extract text
|
||||||
document_texts = []
|
document_texts = []
|
||||||
all_text = ""
|
all_text = ""
|
||||||
|
|
||||||
for file in uploaded_files:
|
for file in uploaded_files:
|
||||||
if file and file.filename and allowed_file(file.filename):
|
if file and file.filename and allowed_file(file.filename):
|
||||||
# 保存文件到项目目录
|
# Save file to project directory
|
||||||
file_info = ProjectManager.save_file_to_project(
|
file_info = ProjectManager.save_file_to_project(
|
||||||
project.project_id,
|
project.project_id,
|
||||||
file,
|
file,
|
||||||
|
|
@ -194,7 +194,7 @@ def generate_ontology():
|
||||||
"size": file_info["size"]
|
"size": file_info["size"]
|
||||||
})
|
})
|
||||||
|
|
||||||
# 提取文本
|
# Extract text
|
||||||
text = FileParser.extract_text(file_info["path"])
|
text = FileParser.extract_text(file_info["path"])
|
||||||
text = TextProcessor.preprocess_text(text)
|
text = TextProcessor.preprocess_text(text)
|
||||||
document_texts.append(text)
|
document_texts.append(text)
|
||||||
|
|
@ -207,13 +207,13 @@ def generate_ontology():
|
||||||
"error": t('api.noDocProcessed')
|
"error": t('api.noDocProcessed')
|
||||||
}), 400
|
}), 400
|
||||||
|
|
||||||
# 保存提取的文本
|
# Save extracted text
|
||||||
project.total_text_length = len(all_text)
|
project.total_text_length = len(all_text)
|
||||||
ProjectManager.save_extracted_text(project.project_id, all_text)
|
ProjectManager.save_extracted_text(project.project_id, all_text)
|
||||||
logger.info(f"文本提取完成,共 {len(all_text)} 字符")
|
logger.info(f"Text extraction complete, total {len(all_text)} characters")
|
||||||
|
|
||||||
# 生成本体
|
# Generate ontology
|
||||||
logger.info("调用 LLM 生成本体定义...")
|
logger.info("Calling LLM to generate ontology definition...")
|
||||||
generator = OntologyGenerator()
|
generator = OntologyGenerator()
|
||||||
ontology = generator.generate(
|
ontology = generator.generate(
|
||||||
document_texts=document_texts,
|
document_texts=document_texts,
|
||||||
|
|
@ -221,10 +221,10 @@ def generate_ontology():
|
||||||
additional_context=additional_context if additional_context else None
|
additional_context=additional_context if additional_context else None
|
||||||
)
|
)
|
||||||
|
|
||||||
# 保存本体到项目
|
# Save ontology to project
|
||||||
entity_count = len(ontology.get("entity_types", []))
|
entity_count = len(ontology.get("entity_types", []))
|
||||||
edge_count = len(ontology.get("edge_types", []))
|
edge_count = len(ontology.get("edge_types", []))
|
||||||
logger.info(f"本体生成完成: {entity_count} 个实体类型, {edge_count} 个关系类型")
|
logger.info(f"Ontology generation complete: {entity_count} entity types, {edge_count} relationship types")
|
||||||
|
|
||||||
project.ontology = {
|
project.ontology = {
|
||||||
"entity_types": ontology.get("entity_types", []),
|
"entity_types": ontology.get("entity_types", []),
|
||||||
|
|
@ -233,7 +233,7 @@ def generate_ontology():
|
||||||
project.analysis_summary = ontology.get("analysis_summary", "")
|
project.analysis_summary = ontology.get("analysis_summary", "")
|
||||||
project.status = ProjectStatus.ONTOLOGY_GENERATED
|
project.status = ProjectStatus.ONTOLOGY_GENERATED
|
||||||
ProjectManager.save_project(project)
|
ProjectManager.save_project(project)
|
||||||
logger.info(f"=== 本体生成完成 === 项目ID: {project.project_id}")
|
logger.info(f"=== Ontology generation complete === Project ID: {project.project_id}")
|
||||||
|
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"success": True,
|
"success": True,
|
||||||
|
|
@ -255,49 +255,49 @@ def generate_ontology():
|
||||||
}), 500
|
}), 500
|
||||||
|
|
||||||
|
|
||||||
# ============== 接口2:构建图谱 ==============
|
# ============== Endpoint 2: Build graph ==============
|
||||||
|
|
||||||
@graph_bp.route('/build', methods=['POST'])
|
@graph_bp.route('/build', methods=['POST'])
|
||||||
def build_graph():
|
def build_graph():
|
||||||
"""
|
"""
|
||||||
接口2:根据project_id构建图谱
|
Endpoint 2: Build graph from project_id
|
||||||
|
|
||||||
请求(JSON):
|
Request (JSON):
|
||||||
{
|
{
|
||||||
"project_id": "proj_xxxx", // 必填,来自接口1
|
"project_id": "proj_xxxx", // required, from endpoint 1
|
||||||
"graph_name": "图谱名称", // 可选
|
"graph_name": "Graph name", // optional
|
||||||
"chunk_size": 500, // 可选,默认500
|
"chunk_size": 500, // optional, default 500
|
||||||
"chunk_overlap": 50 // 可选,默认50
|
"chunk_overlap": 50 // optional, default 50
|
||||||
}
|
}
|
||||||
|
|
||||||
返回:
|
Returns:
|
||||||
{
|
{
|
||||||
"success": true,
|
"success": true,
|
||||||
"data": {
|
"data": {
|
||||||
"project_id": "proj_xxxx",
|
"project_id": "proj_xxxx",
|
||||||
"task_id": "task_xxxx",
|
"task_id": "task_xxxx",
|
||||||
"message": "图谱构建任务已启动"
|
"message": "Graph build task started"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
logger.info("=== 开始构建图谱 ===")
|
logger.info("=== Starting graph build ===")
|
||||||
|
|
||||||
# 检查配置
|
# Check configuration
|
||||||
errors = []
|
errors = []
|
||||||
if not Config.ZEP_API_KEY:
|
if not Config.ZEP_API_KEY:
|
||||||
errors.append(t('api.zepApiKeyMissing'))
|
errors.append(t('api.zepApiKeyMissing'))
|
||||||
if errors:
|
if errors:
|
||||||
logger.error(f"配置错误: {errors}")
|
logger.error(f"Configuration error: {errors}")
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": t('api.configError', details="; ".join(errors))
|
"error": t('api.configError', details="; ".join(errors))
|
||||||
}), 500
|
}), 500
|
||||||
|
|
||||||
# 解析请求
|
# Parse request
|
||||||
data = request.get_json() or {}
|
data = request.get_json() or {}
|
||||||
project_id = data.get('project_id')
|
project_id = data.get('project_id')
|
||||||
logger.debug(f"请求参数: project_id={project_id}")
|
logger.debug(f"Request parameters: project_id={project_id}")
|
||||||
|
|
||||||
if not project_id:
|
if not project_id:
|
||||||
return jsonify({
|
return jsonify({
|
||||||
|
|
@ -305,7 +305,7 @@ def build_graph():
|
||||||
"error": t('api.requireProjectId')
|
"error": t('api.requireProjectId')
|
||||||
}), 400
|
}), 400
|
||||||
|
|
||||||
# 获取项目
|
# Get project
|
||||||
project = ProjectManager.get_project(project_id)
|
project = ProjectManager.get_project(project_id)
|
||||||
if not project:
|
if not project:
|
||||||
return jsonify({
|
return jsonify({
|
||||||
|
|
@ -313,8 +313,8 @@ def build_graph():
|
||||||
"error": t('api.projectNotFound', id=project_id)
|
"error": t('api.projectNotFound', id=project_id)
|
||||||
}), 404
|
}), 404
|
||||||
|
|
||||||
# 检查项目状态
|
# Check project status
|
||||||
force = data.get('force', False) # 强制重新构建
|
force = data.get('force', False) # Force rebuild
|
||||||
|
|
||||||
if project.status == ProjectStatus.CREATED:
|
if project.status == ProjectStatus.CREATED:
|
||||||
return jsonify({
|
return jsonify({
|
||||||
|
|
@ -329,23 +329,23 @@ def build_graph():
|
||||||
"task_id": project.graph_build_task_id
|
"task_id": project.graph_build_task_id
|
||||||
}), 400
|
}), 400
|
||||||
|
|
||||||
# 如果强制重建,重置状态
|
# If force rebuild, reset status
|
||||||
if force and project.status in [ProjectStatus.GRAPH_BUILDING, ProjectStatus.FAILED, ProjectStatus.GRAPH_COMPLETED]:
|
if force and project.status in [ProjectStatus.GRAPH_BUILDING, ProjectStatus.FAILED, ProjectStatus.GRAPH_COMPLETED]:
|
||||||
project.status = ProjectStatus.ONTOLOGY_GENERATED
|
project.status = ProjectStatus.ONTOLOGY_GENERATED
|
||||||
project.graph_id = None
|
project.graph_id = None
|
||||||
project.graph_build_task_id = None
|
project.graph_build_task_id = None
|
||||||
project.error = None
|
project.error = None
|
||||||
|
|
||||||
# 获取配置
|
# Get configuration
|
||||||
graph_name = data.get('graph_name', project.name or 'MiroFish Graph')
|
graph_name = data.get('graph_name', project.name or 'MiroFish Graph')
|
||||||
chunk_size = data.get('chunk_size', project.chunk_size or Config.DEFAULT_CHUNK_SIZE)
|
chunk_size = data.get('chunk_size', project.chunk_size or Config.DEFAULT_CHUNK_SIZE)
|
||||||
chunk_overlap = data.get('chunk_overlap', project.chunk_overlap or Config.DEFAULT_CHUNK_OVERLAP)
|
chunk_overlap = data.get('chunk_overlap', project.chunk_overlap or Config.DEFAULT_CHUNK_OVERLAP)
|
||||||
|
|
||||||
# 更新项目配置
|
# Update project configuration
|
||||||
project.chunk_size = chunk_size
|
project.chunk_size = chunk_size
|
||||||
project.chunk_overlap = chunk_overlap
|
project.chunk_overlap = chunk_overlap
|
||||||
|
|
||||||
# 获取提取的文本
|
# Get extracted text
|
||||||
text = ProjectManager.get_extracted_text(project_id)
|
text = ProjectManager.get_extracted_text(project_id)
|
||||||
if not text:
|
if not text:
|
||||||
return jsonify({
|
return jsonify({
|
||||||
|
|
@ -353,7 +353,7 @@ def build_graph():
|
||||||
"error": t('api.textNotFound')
|
"error": t('api.textNotFound')
|
||||||
}), 400
|
}), 400
|
||||||
|
|
||||||
# 获取本体
|
# Get ontology
|
||||||
ontology = project.ontology
|
ontology = project.ontology
|
||||||
if not ontology:
|
if not ontology:
|
||||||
return jsonify({
|
return jsonify({
|
||||||
|
|
@ -361,12 +361,12 @@ def build_graph():
|
||||||
"error": t('api.ontologyNotFound')
|
"error": t('api.ontologyNotFound')
|
||||||
}), 400
|
}), 400
|
||||||
|
|
||||||
# 创建异步任务
|
# Create async task
|
||||||
task_manager = TaskManager()
|
task_manager = TaskManager()
|
||||||
task_id = task_manager.create_task(f"构建图谱: {graph_name}")
|
task_id = task_manager.create_task(f"Build graph: {graph_name}")
|
||||||
logger.info(f"创建图谱构建任务: task_id={task_id}, project_id={project_id}")
|
logger.info(f"Graph build task created: task_id={task_id}, project_id={project_id}")
|
||||||
|
|
||||||
# 更新项目状态
|
# Update project status
|
||||||
project.status = ProjectStatus.GRAPH_BUILDING
|
project.status = ProjectStatus.GRAPH_BUILDING
|
||||||
project.graph_build_task_id = task_id
|
project.graph_build_task_id = task_id
|
||||||
ProjectManager.save_project(project)
|
ProjectManager.save_project(project)
|
||||||
|
|
@ -374,22 +374,22 @@ def build_graph():
|
||||||
# Capture locale before spawning background thread
|
# Capture locale before spawning background thread
|
||||||
current_locale = get_locale()
|
current_locale = get_locale()
|
||||||
|
|
||||||
# 启动后台任务
|
# Start background task
|
||||||
def build_task():
|
def build_task():
|
||||||
set_locale(current_locale)
|
set_locale(current_locale)
|
||||||
build_logger = get_logger('mirofish.build')
|
build_logger = get_logger('mirofish.build')
|
||||||
try:
|
try:
|
||||||
build_logger.info(f"[{task_id}] 开始构建图谱...")
|
build_logger.info(f"[{task_id}] Starting graph build...")
|
||||||
task_manager.update_task(
|
task_manager.update_task(
|
||||||
task_id,
|
task_id,
|
||||||
status=TaskStatus.PROCESSING,
|
status=TaskStatus.PROCESSING,
|
||||||
message=t('progress.initGraphService')
|
message=t('progress.initGraphService')
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建图谱构建服务
|
# Create graph builder service
|
||||||
builder = GraphBuilderService(api_key=Config.ZEP_API_KEY)
|
builder = GraphBuilderService(api_key=Config.ZEP_API_KEY)
|
||||||
|
|
||||||
# 分块
|
# Split into chunks
|
||||||
task_manager.update_task(
|
task_manager.update_task(
|
||||||
task_id,
|
task_id,
|
||||||
message=t('progress.textChunking'),
|
message=t('progress.textChunking'),
|
||||||
|
|
@ -402,7 +402,7 @@ def build_graph():
|
||||||
)
|
)
|
||||||
total_chunks = len(chunks)
|
total_chunks = len(chunks)
|
||||||
|
|
||||||
# 创建图谱
|
# Create graph
|
||||||
task_manager.update_task(
|
task_manager.update_task(
|
||||||
task_id,
|
task_id,
|
||||||
message=t('progress.creatingZepGraph'),
|
message=t('progress.creatingZepGraph'),
|
||||||
|
|
@ -410,11 +410,11 @@ def build_graph():
|
||||||
)
|
)
|
||||||
graph_id = builder.create_graph(name=graph_name)
|
graph_id = builder.create_graph(name=graph_name)
|
||||||
|
|
||||||
# 更新项目的graph_id
|
# Update project graph_id
|
||||||
project.graph_id = graph_id
|
project.graph_id = graph_id
|
||||||
ProjectManager.save_project(project)
|
ProjectManager.save_project(project)
|
||||||
|
|
||||||
# 设置本体
|
# Set ontology
|
||||||
task_manager.update_task(
|
task_manager.update_task(
|
||||||
task_id,
|
task_id,
|
||||||
message=t('progress.settingOntology'),
|
message=t('progress.settingOntology'),
|
||||||
|
|
@ -422,7 +422,7 @@ def build_graph():
|
||||||
)
|
)
|
||||||
builder.set_ontology(graph_id, ontology)
|
builder.set_ontology(graph_id, ontology)
|
||||||
|
|
||||||
# 添加文本(progress_callback 签名是 (msg, progress_ratio))
|
# Add text (progress_callback signature: (msg, progress_ratio))
|
||||||
def add_progress_callback(msg, progress_ratio):
|
def add_progress_callback(msg, progress_ratio):
|
||||||
progress = 15 + int(progress_ratio * 40) # 15% - 55%
|
progress = 15 + int(progress_ratio * 40) # 15% - 55%
|
||||||
task_manager.update_task(
|
task_manager.update_task(
|
||||||
|
|
@ -444,7 +444,7 @@ def build_graph():
|
||||||
progress_callback=add_progress_callback
|
progress_callback=add_progress_callback
|
||||||
)
|
)
|
||||||
|
|
||||||
# 等待Zep处理完成(查询每个episode的processed状态)
|
# Wait for Zep processing to complete (poll each episode's processed status)
|
||||||
task_manager.update_task(
|
task_manager.update_task(
|
||||||
task_id,
|
task_id,
|
||||||
message=t('progress.waitingZepProcess'),
|
message=t('progress.waitingZepProcess'),
|
||||||
|
|
@ -461,7 +461,7 @@ def build_graph():
|
||||||
|
|
||||||
builder._wait_for_episodes(episode_uuids, wait_progress_callback)
|
builder._wait_for_episodes(episode_uuids, wait_progress_callback)
|
||||||
|
|
||||||
# 获取图谱数据
|
# Fetch graph data
|
||||||
task_manager.update_task(
|
task_manager.update_task(
|
||||||
task_id,
|
task_id,
|
||||||
message=t('progress.fetchingGraphData'),
|
message=t('progress.fetchingGraphData'),
|
||||||
|
|
@ -469,15 +469,15 @@ def build_graph():
|
||||||
)
|
)
|
||||||
graph_data = builder.get_graph_data(graph_id)
|
graph_data = builder.get_graph_data(graph_id)
|
||||||
|
|
||||||
# 更新项目状态
|
# Update project status
|
||||||
project.status = ProjectStatus.GRAPH_COMPLETED
|
project.status = ProjectStatus.GRAPH_COMPLETED
|
||||||
ProjectManager.save_project(project)
|
ProjectManager.save_project(project)
|
||||||
|
|
||||||
node_count = graph_data.get("node_count", 0)
|
node_count = graph_data.get("node_count", 0)
|
||||||
edge_count = graph_data.get("edge_count", 0)
|
edge_count = graph_data.get("edge_count", 0)
|
||||||
build_logger.info(f"[{task_id}] 图谱构建完成: graph_id={graph_id}, 节点={node_count}, 边={edge_count}")
|
build_logger.info(f"[{task_id}] Graph build complete: graph_id={graph_id}, nodes={node_count}, edges={edge_count}")
|
||||||
|
|
||||||
# 完成
|
# Complete
|
||||||
task_manager.update_task(
|
task_manager.update_task(
|
||||||
task_id,
|
task_id,
|
||||||
status=TaskStatus.COMPLETED,
|
status=TaskStatus.COMPLETED,
|
||||||
|
|
@ -493,8 +493,8 @@ def build_graph():
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 更新项目状态为失败
|
# Update project status to failed
|
||||||
build_logger.error(f"[{task_id}] 图谱构建失败: {str(e)}")
|
build_logger.error(f"[{task_id}] Graph build failed: {str(e)}")
|
||||||
build_logger.debug(traceback.format_exc())
|
build_logger.debug(traceback.format_exc())
|
||||||
|
|
||||||
project.status = ProjectStatus.FAILED
|
project.status = ProjectStatus.FAILED
|
||||||
|
|
@ -508,7 +508,7 @@ def build_graph():
|
||||||
error=traceback.format_exc()
|
error=traceback.format_exc()
|
||||||
)
|
)
|
||||||
|
|
||||||
# 启动后台线程
|
# Start background thread
|
||||||
thread = threading.Thread(target=build_task, daemon=True)
|
thread = threading.Thread(target=build_task, daemon=True)
|
||||||
thread.start()
|
thread.start()
|
||||||
|
|
||||||
|
|
@ -529,12 +529,12 @@ def build_graph():
|
||||||
}), 500
|
}), 500
|
||||||
|
|
||||||
|
|
||||||
# ============== 任务查询接口 ==============
|
# ============== Task query endpoints ==============
|
||||||
|
|
||||||
@graph_bp.route('/task/<task_id>', methods=['GET'])
|
@graph_bp.route('/task/<task_id>', methods=['GET'])
|
||||||
def get_task(task_id: str):
|
def get_task(task_id: str):
|
||||||
"""
|
"""
|
||||||
查询任务状态
|
Query task status
|
||||||
"""
|
"""
|
||||||
task = TaskManager().get_task(task_id)
|
task = TaskManager().get_task(task_id)
|
||||||
|
|
||||||
|
|
@ -553,7 +553,7 @@ def get_task(task_id: str):
|
||||||
@graph_bp.route('/tasks', methods=['GET'])
|
@graph_bp.route('/tasks', methods=['GET'])
|
||||||
def list_tasks():
|
def list_tasks():
|
||||||
"""
|
"""
|
||||||
列出所有任务
|
List all tasks
|
||||||
"""
|
"""
|
||||||
tasks = TaskManager().list_tasks()
|
tasks = TaskManager().list_tasks()
|
||||||
|
|
||||||
|
|
@ -564,12 +564,12 @@ def list_tasks():
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
# ============== 图谱数据接口 ==============
|
# ============== Graph data endpoints ==============
|
||||||
|
|
||||||
@graph_bp.route('/data/<graph_id>', methods=['GET'])
|
@graph_bp.route('/data/<graph_id>', methods=['GET'])
|
||||||
def get_graph_data(graph_id: str):
|
def get_graph_data(graph_id: str):
|
||||||
"""
|
"""
|
||||||
获取图谱数据(节点和边)
|
Get graph data (nodes and edges)
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if not Config.ZEP_API_KEY:
|
if not Config.ZEP_API_KEY:
|
||||||
|
|
@ -597,7 +597,7 @@ def get_graph_data(graph_id: str):
|
||||||
@graph_bp.route('/delete/<graph_id>', methods=['DELETE'])
|
@graph_bp.route('/delete/<graph_id>', methods=['DELETE'])
|
||||||
def delete_graph(graph_id: str):
|
def delete_graph(graph_id: str):
|
||||||
"""
|
"""
|
||||||
删除Zep图谱
|
Delete a Zep graph
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if not Config.ZEP_API_KEY:
|
if not Config.ZEP_API_KEY:
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"""
|
"""
|
||||||
Report API路由
|
Report API routes
|
||||||
提供模拟报告生成、获取、对话等接口
|
Provides simulation report generation, retrieval, and chat endpoints
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
@ -20,30 +20,30 @@ from ..utils.locale import t, get_locale, set_locale
|
||||||
logger = get_logger('mirofish.api.report')
|
logger = get_logger('mirofish.api.report')
|
||||||
|
|
||||||
|
|
||||||
# ============== 报告生成接口 ==============
|
# ============== Report generation endpoints ==============
|
||||||
|
|
||||||
@report_bp.route('/generate', methods=['POST'])
|
@report_bp.route('/generate', methods=['POST'])
|
||||||
def generate_report():
|
def generate_report():
|
||||||
"""
|
"""
|
||||||
生成模拟分析报告(异步任务)
|
Generate a simulation analysis report (async task)
|
||||||
|
|
||||||
这是一个耗时操作,接口会立即返回task_id,
|
This is a long-running operation; the endpoint returns task_id immediately.
|
||||||
使用 GET /api/report/generate/status 查询进度
|
Use GET /api/report/generate/status to poll progress.
|
||||||
|
|
||||||
请求(JSON):
|
Request (JSON):
|
||||||
{
|
{
|
||||||
"simulation_id": "sim_xxxx", // 必填,模拟ID
|
"simulation_id": "sim_xxxx", // required, simulation ID
|
||||||
"force_regenerate": false // 可选,强制重新生成
|
"force_regenerate": false // optional, force regeneration
|
||||||
}
|
}
|
||||||
|
|
||||||
返回:
|
Returns:
|
||||||
{
|
{
|
||||||
"success": true,
|
"success": true,
|
||||||
"data": {
|
"data": {
|
||||||
"simulation_id": "sim_xxxx",
|
"simulation_id": "sim_xxxx",
|
||||||
"task_id": "task_xxxx",
|
"task_id": "task_xxxx",
|
||||||
"status": "generating",
|
"status": "generating",
|
||||||
"message": "报告生成任务已启动"
|
"message": "Report generation task started"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
@ -59,7 +59,7 @@ def generate_report():
|
||||||
|
|
||||||
force_regenerate = data.get('force_regenerate', False)
|
force_regenerate = data.get('force_regenerate', False)
|
||||||
|
|
||||||
# 获取模拟信息
|
# Get simulation info
|
||||||
manager = SimulationManager()
|
manager = SimulationManager()
|
||||||
state = manager.get_simulation(simulation_id)
|
state = manager.get_simulation(simulation_id)
|
||||||
|
|
||||||
|
|
@ -69,7 +69,7 @@ def generate_report():
|
||||||
"error": t('api.simulationNotFound', id=simulation_id)
|
"error": t('api.simulationNotFound', id=simulation_id)
|
||||||
}), 404
|
}), 404
|
||||||
|
|
||||||
# 检查是否已有报告
|
# Check if a report already exists
|
||||||
if not force_regenerate:
|
if not force_regenerate:
|
||||||
existing_report = ReportManager.get_report_by_simulation(simulation_id)
|
existing_report = ReportManager.get_report_by_simulation(simulation_id)
|
||||||
if existing_report and existing_report.status == ReportStatus.COMPLETED:
|
if existing_report and existing_report.status == ReportStatus.COMPLETED:
|
||||||
|
|
@ -84,7 +84,7 @@ def generate_report():
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
# 获取项目信息
|
# Get project info
|
||||||
project = ProjectManager.get_project(state.project_id)
|
project = ProjectManager.get_project(state.project_id)
|
||||||
if not project:
|
if not project:
|
||||||
return jsonify({
|
return jsonify({
|
||||||
|
|
@ -106,11 +106,11 @@ def generate_report():
|
||||||
"error": t('api.missingSimRequirement')
|
"error": t('api.missingSimRequirement')
|
||||||
}), 400
|
}), 400
|
||||||
|
|
||||||
# 提前生成 report_id,以便立即返回给前端
|
# Pre-generate report_id so it can be returned immediately
|
||||||
import uuid
|
import uuid
|
||||||
report_id = f"report_{uuid.uuid4().hex[:12]}"
|
report_id = f"report_{uuid.uuid4().hex[:12]}"
|
||||||
|
|
||||||
# 创建异步任务
|
# Create async task
|
||||||
task_manager = TaskManager()
|
task_manager = TaskManager()
|
||||||
task_id = task_manager.create_task(
|
task_id = task_manager.create_task(
|
||||||
task_type="report_generate",
|
task_type="report_generate",
|
||||||
|
|
@ -124,7 +124,7 @@ def generate_report():
|
||||||
# Capture locale before spawning background thread
|
# Capture locale before spawning background thread
|
||||||
current_locale = get_locale()
|
current_locale = get_locale()
|
||||||
|
|
||||||
# 定义后台任务
|
# Define background task
|
||||||
def run_generate():
|
def run_generate():
|
||||||
set_locale(current_locale)
|
set_locale(current_locale)
|
||||||
try:
|
try:
|
||||||
|
|
@ -135,14 +135,14 @@ def generate_report():
|
||||||
message=t('api.initReportAgent')
|
message=t('api.initReportAgent')
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建Report Agent
|
# Create Report Agent
|
||||||
agent = ReportAgent(
|
agent = ReportAgent(
|
||||||
graph_id=graph_id,
|
graph_id=graph_id,
|
||||||
simulation_id=simulation_id,
|
simulation_id=simulation_id,
|
||||||
simulation_requirement=simulation_requirement
|
simulation_requirement=simulation_requirement
|
||||||
)
|
)
|
||||||
|
|
||||||
# 进度回调
|
# Progress callback
|
||||||
def progress_callback(stage, progress, message):
|
def progress_callback(stage, progress, message):
|
||||||
task_manager.update_task(
|
task_manager.update_task(
|
||||||
task_id,
|
task_id,
|
||||||
|
|
@ -150,13 +150,13 @@ def generate_report():
|
||||||
message=f"[{stage}] {message}"
|
message=f"[{stage}] {message}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 生成报告(传入预先生成的 report_id)
|
# Generate report (pass pre-generated report_id)
|
||||||
report = agent.generate_report(
|
report = agent.generate_report(
|
||||||
progress_callback=progress_callback,
|
progress_callback=progress_callback,
|
||||||
report_id=report_id
|
report_id=report_id
|
||||||
)
|
)
|
||||||
|
|
||||||
# 保存报告
|
# Save report
|
||||||
ReportManager.save_report(report)
|
ReportManager.save_report(report)
|
||||||
|
|
||||||
if report.status == ReportStatus.COMPLETED:
|
if report.status == ReportStatus.COMPLETED:
|
||||||
|
|
@ -172,10 +172,10 @@ def generate_report():
|
||||||
task_manager.fail_task(task_id, report.error or t('api.reportGenerateFailed'))
|
task_manager.fail_task(task_id, report.error or t('api.reportGenerateFailed'))
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"报告生成失败: {str(e)}")
|
logger.error(f"Report generation failed: {str(e)}")
|
||||||
task_manager.fail_task(task_id, str(e))
|
task_manager.fail_task(task_id, str(e))
|
||||||
|
|
||||||
# 启动后台线程
|
# Start background thread
|
||||||
thread = threading.Thread(target=run_generate, daemon=True)
|
thread = threading.Thread(target=run_generate, daemon=True)
|
||||||
thread.start()
|
thread.start()
|
||||||
|
|
||||||
|
|
@ -192,7 +192,7 @@ def generate_report():
|
||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"启动报告生成任务失败: {str(e)}")
|
logger.error(f"Failed to start report generation task: {str(e)}")
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
|
|
@ -203,15 +203,15 @@ def generate_report():
|
||||||
@report_bp.route('/generate/status', methods=['POST'])
|
@report_bp.route('/generate/status', methods=['POST'])
|
||||||
def get_generate_status():
|
def get_generate_status():
|
||||||
"""
|
"""
|
||||||
查询报告生成任务进度
|
Query report generation task progress
|
||||||
|
|
||||||
请求(JSON):
|
Request (JSON):
|
||||||
{
|
{
|
||||||
"task_id": "task_xxxx", // 可选,generate返回的task_id
|
"task_id": "task_xxxx", // optional, task_id from generate
|
||||||
"simulation_id": "sim_xxxx" // 可选,模拟ID
|
"simulation_id": "sim_xxxx" // optional, simulation ID
|
||||||
}
|
}
|
||||||
|
|
||||||
返回:
|
Returns:
|
||||||
{
|
{
|
||||||
"success": true,
|
"success": true,
|
||||||
"data": {
|
"data": {
|
||||||
|
|
@ -228,7 +228,7 @@ def get_generate_status():
|
||||||
task_id = data.get('task_id')
|
task_id = data.get('task_id')
|
||||||
simulation_id = data.get('simulation_id')
|
simulation_id = data.get('simulation_id')
|
||||||
|
|
||||||
# 如果提供了simulation_id,先检查是否已有完成的报告
|
# If simulation_id is provided, check whether a completed report exists
|
||||||
if simulation_id:
|
if simulation_id:
|
||||||
existing_report = ReportManager.get_report_by_simulation(simulation_id)
|
existing_report = ReportManager.get_report_by_simulation(simulation_id)
|
||||||
if existing_report and existing_report.status == ReportStatus.COMPLETED:
|
if existing_report and existing_report.status == ReportStatus.COMPLETED:
|
||||||
|
|
@ -265,21 +265,21 @@ def get_generate_status():
|
||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"查询任务状态失败: {str(e)}")
|
logger.error(f"Failed to query task status: {str(e)}")
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": str(e)
|
"error": str(e)
|
||||||
}), 500
|
}), 500
|
||||||
|
|
||||||
|
|
||||||
# ============== 报告获取接口 ==============
|
# ============== Report retrieval endpoints ==============
|
||||||
|
|
||||||
@report_bp.route('/<report_id>', methods=['GET'])
|
@report_bp.route('/<report_id>', methods=['GET'])
|
||||||
def get_report(report_id: str):
|
def get_report(report_id: str):
|
||||||
"""
|
"""
|
||||||
获取报告详情
|
Get report details
|
||||||
|
|
||||||
返回:
|
Returns:
|
||||||
{
|
{
|
||||||
"success": true,
|
"success": true,
|
||||||
"data": {
|
"data": {
|
||||||
|
|
@ -308,7 +308,7 @@ def get_report(report_id: str):
|
||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取报告失败: {str(e)}")
|
logger.error(f"Failed to get report: {str(e)}")
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
|
|
@ -319,9 +319,9 @@ def get_report(report_id: str):
|
||||||
@report_bp.route('/by-simulation/<simulation_id>', methods=['GET'])
|
@report_bp.route('/by-simulation/<simulation_id>', methods=['GET'])
|
||||||
def get_report_by_simulation(simulation_id: str):
|
def get_report_by_simulation(simulation_id: str):
|
||||||
"""
|
"""
|
||||||
根据模拟ID获取报告
|
Get report by simulation ID
|
||||||
|
|
||||||
返回:
|
Returns:
|
||||||
{
|
{
|
||||||
"success": true,
|
"success": true,
|
||||||
"data": {
|
"data": {
|
||||||
|
|
@ -347,7 +347,7 @@ def get_report_by_simulation(simulation_id: str):
|
||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取报告失败: {str(e)}")
|
logger.error(f"Failed to get report: {str(e)}")
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
|
|
@ -358,13 +358,13 @@ def get_report_by_simulation(simulation_id: str):
|
||||||
@report_bp.route('/list', methods=['GET'])
|
@report_bp.route('/list', methods=['GET'])
|
||||||
def list_reports():
|
def list_reports():
|
||||||
"""
|
"""
|
||||||
列出所有报告
|
List all reports
|
||||||
|
|
||||||
Query参数:
|
Query parameters:
|
||||||
simulation_id: 按模拟ID过滤(可选)
|
simulation_id: filter by simulation ID (optional)
|
||||||
limit: 返回数量限制(默认50)
|
limit: result count limit (default 50)
|
||||||
|
|
||||||
返回:
|
Returns:
|
||||||
{
|
{
|
||||||
"success": true,
|
"success": true,
|
||||||
"data": [...],
|
"data": [...],
|
||||||
|
|
@ -387,7 +387,7 @@ def list_reports():
|
||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"列出报告失败: {str(e)}")
|
logger.error(f"Failed to list reports: {str(e)}")
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
|
|
@ -398,9 +398,9 @@ def list_reports():
|
||||||
@report_bp.route('/<report_id>/download', methods=['GET'])
|
@report_bp.route('/<report_id>/download', methods=['GET'])
|
||||||
def download_report(report_id: str):
|
def download_report(report_id: str):
|
||||||
"""
|
"""
|
||||||
下载报告(Markdown格式)
|
Download report (Markdown format)
|
||||||
|
|
||||||
返回Markdown文件
|
Returns a Markdown file
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
report = ReportManager.get_report(report_id)
|
report = ReportManager.get_report(report_id)
|
||||||
|
|
@ -414,7 +414,7 @@ def download_report(report_id: str):
|
||||||
md_path = ReportManager._get_report_markdown_path(report_id)
|
md_path = ReportManager._get_report_markdown_path(report_id)
|
||||||
|
|
||||||
if not os.path.exists(md_path):
|
if not os.path.exists(md_path):
|
||||||
# 如果MD文件不存在,生成一个临时文件
|
# If MD file doesn't exist, create a temporary file
|
||||||
import tempfile
|
import tempfile
|
||||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.md', delete=False) as f:
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.md', delete=False) as f:
|
||||||
f.write(report.markdown_content)
|
f.write(report.markdown_content)
|
||||||
|
|
@ -433,7 +433,7 @@ def download_report(report_id: str):
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"下载报告失败: {str(e)}")
|
logger.error(f"Failed to download report: {str(e)}")
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
|
|
@ -443,7 +443,7 @@ def download_report(report_id: str):
|
||||||
|
|
||||||
@report_bp.route('/<report_id>', methods=['DELETE'])
|
@report_bp.route('/<report_id>', methods=['DELETE'])
|
||||||
def delete_report(report_id: str):
|
def delete_report(report_id: str):
|
||||||
"""删除报告"""
|
"""Delete a report"""
|
||||||
try:
|
try:
|
||||||
success = ReportManager.delete_report(report_id)
|
success = ReportManager.delete_report(report_id)
|
||||||
|
|
||||||
|
|
@ -459,7 +459,7 @@ def delete_report(report_id: str):
|
||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"删除报告失败: {str(e)}")
|
logger.error(f"Failed to delete report: {str(e)}")
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
|
|
@ -467,32 +467,32 @@ def delete_report(report_id: str):
|
||||||
}), 500
|
}), 500
|
||||||
|
|
||||||
|
|
||||||
# ============== Report Agent对话接口 ==============
|
# ============== Report Agent chat endpoint ==============
|
||||||
|
|
||||||
@report_bp.route('/chat', methods=['POST'])
|
@report_bp.route('/chat', methods=['POST'])
|
||||||
def chat_with_report_agent():
|
def chat_with_report_agent():
|
||||||
"""
|
"""
|
||||||
与Report Agent对话
|
Chat with the Report Agent
|
||||||
|
|
||||||
Report Agent可以在对话中自主调用检索工具来回答问题
|
The Report Agent can autonomously call retrieval tools to answer questions.
|
||||||
|
|
||||||
请求(JSON):
|
Request (JSON):
|
||||||
{
|
{
|
||||||
"simulation_id": "sim_xxxx", // 必填,模拟ID
|
"simulation_id": "sim_xxxx", // required, simulation ID
|
||||||
"message": "请解释一下舆情走向", // 必填,用户消息
|
"message": "Explain the trend...", // required, user message
|
||||||
"chat_history": [ // 可选,对话历史
|
"chat_history": [ // optional, conversation history
|
||||||
{"role": "user", "content": "..."},
|
{"role": "user", "content": "..."},
|
||||||
{"role": "assistant", "content": "..."}
|
{"role": "assistant", "content": "..."}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
返回:
|
Returns:
|
||||||
{
|
{
|
||||||
"success": true,
|
"success": true,
|
||||||
"data": {
|
"data": {
|
||||||
"response": "Agent回复...",
|
"response": "Agent reply...",
|
||||||
"tool_calls": [调用的工具列表],
|
"tool_calls": [list of tools called],
|
||||||
"sources": [信息来源]
|
"sources": [information sources]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
@ -515,7 +515,7 @@ def chat_with_report_agent():
|
||||||
"error": t('api.requireMessage')
|
"error": t('api.requireMessage')
|
||||||
}), 400
|
}), 400
|
||||||
|
|
||||||
# 获取模拟和项目信息
|
# Get simulation and project info
|
||||||
manager = SimulationManager()
|
manager = SimulationManager()
|
||||||
state = manager.get_simulation(simulation_id)
|
state = manager.get_simulation(simulation_id)
|
||||||
|
|
||||||
|
|
@ -541,7 +541,7 @@ def chat_with_report_agent():
|
||||||
|
|
||||||
simulation_requirement = project.simulation_requirement or ""
|
simulation_requirement = project.simulation_requirement or ""
|
||||||
|
|
||||||
# 创建Agent并进行对话
|
# Create agent and start chat
|
||||||
agent = ReportAgent(
|
agent = ReportAgent(
|
||||||
graph_id=graph_id,
|
graph_id=graph_id,
|
||||||
simulation_id=simulation_id,
|
simulation_id=simulation_id,
|
||||||
|
|
@ -556,7 +556,7 @@ def chat_with_report_agent():
|
||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"对话失败: {str(e)}")
|
logger.error(f"Chat failed: {str(e)}")
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
|
|
@ -564,22 +564,22 @@ def chat_with_report_agent():
|
||||||
}), 500
|
}), 500
|
||||||
|
|
||||||
|
|
||||||
# ============== 报告进度与分章节接口 ==============
|
# ============== Report progress and section endpoints ==============
|
||||||
|
|
||||||
@report_bp.route('/<report_id>/progress', methods=['GET'])
|
@report_bp.route('/<report_id>/progress', methods=['GET'])
|
||||||
def get_report_progress(report_id: str):
|
def get_report_progress(report_id: str):
|
||||||
"""
|
"""
|
||||||
获取报告生成进度(实时)
|
Get report generation progress (real-time)
|
||||||
|
|
||||||
返回:
|
Returns:
|
||||||
{
|
{
|
||||||
"success": true,
|
"success": true,
|
||||||
"data": {
|
"data": {
|
||||||
"status": "generating",
|
"status": "generating",
|
||||||
"progress": 45,
|
"progress": 45,
|
||||||
"message": "正在生成章节: 关键发现",
|
"message": "Generating section: Key Findings",
|
||||||
"current_section": "关键发现",
|
"current_section": "Key Findings",
|
||||||
"completed_sections": ["执行摘要", "模拟背景"],
|
"completed_sections": ["Executive Summary", "Simulation Background"],
|
||||||
"updated_at": "2025-12-09T..."
|
"updated_at": "2025-12-09T..."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -599,7 +599,7 @@ def get_report_progress(report_id: str):
|
||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取报告进度失败: {str(e)}")
|
logger.error(f"Failed to get report progress: {str(e)}")
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
|
|
@ -610,11 +610,12 @@ def get_report_progress(report_id: str):
|
||||||
@report_bp.route('/<report_id>/sections', methods=['GET'])
|
@report_bp.route('/<report_id>/sections', methods=['GET'])
|
||||||
def get_report_sections(report_id: str):
|
def get_report_sections(report_id: str):
|
||||||
"""
|
"""
|
||||||
获取已生成的章节列表(分章节输出)
|
Get list of already-generated sections (section-by-section output)
|
||||||
|
|
||||||
前端可以轮询此接口获取已生成的章节内容,无需等待整个报告完成
|
The frontend can poll this endpoint to get section content as it is generated,
|
||||||
|
without waiting for the full report to complete.
|
||||||
|
|
||||||
返回:
|
Returns:
|
||||||
{
|
{
|
||||||
"success": true,
|
"success": true,
|
||||||
"data": {
|
"data": {
|
||||||
|
|
@ -623,7 +624,7 @@ def get_report_sections(report_id: str):
|
||||||
{
|
{
|
||||||
"filename": "section_01.md",
|
"filename": "section_01.md",
|
||||||
"section_index": 1,
|
"section_index": 1,
|
||||||
"content": "## 执行摘要\\n\\n..."
|
"content": "## Executive Summary\\n\\n..."
|
||||||
},
|
},
|
||||||
...
|
...
|
||||||
],
|
],
|
||||||
|
|
@ -635,7 +636,7 @@ def get_report_sections(report_id: str):
|
||||||
try:
|
try:
|
||||||
sections = ReportManager.get_generated_sections(report_id)
|
sections = ReportManager.get_generated_sections(report_id)
|
||||||
|
|
||||||
# 获取报告状态
|
# Get report status
|
||||||
report = ReportManager.get_report(report_id)
|
report = ReportManager.get_report(report_id)
|
||||||
is_complete = report is not None and report.status == ReportStatus.COMPLETED
|
is_complete = report is not None and report.status == ReportStatus.COMPLETED
|
||||||
|
|
||||||
|
|
@ -650,7 +651,7 @@ def get_report_sections(report_id: str):
|
||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取章节列表失败: {str(e)}")
|
logger.error(f"Failed to get section list: {str(e)}")
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
|
|
@ -661,14 +662,14 @@ def get_report_sections(report_id: str):
|
||||||
@report_bp.route('/<report_id>/section/<int:section_index>', methods=['GET'])
|
@report_bp.route('/<report_id>/section/<int:section_index>', methods=['GET'])
|
||||||
def get_single_section(report_id: str, section_index: int):
|
def get_single_section(report_id: str, section_index: int):
|
||||||
"""
|
"""
|
||||||
获取单个章节内容
|
Get single section content
|
||||||
|
|
||||||
返回:
|
Returns:
|
||||||
{
|
{
|
||||||
"success": true,
|
"success": true,
|
||||||
"data": {
|
"data": {
|
||||||
"filename": "section_01.md",
|
"filename": "section_01.md",
|
||||||
"content": "## 执行摘要\\n\\n..."
|
"content": "## Executive Summary\\n\\n..."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
@ -694,7 +695,7 @@ def get_single_section(report_id: str, section_index: int):
|
||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取章节内容失败: {str(e)}")
|
logger.error(f"Failed to get section content: {str(e)}")
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
|
|
@ -702,16 +703,16 @@ def get_single_section(report_id: str, section_index: int):
|
||||||
}), 500
|
}), 500
|
||||||
|
|
||||||
|
|
||||||
# ============== 报告状态检查接口 ==============
|
# ============== Report status check endpoint ==============
|
||||||
|
|
||||||
@report_bp.route('/check/<simulation_id>', methods=['GET'])
|
@report_bp.route('/check/<simulation_id>', methods=['GET'])
|
||||||
def check_report_status(simulation_id: str):
|
def check_report_status(simulation_id: str):
|
||||||
"""
|
"""
|
||||||
检查模拟是否有报告,以及报告状态
|
Check whether a simulation has a report and its status
|
||||||
|
|
||||||
用于前端判断是否解锁Interview功能
|
Used by the frontend to determine whether to unlock the Interview feature.
|
||||||
|
|
||||||
返回:
|
Returns:
|
||||||
{
|
{
|
||||||
"success": true,
|
"success": true,
|
||||||
"data": {
|
"data": {
|
||||||
|
|
@ -730,7 +731,7 @@ def check_report_status(simulation_id: str):
|
||||||
report_status = report.status.value if report else None
|
report_status = report.status.value if report else None
|
||||||
report_id = report.report_id if report else None
|
report_id = report.report_id if report else None
|
||||||
|
|
||||||
# 只有报告完成后才解锁interview
|
# Interview is unlocked only after the report is complete
|
||||||
interview_unlocked = has_report and report.status == ReportStatus.COMPLETED
|
interview_unlocked = has_report and report.status == ReportStatus.COMPLETED
|
||||||
|
|
||||||
return jsonify({
|
return jsonify({
|
||||||
|
|
@ -745,7 +746,7 @@ def check_report_status(simulation_id: str):
|
||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"检查报告状态失败: {str(e)}")
|
logger.error(f"Failed to check report status: {str(e)}")
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
|
|
@ -753,22 +754,22 @@ def check_report_status(simulation_id: str):
|
||||||
}), 500
|
}), 500
|
||||||
|
|
||||||
|
|
||||||
# ============== Agent 日志接口 ==============
|
# ============== Agent log endpoints ==============
|
||||||
|
|
||||||
@report_bp.route('/<report_id>/agent-log', methods=['GET'])
|
@report_bp.route('/<report_id>/agent-log', methods=['GET'])
|
||||||
def get_agent_log(report_id: str):
|
def get_agent_log(report_id: str):
|
||||||
"""
|
"""
|
||||||
获取 Report Agent 的详细执行日志
|
Get detailed execution log of the Report Agent
|
||||||
|
|
||||||
实时获取报告生成过程中的每一步动作,包括:
|
Retrieves step-by-step actions during report generation, including:
|
||||||
- 报告开始、规划开始/完成
|
- Report start, planning start/complete
|
||||||
- 每个章节的开始、工具调用、LLM响应、完成
|
- Each section's start, tool calls, LLM response, completion
|
||||||
- 报告完成或失败
|
- Report completion or failure
|
||||||
|
|
||||||
Query参数:
|
Query parameters:
|
||||||
from_line: 从第几行开始读取(可选,默认0,用于增量获取)
|
from_line: start reading from this line (optional, default 0, for incremental fetch)
|
||||||
|
|
||||||
返回:
|
Returns:
|
||||||
{
|
{
|
||||||
"success": true,
|
"success": true,
|
||||||
"data": {
|
"data": {
|
||||||
|
|
@ -779,7 +780,7 @@ def get_agent_log(report_id: str):
|
||||||
"report_id": "report_xxxx",
|
"report_id": "report_xxxx",
|
||||||
"action": "tool_call",
|
"action": "tool_call",
|
||||||
"stage": "generating",
|
"stage": "generating",
|
||||||
"section_title": "执行摘要",
|
"section_title": "Executive Summary",
|
||||||
"section_index": 1,
|
"section_index": 1,
|
||||||
"details": {
|
"details": {
|
||||||
"tool_name": "insight_forge",
|
"tool_name": "insight_forge",
|
||||||
|
|
@ -806,7 +807,7 @@ def get_agent_log(report_id: str):
|
||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取Agent日志失败: {str(e)}")
|
logger.error(f"Failed to get Agent log: {str(e)}")
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
|
|
@ -817,9 +818,9 @@ def get_agent_log(report_id: str):
|
||||||
@report_bp.route('/<report_id>/agent-log/stream', methods=['GET'])
|
@report_bp.route('/<report_id>/agent-log/stream', methods=['GET'])
|
||||||
def stream_agent_log(report_id: str):
|
def stream_agent_log(report_id: str):
|
||||||
"""
|
"""
|
||||||
获取完整的 Agent 日志(一次性获取全部)
|
Get the full Agent log (fetch all at once)
|
||||||
|
|
||||||
返回:
|
Returns:
|
||||||
{
|
{
|
||||||
"success": true,
|
"success": true,
|
||||||
"data": {
|
"data": {
|
||||||
|
|
@ -840,7 +841,7 @@ def stream_agent_log(report_id: str):
|
||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取Agent日志失败: {str(e)}")
|
logger.error(f"Failed to get Agent log: {str(e)}")
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
|
|
@ -848,27 +849,27 @@ def stream_agent_log(report_id: str):
|
||||||
}), 500
|
}), 500
|
||||||
|
|
||||||
|
|
||||||
# ============== 控制台日志接口 ==============
|
# ============== Console log endpoints ==============
|
||||||
|
|
||||||
@report_bp.route('/<report_id>/console-log', methods=['GET'])
|
@report_bp.route('/<report_id>/console-log', methods=['GET'])
|
||||||
def get_console_log(report_id: str):
|
def get_console_log(report_id: str):
|
||||||
"""
|
"""
|
||||||
获取 Report Agent 的控制台输出日志
|
Get the console output log of the Report Agent
|
||||||
|
|
||||||
实时获取报告生成过程中的控制台输出(INFO、WARNING等),
|
Returns real-time console output (INFO, WARNING, etc.) during report generation.
|
||||||
这与 agent-log 接口返回的结构化 JSON 日志不同,
|
Unlike the agent-log endpoint which returns structured JSON logs,
|
||||||
是纯文本格式的控制台风格日志。
|
this returns plain-text console-style logs.
|
||||||
|
|
||||||
Query参数:
|
Query parameters:
|
||||||
from_line: 从第几行开始读取(可选,默认0,用于增量获取)
|
from_line: start reading from this line (optional, default 0, for incremental fetch)
|
||||||
|
|
||||||
返回:
|
Returns:
|
||||||
{
|
{
|
||||||
"success": true,
|
"success": true,
|
||||||
"data": {
|
"data": {
|
||||||
"logs": [
|
"logs": [
|
||||||
"[19:46:14] INFO: 搜索完成: 找到 15 条相关事实",
|
"[19:46:14] INFO: Search complete: found 15 relevant facts",
|
||||||
"[19:46:14] INFO: 图谱搜索: graph_id=xxx, query=...",
|
"[19:46:14] INFO: Graph search: graph_id=xxx, query=...",
|
||||||
...
|
...
|
||||||
],
|
],
|
||||||
"total_lines": 100,
|
"total_lines": 100,
|
||||||
|
|
@ -888,7 +889,7 @@ def get_console_log(report_id: str):
|
||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取控制台日志失败: {str(e)}")
|
logger.error(f"Failed to get console log: {str(e)}")
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
|
|
@ -899,9 +900,9 @@ def get_console_log(report_id: str):
|
||||||
@report_bp.route('/<report_id>/console-log/stream', methods=['GET'])
|
@report_bp.route('/<report_id>/console-log/stream', methods=['GET'])
|
||||||
def stream_console_log(report_id: str):
|
def stream_console_log(report_id: str):
|
||||||
"""
|
"""
|
||||||
获取完整的控制台日志(一次性获取全部)
|
Get the full console log (fetch all at once)
|
||||||
|
|
||||||
返回:
|
Returns:
|
||||||
{
|
{
|
||||||
"success": true,
|
"success": true,
|
||||||
"data": {
|
"data": {
|
||||||
|
|
@ -922,7 +923,7 @@ def stream_console_log(report_id: str):
|
||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取控制台日志失败: {str(e)}")
|
logger.error(f"Failed to get console log: {str(e)}")
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
|
|
@ -930,17 +931,17 @@ def stream_console_log(report_id: str):
|
||||||
}), 500
|
}), 500
|
||||||
|
|
||||||
|
|
||||||
# ============== 工具调用接口(供调试使用)==============
|
# ============== Tool call endpoints (for debugging) ==============
|
||||||
|
|
||||||
@report_bp.route('/tools/search', methods=['POST'])
|
@report_bp.route('/tools/search', methods=['POST'])
|
||||||
def search_graph_tool():
|
def search_graph_tool():
|
||||||
"""
|
"""
|
||||||
图谱搜索工具接口(供调试使用)
|
Graph search tool endpoint (for debugging)
|
||||||
|
|
||||||
请求(JSON):
|
Request (JSON):
|
||||||
{
|
{
|
||||||
"graph_id": "mirofish_xxxx",
|
"graph_id": "mirofish_xxxx",
|
||||||
"query": "搜索查询",
|
"query": "search query",
|
||||||
"limit": 10
|
"limit": 10
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
@ -972,7 +973,7 @@ def search_graph_tool():
|
||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"图谱搜索失败: {str(e)}")
|
logger.error(f"Graph search failed: {str(e)}")
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
|
|
@ -983,9 +984,9 @@ def search_graph_tool():
|
||||||
@report_bp.route('/tools/statistics', methods=['POST'])
|
@report_bp.route('/tools/statistics', methods=['POST'])
|
||||||
def get_graph_statistics_tool():
|
def get_graph_statistics_tool():
|
||||||
"""
|
"""
|
||||||
图谱统计工具接口(供调试使用)
|
Graph statistics tool endpoint (for debugging)
|
||||||
|
|
||||||
请求(JSON):
|
Request (JSON):
|
||||||
{
|
{
|
||||||
"graph_id": "mirofish_xxxx"
|
"graph_id": "mirofish_xxxx"
|
||||||
}
|
}
|
||||||
|
|
@ -1012,7 +1013,7 @@ def get_graph_statistics_tool():
|
||||||
})
|
})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取图谱统计失败: {str(e)}")
|
logger.error(f"Failed to get graph statistics: {str(e)}")
|
||||||
return jsonify({
|
return jsonify({
|
||||||
"success": False,
|
"success": False,
|
||||||
"error": str(e),
|
"error": str(e),
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,55 +1,55 @@
|
||||||
"""
|
"""
|
||||||
配置管理
|
Configuration management
|
||||||
统一从项目根目录的 .env 文件加载配置
|
Loads config uniformly from the .env file at the project root
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
# 加载项目根目录的 .env 文件
|
# Load the .env file from the project root
|
||||||
# 路径: MiroFish/.env (相对于 backend/app/config.py)
|
# Path: MiroFish/.env (relative to backend/app/config.py)
|
||||||
project_root_env = os.path.join(os.path.dirname(__file__), '../../.env')
|
project_root_env = os.path.join(os.path.dirname(__file__), '../../.env')
|
||||||
|
|
||||||
if os.path.exists(project_root_env):
|
if os.path.exists(project_root_env):
|
||||||
load_dotenv(project_root_env, override=True)
|
load_dotenv(project_root_env, override=True)
|
||||||
else:
|
else:
|
||||||
# 如果根目录没有 .env,尝试加载环境变量(用于生产环境)
|
# If no root-level .env file found, load from environment variables (production)
|
||||||
load_dotenv(override=True)
|
load_dotenv(override=True)
|
||||||
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Flask配置类"""
|
"""Flask configuration class"""
|
||||||
|
|
||||||
# Flask配置
|
# Flask settings
|
||||||
SECRET_KEY = os.environ.get('SECRET_KEY', 'mirofish-secret-key')
|
SECRET_KEY = os.environ.get('SECRET_KEY', 'mirofish-secret-key')
|
||||||
DEMO_PASSWORD = os.environ.get('DEMO_PASSWORD', '')
|
DEMO_PASSWORD = os.environ.get('DEMO_PASSWORD', '')
|
||||||
DEBUG = os.environ.get('FLASK_DEBUG', 'True').lower() == 'true'
|
DEBUG = os.environ.get('FLASK_DEBUG', 'True').lower() == 'true'
|
||||||
|
|
||||||
# JSON配置 - 禁用ASCII转义,让中文直接显示(而不是 \uXXXX 格式)
|
# JSON settings - disable ASCII escaping so non-ASCII chars are output directly (not as \uXXXX)
|
||||||
JSON_AS_ASCII = False
|
JSON_AS_ASCII = False
|
||||||
|
|
||||||
# LLM配置(统一使用OpenAI格式)
|
# LLM settings (unified OpenAI-compatible format)
|
||||||
LLM_API_KEY = os.environ.get('LLM_API_KEY')
|
LLM_API_KEY = os.environ.get('LLM_API_KEY')
|
||||||
LLM_BASE_URL = os.environ.get('LLM_BASE_URL', 'https://api.openai.com/v1')
|
LLM_BASE_URL = os.environ.get('LLM_BASE_URL', 'https://api.openai.com/v1')
|
||||||
LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME', 'gpt-4o-mini')
|
LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME', 'gpt-4o-mini')
|
||||||
|
|
||||||
# Zep配置
|
# Zep settings
|
||||||
ZEP_API_KEY = os.environ.get('ZEP_API_KEY')
|
ZEP_API_KEY = os.environ.get('ZEP_API_KEY')
|
||||||
|
|
||||||
# 文件上传配置
|
# File upload settings
|
||||||
MAX_CONTENT_LENGTH = 50 * 1024 * 1024 # 50MB
|
MAX_CONTENT_LENGTH = 50 * 1024 * 1024 # 50MB
|
||||||
UPLOAD_FOLDER = os.path.join(os.path.dirname(__file__), '../uploads')
|
UPLOAD_FOLDER = os.path.join(os.path.dirname(__file__), '../uploads')
|
||||||
ALLOWED_EXTENSIONS = {'pdf', 'md', 'txt', 'markdown'}
|
ALLOWED_EXTENSIONS = {'pdf', 'md', 'txt', 'markdown'}
|
||||||
|
|
||||||
# 文本处理配置
|
# Text processing settings
|
||||||
DEFAULT_CHUNK_SIZE = 500 # 默认切块大小
|
DEFAULT_CHUNK_SIZE = 500 # default chunk size
|
||||||
DEFAULT_CHUNK_OVERLAP = 50 # 默认重叠大小
|
DEFAULT_CHUNK_OVERLAP = 50 # default overlap size
|
||||||
|
|
||||||
# OASIS模拟配置
|
# OASIS simulation settings
|
||||||
OASIS_DEFAULT_MAX_ROUNDS = int(os.environ.get('OASIS_DEFAULT_MAX_ROUNDS', '10'))
|
OASIS_DEFAULT_MAX_ROUNDS = int(os.environ.get('OASIS_DEFAULT_MAX_ROUNDS', '10'))
|
||||||
OASIS_SIMULATION_DATA_DIR = os.path.join(os.path.dirname(__file__), '../uploads/simulations')
|
OASIS_SIMULATION_DATA_DIR = os.path.join(os.path.dirname(__file__), '../uploads/simulations')
|
||||||
|
|
||||||
# OASIS平台可用动作配置
|
# OASIS platform available actions
|
||||||
OASIS_TWITTER_ACTIONS = [
|
OASIS_TWITTER_ACTIONS = [
|
||||||
'CREATE_POST', 'LIKE_POST', 'REPOST', 'FOLLOW', 'DO_NOTHING', 'QUOTE_POST'
|
'CREATE_POST', 'LIKE_POST', 'REPOST', 'FOLLOW', 'DO_NOTHING', 'QUOTE_POST'
|
||||||
]
|
]
|
||||||
|
|
@ -59,18 +59,18 @@ class Config:
|
||||||
'TREND', 'REFRESH', 'DO_NOTHING', 'FOLLOW', 'MUTE'
|
'TREND', 'REFRESH', 'DO_NOTHING', 'FOLLOW', 'MUTE'
|
||||||
]
|
]
|
||||||
|
|
||||||
# Report Agent配置
|
# Report Agent settings
|
||||||
REPORT_AGENT_MAX_TOOL_CALLS = int(os.environ.get('REPORT_AGENT_MAX_TOOL_CALLS', '5'))
|
REPORT_AGENT_MAX_TOOL_CALLS = int(os.environ.get('REPORT_AGENT_MAX_TOOL_CALLS', '5'))
|
||||||
REPORT_AGENT_MAX_REFLECTION_ROUNDS = int(os.environ.get('REPORT_AGENT_MAX_REFLECTION_ROUNDS', '2'))
|
REPORT_AGENT_MAX_REFLECTION_ROUNDS = int(os.environ.get('REPORT_AGENT_MAX_REFLECTION_ROUNDS', '2'))
|
||||||
REPORT_AGENT_TEMPERATURE = float(os.environ.get('REPORT_AGENT_TEMPERATURE', '0.5'))
|
REPORT_AGENT_TEMPERATURE = float(os.environ.get('REPORT_AGENT_TEMPERATURE', '0.5'))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate(cls):
|
def validate(cls):
|
||||||
"""验证必要配置"""
|
"""Validate required configuration"""
|
||||||
errors = []
|
errors = []
|
||||||
if not cls.LLM_API_KEY:
|
if not cls.LLM_API_KEY:
|
||||||
errors.append("LLM_API_KEY 未配置")
|
errors.append("LLM_API_KEY is not configured")
|
||||||
if not cls.ZEP_API_KEY:
|
if not cls.ZEP_API_KEY:
|
||||||
errors.append("ZEP_API_KEY 未配置")
|
errors.append("ZEP_API_KEY is not configured")
|
||||||
return errors
|
return errors
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
"""
|
"""
|
||||||
数据模型模块
|
Data models module
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .task import TaskManager, TaskStatus
|
from .task import TaskManager, TaskStatus
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"""
|
"""
|
||||||
项目上下文管理
|
Project context management
|
||||||
用于在服务端持久化项目状态,避免前端在接口间传递大量数据
|
Persists project state server-side so the frontend does not need to pass large amounts of data between endpoints.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
@ -15,45 +15,45 @@ from ..config import Config
|
||||||
|
|
||||||
|
|
||||||
class ProjectStatus(str, Enum):
|
class ProjectStatus(str, Enum):
|
||||||
"""项目状态"""
|
"""Project status"""
|
||||||
CREATED = "created" # 刚创建,文件已上传
|
CREATED = "created" # Just created; files uploaded
|
||||||
ONTOLOGY_GENERATED = "ontology_generated" # 本体已生成
|
ONTOLOGY_GENERATED = "ontology_generated" # Ontology generated
|
||||||
GRAPH_BUILDING = "graph_building" # 图谱构建中
|
GRAPH_BUILDING = "graph_building" # Graph building in progress
|
||||||
GRAPH_COMPLETED = "graph_completed" # 图谱构建完成
|
GRAPH_COMPLETED = "graph_completed" # Graph build complete
|
||||||
FAILED = "failed" # 失败
|
FAILED = "failed" # Failed
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Project:
|
class Project:
|
||||||
"""项目数据模型"""
|
"""Project data model"""
|
||||||
project_id: str
|
project_id: str
|
||||||
name: str
|
name: str
|
||||||
status: ProjectStatus
|
status: ProjectStatus
|
||||||
created_at: str
|
created_at: str
|
||||||
updated_at: str
|
updated_at: str
|
||||||
|
|
||||||
# 文件信息
|
# File info
|
||||||
files: List[Dict[str, str]] = field(default_factory=list) # [{filename, path, size}]
|
files: List[Dict[str, str]] = field(default_factory=list) # [{filename, path, size}]
|
||||||
total_text_length: int = 0
|
total_text_length: int = 0
|
||||||
|
|
||||||
# 本体信息(接口1生成后填充)
|
# Ontology info (populated after endpoint 1)
|
||||||
ontology: Optional[Dict[str, Any]] = None
|
ontology: Optional[Dict[str, Any]] = None
|
||||||
analysis_summary: Optional[str] = None
|
analysis_summary: Optional[str] = None
|
||||||
|
|
||||||
# 图谱信息(接口2完成后填充)
|
# Graph info (populated after endpoint 2 completes)
|
||||||
graph_id: Optional[str] = None
|
graph_id: Optional[str] = None
|
||||||
graph_build_task_id: Optional[str] = None
|
graph_build_task_id: Optional[str] = None
|
||||||
|
|
||||||
# 配置
|
# Configuration
|
||||||
simulation_requirement: Optional[str] = None
|
simulation_requirement: Optional[str] = None
|
||||||
chunk_size: int = 500
|
chunk_size: int = 500
|
||||||
chunk_overlap: int = 50
|
chunk_overlap: int = 50
|
||||||
|
|
||||||
# 错误信息
|
# Error info
|
||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
"""转换为字典"""
|
"""Convert to dictionary"""
|
||||||
return {
|
return {
|
||||||
"project_id": self.project_id,
|
"project_id": self.project_id,
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
|
|
@ -74,7 +74,7 @@ class Project:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_dict(cls, data: Dict[str, Any]) -> 'Project':
|
def from_dict(cls, data: Dict[str, Any]) -> 'Project':
|
||||||
"""从字典创建"""
|
"""Create from dictionary"""
|
||||||
status = data.get('status', 'created')
|
status = data.get('status', 'created')
|
||||||
if isinstance(status, str):
|
if isinstance(status, str):
|
||||||
status = ProjectStatus(status)
|
status = ProjectStatus(status)
|
||||||
|
|
@ -99,46 +99,46 @@ class Project:
|
||||||
|
|
||||||
|
|
||||||
class ProjectManager:
|
class ProjectManager:
|
||||||
"""项目管理器 - 负责项目的持久化存储和检索"""
|
"""Project manager - handles persistent storage and retrieval of projects"""
|
||||||
|
|
||||||
# 项目存储根目录
|
# Root directory for project storage
|
||||||
PROJECTS_DIR = os.path.join(Config.UPLOAD_FOLDER, 'projects')
|
PROJECTS_DIR = os.path.join(Config.UPLOAD_FOLDER, 'projects')
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _ensure_projects_dir(cls):
|
def _ensure_projects_dir(cls):
|
||||||
"""确保项目目录存在"""
|
"""Ensure the projects directory exists"""
|
||||||
os.makedirs(cls.PROJECTS_DIR, exist_ok=True)
|
os.makedirs(cls.PROJECTS_DIR, exist_ok=True)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_project_dir(cls, project_id: str) -> str:
|
def _get_project_dir(cls, project_id: str) -> str:
|
||||||
"""获取项目目录路径"""
|
"""Get project directory path"""
|
||||||
return os.path.join(cls.PROJECTS_DIR, project_id)
|
return os.path.join(cls.PROJECTS_DIR, project_id)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_project_meta_path(cls, project_id: str) -> str:
|
def _get_project_meta_path(cls, project_id: str) -> str:
|
||||||
"""获取项目元数据文件路径"""
|
"""Get project metadata file path"""
|
||||||
return os.path.join(cls._get_project_dir(project_id), 'project.json')
|
return os.path.join(cls._get_project_dir(project_id), 'project.json')
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_project_files_dir(cls, project_id: str) -> str:
|
def _get_project_files_dir(cls, project_id: str) -> str:
|
||||||
"""获取项目文件存储目录"""
|
"""Get project files storage directory"""
|
||||||
return os.path.join(cls._get_project_dir(project_id), 'files')
|
return os.path.join(cls._get_project_dir(project_id), 'files')
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_project_text_path(cls, project_id: str) -> str:
|
def _get_project_text_path(cls, project_id: str) -> str:
|
||||||
"""获取项目提取文本存储路径"""
|
"""Get path for storing the extracted project text"""
|
||||||
return os.path.join(cls._get_project_dir(project_id), 'extracted_text.txt')
|
return os.path.join(cls._get_project_dir(project_id), 'extracted_text.txt')
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_project(cls, name: str = "Unnamed Project") -> Project:
|
def create_project(cls, name: str = "Unnamed Project") -> Project:
|
||||||
"""
|
"""
|
||||||
创建新项目
|
Create a new project.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: 项目名称
|
name: project name
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
新创建的Project对象
|
newly created Project object
|
||||||
"""
|
"""
|
||||||
cls._ensure_projects_dir()
|
cls._ensure_projects_dir()
|
||||||
|
|
||||||
|
|
@ -153,20 +153,20 @@ class ProjectManager:
|
||||||
updated_at=now
|
updated_at=now
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建项目目录结构
|
# Create project directory structure
|
||||||
project_dir = cls._get_project_dir(project_id)
|
project_dir = cls._get_project_dir(project_id)
|
||||||
files_dir = cls._get_project_files_dir(project_id)
|
files_dir = cls._get_project_files_dir(project_id)
|
||||||
os.makedirs(project_dir, exist_ok=True)
|
os.makedirs(project_dir, exist_ok=True)
|
||||||
os.makedirs(files_dir, exist_ok=True)
|
os.makedirs(files_dir, exist_ok=True)
|
||||||
|
|
||||||
# 保存项目元数据
|
# Save project metadata
|
||||||
cls.save_project(project)
|
cls.save_project(project)
|
||||||
|
|
||||||
return project
|
return project
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def save_project(cls, project: Project) -> None:
|
def save_project(cls, project: Project) -> None:
|
||||||
"""保存项目元数据"""
|
"""Save project metadata"""
|
||||||
project.updated_at = datetime.now().isoformat()
|
project.updated_at = datetime.now().isoformat()
|
||||||
meta_path = cls._get_project_meta_path(project.project_id)
|
meta_path = cls._get_project_meta_path(project.project_id)
|
||||||
|
|
||||||
|
|
@ -176,13 +176,13 @@ class ProjectManager:
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_project(cls, project_id: str) -> Optional[Project]:
|
def get_project(cls, project_id: str) -> Optional[Project]:
|
||||||
"""
|
"""
|
||||||
获取项目
|
Get a project.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
project_id: 项目ID
|
project_id: project ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Project对象,如果不存在返回None
|
Project object, or None if not found
|
||||||
"""
|
"""
|
||||||
meta_path = cls._get_project_meta_path(project_id)
|
meta_path = cls._get_project_meta_path(project_id)
|
||||||
|
|
||||||
|
|
@ -197,13 +197,13 @@ class ProjectManager:
|
||||||
@classmethod
|
@classmethod
|
||||||
def list_projects(cls, limit: int = 50) -> List[Project]:
|
def list_projects(cls, limit: int = 50) -> List[Project]:
|
||||||
"""
|
"""
|
||||||
列出所有项目
|
List all projects.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
limit: 返回数量限制
|
limit: result count limit
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
项目列表,按创建时间倒序
|
list of projects sorted by creation time, descending
|
||||||
"""
|
"""
|
||||||
cls._ensure_projects_dir()
|
cls._ensure_projects_dir()
|
||||||
|
|
||||||
|
|
@ -213,7 +213,7 @@ class ProjectManager:
|
||||||
if project:
|
if project:
|
||||||
projects.append(project)
|
projects.append(project)
|
||||||
|
|
||||||
# 按创建时间倒序排序
|
# Sort by creation time, descending
|
||||||
projects.sort(key=lambda p: p.created_at, reverse=True)
|
projects.sort(key=lambda p: p.created_at, reverse=True)
|
||||||
|
|
||||||
return projects[:limit]
|
return projects[:limit]
|
||||||
|
|
@ -221,13 +221,13 @@ class ProjectManager:
|
||||||
@classmethod
|
@classmethod
|
||||||
def delete_project(cls, project_id: str) -> bool:
|
def delete_project(cls, project_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
删除项目及其所有文件
|
Delete a project and all its files.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
project_id: 项目ID
|
project_id: project ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
是否删除成功
|
True if successfully deleted
|
||||||
"""
|
"""
|
||||||
project_dir = cls._get_project_dir(project_id)
|
project_dir = cls._get_project_dir(project_id)
|
||||||
|
|
||||||
|
|
@ -240,28 +240,28 @@ class ProjectManager:
|
||||||
@classmethod
|
@classmethod
|
||||||
def save_file_to_project(cls, project_id: str, file_storage, original_filename: str) -> Dict[str, str]:
|
def save_file_to_project(cls, project_id: str, file_storage, original_filename: str) -> Dict[str, str]:
|
||||||
"""
|
"""
|
||||||
保存上传的文件到项目目录
|
Save an uploaded file to the project directory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
project_id: 项目ID
|
project_id: project ID
|
||||||
file_storage: Flask的FileStorage对象
|
file_storage: Flask FileStorage object
|
||||||
original_filename: 原始文件名
|
original_filename: original filename
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
文件信息字典 {filename, path, size}
|
file info dict {filename, path, size}
|
||||||
"""
|
"""
|
||||||
files_dir = cls._get_project_files_dir(project_id)
|
files_dir = cls._get_project_files_dir(project_id)
|
||||||
os.makedirs(files_dir, exist_ok=True)
|
os.makedirs(files_dir, exist_ok=True)
|
||||||
|
|
||||||
# 生成安全的文件名
|
# Generate a safe filename
|
||||||
ext = os.path.splitext(original_filename)[1].lower()
|
ext = os.path.splitext(original_filename)[1].lower()
|
||||||
safe_filename = f"{uuid.uuid4().hex[:8]}{ext}"
|
safe_filename = f"{uuid.uuid4().hex[:8]}{ext}"
|
||||||
file_path = os.path.join(files_dir, safe_filename)
|
file_path = os.path.join(files_dir, safe_filename)
|
||||||
|
|
||||||
# 保存文件
|
# Save file
|
||||||
file_storage.save(file_path)
|
file_storage.save(file_path)
|
||||||
|
|
||||||
# 获取文件大小
|
# Get file size
|
||||||
file_size = os.path.getsize(file_path)
|
file_size = os.path.getsize(file_path)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
|
|
@ -273,14 +273,14 @@ class ProjectManager:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def save_extracted_text(cls, project_id: str, text: str) -> None:
|
def save_extracted_text(cls, project_id: str, text: str) -> None:
|
||||||
"""保存提取的文本"""
|
"""Save extracted text"""
|
||||||
text_path = cls._get_project_text_path(project_id)
|
text_path = cls._get_project_text_path(project_id)
|
||||||
with open(text_path, 'w', encoding='utf-8') as f:
|
with open(text_path, 'w', encoding='utf-8') as f:
|
||||||
f.write(text)
|
f.write(text)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_extracted_text(cls, project_id: str) -> Optional[str]:
|
def get_extracted_text(cls, project_id: str) -> Optional[str]:
|
||||||
"""获取提取的文本"""
|
"""Get extracted text"""
|
||||||
text_path = cls._get_project_text_path(project_id)
|
text_path = cls._get_project_text_path(project_id)
|
||||||
|
|
||||||
if not os.path.exists(text_path):
|
if not os.path.exists(text_path):
|
||||||
|
|
@ -291,7 +291,7 @@ class ProjectManager:
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_project_files(cls, project_id: str) -> List[str]:
|
def get_project_files(cls, project_id: str) -> List[str]:
|
||||||
"""获取项目的所有文件路径"""
|
"""Get all file paths for a project"""
|
||||||
files_dir = cls._get_project_files_dir(project_id)
|
files_dir = cls._get_project_files_dir(project_id)
|
||||||
|
|
||||||
if not os.path.exists(files_dir):
|
if not os.path.exists(files_dir):
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"""
|
"""
|
||||||
任务状态管理
|
Task state management
|
||||||
用于跟踪长时间运行的任务(如图谱构建)
|
Used to track long-running tasks (e.g. graph building).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
|
|
@ -14,30 +14,30 @@ from ..utils.locale import t
|
||||||
|
|
||||||
|
|
||||||
class TaskStatus(str, Enum):
|
class TaskStatus(str, Enum):
|
||||||
"""任务状态枚举"""
|
"""Task status enum"""
|
||||||
PENDING = "pending" # 等待中
|
PENDING = "pending" # Waiting
|
||||||
PROCESSING = "processing" # 处理中
|
PROCESSING = "processing" # In progress
|
||||||
COMPLETED = "completed" # 已完成
|
COMPLETED = "completed" # Completed
|
||||||
FAILED = "failed" # 失败
|
FAILED = "failed" # Failed
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Task:
|
class Task:
|
||||||
"""任务数据类"""
|
"""Task data class"""
|
||||||
task_id: str
|
task_id: str
|
||||||
task_type: str
|
task_type: str
|
||||||
status: TaskStatus
|
status: TaskStatus
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
updated_at: datetime
|
updated_at: datetime
|
||||||
progress: int = 0 # 总进度百分比 0-100
|
progress: int = 0 # Total progress percentage 0-100
|
||||||
message: str = "" # 状态消息
|
message: str = "" # Status message
|
||||||
result: Optional[Dict] = None # 任务结果
|
result: Optional[Dict] = None # Task result
|
||||||
error: Optional[str] = None # 错误信息
|
error: Optional[str] = None # Error info
|
||||||
metadata: Dict = field(default_factory=dict) # 额外元数据
|
metadata: Dict = field(default_factory=dict) # Extra metadata
|
||||||
progress_detail: Dict = field(default_factory=dict) # 详细进度信息
|
progress_detail: Dict = field(default_factory=dict) # Detailed progress info
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
"""转换为字典"""
|
"""Convert to dictionary"""
|
||||||
return {
|
return {
|
||||||
"task_id": self.task_id,
|
"task_id": self.task_id,
|
||||||
"task_type": self.task_type,
|
"task_type": self.task_type,
|
||||||
|
|
@ -55,15 +55,15 @@ class Task:
|
||||||
|
|
||||||
class TaskManager:
|
class TaskManager:
|
||||||
"""
|
"""
|
||||||
任务管理器
|
Task manager
|
||||||
线程安全的任务状态管理
|
Thread-safe task state management
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_instance = None
|
_instance = None
|
||||||
_lock = threading.Lock()
|
_lock = threading.Lock()
|
||||||
|
|
||||||
def __new__(cls):
|
def __new__(cls):
|
||||||
"""单例模式"""
|
"""Singleton pattern"""
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
with cls._lock:
|
with cls._lock:
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
|
|
@ -74,14 +74,14 @@ class TaskManager:
|
||||||
|
|
||||||
def create_task(self, task_type: str, metadata: Optional[Dict] = None) -> str:
|
def create_task(self, task_type: str, metadata: Optional[Dict] = None) -> str:
|
||||||
"""
|
"""
|
||||||
创建新任务
|
Create a new task.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
task_type: 任务类型
|
task_type: task type
|
||||||
metadata: 额外元数据
|
metadata: extra metadata
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
任务ID
|
task ID
|
||||||
"""
|
"""
|
||||||
task_id = str(uuid.uuid4())
|
task_id = str(uuid.uuid4())
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
|
|
@ -101,7 +101,7 @@ class TaskManager:
|
||||||
return task_id
|
return task_id
|
||||||
|
|
||||||
def get_task(self, task_id: str) -> Optional[Task]:
|
def get_task(self, task_id: str) -> Optional[Task]:
|
||||||
"""获取任务"""
|
"""Get a task"""
|
||||||
with self._task_lock:
|
with self._task_lock:
|
||||||
return self._tasks.get(task_id)
|
return self._tasks.get(task_id)
|
||||||
|
|
||||||
|
|
@ -116,16 +116,16 @@ class TaskManager:
|
||||||
progress_detail: Optional[Dict] = None
|
progress_detail: Optional[Dict] = None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
更新任务状态
|
Update task status.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
task_id: 任务ID
|
task_id: task ID
|
||||||
status: 新状态
|
status: new status
|
||||||
progress: 进度
|
progress: progress
|
||||||
message: 消息
|
message: message
|
||||||
result: 结果
|
result: result
|
||||||
error: 错误信息
|
error: error info
|
||||||
progress_detail: 详细进度信息
|
progress_detail: detailed progress info
|
||||||
"""
|
"""
|
||||||
with self._task_lock:
|
with self._task_lock:
|
||||||
task = self._tasks.get(task_id)
|
task = self._tasks.get(task_id)
|
||||||
|
|
@ -145,7 +145,7 @@ class TaskManager:
|
||||||
task.progress_detail = progress_detail
|
task.progress_detail = progress_detail
|
||||||
|
|
||||||
def complete_task(self, task_id: str, result: Dict):
|
def complete_task(self, task_id: str, result: Dict):
|
||||||
"""标记任务完成"""
|
"""Mark task as complete"""
|
||||||
self.update_task(
|
self.update_task(
|
||||||
task_id,
|
task_id,
|
||||||
status=TaskStatus.COMPLETED,
|
status=TaskStatus.COMPLETED,
|
||||||
|
|
@ -155,7 +155,7 @@ class TaskManager:
|
||||||
)
|
)
|
||||||
|
|
||||||
def fail_task(self, task_id: str, error: str):
|
def fail_task(self, task_id: str, error: str):
|
||||||
"""标记任务失败"""
|
"""Mark task as failed"""
|
||||||
self.update_task(
|
self.update_task(
|
||||||
task_id,
|
task_id,
|
||||||
status=TaskStatus.FAILED,
|
status=TaskStatus.FAILED,
|
||||||
|
|
@ -164,7 +164,7 @@ class TaskManager:
|
||||||
)
|
)
|
||||||
|
|
||||||
def list_tasks(self, task_type: Optional[str] = None) -> list:
|
def list_tasks(self, task_type: Optional[str] = None) -> list:
|
||||||
"""列出任务"""
|
"""List tasks"""
|
||||||
with self._task_lock:
|
with self._task_lock:
|
||||||
tasks = list(self._tasks.values())
|
tasks = list(self._tasks.values())
|
||||||
if task_type:
|
if task_type:
|
||||||
|
|
@ -172,7 +172,7 @@ class TaskManager:
|
||||||
return [t.to_dict() for t in sorted(tasks, key=lambda x: x.created_at, reverse=True)]
|
return [t.to_dict() for t in sorted(tasks, key=lambda x: x.created_at, reverse=True)]
|
||||||
|
|
||||||
def cleanup_old_tasks(self, max_age_hours: int = 24):
|
def cleanup_old_tasks(self, max_age_hours: int = 24):
|
||||||
"""清理旧任务"""
|
"""Clean up old tasks"""
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
cutoff = datetime.now() - timedelta(hours=max_age_hours)
|
cutoff = datetime.now() - timedelta(hours=max_age_hours)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
"""
|
"""
|
||||||
业务服务模块
|
Business services module
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .ontology_generator import OntologyGenerator
|
from .ontology_generator import OntologyGenerator
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"""
|
"""
|
||||||
图谱构建服务
|
Graph building service
|
||||||
接口2:使用Zep API构建Standalone Graph
|
Endpoint 2: Build a Standalone Graph using the Zep API
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
@ -22,7 +22,7 @@ from ..utils.locale import t, get_locale, set_locale
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GraphInfo:
|
class GraphInfo:
|
||||||
"""图谱信息"""
|
"""Graph info"""
|
||||||
graph_id: str
|
graph_id: str
|
||||||
node_count: int
|
node_count: int
|
||||||
edge_count: int
|
edge_count: int
|
||||||
|
|
@ -39,14 +39,14 @@ class GraphInfo:
|
||||||
|
|
||||||
class GraphBuilderService:
|
class GraphBuilderService:
|
||||||
"""
|
"""
|
||||||
图谱构建服务
|
Graph building service
|
||||||
负责调用Zep API构建知识图谱
|
Responsible for calling the Zep API to build the knowledge graph.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, api_key: Optional[str] = None):
|
def __init__(self, api_key: Optional[str] = None):
|
||||||
self.api_key = api_key or Config.ZEP_API_KEY
|
self.api_key = api_key or Config.ZEP_API_KEY
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError("ZEP_API_KEY 未配置")
|
raise ValueError("ZEP_API_KEY is not configured")
|
||||||
|
|
||||||
self.client = Zep(api_key=self.api_key)
|
self.client = Zep(api_key=self.api_key)
|
||||||
self.task_manager = TaskManager()
|
self.task_manager = TaskManager()
|
||||||
|
|
@ -61,20 +61,20 @@ class GraphBuilderService:
|
||||||
batch_size: int = 3
|
batch_size: int = 3
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
异步构建图谱
|
Build the graph asynchronously.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: 输入文本
|
text: input text
|
||||||
ontology: 本体定义(来自接口1的输出)
|
ontology: ontology definition (output from endpoint 1)
|
||||||
graph_name: 图谱名称
|
graph_name: graph name
|
||||||
chunk_size: 文本块大小
|
chunk_size: text chunk size
|
||||||
chunk_overlap: 块重叠大小
|
chunk_overlap: chunk overlap size
|
||||||
batch_size: 每批发送的块数量
|
batch_size: number of chunks per batch
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
任务ID
|
task ID
|
||||||
"""
|
"""
|
||||||
# 创建任务
|
# Create task
|
||||||
task_id = self.task_manager.create_task(
|
task_id = self.task_manager.create_task(
|
||||||
task_type="graph_build",
|
task_type="graph_build",
|
||||||
metadata={
|
metadata={
|
||||||
|
|
@ -87,7 +87,7 @@ class GraphBuilderService:
|
||||||
# Capture locale before spawning background thread
|
# Capture locale before spawning background thread
|
||||||
current_locale = get_locale()
|
current_locale = get_locale()
|
||||||
|
|
||||||
# 在后台线程中执行构建
|
# Run build in background thread
|
||||||
thread = threading.Thread(
|
thread = threading.Thread(
|
||||||
target=self._build_graph_worker,
|
target=self._build_graph_worker,
|
||||||
args=(task_id, text, ontology, graph_name, chunk_size, chunk_overlap, batch_size, current_locale)
|
args=(task_id, text, ontology, graph_name, chunk_size, chunk_overlap, batch_size, current_locale)
|
||||||
|
|
@ -108,7 +108,7 @@ class GraphBuilderService:
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
locale: str = 'zh'
|
locale: str = 'zh'
|
||||||
):
|
):
|
||||||
"""图谱构建工作线程"""
|
"""Graph build worker thread"""
|
||||||
set_locale(locale)
|
set_locale(locale)
|
||||||
try:
|
try:
|
||||||
self.task_manager.update_task(
|
self.task_manager.update_task(
|
||||||
|
|
@ -118,7 +118,7 @@ class GraphBuilderService:
|
||||||
message=t('progress.startBuildingGraph')
|
message=t('progress.startBuildingGraph')
|
||||||
)
|
)
|
||||||
|
|
||||||
# 1. 创建图谱
|
# 1. Create graph
|
||||||
graph_id = self.create_graph(graph_name)
|
graph_id = self.create_graph(graph_name)
|
||||||
self.task_manager.update_task(
|
self.task_manager.update_task(
|
||||||
task_id,
|
task_id,
|
||||||
|
|
@ -126,7 +126,7 @@ class GraphBuilderService:
|
||||||
message=t('progress.graphCreated', graphId=graph_id)
|
message=t('progress.graphCreated', graphId=graph_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 2. 设置本体
|
# 2. Set ontology
|
||||||
self.set_ontology(graph_id, ontology)
|
self.set_ontology(graph_id, ontology)
|
||||||
self.task_manager.update_task(
|
self.task_manager.update_task(
|
||||||
task_id,
|
task_id,
|
||||||
|
|
@ -134,7 +134,7 @@ class GraphBuilderService:
|
||||||
message=t('progress.ontologySet')
|
message=t('progress.ontologySet')
|
||||||
)
|
)
|
||||||
|
|
||||||
# 3. 文本分块
|
# 3. Split text into chunks
|
||||||
chunks = TextProcessor.split_text(text, chunk_size, chunk_overlap)
|
chunks = TextProcessor.split_text(text, chunk_size, chunk_overlap)
|
||||||
total_chunks = len(chunks)
|
total_chunks = len(chunks)
|
||||||
self.task_manager.update_task(
|
self.task_manager.update_task(
|
||||||
|
|
@ -143,7 +143,7 @@ class GraphBuilderService:
|
||||||
message=t('progress.textSplit', count=total_chunks)
|
message=t('progress.textSplit', count=total_chunks)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 4. 分批发送数据
|
# 4. Send data in batches
|
||||||
episode_uuids = self.add_text_batches(
|
episode_uuids = self.add_text_batches(
|
||||||
graph_id, chunks, batch_size,
|
graph_id, chunks, batch_size,
|
||||||
lambda msg, prog: self.task_manager.update_task(
|
lambda msg, prog: self.task_manager.update_task(
|
||||||
|
|
@ -153,7 +153,7 @@ class GraphBuilderService:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 5. 等待Zep处理完成
|
# 5. Wait for Zep processing to complete
|
||||||
self.task_manager.update_task(
|
self.task_manager.update_task(
|
||||||
task_id,
|
task_id,
|
||||||
progress=60,
|
progress=60,
|
||||||
|
|
@ -169,7 +169,7 @@ class GraphBuilderService:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# 6. 获取图谱信息
|
# 6. Fetch graph info
|
||||||
self.task_manager.update_task(
|
self.task_manager.update_task(
|
||||||
task_id,
|
task_id,
|
||||||
progress=90,
|
progress=90,
|
||||||
|
|
@ -178,7 +178,7 @@ class GraphBuilderService:
|
||||||
|
|
||||||
graph_info = self._get_graph_info(graph_id)
|
graph_info = self._get_graph_info(graph_id)
|
||||||
|
|
||||||
# 完成
|
# Complete
|
||||||
self.task_manager.complete_task(task_id, {
|
self.task_manager.complete_task(task_id, {
|
||||||
"graph_id": graph_id,
|
"graph_id": graph_id,
|
||||||
"graph_info": graph_info.to_dict(),
|
"graph_info": graph_info.to_dict(),
|
||||||
|
|
@ -191,7 +191,7 @@ class GraphBuilderService:
|
||||||
self.task_manager.fail_task(task_id, error_msg)
|
self.task_manager.fail_task(task_id, error_msg)
|
||||||
|
|
||||||
def create_graph(self, name: str) -> str:
|
def create_graph(self, name: str) -> str:
|
||||||
"""创建Zep图谱(公开方法)"""
|
"""Create a Zep graph (public method)"""
|
||||||
graph_id = f"mirofish_{uuid.uuid4().hex[:16]}"
|
graph_id = f"mirofish_{uuid.uuid4().hex[:16]}"
|
||||||
|
|
||||||
self.client.graph.create(
|
self.client.graph.create(
|
||||||
|
|
@ -203,74 +203,74 @@ class GraphBuilderService:
|
||||||
return graph_id
|
return graph_id
|
||||||
|
|
||||||
def set_ontology(self, graph_id: str, ontology: Dict[str, Any]):
|
def set_ontology(self, graph_id: str, ontology: Dict[str, Any]):
|
||||||
"""设置图谱本体(公开方法)"""
|
"""Set graph ontology (public method)"""
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel
|
from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel
|
||||||
|
|
||||||
# 抑制 Pydantic v2 关于 Field(default=None) 的警告
|
# Suppress Pydantic v2 warnings about Field(default=None)
|
||||||
# 这是 Zep SDK 要求的用法,警告来自动态类创建,可以安全忽略
|
# This is the usage required by the Zep SDK; warnings come from dynamic class creation and can be safely ignored
|
||||||
warnings.filterwarnings('ignore', category=UserWarning, module='pydantic')
|
warnings.filterwarnings('ignore', category=UserWarning, module='pydantic')
|
||||||
|
|
||||||
# Zep 保留名称,不能作为属性名
|
# Zep reserved names that cannot be used as attribute names
|
||||||
RESERVED_NAMES = {'uuid', 'name', 'group_id', 'name_embedding', 'summary', 'created_at'}
|
RESERVED_NAMES = {'uuid', 'name', 'group_id', 'name_embedding', 'summary', 'created_at'}
|
||||||
|
|
||||||
def safe_attr_name(attr_name: str) -> str:
|
def safe_attr_name(attr_name: str) -> str:
|
||||||
"""将保留名称转换为安全名称"""
|
"""Convert reserved names to safe attribute names"""
|
||||||
if attr_name.lower() in RESERVED_NAMES:
|
if attr_name.lower() in RESERVED_NAMES:
|
||||||
return f"entity_{attr_name}"
|
return f"entity_{attr_name}"
|
||||||
return attr_name
|
return attr_name
|
||||||
|
|
||||||
# 动态创建实体类型
|
# Dynamically create entity types
|
||||||
entity_types = {}
|
entity_types = {}
|
||||||
for entity_def in ontology.get("entity_types", []):
|
for entity_def in ontology.get("entity_types", []):
|
||||||
name = entity_def["name"]
|
name = entity_def["name"]
|
||||||
description = entity_def.get("description", f"A {name} entity.")
|
description = entity_def.get("description", f"A {name} entity.")
|
||||||
|
|
||||||
# 创建属性字典和类型注解(Pydantic v2 需要)
|
# Build attribute dict and type annotations (required by Pydantic v2)
|
||||||
attrs = {"__doc__": description}
|
attrs = {"__doc__": description}
|
||||||
annotations = {}
|
annotations = {}
|
||||||
|
|
||||||
for attr_def in entity_def.get("attributes", []):
|
for attr_def in entity_def.get("attributes", []):
|
||||||
attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称
|
attr_name = safe_attr_name(attr_def["name"]) # Use safe name
|
||||||
attr_desc = attr_def.get("description", attr_name)
|
attr_desc = attr_def.get("description", attr_name)
|
||||||
# Zep API 需要 Field 的 description,这是必需的
|
# Zep API requires Field description — this is mandatory
|
||||||
attrs[attr_name] = Field(description=attr_desc, default=None)
|
attrs[attr_name] = Field(description=attr_desc, default=None)
|
||||||
annotations[attr_name] = Optional[EntityText] # 类型注解
|
annotations[attr_name] = Optional[EntityText] # Type annotation
|
||||||
|
|
||||||
attrs["__annotations__"] = annotations
|
attrs["__annotations__"] = annotations
|
||||||
|
|
||||||
# 动态创建类
|
# Dynamically create class
|
||||||
entity_class = type(name, (EntityModel,), attrs)
|
entity_class = type(name, (EntityModel,), attrs)
|
||||||
entity_class.__doc__ = description
|
entity_class.__doc__ = description
|
||||||
entity_types[name] = entity_class
|
entity_types[name] = entity_class
|
||||||
|
|
||||||
# 动态创建边类型
|
# Dynamically create edge types
|
||||||
edge_definitions = {}
|
edge_definitions = {}
|
||||||
for edge_def in ontology.get("edge_types", []):
|
for edge_def in ontology.get("edge_types", []):
|
||||||
name = edge_def["name"]
|
name = edge_def["name"]
|
||||||
description = edge_def.get("description", f"A {name} relationship.")
|
description = edge_def.get("description", f"A {name} relationship.")
|
||||||
|
|
||||||
# 创建属性字典和类型注解
|
# Build attribute dict and type annotations
|
||||||
attrs = {"__doc__": description}
|
attrs = {"__doc__": description}
|
||||||
annotations = {}
|
annotations = {}
|
||||||
|
|
||||||
for attr_def in edge_def.get("attributes", []):
|
for attr_def in edge_def.get("attributes", []):
|
||||||
attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称
|
attr_name = safe_attr_name(attr_def["name"]) # Use safe name
|
||||||
attr_desc = attr_def.get("description", attr_name)
|
attr_desc = attr_def.get("description", attr_name)
|
||||||
# Zep API 需要 Field 的 description,这是必需的
|
# Zep API requires Field description — this is mandatory
|
||||||
attrs[attr_name] = Field(description=attr_desc, default=None)
|
attrs[attr_name] = Field(description=attr_desc, default=None)
|
||||||
annotations[attr_name] = Optional[str] # 边属性用str类型
|
annotations[attr_name] = Optional[str] # Edge attributes use str type
|
||||||
|
|
||||||
attrs["__annotations__"] = annotations
|
attrs["__annotations__"] = annotations
|
||||||
|
|
||||||
# 动态创建类
|
# Dynamically create class
|
||||||
class_name = ''.join(word.capitalize() for word in name.split('_'))
|
class_name = ''.join(word.capitalize() for word in name.split('_'))
|
||||||
edge_class = type(class_name, (EdgeModel,), attrs)
|
edge_class = type(class_name, (EdgeModel,), attrs)
|
||||||
edge_class.__doc__ = description
|
edge_class.__doc__ = description
|
||||||
|
|
||||||
# 构建source_targets
|
# Build source_targets
|
||||||
source_targets = []
|
source_targets = []
|
||||||
for st in edge_def.get("source_targets", []):
|
for st in edge_def.get("source_targets", []):
|
||||||
source_targets.append(
|
source_targets.append(
|
||||||
|
|
@ -283,7 +283,7 @@ class GraphBuilderService:
|
||||||
if source_targets:
|
if source_targets:
|
||||||
edge_definitions[name] = (edge_class, source_targets)
|
edge_definitions[name] = (edge_class, source_targets)
|
||||||
|
|
||||||
# 调用Zep API设置本体
|
# Call Zep API to set ontology
|
||||||
if entity_types or edge_definitions:
|
if entity_types or edge_definitions:
|
||||||
self.client.graph.set_ontology(
|
self.client.graph.set_ontology(
|
||||||
graph_ids=[graph_id],
|
graph_ids=[graph_id],
|
||||||
|
|
@ -298,7 +298,7 @@ class GraphBuilderService:
|
||||||
batch_size: int = 3,
|
batch_size: int = 3,
|
||||||
progress_callback: Optional[Callable] = None
|
progress_callback: Optional[Callable] = None
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""分批添加文本到图谱,返回所有 episode 的 uuid 列表"""
|
"""Add text to the graph in batches; returns a list of all episode UUIDs"""
|
||||||
episode_uuids = []
|
episode_uuids = []
|
||||||
total_chunks = len(chunks)
|
total_chunks = len(chunks)
|
||||||
|
|
||||||
|
|
@ -314,27 +314,27 @@ class GraphBuilderService:
|
||||||
progress
|
progress
|
||||||
)
|
)
|
||||||
|
|
||||||
# 构建episode数据
|
# Build episode data
|
||||||
episodes = [
|
episodes = [
|
||||||
EpisodeData(data=chunk, type="text")
|
EpisodeData(data=chunk, type="text")
|
||||||
for chunk in batch_chunks
|
for chunk in batch_chunks
|
||||||
]
|
]
|
||||||
|
|
||||||
# 发送到Zep
|
# Send to Zep
|
||||||
try:
|
try:
|
||||||
batch_result = self.client.graph.add_batch(
|
batch_result = self.client.graph.add_batch(
|
||||||
graph_id=graph_id,
|
graph_id=graph_id,
|
||||||
episodes=episodes
|
episodes=episodes
|
||||||
)
|
)
|
||||||
|
|
||||||
# 收集返回的 episode uuid
|
# Collect returned episode UUIDs
|
||||||
if batch_result and isinstance(batch_result, list):
|
if batch_result and isinstance(batch_result, list):
|
||||||
for ep in batch_result:
|
for ep in batch_result:
|
||||||
ep_uuid = getattr(ep, 'uuid_', None) or getattr(ep, 'uuid', None)
|
ep_uuid = getattr(ep, 'uuid_', None) or getattr(ep, 'uuid', None)
|
||||||
if ep_uuid:
|
if ep_uuid:
|
||||||
episode_uuids.append(ep_uuid)
|
episode_uuids.append(ep_uuid)
|
||||||
|
|
||||||
# 避免请求过快
|
# Avoid sending requests too quickly
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -350,7 +350,7 @@ class GraphBuilderService:
|
||||||
progress_callback: Optional[Callable] = None,
|
progress_callback: Optional[Callable] = None,
|
||||||
timeout: int = 600
|
timeout: int = 600
|
||||||
):
|
):
|
||||||
"""等待所有 episode 处理完成(通过查询每个 episode 的 processed 状态)"""
|
"""Wait for all episodes to finish processing (by polling each episode's processed status)"""
|
||||||
if not episode_uuids:
|
if not episode_uuids:
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
progress_callback(t('progress.noEpisodesWait'), 1.0)
|
progress_callback(t('progress.noEpisodesWait'), 1.0)
|
||||||
|
|
@ -373,7 +373,7 @@ class GraphBuilderService:
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
# 检查每个 episode 的处理状态
|
# Check processing status of each episode
|
||||||
for ep_uuid in list(pending_episodes):
|
for ep_uuid in list(pending_episodes):
|
||||||
try:
|
try:
|
||||||
episode = self.client.graph.episode.get(uuid_=ep_uuid)
|
episode = self.client.graph.episode.get(uuid_=ep_uuid)
|
||||||
|
|
@ -384,7 +384,7 @@ class GraphBuilderService:
|
||||||
completed_count += 1
|
completed_count += 1
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 忽略单个查询错误,继续
|
# Ignore individual query errors and continue
|
||||||
pass
|
pass
|
||||||
|
|
||||||
elapsed = int(time.time() - start_time)
|
elapsed = int(time.time() - start_time)
|
||||||
|
|
@ -395,20 +395,20 @@ class GraphBuilderService:
|
||||||
)
|
)
|
||||||
|
|
||||||
if pending_episodes:
|
if pending_episodes:
|
||||||
time.sleep(3) # 每3秒检查一次
|
time.sleep(3) # Check every 3 seconds
|
||||||
|
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
progress_callback(t('progress.processingComplete', completed=completed_count, total=total_episodes), 1.0)
|
progress_callback(t('progress.processingComplete', completed=completed_count, total=total_episodes), 1.0)
|
||||||
|
|
||||||
def _get_graph_info(self, graph_id: str) -> GraphInfo:
|
def _get_graph_info(self, graph_id: str) -> GraphInfo:
|
||||||
"""获取图谱信息"""
|
"""Retrieve graph info"""
|
||||||
# 获取节点(分页)
|
# Fetch nodes (paginated)
|
||||||
nodes = fetch_all_nodes(self.client, graph_id)
|
nodes = fetch_all_nodes(self.client, graph_id)
|
||||||
|
|
||||||
# 获取边(分页)
|
# Fetch edges (paginated)
|
||||||
edges = fetch_all_edges(self.client, graph_id)
|
edges = fetch_all_edges(self.client, graph_id)
|
||||||
|
|
||||||
# 统计实体类型
|
# Count entity types
|
||||||
entity_types = set()
|
entity_types = set()
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
if node.labels:
|
if node.labels:
|
||||||
|
|
@ -425,25 +425,25 @@ class GraphBuilderService:
|
||||||
|
|
||||||
def get_graph_data(self, graph_id: str) -> Dict[str, Any]:
|
def get_graph_data(self, graph_id: str) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
获取完整图谱数据(包含详细信息)
|
Retrieve full graph data (with detailed information).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph_id: 图谱ID
|
graph_id: graph ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
包含nodes和edges的字典,包括时间信息、属性等详细数据
|
Dictionary containing nodes and edges with timestamps, attributes, and other details
|
||||||
"""
|
"""
|
||||||
nodes = fetch_all_nodes(self.client, graph_id)
|
nodes = fetch_all_nodes(self.client, graph_id)
|
||||||
edges = fetch_all_edges(self.client, graph_id)
|
edges = fetch_all_edges(self.client, graph_id)
|
||||||
|
|
||||||
# 创建节点映射用于获取节点名称
|
# Build node map for looking up node names
|
||||||
node_map = {}
|
node_map = {}
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
node_map[node.uuid_] = node.name or ""
|
node_map[node.uuid_] = node.name or ""
|
||||||
|
|
||||||
nodes_data = []
|
nodes_data = []
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
# 获取创建时间
|
# Get creation timestamp
|
||||||
created_at = getattr(node, 'created_at', None)
|
created_at = getattr(node, 'created_at', None)
|
||||||
if created_at:
|
if created_at:
|
||||||
created_at = str(created_at)
|
created_at = str(created_at)
|
||||||
|
|
@ -459,20 +459,20 @@ class GraphBuilderService:
|
||||||
|
|
||||||
edges_data = []
|
edges_data = []
|
||||||
for edge in edges:
|
for edge in edges:
|
||||||
# 获取时间信息
|
# Get timestamps
|
||||||
created_at = getattr(edge, 'created_at', None)
|
created_at = getattr(edge, 'created_at', None)
|
||||||
valid_at = getattr(edge, 'valid_at', None)
|
valid_at = getattr(edge, 'valid_at', None)
|
||||||
invalid_at = getattr(edge, 'invalid_at', None)
|
invalid_at = getattr(edge, 'invalid_at', None)
|
||||||
expired_at = getattr(edge, 'expired_at', None)
|
expired_at = getattr(edge, 'expired_at', None)
|
||||||
|
|
||||||
# 获取 episodes
|
# Get episodes
|
||||||
episodes = getattr(edge, 'episodes', None) or getattr(edge, 'episode_ids', None)
|
episodes = getattr(edge, 'episodes', None) or getattr(edge, 'episode_ids', None)
|
||||||
if episodes and not isinstance(episodes, list):
|
if episodes and not isinstance(episodes, list):
|
||||||
episodes = [str(episodes)]
|
episodes = [str(episodes)]
|
||||||
elif episodes:
|
elif episodes:
|
||||||
episodes = [str(e) for e in episodes]
|
episodes = [str(e) for e in episodes]
|
||||||
|
|
||||||
# 获取 fact_type
|
# Get fact_type
|
||||||
fact_type = getattr(edge, 'fact_type', None) or edge.name or ""
|
fact_type = getattr(edge, 'fact_type', None) or edge.name or ""
|
||||||
|
|
||||||
edges_data.append({
|
edges_data.append({
|
||||||
|
|
@ -501,6 +501,6 @@ class GraphBuilderService:
|
||||||
}
|
}
|
||||||
|
|
||||||
def delete_graph(self, graph_id: str):
|
def delete_graph(self, graph_id: str):
|
||||||
"""删除图谱"""
|
"""Delete graph"""
|
||||||
self.client.graph.delete(graph_id=graph_id)
|
self.client.graph.delete(graph_id=graph_id)
|
||||||
|
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,6 +1,6 @@
|
||||||
"""
|
"""
|
||||||
本体生成服务
|
Ontology generation service
|
||||||
接口1:分析文本内容,生成适合社会模拟的实体和关系类型定义
|
Endpoint 1: Analyze text content and generate entity and relationship type definitions suitable for social simulation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
|
@ -14,169 +14,169 @@ logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _to_pascal_case(name: str) -> str:
|
def _to_pascal_case(name: str) -> str:
|
||||||
"""将任意格式的名称转换为 PascalCase(如 'works_for' -> 'WorksFor', 'person' -> 'Person')"""
|
"""Convert a name in any format to PascalCase (e.g. 'works_for' -> 'WorksFor', 'person' -> 'Person')"""
|
||||||
# 按非字母数字字符分割
|
# Split on non-alphanumeric characters
|
||||||
parts = re.split(r'[^a-zA-Z0-9]+', name)
|
parts = re.split(r'[^a-zA-Z0-9]+', name)
|
||||||
# 再按 camelCase 边界分割(如 'camelCase' -> ['camel', 'Case'])
|
# Also split on camelCase boundaries (e.g. 'camelCase' -> ['camel', 'Case'])
|
||||||
words = []
|
words = []
|
||||||
for part in parts:
|
for part in parts:
|
||||||
words.extend(re.sub(r'([a-z])([A-Z])', r'\1_\2', part).split('_'))
|
words.extend(re.sub(r'([a-z])([A-Z])', r'\1_\2', part).split('_'))
|
||||||
# 每个词首字母大写,过滤空串
|
# Capitalize each word and filter empty strings
|
||||||
result = ''.join(word.capitalize() for word in words if word)
|
result = ''.join(word.capitalize() for word in words if word)
|
||||||
return result if result else 'Unknown'
|
return result if result else 'Unknown'
|
||||||
|
|
||||||
|
|
||||||
# 本体生成的系统提示词
|
# System prompt for ontology generation
|
||||||
ONTOLOGY_SYSTEM_PROMPT = """你是一个专业的知识图谱本体设计专家。你的任务是分析给定的文本内容和模拟需求,设计适合**社交媒体舆论模拟**的实体类型和关系类型。
|
ONTOLOGY_SYSTEM_PROMPT = """You are a professional knowledge graph ontology design expert. Your task is to analyze the given text content and simulation requirements, and design entity types and relationship types suitable for **social media opinion simulation**.
|
||||||
|
|
||||||
**重要:你必须输出有效的JSON格式数据,不要输出任何其他内容。**
|
**Important: You must output valid JSON format data, and nothing else.**
|
||||||
|
|
||||||
## 核心任务背景
|
## Core Task Background
|
||||||
|
|
||||||
我们正在构建一个**社交媒体舆论模拟系统**。在这个系统中:
|
We are building a **social media opinion simulation system**. In this system:
|
||||||
- 每个实体都是一个可以在社交媒体上发声、互动、传播信息的"账号"或"主体"
|
- Every entity is an "account" or "subject" that can speak out, interact, and spread information on social media
|
||||||
- 实体之间会相互影响、转发、评论、回应
|
- Entities influence each other, repost, comment, and respond
|
||||||
- 我们需要模拟舆论事件中各方的反应和信息传播路径
|
- We need to simulate each party's reaction and the information propagation path during opinion events
|
||||||
|
|
||||||
因此,**实体必须是现实中真实存在的、可以在社媒上发声和互动的主体**:
|
Therefore, **entities must be real-world subjects that exist and can speak out and interact on social media**:
|
||||||
|
|
||||||
**可以是**:
|
**Can be**:
|
||||||
- 具体的个人(公众人物、当事人、意见领袖、专家学者、普通人)
|
- Specific individuals (public figures, persons involved, opinion leaders, experts and scholars, ordinary people)
|
||||||
- 公司、企业(包括其官方账号)
|
- Companies and enterprises (including their official accounts)
|
||||||
- 组织机构(大学、协会、NGO、工会等)
|
- Organizations (universities, associations, NGOs, unions, etc.)
|
||||||
- 政府部门、监管机构
|
- Government departments and regulatory agencies
|
||||||
- 媒体机构(报纸、电视台、自媒体、网站)
|
- Media organizations (newspapers, TV stations, self-media, websites)
|
||||||
- 社交媒体平台本身
|
- Social media platforms themselves
|
||||||
- 特定群体代表(如校友会、粉丝团、维权群体等)
|
- Representatives of specific groups (e.g. alumni associations, fan clubs, rights-protection groups, etc.)
|
||||||
|
|
||||||
**不可以是**:
|
**Cannot be**:
|
||||||
- 抽象概念(如"舆论"、"情绪"、"趋势")
|
- Abstract concepts (e.g. "public opinion", "emotion", "trend")
|
||||||
- 主题/话题(如"学术诚信"、"教育改革")
|
- Topics/themes (e.g. "academic integrity", "education reform")
|
||||||
- 观点/态度(如"支持方"、"反对方")
|
- Viewpoints/attitudes (e.g. "supporters", "opponents")
|
||||||
|
|
||||||
## 输出格式
|
## Output Format
|
||||||
|
|
||||||
请输出JSON格式,包含以下结构:
|
Please output JSON format with the following structure:
|
||||||
|
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"entity_types": [
|
"entity_types": [
|
||||||
{
|
{
|
||||||
"name": "实体类型名称(英文,PascalCase)",
|
"name": "Entity type name (English, PascalCase)",
|
||||||
"description": "简短描述(英文,不超过100字符)",
|
"description": "Brief description (English, max 100 characters)",
|
||||||
"attributes": [
|
"attributes": [
|
||||||
{
|
{
|
||||||
"name": "属性名(英文,snake_case)",
|
"name": "Attribute name (English, snake_case)",
|
||||||
"type": "text",
|
"type": "text",
|
||||||
"description": "属性描述"
|
"description": "Attribute description"
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"examples": ["示例实体1", "示例实体2"]
|
"examples": ["Example entity 1", "Example entity 2"]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"edge_types": [
|
"edge_types": [
|
||||||
{
|
{
|
||||||
"name": "关系类型名称(英文,UPPER_SNAKE_CASE)",
|
"name": "Relationship type name (English, UPPER_SNAKE_CASE)",
|
||||||
"description": "简短描述(英文,不超过100字符)",
|
"description": "Brief description (English, max 100 characters)",
|
||||||
"source_targets": [
|
"source_targets": [
|
||||||
{"source": "源实体类型", "target": "目标实体类型"}
|
{"source": "Source entity type", "target": "Target entity type"}
|
||||||
],
|
],
|
||||||
"attributes": []
|
"attributes": []
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"analysis_summary": "对文本内容的简要分析说明"
|
"analysis_summary": "Brief analysis summary of the text content"
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
## 设计指南(极其重要!)
|
## Design Guidelines (Extremely Important!)
|
||||||
|
|
||||||
### 1. 实体类型设计 - 必须严格遵守
|
### 1. Entity Type Design — Must Be Strictly Followed
|
||||||
|
|
||||||
**数量要求:必须正好10个实体类型**
|
**Quantity requirement: exactly 10 entity types**
|
||||||
|
|
||||||
**层次结构要求(必须同时包含具体类型和兜底类型)**:
|
**Hierarchy requirement (must include both specific types and fallback types)**:
|
||||||
|
|
||||||
你的10个实体类型必须包含以下层次:
|
Your 10 entity types must include the following levels:
|
||||||
|
|
||||||
A. **兜底类型(必须包含,放在列表最后2个)**:
|
A. **Fallback types (required, placed as the last 2 in the list)**:
|
||||||
- `Person`: 任何自然人个体的兜底类型。当一个人不属于其他更具体的人物类型时,归入此类。
|
- `Person`: Fallback type for any individual person. Use this when a person does not fit any other more specific person type.
|
||||||
- `Organization`: 任何组织机构的兜底类型。当一个组织不属于其他更具体的组织类型时,归入此类。
|
- `Organization`: Fallback type for any organization. Use this when an organization does not fit any other more specific organization type.
|
||||||
|
|
||||||
B. **具体类型(8个,根据文本内容设计)**:
|
B. **Specific types (8 types, designed based on text content)**:
|
||||||
- 针对文本中出现的主要角色,设计更具体的类型
|
- Design more specific types for the main roles that appear in the text
|
||||||
- 例如:如果文本涉及学术事件,可以有 `Student`, `Professor`, `University`
|
- Example: if the text involves an academic event, you might have `Student`, `Professor`, `University`
|
||||||
- 例如:如果文本涉及商业事件,可以有 `Company`, `CEO`, `Employee`
|
- Example: if the text involves a business event, you might have `Company`, `CEO`, `Employee`
|
||||||
|
|
||||||
**为什么需要兜底类型**:
|
**Why fallback types are needed**:
|
||||||
- 文本中会出现各种人物,如"中小学教师"、"路人甲"、"某位网友"
|
- Various people appear in text, such as "primary and secondary school teachers", "passersby", "some netizen"
|
||||||
- 如果没有专门的类型匹配,他们应该被归入 `Person`
|
- If there is no dedicated type to match them, they should fall into `Person`
|
||||||
- 同理,小型组织、临时团体等应该归入 `Organization`
|
- Similarly, small organizations, ad hoc groups, etc. should fall into `Organization`
|
||||||
|
|
||||||
**具体类型的设计原则**:
|
**Principles for designing specific types**:
|
||||||
- 从文本中识别出高频出现或关键的角色类型
|
- Identify high-frequency or key role types from the text
|
||||||
- 每个具体类型应该有明确的边界,避免重叠
|
- Each specific type should have clear boundaries and avoid overlap
|
||||||
- description 必须清晰说明这个类型和兜底类型的区别
|
- The description must clearly explain the difference between this type and the fallback types
|
||||||
|
|
||||||
### 2. 关系类型设计
|
### 2. Relationship Type Design
|
||||||
|
|
||||||
- 数量:6-10个
|
- Quantity: 6-10
|
||||||
- 关系应该反映社媒互动中的真实联系
|
- Relationships should reflect real connections in social media interactions
|
||||||
- 确保关系的 source_targets 涵盖你定义的实体类型
|
- Ensure the source_targets in relationships cover the entity types you have defined
|
||||||
|
|
||||||
### 3. 属性设计
|
### 3. Attribute Design
|
||||||
|
|
||||||
- 每个实体类型1-3个关键属性
|
- 1-3 key attributes per entity type
|
||||||
- **注意**:属性名不能使用 `name`、`uuid`、`group_id`、`created_at`、`summary`(这些是系统保留字)
|
- **Note**: Attribute names must not use `name`, `uuid`, `group_id`, `created_at`, `summary` (these are system reserved words)
|
||||||
- 推荐使用:`full_name`, `title`, `role`, `position`, `location`, `description` 等
|
- Recommended: `full_name`, `title`, `role`, `position`, `location`, `description`, etc.
|
||||||
|
|
||||||
## 实体类型参考
|
## Entity Type Reference
|
||||||
|
|
||||||
**个人类(具体)**:
|
**Individual types (specific)**:
|
||||||
- Student: 学生
|
- Student: student
|
||||||
- Professor: 教授/学者
|
- Professor: professor/scholar
|
||||||
- Journalist: 记者
|
- Journalist: journalist
|
||||||
- Celebrity: 明星/网红
|
- Celebrity: celebrity/influencer
|
||||||
- Executive: 高管
|
- Executive: corporate executive
|
||||||
- Official: 政府官员
|
- Official: government official
|
||||||
- Lawyer: 律师
|
- Lawyer: lawyer
|
||||||
- Doctor: 医生
|
- Doctor: doctor
|
||||||
|
|
||||||
**个人类(兜底)**:
|
**Individual types (fallback)**:
|
||||||
- Person: 任何自然人(不属于上述具体类型时使用)
|
- Person: any individual (use when not fitting the specific types above)
|
||||||
|
|
||||||
**组织类(具体)**:
|
**Organization types (specific)**:
|
||||||
- University: 高校
|
- University: university/college
|
||||||
- Company: 公司企业
|
- Company: company/enterprise
|
||||||
- GovernmentAgency: 政府机构
|
- GovernmentAgency: government agency
|
||||||
- MediaOutlet: 媒体机构
|
- MediaOutlet: media organization
|
||||||
- Hospital: 医院
|
- Hospital: hospital
|
||||||
- School: 中小学
|
- School: primary/secondary school
|
||||||
- NGO: 非政府组织
|
- NGO: non-governmental organization
|
||||||
|
|
||||||
**组织类(兜底)**:
|
**Organization types (fallback)**:
|
||||||
- Organization: 任何组织机构(不属于上述具体类型时使用)
|
- Organization: any organization (use when not fitting the specific types above)
|
||||||
|
|
||||||
## 关系类型参考
|
## Relationship Type Reference
|
||||||
|
|
||||||
- WORKS_FOR: 工作于
|
- WORKS_FOR: works for
|
||||||
- STUDIES_AT: 就读于
|
- STUDIES_AT: studies at
|
||||||
- AFFILIATED_WITH: 隶属于
|
- AFFILIATED_WITH: affiliated with
|
||||||
- REPRESENTS: 代表
|
- REPRESENTS: represents
|
||||||
- REGULATES: 监管
|
- REGULATES: regulates
|
||||||
- REPORTS_ON: 报道
|
- REPORTS_ON: reports on
|
||||||
- COMMENTS_ON: 评论
|
- COMMENTS_ON: comments on
|
||||||
- RESPONDS_TO: 回应
|
- RESPONDS_TO: responds to
|
||||||
- SUPPORTS: 支持
|
- SUPPORTS: supports
|
||||||
- OPPOSES: 反对
|
- OPPOSES: opposes
|
||||||
- COLLABORATES_WITH: 合作
|
- COLLABORATES_WITH: collaborates with
|
||||||
- COMPETES_WITH: 竞争
|
- COMPETES_WITH: competes with
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class OntologyGenerator:
|
class OntologyGenerator:
|
||||||
"""
|
"""
|
||||||
本体生成器
|
Ontology generator
|
||||||
分析文本内容,生成实体和关系类型定义
|
Analyzes text content and generates entity and relationship type definitions.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, llm_client: Optional[LLMClient] = None):
|
def __init__(self, llm_client: Optional[LLMClient] = None):
|
||||||
|
|
@ -189,95 +189,100 @@ class OntologyGenerator:
|
||||||
additional_context: Optional[str] = None
|
additional_context: Optional[str] = None
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
生成本体定义
|
Generate ontology definition.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
document_texts: 文档文本列表
|
document_texts: list of document texts
|
||||||
simulation_requirement: 模拟需求描述
|
simulation_requirement: simulation requirement description
|
||||||
additional_context: 额外上下文
|
additional_context: additional context
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
本体定义(entity_types, edge_types等)
|
Ontology definition (entity_types, edge_types, etc.)
|
||||||
"""
|
"""
|
||||||
# 构建用户消息
|
lang_instruction = get_language_instruction()
|
||||||
|
|
||||||
|
# Build user message
|
||||||
user_message = self._build_user_message(
|
user_message = self._build_user_message(
|
||||||
document_texts,
|
document_texts,
|
||||||
simulation_requirement,
|
simulation_requirement,
|
||||||
additional_context
|
additional_context,
|
||||||
|
lang_instruction
|
||||||
)
|
)
|
||||||
|
|
||||||
lang_instruction = get_language_instruction()
|
system_prompt = f"LANGUAGE INSTRUCTION (HIGHEST PRIORITY — MUST BE FOLLOWED): {lang_instruction} All description fields, analysis_summary, and examples MUST be written in this language.\n\n{ONTOLOGY_SYSTEM_PROMPT}\n\n{lang_instruction}\nIMPORTANT: Entity type names MUST be in English PascalCase (e.g., 'PersonEntity', 'MediaOrganization'). Relationship type names MUST be in English UPPER_SNAKE_CASE (e.g., 'WORKS_FOR'). Attribute names MUST be in English snake_case. Only description fields and analysis_summary should use the specified language above."
|
||||||
system_prompt = f"{ONTOLOGY_SYSTEM_PROMPT}\n\n{lang_instruction}\nIMPORTANT: Entity type names MUST be in English PascalCase (e.g., 'PersonEntity', 'MediaOrganization'). Relationship type names MUST be in English UPPER_SNAKE_CASE (e.g., 'WORKS_FOR'). Attribute names MUST be in English snake_case. Only description fields and analysis_summary should use the specified language above."
|
|
||||||
messages = [
|
messages = [
|
||||||
{"role": "system", "content": system_prompt},
|
{"role": "system", "content": system_prompt},
|
||||||
{"role": "user", "content": user_message}
|
{"role": "user", "content": user_message}
|
||||||
]
|
]
|
||||||
|
|
||||||
# 调用LLM
|
# Call LLM
|
||||||
result = self.llm_client.chat_json(
|
result = self.llm_client.chat_json(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
temperature=0.3,
|
temperature=0.3,
|
||||||
max_tokens=4096
|
max_tokens=4096
|
||||||
)
|
)
|
||||||
|
|
||||||
# 验证和后处理
|
# Validate and post-process
|
||||||
result = self._validate_and_process(result)
|
result = self._validate_and_process(result)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
# 传给 LLM 的文本最大长度(5万字)
|
# Maximum text length passed to LLM (50,000 characters)
|
||||||
MAX_TEXT_LENGTH_FOR_LLM = 50000
|
MAX_TEXT_LENGTH_FOR_LLM = 50000
|
||||||
|
|
||||||
def _build_user_message(
|
def _build_user_message(
|
||||||
self,
|
self,
|
||||||
document_texts: List[str],
|
document_texts: List[str],
|
||||||
simulation_requirement: str,
|
simulation_requirement: str,
|
||||||
additional_context: Optional[str]
|
additional_context: Optional[str],
|
||||||
|
lang_instruction: str = ""
|
||||||
) -> str:
|
) -> str:
|
||||||
"""构建用户消息"""
|
"""Build user message"""
|
||||||
|
|
||||||
# 合并文本
|
# Merge texts
|
||||||
combined_text = "\n\n---\n\n".join(document_texts)
|
combined_text = "\n\n---\n\n".join(document_texts)
|
||||||
original_length = len(combined_text)
|
original_length = len(combined_text)
|
||||||
|
|
||||||
# 如果文本超过5万字,截断(仅影响传给LLM的内容,不影响图谱构建)
|
# If text exceeds 50,000 characters, truncate (only affects what is passed to LLM, not graph building)
|
||||||
if len(combined_text) > self.MAX_TEXT_LENGTH_FOR_LLM:
|
if len(combined_text) > self.MAX_TEXT_LENGTH_FOR_LLM:
|
||||||
combined_text = combined_text[:self.MAX_TEXT_LENGTH_FOR_LLM]
|
combined_text = combined_text[:self.MAX_TEXT_LENGTH_FOR_LLM]
|
||||||
combined_text += f"\n\n...(原文共{original_length}字,已截取前{self.MAX_TEXT_LENGTH_FOR_LLM}字用于本体分析)..."
|
combined_text += f"\n\n...(text truncated at {self.MAX_TEXT_LENGTH_FOR_LLM} chars out of {original_length} total)..."
|
||||||
|
|
||||||
message = f"""## 模拟需求
|
message = f"""## Simulation requirement
|
||||||
|
|
||||||
{simulation_requirement}
|
{simulation_requirement}
|
||||||
|
|
||||||
## 文档内容
|
## Document content
|
||||||
|
|
||||||
{combined_text}
|
{combined_text}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if additional_context:
|
if additional_context:
|
||||||
message += f"""
|
message += f"""
|
||||||
## 额外说明
|
## Additional context
|
||||||
|
|
||||||
{additional_context}
|
{additional_context}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
message += """
|
message += f"""
|
||||||
请根据以上内容,设计适合社会舆论模拟的实体类型和关系类型。
|
Based on the content above, design entity types and relationship types suitable for social opinion simulation.
|
||||||
|
|
||||||
**必须遵守的规则**:
|
**Mandatory rules**:
|
||||||
1. 必须正好输出10个实体类型
|
1. Output exactly 10 entity types
|
||||||
2. 最后2个必须是兜底类型:Person(个人兜底)和 Organization(组织兜底)
|
2. The last 2 must be fallback types: Person (individual fallback) and Organization (organization fallback)
|
||||||
3. 前8个是根据文本内容设计的具体类型
|
3. The first 8 are specific types designed from the document content
|
||||||
4. 所有实体类型必须是现实中可以发声的主体,不能是抽象概念
|
4. All entity types must be real-world subjects capable of speaking out, not abstract concepts
|
||||||
5. 属性名不能使用 name、uuid、group_id 等保留字,用 full_name、org_name 等替代
|
5. Attribute names must not use reserved words: name, uuid, group_id — use full_name, org_name, etc. instead
|
||||||
|
|
||||||
|
{lang_instruction}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
return message
|
return message
|
||||||
|
|
||||||
def _validate_and_process(self, result: Dict[str, Any]) -> Dict[str, Any]:
|
def _validate_and_process(self, result: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
"""验证和后处理结果"""
|
"""Validate and post-process the result"""
|
||||||
|
|
||||||
# 确保必要字段存在
|
# Ensure required fields exist
|
||||||
if "entity_types" not in result:
|
if "entity_types" not in result:
|
||||||
result["entity_types"] = []
|
result["entity_types"] = []
|
||||||
if "edge_types" not in result:
|
if "edge_types" not in result:
|
||||||
|
|
@ -285,11 +290,11 @@ class OntologyGenerator:
|
||||||
if "analysis_summary" not in result:
|
if "analysis_summary" not in result:
|
||||||
result["analysis_summary"] = ""
|
result["analysis_summary"] = ""
|
||||||
|
|
||||||
# 验证实体类型
|
# Validate entity types
|
||||||
# 记录原始名称到 PascalCase 的映射,用于后续修正 edge 的 source_targets 引用
|
# Record mapping from original name to PascalCase for fixing edge source_targets references later
|
||||||
entity_name_map = {}
|
entity_name_map = {}
|
||||||
for entity in result["entity_types"]:
|
for entity in result["entity_types"]:
|
||||||
# 强制将 entity name 转为 PascalCase(Zep API 要求)
|
# Force entity name to PascalCase (required by Zep API)
|
||||||
if "name" in entity:
|
if "name" in entity:
|
||||||
original_name = entity["name"]
|
original_name = entity["name"]
|
||||||
entity["name"] = _to_pascal_case(original_name)
|
entity["name"] = _to_pascal_case(original_name)
|
||||||
|
|
@ -300,19 +305,19 @@ class OntologyGenerator:
|
||||||
entity["attributes"] = []
|
entity["attributes"] = []
|
||||||
if "examples" not in entity:
|
if "examples" not in entity:
|
||||||
entity["examples"] = []
|
entity["examples"] = []
|
||||||
# 确保description不超过100字符
|
# Ensure description does not exceed 100 characters
|
||||||
if len(entity.get("description", "")) > 100:
|
if len(entity.get("description", "")) > 100:
|
||||||
entity["description"] = entity["description"][:97] + "..."
|
entity["description"] = entity["description"][:97] + "..."
|
||||||
|
|
||||||
# 验证关系类型
|
# Validate relationship types
|
||||||
for edge in result["edge_types"]:
|
for edge in result["edge_types"]:
|
||||||
# 强制将 edge name 转为 SCREAMING_SNAKE_CASE(Zep API 要求)
|
# Force edge name to SCREAMING_SNAKE_CASE (required by Zep API)
|
||||||
if "name" in edge:
|
if "name" in edge:
|
||||||
original_name = edge["name"]
|
original_name = edge["name"]
|
||||||
edge["name"] = original_name.upper()
|
edge["name"] = original_name.upper()
|
||||||
if edge["name"] != original_name:
|
if edge["name"] != original_name:
|
||||||
logger.warning(f"Edge type name '{original_name}' auto-converted to '{edge['name']}'")
|
logger.warning(f"Edge type name '{original_name}' auto-converted to '{edge['name']}'")
|
||||||
# 修正 source_targets 中的实体名称引用,与转换后的 PascalCase 保持一致
|
# Fix entity name references in source_targets to match converted PascalCase names
|
||||||
for st in edge.get("source_targets", []):
|
for st in edge.get("source_targets", []):
|
||||||
if st.get("source") in entity_name_map:
|
if st.get("source") in entity_name_map:
|
||||||
st["source"] = entity_name_map[st["source"]]
|
st["source"] = entity_name_map[st["source"]]
|
||||||
|
|
@ -325,11 +330,11 @@ class OntologyGenerator:
|
||||||
if len(edge.get("description", "")) > 100:
|
if len(edge.get("description", "")) > 100:
|
||||||
edge["description"] = edge["description"][:97] + "..."
|
edge["description"] = edge["description"][:97] + "..."
|
||||||
|
|
||||||
# Zep API 限制:最多 10 个自定义实体类型,最多 10 个自定义边类型
|
# Zep API limit: maximum 10 custom entity types and 10 custom edge types
|
||||||
MAX_ENTITY_TYPES = 10
|
MAX_ENTITY_TYPES = 10
|
||||||
MAX_EDGE_TYPES = 10
|
MAX_EDGE_TYPES = 10
|
||||||
|
|
||||||
# 去重:按 name 去重,保留首次出现的
|
# Deduplicate: keep first occurrence by name
|
||||||
seen_names = set()
|
seen_names = set()
|
||||||
deduped = []
|
deduped = []
|
||||||
for entity in result["entity_types"]:
|
for entity in result["entity_types"]:
|
||||||
|
|
@ -341,7 +346,7 @@ class OntologyGenerator:
|
||||||
logger.warning(f"Duplicate entity type '{name}' removed during validation")
|
logger.warning(f"Duplicate entity type '{name}' removed during validation")
|
||||||
result["entity_types"] = deduped
|
result["entity_types"] = deduped
|
||||||
|
|
||||||
# 兜底类型定义
|
# Fallback type definitions
|
||||||
person_fallback = {
|
person_fallback = {
|
||||||
"name": "Person",
|
"name": "Person",
|
||||||
"description": "Any individual person not fitting other specific person types.",
|
"description": "Any individual person not fitting other specific person types.",
|
||||||
|
|
@ -362,12 +367,12 @@ class OntologyGenerator:
|
||||||
"examples": ["small business", "community group"]
|
"examples": ["small business", "community group"]
|
||||||
}
|
}
|
||||||
|
|
||||||
# 检查是否已有兜底类型
|
# Check whether fallback types already exist
|
||||||
entity_names = {e["name"] for e in result["entity_types"]}
|
entity_names = {e["name"] for e in result["entity_types"]}
|
||||||
has_person = "Person" in entity_names
|
has_person = "Person" in entity_names
|
||||||
has_organization = "Organization" in entity_names
|
has_organization = "Organization" in entity_names
|
||||||
|
|
||||||
# 需要添加的兜底类型
|
# Collect fallback types to add
|
||||||
fallbacks_to_add = []
|
fallbacks_to_add = []
|
||||||
if not has_person:
|
if not has_person:
|
||||||
fallbacks_to_add.append(person_fallback)
|
fallbacks_to_add.append(person_fallback)
|
||||||
|
|
@ -378,17 +383,17 @@ class OntologyGenerator:
|
||||||
current_count = len(result["entity_types"])
|
current_count = len(result["entity_types"])
|
||||||
needed_slots = len(fallbacks_to_add)
|
needed_slots = len(fallbacks_to_add)
|
||||||
|
|
||||||
# 如果添加后会超过 10 个,需要移除一些现有类型
|
# If adding them would exceed 10, remove some existing types
|
||||||
if current_count + needed_slots > MAX_ENTITY_TYPES:
|
if current_count + needed_slots > MAX_ENTITY_TYPES:
|
||||||
# 计算需要移除多少个
|
# Calculate how many to remove
|
||||||
to_remove = current_count + needed_slots - MAX_ENTITY_TYPES
|
to_remove = current_count + needed_slots - MAX_ENTITY_TYPES
|
||||||
# 从末尾移除(保留前面更重要的具体类型)
|
# Remove from the end (preserve the more important specific types at the front)
|
||||||
result["entity_types"] = result["entity_types"][:-to_remove]
|
result["entity_types"] = result["entity_types"][:-to_remove]
|
||||||
|
|
||||||
# 添加兜底类型
|
# Add fallback types
|
||||||
result["entity_types"].extend(fallbacks_to_add)
|
result["entity_types"].extend(fallbacks_to_add)
|
||||||
|
|
||||||
# 最终确保不超过限制(防御性编程)
|
# Final guard: ensure limits are not exceeded (defensive programming)
|
||||||
if len(result["entity_types"]) > MAX_ENTITY_TYPES:
|
if len(result["entity_types"]) > MAX_ENTITY_TYPES:
|
||||||
result["entity_types"] = result["entity_types"][:MAX_ENTITY_TYPES]
|
result["entity_types"] = result["entity_types"][:MAX_ENTITY_TYPES]
|
||||||
|
|
||||||
|
|
@ -399,29 +404,29 @@ class OntologyGenerator:
|
||||||
|
|
||||||
def generate_python_code(self, ontology: Dict[str, Any]) -> str:
|
def generate_python_code(self, ontology: Dict[str, Any]) -> str:
|
||||||
"""
|
"""
|
||||||
将本体定义转换为Python代码(类似ontology.py)
|
Convert the ontology definition to Python code (similar to ontology.py).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
ontology: 本体定义
|
ontology: ontology definition
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Python代码字符串
|
Python code string
|
||||||
"""
|
"""
|
||||||
code_lines = [
|
code_lines = [
|
||||||
'"""',
|
'"""',
|
||||||
'自定义实体类型定义',
|
'Custom entity type definitions',
|
||||||
'由MiroFish自动生成,用于社会舆论模拟',
|
'Auto-generated by MiroFish for social opinion simulation',
|
||||||
'"""',
|
'"""',
|
||||||
'',
|
'',
|
||||||
'from pydantic import Field',
|
'from pydantic import Field',
|
||||||
'from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel',
|
'from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel',
|
||||||
'',
|
'',
|
||||||
'',
|
'',
|
||||||
'# ============== 实体类型定义 ==============',
|
'# ============== Entity type definitions ==============',
|
||||||
'',
|
'',
|
||||||
]
|
]
|
||||||
|
|
||||||
# 生成实体类型
|
# Generate entity types
|
||||||
for entity in ontology.get("entity_types", []):
|
for entity in ontology.get("entity_types", []):
|
||||||
name = entity["name"]
|
name = entity["name"]
|
||||||
desc = entity.get("description", f"A {name} entity.")
|
desc = entity.get("description", f"A {name} entity.")
|
||||||
|
|
@ -444,13 +449,13 @@ class OntologyGenerator:
|
||||||
code_lines.append('')
|
code_lines.append('')
|
||||||
code_lines.append('')
|
code_lines.append('')
|
||||||
|
|
||||||
code_lines.append('# ============== 关系类型定义 ==============')
|
code_lines.append('# ============== Relationship type definitions ==============')
|
||||||
code_lines.append('')
|
code_lines.append('')
|
||||||
|
|
||||||
# 生成关系类型
|
# Generate relationship types
|
||||||
for edge in ontology.get("edge_types", []):
|
for edge in ontology.get("edge_types", []):
|
||||||
name = edge["name"]
|
name = edge["name"]
|
||||||
# 转换为PascalCase类名
|
# Convert to PascalCase class name
|
||||||
class_name = ''.join(word.capitalize() for word in name.split('_'))
|
class_name = ''.join(word.capitalize() for word in name.split('_'))
|
||||||
desc = edge.get("description", f"A {name} relationship.")
|
desc = edge.get("description", f"A {name} relationship.")
|
||||||
|
|
||||||
|
|
@ -472,8 +477,8 @@ class OntologyGenerator:
|
||||||
code_lines.append('')
|
code_lines.append('')
|
||||||
code_lines.append('')
|
code_lines.append('')
|
||||||
|
|
||||||
# 生成类型字典
|
# Generate type dictionaries
|
||||||
code_lines.append('# ============== 类型配置 ==============')
|
code_lines.append('# ============== Type configuration ==============')
|
||||||
code_lines.append('')
|
code_lines.append('')
|
||||||
code_lines.append('ENTITY_TYPES = {')
|
code_lines.append('ENTITY_TYPES = {')
|
||||||
for entity in ontology.get("entity_types", []):
|
for entity in ontology.get("entity_types", []):
|
||||||
|
|
@ -489,7 +494,7 @@ class OntologyGenerator:
|
||||||
code_lines.append('}')
|
code_lines.append('}')
|
||||||
code_lines.append('')
|
code_lines.append('')
|
||||||
|
|
||||||
# 生成边的source_targets映射
|
# Generate edge source_targets mapping
|
||||||
code_lines.append('EDGE_SOURCE_TARGETS = {')
|
code_lines.append('EDGE_SOURCE_TARGETS = {')
|
||||||
for edge in ontology.get("edge_types", []):
|
for edge in ontology.get("edge_types", []):
|
||||||
name = edge["name"]
|
name = edge["name"]
|
||||||
|
|
@ -503,4 +508,3 @@ class OntologyGenerator:
|
||||||
code_lines.append('}')
|
code_lines.append('}')
|
||||||
|
|
||||||
return '\n'.join(code_lines)
|
return '\n'.join(code_lines)
|
||||||
|
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -1,11 +1,11 @@
|
||||||
"""
|
"""
|
||||||
模拟IPC通信模块
|
Simulation IPC communication module
|
||||||
用于Flask后端和模拟脚本之间的进程间通信
|
Used for inter-process communication between the Flask backend and simulation scripts.
|
||||||
|
|
||||||
通过文件系统实现简单的命令/响应模式:
|
Implements a simple command/response pattern via the file system:
|
||||||
1. Flask写入命令到 commands/ 目录
|
1. Flask writes commands to the commands/ directory
|
||||||
2. 模拟脚本轮询命令目录,执行命令并写入响应到 responses/ 目录
|
2. Simulation scripts poll the command directory, execute commands, and write responses to the responses/ directory
|
||||||
3. Flask轮询响应目录获取结果
|
3. Flask polls the response directory to get results
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
@ -23,14 +23,14 @@ logger = get_logger('mirofish.simulation_ipc')
|
||||||
|
|
||||||
|
|
||||||
class CommandType(str, Enum):
|
class CommandType(str, Enum):
|
||||||
"""命令类型"""
|
"""Command type"""
|
||||||
INTERVIEW = "interview" # 单个Agent采访
|
INTERVIEW = "interview" # Single agent interview
|
||||||
BATCH_INTERVIEW = "batch_interview" # 批量采访
|
BATCH_INTERVIEW = "batch_interview" # Batch interview
|
||||||
CLOSE_ENV = "close_env" # 关闭环境
|
CLOSE_ENV = "close_env" # Close environment
|
||||||
|
|
||||||
|
|
||||||
class CommandStatus(str, Enum):
|
class CommandStatus(str, Enum):
|
||||||
"""命令状态"""
|
"""Command status"""
|
||||||
PENDING = "pending"
|
PENDING = "pending"
|
||||||
PROCESSING = "processing"
|
PROCESSING = "processing"
|
||||||
COMPLETED = "completed"
|
COMPLETED = "completed"
|
||||||
|
|
@ -39,7 +39,7 @@ class CommandStatus(str, Enum):
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class IPCCommand:
|
class IPCCommand:
|
||||||
"""IPC命令"""
|
"""IPC command"""
|
||||||
command_id: str
|
command_id: str
|
||||||
command_type: CommandType
|
command_type: CommandType
|
||||||
args: Dict[str, Any]
|
args: Dict[str, Any]
|
||||||
|
|
@ -65,7 +65,7 @@ class IPCCommand:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class IPCResponse:
|
class IPCResponse:
|
||||||
"""IPC响应"""
|
"""IPC response"""
|
||||||
command_id: str
|
command_id: str
|
||||||
status: CommandStatus
|
status: CommandStatus
|
||||||
result: Optional[Dict[str, Any]] = None
|
result: Optional[Dict[str, Any]] = None
|
||||||
|
|
@ -94,23 +94,23 @@ class IPCResponse:
|
||||||
|
|
||||||
class SimulationIPCClient:
|
class SimulationIPCClient:
|
||||||
"""
|
"""
|
||||||
模拟IPC客户端(Flask端使用)
|
Simulation IPC client (used by the Flask side)
|
||||||
|
|
||||||
用于向模拟进程发送命令并等待响应
|
Used to send commands to the simulation process and wait for responses
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, simulation_dir: str):
|
def __init__(self, simulation_dir: str):
|
||||||
"""
|
"""
|
||||||
初始化IPC客户端
|
Initialize the IPC client
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
simulation_dir: 模拟数据目录
|
simulation_dir: simulation data directory
|
||||||
"""
|
"""
|
||||||
self.simulation_dir = simulation_dir
|
self.simulation_dir = simulation_dir
|
||||||
self.commands_dir = os.path.join(simulation_dir, "ipc_commands")
|
self.commands_dir = os.path.join(simulation_dir, "ipc_commands")
|
||||||
self.responses_dir = os.path.join(simulation_dir, "ipc_responses")
|
self.responses_dir = os.path.join(simulation_dir, "ipc_responses")
|
||||||
|
|
||||||
# 确保目录存在
|
# Ensure directories exist
|
||||||
os.makedirs(self.commands_dir, exist_ok=True)
|
os.makedirs(self.commands_dir, exist_ok=True)
|
||||||
os.makedirs(self.responses_dir, exist_ok=True)
|
os.makedirs(self.responses_dir, exist_ok=True)
|
||||||
|
|
||||||
|
|
@ -122,19 +122,19 @@ class SimulationIPCClient:
|
||||||
poll_interval: float = 0.5
|
poll_interval: float = 0.5
|
||||||
) -> IPCResponse:
|
) -> IPCResponse:
|
||||||
"""
|
"""
|
||||||
发送命令并等待响应
|
Send a command and wait for a response
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
command_type: 命令类型
|
command_type: command type
|
||||||
args: 命令参数
|
args: command arguments
|
||||||
timeout: 超时时间(秒)
|
timeout: timeout in seconds
|
||||||
poll_interval: 轮询间隔(秒)
|
poll_interval: polling interval in seconds
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
IPCResponse
|
IPCResponse
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TimeoutError: 等待响应超时
|
TimeoutError: timed out waiting for a response
|
||||||
"""
|
"""
|
||||||
command_id = str(uuid.uuid4())
|
command_id = str(uuid.uuid4())
|
||||||
command = IPCCommand(
|
command = IPCCommand(
|
||||||
|
|
@ -143,14 +143,14 @@ class SimulationIPCClient:
|
||||||
args=args
|
args=args
|
||||||
)
|
)
|
||||||
|
|
||||||
# 写入命令文件
|
# Write command file
|
||||||
command_file = os.path.join(self.commands_dir, f"{command_id}.json")
|
command_file = os.path.join(self.commands_dir, f"{command_id}.json")
|
||||||
with open(command_file, 'w', encoding='utf-8') as f:
|
with open(command_file, 'w', encoding='utf-8') as f:
|
||||||
json.dump(command.to_dict(), f, ensure_ascii=False, indent=2)
|
json.dump(command.to_dict(), f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
logger.info(f"发送IPC命令: {command_type.value}, command_id={command_id}")
|
logger.info(f"Sending IPC command: {command_type.value}, command_id={command_id}")
|
||||||
|
|
||||||
# 等待响应
|
# Wait for response
|
||||||
response_file = os.path.join(self.responses_dir, f"{command_id}.json")
|
response_file = os.path.join(self.responses_dir, f"{command_id}.json")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
|
@ -161,30 +161,30 @@ class SimulationIPCClient:
|
||||||
response_data = json.load(f)
|
response_data = json.load(f)
|
||||||
response = IPCResponse.from_dict(response_data)
|
response = IPCResponse.from_dict(response_data)
|
||||||
|
|
||||||
# 清理命令和响应文件
|
# Clean up command and response files
|
||||||
try:
|
try:
|
||||||
os.remove(command_file)
|
os.remove(command_file)
|
||||||
os.remove(response_file)
|
os.remove(response_file)
|
||||||
except OSError:
|
except OSError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
logger.info(f"收到IPC响应: command_id={command_id}, status={response.status.value}")
|
logger.info(f"Received IPC response: command_id={command_id}, status={response.status.value}")
|
||||||
return response
|
return response
|
||||||
except (json.JSONDecodeError, KeyError) as e:
|
except (json.JSONDecodeError, KeyError) as e:
|
||||||
logger.warning(f"解析响应失败: {e}")
|
logger.warning(f"Failed to parse response: {e}")
|
||||||
|
|
||||||
time.sleep(poll_interval)
|
time.sleep(poll_interval)
|
||||||
|
|
||||||
# 超时
|
# Timeout
|
||||||
logger.error(f"等待IPC响应超时: command_id={command_id}")
|
logger.error(f"Timed out waiting for IPC response: command_id={command_id}")
|
||||||
|
|
||||||
# 清理命令文件
|
# Clean up command file
|
||||||
try:
|
try:
|
||||||
os.remove(command_file)
|
os.remove(command_file)
|
||||||
except OSError:
|
except OSError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
raise TimeoutError(f"等待命令响应超时 ({timeout}秒)")
|
raise TimeoutError(f"Timed out waiting for command response ({timeout}s)")
|
||||||
|
|
||||||
def send_interview(
|
def send_interview(
|
||||||
self,
|
self,
|
||||||
|
|
@ -194,19 +194,19 @@ class SimulationIPCClient:
|
||||||
timeout: float = 60.0
|
timeout: float = 60.0
|
||||||
) -> IPCResponse:
|
) -> IPCResponse:
|
||||||
"""
|
"""
|
||||||
发送单个Agent采访命令
|
Send a single agent interview command
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
agent_id: Agent ID
|
agent_id: Agent ID
|
||||||
prompt: 采访问题
|
prompt: interview question
|
||||||
platform: 指定平台(可选)
|
platform: target platform (optional)
|
||||||
- "twitter": 只采访Twitter平台
|
- "twitter": interview only the Twitter platform
|
||||||
- "reddit": 只采访Reddit平台
|
- "reddit": interview only the Reddit platform
|
||||||
- None: 双平台模拟时同时采访两个平台,单平台模拟时采访该平台
|
- None: in dual-platform mode, interview both; in single-platform mode, interview that platform
|
||||||
timeout: 超时时间
|
timeout: timeout in seconds
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
IPCResponse,result字段包含采访结果
|
IPCResponse with interview result in the result field
|
||||||
"""
|
"""
|
||||||
args = {
|
args = {
|
||||||
"agent_id": agent_id,
|
"agent_id": agent_id,
|
||||||
|
|
@ -228,18 +228,18 @@ class SimulationIPCClient:
|
||||||
timeout: float = 120.0
|
timeout: float = 120.0
|
||||||
) -> IPCResponse:
|
) -> IPCResponse:
|
||||||
"""
|
"""
|
||||||
发送批量采访命令
|
Send a batch interview command
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
interviews: 采访列表,每个元素包含 {"agent_id": int, "prompt": str, "platform": str(可选)}
|
interviews: list of interviews, each containing {"agent_id": int, "prompt": str, "platform": str (optional)}
|
||||||
platform: 默认平台(可选,会被每个采访项的platform覆盖)
|
platform: default platform (optional; overridden per-item by each interview's platform)
|
||||||
- "twitter": 默认只采访Twitter平台
|
- "twitter": default to Twitter platform only
|
||||||
- "reddit": 默认只采访Reddit平台
|
- "reddit": default to Reddit platform only
|
||||||
- None: 双平台模拟时每个Agent同时采访两个平台
|
- None: in dual-platform mode, interview each agent on both platforms
|
||||||
timeout: 超时时间
|
timeout: timeout in seconds
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
IPCResponse,result字段包含所有采访结果
|
IPCResponse with all interview results in the result field
|
||||||
"""
|
"""
|
||||||
args = {"interviews": interviews}
|
args = {"interviews": interviews}
|
||||||
if platform:
|
if platform:
|
||||||
|
|
@ -253,10 +253,10 @@ class SimulationIPCClient:
|
||||||
|
|
||||||
def send_close_env(self, timeout: float = 30.0) -> IPCResponse:
|
def send_close_env(self, timeout: float = 30.0) -> IPCResponse:
|
||||||
"""
|
"""
|
||||||
发送关闭环境命令
|
Send a close-environment command
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
timeout: 超时时间
|
timeout: timeout in seconds
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
IPCResponse
|
IPCResponse
|
||||||
|
|
@ -269,9 +269,9 @@ class SimulationIPCClient:
|
||||||
|
|
||||||
def check_env_alive(self) -> bool:
|
def check_env_alive(self) -> bool:
|
||||||
"""
|
"""
|
||||||
检查模拟环境是否存活
|
Check whether the simulation environment is alive
|
||||||
|
|
||||||
通过检查 env_status.json 文件来判断
|
Determined by checking the env_status.json file
|
||||||
"""
|
"""
|
||||||
status_file = os.path.join(self.simulation_dir, "env_status.json")
|
status_file = os.path.join(self.simulation_dir, "env_status.json")
|
||||||
if not os.path.exists(status_file):
|
if not os.path.exists(status_file):
|
||||||
|
|
@ -287,41 +287,41 @@ class SimulationIPCClient:
|
||||||
|
|
||||||
class SimulationIPCServer:
|
class SimulationIPCServer:
|
||||||
"""
|
"""
|
||||||
模拟IPC服务器(模拟脚本端使用)
|
Simulation IPC server (used by the simulation script side)
|
||||||
|
|
||||||
轮询命令目录,执行命令并返回响应
|
Polls the command directory, executes commands, and returns responses
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, simulation_dir: str):
|
def __init__(self, simulation_dir: str):
|
||||||
"""
|
"""
|
||||||
初始化IPC服务器
|
Initialize the IPC server
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
simulation_dir: 模拟数据目录
|
simulation_dir: simulation data directory
|
||||||
"""
|
"""
|
||||||
self.simulation_dir = simulation_dir
|
self.simulation_dir = simulation_dir
|
||||||
self.commands_dir = os.path.join(simulation_dir, "ipc_commands")
|
self.commands_dir = os.path.join(simulation_dir, "ipc_commands")
|
||||||
self.responses_dir = os.path.join(simulation_dir, "ipc_responses")
|
self.responses_dir = os.path.join(simulation_dir, "ipc_responses")
|
||||||
|
|
||||||
# 确保目录存在
|
# Ensure directories exist
|
||||||
os.makedirs(self.commands_dir, exist_ok=True)
|
os.makedirs(self.commands_dir, exist_ok=True)
|
||||||
os.makedirs(self.responses_dir, exist_ok=True)
|
os.makedirs(self.responses_dir, exist_ok=True)
|
||||||
|
|
||||||
# 环境状态
|
# Environment status
|
||||||
self._running = False
|
self._running = False
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
"""标记服务器为运行状态"""
|
"""Mark the server as running"""
|
||||||
self._running = True
|
self._running = True
|
||||||
self._update_env_status("alive")
|
self._update_env_status("alive")
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
"""标记服务器为停止状态"""
|
"""Mark the server as stopped"""
|
||||||
self._running = False
|
self._running = False
|
||||||
self._update_env_status("stopped")
|
self._update_env_status("stopped")
|
||||||
|
|
||||||
def _update_env_status(self, status: str):
|
def _update_env_status(self, status: str):
|
||||||
"""更新环境状态文件"""
|
"""Update the environment status file"""
|
||||||
status_file = os.path.join(self.simulation_dir, "env_status.json")
|
status_file = os.path.join(self.simulation_dir, "env_status.json")
|
||||||
with open(status_file, 'w', encoding='utf-8') as f:
|
with open(status_file, 'w', encoding='utf-8') as f:
|
||||||
json.dump({
|
json.dump({
|
||||||
|
|
@ -331,15 +331,15 @@ class SimulationIPCServer:
|
||||||
|
|
||||||
def poll_commands(self) -> Optional[IPCCommand]:
|
def poll_commands(self) -> Optional[IPCCommand]:
|
||||||
"""
|
"""
|
||||||
轮询命令目录,返回第一个待处理的命令
|
Poll the command directory and return the first pending command
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
IPCCommand 或 None
|
IPCCommand or None
|
||||||
"""
|
"""
|
||||||
if not os.path.exists(self.commands_dir):
|
if not os.path.exists(self.commands_dir):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 按时间排序获取命令文件
|
# Get command files sorted by modification time
|
||||||
command_files = []
|
command_files = []
|
||||||
for filename in os.listdir(self.commands_dir):
|
for filename in os.listdir(self.commands_dir):
|
||||||
if filename.endswith('.json'):
|
if filename.endswith('.json'):
|
||||||
|
|
@ -354,23 +354,23 @@ class SimulationIPCServer:
|
||||||
data = json.load(f)
|
data = json.load(f)
|
||||||
return IPCCommand.from_dict(data)
|
return IPCCommand.from_dict(data)
|
||||||
except (json.JSONDecodeError, KeyError, OSError) as e:
|
except (json.JSONDecodeError, KeyError, OSError) as e:
|
||||||
logger.warning(f"读取命令文件失败: {filepath}, {e}")
|
logger.warning(f"Failed to read command file: {filepath}, {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def send_response(self, response: IPCResponse):
|
def send_response(self, response: IPCResponse):
|
||||||
"""
|
"""
|
||||||
发送响应
|
Send a response
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
response: IPC响应
|
response: IPC response
|
||||||
"""
|
"""
|
||||||
response_file = os.path.join(self.responses_dir, f"{response.command_id}.json")
|
response_file = os.path.join(self.responses_dir, f"{response.command_id}.json")
|
||||||
with open(response_file, 'w', encoding='utf-8') as f:
|
with open(response_file, 'w', encoding='utf-8') as f:
|
||||||
json.dump(response.to_dict(), f, ensure_ascii=False, indent=2)
|
json.dump(response.to_dict(), f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
# 删除命令文件
|
# Delete the command file
|
||||||
command_file = os.path.join(self.commands_dir, f"{response.command_id}.json")
|
command_file = os.path.join(self.commands_dir, f"{response.command_id}.json")
|
||||||
try:
|
try:
|
||||||
os.remove(command_file)
|
os.remove(command_file)
|
||||||
|
|
@ -378,7 +378,7 @@ class SimulationIPCServer:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def send_success(self, command_id: str, result: Dict[str, Any]):
|
def send_success(self, command_id: str, result: Dict[str, Any]):
|
||||||
"""发送成功响应"""
|
"""Send a success response"""
|
||||||
self.send_response(IPCResponse(
|
self.send_response(IPCResponse(
|
||||||
command_id=command_id,
|
command_id=command_id,
|
||||||
status=CommandStatus.COMPLETED,
|
status=CommandStatus.COMPLETED,
|
||||||
|
|
@ -386,7 +386,7 @@ class SimulationIPCServer:
|
||||||
))
|
))
|
||||||
|
|
||||||
def send_error(self, command_id: str, error: str):
|
def send_error(self, command_id: str, error: str):
|
||||||
"""发送错误响应"""
|
"""Send an error response"""
|
||||||
self.send_response(IPCResponse(
|
self.send_response(IPCResponse(
|
||||||
command_id=command_id,
|
command_id=command_id,
|
||||||
status=CommandStatus.FAILED,
|
status=CommandStatus.FAILED,
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
"""
|
"""
|
||||||
OASIS模拟管理器
|
OASIS simulation manager
|
||||||
管理Twitter和Reddit双平台并行模拟
|
Manages parallel simulation on both Twitter and Reddit platforms.
|
||||||
使用预设脚本 + LLM智能生成配置参数
|
Uses preset scripts with LLM-generated configuration parameters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
@ -23,60 +23,60 @@ logger = get_logger('mirofish.simulation')
|
||||||
|
|
||||||
|
|
||||||
class SimulationStatus(str, Enum):
|
class SimulationStatus(str, Enum):
|
||||||
"""模拟状态"""
|
"""Simulation status"""
|
||||||
CREATED = "created"
|
CREATED = "created"
|
||||||
PREPARING = "preparing"
|
PREPARING = "preparing"
|
||||||
READY = "ready"
|
READY = "ready"
|
||||||
RUNNING = "running"
|
RUNNING = "running"
|
||||||
PAUSED = "paused"
|
PAUSED = "paused"
|
||||||
STOPPED = "stopped" # 模拟被手动停止
|
STOPPED = "stopped" # Simulation manually stopped
|
||||||
COMPLETED = "completed" # 模拟自然完成
|
COMPLETED = "completed" # Simulation naturally completed
|
||||||
FAILED = "failed"
|
FAILED = "failed"
|
||||||
|
|
||||||
|
|
||||||
class PlatformType(str, Enum):
|
class PlatformType(str, Enum):
|
||||||
"""平台类型"""
|
"""Platform type"""
|
||||||
TWITTER = "twitter"
|
TWITTER = "twitter"
|
||||||
REDDIT = "reddit"
|
REDDIT = "reddit"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SimulationState:
|
class SimulationState:
|
||||||
"""模拟状态"""
|
"""Simulation state"""
|
||||||
simulation_id: str
|
simulation_id: str
|
||||||
project_id: str
|
project_id: str
|
||||||
graph_id: str
|
graph_id: str
|
||||||
|
|
||||||
# 平台启用状态
|
# Platform enable flags
|
||||||
enable_twitter: bool = True
|
enable_twitter: bool = True
|
||||||
enable_reddit: bool = True
|
enable_reddit: bool = True
|
||||||
|
|
||||||
# 状态
|
# Status
|
||||||
status: SimulationStatus = SimulationStatus.CREATED
|
status: SimulationStatus = SimulationStatus.CREATED
|
||||||
|
|
||||||
# 准备阶段数据
|
# Preparation phase data
|
||||||
entities_count: int = 0
|
entities_count: int = 0
|
||||||
profiles_count: int = 0
|
profiles_count: int = 0
|
||||||
entity_types: List[str] = field(default_factory=list)
|
entity_types: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
# 配置生成信息
|
# Config generation info
|
||||||
config_generated: bool = False
|
config_generated: bool = False
|
||||||
config_reasoning: str = ""
|
config_reasoning: str = ""
|
||||||
|
|
||||||
# 运行时数据
|
# Runtime data
|
||||||
current_round: int = 0
|
current_round: int = 0
|
||||||
twitter_status: str = "not_started"
|
twitter_status: str = "not_started"
|
||||||
reddit_status: str = "not_started"
|
reddit_status: str = "not_started"
|
||||||
|
|
||||||
# 时间戳
|
# Timestamps
|
||||||
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||||
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||||
|
|
||||||
# 错误信息
|
# Error message
|
||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
"""完整状态字典(内部使用)"""
|
"""Full state dictionary (internal use)"""
|
||||||
return {
|
return {
|
||||||
"simulation_id": self.simulation_id,
|
"simulation_id": self.simulation_id,
|
||||||
"project_id": self.project_id,
|
"project_id": self.project_id,
|
||||||
|
|
@ -98,7 +98,7 @@ class SimulationState:
|
||||||
}
|
}
|
||||||
|
|
||||||
def to_simple_dict(self) -> Dict[str, Any]:
|
def to_simple_dict(self) -> Dict[str, Any]:
|
||||||
"""简化状态字典(API返回使用)"""
|
"""Simplified state dictionary (used for API responses)"""
|
||||||
return {
|
return {
|
||||||
"simulation_id": self.simulation_id,
|
"simulation_id": self.simulation_id,
|
||||||
"project_id": self.project_id,
|
"project_id": self.project_id,
|
||||||
|
|
@ -114,36 +114,36 @@ class SimulationState:
|
||||||
|
|
||||||
class SimulationManager:
|
class SimulationManager:
|
||||||
"""
|
"""
|
||||||
模拟管理器
|
Simulation manager
|
||||||
|
|
||||||
核心功能:
|
Core functions:
|
||||||
1. 从Zep图谱读取实体并过滤
|
1. Read and filter entities from the Zep graph
|
||||||
2. 生成OASIS Agent Profile
|
2. Generate OASIS Agent Profiles
|
||||||
3. 使用LLM智能生成模拟配置参数
|
3. Use LLM to intelligently generate simulation configuration parameters
|
||||||
4. 准备预设脚本所需的所有文件
|
4. Prepare all files required by the preset scripts
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 模拟数据存储目录
|
# Simulation data storage directory
|
||||||
SIMULATION_DATA_DIR = os.path.join(
|
SIMULATION_DATA_DIR = os.path.join(
|
||||||
os.path.dirname(__file__),
|
os.path.dirname(__file__),
|
||||||
'../../uploads/simulations'
|
'../../uploads/simulations'
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# 确保目录存在
|
# Ensure directory exists
|
||||||
os.makedirs(self.SIMULATION_DATA_DIR, exist_ok=True)
|
os.makedirs(self.SIMULATION_DATA_DIR, exist_ok=True)
|
||||||
|
|
||||||
# 内存中的模拟状态缓存
|
# In-memory simulation state cache
|
||||||
self._simulations: Dict[str, SimulationState] = {}
|
self._simulations: Dict[str, SimulationState] = {}
|
||||||
|
|
||||||
def _get_simulation_dir(self, simulation_id: str) -> str:
|
def _get_simulation_dir(self, simulation_id: str) -> str:
|
||||||
"""获取模拟数据目录"""
|
"""Get the simulation data directory"""
|
||||||
sim_dir = os.path.join(self.SIMULATION_DATA_DIR, simulation_id)
|
sim_dir = os.path.join(self.SIMULATION_DATA_DIR, simulation_id)
|
||||||
os.makedirs(sim_dir, exist_ok=True)
|
os.makedirs(sim_dir, exist_ok=True)
|
||||||
return sim_dir
|
return sim_dir
|
||||||
|
|
||||||
def _save_simulation_state(self, state: SimulationState):
|
def _save_simulation_state(self, state: SimulationState):
|
||||||
"""保存模拟状态到文件"""
|
"""Save simulation state to file"""
|
||||||
sim_dir = self._get_simulation_dir(state.simulation_id)
|
sim_dir = self._get_simulation_dir(state.simulation_id)
|
||||||
state_file = os.path.join(sim_dir, "state.json")
|
state_file = os.path.join(sim_dir, "state.json")
|
||||||
|
|
||||||
|
|
@ -155,7 +155,7 @@ class SimulationManager:
|
||||||
self._simulations[state.simulation_id] = state
|
self._simulations[state.simulation_id] = state
|
||||||
|
|
||||||
def _load_simulation_state(self, simulation_id: str) -> Optional[SimulationState]:
|
def _load_simulation_state(self, simulation_id: str) -> Optional[SimulationState]:
|
||||||
"""从文件加载模拟状态"""
|
"""Load simulation state from file"""
|
||||||
if simulation_id in self._simulations:
|
if simulation_id in self._simulations:
|
||||||
return self._simulations[simulation_id]
|
return self._simulations[simulation_id]
|
||||||
|
|
||||||
|
|
@ -199,13 +199,13 @@ class SimulationManager:
|
||||||
enable_reddit: bool = True,
|
enable_reddit: bool = True,
|
||||||
) -> SimulationState:
|
) -> SimulationState:
|
||||||
"""
|
"""
|
||||||
创建新的模拟
|
Create a new simulation.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
project_id: 项目ID
|
project_id: project ID
|
||||||
graph_id: Zep图谱ID
|
graph_id: Zep graph ID
|
||||||
enable_twitter: 是否启用Twitter模拟
|
enable_twitter: whether to enable Twitter simulation
|
||||||
enable_reddit: 是否启用Reddit模拟
|
enable_reddit: whether to enable Reddit simulation
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SimulationState
|
SimulationState
|
||||||
|
|
@ -223,7 +223,7 @@ class SimulationManager:
|
||||||
)
|
)
|
||||||
|
|
||||||
self._save_simulation_state(state)
|
self._save_simulation_state(state)
|
||||||
logger.info(f"创建模拟: {simulation_id}, project={project_id}, graph={graph_id}")
|
logger.info(f"Simulation created: {simulation_id}, project={project_id}, graph={graph_id}")
|
||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
@ -238,30 +238,30 @@ class SimulationManager:
|
||||||
parallel_profile_count: int = 3
|
parallel_profile_count: int = 3
|
||||||
) -> SimulationState:
|
) -> SimulationState:
|
||||||
"""
|
"""
|
||||||
准备模拟环境(全程自动化)
|
Prepare the simulation environment (fully automated).
|
||||||
|
|
||||||
步骤:
|
Steps:
|
||||||
1. 从Zep图谱读取并过滤实体
|
1. Read and filter entities from the Zep graph
|
||||||
2. 为每个实体生成OASIS Agent Profile(可选LLM增强,支持并行)
|
2. Generate an OASIS Agent Profile for each entity (optional LLM enhancement, supports parallelism)
|
||||||
3. 使用LLM智能生成模拟配置参数(时间、活跃度、发言频率等)
|
3. Use LLM to intelligently generate simulation configuration parameters (time, activity level, posting frequency, etc.)
|
||||||
4. 保存配置文件和Profile文件
|
4. Save configuration files and profile files
|
||||||
5. 复制预设脚本到模拟目录
|
5. Copy preset scripts to the simulation directory
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
simulation_id: 模拟ID
|
simulation_id: simulation ID
|
||||||
simulation_requirement: 模拟需求描述(用于LLM生成配置)
|
simulation_requirement: simulation requirement description (used for LLM config generation)
|
||||||
document_text: 原始文档内容(用于LLM理解背景)
|
document_text: original document content (used for LLM background understanding)
|
||||||
defined_entity_types: 预定义的实体类型(可选)
|
defined_entity_types: predefined entity types (optional)
|
||||||
use_llm_for_profiles: 是否使用LLM生成详细人设
|
use_llm_for_profiles: whether to use LLM to generate detailed personas
|
||||||
progress_callback: 进度回调函数 (stage, progress, message)
|
progress_callback: progress callback function (stage, progress, message)
|
||||||
parallel_profile_count: 并行生成人设的数量,默认3
|
parallel_profile_count: number of profiles to generate in parallel, default 3
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SimulationState
|
SimulationState
|
||||||
"""
|
"""
|
||||||
state = self._load_simulation_state(simulation_id)
|
state = self._load_simulation_state(simulation_id)
|
||||||
if not state:
|
if not state:
|
||||||
raise ValueError(f"模拟不存在: {simulation_id}")
|
raise ValueError(f"Simulation not found: {simulation_id}")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
state.status = SimulationStatus.PREPARING
|
state.status = SimulationStatus.PREPARING
|
||||||
|
|
@ -269,7 +269,7 @@ class SimulationManager:
|
||||||
|
|
||||||
sim_dir = self._get_simulation_dir(simulation_id)
|
sim_dir = self._get_simulation_dir(simulation_id)
|
||||||
|
|
||||||
# ========== 阶段1: 读取并过滤实体 ==========
|
# ========== Stage 1: Read and filter entities ==========
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
progress_callback("reading", 0, t('progress.connectingZepGraph'))
|
progress_callback("reading", 0, t('progress.connectingZepGraph'))
|
||||||
|
|
||||||
|
|
@ -297,11 +297,11 @@ class SimulationManager:
|
||||||
|
|
||||||
if filtered.filtered_count == 0:
|
if filtered.filtered_count == 0:
|
||||||
state.status = SimulationStatus.FAILED
|
state.status = SimulationStatus.FAILED
|
||||||
state.error = "没有找到符合条件的实体,请检查图谱是否正确构建"
|
state.error = "No qualifying entities found. Please check that the graph was built correctly."
|
||||||
self._save_simulation_state(state)
|
self._save_simulation_state(state)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
# ========== 阶段2: 生成Agent Profile ==========
|
# ========== Stage 2: Generate Agent Profiles ==========
|
||||||
total_entities = len(filtered.entities)
|
total_entities = len(filtered.entities)
|
||||||
|
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
|
|
@ -312,7 +312,7 @@ class SimulationManager:
|
||||||
total=total_entities
|
total=total_entities
|
||||||
)
|
)
|
||||||
|
|
||||||
# 传入graph_id以启用Zep检索功能,获取更丰富的上下文
|
# Pass graph_id to enable Zep retrieval for richer context
|
||||||
generator = OasisProfileGenerator(graph_id=state.graph_id)
|
generator = OasisProfileGenerator(graph_id=state.graph_id)
|
||||||
|
|
||||||
def profile_progress(current, total, msg):
|
def profile_progress(current, total, msg):
|
||||||
|
|
@ -326,7 +326,7 @@ class SimulationManager:
|
||||||
item_name=msg
|
item_name=msg
|
||||||
)
|
)
|
||||||
|
|
||||||
# 设置实时保存的文件路径(优先使用 Reddit JSON 格式)
|
# Set real-time save path (prefer Reddit JSON format)
|
||||||
realtime_output_path = None
|
realtime_output_path = None
|
||||||
realtime_platform = "reddit"
|
realtime_platform = "reddit"
|
||||||
if state.enable_reddit:
|
if state.enable_reddit:
|
||||||
|
|
@ -340,16 +340,16 @@ class SimulationManager:
|
||||||
entities=filtered.entities,
|
entities=filtered.entities,
|
||||||
use_llm=use_llm_for_profiles,
|
use_llm=use_llm_for_profiles,
|
||||||
progress_callback=profile_progress,
|
progress_callback=profile_progress,
|
||||||
graph_id=state.graph_id, # 传入graph_id用于Zep检索
|
graph_id=state.graph_id, # Pass graph_id for Zep retrieval
|
||||||
parallel_count=parallel_profile_count, # 并行生成数量
|
parallel_count=parallel_profile_count, # Parallel generation count
|
||||||
realtime_output_path=realtime_output_path, # 实时保存路径
|
realtime_output_path=realtime_output_path, # Real-time save path
|
||||||
output_platform=realtime_platform # 输出格式
|
output_platform=realtime_platform # Output format
|
||||||
)
|
)
|
||||||
|
|
||||||
state.profiles_count = len(profiles)
|
state.profiles_count = len(profiles)
|
||||||
|
|
||||||
# 保存Profile文件(注意:Twitter使用CSV格式,Reddit使用JSON格式)
|
# Save profile files (note: Twitter uses CSV format, Reddit uses JSON format)
|
||||||
# Reddit 已经在生成过程中实时保存了,这里再保存一次确保完整性
|
# Reddit has already been saved incrementally during generation; save once more to ensure completeness
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
progress_callback(
|
progress_callback(
|
||||||
"generating_profiles", 95,
|
"generating_profiles", 95,
|
||||||
|
|
@ -366,7 +366,7 @@ class SimulationManager:
|
||||||
)
|
)
|
||||||
|
|
||||||
if state.enable_twitter:
|
if state.enable_twitter:
|
||||||
# Twitter使用CSV格式!这是OASIS的要求
|
# Twitter uses CSV format — this is a requirement of OASIS
|
||||||
generator.save_profiles(
|
generator.save_profiles(
|
||||||
profiles=profiles,
|
profiles=profiles,
|
||||||
file_path=os.path.join(sim_dir, "twitter_profiles.csv"),
|
file_path=os.path.join(sim_dir, "twitter_profiles.csv"),
|
||||||
|
|
@ -381,7 +381,7 @@ class SimulationManager:
|
||||||
total=len(profiles)
|
total=len(profiles)
|
||||||
)
|
)
|
||||||
|
|
||||||
# ========== 阶段3: LLM智能生成模拟配置 ==========
|
# ========== Stage 3: LLM intelligent simulation configuration generation ==========
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
progress_callback(
|
progress_callback(
|
||||||
"generating_config", 0,
|
"generating_config", 0,
|
||||||
|
|
@ -419,7 +419,7 @@ class SimulationManager:
|
||||||
total=3
|
total=3
|
||||||
)
|
)
|
||||||
|
|
||||||
# 保存配置文件
|
# Save configuration file
|
||||||
config_path = os.path.join(sim_dir, "simulation_config.json")
|
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||||||
with open(config_path, 'w', encoding='utf-8') as f:
|
with open(config_path, 'w', encoding='utf-8') as f:
|
||||||
f.write(sim_params.to_json())
|
f.write(sim_params.to_json())
|
||||||
|
|
@ -435,20 +435,20 @@ class SimulationManager:
|
||||||
total=3
|
total=3
|
||||||
)
|
)
|
||||||
|
|
||||||
# 注意:运行脚本保留在 backend/scripts/ 目录,不再复制到模拟目录
|
# Note: run scripts remain in backend/scripts/; they are not copied to the simulation directory.
|
||||||
# 启动模拟时,simulation_runner 会从 scripts/ 目录运行脚本
|
# When starting a simulation, simulation_runner runs scripts from the scripts/ directory.
|
||||||
|
|
||||||
# 更新状态
|
# Update status
|
||||||
state.status = SimulationStatus.READY
|
state.status = SimulationStatus.READY
|
||||||
self._save_simulation_state(state)
|
self._save_simulation_state(state)
|
||||||
|
|
||||||
logger.info(f"模拟准备完成: {simulation_id}, "
|
logger.info(f"Simulation preparation complete: {simulation_id}, "
|
||||||
f"entities={state.entities_count}, profiles={state.profiles_count}")
|
f"entities={state.entities_count}, profiles={state.profiles_count}")
|
||||||
|
|
||||||
return state
|
return state
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"模拟准备失败: {simulation_id}, error={str(e)}")
|
logger.error(f"Simulation preparation failed: {simulation_id}, error={str(e)}")
|
||||||
import traceback
|
import traceback
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
state.status = SimulationStatus.FAILED
|
state.status = SimulationStatus.FAILED
|
||||||
|
|
@ -457,16 +457,16 @@ class SimulationManager:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def get_simulation(self, simulation_id: str) -> Optional[SimulationState]:
|
def get_simulation(self, simulation_id: str) -> Optional[SimulationState]:
|
||||||
"""获取模拟状态"""
|
"""Get simulation state"""
|
||||||
return self._load_simulation_state(simulation_id)
|
return self._load_simulation_state(simulation_id)
|
||||||
|
|
||||||
def list_simulations(self, project_id: Optional[str] = None) -> List[SimulationState]:
|
def list_simulations(self, project_id: Optional[str] = None) -> List[SimulationState]:
|
||||||
"""列出所有模拟"""
|
"""List all simulations"""
|
||||||
simulations = []
|
simulations = []
|
||||||
|
|
||||||
if os.path.exists(self.SIMULATION_DATA_DIR):
|
if os.path.exists(self.SIMULATION_DATA_DIR):
|
||||||
for sim_id in os.listdir(self.SIMULATION_DATA_DIR):
|
for sim_id in os.listdir(self.SIMULATION_DATA_DIR):
|
||||||
# 跳过隐藏文件(如 .DS_Store)和非目录文件
|
# Skip hidden files (e.g. .DS_Store) and non-directory entries
|
||||||
sim_path = os.path.join(self.SIMULATION_DATA_DIR, sim_id)
|
sim_path = os.path.join(self.SIMULATION_DATA_DIR, sim_id)
|
||||||
if sim_id.startswith('.') or not os.path.isdir(sim_path):
|
if sim_id.startswith('.') or not os.path.isdir(sim_path):
|
||||||
continue
|
continue
|
||||||
|
|
@ -479,10 +479,10 @@ class SimulationManager:
|
||||||
return simulations
|
return simulations
|
||||||
|
|
||||||
def get_profiles(self, simulation_id: str, platform: str = "reddit") -> List[Dict[str, Any]]:
|
def get_profiles(self, simulation_id: str, platform: str = "reddit") -> List[Dict[str, Any]]:
|
||||||
"""获取模拟的Agent Profile"""
|
"""Get agent profiles for a simulation"""
|
||||||
state = self._load_simulation_state(simulation_id)
|
state = self._load_simulation_state(simulation_id)
|
||||||
if not state:
|
if not state:
|
||||||
raise ValueError(f"模拟不存在: {simulation_id}")
|
raise ValueError(f"Simulation not found: {simulation_id}")
|
||||||
|
|
||||||
sim_dir = self._get_simulation_dir(simulation_id)
|
sim_dir = self._get_simulation_dir(simulation_id)
|
||||||
profile_path = os.path.join(sim_dir, f"{platform}_profiles.json")
|
profile_path = os.path.join(sim_dir, f"{platform}_profiles.json")
|
||||||
|
|
@ -494,7 +494,7 @@ class SimulationManager:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
def get_simulation_config(self, simulation_id: str) -> Optional[Dict[str, Any]]:
|
def get_simulation_config(self, simulation_id: str) -> Optional[Dict[str, Any]]:
|
||||||
"""获取模拟配置"""
|
"""Get simulation configuration"""
|
||||||
sim_dir = self._get_simulation_dir(simulation_id)
|
sim_dir = self._get_simulation_dir(simulation_id)
|
||||||
config_path = os.path.join(sim_dir, "simulation_config.json")
|
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||||||
|
|
||||||
|
|
@ -505,7 +505,7 @@ class SimulationManager:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
def get_run_instructions(self, simulation_id: str) -> Dict[str, str]:
|
def get_run_instructions(self, simulation_id: str) -> Dict[str, str]:
|
||||||
"""获取运行说明"""
|
"""Get run instructions"""
|
||||||
sim_dir = self._get_simulation_dir(simulation_id)
|
sim_dir = self._get_simulation_dir(simulation_id)
|
||||||
config_path = os.path.join(sim_dir, "simulation_config.json")
|
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||||||
scripts_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../scripts'))
|
scripts_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../scripts'))
|
||||||
|
|
@ -520,10 +520,10 @@ class SimulationManager:
|
||||||
"parallel": f"python {scripts_dir}/run_parallel_simulation.py --config {config_path}",
|
"parallel": f"python {scripts_dir}/run_parallel_simulation.py --config {config_path}",
|
||||||
},
|
},
|
||||||
"instructions": (
|
"instructions": (
|
||||||
f"1. 激活conda环境: conda activate MiroFish\n"
|
f"1. Activate conda environment: conda activate MiroFish\n"
|
||||||
f"2. 运行模拟 (脚本位于 {scripts_dir}):\n"
|
f"2. Run simulation (scripts located at {scripts_dir}):\n"
|
||||||
f" - 单独运行Twitter: python {scripts_dir}/run_twitter_simulation.py --config {config_path}\n"
|
f" - Twitter only: python {scripts_dir}/run_twitter_simulation.py --config {config_path}\n"
|
||||||
f" - 单独运行Reddit: python {scripts_dir}/run_reddit_simulation.py --config {config_path}\n"
|
f" - Reddit only: python {scripts_dir}/run_reddit_simulation.py --config {config_path}\n"
|
||||||
f" - 并行运行双平台: python {scripts_dir}/run_parallel_simulation.py --config {config_path}"
|
f" - Both platforms in parallel: python {scripts_dir}/run_parallel_simulation.py --config {config_path}"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,5 +1,5 @@
|
||||||
"""
|
"""
|
||||||
文本处理服务
|
Text processing service
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
@ -7,11 +7,11 @@ from ..utils.file_parser import FileParser, split_text_into_chunks
|
||||||
|
|
||||||
|
|
||||||
class TextProcessor:
|
class TextProcessor:
|
||||||
"""文本处理器"""
|
"""Text processor"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def extract_from_files(file_paths: List[str]) -> str:
|
def extract_from_files(file_paths: List[str]) -> str:
|
||||||
"""从多个文件提取文本"""
|
"""Extract text from multiple files"""
|
||||||
return FileParser.extract_from_multiple(file_paths)
|
return FileParser.extract_from_multiple(file_paths)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -21,40 +21,40 @@ class TextProcessor:
|
||||||
overlap: int = 50
|
overlap: int = 50
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
分割文本
|
Split text into chunks.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: 原始文本
|
text: raw text
|
||||||
chunk_size: 块大小
|
chunk_size: chunk size
|
||||||
overlap: 重叠大小
|
overlap: overlap size
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
文本块列表
|
list of text chunks
|
||||||
"""
|
"""
|
||||||
return split_text_into_chunks(text, chunk_size, overlap)
|
return split_text_into_chunks(text, chunk_size, overlap)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def preprocess_text(text: str) -> str:
|
def preprocess_text(text: str) -> str:
|
||||||
"""
|
"""
|
||||||
预处理文本
|
Preprocess text:
|
||||||
- 移除多余空白
|
- Remove excess whitespace
|
||||||
- 标准化换行
|
- Normalize line endings
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: 原始文本
|
text: raw text
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
处理后的文本
|
processed text
|
||||||
"""
|
"""
|
||||||
import re
|
import re
|
||||||
|
|
||||||
# 标准化换行
|
# Normalize line endings
|
||||||
text = text.replace('\r\n', '\n').replace('\r', '\n')
|
text = text.replace('\r\n', '\n').replace('\r', '\n')
|
||||||
|
|
||||||
# 移除连续空行(保留最多两个换行)
|
# Remove consecutive blank lines (keep at most two newlines)
|
||||||
text = re.sub(r'\n{3,}', '\n\n', text)
|
text = re.sub(r'\n{3,}', '\n\n', text)
|
||||||
|
|
||||||
# 移除行首行尾空白
|
# Strip leading/trailing whitespace from each line
|
||||||
lines = [line.strip() for line in text.split('\n')]
|
lines = [line.strip() for line in text.split('\n')]
|
||||||
text = '\n'.join(lines)
|
text = '\n'.join(lines)
|
||||||
|
|
||||||
|
|
@ -62,7 +62,7 @@ class TextProcessor:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_text_stats(text: str) -> dict:
|
def get_text_stats(text: str) -> dict:
|
||||||
"""获取文本统计信息"""
|
"""Get text statistics"""
|
||||||
return {
|
return {
|
||||||
"total_chars": len(text),
|
"total_chars": len(text),
|
||||||
"total_lines": text.count('\n') + 1,
|
"total_lines": text.count('\n') + 1,
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"""
|
"""
|
||||||
Zep实体读取与过滤服务
|
Zep entity read and filter service
|
||||||
从Zep图谱中读取节点,筛选出符合预定义实体类型的节点
|
Reads nodes from the Zep graph and filters out nodes that match predefined entity types
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
|
@ -15,21 +15,21 @@ from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
|
||||||
|
|
||||||
logger = get_logger('mirofish.zep_entity_reader')
|
logger = get_logger('mirofish.zep_entity_reader')
|
||||||
|
|
||||||
# 用于泛型返回类型
|
# Generic return type
|
||||||
T = TypeVar('T')
|
T = TypeVar('T')
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EntityNode:
|
class EntityNode:
|
||||||
"""实体节点数据结构"""
|
"""Entity node data structure"""
|
||||||
uuid: str
|
uuid: str
|
||||||
name: str
|
name: str
|
||||||
labels: List[str]
|
labels: List[str]
|
||||||
summary: str
|
summary: str
|
||||||
attributes: Dict[str, Any]
|
attributes: Dict[str, Any]
|
||||||
# 相关的边信息
|
# Related edge info
|
||||||
related_edges: List[Dict[str, Any]] = field(default_factory=list)
|
related_edges: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
# 相关的其他节点信息
|
# Related node info
|
||||||
related_nodes: List[Dict[str, Any]] = field(default_factory=list)
|
related_nodes: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
|
@ -44,7 +44,7 @@ class EntityNode:
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_entity_type(self) -> Optional[str]:
|
def get_entity_type(self) -> Optional[str]:
|
||||||
"""获取实体类型(排除默认的Entity标签)"""
|
"""Get entity type (excluding the default Entity label)"""
|
||||||
for label in self.labels:
|
for label in self.labels:
|
||||||
if label not in ["Entity", "Node"]:
|
if label not in ["Entity", "Node"]:
|
||||||
return label
|
return label
|
||||||
|
|
@ -53,7 +53,7 @@ class EntityNode:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FilteredEntities:
|
class FilteredEntities:
|
||||||
"""过滤后的实体集合"""
|
"""Filtered entity collection"""
|
||||||
entities: List[EntityNode]
|
entities: List[EntityNode]
|
||||||
entity_types: Set[str]
|
entity_types: Set[str]
|
||||||
total_count: int
|
total_count: int
|
||||||
|
|
@ -70,18 +70,18 @@ class FilteredEntities:
|
||||||
|
|
||||||
class ZepEntityReader:
|
class ZepEntityReader:
|
||||||
"""
|
"""
|
||||||
Zep实体读取与过滤服务
|
Zep entity read and filter service
|
||||||
|
|
||||||
主要功能:
|
Main features:
|
||||||
1. 从Zep图谱读取所有节点
|
1. Read all nodes from the Zep graph
|
||||||
2. 筛选出符合预定义实体类型的节点(Labels不只是Entity的节点)
|
2. Filter out nodes matching predefined entity types (nodes with labels beyond just "Entity")
|
||||||
3. 获取每个实体的相关边和关联节点信息
|
3. Fetch related edges and associated node info for each entity
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, api_key: Optional[str] = None):
|
def __init__(self, api_key: Optional[str] = None):
|
||||||
self.api_key = api_key or Config.ZEP_API_KEY
|
self.api_key = api_key or Config.ZEP_API_KEY
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError("ZEP_API_KEY 未配置")
|
raise ValueError("ZEP_API_KEY is not configured")
|
||||||
|
|
||||||
self.client = Zep(api_key=self.api_key)
|
self.client = Zep(api_key=self.api_key)
|
||||||
|
|
||||||
|
|
@ -93,16 +93,16 @@ class ZepEntityReader:
|
||||||
initial_delay: float = 2.0
|
initial_delay: float = 2.0
|
||||||
) -> T:
|
) -> T:
|
||||||
"""
|
"""
|
||||||
带重试机制的Zep API调用
|
Zep API call with retry logic
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
func: 要执行的函数(无参数的lambda或callable)
|
func: function to execute (a lambda or callable with no arguments)
|
||||||
operation_name: 操作名称,用于日志
|
operation_name: operation name for logging
|
||||||
max_retries: 最大重试次数(默认3次,即最多尝试3次)
|
max_retries: maximum number of retries (default 3, meaning up to 3 attempts total)
|
||||||
initial_delay: 初始延迟秒数
|
initial_delay: initial delay in seconds
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
API调用结果
|
API call result
|
||||||
"""
|
"""
|
||||||
last_exception = None
|
last_exception = None
|
||||||
delay = initial_delay
|
delay = initial_delay
|
||||||
|
|
@ -114,27 +114,27 @@ class ZepEntityReader:
|
||||||
last_exception = e
|
last_exception = e
|
||||||
if attempt < max_retries - 1:
|
if attempt < max_retries - 1:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Zep {operation_name} 第 {attempt + 1} 次尝试失败: {str(e)[:100]}, "
|
f"Zep {operation_name} attempt {attempt + 1} failed: {str(e)[:100]}, "
|
||||||
f"{delay:.1f}秒后重试..."
|
f"retrying in {delay:.1f}s..."
|
||||||
)
|
)
|
||||||
time.sleep(delay)
|
time.sleep(delay)
|
||||||
delay *= 2 # 指数退避
|
delay *= 2 # Exponential backoff
|
||||||
else:
|
else:
|
||||||
logger.error(f"Zep {operation_name} 在 {max_retries} 次尝试后仍失败: {str(e)}")
|
logger.error(f"Zep {operation_name} still failing after {max_retries} attempts: {str(e)}")
|
||||||
|
|
||||||
raise last_exception
|
raise last_exception
|
||||||
|
|
||||||
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
|
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
获取图谱的所有节点(分页获取)
|
Get all nodes in the graph (paginated)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph_id: 图谱ID
|
graph_id: graph ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
节点列表
|
Node list
|
||||||
"""
|
"""
|
||||||
logger.info(f"获取图谱 {graph_id} 的所有节点...")
|
logger.info(f"Fetching all nodes for graph {graph_id}...")
|
||||||
|
|
||||||
nodes = fetch_all_nodes(self.client, graph_id)
|
nodes = fetch_all_nodes(self.client, graph_id)
|
||||||
|
|
||||||
|
|
@ -148,20 +148,20 @@ class ZepEntityReader:
|
||||||
"attributes": node.attributes or {},
|
"attributes": node.attributes or {},
|
||||||
})
|
})
|
||||||
|
|
||||||
logger.info(f"共获取 {len(nodes_data)} 个节点")
|
logger.info(f"Fetched {len(nodes_data)} nodes")
|
||||||
return nodes_data
|
return nodes_data
|
||||||
|
|
||||||
def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]:
|
def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
获取图谱的所有边(分页获取)
|
Get all edges in the graph (paginated)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph_id: 图谱ID
|
graph_id: graph ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
边列表
|
Edge list
|
||||||
"""
|
"""
|
||||||
logger.info(f"获取图谱 {graph_id} 的所有边...")
|
logger.info(f"Fetching all edges for graph {graph_id}...")
|
||||||
|
|
||||||
edges = fetch_all_edges(self.client, graph_id)
|
edges = fetch_all_edges(self.client, graph_id)
|
||||||
|
|
||||||
|
|
@ -176,24 +176,24 @@ class ZepEntityReader:
|
||||||
"attributes": edge.attributes or {},
|
"attributes": edge.attributes or {},
|
||||||
})
|
})
|
||||||
|
|
||||||
logger.info(f"共获取 {len(edges_data)} 条边")
|
logger.info(f"Fetched {len(edges_data)} edges")
|
||||||
return edges_data
|
return edges_data
|
||||||
|
|
||||||
def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]:
|
def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""
|
||||||
获取指定节点的所有相关边(带重试机制)
|
Get all edges related to the specified node (with retry logic)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_uuid: 节点UUID
|
node_uuid: node UUID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
边列表
|
Edge list
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 使用重试机制调用Zep API
|
# Call Zep API with retry
|
||||||
edges = self._call_with_retry(
|
edges = self._call_with_retry(
|
||||||
func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid),
|
func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid),
|
||||||
operation_name=f"获取节点边(node={node_uuid[:8]}...)"
|
operation_name=f"get node edges (node={node_uuid[:8]}...)"
|
||||||
)
|
)
|
||||||
|
|
||||||
edges_data = []
|
edges_data = []
|
||||||
|
|
@ -209,7 +209,7 @@ class ZepEntityReader:
|
||||||
|
|
||||||
return edges_data
|
return edges_data
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"获取节点 {node_uuid} 的边失败: {str(e)}")
|
logger.warning(f"Failed to get edges for node {node_uuid}: {str(e)}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def filter_defined_entities(
|
def filter_defined_entities(
|
||||||
|
|
@ -219,47 +219,47 @@ class ZepEntityReader:
|
||||||
enrich_with_edges: bool = True
|
enrich_with_edges: bool = True
|
||||||
) -> FilteredEntities:
|
) -> FilteredEntities:
|
||||||
"""
|
"""
|
||||||
筛选出符合预定义实体类型的节点
|
Filter out nodes that match predefined entity types
|
||||||
|
|
||||||
筛选逻辑:
|
Filter logic:
|
||||||
- 如果节点的Labels只有一个"Entity",说明这个实体不符合我们预定义的类型,跳过
|
- If a node's Labels contain only "Entity", it does not match our predefined types; skip it
|
||||||
- 如果节点的Labels包含除"Entity"和"Node"之外的标签,说明符合预定义类型,保留
|
- If a node's Labels contain labels other than "Entity" and "Node", it matches a predefined type; keep it
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph_id: 图谱ID
|
graph_id: graph ID
|
||||||
defined_entity_types: 预定义的实体类型列表(可选,如果提供则只保留这些类型)
|
defined_entity_types: list of predefined entity types (optional; if provided, only these types are kept)
|
||||||
enrich_with_edges: 是否获取每个实体的相关边信息
|
enrich_with_edges: whether to fetch related edge info for each entity
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
FilteredEntities: 过滤后的实体集合
|
FilteredEntities: filtered entity collection
|
||||||
"""
|
"""
|
||||||
logger.info(f"开始筛选图谱 {graph_id} 的实体...")
|
logger.info(f"Starting entity filtering for graph {graph_id}...")
|
||||||
|
|
||||||
# 获取所有节点
|
# Get all nodes
|
||||||
all_nodes = self.get_all_nodes(graph_id)
|
all_nodes = self.get_all_nodes(graph_id)
|
||||||
total_count = len(all_nodes)
|
total_count = len(all_nodes)
|
||||||
|
|
||||||
# 获取所有边(用于后续关联查找)
|
# Get all edges (for relation lookup)
|
||||||
all_edges = self.get_all_edges(graph_id) if enrich_with_edges else []
|
all_edges = self.get_all_edges(graph_id) if enrich_with_edges else []
|
||||||
|
|
||||||
# 构建节点UUID到节点数据的映射
|
# Build UUID-to-node mapping
|
||||||
node_map = {n["uuid"]: n for n in all_nodes}
|
node_map = {n["uuid"]: n for n in all_nodes}
|
||||||
|
|
||||||
# 筛选符合条件的实体
|
# Filter matching entities
|
||||||
filtered_entities = []
|
filtered_entities = []
|
||||||
entity_types_found = set()
|
entity_types_found = set()
|
||||||
|
|
||||||
for node in all_nodes:
|
for node in all_nodes:
|
||||||
labels = node.get("labels", [])
|
labels = node.get("labels", [])
|
||||||
|
|
||||||
# 筛选逻辑:Labels必须包含除"Entity"和"Node"之外的标签
|
# Filter logic: Labels must contain at least one label other than "Entity" and "Node"
|
||||||
custom_labels = [l for l in labels if l not in ["Entity", "Node"]]
|
custom_labels = [l for l in labels if l not in ["Entity", "Node"]]
|
||||||
|
|
||||||
if not custom_labels:
|
if not custom_labels:
|
||||||
# 只有默认标签,跳过
|
# Only default labels; skip
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 如果指定了预定义类型,检查是否匹配
|
# If predefined types are specified, check for a match
|
||||||
if defined_entity_types:
|
if defined_entity_types:
|
||||||
matching_labels = [l for l in custom_labels if l in defined_entity_types]
|
matching_labels = [l for l in custom_labels if l in defined_entity_types]
|
||||||
if not matching_labels:
|
if not matching_labels:
|
||||||
|
|
@ -270,7 +270,7 @@ class ZepEntityReader:
|
||||||
|
|
||||||
entity_types_found.add(entity_type)
|
entity_types_found.add(entity_type)
|
||||||
|
|
||||||
# 创建实体节点对象
|
# Create entity node object
|
||||||
entity = EntityNode(
|
entity = EntityNode(
|
||||||
uuid=node["uuid"],
|
uuid=node["uuid"],
|
||||||
name=node["name"],
|
name=node["name"],
|
||||||
|
|
@ -279,7 +279,7 @@ class ZepEntityReader:
|
||||||
attributes=node["attributes"],
|
attributes=node["attributes"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# 获取相关边和节点
|
# Fetch related edges and nodes
|
||||||
if enrich_with_edges:
|
if enrich_with_edges:
|
||||||
related_edges = []
|
related_edges = []
|
||||||
related_node_uuids = set()
|
related_node_uuids = set()
|
||||||
|
|
@ -304,7 +304,7 @@ class ZepEntityReader:
|
||||||
|
|
||||||
entity.related_edges = related_edges
|
entity.related_edges = related_edges
|
||||||
|
|
||||||
# 获取关联节点的基本信息
|
# Fetch basic info for related nodes
|
||||||
related_nodes = []
|
related_nodes = []
|
||||||
for related_uuid in related_node_uuids:
|
for related_uuid in related_node_uuids:
|
||||||
if related_uuid in node_map:
|
if related_uuid in node_map:
|
||||||
|
|
@ -320,8 +320,8 @@ class ZepEntityReader:
|
||||||
|
|
||||||
filtered_entities.append(entity)
|
filtered_entities.append(entity)
|
||||||
|
|
||||||
logger.info(f"筛选完成: 总节点 {total_count}, 符合条件 {len(filtered_entities)}, "
|
logger.info(f"Filtering complete: total nodes {total_count}, matching {len(filtered_entities)}, "
|
||||||
f"实体类型: {entity_types_found}")
|
f"entity types: {entity_types_found}")
|
||||||
|
|
||||||
return FilteredEntities(
|
return FilteredEntities(
|
||||||
entities=filtered_entities,
|
entities=filtered_entities,
|
||||||
|
|
@ -336,33 +336,33 @@ class ZepEntityReader:
|
||||||
entity_uuid: str
|
entity_uuid: str
|
||||||
) -> Optional[EntityNode]:
|
) -> Optional[EntityNode]:
|
||||||
"""
|
"""
|
||||||
获取单个实体及其完整上下文(边和关联节点,带重试机制)
|
Get a single entity and its full context (edges and related nodes, with retry)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph_id: 图谱ID
|
graph_id: graph ID
|
||||||
entity_uuid: 实体UUID
|
entity_uuid: entity UUID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
EntityNode或None
|
EntityNode or None
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 使用重试机制获取节点
|
# Get the node with retry
|
||||||
node = self._call_with_retry(
|
node = self._call_with_retry(
|
||||||
func=lambda: self.client.graph.node.get(uuid_=entity_uuid),
|
func=lambda: self.client.graph.node.get(uuid_=entity_uuid),
|
||||||
operation_name=f"获取节点详情(uuid={entity_uuid[:8]}...)"
|
operation_name=f"get node detail (uuid={entity_uuid[:8]}...)"
|
||||||
)
|
)
|
||||||
|
|
||||||
if not node:
|
if not node:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 获取节点的边
|
# Get the node's edges
|
||||||
edges = self.get_node_edges(entity_uuid)
|
edges = self.get_node_edges(entity_uuid)
|
||||||
|
|
||||||
# 获取所有节点用于关联查找
|
# Get all nodes for relation lookup
|
||||||
all_nodes = self.get_all_nodes(graph_id)
|
all_nodes = self.get_all_nodes(graph_id)
|
||||||
node_map = {n["uuid"]: n for n in all_nodes}
|
node_map = {n["uuid"]: n for n in all_nodes}
|
||||||
|
|
||||||
# 处理相关边和节点
|
# Process related edges and nodes
|
||||||
related_edges = []
|
related_edges = []
|
||||||
related_node_uuids = set()
|
related_node_uuids = set()
|
||||||
|
|
||||||
|
|
@ -384,7 +384,7 @@ class ZepEntityReader:
|
||||||
})
|
})
|
||||||
related_node_uuids.add(edge["source_node_uuid"])
|
related_node_uuids.add(edge["source_node_uuid"])
|
||||||
|
|
||||||
# 获取关联节点信息
|
# Fetch related node info
|
||||||
related_nodes = []
|
related_nodes = []
|
||||||
for related_uuid in related_node_uuids:
|
for related_uuid in related_node_uuids:
|
||||||
if related_uuid in node_map:
|
if related_uuid in node_map:
|
||||||
|
|
@ -407,7 +407,7 @@ class ZepEntityReader:
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取实体 {entity_uuid} 失败: {str(e)}")
|
logger.error(f"Failed to get entity {entity_uuid}: {str(e)}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def get_entities_by_type(
|
def get_entities_by_type(
|
||||||
|
|
@ -417,15 +417,15 @@ class ZepEntityReader:
|
||||||
enrich_with_edges: bool = True
|
enrich_with_edges: bool = True
|
||||||
) -> List[EntityNode]:
|
) -> List[EntityNode]:
|
||||||
"""
|
"""
|
||||||
获取指定类型的所有实体
|
Get all entities of a specified type
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph_id: 图谱ID
|
graph_id: graph ID
|
||||||
entity_type: 实体类型(如 "Student", "PublicFigure" 等)
|
entity_type: entity type (e.g. "Student", "PublicFigure")
|
||||||
enrich_with_edges: 是否获取相关边信息
|
enrich_with_edges: whether to fetch related edge info
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
实体列表
|
Entity list
|
||||||
"""
|
"""
|
||||||
result = self.filter_defined_entities(
|
result = self.filter_defined_entities(
|
||||||
graph_id=graph_id,
|
graph_id=graph_id,
|
||||||
|
|
@ -433,5 +433,3 @@ class ZepEntityReader:
|
||||||
enrich_with_edges=enrich_with_edges
|
enrich_with_edges=enrich_with_edges
|
||||||
)
|
)
|
||||||
return result.entities
|
return result.entities
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"""
|
"""
|
||||||
Zep图谱记忆更新服务
|
Zep graph memory update service
|
||||||
将模拟中的Agent活动动态更新到Zep图谱中
|
Dynamically updates agent activities from the simulation to the Zep graph
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
@ -23,7 +23,7 @@ logger = get_logger('mirofish.zep_graph_memory_updater')
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AgentActivity:
|
class AgentActivity:
|
||||||
"""Agent活动记录"""
|
"""Agent activity record"""
|
||||||
platform: str # twitter / reddit
|
platform: str # twitter / reddit
|
||||||
agent_id: int
|
agent_id: int
|
||||||
agent_name: str
|
agent_name: str
|
||||||
|
|
@ -34,12 +34,13 @@ class AgentActivity:
|
||||||
|
|
||||||
def to_episode_text(self) -> str:
|
def to_episode_text(self) -> str:
|
||||||
"""
|
"""
|
||||||
将活动转换为可以发送给Zep的文本描述
|
Convert the activity to a text description suitable for sending to Zep
|
||||||
|
|
||||||
采用自然语言描述格式,让Zep能够从中提取实体和关系
|
Uses a natural-language description format so Zep can extract entities and
|
||||||
不添加模拟相关的前缀,避免误导图谱更新
|
relationships. No simulation-specific prefix is added to avoid misleading
|
||||||
|
graph updates.
|
||||||
"""
|
"""
|
||||||
# 根据不同的动作类型生成不同的描述
|
# Generate a description based on the action type
|
||||||
action_descriptions = {
|
action_descriptions = {
|
||||||
"CREATE_POST": self._describe_create_post,
|
"CREATE_POST": self._describe_create_post,
|
||||||
"LIKE_POST": self._describe_like_post,
|
"LIKE_POST": self._describe_like_post,
|
||||||
|
|
@ -58,222 +59,223 @@ class AgentActivity:
|
||||||
describe_func = action_descriptions.get(self.action_type, self._describe_generic)
|
describe_func = action_descriptions.get(self.action_type, self._describe_generic)
|
||||||
description = describe_func()
|
description = describe_func()
|
||||||
|
|
||||||
# 直接返回 "agent名称: 活动描述" 格式,不添加模拟前缀
|
# Return "agent_name: activity description" format without a simulation prefix
|
||||||
return f"{self.agent_name}: {description}"
|
return f"{self.agent_name}: {description}"
|
||||||
|
|
||||||
def _describe_create_post(self) -> str:
|
def _describe_create_post(self) -> str:
|
||||||
content = self.action_args.get("content", "")
|
content = self.action_args.get("content", "")
|
||||||
if content:
|
if content:
|
||||||
return f"发布了一条帖子:「{content}」"
|
return f'posted: "{content}"'
|
||||||
return "发布了一条帖子"
|
return "created a post"
|
||||||
|
|
||||||
def _describe_like_post(self) -> str:
|
def _describe_like_post(self) -> str:
|
||||||
"""点赞帖子 - 包含帖子原文和作者信息"""
|
"""Like a post — includes post content and author info"""
|
||||||
post_content = self.action_args.get("post_content", "")
|
post_content = self.action_args.get("post_content", "")
|
||||||
post_author = self.action_args.get("post_author_name", "")
|
post_author = self.action_args.get("post_author_name", "")
|
||||||
|
|
||||||
if post_content and post_author:
|
if post_content and post_author:
|
||||||
return f"点赞了{post_author}的帖子:「{post_content}」"
|
return f'liked {post_author}\'s post: "{post_content}"'
|
||||||
elif post_content:
|
elif post_content:
|
||||||
return f"点赞了一条帖子:「{post_content}」"
|
return f'liked a post: "{post_content}"'
|
||||||
elif post_author:
|
elif post_author:
|
||||||
return f"点赞了{post_author}的一条帖子"
|
return f"liked a post by {post_author}"
|
||||||
return "点赞了一条帖子"
|
return "liked a post"
|
||||||
|
|
||||||
def _describe_dislike_post(self) -> str:
|
def _describe_dislike_post(self) -> str:
|
||||||
"""踩帖子 - 包含帖子原文和作者信息"""
|
"""Dislike a post — includes post content and author info"""
|
||||||
post_content = self.action_args.get("post_content", "")
|
post_content = self.action_args.get("post_content", "")
|
||||||
post_author = self.action_args.get("post_author_name", "")
|
post_author = self.action_args.get("post_author_name", "")
|
||||||
|
|
||||||
if post_content and post_author:
|
if post_content and post_author:
|
||||||
return f"踩了{post_author}的帖子:「{post_content}」"
|
return f'disliked {post_author}\'s post: "{post_content}"'
|
||||||
elif post_content:
|
elif post_content:
|
||||||
return f"踩了一条帖子:「{post_content}」"
|
return f'disliked a post: "{post_content}"'
|
||||||
elif post_author:
|
elif post_author:
|
||||||
return f"踩了{post_author}的一条帖子"
|
return f"disliked a post by {post_author}"
|
||||||
return "踩了一条帖子"
|
return "disliked a post"
|
||||||
|
|
||||||
def _describe_repost(self) -> str:
|
def _describe_repost(self) -> str:
|
||||||
"""转发帖子 - 包含原帖内容和作者信息"""
|
"""Repost — includes original post content and author info"""
|
||||||
original_content = self.action_args.get("original_content", "")
|
original_content = self.action_args.get("original_content", "")
|
||||||
original_author = self.action_args.get("original_author_name", "")
|
original_author = self.action_args.get("original_author_name", "")
|
||||||
|
|
||||||
if original_content and original_author:
|
if original_content and original_author:
|
||||||
return f"转发了{original_author}的帖子:「{original_content}」"
|
return f'reposted {original_author}\'s post: "{original_content}"'
|
||||||
elif original_content:
|
elif original_content:
|
||||||
return f"转发了一条帖子:「{original_content}」"
|
return f'reposted: "{original_content}"'
|
||||||
elif original_author:
|
elif original_author:
|
||||||
return f"转发了{original_author}的一条帖子"
|
return f"reposted a post by {original_author}"
|
||||||
return "转发了一条帖子"
|
return "reposted a post"
|
||||||
|
|
||||||
def _describe_quote_post(self) -> str:
|
def _describe_quote_post(self) -> str:
|
||||||
"""引用帖子 - 包含原帖内容、作者信息和引用评论"""
|
"""Quote post — includes original post content, author info, and quote comment"""
|
||||||
original_content = self.action_args.get("original_content", "")
|
original_content = self.action_args.get("original_content", "")
|
||||||
original_author = self.action_args.get("original_author_name", "")
|
original_author = self.action_args.get("original_author_name", "")
|
||||||
quote_content = self.action_args.get("quote_content", "") or self.action_args.get("content", "")
|
quote_content = self.action_args.get("quote_content", "") or self.action_args.get("content", "")
|
||||||
|
|
||||||
base = ""
|
base = ""
|
||||||
if original_content and original_author:
|
if original_content and original_author:
|
||||||
base = f"引用了{original_author}的帖子「{original_content}」"
|
base = f'quoted {original_author}\'s post "{original_content}"'
|
||||||
elif original_content:
|
elif original_content:
|
||||||
base = f"引用了一条帖子「{original_content}」"
|
base = f'quoted a post: "{original_content}"'
|
||||||
elif original_author:
|
elif original_author:
|
||||||
base = f"引用了{original_author}的一条帖子"
|
base = f"quoted a post by {original_author}"
|
||||||
else:
|
else:
|
||||||
base = "引用了一条帖子"
|
base = "quoted a post"
|
||||||
|
|
||||||
if quote_content:
|
if quote_content:
|
||||||
base += f",并评论道:「{quote_content}」"
|
base += f' with comment: "{quote_content}"'
|
||||||
return base
|
return base
|
||||||
|
|
||||||
def _describe_follow(self) -> str:
|
def _describe_follow(self) -> str:
|
||||||
"""关注用户 - 包含被关注用户的名称"""
|
"""Follow a user — includes the followed user's name"""
|
||||||
target_user_name = self.action_args.get("target_user_name", "")
|
target_user_name = self.action_args.get("target_user_name", "")
|
||||||
|
|
||||||
if target_user_name:
|
if target_user_name:
|
||||||
return f"关注了用户「{target_user_name}」"
|
return f'followed user "{target_user_name}"'
|
||||||
return "关注了一个用户"
|
return "followed a user"
|
||||||
|
|
||||||
def _describe_create_comment(self) -> str:
|
def _describe_create_comment(self) -> str:
|
||||||
"""发表评论 - 包含评论内容和所评论的帖子信息"""
|
"""Create a comment — includes comment content and the post being commented on"""
|
||||||
content = self.action_args.get("content", "")
|
content = self.action_args.get("content", "")
|
||||||
post_content = self.action_args.get("post_content", "")
|
post_content = self.action_args.get("post_content", "")
|
||||||
post_author = self.action_args.get("post_author_name", "")
|
post_author = self.action_args.get("post_author_name", "")
|
||||||
|
|
||||||
if content:
|
if content:
|
||||||
if post_content and post_author:
|
if post_content and post_author:
|
||||||
return f"在{post_author}的帖子「{post_content}」下评论道:「{content}」"
|
return f'commented on {post_author}\'s post "{post_content}": "{content}"'
|
||||||
elif post_content:
|
elif post_content:
|
||||||
return f"在帖子「{post_content}」下评论道:「{content}」"
|
return f'commented on post "{post_content}": "{content}"'
|
||||||
elif post_author:
|
elif post_author:
|
||||||
return f"在{post_author}的帖子下评论道:「{content}」"
|
return f'commented on {post_author}\'s post: "{content}"'
|
||||||
return f"评论道:「{content}」"
|
return f'commented: "{content}"'
|
||||||
return "发表了评论"
|
return "posted a comment"
|
||||||
|
|
||||||
def _describe_like_comment(self) -> str:
|
def _describe_like_comment(self) -> str:
|
||||||
"""点赞评论 - 包含评论内容和作者信息"""
|
"""Like a comment — includes comment content and author info"""
|
||||||
comment_content = self.action_args.get("comment_content", "")
|
comment_content = self.action_args.get("comment_content", "")
|
||||||
comment_author = self.action_args.get("comment_author_name", "")
|
comment_author = self.action_args.get("comment_author_name", "")
|
||||||
|
|
||||||
if comment_content and comment_author:
|
if comment_content and comment_author:
|
||||||
return f"点赞了{comment_author}的评论:「{comment_content}」"
|
return f'liked {comment_author}\'s comment: "{comment_content}"'
|
||||||
elif comment_content:
|
elif comment_content:
|
||||||
return f"点赞了一条评论:「{comment_content}」"
|
return f'liked a comment: "{comment_content}"'
|
||||||
elif comment_author:
|
elif comment_author:
|
||||||
return f"点赞了{comment_author}的一条评论"
|
return f"liked a comment by {comment_author}"
|
||||||
return "点赞了一条评论"
|
return "liked a comment"
|
||||||
|
|
||||||
def _describe_dislike_comment(self) -> str:
|
def _describe_dislike_comment(self) -> str:
|
||||||
"""踩评论 - 包含评论内容和作者信息"""
|
"""Dislike a comment — includes comment content and author info"""
|
||||||
comment_content = self.action_args.get("comment_content", "")
|
comment_content = self.action_args.get("comment_content", "")
|
||||||
comment_author = self.action_args.get("comment_author_name", "")
|
comment_author = self.action_args.get("comment_author_name", "")
|
||||||
|
|
||||||
if comment_content and comment_author:
|
if comment_content and comment_author:
|
||||||
return f"踩了{comment_author}的评论:「{comment_content}」"
|
return f'disliked {comment_author}\'s comment: "{comment_content}"'
|
||||||
elif comment_content:
|
elif comment_content:
|
||||||
return f"踩了一条评论:「{comment_content}」"
|
return f'disliked a comment: "{comment_content}"'
|
||||||
elif comment_author:
|
elif comment_author:
|
||||||
return f"踩了{comment_author}的一条评论"
|
return f"disliked a comment by {comment_author}"
|
||||||
return "踩了一条评论"
|
return "disliked a comment"
|
||||||
|
|
||||||
def _describe_search(self) -> str:
|
def _describe_search(self) -> str:
|
||||||
"""搜索帖子 - 包含搜索关键词"""
|
"""Search posts — includes search keyword"""
|
||||||
query = self.action_args.get("query", "") or self.action_args.get("keyword", "")
|
query = self.action_args.get("query", "") or self.action_args.get("keyword", "")
|
||||||
return f"搜索了「{query}」" if query else "进行了搜索"
|
return f'searched for "{query}"' if query else "performed a search"
|
||||||
|
|
||||||
def _describe_search_user(self) -> str:
|
def _describe_search_user(self) -> str:
|
||||||
"""搜索用户 - 包含搜索关键词"""
|
"""Search users — includes search keyword"""
|
||||||
query = self.action_args.get("query", "") or self.action_args.get("username", "")
|
query = self.action_args.get("query", "") or self.action_args.get("username", "")
|
||||||
return f"搜索了用户「{query}」" if query else "搜索了用户"
|
return f'searched for user "{query}"' if query else "searched for a user"
|
||||||
|
|
||||||
def _describe_mute(self) -> str:
|
def _describe_mute(self) -> str:
|
||||||
"""屏蔽用户 - 包含被屏蔽用户的名称"""
|
"""Mute a user — includes the muted user's name"""
|
||||||
target_user_name = self.action_args.get("target_user_name", "")
|
target_user_name = self.action_args.get("target_user_name", "")
|
||||||
|
|
||||||
if target_user_name:
|
if target_user_name:
|
||||||
return f"屏蔽了用户「{target_user_name}」"
|
return f'muted user "{target_user_name}"'
|
||||||
return "屏蔽了一个用户"
|
return "muted a user"
|
||||||
|
|
||||||
def _describe_generic(self) -> str:
|
def _describe_generic(self) -> str:
|
||||||
# 对于未知的动作类型,生成通用描述
|
# Generic description for unknown action types
|
||||||
return f"执行了{self.action_type}操作"
|
return f"performed action: {self.action_type}"
|
||||||
|
|
||||||
|
|
||||||
class ZepGraphMemoryUpdater:
|
class ZepGraphMemoryUpdater:
|
||||||
"""
|
"""
|
||||||
Zep图谱记忆更新器
|
Zep graph memory updater
|
||||||
|
|
||||||
监控模拟的actions日志文件,将新的agent活动实时更新到Zep图谱中。
|
Monitors the simulation's actions log file and updates new agent activities
|
||||||
按平台分组,每累积BATCH_SIZE条活动后批量发送到Zep。
|
to the Zep graph in real time. Activities are grouped by platform; each platform
|
||||||
|
batches up to BATCH_SIZE activities before sending them to Zep.
|
||||||
|
|
||||||
所有有意义的行为都会被更新到Zep,action_args中会包含完整的上下文信息:
|
All meaningful actions are updated to Zep. action_args contains full context:
|
||||||
- 点赞/踩的帖子原文
|
- Original post content for likes/dislikes
|
||||||
- 转发/引用的帖子原文
|
- Original post content for reposts/quotes
|
||||||
- 关注/屏蔽的用户名
|
- Usernames for follows/mutes
|
||||||
- 点赞/踩的评论原文
|
- Original comment content for comment likes/dislikes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 批量发送大小(每个平台累积多少条后发送)
|
# Batch send size (activities per platform before sending)
|
||||||
BATCH_SIZE = 5
|
BATCH_SIZE = 5
|
||||||
|
|
||||||
# 平台名称映射(用于控制台显示)
|
# Platform display names
|
||||||
PLATFORM_DISPLAY_NAMES = {
|
PLATFORM_DISPLAY_NAMES = {
|
||||||
'twitter': '世界1',
|
'twitter': 'World 1',
|
||||||
'reddit': '世界2',
|
'reddit': 'World 2',
|
||||||
}
|
}
|
||||||
|
|
||||||
# 发送间隔(秒),避免请求过快
|
# Send interval (seconds) to avoid sending too fast
|
||||||
SEND_INTERVAL = 0.5
|
SEND_INTERVAL = 0.5
|
||||||
|
|
||||||
# 重试配置
|
# Retry config
|
||||||
MAX_RETRIES = 3
|
MAX_RETRIES = 3
|
||||||
RETRY_DELAY = 2 # 秒
|
RETRY_DELAY = 2 # seconds
|
||||||
|
|
||||||
def __init__(self, graph_id: str, api_key: Optional[str] = None):
|
def __init__(self, graph_id: str, api_key: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
初始化更新器
|
Initialize the updater
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph_id: Zep图谱ID
|
graph_id: Zep graph ID
|
||||||
api_key: Zep API Key(可选,默认从配置读取)
|
api_key: Zep API key (optional; defaults to config value)
|
||||||
"""
|
"""
|
||||||
self.graph_id = graph_id
|
self.graph_id = graph_id
|
||||||
self.api_key = api_key or Config.ZEP_API_KEY
|
self.api_key = api_key or Config.ZEP_API_KEY
|
||||||
|
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError("ZEP_API_KEY未配置")
|
raise ValueError("ZEP_API_KEY is not configured")
|
||||||
|
|
||||||
self.client = Zep(api_key=self.api_key)
|
self.client = Zep(api_key=self.api_key)
|
||||||
|
|
||||||
# 活动队列
|
# Activity queue
|
||||||
self._activity_queue: Queue = Queue()
|
self._activity_queue: Queue = Queue()
|
||||||
|
|
||||||
# 按平台分组的活动缓冲区(每个平台各自累积到BATCH_SIZE后批量发送)
|
# Per-platform activity buffers (each platform accumulates to BATCH_SIZE before batch sending)
|
||||||
self._platform_buffers: Dict[str, List[AgentActivity]] = {
|
self._platform_buffers: Dict[str, List[AgentActivity]] = {
|
||||||
'twitter': [],
|
'twitter': [],
|
||||||
'reddit': [],
|
'reddit': [],
|
||||||
}
|
}
|
||||||
self._buffer_lock = threading.Lock()
|
self._buffer_lock = threading.Lock()
|
||||||
|
|
||||||
# 控制标志
|
# Control flags
|
||||||
self._running = False
|
self._running = False
|
||||||
self._worker_thread: Optional[threading.Thread] = None
|
self._worker_thread: Optional[threading.Thread] = None
|
||||||
|
|
||||||
# 统计
|
# Statistics
|
||||||
self._total_activities = 0 # 实际添加到队列的活动数
|
self._total_activities = 0 # Activities added to queue
|
||||||
self._total_sent = 0 # 成功发送到Zep的批次数
|
self._total_sent = 0 # Batches successfully sent to Zep
|
||||||
self._total_items_sent = 0 # 成功发送到Zep的活动条数
|
self._total_items_sent = 0 # Individual activities successfully sent to Zep
|
||||||
self._failed_count = 0 # 发送失败的批次数
|
self._failed_count = 0 # Batches that failed to send
|
||||||
self._skipped_count = 0 # 被过滤跳过的活动数(DO_NOTHING)
|
self._skipped_count = 0 # Activities filtered out (DO_NOTHING)
|
||||||
|
|
||||||
logger.info(f"ZepGraphMemoryUpdater 初始化完成: graph_id={graph_id}, batch_size={self.BATCH_SIZE}")
|
logger.info(f"ZepGraphMemoryUpdater initialized: graph_id={graph_id}, batch_size={self.BATCH_SIZE}")
|
||||||
|
|
||||||
def _get_platform_display_name(self, platform: str) -> str:
|
def _get_platform_display_name(self, platform: str) -> str:
|
||||||
"""获取平台的显示名称"""
|
"""Get the display name for a platform"""
|
||||||
return self.PLATFORM_DISPLAY_NAMES.get(platform.lower(), platform)
|
return self.PLATFORM_DISPLAY_NAMES.get(platform.lower(), platform)
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
"""启动后台工作线程"""
|
"""Start the background worker thread"""
|
||||||
if self._running:
|
if self._running:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -288,19 +290,19 @@ class ZepGraphMemoryUpdater:
|
||||||
name=f"ZepMemoryUpdater-{self.graph_id[:8]}"
|
name=f"ZepMemoryUpdater-{self.graph_id[:8]}"
|
||||||
)
|
)
|
||||||
self._worker_thread.start()
|
self._worker_thread.start()
|
||||||
logger.info(f"ZepGraphMemoryUpdater 已启动: graph_id={self.graph_id}")
|
logger.info(f"ZepGraphMemoryUpdater started: graph_id={self.graph_id}")
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
"""停止后台工作线程"""
|
"""Stop the background worker thread"""
|
||||||
self._running = False
|
self._running = False
|
||||||
|
|
||||||
# 发送剩余的活动
|
# Send remaining activities
|
||||||
self._flush_remaining()
|
self._flush_remaining()
|
||||||
|
|
||||||
if self._worker_thread and self._worker_thread.is_alive():
|
if self._worker_thread and self._worker_thread.is_alive():
|
||||||
self._worker_thread.join(timeout=10)
|
self._worker_thread.join(timeout=10)
|
||||||
|
|
||||||
logger.info(f"ZepGraphMemoryUpdater 已停止: graph_id={self.graph_id}, "
|
logger.info(f"ZepGraphMemoryUpdater stopped: graph_id={self.graph_id}, "
|
||||||
f"total_activities={self._total_activities}, "
|
f"total_activities={self._total_activities}, "
|
||||||
f"batches_sent={self._total_sent}, "
|
f"batches_sent={self._total_sent}, "
|
||||||
f"items_sent={self._total_items_sent}, "
|
f"items_sent={self._total_items_sent}, "
|
||||||
|
|
@ -309,43 +311,43 @@ class ZepGraphMemoryUpdater:
|
||||||
|
|
||||||
def add_activity(self, activity: AgentActivity):
|
def add_activity(self, activity: AgentActivity):
|
||||||
"""
|
"""
|
||||||
添加一个agent活动到队列
|
Add an agent activity to the queue
|
||||||
|
|
||||||
所有有意义的行为都会被添加到队列,包括:
|
All meaningful actions are added to the queue, including:
|
||||||
- CREATE_POST(发帖)
|
- CREATE_POST
|
||||||
- CREATE_COMMENT(评论)
|
- CREATE_COMMENT
|
||||||
- QUOTE_POST(引用帖子)
|
- QUOTE_POST
|
||||||
- SEARCH_POSTS(搜索帖子)
|
- SEARCH_POSTS
|
||||||
- SEARCH_USER(搜索用户)
|
- SEARCH_USER
|
||||||
- LIKE_POST/DISLIKE_POST(点赞/踩帖子)
|
- LIKE_POST/DISLIKE_POST
|
||||||
- REPOST(转发)
|
- REPOST
|
||||||
- FOLLOW(关注)
|
- FOLLOW
|
||||||
- MUTE(屏蔽)
|
- MUTE
|
||||||
- LIKE_COMMENT/DISLIKE_COMMENT(点赞/踩评论)
|
- LIKE_COMMENT/DISLIKE_COMMENT
|
||||||
|
|
||||||
action_args中会包含完整的上下文信息(如帖子原文、用户名等)。
|
action_args contains full context (e.g. post content, usernames, etc.).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
activity: Agent活动记录
|
activity: agent activity record
|
||||||
"""
|
"""
|
||||||
# 跳过DO_NOTHING类型的活动
|
# Skip DO_NOTHING activities
|
||||||
if activity.action_type == "DO_NOTHING":
|
if activity.action_type == "DO_NOTHING":
|
||||||
self._skipped_count += 1
|
self._skipped_count += 1
|
||||||
return
|
return
|
||||||
|
|
||||||
self._activity_queue.put(activity)
|
self._activity_queue.put(activity)
|
||||||
self._total_activities += 1
|
self._total_activities += 1
|
||||||
logger.debug(f"添加活动到Zep队列: {activity.agent_name} - {activity.action_type}")
|
logger.debug(f"Added activity to Zep queue: {activity.agent_name} - {activity.action_type}")
|
||||||
|
|
||||||
def add_activity_from_dict(self, data: Dict[str, Any], platform: str):
|
def add_activity_from_dict(self, data: Dict[str, Any], platform: str):
|
||||||
"""
|
"""
|
||||||
从字典数据添加活动
|
Add an activity from a dictionary
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: 从actions.jsonl解析的字典数据
|
data: dict parsed from actions.jsonl
|
||||||
platform: 平台名称 (twitter/reddit)
|
platform: platform name (twitter/reddit)
|
||||||
"""
|
"""
|
||||||
# 跳过事件类型的条目
|
# Skip event-type entries
|
||||||
if "event_type" in data:
|
if "event_type" in data:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
@ -362,53 +364,53 @@ class ZepGraphMemoryUpdater:
|
||||||
self.add_activity(activity)
|
self.add_activity(activity)
|
||||||
|
|
||||||
def _worker_loop(self, locale: str = 'zh'):
|
def _worker_loop(self, locale: str = 'zh'):
|
||||||
"""后台工作循环 - 按平台批量发送活动到Zep"""
|
"""Background worker loop — batch-sends activities to Zep per platform"""
|
||||||
set_locale(locale)
|
set_locale(locale)
|
||||||
while self._running or not self._activity_queue.empty():
|
while self._running or not self._activity_queue.empty():
|
||||||
try:
|
try:
|
||||||
# 尝试从队列获取活动(超时1秒)
|
# Try to get an activity from the queue (1 second timeout)
|
||||||
try:
|
try:
|
||||||
activity = self._activity_queue.get(timeout=1)
|
activity = self._activity_queue.get(timeout=1)
|
||||||
|
|
||||||
# 将活动添加到对应平台的缓冲区
|
# Add activity to the corresponding platform buffer
|
||||||
platform = activity.platform.lower()
|
platform = activity.platform.lower()
|
||||||
with self._buffer_lock:
|
with self._buffer_lock:
|
||||||
if platform not in self._platform_buffers:
|
if platform not in self._platform_buffers:
|
||||||
self._platform_buffers[platform] = []
|
self._platform_buffers[platform] = []
|
||||||
self._platform_buffers[platform].append(activity)
|
self._platform_buffers[platform].append(activity)
|
||||||
|
|
||||||
# 检查该平台是否达到批量大小
|
# Check if this platform has reached the batch size
|
||||||
if len(self._platform_buffers[platform]) >= self.BATCH_SIZE:
|
if len(self._platform_buffers[platform]) >= self.BATCH_SIZE:
|
||||||
batch = self._platform_buffers[platform][:self.BATCH_SIZE]
|
batch = self._platform_buffers[platform][:self.BATCH_SIZE]
|
||||||
self._platform_buffers[platform] = self._platform_buffers[platform][self.BATCH_SIZE:]
|
self._platform_buffers[platform] = self._platform_buffers[platform][self.BATCH_SIZE:]
|
||||||
# 释放锁后再发送
|
# Release lock before sending
|
||||||
self._send_batch_activities(batch, platform)
|
self._send_batch_activities(batch, platform)
|
||||||
# 发送间隔,避免请求过快
|
# Throttle to avoid sending too fast
|
||||||
time.sleep(self.SEND_INTERVAL)
|
time.sleep(self.SEND_INTERVAL)
|
||||||
|
|
||||||
except Empty:
|
except Empty:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"工作循环异常: {e}")
|
logger.error(f"Worker loop exception: {e}")
|
||||||
time.sleep(1)
|
time.sleep(1)
|
||||||
|
|
||||||
def _send_batch_activities(self, activities: List[AgentActivity], platform: str):
|
def _send_batch_activities(self, activities: List[AgentActivity], platform: str):
|
||||||
"""
|
"""
|
||||||
批量发送活动到Zep图谱(合并为一条文本)
|
Batch-send activities to the Zep graph (merged into a single text block)
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
activities: Agent活动列表
|
activities: list of agent activities
|
||||||
platform: 平台名称
|
platform: platform name
|
||||||
"""
|
"""
|
||||||
if not activities:
|
if not activities:
|
||||||
return
|
return
|
||||||
|
|
||||||
# 将多条活动合并为一条文本,用换行分隔
|
# Merge multiple activities into a single text, separated by newlines
|
||||||
episode_texts = [activity.to_episode_text() for activity in activities]
|
episode_texts = [activity.to_episode_text() for activity in activities]
|
||||||
combined_text = "\n".join(episode_texts)
|
combined_text = "\n".join(episode_texts)
|
||||||
|
|
||||||
# 带重试的发送
|
# Send with retry
|
||||||
for attempt in range(self.MAX_RETRIES):
|
for attempt in range(self.MAX_RETRIES):
|
||||||
try:
|
try:
|
||||||
self.client.graph.add(
|
self.client.graph.add(
|
||||||
|
|
@ -420,21 +422,21 @@ class ZepGraphMemoryUpdater:
|
||||||
self._total_sent += 1
|
self._total_sent += 1
|
||||||
self._total_items_sent += len(activities)
|
self._total_items_sent += len(activities)
|
||||||
display_name = self._get_platform_display_name(platform)
|
display_name = self._get_platform_display_name(platform)
|
||||||
logger.info(f"成功批量发送 {len(activities)} 条{display_name}活动到图谱 {self.graph_id}")
|
logger.info(f"Successfully sent batch of {len(activities)} {display_name} activities to graph {self.graph_id}")
|
||||||
logger.debug(f"批量内容预览: {combined_text[:200]}...")
|
logger.debug(f"Batch content preview: {combined_text[:200]}...")
|
||||||
return
|
return
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if attempt < self.MAX_RETRIES - 1:
|
if attempt < self.MAX_RETRIES - 1:
|
||||||
logger.warning(f"批量发送到Zep失败 (尝试 {attempt + 1}/{self.MAX_RETRIES}): {e}")
|
logger.warning(f"Batch send to Zep failed (attempt {attempt + 1}/{self.MAX_RETRIES}): {e}")
|
||||||
time.sleep(self.RETRY_DELAY * (attempt + 1))
|
time.sleep(self.RETRY_DELAY * (attempt + 1))
|
||||||
else:
|
else:
|
||||||
logger.error(f"批量发送到Zep失败,已重试{self.MAX_RETRIES}次: {e}")
|
logger.error(f"Batch send to Zep failed after {self.MAX_RETRIES} attempts: {e}")
|
||||||
self._failed_count += 1
|
self._failed_count += 1
|
||||||
|
|
||||||
def _flush_remaining(self):
|
def _flush_remaining(self):
|
||||||
"""发送队列和缓冲区中剩余的活动"""
|
"""Send remaining activities from the queue and buffers"""
|
||||||
# 首先处理队列中剩余的活动,添加到缓冲区
|
# First, drain the queue into the buffers
|
||||||
while not self._activity_queue.empty():
|
while not self._activity_queue.empty():
|
||||||
try:
|
try:
|
||||||
activity = self._activity_queue.get_nowait()
|
activity = self._activity_queue.get_nowait()
|
||||||
|
|
@ -446,41 +448,41 @@ class ZepGraphMemoryUpdater:
|
||||||
except Empty:
|
except Empty:
|
||||||
break
|
break
|
||||||
|
|
||||||
# 然后发送各平台缓冲区中剩余的活动(即使不足BATCH_SIZE条)
|
# Then send remaining activities in each platform buffer (even if below BATCH_SIZE)
|
||||||
with self._buffer_lock:
|
with self._buffer_lock:
|
||||||
for platform, buffer in self._platform_buffers.items():
|
for platform, buffer in self._platform_buffers.items():
|
||||||
if buffer:
|
if buffer:
|
||||||
display_name = self._get_platform_display_name(platform)
|
display_name = self._get_platform_display_name(platform)
|
||||||
logger.info(f"发送{display_name}平台剩余的 {len(buffer)} 条活动")
|
logger.info(f"Sending {len(buffer)} remaining {display_name} platform activities")
|
||||||
self._send_batch_activities(buffer, platform)
|
self._send_batch_activities(buffer, platform)
|
||||||
# 清空所有缓冲区
|
# Clear all buffers
|
||||||
for platform in self._platform_buffers:
|
for platform in self._platform_buffers:
|
||||||
self._platform_buffers[platform] = []
|
self._platform_buffers[platform] = []
|
||||||
|
|
||||||
def get_stats(self) -> Dict[str, Any]:
|
def get_stats(self) -> Dict[str, Any]:
|
||||||
"""获取统计信息"""
|
"""Get statistics"""
|
||||||
with self._buffer_lock:
|
with self._buffer_lock:
|
||||||
buffer_sizes = {p: len(b) for p, b in self._platform_buffers.items()}
|
buffer_sizes = {p: len(b) for p, b in self._platform_buffers.items()}
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"graph_id": self.graph_id,
|
"graph_id": self.graph_id,
|
||||||
"batch_size": self.BATCH_SIZE,
|
"batch_size": self.BATCH_SIZE,
|
||||||
"total_activities": self._total_activities, # 添加到队列的活动总数
|
"total_activities": self._total_activities, # Total activities added to queue
|
||||||
"batches_sent": self._total_sent, # 成功发送的批次数
|
"batches_sent": self._total_sent, # Batches successfully sent
|
||||||
"items_sent": self._total_items_sent, # 成功发送的活动条数
|
"items_sent": self._total_items_sent, # Individual activities successfully sent
|
||||||
"failed_count": self._failed_count, # 发送失败的批次数
|
"failed_count": self._failed_count, # Batches that failed to send
|
||||||
"skipped_count": self._skipped_count, # 被过滤跳过的活动数(DO_NOTHING)
|
"skipped_count": self._skipped_count, # Activities filtered out (DO_NOTHING)
|
||||||
"queue_size": self._activity_queue.qsize(),
|
"queue_size": self._activity_queue.qsize(),
|
||||||
"buffer_sizes": buffer_sizes, # 各平台缓冲区大小
|
"buffer_sizes": buffer_sizes, # Per-platform buffer sizes
|
||||||
"running": self._running,
|
"running": self._running,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class ZepGraphMemoryManager:
|
class ZepGraphMemoryManager:
|
||||||
"""
|
"""
|
||||||
管理多个模拟的Zep图谱记忆更新器
|
Manages Zep graph memory updaters for multiple simulations
|
||||||
|
|
||||||
每个模拟可以有自己的更新器实例
|
Each simulation can have its own updater instance
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_updaters: Dict[str, ZepGraphMemoryUpdater] = {}
|
_updaters: Dict[str, ZepGraphMemoryUpdater] = {}
|
||||||
|
|
@ -489,17 +491,17 @@ class ZepGraphMemoryManager:
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_updater(cls, simulation_id: str, graph_id: str) -> ZepGraphMemoryUpdater:
|
def create_updater(cls, simulation_id: str, graph_id: str) -> ZepGraphMemoryUpdater:
|
||||||
"""
|
"""
|
||||||
为模拟创建图谱记忆更新器
|
Create a graph memory updater for a simulation
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
simulation_id: 模拟ID
|
simulation_id: simulation ID
|
||||||
graph_id: Zep图谱ID
|
graph_id: Zep graph ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ZepGraphMemoryUpdater实例
|
ZepGraphMemoryUpdater instance
|
||||||
"""
|
"""
|
||||||
with cls._lock:
|
with cls._lock:
|
||||||
# 如果已存在,先停止旧的
|
# If one already exists, stop it first
|
||||||
if simulation_id in cls._updaters:
|
if simulation_id in cls._updaters:
|
||||||
cls._updaters[simulation_id].stop()
|
cls._updaters[simulation_id].stop()
|
||||||
|
|
||||||
|
|
@ -507,30 +509,30 @@ class ZepGraphMemoryManager:
|
||||||
updater.start()
|
updater.start()
|
||||||
cls._updaters[simulation_id] = updater
|
cls._updaters[simulation_id] = updater
|
||||||
|
|
||||||
logger.info(f"创建图谱记忆更新器: simulation_id={simulation_id}, graph_id={graph_id}")
|
logger.info(f"Created graph memory updater: simulation_id={simulation_id}, graph_id={graph_id}")
|
||||||
return updater
|
return updater
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_updater(cls, simulation_id: str) -> Optional[ZepGraphMemoryUpdater]:
|
def get_updater(cls, simulation_id: str) -> Optional[ZepGraphMemoryUpdater]:
|
||||||
"""获取模拟的更新器"""
|
"""Get the updater for a simulation"""
|
||||||
return cls._updaters.get(simulation_id)
|
return cls._updaters.get(simulation_id)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def stop_updater(cls, simulation_id: str):
|
def stop_updater(cls, simulation_id: str):
|
||||||
"""停止并移除模拟的更新器"""
|
"""Stop and remove the updater for a simulation"""
|
||||||
with cls._lock:
|
with cls._lock:
|
||||||
if simulation_id in cls._updaters:
|
if simulation_id in cls._updaters:
|
||||||
cls._updaters[simulation_id].stop()
|
cls._updaters[simulation_id].stop()
|
||||||
del cls._updaters[simulation_id]
|
del cls._updaters[simulation_id]
|
||||||
logger.info(f"已停止图谱记忆更新器: simulation_id={simulation_id}")
|
logger.info(f"Stopped graph memory updater: simulation_id={simulation_id}")
|
||||||
|
|
||||||
# 防止 stop_all 重复调用的标志
|
# Flag to prevent stop_all from being called more than once
|
||||||
_stop_all_done = False
|
_stop_all_done = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def stop_all(cls):
|
def stop_all(cls):
|
||||||
"""停止所有更新器"""
|
"""Stop all updaters"""
|
||||||
# 防止重复调用
|
# Prevent duplicate calls
|
||||||
if cls._stop_all_done:
|
if cls._stop_all_done:
|
||||||
return
|
return
|
||||||
cls._stop_all_done = True
|
cls._stop_all_done = True
|
||||||
|
|
@ -541,13 +543,13 @@ class ZepGraphMemoryManager:
|
||||||
try:
|
try:
|
||||||
updater.stop()
|
updater.stop()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"停止更新器失败: simulation_id={simulation_id}, error={e}")
|
logger.error(f"Failed to stop updater: simulation_id={simulation_id}, error={e}")
|
||||||
cls._updaters.clear()
|
cls._updaters.clear()
|
||||||
logger.info("已停止所有图谱记忆更新器")
|
logger.info("All graph memory updaters stopped")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_all_stats(cls) -> Dict[str, Dict[str, Any]]:
|
def get_all_stats(cls) -> Dict[str, Dict[str, Any]]:
|
||||||
"""获取所有更新器的统计信息"""
|
"""Get statistics for all updaters"""
|
||||||
return {
|
return {
|
||||||
sim_id: updater.get_stats()
|
sim_id: updater.get_stats()
|
||||||
for sim_id, updater in cls._updaters.items()
|
for sim_id, updater in cls._updaters.items()
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,5 +1,5 @@
|
||||||
"""
|
"""
|
||||||
工具模块
|
Utilities module
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .file_parser import FileParser
|
from .file_parser import FileParser
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"""
|
"""
|
||||||
文件解析工具
|
File parsing utilities
|
||||||
支持PDF、Markdown、TXT文件的文本提取
|
Supports text extraction from PDF, Markdown, and TXT files
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
@ -10,29 +10,29 @@ from typing import List, Optional
|
||||||
|
|
||||||
def _read_text_with_fallback(file_path: str) -> str:
|
def _read_text_with_fallback(file_path: str) -> str:
|
||||||
"""
|
"""
|
||||||
读取文本文件,UTF-8失败时自动探测编码。
|
Read a text file, automatically detecting encoding if UTF-8 fails.
|
||||||
|
|
||||||
采用多级回退策略:
|
Uses a multi-level fallback strategy:
|
||||||
1. 首先尝试 UTF-8 解码
|
1. First attempts UTF-8 decoding
|
||||||
2. 使用 charset_normalizer 检测编码
|
2. Uses charset_normalizer to detect encoding
|
||||||
3. 回退到 chardet 检测编码
|
3. Falls back to chardet for encoding detection
|
||||||
4. 最终使用 UTF-8 + errors='replace' 兜底
|
4. Final fallback: UTF-8 with errors='replace'
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_path: 文件路径
|
file_path: Path to the file
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
解码后的文本内容
|
Decoded text content
|
||||||
"""
|
"""
|
||||||
data = Path(file_path).read_bytes()
|
data = Path(file_path).read_bytes()
|
||||||
|
|
||||||
# 首先尝试 UTF-8
|
# First attempt: UTF-8
|
||||||
try:
|
try:
|
||||||
return data.decode('utf-8')
|
return data.decode('utf-8')
|
||||||
except UnicodeDecodeError:
|
except UnicodeDecodeError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# 尝试使用 charset_normalizer 检测编码
|
# Attempt encoding detection with charset_normalizer
|
||||||
encoding = None
|
encoding = None
|
||||||
try:
|
try:
|
||||||
from charset_normalizer import from_bytes
|
from charset_normalizer import from_bytes
|
||||||
|
|
@ -42,7 +42,7 @@ def _read_text_with_fallback(file_path: str) -> str:
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# 回退到 chardet
|
# Fall back to chardet
|
||||||
if not encoding:
|
if not encoding:
|
||||||
try:
|
try:
|
||||||
import chardet
|
import chardet
|
||||||
|
|
@ -51,7 +51,7 @@ def _read_text_with_fallback(file_path: str) -> str:
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# 最终兜底:使用 UTF-8 + replace
|
# Final fallback: UTF-8 with replace
|
||||||
if not encoding:
|
if not encoding:
|
||||||
encoding = 'utf-8'
|
encoding = 'utf-8'
|
||||||
|
|
||||||
|
|
@ -59,30 +59,30 @@ def _read_text_with_fallback(file_path: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
class FileParser:
|
class FileParser:
|
||||||
"""文件解析器"""
|
"""File parser"""
|
||||||
|
|
||||||
SUPPORTED_EXTENSIONS = {'.pdf', '.md', '.markdown', '.txt'}
|
SUPPORTED_EXTENSIONS = {'.pdf', '.md', '.markdown', '.txt'}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def extract_text(cls, file_path: str) -> str:
|
def extract_text(cls, file_path: str) -> str:
|
||||||
"""
|
"""
|
||||||
从文件中提取文本
|
Extract text from a file
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_path: 文件路径
|
file_path: Path to the file
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
提取的文本内容
|
Extracted text content
|
||||||
"""
|
"""
|
||||||
path = Path(file_path)
|
path = Path(file_path)
|
||||||
|
|
||||||
if not path.exists():
|
if not path.exists():
|
||||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
raise FileNotFoundError(f"File not found: {file_path}")
|
||||||
|
|
||||||
suffix = path.suffix.lower()
|
suffix = path.suffix.lower()
|
||||||
|
|
||||||
if suffix not in cls.SUPPORTED_EXTENSIONS:
|
if suffix not in cls.SUPPORTED_EXTENSIONS:
|
||||||
raise ValueError(f"不支持的文件格式: {suffix}")
|
raise ValueError(f"Unsupported file format: {suffix}")
|
||||||
|
|
||||||
if suffix == '.pdf':
|
if suffix == '.pdf':
|
||||||
return cls._extract_from_pdf(file_path)
|
return cls._extract_from_pdf(file_path)
|
||||||
|
|
@ -91,15 +91,15 @@ class FileParser:
|
||||||
elif suffix == '.txt':
|
elif suffix == '.txt':
|
||||||
return cls._extract_from_txt(file_path)
|
return cls._extract_from_txt(file_path)
|
||||||
|
|
||||||
raise ValueError(f"无法处理的文件格式: {suffix}")
|
raise ValueError(f"Cannot process file format: {suffix}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_from_pdf(file_path: str) -> str:
|
def _extract_from_pdf(file_path: str) -> str:
|
||||||
"""从PDF提取文本"""
|
"""Extract text from a PDF file"""
|
||||||
try:
|
try:
|
||||||
import fitz # PyMuPDF
|
import fitz # PyMuPDF
|
||||||
except ImportError:
|
except ImportError:
|
||||||
raise ImportError("需要安装PyMuPDF: pip install PyMuPDF")
|
raise ImportError("PyMuPDF is required: pip install PyMuPDF")
|
||||||
|
|
||||||
text_parts = []
|
text_parts = []
|
||||||
with fitz.open(file_path) as doc:
|
with fitz.open(file_path) as doc:
|
||||||
|
|
@ -112,24 +112,24 @@ class FileParser:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_from_md(file_path: str) -> str:
|
def _extract_from_md(file_path: str) -> str:
|
||||||
"""从Markdown提取文本,支持自动编码检测"""
|
"""Extract text from a Markdown file with automatic encoding detection"""
|
||||||
return _read_text_with_fallback(file_path)
|
return _read_text_with_fallback(file_path)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _extract_from_txt(file_path: str) -> str:
|
def _extract_from_txt(file_path: str) -> str:
|
||||||
"""从TXT提取文本,支持自动编码检测"""
|
"""Extract text from a TXT file with automatic encoding detection"""
|
||||||
return _read_text_with_fallback(file_path)
|
return _read_text_with_fallback(file_path)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def extract_from_multiple(cls, file_paths: List[str]) -> str:
|
def extract_from_multiple(cls, file_paths: List[str]) -> str:
|
||||||
"""
|
"""
|
||||||
从多个文件提取文本并合并
|
Extract text from multiple files and merge the results
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
file_paths: 文件路径列表
|
file_paths: List of file paths
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
合并后的文本
|
Merged text content
|
||||||
"""
|
"""
|
||||||
all_texts = []
|
all_texts = []
|
||||||
|
|
||||||
|
|
@ -137,9 +137,9 @@ class FileParser:
|
||||||
try:
|
try:
|
||||||
text = cls.extract_text(file_path)
|
text = cls.extract_text(file_path)
|
||||||
filename = Path(file_path).name
|
filename = Path(file_path).name
|
||||||
all_texts.append(f"=== 文档 {i}: {filename} ===\n{text}")
|
all_texts.append(f"=== Document {i}: {filename} ===\n{text}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
all_texts.append(f"=== 文档 {i}: {file_path} (提取失败: {str(e)}) ===")
|
all_texts.append(f"=== Document {i}: {file_path} (extraction failed: {str(e)}) ===")
|
||||||
|
|
||||||
return "\n\n".join(all_texts)
|
return "\n\n".join(all_texts)
|
||||||
|
|
||||||
|
|
@ -150,15 +150,15 @@ def split_text_into_chunks(
|
||||||
overlap: int = 50
|
overlap: int = 50
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
"""
|
"""
|
||||||
将文本分割成小块
|
Split text into smaller chunks
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: 原始文本
|
text: Source text
|
||||||
chunk_size: 每块的字符数
|
chunk_size: Number of characters per chunk
|
||||||
overlap: 重叠字符数
|
overlap: Number of overlapping characters between chunks
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
文本块列表
|
List of text chunks
|
||||||
"""
|
"""
|
||||||
if len(text) <= chunk_size:
|
if len(text) <= chunk_size:
|
||||||
return [text] if text.strip() else []
|
return [text] if text.strip() else []
|
||||||
|
|
@ -169,9 +169,9 @@ def split_text_into_chunks(
|
||||||
while start < len(text):
|
while start < len(text):
|
||||||
end = start + chunk_size
|
end = start + chunk_size
|
||||||
|
|
||||||
# 尝试在句子边界处分割
|
# Try to split at sentence boundaries
|
||||||
if end < len(text):
|
if end < len(text):
|
||||||
# 查找最近的句子结束符
|
# Find the nearest sentence-ending separator
|
||||||
for sep in ['。', '!', '?', '.\n', '!\n', '?\n', '\n\n', '. ', '! ', '? ']:
|
for sep in ['。', '!', '?', '.\n', '!\n', '?\n', '\n\n', '. ', '! ', '? ']:
|
||||||
last_sep = text[start:end].rfind(sep)
|
last_sep = text[start:end].rfind(sep)
|
||||||
if last_sep != -1 and last_sep > chunk_size * 0.3:
|
if last_sep != -1 and last_sep > chunk_size * 0.3:
|
||||||
|
|
@ -182,8 +182,7 @@ def split_text_into_chunks(
|
||||||
if chunk:
|
if chunk:
|
||||||
chunks.append(chunk)
|
chunks.append(chunk)
|
||||||
|
|
||||||
# 下一个块从重叠位置开始
|
# Next chunk starts at the overlap position
|
||||||
start = end - overlap if end < len(text) else len(text)
|
start = end - overlap if end < len(text) else len(text)
|
||||||
|
|
||||||
return chunks
|
return chunks
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,19 @@
|
||||||
"""
|
"""
|
||||||
LLM客户端封装
|
LLM client wrapper
|
||||||
统一使用OpenAI格式调用
|
Unified interface using the OpenAI-compatible API format
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from typing import Optional, Dict, Any, List
|
from typing import Optional, Dict, Any, List
|
||||||
|
from urllib.parse import urlparse, parse_qs, urlunparse
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
|
|
||||||
from ..config import Config
|
from ..config import Config
|
||||||
|
|
||||||
|
|
||||||
class LLMClient:
|
class LLMClient:
|
||||||
"""LLM客户端"""
|
"""LLM client"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -21,15 +22,30 @@ class LLMClient:
|
||||||
model: Optional[str] = None
|
model: Optional[str] = None
|
||||||
):
|
):
|
||||||
self.api_key = api_key or Config.LLM_API_KEY
|
self.api_key = api_key or Config.LLM_API_KEY
|
||||||
self.base_url = base_url or Config.LLM_BASE_URL
|
raw_url = base_url or Config.LLM_BASE_URL
|
||||||
self.model = model or Config.LLM_MODEL_NAME
|
self.model = model or Config.LLM_MODEL_NAME
|
||||||
|
|
||||||
if not self.api_key:
|
if not self.api_key:
|
||||||
raise ValueError("LLM_API_KEY 未配置")
|
raise ValueError("LLM_API_KEY is not configured")
|
||||||
|
|
||||||
|
# Azure Portal provides full endpoint URLs like:
|
||||||
|
# https://<resource>.cognitiveservices.azure.com/openai/deployments/<model>/chat/completions?api-version=...
|
||||||
|
# The OpenAI SDK expects a base_url and appends /chat/completions itself,
|
||||||
|
# so we strip that suffix and extract api-version as a default query param.
|
||||||
|
default_query: Dict[str, str] = {}
|
||||||
|
if raw_url and '/chat/completions' in raw_url:
|
||||||
|
parsed = urlparse(raw_url)
|
||||||
|
qs = parse_qs(parsed.query)
|
||||||
|
if 'api-version' in qs:
|
||||||
|
default_query['api-version'] = qs['api-version'][0]
|
||||||
|
clean_path = parsed.path.replace('/chat/completions', '').rstrip('/')
|
||||||
|
raw_url = urlunparse(parsed._replace(path=clean_path, query=''))
|
||||||
|
|
||||||
|
self.base_url = raw_url
|
||||||
self.client = OpenAI(
|
self.client = OpenAI(
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
base_url=self.base_url
|
base_url=self.base_url,
|
||||||
|
default_query=default_query if default_query else None
|
||||||
)
|
)
|
||||||
|
|
||||||
def chat(
|
def chat(
|
||||||
|
|
@ -40,22 +56,22 @@ class LLMClient:
|
||||||
response_format: Optional[Dict] = None
|
response_format: Optional[Dict] = None
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
发送聊天请求
|
Send a chat request
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: 消息列表
|
messages: List of messages
|
||||||
temperature: 温度参数
|
temperature: Temperature parameter
|
||||||
max_tokens: 最大token数
|
max_tokens: Maximum number of tokens
|
||||||
response_format: 响应格式(如JSON模式)
|
response_format: Response format (e.g. JSON mode)
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
模型响应文本
|
Model response text
|
||||||
"""
|
"""
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"temperature": temperature,
|
"temperature": temperature,
|
||||||
"max_tokens": max_tokens,
|
"max_completion_tokens": max_tokens,
|
||||||
}
|
}
|
||||||
|
|
||||||
if response_format:
|
if response_format:
|
||||||
|
|
@ -63,7 +79,7 @@ class LLMClient:
|
||||||
|
|
||||||
response = self.client.chat.completions.create(**kwargs)
|
response = self.client.chat.completions.create(**kwargs)
|
||||||
content = response.choices[0].message.content
|
content = response.choices[0].message.content
|
||||||
# 部分模型(如MiniMax M2.5)会在content中包含<think>思考内容,需要移除
|
# Some models (e.g. MiniMax M2.5) include <think> reasoning content in the response; strip it out
|
||||||
content = re.sub(r'<think>[\s\S]*?</think>', '', content).strip()
|
content = re.sub(r'<think>[\s\S]*?</think>', '', content).strip()
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
@ -74,15 +90,15 @@ class LLMClient:
|
||||||
max_tokens: int = 4096
|
max_tokens: int = 4096
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
发送聊天请求并返回JSON
|
Send a chat request and return parsed JSON
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
messages: 消息列表
|
messages: List of messages
|
||||||
temperature: 温度参数
|
temperature: Temperature parameter
|
||||||
max_tokens: 最大token数
|
max_tokens: Maximum number of tokens
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
解析后的JSON对象
|
Parsed JSON object
|
||||||
"""
|
"""
|
||||||
response = self.chat(
|
response = self.chat(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|
@ -90,7 +106,7 @@ class LLMClient:
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
response_format={"type": "json_object"}
|
response_format={"type": "json_object"}
|
||||||
)
|
)
|
||||||
# 清理markdown代码块标记
|
# Strip markdown code-block markers if present
|
||||||
cleaned_response = response.strip()
|
cleaned_response = response.strip()
|
||||||
cleaned_response = re.sub(r'^```(?:json)?\s*\n?', '', cleaned_response, flags=re.IGNORECASE)
|
cleaned_response = re.sub(r'^```(?:json)?\s*\n?', '', cleaned_response, flags=re.IGNORECASE)
|
||||||
cleaned_response = re.sub(r'\n?```\s*$', '', cleaned_response)
|
cleaned_response = re.sub(r'\n?```\s*$', '', cleaned_response)
|
||||||
|
|
@ -99,5 +115,4 @@ class LLMClient:
|
||||||
try:
|
try:
|
||||||
return json.loads(cleaned_response)
|
return json.loads(cleaned_response)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
raise ValueError(f"LLM返回的JSON格式无效: {cleaned_response}")
|
raise ValueError(f"Invalid JSON returned by LLM: {cleaned_response}")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -66,4 +66,4 @@ def t(key: str, **kwargs) -> str:
|
||||||
def get_language_instruction() -> str:
|
def get_language_instruction() -> str:
|
||||||
locale = get_locale()
|
locale = get_locale()
|
||||||
lang_config = _languages.get(locale, _languages.get('zh', {}))
|
lang_config = _languages.get(locale, _languages.get('zh', {}))
|
||||||
return lang_config.get('llmInstruction', '请使用中文回答。')
|
return lang_config.get('llmInstruction', 'Please respond in Chinese.')
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"""
|
"""
|
||||||
日志配置模块
|
Logging configuration module
|
||||||
提供统一的日志管理,同时输出到控制台和文件
|
Provides unified log management, writing to both console and file
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
@ -12,47 +12,47 @@ from logging.handlers import RotatingFileHandler
|
||||||
|
|
||||||
def _ensure_utf8_stdout():
|
def _ensure_utf8_stdout():
|
||||||
"""
|
"""
|
||||||
确保 stdout/stderr 使用 UTF-8 编码
|
Ensure stdout/stderr use UTF-8 encoding.
|
||||||
解决 Windows 控制台中文乱码问题
|
Fixes garbled output in Windows consoles.
|
||||||
"""
|
"""
|
||||||
if sys.platform == 'win32':
|
if sys.platform == 'win32':
|
||||||
# Windows 下重新配置标准输出为 UTF-8
|
# Reconfigure standard streams to UTF-8 on Windows
|
||||||
if hasattr(sys.stdout, 'reconfigure'):
|
if hasattr(sys.stdout, 'reconfigure'):
|
||||||
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
|
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
|
||||||
if hasattr(sys.stderr, 'reconfigure'):
|
if hasattr(sys.stderr, 'reconfigure'):
|
||||||
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
|
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
|
||||||
|
|
||||||
|
|
||||||
# 日志目录
|
# Log directory
|
||||||
LOG_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'logs')
|
LOG_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'logs')
|
||||||
|
|
||||||
|
|
||||||
def setup_logger(name: str = 'mirofish', level: int = logging.DEBUG) -> logging.Logger:
|
def setup_logger(name: str = 'mirofish', level: int = logging.DEBUG) -> logging.Logger:
|
||||||
"""
|
"""
|
||||||
设置日志器
|
Set up a logger
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: 日志器名称
|
name: Logger name
|
||||||
level: 日志级别
|
level: Log level
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
配置好的日志器
|
Configured logger instance
|
||||||
"""
|
"""
|
||||||
# 确保日志目录存在
|
# Ensure the log directory exists
|
||||||
os.makedirs(LOG_DIR, exist_ok=True)
|
os.makedirs(LOG_DIR, exist_ok=True)
|
||||||
|
|
||||||
# 创建日志器
|
# Create logger
|
||||||
logger = logging.getLogger(name)
|
logger = logging.getLogger(name)
|
||||||
logger.setLevel(level)
|
logger.setLevel(level)
|
||||||
|
|
||||||
# 阻止日志向上传播到根 logger,避免重复输出
|
# Prevent log records from propagating to the root logger to avoid duplicate output
|
||||||
logger.propagate = False
|
logger.propagate = False
|
||||||
|
|
||||||
# 如果已经有处理器,不重复添加
|
# Skip adding handlers if they already exist
|
||||||
if logger.handlers:
|
if logger.handlers:
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
# 日志格式
|
# Log formatters
|
||||||
detailed_formatter = logging.Formatter(
|
detailed_formatter = logging.Formatter(
|
||||||
'[%(asctime)s] %(levelname)s [%(name)s.%(funcName)s:%(lineno)d] %(message)s',
|
'[%(asctime)s] %(levelname)s [%(name)s.%(funcName)s:%(lineno)d] %(message)s',
|
||||||
datefmt='%Y-%m-%d %H:%M:%S'
|
datefmt='%Y-%m-%d %H:%M:%S'
|
||||||
|
|
@ -63,7 +63,7 @@ def setup_logger(name: str = 'mirofish', level: int = logging.DEBUG) -> logging.
|
||||||
datefmt='%H:%M:%S'
|
datefmt='%H:%M:%S'
|
||||||
)
|
)
|
||||||
|
|
||||||
# 1. 文件处理器 - 详细日志(按日期命名,带轮转)
|
# 1. File handler — detailed logs (date-stamped filename with rotation)
|
||||||
log_filename = datetime.now().strftime('%Y-%m-%d') + '.log'
|
log_filename = datetime.now().strftime('%Y-%m-%d') + '.log'
|
||||||
file_handler = RotatingFileHandler(
|
file_handler = RotatingFileHandler(
|
||||||
os.path.join(LOG_DIR, log_filename),
|
os.path.join(LOG_DIR, log_filename),
|
||||||
|
|
@ -74,14 +74,14 @@ def setup_logger(name: str = 'mirofish', level: int = logging.DEBUG) -> logging.
|
||||||
file_handler.setLevel(logging.DEBUG)
|
file_handler.setLevel(logging.DEBUG)
|
||||||
file_handler.setFormatter(detailed_formatter)
|
file_handler.setFormatter(detailed_formatter)
|
||||||
|
|
||||||
# 2. 控制台处理器 - 简洁日志(INFO及以上)
|
# 2. Console handler — concise logs (INFO and above)
|
||||||
# 确保 Windows 下使用 UTF-8 编码,避免中文乱码
|
# Ensure UTF-8 encoding on Windows to avoid garbled output
|
||||||
_ensure_utf8_stdout()
|
_ensure_utf8_stdout()
|
||||||
console_handler = logging.StreamHandler(sys.stdout)
|
console_handler = logging.StreamHandler(sys.stdout)
|
||||||
console_handler.setLevel(logging.INFO)
|
console_handler.setLevel(logging.INFO)
|
||||||
console_handler.setFormatter(simple_formatter)
|
console_handler.setFormatter(simple_formatter)
|
||||||
|
|
||||||
# 添加处理器
|
# Register handlers
|
||||||
logger.addHandler(file_handler)
|
logger.addHandler(file_handler)
|
||||||
logger.addHandler(console_handler)
|
logger.addHandler(console_handler)
|
||||||
|
|
||||||
|
|
@ -90,13 +90,13 @@ def setup_logger(name: str = 'mirofish', level: int = logging.DEBUG) -> logging.
|
||||||
|
|
||||||
def get_logger(name: str = 'mirofish') -> logging.Logger:
|
def get_logger(name: str = 'mirofish') -> logging.Logger:
|
||||||
"""
|
"""
|
||||||
获取日志器(如果不存在则创建)
|
Get a logger, creating it if it does not exist
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name: 日志器名称
|
name: Logger name
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
日志器实例
|
Logger instance
|
||||||
"""
|
"""
|
||||||
logger = logging.getLogger(name)
|
logger = logging.getLogger(name)
|
||||||
if not logger.handlers:
|
if not logger.handlers:
|
||||||
|
|
@ -104,11 +104,11 @@ def get_logger(name: str = 'mirofish') -> logging.Logger:
|
||||||
return logger
|
return logger
|
||||||
|
|
||||||
|
|
||||||
# 创建默认日志器
|
# Create default logger
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
# 便捷方法
|
# Convenience functions
|
||||||
def debug(msg, *args, **kwargs):
|
def debug(msg, *args, **kwargs):
|
||||||
logger.debug(msg, *args, **kwargs)
|
logger.debug(msg, *args, **kwargs)
|
||||||
|
|
||||||
|
|
@ -123,4 +123,3 @@ def error(msg, *args, **kwargs):
|
||||||
|
|
||||||
def critical(msg, *args, **kwargs):
|
def critical(msg, *args, **kwargs):
|
||||||
logger.critical(msg, *args, **kwargs)
|
logger.critical(msg, *args, **kwargs)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
"""
|
"""
|
||||||
API调用重试机制
|
API call retry mechanism
|
||||||
用于处理LLM等外部API调用的重试逻辑
|
Handles retry logic for external API calls such as LLM services
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
|
@ -22,16 +22,16 @@ def retry_with_backoff(
|
||||||
on_retry: Optional[Callable[[Exception, int], None]] = None
|
on_retry: Optional[Callable[[Exception, int], None]] = None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
带指数退避的重试装饰器
|
Retry decorator with exponential backoff
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
max_retries: 最大重试次数
|
max_retries: Maximum number of retries
|
||||||
initial_delay: 初始延迟(秒)
|
initial_delay: Initial delay in seconds
|
||||||
max_delay: 最大延迟(秒)
|
max_delay: Maximum delay in seconds
|
||||||
backoff_factor: 退避因子
|
backoff_factor: Backoff multiplier
|
||||||
jitter: 是否添加随机抖动
|
jitter: Whether to add random jitter
|
||||||
exceptions: 需要重试的异常类型
|
exceptions: Exception types that should trigger a retry
|
||||||
on_retry: 重试时的回调函数 (exception, retry_count)
|
on_retry: Callback invoked on each retry (exception, retry_count)
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
@retry_with_backoff(max_retries=3)
|
@retry_with_backoff(max_retries=3)
|
||||||
|
|
@ -52,17 +52,17 @@ def retry_with_backoff(
|
||||||
last_exception = e
|
last_exception = e
|
||||||
|
|
||||||
if attempt == max_retries:
|
if attempt == max_retries:
|
||||||
logger.error(f"函数 {func.__name__} 在 {max_retries} 次重试后仍失败: {str(e)}")
|
logger.error(f"Function {func.__name__} failed after {max_retries} retries: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
# 计算延迟
|
# Calculate delay
|
||||||
current_delay = min(delay, max_delay)
|
current_delay = min(delay, max_delay)
|
||||||
if jitter:
|
if jitter:
|
||||||
current_delay = current_delay * (0.5 + random.random())
|
current_delay = current_delay * (0.5 + random.random())
|
||||||
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"函数 {func.__name__} 第 {attempt + 1} 次尝试失败: {str(e)}, "
|
f"Function {func.__name__} attempt {attempt + 1} failed: {str(e)}, "
|
||||||
f"{current_delay:.1f}秒后重试..."
|
f"retrying in {current_delay:.1f}s..."
|
||||||
)
|
)
|
||||||
|
|
||||||
if on_retry:
|
if on_retry:
|
||||||
|
|
@ -87,7 +87,7 @@ def retry_with_backoff_async(
|
||||||
on_retry: Optional[Callable[[Exception, int], None]] = None
|
on_retry: Optional[Callable[[Exception, int], None]] = None
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
异步版本的重试装饰器
|
Async version of the retry decorator
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
|
@ -105,7 +105,7 @@ def retry_with_backoff_async(
|
||||||
last_exception = e
|
last_exception = e
|
||||||
|
|
||||||
if attempt == max_retries:
|
if attempt == max_retries:
|
||||||
logger.error(f"异步函数 {func.__name__} 在 {max_retries} 次重试后仍失败: {str(e)}")
|
logger.error(f"Async function {func.__name__} failed after {max_retries} retries: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
current_delay = min(delay, max_delay)
|
current_delay = min(delay, max_delay)
|
||||||
|
|
@ -113,8 +113,8 @@ def retry_with_backoff_async(
|
||||||
current_delay = current_delay * (0.5 + random.random())
|
current_delay = current_delay * (0.5 + random.random())
|
||||||
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"异步函数 {func.__name__} 第 {attempt + 1} 次尝试失败: {str(e)}, "
|
f"Async function {func.__name__} attempt {attempt + 1} failed: {str(e)}, "
|
||||||
f"{current_delay:.1f}秒后重试..."
|
f"retrying in {current_delay:.1f}s..."
|
||||||
)
|
)
|
||||||
|
|
||||||
if on_retry:
|
if on_retry:
|
||||||
|
|
@ -131,7 +131,7 @@ def retry_with_backoff_async(
|
||||||
|
|
||||||
class RetryableAPIClient:
|
class RetryableAPIClient:
|
||||||
"""
|
"""
|
||||||
可重试的API客户端封装
|
Retryable API client wrapper
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -154,16 +154,16 @@ class RetryableAPIClient:
|
||||||
**kwargs
|
**kwargs
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""
|
"""
|
||||||
执行函数调用并在失败时重试
|
Execute a function call and retry on failure
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
func: 要调用的函数
|
func: Function to call
|
||||||
*args: 函数参数
|
*args: Positional arguments for the function
|
||||||
exceptions: 需要重试的异常类型
|
exceptions: Exception types that should trigger a retry
|
||||||
**kwargs: 函数关键字参数
|
**kwargs: Keyword arguments for the function
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
函数返回值
|
Return value of the function
|
||||||
"""
|
"""
|
||||||
last_exception = None
|
last_exception = None
|
||||||
delay = self.initial_delay
|
delay = self.initial_delay
|
||||||
|
|
@ -176,15 +176,15 @@ class RetryableAPIClient:
|
||||||
last_exception = e
|
last_exception = e
|
||||||
|
|
||||||
if attempt == self.max_retries:
|
if attempt == self.max_retries:
|
||||||
logger.error(f"API调用在 {self.max_retries} 次重试后仍失败: {str(e)}")
|
logger.error(f"API call failed after {self.max_retries} retries: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
current_delay = min(delay, self.max_delay)
|
current_delay = min(delay, self.max_delay)
|
||||||
current_delay = current_delay * (0.5 + random.random())
|
current_delay = current_delay * (0.5 + random.random())
|
||||||
|
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"API调用第 {attempt + 1} 次尝试失败: {str(e)}, "
|
f"API call attempt {attempt + 1} failed: {str(e)}, "
|
||||||
f"{current_delay:.1f}秒后重试..."
|
f"retrying in {current_delay:.1f}s..."
|
||||||
)
|
)
|
||||||
|
|
||||||
time.sleep(current_delay)
|
time.sleep(current_delay)
|
||||||
|
|
@ -200,16 +200,16 @@ class RetryableAPIClient:
|
||||||
continue_on_failure: bool = True
|
continue_on_failure: bool = True
|
||||||
) -> Tuple[list, list]:
|
) -> Tuple[list, list]:
|
||||||
"""
|
"""
|
||||||
批量调用并对每个失败项单独重试
|
Process a batch of items, retrying individually on failure
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
items: 要处理的项目列表
|
items: List of items to process
|
||||||
process_func: 处理函数,接收单个item作为参数
|
process_func: Processing function that accepts a single item
|
||||||
exceptions: 需要重试的异常类型
|
exceptions: Exception types that should trigger a retry
|
||||||
continue_on_failure: 单项失败后是否继续处理其他项
|
continue_on_failure: Whether to continue processing remaining items after a failure
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
(成功结果列表, 失败项列表)
|
(list of successful results, list of failed items)
|
||||||
"""
|
"""
|
||||||
results = []
|
results = []
|
||||||
failures = []
|
failures = []
|
||||||
|
|
@ -224,7 +224,7 @@ class RetryableAPIClient:
|
||||||
results.append(result)
|
results.append(result)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"处理第 {idx + 1} 项失败: {str(e)}")
|
logger.error(f"Failed to process item {idx + 1}: {str(e)}")
|
||||||
failures.append({
|
failures.append({
|
||||||
"index": idx,
|
"index": idx,
|
||||||
"item": item,
|
"item": item,
|
||||||
|
|
@ -235,4 +235,3 @@ class RetryableAPIClient:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
return results, failures
|
return results, failures
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
"""Zep Graph 分页读取工具。
|
"""Zep Graph paginated fetch utilities.
|
||||||
|
|
||||||
Zep 的 node/edge 列表接口使用 UUID cursor 分页,
|
Zep's node/edge list endpoints use UUID-cursor-based pagination.
|
||||||
本模块封装自动翻页逻辑(含单页重试),对调用方透明地返回完整列表。
|
This module wraps the auto-pagination logic (with per-page retries) and
|
||||||
|
returns the full result list transparently to the caller.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
@ -31,7 +32,7 @@ def _fetch_page_with_retry(
|
||||||
page_description: str = "page",
|
page_description: str = "page",
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> list[Any]:
|
) -> list[Any]:
|
||||||
"""单页请求,失败时指数退避重试。仅重试网络/IO类瞬态错误。"""
|
"""Fetch a single page with exponential-backoff retry on transient network/IO errors."""
|
||||||
if max_retries < 1:
|
if max_retries < 1:
|
||||||
raise ValueError("max_retries must be >= 1")
|
raise ValueError("max_retries must be >= 1")
|
||||||
|
|
||||||
|
|
@ -64,7 +65,7 @@ def fetch_all_nodes(
|
||||||
max_retries: int = _DEFAULT_MAX_RETRIES,
|
max_retries: int = _DEFAULT_MAX_RETRIES,
|
||||||
retry_delay: float = _DEFAULT_RETRY_DELAY,
|
retry_delay: float = _DEFAULT_RETRY_DELAY,
|
||||||
) -> list[Any]:
|
) -> list[Any]:
|
||||||
"""分页获取图谱节点,最多返回 max_items 条(默认 2000)。每页请求自带重试。"""
|
"""Fetch all graph nodes with pagination, returning at most max_items (default 2000). Each page request includes retries."""
|
||||||
all_nodes: list[Any] = []
|
all_nodes: list[Any] = []
|
||||||
cursor: str | None = None
|
cursor: str | None = None
|
||||||
page_num = 0
|
page_num = 0
|
||||||
|
|
@ -109,7 +110,7 @@ def fetch_all_edges(
|
||||||
max_retries: int = _DEFAULT_MAX_RETRIES,
|
max_retries: int = _DEFAULT_MAX_RETRIES,
|
||||||
retry_delay: float = _DEFAULT_RETRY_DELAY,
|
retry_delay: float = _DEFAULT_RETRY_DELAY,
|
||||||
) -> list[Any]:
|
) -> list[Any]:
|
||||||
"""分页获取图谱所有边,返回完整列表。每页请求自带重试。"""
|
"""Fetch all graph edges with pagination, returning the complete list. Each page request includes retries."""
|
||||||
all_edges: list[Any] = []
|
all_edges: list[Any] = []
|
||||||
cursor: str | None = None
|
cursor: str | None = None
|
||||||
page_num = 0
|
page_num = 0
|
||||||
|
|
|
||||||
|
|
@ -1435,7 +1435,6 @@
|
||||||
"resolved": "https://registry.npmjs.org/d3-selection/-/d3-selection-3.0.0.tgz",
|
"resolved": "https://registry.npmjs.org/d3-selection/-/d3-selection-3.0.0.tgz",
|
||||||
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
|
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
|
||||||
"license": "ISC",
|
"license": "ISC",
|
||||||
"peer": true,
|
|
||||||
"engines": {
|
"engines": {
|
||||||
"node": ">=12"
|
"node": ">=12"
|
||||||
}
|
}
|
||||||
|
|
@ -1913,7 +1912,6 @@
|
||||||
"integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==",
|
"integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==",
|
||||||
"dev": true,
|
"dev": true,
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
"peer": true,
|
|
||||||
"engines": {
|
"engines": {
|
||||||
"node": ">=12"
|
"node": ">=12"
|
||||||
},
|
},
|
||||||
|
|
@ -2053,7 +2051,6 @@
|
||||||
"integrity": "sha512-ITcnkFeR3+fI8P1wMgItjGrR10170d8auB4EpMLPqmx6uxElH3a/hHGQabSHKdqd4FXWO1nFIp9rRn7JQ34ACQ==",
|
"integrity": "sha512-ITcnkFeR3+fI8P1wMgItjGrR10170d8auB4EpMLPqmx6uxElH3a/hHGQabSHKdqd4FXWO1nFIp9rRn7JQ34ACQ==",
|
||||||
"dev": true,
|
"dev": true,
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
"peer": true,
|
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"esbuild": "^0.25.0",
|
"esbuild": "^0.25.0",
|
||||||
"fdir": "^6.5.0",
|
"fdir": "^6.5.0",
|
||||||
|
|
@ -2128,7 +2125,6 @@
|
||||||
"resolved": "https://registry.npmjs.org/vue/-/vue-3.5.25.tgz",
|
"resolved": "https://registry.npmjs.org/vue/-/vue-3.5.25.tgz",
|
||||||
"integrity": "sha512-YLVdgv2K13WJ6n+kD5owehKtEXwdwXuj2TTyJMsO7pSeKw2bfRNZGjhB7YzrpbMYj5b5QsUebHpOqR3R3ziy/g==",
|
"integrity": "sha512-YLVdgv2K13WJ6n+kD5owehKtEXwdwXuj2TTyJMsO7pSeKw2bfRNZGjhB7YzrpbMYj5b5QsUebHpOqR3R3ziy/g==",
|
||||||
"license": "MIT",
|
"license": "MIT",
|
||||||
"peer": true,
|
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@vue/compiler-dom": "3.5.25",
|
"@vue/compiler-dom": "3.5.25",
|
||||||
"@vue/compiler-sfc": "3.5.25",
|
"@vue/compiler-sfc": "3.5.25",
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import authState, { clearToken } from '../store/auth'
|
||||||
|
|
||||||
// 创建axios实例
|
// 创建axios实例
|
||||||
const service = axios.create({
|
const service = axios.create({
|
||||||
baseURL: import.meta.env.VITE_API_BASE_URL || 'http://localhost:5001',
|
baseURL: import.meta.env.VITE_API_BASE_URL || '',
|
||||||
timeout: 300000, // 5分钟超时(本体生成可能需要较长时间)
|
timeout: 300000, // 5分钟超时(本体生成可能需要较长时间)
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json'
|
'Content-Type': 'application/json'
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ for (const path in localeFiles) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const savedLocale = localStorage.getItem('locale') || 'zh'
|
const savedLocale = localStorage.getItem('locale') || 'ca'
|
||||||
|
|
||||||
const i18n = createI18n({
|
const i18n = createI18n({
|
||||||
legacy: false,
|
legacy: false,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue