docs(i18n): translate chinese docstrings/comments in backend root, api init, simulation_ipc, simulation_manager, zep_entity_reader
This commit is contained in:
parent
e3f7defefc
commit
e1019d91cb
|
|
@ -1,12 +1,10 @@
|
||||||
"""
|
"""MiroFish backend Flask application factory."""
|
||||||
MiroFish Backend - Flask应用工厂
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
# 抑制 multiprocessing resource_tracker 的警告(来自第三方库如 transformers)
|
# Silence multiprocessing.resource_tracker warnings emitted by some third-party
|
||||||
# 需要在所有其他导入之前设置
|
# libraries (e.g. transformers); must run before those modules are imported.
|
||||||
warnings.filterwarnings("ignore", message=".*resource_tracker.*")
|
warnings.filterwarnings("ignore", message=".*resource_tracker.*")
|
||||||
|
|
||||||
from flask import Flask, request
|
from flask import Flask, request
|
||||||
|
|
@ -18,19 +16,21 @@ from .utils.locale import t
|
||||||
|
|
||||||
|
|
||||||
def create_app(config_class=Config):
|
def create_app(config_class=Config):
|
||||||
"""Flask应用工厂函数"""
|
"""Flask application factory."""
|
||||||
app = Flask(__name__)
|
app = Flask(__name__)
|
||||||
app.config.from_object(config_class)
|
app.config.from_object(config_class)
|
||||||
|
|
||||||
# 设置JSON编码:确保中文直接显示(而不是 \uXXXX 格式)
|
# Configure JSON encoding so non-ASCII characters render literally
|
||||||
# Flask >= 2.3 使用 app.json.ensure_ascii,旧版本使用 JSON_AS_ASCII 配置
|
# rather than as \uXXXX escape sequences. Flask >= 2.3 exposes
|
||||||
|
# ``app.json.ensure_ascii``; older versions use ``JSON_AS_ASCII``.
|
||||||
if hasattr(app, 'json') and hasattr(app.json, 'ensure_ascii'):
|
if hasattr(app, 'json') and hasattr(app.json, 'ensure_ascii'):
|
||||||
app.json.ensure_ascii = False
|
app.json.ensure_ascii = False
|
||||||
|
|
||||||
# 设置日志
|
# Configure logging.
|
||||||
logger = setup_logger('mirofish')
|
logger = setup_logger('mirofish')
|
||||||
|
|
||||||
# 只在 reloader 子进程中打印启动信息(避免 debug 模式下打印两次)
|
# Only print startup banners in the reloader child process to avoid
|
||||||
|
# double-printing in debug mode.
|
||||||
is_reloader_process = os.environ.get('WERKZEUG_RUN_MAIN') == 'true'
|
is_reloader_process = os.environ.get('WERKZEUG_RUN_MAIN') == 'true'
|
||||||
debug_mode = app.config.get('DEBUG', False)
|
debug_mode = app.config.get('DEBUG', False)
|
||||||
should_log_startup = not debug_mode or is_reloader_process
|
should_log_startup = not debug_mode or is_reloader_process
|
||||||
|
|
@ -40,16 +40,17 @@ def create_app(config_class=Config):
|
||||||
logger.info(t("log.bootstrap.m001"))
|
logger.info(t("log.bootstrap.m001"))
|
||||||
logger.info("=" * 50)
|
logger.info("=" * 50)
|
||||||
|
|
||||||
# 启用CORS
|
# Enable CORS.
|
||||||
CORS(app, resources={r"/api/*": {"origins": "*"}})
|
CORS(app, resources={r"/api/*": {"origins": "*"}})
|
||||||
|
|
||||||
# 注册模拟进程清理函数(确保服务器关闭时终止所有模拟进程)
|
# Register simulation-process cleanup so all child processes are torn down
|
||||||
|
# when the Flask server shuts down.
|
||||||
from .services.simulation_runner import SimulationRunner
|
from .services.simulation_runner import SimulationRunner
|
||||||
SimulationRunner.register_cleanup()
|
SimulationRunner.register_cleanup()
|
||||||
if should_log_startup:
|
if should_log_startup:
|
||||||
logger.info(t("log.bootstrap.m002"))
|
logger.info(t("log.bootstrap.m002"))
|
||||||
|
|
||||||
# 请求日志中间件
|
# Request-logging middleware.
|
||||||
@app.before_request
|
@app.before_request
|
||||||
def log_request():
|
def log_request():
|
||||||
logger = get_logger('mirofish.request')
|
logger = get_logger('mirofish.request')
|
||||||
|
|
@ -63,13 +64,13 @@ def create_app(config_class=Config):
|
||||||
logger.debug(t("log.bootstrap.m005", response=response.status_code))
|
logger.debug(t("log.bootstrap.m005", response=response.status_code))
|
||||||
return response
|
return response
|
||||||
|
|
||||||
# 注册蓝图
|
# Register API blueprints.
|
||||||
from .api import graph_bp, simulation_bp, report_bp
|
from .api import graph_bp, simulation_bp, report_bp
|
||||||
app.register_blueprint(graph_bp, url_prefix='/api/graph')
|
app.register_blueprint(graph_bp, url_prefix='/api/graph')
|
||||||
app.register_blueprint(simulation_bp, url_prefix='/api/simulation')
|
app.register_blueprint(simulation_bp, url_prefix='/api/simulation')
|
||||||
app.register_blueprint(report_bp, url_prefix='/api/report')
|
app.register_blueprint(report_bp, url_prefix='/api/report')
|
||||||
|
|
||||||
# 健康检查
|
# Health-check endpoint.
|
||||||
@app.route('/health')
|
@app.route('/health')
|
||||||
def health():
|
def health():
|
||||||
return {'status': 'ok', 'service': 'MiroFish Backend'}
|
return {'status': 'ok', 'service': 'MiroFish Backend'}
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,4 @@
|
||||||
"""
|
"""API blueprints package."""
|
||||||
API路由模块
|
|
||||||
"""
|
|
||||||
|
|
||||||
from flask import Blueprint
|
from flask import Blueprint
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,38 +1,40 @@
|
||||||
"""
|
"""Configuration management.
|
||||||
配置管理
|
|
||||||
统一从项目根目录的 .env 文件加载配置
|
Loads configuration values from the project-root ``.env`` file.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
# 加载项目根目录的 .env 文件
|
# Load the project-root .env file.
|
||||||
# 路径: MiroFish/.env (相对于 backend/app/config.py)
|
# Path: MiroFish/.env (relative to backend/app/config.py).
|
||||||
project_root_env = os.path.join(os.path.dirname(__file__), '../../.env')
|
project_root_env = os.path.join(os.path.dirname(__file__), '../../.env')
|
||||||
|
|
||||||
if os.path.exists(project_root_env):
|
if os.path.exists(project_root_env):
|
||||||
load_dotenv(project_root_env, override=True)
|
load_dotenv(project_root_env, override=True)
|
||||||
else:
|
else:
|
||||||
# 如果根目录没有 .env,尝试加载环境变量(用于生产环境)
|
# If the project root has no .env, fall back to the process environment
|
||||||
|
# (used in production deployments).
|
||||||
load_dotenv(override=True)
|
load_dotenv(override=True)
|
||||||
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Flask配置类"""
|
"""Flask configuration class."""
|
||||||
|
|
||||||
# Flask配置
|
# Flask settings.
|
||||||
SECRET_KEY = os.environ.get('SECRET_KEY', 'mirofish-secret-key')
|
SECRET_KEY = os.environ.get('SECRET_KEY', 'mirofish-secret-key')
|
||||||
DEBUG = os.environ.get('FLASK_DEBUG', 'True').lower() == 'true'
|
DEBUG = os.environ.get('FLASK_DEBUG', 'True').lower() == 'true'
|
||||||
|
|
||||||
# JSON配置 - 禁用ASCII转义,让中文直接显示(而不是 \uXXXX 格式)
|
# JSON settings: disable ASCII escaping so non-ASCII output renders literally
|
||||||
|
# rather than as \uXXXX escape sequences.
|
||||||
JSON_AS_ASCII = False
|
JSON_AS_ASCII = False
|
||||||
|
|
||||||
# LLM配置(统一使用OpenAI格式)
|
# LLM settings (called via the OpenAI-compatible API surface).
|
||||||
LLM_API_KEY = os.environ.get('LLM_API_KEY')
|
LLM_API_KEY = os.environ.get('LLM_API_KEY')
|
||||||
LLM_BASE_URL = os.environ.get('LLM_BASE_URL', 'https://api.openai.com/v1')
|
LLM_BASE_URL = os.environ.get('LLM_BASE_URL', 'https://api.openai.com/v1')
|
||||||
LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME', 'gpt-4o-mini')
|
LLM_MODEL_NAME = os.environ.get('LLM_MODEL_NAME', 'gpt-4o-mini')
|
||||||
|
|
||||||
# Neo4j + Graphiti配置(替代 Zep Cloud)
|
# Neo4j + Graphiti settings (replacement for Zep Cloud).
|
||||||
NEO4J_URI = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
|
NEO4J_URI = os.environ.get('NEO4J_URI', 'bolt://localhost:7687')
|
||||||
NEO4J_USER = os.environ.get('NEO4J_USER', 'neo4j')
|
NEO4J_USER = os.environ.get('NEO4J_USER', 'neo4j')
|
||||||
NEO4J_PASSWORD = os.environ.get('NEO4J_PASSWORD', 'mirofish123')
|
NEO4J_PASSWORD = os.environ.get('NEO4J_PASSWORD', 'mirofish123')
|
||||||
|
|
@ -50,23 +52,23 @@ class Config:
|
||||||
EMBEDDING_API_KEY = os.environ.get('EMBEDDING_API_KEY')
|
EMBEDDING_API_KEY = os.environ.get('EMBEDDING_API_KEY')
|
||||||
EMBEDDING_BASE_URL = os.environ.get('EMBEDDING_BASE_URL')
|
EMBEDDING_BASE_URL = os.environ.get('EMBEDDING_BASE_URL')
|
||||||
|
|
||||||
# Zep配置(保留兼容性,已废弃)
|
# Zep settings (kept for backwards compatibility; deprecated).
|
||||||
ZEP_API_KEY = os.environ.get('ZEP_API_KEY', '')
|
ZEP_API_KEY = os.environ.get('ZEP_API_KEY', '')
|
||||||
|
|
||||||
# 文件上传配置
|
# File upload settings.
|
||||||
MAX_CONTENT_LENGTH = 50 * 1024 * 1024 # 50MB
|
MAX_CONTENT_LENGTH = 50 * 1024 * 1024 # 50MB
|
||||||
UPLOAD_FOLDER = os.path.join(os.path.dirname(__file__), '../uploads')
|
UPLOAD_FOLDER = os.path.join(os.path.dirname(__file__), '../uploads')
|
||||||
ALLOWED_EXTENSIONS = {'pdf', 'md', 'txt', 'markdown'}
|
ALLOWED_EXTENSIONS = {'pdf', 'md', 'txt', 'markdown'}
|
||||||
|
|
||||||
# 文本处理配置
|
# Text processing settings.
|
||||||
DEFAULT_CHUNK_SIZE = 500 # 默认切块大小
|
DEFAULT_CHUNK_SIZE = 500 # default chunk size in characters
|
||||||
DEFAULT_CHUNK_OVERLAP = 50 # 默认重叠大小
|
DEFAULT_CHUNK_OVERLAP = 50 # default overlap in characters
|
||||||
|
|
||||||
# OASIS模拟配置
|
# OASIS simulation settings.
|
||||||
OASIS_DEFAULT_MAX_ROUNDS = int(os.environ.get('OASIS_DEFAULT_MAX_ROUNDS', '10'))
|
OASIS_DEFAULT_MAX_ROUNDS = int(os.environ.get('OASIS_DEFAULT_MAX_ROUNDS', '10'))
|
||||||
OASIS_SIMULATION_DATA_DIR = os.path.join(os.path.dirname(__file__), '../uploads/simulations')
|
OASIS_SIMULATION_DATA_DIR = os.path.join(os.path.dirname(__file__), '../uploads/simulations')
|
||||||
|
|
||||||
# OASIS平台可用动作配置
|
# OASIS per-platform allowed action lists.
|
||||||
OASIS_TWITTER_ACTIONS = [
|
OASIS_TWITTER_ACTIONS = [
|
||||||
'CREATE_POST', 'LIKE_POST', 'REPOST', 'FOLLOW', 'DO_NOTHING', 'QUOTE_POST'
|
'CREATE_POST', 'LIKE_POST', 'REPOST', 'FOLLOW', 'DO_NOTHING', 'QUOTE_POST'
|
||||||
]
|
]
|
||||||
|
|
@ -76,14 +78,14 @@ class Config:
|
||||||
'TREND', 'REFRESH', 'DO_NOTHING', 'FOLLOW', 'MUTE'
|
'TREND', 'REFRESH', 'DO_NOTHING', 'FOLLOW', 'MUTE'
|
||||||
]
|
]
|
||||||
|
|
||||||
# Report Agent配置
|
# Report agent settings.
|
||||||
REPORT_AGENT_MAX_TOOL_CALLS = int(os.environ.get('REPORT_AGENT_MAX_TOOL_CALLS', '5'))
|
REPORT_AGENT_MAX_TOOL_CALLS = int(os.environ.get('REPORT_AGENT_MAX_TOOL_CALLS', '5'))
|
||||||
REPORT_AGENT_MAX_REFLECTION_ROUNDS = int(os.environ.get('REPORT_AGENT_MAX_REFLECTION_ROUNDS', '2'))
|
REPORT_AGENT_MAX_REFLECTION_ROUNDS = int(os.environ.get('REPORT_AGENT_MAX_REFLECTION_ROUNDS', '2'))
|
||||||
REPORT_AGENT_TEMPERATURE = float(os.environ.get('REPORT_AGENT_TEMPERATURE', '0.5'))
|
REPORT_AGENT_TEMPERATURE = float(os.environ.get('REPORT_AGENT_TEMPERATURE', '0.5'))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate(cls):
|
def validate(cls):
|
||||||
"""验证必要配置"""
|
"""Validate that required configuration values are present."""
|
||||||
errors = []
|
errors = []
|
||||||
if not cls.LLM_API_KEY:
|
if not cls.LLM_API_KEY:
|
||||||
errors.append("LLM_API_KEY 未配置")
|
errors.append("LLM_API_KEY 未配置")
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,12 @@
|
||||||
"""
|
"""Simulation IPC module.
|
||||||
模拟IPC通信模块
|
|
||||||
用于Flask后端和模拟脚本之间的进程间通信
|
|
||||||
|
|
||||||
通过文件系统实现简单的命令/响应模式:
|
Inter-process communication between the Flask backend and the simulation
|
||||||
1. Flask写入命令到 commands/ 目录
|
subprocess. Implements a simple file-system command/response pattern:
|
||||||
2. 模拟脚本轮询命令目录,执行命令并写入响应到 responses/ 目录
|
|
||||||
3. Flask轮询响应目录获取结果
|
1. Flask writes commands into ``commands/``.
|
||||||
|
2. The simulation script polls for commands, executes them, and writes
|
||||||
|
responses into ``responses/``.
|
||||||
|
3. Flask polls the responses directory for results.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
@ -24,14 +25,14 @@ logger = get_logger('mirofish.simulation_ipc')
|
||||||
|
|
||||||
|
|
||||||
class CommandType(str, Enum):
|
class CommandType(str, Enum):
|
||||||
"""命令类型"""
|
"""IPC command types."""
|
||||||
INTERVIEW = "interview" # 单个Agent采访
|
INTERVIEW = "interview" # interview a single agent
|
||||||
BATCH_INTERVIEW = "batch_interview" # 批量采访
|
BATCH_INTERVIEW = "batch_interview" # interview multiple agents at once
|
||||||
CLOSE_ENV = "close_env" # 关闭环境
|
CLOSE_ENV = "close_env" # tear down the environment
|
||||||
|
|
||||||
|
|
||||||
class CommandStatus(str, Enum):
|
class CommandStatus(str, Enum):
|
||||||
"""命令状态"""
|
"""IPC command status."""
|
||||||
PENDING = "pending"
|
PENDING = "pending"
|
||||||
PROCESSING = "processing"
|
PROCESSING = "processing"
|
||||||
COMPLETED = "completed"
|
COMPLETED = "completed"
|
||||||
|
|
@ -40,7 +41,7 @@ class CommandStatus(str, Enum):
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class IPCCommand:
|
class IPCCommand:
|
||||||
"""IPC命令"""
|
"""A command sent over the IPC channel."""
|
||||||
command_id: str
|
command_id: str
|
||||||
command_type: CommandType
|
command_type: CommandType
|
||||||
args: Dict[str, Any]
|
args: Dict[str, Any]
|
||||||
|
|
@ -66,7 +67,7 @@ class IPCCommand:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class IPCResponse:
|
class IPCResponse:
|
||||||
"""IPC响应"""
|
"""A response returned over the IPC channel."""
|
||||||
command_id: str
|
command_id: str
|
||||||
status: CommandStatus
|
status: CommandStatus
|
||||||
result: Optional[Dict[str, Any]] = None
|
result: Optional[Dict[str, Any]] = None
|
||||||
|
|
@ -94,24 +95,22 @@ class IPCResponse:
|
||||||
|
|
||||||
|
|
||||||
class SimulationIPCClient:
|
class SimulationIPCClient:
|
||||||
"""
|
"""IPC client used by the Flask side.
|
||||||
模拟IPC客户端(Flask端使用)
|
|
||||||
|
|
||||||
用于向模拟进程发送命令并等待响应
|
Sends commands to the simulation process and waits for responses.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, simulation_dir: str):
|
def __init__(self, simulation_dir: str):
|
||||||
"""
|
"""Initialize the IPC client.
|
||||||
初始化IPC客户端
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
simulation_dir: 模拟数据目录
|
simulation_dir: Directory holding the simulation's IPC files.
|
||||||
"""
|
"""
|
||||||
self.simulation_dir = simulation_dir
|
self.simulation_dir = simulation_dir
|
||||||
self.commands_dir = os.path.join(simulation_dir, "ipc_commands")
|
self.commands_dir = os.path.join(simulation_dir, "ipc_commands")
|
||||||
self.responses_dir = os.path.join(simulation_dir, "ipc_responses")
|
self.responses_dir = os.path.join(simulation_dir, "ipc_responses")
|
||||||
|
|
||||||
# 确保目录存在
|
# Ensure both directories exist before use.
|
||||||
os.makedirs(self.commands_dir, exist_ok=True)
|
os.makedirs(self.commands_dir, exist_ok=True)
|
||||||
os.makedirs(self.responses_dir, exist_ok=True)
|
os.makedirs(self.responses_dir, exist_ok=True)
|
||||||
|
|
||||||
|
|
@ -122,20 +121,19 @@ class SimulationIPCClient:
|
||||||
timeout: float = 60.0,
|
timeout: float = 60.0,
|
||||||
poll_interval: float = 0.5
|
poll_interval: float = 0.5
|
||||||
) -> IPCResponse:
|
) -> IPCResponse:
|
||||||
"""
|
"""Send a command and wait for the response.
|
||||||
发送命令并等待响应
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
command_type: 命令类型
|
command_type: Command type to send.
|
||||||
args: 命令参数
|
args: Command arguments.
|
||||||
timeout: 超时时间(秒)
|
timeout: Timeout in seconds.
|
||||||
poll_interval: 轮询间隔(秒)
|
poll_interval: Polling interval in seconds.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
IPCResponse
|
The ``IPCResponse``.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TimeoutError: 等待响应超时
|
TimeoutError: When no response arrives before ``timeout``.
|
||||||
"""
|
"""
|
||||||
command_id = str(uuid.uuid4())
|
command_id = str(uuid.uuid4())
|
||||||
command = IPCCommand(
|
command = IPCCommand(
|
||||||
|
|
@ -144,14 +142,14 @@ class SimulationIPCClient:
|
||||||
args=args
|
args=args
|
||||||
)
|
)
|
||||||
|
|
||||||
# 写入命令文件
|
# Write the command file.
|
||||||
command_file = os.path.join(self.commands_dir, f"{command_id}.json")
|
command_file = os.path.join(self.commands_dir, f"{command_id}.json")
|
||||||
with open(command_file, 'w', encoding='utf-8') as f:
|
with open(command_file, 'w', encoding='utf-8') as f:
|
||||||
json.dump(command.to_dict(), f, ensure_ascii=False, indent=2)
|
json.dump(command.to_dict(), f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
logger.info(t("log.simulation_ipc.m001", command_type=command_type.value, command_id=command_id))
|
logger.info(t("log.simulation_ipc.m001", command_type=command_type.value, command_id=command_id))
|
||||||
|
|
||||||
# 等待响应
|
# Poll for the response file.
|
||||||
response_file = os.path.join(self.responses_dir, f"{command_id}.json")
|
response_file = os.path.join(self.responses_dir, f"{command_id}.json")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
|
|
@ -162,7 +160,7 @@ class SimulationIPCClient:
|
||||||
response_data = json.load(f)
|
response_data = json.load(f)
|
||||||
response = IPCResponse.from_dict(response_data)
|
response = IPCResponse.from_dict(response_data)
|
||||||
|
|
||||||
# 清理命令和响应文件
|
# Clean up command and response files after successful read.
|
||||||
try:
|
try:
|
||||||
os.remove(command_file)
|
os.remove(command_file)
|
||||||
os.remove(response_file)
|
os.remove(response_file)
|
||||||
|
|
@ -176,10 +174,10 @@ class SimulationIPCClient:
|
||||||
|
|
||||||
time.sleep(poll_interval)
|
time.sleep(poll_interval)
|
||||||
|
|
||||||
# 超时
|
# Timed out waiting for the response.
|
||||||
logger.error(t("log.simulation_ipc.m004", command_id=command_id))
|
logger.error(t("log.simulation_ipc.m004", command_id=command_id))
|
||||||
|
|
||||||
# 清理命令文件
|
# Clean up the unanswered command file.
|
||||||
try:
|
try:
|
||||||
os.remove(command_file)
|
os.remove(command_file)
|
||||||
except OSError:
|
except OSError:
|
||||||
|
|
@ -194,20 +192,19 @@ class SimulationIPCClient:
|
||||||
platform: str = None,
|
platform: str = None,
|
||||||
timeout: float = 60.0
|
timeout: float = 60.0
|
||||||
) -> IPCResponse:
|
) -> IPCResponse:
|
||||||
"""
|
"""Send a single-agent interview command.
|
||||||
发送单个Agent采访命令
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
agent_id: Agent ID
|
agent_id: Agent id to interview.
|
||||||
prompt: 采访问题
|
prompt: Interview question.
|
||||||
platform: 指定平台(可选)
|
platform: Optional platform selector.
|
||||||
- "twitter": 只采访Twitter平台
|
- ``"twitter"``: interview only on Twitter.
|
||||||
- "reddit": 只采访Reddit平台
|
- ``"reddit"``: interview only on Reddit.
|
||||||
- None: 双平台模拟时同时采访两个平台,单平台模拟时采访该平台
|
- ``None``: dual-platform if applicable, else the single active platform.
|
||||||
timeout: 超时时间
|
timeout: Timeout in seconds.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
IPCResponse,result字段包含采访结果
|
``IPCResponse`` whose ``result`` carries the interview response.
|
||||||
"""
|
"""
|
||||||
args = {
|
args = {
|
||||||
"agent_id": agent_id,
|
"agent_id": agent_id,
|
||||||
|
|
@ -228,19 +225,18 @@ class SimulationIPCClient:
|
||||||
platform: str = None,
|
platform: str = None,
|
||||||
timeout: float = 120.0
|
timeout: float = 120.0
|
||||||
) -> IPCResponse:
|
) -> IPCResponse:
|
||||||
"""
|
"""Send a batched interview command.
|
||||||
发送批量采访命令
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
interviews: 采访列表,每个元素包含 {"agent_id": int, "prompt": str, "platform": str(可选)}
|
interviews: List of items shaped ``{"agent_id": int, "prompt": str, "platform": str?}``.
|
||||||
platform: 默认平台(可选,会被每个采访项的platform覆盖)
|
platform: Default platform; per-item ``platform`` overrides this.
|
||||||
- "twitter": 默认只采访Twitter平台
|
- ``"twitter"``: default to Twitter.
|
||||||
- "reddit": 默认只采访Reddit平台
|
- ``"reddit"``: default to Reddit.
|
||||||
- None: 双平台模拟时每个Agent同时采访两个平台
|
- ``None``: dual-platform interview when applicable.
|
||||||
timeout: 超时时间
|
timeout: Timeout in seconds.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
IPCResponse,result字段包含所有采访结果
|
``IPCResponse`` whose ``result`` carries every interview response.
|
||||||
"""
|
"""
|
||||||
args = {"interviews": interviews}
|
args = {"interviews": interviews}
|
||||||
if platform:
|
if platform:
|
||||||
|
|
@ -253,14 +249,13 @@ class SimulationIPCClient:
|
||||||
)
|
)
|
||||||
|
|
||||||
def send_close_env(self, timeout: float = 30.0) -> IPCResponse:
|
def send_close_env(self, timeout: float = 30.0) -> IPCResponse:
|
||||||
"""
|
"""Send a tear-down-environment command.
|
||||||
发送关闭环境命令
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
timeout: 超时时间
|
timeout: Timeout in seconds.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
IPCResponse
|
``IPCResponse``.
|
||||||
"""
|
"""
|
||||||
return self.send_command(
|
return self.send_command(
|
||||||
command_type=CommandType.CLOSE_ENV,
|
command_type=CommandType.CLOSE_ENV,
|
||||||
|
|
@ -269,10 +264,9 @@ class SimulationIPCClient:
|
||||||
)
|
)
|
||||||
|
|
||||||
def check_env_alive(self) -> bool:
|
def check_env_alive(self) -> bool:
|
||||||
"""
|
"""Return ``True`` if the simulation environment reports as alive.
|
||||||
检查模拟环境是否存活
|
|
||||||
|
|
||||||
通过检查 env_status.json 文件来判断
|
Reads ``env_status.json`` written by the IPC server side.
|
||||||
"""
|
"""
|
||||||
status_file = os.path.join(self.simulation_dir, "env_status.json")
|
status_file = os.path.join(self.simulation_dir, "env_status.json")
|
||||||
if not os.path.exists(status_file):
|
if not os.path.exists(status_file):
|
||||||
|
|
@ -287,42 +281,40 @@ class SimulationIPCClient:
|
||||||
|
|
||||||
|
|
||||||
class SimulationIPCServer:
|
class SimulationIPCServer:
|
||||||
"""
|
"""IPC server used by the simulation script.
|
||||||
模拟IPC服务器(模拟脚本端使用)
|
|
||||||
|
|
||||||
轮询命令目录,执行命令并返回响应
|
Polls the commands directory, executes commands, and writes responses.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, simulation_dir: str):
|
def __init__(self, simulation_dir: str):
|
||||||
"""
|
"""Initialize the IPC server.
|
||||||
初始化IPC服务器
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
simulation_dir: 模拟数据目录
|
simulation_dir: Directory holding the simulation's IPC files.
|
||||||
"""
|
"""
|
||||||
self.simulation_dir = simulation_dir
|
self.simulation_dir = simulation_dir
|
||||||
self.commands_dir = os.path.join(simulation_dir, "ipc_commands")
|
self.commands_dir = os.path.join(simulation_dir, "ipc_commands")
|
||||||
self.responses_dir = os.path.join(simulation_dir, "ipc_responses")
|
self.responses_dir = os.path.join(simulation_dir, "ipc_responses")
|
||||||
|
|
||||||
# 确保目录存在
|
# Ensure both directories exist before use.
|
||||||
os.makedirs(self.commands_dir, exist_ok=True)
|
os.makedirs(self.commands_dir, exist_ok=True)
|
||||||
os.makedirs(self.responses_dir, exist_ok=True)
|
os.makedirs(self.responses_dir, exist_ok=True)
|
||||||
|
|
||||||
# 环境状态
|
# Server-running flag.
|
||||||
self._running = False
|
self._running = False
|
||||||
|
|
||||||
def start(self):
|
def start(self):
|
||||||
"""标记服务器为运行状态"""
|
"""Mark the server as alive and persist the state."""
|
||||||
self._running = True
|
self._running = True
|
||||||
self._update_env_status("alive")
|
self._update_env_status("alive")
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
"""标记服务器为停止状态"""
|
"""Mark the server as stopped and persist the state."""
|
||||||
self._running = False
|
self._running = False
|
||||||
self._update_env_status("stopped")
|
self._update_env_status("stopped")
|
||||||
|
|
||||||
def _update_env_status(self, status: str):
|
def _update_env_status(self, status: str):
|
||||||
"""更新环境状态文件"""
|
"""Update the persistent environment-status file."""
|
||||||
status_file = os.path.join(self.simulation_dir, "env_status.json")
|
status_file = os.path.join(self.simulation_dir, "env_status.json")
|
||||||
with open(status_file, 'w', encoding='utf-8') as f:
|
with open(status_file, 'w', encoding='utf-8') as f:
|
||||||
json.dump({
|
json.dump({
|
||||||
|
|
@ -331,16 +323,15 @@ class SimulationIPCServer:
|
||||||
}, f, ensure_ascii=False, indent=2)
|
}, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
def poll_commands(self) -> Optional[IPCCommand]:
|
def poll_commands(self) -> Optional[IPCCommand]:
|
||||||
"""
|
"""Poll the commands directory and return the next pending command.
|
||||||
轮询命令目录,返回第一个待处理的命令
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
IPCCommand 或 None
|
``IPCCommand`` or ``None`` if no pending commands remain.
|
||||||
"""
|
"""
|
||||||
if not os.path.exists(self.commands_dir):
|
if not os.path.exists(self.commands_dir):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 按时间排序获取命令文件
|
# Sort by mtime so we process commands in arrival order.
|
||||||
command_files = []
|
command_files = []
|
||||||
for filename in os.listdir(self.commands_dir):
|
for filename in os.listdir(self.commands_dir):
|
||||||
if filename.endswith('.json'):
|
if filename.endswith('.json'):
|
||||||
|
|
@ -361,17 +352,16 @@ class SimulationIPCServer:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def send_response(self, response: IPCResponse):
|
def send_response(self, response: IPCResponse):
|
||||||
"""
|
"""Write a response file.
|
||||||
发送响应
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
response: IPC响应
|
response: The response to send.
|
||||||
"""
|
"""
|
||||||
response_file = os.path.join(self.responses_dir, f"{response.command_id}.json")
|
response_file = os.path.join(self.responses_dir, f"{response.command_id}.json")
|
||||||
with open(response_file, 'w', encoding='utf-8') as f:
|
with open(response_file, 'w', encoding='utf-8') as f:
|
||||||
json.dump(response.to_dict(), f, ensure_ascii=False, indent=2)
|
json.dump(response.to_dict(), f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
# 删除命令文件
|
# Delete the matching command file.
|
||||||
command_file = os.path.join(self.commands_dir, f"{response.command_id}.json")
|
command_file = os.path.join(self.commands_dir, f"{response.command_id}.json")
|
||||||
try:
|
try:
|
||||||
os.remove(command_file)
|
os.remove(command_file)
|
||||||
|
|
@ -379,7 +369,7 @@ class SimulationIPCServer:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def send_success(self, command_id: str, result: Dict[str, Any]):
|
def send_success(self, command_id: str, result: Dict[str, Any]):
|
||||||
"""发送成功响应"""
|
"""Send a success response."""
|
||||||
self.send_response(IPCResponse(
|
self.send_response(IPCResponse(
|
||||||
command_id=command_id,
|
command_id=command_id,
|
||||||
status=CommandStatus.COMPLETED,
|
status=CommandStatus.COMPLETED,
|
||||||
|
|
@ -387,7 +377,7 @@ class SimulationIPCServer:
|
||||||
))
|
))
|
||||||
|
|
||||||
def send_error(self, command_id: str, error: str):
|
def send_error(self, command_id: str, error: str):
|
||||||
"""发送错误响应"""
|
"""Send a failure response."""
|
||||||
self.send_response(IPCResponse(
|
self.send_response(IPCResponse(
|
||||||
command_id=command_id,
|
command_id=command_id,
|
||||||
status=CommandStatus.FAILED,
|
status=CommandStatus.FAILED,
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
"""
|
"""OASIS simulation manager.
|
||||||
OASIS模拟管理器
|
|
||||||
管理Twitter和Reddit双平台并行模拟
|
Drives parallel Twitter + Reddit simulations using preset scripts plus
|
||||||
使用预设脚本 + LLM智能生成配置参数
|
LLM-generated configuration parameters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
@ -23,60 +23,60 @@ logger = get_logger('mirofish.simulation')
|
||||||
|
|
||||||
|
|
||||||
class SimulationStatus(str, Enum):
|
class SimulationStatus(str, Enum):
|
||||||
"""模拟状态"""
|
"""Simulation lifecycle status."""
|
||||||
CREATED = "created"
|
CREATED = "created"
|
||||||
PREPARING = "preparing"
|
PREPARING = "preparing"
|
||||||
READY = "ready"
|
READY = "ready"
|
||||||
RUNNING = "running"
|
RUNNING = "running"
|
||||||
PAUSED = "paused"
|
PAUSED = "paused"
|
||||||
STOPPED = "stopped" # 模拟被手动停止
|
STOPPED = "stopped" # manually stopped
|
||||||
COMPLETED = "completed" # 模拟自然完成
|
COMPLETED = "completed" # finished naturally
|
||||||
FAILED = "failed"
|
FAILED = "failed"
|
||||||
|
|
||||||
|
|
||||||
class PlatformType(str, Enum):
|
class PlatformType(str, Enum):
|
||||||
"""平台类型"""
|
"""Simulated platform types."""
|
||||||
TWITTER = "twitter"
|
TWITTER = "twitter"
|
||||||
REDDIT = "reddit"
|
REDDIT = "reddit"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SimulationState:
|
class SimulationState:
|
||||||
"""模拟状态"""
|
"""In-memory + persisted state for a single simulation."""
|
||||||
simulation_id: str
|
simulation_id: str
|
||||||
project_id: str
|
project_id: str
|
||||||
graph_id: str
|
graph_id: str
|
||||||
|
|
||||||
# 平台启用状态
|
# Per-platform enable flags.
|
||||||
enable_twitter: bool = True
|
enable_twitter: bool = True
|
||||||
enable_reddit: bool = True
|
enable_reddit: bool = True
|
||||||
|
|
||||||
# 状态
|
# Lifecycle status.
|
||||||
status: SimulationStatus = SimulationStatus.CREATED
|
status: SimulationStatus = SimulationStatus.CREATED
|
||||||
|
|
||||||
# 准备阶段数据
|
# Counters captured during the prepare phase.
|
||||||
entities_count: int = 0
|
entities_count: int = 0
|
||||||
profiles_count: int = 0
|
profiles_count: int = 0
|
||||||
entity_types: List[str] = field(default_factory=list)
|
entity_types: List[str] = field(default_factory=list)
|
||||||
|
|
||||||
# 配置生成信息
|
# Information about the auto-generated config.
|
||||||
config_generated: bool = False
|
config_generated: bool = False
|
||||||
config_reasoning: str = ""
|
config_reasoning: str = ""
|
||||||
|
|
||||||
# 运行时数据
|
# Runtime data.
|
||||||
current_round: int = 0
|
current_round: int = 0
|
||||||
twitter_status: str = "not_started"
|
twitter_status: str = "not_started"
|
||||||
reddit_status: str = "not_started"
|
reddit_status: str = "not_started"
|
||||||
|
|
||||||
# 时间戳
|
# Timestamps.
|
||||||
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||||
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
||||||
|
|
||||||
# 错误信息
|
# Error message when status == FAILED.
|
||||||
error: Optional[str] = None
|
error: Optional[str] = None
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
"""完整状态字典(内部使用)"""
|
"""Full state dict (used for persistence and internal callers)."""
|
||||||
return {
|
return {
|
||||||
"simulation_id": self.simulation_id,
|
"simulation_id": self.simulation_id,
|
||||||
"project_id": self.project_id,
|
"project_id": self.project_id,
|
||||||
|
|
@ -98,7 +98,7 @@ class SimulationState:
|
||||||
}
|
}
|
||||||
|
|
||||||
def to_simple_dict(self) -> Dict[str, Any]:
|
def to_simple_dict(self) -> Dict[str, Any]:
|
||||||
"""简化状态字典(API返回使用)"""
|
"""Simplified state dict (used for API responses)."""
|
||||||
return {
|
return {
|
||||||
"simulation_id": self.simulation_id,
|
"simulation_id": self.simulation_id,
|
||||||
"project_id": self.project_id,
|
"project_id": self.project_id,
|
||||||
|
|
@ -113,37 +113,36 @@ class SimulationState:
|
||||||
|
|
||||||
|
|
||||||
class SimulationManager:
|
class SimulationManager:
|
||||||
"""
|
"""Simulation manager.
|
||||||
模拟管理器
|
|
||||||
|
|
||||||
核心功能:
|
Core responsibilities:
|
||||||
1. 从Zep图谱读取实体并过滤
|
1. Read entities from the Zep graph and filter to the configured types.
|
||||||
2. 生成OASIS Agent Profile
|
2. Generate OASIS agent profiles per entity.
|
||||||
3. 使用LLM智能生成模拟配置参数
|
3. Use the LLM to generate simulation configuration parameters.
|
||||||
4. 准备预设脚本所需的所有文件
|
4. Materialize the files the preset scripts expect.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 模拟数据存储目录
|
# Root directory for persisted simulation data.
|
||||||
SIMULATION_DATA_DIR = os.path.join(
|
SIMULATION_DATA_DIR = os.path.join(
|
||||||
os.path.dirname(__file__),
|
os.path.dirname(__file__),
|
||||||
'../../uploads/simulations'
|
'../../uploads/simulations'
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
# 确保目录存在
|
# Ensure the simulation data directory exists.
|
||||||
os.makedirs(self.SIMULATION_DATA_DIR, exist_ok=True)
|
os.makedirs(self.SIMULATION_DATA_DIR, exist_ok=True)
|
||||||
|
|
||||||
# 内存中的模拟状态缓存
|
# In-memory cache of simulation state objects.
|
||||||
self._simulations: Dict[str, SimulationState] = {}
|
self._simulations: Dict[str, SimulationState] = {}
|
||||||
|
|
||||||
def _get_simulation_dir(self, simulation_id: str) -> str:
|
def _get_simulation_dir(self, simulation_id: str) -> str:
|
||||||
"""获取模拟数据目录"""
|
"""Return the on-disk directory for a simulation, creating if missing."""
|
||||||
sim_dir = os.path.join(self.SIMULATION_DATA_DIR, simulation_id)
|
sim_dir = os.path.join(self.SIMULATION_DATA_DIR, simulation_id)
|
||||||
os.makedirs(sim_dir, exist_ok=True)
|
os.makedirs(sim_dir, exist_ok=True)
|
||||||
return sim_dir
|
return sim_dir
|
||||||
|
|
||||||
def _save_simulation_state(self, state: SimulationState):
|
def _save_simulation_state(self, state: SimulationState):
|
||||||
"""保存模拟状态到文件"""
|
"""Persist a simulation state to disk and update the cache."""
|
||||||
sim_dir = self._get_simulation_dir(state.simulation_id)
|
sim_dir = self._get_simulation_dir(state.simulation_id)
|
||||||
state_file = os.path.join(sim_dir, "state.json")
|
state_file = os.path.join(sim_dir, "state.json")
|
||||||
|
|
||||||
|
|
@ -155,7 +154,7 @@ class SimulationManager:
|
||||||
self._simulations[state.simulation_id] = state
|
self._simulations[state.simulation_id] = state
|
||||||
|
|
||||||
def _load_simulation_state(self, simulation_id: str) -> Optional[SimulationState]:
|
def _load_simulation_state(self, simulation_id: str) -> Optional[SimulationState]:
|
||||||
"""从文件加载模拟状态"""
|
"""Load a simulation state from disk (or cache) by id."""
|
||||||
if simulation_id in self._simulations:
|
if simulation_id in self._simulations:
|
||||||
return self._simulations[simulation_id]
|
return self._simulations[simulation_id]
|
||||||
|
|
||||||
|
|
@ -198,17 +197,16 @@ class SimulationManager:
|
||||||
enable_twitter: bool = True,
|
enable_twitter: bool = True,
|
||||||
enable_reddit: bool = True,
|
enable_reddit: bool = True,
|
||||||
) -> SimulationState:
|
) -> SimulationState:
|
||||||
"""
|
"""Create a new simulation in the ``CREATED`` state.
|
||||||
创建新的模拟
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
project_id: 项目ID
|
project_id: Owning project id.
|
||||||
graph_id: Zep图谱ID
|
graph_id: Source Zep graph id.
|
||||||
enable_twitter: 是否启用Twitter模拟
|
enable_twitter: When ``True``, the Twitter simulation runs.
|
||||||
enable_reddit: 是否启用Reddit模拟
|
enable_reddit: When ``True``, the Reddit simulation runs.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SimulationState
|
The created ``SimulationState``.
|
||||||
"""
|
"""
|
||||||
import uuid
|
import uuid
|
||||||
simulation_id = f"sim_{uuid.uuid4().hex[:12]}"
|
simulation_id = f"sim_{uuid.uuid4().hex[:12]}"
|
||||||
|
|
@ -237,27 +235,26 @@ class SimulationManager:
|
||||||
progress_callback: Optional[callable] = None,
|
progress_callback: Optional[callable] = None,
|
||||||
parallel_profile_count: int = 3
|
parallel_profile_count: int = 3
|
||||||
) -> SimulationState:
|
) -> SimulationState:
|
||||||
"""
|
"""Prepare the simulation environment end-to-end.
|
||||||
准备模拟环境(全程自动化)
|
|
||||||
|
|
||||||
步骤:
|
Steps:
|
||||||
1. 从Zep图谱读取并过滤实体
|
1. Read and filter entities from the graph.
|
||||||
2. 为每个实体生成OASIS Agent Profile(可选LLM增强,支持并行)
|
2. Generate OASIS agent profiles (optional LLM enrichment, parallel-capable).
|
||||||
3. 使用LLM智能生成模拟配置参数(时间、活跃度、发言频率等)
|
3. Use the LLM to produce simulation parameters (timing, activity, posting frequency).
|
||||||
4. 保存配置文件和Profile文件
|
4. Save the configuration and profile files.
|
||||||
5. 复制预设脚本到模拟目录
|
5. Copy preset scripts into the simulation directory.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
simulation_id: 模拟ID
|
simulation_id: Simulation id.
|
||||||
simulation_requirement: 模拟需求描述(用于LLM生成配置)
|
simulation_requirement: Free-text description of the simulation goal.
|
||||||
document_text: 原始文档内容(用于LLM理解背景)
|
document_text: Raw source document text passed to the LLM for context.
|
||||||
defined_entity_types: 预定义的实体类型(可选)
|
defined_entity_types: Optional list of allowed entity types.
|
||||||
use_llm_for_profiles: 是否使用LLM生成详细人设
|
use_llm_for_profiles: When ``True``, enrich profiles via the LLM.
|
||||||
progress_callback: 进度回调函数 (stage, progress, message)
|
progress_callback: Optional callback ``(stage, progress, message, **extras)``.
|
||||||
parallel_profile_count: 并行生成人设的数量,默认3
|
parallel_profile_count: Number of profile generations to run in parallel.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
SimulationState
|
The updated ``SimulationState``.
|
||||||
"""
|
"""
|
||||||
state = self._load_simulation_state(simulation_id)
|
state = self._load_simulation_state(simulation_id)
|
||||||
if not state:
|
if not state:
|
||||||
|
|
@ -269,7 +266,7 @@ class SimulationManager:
|
||||||
|
|
||||||
sim_dir = self._get_simulation_dir(simulation_id)
|
sim_dir = self._get_simulation_dir(simulation_id)
|
||||||
|
|
||||||
# ========== 阶段1: 读取并过滤实体 ==========
|
# ========== Stage 1: read and filter entities ==========
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
progress_callback("reading", 0, t('progress.connectingZepGraph'))
|
progress_callback("reading", 0, t('progress.connectingZepGraph'))
|
||||||
|
|
||||||
|
|
@ -301,7 +298,7 @@ class SimulationManager:
|
||||||
self._save_simulation_state(state)
|
self._save_simulation_state(state)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
# ========== 阶段2: 生成Agent Profile ==========
|
# ========== Stage 2: generate agent profiles ==========
|
||||||
total_entities = len(filtered.entities)
|
total_entities = len(filtered.entities)
|
||||||
|
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
|
|
@ -312,7 +309,7 @@ class SimulationManager:
|
||||||
total=total_entities
|
total=total_entities
|
||||||
)
|
)
|
||||||
|
|
||||||
# 传入graph_id以启用Zep检索功能,获取更丰富的上下文
|
# Pass the graph_id so the generator can use Zep retrieval for richer context.
|
||||||
generator = OasisProfileGenerator(graph_id=state.graph_id)
|
generator = OasisProfileGenerator(graph_id=state.graph_id)
|
||||||
|
|
||||||
def profile_progress(current, total, msg):
|
def profile_progress(current, total, msg):
|
||||||
|
|
@ -326,7 +323,7 @@ class SimulationManager:
|
||||||
item_name=msg
|
item_name=msg
|
||||||
)
|
)
|
||||||
|
|
||||||
# 设置实时保存的文件路径(优先使用 Reddit JSON 格式)
|
# Configure the realtime save target (prefer Reddit JSON if Reddit is enabled).
|
||||||
realtime_output_path = None
|
realtime_output_path = None
|
||||||
realtime_platform = "reddit"
|
realtime_platform = "reddit"
|
||||||
if state.enable_reddit:
|
if state.enable_reddit:
|
||||||
|
|
@ -340,16 +337,16 @@ class SimulationManager:
|
||||||
entities=filtered.entities,
|
entities=filtered.entities,
|
||||||
use_llm=use_llm_for_profiles,
|
use_llm=use_llm_for_profiles,
|
||||||
progress_callback=profile_progress,
|
progress_callback=profile_progress,
|
||||||
graph_id=state.graph_id, # 传入graph_id用于Zep检索
|
graph_id=state.graph_id, # used for Zep retrieval enrichment
|
||||||
parallel_count=parallel_profile_count, # 并行生成数量
|
parallel_count=parallel_profile_count,
|
||||||
realtime_output_path=realtime_output_path, # 实时保存路径
|
realtime_output_path=realtime_output_path,
|
||||||
output_platform=realtime_platform # 输出格式
|
output_platform=realtime_platform
|
||||||
)
|
)
|
||||||
|
|
||||||
state.profiles_count = len(profiles)
|
state.profiles_count = len(profiles)
|
||||||
|
|
||||||
# 保存Profile文件(注意:Twitter使用CSV格式,Reddit使用JSON格式)
|
# Save profile files. Reddit also writes JSON during generation; this is
|
||||||
# Reddit 已经在生成过程中实时保存了,这里再保存一次确保完整性
|
# a final consistency write. Twitter requires CSV per OASIS conventions.
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
progress_callback(
|
progress_callback(
|
||||||
"generating_profiles", 95,
|
"generating_profiles", 95,
|
||||||
|
|
@ -366,7 +363,7 @@ class SimulationManager:
|
||||||
)
|
)
|
||||||
|
|
||||||
if state.enable_twitter:
|
if state.enable_twitter:
|
||||||
# Twitter使用CSV格式!这是OASIS的要求
|
# Twitter uses CSV format — required by OASIS.
|
||||||
generator.save_profiles(
|
generator.save_profiles(
|
||||||
profiles=profiles,
|
profiles=profiles,
|
||||||
file_path=os.path.join(sim_dir, "twitter_profiles.csv"),
|
file_path=os.path.join(sim_dir, "twitter_profiles.csv"),
|
||||||
|
|
@ -381,7 +378,7 @@ class SimulationManager:
|
||||||
total=len(profiles)
|
total=len(profiles)
|
||||||
)
|
)
|
||||||
|
|
||||||
# ========== 阶段3: LLM智能生成模拟配置 ==========
|
# ========== Stage 3: LLM-driven simulation config ==========
|
||||||
if progress_callback:
|
if progress_callback:
|
||||||
progress_callback(
|
progress_callback(
|
||||||
"generating_config", 0,
|
"generating_config", 0,
|
||||||
|
|
@ -419,7 +416,7 @@ class SimulationManager:
|
||||||
total=3
|
total=3
|
||||||
)
|
)
|
||||||
|
|
||||||
# 保存配置文件
|
# Save the configuration file.
|
||||||
config_path = os.path.join(sim_dir, "simulation_config.json")
|
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||||||
with open(config_path, 'w', encoding='utf-8') as f:
|
with open(config_path, 'w', encoding='utf-8') as f:
|
||||||
f.write(sim_params.to_json())
|
f.write(sim_params.to_json())
|
||||||
|
|
@ -435,10 +432,9 @@ class SimulationManager:
|
||||||
total=3
|
total=3
|
||||||
)
|
)
|
||||||
|
|
||||||
# 注意:运行脚本保留在 backend/scripts/ 目录,不再复制到模拟目录
|
# The runtime scripts now live under backend/scripts/; we no longer copy
|
||||||
# 启动模拟时,simulation_runner 会从 scripts/ 目录运行脚本
|
# them per-simulation. simulation_runner invokes them in place.
|
||||||
|
|
||||||
# 更新状态
|
|
||||||
state.status = SimulationStatus.READY
|
state.status = SimulationStatus.READY
|
||||||
self._save_simulation_state(state)
|
self._save_simulation_state(state)
|
||||||
|
|
||||||
|
|
@ -456,16 +452,16 @@ class SimulationManager:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def get_simulation(self, simulation_id: str) -> Optional[SimulationState]:
|
def get_simulation(self, simulation_id: str) -> Optional[SimulationState]:
|
||||||
"""获取模拟状态"""
|
"""Return the simulation's state, or ``None`` if unknown."""
|
||||||
return self._load_simulation_state(simulation_id)
|
return self._load_simulation_state(simulation_id)
|
||||||
|
|
||||||
def list_simulations(self, project_id: Optional[str] = None) -> List[SimulationState]:
|
def list_simulations(self, project_id: Optional[str] = None) -> List[SimulationState]:
|
||||||
"""列出所有模拟"""
|
"""List all simulations, optionally filtered by ``project_id``."""
|
||||||
simulations = []
|
simulations = []
|
||||||
|
|
||||||
if os.path.exists(self.SIMULATION_DATA_DIR):
|
if os.path.exists(self.SIMULATION_DATA_DIR):
|
||||||
for sim_id in os.listdir(self.SIMULATION_DATA_DIR):
|
for sim_id in os.listdir(self.SIMULATION_DATA_DIR):
|
||||||
# 跳过隐藏文件(如 .DS_Store)和非目录文件
|
# Skip dotfiles (e.g. .DS_Store) and non-directories.
|
||||||
sim_path = os.path.join(self.SIMULATION_DATA_DIR, sim_id)
|
sim_path = os.path.join(self.SIMULATION_DATA_DIR, sim_id)
|
||||||
if sim_id.startswith('.') or not os.path.isdir(sim_path):
|
if sim_id.startswith('.') or not os.path.isdir(sim_path):
|
||||||
continue
|
continue
|
||||||
|
|
@ -478,7 +474,7 @@ class SimulationManager:
|
||||||
return simulations
|
return simulations
|
||||||
|
|
||||||
def get_profiles(self, simulation_id: str, platform: str = "reddit") -> List[Dict[str, Any]]:
|
def get_profiles(self, simulation_id: str, platform: str = "reddit") -> List[Dict[str, Any]]:
|
||||||
"""获取模拟的Agent Profile"""
|
"""Return the persisted agent profiles for a platform."""
|
||||||
state = self._load_simulation_state(simulation_id)
|
state = self._load_simulation_state(simulation_id)
|
||||||
if not state:
|
if not state:
|
||||||
raise ValueError(f"模拟不存在: {simulation_id}")
|
raise ValueError(f"模拟不存在: {simulation_id}")
|
||||||
|
|
@ -493,7 +489,7 @@ class SimulationManager:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
def get_simulation_config(self, simulation_id: str) -> Optional[Dict[str, Any]]:
|
def get_simulation_config(self, simulation_id: str) -> Optional[Dict[str, Any]]:
|
||||||
"""获取模拟配置"""
|
"""Return the persisted simulation config dict, or ``None`` if absent."""
|
||||||
sim_dir = self._get_simulation_dir(simulation_id)
|
sim_dir = self._get_simulation_dir(simulation_id)
|
||||||
config_path = os.path.join(sim_dir, "simulation_config.json")
|
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||||||
|
|
||||||
|
|
@ -504,7 +500,7 @@ class SimulationManager:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
def get_run_instructions(self, simulation_id: str) -> Dict[str, str]:
|
def get_run_instructions(self, simulation_id: str) -> Dict[str, str]:
|
||||||
"""获取运行说明"""
|
"""Return shell commands and instructions to launch the simulation manually."""
|
||||||
sim_dir = self._get_simulation_dir(simulation_id)
|
sim_dir = self._get_simulation_dir(simulation_id)
|
||||||
config_path = os.path.join(sim_dir, "simulation_config.json")
|
config_path = os.path.join(sim_dir, "simulation_config.json")
|
||||||
scripts_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../scripts'))
|
scripts_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../scripts'))
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
"""
|
"""Zep entity reader and filter service.
|
||||||
Zep实体读取与过滤服务
|
|
||||||
从Zep图谱中读取节点,筛选出符合预定义实体类型的节点
|
Reads nodes from a Zep graph and filters down to those that match a
|
||||||
|
predefined ontology of entity types.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
|
@ -16,21 +17,21 @@ from ..utils.locale import t
|
||||||
|
|
||||||
logger = get_logger('mirofish.zep_entity_reader')
|
logger = get_logger('mirofish.zep_entity_reader')
|
||||||
|
|
||||||
# 用于泛型返回类型
|
# Generic return-type variable.
|
||||||
T = TypeVar('T')
|
T = TypeVar('T')
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EntityNode:
|
class EntityNode:
|
||||||
"""实体节点数据结构"""
|
"""In-memory representation of an entity node from the graph."""
|
||||||
uuid: str
|
uuid: str
|
||||||
name: str
|
name: str
|
||||||
labels: List[str]
|
labels: List[str]
|
||||||
summary: str
|
summary: str
|
||||||
attributes: Dict[str, Any]
|
attributes: Dict[str, Any]
|
||||||
# 相关的边信息
|
# Edges connected to this entity.
|
||||||
related_edges: List[Dict[str, Any]] = field(default_factory=list)
|
related_edges: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
# 相关的其他节点信息
|
# Other nodes connected through related edges.
|
||||||
related_nodes: List[Dict[str, Any]] = field(default_factory=list)
|
related_nodes: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
|
@ -45,7 +46,7 @@ class EntityNode:
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_entity_type(self) -> Optional[str]:
|
def get_entity_type(self) -> Optional[str]:
|
||||||
"""获取实体类型(排除默认的Entity标签)"""
|
"""Return the first non-default label, or ``None`` if only defaults are present."""
|
||||||
for label in self.labels:
|
for label in self.labels:
|
||||||
if label not in ["Entity", "Node"]:
|
if label not in ["Entity", "Node"]:
|
||||||
return label
|
return label
|
||||||
|
|
@ -54,7 +55,7 @@ class EntityNode:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FilteredEntities:
|
class FilteredEntities:
|
||||||
"""过滤后的实体集合"""
|
"""Result of a filter pass over the graph: matching entities + counts."""
|
||||||
entities: List[EntityNode]
|
entities: List[EntityNode]
|
||||||
entity_types: Set[str]
|
entity_types: Set[str]
|
||||||
total_count: int
|
total_count: int
|
||||||
|
|
@ -70,13 +71,12 @@ class FilteredEntities:
|
||||||
|
|
||||||
|
|
||||||
class ZepEntityReader:
|
class ZepEntityReader:
|
||||||
"""
|
"""Read entities from a Zep graph and filter to ontology-defined types.
|
||||||
Zep实体读取与过滤服务
|
|
||||||
|
|
||||||
主要功能:
|
Capabilities:
|
||||||
1. 从Zep图谱读取所有节点
|
1. Read all nodes from the graph.
|
||||||
2. 筛选出符合预定义实体类型的节点(Labels不只是Entity的节点)
|
2. Keep nodes whose labels include something other than the default ``Entity``.
|
||||||
3. 获取每个实体的相关边和关联节点信息
|
3. Optionally enrich each entity with its connected edges and neighboring nodes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, api_key: Optional[str] = None):
|
def __init__(self, api_key: Optional[str] = None):
|
||||||
|
|
@ -89,17 +89,16 @@ class ZepEntityReader:
|
||||||
max_retries: int = 3,
|
max_retries: int = 3,
|
||||||
initial_delay: float = 2.0
|
initial_delay: float = 2.0
|
||||||
) -> T:
|
) -> T:
|
||||||
"""
|
"""Call a Zep API function with retry on failure.
|
||||||
带重试机制的Zep API调用
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
func: 要执行的函数(无参数的lambda或callable)
|
func: A zero-argument callable performing the request.
|
||||||
operation_name: 操作名称,用于日志
|
operation_name: Operation label used in log output.
|
||||||
max_retries: 最大重试次数(默认3次,即最多尝试3次)
|
max_retries: Maximum number of attempts (default 3 — i.e. up to 3 tries total).
|
||||||
initial_delay: 初始延迟秒数
|
initial_delay: Initial delay between retries in seconds.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
API调用结果
|
The return value of ``func``.
|
||||||
"""
|
"""
|
||||||
last_exception = None
|
last_exception = None
|
||||||
delay = initial_delay
|
delay = initial_delay
|
||||||
|
|
@ -114,21 +113,20 @@ class ZepEntityReader:
|
||||||
t("log.zep_entity_reader.m001", operation_name=operation_name, attempt=attempt + 1, str=str(e)[:100], delay=delay)
|
t("log.zep_entity_reader.m001", operation_name=operation_name, attempt=attempt + 1, str=str(e)[:100], delay=delay)
|
||||||
)
|
)
|
||||||
time.sleep(delay)
|
time.sleep(delay)
|
||||||
delay *= 2 # 指数退避
|
delay *= 2 # exponential backoff
|
||||||
else:
|
else:
|
||||||
logger.error(t("log.zep_entity_reader.m002", operation_name=operation_name, max_retries=max_retries, str=str(e)))
|
logger.error(t("log.zep_entity_reader.m002", operation_name=operation_name, max_retries=max_retries, str=str(e)))
|
||||||
|
|
||||||
raise last_exception
|
raise last_exception
|
||||||
|
|
||||||
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
|
def get_all_nodes(self, graph_id: str) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""Return every node in the graph (paginated under the hood).
|
||||||
获取图谱的所有节点(分页获取)
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph_id: 图谱ID
|
graph_id: Graph identifier.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
节点列表
|
A list of node dicts.
|
||||||
"""
|
"""
|
||||||
logger.info(t("log.zep_entity_reader.m003", graph_id=graph_id))
|
logger.info(t("log.zep_entity_reader.m003", graph_id=graph_id))
|
||||||
|
|
||||||
|
|
@ -148,14 +146,13 @@ class ZepEntityReader:
|
||||||
return nodes_data
|
return nodes_data
|
||||||
|
|
||||||
def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]:
|
def get_all_edges(self, graph_id: str) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""Return every edge in the graph (paginated under the hood).
|
||||||
获取图谱的所有边(分页获取)
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph_id: 图谱ID
|
graph_id: Graph identifier.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
边列表
|
A list of edge dicts.
|
||||||
"""
|
"""
|
||||||
logger.info(t("log.zep_entity_reader.m005", graph_id=graph_id))
|
logger.info(t("log.zep_entity_reader.m005", graph_id=graph_id))
|
||||||
|
|
||||||
|
|
@ -176,17 +173,16 @@ class ZepEntityReader:
|
||||||
return edges_data
|
return edges_data
|
||||||
|
|
||||||
def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]:
|
def get_node_edges(self, node_uuid: str) -> List[Dict[str, Any]]:
|
||||||
"""
|
"""Return every edge connected to the given node (with retry).
|
||||||
获取指定节点的所有相关边(带重试机制)
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
node_uuid: 节点UUID
|
node_uuid: Node UUID.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
边列表
|
A list of edge dicts.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 使用重试机制调用Zep API
|
# Wrap the API call in retry logic.
|
||||||
edges = self._call_with_retry(
|
edges = self._call_with_retry(
|
||||||
func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid),
|
func=lambda: self.client.graph.node.get_entity_edges(node_uuid=node_uuid),
|
||||||
operation_name=f"获取节点边(node={node_uuid[:8]}...)"
|
operation_name=f"获取节点边(node={node_uuid[:8]}...)"
|
||||||
|
|
@ -214,20 +210,19 @@ class ZepEntityReader:
|
||||||
defined_entity_types: Optional[List[str]] = None,
|
defined_entity_types: Optional[List[str]] = None,
|
||||||
enrich_with_edges: bool = True
|
enrich_with_edges: bool = True
|
||||||
) -> FilteredEntities:
|
) -> FilteredEntities:
|
||||||
"""
|
"""Filter nodes down to entities matching the predefined ontology types.
|
||||||
筛选出符合预定义实体类型的节点
|
|
||||||
|
|
||||||
筛选逻辑:
|
Filtering rules:
|
||||||
- 如果节点的Labels只有一个"Entity",说明这个实体不符合我们预定义的类型,跳过
|
- Skip nodes whose only label is ``Entity`` (uncategorized).
|
||||||
- 如果节点的Labels包含除"Entity"和"Node"之外的标签,说明符合预定义类型,保留
|
- Keep nodes whose labels include anything other than ``Entity`` and ``Node``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph_id: 图谱ID
|
graph_id: Graph identifier.
|
||||||
defined_entity_types: 预定义的实体类型列表(可选,如果提供则只保留这些类型)
|
defined_entity_types: Optional allow-list; when provided, only matching types are kept.
|
||||||
enrich_with_edges: 是否获取每个实体的相关边信息
|
enrich_with_edges: When ``True``, populate related_edges and related_nodes.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
FilteredEntities: 过滤后的实体集合
|
A ``FilteredEntities`` summary.
|
||||||
"""
|
"""
|
||||||
logger.info(t("log.zep_entity_reader.m008", graph_id=graph_id))
|
logger.info(t("log.zep_entity_reader.m008", graph_id=graph_id))
|
||||||
|
|
||||||
|
|
@ -243,7 +238,7 @@ class ZepEntityReader:
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# 获取所有节点
|
# Read every node from the graph.
|
||||||
all_nodes = self.get_all_nodes(graph_id)
|
all_nodes = self.get_all_nodes(graph_id)
|
||||||
total_count = len(all_nodes)
|
total_count = len(all_nodes)
|
||||||
|
|
||||||
|
|
@ -259,27 +254,27 @@ class ZepEntityReader:
|
||||||
if entity_type != "Entity":
|
if entity_type != "Entity":
|
||||||
node["labels"] = [entity_type] + labels
|
node["labels"] = [entity_type] + labels
|
||||||
|
|
||||||
# 获取所有边(用于后续关联查找)
|
# Read every edge so we can enrich entities later.
|
||||||
all_edges = self.get_all_edges(graph_id) if enrich_with_edges else []
|
all_edges = self.get_all_edges(graph_id) if enrich_with_edges else []
|
||||||
|
|
||||||
# 构建节点UUID到节点数据的映射
|
# uuid -> node-data map for fast lookup.
|
||||||
node_map = {n["uuid"]: n for n in all_nodes}
|
node_map = {n["uuid"]: n for n in all_nodes}
|
||||||
|
|
||||||
# 筛选符合条件的实体
|
# Filter to entities that match the criteria.
|
||||||
filtered_entities = []
|
filtered_entities = []
|
||||||
entity_types_found = set()
|
entity_types_found = set()
|
||||||
|
|
||||||
for node in all_nodes:
|
for node in all_nodes:
|
||||||
labels = node.get("labels", [])
|
labels = node.get("labels", [])
|
||||||
|
|
||||||
# 筛选逻辑:Labels必须包含除"Entity"和"Node"之外的标签
|
# Filtering rule: labels must contain something other than the defaults.
|
||||||
custom_labels = [l for l in labels if l not in ["Entity", "Node"]]
|
custom_labels = [l for l in labels if l not in ["Entity", "Node"]]
|
||||||
|
|
||||||
if not custom_labels:
|
if not custom_labels:
|
||||||
# 只有默认标签,跳过
|
# Only default labels — skip.
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 如果指定了预定义类型,检查是否匹配
|
# When a predefined-type list is supplied, require a match against it.
|
||||||
if defined_entity_types:
|
if defined_entity_types:
|
||||||
matching_labels = [l for l in custom_labels if l in defined_entity_types]
|
matching_labels = [l for l in custom_labels if l in defined_entity_types]
|
||||||
if not matching_labels:
|
if not matching_labels:
|
||||||
|
|
@ -290,7 +285,6 @@ class ZepEntityReader:
|
||||||
|
|
||||||
entity_types_found.add(entity_type)
|
entity_types_found.add(entity_type)
|
||||||
|
|
||||||
# 创建实体节点对象
|
|
||||||
entity = EntityNode(
|
entity = EntityNode(
|
||||||
uuid=node["uuid"],
|
uuid=node["uuid"],
|
||||||
name=node["name"],
|
name=node["name"],
|
||||||
|
|
@ -299,7 +293,7 @@ class ZepEntityReader:
|
||||||
attributes=node["attributes"],
|
attributes=node["attributes"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# 获取相关边和节点
|
# Enrich with related edges and neighboring nodes.
|
||||||
if enrich_with_edges:
|
if enrich_with_edges:
|
||||||
related_edges = []
|
related_edges = []
|
||||||
related_node_uuids = set()
|
related_node_uuids = set()
|
||||||
|
|
@ -324,7 +318,7 @@ class ZepEntityReader:
|
||||||
|
|
||||||
entity.related_edges = related_edges
|
entity.related_edges = related_edges
|
||||||
|
|
||||||
# 获取关联节点的基本信息
|
# Populate basic info for each neighboring node.
|
||||||
related_nodes = []
|
related_nodes = []
|
||||||
for related_uuid in related_node_uuids:
|
for related_uuid in related_node_uuids:
|
||||||
if related_uuid in node_map:
|
if related_uuid in node_map:
|
||||||
|
|
@ -354,18 +348,17 @@ class ZepEntityReader:
|
||||||
graph_id: str,
|
graph_id: str,
|
||||||
entity_uuid: str
|
entity_uuid: str
|
||||||
) -> Optional[EntityNode]:
|
) -> Optional[EntityNode]:
|
||||||
"""
|
"""Fetch a single entity with its full context (edges + neighbors), with retry.
|
||||||
获取单个实体及其完整上下文(边和关联节点,带重试机制)
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph_id: 图谱ID
|
graph_id: Graph identifier.
|
||||||
entity_uuid: 实体UUID
|
entity_uuid: Entity UUID.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
EntityNode或None
|
``EntityNode`` or ``None`` if not found.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 使用重试机制获取节点
|
# Fetch the node with retry.
|
||||||
node = self._call_with_retry(
|
node = self._call_with_retry(
|
||||||
func=lambda: self.client.graph.node.get(uuid_=entity_uuid),
|
func=lambda: self.client.graph.node.get(uuid_=entity_uuid),
|
||||||
operation_name=f"获取节点详情(uuid={entity_uuid[:8]}...)"
|
operation_name=f"获取节点详情(uuid={entity_uuid[:8]}...)"
|
||||||
|
|
@ -374,14 +367,14 @@ class ZepEntityReader:
|
||||||
if not node:
|
if not node:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# 获取节点的边
|
# Edges connected to this node.
|
||||||
edges = self.get_node_edges(entity_uuid)
|
edges = self.get_node_edges(entity_uuid)
|
||||||
|
|
||||||
# 获取所有节点用于关联查找
|
# All graph nodes, used for neighbor lookup.
|
||||||
all_nodes = self.get_all_nodes(graph_id)
|
all_nodes = self.get_all_nodes(graph_id)
|
||||||
node_map = {n["uuid"]: n for n in all_nodes}
|
node_map = {n["uuid"]: n for n in all_nodes}
|
||||||
|
|
||||||
# 处理相关边和节点
|
# Collect related edges and neighboring uuids.
|
||||||
related_edges = []
|
related_edges = []
|
||||||
related_node_uuids = set()
|
related_node_uuids = set()
|
||||||
|
|
||||||
|
|
@ -403,7 +396,7 @@ class ZepEntityReader:
|
||||||
})
|
})
|
||||||
related_node_uuids.add(edge["source_node_uuid"])
|
related_node_uuids.add(edge["source_node_uuid"])
|
||||||
|
|
||||||
# 获取关联节点信息
|
# Populate basic info for each neighboring node.
|
||||||
related_nodes = []
|
related_nodes = []
|
||||||
for related_uuid in related_node_uuids:
|
for related_uuid in related_node_uuids:
|
||||||
if related_uuid in node_map:
|
if related_uuid in node_map:
|
||||||
|
|
@ -435,16 +428,15 @@ class ZepEntityReader:
|
||||||
entity_type: str,
|
entity_type: str,
|
||||||
enrich_with_edges: bool = True
|
enrich_with_edges: bool = True
|
||||||
) -> List[EntityNode]:
|
) -> List[EntityNode]:
|
||||||
"""
|
"""Return every entity matching the given type.
|
||||||
获取指定类型的所有实体
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph_id: 图谱ID
|
graph_id: Graph identifier.
|
||||||
entity_type: 实体类型(如 "Student", "PublicFigure" 等)
|
entity_type: Entity type label (e.g. ``Student``, ``PublicFigure``).
|
||||||
enrich_with_edges: 是否获取相关边信息
|
enrich_with_edges: When ``True``, populate related edges/nodes.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
实体列表
|
A list of matching ``EntityNode`` instances.
|
||||||
"""
|
"""
|
||||||
result = self.filter_defined_entities(
|
result = self.filter_defined_entities(
|
||||||
graph_id=graph_id,
|
graph_id=graph_id,
|
||||||
|
|
|
||||||
|
|
@ -1,21 +1,20 @@
|
||||||
"""
|
"""MiroFish backend entry point."""
|
||||||
MiroFish Backend 启动入口
|
|
||||||
"""
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
# 解决 Windows 控制台中文乱码问题:在所有导入之前设置 UTF-8 编码
|
# Force UTF-8 on Windows console before importing anything that might write to
|
||||||
|
# stdout/stderr; otherwise non-ASCII characters render as mojibake.
|
||||||
if sys.platform == 'win32':
|
if sys.platform == 'win32':
|
||||||
# 设置环境变量确保 Python 使用 UTF-8
|
# Make sure Python itself uses UTF-8.
|
||||||
os.environ.setdefault('PYTHONIOENCODING', 'utf-8')
|
os.environ.setdefault('PYTHONIOENCODING', 'utf-8')
|
||||||
# 重新配置标准输出流为 UTF-8
|
# Reconfigure the standard streams to UTF-8.
|
||||||
if hasattr(sys.stdout, 'reconfigure'):
|
if hasattr(sys.stdout, 'reconfigure'):
|
||||||
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
|
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
|
||||||
if hasattr(sys.stderr, 'reconfigure'):
|
if hasattr(sys.stderr, 'reconfigure'):
|
||||||
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
|
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
|
||||||
|
|
||||||
# 添加项目根目录到路径
|
# Add the project root to sys.path so the ``app`` package resolves.
|
||||||
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
||||||
|
|
||||||
from app import create_app
|
from app import create_app
|
||||||
|
|
@ -23,8 +22,7 @@ from app.config import Config
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
"""主函数"""
|
"""Validate configuration and start the Flask development server."""
|
||||||
# 验证配置
|
|
||||||
errors = Config.validate()
|
errors = Config.validate()
|
||||||
if errors:
|
if errors:
|
||||||
print("配置错误:")
|
print("配置错误:")
|
||||||
|
|
@ -33,18 +31,15 @@ def main():
|
||||||
print("\n请检查 .env 文件中的配置")
|
print("\n请检查 .env 文件中的配置")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# 创建应用
|
|
||||||
app = create_app()
|
app = create_app()
|
||||||
|
|
||||||
# 获取运行配置
|
# Resolve runtime host/port from the environment.
|
||||||
host = os.environ.get('FLASK_HOST', '0.0.0.0')
|
host = os.environ.get('FLASK_HOST', '0.0.0.0')
|
||||||
port = int(os.environ.get('FLASK_PORT', 5001))
|
port = int(os.environ.get('FLASK_PORT', 5001))
|
||||||
debug = Config.DEBUG
|
debug = Config.DEBUG
|
||||||
|
|
||||||
# 启动服务
|
|
||||||
app.run(host=host, port=port, debug=debug, threaded=True)
|
app.run(host=host, port=port, debug=debug, threaded=True)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue