MicroFish/backend/app/services/simulation_runner.py

1738 lines
67 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
OASIS simulation runner.
Runs the simulation in the background, records each agent's actions, and supports real-time status monitoring.
"""
import os
import sys
import json
import time
import asyncio
import threading
import subprocess
import signal
import atexit
from typing import Dict, Any, List, Optional, Union
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from queue import Queue
from ..config import Config
from ..utils.logger import get_logger
from ..utils.locale import get_locale, set_locale, t
from .zep_graph_memory_updater import ZepGraphMemoryManager
from .simulation_ipc import SimulationIPCClient, CommandType, IPCResponse
logger = get_logger('mirofish.simulation_runner')
# Tracks whether the cleanup handler has been registered (guards against double registration in Flask reloader).
_cleanup_registered = False
IS_WINDOWS = sys.platform == 'win32'
class RunnerStatus(str, Enum):
"""Runner lifecycle states."""
IDLE = "idle"
STARTING = "starting"
RUNNING = "running"
PAUSED = "paused"
STOPPING = "stopping"
STOPPED = "stopped"
COMPLETED = "completed"
FAILED = "failed"
@dataclass
class AgentAction:
"""A single recorded agent action."""
round_num: int
timestamp: str
platform: str # twitter / reddit
agent_id: int
agent_name: str
action_type: str # CREATE_POST, LIKE_POST, etc.
action_args: Dict[str, Any] = field(default_factory=dict)
result: Optional[str] = None
success: bool = True
def to_dict(self) -> Dict[str, Any]:
return {
"round_num": self.round_num,
"timestamp": self.timestamp,
"platform": self.platform,
"agent_id": self.agent_id,
"agent_name": self.agent_name,
"action_type": self.action_type,
"action_args": self.action_args,
"result": self.result,
"success": self.success,
}
@dataclass
class RoundSummary:
"""Per-round summary statistics."""
round_num: int
start_time: str
end_time: Optional[str] = None
simulated_hour: int = 0
twitter_actions: int = 0
reddit_actions: int = 0
active_agents: List[int] = field(default_factory=list)
actions: List[AgentAction] = field(default_factory=list)
def to_dict(self) -> Dict[str, Any]:
return {
"round_num": self.round_num,
"start_time": self.start_time,
"end_time": self.end_time,
"simulated_hour": self.simulated_hour,
"twitter_actions": self.twitter_actions,
"reddit_actions": self.reddit_actions,
"active_agents": self.active_agents,
"actions_count": len(self.actions),
"actions": [a.to_dict() for a in self.actions],
}
@dataclass
class SimulationRunState:
"""Live runtime state for a simulation."""
simulation_id: str
runner_status: RunnerStatus = RunnerStatus.IDLE
current_round: int = 0
total_rounds: int = 0
simulated_hours: int = 0
total_simulation_hours: int = 0
# Per-platform round and simulated-time counters (used when both platforms run in parallel).
twitter_current_round: int = 0
reddit_current_round: int = 0
twitter_simulated_hours: int = 0
reddit_simulated_hours: int = 0
twitter_running: bool = False
reddit_running: bool = False
twitter_actions_count: int = 0
reddit_actions_count: int = 0
# Per-platform completion flags, set when a simulation_end event is observed in actions.jsonl.
twitter_completed: bool = False
reddit_completed: bool = False
rounds: List[RoundSummary] = field(default_factory=list)
# Recent actions buffer; surfaced to the frontend for the live feed.
recent_actions: List[AgentAction] = field(default_factory=list)
max_recent_actions: int = 50
started_at: Optional[str] = None
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
completed_at: Optional[str] = None
error: Optional[str] = None
# Main subprocess PID — captured so the process can later be stopped.
process_pid: Optional[int] = None
def add_action(self, action: AgentAction):
"""Prepend an action to the recent-actions buffer and update counters."""
self.recent_actions.insert(0, action)
if len(self.recent_actions) > self.max_recent_actions:
self.recent_actions = self.recent_actions[:self.max_recent_actions]
if action.platform == "twitter":
self.twitter_actions_count += 1
else:
self.reddit_actions_count += 1
self.updated_at = datetime.now().isoformat()
def to_dict(self) -> Dict[str, Any]:
return {
"simulation_id": self.simulation_id,
"runner_status": self.runner_status.value,
"current_round": self.current_round,
"total_rounds": self.total_rounds,
"simulated_hours": self.simulated_hours,
"total_simulation_hours": self.total_simulation_hours,
"progress_percent": round(self.current_round / max(self.total_rounds, 1) * 100, 1),
# Per-platform round and simulated-time counters.
"twitter_current_round": self.twitter_current_round,
"reddit_current_round": self.reddit_current_round,
"twitter_simulated_hours": self.twitter_simulated_hours,
"reddit_simulated_hours": self.reddit_simulated_hours,
"twitter_running": self.twitter_running,
"reddit_running": self.reddit_running,
"twitter_completed": self.twitter_completed,
"reddit_completed": self.reddit_completed,
"twitter_actions_count": self.twitter_actions_count,
"reddit_actions_count": self.reddit_actions_count,
"total_actions_count": self.twitter_actions_count + self.reddit_actions_count,
"started_at": self.started_at,
"updated_at": self.updated_at,
"completed_at": self.completed_at,
"error": self.error,
"process_pid": self.process_pid,
}
def to_detail_dict(self) -> Dict[str, Any]:
"""Return the dict form of the state including recent actions."""
result = self.to_dict()
result["recent_actions"] = [a.to_dict() for a in self.recent_actions]
result["rounds_count"] = len(self.rounds)
return result
class SimulationRunner:
"""
Simulation runner.
Responsibilities:
1. Run the OASIS simulation in a background subprocess.
2. Parse the run logs and record each agent's actions.
3. Provide real-time status query interfaces.
4. Support pause/stop/resume operations.
"""
RUN_STATE_DIR = os.path.join(
os.path.dirname(__file__),
'../../uploads/simulations'
)
SCRIPTS_DIR = os.path.join(
os.path.dirname(__file__),
'../../scripts'
)
# In-memory caches of runtime state, processes, queues, monitor threads, and log file handles.
_run_states: Dict[str, SimulationRunState] = {}
_processes: Dict[str, subprocess.Popen] = {}
_action_queues: Dict[str, Queue] = {}
_monitor_threads: Dict[str, threading.Thread] = {}
_stdout_files: Dict[str, Any] = {}
_stderr_files: Dict[str, Any] = {}
# Graph-memory-update flag per simulation_id.
_graph_memory_enabled: Dict[str, bool] = {}
@classmethod
def get_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]:
"""Return the cached run state, falling back to disk if not loaded yet."""
if simulation_id in cls._run_states:
return cls._run_states[simulation_id]
state = cls._load_run_state(simulation_id)
if state:
cls._run_states[simulation_id] = state
return state
@classmethod
def _load_run_state(cls, simulation_id: str) -> Optional[SimulationRunState]:
"""Load run state from the on-disk JSON snapshot."""
state_file = os.path.join(cls.RUN_STATE_DIR, simulation_id, "run_state.json")
if not os.path.exists(state_file):
return None
try:
with open(state_file, 'r', encoding='utf-8') as f:
data = json.load(f)
state = SimulationRunState(
simulation_id=simulation_id,
runner_status=RunnerStatus(data.get("runner_status", "idle")),
current_round=data.get("current_round", 0),
total_rounds=data.get("total_rounds", 0),
simulated_hours=data.get("simulated_hours", 0),
total_simulation_hours=data.get("total_simulation_hours", 0),
# Per-platform round and simulated-time counters.
twitter_current_round=data.get("twitter_current_round", 0),
reddit_current_round=data.get("reddit_current_round", 0),
twitter_simulated_hours=data.get("twitter_simulated_hours", 0),
reddit_simulated_hours=data.get("reddit_simulated_hours", 0),
twitter_running=data.get("twitter_running", False),
reddit_running=data.get("reddit_running", False),
twitter_completed=data.get("twitter_completed", False),
reddit_completed=data.get("reddit_completed", False),
twitter_actions_count=data.get("twitter_actions_count", 0),
reddit_actions_count=data.get("reddit_actions_count", 0),
started_at=data.get("started_at"),
updated_at=data.get("updated_at", datetime.now().isoformat()),
completed_at=data.get("completed_at"),
error=data.get("error"),
process_pid=data.get("process_pid"),
)
# Restore the recent-actions buffer.
actions_data = data.get("recent_actions", [])
for a in actions_data:
state.recent_actions.append(AgentAction(
round_num=a.get("round_num", 0),
timestamp=a.get("timestamp", ""),
platform=a.get("platform", ""),
agent_id=a.get("agent_id", 0),
agent_name=a.get("agent_name", ""),
action_type=a.get("action_type", ""),
action_args=a.get("action_args", {}),
result=a.get("result"),
success=a.get("success", True),
))
return state
except Exception as e:
logger.error(t("log.simulation_runner.m001", str=str(e)))
return None
@classmethod
def _save_run_state(cls, state: SimulationRunState):
"""Persist the run state to its JSON snapshot file."""
sim_dir = os.path.join(cls.RUN_STATE_DIR, state.simulation_id)
os.makedirs(sim_dir, exist_ok=True)
state_file = os.path.join(sim_dir, "run_state.json")
data = state.to_detail_dict()
with open(state_file, 'w', encoding='utf-8') as f:
json.dump(data, f, ensure_ascii=False, indent=2)
cls._run_states[state.simulation_id] = state
@classmethod
def start_simulation(
cls,
simulation_id: str,
platform: str = "parallel", # twitter / reddit / parallel
max_rounds: int = None, # Optional cap on simulation rounds (truncates overly long runs).
enable_graph_memory_update: bool = False, # Whether to push activity into the Zep graph.
graph_id: str = None # Zep graph ID (required when graph-memory updates are enabled).
) -> SimulationRunState:
"""
Start the simulation.
Args:
simulation_id: Simulation ID.
platform: Platform to run (twitter/reddit/parallel).
max_rounds: Optional cap on simulation rounds (truncates overly long runs).
enable_graph_memory_update: Whether to push agent activity to the Zep graph in real time.
graph_id: Zep graph ID (required when graph-memory updates are enabled).
Returns:
SimulationRunState
"""
# Refuse to start a duplicate run for the same simulation_id.
existing = cls.get_run_state(simulation_id)
if existing and existing.runner_status in [RunnerStatus.RUNNING, RunnerStatus.STARTING]:
raise ValueError(f"模拟已在运行中: {simulation_id}")
# Load the simulation configuration written during preparation.
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
config_path = os.path.join(sim_dir, "simulation_config.json")
if not os.path.exists(config_path):
raise ValueError(f"模拟配置不存在,请先调用 /prepare 接口")
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
# Compute total rounds from time-window settings.
time_config = config.get("time_config", {})
total_hours = time_config.get("total_simulation_hours", 72)
minutes_per_round = time_config.get("minutes_per_round", 30)
total_rounds = int(total_hours * 60 / minutes_per_round)
# If a cap was provided, clamp total_rounds.
if max_rounds is not None and max_rounds > 0:
original_rounds = total_rounds
total_rounds = min(total_rounds, max_rounds)
if total_rounds < original_rounds:
logger.info(t("log.simulation_runner.m002", original_rounds=original_rounds, total_rounds=total_rounds, max_rounds=max_rounds))
state = SimulationRunState(
simulation_id=simulation_id,
runner_status=RunnerStatus.STARTING,
total_rounds=total_rounds,
total_simulation_hours=total_hours,
started_at=datetime.now().isoformat(),
)
cls._save_run_state(state)
# Spin up a graph-memory updater if requested.
if enable_graph_memory_update:
if not graph_id:
raise ValueError("启用图谱记忆更新时必须提供 graph_id")
try:
ZepGraphMemoryManager.create_updater(simulation_id, graph_id)
cls._graph_memory_enabled[simulation_id] = True
logger.info(t("log.simulation_runner.m003", simulation_id=simulation_id, graph_id=graph_id))
except Exception as e:
logger.error(t("log.simulation_runner.m004", e=e))
cls._graph_memory_enabled[simulation_id] = False
else:
cls._graph_memory_enabled[simulation_id] = False
# Pick the entry script (lives in backend/scripts/) based on the requested platform.
if platform == "twitter":
script_name = "run_twitter_simulation.py"
state.twitter_running = True
elif platform == "reddit":
script_name = "run_reddit_simulation.py"
state.reddit_running = True
else:
script_name = "run_parallel_simulation.py"
state.twitter_running = True
state.reddit_running = True
script_path = os.path.join(cls.SCRIPTS_DIR, script_name)
if not os.path.exists(script_path):
raise ValueError(f"脚本不存在: {script_path}")
action_queue = Queue()
cls._action_queues[simulation_id] = action_queue
try:
# Log layout written by the subprocess:
# twitter/actions.jsonl - Twitter action log
# reddit/actions.jsonl - Reddit action log
# simulation.log - main-process log
cmd = [
sys.executable,
script_path,
"--config", config_path,
]
if max_rounds is not None and max_rounds > 0:
cmd.extend(["--max-rounds", str(max_rounds)])
# Redirect stdout/stderr to a file so a full pipe buffer cannot block the subprocess.
main_log_path = os.path.join(sim_dir, "simulation.log")
main_log_file = open(main_log_path, 'w', encoding='utf-8')
# Force UTF-8 in the child so third-party libs (e.g. OASIS) that open files without an
# explicit encoding work correctly on Windows.
env = os.environ.copy()
env['PYTHONUTF8'] = '1'
env['PYTHONIOENCODING'] = 'utf-8'
# cwd is the simulation directory so generated artifacts (databases, etc.) land there.
# start_new_session=True creates a fresh process group so os.killpg can terminate the
# entire tree on shutdown.
process = subprocess.Popen(
cmd,
cwd=sim_dir,
stdout=main_log_file,
stderr=subprocess.STDOUT,
text=True,
encoding='utf-8',
bufsize=1,
env=env,
start_new_session=True,
)
# Retain the log file handle so it can be closed after the subprocess exits.
cls._stdout_files[simulation_id] = main_log_file
cls._stderr_files[simulation_id] = None
state.process_pid = process.pid
state.runner_status = RunnerStatus.RUNNING
cls._processes[simulation_id] = process
cls._save_run_state(state)
# Capture locale before spawning monitor thread
current_locale = get_locale()
# Spawn the log-tailing monitor thread.
monitor_thread = threading.Thread(
target=cls._monitor_simulation,
args=(simulation_id, current_locale),
daemon=True
)
monitor_thread.start()
cls._monitor_threads[simulation_id] = monitor_thread
logger.info(t("log.simulation_runner.m005", simulation_id=simulation_id, process=process.pid, platform=platform))
except Exception as e:
state.runner_status = RunnerStatus.FAILED
state.error = str(e)
cls._save_run_state(state)
raise
return state
@classmethod
def _monitor_simulation(cls, simulation_id: str, locale: str = 'zh'):
"""Monitor the simulation process and tail its per-platform action logs."""
set_locale(locale)
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
twitter_actions_log = os.path.join(sim_dir, "twitter", "actions.jsonl")
reddit_actions_log = os.path.join(sim_dir, "reddit", "actions.jsonl")
process = cls._processes.get(simulation_id)
state = cls.get_run_state(simulation_id)
if not process or not state:
return
twitter_position = 0
reddit_position = 0
try:
while process.poll() is None:
if os.path.exists(twitter_actions_log):
twitter_position = cls._read_action_log(
twitter_actions_log, twitter_position, state, "twitter"
)
if os.path.exists(reddit_actions_log):
reddit_position = cls._read_action_log(
reddit_actions_log, reddit_position, state, "reddit"
)
cls._save_run_state(state)
time.sleep(2)
# Drain any log lines written between the last poll and the process exit.
if os.path.exists(twitter_actions_log):
cls._read_action_log(twitter_actions_log, twitter_position, state, "twitter")
if os.path.exists(reddit_actions_log):
cls._read_action_log(reddit_actions_log, reddit_position, state, "reddit")
exit_code = process.returncode
if exit_code == 0:
state.runner_status = RunnerStatus.COMPLETED
state.completed_at = datetime.now().isoformat()
logger.info(t("log.simulation_runner.m006", simulation_id=simulation_id))
else:
state.runner_status = RunnerStatus.FAILED
# Pull the tail of the main log so the failure context is surfaced in state.error.
main_log_path = os.path.join(sim_dir, "simulation.log")
error_info = ""
try:
if os.path.exists(main_log_path):
with open(main_log_path, 'r', encoding='utf-8') as f:
error_info = f.read()[-2000:] # keep only the last 2000 chars
except Exception:
pass
state.error = f"进程退出码: {exit_code}, 错误: {error_info}"
logger.error(t("log.simulation_runner.m007", simulation_id=simulation_id, state=state.error))
state.twitter_running = False
state.reddit_running = False
cls._save_run_state(state)
except Exception as e:
logger.error(t("log.simulation_runner.m008", simulation_id=simulation_id, str=str(e)))
state.runner_status = RunnerStatus.FAILED
state.error = str(e)
cls._save_run_state(state)
finally:
# Tear down the graph-memory updater, if we started one.
if cls._graph_memory_enabled.get(simulation_id, False):
try:
ZepGraphMemoryManager.stop_updater(simulation_id)
logger.info(t("log.simulation_runner.m009", simulation_id=simulation_id))
except Exception as e:
logger.error(t("log.simulation_runner.m010", e=e))
cls._graph_memory_enabled.pop(simulation_id, None)
cls._processes.pop(simulation_id, None)
cls._action_queues.pop(simulation_id, None)
# Close the retained log file handles.
if simulation_id in cls._stdout_files:
try:
cls._stdout_files[simulation_id].close()
except Exception:
pass
cls._stdout_files.pop(simulation_id, None)
if simulation_id in cls._stderr_files and cls._stderr_files[simulation_id]:
try:
cls._stderr_files[simulation_id].close()
except Exception:
pass
cls._stderr_files.pop(simulation_id, None)
@classmethod
def _read_action_log(
cls,
log_path: str,
position: int,
state: SimulationRunState,
platform: str
) -> int:
"""
Read new entries from a per-platform action log.
Args:
log_path: Path to the action-log file.
position: Byte offset where the previous read finished.
state: Run-state object to mutate.
platform: Platform name (twitter/reddit).
Returns:
New byte offset after this read.
"""
graph_memory_enabled = cls._graph_memory_enabled.get(state.simulation_id, False)
graph_updater = None
if graph_memory_enabled:
graph_updater = ZepGraphMemoryManager.get_updater(state.simulation_id)
try:
with open(log_path, 'r', encoding='utf-8') as f:
f.seek(position)
for line in f:
line = line.strip()
if line:
try:
action_data = json.loads(line)
# Event records (simulation_start/end, round_end, ...) are routed here.
if "event_type" in action_data:
event_type = action_data.get("event_type")
# simulation_end means the platform finished its run.
if event_type == "simulation_end":
if platform == "twitter":
state.twitter_completed = True
state.twitter_running = False
logger.info(t("log.simulation_runner.m011", state=state.simulation_id, action_data=action_data.get('total_rounds'), action_data_2=action_data.get('total_actions')))
elif platform == "reddit":
state.reddit_completed = True
state.reddit_running = False
logger.info(t("log.simulation_runner.m012", state=state.simulation_id, action_data=action_data.get('total_rounds'), action_data_2=action_data.get('total_actions')))
# Mark the run as completed once every enabled platform has reported
# simulation_end. Single-platform runs only need that one.
all_completed = cls._check_all_platforms_completed(state)
if all_completed:
state.runner_status = RunnerStatus.COMPLETED
state.completed_at = datetime.now().isoformat()
logger.info(t("log.simulation_runner.m013", state=state.simulation_id))
# Round counters come from round_end events.
elif event_type == "round_end":
round_num = action_data.get("round", 0)
simulated_hours = action_data.get("simulated_hours", 0)
if platform == "twitter":
if round_num > state.twitter_current_round:
state.twitter_current_round = round_num
state.twitter_simulated_hours = simulated_hours
elif platform == "reddit":
if round_num > state.reddit_current_round:
state.reddit_current_round = round_num
state.reddit_simulated_hours = simulated_hours
# Overall counters track the max across enabled platforms.
if round_num > state.current_round:
state.current_round = round_num
state.simulated_hours = max(state.twitter_simulated_hours, state.reddit_simulated_hours)
continue
action = AgentAction(
round_num=action_data.get("round", 0),
timestamp=action_data.get("timestamp", datetime.now().isoformat()),
platform=platform,
agent_id=action_data.get("agent_id", 0),
agent_name=action_data.get("agent_name", ""),
action_type=action_data.get("action_type", ""),
action_args=action_data.get("action_args", {}),
result=action_data.get("result"),
success=action_data.get("success", True),
)
state.add_action(action)
if action.round_num and action.round_num > state.current_round:
state.current_round = action.round_num
# Forward the activity to the Zep graph when the updater is enabled.
if graph_updater:
graph_updater.add_activity_from_dict(action_data, platform)
except json.JSONDecodeError:
pass
return f.tell()
except Exception as e:
logger.warning(t("log.simulation_runner.m014", log_path=log_path, e=e))
return position
@classmethod
def _check_all_platforms_completed(cls, state: SimulationRunState) -> bool:
"""
Return whether every enabled platform has completed its simulation.
A platform counts as enabled when its corresponding actions.jsonl file exists on disk.
Returns:
True if all enabled platforms have completed.
"""
sim_dir = os.path.join(cls.RUN_STATE_DIR, state.simulation_id)
twitter_log = os.path.join(sim_dir, "twitter", "actions.jsonl")
reddit_log = os.path.join(sim_dir, "reddit", "actions.jsonl")
# File presence is our enabled-platform signal.
twitter_enabled = os.path.exists(twitter_log)
reddit_enabled = os.path.exists(reddit_log)
if twitter_enabled and not state.twitter_completed:
return False
if reddit_enabled and not state.reddit_completed:
return False
# At least one platform must be enabled (and, by the checks above, completed).
return twitter_enabled or reddit_enabled
@classmethod
def _terminate_process(cls, process: subprocess.Popen, simulation_id: str, timeout: int = 10):
"""
Terminate a process and its subprocesses in a cross-platform way.
Args:
process: Process to terminate.
simulation_id: Simulation ID (used for log messages).
timeout: Seconds to wait for graceful exit before escalating.
"""
if IS_WINDOWS:
# Windows: taskkill /T tears down the whole process tree, /F escalates to a hard kill.
logger.info(t("log.simulation_runner.m015", simulation_id=simulation_id, process=process.pid))
try:
# Graceful termination first.
subprocess.run(
['taskkill', '/PID', str(process.pid), '/T'],
capture_output=True,
timeout=5
)
try:
process.wait(timeout=timeout)
except subprocess.TimeoutExpired:
# Force kill the tree.
logger.warning(t("log.simulation_runner.m016", simulation_id=simulation_id))
subprocess.run(
['taskkill', '/F', '/PID', str(process.pid), '/T'],
capture_output=True,
timeout=5
)
process.wait(timeout=5)
except Exception as e:
logger.warning(t("log.simulation_runner.m017", e=e))
process.terminate()
try:
process.wait(timeout=5)
except subprocess.TimeoutExpired:
process.kill()
else:
# Unix: kill the entire process group.
# Because the subprocess was started with start_new_session=True the pgid equals the PID.
pgid = os.getpgid(process.pid)
logger.info(t("log.simulation_runner.m018", simulation_id=simulation_id, pgid=pgid))
# SIGTERM first to allow graceful shutdown.
os.killpg(pgid, signal.SIGTERM)
try:
process.wait(timeout=timeout)
except subprocess.TimeoutExpired:
# Escalate to SIGKILL on timeout.
logger.warning(t("log.simulation_runner.m019", simulation_id=simulation_id))
os.killpg(pgid, signal.SIGKILL)
process.wait(timeout=5)
@classmethod
def stop_simulation(cls, simulation_id: str) -> SimulationRunState:
"""Stop the simulation subprocess and update its state."""
state = cls.get_run_state(simulation_id)
if not state:
raise ValueError(f"模拟不存在: {simulation_id}")
if state.runner_status not in [RunnerStatus.RUNNING, RunnerStatus.PAUSED]:
raise ValueError(f"模拟未在运行: {simulation_id}, status={state.runner_status}")
state.runner_status = RunnerStatus.STOPPING
cls._save_run_state(state)
process = cls._processes.get(simulation_id)
if process and process.poll() is None:
try:
cls._terminate_process(process, simulation_id)
except ProcessLookupError:
# The process has already exited.
pass
except Exception as e:
logger.error(t("log.simulation_runner.m020", simulation_id=simulation_id, e=e))
# Fall back to direct termination on the Popen handle.
try:
process.terminate()
process.wait(timeout=5)
except Exception:
process.kill()
state.runner_status = RunnerStatus.STOPPED
state.twitter_running = False
state.reddit_running = False
state.completed_at = datetime.now().isoformat()
cls._save_run_state(state)
# Tear down the graph-memory updater, if any.
if cls._graph_memory_enabled.get(simulation_id, False):
try:
ZepGraphMemoryManager.stop_updater(simulation_id)
logger.info(t("log.simulation_runner.m021", simulation_id=simulation_id))
except Exception as e:
logger.error(t("log.simulation_runner.m022", e=e))
cls._graph_memory_enabled.pop(simulation_id, None)
logger.info(t("log.simulation_runner.m023", simulation_id=simulation_id))
return state
@classmethod
def _read_actions_from_file(
cls,
file_path: str,
default_platform: Optional[str] = None,
platform_filter: Optional[str] = None,
agent_id: Optional[int] = None,
round_num: Optional[int] = None
) -> List[AgentAction]:
"""
Read actions from a single action-log file.
Args:
file_path: Path to the action-log file.
default_platform: Platform to assume when a record has no `platform` field.
platform_filter: Optional platform filter.
agent_id: Optional agent-id filter.
round_num: Optional round-number filter.
"""
if not os.path.exists(file_path):
return []
actions = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
line = line.strip()
if not line:
continue
try:
data = json.loads(line)
# Skip event records (simulation_start, round_start, round_end, ...).
if "event_type" in data:
continue
# Skip records without an agent_id (non-agent actions).
if "agent_id" not in data:
continue
# Prefer the record's own platform; fall back to the default for legacy entries.
record_platform = data.get("platform") or default_platform or ""
if platform_filter and record_platform != platform_filter:
continue
if agent_id is not None and data.get("agent_id") != agent_id:
continue
if round_num is not None and data.get("round") != round_num:
continue
actions.append(AgentAction(
round_num=data.get("round", 0),
timestamp=data.get("timestamp", ""),
platform=record_platform,
agent_id=data.get("agent_id", 0),
agent_name=data.get("agent_name", ""),
action_type=data.get("action_type", ""),
action_args=data.get("action_args", {}),
result=data.get("result"),
success=data.get("success", True),
))
except json.JSONDecodeError:
continue
return actions
@classmethod
def get_all_actions(
cls,
simulation_id: str,
platform: Optional[str] = None,
agent_id: Optional[int] = None,
round_num: Optional[int] = None
) -> List[AgentAction]:
"""
Return the complete action history across all platforms (no pagination).
Args:
simulation_id: Simulation ID.
platform: Optional platform filter (twitter/reddit).
agent_id: Optional agent filter.
round_num: Optional round filter.
Returns:
Full action list, sorted by timestamp with newest first.
"""
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
actions = []
# Twitter action log: derive platform from the file path.
twitter_actions_log = os.path.join(sim_dir, "twitter", "actions.jsonl")
if not platform or platform == "twitter":
actions.extend(cls._read_actions_from_file(
twitter_actions_log,
default_platform="twitter",
platform_filter=platform,
agent_id=agent_id,
round_num=round_num
))
# Reddit action log: derive platform from the file path.
reddit_actions_log = os.path.join(sim_dir, "reddit", "actions.jsonl")
if not platform or platform == "reddit":
actions.extend(cls._read_actions_from_file(
reddit_actions_log,
default_platform="reddit",
platform_filter=platform,
agent_id=agent_id,
round_num=round_num
))
# Fall back to the legacy single-file layout if no per-platform files exist.
if not actions:
actions_log = os.path.join(sim_dir, "actions.jsonl")
actions = cls._read_actions_from_file(
actions_log,
default_platform=None, # Legacy files carry their own platform field.
platform_filter=platform,
agent_id=agent_id,
round_num=round_num
)
# Newest-first by timestamp.
actions.sort(key=lambda x: x.timestamp, reverse=True)
return actions
@classmethod
def get_actions(
cls,
simulation_id: str,
limit: int = 100,
offset: int = 0,
platform: Optional[str] = None,
agent_id: Optional[int] = None,
round_num: Optional[int] = None
) -> List[AgentAction]:
"""
Return action history with pagination.
Args:
simulation_id: Simulation ID.
limit: Maximum number of actions to return.
offset: Offset into the sorted result list.
platform: Optional platform filter.
agent_id: Optional agent filter.
round_num: Optional round filter.
Returns:
A page of actions.
"""
actions = cls.get_all_actions(
simulation_id=simulation_id,
platform=platform,
agent_id=agent_id,
round_num=round_num
)
return actions[offset:offset + limit]
@classmethod
def get_timeline(
cls,
simulation_id: str,
start_round: int = 0,
end_round: Optional[int] = None
) -> List[Dict[str, Any]]:
"""
Return a per-round timeline summary for the simulation.
Args:
simulation_id: Simulation ID.
start_round: First round to include (inclusive).
end_round: Last round to include (inclusive); None means no upper bound.
Returns:
One summary entry per round.
"""
actions = cls.get_actions(simulation_id, limit=10000)
# Group actions by round.
rounds: Dict[int, Dict[str, Any]] = {}
for action in actions:
round_num = action.round_num
if round_num < start_round:
continue
if end_round is not None and round_num > end_round:
continue
if round_num not in rounds:
rounds[round_num] = {
"round_num": round_num,
"twitter_actions": 0,
"reddit_actions": 0,
"active_agents": set(),
"action_types": {},
"first_action_time": action.timestamp,
"last_action_time": action.timestamp,
}
r = rounds[round_num]
if action.platform == "twitter":
r["twitter_actions"] += 1
else:
r["reddit_actions"] += 1
r["active_agents"].add(action.agent_id)
r["action_types"][action.action_type] = r["action_types"].get(action.action_type, 0) + 1
r["last_action_time"] = action.timestamp
# Materialise into a sorted list.
result = []
for round_num in sorted(rounds.keys()):
r = rounds[round_num]
result.append({
"round_num": round_num,
"twitter_actions": r["twitter_actions"],
"reddit_actions": r["reddit_actions"],
"total_actions": r["twitter_actions"] + r["reddit_actions"],
"active_agents_count": len(r["active_agents"]),
"active_agents": list(r["active_agents"]),
"action_types": r["action_types"],
"first_action_time": r["first_action_time"],
"last_action_time": r["last_action_time"],
})
return result
@classmethod
def get_agent_stats(cls, simulation_id: str) -> List[Dict[str, Any]]:
"""
Return per-agent statistics for the simulation.
Returns:
Per-agent statistics, sorted by total action count (descending).
"""
actions = cls.get_actions(simulation_id, limit=10000)
agent_stats: Dict[int, Dict[str, Any]] = {}
for action in actions:
agent_id = action.agent_id
if agent_id not in agent_stats:
agent_stats[agent_id] = {
"agent_id": agent_id,
"agent_name": action.agent_name,
"total_actions": 0,
"twitter_actions": 0,
"reddit_actions": 0,
"action_types": {},
"first_action_time": action.timestamp,
"last_action_time": action.timestamp,
}
stats = agent_stats[agent_id]
stats["total_actions"] += 1
if action.platform == "twitter":
stats["twitter_actions"] += 1
else:
stats["reddit_actions"] += 1
stats["action_types"][action.action_type] = stats["action_types"].get(action.action_type, 0) + 1
stats["last_action_time"] = action.timestamp
result = sorted(agent_stats.values(), key=lambda x: x["total_actions"], reverse=True)
return result
@classmethod
def cleanup_simulation_logs(cls, simulation_id: str) -> Dict[str, Any]:
"""
Clean up the simulation's run logs so the simulation can be force-restarted.
Deletes the following files:
- run_state.json
- twitter/actions.jsonl
- reddit/actions.jsonl
- simulation.log
- stdout.log / stderr.log
- twitter_simulation.db (simulation database)
- reddit_simulation.db (simulation database)
- env_status.json (environment status)
Note: simulation_config.json and the profile files are preserved.
Args:
simulation_id: Simulation ID.
Returns:
Cleanup result info.
"""
import shutil
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
if not os.path.exists(sim_dir):
return {"success": True, "message": "模拟目录不存在,无需清理"}
cleaned_files = []
errors = []
# Files to delete (includes per-platform databases).
files_to_delete = [
"run_state.json",
"simulation.log",
"stdout.log",
"stderr.log",
"twitter_simulation.db", # Twitter platform database.
"reddit_simulation.db", # Reddit platform database.
"env_status.json", # Environment-status file.
]
# Per-platform directories whose action logs should be cleaned.
dirs_to_clean = ["twitter", "reddit"]
for filename in files_to_delete:
file_path = os.path.join(sim_dir, filename)
if os.path.exists(file_path):
try:
os.remove(file_path)
cleaned_files.append(filename)
except Exception as e:
errors.append(f"删除 {filename} 失败: {str(e)}")
# Clean per-platform action logs.
for dir_name in dirs_to_clean:
dir_path = os.path.join(sim_dir, dir_name)
if os.path.exists(dir_path):
actions_file = os.path.join(dir_path, "actions.jsonl")
if os.path.exists(actions_file):
try:
os.remove(actions_file)
cleaned_files.append(f"{dir_name}/actions.jsonl")
except Exception as e:
errors.append(f"删除 {dir_name}/actions.jsonl 失败: {str(e)}")
# Drop the in-memory run state for this simulation.
if simulation_id in cls._run_states:
del cls._run_states[simulation_id]
logger.info(t("log.simulation_runner.m024", simulation_id=simulation_id, cleaned_files=cleaned_files))
return {
"success": len(errors) == 0,
"cleaned_files": cleaned_files,
"errors": errors if errors else None
}
# Guard so cleanup_all_simulations only runs once per process lifetime.
_cleanup_done = False
@classmethod
def cleanup_all_simulations(cls):
"""
Clean up every running simulation subprocess.
Invoked at server shutdown to guarantee no child processes leak.
"""
if cls._cleanup_done:
return
cls._cleanup_done = True
# Skip the "shutting down" log entirely if there's nothing to clean up.
has_processes = bool(cls._processes)
has_updaters = bool(cls._graph_memory_enabled)
if not has_processes and not has_updaters:
return
logger.info(t("log.simulation_runner.m025"))
# Stop graph-memory updaters first (stop_all logs internally).
try:
ZepGraphMemoryManager.stop_all()
except Exception as e:
logger.error(t("log.simulation_runner.m026", e=e))
cls._graph_memory_enabled.clear()
# Snapshot the process map so we can mutate it during iteration.
processes = list(cls._processes.items())
for simulation_id, process in processes:
try:
if process.poll() is None:
logger.info(t("log.simulation_runner.m027", simulation_id=simulation_id, process=process.pid))
try:
cls._terminate_process(process, simulation_id, timeout=5)
except (ProcessLookupError, OSError):
# The process may already be gone; fall back to direct termination.
try:
process.terminate()
process.wait(timeout=3)
except Exception:
process.kill()
# Update run_state.json so external readers see the stopped status.
state = cls.get_run_state(simulation_id)
if state:
state.runner_status = RunnerStatus.STOPPED
state.twitter_running = False
state.reddit_running = False
state.completed_at = datetime.now().isoformat()
state.error = "服务器关闭,模拟被终止"
cls._save_run_state(state)
# Also flip the project-level state.json status to "stopped".
try:
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
state_file = os.path.join(sim_dir, "state.json")
logger.info(t("log.simulation_runner.m028", state_file=state_file))
if os.path.exists(state_file):
with open(state_file, 'r', encoding='utf-8') as f:
state_data = json.load(f)
state_data['status'] = 'stopped'
state_data['updated_at'] = datetime.now().isoformat()
with open(state_file, 'w', encoding='utf-8') as f:
json.dump(state_data, f, indent=2, ensure_ascii=False)
logger.info(t("log.simulation_runner.m029", simulation_id=simulation_id))
else:
logger.warning(t("log.simulation_runner.m030", state_file=state_file))
except Exception as state_err:
logger.warning(t("log.simulation_runner.m031", simulation_id=simulation_id, state_err=state_err))
except Exception as e:
logger.error(t("log.simulation_runner.m032", simulation_id=simulation_id, e=e))
# Close any retained log file handles.
for simulation_id, file_handle in list(cls._stdout_files.items()):
try:
if file_handle:
file_handle.close()
except Exception:
pass
cls._stdout_files.clear()
for simulation_id, file_handle in list(cls._stderr_files.items()):
try:
if file_handle:
file_handle.close()
except Exception:
pass
cls._stderr_files.clear()
# Drop in-memory bookkeeping.
cls._processes.clear()
cls._action_queues.clear()
logger.info(t("log.simulation_runner.m033"))
@classmethod
def register_cleanup(cls):
"""
Register the shutdown cleanup hook.
Called at Flask application startup so that all simulation subprocesses are torn down
when the server stops.
"""
global _cleanup_registered
if _cleanup_registered:
return
# In Flask debug mode the reloader spawns a child process that actually runs the app
# (signaled by WERKZEUG_RUN_MAIN=true). Outside debug mode that variable is unset and we
# still want to register the cleanup hook.
is_reloader_process = os.environ.get('WERKZEUG_RUN_MAIN') == 'true'
is_debug_mode = os.environ.get('FLASK_DEBUG') == '1' or os.environ.get('WERKZEUG_RUN_MAIN') is not None
# Debug mode: only register inside the reloader child. Non-debug: always register.
if is_debug_mode and not is_reloader_process:
_cleanup_registered = True # Prevent the parent process from retrying.
return
# Capture the previously installed signal handlers so we can chain to them.
original_sigint = signal.getsignal(signal.SIGINT)
original_sigterm = signal.getsignal(signal.SIGTERM)
# SIGHUP exists only on Unix (macOS/Linux); Windows does not have it.
original_sighup = None
has_sighup = hasattr(signal, 'SIGHUP')
if has_sighup:
original_sighup = signal.getsignal(signal.SIGHUP)
def cleanup_handler(signum=None, frame=None):
"""Signal handler that cleans up simulations before delegating to the original handler."""
# Only log when there is actually something to clean up.
if cls._processes or cls._graph_memory_enabled:
logger.info(t("log.simulation_runner.m034", signum=signum))
cls.cleanup_all_simulations()
# Chain to the original handler so Flask exits normally.
if signum == signal.SIGINT and callable(original_sigint):
original_sigint(signum, frame)
elif signum == signal.SIGTERM and callable(original_sigterm):
original_sigterm(signum, frame)
elif has_sighup and signum == signal.SIGHUP:
# SIGHUP is sent when the terminal is closed.
if callable(original_sighup):
original_sighup(signum, frame)
else:
# Default behavior: exit cleanly.
sys.exit(0)
else:
# If the original handler is not callable (e.g. SIG_DFL), use the default behavior.
raise KeyboardInterrupt
# Register the atexit handler as a fallback.
atexit.register(cls.cleanup_all_simulations)
# Register signal handlers (only valid from the main thread).
try:
# SIGTERM: default signal sent by `kill`.
signal.signal(signal.SIGTERM, cleanup_handler)
# SIGINT: Ctrl+C
signal.signal(signal.SIGINT, cleanup_handler)
# SIGHUP: terminal close (Unix only).
if has_sighup:
signal.signal(signal.SIGHUP, cleanup_handler)
except ValueError:
# Not the main thread — fall back to the atexit hook.
logger.warning(t("log.simulation_runner.m035"))
_cleanup_registered = True
@classmethod
def get_running_simulations(cls) -> List[str]:
"""Return a list of every simulation ID with a live subprocess."""
running = []
for sim_id, process in cls._processes.items():
if process.poll() is None:
running.append(sim_id)
return running
# ============== Interview feature ==============
@classmethod
def check_env_alive(cls, simulation_id: str) -> bool:
"""
Check whether the simulation environment is alive and able to receive interview commands.
Args:
simulation_id: Simulation ID.
Returns:
True if the environment is alive, False if it has shut down.
"""
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
if not os.path.exists(sim_dir):
return False
ipc_client = SimulationIPCClient(sim_dir)
return ipc_client.check_env_alive()
@classmethod
def get_env_status_detail(cls, simulation_id: str) -> Dict[str, Any]:
"""
Return detailed status info for the simulation environment.
Args:
simulation_id: Simulation ID.
Returns:
Status dict containing status, twitter_available, reddit_available, timestamp.
"""
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
status_file = os.path.join(sim_dir, "env_status.json")
default_status = {
"status": "stopped",
"twitter_available": False,
"reddit_available": False,
"timestamp": None
}
if not os.path.exists(status_file):
return default_status
try:
with open(status_file, 'r', encoding='utf-8') as f:
status = json.load(f)
return {
"status": status.get("status", "stopped"),
"twitter_available": status.get("twitter_available", False),
"reddit_available": status.get("reddit_available", False),
"timestamp": status.get("timestamp")
}
except (json.JSONDecodeError, OSError):
return default_status
@classmethod
def interview_agent(
cls,
simulation_id: str,
agent_id: int,
prompt: str,
platform: str = None,
timeout: float = 60.0
) -> Dict[str, Any]:
"""
Interview a single agent.
Args:
simulation_id: Simulation ID.
agent_id: Agent ID.
prompt: Interview question.
platform: Optional platform selector.
- "twitter": only interview the agent on Twitter.
- "reddit": only interview the agent on Reddit.
- None: in dual-platform runs, interview both platforms and return a merged result.
timeout: Timeout in seconds.
Returns:
Interview result dict.
Raises:
ValueError: Simulation does not exist or its environment is not running.
TimeoutError: Timed out waiting for the response.
"""
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
if not os.path.exists(sim_dir):
raise ValueError(f"模拟不存在: {simulation_id}")
ipc_client = SimulationIPCClient(sim_dir)
if not ipc_client.check_env_alive():
raise ValueError(f"模拟环境未运行或已关闭无法执行Interview: {simulation_id}")
logger.info(t("log.simulation_runner.m036", simulation_id=simulation_id, agent_id=agent_id, platform=platform))
response = ipc_client.send_interview(
agent_id=agent_id,
prompt=prompt,
platform=platform,
timeout=timeout
)
if response.status.value == "completed":
return {
"success": True,
"agent_id": agent_id,
"prompt": prompt,
"result": response.result,
"timestamp": response.timestamp
}
else:
return {
"success": False,
"agent_id": agent_id,
"prompt": prompt,
"error": response.error,
"timestamp": response.timestamp
}
@classmethod
def interview_agents_batch(
cls,
simulation_id: str,
interviews: List[Dict[str, Any]],
platform: str = None,
timeout: float = 120.0
) -> Dict[str, Any]:
"""
Interview multiple agents in batch.
Args:
simulation_id: Simulation ID.
interviews: Interview list; each entry is {"agent_id": int, "prompt": str, "platform": str (optional)}.
platform: Optional default platform (overridden per-interview by an entry's own `platform`).
- "twitter": default to interviewing only Twitter.
- "reddit": default to interviewing only Reddit.
- None: in dual-platform runs, interview every agent on both platforms.
timeout: Timeout in seconds.
Returns:
Batch interview result dict.
Raises:
ValueError: Simulation does not exist or its environment is not running.
TimeoutError: Timed out waiting for the response.
"""
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
if not os.path.exists(sim_dir):
raise ValueError(f"模拟不存在: {simulation_id}")
ipc_client = SimulationIPCClient(sim_dir)
if not ipc_client.check_env_alive():
raise ValueError(f"模拟环境未运行或已关闭无法执行Interview: {simulation_id}")
logger.info(t("log.simulation_runner.m037", simulation_id=simulation_id, len=len(interviews), platform=platform))
response = ipc_client.send_batch_interview(
interviews=interviews,
platform=platform,
timeout=timeout
)
if response.status.value == "completed":
return {
"success": True,
"interviews_count": len(interviews),
"result": response.result,
"timestamp": response.timestamp
}
else:
return {
"success": False,
"interviews_count": len(interviews),
"error": response.error,
"timestamp": response.timestamp
}
@classmethod
def interview_all_agents(
cls,
simulation_id: str,
prompt: str,
platform: str = None,
timeout: float = 180.0
) -> Dict[str, Any]:
"""
Interview every agent in the simulation (global interview).
Sends the same prompt to every agent in the simulation.
Args:
simulation_id: Simulation ID.
prompt: Interview question used for every agent.
platform: Optional platform selector.
- "twitter": only interview Twitter.
- "reddit": only interview Reddit.
- None: in dual-platform runs, interview every agent on both platforms.
timeout: Timeout in seconds.
Returns:
Global interview result dict.
"""
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
if not os.path.exists(sim_dir):
raise ValueError(f"模拟不存在: {simulation_id}")
# Read every agent from the simulation config.
config_path = os.path.join(sim_dir, "simulation_config.json")
if not os.path.exists(config_path):
raise ValueError(f"模拟配置不存在: {simulation_id}")
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
agent_configs = config.get("agent_configs", [])
if not agent_configs:
raise ValueError(f"模拟配置中没有Agent: {simulation_id}")
# Build the batch-interview payload.
interviews = []
for agent_config in agent_configs:
agent_id = agent_config.get("agent_id")
if agent_id is not None:
interviews.append({
"agent_id": agent_id,
"prompt": prompt
})
logger.info(t("log.simulation_runner.m038", simulation_id=simulation_id, len=len(interviews), platform=platform))
return cls.interview_agents_batch(
simulation_id=simulation_id,
interviews=interviews,
platform=platform,
timeout=timeout
)
@classmethod
def close_simulation_env(
cls,
simulation_id: str,
timeout: float = 30.0
) -> Dict[str, Any]:
"""
Close the simulation environment (does not stop the simulation subprocess).
Sends a close-environment command to the simulation so it exits its wait-for-command mode
gracefully.
Args:
simulation_id: Simulation ID.
timeout: Timeout in seconds.
Returns:
Operation-result dict.
"""
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
if not os.path.exists(sim_dir):
raise ValueError(f"模拟不存在: {simulation_id}")
ipc_client = SimulationIPCClient(sim_dir)
if not ipc_client.check_env_alive():
return {
"success": True,
"message": "环境已经关闭"
}
logger.info(t("log.simulation_runner.m039", simulation_id=simulation_id))
try:
response = ipc_client.send_close_env(timeout=timeout)
return {
"success": response.status.value == "completed",
"message": "环境关闭命令已发送",
"result": response.result,
"timestamp": response.timestamp
}
except TimeoutError:
# Timing out can simply mean the environment is already shutting down.
return {
"success": True,
"message": "环境关闭命令已发送(等待响应超时,环境可能正在关闭)"
}
@classmethod
def _get_interview_history_from_db(
cls,
db_path: str,
platform_name: str,
agent_id: Optional[int] = None,
limit: int = 100
) -> List[Dict[str, Any]]:
"""Read the interview history from a single per-platform database."""
import sqlite3
if not os.path.exists(db_path):
return []
results = []
try:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
if agent_id is not None:
cursor.execute("""
SELECT user_id, info, created_at
FROM trace
WHERE action = 'interview' AND user_id = ?
ORDER BY created_at DESC
LIMIT ?
""", (agent_id, limit))
else:
cursor.execute("""
SELECT user_id, info, created_at
FROM trace
WHERE action = 'interview'
ORDER BY created_at DESC
LIMIT ?
""", (limit,))
for user_id, info_json, created_at in cursor.fetchall():
try:
info = json.loads(info_json) if info_json else {}
except json.JSONDecodeError:
info = {"raw": info_json}
results.append({
"agent_id": user_id,
"response": info.get("response", info),
"prompt": info.get("prompt", ""),
"timestamp": created_at,
"platform": platform_name
})
conn.close()
except Exception as e:
logger.error(t("log.simulation_runner.m040", platform_name=platform_name, e=e))
return results
@classmethod
def get_interview_history(
cls,
simulation_id: str,
platform: str = None,
agent_id: Optional[int] = None,
limit: int = 100
) -> List[Dict[str, Any]]:
"""
Return the interview history (read from the per-platform databases).
Args:
simulation_id: Simulation ID.
platform: Platform selector (reddit/twitter/None).
- "reddit": only return Reddit history.
- "twitter": only return Twitter history.
- None: return history from both platforms.
agent_id: Optional agent-id filter; if set, only that agent's history is returned.
limit: Max number of records per platform.
Returns:
Interview-history list.
"""
sim_dir = os.path.join(cls.RUN_STATE_DIR, simulation_id)
results = []
# Decide which platform databases to query.
if platform in ("reddit", "twitter"):
platforms = [platform]
else:
# No platform specified: query both.
platforms = ["twitter", "reddit"]
for p in platforms:
db_path = os.path.join(sim_dir, f"{p}_simulation.db")
platform_results = cls._get_interview_history_from_db(
db_path=db_path,
platform_name=p,
agent_id=agent_id,
limit=limit
)
results.extend(platform_results)
# Newest-first by timestamp.
results.sort(key=lambda x: x.get("timestamp", ""), reverse=True)
# When multiple platforms were queried, cap the merged result size.
if len(platforms) > 1 and len(results) > limit:
results = results[:limit]
return results