diff --git a/backend/scripts/run_parallel_simulation.py b/backend/scripts/run_parallel_simulation.py index 2a627ffd..9dd3d8b9 100644 --- a/backend/scripts/run_parallel_simulation.py +++ b/backend/scripts/run_parallel_simulation.py @@ -1,67 +1,70 @@ -""" -OASIS 双平台并行模拟预设脚本 -同时运行Twitter和Reddit模拟,读取相同的配置文件 +"""OASIS dual-platform parallel simulation preset script. -功能特性: -- 双平台(Twitter + Reddit)并行模拟 -- 完成模拟后不立即关闭环境,进入等待命令模式 -- 支持通过IPC接收Interview命令 -- 支持单个Agent采访和批量采访 -- 支持远程关闭环境命令 +Runs Twitter and Reddit simulations simultaneously, reading the same config file. -使用方式: +Features: +- Dual-platform (Twitter + Reddit) parallel simulation +- Keeps environments alive after the simulation finishes and enters wait-for-command mode +- Receives Interview commands via IPC +- Supports single-agent and batch interviews +- Supports a remote close-environment command + +Usage: python run_parallel_simulation.py --config simulation_config.json - python run_parallel_simulation.py --config simulation_config.json --no-wait # 完成后立即关闭 + python run_parallel_simulation.py --config simulation_config.json --no-wait # close immediately when done python run_parallel_simulation.py --config simulation_config.json --twitter-only python run_parallel_simulation.py --config simulation_config.json --reddit-only -日志结构: +Log layout: sim_xxx/ ├── twitter/ - │ └── actions.jsonl # Twitter 平台动作日志 + │ └── actions.jsonl # Twitter platform action log ├── reddit/ - │ └── actions.jsonl # Reddit 平台动作日志 - ├── simulation.log # 主模拟进程日志 - └── run_state.json # 运行状态(API 查询用) + │ └── actions.jsonl # Reddit platform action log + ├── simulation.log # main simulation process log + └── run_state.json # run state (used by API queries) """ # ============================================================ -# 解决 Windows 编码问题:在所有 import 之前设置 UTF-8 编码 -# 这是为了修复 OASIS 第三方库读取文件时未指定编码的问题 +# Fix the Windows encoding issue by forcing UTF-8 before any import. +# This works around the OASIS third-party library opening files without +# specifying an encoding. # ============================================================ import sys import os if sys.platform == 'win32': - # 设置 Python 默认 I/O 编码为 UTF-8 - # 这会影响所有未指定编码的 open() 调用 + # Set Python's default I/O encoding to UTF-8 so every open() call without + # an explicit encoding picks it up. os.environ.setdefault('PYTHONUTF8', '1') os.environ.setdefault('PYTHONIOENCODING', 'utf-8') - - # 重新配置标准输出流为 UTF-8(解决控制台中文乱码) + + # Reconfigure stdout/stderr to UTF-8 to avoid mojibake in the console. if hasattr(sys.stdout, 'reconfigure'): sys.stdout.reconfigure(encoding='utf-8', errors='replace') if hasattr(sys.stderr, 'reconfigure'): sys.stderr.reconfigure(encoding='utf-8', errors='replace') - - # 强制设置默认编码(影响 open() 函数的默认编码) - # 注意:这需要在 Python 启动时就设置,运行时设置可能不生效 - # 所以我们还需要 monkey-patch 内置的 open 函数 + + # Force the default encoding used by open(). The env-var approach above + # only works when set at interpreter startup, so we additionally + # monkey-patch the built-in open(). import builtins _original_open = builtins.open - - def _utf8_open(file, mode='r', buffering=-1, encoding=None, errors=None, + + def _utf8_open(file, mode='r', buffering=-1, encoding=None, errors=None, newline=None, closefd=True, opener=None): + """Wrap open() so text-mode calls default to UTF-8. + + Fixes third-party libraries (such as OASIS) that open files without + specifying an encoding. """ - 包装 open() 函数,对于文本模式默认使用 UTF-8 编码 - 这可以修复第三方库(如 OASIS)读取文件时未指定编码的问题 - """ - # 只对文本模式(非二进制)且未指定编码的情况设置默认编码 + # Only override when the caller is using text mode and did not request + # an explicit encoding. if encoding is None and 'b' not in mode: encoding = 'utf-8' - return _original_open(file, mode, buffering, encoding, errors, + return _original_open(file, mode, buffering, encoding, errors, newline, closefd, opener) - + builtins.open = _utf8_open import argparse @@ -77,26 +80,26 @@ from datetime import datetime from typing import Dict, Any, List, Optional, Tuple -# 全局变量:用于信号处理 +# Globals used by the signal handlers. _shutdown_event = None _cleanup_done = False -# 添加 backend 目录到路径 -# 脚本固定位于 backend/scripts/ 目录 +# Add the backend directory to sys.path. The script always lives in +# backend/scripts/. _scripts_dir = os.path.dirname(os.path.abspath(__file__)) _backend_dir = os.path.abspath(os.path.join(_scripts_dir, '..')) _project_root = os.path.abspath(os.path.join(_backend_dir, '..')) sys.path.insert(0, _scripts_dir) sys.path.insert(0, _backend_dir) -# 加载项目根目录的 .env 文件(包含 LLM_API_KEY 等配置) +# Load the .env from the project root (contains LLM_API_KEY etc.). from dotenv import load_dotenv _env_file = os.path.join(_project_root, '.env') if os.path.exists(_env_file): load_dotenv(_env_file) print(f"已加载环境配置: {_env_file}") else: - # 尝试加载 backend/.env + # Fall back to backend/.env. _backend_env = os.path.join(_backend_dir, '.env') if os.path.exists(_backend_env): load_dotenv(_backend_env) @@ -104,51 +107,51 @@ else: class MaxTokensWarningFilter(logging.Filter): - """过滤掉 camel-ai 关于 max_tokens 的警告(我们故意不设置 max_tokens,让模型自行决定)""" - + """Suppress camel-ai max_tokens warnings. + + We intentionally leave max_tokens unset so the model decides; the warning is noise. + """ + def filter(self, record): - # 过滤掉包含 max_tokens 警告的日志 if "max_tokens" in record.getMessage() and "Invalid or missing" in record.getMessage(): return False return True -# 在模块加载时立即添加过滤器,确保在 camel 代码执行前生效 +# Install the filter at import time so it is active before any camel code runs. logging.getLogger().addFilter(MaxTokensWarningFilter()) def disable_oasis_logging(): + """Disable verbose OASIS library logging. + + OASIS logs every agent observation and action which is extremely noisy; we + rely on our own action_logger instead. """ - 禁用 OASIS 库的详细日志输出 - OASIS 的日志太冗余(记录每个 agent 的观察和动作),我们使用自己的 action_logger - """ - # 禁用 OASIS 的所有日志器 oasis_loggers = [ "social.agent", - "social.twitter", + "social.twitter", "social.rec", "oasis.env", "table", ] - + for logger_name in oasis_loggers: logger = logging.getLogger(logger_name) - logger.setLevel(logging.CRITICAL) # 只记录严重错误 + logger.setLevel(logging.CRITICAL) # only keep severe errors logger.handlers.clear() logger.propagate = False def init_logging_for_simulation(simulation_dir: str): - """ - 初始化模拟的日志配置 - + """Initialize logging for a simulation run. + Args: - simulation_dir: 模拟目录路径 + simulation_dir: path to the simulation directory. """ - # 禁用 OASIS 的详细日志 disable_oasis_logging() - - # 清理旧的 log 目录(如果存在) + + # Clean up any pre-existing log directory. old_log_dir = os.path.join(simulation_dir, "log") if os.path.exists(old_log_dir): import shutil @@ -174,7 +177,8 @@ except ImportError as e: sys.exit(1) -# Twitter可用动作(不包含INTERVIEW,INTERVIEW只能通过ManualAction手动触发) +# Twitter actions available to agents. INTERVIEW is excluded because it can only +# be triggered manually via ManualAction. TWITTER_ACTIONS = [ ActionType.CREATE_POST, ActionType.LIKE_POST, @@ -184,7 +188,8 @@ TWITTER_ACTIONS = [ ActionType.QUOTE_POST, ] -# Reddit可用动作(不包含INTERVIEW,INTERVIEW只能通过ManualAction手动触发) +# Reddit actions available to agents. INTERVIEW is excluded because it can only +# be triggered manually via ManualAction. REDDIT_ACTIONS = [ ActionType.LIKE_POST, ActionType.DISLIKE_POST, @@ -202,23 +207,22 @@ REDDIT_ACTIONS = [ ] -# IPC相关常量 +# IPC-related constants. IPC_COMMANDS_DIR = "ipc_commands" IPC_RESPONSES_DIR = "ipc_responses" ENV_STATUS_FILE = "env_status.json" class CommandType: - """命令类型常量""" + """Command type constants.""" INTERVIEW = "interview" BATCH_INTERVIEW = "batch_interview" CLOSE_ENV = "close_env" class ParallelIPCHandler: - """ - 双平台IPC命令处理器 - - 管理两个平台的环境,处理Interview命令 + """Dual-platform IPC command handler. + + Manages both platform environments and processes Interview commands. """ def __init__( @@ -238,13 +242,12 @@ class ParallelIPCHandler: self.commands_dir = os.path.join(simulation_dir, IPC_COMMANDS_DIR) self.responses_dir = os.path.join(simulation_dir, IPC_RESPONSES_DIR) self.status_file = os.path.join(simulation_dir, ENV_STATUS_FILE) - - # 确保目录存在 + os.makedirs(self.commands_dir, exist_ok=True) os.makedirs(self.responses_dir, exist_ok=True) - + def update_status(self, status: str): - """更新环境状态""" + """Update the recorded environment status.""" with open(self.status_file, 'w', encoding='utf-8') as f: json.dump({ "status": status, @@ -254,11 +257,11 @@ class ParallelIPCHandler: }, f, ensure_ascii=False, indent=2) def poll_command(self) -> Optional[Dict[str, Any]]: - """轮询获取待处理命令""" + """Poll for the next pending command.""" if not os.path.exists(self.commands_dir): return None - - # 获取命令文件(按时间排序) + + # Collect command files sorted by mtime so older commands run first. command_files = [] for filename in os.listdir(self.commands_dir): if filename.endswith('.json'): @@ -277,7 +280,7 @@ class ParallelIPCHandler: return None def send_response(self, command_id: str, status: str, result: Dict = None, error: str = None): - """发送响应""" + """Send a response for a previously dispatched command.""" response = { "command_id": command_id, "status": status, @@ -289,8 +292,8 @@ class ParallelIPCHandler: response_file = os.path.join(self.responses_dir, f"{command_id}.json") with open(response_file, 'w', encoding='utf-8') as f: json.dump(response, f, ensure_ascii=False, indent=2) - - # 删除命令文件 + + # Remove the original command file once a response is recorded. command_file = os.path.join(self.commands_dir, f"{command_id}.json") try: os.remove(command_file) @@ -298,14 +301,14 @@ class ParallelIPCHandler: pass def _get_env_and_graph(self, platform: str): - """ - 获取指定平台的环境和agent_graph - + """Return the environment and agent graph for the given platform. + Args: - platform: 平台名称 ("twitter" 或 "reddit") - + platform: platform name ("twitter" or "reddit"). + Returns: - (env, agent_graph, platform_name) 或 (None, None, None) + Tuple ``(env, agent_graph, platform_name)`` or ``(None, None, None)`` + when the platform is unavailable. """ if platform == "twitter" and self.twitter_env: return self.twitter_env, self.twitter_agent_graph, "twitter" @@ -315,11 +318,10 @@ class ParallelIPCHandler: return None, None, None async def _interview_single_platform(self, agent_id: int, prompt: str, platform: str) -> Dict[str, Any]: - """ - 在单个平台上执行Interview - + """Run an Interview on a single platform. + Returns: - 包含结果的字典,或包含error的字典 + A dict with the interview result, or a dict containing an ``error`` key. """ env, agent_graph, actual_platform = self._get_env_and_graph(platform) @@ -343,22 +345,21 @@ class ParallelIPCHandler: return {"platform": platform, "error": str(e)} async def handle_interview(self, command_id: str, agent_id: int, prompt: str, platform: str = None) -> bool: - """ - 处理单个Agent采访命令 - + """Handle a single-agent interview command. + Args: - command_id: 命令ID - agent_id: Agent ID - prompt: 采访问题 - platform: 指定平台(可选) - - "twitter": 只采访Twitter平台 - - "reddit": 只采访Reddit平台 - - None/不指定: 同时采访两个平台,返回整合结果 - + command_id: command identifier. + agent_id: agent identifier. + prompt: interview prompt. + platform: optional platform selector. + - "twitter": interview on Twitter only. + - "reddit": interview on Reddit only. + - ``None``: interview on both platforms and return a merged result. + Returns: - True 表示成功,False 表示失败 + ``True`` on success, ``False`` on failure. """ - # 如果指定了平台,只采访该平台 + # If a specific platform was requested, only interview on that platform. if platform in ("twitter", "reddit"): result = await self._interview_single_platform(agent_id, prompt, platform) @@ -371,7 +372,7 @@ class ParallelIPCHandler: print(f" Interview完成: agent_id={agent_id}, platform={platform}") return True - # 未指定平台:同时采访两个平台 + # No platform specified: interview on both platforms simultaneously. if not self.twitter_env and not self.reddit_env: self.send_response(command_id, "failed", error="没有可用的模拟环境") return False @@ -383,7 +384,7 @@ class ParallelIPCHandler: } success_count = 0 - # 并行采访两个平台 + # Run the two platform interviews in parallel. tasks = [] platforms_to_interview = [] @@ -394,8 +395,7 @@ class ParallelIPCHandler: if self.reddit_env: tasks.append(self._interview_single_platform(agent_id, prompt, "reddit")) platforms_to_interview.append("reddit") - - # 并行执行 + platform_results = await asyncio.gather(*tasks) for platform_name, platform_result in zip(platforms_to_interview, platform_results): @@ -414,22 +414,21 @@ class ParallelIPCHandler: return False async def handle_batch_interview(self, command_id: str, interviews: List[Dict], platform: str = None) -> bool: - """ - 处理批量采访命令 - + """Handle a batch-interview command. + Args: - command_id: 命令ID - interviews: [{"agent_id": int, "prompt": str, "platform": str(optional)}, ...] - platform: 默认平台(可被每个interview项覆盖) - - "twitter": 只采访Twitter平台 - - "reddit": 只采访Reddit平台 - - None/不指定: 每个Agent同时采访两个平台 + command_id: command identifier. + interviews: ``[{"agent_id": int, "prompt": str, "platform": str(optional)}, ...]``. + platform: default platform (can be overridden per interview entry). + - "twitter": interview on Twitter only. + - "reddit": interview on Reddit only. + - ``None``: interview every agent on both platforms. """ - # 按平台分组 + # Bucket interviews by target platform. twitter_interviews = [] reddit_interviews = [] - both_platforms_interviews = [] # 需要同时采访两个平台的 - + both_platforms_interviews = [] # entries that need both platforms + for interview in interviews: item_platform = interview.get("platform", platform) if item_platform == "twitter": @@ -437,10 +436,10 @@ class ParallelIPCHandler: elif item_platform == "reddit": reddit_interviews.append(interview) else: - # 未指定平台:两个平台都采访 + # No platform specified: interview on both. both_platforms_interviews.append(interview) - - # 把 both_platforms_interviews 拆分到两个平台 + + # Fan the both-platform entries out into the per-platform buckets. if both_platforms_interviews: if self.twitter_env: twitter_interviews.extend(both_platforms_interviews) @@ -448,8 +447,8 @@ class ParallelIPCHandler: reddit_interviews.extend(both_platforms_interviews) results = {} - - # 处理Twitter平台的采访 + + # Run the Twitter-side interviews. if twitter_interviews and self.twitter_env: try: twitter_actions = {} @@ -476,7 +475,7 @@ class ParallelIPCHandler: except Exception as e: print(f" Twitter批量Interview失败: {e}") - # 处理Reddit平台的采访 + # Run the Reddit-side interviews. if reddit_interviews and self.reddit_env: try: reddit_actions = {} @@ -515,7 +514,7 @@ class ParallelIPCHandler: return False def _get_interview_result(self, agent_id: int, platform: str) -> Dict[str, Any]: - """从数据库获取最新的Interview结果""" + """Read the latest Interview result for an agent from the database.""" db_path = os.path.join(self.simulation_dir, f"{platform}_simulation.db") result = { @@ -530,8 +529,8 @@ class ParallelIPCHandler: try: conn = sqlite3.connect(db_path) cursor = conn.cursor() - - # 查询最新的Interview记录 + + # Look up the most recent Interview row for this agent. cursor.execute(""" SELECT user_id, info, created_at FROM trace @@ -558,11 +557,10 @@ class ParallelIPCHandler: return result async def process_commands(self) -> bool: - """ - 处理所有待处理命令 - + """Process all pending commands. + Returns: - True 表示继续运行,False 表示应该退出 + ``True`` to keep running, ``False`` if the process should exit. """ command = self.poll_command() if not command: @@ -602,15 +600,15 @@ class ParallelIPCHandler: def load_config(config_path: str) -> Dict[str, Any]: - """加载配置文件""" + """Load a JSON config file from disk.""" with open(config_path, 'r', encoding='utf-8') as f: return json.load(f) -# 需要过滤掉的非核心动作类型(这些动作对分析价值较低) +# Non-core action types to filter out: they provide little analytical value. FILTERED_ACTIONS = {'refresh', 'sign_up'} -# 动作类型映射表(数据库中的名称 -> 标准名称) +# Action-type mapping (database name -> canonical name). ACTION_TYPE_MAP = { 'create_post': 'CREATE_POST', 'like_post': 'LIKE_POST', @@ -631,16 +629,16 @@ ACTION_TYPE_MAP = { def get_agent_names_from_config(config: Dict[str, Any]) -> Dict[int, str]: - """ - 从 simulation_config 中获取 agent_id -> entity_name 的映射 - - 这样可以在 actions.jsonl 中显示真实的实体名称,而不是 "Agent_0" 这样的代号 - + """Build an ``agent_id -> entity_name`` map from the simulation config. + + Using the entity name lets actions.jsonl display the real entity rather + than placeholder labels like ``Agent_0``. + Args: - config: simulation_config.json 的内容 - + config: contents of ``simulation_config.json``. + Returns: - agent_id -> entity_name 的映射字典 + Mapping from agent id to entity name. """ agent_names = {} agent_configs = config.get("agent_configs", []) @@ -659,18 +657,20 @@ def fetch_new_actions_from_db( last_rowid: int, agent_names: Dict[int, str] ) -> Tuple[List[Dict[str, Any]], int]: - """ - 从数据库中获取新的动作记录,并补充完整的上下文信息 - + """Fetch new action rows from the database and enrich them with context. + Args: - db_path: 数据库文件路径 - last_rowid: 上次读取的最大 rowid 值(使用 rowid 而不是 created_at,因为不同平台的 created_at 格式不同) - agent_names: agent_id -> agent_name 映射 - + db_path: path to the database file. + last_rowid: highest rowid processed previously. We track ``rowid`` + rather than ``created_at`` because the two platforms use different + ``created_at`` formats. + agent_names: ``agent_id -> agent_name`` mapping. + Returns: - (actions_list, new_last_rowid) - - actions_list: 动作列表,每个元素包含 agent_id, agent_name, action_type, action_args(含上下文信息) - - new_last_rowid: 新的最大 rowid 值 + Tuple ``(actions_list, new_last_rowid)``. + - ``actions_list``: action records, each containing ``agent_id``, + ``agent_name``, ``action_type``, and ``action_args`` (with context). + - ``new_last_rowid``: the new highest rowid seen. """ actions = [] new_last_rowid = last_rowid @@ -681,9 +681,10 @@ def fetch_new_actions_from_db( try: conn = sqlite3.connect(db_path) cursor = conn.cursor() - - # 使用 rowid 来追踪已处理的记录(rowid 是 SQLite 的内置自增字段) - # 这样可以避免 created_at 格式差异问题(Twitter 用整数,Reddit 用日期时间字符串) + + # Use ``rowid`` to track processed rows. ``rowid`` is SQLite's built-in + # auto-increment column and avoids the cross-platform ``created_at`` + # format mismatch (Twitter stores integers, Reddit stores datetime strings). cursor.execute(""" SELECT rowid, user_id, action, info FROM trace @@ -692,20 +693,17 @@ def fetch_new_actions_from_db( """, (last_rowid,)) for rowid, user_id, action, info_json in cursor.fetchall(): - # 更新最大 rowid new_last_rowid = rowid - - # 过滤非核心动作 + if action in FILTERED_ACTIONS: continue - - # 解析动作参数 + try: action_args = json.loads(info_json) if info_json else {} except json.JSONDecodeError: action_args = {} - - # 精简 action_args,只保留关键字段(保留完整内容,不截断) + + # Slim ``action_args`` down to the key fields. Content is kept in full (no truncation). simplified_args = {} if 'content' in action_args: simplified_args['content'] = action_args['content'] @@ -726,10 +724,9 @@ def fetch_new_actions_from_db( if 'dislike_id' in action_args: simplified_args['dislike_id'] = action_args['dislike_id'] - # 转换动作类型名称 action_type = ACTION_TYPE_MAP.get(action, action.upper()) - - # 补充上下文信息(帖子内容、用户名等) + + # Enrich with context such as post content and author name. _enrich_action_context(cursor, action_type, simplified_args, agent_names) actions.append({ @@ -752,17 +749,16 @@ def _enrich_action_context( action_args: Dict[str, Any], agent_names: Dict[int, str] ) -> None: - """ - 为动作补充上下文信息(帖子内容、用户名等) - + """Enrich an action's args with context such as post content and author name. + Args: - cursor: 数据库游标 - action_type: 动作类型 - action_args: 动作参数(会被修改) - agent_names: agent_id -> agent_name 映射 + cursor: database cursor. + action_type: action type. + action_args: action args (mutated in place). + agent_names: ``agent_id -> agent_name`` mapping. """ try: - # 点赞/踩帖子:补充帖子内容和作者 + # Like/dislike post: include the post content and author name. if action_type in ('LIKE_POST', 'DISLIKE_POST'): post_id = action_args.get('post_id') if post_id: @@ -771,11 +767,11 @@ def _enrich_action_context( action_args['post_content'] = post_info.get('content', '') action_args['post_author_name'] = post_info.get('author_name', '') - # 转发帖子:补充原帖内容和作者 + # Repost: include the original post content and author name. elif action_type == 'REPOST': new_post_id = action_args.get('new_post_id') if new_post_id: - # 转发帖子的 original_post_id 指向原帖 + # On a repost row, ``original_post_id`` points at the original post. cursor.execute(""" SELECT original_post_id FROM post WHERE post_id = ? """, (new_post_id,)) @@ -787,18 +783,18 @@ def _enrich_action_context( action_args['original_content'] = original_info.get('content', '') action_args['original_author_name'] = original_info.get('author_name', '') - # 引用帖子:补充原帖内容、作者和引用评论 + # Quote post: include the original post content, author name, and quote comment. elif action_type == 'QUOTE_POST': quoted_id = action_args.get('quoted_id') new_post_id = action_args.get('new_post_id') - + if quoted_id: original_info = _get_post_info(cursor, quoted_id, agent_names) if original_info: action_args['original_content'] = original_info.get('content', '') action_args['original_author_name'] = original_info.get('author_name', '') - - # 获取引用帖子的评论内容(quote_content) + + # Read the quote comment (``quote_content``). if new_post_id: cursor.execute(""" SELECT quote_content FROM post WHERE post_id = ? @@ -807,11 +803,11 @@ def _enrich_action_context( if row and row[0]: action_args['quote_content'] = row[0] - # 关注用户:补充被关注用户的名称 + # Follow: include the followee's display name. elif action_type == 'FOLLOW': follow_id = action_args.get('follow_id') if follow_id: - # 从 follow 表获取 followee_id + # Look up ``followee_id`` from the ``follow`` table. cursor.execute(""" SELECT followee_id FROM follow WHERE follow_id = ? """, (follow_id,)) @@ -822,16 +818,16 @@ def _enrich_action_context( if target_name: action_args['target_user_name'] = target_name - # 屏蔽用户:补充被屏蔽用户的名称 + # Mute: include the muted user's display name. elif action_type == 'MUTE': - # 从 action_args 中获取 user_id 或 target_id + # Read ``user_id`` or ``target_id`` from action_args. target_id = action_args.get('user_id') or action_args.get('target_id') if target_id: target_name = _get_user_name(cursor, target_id, agent_names) if target_name: action_args['target_user_name'] = target_name - # 点赞/踩评论:补充评论内容和作者 + # Like/dislike comment: include the comment content and author name. elif action_type in ('LIKE_COMMENT', 'DISLIKE_COMMENT'): comment_id = action_args.get('comment_id') if comment_id: @@ -840,7 +836,7 @@ def _enrich_action_context( action_args['comment_content'] = comment_info.get('content', '') action_args['comment_author_name'] = comment_info.get('author_name', '') - # 发表评论:补充所评论的帖子信息 + # Create comment: include the parent post's content and author name. elif action_type == 'CREATE_COMMENT': post_id = action_args.get('post_id') if post_id: @@ -850,7 +846,7 @@ def _enrich_action_context( action_args['post_author_name'] = post_info.get('author_name', '') except Exception as e: - # 补充上下文失败不影响主流程 + # Failing to enrich context must not break the main flow. print(f"补充动作上下文失败: {e}") @@ -859,16 +855,15 @@ def _get_post_info( post_id: int, agent_names: Dict[int, str] ) -> Optional[Dict[str, str]]: - """ - 获取帖子信息 - + """Look up post info. + Args: - cursor: 数据库游标 - post_id: 帖子ID - agent_names: agent_id -> agent_name 映射 - + cursor: database cursor. + post_id: post identifier. + agent_names: ``agent_id -> agent_name`` mapping. + Returns: - 包含 content 和 author_name 的字典,或 None + Dict with ``content`` and ``author_name``, or ``None`` when not found. """ try: cursor.execute(""" @@ -882,18 +877,18 @@ def _get_post_info( content = row[0] or '' user_id = row[1] agent_id = row[2] - - # 优先使用 agent_names 中的名称 + + # Prefer the entity_name supplied via agent_names. author_name = '' if agent_id is not None and agent_id in agent_names: author_name = agent_names[agent_id] elif user_id: - # 从 user 表获取名称 + # Fall back to the user table. cursor.execute("SELECT name, user_name FROM user WHERE user_id = ?", (user_id,)) user_row = cursor.fetchone() if user_row: author_name = user_row[0] or user_row[1] or '' - + return {'content': content, 'author_name': author_name} except Exception: pass @@ -905,16 +900,15 @@ def _get_user_name( user_id: int, agent_names: Dict[int, str] ) -> Optional[str]: - """ - 获取用户名称 - + """Look up a user's display name. + Args: - cursor: 数据库游标 - user_id: 用户ID - agent_names: agent_id -> agent_name 映射 - + cursor: database cursor. + user_id: user identifier. + agent_names: ``agent_id -> agent_name`` mapping. + Returns: - 用户名称,或 None + Display name, or ``None`` when the user cannot be found. """ try: cursor.execute(""" @@ -925,8 +919,8 @@ def _get_user_name( agent_id = row[0] name = row[1] user_name = row[2] - - # 优先使用 agent_names 中的名称 + + # Prefer the entity_name supplied via agent_names. if agent_id is not None and agent_id in agent_names: return agent_names[agent_id] return name or user_name or '' @@ -940,16 +934,15 @@ def _get_comment_info( comment_id: int, agent_names: Dict[int, str] ) -> Optional[Dict[str, str]]: - """ - 获取评论信息 - + """Look up comment info. + Args: - cursor: 数据库游标 - comment_id: 评论ID - agent_names: agent_id -> agent_name 映射 - + cursor: database cursor. + comment_id: comment identifier. + agent_names: ``agent_id -> agent_name`` mapping. + Returns: - 包含 content 和 author_name 的字典,或 None + Dict with ``content`` and ``author_name``, or ``None`` when not found. """ try: cursor.execute(""" @@ -963,18 +956,18 @@ def _get_comment_info( content = row[0] or '' user_id = row[1] agent_id = row[2] - - # 优先使用 agent_names 中的名称 + + # Prefer the entity_name supplied via agent_names. author_name = '' if agent_id is not None and agent_id in agent_names: author_name = agent_names[agent_id] elif user_id: - # 从 user 表获取名称 + # Fall back to the user table. cursor.execute("SELECT name, user_name FROM user WHERE user_id = ?", (user_id,)) user_row = cursor.fetchone() if user_row: author_name = user_row[0] or user_row[1] or '' - + return {'content': content, 'author_name': author_name} except Exception: pass @@ -982,44 +975,44 @@ def _get_comment_info( def create_model(config: Dict[str, Any], use_boost: bool = False): - """ - 创建LLM模型 - - 支持双 LLM 配置,用于并行模拟时提速: - - 通用配置:LLM_API_KEY, LLM_BASE_URL, LLM_MODEL_NAME - - 加速配置(可选):LLM_BOOST_API_KEY, LLM_BOOST_BASE_URL, LLM_BOOST_MODEL_NAME - - 如果配置了加速 LLM,并行模拟时可以让不同平台使用不同的 API 服务商,提高并发能力。 - + """Create the LLM model used by the simulation. + + Two LLM configurations are supported, which lets parallel simulations run faster: + - default: ``LLM_API_KEY``, ``LLM_BASE_URL``, ``LLM_MODEL_NAME``. + - boost (optional): ``LLM_BOOST_API_KEY``, ``LLM_BOOST_BASE_URL``, ``LLM_BOOST_MODEL_NAME``. + + When a boost LLM is configured, the two platforms can target different API + providers, increasing overall concurrency. + Args: - config: 模拟配置字典 - use_boost: 是否使用加速 LLM 配置(如果可用) + config: simulation config dict. + use_boost: whether to use the boost LLM config when available. """ - # 检查是否有加速配置 + # Inspect the boost configuration. boost_api_key = os.environ.get("LLM_BOOST_API_KEY", "") boost_base_url = os.environ.get("LLM_BOOST_BASE_URL", "") boost_model = os.environ.get("LLM_BOOST_MODEL_NAME", "") has_boost_config = bool(boost_api_key) - - # 根据参数和配置情况选择使用哪个 LLM + + # Choose which LLM to use based on the request and what is configured. if use_boost and has_boost_config: - # 使用加速配置 + # Use the boost configuration. llm_api_key = boost_api_key llm_base_url = boost_base_url llm_model = boost_model or os.environ.get("LLM_MODEL_NAME", "") config_label = "[加速LLM]" else: - # 使用通用配置 + # Use the default configuration. llm_api_key = os.environ.get("LLM_API_KEY", "") llm_base_url = os.environ.get("LLM_BASE_URL", "") llm_model = os.environ.get("LLM_MODEL_NAME", "") config_label = "[通用LLM]" - - # 如果 .env 中没有模型名,则使用 config 作为备用 + + # Fall back to the model name in the config when .env does not provide one. if not llm_model: llm_model = config.get("llm_model", "gpt-4o-mini") - - # 设置 camel-ai 所需的环境变量 + + # Populate the env vars camel-ai expects. if llm_api_key: os.environ["OPENAI_API_KEY"] = llm_api_key @@ -1043,7 +1036,7 @@ def get_active_agents_for_round( current_hour: int, round_num: int ) -> List: - """根据时间和配置决定本轮激活哪些Agent""" + """Decide which agents are active in this round based on time and config.""" time_config = config.get("time_config", {}) agent_configs = config.get("agent_configs", []) @@ -1091,7 +1084,7 @@ def get_active_agents_for_round( class PlatformSimulation: - """平台模拟结果容器""" + """Container for the result of a platform simulation.""" def __init__(self): self.env = None self.agent_graph = None @@ -1105,17 +1098,17 @@ async def run_twitter_simulation( main_logger: Optional[SimulationLogManager] = None, max_rounds: Optional[int] = None ) -> PlatformSimulation: - """运行Twitter模拟 - + """Run the Twitter simulation. + Args: - config: 模拟配置 - simulation_dir: 模拟目录 - action_logger: 动作日志记录器 - main_logger: 主日志管理器 - max_rounds: 最大模拟轮数(可选,用于截断过长的模拟) - + config: simulation config. + simulation_dir: simulation directory. + action_logger: action logger. + main_logger: main log manager. + max_rounds: optional cap on the number of rounds, used to truncate long runs. + Returns: - PlatformSimulation: 包含env和agent_graph的结果对象 + PlatformSimulation containing the env and agent_graph. """ result = PlatformSimulation() @@ -1125,11 +1118,11 @@ async def run_twitter_simulation( print(f"[Twitter] {msg}") log_info("初始化...") - - # Twitter 使用通用 LLM 配置 + + # Twitter uses the default LLM config. model = create_model(config, use_boost=False) - - # OASIS Twitter使用CSV格式 + + # OASIS Twitter expects a CSV profile file. profile_path = os.path.join(simulation_dir, "twitter_profiles.csv") if not os.path.exists(profile_path): log_info(f"错误: Profile文件不存在: {profile_path}") @@ -1141,13 +1134,13 @@ async def run_twitter_simulation( available_actions=TWITTER_ACTIONS, ) - # 从配置文件获取 Agent 真实名称映射(使用 entity_name 而非默认的 Agent_X) + # Pull real agent names from the config (use entity_name rather than the default Agent_X). agent_names = get_agent_names_from_config(config) - # 如果配置中没有某个 agent,则使用 OASIS 的默认名称 + # If the config does not list a particular agent, fall back to OASIS's default name. for agent_id, agent in result.agent_graph.get_agents(): if agent_id not in agent_names: agent_names[agent_id] = getattr(agent, 'name', f'Agent_{agent_id}') - + db_path = os.path.join(simulation_dir, "twitter_simulation.db") if os.path.exists(db_path): os.remove(db_path) @@ -1156,7 +1149,7 @@ async def run_twitter_simulation( agent_graph=result.agent_graph, platform=oasis.DefaultPlatformType.TWITTER, database_path=db_path, - semaphore=30, # 限制最大并发 LLM 请求数,防止 API 过载 + semaphore=30, # cap concurrent LLM requests to avoid overloading the API ) await result.env.reset() @@ -1166,13 +1159,13 @@ async def run_twitter_simulation( action_logger.log_simulation_start(config) total_actions = 0 - last_rowid = 0 # 跟踪数据库中最后处理的行号(使用 rowid 避免 created_at 格式差异) - - # 执行初始事件 + last_rowid = 0 # last processed db row; using rowid avoids created_at format differences + + # Run the initial events. event_config = config.get("event_config", {}) initial_posts = event_config.get("initial_posts", []) - - # 记录 round 0 开始(初始事件阶段) + + # Mark the start of round 0 (the initial-events phase). if action_logger: action_logger.log_round_start(0, 0) # round 0, simulated_hour 0 @@ -1206,17 +1199,17 @@ async def run_twitter_simulation( await result.env.step(initial_actions) log_info(f"已发布 {len(initial_actions)} 条初始帖子") - # 记录 round 0 结束 + # Mark the end of round 0. if action_logger: action_logger.log_round_end(0, initial_action_count) - - # 主模拟循环 + + # Main simulation loop. time_config = config.get("time_config", {}) total_hours = time_config.get("total_simulation_hours", 72) minutes_per_round = time_config.get("minutes_per_round", 30) total_rounds = (total_hours * 60) // minutes_per_round - - # 如果指定了最大轮数,则截断 + + # Truncate when a max round count was supplied. if max_rounds is not None and max_rounds > 0: original_rounds = total_rounds total_rounds = min(total_rounds, max_rounds) @@ -1226,7 +1219,7 @@ async def run_twitter_simulation( start_time = datetime.now() for round_num in range(total_rounds): - # 检查是否收到退出信号 + # Bail out if a shutdown signal was received. if _shutdown_event and _shutdown_event.is_set(): if main_logger: main_logger.info(f"收到退出信号,在第 {round_num + 1} 轮停止模拟") @@ -1240,12 +1233,12 @@ async def run_twitter_simulation( result.env, config, simulated_hour, round_num ) - # 无论是否有活跃agent,都记录round开始 + # Always log round-start, even when no agents are active. if action_logger: action_logger.log_round_start(round_num + 1, simulated_hour) - + if not active_agents: - # 没有活跃agent时也记录round结束(actions_count=0) + # Still emit round-end (with actions_count=0) so the log stays consistent. if action_logger: action_logger.log_round_end(round_num + 1, 0) continue @@ -1253,7 +1246,7 @@ async def run_twitter_simulation( actions = {agent: LLMAction() for _, agent in active_agents} await result.env.step(actions) - # 从数据库获取实际执行的动作并记录 + # Pull the actually-executed actions from the database and log them. actual_actions, last_rowid = fetch_new_actions_from_db( db_path, last_rowid, agent_names ) @@ -1278,7 +1271,7 @@ async def run_twitter_simulation( progress = (round_num + 1) / total_rounds * 100 log_info(f"Day {simulated_day}, {simulated_hour:02d}:00 - Round {round_num + 1}/{total_rounds} ({progress:.1f}%)") - # 注意:不关闭环境,保留给Interview使用 + # Note: do NOT close the env here; we keep it alive for Interview commands. if action_logger: action_logger.log_simulation_end(total_rounds, total_actions) @@ -1297,17 +1290,17 @@ async def run_reddit_simulation( main_logger: Optional[SimulationLogManager] = None, max_rounds: Optional[int] = None ) -> PlatformSimulation: - """运行Reddit模拟 - + """Run the Reddit simulation. + Args: - config: 模拟配置 - simulation_dir: 模拟目录 - action_logger: 动作日志记录器 - main_logger: 主日志管理器 - max_rounds: 最大模拟轮数(可选,用于截断过长的模拟) - + config: simulation config. + simulation_dir: simulation directory. + action_logger: action logger. + main_logger: main log manager. + max_rounds: optional cap on the number of rounds, used to truncate long runs. + Returns: - PlatformSimulation: 包含env和agent_graph的结果对象 + PlatformSimulation containing the env and agent_graph. """ result = PlatformSimulation() @@ -1318,7 +1311,7 @@ async def run_reddit_simulation( log_info("初始化...") - # Reddit 使用加速 LLM 配置(如果有的话,否则回退到通用配置) + # Reddit uses the boost LLM config when available, falling back to the default. model = create_model(config, use_boost=True) profile_path = os.path.join(simulation_dir, "reddit_profiles.json") @@ -1332,13 +1325,13 @@ async def run_reddit_simulation( available_actions=REDDIT_ACTIONS, ) - # 从配置文件获取 Agent 真实名称映射(使用 entity_name 而非默认的 Agent_X) + # Pull real agent names from the config (use entity_name rather than the default Agent_X). agent_names = get_agent_names_from_config(config) - # 如果配置中没有某个 agent,则使用 OASIS 的默认名称 + # If the config does not list a particular agent, fall back to OASIS's default name. for agent_id, agent in result.agent_graph.get_agents(): if agent_id not in agent_names: agent_names[agent_id] = getattr(agent, 'name', f'Agent_{agent_id}') - + db_path = os.path.join(simulation_dir, "reddit_simulation.db") if os.path.exists(db_path): os.remove(db_path) @@ -1347,7 +1340,7 @@ async def run_reddit_simulation( agent_graph=result.agent_graph, platform=oasis.DefaultPlatformType.REDDIT, database_path=db_path, - semaphore=30, # 限制最大并发 LLM 请求数,防止 API 过载 + semaphore=30, # cap concurrent LLM requests to avoid overloading the API ) await result.env.reset() @@ -1357,13 +1350,13 @@ async def run_reddit_simulation( action_logger.log_simulation_start(config) total_actions = 0 - last_rowid = 0 # 跟踪数据库中最后处理的行号(使用 rowid 避免 created_at 格式差异) - - # 执行初始事件 + last_rowid = 0 # last processed db row; using rowid avoids created_at format differences + + # Run the initial events. event_config = config.get("event_config", {}) initial_posts = event_config.get("initial_posts", []) - - # 记录 round 0 开始(初始事件阶段) + + # Mark the start of round 0 (the initial-events phase). if action_logger: action_logger.log_round_start(0, 0) # round 0, simulated_hour 0 @@ -1405,17 +1398,17 @@ async def run_reddit_simulation( await result.env.step(initial_actions) log_info(f"已发布 {len(initial_actions)} 条初始帖子") - # 记录 round 0 结束 + # Mark the end of round 0. if action_logger: action_logger.log_round_end(0, initial_action_count) - - # 主模拟循环 + + # Main simulation loop. time_config = config.get("time_config", {}) total_hours = time_config.get("total_simulation_hours", 72) minutes_per_round = time_config.get("minutes_per_round", 30) total_rounds = (total_hours * 60) // minutes_per_round - - # 如果指定了最大轮数,则截断 + + # Truncate when a max round count was supplied. if max_rounds is not None and max_rounds > 0: original_rounds = total_rounds total_rounds = min(total_rounds, max_rounds) @@ -1425,7 +1418,7 @@ async def run_reddit_simulation( start_time = datetime.now() for round_num in range(total_rounds): - # 检查是否收到退出信号 + # Bail out if a shutdown signal was received. if _shutdown_event and _shutdown_event.is_set(): if main_logger: main_logger.info(f"收到退出信号,在第 {round_num + 1} 轮停止模拟") @@ -1439,12 +1432,12 @@ async def run_reddit_simulation( result.env, config, simulated_hour, round_num ) - # 无论是否有活跃agent,都记录round开始 + # Always log round-start, even when no agents are active. if action_logger: action_logger.log_round_start(round_num + 1, simulated_hour) - + if not active_agents: - # 没有活跃agent时也记录round结束(actions_count=0) + # Still emit round-end (with actions_count=0) so the log stays consistent. if action_logger: action_logger.log_round_end(round_num + 1, 0) continue @@ -1452,7 +1445,7 @@ async def run_reddit_simulation( actions = {agent: LLMAction() for _, agent in active_agents} await result.env.step(actions) - # 从数据库获取实际执行的动作并记录 + # Pull the actually-executed actions from the database and log them. actual_actions, last_rowid = fetch_new_actions_from_db( db_path, last_rowid, agent_names ) @@ -1477,7 +1470,7 @@ async def run_reddit_simulation( progress = (round_num + 1) / total_rounds * 100 log_info(f"Day {simulated_day}, {simulated_hour:02d}:00 - Round {round_num + 1}/{total_rounds} ({progress:.1f}%)") - # 注意:不关闭环境,保留给Interview使用 + # Note: do NOT close the env here; we keep it alive for Interview commands. if action_logger: action_logger.log_simulation_end(total_rounds, total_actions) @@ -1522,7 +1515,8 @@ async def main(): args = parser.parse_args() - # 在 main 函数开始时创建 shutdown 事件,确保整个程序都能响应退出信号 + # Create the shutdown event at the start of main() so the whole program + # can respond to exit signals. global _shutdown_event _shutdown_event = asyncio.Event() @@ -1534,10 +1528,10 @@ async def main(): simulation_dir = os.path.dirname(args.config) or "." wait_for_commands = not args.no_wait - # 初始化日志配置(禁用 OASIS 日志,清理旧文件) + # Initialize logging (disable OASIS logs, clean up stale files). init_logging_for_simulation(simulation_dir) - - # 创建日志管理器 + + # Create the log manager. log_manager = SimulationLogManager(simulation_dir) twitter_logger = log_manager.get_twitter_logger() reddit_logger = log_manager.get_reddit_logger() @@ -1572,7 +1566,7 @@ async def main(): start_time = datetime.now() - # 存储两个平台的模拟结果 + # Holds the result for each platform simulation. twitter_result: Optional[PlatformSimulation] = None reddit_result: Optional[PlatformSimulation] = None @@ -1581,7 +1575,7 @@ async def main(): elif args.reddit_only: reddit_result = await run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager, args.max_rounds) else: - # 并行运行(每个平台使用独立的日志记录器) + # Run both platforms in parallel; each platform uses its own logger. results = await asyncio.gather( run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager, args.max_rounds), run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager, args.max_rounds), @@ -1592,7 +1586,7 @@ async def main(): log_manager.info("=" * 60) log_manager.info(f"模拟循环完成! 总耗时: {total_elapsed:.1f}秒") - # 是否进入等待命令模式 + # Enter wait-for-command mode if requested. if wait_for_commands: log_manager.info("") log_manager.info("=" * 60) @@ -1600,7 +1594,7 @@ async def main(): log_manager.info("支持的命令: interview, batch_interview, close_env") log_manager.info("=" * 60) - # 创建IPC处理器 + # Create the IPC handler. ipc_handler = ParallelIPCHandler( simulation_dir=simulation_dir, twitter_env=twitter_result.env if twitter_result else None, @@ -1610,18 +1604,18 @@ async def main(): ) ipc_handler.update_status("alive") - # 等待命令循环(使用全局 _shutdown_event) + # Command-wait loop (driven by the global ``_shutdown_event``). try: while not _shutdown_event.is_set(): should_continue = await ipc_handler.process_commands() if not should_continue: break - # 使用 wait_for 替代 sleep,这样可以响应 shutdown_event + # Use ``wait_for`` instead of ``sleep`` so the loop reacts to shutdown_event. try: await asyncio.wait_for(_shutdown_event.wait(), timeout=0.5) - break # 收到退出信号 + break # shutdown signal received except asyncio.TimeoutError: - pass # 超时继续循环 + pass # timed out, continue looping except KeyboardInterrupt: print("\n收到中断信号") except asyncio.CancelledError: @@ -1632,7 +1626,7 @@ async def main(): log_manager.info("\n关闭环境...") ipc_handler.update_status("stopped") - # 关闭环境 + # Close the environments. if twitter_result and twitter_result.env: await twitter_result.env.close() log_manager.info("[Twitter] 环境已关闭") @@ -1651,14 +1645,13 @@ async def main(): def setup_signal_handlers(loop=None): - """ - 设置信号处理器,确保收到 SIGTERM/SIGINT 时能够正确退出 - - 持久化模拟场景:模拟完成后不退出,等待 interview 命令 - 当收到终止信号时,需要: - 1. 通知 asyncio 循环退出等待 - 2. 让程序有机会正常清理资源(关闭数据库、环境等) - 3. 然后才退出 + """Install signal handlers that exit cleanly on SIGTERM/SIGINT. + + Persistent-simulation flow: the process keeps running after the simulation + finishes so it can serve interview commands. On a termination signal we: + 1. Tell the asyncio loop to stop waiting. + 2. Give the program a chance to clean up (close databases, envs, ...). + 3. Then exit. """ def signal_handler(signum, frame): global _cleanup_done @@ -1667,12 +1660,12 @@ def setup_signal_handlers(loop=None): if not _cleanup_done: _cleanup_done = True - # 设置事件通知 asyncio 循环退出(让循环有机会清理资源) + # Notify the asyncio loop to exit so it can clean up resources. if _shutdown_event: _shutdown_event.set() - - # 不要直接 sys.exit(),让 asyncio 循环正常退出并清理资源 - # 如果是重复收到信号,才强制退出 + + # Avoid sys.exit() on the first signal: let the asyncio loop exit cleanly. + # Only force-exit if a second signal comes in. else: print("强制退出...") sys.exit(1) @@ -1690,7 +1683,7 @@ if __name__ == "__main__": except SystemExit: pass finally: - # 清理 multiprocessing 资源跟踪器(防止退出时的警告) + # Clean up the multiprocessing resource tracker to avoid exit warnings. try: from multiprocessing import resource_tracker resource_tracker._resource_tracker._stop() diff --git a/backend/scripts/run_reddit_simulation.py b/backend/scripts/run_reddit_simulation.py index 14907cbd..d3adc560 100644 --- a/backend/scripts/run_reddit_simulation.py +++ b/backend/scripts/run_reddit_simulation.py @@ -1,16 +1,16 @@ -""" -OASIS Reddit模拟预设脚本 -此脚本读取配置文件中的参数来执行模拟,实现全程自动化 +"""OASIS Reddit simulation preset script. -功能特性: -- 完成模拟后不立即关闭环境,进入等待命令模式 -- 支持通过IPC接收Interview命令 -- 支持单个Agent采访和批量采访 -- 支持远程关闭环境命令 +This script reads parameters from a config file and runs the simulation end-to-end automatically. -使用方式: +Features: +- After the simulation finishes, the environment stays alive and enters a command-wait mode. +- Accepts Interview commands over IPC. +- Supports single-agent and batch interviews. +- Supports a remote close-environment command. + +Usage: python run_reddit_simulation.py --config /path/to/simulation_config.json - python run_reddit_simulation.py --config /path/to/simulation_config.json --no-wait # 完成后立即关闭 + python run_reddit_simulation.py --config /path/to/simulation_config.json --no-wait # close immediately when done """ import argparse @@ -25,18 +25,18 @@ import sqlite3 from datetime import datetime from typing import Dict, Any, List, Optional -# 全局变量:用于信号处理 +# Globals used by the signal handler. _shutdown_event = None _cleanup_done = False -# 添加项目路径 +# Add project paths to sys.path so sibling modules import correctly. _scripts_dir = os.path.dirname(os.path.abspath(__file__)) _backend_dir = os.path.abspath(os.path.join(_scripts_dir, '..')) _project_root = os.path.abspath(os.path.join(_backend_dir, '..')) sys.path.insert(0, _scripts_dir) sys.path.insert(0, _backend_dir) -# 加载项目根目录的 .env 文件(包含 LLM_API_KEY 等配置) +# Load the .env file from the project root (contains LLM_API_KEY and related settings). from dotenv import load_dotenv _env_file = os.path.join(_project_root, '.env') if os.path.exists(_env_file): @@ -51,7 +51,7 @@ import re class UnicodeFormatter(logging.Formatter): - """自定义格式化器,将 Unicode 转义序列转换为可读字符""" + """Custom log formatter that converts Unicode escape sequences into readable characters.""" UNICODE_ESCAPE_PATTERN = re.compile(r'\\u([0-9a-fA-F]{4})') @@ -68,24 +68,23 @@ class UnicodeFormatter(logging.Formatter): class MaxTokensWarningFilter(logging.Filter): - """过滤掉 camel-ai 关于 max_tokens 的警告(我们故意不设置 max_tokens,让模型自行决定)""" - + """Suppress camel-ai's max_tokens warning (we intentionally leave max_tokens unset and let the model decide).""" + def filter(self, record): - # 过滤掉包含 max_tokens 警告的日志 if "max_tokens" in record.getMessage() and "Invalid or missing" in record.getMessage(): return False return True -# 在模块加载时立即添加过滤器,确保在 camel 代码执行前生效 +# Install the filter at module import time so it takes effect before any camel code runs. logging.getLogger().addFilter(MaxTokensWarningFilter()) def setup_oasis_logging(log_dir: str): - """配置 OASIS 的日志,使用固定名称的日志文件""" + """Configure OASIS logging with fixed log file names.""" os.makedirs(log_dir, exist_ok=True) - - # 清理旧的日志文件 + + # Remove stale log files from previous runs so the new run starts clean. for f in os.listdir(log_dir): old_log = os.path.join(log_dir, f) if os.path.isfile(old_log) and f.endswith('.log'): @@ -131,20 +130,20 @@ except ImportError as e: sys.exit(1) -# IPC相关常量 +# IPC-related constants. IPC_COMMANDS_DIR = "ipc_commands" IPC_RESPONSES_DIR = "ipc_responses" ENV_STATUS_FILE = "env_status.json" class CommandType: - """命令类型常量""" + """Command type constants.""" INTERVIEW = "interview" BATCH_INTERVIEW = "batch_interview" CLOSE_ENV = "close_env" class IPCHandler: - """IPC命令处理器""" + """IPC command handler.""" def __init__(self, simulation_dir: str, env, agent_graph): self.simulation_dir = simulation_dir @@ -154,13 +153,12 @@ class IPCHandler: self.responses_dir = os.path.join(simulation_dir, IPC_RESPONSES_DIR) self.status_file = os.path.join(simulation_dir, ENV_STATUS_FILE) self._running = True - - # 确保目录存在 + os.makedirs(self.commands_dir, exist_ok=True) os.makedirs(self.responses_dir, exist_ok=True) def update_status(self, status: str): - """更新环境状态""" + """Update the environment status file.""" with open(self.status_file, 'w', encoding='utf-8') as f: json.dump({ "status": status, @@ -168,11 +166,11 @@ class IPCHandler: }, f, ensure_ascii=False, indent=2) def poll_command(self) -> Optional[Dict[str, Any]]: - """轮询获取待处理命令""" + """Poll for pending IPC commands.""" if not os.path.exists(self.commands_dir): return None - - # 获取命令文件(按时间排序) + + # Collect command files sorted by modification time so older commands are handled first. command_files = [] for filename in os.listdir(self.commands_dir): if filename.endswith('.json'): @@ -191,7 +189,7 @@ class IPCHandler: return None def send_response(self, command_id: str, status: str, result: Dict = None, error: str = None): - """发送响应""" + """Send an IPC response for a command.""" response = { "command_id": command_id, "status": status, @@ -203,8 +201,8 @@ class IPCHandler: response_file = os.path.join(self.responses_dir, f"{command_id}.json") with open(response_file, 'w', encoding='utf-8') as f: json.dump(response, f, ensure_ascii=False, indent=2) - - # 删除命令文件 + + # Remove the command file once a response has been written so it isn't re-processed. command_file = os.path.join(self.commands_dir, f"{command_id}.json") try: os.remove(command_file) @@ -212,29 +210,25 @@ class IPCHandler: pass async def handle_interview(self, command_id: str, agent_id: int, prompt: str) -> bool: - """ - 处理单个Agent采访命令 - + """Handle a single-agent interview command. + Returns: - True 表示成功,False 表示失败 + True on success, False on failure. """ try: - # 获取Agent agent = self.agent_graph.get_agent(agent_id) - - # 创建Interview动作 + interview_action = ManualAction( action_type=ActionType.INTERVIEW, action_args={"prompt": prompt} ) - - # 执行Interview + actions = {agent: interview_action} await self.env.step(actions) - - # 从数据库获取结果 + + # Read the interview answer back from the simulation database. result = self._get_interview_result(agent_id) - + self.send_response(command_id, "completed", result=result) print(f" Interview完成: agent_id={agent_id}") return True @@ -246,17 +240,15 @@ class IPCHandler: return False async def handle_batch_interview(self, command_id: str, interviews: List[Dict]) -> bool: - """ - 处理批量采访命令 - + """Handle a batch interview command. + Args: interviews: [{"agent_id": int, "prompt": str}, ...] """ try: - # 构建动作字典 actions = {} - agent_prompts = {} # 记录每个agent的prompt - + agent_prompts = {} # Track which prompt was sent to each agent so results can be paired back. + for interview in interviews: agent_id = interview.get("agent_id") prompt = interview.get("prompt", "") @@ -274,11 +266,9 @@ class IPCHandler: if not actions: self.send_response(command_id, "failed", error="没有有效的Agent") return False - - # 执行批量Interview + await self.env.step(actions) - - # 获取所有结果 + results = {} for agent_id in agent_prompts.keys(): result = self._get_interview_result(agent_id) @@ -298,7 +288,7 @@ class IPCHandler: return False def _get_interview_result(self, agent_id: int) -> Dict[str, Any]: - """从数据库获取最新的Interview结果""" + """Fetch the most recent interview result for an agent from the database.""" db_path = os.path.join(self.simulation_dir, "reddit_simulation.db") result = { @@ -313,8 +303,8 @@ class IPCHandler: try: conn = sqlite3.connect(db_path) cursor = conn.cursor() - - # 查询最新的Interview记录 + + # Query the most recent interview row for this agent. cursor.execute(""" SELECT user_id, info, created_at FROM trace @@ -341,11 +331,10 @@ class IPCHandler: return result async def process_commands(self) -> bool: - """ - 处理所有待处理命令 - + """Process all pending IPC commands. + Returns: - True 表示继续运行,False 表示应该退出 + True to keep running, False if the loop should exit. """ command = self.poll_command() if not command: @@ -383,9 +372,9 @@ class IPCHandler: class RedditSimulationRunner: - """Reddit模拟运行器""" - - # Reddit可用动作(不包含INTERVIEW,INTERVIEW只能通过ManualAction手动触发) + """Reddit simulation runner.""" + + # Available Reddit actions (INTERVIEW is excluded because it can only be triggered via ManualAction). AVAILABLE_ACTIONS = [ ActionType.LIKE_POST, ActionType.DISLIKE_POST, @@ -403,12 +392,11 @@ class RedditSimulationRunner: ] def __init__(self, config_path: str, wait_for_commands: bool = True): - """ - 初始化模拟运行器 - + """Initialize the simulation runner. + Args: - config_path: 配置文件路径 (simulation_config.json) - wait_for_commands: 模拟完成后是否等待命令(默认True) + config_path: Path to the configuration file (simulation_config.json). + wait_for_commands: Whether to wait for commands after the simulation finishes (default True). """ self.config_path = config_path self.config = self._load_config() @@ -419,37 +407,36 @@ class RedditSimulationRunner: self.ipc_handler = None def _load_config(self) -> Dict[str, Any]: - """加载配置文件""" + """Load the configuration file.""" with open(self.config_path, 'r', encoding='utf-8') as f: return json.load(f) - + def _get_profile_path(self) -> str: - """获取Profile文件路径""" + """Return the path to the agent profiles file.""" return os.path.join(self.simulation_dir, "reddit_profiles.json") - + def _get_db_path(self) -> str: - """获取数据库路径""" + """Return the path to the simulation database.""" return os.path.join(self.simulation_dir, "reddit_simulation.db") - + def _create_model(self): + """Create the LLM model. + + Configuration is sourced from the project-root ``.env`` file (highest priority): + - LLM_API_KEY: API key. + - LLM_BASE_URL: API base URL. + - LLM_MODEL_NAME: Model name. """ - 创建LLM模型 - - 统一使用项目根目录 .env 文件中的配置(优先级最高): - - LLM_API_KEY: API密钥 - - LLM_BASE_URL: API基础URL - - LLM_MODEL_NAME: 模型名称 - """ - # 优先从 .env 读取配置 + # Prefer values from .env over the per-simulation config. llm_api_key = os.environ.get("LLM_API_KEY", "") llm_base_url = os.environ.get("LLM_BASE_URL", "") llm_model = os.environ.get("LLM_MODEL_NAME", "") - - # 如果 .env 中没有,则使用 config 作为备用 + + # Fall back to the simulation config file if .env did not specify a model. if not llm_model: llm_model = self.config.get("llm_model", "gpt-4o-mini") - - # 设置 camel-ai 所需的环境变量 + + # Export the env vars camel-ai expects. if llm_api_key: os.environ["OPENAI_API_KEY"] = llm_api_key @@ -472,9 +459,7 @@ class RedditSimulationRunner: current_hour: int, round_num: int ) -> List: - """ - 根据时间和配置决定本轮激活哪些Agent - """ + """Decide which agents are active for the current round, based on time of day and config.""" time_config = self.config.get("time_config", {}) agent_configs = self.config.get("agent_configs", []) @@ -521,10 +506,10 @@ class RedditSimulationRunner: return active_agents async def run(self, max_rounds: int = None): - """运行Reddit模拟 - + """Run the Reddit simulation. + Args: - max_rounds: 最大模拟轮数(可选,用于截断过长的模拟) + max_rounds: Optional cap on the number of simulation rounds (used to truncate overly long runs). """ print("=" * 60) print("OASIS Reddit模拟") @@ -538,7 +523,7 @@ class RedditSimulationRunner: minutes_per_round = time_config.get("minutes_per_round", 30) total_rounds = (total_hours * 60) // minutes_per_round - # 如果指定了最大轮数,则截断 + # Truncate if a max_rounds cap was supplied. if max_rounds is not None and max_rounds > 0: original_rounds = total_rounds total_rounds = min(total_rounds, max_rounds) @@ -578,17 +563,16 @@ class RedditSimulationRunner: agent_graph=self.agent_graph, platform=oasis.DefaultPlatformType.REDDIT, database_path=db_path, - semaphore=30, # 限制最大并发 LLM 请求数,防止 API 过载 + semaphore=30, # Cap concurrent LLM requests to avoid overloading the API. ) await self.env.reset() print("环境初始化完成\n") - # 初始化IPC处理器 self.ipc_handler = IPCHandler(self.simulation_dir, self.env, self.agent_graph) self.ipc_handler.update_status("running") - - # 执行初始事件 + + # Apply the configured initial events (seed posts) before starting the main loop. event_config = self.config.get("event_config", {}) initial_posts = event_config.get("initial_posts", []) @@ -619,7 +603,7 @@ class RedditSimulationRunner: await self.env.step(initial_actions) print(f" 已发布 {len(initial_actions)} 条初始帖子") - # 主模拟循环 + # Main simulation loop. print("\n开始模拟循环...") start_time = datetime.now() @@ -655,7 +639,7 @@ class RedditSimulationRunner: print(f" - 总耗时: {total_elapsed:.1f}秒") print(f" - 数据库: {db_path}") - # 是否进入等待命令模式 + # Optionally enter command-wait mode. if self.wait_for_commands: print("\n" + "=" * 60) print("进入等待命令模式 - 环境保持运行") @@ -664,7 +648,7 @@ class RedditSimulationRunner: self.ipc_handler.update_status("alive") - # 等待命令循环(使用全局 _shutdown_event) + # Command-wait loop driven by the global _shutdown_event. try: while not _shutdown_event.is_set(): should_continue = await self.ipc_handler.process_commands() @@ -672,7 +656,7 @@ class RedditSimulationRunner: break try: await asyncio.wait_for(_shutdown_event.wait(), timeout=0.5) - break # 收到退出信号 + break # Shutdown signal received. except asyncio.TimeoutError: pass except KeyboardInterrupt: @@ -683,8 +667,7 @@ class RedditSimulationRunner: print(f"\n命令处理出错: {e}") print("\n关闭环境...") - - # 关闭环境 + self.ipc_handler.update_status("stopped") await self.env.close() @@ -715,7 +698,7 @@ async def main(): args = parser.parse_args() - # 在 main 函数开始时创建 shutdown 事件 + # Create the shutdown event lazily here so it is bound to the running asyncio loop. global _shutdown_event _shutdown_event = asyncio.Event() @@ -723,7 +706,7 @@ async def main(): print(f"错误: 配置文件不存在: {args.config}") sys.exit(1) - # 初始化日志配置(使用固定文件名,清理旧日志) + # Initialize log config with fixed filenames; old logs are cleared inside setup_oasis_logging. simulation_dir = os.path.dirname(args.config) or "." setup_oasis_logging(os.path.join(simulation_dir, "log")) @@ -735,9 +718,9 @@ async def main(): def setup_signal_handlers(): - """ - 设置信号处理器,确保收到 SIGTERM/SIGINT 时能够正确退出 - 让程序有机会正常清理资源(关闭数据库、环境等) + """Install signal handlers so SIGTERM/SIGINT trigger a graceful exit. + + This gives the program a chance to clean up resources (close the database, the OASIS environment, etc.). """ def signal_handler(signum, frame): global _cleanup_done @@ -748,7 +731,7 @@ def setup_signal_handlers(): if _shutdown_event: _shutdown_event.set() else: - # 重复收到信号才强制退出 + # Force exit only on a repeat signal so the user can still hard-kill if cleanup hangs. print("强制退出...") sys.exit(1) diff --git a/backend/scripts/run_twitter_simulation.py b/backend/scripts/run_twitter_simulation.py index caab9e9d..4e96e06b 100644 --- a/backend/scripts/run_twitter_simulation.py +++ b/backend/scripts/run_twitter_simulation.py @@ -1,16 +1,18 @@ """ -OASIS Twitter模拟预设脚本 -此脚本读取配置文件中的参数来执行模拟,实现全程自动化 +OASIS Twitter simulation preset script. -功能特性: -- 完成模拟后不立即关闭环境,进入等待命令模式 -- 支持通过IPC接收Interview命令 -- 支持单个Agent采访和批量采访 -- 支持远程关闭环境命令 +This script reads parameters from a config file to run a fully automated simulation. -使用方式: +Features: +- Does not close the environment immediately when the simulation finishes; enters + command-wait mode instead. +- Receives Interview commands over IPC. +- Supports both single-agent and batch interviews. +- Supports a remote close-environment command. + +Usage: python run_twitter_simulation.py --config /path/to/simulation_config.json - python run_twitter_simulation.py --config /path/to/simulation_config.json --no-wait # 完成后立即关闭 + python run_twitter_simulation.py --config /path/to/simulation_config.json --no-wait # close immediately when done """ import argparse @@ -25,18 +27,17 @@ import sqlite3 from datetime import datetime from typing import Dict, Any, List, Optional -# 全局变量:用于信号处理 +# Globals used by the signal handler. _shutdown_event = None _cleanup_done = False -# 添加项目路径 _scripts_dir = os.path.dirname(os.path.abspath(__file__)) _backend_dir = os.path.abspath(os.path.join(_scripts_dir, '..')) _project_root = os.path.abspath(os.path.join(_backend_dir, '..')) sys.path.insert(0, _scripts_dir) sys.path.insert(0, _backend_dir) -# 加载项目根目录的 .env 文件(包含 LLM_API_KEY 等配置) +# Load the project-root .env (it carries LLM_API_KEY and friends). from dotenv import load_dotenv _env_file = os.path.join(_project_root, '.env') if os.path.exists(_env_file): @@ -51,7 +52,7 @@ import re class UnicodeFormatter(logging.Formatter): - """自定义格式化器,将 Unicode 转义序列转换为可读字符""" + """Custom formatter that turns Unicode escape sequences into readable characters.""" UNICODE_ESCAPE_PATTERN = re.compile(r'\\u([0-9a-fA-F]{4})') @@ -68,24 +69,23 @@ class UnicodeFormatter(logging.Formatter): class MaxTokensWarningFilter(logging.Filter): - """过滤掉 camel-ai 关于 max_tokens 的警告(我们故意不设置 max_tokens,让模型自行决定)""" - + """Suppress camel-ai's max_tokens warning — we intentionally leave it unset and let the model decide.""" + def filter(self, record): - # 过滤掉包含 max_tokens 警告的日志 if "max_tokens" in record.getMessage() and "Invalid or missing" in record.getMessage(): return False return True -# 在模块加载时立即添加过滤器,确保在 camel 代码执行前生效 +# Install the filter at import time so it is active before any camel code runs. logging.getLogger().addFilter(MaxTokensWarningFilter()) def setup_oasis_logging(log_dir: str): - """配置 OASIS 的日志,使用固定名称的日志文件""" + """Configure OASIS logging with fixed log filenames.""" os.makedirs(log_dir, exist_ok=True) - - # 清理旧的日志文件 + + # Wipe stale log files from previous runs. for f in os.listdir(log_dir): old_log = os.path.join(log_dir, f) if os.path.isfile(old_log) and f.endswith('.log'): @@ -131,21 +131,21 @@ except ImportError as e: sys.exit(1) -# IPC相关常量 +# IPC-related constants. IPC_COMMANDS_DIR = "ipc_commands" IPC_RESPONSES_DIR = "ipc_responses" ENV_STATUS_FILE = "env_status.json" class CommandType: - """命令类型常量""" + """Command type constants.""" INTERVIEW = "interview" BATCH_INTERVIEW = "batch_interview" CLOSE_ENV = "close_env" class IPCHandler: - """IPC命令处理器""" - + """Handles IPC commands directed at the running simulation.""" + def __init__(self, simulation_dir: str, env, agent_graph): self.simulation_dir = simulation_dir self.env = env @@ -154,13 +154,12 @@ class IPCHandler: self.responses_dir = os.path.join(simulation_dir, IPC_RESPONSES_DIR) self.status_file = os.path.join(simulation_dir, ENV_STATUS_FILE) self._running = True - - # 确保目录存在 + os.makedirs(self.commands_dir, exist_ok=True) os.makedirs(self.responses_dir, exist_ok=True) - + def update_status(self, status: str): - """更新环境状态""" + """Write the current environment status to the status file.""" with open(self.status_file, 'w', encoding='utf-8') as f: json.dump({ "status": status, @@ -168,11 +167,11 @@ class IPCHandler: }, f, ensure_ascii=False, indent=2) def poll_command(self) -> Optional[Dict[str, Any]]: - """轮询获取待处理命令""" + """Poll for the next pending command.""" if not os.path.exists(self.commands_dir): return None - - # 获取命令文件(按时间排序) + + # Collect command files ordered by mtime. command_files = [] for filename in os.listdir(self.commands_dir): if filename.endswith('.json'): @@ -191,7 +190,7 @@ class IPCHandler: return None def send_response(self, command_id: str, status: str, result: Dict = None, error: str = None): - """发送响应""" + """Send a response for a processed command.""" response = { "command_id": command_id, "status": status, @@ -203,8 +202,8 @@ class IPCHandler: response_file = os.path.join(self.responses_dir, f"{command_id}.json") with open(response_file, 'w', encoding='utf-8') as f: json.dump(response, f, ensure_ascii=False, indent=2) - - # 删除命令文件 + + # Remove the command file once a response has been written. command_file = os.path.join(self.commands_dir, f"{command_id}.json") try: os.remove(command_file) @@ -212,27 +211,23 @@ class IPCHandler: pass async def handle_interview(self, command_id: str, agent_id: int, prompt: str) -> bool: - """ - 处理单个Agent采访命令 - + """Handle a single-agent interview command. + Returns: - True 表示成功,False 表示失败 + True on success, False on failure. """ try: - # 获取Agent agent = self.agent_graph.get_agent(agent_id) - - # 创建Interview动作 + interview_action = ManualAction( action_type=ActionType.INTERVIEW, action_args={"prompt": prompt} ) - - # 执行Interview + actions = {agent: interview_action} await self.env.step(actions) - - # 从数据库获取结果 + + # Pull the resulting transcript from the simulation database. result = self._get_interview_result(agent_id) self.send_response(command_id, "completed", result=result) @@ -246,17 +241,15 @@ class IPCHandler: return False async def handle_batch_interview(self, command_id: str, interviews: List[Dict]) -> bool: - """ - 处理批量采访命令 - + """Handle a batch interview command. + Args: interviews: [{"agent_id": int, "prompt": str}, ...] """ try: - # 构建动作字典 actions = {} - agent_prompts = {} # 记录每个agent的prompt - + agent_prompts = {} # Track the prompt issued to each agent for later result lookup. + for interview in interviews: agent_id = interview.get("agent_id") prompt = interview.get("prompt", "") @@ -274,11 +267,10 @@ class IPCHandler: if not actions: self.send_response(command_id, "failed", error="没有有效的Agent") return False - - # 执行批量Interview + await self.env.step(actions) - - # 获取所有结果 + + # Collect the per-agent interview results. results = {} for agent_id in agent_prompts.keys(): result = self._get_interview_result(agent_id) @@ -298,7 +290,7 @@ class IPCHandler: return False def _get_interview_result(self, agent_id: int) -> Dict[str, Any]: - """从数据库获取最新的Interview结果""" + """Fetch the most recent interview result for an agent from the database.""" db_path = os.path.join(self.simulation_dir, "twitter_simulation.db") result = { @@ -313,8 +305,8 @@ class IPCHandler: try: conn = sqlite3.connect(db_path) cursor = conn.cursor() - - # 查询最新的Interview记录 + + # Pull the most recent INTERVIEW trace row for this agent. cursor.execute(""" SELECT user_id, info, created_at FROM trace @@ -341,11 +333,10 @@ class IPCHandler: return result async def process_commands(self) -> bool: - """ - 处理所有待处理命令 - + """Process pending commands. + Returns: - True 表示继续运行,False 表示应该退出 + True if the run loop should continue, False if it should exit. """ command = self.poll_command() if not command: @@ -383,9 +374,9 @@ class IPCHandler: class TwitterSimulationRunner: - """Twitter模拟运行器""" - - # Twitter可用动作(不包含INTERVIEW,INTERVIEW只能通过ManualAction手动触发) + """Drives a single Twitter simulation run.""" + + # Available Twitter actions. INTERVIEW is intentionally excluded — it can only be triggered via ManualAction. AVAILABLE_ACTIONS = [ ActionType.CREATE_POST, ActionType.LIKE_POST, @@ -396,12 +387,11 @@ class TwitterSimulationRunner: ] def __init__(self, config_path: str, wait_for_commands: bool = True): - """ - 初始化模拟运行器 - + """Initialize the simulation runner. + Args: - config_path: 配置文件路径 (simulation_config.json) - wait_for_commands: 模拟完成后是否等待命令(默认True) + config_path: Path to the config file (simulation_config.json). + wait_for_commands: Whether to wait for IPC commands after the simulation completes (default True). """ self.config_path = config_path self.config = self._load_config() @@ -412,37 +402,36 @@ class TwitterSimulationRunner: self.ipc_handler = None def _load_config(self) -> Dict[str, Any]: - """加载配置文件""" + """Load the simulation config file.""" with open(self.config_path, 'r', encoding='utf-8') as f: return json.load(f) - + def _get_profile_path(self) -> str: - """获取Profile文件路径(OASIS Twitter使用CSV格式)""" + """Return the agent profile path (OASIS Twitter expects CSV).""" return os.path.join(self.simulation_dir, "twitter_profiles.csv") - + def _get_db_path(self) -> str: - """获取数据库路径""" + """Return the simulation SQLite database path.""" return os.path.join(self.simulation_dir, "twitter_simulation.db") - + def _create_model(self): + """Create the LLM model. + + Uses the project-root .env file (highest precedence): + - LLM_API_KEY: API key + - LLM_BASE_URL: API base URL + - LLM_MODEL_NAME: model name """ - 创建LLM模型 - - 统一使用项目根目录 .env 文件中的配置(优先级最高): - - LLM_API_KEY: API密钥 - - LLM_BASE_URL: API基础URL - - LLM_MODEL_NAME: 模型名称 - """ - # 优先从 .env 读取配置 + # Prefer values from .env. llm_api_key = os.environ.get("LLM_API_KEY", "") llm_base_url = os.environ.get("LLM_BASE_URL", "") llm_model = os.environ.get("LLM_MODEL_NAME", "") - - # 如果 .env 中没有,则使用 config 作为备用 + + # Fall back to the simulation config if .env did not provide a model name. if not llm_model: llm_model = self.config.get("llm_model", "gpt-4o-mini") - - # 设置 camel-ai 所需的环境变量 + + # camel-ai reads OPENAI_API_KEY from the environment. if llm_api_key: os.environ["OPENAI_API_KEY"] = llm_api_key @@ -465,25 +454,24 @@ class TwitterSimulationRunner: current_hour: int, round_num: int ) -> List: - """ - 根据时间和配置决定本轮激活哪些Agent - + """Decide which agents activate this round, based on time and config. + Args: - env: OASIS环境 - current_hour: 当前模拟小时(0-23) - round_num: 当前轮数 - + env: The OASIS environment. + current_hour: Current simulated hour (0-23). + round_num: Current round number. + Returns: - 激活的Agent列表 + The list of agents activated this round. """ time_config = self.config.get("time_config", {}) agent_configs = self.config.get("agent_configs", []) - - # 基础激活数量 + + # Base activation count per round. base_min = time_config.get("agents_per_hour_min", 5) base_max = time_config.get("agents_per_hour_max", 20) - - # 根据时段调整 + + # Adjust by time-of-day (peak vs. off-peak hours). peak_hours = time_config.get("peak_hours", [9, 10, 11, 14, 15, 20, 21, 22]) off_peak_hours = time_config.get("off_peak_hours", [0, 1, 2, 3, 4, 5]) @@ -495,29 +483,27 @@ class TwitterSimulationRunner: multiplier = 1.0 target_count = int(random.uniform(base_min, base_max) * multiplier) - - # 根据每个Agent的配置计算激活概率 + + # Compute activation probability for each configured agent. candidates = [] for cfg in agent_configs: agent_id = cfg.get("agent_id", 0) active_hours = cfg.get("active_hours", list(range(8, 23))) activity_level = cfg.get("activity_level", 0.5) - - # 检查是否在活跃时间 + if current_hour not in active_hours: continue - - # 根据活跃度计算概率 + if random.random() < activity_level: candidates.append(agent_id) - - # 随机选择 + + # Pick a random subset of the eligible candidates. selected_ids = random.sample( - candidates, + candidates, min(target_count, len(candidates)) ) if candidates else [] - - # 转换为Agent对象 + + # Resolve IDs to Agent objects. active_agents = [] for agent_id in selected_ids: try: @@ -529,10 +515,10 @@ class TwitterSimulationRunner: return active_agents async def run(self, max_rounds: int = None): - """运行Twitter模拟 - + """Run the Twitter simulation. + Args: - max_rounds: 最大模拟轮数(可选,用于截断过长的模拟) + max_rounds: Optional cap on the number of rounds, used to truncate overly long simulations. """ print("=" * 60) print("OASIS Twitter模拟") @@ -540,16 +526,14 @@ class TwitterSimulationRunner: print(f"模拟ID: {self.config.get('simulation_id', 'unknown')}") print(f"等待命令模式: {'启用' if self.wait_for_commands else '禁用'}") print("=" * 60) - - # 加载时间配置 + time_config = self.config.get("time_config", {}) total_hours = time_config.get("total_simulation_hours", 72) minutes_per_round = time_config.get("minutes_per_round", 30) - - # 计算总轮数 + total_rounds = (total_hours * 60) // minutes_per_round - - # 如果指定了最大轮数,则截断 + + # Truncate to max_rounds when one was supplied. if max_rounds is not None and max_rounds > 0: original_rounds = total_rounds total_rounds = min(total_rounds, max_rounds) @@ -563,12 +547,11 @@ class TwitterSimulationRunner: if max_rounds: print(f" - 最大轮数限制: {max_rounds}") print(f" - Agent数量: {len(self.config.get('agent_configs', []))}") - - # 创建模型 + print("\n初始化LLM模型...") model = self._create_model() - - # 加载Agent图 + + # Load the agent graph from the profile CSV. print("加载Agent Profile...") profile_path = self._get_profile_path() if not os.path.exists(profile_path): @@ -581,29 +564,27 @@ class TwitterSimulationRunner: available_actions=self.AVAILABLE_ACTIONS, ) - # 数据库路径 + # Reset the simulation database for a clean run. db_path = self._get_db_path() if os.path.exists(db_path): os.remove(db_path) print(f"已删除旧数据库: {db_path}") - - # 创建环境 + print("创建OASIS环境...") self.env = oasis.make( agent_graph=self.agent_graph, platform=oasis.DefaultPlatformType.TWITTER, database_path=db_path, - semaphore=30, # 限制最大并发 LLM 请求数,防止 API 过载 + semaphore=30, # Cap concurrent LLM requests to avoid API overload. ) await self.env.reset() print("环境初始化完成\n") - - # 初始化IPC处理器 + self.ipc_handler = IPCHandler(self.simulation_dir, self.env, self.agent_graph) self.ipc_handler.update_status("running") - - # 执行初始事件 + + # Run the initial seeded events (kickoff posts). event_config = self.config.get("event_config", {}) initial_posts = event_config.get("initial_posts", []) @@ -625,35 +606,32 @@ class TwitterSimulationRunner: if initial_actions: await self.env.step(initial_actions) print(f" 已发布 {len(initial_actions)} 条初始帖子") - - # 主模拟循环 + + # Main simulation loop. print("\n开始模拟循环...") start_time = datetime.now() - + for round_num in range(total_rounds): - # 计算当前模拟时间 + # Map round number to simulated wall-clock time. simulated_minutes = round_num * minutes_per_round simulated_hour = (simulated_minutes // 60) % 24 simulated_day = simulated_minutes // (60 * 24) + 1 - - # 获取本轮激活的Agent + active_agents = self._get_active_agents_for_round( self.env, simulated_hour, round_num ) - + if not active_agents: continue - - # 构建动作 + actions = { agent: LLMAction() for _, agent in active_agents } - - # 执行动作 + await self.env.step(actions) - - # 打印进度 + + # Periodic progress log. if (round_num + 1) % 10 == 0 or round_num == 0: elapsed = (datetime.now() - start_time).total_seconds() progress = (round_num + 1) / total_rounds * 100 @@ -667,7 +645,7 @@ class TwitterSimulationRunner: print(f" - 总耗时: {total_elapsed:.1f}秒") print(f" - 数据库: {db_path}") - # 是否进入等待命令模式 + # Optionally enter command-wait mode. if self.wait_for_commands: print("\n" + "=" * 60) print("进入等待命令模式 - 环境保持运行") @@ -675,8 +653,8 @@ class TwitterSimulationRunner: print("=" * 60) self.ipc_handler.update_status("alive") - - # 等待命令循环(使用全局 _shutdown_event) + + # Command-wait loop, driven by the global _shutdown_event. try: while not _shutdown_event.is_set(): should_continue = await self.ipc_handler.process_commands() @@ -684,7 +662,7 @@ class TwitterSimulationRunner: break try: await asyncio.wait_for(_shutdown_event.wait(), timeout=0.5) - break # 收到退出信号 + break # Shutdown signal received. except asyncio.TimeoutError: pass except KeyboardInterrupt: @@ -695,8 +673,7 @@ class TwitterSimulationRunner: print(f"\n命令处理出错: {e}") print("\n关闭环境...") - - # 关闭环境 + self.ipc_handler.update_status("stopped") await self.env.close() @@ -726,16 +703,16 @@ async def main(): ) args = parser.parse_args() - - # 在 main 函数开始时创建 shutdown 事件 + + # Create the shutdown event inside the running event loop. global _shutdown_event _shutdown_event = asyncio.Event() - + if not os.path.exists(args.config): print(f"错误: 配置文件不存在: {args.config}") sys.exit(1) - - # 初始化日志配置(使用固定文件名,清理旧日志) + + # Initialize logging with fixed filenames; old logs are wiped. simulation_dir = os.path.dirname(args.config) or "." setup_oasis_logging(os.path.join(simulation_dir, "log")) @@ -747,9 +724,11 @@ async def main(): def setup_signal_handlers(): - """ - 设置信号处理器,确保收到 SIGTERM/SIGINT 时能够正确退出 - 让程序有机会正常清理资源(关闭数据库、环境等) + """Install signal handlers so SIGTERM/SIGINT trigger an orderly shutdown. + + The handler gives the program a chance to clean up resources properly + (closing the database, the OASIS environment, etc.) on the first signal, + and only force-exits on a repeated signal. """ def signal_handler(signum, frame): global _cleanup_done @@ -760,7 +739,7 @@ def setup_signal_handlers(): if _shutdown_event: _shutdown_event.set() else: - # 重复收到信号才强制退出 + # Force exit only on a repeat signal. print("强制退出...") sys.exit(1)