chore(i18n): replace all hardcoded Chinese strings with English in backend
Translate all Chinese comments, docstrings, log messages, error messages, and LLM prompt text to English across the entire backend codebase. Locale translation files (locales/*.json) are unchanged. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
e3943c7d7c
commit
7d172b9eec
|
|
@ -20,15 +20,15 @@ set -euo pipefail
|
|||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
|
||||
# ── Carregar configuració ─────────────────────────────────────────────────────
|
||||
CONFIG_FILE="${SCRIPT_DIR}/config.sh"
|
||||
if [[ ! -f "$CONFIG_FILE" ]]; then
|
||||
echo "ERROR: No s'ha trobat azure/config.sh"
|
||||
echo " Còpia l'exemple: cp azure/config.sh.example azure/config.sh"
|
||||
echo " Després omple els valors i torna a executar."
|
||||
exit 1
|
||||
fi
|
||||
#CONFIG_FILE="${SCRIPT_DIR}/config.sh"
|
||||
#if [[ ! -f "$CONFIG_FILE" ]]; then
|
||||
# echo "ERROR: No s'ha trobat azure/config.sh"
|
||||
# echo " Còpia l'exemple: cp azure/config.sh.example azure/config.sh"
|
||||
# echo " Després omple els valors i torna a executar."
|
||||
# exit 1
|
||||
#fi
|
||||
# shellcheck source=config.sh.example
|
||||
source "$CONFIG_FILE"
|
||||
#source "$CONFIG_FILE"
|
||||
|
||||
# ── Validar variables obligatòries ───────────────────────────────────────────
|
||||
REQUIRED_VARS=(
|
||||
|
|
|
|||
|
|
@ -17,14 +17,14 @@ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
|||
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
||||
|
||||
# ── Carregar configuració ─────────────────────────────────────────────────────
|
||||
CONFIG_FILE="${SCRIPT_DIR}/config.sh"
|
||||
if [[ ! -f "$CONFIG_FILE" ]]; then
|
||||
echo "ERROR: No s'ha trobat azure/config.sh"
|
||||
echo " Còpia l'exemple: cp azure/config.sh.example azure/config.sh"
|
||||
exit 1
|
||||
fi
|
||||
#CONFIG_FILE="${SCRIPT_DIR}/config.sh"
|
||||
#if [[ ! -f "$CONFIG_FILE" ]]; then
|
||||
# echo "ERROR: No s'ha trobat azure/config.sh"
|
||||
# echo " Còpia l'exemple: cp azure/config.sh.example azure/config.sh"
|
||||
# exit 1
|
||||
#fi
|
||||
# shellcheck source=config.sh.example
|
||||
source "$CONFIG_FILE"
|
||||
#source "$CONFIG_FILE"
|
||||
|
||||
# ── Validar variables obligatòries ───────────────────────────────────────────
|
||||
REQUIRED_VARS=(
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
"""
|
||||
MiroFish Backend - Flask应用工厂
|
||||
MiroFish Backend - Flask application factory
|
||||
"""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
|
||||
# 抑制 multiprocessing resource_tracker 的警告(来自第三方库如 transformers)
|
||||
# 需要在所有其他导入之前设置
|
||||
# Suppress multiprocessing resource_tracker warnings (from third-party libraries like transformers)
|
||||
# Must be set before all other imports
|
||||
warnings.filterwarnings("ignore", message=".*resource_tracker.*")
|
||||
|
||||
import jwt
|
||||
|
|
@ -21,36 +21,36 @@ _PUBLIC_PATHS = {'/health', '/api/auth/login'}
|
|||
|
||||
|
||||
def create_app(config_class=Config):
|
||||
"""Flask应用工厂函数"""
|
||||
"""Flask application factory"""
|
||||
app = Flask(__name__)
|
||||
app.config.from_object(config_class)
|
||||
|
||||
# 设置JSON编码:确保中文直接显示(而不是 \uXXXX 格式)
|
||||
# Flask >= 2.3 使用 app.json.ensure_ascii,旧版本使用 JSON_AS_ASCII 配置
|
||||
# Configure JSON encoding: ensure non-ASCII characters are output directly (not as \uXXXX)
|
||||
# Flask >= 2.3 uses app.json.ensure_ascii; older versions use JSON_AS_ASCII config
|
||||
if hasattr(app, 'json') and hasattr(app.json, 'ensure_ascii'):
|
||||
app.json.ensure_ascii = False
|
||||
|
||||
# 设置日志
|
||||
# Set up logging
|
||||
logger = setup_logger('mirofish')
|
||||
|
||||
# 只在 reloader 子进程中打印启动信息(避免 debug 模式下打印两次)
|
||||
# Only log startup info in the reloader subprocess (avoids double-printing in debug mode)
|
||||
is_reloader_process = os.environ.get('WERKZEUG_RUN_MAIN') == 'true'
|
||||
debug_mode = app.config.get('DEBUG', False)
|
||||
should_log_startup = not debug_mode or is_reloader_process
|
||||
|
||||
if should_log_startup:
|
||||
logger.info("=" * 50)
|
||||
logger.info("MiroFish Backend 启动中...")
|
||||
logger.info("MiroFish Backend starting...")
|
||||
logger.info("=" * 50)
|
||||
|
||||
# 启用CORS
|
||||
# Enable CORS
|
||||
CORS(app, resources={r"/api/*": {"origins": "*"}})
|
||||
|
||||
# 注册模拟进程清理函数(确保服务器关闭时终止所有模拟进程)
|
||||
# Register simulation process cleanup (ensures all simulation processes are terminated on server shutdown)
|
||||
from .services.simulation_runner import SimulationRunner
|
||||
SimulationRunner.register_cleanup()
|
||||
if should_log_startup:
|
||||
logger.info("已注册模拟进程清理函数")
|
||||
logger.info("Simulation process cleanup handler registered")
|
||||
|
||||
# Middleware d'autenticació JWT — s'executa ABANS del log_request (ordre FIFO)
|
||||
@app.before_request
|
||||
|
|
@ -70,28 +70,28 @@ def create_app(config_class=Config):
|
|||
except jwt.InvalidTokenError:
|
||||
return jsonify({'success': False, 'error': 'Invalid token'}), 401
|
||||
|
||||
# 请求日志中间件
|
||||
# Request logging middleware
|
||||
@app.before_request
|
||||
def log_request():
|
||||
logger = get_logger('mirofish.request')
|
||||
logger.debug(f"请求: {request.method} {request.path}")
|
||||
logger.debug(f"Request: {request.method} {request.path}")
|
||||
if request.content_type and 'json' in request.content_type:
|
||||
logger.debug(f"请求体: {request.get_json(silent=True)}")
|
||||
logger.debug(f"Request body: {request.get_json(silent=True)}")
|
||||
|
||||
@app.after_request
|
||||
def log_response(response):
|
||||
logger = get_logger('mirofish.request')
|
||||
logger.debug(f"响应: {response.status_code}")
|
||||
logger.debug(f"Response: {response.status_code}")
|
||||
return response
|
||||
|
||||
# 注册蓝图 (auth primer, luego els existents)
|
||||
# Register blueprints (auth first, then the rest)
|
||||
from .api import graph_bp, simulation_bp, report_bp, auth_bp
|
||||
app.register_blueprint(auth_bp, url_prefix='/api/auth')
|
||||
app.register_blueprint(graph_bp, url_prefix='/api/graph')
|
||||
app.register_blueprint(simulation_bp, url_prefix='/api/simulation')
|
||||
app.register_blueprint(report_bp, url_prefix='/api/report')
|
||||
|
||||
# 健康检查
|
||||
# Health check
|
||||
@app.route('/health')
|
||||
def health():
|
||||
return {'status': 'ok', 'service': 'MiroFish Backend'}
|
||||
|
|
@ -111,6 +111,6 @@ def create_app(config_class=Config):
|
|||
return _send_file(_os.path.join(_dist, 'index.html'))
|
||||
|
||||
if should_log_startup:
|
||||
logger.info("MiroFish Backend 启动完成")
|
||||
logger.info("MiroFish Backend startup complete")
|
||||
|
||||
return app
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
API路由模块
|
||||
API routes module
|
||||
"""
|
||||
|
||||
from flask import Blueprint
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
图谱相关API路由
|
||||
采用项目上下文机制,服务端持久化状态
|
||||
Graph-related API routes
|
||||
Uses project context mechanism with server-side persistent state
|
||||
"""
|
||||
|
||||
import os
|
||||
|
|
@ -19,24 +19,24 @@ from ..utils.locale import t, get_locale, set_locale
|
|||
from ..models.task import TaskManager, TaskStatus
|
||||
from ..models.project import ProjectManager, ProjectStatus
|
||||
|
||||
# 获取日志器
|
||||
# Get logger
|
||||
logger = get_logger('mirofish.api')
|
||||
|
||||
|
||||
def allowed_file(filename: str) -> bool:
|
||||
"""检查文件扩展名是否允许"""
|
||||
"""Check if the file extension is allowed"""
|
||||
if not filename or '.' not in filename:
|
||||
return False
|
||||
ext = os.path.splitext(filename)[1].lower().lstrip('.')
|
||||
return ext in Config.ALLOWED_EXTENSIONS
|
||||
|
||||
|
||||
# ============== 项目管理接口 ==============
|
||||
# ============== Project management endpoints ==============
|
||||
|
||||
@graph_bp.route('/project/<project_id>', methods=['GET'])
|
||||
def get_project(project_id: str):
|
||||
"""
|
||||
获取项目详情
|
||||
Get project details
|
||||
"""
|
||||
project = ProjectManager.get_project(project_id)
|
||||
|
||||
|
|
@ -55,7 +55,7 @@ def get_project(project_id: str):
|
|||
@graph_bp.route('/project/list', methods=['GET'])
|
||||
def list_projects():
|
||||
"""
|
||||
列出所有项目
|
||||
List all projects
|
||||
"""
|
||||
limit = request.args.get('limit', 50, type=int)
|
||||
projects = ProjectManager.list_projects(limit=limit)
|
||||
|
|
@ -70,7 +70,7 @@ def list_projects():
|
|||
@graph_bp.route('/project/<project_id>', methods=['DELETE'])
|
||||
def delete_project(project_id: str):
|
||||
"""
|
||||
删除项目
|
||||
Delete a project
|
||||
"""
|
||||
success = ProjectManager.delete_project(project_id)
|
||||
|
||||
|
|
@ -89,7 +89,7 @@ def delete_project(project_id: str):
|
|||
@graph_bp.route('/project/<project_id>/reset', methods=['POST'])
|
||||
def reset_project(project_id: str):
|
||||
"""
|
||||
重置项目状态(用于重新构建图谱)
|
||||
Reset project status (used to rebuild the graph)
|
||||
"""
|
||||
project = ProjectManager.get_project(project_id)
|
||||
|
||||
|
|
@ -99,7 +99,7 @@ def reset_project(project_id: str):
|
|||
"error": t('api.projectNotFound', id=project_id)
|
||||
}), 404
|
||||
|
||||
# 重置到本体已生成状态
|
||||
# Reset to ontology-generated status
|
||||
if project.ontology:
|
||||
project.status = ProjectStatus.ONTOLOGY_GENERATED
|
||||
else:
|
||||
|
|
@ -117,22 +117,22 @@ def reset_project(project_id: str):
|
|||
})
|
||||
|
||||
|
||||
# ============== 接口1:上传文件并生成本体 ==============
|
||||
# ============== Endpoint 1: Upload files and generate ontology ==============
|
||||
|
||||
@graph_bp.route('/ontology/generate', methods=['POST'])
|
||||
def generate_ontology():
|
||||
"""
|
||||
接口1:上传文件,分析生成本体定义
|
||||
|
||||
请求方式:multipart/form-data
|
||||
|
||||
参数:
|
||||
files: 上传的文件(PDF/MD/TXT),可多个
|
||||
simulation_requirement: 模拟需求描述(必填)
|
||||
project_name: 项目名称(可选)
|
||||
additional_context: 额外说明(可选)
|
||||
|
||||
返回:
|
||||
Endpoint 1: Upload files and generate ontology definition
|
||||
|
||||
Request method: multipart/form-data
|
||||
|
||||
Parameters:
|
||||
files: Uploaded files (PDF/MD/TXT), multiple allowed
|
||||
simulation_requirement: Simulation requirement description (required)
|
||||
project_name: Project name (optional)
|
||||
additional_context: Additional context (optional)
|
||||
|
||||
Returns:
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
|
|
@ -148,15 +148,15 @@ def generate_ontology():
|
|||
}
|
||||
"""
|
||||
try:
|
||||
logger.info("=== 开始生成本体定义 ===")
|
||||
|
||||
# 获取参数
|
||||
logger.info("=== Starting ontology generation ===")
|
||||
|
||||
# Get parameters
|
||||
simulation_requirement = request.form.get('simulation_requirement', '')
|
||||
project_name = request.form.get('project_name', 'Unnamed Project')
|
||||
additional_context = request.form.get('additional_context', '')
|
||||
|
||||
logger.debug(f"项目名称: {project_name}")
|
||||
logger.debug(f"模拟需求: {simulation_requirement[:100]}...")
|
||||
|
||||
logger.debug(f"Project name: {project_name}")
|
||||
logger.debug(f"Simulation requirement: {simulation_requirement[:100]}...")
|
||||
|
||||
if not simulation_requirement:
|
||||
return jsonify({
|
||||
|
|
@ -164,68 +164,68 @@ def generate_ontology():
|
|||
"error": t('api.requireSimulationRequirement')
|
||||
}), 400
|
||||
|
||||
# 获取上传的文件
|
||||
# Get uploaded files
|
||||
uploaded_files = request.files.getlist('files')
|
||||
if not uploaded_files or all(not f.filename for f in uploaded_files):
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": t('api.requireFileUpload')
|
||||
}), 400
|
||||
|
||||
# 创建项目
|
||||
|
||||
# Create project
|
||||
project = ProjectManager.create_project(name=project_name)
|
||||
project.simulation_requirement = simulation_requirement
|
||||
logger.info(f"创建项目: {project.project_id}")
|
||||
|
||||
# 保存文件并提取文本
|
||||
logger.info(f"Project created: {project.project_id}")
|
||||
|
||||
# Save files and extract text
|
||||
document_texts = []
|
||||
all_text = ""
|
||||
|
||||
|
||||
for file in uploaded_files:
|
||||
if file and file.filename and allowed_file(file.filename):
|
||||
# 保存文件到项目目录
|
||||
# Save file to project directory
|
||||
file_info = ProjectManager.save_file_to_project(
|
||||
project.project_id,
|
||||
file,
|
||||
project.project_id,
|
||||
file,
|
||||
file.filename
|
||||
)
|
||||
project.files.append({
|
||||
"filename": file_info["original_filename"],
|
||||
"size": file_info["size"]
|
||||
})
|
||||
|
||||
# 提取文本
|
||||
|
||||
# Extract text
|
||||
text = FileParser.extract_text(file_info["path"])
|
||||
text = TextProcessor.preprocess_text(text)
|
||||
document_texts.append(text)
|
||||
all_text += f"\n\n=== {file_info['original_filename']} ===\n{text}"
|
||||
|
||||
|
||||
if not document_texts:
|
||||
ProjectManager.delete_project(project.project_id)
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": t('api.noDocProcessed')
|
||||
}), 400
|
||||
|
||||
# 保存提取的文本
|
||||
|
||||
# Save extracted text
|
||||
project.total_text_length = len(all_text)
|
||||
ProjectManager.save_extracted_text(project.project_id, all_text)
|
||||
logger.info(f"文本提取完成,共 {len(all_text)} 字符")
|
||||
|
||||
# 生成本体
|
||||
logger.info("调用 LLM 生成本体定义...")
|
||||
logger.info(f"Text extraction complete, total {len(all_text)} characters")
|
||||
|
||||
# Generate ontology
|
||||
logger.info("Calling LLM to generate ontology definition...")
|
||||
generator = OntologyGenerator()
|
||||
ontology = generator.generate(
|
||||
document_texts=document_texts,
|
||||
simulation_requirement=simulation_requirement,
|
||||
additional_context=additional_context if additional_context else None
|
||||
)
|
||||
|
||||
# 保存本体到项目
|
||||
|
||||
# Save ontology to project
|
||||
entity_count = len(ontology.get("entity_types", []))
|
||||
edge_count = len(ontology.get("edge_types", []))
|
||||
logger.info(f"本体生成完成: {entity_count} 个实体类型, {edge_count} 个关系类型")
|
||||
|
||||
logger.info(f"Ontology generation complete: {entity_count} entity types, {edge_count} relationship types")
|
||||
|
||||
project.ontology = {
|
||||
"entity_types": ontology.get("entity_types", []),
|
||||
"edge_types": ontology.get("edge_types", [])
|
||||
|
|
@ -233,7 +233,7 @@ def generate_ontology():
|
|||
project.analysis_summary = ontology.get("analysis_summary", "")
|
||||
project.status = ProjectStatus.ONTOLOGY_GENERATED
|
||||
ProjectManager.save_project(project)
|
||||
logger.info(f"=== 本体生成完成 === 项目ID: {project.project_id}")
|
||||
logger.info(f"=== Ontology generation complete === Project ID: {project.project_id}")
|
||||
|
||||
return jsonify({
|
||||
"success": True,
|
||||
|
|
@ -255,49 +255,49 @@ def generate_ontology():
|
|||
}), 500
|
||||
|
||||
|
||||
# ============== 接口2:构建图谱 ==============
|
||||
# ============== Endpoint 2: Build graph ==============
|
||||
|
||||
@graph_bp.route('/build', methods=['POST'])
|
||||
def build_graph():
|
||||
"""
|
||||
接口2:根据project_id构建图谱
|
||||
|
||||
请求(JSON):
|
||||
Endpoint 2: Build graph from project_id
|
||||
|
||||
Request (JSON):
|
||||
{
|
||||
"project_id": "proj_xxxx", // 必填,来自接口1
|
||||
"graph_name": "图谱名称", // 可选
|
||||
"chunk_size": 500, // 可选,默认500
|
||||
"chunk_overlap": 50 // 可选,默认50
|
||||
"project_id": "proj_xxxx", // required, from endpoint 1
|
||||
"graph_name": "Graph name", // optional
|
||||
"chunk_size": 500, // optional, default 500
|
||||
"chunk_overlap": 50 // optional, default 50
|
||||
}
|
||||
|
||||
返回:
|
||||
|
||||
Returns:
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"project_id": "proj_xxxx",
|
||||
"task_id": "task_xxxx",
|
||||
"message": "图谱构建任务已启动"
|
||||
"message": "Graph build task started"
|
||||
}
|
||||
}
|
||||
"""
|
||||
try:
|
||||
logger.info("=== 开始构建图谱 ===")
|
||||
|
||||
# 检查配置
|
||||
logger.info("=== Starting graph build ===")
|
||||
|
||||
# Check configuration
|
||||
errors = []
|
||||
if not Config.ZEP_API_KEY:
|
||||
errors.append(t('api.zepApiKeyMissing'))
|
||||
if errors:
|
||||
logger.error(f"配置错误: {errors}")
|
||||
logger.error(f"Configuration error: {errors}")
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": t('api.configError', details="; ".join(errors))
|
||||
}), 500
|
||||
|
||||
# 解析请求
|
||||
|
||||
# Parse request
|
||||
data = request.get_json() or {}
|
||||
project_id = data.get('project_id')
|
||||
logger.debug(f"请求参数: project_id={project_id}")
|
||||
logger.debug(f"Request parameters: project_id={project_id}")
|
||||
|
||||
if not project_id:
|
||||
return jsonify({
|
||||
|
|
@ -305,7 +305,7 @@ def build_graph():
|
|||
"error": t('api.requireProjectId')
|
||||
}), 400
|
||||
|
||||
# 获取项目
|
||||
# Get project
|
||||
project = ProjectManager.get_project(project_id)
|
||||
if not project:
|
||||
return jsonify({
|
||||
|
|
@ -313,83 +313,83 @@ def build_graph():
|
|||
"error": t('api.projectNotFound', id=project_id)
|
||||
}), 404
|
||||
|
||||
# 检查项目状态
|
||||
force = data.get('force', False) # 强制重新构建
|
||||
|
||||
# Check project status
|
||||
force = data.get('force', False) # Force rebuild
|
||||
|
||||
if project.status == ProjectStatus.CREATED:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": t('api.ontologyNotGenerated')
|
||||
}), 400
|
||||
|
||||
|
||||
if project.status == ProjectStatus.GRAPH_BUILDING and not force:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": t('api.graphBuilding'),
|
||||
"task_id": project.graph_build_task_id
|
||||
}), 400
|
||||
|
||||
# 如果强制重建,重置状态
|
||||
|
||||
# If force rebuild, reset status
|
||||
if force and project.status in [ProjectStatus.GRAPH_BUILDING, ProjectStatus.FAILED, ProjectStatus.GRAPH_COMPLETED]:
|
||||
project.status = ProjectStatus.ONTOLOGY_GENERATED
|
||||
project.graph_id = None
|
||||
project.graph_build_task_id = None
|
||||
project.error = None
|
||||
|
||||
# 获取配置
|
||||
|
||||
# Get configuration
|
||||
graph_name = data.get('graph_name', project.name or 'MiroFish Graph')
|
||||
chunk_size = data.get('chunk_size', project.chunk_size or Config.DEFAULT_CHUNK_SIZE)
|
||||
chunk_overlap = data.get('chunk_overlap', project.chunk_overlap or Config.DEFAULT_CHUNK_OVERLAP)
|
||||
|
||||
# 更新项目配置
|
||||
|
||||
# Update project configuration
|
||||
project.chunk_size = chunk_size
|
||||
project.chunk_overlap = chunk_overlap
|
||||
|
||||
# 获取提取的文本
|
||||
|
||||
# Get extracted text
|
||||
text = ProjectManager.get_extracted_text(project_id)
|
||||
if not text:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": t('api.textNotFound')
|
||||
}), 400
|
||||
|
||||
# 获取本体
|
||||
|
||||
# Get ontology
|
||||
ontology = project.ontology
|
||||
if not ontology:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": t('api.ontologyNotFound')
|
||||
}), 400
|
||||
|
||||
# 创建异步任务
|
||||
|
||||
# Create async task
|
||||
task_manager = TaskManager()
|
||||
task_id = task_manager.create_task(f"构建图谱: {graph_name}")
|
||||
logger.info(f"创建图谱构建任务: task_id={task_id}, project_id={project_id}")
|
||||
|
||||
# 更新项目状态
|
||||
task_id = task_manager.create_task(f"Build graph: {graph_name}")
|
||||
logger.info(f"Graph build task created: task_id={task_id}, project_id={project_id}")
|
||||
|
||||
# Update project status
|
||||
project.status = ProjectStatus.GRAPH_BUILDING
|
||||
project.graph_build_task_id = task_id
|
||||
ProjectManager.save_project(project)
|
||||
|
||||
|
||||
# Capture locale before spawning background thread
|
||||
current_locale = get_locale()
|
||||
|
||||
# 启动后台任务
|
||||
# Start background task
|
||||
def build_task():
|
||||
set_locale(current_locale)
|
||||
build_logger = get_logger('mirofish.build')
|
||||
try:
|
||||
build_logger.info(f"[{task_id}] 开始构建图谱...")
|
||||
build_logger.info(f"[{task_id}] Starting graph build...")
|
||||
task_manager.update_task(
|
||||
task_id,
|
||||
status=TaskStatus.PROCESSING,
|
||||
message=t('progress.initGraphService')
|
||||
)
|
||||
|
||||
# 创建图谱构建服务
|
||||
# Create graph builder service
|
||||
builder = GraphBuilderService(api_key=Config.ZEP_API_KEY)
|
||||
|
||||
# 分块
|
||||
|
||||
# Split into chunks
|
||||
task_manager.update_task(
|
||||
task_id,
|
||||
message=t('progress.textChunking'),
|
||||
|
|
@ -402,7 +402,7 @@ def build_graph():
|
|||
)
|
||||
total_chunks = len(chunks)
|
||||
|
||||
# 创建图谱
|
||||
# Create graph
|
||||
task_manager.update_task(
|
||||
task_id,
|
||||
message=t('progress.creatingZepGraph'),
|
||||
|
|
@ -410,11 +410,11 @@ def build_graph():
|
|||
)
|
||||
graph_id = builder.create_graph(name=graph_name)
|
||||
|
||||
# 更新项目的graph_id
|
||||
# Update project graph_id
|
||||
project.graph_id = graph_id
|
||||
ProjectManager.save_project(project)
|
||||
|
||||
# 设置本体
|
||||
|
||||
# Set ontology
|
||||
task_manager.update_task(
|
||||
task_id,
|
||||
message=t('progress.settingOntology'),
|
||||
|
|
@ -422,7 +422,7 @@ def build_graph():
|
|||
)
|
||||
builder.set_ontology(graph_id, ontology)
|
||||
|
||||
# 添加文本(progress_callback 签名是 (msg, progress_ratio))
|
||||
# Add text (progress_callback signature: (msg, progress_ratio))
|
||||
def add_progress_callback(msg, progress_ratio):
|
||||
progress = 15 + int(progress_ratio * 40) # 15% - 55%
|
||||
task_manager.update_task(
|
||||
|
|
@ -444,7 +444,7 @@ def build_graph():
|
|||
progress_callback=add_progress_callback
|
||||
)
|
||||
|
||||
# 等待Zep处理完成(查询每个episode的processed状态)
|
||||
# Wait for Zep processing to complete (poll each episode's processed status)
|
||||
task_manager.update_task(
|
||||
task_id,
|
||||
message=t('progress.waitingZepProcess'),
|
||||
|
|
@ -461,7 +461,7 @@ def build_graph():
|
|||
|
||||
builder._wait_for_episodes(episode_uuids, wait_progress_callback)
|
||||
|
||||
# 获取图谱数据
|
||||
# Fetch graph data
|
||||
task_manager.update_task(
|
||||
task_id,
|
||||
message=t('progress.fetchingGraphData'),
|
||||
|
|
@ -469,15 +469,15 @@ def build_graph():
|
|||
)
|
||||
graph_data = builder.get_graph_data(graph_id)
|
||||
|
||||
# 更新项目状态
|
||||
# Update project status
|
||||
project.status = ProjectStatus.GRAPH_COMPLETED
|
||||
ProjectManager.save_project(project)
|
||||
|
||||
|
||||
node_count = graph_data.get("node_count", 0)
|
||||
edge_count = graph_data.get("edge_count", 0)
|
||||
build_logger.info(f"[{task_id}] 图谱构建完成: graph_id={graph_id}, 节点={node_count}, 边={edge_count}")
|
||||
|
||||
# 完成
|
||||
build_logger.info(f"[{task_id}] Graph build complete: graph_id={graph_id}, nodes={node_count}, edges={edge_count}")
|
||||
|
||||
# Complete
|
||||
task_manager.update_task(
|
||||
task_id,
|
||||
status=TaskStatus.COMPLETED,
|
||||
|
|
@ -493,8 +493,8 @@ def build_graph():
|
|||
)
|
||||
|
||||
except Exception as e:
|
||||
# 更新项目状态为失败
|
||||
build_logger.error(f"[{task_id}] 图谱构建失败: {str(e)}")
|
||||
# Update project status to failed
|
||||
build_logger.error(f"[{task_id}] Graph build failed: {str(e)}")
|
||||
build_logger.debug(traceback.format_exc())
|
||||
|
||||
project.status = ProjectStatus.FAILED
|
||||
|
|
@ -508,7 +508,7 @@ def build_graph():
|
|||
error=traceback.format_exc()
|
||||
)
|
||||
|
||||
# 启动后台线程
|
||||
# Start background thread
|
||||
thread = threading.Thread(target=build_task, daemon=True)
|
||||
thread.start()
|
||||
|
||||
|
|
@ -529,12 +529,12 @@ def build_graph():
|
|||
}), 500
|
||||
|
||||
|
||||
# ============== 任务查询接口 ==============
|
||||
# ============== Task query endpoints ==============
|
||||
|
||||
@graph_bp.route('/task/<task_id>', methods=['GET'])
|
||||
def get_task(task_id: str):
|
||||
"""
|
||||
查询任务状态
|
||||
Query task status
|
||||
"""
|
||||
task = TaskManager().get_task(task_id)
|
||||
|
||||
|
|
@ -553,7 +553,7 @@ def get_task(task_id: str):
|
|||
@graph_bp.route('/tasks', methods=['GET'])
|
||||
def list_tasks():
|
||||
"""
|
||||
列出所有任务
|
||||
List all tasks
|
||||
"""
|
||||
tasks = TaskManager().list_tasks()
|
||||
|
||||
|
|
@ -564,12 +564,12 @@ def list_tasks():
|
|||
})
|
||||
|
||||
|
||||
# ============== 图谱数据接口 ==============
|
||||
# ============== Graph data endpoints ==============
|
||||
|
||||
@graph_bp.route('/data/<graph_id>', methods=['GET'])
|
||||
def get_graph_data(graph_id: str):
|
||||
"""
|
||||
获取图谱数据(节点和边)
|
||||
Get graph data (nodes and edges)
|
||||
"""
|
||||
try:
|
||||
if not Config.ZEP_API_KEY:
|
||||
|
|
@ -597,7 +597,7 @@ def get_graph_data(graph_id: str):
|
|||
@graph_bp.route('/delete/<graph_id>', methods=['DELETE'])
|
||||
def delete_graph(graph_id: str):
|
||||
"""
|
||||
删除Zep图谱
|
||||
Delete a Zep graph
|
||||
"""
|
||||
try:
|
||||
if not Config.ZEP_API_KEY:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
Report API路由
|
||||
提供模拟报告生成、获取、对话等接口
|
||||
Report API routes
|
||||
Provides simulation report generation, retrieval, and chat endpoints
|
||||
"""
|
||||
|
||||
import os
|
||||
|
|
@ -20,30 +20,30 @@ from ..utils.locale import t, get_locale, set_locale
|
|||
logger = get_logger('mirofish.api.report')
|
||||
|
||||
|
||||
# ============== 报告生成接口 ==============
|
||||
# ============== Report generation endpoints ==============
|
||||
|
||||
@report_bp.route('/generate', methods=['POST'])
|
||||
def generate_report():
|
||||
"""
|
||||
生成模拟分析报告(异步任务)
|
||||
|
||||
这是一个耗时操作,接口会立即返回task_id,
|
||||
使用 GET /api/report/generate/status 查询进度
|
||||
|
||||
请求(JSON):
|
||||
Generate a simulation analysis report (async task)
|
||||
|
||||
This is a long-running operation; the endpoint returns task_id immediately.
|
||||
Use GET /api/report/generate/status to poll progress.
|
||||
|
||||
Request (JSON):
|
||||
{
|
||||
"simulation_id": "sim_xxxx", // 必填,模拟ID
|
||||
"force_regenerate": false // 可选,强制重新生成
|
||||
"simulation_id": "sim_xxxx", // required, simulation ID
|
||||
"force_regenerate": false // optional, force regeneration
|
||||
}
|
||||
|
||||
返回:
|
||||
|
||||
Returns:
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"simulation_id": "sim_xxxx",
|
||||
"task_id": "task_xxxx",
|
||||
"status": "generating",
|
||||
"message": "报告生成任务已启动"
|
||||
"message": "Report generation task started"
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
|
@ -59,17 +59,17 @@ def generate_report():
|
|||
|
||||
force_regenerate = data.get('force_regenerate', False)
|
||||
|
||||
# 获取模拟信息
|
||||
# Get simulation info
|
||||
manager = SimulationManager()
|
||||
state = manager.get_simulation(simulation_id)
|
||||
|
||||
|
||||
if not state:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": t('api.simulationNotFound', id=simulation_id)
|
||||
}), 404
|
||||
|
||||
# 检查是否已有报告
|
||||
# Check if a report already exists
|
||||
if not force_regenerate:
|
||||
existing_report = ReportManager.get_report_by_simulation(simulation_id)
|
||||
if existing_report and existing_report.status == ReportStatus.COMPLETED:
|
||||
|
|
@ -84,7 +84,7 @@ def generate_report():
|
|||
}
|
||||
})
|
||||
|
||||
# 获取项目信息
|
||||
# Get project info
|
||||
project = ProjectManager.get_project(state.project_id)
|
||||
if not project:
|
||||
return jsonify({
|
||||
|
|
@ -106,11 +106,11 @@ def generate_report():
|
|||
"error": t('api.missingSimRequirement')
|
||||
}), 400
|
||||
|
||||
# 提前生成 report_id,以便立即返回给前端
|
||||
# Pre-generate report_id so it can be returned immediately
|
||||
import uuid
|
||||
report_id = f"report_{uuid.uuid4().hex[:12]}"
|
||||
|
||||
# 创建异步任务
|
||||
|
||||
# Create async task
|
||||
task_manager = TaskManager()
|
||||
task_id = task_manager.create_task(
|
||||
task_type="report_generate",
|
||||
|
|
@ -124,7 +124,7 @@ def generate_report():
|
|||
# Capture locale before spawning background thread
|
||||
current_locale = get_locale()
|
||||
|
||||
# 定义后台任务
|
||||
# Define background task
|
||||
def run_generate():
|
||||
set_locale(current_locale)
|
||||
try:
|
||||
|
|
@ -134,29 +134,29 @@ def generate_report():
|
|||
progress=0,
|
||||
message=t('api.initReportAgent')
|
||||
)
|
||||
|
||||
# 创建Report Agent
|
||||
|
||||
# Create Report Agent
|
||||
agent = ReportAgent(
|
||||
graph_id=graph_id,
|
||||
simulation_id=simulation_id,
|
||||
simulation_requirement=simulation_requirement
|
||||
)
|
||||
|
||||
# 进度回调
|
||||
# Progress callback
|
||||
def progress_callback(stage, progress, message):
|
||||
task_manager.update_task(
|
||||
task_id,
|
||||
progress=progress,
|
||||
message=f"[{stage}] {message}"
|
||||
)
|
||||
|
||||
# 生成报告(传入预先生成的 report_id)
|
||||
|
||||
# Generate report (pass pre-generated report_id)
|
||||
report = agent.generate_report(
|
||||
progress_callback=progress_callback,
|
||||
report_id=report_id
|
||||
)
|
||||
|
||||
# 保存报告
|
||||
# Save report
|
||||
ReportManager.save_report(report)
|
||||
|
||||
if report.status == ReportStatus.COMPLETED:
|
||||
|
|
@ -172,10 +172,10 @@ def generate_report():
|
|||
task_manager.fail_task(task_id, report.error or t('api.reportGenerateFailed'))
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"报告生成失败: {str(e)}")
|
||||
logger.error(f"Report generation failed: {str(e)}")
|
||||
task_manager.fail_task(task_id, str(e))
|
||||
|
||||
# 启动后台线程
|
||||
|
||||
# Start background thread
|
||||
thread = threading.Thread(target=run_generate, daemon=True)
|
||||
thread.start()
|
||||
|
||||
|
|
@ -192,7 +192,7 @@ def generate_report():
|
|||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"启动报告生成任务失败: {str(e)}")
|
||||
logger.error(f"Failed to start report generation task: {str(e)}")
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
|
|
@ -203,15 +203,15 @@ def generate_report():
|
|||
@report_bp.route('/generate/status', methods=['POST'])
|
||||
def get_generate_status():
|
||||
"""
|
||||
查询报告生成任务进度
|
||||
|
||||
请求(JSON):
|
||||
Query report generation task progress
|
||||
|
||||
Request (JSON):
|
||||
{
|
||||
"task_id": "task_xxxx", // 可选,generate返回的task_id
|
||||
"simulation_id": "sim_xxxx" // 可选,模拟ID
|
||||
"task_id": "task_xxxx", // optional, task_id from generate
|
||||
"simulation_id": "sim_xxxx" // optional, simulation ID
|
||||
}
|
||||
|
||||
返回:
|
||||
|
||||
Returns:
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
|
|
@ -228,7 +228,7 @@ def get_generate_status():
|
|||
task_id = data.get('task_id')
|
||||
simulation_id = data.get('simulation_id')
|
||||
|
||||
# 如果提供了simulation_id,先检查是否已有完成的报告
|
||||
# If simulation_id is provided, check whether a completed report exists
|
||||
if simulation_id:
|
||||
existing_report = ReportManager.get_report_by_simulation(simulation_id)
|
||||
if existing_report and existing_report.status == ReportStatus.COMPLETED:
|
||||
|
|
@ -265,21 +265,21 @@ def get_generate_status():
|
|||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"查询任务状态失败: {str(e)}")
|
||||
logger.error(f"Failed to query task status: {str(e)}")
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": str(e)
|
||||
}), 500
|
||||
|
||||
|
||||
# ============== 报告获取接口 ==============
|
||||
# ============== Report retrieval endpoints ==============
|
||||
|
||||
@report_bp.route('/<report_id>', methods=['GET'])
|
||||
def get_report(report_id: str):
|
||||
"""
|
||||
获取报告详情
|
||||
|
||||
返回:
|
||||
Get report details
|
||||
|
||||
Returns:
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
|
|
@ -308,7 +308,7 @@ def get_report(report_id: str):
|
|||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取报告失败: {str(e)}")
|
||||
logger.error(f"Failed to get report: {str(e)}")
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
|
|
@ -319,9 +319,9 @@ def get_report(report_id: str):
|
|||
@report_bp.route('/by-simulation/<simulation_id>', methods=['GET'])
|
||||
def get_report_by_simulation(simulation_id: str):
|
||||
"""
|
||||
根据模拟ID获取报告
|
||||
|
||||
返回:
|
||||
Get report by simulation ID
|
||||
|
||||
Returns:
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
|
|
@ -347,7 +347,7 @@ def get_report_by_simulation(simulation_id: str):
|
|||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取报告失败: {str(e)}")
|
||||
logger.error(f"Failed to get report: {str(e)}")
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
|
|
@ -358,13 +358,13 @@ def get_report_by_simulation(simulation_id: str):
|
|||
@report_bp.route('/list', methods=['GET'])
|
||||
def list_reports():
|
||||
"""
|
||||
列出所有报告
|
||||
|
||||
Query参数:
|
||||
simulation_id: 按模拟ID过滤(可选)
|
||||
limit: 返回数量限制(默认50)
|
||||
|
||||
返回:
|
||||
List all reports
|
||||
|
||||
Query parameters:
|
||||
simulation_id: filter by simulation ID (optional)
|
||||
limit: result count limit (default 50)
|
||||
|
||||
Returns:
|
||||
{
|
||||
"success": true,
|
||||
"data": [...],
|
||||
|
|
@ -387,7 +387,7 @@ def list_reports():
|
|||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"列出报告失败: {str(e)}")
|
||||
logger.error(f"Failed to list reports: {str(e)}")
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
|
|
@ -398,9 +398,9 @@ def list_reports():
|
|||
@report_bp.route('/<report_id>/download', methods=['GET'])
|
||||
def download_report(report_id: str):
|
||||
"""
|
||||
下载报告(Markdown格式)
|
||||
|
||||
返回Markdown文件
|
||||
Download report (Markdown format)
|
||||
|
||||
Returns a Markdown file
|
||||
"""
|
||||
try:
|
||||
report = ReportManager.get_report(report_id)
|
||||
|
|
@ -414,7 +414,7 @@ def download_report(report_id: str):
|
|||
md_path = ReportManager._get_report_markdown_path(report_id)
|
||||
|
||||
if not os.path.exists(md_path):
|
||||
# 如果MD文件不存在,生成一个临时文件
|
||||
# If MD file doesn't exist, create a temporary file
|
||||
import tempfile
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.md', delete=False) as f:
|
||||
f.write(report.markdown_content)
|
||||
|
|
@ -433,7 +433,7 @@ def download_report(report_id: str):
|
|||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"下载报告失败: {str(e)}")
|
||||
logger.error(f"Failed to download report: {str(e)}")
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
|
|
@ -443,7 +443,7 @@ def download_report(report_id: str):
|
|||
|
||||
@report_bp.route('/<report_id>', methods=['DELETE'])
|
||||
def delete_report(report_id: str):
|
||||
"""删除报告"""
|
||||
"""Delete a report"""
|
||||
try:
|
||||
success = ReportManager.delete_report(report_id)
|
||||
|
||||
|
|
@ -459,7 +459,7 @@ def delete_report(report_id: str):
|
|||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"删除报告失败: {str(e)}")
|
||||
logger.error(f"Failed to delete report: {str(e)}")
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
|
|
@ -467,32 +467,32 @@ def delete_report(report_id: str):
|
|||
}), 500
|
||||
|
||||
|
||||
# ============== Report Agent对话接口 ==============
|
||||
# ============== Report Agent chat endpoint ==============
|
||||
|
||||
@report_bp.route('/chat', methods=['POST'])
|
||||
def chat_with_report_agent():
|
||||
"""
|
||||
与Report Agent对话
|
||||
|
||||
Report Agent可以在对话中自主调用检索工具来回答问题
|
||||
|
||||
请求(JSON):
|
||||
Chat with the Report Agent
|
||||
|
||||
The Report Agent can autonomously call retrieval tools to answer questions.
|
||||
|
||||
Request (JSON):
|
||||
{
|
||||
"simulation_id": "sim_xxxx", // 必填,模拟ID
|
||||
"message": "请解释一下舆情走向", // 必填,用户消息
|
||||
"chat_history": [ // 可选,对话历史
|
||||
"simulation_id": "sim_xxxx", // required, simulation ID
|
||||
"message": "Explain the trend...", // required, user message
|
||||
"chat_history": [ // optional, conversation history
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."}
|
||||
]
|
||||
}
|
||||
|
||||
返回:
|
||||
|
||||
Returns:
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"response": "Agent回复...",
|
||||
"tool_calls": [调用的工具列表],
|
||||
"sources": [信息来源]
|
||||
"response": "Agent reply...",
|
||||
"tool_calls": [list of tools called],
|
||||
"sources": [information sources]
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
|
@ -515,7 +515,7 @@ def chat_with_report_agent():
|
|||
"error": t('api.requireMessage')
|
||||
}), 400
|
||||
|
||||
# 获取模拟和项目信息
|
||||
# Get simulation and project info
|
||||
manager = SimulationManager()
|
||||
state = manager.get_simulation(simulation_id)
|
||||
|
||||
|
|
@ -541,7 +541,7 @@ def chat_with_report_agent():
|
|||
|
||||
simulation_requirement = project.simulation_requirement or ""
|
||||
|
||||
# 创建Agent并进行对话
|
||||
# Create agent and start chat
|
||||
agent = ReportAgent(
|
||||
graph_id=graph_id,
|
||||
simulation_id=simulation_id,
|
||||
|
|
@ -556,7 +556,7 @@ def chat_with_report_agent():
|
|||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"对话失败: {str(e)}")
|
||||
logger.error(f"Chat failed: {str(e)}")
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
|
|
@ -564,22 +564,22 @@ def chat_with_report_agent():
|
|||
}), 500
|
||||
|
||||
|
||||
# ============== 报告进度与分章节接口 ==============
|
||||
# ============== Report progress and section endpoints ==============
|
||||
|
||||
@report_bp.route('/<report_id>/progress', methods=['GET'])
|
||||
def get_report_progress(report_id: str):
|
||||
"""
|
||||
获取报告生成进度(实时)
|
||||
|
||||
返回:
|
||||
Get report generation progress (real-time)
|
||||
|
||||
Returns:
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"status": "generating",
|
||||
"progress": 45,
|
||||
"message": "正在生成章节: 关键发现",
|
||||
"current_section": "关键发现",
|
||||
"completed_sections": ["执行摘要", "模拟背景"],
|
||||
"message": "Generating section: Key Findings",
|
||||
"current_section": "Key Findings",
|
||||
"completed_sections": ["Executive Summary", "Simulation Background"],
|
||||
"updated_at": "2025-12-09T..."
|
||||
}
|
||||
}
|
||||
|
|
@ -599,7 +599,7 @@ def get_report_progress(report_id: str):
|
|||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取报告进度失败: {str(e)}")
|
||||
logger.error(f"Failed to get report progress: {str(e)}")
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
|
|
@ -610,11 +610,12 @@ def get_report_progress(report_id: str):
|
|||
@report_bp.route('/<report_id>/sections', methods=['GET'])
|
||||
def get_report_sections(report_id: str):
|
||||
"""
|
||||
获取已生成的章节列表(分章节输出)
|
||||
|
||||
前端可以轮询此接口获取已生成的章节内容,无需等待整个报告完成
|
||||
|
||||
返回:
|
||||
Get list of already-generated sections (section-by-section output)
|
||||
|
||||
The frontend can poll this endpoint to get section content as it is generated,
|
||||
without waiting for the full report to complete.
|
||||
|
||||
Returns:
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
|
|
@ -623,7 +624,7 @@ def get_report_sections(report_id: str):
|
|||
{
|
||||
"filename": "section_01.md",
|
||||
"section_index": 1,
|
||||
"content": "## 执行摘要\\n\\n..."
|
||||
"content": "## Executive Summary\\n\\n..."
|
||||
},
|
||||
...
|
||||
],
|
||||
|
|
@ -635,7 +636,7 @@ def get_report_sections(report_id: str):
|
|||
try:
|
||||
sections = ReportManager.get_generated_sections(report_id)
|
||||
|
||||
# 获取报告状态
|
||||
# Get report status
|
||||
report = ReportManager.get_report(report_id)
|
||||
is_complete = report is not None and report.status == ReportStatus.COMPLETED
|
||||
|
||||
|
|
@ -650,7 +651,7 @@ def get_report_sections(report_id: str):
|
|||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取章节列表失败: {str(e)}")
|
||||
logger.error(f"Failed to get section list: {str(e)}")
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
|
|
@ -661,14 +662,14 @@ def get_report_sections(report_id: str):
|
|||
@report_bp.route('/<report_id>/section/<int:section_index>', methods=['GET'])
|
||||
def get_single_section(report_id: str, section_index: int):
|
||||
"""
|
||||
获取单个章节内容
|
||||
|
||||
返回:
|
||||
Get single section content
|
||||
|
||||
Returns:
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"filename": "section_01.md",
|
||||
"content": "## 执行摘要\\n\\n..."
|
||||
"content": "## Executive Summary\\n\\n..."
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
|
@ -694,7 +695,7 @@ def get_single_section(report_id: str, section_index: int):
|
|||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取章节内容失败: {str(e)}")
|
||||
logger.error(f"Failed to get section content: {str(e)}")
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
|
|
@ -702,16 +703,16 @@ def get_single_section(report_id: str, section_index: int):
|
|||
}), 500
|
||||
|
||||
|
||||
# ============== 报告状态检查接口 ==============
|
||||
# ============== Report status check endpoint ==============
|
||||
|
||||
@report_bp.route('/check/<simulation_id>', methods=['GET'])
|
||||
def check_report_status(simulation_id: str):
|
||||
"""
|
||||
检查模拟是否有报告,以及报告状态
|
||||
|
||||
用于前端判断是否解锁Interview功能
|
||||
|
||||
返回:
|
||||
Check whether a simulation has a report and its status
|
||||
|
||||
Used by the frontend to determine whether to unlock the Interview feature.
|
||||
|
||||
Returns:
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
|
|
@ -730,7 +731,7 @@ def check_report_status(simulation_id: str):
|
|||
report_status = report.status.value if report else None
|
||||
report_id = report.report_id if report else None
|
||||
|
||||
# 只有报告完成后才解锁interview
|
||||
# Interview is unlocked only after the report is complete
|
||||
interview_unlocked = has_report and report.status == ReportStatus.COMPLETED
|
||||
|
||||
return jsonify({
|
||||
|
|
@ -745,7 +746,7 @@ def check_report_status(simulation_id: str):
|
|||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检查报告状态失败: {str(e)}")
|
||||
logger.error(f"Failed to check report status: {str(e)}")
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
|
|
@ -753,22 +754,22 @@ def check_report_status(simulation_id: str):
|
|||
}), 500
|
||||
|
||||
|
||||
# ============== Agent 日志接口 ==============
|
||||
# ============== Agent log endpoints ==============
|
||||
|
||||
@report_bp.route('/<report_id>/agent-log', methods=['GET'])
|
||||
def get_agent_log(report_id: str):
|
||||
"""
|
||||
获取 Report Agent 的详细执行日志
|
||||
|
||||
实时获取报告生成过程中的每一步动作,包括:
|
||||
- 报告开始、规划开始/完成
|
||||
- 每个章节的开始、工具调用、LLM响应、完成
|
||||
- 报告完成或失败
|
||||
|
||||
Query参数:
|
||||
from_line: 从第几行开始读取(可选,默认0,用于增量获取)
|
||||
|
||||
返回:
|
||||
Get detailed execution log of the Report Agent
|
||||
|
||||
Retrieves step-by-step actions during report generation, including:
|
||||
- Report start, planning start/complete
|
||||
- Each section's start, tool calls, LLM response, completion
|
||||
- Report completion or failure
|
||||
|
||||
Query parameters:
|
||||
from_line: start reading from this line (optional, default 0, for incremental fetch)
|
||||
|
||||
Returns:
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
|
|
@ -779,7 +780,7 @@ def get_agent_log(report_id: str):
|
|||
"report_id": "report_xxxx",
|
||||
"action": "tool_call",
|
||||
"stage": "generating",
|
||||
"section_title": "执行摘要",
|
||||
"section_title": "Executive Summary",
|
||||
"section_index": 1,
|
||||
"details": {
|
||||
"tool_name": "insight_forge",
|
||||
|
|
@ -806,7 +807,7 @@ def get_agent_log(report_id: str):
|
|||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取Agent日志失败: {str(e)}")
|
||||
logger.error(f"Failed to get Agent log: {str(e)}")
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
|
|
@ -817,9 +818,9 @@ def get_agent_log(report_id: str):
|
|||
@report_bp.route('/<report_id>/agent-log/stream', methods=['GET'])
|
||||
def stream_agent_log(report_id: str):
|
||||
"""
|
||||
获取完整的 Agent 日志(一次性获取全部)
|
||||
|
||||
返回:
|
||||
Get the full Agent log (fetch all at once)
|
||||
|
||||
Returns:
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
|
|
@ -840,7 +841,7 @@ def stream_agent_log(report_id: str):
|
|||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取Agent日志失败: {str(e)}")
|
||||
logger.error(f"Failed to get Agent log: {str(e)}")
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
|
|
@ -848,27 +849,27 @@ def stream_agent_log(report_id: str):
|
|||
}), 500
|
||||
|
||||
|
||||
# ============== 控制台日志接口 ==============
|
||||
# ============== Console log endpoints ==============
|
||||
|
||||
@report_bp.route('/<report_id>/console-log', methods=['GET'])
|
||||
def get_console_log(report_id: str):
|
||||
"""
|
||||
获取 Report Agent 的控制台输出日志
|
||||
|
||||
实时获取报告生成过程中的控制台输出(INFO、WARNING等),
|
||||
这与 agent-log 接口返回的结构化 JSON 日志不同,
|
||||
是纯文本格式的控制台风格日志。
|
||||
|
||||
Query参数:
|
||||
from_line: 从第几行开始读取(可选,默认0,用于增量获取)
|
||||
|
||||
返回:
|
||||
Get the console output log of the Report Agent
|
||||
|
||||
Returns real-time console output (INFO, WARNING, etc.) during report generation.
|
||||
Unlike the agent-log endpoint which returns structured JSON logs,
|
||||
this returns plain-text console-style logs.
|
||||
|
||||
Query parameters:
|
||||
from_line: start reading from this line (optional, default 0, for incremental fetch)
|
||||
|
||||
Returns:
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
"logs": [
|
||||
"[19:46:14] INFO: 搜索完成: 找到 15 条相关事实",
|
||||
"[19:46:14] INFO: 图谱搜索: graph_id=xxx, query=...",
|
||||
"[19:46:14] INFO: Search complete: found 15 relevant facts",
|
||||
"[19:46:14] INFO: Graph search: graph_id=xxx, query=...",
|
||||
...
|
||||
],
|
||||
"total_lines": 100,
|
||||
|
|
@ -888,7 +889,7 @@ def get_console_log(report_id: str):
|
|||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取控制台日志失败: {str(e)}")
|
||||
logger.error(f"Failed to get console log: {str(e)}")
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
|
|
@ -899,9 +900,9 @@ def get_console_log(report_id: str):
|
|||
@report_bp.route('/<report_id>/console-log/stream', methods=['GET'])
|
||||
def stream_console_log(report_id: str):
|
||||
"""
|
||||
获取完整的控制台日志(一次性获取全部)
|
||||
|
||||
返回:
|
||||
Get the full console log (fetch all at once)
|
||||
|
||||
Returns:
|
||||
{
|
||||
"success": true,
|
||||
"data": {
|
||||
|
|
@ -922,7 +923,7 @@ def stream_console_log(report_id: str):
|
|||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取控制台日志失败: {str(e)}")
|
||||
logger.error(f"Failed to get console log: {str(e)}")
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
|
|
@ -930,17 +931,17 @@ def stream_console_log(report_id: str):
|
|||
}), 500
|
||||
|
||||
|
||||
# ============== 工具调用接口(供调试使用)==============
|
||||
# ============== Tool call endpoints (for debugging) ==============
|
||||
|
||||
@report_bp.route('/tools/search', methods=['POST'])
|
||||
def search_graph_tool():
|
||||
"""
|
||||
图谱搜索工具接口(供调试使用)
|
||||
|
||||
请求(JSON):
|
||||
Graph search tool endpoint (for debugging)
|
||||
|
||||
Request (JSON):
|
||||
{
|
||||
"graph_id": "mirofish_xxxx",
|
||||
"query": "搜索查询",
|
||||
"query": "search query",
|
||||
"limit": 10
|
||||
}
|
||||
"""
|
||||
|
|
@ -972,7 +973,7 @@ def search_graph_tool():
|
|||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"图谱搜索失败: {str(e)}")
|
||||
logger.error(f"Graph search failed: {str(e)}")
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
|
|
@ -983,9 +984,9 @@ def search_graph_tool():
|
|||
@report_bp.route('/tools/statistics', methods=['POST'])
|
||||
def get_graph_statistics_tool():
|
||||
"""
|
||||
图谱统计工具接口(供调试使用)
|
||||
|
||||
请求(JSON):
|
||||
Graph statistics tool endpoint (for debugging)
|
||||
|
||||
Request (JSON):
|
||||
{
|
||||
"graph_id": "mirofish_xxxx"
|
||||
}
|
||||
|
|
@ -1012,7 +1013,7 @@ def get_graph_statistics_tool():
|
|||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取图谱统计失败: {str(e)}")
|
||||
logger.error(f"Failed to get graph statistics: {str(e)}")
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,55 +1,55 @@
|
|||
"""
|
||||
配置管理
|
||||
统一从项目根目录的 .env 文件加载配置
|
||||
Configuration management
|
||||
Loads config uniformly from the .env file at the project root
|
||||
"""
|
||||
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 加载项目根目录的 .env 文件
|
||||
# 路径: MiroFish/.env (相对于 backend/app/config.py)
|
||||
# Load the .env file from the project root
|
||||
# Path: MiroFish/.env (relative to backend/app/config.py)
|
||||
project_root_env = os.path.join(os.path.dirname(__file__), '../../.env')
|
||||
|
||||
if os.path.exists(project_root_env):
|
||||
load_dotenv(project_root_env, override=True)
|
||||
else:
|
||||
# 如果根目录没有 .env,尝试加载环境变量(用于生产环境)
|
||||
# If no root-level .env file found, load from environment variables (production)
|
||||
load_dotenv(override=True)
|
||||
|
||||
|
||||
class Config:
|
||||
"""Flask配置类"""
|
||||
|
||||
# Flask配置
|
||||
"""Flask configuration class"""
|
||||
|
||||
# Flask settings
|
||||
SECRET_KEY = os.environ.get('SECRET_KEY', 'mirofish-secret-key')
|
||||
DEMO_PASSWORD = os.environ.get('DEMO_PASSWORD', '')
|
||||
DEBUG = os.environ.get('FLASK_DEBUG', 'True').lower() == 'true'
|
||||
|
||||
# JSON配置 - 禁用ASCII转义,让中文直接显示(而不是 \uXXXX 格式)
|
||||
|
||||
# JSON settings - disable ASCII escaping so non-ASCII chars are output directly (not as \uXXXX)
|
||||
JSON_AS_ASCII = False
|
||||
|
||||
# LLM配置(统一使用OpenAI格式)
|
||||
|
||||
# LLM settings (unified OpenAI-compatible format)
|
||||
LLM_API_KEY = os.environ.get('LLM_API_KEY')
|
||||
LLM_BASE_URL = os.environ.get('LLM_BASE_URL', 'https://api.openai.com/v1')
|
||||
LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME', 'gpt-4o-mini')
|
||||
|
||||
# Zep配置
|
||||
# Zep settings
|
||||
ZEP_API_KEY = os.environ.get('ZEP_API_KEY')
|
||||
|
||||
# 文件上传配置
|
||||
|
||||
# File upload settings
|
||||
MAX_CONTENT_LENGTH = 50 * 1024 * 1024 # 50MB
|
||||
UPLOAD_FOLDER = os.path.join(os.path.dirname(__file__), '../uploads')
|
||||
ALLOWED_EXTENSIONS = {'pdf', 'md', 'txt', 'markdown'}
|
||||
|
||||
# 文本处理配置
|
||||
DEFAULT_CHUNK_SIZE = 500 # 默认切块大小
|
||||
DEFAULT_CHUNK_OVERLAP = 50 # 默认重叠大小
|
||||
|
||||
# OASIS模拟配置
|
||||
# Text processing settings
|
||||
DEFAULT_CHUNK_SIZE = 500 # default chunk size
|
||||
DEFAULT_CHUNK_OVERLAP = 50 # default overlap size
|
||||
|
||||
# OASIS simulation settings
|
||||
OASIS_DEFAULT_MAX_ROUNDS = int(os.environ.get('OASIS_DEFAULT_MAX_ROUNDS', '10'))
|
||||
OASIS_SIMULATION_DATA_DIR = os.path.join(os.path.dirname(__file__), '../uploads/simulations')
|
||||
|
||||
# OASIS平台可用动作配置
|
||||
# OASIS platform available actions
|
||||
OASIS_TWITTER_ACTIONS = [
|
||||
'CREATE_POST', 'LIKE_POST', 'REPOST', 'FOLLOW', 'DO_NOTHING', 'QUOTE_POST'
|
||||
]
|
||||
|
|
@ -59,18 +59,18 @@ class Config:
|
|||
'TREND', 'REFRESH', 'DO_NOTHING', 'FOLLOW', 'MUTE'
|
||||
]
|
||||
|
||||
# Report Agent配置
|
||||
# Report Agent settings
|
||||
REPORT_AGENT_MAX_TOOL_CALLS = int(os.environ.get('REPORT_AGENT_MAX_TOOL_CALLS', '5'))
|
||||
REPORT_AGENT_MAX_REFLECTION_ROUNDS = int(os.environ.get('REPORT_AGENT_MAX_REFLECTION_ROUNDS', '2'))
|
||||
REPORT_AGENT_TEMPERATURE = float(os.environ.get('REPORT_AGENT_TEMPERATURE', '0.5'))
|
||||
|
||||
@classmethod
|
||||
def validate(cls):
|
||||
"""验证必要配置"""
|
||||
"""Validate required configuration"""
|
||||
errors = []
|
||||
if not cls.LLM_API_KEY:
|
||||
errors.append("LLM_API_KEY 未配置")
|
||||
errors.append("LLM_API_KEY is not configured")
|
||||
if not cls.ZEP_API_KEY:
|
||||
errors.append("ZEP_API_KEY 未配置")
|
||||
errors.append("ZEP_API_KEY is not configured")
|
||||
return errors
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
数据模型模块
|
||||
Data models module
|
||||
"""
|
||||
|
||||
from .task import TaskManager, TaskStatus
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
项目上下文管理
|
||||
用于在服务端持久化项目状态,避免前端在接口间传递大量数据
|
||||
Project context management
|
||||
Persists project state server-side so the frontend does not need to pass large amounts of data between endpoints.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
|
@ -15,45 +15,45 @@ from ..config import Config
|
|||
|
||||
|
||||
class ProjectStatus(str, Enum):
|
||||
"""项目状态"""
|
||||
CREATED = "created" # 刚创建,文件已上传
|
||||
ONTOLOGY_GENERATED = "ontology_generated" # 本体已生成
|
||||
GRAPH_BUILDING = "graph_building" # 图谱构建中
|
||||
GRAPH_COMPLETED = "graph_completed" # 图谱构建完成
|
||||
FAILED = "failed" # 失败
|
||||
"""Project status"""
|
||||
CREATED = "created" # Just created; files uploaded
|
||||
ONTOLOGY_GENERATED = "ontology_generated" # Ontology generated
|
||||
GRAPH_BUILDING = "graph_building" # Graph building in progress
|
||||
GRAPH_COMPLETED = "graph_completed" # Graph build complete
|
||||
FAILED = "failed" # Failed
|
||||
|
||||
|
||||
@dataclass
|
||||
class Project:
|
||||
"""项目数据模型"""
|
||||
"""Project data model"""
|
||||
project_id: str
|
||||
name: str
|
||||
status: ProjectStatus
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
# 文件信息
|
||||
|
||||
# File info
|
||||
files: List[Dict[str, str]] = field(default_factory=list) # [{filename, path, size}]
|
||||
total_text_length: int = 0
|
||||
|
||||
# 本体信息(接口1生成后填充)
|
||||
|
||||
# Ontology info (populated after endpoint 1)
|
||||
ontology: Optional[Dict[str, Any]] = None
|
||||
analysis_summary: Optional[str] = None
|
||||
|
||||
# 图谱信息(接口2完成后填充)
|
||||
|
||||
# Graph info (populated after endpoint 2 completes)
|
||||
graph_id: Optional[str] = None
|
||||
graph_build_task_id: Optional[str] = None
|
||||
|
||||
# 配置
|
||||
|
||||
# Configuration
|
||||
simulation_requirement: Optional[str] = None
|
||||
chunk_size: int = 500
|
||||
chunk_overlap: int = 50
|
||||
|
||||
# 错误信息
|
||||
|
||||
# Error info
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
"""Convert to dictionary"""
|
||||
return {
|
||||
"project_id": self.project_id,
|
||||
"name": self.name,
|
||||
|
|
@ -74,7 +74,7 @@ class Project:
|
|||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'Project':
|
||||
"""从字典创建"""
|
||||
"""Create from dictionary"""
|
||||
status = data.get('status', 'created')
|
||||
if isinstance(status, str):
|
||||
status = ProjectStatus(status)
|
||||
|
|
@ -99,52 +99,52 @@ class Project:
|
|||
|
||||
|
||||
class ProjectManager:
|
||||
"""项目管理器 - 负责项目的持久化存储和检索"""
|
||||
|
||||
# 项目存储根目录
|
||||
"""Project manager - handles persistent storage and retrieval of projects"""
|
||||
|
||||
# Root directory for project storage
|
||||
PROJECTS_DIR = os.path.join(Config.UPLOAD_FOLDER, 'projects')
|
||||
|
||||
|
||||
@classmethod
|
||||
def _ensure_projects_dir(cls):
|
||||
"""确保项目目录存在"""
|
||||
"""Ensure the projects directory exists"""
|
||||
os.makedirs(cls.PROJECTS_DIR, exist_ok=True)
|
||||
|
||||
|
||||
@classmethod
|
||||
def _get_project_dir(cls, project_id: str) -> str:
|
||||
"""获取项目目录路径"""
|
||||
"""Get project directory path"""
|
||||
return os.path.join(cls.PROJECTS_DIR, project_id)
|
||||
|
||||
|
||||
@classmethod
|
||||
def _get_project_meta_path(cls, project_id: str) -> str:
|
||||
"""获取项目元数据文件路径"""
|
||||
"""Get project metadata file path"""
|
||||
return os.path.join(cls._get_project_dir(project_id), 'project.json')
|
||||
|
||||
|
||||
@classmethod
|
||||
def _get_project_files_dir(cls, project_id: str) -> str:
|
||||
"""获取项目文件存储目录"""
|
||||
"""Get project files storage directory"""
|
||||
return os.path.join(cls._get_project_dir(project_id), 'files')
|
||||
|
||||
|
||||
@classmethod
|
||||
def _get_project_text_path(cls, project_id: str) -> str:
|
||||
"""获取项目提取文本存储路径"""
|
||||
"""Get path for storing the extracted project text"""
|
||||
return os.path.join(cls._get_project_dir(project_id), 'extracted_text.txt')
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_project(cls, name: str = "Unnamed Project") -> Project:
|
||||
"""
|
||||
创建新项目
|
||||
|
||||
Create a new project.
|
||||
|
||||
Args:
|
||||
name: 项目名称
|
||||
|
||||
name: project name
|
||||
|
||||
Returns:
|
||||
新创建的Project对象
|
||||
newly created Project object
|
||||
"""
|
||||
cls._ensure_projects_dir()
|
||||
|
||||
|
||||
project_id = f"proj_{uuid.uuid4().hex[:12]}"
|
||||
now = datetime.now().isoformat()
|
||||
|
||||
|
||||
project = Project(
|
||||
project_id=project_id,
|
||||
name=name,
|
||||
|
|
@ -152,21 +152,21 @@ class ProjectManager:
|
|||
created_at=now,
|
||||
updated_at=now
|
||||
)
|
||||
|
||||
# 创建项目目录结构
|
||||
|
||||
# Create project directory structure
|
||||
project_dir = cls._get_project_dir(project_id)
|
||||
files_dir = cls._get_project_files_dir(project_id)
|
||||
os.makedirs(project_dir, exist_ok=True)
|
||||
os.makedirs(files_dir, exist_ok=True)
|
||||
|
||||
# 保存项目元数据
|
||||
|
||||
# Save project metadata
|
||||
cls.save_project(project)
|
||||
|
||||
|
||||
return project
|
||||
|
||||
|
||||
@classmethod
|
||||
def save_project(cls, project: Project) -> None:
|
||||
"""保存项目元数据"""
|
||||
"""Save project metadata"""
|
||||
project.updated_at = datetime.now().isoformat()
|
||||
meta_path = cls._get_project_meta_path(project.project_id)
|
||||
|
||||
|
|
@ -176,13 +176,13 @@ class ProjectManager:
|
|||
@classmethod
|
||||
def get_project(cls, project_id: str) -> Optional[Project]:
|
||||
"""
|
||||
获取项目
|
||||
|
||||
Get a project.
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
|
||||
project_id: project ID
|
||||
|
||||
Returns:
|
||||
Project对象,如果不存在返回None
|
||||
Project object, or None if not found
|
||||
"""
|
||||
meta_path = cls._get_project_meta_path(project_id)
|
||||
|
||||
|
|
@ -197,23 +197,23 @@ class ProjectManager:
|
|||
@classmethod
|
||||
def list_projects(cls, limit: int = 50) -> List[Project]:
|
||||
"""
|
||||
列出所有项目
|
||||
|
||||
List all projects.
|
||||
|
||||
Args:
|
||||
limit: 返回数量限制
|
||||
|
||||
limit: result count limit
|
||||
|
||||
Returns:
|
||||
项目列表,按创建时间倒序
|
||||
list of projects sorted by creation time, descending
|
||||
"""
|
||||
cls._ensure_projects_dir()
|
||||
|
||||
|
||||
projects = []
|
||||
for project_id in os.listdir(cls.PROJECTS_DIR):
|
||||
project = cls.get_project(project_id)
|
||||
if project:
|
||||
projects.append(project)
|
||||
|
||||
# 按创建时间倒序排序
|
||||
|
||||
# Sort by creation time, descending
|
||||
projects.sort(key=lambda p: p.created_at, reverse=True)
|
||||
|
||||
return projects[:limit]
|
||||
|
|
@ -221,13 +221,13 @@ class ProjectManager:
|
|||
@classmethod
|
||||
def delete_project(cls, project_id: str) -> bool:
|
||||
"""
|
||||
删除项目及其所有文件
|
||||
|
||||
Delete a project and all its files.
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
|
||||
project_id: project ID
|
||||
|
||||
Returns:
|
||||
是否删除成功
|
||||
True if successfully deleted
|
||||
"""
|
||||
project_dir = cls._get_project_dir(project_id)
|
||||
|
||||
|
|
@ -240,28 +240,28 @@ class ProjectManager:
|
|||
@classmethod
|
||||
def save_file_to_project(cls, project_id: str, file_storage, original_filename: str) -> Dict[str, str]:
|
||||
"""
|
||||
保存上传的文件到项目目录
|
||||
|
||||
Save an uploaded file to the project directory.
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
file_storage: Flask的FileStorage对象
|
||||
original_filename: 原始文件名
|
||||
|
||||
project_id: project ID
|
||||
file_storage: Flask FileStorage object
|
||||
original_filename: original filename
|
||||
|
||||
Returns:
|
||||
文件信息字典 {filename, path, size}
|
||||
file info dict {filename, path, size}
|
||||
"""
|
||||
files_dir = cls._get_project_files_dir(project_id)
|
||||
os.makedirs(files_dir, exist_ok=True)
|
||||
|
||||
# 生成安全的文件名
|
||||
|
||||
# Generate a safe filename
|
||||
ext = os.path.splitext(original_filename)[1].lower()
|
||||
safe_filename = f"{uuid.uuid4().hex[:8]}{ext}"
|
||||
file_path = os.path.join(files_dir, safe_filename)
|
||||
|
||||
# 保存文件
|
||||
|
||||
# Save file
|
||||
file_storage.save(file_path)
|
||||
|
||||
# 获取文件大小
|
||||
|
||||
# Get file size
|
||||
file_size = os.path.getsize(file_path)
|
||||
|
||||
return {
|
||||
|
|
@ -273,14 +273,14 @@ class ProjectManager:
|
|||
|
||||
@classmethod
|
||||
def save_extracted_text(cls, project_id: str, text: str) -> None:
|
||||
"""保存提取的文本"""
|
||||
"""Save extracted text"""
|
||||
text_path = cls._get_project_text_path(project_id)
|
||||
with open(text_path, 'w', encoding='utf-8') as f:
|
||||
f.write(text)
|
||||
|
||||
@classmethod
|
||||
def get_extracted_text(cls, project_id: str) -> Optional[str]:
|
||||
"""获取提取的文本"""
|
||||
"""Get extracted text"""
|
||||
text_path = cls._get_project_text_path(project_id)
|
||||
|
||||
if not os.path.exists(text_path):
|
||||
|
|
@ -291,7 +291,7 @@ class ProjectManager:
|
|||
|
||||
@classmethod
|
||||
def get_project_files(cls, project_id: str) -> List[str]:
|
||||
"""获取项目的所有文件路径"""
|
||||
"""Get all file paths for a project"""
|
||||
files_dir = cls._get_project_files_dir(project_id)
|
||||
|
||||
if not os.path.exists(files_dir):
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
任务状态管理
|
||||
用于跟踪长时间运行的任务(如图谱构建)
|
||||
Task state management
|
||||
Used to track long-running tasks (e.g. graph building).
|
||||
"""
|
||||
|
||||
import uuid
|
||||
|
|
@ -14,30 +14,30 @@ from ..utils.locale import t
|
|||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
"""任务状态枚举"""
|
||||
PENDING = "pending" # 等待中
|
||||
PROCESSING = "processing" # 处理中
|
||||
COMPLETED = "completed" # 已完成
|
||||
FAILED = "failed" # 失败
|
||||
"""Task status enum"""
|
||||
PENDING = "pending" # Waiting
|
||||
PROCESSING = "processing" # In progress
|
||||
COMPLETED = "completed" # Completed
|
||||
FAILED = "failed" # Failed
|
||||
|
||||
|
||||
@dataclass
|
||||
class Task:
|
||||
"""任务数据类"""
|
||||
"""Task data class"""
|
||||
task_id: str
|
||||
task_type: str
|
||||
status: TaskStatus
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
progress: int = 0 # 总进度百分比 0-100
|
||||
message: str = "" # 状态消息
|
||||
result: Optional[Dict] = None # 任务结果
|
||||
error: Optional[str] = None # 错误信息
|
||||
metadata: Dict = field(default_factory=dict) # 额外元数据
|
||||
progress_detail: Dict = field(default_factory=dict) # 详细进度信息
|
||||
|
||||
progress: int = 0 # Total progress percentage 0-100
|
||||
message: str = "" # Status message
|
||||
result: Optional[Dict] = None # Task result
|
||||
error: Optional[str] = None # Error info
|
||||
metadata: Dict = field(default_factory=dict) # Extra metadata
|
||||
progress_detail: Dict = field(default_factory=dict) # Detailed progress info
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
"""Convert to dictionary"""
|
||||
return {
|
||||
"task_id": self.task_id,
|
||||
"task_type": self.task_type,
|
||||
|
|
@ -55,15 +55,15 @@ class Task:
|
|||
|
||||
class TaskManager:
|
||||
"""
|
||||
任务管理器
|
||||
线程安全的任务状态管理
|
||||
Task manager
|
||||
Thread-safe task state management
|
||||
"""
|
||||
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
|
||||
def __new__(cls):
|
||||
"""单例模式"""
|
||||
"""Singleton pattern"""
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
|
|
@ -74,14 +74,14 @@ class TaskManager:
|
|||
|
||||
def create_task(self, task_type: str, metadata: Optional[Dict] = None) -> str:
|
||||
"""
|
||||
创建新任务
|
||||
|
||||
Create a new task.
|
||||
|
||||
Args:
|
||||
task_type: 任务类型
|
||||
metadata: 额外元数据
|
||||
|
||||
task_type: task type
|
||||
metadata: extra metadata
|
||||
|
||||
Returns:
|
||||
任务ID
|
||||
task ID
|
||||
"""
|
||||
task_id = str(uuid.uuid4())
|
||||
now = datetime.now()
|
||||
|
|
@ -101,7 +101,7 @@ class TaskManager:
|
|||
return task_id
|
||||
|
||||
def get_task(self, task_id: str) -> Optional[Task]:
|
||||
"""获取任务"""
|
||||
"""Get a task"""
|
||||
with self._task_lock:
|
||||
return self._tasks.get(task_id)
|
||||
|
||||
|
|
@ -116,16 +116,16 @@ class TaskManager:
|
|||
progress_detail: Optional[Dict] = None
|
||||
):
|
||||
"""
|
||||
更新任务状态
|
||||
|
||||
Update task status.
|
||||
|
||||
Args:
|
||||
task_id: 任务ID
|
||||
status: 新状态
|
||||
progress: 进度
|
||||
message: 消息
|
||||
result: 结果
|
||||
error: 错误信息
|
||||
progress_detail: 详细进度信息
|
||||
task_id: task ID
|
||||
status: new status
|
||||
progress: progress
|
||||
message: message
|
||||
result: result
|
||||
error: error info
|
||||
progress_detail: detailed progress info
|
||||
"""
|
||||
with self._task_lock:
|
||||
task = self._tasks.get(task_id)
|
||||
|
|
@ -145,7 +145,7 @@ class TaskManager:
|
|||
task.progress_detail = progress_detail
|
||||
|
||||
def complete_task(self, task_id: str, result: Dict):
|
||||
"""标记任务完成"""
|
||||
"""Mark task as complete"""
|
||||
self.update_task(
|
||||
task_id,
|
||||
status=TaskStatus.COMPLETED,
|
||||
|
|
@ -155,7 +155,7 @@ class TaskManager:
|
|||
)
|
||||
|
||||
def fail_task(self, task_id: str, error: str):
|
||||
"""标记任务失败"""
|
||||
"""Mark task as failed"""
|
||||
self.update_task(
|
||||
task_id,
|
||||
status=TaskStatus.FAILED,
|
||||
|
|
@ -164,7 +164,7 @@ class TaskManager:
|
|||
)
|
||||
|
||||
def list_tasks(self, task_type: Optional[str] = None) -> list:
|
||||
"""列出任务"""
|
||||
"""List tasks"""
|
||||
with self._task_lock:
|
||||
tasks = list(self._tasks.values())
|
||||
if task_type:
|
||||
|
|
@ -172,7 +172,7 @@ class TaskManager:
|
|||
return [t.to_dict() for t in sorted(tasks, key=lambda x: x.created_at, reverse=True)]
|
||||
|
||||
def cleanup_old_tasks(self, max_age_hours: int = 24):
|
||||
"""清理旧任务"""
|
||||
"""Clean up old tasks"""
|
||||
from datetime import timedelta
|
||||
cutoff = datetime.now() - timedelta(hours=max_age_hours)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
业务服务模块
|
||||
Business services module
|
||||
"""
|
||||
|
||||
from .ontology_generator import OntologyGenerator
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
图谱构建服务
|
||||
接口2:使用Zep API构建Standalone Graph
|
||||
Graph building service
|
||||
Endpoint 2: Build a Standalone Graph using the Zep API
|
||||
"""
|
||||
|
||||
import os
|
||||
|
|
@ -22,7 +22,7 @@ from ..utils.locale import t, get_locale, set_locale
|
|||
|
||||
@dataclass
|
||||
class GraphInfo:
|
||||
"""图谱信息"""
|
||||
"""Graph info"""
|
||||
graph_id: str
|
||||
node_count: int
|
||||
edge_count: int
|
||||
|
|
@ -39,14 +39,14 @@ class GraphInfo:
|
|||
|
||||
class GraphBuilderService:
|
||||
"""
|
||||
图谱构建服务
|
||||
负责调用Zep API构建知识图谱
|
||||
Graph building service
|
||||
Responsible for calling the Zep API to build the knowledge graph.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
self.api_key = api_key or Config.ZEP_API_KEY
|
||||
if not self.api_key:
|
||||
raise ValueError("ZEP_API_KEY 未配置")
|
||||
raise ValueError("ZEP_API_KEY is not configured")
|
||||
|
||||
self.client = Zep(api_key=self.api_key)
|
||||
self.task_manager = TaskManager()
|
||||
|
|
@ -61,20 +61,20 @@ class GraphBuilderService:
|
|||
batch_size: int = 3
|
||||
) -> str:
|
||||
"""
|
||||
异步构建图谱
|
||||
|
||||
Build the graph asynchronously.
|
||||
|
||||
Args:
|
||||
text: 输入文本
|
||||
ontology: 本体定义(来自接口1的输出)
|
||||
graph_name: 图谱名称
|
||||
chunk_size: 文本块大小
|
||||
chunk_overlap: 块重叠大小
|
||||
batch_size: 每批发送的块数量
|
||||
|
||||
text: input text
|
||||
ontology: ontology definition (output from endpoint 1)
|
||||
graph_name: graph name
|
||||
chunk_size: text chunk size
|
||||
chunk_overlap: chunk overlap size
|
||||
batch_size: number of chunks per batch
|
||||
|
||||
Returns:
|
||||
任务ID
|
||||
task ID
|
||||
"""
|
||||
# 创建任务
|
||||
# Create task
|
||||
task_id = self.task_manager.create_task(
|
||||
task_type="graph_build",
|
||||
metadata={
|
||||
|
|
@ -87,7 +87,7 @@ class GraphBuilderService:
|
|||
# Capture locale before spawning background thread
|
||||
current_locale = get_locale()
|
||||
|
||||
# 在后台线程中执行构建
|
||||
# Run build in background thread
|
||||
thread = threading.Thread(
|
||||
target=self._build_graph_worker,
|
||||
args=(task_id, text, ontology, graph_name, chunk_size, chunk_overlap, batch_size, current_locale)
|
||||
|
|
@ -108,7 +108,7 @@ class GraphBuilderService:
|
|||
batch_size: int,
|
||||
locale: str = 'zh'
|
||||
):
|
||||
"""图谱构建工作线程"""
|
||||
"""Graph build worker thread"""
|
||||
set_locale(locale)
|
||||
try:
|
||||
self.task_manager.update_task(
|
||||
|
|
@ -118,7 +118,7 @@ class GraphBuilderService:
|
|||
message=t('progress.startBuildingGraph')
|
||||
)
|
||||
|
||||
# 1. 创建图谱
|
||||
# 1. Create graph
|
||||
graph_id = self.create_graph(graph_name)
|
||||
self.task_manager.update_task(
|
||||
task_id,
|
||||
|
|
@ -126,7 +126,7 @@ class GraphBuilderService:
|
|||
message=t('progress.graphCreated', graphId=graph_id)
|
||||
)
|
||||
|
||||
# 2. 设置本体
|
||||
# 2. Set ontology
|
||||
self.set_ontology(graph_id, ontology)
|
||||
self.task_manager.update_task(
|
||||
task_id,
|
||||
|
|
@ -134,7 +134,7 @@ class GraphBuilderService:
|
|||
message=t('progress.ontologySet')
|
||||
)
|
||||
|
||||
# 3. 文本分块
|
||||
# 3. Split text into chunks
|
||||
chunks = TextProcessor.split_text(text, chunk_size, chunk_overlap)
|
||||
total_chunks = len(chunks)
|
||||
self.task_manager.update_task(
|
||||
|
|
@ -143,7 +143,7 @@ class GraphBuilderService:
|
|||
message=t('progress.textSplit', count=total_chunks)
|
||||
)
|
||||
|
||||
# 4. 分批发送数据
|
||||
# 4. Send data in batches
|
||||
episode_uuids = self.add_text_batches(
|
||||
graph_id, chunks, batch_size,
|
||||
lambda msg, prog: self.task_manager.update_task(
|
||||
|
|
@ -153,7 +153,7 @@ class GraphBuilderService:
|
|||
)
|
||||
)
|
||||
|
||||
# 5. 等待Zep处理完成
|
||||
# 5. Wait for Zep processing to complete
|
||||
self.task_manager.update_task(
|
||||
task_id,
|
||||
progress=60,
|
||||
|
|
@ -169,7 +169,7 @@ class GraphBuilderService:
|
|||
)
|
||||
)
|
||||
|
||||
# 6. 获取图谱信息
|
||||
# 6. Fetch graph info
|
||||
self.task_manager.update_task(
|
||||
task_id,
|
||||
progress=90,
|
||||
|
|
@ -178,7 +178,7 @@ class GraphBuilderService:
|
|||
|
||||
graph_info = self._get_graph_info(graph_id)
|
||||
|
||||
# 完成
|
||||
# Complete
|
||||
self.task_manager.complete_task(task_id, {
|
||||
"graph_id": graph_id,
|
||||
"graph_info": graph_info.to_dict(),
|
||||
|
|
@ -191,7 +191,7 @@ class GraphBuilderService:
|
|||
self.task_manager.fail_task(task_id, error_msg)
|
||||
|
||||
def create_graph(self, name: str) -> str:
|
||||
"""创建Zep图谱(公开方法)"""
|
||||
"""Create a Zep graph (public method)"""
|
||||
graph_id = f"mirofish_{uuid.uuid4().hex[:16]}"
|
||||
|
||||
self.client.graph.create(
|
||||
|
|
@ -203,74 +203,74 @@ class GraphBuilderService:
|
|||
return graph_id
|
||||
|
||||
def set_ontology(self, graph_id: str, ontology: Dict[str, Any]):
|
||||
"""设置图谱本体(公开方法)"""
|
||||
"""Set graph ontology (public method)"""
|
||||
import warnings
|
||||
from typing import Optional
|
||||
from pydantic import Field
|
||||
from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel
|
||||
|
||||
# 抑制 Pydantic v2 关于 Field(default=None) 的警告
|
||||
# 这是 Zep SDK 要求的用法,警告来自动态类创建,可以安全忽略
|
||||
|
||||
# Suppress Pydantic v2 warnings about Field(default=None)
|
||||
# This is the usage required by the Zep SDK; warnings come from dynamic class creation and can be safely ignored
|
||||
warnings.filterwarnings('ignore', category=UserWarning, module='pydantic')
|
||||
|
||||
# Zep 保留名称,不能作为属性名
|
||||
|
||||
# Zep reserved names that cannot be used as attribute names
|
||||
RESERVED_NAMES = {'uuid', 'name', 'group_id', 'name_embedding', 'summary', 'created_at'}
|
||||
|
||||
|
||||
def safe_attr_name(attr_name: str) -> str:
|
||||
"""将保留名称转换为安全名称"""
|
||||
"""Convert reserved names to safe attribute names"""
|
||||
if attr_name.lower() in RESERVED_NAMES:
|
||||
return f"entity_{attr_name}"
|
||||
return attr_name
|
||||
|
||||
# 动态创建实体类型
|
||||
# Dynamically create entity types
|
||||
entity_types = {}
|
||||
for entity_def in ontology.get("entity_types", []):
|
||||
name = entity_def["name"]
|
||||
description = entity_def.get("description", f"A {name} entity.")
|
||||
|
||||
# 创建属性字典和类型注解(Pydantic v2 需要)
|
||||
|
||||
# Build attribute dict and type annotations (required by Pydantic v2)
|
||||
attrs = {"__doc__": description}
|
||||
annotations = {}
|
||||
|
||||
|
||||
for attr_def in entity_def.get("attributes", []):
|
||||
attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称
|
||||
attr_name = safe_attr_name(attr_def["name"]) # Use safe name
|
||||
attr_desc = attr_def.get("description", attr_name)
|
||||
# Zep API 需要 Field 的 description,这是必需的
|
||||
# Zep API requires Field description — this is mandatory
|
||||
attrs[attr_name] = Field(description=attr_desc, default=None)
|
||||
annotations[attr_name] = Optional[EntityText] # 类型注解
|
||||
|
||||
annotations[attr_name] = Optional[EntityText] # Type annotation
|
||||
|
||||
attrs["__annotations__"] = annotations
|
||||
|
||||
# 动态创建类
|
||||
|
||||
# Dynamically create class
|
||||
entity_class = type(name, (EntityModel,), attrs)
|
||||
entity_class.__doc__ = description
|
||||
entity_types[name] = entity_class
|
||||
|
||||
# 动态创建边类型
|
||||
# Dynamically create edge types
|
||||
edge_definitions = {}
|
||||
for edge_def in ontology.get("edge_types", []):
|
||||
name = edge_def["name"]
|
||||
description = edge_def.get("description", f"A {name} relationship.")
|
||||
|
||||
# 创建属性字典和类型注解
|
||||
|
||||
# Build attribute dict and type annotations
|
||||
attrs = {"__doc__": description}
|
||||
annotations = {}
|
||||
|
||||
|
||||
for attr_def in edge_def.get("attributes", []):
|
||||
attr_name = safe_attr_name(attr_def["name"]) # 使用安全名称
|
||||
attr_name = safe_attr_name(attr_def["name"]) # Use safe name
|
||||
attr_desc = attr_def.get("description", attr_name)
|
||||
# Zep API 需要 Field 的 description,这是必需的
|
||||
# Zep API requires Field description — this is mandatory
|
||||
attrs[attr_name] = Field(description=attr_desc, default=None)
|
||||
annotations[attr_name] = Optional[str] # 边属性用str类型
|
||||
|
||||
annotations[attr_name] = Optional[str] # Edge attributes use str type
|
||||
|
||||
attrs["__annotations__"] = annotations
|
||||
|
||||
# 动态创建类
|
||||
|
||||
# Dynamically create class
|
||||
class_name = ''.join(word.capitalize() for word in name.split('_'))
|
||||
edge_class = type(class_name, (EdgeModel,), attrs)
|
||||
edge_class.__doc__ = description
|
||||
|
||||
# 构建source_targets
|
||||
# Build source_targets
|
||||
source_targets = []
|
||||
for st in edge_def.get("source_targets", []):
|
||||
source_targets.append(
|
||||
|
|
@ -283,7 +283,7 @@ class GraphBuilderService:
|
|||
if source_targets:
|
||||
edge_definitions[name] = (edge_class, source_targets)
|
||||
|
||||
# 调用Zep API设置本体
|
||||
# Call Zep API to set ontology
|
||||
if entity_types or edge_definitions:
|
||||
self.client.graph.set_ontology(
|
||||
graph_ids=[graph_id],
|
||||
|
|
@ -298,7 +298,7 @@ class GraphBuilderService:
|
|||
batch_size: int = 3,
|
||||
progress_callback: Optional[Callable] = None
|
||||
) -> List[str]:
|
||||
"""分批添加文本到图谱,返回所有 episode 的 uuid 列表"""
|
||||
"""Add text to the graph in batches; returns a list of all episode UUIDs"""
|
||||
episode_uuids = []
|
||||
total_chunks = len(chunks)
|
||||
|
||||
|
|
@ -314,27 +314,27 @@ class GraphBuilderService:
|
|||
progress
|
||||
)
|
||||
|
||||
# 构建episode数据
|
||||
# Build episode data
|
||||
episodes = [
|
||||
EpisodeData(data=chunk, type="text")
|
||||
for chunk in batch_chunks
|
||||
]
|
||||
|
||||
# 发送到Zep
|
||||
# Send to Zep
|
||||
try:
|
||||
batch_result = self.client.graph.add_batch(
|
||||
graph_id=graph_id,
|
||||
episodes=episodes
|
||||
)
|
||||
|
||||
# 收集返回的 episode uuid
|
||||
# Collect returned episode UUIDs
|
||||
if batch_result and isinstance(batch_result, list):
|
||||
for ep in batch_result:
|
||||
ep_uuid = getattr(ep, 'uuid_', None) or getattr(ep, 'uuid', None)
|
||||
if ep_uuid:
|
||||
episode_uuids.append(ep_uuid)
|
||||
|
||||
# 避免请求过快
|
||||
# Avoid sending requests too quickly
|
||||
time.sleep(1)
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -350,7 +350,7 @@ class GraphBuilderService:
|
|||
progress_callback: Optional[Callable] = None,
|
||||
timeout: int = 600
|
||||
):
|
||||
"""等待所有 episode 处理完成(通过查询每个 episode 的 processed 状态)"""
|
||||
"""Wait for all episodes to finish processing (by polling each episode's processed status)"""
|
||||
if not episode_uuids:
|
||||
if progress_callback:
|
||||
progress_callback(t('progress.noEpisodesWait'), 1.0)
|
||||
|
|
@ -373,42 +373,42 @@ class GraphBuilderService:
|
|||
)
|
||||
break
|
||||
|
||||
# 检查每个 episode 的处理状态
|
||||
# Check processing status of each episode
|
||||
for ep_uuid in list(pending_episodes):
|
||||
try:
|
||||
episode = self.client.graph.episode.get(uuid_=ep_uuid)
|
||||
is_processed = getattr(episode, 'processed', False)
|
||||
|
||||
|
||||
if is_processed:
|
||||
pending_episodes.remove(ep_uuid)
|
||||
completed_count += 1
|
||||
|
||||
|
||||
except Exception as e:
|
||||
# 忽略单个查询错误,继续
|
||||
# Ignore individual query errors and continue
|
||||
pass
|
||||
|
||||
|
||||
elapsed = int(time.time() - start_time)
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
t('progress.zepProcessing', completed=completed_count, total=total_episodes, pending=len(pending_episodes), elapsed=elapsed),
|
||||
completed_count / total_episodes if total_episodes > 0 else 0
|
||||
)
|
||||
|
||||
|
||||
if pending_episodes:
|
||||
time.sleep(3) # 每3秒检查一次
|
||||
time.sleep(3) # Check every 3 seconds
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(t('progress.processingComplete', completed=completed_count, total=total_episodes), 1.0)
|
||||
|
||||
def _get_graph_info(self, graph_id: str) -> GraphInfo:
|
||||
"""获取图谱信息"""
|
||||
# 获取节点(分页)
|
||||
"""Retrieve graph info"""
|
||||
# Fetch nodes (paginated)
|
||||
nodes = fetch_all_nodes(self.client, graph_id)
|
||||
|
||||
# 获取边(分页)
|
||||
# Fetch edges (paginated)
|
||||
edges = fetch_all_edges(self.client, graph_id)
|
||||
|
||||
# 统计实体类型
|
||||
# Count entity types
|
||||
entity_types = set()
|
||||
for node in nodes:
|
||||
if node.labels:
|
||||
|
|
@ -425,25 +425,25 @@ class GraphBuilderService:
|
|||
|
||||
def get_graph_data(self, graph_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取完整图谱数据(包含详细信息)
|
||||
|
||||
Retrieve full graph data (with detailed information).
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
|
||||
graph_id: graph ID
|
||||
|
||||
Returns:
|
||||
包含nodes和edges的字典,包括时间信息、属性等详细数据
|
||||
Dictionary containing nodes and edges with timestamps, attributes, and other details
|
||||
"""
|
||||
nodes = fetch_all_nodes(self.client, graph_id)
|
||||
edges = fetch_all_edges(self.client, graph_id)
|
||||
|
||||
# 创建节点映射用于获取节点名称
|
||||
# Build node map for looking up node names
|
||||
node_map = {}
|
||||
for node in nodes:
|
||||
node_map[node.uuid_] = node.name or ""
|
||||
|
||||
|
||||
nodes_data = []
|
||||
for node in nodes:
|
||||
# 获取创建时间
|
||||
# Get creation timestamp
|
||||
created_at = getattr(node, 'created_at', None)
|
||||
if created_at:
|
||||
created_at = str(created_at)
|
||||
|
|
@ -459,20 +459,20 @@ class GraphBuilderService:
|
|||
|
||||
edges_data = []
|
||||
for edge in edges:
|
||||
# 获取时间信息
|
||||
# Get timestamps
|
||||
created_at = getattr(edge, 'created_at', None)
|
||||
valid_at = getattr(edge, 'valid_at', None)
|
||||
invalid_at = getattr(edge, 'invalid_at', None)
|
||||
expired_at = getattr(edge, 'expired_at', None)
|
||||
|
||||
# 获取 episodes
|
||||
|
||||
# Get episodes
|
||||
episodes = getattr(edge, 'episodes', None) or getattr(edge, 'episode_ids', None)
|
||||
if episodes and not isinstance(episodes, list):
|
||||
episodes = [str(episodes)]
|
||||
elif episodes:
|
||||
episodes = [str(e) for e in episodes]
|
||||
|
||||
# 获取 fact_type
|
||||
|
||||
# Get fact_type
|
||||
fact_type = getattr(edge, 'fact_type', None) or edge.name or ""
|
||||
|
||||
edges_data.append({
|
||||
|
|
@ -501,6 +501,6 @@ class GraphBuilderService:
|
|||
}
|
||||
|
||||
def delete_graph(self, graph_id: str):
|
||||
"""删除图谱"""
|
||||
"""Delete graph"""
|
||||
self.client.graph.delete(graph_id=graph_id)
|
||||
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
本体生成服务
|
||||
接口1:分析文本内容,生成适合社会模拟的实体和关系类型定义
|
||||
Ontology generation service
|
||||
Endpoint 1: Analyze text content and generate entity and relationship type definitions suitable for social simulation.
|
||||
"""
|
||||
|
||||
import json
|
||||
|
|
@ -14,174 +14,174 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
def _to_pascal_case(name: str) -> str:
|
||||
"""将任意格式的名称转换为 PascalCase(如 'works_for' -> 'WorksFor', 'person' -> 'Person')"""
|
||||
# 按非字母数字字符分割
|
||||
"""Convert a name in any format to PascalCase (e.g. 'works_for' -> 'WorksFor', 'person' -> 'Person')"""
|
||||
# Split on non-alphanumeric characters
|
||||
parts = re.split(r'[^a-zA-Z0-9]+', name)
|
||||
# 再按 camelCase 边界分割(如 'camelCase' -> ['camel', 'Case'])
|
||||
# Also split on camelCase boundaries (e.g. 'camelCase' -> ['camel', 'Case'])
|
||||
words = []
|
||||
for part in parts:
|
||||
words.extend(re.sub(r'([a-z])([A-Z])', r'\1_\2', part).split('_'))
|
||||
# 每个词首字母大写,过滤空串
|
||||
# Capitalize each word and filter empty strings
|
||||
result = ''.join(word.capitalize() for word in words if word)
|
||||
return result if result else 'Unknown'
|
||||
|
||||
|
||||
# 本体生成的系统提示词
|
||||
ONTOLOGY_SYSTEM_PROMPT = """你是一个专业的知识图谱本体设计专家。你的任务是分析给定的文本内容和模拟需求,设计适合**社交媒体舆论模拟**的实体类型和关系类型。
|
||||
# System prompt for ontology generation
|
||||
ONTOLOGY_SYSTEM_PROMPT = """You are a professional knowledge graph ontology design expert. Your task is to analyze the given text content and simulation requirements, and design entity types and relationship types suitable for **social media opinion simulation**.
|
||||
|
||||
**重要:你必须输出有效的JSON格式数据,不要输出任何其他内容。**
|
||||
**Important: You must output valid JSON format data, and nothing else.**
|
||||
|
||||
## 核心任务背景
|
||||
## Core Task Background
|
||||
|
||||
我们正在构建一个**社交媒体舆论模拟系统**。在这个系统中:
|
||||
- 每个实体都是一个可以在社交媒体上发声、互动、传播信息的"账号"或"主体"
|
||||
- 实体之间会相互影响、转发、评论、回应
|
||||
- 我们需要模拟舆论事件中各方的反应和信息传播路径
|
||||
We are building a **social media opinion simulation system**. In this system:
|
||||
- Every entity is an "account" or "subject" that can speak out, interact, and spread information on social media
|
||||
- Entities influence each other, repost, comment, and respond
|
||||
- We need to simulate each party's reaction and the information propagation path during opinion events
|
||||
|
||||
因此,**实体必须是现实中真实存在的、可以在社媒上发声和互动的主体**:
|
||||
Therefore, **entities must be real-world subjects that exist and can speak out and interact on social media**:
|
||||
|
||||
**可以是**:
|
||||
- 具体的个人(公众人物、当事人、意见领袖、专家学者、普通人)
|
||||
- 公司、企业(包括其官方账号)
|
||||
- 组织机构(大学、协会、NGO、工会等)
|
||||
- 政府部门、监管机构
|
||||
- 媒体机构(报纸、电视台、自媒体、网站)
|
||||
- 社交媒体平台本身
|
||||
- 特定群体代表(如校友会、粉丝团、维权群体等)
|
||||
**Can be**:
|
||||
- Specific individuals (public figures, persons involved, opinion leaders, experts and scholars, ordinary people)
|
||||
- Companies and enterprises (including their official accounts)
|
||||
- Organizations (universities, associations, NGOs, unions, etc.)
|
||||
- Government departments and regulatory agencies
|
||||
- Media organizations (newspapers, TV stations, self-media, websites)
|
||||
- Social media platforms themselves
|
||||
- Representatives of specific groups (e.g. alumni associations, fan clubs, rights-protection groups, etc.)
|
||||
|
||||
**不可以是**:
|
||||
- 抽象概念(如"舆论"、"情绪"、"趋势")
|
||||
- 主题/话题(如"学术诚信"、"教育改革")
|
||||
- 观点/态度(如"支持方"、"反对方")
|
||||
**Cannot be**:
|
||||
- Abstract concepts (e.g. "public opinion", "emotion", "trend")
|
||||
- Topics/themes (e.g. "academic integrity", "education reform")
|
||||
- Viewpoints/attitudes (e.g. "supporters", "opponents")
|
||||
|
||||
## 输出格式
|
||||
## Output Format
|
||||
|
||||
请输出JSON格式,包含以下结构:
|
||||
Please output JSON format with the following structure:
|
||||
|
||||
```json
|
||||
{
|
||||
"entity_types": [
|
||||
{
|
||||
"name": "实体类型名称(英文,PascalCase)",
|
||||
"description": "简短描述(英文,不超过100字符)",
|
||||
"name": "Entity type name (English, PascalCase)",
|
||||
"description": "Brief description (English, max 100 characters)",
|
||||
"attributes": [
|
||||
{
|
||||
"name": "属性名(英文,snake_case)",
|
||||
"name": "Attribute name (English, snake_case)",
|
||||
"type": "text",
|
||||
"description": "属性描述"
|
||||
"description": "Attribute description"
|
||||
}
|
||||
],
|
||||
"examples": ["示例实体1", "示例实体2"]
|
||||
"examples": ["Example entity 1", "Example entity 2"]
|
||||
}
|
||||
],
|
||||
"edge_types": [
|
||||
{
|
||||
"name": "关系类型名称(英文,UPPER_SNAKE_CASE)",
|
||||
"description": "简短描述(英文,不超过100字符)",
|
||||
"name": "Relationship type name (English, UPPER_SNAKE_CASE)",
|
||||
"description": "Brief description (English, max 100 characters)",
|
||||
"source_targets": [
|
||||
{"source": "源实体类型", "target": "目标实体类型"}
|
||||
{"source": "Source entity type", "target": "Target entity type"}
|
||||
],
|
||||
"attributes": []
|
||||
}
|
||||
],
|
||||
"analysis_summary": "对文本内容的简要分析说明"
|
||||
"analysis_summary": "Brief analysis summary of the text content"
|
||||
}
|
||||
```
|
||||
|
||||
## 设计指南(极其重要!)
|
||||
## Design Guidelines (Extremely Important!)
|
||||
|
||||
### 1. 实体类型设计 - 必须严格遵守
|
||||
### 1. Entity Type Design — Must Be Strictly Followed
|
||||
|
||||
**数量要求:必须正好10个实体类型**
|
||||
**Quantity requirement: exactly 10 entity types**
|
||||
|
||||
**层次结构要求(必须同时包含具体类型和兜底类型)**:
|
||||
**Hierarchy requirement (must include both specific types and fallback types)**:
|
||||
|
||||
你的10个实体类型必须包含以下层次:
|
||||
Your 10 entity types must include the following levels:
|
||||
|
||||
A. **兜底类型(必须包含,放在列表最后2个)**:
|
||||
- `Person`: 任何自然人个体的兜底类型。当一个人不属于其他更具体的人物类型时,归入此类。
|
||||
- `Organization`: 任何组织机构的兜底类型。当一个组织不属于其他更具体的组织类型时,归入此类。
|
||||
A. **Fallback types (required, placed as the last 2 in the list)**:
|
||||
- `Person`: Fallback type for any individual person. Use this when a person does not fit any other more specific person type.
|
||||
- `Organization`: Fallback type for any organization. Use this when an organization does not fit any other more specific organization type.
|
||||
|
||||
B. **具体类型(8个,根据文本内容设计)**:
|
||||
- 针对文本中出现的主要角色,设计更具体的类型
|
||||
- 例如:如果文本涉及学术事件,可以有 `Student`, `Professor`, `University`
|
||||
- 例如:如果文本涉及商业事件,可以有 `Company`, `CEO`, `Employee`
|
||||
B. **Specific types (8 types, designed based on text content)**:
|
||||
- Design more specific types for the main roles that appear in the text
|
||||
- Example: if the text involves an academic event, you might have `Student`, `Professor`, `University`
|
||||
- Example: if the text involves a business event, you might have `Company`, `CEO`, `Employee`
|
||||
|
||||
**为什么需要兜底类型**:
|
||||
- 文本中会出现各种人物,如"中小学教师"、"路人甲"、"某位网友"
|
||||
- 如果没有专门的类型匹配,他们应该被归入 `Person`
|
||||
- 同理,小型组织、临时团体等应该归入 `Organization`
|
||||
**Why fallback types are needed**:
|
||||
- Various people appear in text, such as "primary and secondary school teachers", "passersby", "some netizen"
|
||||
- If there is no dedicated type to match them, they should fall into `Person`
|
||||
- Similarly, small organizations, ad hoc groups, etc. should fall into `Organization`
|
||||
|
||||
**具体类型的设计原则**:
|
||||
- 从文本中识别出高频出现或关键的角色类型
|
||||
- 每个具体类型应该有明确的边界,避免重叠
|
||||
- description 必须清晰说明这个类型和兜底类型的区别
|
||||
**Principles for designing specific types**:
|
||||
- Identify high-frequency or key role types from the text
|
||||
- Each specific type should have clear boundaries and avoid overlap
|
||||
- The description must clearly explain the difference between this type and the fallback types
|
||||
|
||||
### 2. 关系类型设计
|
||||
### 2. Relationship Type Design
|
||||
|
||||
- 数量:6-10个
|
||||
- 关系应该反映社媒互动中的真实联系
|
||||
- 确保关系的 source_targets 涵盖你定义的实体类型
|
||||
- Quantity: 6-10
|
||||
- Relationships should reflect real connections in social media interactions
|
||||
- Ensure the source_targets in relationships cover the entity types you have defined
|
||||
|
||||
### 3. 属性设计
|
||||
### 3. Attribute Design
|
||||
|
||||
- 每个实体类型1-3个关键属性
|
||||
- **注意**:属性名不能使用 `name`、`uuid`、`group_id`、`created_at`、`summary`(这些是系统保留字)
|
||||
- 推荐使用:`full_name`, `title`, `role`, `position`, `location`, `description` 等
|
||||
- 1-3 key attributes per entity type
|
||||
- **Note**: Attribute names must not use `name`, `uuid`, `group_id`, `created_at`, `summary` (these are system reserved words)
|
||||
- Recommended: `full_name`, `title`, `role`, `position`, `location`, `description`, etc.
|
||||
|
||||
## 实体类型参考
|
||||
## Entity Type Reference
|
||||
|
||||
**个人类(具体)**:
|
||||
- Student: 学生
|
||||
- Professor: 教授/学者
|
||||
- Journalist: 记者
|
||||
- Celebrity: 明星/网红
|
||||
- Executive: 高管
|
||||
- Official: 政府官员
|
||||
- Lawyer: 律师
|
||||
- Doctor: 医生
|
||||
**Individual types (specific)**:
|
||||
- Student: student
|
||||
- Professor: professor/scholar
|
||||
- Journalist: journalist
|
||||
- Celebrity: celebrity/influencer
|
||||
- Executive: corporate executive
|
||||
- Official: government official
|
||||
- Lawyer: lawyer
|
||||
- Doctor: doctor
|
||||
|
||||
**个人类(兜底)**:
|
||||
- Person: 任何自然人(不属于上述具体类型时使用)
|
||||
**Individual types (fallback)**:
|
||||
- Person: any individual (use when not fitting the specific types above)
|
||||
|
||||
**组织类(具体)**:
|
||||
- University: 高校
|
||||
- Company: 公司企业
|
||||
- GovernmentAgency: 政府机构
|
||||
- MediaOutlet: 媒体机构
|
||||
- Hospital: 医院
|
||||
- School: 中小学
|
||||
- NGO: 非政府组织
|
||||
**Organization types (specific)**:
|
||||
- University: university/college
|
||||
- Company: company/enterprise
|
||||
- GovernmentAgency: government agency
|
||||
- MediaOutlet: media organization
|
||||
- Hospital: hospital
|
||||
- School: primary/secondary school
|
||||
- NGO: non-governmental organization
|
||||
|
||||
**组织类(兜底)**:
|
||||
- Organization: 任何组织机构(不属于上述具体类型时使用)
|
||||
**Organization types (fallback)**:
|
||||
- Organization: any organization (use when not fitting the specific types above)
|
||||
|
||||
## 关系类型参考
|
||||
## Relationship Type Reference
|
||||
|
||||
- WORKS_FOR: 工作于
|
||||
- STUDIES_AT: 就读于
|
||||
- AFFILIATED_WITH: 隶属于
|
||||
- REPRESENTS: 代表
|
||||
- REGULATES: 监管
|
||||
- REPORTS_ON: 报道
|
||||
- COMMENTS_ON: 评论
|
||||
- RESPONDS_TO: 回应
|
||||
- SUPPORTS: 支持
|
||||
- OPPOSES: 反对
|
||||
- COLLABORATES_WITH: 合作
|
||||
- COMPETES_WITH: 竞争
|
||||
- WORKS_FOR: works for
|
||||
- STUDIES_AT: studies at
|
||||
- AFFILIATED_WITH: affiliated with
|
||||
- REPRESENTS: represents
|
||||
- REGULATES: regulates
|
||||
- REPORTS_ON: reports on
|
||||
- COMMENTS_ON: comments on
|
||||
- RESPONDS_TO: responds to
|
||||
- SUPPORTS: supports
|
||||
- OPPOSES: opposes
|
||||
- COLLABORATES_WITH: collaborates with
|
||||
- COMPETES_WITH: competes with
|
||||
"""
|
||||
|
||||
|
||||
class OntologyGenerator:
|
||||
"""
|
||||
本体生成器
|
||||
分析文本内容,生成实体和关系类型定义
|
||||
Ontology generator
|
||||
Analyzes text content and generates entity and relationship type definitions.
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, llm_client: Optional[LLMClient] = None):
|
||||
self.llm_client = llm_client or LLMClient()
|
||||
|
||||
|
||||
def generate(
|
||||
self,
|
||||
document_texts: List[str],
|
||||
|
|
@ -189,107 +189,112 @@ class OntologyGenerator:
|
|||
additional_context: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
生成本体定义
|
||||
|
||||
Generate ontology definition.
|
||||
|
||||
Args:
|
||||
document_texts: 文档文本列表
|
||||
simulation_requirement: 模拟需求描述
|
||||
additional_context: 额外上下文
|
||||
|
||||
document_texts: list of document texts
|
||||
simulation_requirement: simulation requirement description
|
||||
additional_context: additional context
|
||||
|
||||
Returns:
|
||||
本体定义(entity_types, edge_types等)
|
||||
Ontology definition (entity_types, edge_types, etc.)
|
||||
"""
|
||||
# 构建用户消息
|
||||
user_message = self._build_user_message(
|
||||
document_texts,
|
||||
simulation_requirement,
|
||||
additional_context
|
||||
)
|
||||
|
||||
lang_instruction = get_language_instruction()
|
||||
system_prompt = f"{ONTOLOGY_SYSTEM_PROMPT}\n\n{lang_instruction}\nIMPORTANT: Entity type names MUST be in English PascalCase (e.g., 'PersonEntity', 'MediaOrganization'). Relationship type names MUST be in English UPPER_SNAKE_CASE (e.g., 'WORKS_FOR'). Attribute names MUST be in English snake_case. Only description fields and analysis_summary should use the specified language above."
|
||||
|
||||
# Build user message
|
||||
user_message = self._build_user_message(
|
||||
document_texts,
|
||||
simulation_requirement,
|
||||
additional_context,
|
||||
lang_instruction
|
||||
)
|
||||
|
||||
system_prompt = f"LANGUAGE INSTRUCTION (HIGHEST PRIORITY — MUST BE FOLLOWED): {lang_instruction} All description fields, analysis_summary, and examples MUST be written in this language.\n\n{ONTOLOGY_SYSTEM_PROMPT}\n\n{lang_instruction}\nIMPORTANT: Entity type names MUST be in English PascalCase (e.g., 'PersonEntity', 'MediaOrganization'). Relationship type names MUST be in English UPPER_SNAKE_CASE (e.g., 'WORKS_FOR'). Attribute names MUST be in English snake_case. Only description fields and analysis_summary should use the specified language above."
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_message}
|
||||
]
|
||||
|
||||
# 调用LLM
|
||||
|
||||
# Call LLM
|
||||
result = self.llm_client.chat_json(
|
||||
messages=messages,
|
||||
temperature=0.3,
|
||||
max_tokens=4096
|
||||
)
|
||||
|
||||
# 验证和后处理
|
||||
|
||||
# Validate and post-process
|
||||
result = self._validate_and_process(result)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
# 传给 LLM 的文本最大长度(5万字)
|
||||
|
||||
# Maximum text length passed to LLM (50,000 characters)
|
||||
MAX_TEXT_LENGTH_FOR_LLM = 50000
|
||||
|
||||
|
||||
def _build_user_message(
|
||||
self,
|
||||
document_texts: List[str],
|
||||
simulation_requirement: str,
|
||||
additional_context: Optional[str]
|
||||
additional_context: Optional[str],
|
||||
lang_instruction: str = ""
|
||||
) -> str:
|
||||
"""构建用户消息"""
|
||||
|
||||
# 合并文本
|
||||
"""Build user message"""
|
||||
|
||||
# Merge texts
|
||||
combined_text = "\n\n---\n\n".join(document_texts)
|
||||
original_length = len(combined_text)
|
||||
|
||||
# 如果文本超过5万字,截断(仅影响传给LLM的内容,不影响图谱构建)
|
||||
|
||||
# If text exceeds 50,000 characters, truncate (only affects what is passed to LLM, not graph building)
|
||||
if len(combined_text) > self.MAX_TEXT_LENGTH_FOR_LLM:
|
||||
combined_text = combined_text[:self.MAX_TEXT_LENGTH_FOR_LLM]
|
||||
combined_text += f"\n\n...(原文共{original_length}字,已截取前{self.MAX_TEXT_LENGTH_FOR_LLM}字用于本体分析)..."
|
||||
|
||||
message = f"""## 模拟需求
|
||||
combined_text += f"\n\n...(text truncated at {self.MAX_TEXT_LENGTH_FOR_LLM} chars out of {original_length} total)..."
|
||||
|
||||
message = f"""## Simulation requirement
|
||||
|
||||
{simulation_requirement}
|
||||
|
||||
## 文档内容
|
||||
## Document content
|
||||
|
||||
{combined_text}
|
||||
"""
|
||||
|
||||
|
||||
if additional_context:
|
||||
message += f"""
|
||||
## 额外说明
|
||||
## Additional context
|
||||
|
||||
{additional_context}
|
||||
"""
|
||||
|
||||
message += """
|
||||
请根据以上内容,设计适合社会舆论模拟的实体类型和关系类型。
|
||||
|
||||
**必须遵守的规则**:
|
||||
1. 必须正好输出10个实体类型
|
||||
2. 最后2个必须是兜底类型:Person(个人兜底)和 Organization(组织兜底)
|
||||
3. 前8个是根据文本内容设计的具体类型
|
||||
4. 所有实体类型必须是现实中可以发声的主体,不能是抽象概念
|
||||
5. 属性名不能使用 name、uuid、group_id 等保留字,用 full_name、org_name 等替代
|
||||
message += f"""
|
||||
Based on the content above, design entity types and relationship types suitable for social opinion simulation.
|
||||
|
||||
**Mandatory rules**:
|
||||
1. Output exactly 10 entity types
|
||||
2. The last 2 must be fallback types: Person (individual fallback) and Organization (organization fallback)
|
||||
3. The first 8 are specific types designed from the document content
|
||||
4. All entity types must be real-world subjects capable of speaking out, not abstract concepts
|
||||
5. Attribute names must not use reserved words: name, uuid, group_id — use full_name, org_name, etc. instead
|
||||
|
||||
{lang_instruction}
|
||||
"""
|
||||
|
||||
|
||||
return message
|
||||
|
||||
|
||||
def _validate_and_process(self, result: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""验证和后处理结果"""
|
||||
|
||||
# 确保必要字段存在
|
||||
"""Validate and post-process the result"""
|
||||
|
||||
# Ensure required fields exist
|
||||
if "entity_types" not in result:
|
||||
result["entity_types"] = []
|
||||
if "edge_types" not in result:
|
||||
result["edge_types"] = []
|
||||
if "analysis_summary" not in result:
|
||||
result["analysis_summary"] = ""
|
||||
|
||||
# 验证实体类型
|
||||
# 记录原始名称到 PascalCase 的映射,用于后续修正 edge 的 source_targets 引用
|
||||
|
||||
# Validate entity types
|
||||
# Record mapping from original name to PascalCase for fixing edge source_targets references later
|
||||
entity_name_map = {}
|
||||
for entity in result["entity_types"]:
|
||||
# 强制将 entity name 转为 PascalCase(Zep API 要求)
|
||||
# Force entity name to PascalCase (required by Zep API)
|
||||
if "name" in entity:
|
||||
original_name = entity["name"]
|
||||
entity["name"] = _to_pascal_case(original_name)
|
||||
|
|
@ -300,19 +305,19 @@ class OntologyGenerator:
|
|||
entity["attributes"] = []
|
||||
if "examples" not in entity:
|
||||
entity["examples"] = []
|
||||
# 确保description不超过100字符
|
||||
# Ensure description does not exceed 100 characters
|
||||
if len(entity.get("description", "")) > 100:
|
||||
entity["description"] = entity["description"][:97] + "..."
|
||||
|
||||
# 验证关系类型
|
||||
|
||||
# Validate relationship types
|
||||
for edge in result["edge_types"]:
|
||||
# 强制将 edge name 转为 SCREAMING_SNAKE_CASE(Zep API 要求)
|
||||
# Force edge name to SCREAMING_SNAKE_CASE (required by Zep API)
|
||||
if "name" in edge:
|
||||
original_name = edge["name"]
|
||||
edge["name"] = original_name.upper()
|
||||
if edge["name"] != original_name:
|
||||
logger.warning(f"Edge type name '{original_name}' auto-converted to '{edge['name']}'")
|
||||
# 修正 source_targets 中的实体名称引用,与转换后的 PascalCase 保持一致
|
||||
# Fix entity name references in source_targets to match converted PascalCase names
|
||||
for st in edge.get("source_targets", []):
|
||||
if st.get("source") in entity_name_map:
|
||||
st["source"] = entity_name_map[st["source"]]
|
||||
|
|
@ -324,12 +329,12 @@ class OntologyGenerator:
|
|||
edge["attributes"] = []
|
||||
if len(edge.get("description", "")) > 100:
|
||||
edge["description"] = edge["description"][:97] + "..."
|
||||
|
||||
# Zep API 限制:最多 10 个自定义实体类型,最多 10 个自定义边类型
|
||||
|
||||
# Zep API limit: maximum 10 custom entity types and 10 custom edge types
|
||||
MAX_ENTITY_TYPES = 10
|
||||
MAX_EDGE_TYPES = 10
|
||||
|
||||
# 去重:按 name 去重,保留首次出现的
|
||||
# Deduplicate: keep first occurrence by name
|
||||
seen_names = set()
|
||||
deduped = []
|
||||
for entity in result["entity_types"]:
|
||||
|
|
@ -341,7 +346,7 @@ class OntologyGenerator:
|
|||
logger.warning(f"Duplicate entity type '{name}' removed during validation")
|
||||
result["entity_types"] = deduped
|
||||
|
||||
# 兜底类型定义
|
||||
# Fallback type definitions
|
||||
person_fallback = {
|
||||
"name": "Person",
|
||||
"description": "Any individual person not fitting other specific person types.",
|
||||
|
|
@ -351,7 +356,7 @@ class OntologyGenerator:
|
|||
],
|
||||
"examples": ["ordinary citizen", "anonymous netizen"]
|
||||
}
|
||||
|
||||
|
||||
organization_fallback = {
|
||||
"name": "Organization",
|
||||
"description": "Any organization not fitting other specific organization types.",
|
||||
|
|
@ -361,74 +366,74 @@ class OntologyGenerator:
|
|||
],
|
||||
"examples": ["small business", "community group"]
|
||||
}
|
||||
|
||||
# 检查是否已有兜底类型
|
||||
|
||||
# Check whether fallback types already exist
|
||||
entity_names = {e["name"] for e in result["entity_types"]}
|
||||
has_person = "Person" in entity_names
|
||||
has_organization = "Organization" in entity_names
|
||||
|
||||
# 需要添加的兜底类型
|
||||
|
||||
# Collect fallback types to add
|
||||
fallbacks_to_add = []
|
||||
if not has_person:
|
||||
fallbacks_to_add.append(person_fallback)
|
||||
if not has_organization:
|
||||
fallbacks_to_add.append(organization_fallback)
|
||||
|
||||
|
||||
if fallbacks_to_add:
|
||||
current_count = len(result["entity_types"])
|
||||
needed_slots = len(fallbacks_to_add)
|
||||
|
||||
# 如果添加后会超过 10 个,需要移除一些现有类型
|
||||
|
||||
# If adding them would exceed 10, remove some existing types
|
||||
if current_count + needed_slots > MAX_ENTITY_TYPES:
|
||||
# 计算需要移除多少个
|
||||
# Calculate how many to remove
|
||||
to_remove = current_count + needed_slots - MAX_ENTITY_TYPES
|
||||
# 从末尾移除(保留前面更重要的具体类型)
|
||||
# Remove from the end (preserve the more important specific types at the front)
|
||||
result["entity_types"] = result["entity_types"][:-to_remove]
|
||||
|
||||
# 添加兜底类型
|
||||
|
||||
# Add fallback types
|
||||
result["entity_types"].extend(fallbacks_to_add)
|
||||
|
||||
# 最终确保不超过限制(防御性编程)
|
||||
|
||||
# Final guard: ensure limits are not exceeded (defensive programming)
|
||||
if len(result["entity_types"]) > MAX_ENTITY_TYPES:
|
||||
result["entity_types"] = result["entity_types"][:MAX_ENTITY_TYPES]
|
||||
|
||||
|
||||
if len(result["edge_types"]) > MAX_EDGE_TYPES:
|
||||
result["edge_types"] = result["edge_types"][:MAX_EDGE_TYPES]
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def generate_python_code(self, ontology: Dict[str, Any]) -> str:
|
||||
"""
|
||||
将本体定义转换为Python代码(类似ontology.py)
|
||||
|
||||
Convert the ontology definition to Python code (similar to ontology.py).
|
||||
|
||||
Args:
|
||||
ontology: 本体定义
|
||||
|
||||
ontology: ontology definition
|
||||
|
||||
Returns:
|
||||
Python代码字符串
|
||||
Python code string
|
||||
"""
|
||||
code_lines = [
|
||||
'"""',
|
||||
'自定义实体类型定义',
|
||||
'由MiroFish自动生成,用于社会舆论模拟',
|
||||
'Custom entity type definitions',
|
||||
'Auto-generated by MiroFish for social opinion simulation',
|
||||
'"""',
|
||||
'',
|
||||
'from pydantic import Field',
|
||||
'from zep_cloud.external_clients.ontology import EntityModel, EntityText, EdgeModel',
|
||||
'',
|
||||
'',
|
||||
'# ============== 实体类型定义 ==============',
|
||||
'# ============== Entity type definitions ==============',
|
||||
'',
|
||||
]
|
||||
|
||||
# 生成实体类型
|
||||
|
||||
# Generate entity types
|
||||
for entity in ontology.get("entity_types", []):
|
||||
name = entity["name"]
|
||||
desc = entity.get("description", f"A {name} entity.")
|
||||
|
||||
|
||||
code_lines.append(f'class {name}(EntityModel):')
|
||||
code_lines.append(f' """{desc}"""')
|
||||
|
||||
|
||||
attrs = entity.get("attributes", [])
|
||||
if attrs:
|
||||
for attr in attrs:
|
||||
|
|
@ -440,23 +445,23 @@ class OntologyGenerator:
|
|||
code_lines.append(f' )')
|
||||
else:
|
||||
code_lines.append(' pass')
|
||||
|
||||
|
||||
code_lines.append('')
|
||||
code_lines.append('')
|
||||
|
||||
code_lines.append('# ============== 关系类型定义 ==============')
|
||||
|
||||
code_lines.append('# ============== Relationship type definitions ==============')
|
||||
code_lines.append('')
|
||||
|
||||
# 生成关系类型
|
||||
|
||||
# Generate relationship types
|
||||
for edge in ontology.get("edge_types", []):
|
||||
name = edge["name"]
|
||||
# 转换为PascalCase类名
|
||||
# Convert to PascalCase class name
|
||||
class_name = ''.join(word.capitalize() for word in name.split('_'))
|
||||
desc = edge.get("description", f"A {name} relationship.")
|
||||
|
||||
|
||||
code_lines.append(f'class {class_name}(EdgeModel):')
|
||||
code_lines.append(f' """{desc}"""')
|
||||
|
||||
|
||||
attrs = edge.get("attributes", [])
|
||||
if attrs:
|
||||
for attr in attrs:
|
||||
|
|
@ -468,12 +473,12 @@ class OntologyGenerator:
|
|||
code_lines.append(f' )')
|
||||
else:
|
||||
code_lines.append(' pass')
|
||||
|
||||
|
||||
code_lines.append('')
|
||||
code_lines.append('')
|
||||
|
||||
# 生成类型字典
|
||||
code_lines.append('# ============== 类型配置 ==============')
|
||||
|
||||
# Generate type dictionaries
|
||||
code_lines.append('# ============== Type configuration ==============')
|
||||
code_lines.append('')
|
||||
code_lines.append('ENTITY_TYPES = {')
|
||||
for entity in ontology.get("entity_types", []):
|
||||
|
|
@ -488,8 +493,8 @@ class OntologyGenerator:
|
|||
code_lines.append(f' "{name}": {class_name},')
|
||||
code_lines.append('}')
|
||||
code_lines.append('')
|
||||
|
||||
# 生成边的source_targets映射
|
||||
|
||||
# Generate edge source_targets mapping
|
||||
code_lines.append('EDGE_SOURCE_TARGETS = {')
|
||||
for edge in ontology.get("edge_types", []):
|
||||
name = edge["name"]
|
||||
|
|
@ -501,6 +506,5 @@ class OntologyGenerator:
|
|||
])
|
||||
code_lines.append(f' "{name}": [{st_list}],')
|
||||
code_lines.append('}')
|
||||
|
||||
return '\n'.join(code_lines)
|
||||
|
||||
return '\n'.join(code_lines)
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
|
|
@ -1,11 +1,11 @@
|
|||
"""
|
||||
模拟IPC通信模块
|
||||
用于Flask后端和模拟脚本之间的进程间通信
|
||||
Simulation IPC communication module
|
||||
Used for inter-process communication between the Flask backend and simulation scripts.
|
||||
|
||||
通过文件系统实现简单的命令/响应模式:
|
||||
1. Flask写入命令到 commands/ 目录
|
||||
2. 模拟脚本轮询命令目录,执行命令并写入响应到 responses/ 目录
|
||||
3. Flask轮询响应目录获取结果
|
||||
Implements a simple command/response pattern via the file system:
|
||||
1. Flask writes commands to the commands/ directory
|
||||
2. Simulation scripts poll the command directory, execute commands, and write responses to the responses/ directory
|
||||
3. Flask polls the response directory to get results
|
||||
"""
|
||||
|
||||
import os
|
||||
|
|
@ -23,14 +23,14 @@ logger = get_logger('mirofish.simulation_ipc')
|
|||
|
||||
|
||||
class CommandType(str, Enum):
|
||||
"""命令类型"""
|
||||
INTERVIEW = "interview" # 单个Agent采访
|
||||
BATCH_INTERVIEW = "batch_interview" # 批量采访
|
||||
CLOSE_ENV = "close_env" # 关闭环境
|
||||
"""Command type"""
|
||||
INTERVIEW = "interview" # Single agent interview
|
||||
BATCH_INTERVIEW = "batch_interview" # Batch interview
|
||||
CLOSE_ENV = "close_env" # Close environment
|
||||
|
||||
|
||||
class CommandStatus(str, Enum):
|
||||
"""命令状态"""
|
||||
"""Command status"""
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
COMPLETED = "completed"
|
||||
|
|
@ -39,12 +39,12 @@ class CommandStatus(str, Enum):
|
|||
|
||||
@dataclass
|
||||
class IPCCommand:
|
||||
"""IPC命令"""
|
||||
"""IPC command"""
|
||||
command_id: str
|
||||
command_type: CommandType
|
||||
args: Dict[str, Any]
|
||||
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"command_id": self.command_id,
|
||||
|
|
@ -52,7 +52,7 @@ class IPCCommand:
|
|||
"args": self.args,
|
||||
"timestamp": self.timestamp
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'IPCCommand':
|
||||
return cls(
|
||||
|
|
@ -65,13 +65,13 @@ class IPCCommand:
|
|||
|
||||
@dataclass
|
||||
class IPCResponse:
|
||||
"""IPC响应"""
|
||||
"""IPC response"""
|
||||
command_id: str
|
||||
status: CommandStatus
|
||||
result: Optional[Dict[str, Any]] = None
|
||||
error: Optional[str] = None
|
||||
timestamp: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"command_id": self.command_id,
|
||||
|
|
@ -80,7 +80,7 @@ class IPCResponse:
|
|||
"error": self.error,
|
||||
"timestamp": self.timestamp
|
||||
}
|
||||
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'IPCResponse':
|
||||
return cls(
|
||||
|
|
@ -94,26 +94,26 @@ class IPCResponse:
|
|||
|
||||
class SimulationIPCClient:
|
||||
"""
|
||||
模拟IPC客户端(Flask端使用)
|
||||
|
||||
用于向模拟进程发送命令并等待响应
|
||||
Simulation IPC client (used by the Flask side)
|
||||
|
||||
Used to send commands to the simulation process and wait for responses
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, simulation_dir: str):
|
||||
"""
|
||||
初始化IPC客户端
|
||||
|
||||
Initialize the IPC client
|
||||
|
||||
Args:
|
||||
simulation_dir: 模拟数据目录
|
||||
simulation_dir: simulation data directory
|
||||
"""
|
||||
self.simulation_dir = simulation_dir
|
||||
self.commands_dir = os.path.join(simulation_dir, "ipc_commands")
|
||||
self.responses_dir = os.path.join(simulation_dir, "ipc_responses")
|
||||
|
||||
# 确保目录存在
|
||||
|
||||
# Ensure directories exist
|
||||
os.makedirs(self.commands_dir, exist_ok=True)
|
||||
os.makedirs(self.responses_dir, exist_ok=True)
|
||||
|
||||
|
||||
def send_command(
|
||||
self,
|
||||
command_type: CommandType,
|
||||
|
|
@ -122,19 +122,19 @@ class SimulationIPCClient:
|
|||
poll_interval: float = 0.5
|
||||
) -> IPCResponse:
|
||||
"""
|
||||
发送命令并等待响应
|
||||
|
||||
Send a command and wait for a response
|
||||
|
||||
Args:
|
||||
command_type: 命令类型
|
||||
args: 命令参数
|
||||
timeout: 超时时间(秒)
|
||||
poll_interval: 轮询间隔(秒)
|
||||
|
||||
command_type: command type
|
||||
args: command arguments
|
||||
timeout: timeout in seconds
|
||||
poll_interval: polling interval in seconds
|
||||
|
||||
Returns:
|
||||
IPCResponse
|
||||
|
||||
|
||||
Raises:
|
||||
TimeoutError: 等待响应超时
|
||||
TimeoutError: timed out waiting for a response
|
||||
"""
|
||||
command_id = str(uuid.uuid4())
|
||||
command = IPCCommand(
|
||||
|
|
@ -142,50 +142,50 @@ class SimulationIPCClient:
|
|||
command_type=command_type,
|
||||
args=args
|
||||
)
|
||||
|
||||
# 写入命令文件
|
||||
|
||||
# Write command file
|
||||
command_file = os.path.join(self.commands_dir, f"{command_id}.json")
|
||||
with open(command_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(command.to_dict(), f, ensure_ascii=False, indent=2)
|
||||
|
||||
logger.info(f"发送IPC命令: {command_type.value}, command_id={command_id}")
|
||||
|
||||
# 等待响应
|
||||
|
||||
logger.info(f"Sending IPC command: {command_type.value}, command_id={command_id}")
|
||||
|
||||
# Wait for response
|
||||
response_file = os.path.join(self.responses_dir, f"{command_id}.json")
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
while time.time() - start_time < timeout:
|
||||
if os.path.exists(response_file):
|
||||
try:
|
||||
with open(response_file, 'r', encoding='utf-8') as f:
|
||||
response_data = json.load(f)
|
||||
response = IPCResponse.from_dict(response_data)
|
||||
|
||||
# 清理命令和响应文件
|
||||
|
||||
# Clean up command and response files
|
||||
try:
|
||||
os.remove(command_file)
|
||||
os.remove(response_file)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
logger.info(f"收到IPC响应: command_id={command_id}, status={response.status.value}")
|
||||
|
||||
logger.info(f"Received IPC response: command_id={command_id}, status={response.status.value}")
|
||||
return response
|
||||
except (json.JSONDecodeError, KeyError) as e:
|
||||
logger.warning(f"解析响应失败: {e}")
|
||||
|
||||
logger.warning(f"Failed to parse response: {e}")
|
||||
|
||||
time.sleep(poll_interval)
|
||||
|
||||
# 超时
|
||||
logger.error(f"等待IPC响应超时: command_id={command_id}")
|
||||
|
||||
# 清理命令文件
|
||||
|
||||
# Timeout
|
||||
logger.error(f"Timed out waiting for IPC response: command_id={command_id}")
|
||||
|
||||
# Clean up command file
|
||||
try:
|
||||
os.remove(command_file)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
raise TimeoutError(f"等待命令响应超时 ({timeout}秒)")
|
||||
|
||||
|
||||
raise TimeoutError(f"Timed out waiting for command response ({timeout}s)")
|
||||
|
||||
def send_interview(
|
||||
self,
|
||||
agent_id: int,
|
||||
|
|
@ -194,19 +194,19 @@ class SimulationIPCClient:
|
|||
timeout: float = 60.0
|
||||
) -> IPCResponse:
|
||||
"""
|
||||
发送单个Agent采访命令
|
||||
|
||||
Send a single agent interview command
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID
|
||||
prompt: 采访问题
|
||||
platform: 指定平台(可选)
|
||||
- "twitter": 只采访Twitter平台
|
||||
- "reddit": 只采访Reddit平台
|
||||
- None: 双平台模拟时同时采访两个平台,单平台模拟时采访该平台
|
||||
timeout: 超时时间
|
||||
|
||||
prompt: interview question
|
||||
platform: target platform (optional)
|
||||
- "twitter": interview only the Twitter platform
|
||||
- "reddit": interview only the Reddit platform
|
||||
- None: in dual-platform mode, interview both; in single-platform mode, interview that platform
|
||||
timeout: timeout in seconds
|
||||
|
||||
Returns:
|
||||
IPCResponse,result字段包含采访结果
|
||||
IPCResponse with interview result in the result field
|
||||
"""
|
||||
args = {
|
||||
"agent_id": agent_id,
|
||||
|
|
@ -214,13 +214,13 @@ class SimulationIPCClient:
|
|||
}
|
||||
if platform:
|
||||
args["platform"] = platform
|
||||
|
||||
|
||||
return self.send_command(
|
||||
command_type=CommandType.INTERVIEW,
|
||||
args=args,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
|
||||
def send_batch_interview(
|
||||
self,
|
||||
interviews: List[Dict[str, Any]],
|
||||
|
|
@ -228,36 +228,36 @@ class SimulationIPCClient:
|
|||
timeout: float = 120.0
|
||||
) -> IPCResponse:
|
||||
"""
|
||||
发送批量采访命令
|
||||
|
||||
Send a batch interview command
|
||||
|
||||
Args:
|
||||
interviews: 采访列表,每个元素包含 {"agent_id": int, "prompt": str, "platform": str(可选)}
|
||||
platform: 默认平台(可选,会被每个采访项的platform覆盖)
|
||||
- "twitter": 默认只采访Twitter平台
|
||||
- "reddit": 默认只采访Reddit平台
|
||||
- None: 双平台模拟时每个Agent同时采访两个平台
|
||||
timeout: 超时时间
|
||||
|
||||
interviews: list of interviews, each containing {"agent_id": int, "prompt": str, "platform": str (optional)}
|
||||
platform: default platform (optional; overridden per-item by each interview's platform)
|
||||
- "twitter": default to Twitter platform only
|
||||
- "reddit": default to Reddit platform only
|
||||
- None: in dual-platform mode, interview each agent on both platforms
|
||||
timeout: timeout in seconds
|
||||
|
||||
Returns:
|
||||
IPCResponse,result字段包含所有采访结果
|
||||
IPCResponse with all interview results in the result field
|
||||
"""
|
||||
args = {"interviews": interviews}
|
||||
if platform:
|
||||
args["platform"] = platform
|
||||
|
||||
|
||||
return self.send_command(
|
||||
command_type=CommandType.BATCH_INTERVIEW,
|
||||
args=args,
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
|
||||
def send_close_env(self, timeout: float = 30.0) -> IPCResponse:
|
||||
"""
|
||||
发送关闭环境命令
|
||||
|
||||
Send a close-environment command
|
||||
|
||||
Args:
|
||||
timeout: 超时时间
|
||||
|
||||
timeout: timeout in seconds
|
||||
|
||||
Returns:
|
||||
IPCResponse
|
||||
"""
|
||||
|
|
@ -266,17 +266,17 @@ class SimulationIPCClient:
|
|||
args={},
|
||||
timeout=timeout
|
||||
)
|
||||
|
||||
|
||||
def check_env_alive(self) -> bool:
|
||||
"""
|
||||
检查模拟环境是否存活
|
||||
|
||||
通过检查 env_status.json 文件来判断
|
||||
Check whether the simulation environment is alive
|
||||
|
||||
Determined by checking the env_status.json file
|
||||
"""
|
||||
status_file = os.path.join(self.simulation_dir, "env_status.json")
|
||||
if not os.path.exists(status_file):
|
||||
return False
|
||||
|
||||
|
||||
try:
|
||||
with open(status_file, 'r', encoding='utf-8') as f:
|
||||
status = json.load(f)
|
||||
|
|
@ -287,106 +287,106 @@ class SimulationIPCClient:
|
|||
|
||||
class SimulationIPCServer:
|
||||
"""
|
||||
模拟IPC服务器(模拟脚本端使用)
|
||||
|
||||
轮询命令目录,执行命令并返回响应
|
||||
Simulation IPC server (used by the simulation script side)
|
||||
|
||||
Polls the command directory, executes commands, and returns responses
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, simulation_dir: str):
|
||||
"""
|
||||
初始化IPC服务器
|
||||
|
||||
Initialize the IPC server
|
||||
|
||||
Args:
|
||||
simulation_dir: 模拟数据目录
|
||||
simulation_dir: simulation data directory
|
||||
"""
|
||||
self.simulation_dir = simulation_dir
|
||||
self.commands_dir = os.path.join(simulation_dir, "ipc_commands")
|
||||
self.responses_dir = os.path.join(simulation_dir, "ipc_responses")
|
||||
|
||||
# 确保目录存在
|
||||
|
||||
# Ensure directories exist
|
||||
os.makedirs(self.commands_dir, exist_ok=True)
|
||||
os.makedirs(self.responses_dir, exist_ok=True)
|
||||
|
||||
# 环境状态
|
||||
|
||||
# Environment status
|
||||
self._running = False
|
||||
|
||||
|
||||
def start(self):
|
||||
"""标记服务器为运行状态"""
|
||||
"""Mark the server as running"""
|
||||
self._running = True
|
||||
self._update_env_status("alive")
|
||||
|
||||
|
||||
def stop(self):
|
||||
"""标记服务器为停止状态"""
|
||||
"""Mark the server as stopped"""
|
||||
self._running = False
|
||||
self._update_env_status("stopped")
|
||||
|
||||
|
||||
def _update_env_status(self, status: str):
|
||||
"""更新环境状态文件"""
|
||||
"""Update the environment status file"""
|
||||
status_file = os.path.join(self.simulation_dir, "env_status.json")
|
||||
with open(status_file, 'w', encoding='utf-8') as f:
|
||||
json.dump({
|
||||
"status": status,
|
||||
"timestamp": datetime.now().isoformat()
|
||||
}, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
def poll_commands(self) -> Optional[IPCCommand]:
|
||||
"""
|
||||
轮询命令目录,返回第一个待处理的命令
|
||||
|
||||
Poll the command directory and return the first pending command
|
||||
|
||||
Returns:
|
||||
IPCCommand 或 None
|
||||
IPCCommand or None
|
||||
"""
|
||||
if not os.path.exists(self.commands_dir):
|
||||
return None
|
||||
|
||||
# 按时间排序获取命令文件
|
||||
|
||||
# Get command files sorted by modification time
|
||||
command_files = []
|
||||
for filename in os.listdir(self.commands_dir):
|
||||
if filename.endswith('.json'):
|
||||
filepath = os.path.join(self.commands_dir, filename)
|
||||
command_files.append((filepath, os.path.getmtime(filepath)))
|
||||
|
||||
|
||||
command_files.sort(key=lambda x: x[1])
|
||||
|
||||
|
||||
for filepath, _ in command_files:
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
return IPCCommand.from_dict(data)
|
||||
except (json.JSONDecodeError, KeyError, OSError) as e:
|
||||
logger.warning(f"读取命令文件失败: {filepath}, {e}")
|
||||
logger.warning(f"Failed to read command file: {filepath}, {e}")
|
||||
continue
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def send_response(self, response: IPCResponse):
|
||||
"""
|
||||
发送响应
|
||||
|
||||
Send a response
|
||||
|
||||
Args:
|
||||
response: IPC响应
|
||||
response: IPC response
|
||||
"""
|
||||
response_file = os.path.join(self.responses_dir, f"{response.command_id}.json")
|
||||
with open(response_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(response.to_dict(), f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 删除命令文件
|
||||
|
||||
# Delete the command file
|
||||
command_file = os.path.join(self.commands_dir, f"{response.command_id}.json")
|
||||
try:
|
||||
os.remove(command_file)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
def send_success(self, command_id: str, result: Dict[str, Any]):
|
||||
"""发送成功响应"""
|
||||
"""Send a success response"""
|
||||
self.send_response(IPCResponse(
|
||||
command_id=command_id,
|
||||
status=CommandStatus.COMPLETED,
|
||||
result=result
|
||||
))
|
||||
|
||||
|
||||
def send_error(self, command_id: str, error: str):
|
||||
"""发送错误响应"""
|
||||
"""Send an error response"""
|
||||
self.send_response(IPCResponse(
|
||||
command_id=command_id,
|
||||
status=CommandStatus.FAILED,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""
|
||||
OASIS模拟管理器
|
||||
管理Twitter和Reddit双平台并行模拟
|
||||
使用预设脚本 + LLM智能生成配置参数
|
||||
OASIS simulation manager
|
||||
Manages parallel simulation on both Twitter and Reddit platforms.
|
||||
Uses preset scripts with LLM-generated configuration parameters.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
|
@ -23,60 +23,60 @@ logger = get_logger('mirofish.simulation')
|
|||
|
||||
|
||||
class SimulationStatus(str, Enum):
|
||||
"""模拟状态"""
|
||||
"""Simulation status"""
|
||||
CREATED = "created"
|
||||
PREPARING = "preparing"
|
||||
READY = "ready"
|
||||
RUNNING = "running"
|
||||
PAUSED = "paused"
|
||||
STOPPED = "stopped" # 模拟被手动停止
|
||||
COMPLETED = "completed" # 模拟自然完成
|
||||
STOPPED = "stopped" # Simulation manually stopped
|
||||
COMPLETED = "completed" # Simulation naturally completed
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class PlatformType(str, Enum):
|
||||
"""平台类型"""
|
||||
"""Platform type"""
|
||||
TWITTER = "twitter"
|
||||
REDDIT = "reddit"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimulationState:
|
||||
"""模拟状态"""
|
||||
"""Simulation state"""
|
||||
simulation_id: str
|
||||
project_id: str
|
||||
graph_id: str
|
||||
|
||||
# 平台启用状态
|
||||
|
||||
# Platform enable flags
|
||||
enable_twitter: bool = True
|
||||
enable_reddit: bool = True
|
||||
|
||||
# 状态
|
||||
|
||||
# Status
|
||||
status: SimulationStatus = SimulationStatus.CREATED
|
||||
|
||||
# 准备阶段数据
|
||||
|
||||
# Preparation phase data
|
||||
entities_count: int = 0
|
||||
profiles_count: int = 0
|
||||
entity_types: List[str] = field(default_factory=list)
|
||||
|
||||
# 配置生成信息
|
||||
|
||||
# Config generation info
|
||||
config_generated: bool = False
|
||||
config_reasoning: str = ""
|
||||
|
||||
# 运行时数据
|
||||
|
||||
# Runtime data
|
||||
current_round: int = 0
|
||||
twitter_status: str = "not_started"
|
||||
reddit_status: str = "not_started"
|
||||
|
||||
# 时间戳
|
||||
|
||||
# Timestamps
|
||||
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||
|
||||
# 错误信息
|
||||
|
||||
# Error message
|
||||
error: Optional[str] = None
|
||||
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""完整状态字典(内部使用)"""
|
||||
"""Full state dictionary (internal use)"""
|
||||
return {
|
||||
"simulation_id": self.simulation_id,
|
||||
"project_id": self.project_id,
|
||||
|
|
@ -96,9 +96,9 @@ class SimulationState:
|
|||
"updated_at": self.updated_at,
|
||||
"error": self.error,
|
||||
}
|
||||
|
||||
|
||||
def to_simple_dict(self) -> Dict[str, Any]:
|
||||
"""简化状态字典(API返回使用)"""
|
||||
"""Simplified state dictionary (used for API responses)"""
|
||||
return {
|
||||
"simulation_id": self.simulation_id,
|
||||
"project_id": self.project_id,
|
||||
|
|
@ -114,60 +114,60 @@ class SimulationState:
|
|||
|
||||
class SimulationManager:
|
||||
"""
|
||||
模拟管理器
|
||||
|
||||
核心功能:
|
||||
1. 从Zep图谱读取实体并过滤
|
||||
2. 生成OASIS Agent Profile
|
||||
3. 使用LLM智能生成模拟配置参数
|
||||
4. 准备预设脚本所需的所有文件
|
||||
Simulation manager
|
||||
|
||||
Core functions:
|
||||
1. Read and filter entities from the Zep graph
|
||||
2. Generate OASIS Agent Profiles
|
||||
3. Use LLM to intelligently generate simulation configuration parameters
|
||||
4. Prepare all files required by the preset scripts
|
||||
"""
|
||||
|
||||
# 模拟数据存储目录
|
||||
|
||||
# Simulation data storage directory
|
||||
SIMULATION_DATA_DIR = os.path.join(
|
||||
os.path.dirname(__file__),
|
||||
os.path.dirname(__file__),
|
||||
'../../uploads/simulations'
|
||||
)
|
||||
|
||||
|
||||
def __init__(self):
|
||||
# 确保目录存在
|
||||
# Ensure directory exists
|
||||
os.makedirs(self.SIMULATION_DATA_DIR, exist_ok=True)
|
||||
|
||||
# 内存中的模拟状态缓存
|
||||
|
||||
# In-memory simulation state cache
|
||||
self._simulations: Dict[str, SimulationState] = {}
|
||||
|
||||
|
||||
def _get_simulation_dir(self, simulation_id: str) -> str:
|
||||
"""获取模拟数据目录"""
|
||||
"""Get the simulation data directory"""
|
||||
sim_dir = os.path.join(self.SIMULATION_DATA_DIR, simulation_id)
|
||||
os.makedirs(sim_dir, exist_ok=True)
|
||||
return sim_dir
|
||||
|
||||
|
||||
def _save_simulation_state(self, state: SimulationState):
|
||||
"""保存模拟状态到文件"""
|
||||
"""Save simulation state to file"""
|
||||
sim_dir = self._get_simulation_dir(state.simulation_id)
|
||||
state_file = os.path.join(sim_dir, "state.json")
|
||||
|
||||
|
||||
state.updated_at = datetime.now().isoformat()
|
||||
|
||||
|
||||
with open(state_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(state.to_dict(), f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
self._simulations[state.simulation_id] = state
|
||||
|
||||
|
||||
def _load_simulation_state(self, simulation_id: str) -> Optional[SimulationState]:
|
||||
"""从文件加载模拟状态"""
|
||||
"""Load simulation state from file"""
|
||||
if simulation_id in self._simulations:
|
||||
return self._simulations[simulation_id]
|
||||
|
||||
|
||||
sim_dir = self._get_simulation_dir(simulation_id)
|
||||
state_file = os.path.join(sim_dir, "state.json")
|
||||
|
||||
|
||||
if not os.path.exists(state_file):
|
||||
return None
|
||||
|
||||
|
||||
with open(state_file, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
|
||||
|
||||
state = SimulationState(
|
||||
simulation_id=simulation_id,
|
||||
project_id=data.get("project_id", ""),
|
||||
|
|
@ -187,10 +187,10 @@ class SimulationManager:
|
|||
updated_at=data.get("updated_at", datetime.now().isoformat()),
|
||||
error=data.get("error"),
|
||||
)
|
||||
|
||||
|
||||
self._simulations[simulation_id] = state
|
||||
return state
|
||||
|
||||
|
||||
def create_simulation(
|
||||
self,
|
||||
project_id: str,
|
||||
|
|
@ -199,20 +199,20 @@ class SimulationManager:
|
|||
enable_reddit: bool = True,
|
||||
) -> SimulationState:
|
||||
"""
|
||||
创建新的模拟
|
||||
|
||||
Create a new simulation.
|
||||
|
||||
Args:
|
||||
project_id: 项目ID
|
||||
graph_id: Zep图谱ID
|
||||
enable_twitter: 是否启用Twitter模拟
|
||||
enable_reddit: 是否启用Reddit模拟
|
||||
|
||||
project_id: project ID
|
||||
graph_id: Zep graph ID
|
||||
enable_twitter: whether to enable Twitter simulation
|
||||
enable_reddit: whether to enable Reddit simulation
|
||||
|
||||
Returns:
|
||||
SimulationState
|
||||
"""
|
||||
import uuid
|
||||
simulation_id = f"sim_{uuid.uuid4().hex[:12]}"
|
||||
|
||||
|
||||
state = SimulationState(
|
||||
simulation_id=simulation_id,
|
||||
project_id=project_id,
|
||||
|
|
@ -221,12 +221,12 @@ class SimulationManager:
|
|||
enable_reddit=enable_reddit,
|
||||
status=SimulationStatus.CREATED,
|
||||
)
|
||||
|
||||
|
||||
self._save_simulation_state(state)
|
||||
logger.info(f"创建模拟: {simulation_id}, project={project_id}, graph={graph_id}")
|
||||
|
||||
logger.info(f"Simulation created: {simulation_id}, project={project_id}, graph={graph_id}")
|
||||
|
||||
return state
|
||||
|
||||
|
||||
def prepare_simulation(
|
||||
self,
|
||||
simulation_id: str,
|
||||
|
|
@ -238,55 +238,55 @@ class SimulationManager:
|
|||
parallel_profile_count: int = 3
|
||||
) -> SimulationState:
|
||||
"""
|
||||
准备模拟环境(全程自动化)
|
||||
|
||||
步骤:
|
||||
1. 从Zep图谱读取并过滤实体
|
||||
2. 为每个实体生成OASIS Agent Profile(可选LLM增强,支持并行)
|
||||
3. 使用LLM智能生成模拟配置参数(时间、活跃度、发言频率等)
|
||||
4. 保存配置文件和Profile文件
|
||||
5. 复制预设脚本到模拟目录
|
||||
|
||||
Prepare the simulation environment (fully automated).
|
||||
|
||||
Steps:
|
||||
1. Read and filter entities from the Zep graph
|
||||
2. Generate an OASIS Agent Profile for each entity (optional LLM enhancement, supports parallelism)
|
||||
3. Use LLM to intelligently generate simulation configuration parameters (time, activity level, posting frequency, etc.)
|
||||
4. Save configuration files and profile files
|
||||
5. Copy preset scripts to the simulation directory
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
simulation_requirement: 模拟需求描述(用于LLM生成配置)
|
||||
document_text: 原始文档内容(用于LLM理解背景)
|
||||
defined_entity_types: 预定义的实体类型(可选)
|
||||
use_llm_for_profiles: 是否使用LLM生成详细人设
|
||||
progress_callback: 进度回调函数 (stage, progress, message)
|
||||
parallel_profile_count: 并行生成人设的数量,默认3
|
||||
|
||||
simulation_id: simulation ID
|
||||
simulation_requirement: simulation requirement description (used for LLM config generation)
|
||||
document_text: original document content (used for LLM background understanding)
|
||||
defined_entity_types: predefined entity types (optional)
|
||||
use_llm_for_profiles: whether to use LLM to generate detailed personas
|
||||
progress_callback: progress callback function (stage, progress, message)
|
||||
parallel_profile_count: number of profiles to generate in parallel, default 3
|
||||
|
||||
Returns:
|
||||
SimulationState
|
||||
"""
|
||||
state = self._load_simulation_state(simulation_id)
|
||||
if not state:
|
||||
raise ValueError(f"模拟不存在: {simulation_id}")
|
||||
|
||||
raise ValueError(f"Simulation not found: {simulation_id}")
|
||||
|
||||
try:
|
||||
state.status = SimulationStatus.PREPARING
|
||||
self._save_simulation_state(state)
|
||||
|
||||
|
||||
sim_dir = self._get_simulation_dir(simulation_id)
|
||||
|
||||
# ========== 阶段1: 读取并过滤实体 ==========
|
||||
|
||||
# ========== Stage 1: Read and filter entities ==========
|
||||
if progress_callback:
|
||||
progress_callback("reading", 0, t('progress.connectingZepGraph'))
|
||||
|
||||
|
||||
reader = ZepEntityReader()
|
||||
|
||||
|
||||
if progress_callback:
|
||||
progress_callback("reading", 30, t('progress.readingNodeData'))
|
||||
|
||||
|
||||
filtered = reader.filter_defined_entities(
|
||||
graph_id=state.graph_id,
|
||||
defined_entity_types=defined_entity_types,
|
||||
enrich_with_edges=True
|
||||
)
|
||||
|
||||
|
||||
state.entities_count = filtered.filtered_count
|
||||
state.entity_types = list(filtered.entity_types)
|
||||
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"reading", 100,
|
||||
|
|
@ -294,16 +294,16 @@ class SimulationManager:
|
|||
current=filtered.filtered_count,
|
||||
total=filtered.filtered_count
|
||||
)
|
||||
|
||||
|
||||
if filtered.filtered_count == 0:
|
||||
state.status = SimulationStatus.FAILED
|
||||
state.error = "没有找到符合条件的实体,请检查图谱是否正确构建"
|
||||
state.error = "No qualifying entities found. Please check that the graph was built correctly."
|
||||
self._save_simulation_state(state)
|
||||
return state
|
||||
|
||||
# ========== 阶段2: 生成Agent Profile ==========
|
||||
|
||||
# ========== Stage 2: Generate Agent Profiles ==========
|
||||
total_entities = len(filtered.entities)
|
||||
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"generating_profiles", 0,
|
||||
|
|
@ -311,22 +311,22 @@ class SimulationManager:
|
|||
current=0,
|
||||
total=total_entities
|
||||
)
|
||||
|
||||
# 传入graph_id以启用Zep检索功能,获取更丰富的上下文
|
||||
|
||||
# Pass graph_id to enable Zep retrieval for richer context
|
||||
generator = OasisProfileGenerator(graph_id=state.graph_id)
|
||||
|
||||
|
||||
def profile_progress(current, total, msg):
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"generating_profiles",
|
||||
int(current / total * 100),
|
||||
"generating_profiles",
|
||||
int(current / total * 100),
|
||||
msg,
|
||||
current=current,
|
||||
total=total,
|
||||
item_name=msg
|
||||
)
|
||||
|
||||
# 设置实时保存的文件路径(优先使用 Reddit JSON 格式)
|
||||
|
||||
# Set real-time save path (prefer Reddit JSON format)
|
||||
realtime_output_path = None
|
||||
realtime_platform = "reddit"
|
||||
if state.enable_reddit:
|
||||
|
|
@ -335,21 +335,21 @@ class SimulationManager:
|
|||
elif state.enable_twitter:
|
||||
realtime_output_path = os.path.join(sim_dir, "twitter_profiles.csv")
|
||||
realtime_platform = "twitter"
|
||||
|
||||
|
||||
profiles = generator.generate_profiles_from_entities(
|
||||
entities=filtered.entities,
|
||||
use_llm=use_llm_for_profiles,
|
||||
progress_callback=profile_progress,
|
||||
graph_id=state.graph_id, # 传入graph_id用于Zep检索
|
||||
parallel_count=parallel_profile_count, # 并行生成数量
|
||||
realtime_output_path=realtime_output_path, # 实时保存路径
|
||||
output_platform=realtime_platform # 输出格式
|
||||
graph_id=state.graph_id, # Pass graph_id for Zep retrieval
|
||||
parallel_count=parallel_profile_count, # Parallel generation count
|
||||
realtime_output_path=realtime_output_path, # Real-time save path
|
||||
output_platform=realtime_platform # Output format
|
||||
)
|
||||
|
||||
|
||||
state.profiles_count = len(profiles)
|
||||
|
||||
# 保存Profile文件(注意:Twitter使用CSV格式,Reddit使用JSON格式)
|
||||
# Reddit 已经在生成过程中实时保存了,这里再保存一次确保完整性
|
||||
|
||||
# Save profile files (note: Twitter uses CSV format, Reddit uses JSON format)
|
||||
# Reddit has already been saved incrementally during generation; save once more to ensure completeness
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"generating_profiles", 95,
|
||||
|
|
@ -357,22 +357,22 @@ class SimulationManager:
|
|||
current=total_entities,
|
||||
total=total_entities
|
||||
)
|
||||
|
||||
|
||||
if state.enable_reddit:
|
||||
generator.save_profiles(
|
||||
profiles=profiles,
|
||||
file_path=os.path.join(sim_dir, "reddit_profiles.json"),
|
||||
platform="reddit"
|
||||
)
|
||||
|
||||
|
||||
if state.enable_twitter:
|
||||
# Twitter使用CSV格式!这是OASIS的要求
|
||||
# Twitter uses CSV format — this is a requirement of OASIS
|
||||
generator.save_profiles(
|
||||
profiles=profiles,
|
||||
file_path=os.path.join(sim_dir, "twitter_profiles.csv"),
|
||||
platform="twitter"
|
||||
)
|
||||
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"generating_profiles", 100,
|
||||
|
|
@ -380,8 +380,8 @@ class SimulationManager:
|
|||
current=len(profiles),
|
||||
total=len(profiles)
|
||||
)
|
||||
|
||||
# ========== 阶段3: LLM智能生成模拟配置 ==========
|
||||
|
||||
# ========== Stage 3: LLM intelligent simulation configuration generation ==========
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"generating_config", 0,
|
||||
|
|
@ -389,9 +389,9 @@ class SimulationManager:
|
|||
current=0,
|
||||
total=3
|
||||
)
|
||||
|
||||
|
||||
config_generator = SimulationConfigGenerator()
|
||||
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"generating_config", 30,
|
||||
|
|
@ -399,7 +399,7 @@ class SimulationManager:
|
|||
current=1,
|
||||
total=3
|
||||
)
|
||||
|
||||
|
||||
sim_params = config_generator.generate_config(
|
||||
simulation_id=simulation_id,
|
||||
project_id=state.project_id,
|
||||
|
|
@ -410,7 +410,7 @@ class SimulationManager:
|
|||
enable_twitter=state.enable_twitter,
|
||||
enable_reddit=state.enable_reddit
|
||||
)
|
||||
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"generating_config", 70,
|
||||
|
|
@ -418,15 +418,15 @@ class SimulationManager:
|
|||
current=2,
|
||||
total=3
|
||||
)
|
||||
|
||||
# 保存配置文件
|
||||
|
||||
# Save configuration file
|
||||
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||||
with open(config_path, 'w', encoding='utf-8') as f:
|
||||
f.write(sim_params.to_json())
|
||||
|
||||
|
||||
state.config_generated = True
|
||||
state.config_reasoning = sim_params.generation_reasoning
|
||||
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(
|
||||
"generating_config", 100,
|
||||
|
|
@ -434,82 +434,82 @@ class SimulationManager:
|
|||
current=3,
|
||||
total=3
|
||||
)
|
||||
|
||||
# 注意:运行脚本保留在 backend/scripts/ 目录,不再复制到模拟目录
|
||||
# 启动模拟时,simulation_runner 会从 scripts/ 目录运行脚本
|
||||
|
||||
# 更新状态
|
||||
|
||||
# Note: run scripts remain in backend/scripts/; they are not copied to the simulation directory.
|
||||
# When starting a simulation, simulation_runner runs scripts from the scripts/ directory.
|
||||
|
||||
# Update status
|
||||
state.status = SimulationStatus.READY
|
||||
self._save_simulation_state(state)
|
||||
|
||||
logger.info(f"模拟准备完成: {simulation_id}, "
|
||||
f"entities={state.entities_count}, profiles={state.profiles_count}")
|
||||
|
||||
|
||||
logger.info(f"Simulation preparation complete: {simulation_id}, "
|
||||
f"entities={state.entities_count}, profiles={state.profiles_count}")
|
||||
|
||||
return state
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"模拟准备失败: {simulation_id}, error={str(e)}")
|
||||
logger.error(f"Simulation preparation failed: {simulation_id}, error={str(e)}")
|
||||
import traceback
|
||||
logger.error(traceback.format_exc())
|
||||
state.status = SimulationStatus.FAILED
|
||||
state.error = str(e)
|
||||
self._save_simulation_state(state)
|
||||
raise
|
||||
|
||||
|
||||
def get_simulation(self, simulation_id: str) -> Optional[SimulationState]:
|
||||
"""获取模拟状态"""
|
||||
"""Get simulation state"""
|
||||
return self._load_simulation_state(simulation_id)
|
||||
|
||||
|
||||
def list_simulations(self, project_id: Optional[str] = None) -> List[SimulationState]:
|
||||
"""列出所有模拟"""
|
||||
"""List all simulations"""
|
||||
simulations = []
|
||||
|
||||
|
||||
if os.path.exists(self.SIMULATION_DATA_DIR):
|
||||
for sim_id in os.listdir(self.SIMULATION_DATA_DIR):
|
||||
# 跳过隐藏文件(如 .DS_Store)和非目录文件
|
||||
# Skip hidden files (e.g. .DS_Store) and non-directory entries
|
||||
sim_path = os.path.join(self.SIMULATION_DATA_DIR, sim_id)
|
||||
if sim_id.startswith('.') or not os.path.isdir(sim_path):
|
||||
continue
|
||||
|
||||
|
||||
state = self._load_simulation_state(sim_id)
|
||||
if state:
|
||||
if project_id is None or state.project_id == project_id:
|
||||
simulations.append(state)
|
||||
|
||||
|
||||
return simulations
|
||||
|
||||
|
||||
def get_profiles(self, simulation_id: str, platform: str = "reddit") -> List[Dict[str, Any]]:
|
||||
"""获取模拟的Agent Profile"""
|
||||
"""Get agent profiles for a simulation"""
|
||||
state = self._load_simulation_state(simulation_id)
|
||||
if not state:
|
||||
raise ValueError(f"模拟不存在: {simulation_id}")
|
||||
|
||||
raise ValueError(f"Simulation not found: {simulation_id}")
|
||||
|
||||
sim_dir = self._get_simulation_dir(simulation_id)
|
||||
profile_path = os.path.join(sim_dir, f"{platform}_profiles.json")
|
||||
|
||||
|
||||
if not os.path.exists(profile_path):
|
||||
return []
|
||||
|
||||
|
||||
with open(profile_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def get_simulation_config(self, simulation_id: str) -> Optional[Dict[str, Any]]:
|
||||
"""获取模拟配置"""
|
||||
"""Get simulation configuration"""
|
||||
sim_dir = self._get_simulation_dir(simulation_id)
|
||||
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||||
|
||||
|
||||
if not os.path.exists(config_path):
|
||||
return None
|
||||
|
||||
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def get_run_instructions(self, simulation_id: str) -> Dict[str, str]:
|
||||
"""获取运行说明"""
|
||||
"""Get run instructions"""
|
||||
sim_dir = self._get_simulation_dir(simulation_id)
|
||||
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||||
scripts_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../scripts'))
|
||||
|
||||
|
||||
return {
|
||||
"simulation_dir": sim_dir,
|
||||
"scripts_dir": scripts_dir,
|
||||
|
|
@ -520,10 +520,10 @@ class SimulationManager:
|
|||
"parallel": f"python {scripts_dir}/run_parallel_simulation.py --config {config_path}",
|
||||
},
|
||||
"instructions": (
|
||||
f"1. 激活conda环境: conda activate MiroFish\n"
|
||||
f"2. 运行模拟 (脚本位于 {scripts_dir}):\n"
|
||||
f" - 单独运行Twitter: python {scripts_dir}/run_twitter_simulation.py --config {config_path}\n"
|
||||
f" - 单独运行Reddit: python {scripts_dir}/run_reddit_simulation.py --config {config_path}\n"
|
||||
f" - 并行运行双平台: python {scripts_dir}/run_parallel_simulation.py --config {config_path}"
|
||||
f"1. Activate conda environment: conda activate MiroFish\n"
|
||||
f"2. Run simulation (scripts located at {scripts_dir}):\n"
|
||||
f" - Twitter only: python {scripts_dir}/run_twitter_simulation.py --config {config_path}\n"
|
||||
f" - Reddit only: python {scripts_dir}/run_reddit_simulation.py --config {config_path}\n"
|
||||
f" - Both platforms in parallel: python {scripts_dir}/run_parallel_simulation.py --config {config_path}"
|
||||
)
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
文本处理服务
|
||||
Text processing service
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
|
@ -7,11 +7,11 @@ from ..utils.file_parser import FileParser, split_text_into_chunks
|
|||
|
||||
|
||||
class TextProcessor:
|
||||
"""文本处理器"""
|
||||
|
||||
"""Text processor"""
|
||||
|
||||
@staticmethod
|
||||
def extract_from_files(file_paths: List[str]) -> str:
|
||||
"""从多个文件提取文本"""
|
||||
"""Extract text from multiple files"""
|
||||
return FileParser.extract_from_multiple(file_paths)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -21,48 +21,48 @@ class TextProcessor:
|
|||
overlap: int = 50
|
||||
) -> List[str]:
|
||||
"""
|
||||
分割文本
|
||||
|
||||
Split text into chunks.
|
||||
|
||||
Args:
|
||||
text: 原始文本
|
||||
chunk_size: 块大小
|
||||
overlap: 重叠大小
|
||||
|
||||
text: raw text
|
||||
chunk_size: chunk size
|
||||
overlap: overlap size
|
||||
|
||||
Returns:
|
||||
文本块列表
|
||||
list of text chunks
|
||||
"""
|
||||
return split_text_into_chunks(text, chunk_size, overlap)
|
||||
|
||||
@staticmethod
|
||||
def preprocess_text(text: str) -> str:
|
||||
"""
|
||||
预处理文本
|
||||
- 移除多余空白
|
||||
- 标准化换行
|
||||
|
||||
Preprocess text:
|
||||
- Remove excess whitespace
|
||||
- Normalize line endings
|
||||
|
||||
Args:
|
||||
text: 原始文本
|
||||
|
||||
text: raw text
|
||||
|
||||
Returns:
|
||||
处理后的文本
|
||||
processed text
|
||||
"""
|
||||
import re
|
||||
|
||||
# 标准化换行
|
||||
|
||||
# Normalize line endings
|
||||
text = text.replace('\r\n', '\n').replace('\r', '\n')
|
||||
|
||||
# 移除连续空行(保留最多两个换行)
|
||||
|
||||
# Remove consecutive blank lines (keep at most two newlines)
|
||||
text = re.sub(r'\n{3,}', '\n\n', text)
|
||||
|
||||
# 移除行首行尾空白
|
||||
|
||||
# Strip leading/trailing whitespace from each line
|
||||
lines = [line.strip() for line in text.split('\n')]
|
||||
text = '\n'.join(lines)
|
||||
|
||||
|
||||
return text.strip()
|
||||
|
||||
|
||||
@staticmethod
|
||||
def get_text_stats(text: str) -> dict:
|
||||
"""获取文本统计信息"""
|
||||
"""Get text statistics"""
|
||||
return {
|
||||
"total_chars": len(text),
|
||||
"total_lines": text.count('\n') + 1,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
Zep实体读取与过滤服务
|
||||
从Zep图谱中读取节点,筛选出符合预定义实体类型的节点
|
||||
Zep entity read and filter service
|
||||
Reads nodes from the Zep graph and filters out nodes that match predefined entity types
|
||||
"""
|
||||
|
||||
import time
|
||||
|
|
@ -15,23 +15,23 @@ from ..utils.zep_paging import fetch_all_nodes, fetch_all_edges
|
|||
|
||||
logger = get_logger('mirofish.zep_entity_reader')
|
||||
|
||||
# 用于泛型返回类型
|
||||
# Generic return type
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
@dataclass
|
||||
class EntityNode:
|
||||
"""实体节点数据结构"""
|
||||
"""Entity node data structure"""
|
||||
uuid: str
|
||||
name: str
|
||||
labels: List[str]
|
||||
summary: str
|
||||
attributes: Dict[str, Any]
|
||||
# 相关的边信息
|
||||
# Related edge info
|
||||
related_edges: List[Dict[str, Any]] = field(default_factory=list)
|
||||
# 相关的其他节点信息
|
||||
# Related node info
|
||||
related_nodes: List[Dict[str, Any]] = field(default_factory=list)
|
||||
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"uuid": self.uuid,
|
||||
|
|
@ -42,9 +42,9 @@ class EntityNode:
|
|||
"related_edges": self.related_edges,
|
||||
"related_nodes": self.related_nodes,
|
||||
}
|
||||
|
||||
|
||||
def get_entity_type(self) -> Optional[str]:
|
||||
"""获取实体类型(排除默认的Entity标签)"""
|
||||
"""Get entity type (excluding the default Entity label)"""
|
||||
for label in self.labels:
|
||||
if label not in ["Entity", "Node"]:
|
||||
return label
|
||||
|
|
@ -53,12 +53,12 @@ class EntityNode:
|
|||
|
||||
@dataclass
|
||||
class FilteredEntities:
|
||||
"""过滤后的实体集合"""
|
||||
"""Filtered entity collection"""
|
||||
entities: List[EntityNode]
|
||||
entity_types: Set[str]
|
||||
total_count: int
|
||||
filtered_count: int
|
||||
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"entities": [e.to_dict() for e in self.entities],
|
||||
|
|
@ -70,43 +70,43 @@ class FilteredEntities:
|
|||
|
||||
class ZepEntityReader:
|
||||
"""
|
||||
Zep实体读取与过滤服务
|
||||
|
||||
主要功能:
|
||||
1. 从Zep图谱读取所有节点
|
||||
2. 筛选出符合预定义实体类型的节点(Labels不只是Entity的节点)
|
||||
3. 获取每个实体的相关边和关联节点信息
|
||||
Zep entity read and filter service
|
||||
|
||||
Main features:
|
||||
1. Read all nodes from the Zep graph
|
||||
2. Filter out nodes matching predefined entity types (nodes with labels beyond just "Entity")
|
||||
3. Fetch related edges and associated node info for each entity
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None):
|
||||
self.api_key = api_key or Config.ZEP_API_KEY
|
||||
if not self.api_key:
|
||||
raise ValueError("ZEP_API_KEY 未配置")
|
||||
|
||||
raise ValueError("ZEP_API_KEY is not configured")
|
||||
|
||||
self.client = Zep(api_key=self.api_key)
|
||||
|
||||
|
||||
def _call_with_retry(
|
||||
self,
|
||||
func: Callable[[], T],
|
||||
self,
|
||||
func: Callable[[], T],
|
||||
operation_name: str,
|
||||
max_retries: int = 3,
|
||||
initial_delay: float = 2.0
|
||||
) -> T:
|
||||
"""
|
||||
带重试机制的Zep API调用
|
||||
|
||||
Zep API call with retry logic
|
||||
|
||||
Args:
|
||||
func: 要执行的函数(无参数的lambda或callable)
|
||||
operation_name: 操作名称,用于日志
|
||||
max_retries: 最大重试次数(默认3次,即最多尝试3次)
|
||||
initial_delay: 初始延迟秒数
|
||||
|
||||
func: function to execute (a lambda or callable with no arguments)
|
||||
operation_name: operation name for logging
|
||||
max_retries: maximum number of retries (default 3, meaning up to 3 attempts total)
|
||||
initial_delay: initial delay in seconds
|
||||
|
||||
Returns:
|
||||
API调用结果
|
||||
API call result
|
||||
"""
|
||||
last_exception = None
|
||||
delay = initial_delay
|
||||
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return func()
|
||||
|
|
@ -114,27 +114,27 @@ class ZepEntityReader:
|
|||
last_exception = e
|
||||
if attempt < max_retries - 1:
|
||||
logger.warning(
|
||||
f"Zep {operation_name} 第 {attempt + 1} 次尝试失败: {str(e)[:100]}, "
|
||||
f"{delay:.1f}秒后重试..."
|
||||
f"Zep {operation_name} attempt {attempt + 1} failed: {str(e)[:100]}, "
|
||||
f"retrying in {delay:.1f}s..."
|
||||
)
|
||||
time.sleep(delay)
|
||||
delay *= 2 # 指数退避
|
||||
delay *= 2 # Exponential backoff
|
||||
else:
|
||||
logger.error(f"Zep {operation_name} 在 {max_retries} 次尝试后仍失败: {str(e)}")
|
||||
|
||||
logger.error(f"Zep {operation_name} still failing after {max_retries} attempts: {str(e)}")
|
||||
|
||||
raise last_exception
|
||||
|
||||
|
||||
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取图谱的所有节点(分页获取)
|
||||
Get all nodes in the graph (paginated)
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
graph_id: graph ID
|
||||
|
||||
Returns:
|
||||
节点列表
|
||||
Node list
|
||||
"""
|
||||
logger.info(f"获取图谱 {graph_id} 的所有节点...")
|
||||
logger.info(f"Fetching all nodes for graph {graph_id}...")
|
||||
|
||||
nodes = fetch_all_nodes(self.client, graph_id)
|
||||
|
||||
|
|
@ -148,20 +148,20 @@ class ZepEntityReader:
|
|||
"attributes": node.attributes or {},
|
||||
})
|
||||
|
||||
logger.info(f"共获取 {len(nodes_data)} 个节点")
|
||||
logger.info(f"Fetched {len(nodes_data)} nodes")
|
||||
return nodes_data
|
||||
|
||||
def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取图谱的所有边(分页获取)
|
||||
Get all edges in the graph (paginated)
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
graph_id: graph ID
|
||||
|
||||
Returns:
|
||||
边列表
|
||||
Edge list
|
||||
"""
|
||||
logger.info(f"获取图谱 {graph_id} 的所有边...")
|
||||
logger.info(f"Fetching all edges for graph {graph_id}...")
|
||||
|
||||
edges = fetch_all_edges(self.client, graph_id)
|
||||
|
||||
|
|
@ -176,26 +176,26 @@ class ZepEntityReader:
|
|||
"attributes": edge.attributes or {},
|
||||
})
|
||||
|
||||
logger.info(f"共获取 {len(edges_data)} 条边")
|
||||
logger.info(f"Fetched {len(edges_data)} edges")
|
||||
return edges_data
|
||||
|
||||
|
||||
def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
获取指定节点的所有相关边(带重试机制)
|
||||
|
||||
Get all edges related to the specified node (with retry logic)
|
||||
|
||||
Args:
|
||||
node_uuid: 节点UUID
|
||||
|
||||
node_uuid: node UUID
|
||||
|
||||
Returns:
|
||||
边列表
|
||||
Edge list
|
||||
"""
|
||||
try:
|
||||
# 使用重试机制调用Zep API
|
||||
# Call Zep API with retry
|
||||
edges = self._call_with_retry(
|
||||
func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid),
|
||||
operation_name=f"获取节点边(node={node_uuid[:8]}...)"
|
||||
operation_name=f"get node edges (node={node_uuid[:8]}...)"
|
||||
)
|
||||
|
||||
|
||||
edges_data = []
|
||||
for edge in edges:
|
||||
edges_data.append({
|
||||
|
|
@ -206,60 +206,60 @@ class ZepEntityReader:
|
|||
"target_node_uuid": edge.target_node_uuid,
|
||||
"attributes": edge.attributes or {},
|
||||
})
|
||||
|
||||
|
||||
return edges_data
|
||||
except Exception as e:
|
||||
logger.warning(f"获取节点 {node_uuid} 的边失败: {str(e)}")
|
||||
logger.warning(f"Failed to get edges for node {node_uuid}: {str(e)}")
|
||||
return []
|
||||
|
||||
|
||||
def filter_defined_entities(
|
||||
self,
|
||||
self,
|
||||
graph_id: str,
|
||||
defined_entity_types: Optional[List[str]] = None,
|
||||
enrich_with_edges: bool = True
|
||||
) -> FilteredEntities:
|
||||
"""
|
||||
筛选出符合预定义实体类型的节点
|
||||
|
||||
筛选逻辑:
|
||||
- 如果节点的Labels只有一个"Entity",说明这个实体不符合我们预定义的类型,跳过
|
||||
- 如果节点的Labels包含除"Entity"和"Node"之外的标签,说明符合预定义类型,保留
|
||||
|
||||
Filter out nodes that match predefined entity types
|
||||
|
||||
Filter logic:
|
||||
- If a node's Labels contain only "Entity", it does not match our predefined types; skip it
|
||||
- If a node's Labels contain labels other than "Entity" and "Node", it matches a predefined type; keep it
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
defined_entity_types: 预定义的实体类型列表(可选,如果提供则只保留这些类型)
|
||||
enrich_with_edges: 是否获取每个实体的相关边信息
|
||||
|
||||
graph_id: graph ID
|
||||
defined_entity_types: list of predefined entity types (optional; if provided, only these types are kept)
|
||||
enrich_with_edges: whether to fetch related edge info for each entity
|
||||
|
||||
Returns:
|
||||
FilteredEntities: 过滤后的实体集合
|
||||
FilteredEntities: filtered entity collection
|
||||
"""
|
||||
logger.info(f"开始筛选图谱 {graph_id} 的实体...")
|
||||
|
||||
# 获取所有节点
|
||||
logger.info(f"Starting entity filtering for graph {graph_id}...")
|
||||
|
||||
# Get all nodes
|
||||
all_nodes = self.get_all_nodes(graph_id)
|
||||
total_count = len(all_nodes)
|
||||
|
||||
# 获取所有边(用于后续关联查找)
|
||||
|
||||
# Get all edges (for relation lookup)
|
||||
all_edges = self.get_all_edges(graph_id) if enrich_with_edges else []
|
||||
|
||||
# 构建节点UUID到节点数据的映射
|
||||
|
||||
# Build UUID-to-node mapping
|
||||
node_map = {n["uuid"]: n for n in all_nodes}
|
||||
|
||||
# 筛选符合条件的实体
|
||||
|
||||
# Filter matching entities
|
||||
filtered_entities = []
|
||||
entity_types_found = set()
|
||||
|
||||
|
||||
for node in all_nodes:
|
||||
labels = node.get("labels", [])
|
||||
|
||||
# 筛选逻辑:Labels必须包含除"Entity"和"Node"之外的标签
|
||||
|
||||
# Filter logic: Labels must contain at least one label other than "Entity" and "Node"
|
||||
custom_labels = [l for l in labels if l not in ["Entity", "Node"]]
|
||||
|
||||
|
||||
if not custom_labels:
|
||||
# 只有默认标签,跳过
|
||||
# Only default labels; skip
|
||||
continue
|
||||
|
||||
# 如果指定了预定义类型,检查是否匹配
|
||||
|
||||
# If predefined types are specified, check for a match
|
||||
if defined_entity_types:
|
||||
matching_labels = [l for l in custom_labels if l in defined_entity_types]
|
||||
if not matching_labels:
|
||||
|
|
@ -267,10 +267,10 @@ class ZepEntityReader:
|
|||
entity_type = matching_labels[0]
|
||||
else:
|
||||
entity_type = custom_labels[0]
|
||||
|
||||
|
||||
entity_types_found.add(entity_type)
|
||||
|
||||
# 创建实体节点对象
|
||||
|
||||
# Create entity node object
|
||||
entity = EntityNode(
|
||||
uuid=node["uuid"],
|
||||
name=node["name"],
|
||||
|
|
@ -278,12 +278,12 @@ class ZepEntityReader:
|
|||
summary=node["summary"],
|
||||
attributes=node["attributes"],
|
||||
)
|
||||
|
||||
# 获取相关边和节点
|
||||
|
||||
# Fetch related edges and nodes
|
||||
if enrich_with_edges:
|
||||
related_edges = []
|
||||
related_node_uuids = set()
|
||||
|
||||
|
||||
for edge in all_edges:
|
||||
if edge["source_node_uuid"] == node["uuid"]:
|
||||
related_edges.append({
|
||||
|
|
@ -301,10 +301,10 @@ class ZepEntityReader:
|
|||
"source_node_uuid": edge["source_node_uuid"],
|
||||
})
|
||||
related_node_uuids.add(edge["source_node_uuid"])
|
||||
|
||||
|
||||
entity.related_edges = related_edges
|
||||
|
||||
# 获取关联节点的基本信息
|
||||
|
||||
# Fetch basic info for related nodes
|
||||
related_nodes = []
|
||||
for related_uuid in related_node_uuids:
|
||||
if related_uuid in node_map:
|
||||
|
|
@ -315,57 +315,57 @@ class ZepEntityReader:
|
|||
"labels": related_node["labels"],
|
||||
"summary": related_node.get("summary", ""),
|
||||
})
|
||||
|
||||
|
||||
entity.related_nodes = related_nodes
|
||||
|
||||
|
||||
filtered_entities.append(entity)
|
||||
|
||||
logger.info(f"筛选完成: 总节点 {total_count}, 符合条件 {len(filtered_entities)}, "
|
||||
f"实体类型: {entity_types_found}")
|
||||
|
||||
|
||||
logger.info(f"Filtering complete: total nodes {total_count}, matching {len(filtered_entities)}, "
|
||||
f"entity types: {entity_types_found}")
|
||||
|
||||
return FilteredEntities(
|
||||
entities=filtered_entities,
|
||||
entity_types=entity_types_found,
|
||||
total_count=total_count,
|
||||
filtered_count=len(filtered_entities),
|
||||
)
|
||||
|
||||
|
||||
def get_entity_with_context(
|
||||
self,
|
||||
graph_id: str,
|
||||
self,
|
||||
graph_id: str,
|
||||
entity_uuid: str
|
||||
) -> Optional[EntityNode]:
|
||||
"""
|
||||
获取单个实体及其完整上下文(边和关联节点,带重试机制)
|
||||
|
||||
Get a single entity and its full context (edges and related nodes, with retry)
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
entity_uuid: 实体UUID
|
||||
|
||||
graph_id: graph ID
|
||||
entity_uuid: entity UUID
|
||||
|
||||
Returns:
|
||||
EntityNode或None
|
||||
EntityNode or None
|
||||
"""
|
||||
try:
|
||||
# 使用重试机制获取节点
|
||||
# Get the node with retry
|
||||
node = self._call_with_retry(
|
||||
func=lambda: self.client.graph.node.get(uuid_=entity_uuid),
|
||||
operation_name=f"获取节点详情(uuid={entity_uuid[:8]}...)"
|
||||
operation_name=f"get node detail (uuid={entity_uuid[:8]}...)"
|
||||
)
|
||||
|
||||
|
||||
if not node:
|
||||
return None
|
||||
|
||||
# 获取节点的边
|
||||
|
||||
# Get the node's edges
|
||||
edges = self.get_node_edges(entity_uuid)
|
||||
|
||||
# 获取所有节点用于关联查找
|
||||
|
||||
# Get all nodes for relation lookup
|
||||
all_nodes = self.get_all_nodes(graph_id)
|
||||
node_map = {n["uuid"]: n for n in all_nodes}
|
||||
|
||||
# 处理相关边和节点
|
||||
|
||||
# Process related edges and nodes
|
||||
related_edges = []
|
||||
related_node_uuids = set()
|
||||
|
||||
|
||||
for edge in edges:
|
||||
if edge["source_node_uuid"] == entity_uuid:
|
||||
related_edges.append({
|
||||
|
|
@ -383,8 +383,8 @@ class ZepEntityReader:
|
|||
"source_node_uuid": edge["source_node_uuid"],
|
||||
})
|
||||
related_node_uuids.add(edge["source_node_uuid"])
|
||||
|
||||
# 获取关联节点信息
|
||||
|
||||
# Fetch related node info
|
||||
related_nodes = []
|
||||
for related_uuid in related_node_uuids:
|
||||
if related_uuid in node_map:
|
||||
|
|
@ -395,7 +395,7 @@ class ZepEntityReader:
|
|||
"labels": related_node["labels"],
|
||||
"summary": related_node.get("summary", ""),
|
||||
})
|
||||
|
||||
|
||||
return EntityNode(
|
||||
uuid=getattr(node, 'uuid_', None) or getattr(node, 'uuid', ''),
|
||||
name=node.name or "",
|
||||
|
|
@ -405,27 +405,27 @@ class ZepEntityReader:
|
|||
related_edges=related_edges,
|
||||
related_nodes=related_nodes,
|
||||
)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取实体 {entity_uuid} 失败: {str(e)}")
|
||||
logger.error(f"Failed to get entity {entity_uuid}: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
def get_entities_by_type(
|
||||
self,
|
||||
graph_id: str,
|
||||
self,
|
||||
graph_id: str,
|
||||
entity_type: str,
|
||||
enrich_with_edges: bool = True
|
||||
) -> List[EntityNode]:
|
||||
"""
|
||||
获取指定类型的所有实体
|
||||
|
||||
Get all entities of a specified type
|
||||
|
||||
Args:
|
||||
graph_id: 图谱ID
|
||||
entity_type: 实体类型(如 "Student", "PublicFigure" 等)
|
||||
enrich_with_edges: 是否获取相关边信息
|
||||
|
||||
graph_id: graph ID
|
||||
entity_type: entity type (e.g. "Student", "PublicFigure")
|
||||
enrich_with_edges: whether to fetch related edge info
|
||||
|
||||
Returns:
|
||||
实体列表
|
||||
Entity list
|
||||
"""
|
||||
result = self.filter_defined_entities(
|
||||
graph_id=graph_id,
|
||||
|
|
@ -433,5 +433,3 @@ class ZepEntityReader:
|
|||
enrich_with_edges=enrich_with_edges
|
||||
)
|
||||
return result.entities
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
Zep图谱记忆更新服务
|
||||
将模拟中的Agent活动动态更新到Zep图谱中
|
||||
Zep graph memory update service
|
||||
Dynamically updates agent activities from the simulation to the Zep graph
|
||||
"""
|
||||
|
||||
import os
|
||||
|
|
@ -23,7 +23,7 @@ logger = get_logger('mirofish.zep_graph_memory_updater')
|
|||
|
||||
@dataclass
|
||||
class AgentActivity:
|
||||
"""Agent活动记录"""
|
||||
"""Agent activity record"""
|
||||
platform: str # twitter / reddit
|
||||
agent_id: int
|
||||
agent_name: str
|
||||
|
|
@ -31,15 +31,16 @@ class AgentActivity:
|
|||
action_args: Dict[str, Any]
|
||||
round_num: int
|
||||
timestamp: str
|
||||
|
||||
|
||||
def to_episode_text(self) -> str:
|
||||
"""
|
||||
将活动转换为可以发送给Zep的文本描述
|
||||
|
||||
采用自然语言描述格式,让Zep能够从中提取实体和关系
|
||||
不添加模拟相关的前缀,避免误导图谱更新
|
||||
Convert the activity to a text description suitable for sending to Zep
|
||||
|
||||
Uses a natural-language description format so Zep can extract entities and
|
||||
relationships. No simulation-specific prefix is added to avoid misleading
|
||||
graph updates.
|
||||
"""
|
||||
# 根据不同的动作类型生成不同的描述
|
||||
# Generate a description based on the action type
|
||||
action_descriptions = {
|
||||
"CREATE_POST": self._describe_create_post,
|
||||
"LIKE_POST": self._describe_like_post,
|
||||
|
|
@ -54,226 +55,227 @@ class AgentActivity:
|
|||
"SEARCH_USER": self._describe_search_user,
|
||||
"MUTE": self._describe_mute,
|
||||
}
|
||||
|
||||
|
||||
describe_func = action_descriptions.get(self.action_type, self._describe_generic)
|
||||
description = describe_func()
|
||||
|
||||
# 直接返回 "agent名称: 活动描述" 格式,不添加模拟前缀
|
||||
|
||||
# Return "agent_name: activity description" format without a simulation prefix
|
||||
return f"{self.agent_name}: {description}"
|
||||
|
||||
|
||||
def _describe_create_post(self) -> str:
|
||||
content = self.action_args.get("content", "")
|
||||
if content:
|
||||
return f"发布了一条帖子:「{content}」"
|
||||
return "发布了一条帖子"
|
||||
|
||||
return f'posted: "{content}"'
|
||||
return "created a post"
|
||||
|
||||
def _describe_like_post(self) -> str:
|
||||
"""点赞帖子 - 包含帖子原文和作者信息"""
|
||||
"""Like a post — includes post content and author info"""
|
||||
post_content = self.action_args.get("post_content", "")
|
||||
post_author = self.action_args.get("post_author_name", "")
|
||||
|
||||
|
||||
if post_content and post_author:
|
||||
return f"点赞了{post_author}的帖子:「{post_content}」"
|
||||
return f'liked {post_author}\'s post: "{post_content}"'
|
||||
elif post_content:
|
||||
return f"点赞了一条帖子:「{post_content}」"
|
||||
return f'liked a post: "{post_content}"'
|
||||
elif post_author:
|
||||
return f"点赞了{post_author}的一条帖子"
|
||||
return "点赞了一条帖子"
|
||||
|
||||
return f"liked a post by {post_author}"
|
||||
return "liked a post"
|
||||
|
||||
def _describe_dislike_post(self) -> str:
|
||||
"""踩帖子 - 包含帖子原文和作者信息"""
|
||||
"""Dislike a post — includes post content and author info"""
|
||||
post_content = self.action_args.get("post_content", "")
|
||||
post_author = self.action_args.get("post_author_name", "")
|
||||
|
||||
|
||||
if post_content and post_author:
|
||||
return f"踩了{post_author}的帖子:「{post_content}」"
|
||||
return f'disliked {post_author}\'s post: "{post_content}"'
|
||||
elif post_content:
|
||||
return f"踩了一条帖子:「{post_content}」"
|
||||
return f'disliked a post: "{post_content}"'
|
||||
elif post_author:
|
||||
return f"踩了{post_author}的一条帖子"
|
||||
return "踩了一条帖子"
|
||||
|
||||
return f"disliked a post by {post_author}"
|
||||
return "disliked a post"
|
||||
|
||||
def _describe_repost(self) -> str:
|
||||
"""转发帖子 - 包含原帖内容和作者信息"""
|
||||
"""Repost — includes original post content and author info"""
|
||||
original_content = self.action_args.get("original_content", "")
|
||||
original_author = self.action_args.get("original_author_name", "")
|
||||
|
||||
|
||||
if original_content and original_author:
|
||||
return f"转发了{original_author}的帖子:「{original_content}」"
|
||||
return f'reposted {original_author}\'s post: "{original_content}"'
|
||||
elif original_content:
|
||||
return f"转发了一条帖子:「{original_content}」"
|
||||
return f'reposted: "{original_content}"'
|
||||
elif original_author:
|
||||
return f"转发了{original_author}的一条帖子"
|
||||
return "转发了一条帖子"
|
||||
|
||||
return f"reposted a post by {original_author}"
|
||||
return "reposted a post"
|
||||
|
||||
def _describe_quote_post(self) -> str:
|
||||
"""引用帖子 - 包含原帖内容、作者信息和引用评论"""
|
||||
"""Quote post — includes original post content, author info, and quote comment"""
|
||||
original_content = self.action_args.get("original_content", "")
|
||||
original_author = self.action_args.get("original_author_name", "")
|
||||
quote_content = self.action_args.get("quote_content", "") or self.action_args.get("content", "")
|
||||
|
||||
|
||||
base = ""
|
||||
if original_content and original_author:
|
||||
base = f"引用了{original_author}的帖子「{original_content}」"
|
||||
base = f'quoted {original_author}\'s post "{original_content}"'
|
||||
elif original_content:
|
||||
base = f"引用了一条帖子「{original_content}」"
|
||||
base = f'quoted a post: "{original_content}"'
|
||||
elif original_author:
|
||||
base = f"引用了{original_author}的一条帖子"
|
||||
base = f"quoted a post by {original_author}"
|
||||
else:
|
||||
base = "引用了一条帖子"
|
||||
|
||||
base = "quoted a post"
|
||||
|
||||
if quote_content:
|
||||
base += f",并评论道:「{quote_content}」"
|
||||
base += f' with comment: "{quote_content}"'
|
||||
return base
|
||||
|
||||
|
||||
def _describe_follow(self) -> str:
|
||||
"""关注用户 - 包含被关注用户的名称"""
|
||||
"""Follow a user — includes the followed user's name"""
|
||||
target_user_name = self.action_args.get("target_user_name", "")
|
||||
|
||||
|
||||
if target_user_name:
|
||||
return f"关注了用户「{target_user_name}」"
|
||||
return "关注了一个用户"
|
||||
|
||||
return f'followed user "{target_user_name}"'
|
||||
return "followed a user"
|
||||
|
||||
def _describe_create_comment(self) -> str:
|
||||
"""发表评论 - 包含评论内容和所评论的帖子信息"""
|
||||
"""Create a comment — includes comment content and the post being commented on"""
|
||||
content = self.action_args.get("content", "")
|
||||
post_content = self.action_args.get("post_content", "")
|
||||
post_author = self.action_args.get("post_author_name", "")
|
||||
|
||||
|
||||
if content:
|
||||
if post_content and post_author:
|
||||
return f"在{post_author}的帖子「{post_content}」下评论道:「{content}」"
|
||||
return f'commented on {post_author}\'s post "{post_content}": "{content}"'
|
||||
elif post_content:
|
||||
return f"在帖子「{post_content}」下评论道:「{content}」"
|
||||
return f'commented on post "{post_content}": "{content}"'
|
||||
elif post_author:
|
||||
return f"在{post_author}的帖子下评论道:「{content}」"
|
||||
return f"评论道:「{content}」"
|
||||
return "发表了评论"
|
||||
|
||||
return f'commented on {post_author}\'s post: "{content}"'
|
||||
return f'commented: "{content}"'
|
||||
return "posted a comment"
|
||||
|
||||
def _describe_like_comment(self) -> str:
|
||||
"""点赞评论 - 包含评论内容和作者信息"""
|
||||
"""Like a comment — includes comment content and author info"""
|
||||
comment_content = self.action_args.get("comment_content", "")
|
||||
comment_author = self.action_args.get("comment_author_name", "")
|
||||
|
||||
|
||||
if comment_content and comment_author:
|
||||
return f"点赞了{comment_author}的评论:「{comment_content}」"
|
||||
return f'liked {comment_author}\'s comment: "{comment_content}"'
|
||||
elif comment_content:
|
||||
return f"点赞了一条评论:「{comment_content}」"
|
||||
return f'liked a comment: "{comment_content}"'
|
||||
elif comment_author:
|
||||
return f"点赞了{comment_author}的一条评论"
|
||||
return "点赞了一条评论"
|
||||
|
||||
return f"liked a comment by {comment_author}"
|
||||
return "liked a comment"
|
||||
|
||||
def _describe_dislike_comment(self) -> str:
|
||||
"""踩评论 - 包含评论内容和作者信息"""
|
||||
"""Dislike a comment — includes comment content and author info"""
|
||||
comment_content = self.action_args.get("comment_content", "")
|
||||
comment_author = self.action_args.get("comment_author_name", "")
|
||||
|
||||
|
||||
if comment_content and comment_author:
|
||||
return f"踩了{comment_author}的评论:「{comment_content}」"
|
||||
return f'disliked {comment_author}\'s comment: "{comment_content}"'
|
||||
elif comment_content:
|
||||
return f"踩了一条评论:「{comment_content}」"
|
||||
return f'disliked a comment: "{comment_content}"'
|
||||
elif comment_author:
|
||||
return f"踩了{comment_author}的一条评论"
|
||||
return "踩了一条评论"
|
||||
|
||||
return f"disliked a comment by {comment_author}"
|
||||
return "disliked a comment"
|
||||
|
||||
def _describe_search(self) -> str:
|
||||
"""搜索帖子 - 包含搜索关键词"""
|
||||
"""Search posts — includes search keyword"""
|
||||
query = self.action_args.get("query", "") or self.action_args.get("keyword", "")
|
||||
return f"搜索了「{query}」" if query else "进行了搜索"
|
||||
|
||||
return f'searched for "{query}"' if query else "performed a search"
|
||||
|
||||
def _describe_search_user(self) -> str:
|
||||
"""搜索用户 - 包含搜索关键词"""
|
||||
"""Search users — includes search keyword"""
|
||||
query = self.action_args.get("query", "") or self.action_args.get("username", "")
|
||||
return f"搜索了用户「{query}」" if query else "搜索了用户"
|
||||
|
||||
return f'searched for user "{query}"' if query else "searched for a user"
|
||||
|
||||
def _describe_mute(self) -> str:
|
||||
"""屏蔽用户 - 包含被屏蔽用户的名称"""
|
||||
"""Mute a user — includes the muted user's name"""
|
||||
target_user_name = self.action_args.get("target_user_name", "")
|
||||
|
||||
|
||||
if target_user_name:
|
||||
return f"屏蔽了用户「{target_user_name}」"
|
||||
return "屏蔽了一个用户"
|
||||
|
||||
return f'muted user "{target_user_name}"'
|
||||
return "muted a user"
|
||||
|
||||
def _describe_generic(self) -> str:
|
||||
# 对于未知的动作类型,生成通用描述
|
||||
return f"执行了{self.action_type}操作"
|
||||
# Generic description for unknown action types
|
||||
return f"performed action: {self.action_type}"
|
||||
|
||||
|
||||
class ZepGraphMemoryUpdater:
|
||||
"""
|
||||
Zep图谱记忆更新器
|
||||
|
||||
监控模拟的actions日志文件,将新的agent活动实时更新到Zep图谱中。
|
||||
按平台分组,每累积BATCH_SIZE条活动后批量发送到Zep。
|
||||
|
||||
所有有意义的行为都会被更新到Zep,action_args中会包含完整的上下文信息:
|
||||
- 点赞/踩的帖子原文
|
||||
- 转发/引用的帖子原文
|
||||
- 关注/屏蔽的用户名
|
||||
- 点赞/踩的评论原文
|
||||
Zep graph memory updater
|
||||
|
||||
Monitors the simulation's actions log file and updates new agent activities
|
||||
to the Zep graph in real time. Activities are grouped by platform; each platform
|
||||
batches up to BATCH_SIZE activities before sending them to Zep.
|
||||
|
||||
All meaningful actions are updated to Zep. action_args contains full context:
|
||||
- Original post content for likes/dislikes
|
||||
- Original post content for reposts/quotes
|
||||
- Usernames for follows/mutes
|
||||
- Original comment content for comment likes/dislikes
|
||||
"""
|
||||
|
||||
# 批量发送大小(每个平台累积多少条后发送)
|
||||
|
||||
# Batch send size (activities per platform before sending)
|
||||
BATCH_SIZE = 5
|
||||
|
||||
# 平台名称映射(用于控制台显示)
|
||||
|
||||
# Platform display names
|
||||
PLATFORM_DISPLAY_NAMES = {
|
||||
'twitter': '世界1',
|
||||
'reddit': '世界2',
|
||||
'twitter': 'World 1',
|
||||
'reddit': 'World 2',
|
||||
}
|
||||
|
||||
# 发送间隔(秒),避免请求过快
|
||||
|
||||
# Send interval (seconds) to avoid sending too fast
|
||||
SEND_INTERVAL = 0.5
|
||||
|
||||
# 重试配置
|
||||
|
||||
# Retry config
|
||||
MAX_RETRIES = 3
|
||||
RETRY_DELAY = 2 # 秒
|
||||
|
||||
RETRY_DELAY = 2 # seconds
|
||||
|
||||
def __init__(self, graph_id: str, api_key: Optional[str] = None):
|
||||
"""
|
||||
初始化更新器
|
||||
|
||||
Initialize the updater
|
||||
|
||||
Args:
|
||||
graph_id: Zep图谱ID
|
||||
api_key: Zep API Key(可选,默认从配置读取)
|
||||
graph_id: Zep graph ID
|
||||
api_key: Zep API key (optional; defaults to config value)
|
||||
"""
|
||||
self.graph_id = graph_id
|
||||
self.api_key = api_key or Config.ZEP_API_KEY
|
||||
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("ZEP_API_KEY未配置")
|
||||
|
||||
raise ValueError("ZEP_API_KEY is not configured")
|
||||
|
||||
self.client = Zep(api_key=self.api_key)
|
||||
|
||||
# 活动队列
|
||||
|
||||
# Activity queue
|
||||
self._activity_queue: Queue = Queue()
|
||||
|
||||
# 按平台分组的活动缓冲区(每个平台各自累积到BATCH_SIZE后批量发送)
|
||||
|
||||
# Per-platform activity buffers (each platform accumulates to BATCH_SIZE before batch sending)
|
||||
self._platform_buffers: Dict[str, List[AgentActivity]] = {
|
||||
'twitter': [],
|
||||
'reddit': [],
|
||||
}
|
||||
self._buffer_lock = threading.Lock()
|
||||
|
||||
# 控制标志
|
||||
|
||||
# Control flags
|
||||
self._running = False
|
||||
self._worker_thread: Optional[threading.Thread] = None
|
||||
|
||||
# 统计
|
||||
self._total_activities = 0 # 实际添加到队列的活动数
|
||||
self._total_sent = 0 # 成功发送到Zep的批次数
|
||||
self._total_items_sent = 0 # 成功发送到Zep的活动条数
|
||||
self._failed_count = 0 # 发送失败的批次数
|
||||
self._skipped_count = 0 # 被过滤跳过的活动数(DO_NOTHING)
|
||||
|
||||
logger.info(f"ZepGraphMemoryUpdater 初始化完成: graph_id={graph_id}, batch_size={self.BATCH_SIZE}")
|
||||
|
||||
|
||||
# Statistics
|
||||
self._total_activities = 0 # Activities added to queue
|
||||
self._total_sent = 0 # Batches successfully sent to Zep
|
||||
self._total_items_sent = 0 # Individual activities successfully sent to Zep
|
||||
self._failed_count = 0 # Batches that failed to send
|
||||
self._skipped_count = 0 # Activities filtered out (DO_NOTHING)
|
||||
|
||||
logger.info(f"ZepGraphMemoryUpdater initialized: graph_id={graph_id}, batch_size={self.BATCH_SIZE}")
|
||||
|
||||
def _get_platform_display_name(self, platform: str) -> str:
|
||||
"""获取平台的显示名称"""
|
||||
"""Get the display name for a platform"""
|
||||
return self.PLATFORM_DISPLAY_NAMES.get(platform.lower(), platform)
|
||||
|
||||
|
||||
def start(self):
|
||||
"""启动后台工作线程"""
|
||||
"""Start the background worker thread"""
|
||||
if self._running:
|
||||
return
|
||||
|
||||
|
|
@ -288,67 +290,67 @@ class ZepGraphMemoryUpdater:
|
|||
name=f"ZepMemoryUpdater-{self.graph_id[:8]}"
|
||||
)
|
||||
self._worker_thread.start()
|
||||
logger.info(f"ZepGraphMemoryUpdater 已启动: graph_id={self.graph_id}")
|
||||
|
||||
logger.info(f"ZepGraphMemoryUpdater started: graph_id={self.graph_id}")
|
||||
|
||||
def stop(self):
|
||||
"""停止后台工作线程"""
|
||||
"""Stop the background worker thread"""
|
||||
self._running = False
|
||||
|
||||
# 发送剩余的活动
|
||||
|
||||
# Send remaining activities
|
||||
self._flush_remaining()
|
||||
|
||||
|
||||
if self._worker_thread and self._worker_thread.is_alive():
|
||||
self._worker_thread.join(timeout=10)
|
||||
|
||||
logger.info(f"ZepGraphMemoryUpdater 已停止: graph_id={self.graph_id}, "
|
||||
|
||||
logger.info(f"ZepGraphMemoryUpdater stopped: graph_id={self.graph_id}, "
|
||||
f"total_activities={self._total_activities}, "
|
||||
f"batches_sent={self._total_sent}, "
|
||||
f"items_sent={self._total_items_sent}, "
|
||||
f"failed={self._failed_count}, "
|
||||
f"skipped={self._skipped_count}")
|
||||
|
||||
|
||||
def add_activity(self, activity: AgentActivity):
|
||||
"""
|
||||
添加一个agent活动到队列
|
||||
|
||||
所有有意义的行为都会被添加到队列,包括:
|
||||
- CREATE_POST(发帖)
|
||||
- CREATE_COMMENT(评论)
|
||||
- QUOTE_POST(引用帖子)
|
||||
- SEARCH_POSTS(搜索帖子)
|
||||
- SEARCH_USER(搜索用户)
|
||||
- LIKE_POST/DISLIKE_POST(点赞/踩帖子)
|
||||
- REPOST(转发)
|
||||
- FOLLOW(关注)
|
||||
- MUTE(屏蔽)
|
||||
- LIKE_COMMENT/DISLIKE_COMMENT(点赞/踩评论)
|
||||
|
||||
action_args中会包含完整的上下文信息(如帖子原文、用户名等)。
|
||||
|
||||
Add an agent activity to the queue
|
||||
|
||||
All meaningful actions are added to the queue, including:
|
||||
- CREATE_POST
|
||||
- CREATE_COMMENT
|
||||
- QUOTE_POST
|
||||
- SEARCH_POSTS
|
||||
- SEARCH_USER
|
||||
- LIKE_POST/DISLIKE_POST
|
||||
- REPOST
|
||||
- FOLLOW
|
||||
- MUTE
|
||||
- LIKE_COMMENT/DISLIKE_COMMENT
|
||||
|
||||
action_args contains full context (e.g. post content, usernames, etc.).
|
||||
|
||||
Args:
|
||||
activity: Agent活动记录
|
||||
activity: agent activity record
|
||||
"""
|
||||
# 跳过DO_NOTHING类型的活动
|
||||
# Skip DO_NOTHING activities
|
||||
if activity.action_type == "DO_NOTHING":
|
||||
self._skipped_count += 1
|
||||
return
|
||||
|
||||
|
||||
self._activity_queue.put(activity)
|
||||
self._total_activities += 1
|
||||
logger.debug(f"添加活动到Zep队列: {activity.agent_name} - {activity.action_type}")
|
||||
|
||||
logger.debug(f"Added activity to Zep queue: {activity.agent_name} - {activity.action_type}")
|
||||
|
||||
def add_activity_from_dict(self, data: Dict[str, Any], platform: str):
|
||||
"""
|
||||
从字典数据添加活动
|
||||
|
||||
Add an activity from a dictionary
|
||||
|
||||
Args:
|
||||
data: 从actions.jsonl解析的字典数据
|
||||
platform: 平台名称 (twitter/reddit)
|
||||
data: dict parsed from actions.jsonl
|
||||
platform: platform name (twitter/reddit)
|
||||
"""
|
||||
# 跳过事件类型的条目
|
||||
# Skip event-type entries
|
||||
if "event_type" in data:
|
||||
return
|
||||
|
||||
|
||||
activity = AgentActivity(
|
||||
platform=platform,
|
||||
agent_id=data.get("agent_id", 0),
|
||||
|
|
@ -358,57 +360,57 @@ class ZepGraphMemoryUpdater:
|
|||
round_num=data.get("round", 0),
|
||||
timestamp=data.get("timestamp", datetime.now().isoformat()),
|
||||
)
|
||||
|
||||
|
||||
self.add_activity(activity)
|
||||
|
||||
|
||||
def _worker_loop(self, locale: str = 'zh'):
|
||||
"""后台工作循环 - 按平台批量发送活动到Zep"""
|
||||
"""Background worker loop — batch-sends activities to Zep per platform"""
|
||||
set_locale(locale)
|
||||
while self._running or not self._activity_queue.empty():
|
||||
try:
|
||||
# 尝试从队列获取活动(超时1秒)
|
||||
# Try to get an activity from the queue (1 second timeout)
|
||||
try:
|
||||
activity = self._activity_queue.get(timeout=1)
|
||||
|
||||
# 将活动添加到对应平台的缓冲区
|
||||
|
||||
# Add activity to the corresponding platform buffer
|
||||
platform = activity.platform.lower()
|
||||
with self._buffer_lock:
|
||||
if platform not in self._platform_buffers:
|
||||
self._platform_buffers[platform] = []
|
||||
self._platform_buffers[platform].append(activity)
|
||||
|
||||
# 检查该平台是否达到批量大小
|
||||
|
||||
# Check if this platform has reached the batch size
|
||||
if len(self._platform_buffers[platform]) >= self.BATCH_SIZE:
|
||||
batch = self._platform_buffers[platform][:self.BATCH_SIZE]
|
||||
self._platform_buffers[platform] = self._platform_buffers[platform][self.BATCH_SIZE:]
|
||||
# 释放锁后再发送
|
||||
# Release lock before sending
|
||||
self._send_batch_activities(batch, platform)
|
||||
# 发送间隔,避免请求过快
|
||||
# Throttle to avoid sending too fast
|
||||
time.sleep(self.SEND_INTERVAL)
|
||||
|
||||
|
||||
except Empty:
|
||||
pass
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"工作循环异常: {e}")
|
||||
logger.error(f"Worker loop exception: {e}")
|
||||
time.sleep(1)
|
||||
|
||||
|
||||
def _send_batch_activities(self, activities: List[AgentActivity], platform: str):
|
||||
"""
|
||||
批量发送活动到Zep图谱(合并为一条文本)
|
||||
|
||||
Batch-send activities to the Zep graph (merged into a single text block)
|
||||
|
||||
Args:
|
||||
activities: Agent活动列表
|
||||
platform: 平台名称
|
||||
activities: list of agent activities
|
||||
platform: platform name
|
||||
"""
|
||||
if not activities:
|
||||
return
|
||||
|
||||
# 将多条活动合并为一条文本,用换行分隔
|
||||
|
||||
# Merge multiple activities into a single text, separated by newlines
|
||||
episode_texts = [activity.to_episode_text() for activity in activities]
|
||||
combined_text = "\n".join(episode_texts)
|
||||
|
||||
# 带重试的发送
|
||||
|
||||
# Send with retry
|
||||
for attempt in range(self.MAX_RETRIES):
|
||||
try:
|
||||
self.client.graph.add(
|
||||
|
|
@ -416,25 +418,25 @@ class ZepGraphMemoryUpdater:
|
|||
type="text",
|
||||
data=combined_text
|
||||
)
|
||||
|
||||
|
||||
self._total_sent += 1
|
||||
self._total_items_sent += len(activities)
|
||||
display_name = self._get_platform_display_name(platform)
|
||||
logger.info(f"成功批量发送 {len(activities)} 条{display_name}活动到图谱 {self.graph_id}")
|
||||
logger.debug(f"批量内容预览: {combined_text[:200]}...")
|
||||
logger.info(f"Successfully sent batch of {len(activities)} {display_name} activities to graph {self.graph_id}")
|
||||
logger.debug(f"Batch content preview: {combined_text[:200]}...")
|
||||
return
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if attempt < self.MAX_RETRIES - 1:
|
||||
logger.warning(f"批量发送到Zep失败 (尝试 {attempt + 1}/{self.MAX_RETRIES}): {e}")
|
||||
logger.warning(f"Batch send to Zep failed (attempt {attempt + 1}/{self.MAX_RETRIES}): {e}")
|
||||
time.sleep(self.RETRY_DELAY * (attempt + 1))
|
||||
else:
|
||||
logger.error(f"批量发送到Zep失败,已重试{self.MAX_RETRIES}次: {e}")
|
||||
logger.error(f"Batch send to Zep failed after {self.MAX_RETRIES} attempts: {e}")
|
||||
self._failed_count += 1
|
||||
|
||||
|
||||
def _flush_remaining(self):
|
||||
"""发送队列和缓冲区中剩余的活动"""
|
||||
# 首先处理队列中剩余的活动,添加到缓冲区
|
||||
"""Send remaining activities from the queue and buffers"""
|
||||
# First, drain the queue into the buffers
|
||||
while not self._activity_queue.empty():
|
||||
try:
|
||||
activity = self._activity_queue.get_nowait()
|
||||
|
|
@ -445,110 +447,110 @@ class ZepGraphMemoryUpdater:
|
|||
self._platform_buffers[platform].append(activity)
|
||||
except Empty:
|
||||
break
|
||||
|
||||
# 然后发送各平台缓冲区中剩余的活动(即使不足BATCH_SIZE条)
|
||||
|
||||
# Then send remaining activities in each platform buffer (even if below BATCH_SIZE)
|
||||
with self._buffer_lock:
|
||||
for platform, buffer in self._platform_buffers.items():
|
||||
if buffer:
|
||||
display_name = self._get_platform_display_name(platform)
|
||||
logger.info(f"发送{display_name}平台剩余的 {len(buffer)} 条活动")
|
||||
logger.info(f"Sending {len(buffer)} remaining {display_name} platform activities")
|
||||
self._send_batch_activities(buffer, platform)
|
||||
# 清空所有缓冲区
|
||||
# Clear all buffers
|
||||
for platform in self._platform_buffers:
|
||||
self._platform_buffers[platform] = []
|
||||
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""获取统计信息"""
|
||||
"""Get statistics"""
|
||||
with self._buffer_lock:
|
||||
buffer_sizes = {p: len(b) for p, b in self._platform_buffers.items()}
|
||||
|
||||
|
||||
return {
|
||||
"graph_id": self.graph_id,
|
||||
"batch_size": self.BATCH_SIZE,
|
||||
"total_activities": self._total_activities, # 添加到队列的活动总数
|
||||
"batches_sent": self._total_sent, # 成功发送的批次数
|
||||
"items_sent": self._total_items_sent, # 成功发送的活动条数
|
||||
"failed_count": self._failed_count, # 发送失败的批次数
|
||||
"skipped_count": self._skipped_count, # 被过滤跳过的活动数(DO_NOTHING)
|
||||
"total_activities": self._total_activities, # Total activities added to queue
|
||||
"batches_sent": self._total_sent, # Batches successfully sent
|
||||
"items_sent": self._total_items_sent, # Individual activities successfully sent
|
||||
"failed_count": self._failed_count, # Batches that failed to send
|
||||
"skipped_count": self._skipped_count, # Activities filtered out (DO_NOTHING)
|
||||
"queue_size": self._activity_queue.qsize(),
|
||||
"buffer_sizes": buffer_sizes, # 各平台缓冲区大小
|
||||
"buffer_sizes": buffer_sizes, # Per-platform buffer sizes
|
||||
"running": self._running,
|
||||
}
|
||||
|
||||
|
||||
class ZepGraphMemoryManager:
|
||||
"""
|
||||
管理多个模拟的Zep图谱记忆更新器
|
||||
|
||||
每个模拟可以有自己的更新器实例
|
||||
Manages Zep graph memory updaters for multiple simulations
|
||||
|
||||
Each simulation can have its own updater instance
|
||||
"""
|
||||
|
||||
|
||||
_updaters: Dict[str, ZepGraphMemoryUpdater] = {}
|
||||
_lock = threading.Lock()
|
||||
|
||||
|
||||
@classmethod
|
||||
def create_updater(cls, simulation_id: str, graph_id: str) -> ZepGraphMemoryUpdater:
|
||||
"""
|
||||
为模拟创建图谱记忆更新器
|
||||
|
||||
Create a graph memory updater for a simulation
|
||||
|
||||
Args:
|
||||
simulation_id: 模拟ID
|
||||
graph_id: Zep图谱ID
|
||||
|
||||
simulation_id: simulation ID
|
||||
graph_id: Zep graph ID
|
||||
|
||||
Returns:
|
||||
ZepGraphMemoryUpdater实例
|
||||
ZepGraphMemoryUpdater instance
|
||||
"""
|
||||
with cls._lock:
|
||||
# 如果已存在,先停止旧的
|
||||
# If one already exists, stop it first
|
||||
if simulation_id in cls._updaters:
|
||||
cls._updaters[simulation_id].stop()
|
||||
|
||||
|
||||
updater = ZepGraphMemoryUpdater(graph_id)
|
||||
updater.start()
|
||||
cls._updaters[simulation_id] = updater
|
||||
|
||||
logger.info(f"创建图谱记忆更新器: simulation_id={simulation_id}, graph_id={graph_id}")
|
||||
|
||||
logger.info(f"Created graph memory updater: simulation_id={simulation_id}, graph_id={graph_id}")
|
||||
return updater
|
||||
|
||||
|
||||
@classmethod
|
||||
def get_updater(cls, simulation_id: str) -> Optional[ZepGraphMemoryUpdater]:
|
||||
"""获取模拟的更新器"""
|
||||
"""Get the updater for a simulation"""
|
||||
return cls._updaters.get(simulation_id)
|
||||
|
||||
|
||||
@classmethod
|
||||
def stop_updater(cls, simulation_id: str):
|
||||
"""停止并移除模拟的更新器"""
|
||||
"""Stop and remove the updater for a simulation"""
|
||||
with cls._lock:
|
||||
if simulation_id in cls._updaters:
|
||||
cls._updaters[simulation_id].stop()
|
||||
del cls._updaters[simulation_id]
|
||||
logger.info(f"已停止图谱记忆更新器: simulation_id={simulation_id}")
|
||||
|
||||
# 防止 stop_all 重复调用的标志
|
||||
logger.info(f"Stopped graph memory updater: simulation_id={simulation_id}")
|
||||
|
||||
# Flag to prevent stop_all from being called more than once
|
||||
_stop_all_done = False
|
||||
|
||||
|
||||
@classmethod
|
||||
def stop_all(cls):
|
||||
"""停止所有更新器"""
|
||||
# 防止重复调用
|
||||
"""Stop all updaters"""
|
||||
# Prevent duplicate calls
|
||||
if cls._stop_all_done:
|
||||
return
|
||||
cls._stop_all_done = True
|
||||
|
||||
|
||||
with cls._lock:
|
||||
if cls._updaters:
|
||||
for simulation_id, updater in list(cls._updaters.items()):
|
||||
try:
|
||||
updater.stop()
|
||||
except Exception as e:
|
||||
logger.error(f"停止更新器失败: simulation_id={simulation_id}, error={e}")
|
||||
logger.error(f"Failed to stop updater: simulation_id={simulation_id}, error={e}")
|
||||
cls._updaters.clear()
|
||||
logger.info("已停止所有图谱记忆更新器")
|
||||
|
||||
logger.info("All graph memory updaters stopped")
|
||||
|
||||
@classmethod
|
||||
def get_all_stats(cls) -> Dict[str, Dict[str, Any]]:
|
||||
"""获取所有更新器的统计信息"""
|
||||
"""Get statistics for all updaters"""
|
||||
return {
|
||||
sim_id: updater.get_stats()
|
||||
sim_id: updater.get_stats()
|
||||
for sim_id, updater in cls._updaters.items()
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,5 +1,5 @@
|
|||
"""
|
||||
工具模块
|
||||
Utilities module
|
||||
"""
|
||||
|
||||
from .file_parser import FileParser
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
文件解析工具
|
||||
支持PDF、Markdown、TXT文件的文本提取
|
||||
File parsing utilities
|
||||
Supports text extraction from PDF, Markdown, and TXT files
|
||||
"""
|
||||
|
||||
import os
|
||||
|
|
@ -10,29 +10,29 @@ from typing import List, Optional
|
|||
|
||||
def _read_text_with_fallback(file_path: str) -> str:
|
||||
"""
|
||||
读取文本文件,UTF-8失败时自动探测编码。
|
||||
|
||||
采用多级回退策略:
|
||||
1. 首先尝试 UTF-8 解码
|
||||
2. 使用 charset_normalizer 检测编码
|
||||
3. 回退到 chardet 检测编码
|
||||
4. 最终使用 UTF-8 + errors='replace' 兜底
|
||||
|
||||
Read a text file, automatically detecting encoding if UTF-8 fails.
|
||||
|
||||
Uses a multi-level fallback strategy:
|
||||
1. First attempts UTF-8 decoding
|
||||
2. Uses charset_normalizer to detect encoding
|
||||
3. Falls back to chardet for encoding detection
|
||||
4. Final fallback: UTF-8 with errors='replace'
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
解码后的文本内容
|
||||
Decoded text content
|
||||
"""
|
||||
data = Path(file_path).read_bytes()
|
||||
|
||||
# 首先尝试 UTF-8
|
||||
|
||||
# First attempt: UTF-8
|
||||
try:
|
||||
return data.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
pass
|
||||
|
||||
# 尝试使用 charset_normalizer 检测编码
|
||||
|
||||
# Attempt encoding detection with charset_normalizer
|
||||
encoding = None
|
||||
try:
|
||||
from charset_normalizer import from_bytes
|
||||
|
|
@ -41,8 +41,8 @@ def _read_text_with_fallback(file_path: str) -> str:
|
|||
encoding = best.encoding
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 回退到 chardet
|
||||
|
||||
# Fall back to chardet
|
||||
if not encoding:
|
||||
try:
|
||||
import chardet
|
||||
|
|
@ -50,140 +50,139 @@ def _read_text_with_fallback(file_path: str) -> str:
|
|||
encoding = result.get('encoding') if result else None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 最终兜底:使用 UTF-8 + replace
|
||||
|
||||
# Final fallback: UTF-8 with replace
|
||||
if not encoding:
|
||||
encoding = 'utf-8'
|
||||
|
||||
|
||||
return data.decode(encoding, errors='replace')
|
||||
|
||||
|
||||
class FileParser:
|
||||
"""文件解析器"""
|
||||
|
||||
"""File parser"""
|
||||
|
||||
SUPPORTED_EXTENSIONS = {'.pdf', '.md', '.markdown', '.txt'}
|
||||
|
||||
|
||||
@classmethod
|
||||
def extract_text(cls, file_path: str) -> str:
|
||||
"""
|
||||
从文件中提取文本
|
||||
|
||||
Extract text from a file
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
|
||||
file_path: Path to the file
|
||||
|
||||
Returns:
|
||||
提取的文本内容
|
||||
Extracted text content
|
||||
"""
|
||||
path = Path(file_path)
|
||||
|
||||
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"文件不存在: {file_path}")
|
||||
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
suffix = path.suffix.lower()
|
||||
|
||||
|
||||
if suffix not in cls.SUPPORTED_EXTENSIONS:
|
||||
raise ValueError(f"不支持的文件格式: {suffix}")
|
||||
|
||||
raise ValueError(f"Unsupported file format: {suffix}")
|
||||
|
||||
if suffix == '.pdf':
|
||||
return cls._extract_from_pdf(file_path)
|
||||
elif suffix in {'.md', '.markdown'}:
|
||||
return cls._extract_from_md(file_path)
|
||||
elif suffix == '.txt':
|
||||
return cls._extract_from_txt(file_path)
|
||||
|
||||
raise ValueError(f"无法处理的文件格式: {suffix}")
|
||||
|
||||
|
||||
raise ValueError(f"Cannot process file format: {suffix}")
|
||||
|
||||
@staticmethod
|
||||
def _extract_from_pdf(file_path: str) -> str:
|
||||
"""从PDF提取文本"""
|
||||
"""Extract text from a PDF file"""
|
||||
try:
|
||||
import fitz # PyMuPDF
|
||||
except ImportError:
|
||||
raise ImportError("需要安装PyMuPDF: pip install PyMuPDF")
|
||||
|
||||
raise ImportError("PyMuPDF is required: pip install PyMuPDF")
|
||||
|
||||
text_parts = []
|
||||
with fitz.open(file_path) as doc:
|
||||
for page in doc:
|
||||
text = page.get_text()
|
||||
if text.strip():
|
||||
text_parts.append(text)
|
||||
|
||||
|
||||
return "\n\n".join(text_parts)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _extract_from_md(file_path: str) -> str:
|
||||
"""从Markdown提取文本,支持自动编码检测"""
|
||||
"""Extract text from a Markdown file with automatic encoding detection"""
|
||||
return _read_text_with_fallback(file_path)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _extract_from_txt(file_path: str) -> str:
|
||||
"""从TXT提取文本,支持自动编码检测"""
|
||||
"""Extract text from a TXT file with automatic encoding detection"""
|
||||
return _read_text_with_fallback(file_path)
|
||||
|
||||
|
||||
@classmethod
|
||||
def extract_from_multiple(cls, file_paths: List[str]) -> str:
|
||||
"""
|
||||
从多个文件提取文本并合并
|
||||
|
||||
Extract text from multiple files and merge the results
|
||||
|
||||
Args:
|
||||
file_paths: 文件路径列表
|
||||
|
||||
file_paths: List of file paths
|
||||
|
||||
Returns:
|
||||
合并后的文本
|
||||
Merged text content
|
||||
"""
|
||||
all_texts = []
|
||||
|
||||
|
||||
for i, file_path in enumerate(file_paths, 1):
|
||||
try:
|
||||
text = cls.extract_text(file_path)
|
||||
filename = Path(file_path).name
|
||||
all_texts.append(f"=== 文档 {i}: {filename} ===\n{text}")
|
||||
all_texts.append(f"=== Document {i}: {filename} ===\n{text}")
|
||||
except Exception as e:
|
||||
all_texts.append(f"=== 文档 {i}: {file_path} (提取失败: {str(e)}) ===")
|
||||
|
||||
all_texts.append(f"=== Document {i}: {file_path} (extraction failed: {str(e)}) ===")
|
||||
|
||||
return "\n\n".join(all_texts)
|
||||
|
||||
|
||||
def split_text_into_chunks(
|
||||
text: str,
|
||||
chunk_size: int = 500,
|
||||
text: str,
|
||||
chunk_size: int = 500,
|
||||
overlap: int = 50
|
||||
) -> List[str]:
|
||||
"""
|
||||
将文本分割成小块
|
||||
|
||||
Split text into smaller chunks
|
||||
|
||||
Args:
|
||||
text: 原始文本
|
||||
chunk_size: 每块的字符数
|
||||
overlap: 重叠字符数
|
||||
|
||||
text: Source text
|
||||
chunk_size: Number of characters per chunk
|
||||
overlap: Number of overlapping characters between chunks
|
||||
|
||||
Returns:
|
||||
文本块列表
|
||||
List of text chunks
|
||||
"""
|
||||
if len(text) <= chunk_size:
|
||||
return [text] if text.strip() else []
|
||||
|
||||
|
||||
chunks = []
|
||||
start = 0
|
||||
|
||||
|
||||
while start < len(text):
|
||||
end = start + chunk_size
|
||||
|
||||
# 尝试在句子边界处分割
|
||||
|
||||
# Try to split at sentence boundaries
|
||||
if end < len(text):
|
||||
# 查找最近的句子结束符
|
||||
# Find the nearest sentence-ending separator
|
||||
for sep in ['。', '!', '?', '.\n', '!\n', '?\n', '\n\n', '. ', '! ', '? ']:
|
||||
last_sep = text[start:end].rfind(sep)
|
||||
if last_sep != -1 and last_sep > chunk_size * 0.3:
|
||||
end = start + last_sep + len(sep)
|
||||
break
|
||||
|
||||
|
||||
chunk = text[start:end].strip()
|
||||
if chunk:
|
||||
chunks.append(chunk)
|
||||
|
||||
# 下一个块从重叠位置开始
|
||||
start = end - overlap if end < len(text) else len(text)
|
||||
|
||||
return chunks
|
||||
|
||||
# Next chunk starts at the overlap position
|
||||
start = end - overlap if end < len(text) else len(text)
|
||||
|
||||
return chunks
|
||||
|
|
|
|||
|
|
@ -1,19 +1,20 @@
|
|||
"""
|
||||
LLM客户端封装
|
||||
统一使用OpenAI格式调用
|
||||
LLM client wrapper
|
||||
Unified interface using the OpenAI-compatible API format
|
||||
"""
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Optional, Dict, Any, List
|
||||
from urllib.parse import urlparse, parse_qs, urlunparse
|
||||
from openai import OpenAI
|
||||
|
||||
from ..config import Config
|
||||
|
||||
|
||||
class LLMClient:
|
||||
"""LLM客户端"""
|
||||
|
||||
"""LLM client"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
|
|
@ -21,17 +22,32 @@ class LLMClient:
|
|||
model: Optional[str] = None
|
||||
):
|
||||
self.api_key = api_key or Config.LLM_API_KEY
|
||||
self.base_url = base_url or Config.LLM_BASE_URL
|
||||
raw_url = base_url or Config.LLM_BASE_URL
|
||||
self.model = model or Config.LLM_MODEL_NAME
|
||||
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("LLM_API_KEY 未配置")
|
||||
|
||||
raise ValueError("LLM_API_KEY is not configured")
|
||||
|
||||
# Azure Portal provides full endpoint URLs like:
|
||||
# https://<resource>.cognitiveservices.azure.com/openai/deployments/<model>/chat/completions?api-version=...
|
||||
# The OpenAI SDK expects a base_url and appends /chat/completions itself,
|
||||
# so we strip that suffix and extract api-version as a default query param.
|
||||
default_query: Dict[str, str] = {}
|
||||
if raw_url and '/chat/completions' in raw_url:
|
||||
parsed = urlparse(raw_url)
|
||||
qs = parse_qs(parsed.query)
|
||||
if 'api-version' in qs:
|
||||
default_query['api-version'] = qs['api-version'][0]
|
||||
clean_path = parsed.path.replace('/chat/completions', '').rstrip('/')
|
||||
raw_url = urlunparse(parsed._replace(path=clean_path, query=''))
|
||||
|
||||
self.base_url = raw_url
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url
|
||||
base_url=self.base_url,
|
||||
default_query=default_query if default_query else None
|
||||
)
|
||||
|
||||
|
||||
def chat(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
|
|
@ -40,33 +56,33 @@ class LLMClient:
|
|||
response_format: Optional[Dict] = None
|
||||
) -> str:
|
||||
"""
|
||||
发送聊天请求
|
||||
|
||||
Send a chat request
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
response_format: 响应格式(如JSON模式)
|
||||
|
||||
messages: List of messages
|
||||
temperature: Temperature parameter
|
||||
max_tokens: Maximum number of tokens
|
||||
response_format: Response format (e.g. JSON mode)
|
||||
|
||||
Returns:
|
||||
模型响应文本
|
||||
Model response text
|
||||
"""
|
||||
kwargs = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"max_completion_tokens": max_tokens,
|
||||
}
|
||||
|
||||
|
||||
if response_format:
|
||||
kwargs["response_format"] = response_format
|
||||
|
||||
|
||||
response = self.client.chat.completions.create(**kwargs)
|
||||
content = response.choices[0].message.content
|
||||
# 部分模型(如MiniMax M2.5)会在content中包含<think>思考内容,需要移除
|
||||
# Some models (e.g. MiniMax M2.5) include <think> reasoning content in the response; strip it out
|
||||
content = re.sub(r'<think>[\s\S]*?</think>', '', content).strip()
|
||||
return content
|
||||
|
||||
|
||||
def chat_json(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
|
|
@ -74,15 +90,15 @@ class LLMClient:
|
|||
max_tokens: int = 4096
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
发送聊天请求并返回JSON
|
||||
|
||||
Send a chat request and return parsed JSON
|
||||
|
||||
Args:
|
||||
messages: 消息列表
|
||||
temperature: 温度参数
|
||||
max_tokens: 最大token数
|
||||
|
||||
messages: List of messages
|
||||
temperature: Temperature parameter
|
||||
max_tokens: Maximum number of tokens
|
||||
|
||||
Returns:
|
||||
解析后的JSON对象
|
||||
Parsed JSON object
|
||||
"""
|
||||
response = self.chat(
|
||||
messages=messages,
|
||||
|
|
@ -90,7 +106,7 @@ class LLMClient:
|
|||
max_tokens=max_tokens,
|
||||
response_format={"type": "json_object"}
|
||||
)
|
||||
# 清理markdown代码块标记
|
||||
# Strip markdown code-block markers if present
|
||||
cleaned_response = response.strip()
|
||||
cleaned_response = re.sub(r'^```(?:json)?\s*\n?', '', cleaned_response, flags=re.IGNORECASE)
|
||||
cleaned_response = re.sub(r'\n?```\s*$', '', cleaned_response)
|
||||
|
|
@ -99,5 +115,4 @@ class LLMClient:
|
|||
try:
|
||||
return json.loads(cleaned_response)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"LLM返回的JSON格式无效: {cleaned_response}")
|
||||
|
||||
raise ValueError(f"Invalid JSON returned by LLM: {cleaned_response}")
|
||||
|
|
|
|||
|
|
@ -66,4 +66,4 @@ def t(key: str, **kwargs) -> str:
|
|||
def get_language_instruction() -> str:
|
||||
locale = get_locale()
|
||||
lang_config = _languages.get(locale, _languages.get('zh', {}))
|
||||
return lang_config.get('llmInstruction', '请使用中文回答。')
|
||||
return lang_config.get('llmInstruction', 'Please respond in Chinese.')
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
日志配置模块
|
||||
提供统一的日志管理,同时输出到控制台和文件
|
||||
Logging configuration module
|
||||
Provides unified log management, writing to both console and file
|
||||
"""
|
||||
|
||||
import os
|
||||
|
|
@ -12,58 +12,58 @@ from logging.handlers import RotatingFileHandler
|
|||
|
||||
def _ensure_utf8_stdout():
|
||||
"""
|
||||
确保 stdout/stderr 使用 UTF-8 编码
|
||||
解决 Windows 控制台中文乱码问题
|
||||
Ensure stdout/stderr use UTF-8 encoding.
|
||||
Fixes garbled output in Windows consoles.
|
||||
"""
|
||||
if sys.platform == 'win32':
|
||||
# Windows 下重新配置标准输出为 UTF-8
|
||||
# Reconfigure standard streams to UTF-8 on Windows
|
||||
if hasattr(sys.stdout, 'reconfigure'):
|
||||
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
|
||||
if hasattr(sys.stderr, 'reconfigure'):
|
||||
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
|
||||
|
||||
|
||||
# 日志目录
|
||||
# Log directory
|
||||
LOG_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'logs')
|
||||
|
||||
|
||||
def setup_logger(name: str = 'mirofish', level: int = logging.DEBUG) -> logging.Logger:
|
||||
"""
|
||||
设置日志器
|
||||
|
||||
Set up a logger
|
||||
|
||||
Args:
|
||||
name: 日志器名称
|
||||
level: 日志级别
|
||||
|
||||
name: Logger name
|
||||
level: Log level
|
||||
|
||||
Returns:
|
||||
配置好的日志器
|
||||
Configured logger instance
|
||||
"""
|
||||
# 确保日志目录存在
|
||||
# Ensure the log directory exists
|
||||
os.makedirs(LOG_DIR, exist_ok=True)
|
||||
|
||||
# 创建日志器
|
||||
|
||||
# Create logger
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(level)
|
||||
|
||||
# 阻止日志向上传播到根 logger,避免重复输出
|
||||
|
||||
# Prevent log records from propagating to the root logger to avoid duplicate output
|
||||
logger.propagate = False
|
||||
|
||||
# 如果已经有处理器,不重复添加
|
||||
|
||||
# Skip adding handlers if they already exist
|
||||
if logger.handlers:
|
||||
return logger
|
||||
|
||||
# 日志格式
|
||||
|
||||
# Log formatters
|
||||
detailed_formatter = logging.Formatter(
|
||||
'[%(asctime)s] %(levelname)s [%(name)s.%(funcName)s:%(lineno)d] %(message)s',
|
||||
datefmt='%Y-%m-%d %H:%M:%S'
|
||||
)
|
||||
|
||||
|
||||
simple_formatter = logging.Formatter(
|
||||
'[%(asctime)s] %(levelname)s: %(message)s',
|
||||
datefmt='%H:%M:%S'
|
||||
)
|
||||
|
||||
# 1. 文件处理器 - 详细日志(按日期命名,带轮转)
|
||||
|
||||
# 1. File handler — detailed logs (date-stamped filename with rotation)
|
||||
log_filename = datetime.now().strftime('%Y-%m-%d') + '.log'
|
||||
file_handler = RotatingFileHandler(
|
||||
os.path.join(LOG_DIR, log_filename),
|
||||
|
|
@ -73,30 +73,30 @@ def setup_logger(name: str = 'mirofish', level: int = logging.DEBUG) -> logging.
|
|||
)
|
||||
file_handler.setLevel(logging.DEBUG)
|
||||
file_handler.setFormatter(detailed_formatter)
|
||||
|
||||
# 2. 控制台处理器 - 简洁日志(INFO及以上)
|
||||
# 确保 Windows 下使用 UTF-8 编码,避免中文乱码
|
||||
|
||||
# 2. Console handler — concise logs (INFO and above)
|
||||
# Ensure UTF-8 encoding on Windows to avoid garbled output
|
||||
_ensure_utf8_stdout()
|
||||
console_handler = logging.StreamHandler(sys.stdout)
|
||||
console_handler.setLevel(logging.INFO)
|
||||
console_handler.setFormatter(simple_formatter)
|
||||
|
||||
# 添加处理器
|
||||
|
||||
# Register handlers
|
||||
logger.addHandler(file_handler)
|
||||
logger.addHandler(console_handler)
|
||||
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def get_logger(name: str = 'mirofish') -> logging.Logger:
|
||||
"""
|
||||
获取日志器(如果不存在则创建)
|
||||
|
||||
Get a logger, creating it if it does not exist
|
||||
|
||||
Args:
|
||||
name: 日志器名称
|
||||
|
||||
name: Logger name
|
||||
|
||||
Returns:
|
||||
日志器实例
|
||||
Logger instance
|
||||
"""
|
||||
logger = logging.getLogger(name)
|
||||
if not logger.handlers:
|
||||
|
|
@ -104,11 +104,11 @@ def get_logger(name: str = 'mirofish') -> logging.Logger:
|
|||
return logger
|
||||
|
||||
|
||||
# 创建默认日志器
|
||||
# Create default logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# 便捷方法
|
||||
# Convenience functions
|
||||
def debug(msg, *args, **kwargs):
|
||||
logger.debug(msg, *args, **kwargs)
|
||||
|
||||
|
|
@ -123,4 +123,3 @@ def error(msg, *args, **kwargs):
|
|||
|
||||
def critical(msg, *args, **kwargs):
|
||||
logger.critical(msg, *args, **kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
API调用重试机制
|
||||
用于处理LLM等外部API调用的重试逻辑
|
||||
API call retry mechanism
|
||||
Handles retry logic for external API calls such as LLM services
|
||||
"""
|
||||
|
||||
import time
|
||||
|
|
@ -22,17 +22,17 @@ def retry_with_backoff(
|
|||
on_retry: Optional[Callable[[Exception, int], None]] = None
|
||||
):
|
||||
"""
|
||||
带指数退避的重试装饰器
|
||||
|
||||
Retry decorator with exponential backoff
|
||||
|
||||
Args:
|
||||
max_retries: 最大重试次数
|
||||
initial_delay: 初始延迟(秒)
|
||||
max_delay: 最大延迟(秒)
|
||||
backoff_factor: 退避因子
|
||||
jitter: 是否添加随机抖动
|
||||
exceptions: 需要重试的异常类型
|
||||
on_retry: 重试时的回调函数 (exception, retry_count)
|
||||
|
||||
max_retries: Maximum number of retries
|
||||
initial_delay: Initial delay in seconds
|
||||
max_delay: Maximum delay in seconds
|
||||
backoff_factor: Backoff multiplier
|
||||
jitter: Whether to add random jitter
|
||||
exceptions: Exception types that should trigger a retry
|
||||
on_retry: Callback invoked on each retry (exception, retry_count)
|
||||
|
||||
Usage:
|
||||
@retry_with_backoff(max_retries=3)
|
||||
def call_llm_api():
|
||||
|
|
@ -43,36 +43,36 @@ def retry_with_backoff(
|
|||
def wrapper(*args, **kwargs) -> Any:
|
||||
last_exception = None
|
||||
delay = initial_delay
|
||||
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
except exceptions as e:
|
||||
last_exception = e
|
||||
|
||||
|
||||
if attempt == max_retries:
|
||||
logger.error(f"函数 {func.__name__} 在 {max_retries} 次重试后仍失败: {str(e)}")
|
||||
logger.error(f"Function {func.__name__} failed after {max_retries} retries: {str(e)}")
|
||||
raise
|
||||
|
||||
# 计算延迟
|
||||
|
||||
# Calculate delay
|
||||
current_delay = min(delay, max_delay)
|
||||
if jitter:
|
||||
current_delay = current_delay * (0.5 + random.random())
|
||||
|
||||
|
||||
logger.warning(
|
||||
f"函数 {func.__name__} 第 {attempt + 1} 次尝试失败: {str(e)}, "
|
||||
f"{current_delay:.1f}秒后重试..."
|
||||
f"Function {func.__name__} attempt {attempt + 1} failed: {str(e)}, "
|
||||
f"retrying in {current_delay:.1f}s..."
|
||||
)
|
||||
|
||||
|
||||
if on_retry:
|
||||
on_retry(e, attempt + 1)
|
||||
|
||||
|
||||
time.sleep(current_delay)
|
||||
delay *= backoff_factor
|
||||
|
||||
|
||||
raise last_exception
|
||||
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
|
@ -87,53 +87,53 @@ def retry_with_backoff_async(
|
|||
on_retry: Optional[Callable[[Exception, int], None]] = None
|
||||
):
|
||||
"""
|
||||
异步版本的重试装饰器
|
||||
Async version of the retry decorator
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args, **kwargs) -> Any:
|
||||
last_exception = None
|
||||
delay = initial_delay
|
||||
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
|
||||
except exceptions as e:
|
||||
last_exception = e
|
||||
|
||||
|
||||
if attempt == max_retries:
|
||||
logger.error(f"异步函数 {func.__name__} 在 {max_retries} 次重试后仍失败: {str(e)}")
|
||||
logger.error(f"Async function {func.__name__} failed after {max_retries} retries: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
current_delay = min(delay, max_delay)
|
||||
if jitter:
|
||||
current_delay = current_delay * (0.5 + random.random())
|
||||
|
||||
|
||||
logger.warning(
|
||||
f"异步函数 {func.__name__} 第 {attempt + 1} 次尝试失败: {str(e)}, "
|
||||
f"{current_delay:.1f}秒后重试..."
|
||||
f"Async function {func.__name__} attempt {attempt + 1} failed: {str(e)}, "
|
||||
f"retrying in {current_delay:.1f}s..."
|
||||
)
|
||||
|
||||
|
||||
if on_retry:
|
||||
on_retry(e, attempt + 1)
|
||||
|
||||
|
||||
await asyncio.sleep(current_delay)
|
||||
delay *= backoff_factor
|
||||
|
||||
|
||||
raise last_exception
|
||||
|
||||
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
|
||||
class RetryableAPIClient:
|
||||
"""
|
||||
可重试的API客户端封装
|
||||
Retryable API client wrapper
|
||||
"""
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_retries: int = 3,
|
||||
|
|
@ -145,7 +145,7 @@ class RetryableAPIClient:
|
|||
self.initial_delay = initial_delay
|
||||
self.max_delay = max_delay
|
||||
self.backoff_factor = backoff_factor
|
||||
|
||||
|
||||
def call_with_retry(
|
||||
self,
|
||||
func: Callable,
|
||||
|
|
@ -154,44 +154,44 @@ class RetryableAPIClient:
|
|||
**kwargs
|
||||
) -> Any:
|
||||
"""
|
||||
执行函数调用并在失败时重试
|
||||
|
||||
Execute a function call and retry on failure
|
||||
|
||||
Args:
|
||||
func: 要调用的函数
|
||||
*args: 函数参数
|
||||
exceptions: 需要重试的异常类型
|
||||
**kwargs: 函数关键字参数
|
||||
|
||||
func: Function to call
|
||||
*args: Positional arguments for the function
|
||||
exceptions: Exception types that should trigger a retry
|
||||
**kwargs: Keyword arguments for the function
|
||||
|
||||
Returns:
|
||||
函数返回值
|
||||
Return value of the function
|
||||
"""
|
||||
last_exception = None
|
||||
delay = self.initial_delay
|
||||
|
||||
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
||||
except exceptions as e:
|
||||
last_exception = e
|
||||
|
||||
|
||||
if attempt == self.max_retries:
|
||||
logger.error(f"API调用在 {self.max_retries} 次重试后仍失败: {str(e)}")
|
||||
logger.error(f"API call failed after {self.max_retries} retries: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
current_delay = min(delay, self.max_delay)
|
||||
current_delay = current_delay * (0.5 + random.random())
|
||||
|
||||
|
||||
logger.warning(
|
||||
f"API调用第 {attempt + 1} 次尝试失败: {str(e)}, "
|
||||
f"{current_delay:.1f}秒后重试..."
|
||||
f"API call attempt {attempt + 1} failed: {str(e)}, "
|
||||
f"retrying in {current_delay:.1f}s..."
|
||||
)
|
||||
|
||||
|
||||
time.sleep(current_delay)
|
||||
delay *= self.backoff_factor
|
||||
|
||||
|
||||
raise last_exception
|
||||
|
||||
|
||||
def call_batch_with_retry(
|
||||
self,
|
||||
items: list,
|
||||
|
|
@ -200,20 +200,20 @@ class RetryableAPIClient:
|
|||
continue_on_failure: bool = True
|
||||
) -> Tuple[list, list]:
|
||||
"""
|
||||
批量调用并对每个失败项单独重试
|
||||
|
||||
Process a batch of items, retrying individually on failure
|
||||
|
||||
Args:
|
||||
items: 要处理的项目列表
|
||||
process_func: 处理函数,接收单个item作为参数
|
||||
exceptions: 需要重试的异常类型
|
||||
continue_on_failure: 单项失败后是否继续处理其他项
|
||||
|
||||
items: List of items to process
|
||||
process_func: Processing function that accepts a single item
|
||||
exceptions: Exception types that should trigger a retry
|
||||
continue_on_failure: Whether to continue processing remaining items after a failure
|
||||
|
||||
Returns:
|
||||
(成功结果列表, 失败项列表)
|
||||
(list of successful results, list of failed items)
|
||||
"""
|
||||
results = []
|
||||
failures = []
|
||||
|
||||
|
||||
for idx, item in enumerate(items):
|
||||
try:
|
||||
result = self.call_with_retry(
|
||||
|
|
@ -222,17 +222,16 @@ class RetryableAPIClient:
|
|||
exceptions=exceptions
|
||||
)
|
||||
results.append(result)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"处理第 {idx + 1} 项失败: {str(e)}")
|
||||
logger.error(f"Failed to process item {idx + 1}: {str(e)}")
|
||||
failures.append({
|
||||
"index": idx,
|
||||
"item": item,
|
||||
"error": str(e)
|
||||
})
|
||||
|
||||
|
||||
if not continue_on_failure:
|
||||
raise
|
||||
|
||||
return results, failures
|
||||
|
||||
return results, failures
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
"""Zep Graph 分页读取工具。
|
||||
"""Zep Graph paginated fetch utilities.
|
||||
|
||||
Zep 的 node/edge 列表接口使用 UUID cursor 分页,
|
||||
本模块封装自动翻页逻辑(含单页重试),对调用方透明地返回完整列表。
|
||||
Zep's node/edge list endpoints use UUID-cursor-based pagination.
|
||||
This module wraps the auto-pagination logic (with per-page retries) and
|
||||
returns the full result list transparently to the caller.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
|
@ -31,7 +32,7 @@ def _fetch_page_with_retry(
|
|||
page_description: str = "page",
|
||||
**kwargs: Any,
|
||||
) -> list[Any]:
|
||||
"""单页请求,失败时指数退避重试。仅重试网络/IO类瞬态错误。"""
|
||||
"""Fetch a single page with exponential-backoff retry on transient network/IO errors."""
|
||||
if max_retries < 1:
|
||||
raise ValueError("max_retries must be >= 1")
|
||||
|
||||
|
|
@ -64,7 +65,7 @@ def fetch_all_nodes(
|
|||
max_retries: int = _DEFAULT_MAX_RETRIES,
|
||||
retry_delay: float = _DEFAULT_RETRY_DELAY,
|
||||
) -> list[Any]:
|
||||
"""分页获取图谱节点,最多返回 max_items 条(默认 2000)。每页请求自带重试。"""
|
||||
"""Fetch all graph nodes with pagination, returning at most max_items (default 2000). Each page request includes retries."""
|
||||
all_nodes: list[Any] = []
|
||||
cursor: str | None = None
|
||||
page_num = 0
|
||||
|
|
@ -109,7 +110,7 @@ def fetch_all_edges(
|
|||
max_retries: int = _DEFAULT_MAX_RETRIES,
|
||||
retry_delay: float = _DEFAULT_RETRY_DELAY,
|
||||
) -> list[Any]:
|
||||
"""分页获取图谱所有边,返回完整列表。每页请求自带重试。"""
|
||||
"""Fetch all graph edges with pagination, returning the complete list. Each page request includes retries."""
|
||||
all_edges: list[Any] = []
|
||||
cursor: str | None = None
|
||||
page_num = 0
|
||||
|
|
|
|||
|
|
@ -1435,7 +1435,6 @@
|
|||
"resolved": "https://registry.npmjs.org/d3-selection/-/d3-selection-3.0.0.tgz",
|
||||
"integrity": "sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==",
|
||||
"license": "ISC",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
}
|
||||
|
|
@ -1913,7 +1912,6 @@
|
|||
"integrity": "sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
|
|
@ -2053,7 +2051,6 @@
|
|||
"integrity": "sha512-ITcnkFeR3+fI8P1wMgItjGrR10170d8auB4EpMLPqmx6uxElH3a/hHGQabSHKdqd4FXWO1nFIp9rRn7JQ34ACQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"esbuild": "^0.25.0",
|
||||
"fdir": "^6.5.0",
|
||||
|
|
@ -2128,7 +2125,6 @@
|
|||
"resolved": "https://registry.npmjs.org/vue/-/vue-3.5.25.tgz",
|
||||
"integrity": "sha512-YLVdgv2K13WJ6n+kD5owehKtEXwdwXuj2TTyJMsO7pSeKw2bfRNZGjhB7YzrpbMYj5b5QsUebHpOqR3R3ziy/g==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@vue/compiler-dom": "3.5.25",
|
||||
"@vue/compiler-sfc": "3.5.25",
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import authState, { clearToken } from '../store/auth'
|
|||
|
||||
// 创建axios实例
|
||||
const service = axios.create({
|
||||
baseURL: import.meta.env.VITE_API_BASE_URL || 'http://localhost:5001',
|
||||
baseURL: import.meta.env.VITE_API_BASE_URL || '',
|
||||
timeout: 300000, // 5分钟超时(本体生成可能需要较长时间)
|
||||
headers: {
|
||||
'Content-Type': 'application/json'
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ for (const path in localeFiles) {
|
|||
}
|
||||
}
|
||||
|
||||
const savedLocale = localStorage.getItem('locale') || 'zh'
|
||||
const savedLocale = localStorage.getItem('locale') || 'ca'
|
||||
|
||||
const i18n = createI18n({
|
||||
legacy: false,
|
||||
|
|
|
|||
Loading…
Reference in New Issue