"""OASIS Reddit simulation preset script. This script reads parameters from a config file and runs the simulation end-to-end automatically. Features: - After the simulation finishes, the environment stays alive and enters a command-wait mode. - Accepts Interview commands over IPC. - Supports single-agent and batch interviews. - Supports a remote close-environment command. Usage: python run_reddit_simulation.py --config /path/to/simulation_config.json python run_reddit_simulation.py --config /path/to/simulation_config.json --no-wait # close immediately when done """ import argparse import asyncio import json import logging import os import random import signal import sys import sqlite3 from datetime import datetime from typing import Dict, Any, List, Optional # Globals used by the signal handler. _shutdown_event = None _cleanup_done = False # Add project paths to sys.path so sibling modules import correctly. _scripts_dir = os.path.dirname(os.path.abspath(__file__)) _backend_dir = os.path.abspath(os.path.join(_scripts_dir, '..')) _project_root = os.path.abspath(os.path.join(_backend_dir, '..')) sys.path.insert(0, _scripts_dir) sys.path.insert(0, _backend_dir) # Load the .env file from the project root (contains LLM_API_KEY and related settings). from dotenv import load_dotenv _env_file = os.path.join(_project_root, '.env') if os.path.exists(_env_file): load_dotenv(_env_file) else: _backend_env = os.path.join(_backend_dir, '.env') if os.path.exists(_backend_env): load_dotenv(_backend_env) import re class UnicodeFormatter(logging.Formatter): """Custom log formatter that converts Unicode escape sequences into readable characters.""" UNICODE_ESCAPE_PATTERN = re.compile(r'\\u([0-9a-fA-F]{4})') def format(self, record): result = super().format(record) def replace_unicode(match): try: return chr(int(match.group(1), 16)) except (ValueError, OverflowError): return match.group(0) return self.UNICODE_ESCAPE_PATTERN.sub(replace_unicode, result) class MaxTokensWarningFilter(logging.Filter): """Suppress camel-ai's max_tokens warning (we intentionally leave max_tokens unset and let the model decide).""" def filter(self, record): if "max_tokens" in record.getMessage() and "Invalid or missing" in record.getMessage(): return False return True # Install the filter at module import time so it takes effect before any camel code runs. logging.getLogger().addFilter(MaxTokensWarningFilter()) def setup_oasis_logging(log_dir: str): """Configure OASIS logging with fixed log file names.""" os.makedirs(log_dir, exist_ok=True) # Remove stale log files from previous runs so the new run starts clean. for f in os.listdir(log_dir): old_log = os.path.join(log_dir, f) if os.path.isfile(old_log) and f.endswith('.log'): try: os.remove(old_log) except OSError: pass formatter = UnicodeFormatter("%(levelname)s - %(asctime)s - %(name)s - %(message)s") loggers_config = { "social.agent": os.path.join(log_dir, "social.agent.log"), "social.twitter": os.path.join(log_dir, "social.twitter.log"), "social.rec": os.path.join(log_dir, "social.rec.log"), "oasis.env": os.path.join(log_dir, "oasis.env.log"), "table": os.path.join(log_dir, "table.log"), } for logger_name, log_file in loggers_config.items(): logger = logging.getLogger(logger_name) logger.setLevel(logging.DEBUG) logger.handlers.clear() file_handler = logging.FileHandler(log_file, encoding='utf-8', mode='w') file_handler.setLevel(logging.DEBUG) file_handler.setFormatter(formatter) logger.addHandler(file_handler) logger.propagate = False try: from camel.models import ModelFactory from camel.types import ModelPlatformType import oasis from oasis import ( ActionType, LLMAction, ManualAction, generate_reddit_agent_graph ) except ImportError as e: print(f"错误: 缺少依赖 {e}") print("请先安装: pip install oasis-ai camel-ai") sys.exit(1) # IPC-related constants. IPC_COMMANDS_DIR = "ipc_commands" IPC_RESPONSES_DIR = "ipc_responses" ENV_STATUS_FILE = "env_status.json" class CommandType: """Command type constants.""" INTERVIEW = "interview" BATCH_INTERVIEW = "batch_interview" CLOSE_ENV = "close_env" class IPCHandler: """IPC command handler.""" def __init__(self, simulation_dir: str, env, agent_graph): self.simulation_dir = simulation_dir self.env = env self.agent_graph = agent_graph self.commands_dir = os.path.join(simulation_dir, IPC_COMMANDS_DIR) self.responses_dir = os.path.join(simulation_dir, IPC_RESPONSES_DIR) self.status_file = os.path.join(simulation_dir, ENV_STATUS_FILE) self._running = True os.makedirs(self.commands_dir, exist_ok=True) os.makedirs(self.responses_dir, exist_ok=True) def update_status(self, status: str): """Update the environment status file.""" with open(self.status_file, 'w', encoding='utf-8') as f: json.dump({ "status": status, "timestamp": datetime.now().isoformat() }, f, ensure_ascii=False, indent=2) def poll_command(self) -> Optional[Dict[str, Any]]: """Poll for pending IPC commands.""" if not os.path.exists(self.commands_dir): return None # Collect command files sorted by modification time so older commands are handled first. command_files = [] for filename in os.listdir(self.commands_dir): if filename.endswith('.json'): filepath = os.path.join(self.commands_dir, filename) command_files.append((filepath, os.path.getmtime(filepath))) command_files.sort(key=lambda x: x[1]) for filepath, _ in command_files: try: with open(filepath, 'r', encoding='utf-8') as f: return json.load(f) except (json.JSONDecodeError, OSError): continue return None def send_response(self, command_id: str, status: str, result: Dict = None, error: str = None): """Send an IPC response for a command.""" response = { "command_id": command_id, "status": status, "result": result, "error": error, "timestamp": datetime.now().isoformat() } response_file = os.path.join(self.responses_dir, f"{command_id}.json") with open(response_file, 'w', encoding='utf-8') as f: json.dump(response, f, ensure_ascii=False, indent=2) # Remove the command file once a response has been written so it isn't re-processed. command_file = os.path.join(self.commands_dir, f"{command_id}.json") try: os.remove(command_file) except OSError: pass async def handle_interview(self, command_id: str, agent_id: int, prompt: str) -> bool: """Handle a single-agent interview command. Returns: True on success, False on failure. """ try: agent = self.agent_graph.get_agent(agent_id) interview_action = ManualAction( action_type=ActionType.INTERVIEW, action_args={"prompt": prompt} ) actions = {agent: interview_action} await self.env.step(actions) # Read the interview answer back from the simulation database. result = self._get_interview_result(agent_id) self.send_response(command_id, "completed", result=result) print(f" Interview完成: agent_id={agent_id}") return True except Exception as e: error_msg = str(e) print(f" Interview失败: agent_id={agent_id}, error={error_msg}") self.send_response(command_id, "failed", error=error_msg) return False async def handle_batch_interview(self, command_id: str, interviews: List[Dict]) -> bool: """Handle a batch interview command. Args: interviews: [{"agent_id": int, "prompt": str}, ...] """ try: actions = {} agent_prompts = {} # Track which prompt was sent to each agent so results can be paired back. for interview in interviews: agent_id = interview.get("agent_id") prompt = interview.get("prompt", "") try: agent = self.agent_graph.get_agent(agent_id) actions[agent] = ManualAction( action_type=ActionType.INTERVIEW, action_args={"prompt": prompt} ) agent_prompts[agent_id] = prompt except Exception as e: print(f" 警告: 无法获取Agent {agent_id}: {e}") if not actions: self.send_response(command_id, "failed", error="没有有效的Agent") return False await self.env.step(actions) results = {} for agent_id in agent_prompts.keys(): result = self._get_interview_result(agent_id) results[agent_id] = result self.send_response(command_id, "completed", result={ "interviews_count": len(results), "results": results }) print(f" 批量Interview完成: {len(results)} 个Agent") return True except Exception as e: error_msg = str(e) print(f" 批量Interview失败: {error_msg}") self.send_response(command_id, "failed", error=error_msg) return False def _get_interview_result(self, agent_id: int) -> Dict[str, Any]: """Fetch the most recent interview result for an agent from the database.""" db_path = os.path.join(self.simulation_dir, "reddit_simulation.db") result = { "agent_id": agent_id, "response": None, "timestamp": None } if not os.path.exists(db_path): return result try: conn = sqlite3.connect(db_path) cursor = conn.cursor() # Query the most recent interview row for this agent. cursor.execute(""" SELECT user_id, info, created_at FROM trace WHERE action = ? AND user_id = ? ORDER BY created_at DESC LIMIT 1 """, (ActionType.INTERVIEW.value, agent_id)) row = cursor.fetchone() if row: user_id, info_json, created_at = row try: info = json.loads(info_json) if info_json else {} result["response"] = info.get("response", info) result["timestamp"] = created_at except json.JSONDecodeError: result["response"] = info_json conn.close() except Exception as e: print(f" 读取Interview结果失败: {e}") return result async def process_commands(self) -> bool: """Process all pending IPC commands. Returns: True to keep running, False if the loop should exit. """ command = self.poll_command() if not command: return True command_id = command.get("command_id") command_type = command.get("command_type") args = command.get("args", {}) print(f"\n收到IPC命令: {command_type}, id={command_id}") if command_type == CommandType.INTERVIEW: await self.handle_interview( command_id, args.get("agent_id", 0), args.get("prompt", "") ) return True elif command_type == CommandType.BATCH_INTERVIEW: await self.handle_batch_interview( command_id, args.get("interviews", []) ) return True elif command_type == CommandType.CLOSE_ENV: print("收到关闭环境命令") self.send_response(command_id, "completed", result={"message": "环境即将关闭"}) return False else: self.send_response(command_id, "failed", error=f"未知命令类型: {command_type}") return True class RedditSimulationRunner: """Reddit simulation runner.""" # Available Reddit actions (INTERVIEW is excluded because it can only be triggered via ManualAction). AVAILABLE_ACTIONS = [ ActionType.LIKE_POST, ActionType.DISLIKE_POST, ActionType.CREATE_POST, ActionType.CREATE_COMMENT, ActionType.LIKE_COMMENT, ActionType.DISLIKE_COMMENT, ActionType.SEARCH_POSTS, ActionType.SEARCH_USER, ActionType.TREND, ActionType.REFRESH, ActionType.DO_NOTHING, ActionType.FOLLOW, ActionType.MUTE, ] def __init__(self, config_path: str, wait_for_commands: bool = True): """Initialize the simulation runner. Args: config_path: Path to the configuration file (simulation_config.json). wait_for_commands: Whether to wait for commands after the simulation finishes (default True). """ self.config_path = config_path self.config = self._load_config() self.simulation_dir = os.path.dirname(config_path) self.wait_for_commands = wait_for_commands self.env = None self.agent_graph = None self.ipc_handler = None def _load_config(self) -> Dict[str, Any]: """Load the configuration file.""" with open(self.config_path, 'r', encoding='utf-8') as f: return json.load(f) def _get_profile_path(self) -> str: """Return the path to the agent profiles file.""" return os.path.join(self.simulation_dir, "reddit_profiles.json") def _get_db_path(self) -> str: """Return the path to the simulation database.""" return os.path.join(self.simulation_dir, "reddit_simulation.db") def _create_model(self): """Create the LLM model. Configuration is sourced from the project-root ``.env`` file (highest priority): - LLM_API_KEY: API key. - LLM_BASE_URL: API base URL. - LLM_MODEL_NAME: Model name. """ # Prefer values from .env over the per-simulation config. llm_api_key = os.environ.get("LLM_API_KEY", "") llm_base_url = os.environ.get("LLM_BASE_URL", "") llm_model = os.environ.get("LLM_MODEL_NAME", "") # Fall back to the simulation config file if .env did not specify a model. if not llm_model: llm_model = self.config.get("llm_model", "gpt-4o-mini") # Export the env vars camel-ai expects. if llm_api_key: os.environ["OPENAI_API_KEY"] = llm_api_key if not os.environ.get("OPENAI_API_KEY"): raise ValueError("缺少 API Key 配置,请在项目根目录 .env 文件中设置 LLM_API_KEY") if llm_base_url: os.environ["OPENAI_API_BASE_URL"] = llm_base_url print(f"LLM配置: model={llm_model}, base_url={llm_base_url[:40] if llm_base_url else '默认'}...") return ModelFactory.create( model_platform=ModelPlatformType.OPENAI, model_type=llm_model, ) def _get_active_agents_for_round( self, env, current_hour: int, round_num: int ) -> List: """Decide which agents are active for the current round, based on time of day and config.""" time_config = self.config.get("time_config", {}) agent_configs = self.config.get("agent_configs", []) base_min = time_config.get("agents_per_hour_min", 5) base_max = time_config.get("agents_per_hour_max", 20) peak_hours = time_config.get("peak_hours", [9, 10, 11, 14, 15, 20, 21, 22]) off_peak_hours = time_config.get("off_peak_hours", [0, 1, 2, 3, 4, 5]) if current_hour in peak_hours: multiplier = time_config.get("peak_activity_multiplier", 1.5) elif current_hour in off_peak_hours: multiplier = time_config.get("off_peak_activity_multiplier", 0.3) else: multiplier = 1.0 target_count = int(random.uniform(base_min, base_max) * multiplier) candidates = [] for cfg in agent_configs: agent_id = cfg.get("agent_id", 0) active_hours = cfg.get("active_hours", list(range(8, 23))) activity_level = cfg.get("activity_level", 0.5) if current_hour not in active_hours: continue if random.random() < activity_level: candidates.append(agent_id) selected_ids = random.sample( candidates, min(target_count, len(candidates)) ) if candidates else [] active_agents = [] for agent_id in selected_ids: try: agent = env.agent_graph.get_agent(agent_id) active_agents.append((agent_id, agent)) except Exception: pass return active_agents async def run(self, max_rounds: int = None): """Run the Reddit simulation. Args: max_rounds: Optional cap on the number of simulation rounds (used to truncate overly long runs). """ print("=" * 60) print("OASIS Reddit模拟") print(f"配置文件: {self.config_path}") print(f"模拟ID: {self.config.get('simulation_id', 'unknown')}") print(f"等待命令模式: {'启用' if self.wait_for_commands else '禁用'}") print("=" * 60) time_config = self.config.get("time_config", {}) total_hours = time_config.get("total_simulation_hours", 72) minutes_per_round = time_config.get("minutes_per_round", 30) total_rounds = (total_hours * 60) // minutes_per_round # Truncate if a max_rounds cap was supplied. if max_rounds is not None and max_rounds > 0: original_rounds = total_rounds total_rounds = min(total_rounds, max_rounds) if total_rounds < original_rounds: print(f"\n轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})") print(f"\n模拟参数:") print(f" - 总模拟时长: {total_hours}小时") print(f" - 每轮时间: {minutes_per_round}分钟") print(f" - 总轮数: {total_rounds}") if max_rounds: print(f" - 最大轮数限制: {max_rounds}") print(f" - Agent数量: {len(self.config.get('agent_configs', []))}") print("\n初始化LLM模型...") model = self._create_model() print("加载Agent Profile...") profile_path = self._get_profile_path() if not os.path.exists(profile_path): print(f"错误: Profile文件不存在: {profile_path}") return self.agent_graph = await generate_reddit_agent_graph( profile_path=profile_path, model=model, available_actions=self.AVAILABLE_ACTIONS, ) db_path = self._get_db_path() if os.path.exists(db_path): os.remove(db_path) print(f"已删除旧数据库: {db_path}") print("创建OASIS环境...") self.env = oasis.make( agent_graph=self.agent_graph, platform=oasis.DefaultPlatformType.REDDIT, database_path=db_path, semaphore=30, # Cap concurrent LLM requests to avoid overloading the API. ) await self.env.reset() print("环境初始化完成\n") self.ipc_handler = IPCHandler(self.simulation_dir, self.env, self.agent_graph) self.ipc_handler.update_status("running") # Apply the configured initial events (seed posts) before starting the main loop. event_config = self.config.get("event_config", {}) initial_posts = event_config.get("initial_posts", []) if initial_posts: print(f"执行初始事件 ({len(initial_posts)}条初始帖子)...") initial_actions = {} for post in initial_posts: agent_id = post.get("poster_agent_id", 0) content = post.get("content", "") try: agent = self.env.agent_graph.get_agent(agent_id) if agent in initial_actions: if not isinstance(initial_actions[agent], list): initial_actions[agent] = [initial_actions[agent]] initial_actions[agent].append(ManualAction( action_type=ActionType.CREATE_POST, action_args={"content": content} )) else: initial_actions[agent] = ManualAction( action_type=ActionType.CREATE_POST, action_args={"content": content} ) except Exception as e: print(f" 警告: 无法为Agent {agent_id}创建初始帖子: {e}") if initial_actions: await self.env.step(initial_actions) print(f" 已发布 {len(initial_actions)} 条初始帖子") # Main simulation loop. print("\n开始模拟循环...") start_time = datetime.now() for round_num in range(total_rounds): simulated_minutes = round_num * minutes_per_round simulated_hour = (simulated_minutes // 60) % 24 simulated_day = simulated_minutes // (60 * 24) + 1 active_agents = self._get_active_agents_for_round( self.env, simulated_hour, round_num ) if not active_agents: continue actions = { agent: LLMAction() for _, agent in active_agents } await self.env.step(actions) if (round_num + 1) % 10 == 0 or round_num == 0: elapsed = (datetime.now() - start_time).total_seconds() progress = (round_num + 1) / total_rounds * 100 print(f" [Day {simulated_day}, {simulated_hour:02d}:00] " f"Round {round_num + 1}/{total_rounds} ({progress:.1f}%) " f"- {len(active_agents)} agents active " f"- elapsed: {elapsed:.1f}s") total_elapsed = (datetime.now() - start_time).total_seconds() print(f"\n模拟循环完成!") print(f" - 总耗时: {total_elapsed:.1f}秒") print(f" - 数据库: {db_path}") # Optionally enter command-wait mode. if self.wait_for_commands: print("\n" + "=" * 60) print("进入等待命令模式 - 环境保持运行") print("支持的命令: interview, batch_interview, close_env") print("=" * 60) self.ipc_handler.update_status("alive") # Command-wait loop driven by the global _shutdown_event. try: while not _shutdown_event.is_set(): should_continue = await self.ipc_handler.process_commands() if not should_continue: break try: await asyncio.wait_for(_shutdown_event.wait(), timeout=0.5) break # Shutdown signal received. except asyncio.TimeoutError: pass except KeyboardInterrupt: print("\n收到中断信号") except asyncio.CancelledError: print("\n任务被取消") except Exception as e: print(f"\n命令处理出错: {e}") print("\n关闭环境...") self.ipc_handler.update_status("stopped") await self.env.close() print("环境已关闭") print("=" * 60) async def main(): parser = argparse.ArgumentParser(description='OASIS Reddit模拟') parser.add_argument( '--config', type=str, required=True, help='配置文件路径 (simulation_config.json)' ) parser.add_argument( '--max-rounds', type=int, default=None, help='最大模拟轮数(可选,用于截断过长的模拟)' ) parser.add_argument( '--no-wait', action='store_true', default=False, help='模拟完成后立即关闭环境,不进入等待命令模式' ) args = parser.parse_args() # Create the shutdown event lazily here so it is bound to the running asyncio loop. global _shutdown_event _shutdown_event = asyncio.Event() if not os.path.exists(args.config): print(f"错误: 配置文件不存在: {args.config}") sys.exit(1) # Initialize log config with fixed filenames; old logs are cleared inside setup_oasis_logging. simulation_dir = os.path.dirname(args.config) or "." setup_oasis_logging(os.path.join(simulation_dir, "log")) runner = RedditSimulationRunner( config_path=args.config, wait_for_commands=not args.no_wait ) await runner.run(max_rounds=args.max_rounds) def setup_signal_handlers(): """Install signal handlers so SIGTERM/SIGINT trigger a graceful exit. This gives the program a chance to clean up resources (close the database, the OASIS environment, etc.). """ def signal_handler(signum, frame): global _cleanup_done sig_name = "SIGTERM" if signum == signal.SIGTERM else "SIGINT" print(f"\n收到 {sig_name} 信号,正在退出...") if not _cleanup_done: _cleanup_done = True if _shutdown_event: _shutdown_event.set() else: # Force exit only on a repeat signal so the user can still hard-kill if cleanup hangs. print("强制退出...") sys.exit(1) signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) if __name__ == "__main__": setup_signal_handlers() try: asyncio.run(main()) except KeyboardInterrupt: print("\n程序被中断") except SystemExit: pass finally: print("模拟进程已退出")