MicroFish/backend/scripts/run_parallel_simulation.py

1693 lines
62 KiB
Python

"""OASIS dual-platform parallel simulation preset script.
Runs Twitter and Reddit simulations simultaneously, reading the same config file.
Features:
- Dual-platform (Twitter + Reddit) parallel simulation
- Keeps environments alive after the simulation finishes and enters wait-for-command mode
- Receives Interview commands via IPC
- Supports single-agent and batch interviews
- Supports a remote close-environment command
Usage:
python run_parallel_simulation.py --config simulation_config.json
python run_parallel_simulation.py --config simulation_config.json --no-wait # close immediately when done
python run_parallel_simulation.py --config simulation_config.json --twitter-only
python run_parallel_simulation.py --config simulation_config.json --reddit-only
Log layout:
sim_xxx/
├── twitter/
│ └── actions.jsonl # Twitter platform action log
├── reddit/
│ └── actions.jsonl # Reddit platform action log
├── simulation.log # main simulation process log
└── run_state.json # run state (used by API queries)
"""
# ============================================================
# Fix the Windows encoding issue by forcing UTF-8 before any import.
# This works around the OASIS third-party library opening files without
# specifying an encoding.
# ============================================================
import sys
import os
if sys.platform == 'win32':
# Set Python's default I/O encoding to UTF-8 so every open() call without
# an explicit encoding picks it up.
os.environ.setdefault('PYTHONUTF8', '1')
os.environ.setdefault('PYTHONIOENCODING', 'utf-8')
# Reconfigure stdout/stderr to UTF-8 to avoid mojibake in the console.
if hasattr(sys.stdout, 'reconfigure'):
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
if hasattr(sys.stderr, 'reconfigure'):
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
# Force the default encoding used by open(). The env-var approach above
# only works when set at interpreter startup, so we additionally
# monkey-patch the built-in open().
import builtins
_original_open = builtins.open
def _utf8_open(file, mode='r', buffering=-1, encoding=None, errors=None,
newline=None, closefd=True, opener=None):
"""Wrap open() so text-mode calls default to UTF-8.
Fixes third-party libraries (such as OASIS) that open files without
specifying an encoding.
"""
# Only override when the caller is using text mode and did not request
# an explicit encoding.
if encoding is None and 'b' not in mode:
encoding = 'utf-8'
return _original_open(file, mode, buffering, encoding, errors,
newline, closefd, opener)
builtins.open = _utf8_open
import argparse
import asyncio
import json
import logging
import multiprocessing
import random
import signal
import sqlite3
import warnings
from datetime import datetime
from typing import Dict, Any, List, Optional, Tuple
# Globals used by the signal handlers.
_shutdown_event = None
_cleanup_done = False
# Add the backend directory to sys.path. The script always lives in
# backend/scripts/.
_scripts_dir = os.path.dirname(os.path.abspath(__file__))
_backend_dir = os.path.abspath(os.path.join(_scripts_dir, '..'))
_project_root = os.path.abspath(os.path.join(_backend_dir, '..'))
sys.path.insert(0, _scripts_dir)
sys.path.insert(0, _backend_dir)
# Load the .env from the project root (contains LLM_API_KEY etc.).
from dotenv import load_dotenv
_env_file = os.path.join(_project_root, '.env')
if os.path.exists(_env_file):
load_dotenv(_env_file)
print(f"已加载环境配置: {_env_file}")
else:
# Fall back to backend/.env.
_backend_env = os.path.join(_backend_dir, '.env')
if os.path.exists(_backend_env):
load_dotenv(_backend_env)
print(f"已加载环境配置: {_backend_env}")
class MaxTokensWarningFilter(logging.Filter):
"""Suppress camel-ai max_tokens warnings.
We intentionally leave max_tokens unset so the model decides; the warning is noise.
"""
def filter(self, record):
if "max_tokens" in record.getMessage() and "Invalid or missing" in record.getMessage():
return False
return True
# Install the filter at import time so it is active before any camel code runs.
logging.getLogger().addFilter(MaxTokensWarningFilter())
def disable_oasis_logging():
"""Disable verbose OASIS library logging.
OASIS logs every agent observation and action which is extremely noisy; we
rely on our own action_logger instead.
"""
oasis_loggers = [
"social.agent",
"social.twitter",
"social.rec",
"oasis.env",
"table",
]
for logger_name in oasis_loggers:
logger = logging.getLogger(logger_name)
logger.setLevel(logging.CRITICAL) # only keep severe errors
logger.handlers.clear()
logger.propagate = False
def init_logging_for_simulation(simulation_dir: str):
"""Initialize logging for a simulation run.
Args:
simulation_dir: path to the simulation directory.
"""
disable_oasis_logging()
# Clean up any pre-existing log directory.
old_log_dir = os.path.join(simulation_dir, "log")
if os.path.exists(old_log_dir):
import shutil
shutil.rmtree(old_log_dir, ignore_errors=True)
from action_logger import SimulationLogManager, PlatformActionLogger
try:
from camel.models import ModelFactory
from camel.types import ModelPlatformType
import oasis
from oasis import (
ActionType,
LLMAction,
ManualAction,
generate_twitter_agent_graph,
generate_reddit_agent_graph
)
except ImportError as e:
print(f"错误: 缺少依赖 {e}")
print("请先安装: pip install oasis-ai camel-ai")
sys.exit(1)
# Twitter actions available to agents. INTERVIEW is excluded because it can only
# be triggered manually via ManualAction.
TWITTER_ACTIONS = [
ActionType.CREATE_POST,
ActionType.LIKE_POST,
ActionType.REPOST,
ActionType.FOLLOW,
ActionType.DO_NOTHING,
ActionType.QUOTE_POST,
]
# Reddit actions available to agents. INTERVIEW is excluded because it can only
# be triggered manually via ManualAction.
REDDIT_ACTIONS = [
ActionType.LIKE_POST,
ActionType.DISLIKE_POST,
ActionType.CREATE_POST,
ActionType.CREATE_COMMENT,
ActionType.LIKE_COMMENT,
ActionType.DISLIKE_COMMENT,
ActionType.SEARCH_POSTS,
ActionType.SEARCH_USER,
ActionType.TREND,
ActionType.REFRESH,
ActionType.DO_NOTHING,
ActionType.FOLLOW,
ActionType.MUTE,
]
# IPC-related constants.
IPC_COMMANDS_DIR = "ipc_commands"
IPC_RESPONSES_DIR = "ipc_responses"
ENV_STATUS_FILE = "env_status.json"
class CommandType:
"""Command type constants."""
INTERVIEW = "interview"
BATCH_INTERVIEW = "batch_interview"
CLOSE_ENV = "close_env"
class ParallelIPCHandler:
"""Dual-platform IPC command handler.
Manages both platform environments and processes Interview commands.
"""
def __init__(
self,
simulation_dir: str,
twitter_env=None,
twitter_agent_graph=None,
reddit_env=None,
reddit_agent_graph=None
):
self.simulation_dir = simulation_dir
self.twitter_env = twitter_env
self.twitter_agent_graph = twitter_agent_graph
self.reddit_env = reddit_env
self.reddit_agent_graph = reddit_agent_graph
self.commands_dir = os.path.join(simulation_dir, IPC_COMMANDS_DIR)
self.responses_dir = os.path.join(simulation_dir, IPC_RESPONSES_DIR)
self.status_file = os.path.join(simulation_dir, ENV_STATUS_FILE)
os.makedirs(self.commands_dir, exist_ok=True)
os.makedirs(self.responses_dir, exist_ok=True)
def update_status(self, status: str):
"""Update the recorded environment status."""
with open(self.status_file, 'w', encoding='utf-8') as f:
json.dump({
"status": status,
"twitter_available": self.twitter_env is not None,
"reddit_available": self.reddit_env is not None,
"timestamp": datetime.now().isoformat()
}, f, ensure_ascii=False, indent=2)
def poll_command(self) -> Optional[Dict[str, Any]]:
"""Poll for the next pending command."""
if not os.path.exists(self.commands_dir):
return None
# Collect command files sorted by mtime so older commands run first.
command_files = []
for filename in os.listdir(self.commands_dir):
if filename.endswith('.json'):
filepath = os.path.join(self.commands_dir, filename)
command_files.append((filepath, os.path.getmtime(filepath)))
command_files.sort(key=lambda x: x[1])
for filepath, _ in command_files:
try:
with open(filepath, 'r', encoding='utf-8') as f:
return json.load(f)
except (json.JSONDecodeError, OSError):
continue
return None
def send_response(self, command_id: str, status: str, result: Dict = None, error: str = None):
"""Send a response for a previously dispatched command."""
response = {
"command_id": command_id,
"status": status,
"result": result,
"error": error,
"timestamp": datetime.now().isoformat()
}
response_file = os.path.join(self.responses_dir, f"{command_id}.json")
with open(response_file, 'w', encoding='utf-8') as f:
json.dump(response, f, ensure_ascii=False, indent=2)
# Remove the original command file once a response is recorded.
command_file = os.path.join(self.commands_dir, f"{command_id}.json")
try:
os.remove(command_file)
except OSError:
pass
def _get_env_and_graph(self, platform: str):
"""Return the environment and agent graph for the given platform.
Args:
platform: platform name ("twitter" or "reddit").
Returns:
Tuple ``(env, agent_graph, platform_name)`` or ``(None, None, None)``
when the platform is unavailable.
"""
if platform == "twitter" and self.twitter_env:
return self.twitter_env, self.twitter_agent_graph, "twitter"
elif platform == "reddit" and self.reddit_env:
return self.reddit_env, self.reddit_agent_graph, "reddit"
else:
return None, None, None
async def _interview_single_platform(self, agent_id: int, prompt: str, platform: str) -> Dict[str, Any]:
"""Run an Interview on a single platform.
Returns:
A dict with the interview result, or a dict containing an ``error`` key.
"""
env, agent_graph, actual_platform = self._get_env_and_graph(platform)
if not env or not agent_graph:
return {"platform": platform, "error": f"{platform}平台不可用"}
try:
agent = agent_graph.get_agent(agent_id)
interview_action = ManualAction(
action_type=ActionType.INTERVIEW,
action_args={"prompt": prompt}
)
actions = {agent: interview_action}
await env.step(actions)
result = self._get_interview_result(agent_id, actual_platform)
result["platform"] = actual_platform
return result
except Exception as e:
return {"platform": platform, "error": str(e)}
async def handle_interview(self, command_id: str, agent_id: int, prompt: str, platform: str = None) -> bool:
"""Handle a single-agent interview command.
Args:
command_id: command identifier.
agent_id: agent identifier.
prompt: interview prompt.
platform: optional platform selector.
- "twitter": interview on Twitter only.
- "reddit": interview on Reddit only.
- ``None``: interview on both platforms and return a merged result.
Returns:
``True`` on success, ``False`` on failure.
"""
# If a specific platform was requested, only interview on that platform.
if platform in ("twitter", "reddit"):
result = await self._interview_single_platform(agent_id, prompt, platform)
if "error" in result:
self.send_response(command_id, "failed", error=result["error"])
print(f" Interview失败: agent_id={agent_id}, platform={platform}, error={result['error']}")
return False
else:
self.send_response(command_id, "completed", result=result)
print(f" Interview完成: agent_id={agent_id}, platform={platform}")
return True
# No platform specified: interview on both platforms simultaneously.
if not self.twitter_env and not self.reddit_env:
self.send_response(command_id, "failed", error="没有可用的模拟环境")
return False
results = {
"agent_id": agent_id,
"prompt": prompt,
"platforms": {}
}
success_count = 0
# Run the two platform interviews in parallel.
tasks = []
platforms_to_interview = []
if self.twitter_env:
tasks.append(self._interview_single_platform(agent_id, prompt, "twitter"))
platforms_to_interview.append("twitter")
if self.reddit_env:
tasks.append(self._interview_single_platform(agent_id, prompt, "reddit"))
platforms_to_interview.append("reddit")
platform_results = await asyncio.gather(*tasks)
for platform_name, platform_result in zip(platforms_to_interview, platform_results):
results["platforms"][platform_name] = platform_result
if "error" not in platform_result:
success_count += 1
if success_count > 0:
self.send_response(command_id, "completed", result=results)
print(f" Interview完成: agent_id={agent_id}, 成功平台数={success_count}/{len(platforms_to_interview)}")
return True
else:
errors = [f"{p}: {r.get('error', '未知错误')}" for p, r in results["platforms"].items()]
self.send_response(command_id, "failed", error="; ".join(errors))
print(f" Interview失败: agent_id={agent_id}, 所有平台都失败")
return False
async def handle_batch_interview(self, command_id: str, interviews: List[Dict], platform: str = None) -> bool:
"""Handle a batch-interview command.
Args:
command_id: command identifier.
interviews: ``[{"agent_id": int, "prompt": str, "platform": str(optional)}, ...]``.
platform: default platform (can be overridden per interview entry).
- "twitter": interview on Twitter only.
- "reddit": interview on Reddit only.
- ``None``: interview every agent on both platforms.
"""
# Bucket interviews by target platform.
twitter_interviews = []
reddit_interviews = []
both_platforms_interviews = [] # entries that need both platforms
for interview in interviews:
item_platform = interview.get("platform", platform)
if item_platform == "twitter":
twitter_interviews.append(interview)
elif item_platform == "reddit":
reddit_interviews.append(interview)
else:
# No platform specified: interview on both.
both_platforms_interviews.append(interview)
# Fan the both-platform entries out into the per-platform buckets.
if both_platforms_interviews:
if self.twitter_env:
twitter_interviews.extend(both_platforms_interviews)
if self.reddit_env:
reddit_interviews.extend(both_platforms_interviews)
results = {}
# Run the Twitter-side interviews.
if twitter_interviews and self.twitter_env:
try:
twitter_actions = {}
for interview in twitter_interviews:
agent_id = interview.get("agent_id")
prompt = interview.get("prompt", "")
try:
agent = self.twitter_agent_graph.get_agent(agent_id)
twitter_actions[agent] = ManualAction(
action_type=ActionType.INTERVIEW,
action_args={"prompt": prompt}
)
except Exception as e:
print(f" 警告: 无法获取Twitter Agent {agent_id}: {e}")
if twitter_actions:
await self.twitter_env.step(twitter_actions)
for interview in twitter_interviews:
agent_id = interview.get("agent_id")
result = self._get_interview_result(agent_id, "twitter")
result["platform"] = "twitter"
results[f"twitter_{agent_id}"] = result
except Exception as e:
print(f" Twitter批量Interview失败: {e}")
# Run the Reddit-side interviews.
if reddit_interviews and self.reddit_env:
try:
reddit_actions = {}
for interview in reddit_interviews:
agent_id = interview.get("agent_id")
prompt = interview.get("prompt", "")
try:
agent = self.reddit_agent_graph.get_agent(agent_id)
reddit_actions[agent] = ManualAction(
action_type=ActionType.INTERVIEW,
action_args={"prompt": prompt}
)
except Exception as e:
print(f" 警告: 无法获取Reddit Agent {agent_id}: {e}")
if reddit_actions:
await self.reddit_env.step(reddit_actions)
for interview in reddit_interviews:
agent_id = interview.get("agent_id")
result = self._get_interview_result(agent_id, "reddit")
result["platform"] = "reddit"
results[f"reddit_{agent_id}"] = result
except Exception as e:
print(f" Reddit批量Interview失败: {e}")
if results:
self.send_response(command_id, "completed", result={
"interviews_count": len(results),
"results": results
})
print(f" 批量Interview完成: {len(results)} 个Agent")
return True
else:
self.send_response(command_id, "failed", error="没有成功的采访")
return False
def _get_interview_result(self, agent_id: int, platform: str) -> Dict[str, Any]:
"""Read the latest Interview result for an agent from the database."""
db_path = os.path.join(self.simulation_dir, f"{platform}_simulation.db")
result = {
"agent_id": agent_id,
"response": None,
"timestamp": None
}
if not os.path.exists(db_path):
return result
try:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Look up the most recent Interview row for this agent.
cursor.execute("""
SELECT user_id, info, created_at
FROM trace
WHERE action = ? AND user_id = ?
ORDER BY created_at DESC
LIMIT 1
""", (ActionType.INTERVIEW.value, agent_id))
row = cursor.fetchone()
if row:
user_id, info_json, created_at = row
try:
info = json.loads(info_json) if info_json else {}
result["response"] = info.get("response", info)
result["timestamp"] = created_at
except json.JSONDecodeError:
result["response"] = info_json
conn.close()
except Exception as e:
print(f" 读取Interview结果失败: {e}")
return result
async def process_commands(self) -> bool:
"""Process all pending commands.
Returns:
``True`` to keep running, ``False`` if the process should exit.
"""
command = self.poll_command()
if not command:
return True
command_id = command.get("command_id")
command_type = command.get("command_type")
args = command.get("args", {})
print(f"\n收到IPC命令: {command_type}, id={command_id}")
if command_type == CommandType.INTERVIEW:
await self.handle_interview(
command_id,
args.get("agent_id", 0),
args.get("prompt", ""),
args.get("platform")
)
return True
elif command_type == CommandType.BATCH_INTERVIEW:
await self.handle_batch_interview(
command_id,
args.get("interviews", []),
args.get("platform")
)
return True
elif command_type == CommandType.CLOSE_ENV:
print("收到关闭环境命令")
self.send_response(command_id, "completed", result={"message": "环境即将关闭"})
return False
else:
self.send_response(command_id, "failed", error=f"未知命令类型: {command_type}")
return True
def load_config(config_path: str) -> Dict[str, Any]:
"""Load a JSON config file from disk."""
with open(config_path, 'r', encoding='utf-8') as f:
return json.load(f)
# Non-core action types to filter out: they provide little analytical value.
FILTERED_ACTIONS = {'refresh', 'sign_up'}
# Action-type mapping (database name -> canonical name).
ACTION_TYPE_MAP = {
'create_post': 'CREATE_POST',
'like_post': 'LIKE_POST',
'dislike_post': 'DISLIKE_POST',
'repost': 'REPOST',
'quote_post': 'QUOTE_POST',
'follow': 'FOLLOW',
'mute': 'MUTE',
'create_comment': 'CREATE_COMMENT',
'like_comment': 'LIKE_COMMENT',
'dislike_comment': 'DISLIKE_COMMENT',
'search_posts': 'SEARCH_POSTS',
'search_user': 'SEARCH_USER',
'trend': 'TREND',
'do_nothing': 'DO_NOTHING',
'interview': 'INTERVIEW',
}
def get_agent_names_from_config(config: Dict[str, Any]) -> Dict[int, str]:
"""Build an ``agent_id -> entity_name`` map from the simulation config.
Using the entity name lets actions.jsonl display the real entity rather
than placeholder labels like ``Agent_0``.
Args:
config: contents of ``simulation_config.json``.
Returns:
Mapping from agent id to entity name.
"""
agent_names = {}
agent_configs = config.get("agent_configs", [])
for agent_config in agent_configs:
agent_id = agent_config.get("agent_id")
entity_name = agent_config.get("entity_name", f"Agent_{agent_id}")
if agent_id is not None:
agent_names[agent_id] = entity_name
return agent_names
def fetch_new_actions_from_db(
db_path: str,
last_rowid: int,
agent_names: Dict[int, str]
) -> Tuple[List[Dict[str, Any]], int]:
"""Fetch new action rows from the database and enrich them with context.
Args:
db_path: path to the database file.
last_rowid: highest rowid processed previously. We track ``rowid``
rather than ``created_at`` because the two platforms use different
``created_at`` formats.
agent_names: ``agent_id -> agent_name`` mapping.
Returns:
Tuple ``(actions_list, new_last_rowid)``.
- ``actions_list``: action records, each containing ``agent_id``,
``agent_name``, ``action_type``, and ``action_args`` (with context).
- ``new_last_rowid``: the new highest rowid seen.
"""
actions = []
new_last_rowid = last_rowid
if not os.path.exists(db_path):
return actions, new_last_rowid
try:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
# Use ``rowid`` to track processed rows. ``rowid`` is SQLite's built-in
# auto-increment column and avoids the cross-platform ``created_at``
# format mismatch (Twitter stores integers, Reddit stores datetime strings).
cursor.execute("""
SELECT rowid, user_id, action, info
FROM trace
WHERE rowid > ?
ORDER BY rowid ASC
""", (last_rowid,))
for rowid, user_id, action, info_json in cursor.fetchall():
new_last_rowid = rowid
if action in FILTERED_ACTIONS:
continue
try:
action_args = json.loads(info_json) if info_json else {}
except json.JSONDecodeError:
action_args = {}
# Slim ``action_args`` down to the key fields. Content is kept in full (no truncation).
simplified_args = {}
if 'content' in action_args:
simplified_args['content'] = action_args['content']
if 'post_id' in action_args:
simplified_args['post_id'] = action_args['post_id']
if 'comment_id' in action_args:
simplified_args['comment_id'] = action_args['comment_id']
if 'quoted_id' in action_args:
simplified_args['quoted_id'] = action_args['quoted_id']
if 'new_post_id' in action_args:
simplified_args['new_post_id'] = action_args['new_post_id']
if 'follow_id' in action_args:
simplified_args['follow_id'] = action_args['follow_id']
if 'query' in action_args:
simplified_args['query'] = action_args['query']
if 'like_id' in action_args:
simplified_args['like_id'] = action_args['like_id']
if 'dislike_id' in action_args:
simplified_args['dislike_id'] = action_args['dislike_id']
action_type = ACTION_TYPE_MAP.get(action, action.upper())
# Enrich with context such as post content and author name.
_enrich_action_context(cursor, action_type, simplified_args, agent_names)
actions.append({
'agent_id': user_id,
'agent_name': agent_names.get(user_id, f'Agent_{user_id}'),
'action_type': action_type,
'action_args': simplified_args,
})
conn.close()
except Exception as e:
print(f"读取数据库动作失败: {e}")
return actions, new_last_rowid
def _enrich_action_context(
cursor,
action_type: str,
action_args: Dict[str, Any],
agent_names: Dict[int, str]
) -> None:
"""Enrich an action's args with context such as post content and author name.
Args:
cursor: database cursor.
action_type: action type.
action_args: action args (mutated in place).
agent_names: ``agent_id -> agent_name`` mapping.
"""
try:
# Like/dislike post: include the post content and author name.
if action_type in ('LIKE_POST', 'DISLIKE_POST'):
post_id = action_args.get('post_id')
if post_id:
post_info = _get_post_info(cursor, post_id, agent_names)
if post_info:
action_args['post_content'] = post_info.get('content', '')
action_args['post_author_name'] = post_info.get('author_name', '')
# Repost: include the original post content and author name.
elif action_type == 'REPOST':
new_post_id = action_args.get('new_post_id')
if new_post_id:
# On a repost row, ``original_post_id`` points at the original post.
cursor.execute("""
SELECT original_post_id FROM post WHERE post_id = ?
""", (new_post_id,))
row = cursor.fetchone()
if row and row[0]:
original_post_id = row[0]
original_info = _get_post_info(cursor, original_post_id, agent_names)
if original_info:
action_args['original_content'] = original_info.get('content', '')
action_args['original_author_name'] = original_info.get('author_name', '')
# Quote post: include the original post content, author name, and quote comment.
elif action_type == 'QUOTE_POST':
quoted_id = action_args.get('quoted_id')
new_post_id = action_args.get('new_post_id')
if quoted_id:
original_info = _get_post_info(cursor, quoted_id, agent_names)
if original_info:
action_args['original_content'] = original_info.get('content', '')
action_args['original_author_name'] = original_info.get('author_name', '')
# Read the quote comment (``quote_content``).
if new_post_id:
cursor.execute("""
SELECT quote_content FROM post WHERE post_id = ?
""", (new_post_id,))
row = cursor.fetchone()
if row and row[0]:
action_args['quote_content'] = row[0]
# Follow: include the followee's display name.
elif action_type == 'FOLLOW':
follow_id = action_args.get('follow_id')
if follow_id:
# Look up ``followee_id`` from the ``follow`` table.
cursor.execute("""
SELECT followee_id FROM follow WHERE follow_id = ?
""", (follow_id,))
row = cursor.fetchone()
if row:
followee_id = row[0]
target_name = _get_user_name(cursor, followee_id, agent_names)
if target_name:
action_args['target_user_name'] = target_name
# Mute: include the muted user's display name.
elif action_type == 'MUTE':
# Read ``user_id`` or ``target_id`` from action_args.
target_id = action_args.get('user_id') or action_args.get('target_id')
if target_id:
target_name = _get_user_name(cursor, target_id, agent_names)
if target_name:
action_args['target_user_name'] = target_name
# Like/dislike comment: include the comment content and author name.
elif action_type in ('LIKE_COMMENT', 'DISLIKE_COMMENT'):
comment_id = action_args.get('comment_id')
if comment_id:
comment_info = _get_comment_info(cursor, comment_id, agent_names)
if comment_info:
action_args['comment_content'] = comment_info.get('content', '')
action_args['comment_author_name'] = comment_info.get('author_name', '')
# Create comment: include the parent post's content and author name.
elif action_type == 'CREATE_COMMENT':
post_id = action_args.get('post_id')
if post_id:
post_info = _get_post_info(cursor, post_id, agent_names)
if post_info:
action_args['post_content'] = post_info.get('content', '')
action_args['post_author_name'] = post_info.get('author_name', '')
except Exception as e:
# Failing to enrich context must not break the main flow.
print(f"补充动作上下文失败: {e}")
def _get_post_info(
cursor,
post_id: int,
agent_names: Dict[int, str]
) -> Optional[Dict[str, str]]:
"""Look up post info.
Args:
cursor: database cursor.
post_id: post identifier.
agent_names: ``agent_id -> agent_name`` mapping.
Returns:
Dict with ``content`` and ``author_name``, or ``None`` when not found.
"""
try:
cursor.execute("""
SELECT p.content, p.user_id, u.agent_id
FROM post p
LEFT JOIN user u ON p.user_id = u.user_id
WHERE p.post_id = ?
""", (post_id,))
row = cursor.fetchone()
if row:
content = row[0] or ''
user_id = row[1]
agent_id = row[2]
# Prefer the entity_name supplied via agent_names.
author_name = ''
if agent_id is not None and agent_id in agent_names:
author_name = agent_names[agent_id]
elif user_id:
# Fall back to the user table.
cursor.execute("SELECT name, user_name FROM user WHERE user_id = ?", (user_id,))
user_row = cursor.fetchone()
if user_row:
author_name = user_row[0] or user_row[1] or ''
return {'content': content, 'author_name': author_name}
except Exception:
pass
return None
def _get_user_name(
cursor,
user_id: int,
agent_names: Dict[int, str]
) -> Optional[str]:
"""Look up a user's display name.
Args:
cursor: database cursor.
user_id: user identifier.
agent_names: ``agent_id -> agent_name`` mapping.
Returns:
Display name, or ``None`` when the user cannot be found.
"""
try:
cursor.execute("""
SELECT agent_id, name, user_name FROM user WHERE user_id = ?
""", (user_id,))
row = cursor.fetchone()
if row:
agent_id = row[0]
name = row[1]
user_name = row[2]
# Prefer the entity_name supplied via agent_names.
if agent_id is not None and agent_id in agent_names:
return agent_names[agent_id]
return name or user_name or ''
except Exception:
pass
return None
def _get_comment_info(
cursor,
comment_id: int,
agent_names: Dict[int, str]
) -> Optional[Dict[str, str]]:
"""Look up comment info.
Args:
cursor: database cursor.
comment_id: comment identifier.
agent_names: ``agent_id -> agent_name`` mapping.
Returns:
Dict with ``content`` and ``author_name``, or ``None`` when not found.
"""
try:
cursor.execute("""
SELECT c.content, c.user_id, u.agent_id
FROM comment c
LEFT JOIN user u ON c.user_id = u.user_id
WHERE c.comment_id = ?
""", (comment_id,))
row = cursor.fetchone()
if row:
content = row[0] or ''
user_id = row[1]
agent_id = row[2]
# Prefer the entity_name supplied via agent_names.
author_name = ''
if agent_id is not None and agent_id in agent_names:
author_name = agent_names[agent_id]
elif user_id:
# Fall back to the user table.
cursor.execute("SELECT name, user_name FROM user WHERE user_id = ?", (user_id,))
user_row = cursor.fetchone()
if user_row:
author_name = user_row[0] or user_row[1] or ''
return {'content': content, 'author_name': author_name}
except Exception:
pass
return None
def create_model(config: Dict[str, Any], use_boost: bool = False):
"""Create the LLM model used by the simulation.
Two LLM configurations are supported, which lets parallel simulations run faster:
- default: ``LLM_API_KEY``, ``LLM_BASE_URL``, ``LLM_MODEL_NAME``.
- boost (optional): ``LLM_BOOST_API_KEY``, ``LLM_BOOST_BASE_URL``, ``LLM_BOOST_MODEL_NAME``.
When a boost LLM is configured, the two platforms can target different API
providers, increasing overall concurrency.
Args:
config: simulation config dict.
use_boost: whether to use the boost LLM config when available.
"""
# Inspect the boost configuration.
boost_api_key = os.environ.get("LLM_BOOST_API_KEY", "")
boost_base_url = os.environ.get("LLM_BOOST_BASE_URL", "")
boost_model = os.environ.get("LLM_BOOST_MODEL_NAME", "")
has_boost_config = bool(boost_api_key)
# Choose which LLM to use based on the request and what is configured.
if use_boost and has_boost_config:
# Use the boost configuration.
llm_api_key = boost_api_key
llm_base_url = boost_base_url
llm_model = boost_model or os.environ.get("LLM_MODEL_NAME", "")
config_label = "[加速LLM]"
else:
# Use the default configuration.
llm_api_key = os.environ.get("LLM_API_KEY", "")
llm_base_url = os.environ.get("LLM_BASE_URL", "")
llm_model = os.environ.get("LLM_MODEL_NAME", "")
config_label = "[通用LLM]"
# Fall back to the model name in the config when .env does not provide one.
if not llm_model:
llm_model = config.get("llm_model", "gpt-4o-mini")
# Populate the env vars camel-ai expects.
if llm_api_key:
os.environ["OPENAI_API_KEY"] = llm_api_key
if not os.environ.get("OPENAI_API_KEY"):
raise ValueError("缺少 API Key 配置,请在项目根目录 .env 文件中设置 LLM_API_KEY")
if llm_base_url:
os.environ["OPENAI_API_BASE_URL"] = llm_base_url
print(f"{config_label} model={llm_model}, base_url={llm_base_url[:40] if llm_base_url else '默认'}...")
return ModelFactory.create(
model_platform=ModelPlatformType.OPENAI,
model_type=llm_model,
)
def get_active_agents_for_round(
env,
config: Dict[str, Any],
current_hour: int,
round_num: int
) -> List:
"""Decide which agents are active in this round based on time and config."""
time_config = config.get("time_config", {})
agent_configs = config.get("agent_configs", [])
base_min = time_config.get("agents_per_hour_min", 5)
base_max = time_config.get("agents_per_hour_max", 20)
peak_hours = time_config.get("peak_hours", [9, 10, 11, 14, 15, 20, 21, 22])
off_peak_hours = time_config.get("off_peak_hours", [0, 1, 2, 3, 4, 5])
if current_hour in peak_hours:
multiplier = time_config.get("peak_activity_multiplier", 1.5)
elif current_hour in off_peak_hours:
multiplier = time_config.get("off_peak_activity_multiplier", 0.3)
else:
multiplier = 1.0
target_count = int(random.uniform(base_min, base_max) * multiplier)
candidates = []
for cfg in agent_configs:
agent_id = cfg.get("agent_id", 0)
active_hours = cfg.get("active_hours", list(range(8, 23)))
activity_level = cfg.get("activity_level", 0.5)
if current_hour not in active_hours:
continue
if random.random() < activity_level:
candidates.append(agent_id)
selected_ids = random.sample(
candidates,
min(target_count, len(candidates))
) if candidates else []
active_agents = []
for agent_id in selected_ids:
try:
agent = env.agent_graph.get_agent(agent_id)
active_agents.append((agent_id, agent))
except Exception:
pass
return active_agents
class PlatformSimulation:
"""Container for the result of a platform simulation."""
def __init__(self):
self.env = None
self.agent_graph = None
self.total_actions = 0
async def run_twitter_simulation(
config: Dict[str, Any],
simulation_dir: str,
action_logger: Optional[PlatformActionLogger] = None,
main_logger: Optional[SimulationLogManager] = None,
max_rounds: Optional[int] = None
) -> PlatformSimulation:
"""Run the Twitter simulation.
Args:
config: simulation config.
simulation_dir: simulation directory.
action_logger: action logger.
main_logger: main log manager.
max_rounds: optional cap on the number of rounds, used to truncate long runs.
Returns:
PlatformSimulation containing the env and agent_graph.
"""
result = PlatformSimulation()
def log_info(msg):
if main_logger:
main_logger.info(f"[Twitter] {msg}")
print(f"[Twitter] {msg}")
log_info("初始化...")
# Twitter uses the default LLM config.
model = create_model(config, use_boost=False)
# OASIS Twitter expects a CSV profile file.
profile_path = os.path.join(simulation_dir, "twitter_profiles.csv")
if not os.path.exists(profile_path):
log_info(f"错误: Profile文件不存在: {profile_path}")
return result
result.agent_graph = await generate_twitter_agent_graph(
profile_path=profile_path,
model=model,
available_actions=TWITTER_ACTIONS,
)
# Pull real agent names from the config (use entity_name rather than the default Agent_X).
agent_names = get_agent_names_from_config(config)
# If the config does not list a particular agent, fall back to OASIS's default name.
for agent_id, agent in result.agent_graph.get_agents():
if agent_id not in agent_names:
agent_names[agent_id] = getattr(agent, 'name', f'Agent_{agent_id}')
db_path = os.path.join(simulation_dir, "twitter_simulation.db")
if os.path.exists(db_path):
os.remove(db_path)
result.env = oasis.make(
agent_graph=result.agent_graph,
platform=oasis.DefaultPlatformType.TWITTER,
database_path=db_path,
semaphore=30, # cap concurrent LLM requests to avoid overloading the API
)
await result.env.reset()
log_info("环境已启动")
if action_logger:
action_logger.log_simulation_start(config)
total_actions = 0
last_rowid = 0 # last processed db row; using rowid avoids created_at format differences
# Run the initial events.
event_config = config.get("event_config", {})
initial_posts = event_config.get("initial_posts", [])
# Mark the start of round 0 (the initial-events phase).
if action_logger:
action_logger.log_round_start(0, 0) # round 0, simulated_hour 0
initial_action_count = 0
if initial_posts:
initial_actions = {}
for post in initial_posts:
agent_id = post.get("poster_agent_id", 0)
content = post.get("content", "")
try:
agent = result.env.agent_graph.get_agent(agent_id)
initial_actions[agent] = ManualAction(
action_type=ActionType.CREATE_POST,
action_args={"content": content}
)
if action_logger:
action_logger.log_action(
round_num=0,
agent_id=agent_id,
agent_name=agent_names.get(agent_id, f"Agent_{agent_id}"),
action_type="CREATE_POST",
action_args={"content": content}
)
total_actions += 1
initial_action_count += 1
except Exception:
pass
if initial_actions:
await result.env.step(initial_actions)
log_info(f"已发布 {len(initial_actions)} 条初始帖子")
# Mark the end of round 0.
if action_logger:
action_logger.log_round_end(0, initial_action_count)
# Main simulation loop.
time_config = config.get("time_config", {})
total_hours = time_config.get("total_simulation_hours", 72)
minutes_per_round = time_config.get("minutes_per_round", 30)
total_rounds = (total_hours * 60) // minutes_per_round
# Truncate when a max round count was supplied.
if max_rounds is not None and max_rounds > 0:
original_rounds = total_rounds
total_rounds = min(total_rounds, max_rounds)
if total_rounds < original_rounds:
log_info(f"轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})")
start_time = datetime.now()
for round_num in range(total_rounds):
# Bail out if a shutdown signal was received.
if _shutdown_event and _shutdown_event.is_set():
if main_logger:
main_logger.info(f"收到退出信号,在第 {round_num + 1} 轮停止模拟")
break
simulated_minutes = round_num * minutes_per_round
simulated_hour = (simulated_minutes // 60) % 24
simulated_day = simulated_minutes // (60 * 24) + 1
active_agents = get_active_agents_for_round(
result.env, config, simulated_hour, round_num
)
# Always log round-start, even when no agents are active.
if action_logger:
action_logger.log_round_start(round_num + 1, simulated_hour)
if not active_agents:
# Still emit round-end (with actions_count=0) so the log stays consistent.
if action_logger:
action_logger.log_round_end(round_num + 1, 0)
continue
actions = {agent: LLMAction() for _, agent in active_agents}
await result.env.step(actions)
# Pull the actually-executed actions from the database and log them.
actual_actions, last_rowid = fetch_new_actions_from_db(
db_path, last_rowid, agent_names
)
round_action_count = 0
for action_data in actual_actions:
if action_logger:
action_logger.log_action(
round_num=round_num + 1,
agent_id=action_data['agent_id'],
agent_name=action_data['agent_name'],
action_type=action_data['action_type'],
action_args=action_data['action_args']
)
total_actions += 1
round_action_count += 1
if action_logger:
action_logger.log_round_end(round_num + 1, round_action_count)
if (round_num + 1) % 20 == 0:
progress = (round_num + 1) / total_rounds * 100
log_info(f"Day {simulated_day}, {simulated_hour:02d}:00 - Round {round_num + 1}/{total_rounds} ({progress:.1f}%)")
# Note: do NOT close the env here; we keep it alive for Interview commands.
if action_logger:
action_logger.log_simulation_end(total_rounds, total_actions)
result.total_actions = total_actions
elapsed = (datetime.now() - start_time).total_seconds()
log_info(f"模拟循环完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}")
return result
async def run_reddit_simulation(
config: Dict[str, Any],
simulation_dir: str,
action_logger: Optional[PlatformActionLogger] = None,
main_logger: Optional[SimulationLogManager] = None,
max_rounds: Optional[int] = None
) -> PlatformSimulation:
"""Run the Reddit simulation.
Args:
config: simulation config.
simulation_dir: simulation directory.
action_logger: action logger.
main_logger: main log manager.
max_rounds: optional cap on the number of rounds, used to truncate long runs.
Returns:
PlatformSimulation containing the env and agent_graph.
"""
result = PlatformSimulation()
def log_info(msg):
if main_logger:
main_logger.info(f"[Reddit] {msg}")
print(f"[Reddit] {msg}")
log_info("初始化...")
# Reddit uses the boost LLM config when available, falling back to the default.
model = create_model(config, use_boost=True)
profile_path = os.path.join(simulation_dir, "reddit_profiles.json")
if not os.path.exists(profile_path):
log_info(f"错误: Profile文件不存在: {profile_path}")
return result
result.agent_graph = await generate_reddit_agent_graph(
profile_path=profile_path,
model=model,
available_actions=REDDIT_ACTIONS,
)
# Pull real agent names from the config (use entity_name rather than the default Agent_X).
agent_names = get_agent_names_from_config(config)
# If the config does not list a particular agent, fall back to OASIS's default name.
for agent_id, agent in result.agent_graph.get_agents():
if agent_id not in agent_names:
agent_names[agent_id] = getattr(agent, 'name', f'Agent_{agent_id}')
db_path = os.path.join(simulation_dir, "reddit_simulation.db")
if os.path.exists(db_path):
os.remove(db_path)
result.env = oasis.make(
agent_graph=result.agent_graph,
platform=oasis.DefaultPlatformType.REDDIT,
database_path=db_path,
semaphore=30, # cap concurrent LLM requests to avoid overloading the API
)
await result.env.reset()
log_info("环境已启动")
if action_logger:
action_logger.log_simulation_start(config)
total_actions = 0
last_rowid = 0 # last processed db row; using rowid avoids created_at format differences
# Run the initial events.
event_config = config.get("event_config", {})
initial_posts = event_config.get("initial_posts", [])
# Mark the start of round 0 (the initial-events phase).
if action_logger:
action_logger.log_round_start(0, 0) # round 0, simulated_hour 0
initial_action_count = 0
if initial_posts:
initial_actions = {}
for post in initial_posts:
agent_id = post.get("poster_agent_id", 0)
content = post.get("content", "")
try:
agent = result.env.agent_graph.get_agent(agent_id)
if agent in initial_actions:
if not isinstance(initial_actions[agent], list):
initial_actions[agent] = [initial_actions[agent]]
initial_actions[agent].append(ManualAction(
action_type=ActionType.CREATE_POST,
action_args={"content": content}
))
else:
initial_actions[agent] = ManualAction(
action_type=ActionType.CREATE_POST,
action_args={"content": content}
)
if action_logger:
action_logger.log_action(
round_num=0,
agent_id=agent_id,
agent_name=agent_names.get(agent_id, f"Agent_{agent_id}"),
action_type="CREATE_POST",
action_args={"content": content}
)
total_actions += 1
initial_action_count += 1
except Exception:
pass
if initial_actions:
await result.env.step(initial_actions)
log_info(f"已发布 {len(initial_actions)} 条初始帖子")
# Mark the end of round 0.
if action_logger:
action_logger.log_round_end(0, initial_action_count)
# Main simulation loop.
time_config = config.get("time_config", {})
total_hours = time_config.get("total_simulation_hours", 72)
minutes_per_round = time_config.get("minutes_per_round", 30)
total_rounds = (total_hours * 60) // minutes_per_round
# Truncate when a max round count was supplied.
if max_rounds is not None and max_rounds > 0:
original_rounds = total_rounds
total_rounds = min(total_rounds, max_rounds)
if total_rounds < original_rounds:
log_info(f"轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})")
start_time = datetime.now()
for round_num in range(total_rounds):
# Bail out if a shutdown signal was received.
if _shutdown_event and _shutdown_event.is_set():
if main_logger:
main_logger.info(f"收到退出信号,在第 {round_num + 1} 轮停止模拟")
break
simulated_minutes = round_num * minutes_per_round
simulated_hour = (simulated_minutes // 60) % 24
simulated_day = simulated_minutes // (60 * 24) + 1
active_agents = get_active_agents_for_round(
result.env, config, simulated_hour, round_num
)
# Always log round-start, even when no agents are active.
if action_logger:
action_logger.log_round_start(round_num + 1, simulated_hour)
if not active_agents:
# Still emit round-end (with actions_count=0) so the log stays consistent.
if action_logger:
action_logger.log_round_end(round_num + 1, 0)
continue
actions = {agent: LLMAction() for _, agent in active_agents}
await result.env.step(actions)
# Pull the actually-executed actions from the database and log them.
actual_actions, last_rowid = fetch_new_actions_from_db(
db_path, last_rowid, agent_names
)
round_action_count = 0
for action_data in actual_actions:
if action_logger:
action_logger.log_action(
round_num=round_num + 1,
agent_id=action_data['agent_id'],
agent_name=action_data['agent_name'],
action_type=action_data['action_type'],
action_args=action_data['action_args']
)
total_actions += 1
round_action_count += 1
if action_logger:
action_logger.log_round_end(round_num + 1, round_action_count)
if (round_num + 1) % 20 == 0:
progress = (round_num + 1) / total_rounds * 100
log_info(f"Day {simulated_day}, {simulated_hour:02d}:00 - Round {round_num + 1}/{total_rounds} ({progress:.1f}%)")
# Note: do NOT close the env here; we keep it alive for Interview commands.
if action_logger:
action_logger.log_simulation_end(total_rounds, total_actions)
result.total_actions = total_actions
elapsed = (datetime.now() - start_time).total_seconds()
log_info(f"模拟循环完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}")
return result
async def main():
parser = argparse.ArgumentParser(description='OASIS双平台并行模拟')
parser.add_argument(
'--config',
type=str,
required=True,
help='配置文件路径 (simulation_config.json)'
)
parser.add_argument(
'--twitter-only',
action='store_true',
help='只运行Twitter模拟'
)
parser.add_argument(
'--reddit-only',
action='store_true',
help='只运行Reddit模拟'
)
parser.add_argument(
'--max-rounds',
type=int,
default=None,
help='最大模拟轮数(可选,用于截断过长的模拟)'
)
parser.add_argument(
'--no-wait',
action='store_true',
default=False,
help='模拟完成后立即关闭环境,不进入等待命令模式'
)
args = parser.parse_args()
# Create the shutdown event at the start of main() so the whole program
# can respond to exit signals.
global _shutdown_event
_shutdown_event = asyncio.Event()
if not os.path.exists(args.config):
print(f"错误: 配置文件不存在: {args.config}")
sys.exit(1)
config = load_config(args.config)
simulation_dir = os.path.dirname(args.config) or "."
wait_for_commands = not args.no_wait
# Initialize logging (disable OASIS logs, clean up stale files).
init_logging_for_simulation(simulation_dir)
# Create the log manager.
log_manager = SimulationLogManager(simulation_dir)
twitter_logger = log_manager.get_twitter_logger()
reddit_logger = log_manager.get_reddit_logger()
log_manager.info("=" * 60)
log_manager.info("OASIS 双平台并行模拟")
log_manager.info(f"配置文件: {args.config}")
log_manager.info(f"模拟ID: {config.get('simulation_id', 'unknown')}")
log_manager.info(f"等待命令模式: {'启用' if wait_for_commands else '禁用'}")
log_manager.info("=" * 60)
time_config = config.get("time_config", {})
total_hours = time_config.get('total_simulation_hours', 72)
minutes_per_round = time_config.get('minutes_per_round', 30)
config_total_rounds = (total_hours * 60) // minutes_per_round
log_manager.info(f"模拟参数:")
log_manager.info(f" - 总模拟时长: {total_hours}小时")
log_manager.info(f" - 每轮时间: {minutes_per_round}分钟")
log_manager.info(f" - 配置总轮数: {config_total_rounds}")
if args.max_rounds:
log_manager.info(f" - 最大轮数限制: {args.max_rounds}")
if args.max_rounds < config_total_rounds:
log_manager.info(f" - 实际执行轮数: {args.max_rounds} (已截断)")
log_manager.info(f" - Agent数量: {len(config.get('agent_configs', []))}")
log_manager.info("日志结构:")
log_manager.info(f" - 主日志: simulation.log")
log_manager.info(f" - Twitter动作: twitter/actions.jsonl")
log_manager.info(f" - Reddit动作: reddit/actions.jsonl")
log_manager.info("=" * 60)
start_time = datetime.now()
# Holds the result for each platform simulation.
twitter_result: Optional[PlatformSimulation] = None
reddit_result: Optional[PlatformSimulation] = None
if args.twitter_only:
twitter_result = await run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager, args.max_rounds)
elif args.reddit_only:
reddit_result = await run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager, args.max_rounds)
else:
# Run both platforms in parallel; each platform uses its own logger.
results = await asyncio.gather(
run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager, args.max_rounds),
run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager, args.max_rounds),
)
twitter_result, reddit_result = results
total_elapsed = (datetime.now() - start_time).total_seconds()
log_manager.info("=" * 60)
log_manager.info(f"模拟循环完成! 总耗时: {total_elapsed:.1f}")
# Enter wait-for-command mode if requested.
if wait_for_commands:
log_manager.info("")
log_manager.info("=" * 60)
log_manager.info("进入等待命令模式 - 环境保持运行")
log_manager.info("支持的命令: interview, batch_interview, close_env")
log_manager.info("=" * 60)
# Create the IPC handler.
ipc_handler = ParallelIPCHandler(
simulation_dir=simulation_dir,
twitter_env=twitter_result.env if twitter_result else None,
twitter_agent_graph=twitter_result.agent_graph if twitter_result else None,
reddit_env=reddit_result.env if reddit_result else None,
reddit_agent_graph=reddit_result.agent_graph if reddit_result else None
)
ipc_handler.update_status("alive")
# Command-wait loop (driven by the global ``_shutdown_event``).
try:
while not _shutdown_event.is_set():
should_continue = await ipc_handler.process_commands()
if not should_continue:
break
# Use ``wait_for`` instead of ``sleep`` so the loop reacts to shutdown_event.
try:
await asyncio.wait_for(_shutdown_event.wait(), timeout=0.5)
break # shutdown signal received
except asyncio.TimeoutError:
pass # timed out, continue looping
except KeyboardInterrupt:
print("\n收到中断信号")
except asyncio.CancelledError:
print("\n任务被取消")
except Exception as e:
print(f"\n命令处理出错: {e}")
log_manager.info("\n关闭环境...")
ipc_handler.update_status("stopped")
# Close the environments.
if twitter_result and twitter_result.env:
await twitter_result.env.close()
log_manager.info("[Twitter] 环境已关闭")
if reddit_result and reddit_result.env:
await reddit_result.env.close()
log_manager.info("[Reddit] 环境已关闭")
log_manager.info("=" * 60)
log_manager.info(f"全部完成!")
log_manager.info(f"日志文件:")
log_manager.info(f" - {os.path.join(simulation_dir, 'simulation.log')}")
log_manager.info(f" - {os.path.join(simulation_dir, 'twitter', 'actions.jsonl')}")
log_manager.info(f" - {os.path.join(simulation_dir, 'reddit', 'actions.jsonl')}")
log_manager.info("=" * 60)
def setup_signal_handlers(loop=None):
"""Install signal handlers that exit cleanly on SIGTERM/SIGINT.
Persistent-simulation flow: the process keeps running after the simulation
finishes so it can serve interview commands. On a termination signal we:
1. Tell the asyncio loop to stop waiting.
2. Give the program a chance to clean up (close databases, envs, ...).
3. Then exit.
"""
def signal_handler(signum, frame):
global _cleanup_done
sig_name = "SIGTERM" if signum == signal.SIGTERM else "SIGINT"
print(f"\n收到 {sig_name} 信号,正在退出...")
if not _cleanup_done:
_cleanup_done = True
# Notify the asyncio loop to exit so it can clean up resources.
if _shutdown_event:
_shutdown_event.set()
# Avoid sys.exit() on the first signal: let the asyncio loop exit cleanly.
# Only force-exit if a second signal comes in.
else:
print("强制退出...")
sys.exit(1)
signal.signal(signal.SIGTERM, signal_handler)
signal.signal(signal.SIGINT, signal_handler)
if __name__ == "__main__":
setup_signal_handlers()
try:
asyncio.run(main())
except KeyboardInterrupt:
print("\n程序被中断")
except SystemExit:
pass
finally:
# Clean up the multiprocessing resource tracker to avoid exit warnings.
try:
from multiprocessing import resource_tracker
resource_tracker._resource_tracker._stop()
except Exception:
pass
print("模拟进程已退出")