""" OASIS simulation runner. Runs the simulation in the background, records each agent's actions, and supports real-time status monitoring. """ import os import sys import json import time import asyncio import threading import subprocess import signal import atexit from typing import Dict, Any, List, Optional, Union from dataclasses import dataclass, field from datetime import datetime from enum import Enum from queue import Queue from ..config import Config from ..utils.logger import get_logger from ..utils.locale import get_locale, set_locale, t from .zep_graph_memory_updater import ZepGraphMemoryManager from .simulation_ipc import SimulationIPCClient, CommandType, IPCResponse logger = get_logger('mirofish.simulation_runner') # Tracks whether the cleanup handler has been registered (guards against double registration in Flask reloader). _cleanup_registered = False IS_WINDOWS = sys.platform == 'win32' class RunnerStatus(str, Enum): """Runner lifecycle states.""" IDLE = "idle" STARTING = "starting" RUNNING = "running" PAUSED = "paused" STOPPING = "stopping" STOPPED = "stopped" COMPLETED = "completed" FAILED = "failed" @dataclass class AgentAction: """A single recorded agent action.""" round_num: int timestamp: str platform: str # twitter / reddit agent_id: int agent_name: str action_type: str # CREATE_POST, LIKE_POST, etc. action_args: Dict[str, Any] = field(default_factory=dict) result: Optional[str] = None success: bool = True def to_dict(self) -> Dict[str, Any]: return { "round_num": self.round_num, "timestamp": self.timestamp, "platform": self.platform, "agent_id": self.agent_id, "agent_name": self.agent_name, "action_type": self.action_type, "action_args": self.action_args, "result": self.result, "success": self.success, } @dataclass class RoundSummary: """Per-round summary statistics.""" round_num: int start_time: str end_time: Optional[str] = None simulated_hour: int = 0 twitter_actions: int = 0 reddit_actions: int = 0 active_agents: List[int] = field(default_factory=list) actions: List[AgentAction] = field(default_factory=list) def to_dict(self) -> Dict[str, Any]: return { "round_num": self.round_num, "start_time": self.start_time, "end_time": self.end_time, "simulated_hour": self.simulated_hour, "twitter_actions": self.twitter_actions, "reddit_actions": self.reddit_actions, "active_agents": self.active_agents, "actions_count": len(self.actions), "actions": [a.to_dict() for a in self.actions], } @dataclass class SimulationRunState: """Live runtime state for a simulation.""" simulation_id: str runner_status: RunnerStatus = RunnerStatus.IDLE current_round: int = 0 total_rounds: int = 0 simulated_hours: int = 0 total_simulation_hours: int = 0 # Per-platform round and simulated-time counters (used when both platforms run in parallel). twitter_current_round: int = 0 reddit_current_round: int = 0 twitter_simulated_hours: int = 0 reddit_simulated_hours: int = 0 twitter_running: bool = False reddit_running: bool = False twitter_actions_count: int = 0 reddit_actions_count: int = 0 # Per-platform completion flags, set when a simulation_end event is observed in actions.jsonl. twitter_completed: bool = False reddit_completed: bool = False rounds: List[RoundSummary] = field(default_factory=list) # Recent actions buffer; surfaced to the frontend for the live feed. recent_actions: List[AgentAction] = field(default_factory=list) max_recent_actions: int = 50 started_at: Optional[str] = None updated_at: str = field(default_factory=lambda: datetime.now().isoformat()) completed_at: Optional[str] = None error: Optional[str] = None # Main subprocess PID — captured so the process can later be stopped. process_pid: Optional[int] = None def add_action(self, action: AgentAction): """Prepend an action to the recent-actions buffer and update counters.""" self.recent_actions.insert(0, action) if len(self.recent_actions) > self.max_recent_actions: self.recent_actions = self.recent_actions[:self.max_recent_actions] if action.platform == "twitter": self.twitter_actions_count += 1 else: self.reddit_actions_count += 1 self.updated_at = datetime.now().isoformat() def to_dict(self) -> Dict[str, Any]: return { "simulation_id": self.simulation_id, "runner_status": self.runner_status.value, "current_round": self.current_round, "total_rounds": self.total_rounds, "simulated_hours": self.simulated_hours, "total_simulation_hours": self.total_simulation_hours, "progress_percent": round(self.current_round / max(self.total_rounds, 1) * 100, 1), # Per-platform round and simulated-time counters. "twitter_current_round": self.twitter_current_round, "reddit_current_round": self.reddit_current_round, "twitter_simulated_hours": self.twitter_simulated_hours, "reddit_simulated_hours": self.reddit_simulated_hours, "twitter_running": self.twitter_running, "reddit_running": self.reddit_running, "twitter_completed": self.twitter_completed, "reddit_completed": self.reddit_completed, "twitter_actions_count": self.twitter_actions_count, "reddit_actions_count": self.reddit_actions_count, "total_actions_count": self.twitter_actions_count + self.reddit_actions_count, "started_at": self.started_at, "updated_at": self.updated_at, "completed_at": self.completed_at, "error": self.error, "process_pid": self.process_pid, } def to_detail_dict(self) -> Dict[str, Any]: """Return the dict form of the state including recent actions.""" result = self.to_dict() result["recent_actions"] = [a.to_dict() for a in self.recent_actions] result["rounds_count"] = len(self.rounds) return result class SimulationRunner: """ Simulation runner. Responsibilities: 1. Run the OASIS simulation in a background subprocess. 2. Parse the run logs and record each agent's actions. 3. Provide real-time status query interfaces. 4. Support pause/stop/resume operations. """ RUN_STATE_DIR = os.path.join( os.path.dirname(__file__), '../../uploads/simulations' ) SCRIPTS_DIR = os.path.join( os.path.dirname(__file__), '../../scripts' ) # In-memory caches of runtime state, processes, queues, monitor threads, and log file handles. _run_states: Dict[str, SimulationRunState] = {} _processes: Dict[str, subprocess.Popen] = {} _action_queues: Dict[str, Queue] = {} _monitor_threads: Dict[str, threading.Thread] = {} _stdout_files: Dict[str, Any] = {} _stderr_files: Dict[str, Any] = {} # Graph-memory-update flag per simulation_id. _graph_memory_enabled: Dict[str, bool] = {} @classmethod def get_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]: """Return the cached run state, falling back to disk if not loaded yet.""" if simulation_id in cls._run_states: return cls._run_states[simulation_id] state = cls._load_run_state(simulation_id) if state: cls._run_states[simulation_id] = state return state @classmethod def _load_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]: """Load run state from the on-disk JSON snapshot.""" state_file = os.path.join(cls.RUN_STATE_DIR, simulation_id, "run_state.json") if not os.path.exists(state_file): return None try: with open(state_file, 'r', encoding='utf-8') as f: data = json.load(f) state = SimulationRunState( simulation_id=simulation_id, runner_status=RunnerStatus(data.get("runner_status", "idle")), current_round=data.get("current_round", 0), total_rounds=data.get("total_rounds", 0), simulated_hours=data.get("simulated_hours", 0), total_simulation_hours=data.get("total_simulation_hours", 0), # Per-platform round and simulated-time counters. twitter_current_round=data.get("twitter_current_round", 0), reddit_current_round=data.get("reddit_current_round", 0), twitter_simulated_hours=data.get("twitter_simulated_hours", 0), reddit_simulated_hours=data.get("reddit_simulated_hours", 0), twitter_running=data.get("twitter_running", False), reddit_running=data.get("reddit_running", False), twitter_completed=data.get("twitter_completed", False), reddit_completed=data.get("reddit_completed", False), twitter_actions_count=data.get("twitter_actions_count", 0), reddit_actions_count=data.get("reddit_actions_count", 0), started_at=data.get("started_at"), updated_at=data.get("updated_at", datetime.now().isoformat()), completed_at=data.get("completed_at"), error=data.get("error"), process_pid=data.get("process_pid"), ) # Restore the recent-actions buffer. actions_data = data.get("recent_actions", []) for a in actions_data: state.recent_actions.append(AgentAction( round_num=a.get("round_num", 0), timestamp=a.get("timestamp", ""), platform=a.get("platform", ""), agent_id=a.get("agent_id", 0), agent_name=a.get("agent_name", ""), action_type=a.get("action_type", ""), action_args=a.get("action_args", {}), result=a.get("result"), success=a.get("success", True), )) return state except Exception as e: logger.error(t("log.simulation_runner.m001", str=str(e))) return None @classmethod def _save_run_state(cls, state: SimulationRunState): """Persist the run state to its JSON snapshot file.""" sim_dir = os.path.join(cls.RUN_STATE_DIR, state.simulation_id) os.makedirs(sim_dir, exist_ok=True) state_file = os.path.join(sim_dir, "run_state.json") data = state.to_detail_dict() with open(state_file, 'w', encoding='utf-8') as f: json.dump(data, f, ensure_ascii=False, indent=2) cls._run_states[state.simulation_id] = state @classmethod def start_simulation( cls, simulation_id: str, platform: str = "parallel", # twitter / reddit / parallel max_rounds: int = None, # Optional cap on simulation rounds (truncates overly long runs). enable_graph_memory_update: bool = False, # Whether to push activity into the Zep graph. graph_id: str = None # Zep graph ID (required when graph-memory updates are enabled). ) -> SimulationRunState: """ Start the simulation. Args: simulation_id: Simulation ID. platform: Platform to run (twitter/reddit/parallel). max_rounds: Optional cap on simulation rounds (truncates overly long runs). enable_graph_memory_update: Whether to push agent activity to the Zep graph in real time. graph_id: Zep graph ID (required when graph-memory updates are enabled). Returns: SimulationRunState """ # Refuse to start a duplicate run for the same simulation_id. existing = cls.get_run_state(simulation_id) if existing and existing.runner_status in [RunnerStatus.RUNNING, RunnerStatus.STARTING]: raise ValueError(f"模拟已在运行中: {simulation_id}") # Load the simulation configuration written during preparation. sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) config_path = os.path.join(sim_dir, "simulation_config.json") if not os.path.exists(config_path): raise ValueError(f"模拟配置不存在,请先调用 /prepare 接口") with open(config_path, 'r', encoding='utf-8') as f: config = json.load(f) # Compute total rounds from time-window settings. time_config = config.get("time_config", {}) total_hours = time_config.get("total_simulation_hours", 72) minutes_per_round = time_config.get("minutes_per_round", 30) total_rounds = int(total_hours * 60 / minutes_per_round) # If a cap was provided, clamp total_rounds. if max_rounds is not None and max_rounds > 0: original_rounds = total_rounds total_rounds = min(total_rounds, max_rounds) if total_rounds < original_rounds: logger.info(t("log.simulation_runner.m002", original_rounds=original_rounds, total_rounds=total_rounds, max_rounds=max_rounds)) state = SimulationRunState( simulation_id=simulation_id, runner_status=RunnerStatus.STARTING, total_rounds=total_rounds, total_simulation_hours=total_hours, started_at=datetime.now().isoformat(), ) cls._save_run_state(state) # Spin up a graph-memory updater if requested. if enable_graph_memory_update: if not graph_id: raise ValueError("启用图谱记忆更新时必须提供 graph_id") try: ZepGraphMemoryManager.create_updater(simulation_id, graph_id) cls._graph_memory_enabled[simulation_id] = True logger.info(t("log.simulation_runner.m003", simulation_id=simulation_id, graph_id=graph_id)) except Exception as e: logger.error(t("log.simulation_runner.m004", e=e)) cls._graph_memory_enabled[simulation_id] = False else: cls._graph_memory_enabled[simulation_id] = False # Pick the entry script (lives in backend/scripts/) based on the requested platform. if platform == "twitter": script_name = "run_twitter_simulation.py" state.twitter_running = True elif platform == "reddit": script_name = "run_reddit_simulation.py" state.reddit_running = True else: script_name = "run_parallel_simulation.py" state.twitter_running = True state.reddit_running = True script_path = os.path.join(cls.SCRIPTS_DIR, script_name) if not os.path.exists(script_path): raise ValueError(f"脚本不存在: {script_path}") action_queue = Queue() cls._action_queues[simulation_id] = action_queue try: # Log layout written by the subprocess: # twitter/actions.jsonl - Twitter action log # reddit/actions.jsonl - Reddit action log # simulation.log - main-process log cmd = [ sys.executable, script_path, "--config", config_path, ] if max_rounds is not None and max_rounds > 0: cmd.extend(["--max-rounds", str(max_rounds)]) # Redirect stdout/stderr to a file so a full pipe buffer cannot block the subprocess. main_log_path = os.path.join(sim_dir, "simulation.log") main_log_file = open(main_log_path, 'w', encoding='utf-8') # Force UTF-8 in the child so third-party libs (e.g. OASIS) that open files without an # explicit encoding work correctly on Windows. env = os.environ.copy() env['PYTHONUTF8'] = '1' env['PYTHONIOENCODING'] = 'utf-8' # cwd is the simulation directory so generated artifacts (databases, etc.) land there. # start_new_session=True creates a fresh process group so os.killpg can terminate the # entire tree on shutdown. process = subprocess.Popen( cmd, cwd=sim_dir, stdout=main_log_file, stderr=subprocess.STDOUT, text=True, encoding='utf-8', bufsize=1, env=env, start_new_session=True, ) # Retain the log file handle so it can be closed after the subprocess exits. cls._stdout_files[simulation_id] = main_log_file cls._stderr_files[simulation_id] = None state.process_pid = process.pid state.runner_status = RunnerStatus.RUNNING cls._processes[simulation_id] = process cls._save_run_state(state) # Capture locale before spawning monitor thread current_locale = get_locale() # Spawn the log-tailing monitor thread. monitor_thread = threading.Thread( target=cls._monitor_simulation, args=(simulation_id, current_locale), daemon=True ) monitor_thread.start() cls._monitor_threads[simulation_id] = monitor_thread logger.info(t("log.simulation_runner.m005", simulation_id=simulation_id, process=process.pid, platform=platform)) except Exception as e: state.runner_status = RunnerStatus.FAILED state.error = str(e) cls._save_run_state(state) raise return state @classmethod def _monitor_simulation(cls, simulation_id: str, locale: str = 'zh'): """Monitor the simulation process and tail its per-platform action logs.""" set_locale(locale) sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) twitter_actions_log = os.path.join(sim_dir, "twitter", "actions.jsonl") reddit_actions_log = os.path.join(sim_dir, "reddit", "actions.jsonl") process = cls._processes.get(simulation_id) state = cls.get_run_state(simulation_id) if not process or not state: return twitter_position = 0 reddit_position = 0 try: while process.poll() is None: if os.path.exists(twitter_actions_log): twitter_position = cls._read_action_log( twitter_actions_log, twitter_position, state, "twitter" ) if os.path.exists(reddit_actions_log): reddit_position = cls._read_action_log( reddit_actions_log, reddit_position, state, "reddit" ) cls._save_run_state(state) time.sleep(2) # Drain any log lines written between the last poll and the process exit. if os.path.exists(twitter_actions_log): cls._read_action_log(twitter_actions_log, twitter_position, state, "twitter") if os.path.exists(reddit_actions_log): cls._read_action_log(reddit_actions_log, reddit_position, state, "reddit") exit_code = process.returncode if exit_code == 0: state.runner_status = RunnerStatus.COMPLETED state.completed_at = datetime.now().isoformat() logger.info(t("log.simulation_runner.m006", simulation_id=simulation_id)) else: state.runner_status = RunnerStatus.FAILED # Pull the tail of the main log so the failure context is surfaced in state.error. main_log_path = os.path.join(sim_dir, "simulation.log") error_info = "" try: if os.path.exists(main_log_path): with open(main_log_path, 'r', encoding='utf-8') as f: error_info = f.read()[-2000:] # keep only the last 2000 chars except Exception: pass state.error = f"进程退出码: {exit_code}, 错误: {error_info}" logger.error(t("log.simulation_runner.m007", simulation_id=simulation_id, state=state.error)) state.twitter_running = False state.reddit_running = False cls._save_run_state(state) except Exception as e: logger.error(t("log.simulation_runner.m008", simulation_id=simulation_id, str=str(e))) state.runner_status = RunnerStatus.FAILED state.error = str(e) cls._save_run_state(state) finally: # Tear down the graph-memory updater, if we started one. if cls._graph_memory_enabled.get(simulation_id, False): try: ZepGraphMemoryManager.stop_updater(simulation_id) logger.info(t("log.simulation_runner.m009", simulation_id=simulation_id)) except Exception as e: logger.error(t("log.simulation_runner.m010", e=e)) cls._graph_memory_enabled.pop(simulation_id, None) cls._processes.pop(simulation_id, None) cls._action_queues.pop(simulation_id, None) # Close the retained log file handles. if simulation_id in cls._stdout_files: try: cls._stdout_files[simulation_id].close() except Exception: pass cls._stdout_files.pop(simulation_id, None) if simulation_id in cls._stderr_files and cls._stderr_files[simulation_id]: try: cls._stderr_files[simulation_id].close() except Exception: pass cls._stderr_files.pop(simulation_id, None) @classmethod def _read_action_log( cls, log_path: str, position: int, state: SimulationRunState, platform: str ) -> int: """ Read new entries from a per-platform action log. Args: log_path: Path to the action-log file. position: Byte offset where the previous read finished. state: Run-state object to mutate. platform: Platform name (twitter/reddit). Returns: New byte offset after this read. """ graph_memory_enabled = cls._graph_memory_enabled.get(state.simulation_id, False) graph_updater = None if graph_memory_enabled: graph_updater = ZepGraphMemoryManager.get_updater(state.simulation_id) try: with open(log_path, 'r', encoding='utf-8') as f: f.seek(position) for line in f: line = line.strip() if line: try: action_data = json.loads(line) # Event records (simulation_start/end, round_end, ...) are routed here. if "event_type" in action_data: event_type = action_data.get("event_type") # simulation_end means the platform finished its run. if event_type == "simulation_end": if platform == "twitter": state.twitter_completed = True state.twitter_running = False logger.info(t("log.simulation_runner.m011", state=state.simulation_id, action_data=action_data.get('total_rounds'), action_data_2=action_data.get('total_actions'))) elif platform == "reddit": state.reddit_completed = True state.reddit_running = False logger.info(t("log.simulation_runner.m012", state=state.simulation_id, action_data=action_data.get('total_rounds'), action_data_2=action_data.get('total_actions'))) # Mark the run as completed once every enabled platform has reported # simulation_end. Single-platform runs only need that one. all_completed = cls._check_all_platforms_completed(state) if all_completed: state.runner_status = RunnerStatus.COMPLETED state.completed_at = datetime.now().isoformat() logger.info(t("log.simulation_runner.m013", state=state.simulation_id)) # Round counters come from round_end events. elif event_type == "round_end": round_num = action_data.get("round", 0) simulated_hours = action_data.get("simulated_hours", 0) if platform == "twitter": if round_num > state.twitter_current_round: state.twitter_current_round = round_num state.twitter_simulated_hours = simulated_hours elif platform == "reddit": if round_num > state.reddit_current_round: state.reddit_current_round = round_num state.reddit_simulated_hours = simulated_hours # Overall counters track the max across enabled platforms. if round_num > state.current_round: state.current_round = round_num state.simulated_hours = max(state.twitter_simulated_hours, state.reddit_simulated_hours) continue action = AgentAction( round_num=action_data.get("round", 0), timestamp=action_data.get("timestamp", datetime.now().isoformat()), platform=platform, agent_id=action_data.get("agent_id", 0), agent_name=action_data.get("agent_name", ""), action_type=action_data.get("action_type", ""), action_args=action_data.get("action_args", {}), result=action_data.get("result"), success=action_data.get("success", True), ) state.add_action(action) if action.round_num and action.round_num > state.current_round: state.current_round = action.round_num # Forward the activity to the Zep graph when the updater is enabled. if graph_updater: graph_updater.add_activity_from_dict(action_data, platform) except json.JSONDecodeError: pass return f.tell() except Exception as e: logger.warning(t("log.simulation_runner.m014", log_path=log_path, e=e)) return position @classmethod def _check_all_platforms_completed(cls, state: SimulationRunState) -> bool: """ Return whether every enabled platform has completed its simulation. A platform counts as enabled when its corresponding actions.jsonl file exists on disk. Returns: True if all enabled platforms have completed. """ sim_dir = os.path.join(cls.RUN_STATE_DIR, state.simulation_id) twitter_log = os.path.join(sim_dir, "twitter", "actions.jsonl") reddit_log = os.path.join(sim_dir, "reddit", "actions.jsonl") # File presence is our enabled-platform signal. twitter_enabled = os.path.exists(twitter_log) reddit_enabled = os.path.exists(reddit_log) if twitter_enabled and not state.twitter_completed: return False if reddit_enabled and not state.reddit_completed: return False # At least one platform must be enabled (and, by the checks above, completed). return twitter_enabled or reddit_enabled @classmethod def _terminate_process(cls, process: subprocess.Popen, simulation_id: str, timeout: int = 10): """ Terminate a process and its subprocesses in a cross-platform way. Args: process: Process to terminate. simulation_id: Simulation ID (used for log messages). timeout: Seconds to wait for graceful exit before escalating. """ if IS_WINDOWS: # Windows: taskkill /T tears down the whole process tree, /F escalates to a hard kill. logger.info(t("log.simulation_runner.m015", simulation_id=simulation_id, process=process.pid)) try: # Graceful termination first. subprocess.run( ['taskkill', '/PID', str(process.pid), '/T'], capture_output=True, timeout=5 ) try: process.wait(timeout=timeout) except subprocess.TimeoutExpired: # Force kill the tree. logger.warning(t("log.simulation_runner.m016", simulation_id=simulation_id)) subprocess.run( ['taskkill', '/F', '/PID', str(process.pid), '/T'], capture_output=True, timeout=5 ) process.wait(timeout=5) except Exception as e: logger.warning(t("log.simulation_runner.m017", e=e)) process.terminate() try: process.wait(timeout=5) except subprocess.TimeoutExpired: process.kill() else: # Unix: kill the entire process group. # Because the subprocess was started with start_new_session=True the pgid equals the PID. pgid = os.getpgid(process.pid) logger.info(t("log.simulation_runner.m018", simulation_id=simulation_id, pgid=pgid)) # SIGTERM first to allow graceful shutdown. os.killpg(pgid, signal.SIGTERM) try: process.wait(timeout=timeout) except subprocess.TimeoutExpired: # Escalate to SIGKILL on timeout. logger.warning(t("log.simulation_runner.m019", simulation_id=simulation_id)) os.killpg(pgid, signal.SIGKILL) process.wait(timeout=5) @classmethod def stop_simulation(cls, simulation_id: str) -> SimulationRunState: """Stop the simulation subprocess and update its state.""" state = cls.get_run_state(simulation_id) if not state: raise ValueError(f"模拟不存在: {simulation_id}") if state.runner_status not in [RunnerStatus.RUNNING, RunnerStatus.PAUSED]: raise ValueError(f"模拟未在运行: {simulation_id}, status={state.runner_status}") state.runner_status = RunnerStatus.STOPPING cls._save_run_state(state) process = cls._processes.get(simulation_id) if process and process.poll() is None: try: cls._terminate_process(process, simulation_id) except ProcessLookupError: # The process has already exited. pass except Exception as e: logger.error(t("log.simulation_runner.m020", simulation_id=simulation_id, e=e)) # Fall back to direct termination on the Popen handle. try: process.terminate() process.wait(timeout=5) except Exception: process.kill() state.runner_status = RunnerStatus.STOPPED state.twitter_running = False state.reddit_running = False state.completed_at = datetime.now().isoformat() cls._save_run_state(state) # Tear down the graph-memory updater, if any. if cls._graph_memory_enabled.get(simulation_id, False): try: ZepGraphMemoryManager.stop_updater(simulation_id) logger.info(t("log.simulation_runner.m021", simulation_id=simulation_id)) except Exception as e: logger.error(t("log.simulation_runner.m022", e=e)) cls._graph_memory_enabled.pop(simulation_id, None) logger.info(t("log.simulation_runner.m023", simulation_id=simulation_id)) return state @classmethod def _read_actions_from_file( cls, file_path: str, default_platform: Optional[str] = None, platform_filter: Optional[str] = None, agent_id: Optional[int] = None, round_num: Optional[int] = None ) -> List[AgentAction]: """ Read actions from a single action-log file. Args: file_path: Path to the action-log file. default_platform: Platform to assume when a record has no `platform` field. platform_filter: Optional platform filter. agent_id: Optional agent-id filter. round_num: Optional round-number filter. """ if not os.path.exists(file_path): return [] actions = [] with open(file_path, 'r', encoding='utf-8') as f: for line in f: line = line.strip() if not line: continue try: data = json.loads(line) # Skip event records (simulation_start, round_start, round_end, ...). if "event_type" in data: continue # Skip records without an agent_id (non-agent actions). if "agent_id" not in data: continue # Prefer the record's own platform; fall back to the default for legacy entries. record_platform = data.get("platform") or default_platform or "" if platform_filter and record_platform != platform_filter: continue if agent_id is not None and data.get("agent_id") != agent_id: continue if round_num is not None and data.get("round") != round_num: continue actions.append(AgentAction( round_num=data.get("round", 0), timestamp=data.get("timestamp", ""), platform=record_platform, agent_id=data.get("agent_id", 0), agent_name=data.get("agent_name", ""), action_type=data.get("action_type", ""), action_args=data.get("action_args", {}), result=data.get("result"), success=data.get("success", True), )) except json.JSONDecodeError: continue return actions @classmethod def get_all_actions( cls, simulation_id: str, platform: Optional[str] = None, agent_id: Optional[int] = None, round_num: Optional[int] = None ) -> List[AgentAction]: """ Return the complete action history across all platforms (no pagination). Args: simulation_id: Simulation ID. platform: Optional platform filter (twitter/reddit). agent_id: Optional agent filter. round_num: Optional round filter. Returns: Full action list, sorted by timestamp with newest first. """ sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) actions = [] # Twitter action log: derive platform from the file path. twitter_actions_log = os.path.join(sim_dir, "twitter", "actions.jsonl") if not platform or platform == "twitter": actions.extend(cls._read_actions_from_file( twitter_actions_log, default_platform="twitter", platform_filter=platform, agent_id=agent_id, round_num=round_num )) # Reddit action log: derive platform from the file path. reddit_actions_log = os.path.join(sim_dir, "reddit", "actions.jsonl") if not platform or platform == "reddit": actions.extend(cls._read_actions_from_file( reddit_actions_log, default_platform="reddit", platform_filter=platform, agent_id=agent_id, round_num=round_num )) # Fall back to the legacy single-file layout if no per-platform files exist. if not actions: actions_log = os.path.join(sim_dir, "actions.jsonl") actions = cls._read_actions_from_file( actions_log, default_platform=None, # Legacy files carry their own platform field. platform_filter=platform, agent_id=agent_id, round_num=round_num ) # Newest-first by timestamp. actions.sort(key=lambda x: x.timestamp, reverse=True) return actions @classmethod def get_actions( cls, simulation_id: str, limit: int = 100, offset: int = 0, platform: Optional[str] = None, agent_id: Optional[int] = None, round_num: Optional[int] = None ) -> List[AgentAction]: """ Return action history with pagination. Args: simulation_id: Simulation ID. limit: Maximum number of actions to return. offset: Offset into the sorted result list. platform: Optional platform filter. agent_id: Optional agent filter. round_num: Optional round filter. Returns: A page of actions. """ actions = cls.get_all_actions( simulation_id=simulation_id, platform=platform, agent_id=agent_id, round_num=round_num ) return actions[offset:offset + limit] @classmethod def get_timeline( cls, simulation_id: str, start_round: int = 0, end_round: Optional[int] = None ) -> List[Dict[str, Any]]: """ Return a per-round timeline summary for the simulation. Args: simulation_id: Simulation ID. start_round: First round to include (inclusive). end_round: Last round to include (inclusive); None means no upper bound. Returns: One summary entry per round. """ actions = cls.get_actions(simulation_id, limit=10000) # Group actions by round. rounds: Dict[int, Dict[str, Any]] = {} for action in actions: round_num = action.round_num if round_num < start_round: continue if end_round is not None and round_num > end_round: continue if round_num not in rounds: rounds[round_num] = { "round_num": round_num, "twitter_actions": 0, "reddit_actions": 0, "active_agents": set(), "action_types": {}, "first_action_time": action.timestamp, "last_action_time": action.timestamp, } r = rounds[round_num] if action.platform == "twitter": r["twitter_actions"] += 1 else: r["reddit_actions"] += 1 r["active_agents"].add(action.agent_id) r["action_types"][action.action_type] = r["action_types"].get(action.action_type, 0) + 1 r["last_action_time"] = action.timestamp # Materialise into a sorted list. result = [] for round_num in sorted(rounds.keys()): r = rounds[round_num] result.append({ "round_num": round_num, "twitter_actions": r["twitter_actions"], "reddit_actions": r["reddit_actions"], "total_actions": r["twitter_actions"] + r["reddit_actions"], "active_agents_count": len(r["active_agents"]), "active_agents": list(r["active_agents"]), "action_types": r["action_types"], "first_action_time": r["first_action_time"], "last_action_time": r["last_action_time"], }) return result @classmethod def get_agent_stats(cls, simulation_id: str) -> List[Dict[str, Any]]: """ Return per-agent statistics for the simulation. Returns: Per-agent statistics, sorted by total action count (descending). """ actions = cls.get_actions(simulation_id, limit=10000) agent_stats: Dict[int, Dict[str, Any]] = {} for action in actions: agent_id = action.agent_id if agent_id not in agent_stats: agent_stats[agent_id] = { "agent_id": agent_id, "agent_name": action.agent_name, "total_actions": 0, "twitter_actions": 0, "reddit_actions": 0, "action_types": {}, "first_action_time": action.timestamp, "last_action_time": action.timestamp, } stats = agent_stats[agent_id] stats["total_actions"] += 1 if action.platform == "twitter": stats["twitter_actions"] += 1 else: stats["reddit_actions"] += 1 stats["action_types"][action.action_type] = stats["action_types"].get(action.action_type, 0) + 1 stats["last_action_time"] = action.timestamp result = sorted(agent_stats.values(), key=lambda x: x["total_actions"], reverse=True) return result @classmethod def cleanup_simulation_logs(cls, simulation_id: str) -> Dict[str, Any]: """ Clean up the simulation's run logs so the simulation can be force-restarted. Deletes the following files: - run_state.json - twitter/actions.jsonl - reddit/actions.jsonl - simulation.log - stdout.log / stderr.log - twitter_simulation.db (simulation database) - reddit_simulation.db (simulation database) - env_status.json (environment status) Note: simulation_config.json and the profile files are preserved. Args: simulation_id: Simulation ID. Returns: Cleanup result info. """ import shutil sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) if not os.path.exists(sim_dir): return {"success": True, "message": "模拟目录不存在,无需清理"} cleaned_files = [] errors = [] # Files to delete (includes per-platform databases). files_to_delete = [ "run_state.json", "simulation.log", "stdout.log", "stderr.log", "twitter_simulation.db", # Twitter platform database. "reddit_simulation.db", # Reddit platform database. "env_status.json", # Environment-status file. ] # Per-platform directories whose action logs should be cleaned. dirs_to_clean = ["twitter", "reddit"] for filename in files_to_delete: file_path = os.path.join(sim_dir, filename) if os.path.exists(file_path): try: os.remove(file_path) cleaned_files.append(filename) except Exception as e: errors.append(f"删除 {filename} 失败: {str(e)}") # Clean per-platform action logs. for dir_name in dirs_to_clean: dir_path = os.path.join(sim_dir, dir_name) if os.path.exists(dir_path): actions_file = os.path.join(dir_path, "actions.jsonl") if os.path.exists(actions_file): try: os.remove(actions_file) cleaned_files.append(f"{dir_name}/actions.jsonl") except Exception as e: errors.append(f"删除 {dir_name}/actions.jsonl 失败: {str(e)}") # Drop the in-memory run state for this simulation. if simulation_id in cls._run_states: del cls._run_states[simulation_id] logger.info(t("log.simulation_runner.m024", simulation_id=simulation_id, cleaned_files=cleaned_files)) return { "success": len(errors) == 0, "cleaned_files": cleaned_files, "errors": errors if errors else None } # Guard so cleanup_all_simulations only runs once per process lifetime. _cleanup_done = False @classmethod def cleanup_all_simulations(cls): """ Clean up every running simulation subprocess. Invoked at server shutdown to guarantee no child processes leak. """ if cls._cleanup_done: return cls._cleanup_done = True # Skip the "shutting down" log entirely if there's nothing to clean up. has_processes = bool(cls._processes) has_updaters = bool(cls._graph_memory_enabled) if not has_processes and not has_updaters: return logger.info(t("log.simulation_runner.m025")) # Stop graph-memory updaters first (stop_all logs internally). try: ZepGraphMemoryManager.stop_all() except Exception as e: logger.error(t("log.simulation_runner.m026", e=e)) cls._graph_memory_enabled.clear() # Snapshot the process map so we can mutate it during iteration. processes = list(cls._processes.items()) for simulation_id, process in processes: try: if process.poll() is None: logger.info(t("log.simulation_runner.m027", simulation_id=simulation_id, process=process.pid)) try: cls._terminate_process(process, simulation_id, timeout=5) except (ProcessLookupError, OSError): # The process may already be gone; fall back to direct termination. try: process.terminate() process.wait(timeout=3) except Exception: process.kill() # Update run_state.json so external readers see the stopped status. state = cls.get_run_state(simulation_id) if state: state.runner_status = RunnerStatus.STOPPED state.twitter_running = False state.reddit_running = False state.completed_at = datetime.now().isoformat() state.error = "服务器关闭,模拟被终止" cls._save_run_state(state) # Also flip the project-level state.json status to "stopped". try: sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) state_file = os.path.join(sim_dir, "state.json") logger.info(t("log.simulation_runner.m028", state_file=state_file)) if os.path.exists(state_file): with open(state_file, 'r', encoding='utf-8') as f: state_data = json.load(f) state_data['status'] = 'stopped' state_data['updated_at'] = datetime.now().isoformat() with open(state_file, 'w', encoding='utf-8') as f: json.dump(state_data, f, indent=2, ensure_ascii=False) logger.info(t("log.simulation_runner.m029", simulation_id=simulation_id)) else: logger.warning(t("log.simulation_runner.m030", state_file=state_file)) except Exception as state_err: logger.warning(t("log.simulation_runner.m031", simulation_id=simulation_id, state_err=state_err)) except Exception as e: logger.error(t("log.simulation_runner.m032", simulation_id=simulation_id, e=e)) # Close any retained log file handles. for simulation_id, file_handle in list(cls._stdout_files.items()): try: if file_handle: file_handle.close() except Exception: pass cls._stdout_files.clear() for simulation_id, file_handle in list(cls._stderr_files.items()): try: if file_handle: file_handle.close() except Exception: pass cls._stderr_files.clear() # Drop in-memory bookkeeping. cls._processes.clear() cls._action_queues.clear() logger.info(t("log.simulation_runner.m033")) @classmethod def register_cleanup(cls): """ Register the shutdown cleanup hook. Called at Flask application startup so that all simulation subprocesses are torn down when the server stops. """ global _cleanup_registered if _cleanup_registered: return # In Flask debug mode the reloader spawns a child process that actually runs the app # (signaled by WERKZEUG_RUN_MAIN=true). Outside debug mode that variable is unset and we # still want to register the cleanup hook. is_reloader_process = os.environ.get('WERKZEUG_RUN_MAIN') == 'true' is_debug_mode = os.environ.get('FLASK_DEBUG') == '1' or os.environ.get('WERKZEUG_RUN_MAIN') is not None # Debug mode: only register inside the reloader child. Non-debug: always register. if is_debug_mode and not is_reloader_process: _cleanup_registered = True # Prevent the parent process from retrying. return # Capture the previously installed signal handlers so we can chain to them. original_sigint = signal.getsignal(signal.SIGINT) original_sigterm = signal.getsignal(signal.SIGTERM) # SIGHUP exists only on Unix (macOS/Linux); Windows does not have it. original_sighup = None has_sighup = hasattr(signal, 'SIGHUP') if has_sighup: original_sighup = signal.getsignal(signal.SIGHUP) def cleanup_handler(signum=None, frame=None): """Signal handler that cleans up simulations before delegating to the original handler.""" # Only log when there is actually something to clean up. if cls._processes or cls._graph_memory_enabled: logger.info(t("log.simulation_runner.m034", signum=signum)) cls.cleanup_all_simulations() # Chain to the original handler so Flask exits normally. if signum == signal.SIGINT and callable(original_sigint): original_sigint(signum, frame) elif signum == signal.SIGTERM and callable(original_sigterm): original_sigterm(signum, frame) elif has_sighup and signum == signal.SIGHUP: # SIGHUP is sent when the terminal is closed. if callable(original_sighup): original_sighup(signum, frame) else: # Default behavior: exit cleanly. sys.exit(0) else: # If the original handler is not callable (e.g. SIG_DFL), use the default behavior. raise KeyboardInterrupt # Register the atexit handler as a fallback. atexit.register(cls.cleanup_all_simulations) # Register signal handlers (only valid from the main thread). try: # SIGTERM: default signal sent by `kill`. signal.signal(signal.SIGTERM, cleanup_handler) # SIGINT: Ctrl+C signal.signal(signal.SIGINT, cleanup_handler) # SIGHUP: terminal close (Unix only). if has_sighup: signal.signal(signal.SIGHUP, cleanup_handler) except ValueError: # Not the main thread — fall back to the atexit hook. logger.warning(t("log.simulation_runner.m035")) _cleanup_registered = True @classmethod def get_running_simulations(cls) -> List[str]: """Return a list of every simulation ID with a live subprocess.""" running = [] for sim_id, process in cls._processes.items(): if process.poll() is None: running.append(sim_id) return running # ============== Interview feature ============== @classmethod def check_env_alive(cls, simulation_id: str) -> bool: """ Check whether the simulation environment is alive and able to receive interview commands. Args: simulation_id: Simulation ID. Returns: True if the environment is alive, False if it has shut down. """ sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) if not os.path.exists(sim_dir): return False ipc_client = SimulationIPCClient(sim_dir) return ipc_client.check_env_alive() @classmethod def get_env_status_detail(cls, simulation_id: str) -> Dict[str, Any]: """ Return detailed status info for the simulation environment. Args: simulation_id: Simulation ID. Returns: Status dict containing status, twitter_available, reddit_available, timestamp. """ sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) status_file = os.path.join(sim_dir, "env_status.json") default_status = { "status": "stopped", "twitter_available": False, "reddit_available": False, "timestamp": None } if not os.path.exists(status_file): return default_status try: with open(status_file, 'r', encoding='utf-8') as f: status = json.load(f) return { "status": status.get("status", "stopped"), "twitter_available": status.get("twitter_available", False), "reddit_available": status.get("reddit_available", False), "timestamp": status.get("timestamp") } except (json.JSONDecodeError, OSError): return default_status @classmethod def interview_agent( cls, simulation_id: str, agent_id: int, prompt: str, platform: str = None, timeout: float = 60.0 ) -> Dict[str, Any]: """ Interview a single agent. Args: simulation_id: Simulation ID. agent_id: Agent ID. prompt: Interview question. platform: Optional platform selector. - "twitter": only interview the agent on Twitter. - "reddit": only interview the agent on Reddit. - None: in dual-platform runs, interview both platforms and return a merged result. timeout: Timeout in seconds. Returns: Interview result dict. Raises: ValueError: Simulation does not exist or its environment is not running. TimeoutError: Timed out waiting for the response. """ sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) if not os.path.exists(sim_dir): raise ValueError(f"模拟不存在: {simulation_id}") ipc_client = SimulationIPCClient(sim_dir) if not ipc_client.check_env_alive(): raise ValueError(f"模拟环境未运行或已关闭,无法执行Interview: {simulation_id}") logger.info(t("log.simulation_runner.m036", simulation_id=simulation_id, agent_id=agent_id, platform=platform)) response = ipc_client.send_interview( agent_id=agent_id, prompt=prompt, platform=platform, timeout=timeout ) if response.status.value == "completed": return { "success": True, "agent_id": agent_id, "prompt": prompt, "result": response.result, "timestamp": response.timestamp } else: return { "success": False, "agent_id": agent_id, "prompt": prompt, "error": response.error, "timestamp": response.timestamp } @classmethod def interview_agents_batch( cls, simulation_id: str, interviews: List[Dict[str, Any]], platform: str = None, timeout: float = 120.0 ) -> Dict[str, Any]: """ Interview multiple agents in batch. Args: simulation_id: Simulation ID. interviews: Interview list; each entry is {"agent_id": int, "prompt": str, "platform": str (optional)}. platform: Optional default platform (overridden per-interview by an entry's own `platform`). - "twitter": default to interviewing only Twitter. - "reddit": default to interviewing only Reddit. - None: in dual-platform runs, interview every agent on both platforms. timeout: Timeout in seconds. Returns: Batch interview result dict. Raises: ValueError: Simulation does not exist or its environment is not running. TimeoutError: Timed out waiting for the response. """ sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) if not os.path.exists(sim_dir): raise ValueError(f"模拟不存在: {simulation_id}") ipc_client = SimulationIPCClient(sim_dir) if not ipc_client.check_env_alive(): raise ValueError(f"模拟环境未运行或已关闭,无法执行Interview: {simulation_id}") logger.info(t("log.simulation_runner.m037", simulation_id=simulation_id, len=len(interviews), platform=platform)) response = ipc_client.send_batch_interview( interviews=interviews, platform=platform, timeout=timeout ) if response.status.value == "completed": return { "success": True, "interviews_count": len(interviews), "result": response.result, "timestamp": response.timestamp } else: return { "success": False, "interviews_count": len(interviews), "error": response.error, "timestamp": response.timestamp } @classmethod def interview_all_agents( cls, simulation_id: str, prompt: str, platform: str = None, timeout: float = 180.0 ) -> Dict[str, Any]: """ Interview every agent in the simulation (global interview). Sends the same prompt to every agent in the simulation. Args: simulation_id: Simulation ID. prompt: Interview question used for every agent. platform: Optional platform selector. - "twitter": only interview Twitter. - "reddit": only interview Reddit. - None: in dual-platform runs, interview every agent on both platforms. timeout: Timeout in seconds. Returns: Global interview result dict. """ sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) if not os.path.exists(sim_dir): raise ValueError(f"模拟不存在: {simulation_id}") # Read every agent from the simulation config. config_path = os.path.join(sim_dir, "simulation_config.json") if not os.path.exists(config_path): raise ValueError(f"模拟配置不存在: {simulation_id}") with open(config_path, 'r', encoding='utf-8') as f: config = json.load(f) agent_configs = config.get("agent_configs", []) if not agent_configs: raise ValueError(f"模拟配置中没有Agent: {simulation_id}") # Build the batch-interview payload. interviews = [] for agent_config in agent_configs: agent_id = agent_config.get("agent_id") if agent_id is not None: interviews.append({ "agent_id": agent_id, "prompt": prompt }) logger.info(t("log.simulation_runner.m038", simulation_id=simulation_id, len=len(interviews), platform=platform)) return cls.interview_agents_batch( simulation_id=simulation_id, interviews=interviews, platform=platform, timeout=timeout ) @classmethod def close_simulation_env( cls, simulation_id: str, timeout: float = 30.0 ) -> Dict[str, Any]: """ Close the simulation environment (does not stop the simulation subprocess). Sends a close-environment command to the simulation so it exits its wait-for-command mode gracefully. Args: simulation_id: Simulation ID. timeout: Timeout in seconds. Returns: Operation-result dict. """ sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) if not os.path.exists(sim_dir): raise ValueError(f"模拟不存在: {simulation_id}") ipc_client = SimulationIPCClient(sim_dir) if not ipc_client.check_env_alive(): return { "success": True, "message": "环境已经关闭" } logger.info(t("log.simulation_runner.m039", simulation_id=simulation_id)) try: response = ipc_client.send_close_env(timeout=timeout) return { "success": response.status.value == "completed", "message": "环境关闭命令已发送", "result": response.result, "timestamp": response.timestamp } except TimeoutError: # Timing out can simply mean the environment is already shutting down. return { "success": True, "message": "环境关闭命令已发送(等待响应超时,环境可能正在关闭)" } @classmethod def _get_interview_history_from_db( cls, db_path: str, platform_name: str, agent_id: Optional[int] = None, limit: int = 100 ) -> List[Dict[str, Any]]: """Read the interview history from a single per-platform database.""" import sqlite3 if not os.path.exists(db_path): return [] results = [] try: conn = sqlite3.connect(db_path) cursor = conn.cursor() if agent_id is not None: cursor.execute(""" SELECT user_id, info, created_at FROM trace WHERE action = 'interview' AND user_id = ? ORDER BY created_at DESC LIMIT ? """, (agent_id, limit)) else: cursor.execute(""" SELECT user_id, info, created_at FROM trace WHERE action = 'interview' ORDER BY created_at DESC LIMIT ? """, (limit,)) for user_id, info_json, created_at in cursor.fetchall(): try: info = json.loads(info_json) if info_json else {} except json.JSONDecodeError: info = {"raw": info_json} results.append({ "agent_id": user_id, "response": info.get("response", info), "prompt": info.get("prompt", ""), "timestamp": created_at, "platform": platform_name }) conn.close() except Exception as e: logger.error(t("log.simulation_runner.m040", platform_name=platform_name, e=e)) return results @classmethod def get_interview_history( cls, simulation_id: str, platform: str = None, agent_id: Optional[int] = None, limit: int = 100 ) -> List[Dict[str, Any]]: """ Return the interview history (read from the per-platform databases). Args: simulation_id: Simulation ID. platform: Platform selector (reddit/twitter/None). - "reddit": only return Reddit history. - "twitter": only return Twitter history. - None: return history from both platforms. agent_id: Optional agent-id filter; if set, only that agent's history is returned. limit: Max number of records per platform. Returns: Interview-history list. """ sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id) results = [] # Decide which platform databases to query. if platform in ("reddit", "twitter"): platforms = [platform] else: # No platform specified: query both. platforms = ["twitter", "reddit"] for p in platforms: db_path = os.path.join(sim_dir, f"{p}_simulation.db") platform_results = cls._get_interview_history_from_db( db_path=db_path, platform_name=p, agent_id=agent_id, limit=limit ) results.extend(platform_results) # Newest-first by timestamp. results.sort(key=lambda x: x.get("timestamp", ""), reverse=True) # When multiple platforms were queried, cap the merged result size. if len(platforms) > 1 and len(results) > limit: results = results[:limit] return results