MicroFish/backend/scripts/run_twitter_simulation.py

760 lines
27 KiB
Python

"""
OASIS Twitter simulation preset script.
This script reads parameters from a config file to run a fully automated simulation.
Features:
- Does not close the environment immediately when the simulation finishes; enters
command-wait mode instead.
- Receives Interview commands over IPC.
- Supports both single-agent and batch interviews.
- Supports a remote close-environment command.
Usage:
python run_twitter_simulation.py --config /path/to/simulation_config.json
python run_twitter_simulation.py --config /path/to/simulation_config.json --no-wait # close immediately when done
"""
import argparse
import asyncio
import json
import logging
import os
import random
import signal
import sys
import sqlite3
from datetime import datetime
from typing import Dict, Any, List, Optional
# Globals used by the signal handler.
_shutdown_event = None
_cleanup_done = False
_scripts_dir = os.path.dirname(os.path.abspath(__file__))
_backend_dir = os.path.abspath(os.path.join(_scripts_dir, '..'))
_project_root = os.path.abspath(os.path.join(_backend_dir, '..'))
sys.path.insert(0, _scripts_dir)
sys.path.insert(0, _backend_dir)
# Load the project-root .env (it carries LLM_API_KEY and friends).
from dotenv import load_dotenv
_env_file = os.path.join(_project_root, '.env')
if os.path.exists(_env_file):
load_dotenv(_env_file)
else:
_backend_env = os.path.join(_backend_dir, '.env')
if os.path.exists(_backend_env):
load_dotenv(_backend_env)
import re
class UnicodeFormatter(logging.Formatter):
"""Custom formatter that turns Unicode escape sequences into readable characters."""
UNICODE_ESCAPE_PATTERN = re.compile(r'\\u([0-9a-fA-F]{4})')
def format(self, record):
result = super().format(record)
def replace_unicode(match):
try:
return chr(int(match.group(1), 16))
except (ValueError, OverflowError):
return match.group(0)
return self.UNICODE_ESCAPE_PATTERN.sub(replace_unicode, result)
class MaxTokensWarningFilter(logging.Filter):
"""Suppress camel-ai's max_tokens warning — we intentionally leave it unset and let the model decide."""
def filter(self, record):
if "max_tokens" in record.getMessage() and "Invalid or missing" in record.getMessage():
return False
return True
# Install the filter at import time so it is active before any camel code runs.
logging.getLogger().addFilter(MaxTokensWarningFilter())
def setup_oasis_logging(log_dir: str):
"""Configure OASIS logging with fixed log filenames."""
os.makedirs(log_dir, exist_ok=True)
# Wipe stale log files from previous runs.
for f in os.listdir(log_dir):
old_log = os.path.join(log_dir, f)
if os.path.isfile(old_log) and f.endswith('.log'):
try:
os.remove(old_log)
except OSError:
pass
formatter = UnicodeFormatter("%(levelname)s - %(asctime)s - %(name)s - %(message)s")
loggers_config = {
"social.agent": os.path.join(log_dir, "social.agent.log"),
"social.twitter": os.path.join(log_dir, "social.twitter.log"),
"social.rec": os.path.join(log_dir, "social.rec.log"),
"oasis.env": os.path.join(log_dir, "oasis.env.log"),
"table": os.path.join(log_dir, "table.log"),
}
for logger_name, log_file in loggers_config.items():
logger = logging.getLogger(logger_name)
logger.setLevel(logging.DEBUG)
logger.handlers.clear()
file_handler = logging.FileHandler(log_file, encoding='utf-8', mode='w')
file_handler.setLevel(logging.DEBUG)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logger.propagate = False
try:
from camel.models import ModelFactory
from camel.types import ModelPlatformType
import oasis
from oasis import (
ActionType,
LLMAction,
ManualAction,
generate_twitter_agent_graph
)
except ImportError as e:
print(f"错误: 缺少依赖 {e}")
print("请先安装: pip install oasis-ai camel-ai")
sys.exit(1)
# IPC-related constants.
IPC_COMMANDS_DIR = "ipc_commands"
IPC_RESPONSES_DIR = "ipc_responses"
ENV_STATUS_FILE = "env_status.json"
class CommandType:
"""Command type constants."""
INTERVIEW = "interview"
BATCH_INTERVIEW = "batch_interview"
CLOSE_ENV = "close_env"
class IPCHandler:
"""Handles IPC commands directed at the running simulation."""
def __init__(self, simulation_dir: str, env, agent_graph):
self.simulation_dir = simulation_dir
self.env = env
self.agent_graph = agent_graph
self.commands_dir = os.path.join(simulation_dir, IPC_COMMANDS_DIR)
self.responses_dir = os.path.join(simulation_dir, IPC_RESPONSES_DIR)
self.status_file = os.path.join(simulation_dir, ENV_STATUS_FILE)
self._running = True
os.makedirs(self.commands_dir, exist_ok=True)
os.makedirs(self.responses_dir, exist_ok=True)
def update_status(self, status: str):
"""Write the current environment status to the status file."""
with open(self.status_file, 'w', encoding='utf-8') as f:
json.dump({
"status": status,
"timestamp": datetime.now().isoformat()
}, f, ensure_ascii=False, indent=2)
def poll_command(self) -> Optional[Dict[str, Any]]:
"""Poll for the next pending command."""
if not os.path.exists(self.commands_dir):
return None
# Collect command files ordered by mtime.
command_files = []
for filename in os.listdir(self.commands_dir):
if filename.endswith('.json'):
filepath = os.path.join(self.commands_dir, filename)
command_files.append((filepath, os.path.getmtime(filepath)))
command_files.sort(key=lambda x: x[1])
for filepath, _ in command_files:
try:
with open(filepath, 'r', encoding='utf-8') as f:
return json.load(f)
except (json.JSONDecodeError, OSError):
continue
return None
def send_response(self, command_id: str, status: str, result: Dict = None, error: str = None):
"""Send a response for a processed command."""
response = {
"command_id": command_id,
"status": status,
"result": result,
"error": error,
"timestamp": datetime.now().isoformat()
}
response_file = os.path.join(self.responses_dir, f"{command_id}.json")
with open(response_file, 'w', encoding='utf-8') as f:
json.dump(response, f, ensure_ascii=False, indent=2)
# Remove the command file once a response has been written.
command_file = os.path.join(self.commands_dir, f"{command_id}.json")
try:
os.remove(command_file)
except OSError:
pass
async def handle_interview(self, command_id: str, agent_id: int, prompt: str) -> bool:
"""Handle a single-agent interview command.
Returns:
True on success, False on failure.
"""
try:
agent = self.agent_graph.get_agent(agent_id)
interview_action = ManualAction(
action_type=ActionType.INTERVIEW,
action_args={"prompt": prompt}
)
actions = {agent: interview_action}
await self.env.step(actions)
# Pull the resulting transcript from the simulation database.
result = self._get_interview_result(agent_id)
self.send_response(command_id, "completed", result=result)
print(f" Interview完成: agent_id={agent_id}")
return True
except Exception as e:
error_msg = str(e)
print(f" Interview失败: agent_id={agent_id}, error={error_msg}")
self.send_response(command_id, "failed", error=error_msg)
return False
async def handle_batch_interview(self, command_id: str, interviews: List[Dict]) -> bool:
"""Handle a batch interview command.
Args:
interviews: [{"agent_id": int, "prompt": str}, ...]
"""
try:
actions = {}
agent_prompts = {} # Track the prompt issued to each agent for later result lookup.
for interview in interviews:
agent_id = interview.get("agent_id")
prompt = interview.get("prompt", "")
try:
agent = self.agent_graph.get_agent(agent_id)
actions[agent] = ManualAction(
action_type=ActionType.INTERVIEW,
action_args={"prompt": prompt}
)
agent_prompts[agent_id] = prompt
except Exception as e:
print(f" 警告: 无法获取Agent {agent_id}: {e}")
if not actions:
self.send_response(command_id, "failed", error="没有有效的Agent")
return False
await self.env.step(actions)
# Collect the per-agent interview results.
results = {}
for agent_id in agent_prompts.keys():
result = self._get_interview_result(agent_id)
results[agent_id] = result
self.send_response(command_id, "completed", result={
"interviews_count": len(results),
"results": results
})
print(f" 批量Interview完成: {len(results)} 个Agent")
return True
except Exception as e:
error_msg = str(e)
print(f" 批量Interview失败: {error_msg}")
self.send_response(command_id, "failed", error=error_msg)
return False
def _get_interview_result(self, agent_id: int) -> Dict[str, Any]:
"""Fetch the most recent interview result for an agent from the database."""
db_path = os.path.join(self.simulation_dir, "twitter_simulation.db")
result = {
"agent_id": agent_id,
"response": None,
"timestamp": None
}
if not os.path.exists(db_path):
return result
try:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Pull the most recent INTERVIEW trace row for this agent.
cursor.execute("""
SELECT user_id, info, created_at
FROM trace
WHERE action = ? AND user_id = ?
ORDER BY created_at DESC
LIMIT 1
""", (ActionType.INTERVIEW.value, agent_id))
row = cursor.fetchone()
if row:
user_id, info_json, created_at = row
try:
info = json.loads(info_json) if info_json else {}
result["response"] = info.get("response", info)
result["timestamp"] = created_at
except json.JSONDecodeError:
result["response"] = info_json
conn.close()
except Exception as e:
print(f" 读取Interview结果失败: {e}")
return result
async def process_commands(self) -> bool:
"""Process pending commands.
Returns:
True if the run loop should continue, False if it should exit.
"""
command = self.poll_command()
if not command:
return True
command_id = command.get("command_id")
command_type = command.get("command_type")
args = command.get("args", {})
print(f"\n收到IPC命令: {command_type}, id={command_id}")
if command_type == CommandType.INTERVIEW:
await self.handle_interview(
command_id,
args.get("agent_id", 0),
args.get("prompt", "")
)
return True
elif command_type == CommandType.BATCH_INTERVIEW:
await self.handle_batch_interview(
command_id,
args.get("interviews", [])
)
return True
elif command_type == CommandType.CLOSE_ENV:
print("收到关闭环境命令")
self.send_response(command_id, "completed", result={"message": "环境即将关闭"})
return False
else:
self.send_response(command_id, "failed", error=f"未知命令类型: {command_type}")
return True
class TwitterSimulationRunner:
"""Drives a single Twitter simulation run."""
# Available Twitter actions. INTERVIEW is intentionally excluded — it can only be triggered via ManualAction.
AVAILABLE_ACTIONS = [
ActionType.CREATE_POST,
ActionType.LIKE_POST,
ActionType.REPOST,
ActionType.FOLLOW,
ActionType.DO_NOTHING,
ActionType.QUOTE_POST,
]
def __init__(self, config_path: str, wait_for_commands: bool = True):
"""Initialize the simulation runner.
Args:
config_path: Path to the config file (simulation_config.json).
wait_for_commands: Whether to wait for IPC commands after the simulation completes (default True).
"""
self.config_path = config_path
self.config = self._load_config()
self.simulation_dir = os.path.dirname(config_path)
self.wait_for_commands = wait_for_commands
self.env = None
self.agent_graph = None
self.ipc_handler = None
def _load_config(self) -> Dict[str, Any]:
"""Load the simulation config file."""
with open(self.config_path, 'r', encoding='utf-8') as f:
return json.load(f)
def _get_profile_path(self) -> str:
"""Return the agent profile path (OASIS Twitter expects CSV)."""
return os.path.join(self.simulation_dir, "twitter_profiles.csv")
def _get_db_path(self) -> str:
"""Return the simulation SQLite database path."""
return os.path.join(self.simulation_dir, "twitter_simulation.db")
def _create_model(self):
"""Create the LLM model.
Uses the project-root .env file (highest precedence):
- LLM_API_KEY: API key
- LLM_BASE_URL: API base URL
- LLM_MODEL_NAME: model name
"""
# Prefer values from .env.
llm_api_key = os.environ.get("LLM_API_KEY", "")
llm_base_url = os.environ.get("LLM_BASE_URL", "")
llm_model = os.environ.get("LLM_MODEL_NAME", "")
# Fall back to the simulation config if .env did not provide a model name.
if not llm_model:
llm_model = self.config.get("llm_model", "gpt-4o-mini")
# camel-ai reads OPENAI_API_KEY from the environment.
if llm_api_key:
os.environ["OPENAI_API_KEY"] = llm_api_key
if not os.environ.get("OPENAI_API_KEY"):
raise ValueError("缺少 API Key 配置,请在项目根目录 .env 文件中设置 LLM_API_KEY")
if llm_base_url:
os.environ["OPENAI_API_BASE_URL"] = llm_base_url
print(f"LLM配置: model={llm_model}, base_url={llm_base_url[:40] if llm_base_url else '默认'}...")
return ModelFactory.create(
model_platform=ModelPlatformType.OPENAI,
model_type=llm_model,
)
def _get_active_agents_for_round(
self,
env,
current_hour: int,
round_num: int
) -> List:
"""Decide which agents activate this round, based on time and config.
Args:
env: The OASIS environment.
current_hour: Current simulated hour (0-23).
round_num: Current round number.
Returns:
The list of agents activated this round.
"""
time_config = self.config.get("time_config", {})
agent_configs = self.config.get("agent_configs", [])
# Base activation count per round.
base_min = time_config.get("agents_per_hour_min", 5)
base_max = time_config.get("agents_per_hour_max", 20)
# Adjust by time-of-day (peak vs. off-peak hours).
peak_hours = time_config.get("peak_hours", [9, 10, 11, 14, 15, 20, 21, 22])
off_peak_hours = time_config.get("off_peak_hours", [0, 1, 2, 3, 4, 5])
if current_hour in peak_hours:
multiplier = time_config.get("peak_activity_multiplier", 1.5)
elif current_hour in off_peak_hours:
multiplier = time_config.get("off_peak_activity_multiplier", 0.3)
else:
multiplier = 1.0
target_count = int(random.uniform(base_min, base_max) * multiplier)
# Compute activation probability for each configured agent.
candidates = []
for cfg in agent_configs:
agent_id = cfg.get("agent_id", 0)
active_hours = cfg.get("active_hours", list(range(8, 23)))
activity_level = cfg.get("activity_level", 0.5)
if current_hour not in active_hours:
continue
if random.random() < activity_level:
candidates.append(agent_id)
# Pick a random subset of the eligible candidates.
selected_ids = random.sample(
candidates,
min(target_count, len(candidates))
) if candidates else []
# Resolve IDs to Agent objects.
active_agents = []
for agent_id in selected_ids:
try:
agent = env.agent_graph.get_agent(agent_id)
active_agents.append((agent_id, agent))
except Exception:
pass
return active_agents
async def run(self, max_rounds: int = None):
"""Run the Twitter simulation.
Args:
max_rounds: Optional cap on the number of rounds, used to truncate overly long simulations.
"""
print("=" * 60)
print("OASIS Twitter模拟")
print(f"配置文件: {self.config_path}")
print(f"模拟ID: {self.config.get('simulation_id', 'unknown')}")
print(f"等待命令模式: {'启用' if self.wait_for_commands else '禁用'}")
print("=" * 60)
time_config = self.config.get("time_config", {})
total_hours = time_config.get("total_simulation_hours", 72)
minutes_per_round = time_config.get("minutes_per_round", 30)
total_rounds = (total_hours * 60) // minutes_per_round
# Truncate to max_rounds when one was supplied.
if max_rounds is not None and max_rounds > 0:
original_rounds = total_rounds
total_rounds = min(total_rounds, max_rounds)
if total_rounds < original_rounds:
print(f"\n轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})")
print(f"\n模拟参数:")
print(f" - 总模拟时长: {total_hours}小时")
print(f" - 每轮时间: {minutes_per_round}分钟")
print(f" - 总轮数: {total_rounds}")
if max_rounds:
print(f" - 最大轮数限制: {max_rounds}")
print(f" - Agent数量: {len(self.config.get('agent_configs', []))}")
print("\n初始化LLM模型...")
model = self._create_model()
# Load the agent graph from the profile CSV.
print("加载Agent Profile...")
profile_path = self._get_profile_path()
if not os.path.exists(profile_path):
print(f"错误: Profile文件不存在: {profile_path}")
return
self.agent_graph = await generate_twitter_agent_graph(
profile_path=profile_path,
model=model,
available_actions=self.AVAILABLE_ACTIONS,
)
# Reset the simulation database for a clean run.
db_path = self._get_db_path()
if os.path.exists(db_path):
os.remove(db_path)
print(f"已删除旧数据库: {db_path}")
print("创建OASIS环境...")
self.env = oasis.make(
agent_graph=self.agent_graph,
platform=oasis.DefaultPlatformType.TWITTER,
database_path=db_path,
semaphore=30, # Cap concurrent LLM requests to avoid API overload.
)
await self.env.reset()
print("环境初始化完成\n")
self.ipc_handler = IPCHandler(self.simulation_dir, self.env, self.agent_graph)
self.ipc_handler.update_status("running")
# Run the initial seeded events (kickoff posts).
event_config = self.config.get("event_config", {})
initial_posts = event_config.get("initial_posts", [])
if initial_posts:
print(f"执行初始事件 ({len(initial_posts)}条初始帖子)...")
initial_actions = {}
for post in initial_posts:
agent_id = post.get("poster_agent_id", 0)
content = post.get("content", "")
try:
agent = self.env.agent_graph.get_agent(agent_id)
initial_actions[agent] = ManualAction(
action_type=ActionType.CREATE_POST,
action_args={"content": content}
)
except Exception as e:
print(f" 警告: 无法为Agent {agent_id}创建初始帖子: {e}")
if initial_actions:
await self.env.step(initial_actions)
print(f" 已发布 {len(initial_actions)} 条初始帖子")
# Main simulation loop.
print("\n开始模拟循环...")
start_time = datetime.now()
for round_num in range(total_rounds):
# Map round number to simulated wall-clock time.
simulated_minutes = round_num * minutes_per_round
simulated_hour = (simulated_minutes // 60) % 24
simulated_day = simulated_minutes // (60 * 24) + 1
active_agents = self._get_active_agents_for_round(
self.env, simulated_hour, round_num
)
if not active_agents:
continue
actions = {
agent: LLMAction()
for _, agent in active_agents
}
await self.env.step(actions)
# Periodic progress log.
if (round_num + 1) % 10 == 0 or round_num == 0:
elapsed = (datetime.now() - start_time).total_seconds()
progress = (round_num + 1) / total_rounds * 100
print(f" [Day {simulated_day}, {simulated_hour:02d}:00] "
f"Round {round_num + 1}/{total_rounds} ({progress:.1f}%) "
f"- {len(active_agents)} agents active "
f"- elapsed: {elapsed:.1f}s")
total_elapsed = (datetime.now() - start_time).total_seconds()
print(f"\n模拟循环完成!")
print(f" - 总耗时: {total_elapsed:.1f}")
print(f" - 数据库: {db_path}")
# Optionally enter command-wait mode.
if self.wait_for_commands:
print("\n" + "=" * 60)
print("进入等待命令模式 - 环境保持运行")
print("支持的命令: interview, batch_interview, close_env")
print("=" * 60)
self.ipc_handler.update_status("alive")
# Command-wait loop, driven by the global _shutdown_event.
try:
while not _shutdown_event.is_set():
should_continue = await self.ipc_handler.process_commands()
if not should_continue:
break
try:
await asyncio.wait_for(_shutdown_event.wait(), timeout=0.5)
break # Shutdown signal received.
except asyncio.TimeoutError:
pass
except KeyboardInterrupt:
print("\n收到中断信号")
except asyncio.CancelledError:
print("\n任务被取消")
except Exception as e:
print(f"\n命令处理出错: {e}")
print("\n关闭环境...")
self.ipc_handler.update_status("stopped")
await self.env.close()
print("环境已关闭")
print("=" * 60)
async def main():
parser = argparse.ArgumentParser(description='OASIS Twitter模拟')
parser.add_argument(
'--config',
type=str,
required=True,
help='配置文件路径 (simulation_config.json)'
)
parser.add_argument(
'--max-rounds',
type=int,
default=None,
help='最大模拟轮数(可选,用于截断过长的模拟)'
)
parser.add_argument(
'--no-wait',
action='store_true',
default=False,
help='模拟完成后立即关闭环境,不进入等待命令模式'
)
args = parser.parse_args()
# Create the shutdown event inside the running event loop.
global _shutdown_event
_shutdown_event = asyncio.Event()
if not os.path.exists(args.config):
print(f"错误: 配置文件不存在: {args.config}")
sys.exit(1)
# Initialize logging with fixed filenames; old logs are wiped.
simulation_dir = os.path.dirname(args.config) or "."
setup_oasis_logging(os.path.join(simulation_dir, "log"))
runner = TwitterSimulationRunner(
config_path=args.config,
wait_for_commands=not args.no_wait
)
await runner.run(max_rounds=args.max_rounds)
def setup_signal_handlers():
"""Install signal handlers so SIGTERM/SIGINT trigger an orderly shutdown.
The handler gives the program a chance to clean up resources properly
(closing the database, the OASIS environment, etc.) on the first signal,
and only force-exits on a repeated signal.
"""
def signal_handler(signum, frame):
global _cleanup_done
sig_name = "SIGTERM" if signum == signal.SIGTERM else "SIGINT"
print(f"\n收到 {sig_name} 信号,正在退出...")
if not _cleanup_done:
_cleanup_done = True
if _shutdown_event:
_shutdown_event.set()
else:
# Force exit only on a repeat signal.
print("强制退出...")
sys.exit(1)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
if __name__ == "__main__":
setup_signal_handlers()
try:
asyncio.run(main())
except KeyboardInterrupt:
print("\n程序被中断")
except SystemExit:
pass
finally:
print("模拟进程已退出")