1693 lines
62 KiB
Python
1693 lines
62 KiB
Python
"""OASIS dual-platform parallel simulation preset script.
|
|
|
|
Runs Twitter and Reddit simulations simultaneously, reading the same config file.
|
|
|
|
Features:
|
|
- Dual-platform (Twitter + Reddit) parallel simulation
|
|
- Keeps environments alive after the simulation finishes and enters wait-for-command mode
|
|
- Receives Interview commands via IPC
|
|
- Supports single-agent and batch interviews
|
|
- Supports a remote close-environment command
|
|
|
|
Usage:
|
|
python run_parallel_simulation.py --config simulation_config.json
|
|
python run_parallel_simulation.py --config simulation_config.json --no-wait # close immediately when done
|
|
python run_parallel_simulation.py --config simulation_config.json --twitter-only
|
|
python run_parallel_simulation.py --config simulation_config.json --reddit-only
|
|
|
|
Log layout:
|
|
sim_xxx/
|
|
├── twitter/
|
|
│ └── actions.jsonl # Twitter platform action log
|
|
├── reddit/
|
|
│ └── actions.jsonl # Reddit platform action log
|
|
├── simulation.log # main simulation process log
|
|
└── run_state.json # run state (used by API queries)
|
|
"""
|
|
|
|
# ============================================================
|
|
# Fix the Windows encoding issue by forcing UTF-8 before any import.
|
|
# This works around the OASIS third-party library opening files without
|
|
# specifying an encoding.
|
|
# ============================================================
|
|
import sys
|
|
import os
|
|
|
|
if sys.platform == 'win32':
|
|
# Set Python's default I/O encoding to UTF-8 so every open() call without
|
|
# an explicit encoding picks it up.
|
|
os.environ.setdefault('PYTHONUTF8', '1')
|
|
os.environ.setdefault('PYTHONIOENCODING', 'utf-8')
|
|
|
|
# Reconfigure stdout/stderr to UTF-8 to avoid mojibake in the console.
|
|
if hasattr(sys.stdout, 'reconfigure'):
|
|
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
|
|
if hasattr(sys.stderr, 'reconfigure'):
|
|
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
|
|
|
|
# Force the default encoding used by open(). The env-var approach above
|
|
# only works when set at interpreter startup, so we additionally
|
|
# monkey-patch the built-in open().
|
|
import builtins
|
|
_original_open = builtins.open
|
|
|
|
def _utf8_open(file, mode='r', buffering=-1, encoding=None, errors=None,
|
|
newline=None, closefd=True, opener=None):
|
|
"""Wrap open() so text-mode calls default to UTF-8.
|
|
|
|
Fixes third-party libraries (such as OASIS) that open files without
|
|
specifying an encoding.
|
|
"""
|
|
# Only override when the caller is using text mode and did not request
|
|
# an explicit encoding.
|
|
if encoding is None and 'b' not in mode:
|
|
encoding = 'utf-8'
|
|
return _original_open(file, mode, buffering, encoding, errors,
|
|
newline, closefd, opener)
|
|
|
|
builtins.open = _utf8_open
|
|
|
|
import argparse
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import multiprocessing
|
|
import random
|
|
import signal
|
|
import sqlite3
|
|
import warnings
|
|
from datetime import datetime
|
|
from typing import Dict, Any, List, Optional, Tuple
|
|
|
|
|
|
# Globals used by the signal handlers.
|
|
_shutdown_event = None
|
|
_cleanup_done = False
|
|
|
|
# Add the backend directory to sys.path. The script always lives in
|
|
# backend/scripts/.
|
|
_scripts_dir = os.path.dirname(os.path.abspath(__file__))
|
|
_backend_dir = os.path.abspath(os.path.join(_scripts_dir, '..'))
|
|
_project_root = os.path.abspath(os.path.join(_backend_dir, '..'))
|
|
sys.path.insert(0, _scripts_dir)
|
|
sys.path.insert(0, _backend_dir)
|
|
|
|
# Load the .env from the project root (contains LLM_API_KEY etc.).
|
|
from dotenv import load_dotenv
|
|
_env_file = os.path.join(_project_root, '.env')
|
|
if os.path.exists(_env_file):
|
|
load_dotenv(_env_file)
|
|
print(f"已加载环境配置: {_env_file}")
|
|
else:
|
|
# Fall back to backend/.env.
|
|
_backend_env = os.path.join(_backend_dir, '.env')
|
|
if os.path.exists(_backend_env):
|
|
load_dotenv(_backend_env)
|
|
print(f"已加载环境配置: {_backend_env}")
|
|
|
|
|
|
class MaxTokensWarningFilter(logging.Filter):
|
|
"""Suppress camel-ai max_tokens warnings.
|
|
|
|
We intentionally leave max_tokens unset so the model decides; the warning is noise.
|
|
"""
|
|
|
|
def filter(self, record):
|
|
if "max_tokens" in record.getMessage() and "Invalid or missing" in record.getMessage():
|
|
return False
|
|
return True
|
|
|
|
|
|
# Install the filter at import time so it is active before any camel code runs.
|
|
logging.getLogger().addFilter(MaxTokensWarningFilter())
|
|
|
|
|
|
def disable_oasis_logging():
|
|
"""Disable verbose OASIS library logging.
|
|
|
|
OASIS logs every agent observation and action which is extremely noisy; we
|
|
rely on our own action_logger instead.
|
|
"""
|
|
oasis_loggers = [
|
|
"social.agent",
|
|
"social.twitter",
|
|
"social.rec",
|
|
"oasis.env",
|
|
"table",
|
|
]
|
|
|
|
for logger_name in oasis_loggers:
|
|
logger = logging.getLogger(logger_name)
|
|
logger.setLevel(logging.CRITICAL) # only keep severe errors
|
|
logger.handlers.clear()
|
|
logger.propagate = False
|
|
|
|
|
|
def init_logging_for_simulation(simulation_dir: str):
|
|
"""Initialize logging for a simulation run.
|
|
|
|
Args:
|
|
simulation_dir: path to the simulation directory.
|
|
"""
|
|
disable_oasis_logging()
|
|
|
|
# Clean up any pre-existing log directory.
|
|
old_log_dir = os.path.join(simulation_dir, "log")
|
|
if os.path.exists(old_log_dir):
|
|
import shutil
|
|
shutil.rmtree(old_log_dir, ignore_errors=True)
|
|
|
|
|
|
from action_logger import SimulationLogManager, PlatformActionLogger
|
|
|
|
try:
|
|
from camel.models import ModelFactory
|
|
from camel.types import ModelPlatformType
|
|
import oasis
|
|
from oasis import (
|
|
ActionType,
|
|
LLMAction,
|
|
ManualAction,
|
|
generate_twitter_agent_graph,
|
|
generate_reddit_agent_graph
|
|
)
|
|
except ImportError as e:
|
|
print(f"错误: 缺少依赖 {e}")
|
|
print("请先安装: pip install oasis-ai camel-ai")
|
|
sys.exit(1)
|
|
|
|
|
|
# Twitter actions available to agents. INTERVIEW is excluded because it can only
|
|
# be triggered manually via ManualAction.
|
|
TWITTER_ACTIONS = [
|
|
ActionType.CREATE_POST,
|
|
ActionType.LIKE_POST,
|
|
ActionType.REPOST,
|
|
ActionType.FOLLOW,
|
|
ActionType.DO_NOTHING,
|
|
ActionType.QUOTE_POST,
|
|
]
|
|
|
|
# Reddit actions available to agents. INTERVIEW is excluded because it can only
|
|
# be triggered manually via ManualAction.
|
|
REDDIT_ACTIONS = [
|
|
ActionType.LIKE_POST,
|
|
ActionType.DISLIKE_POST,
|
|
ActionType.CREATE_POST,
|
|
ActionType.CREATE_COMMENT,
|
|
ActionType.LIKE_COMMENT,
|
|
ActionType.DISLIKE_COMMENT,
|
|
ActionType.SEARCH_POSTS,
|
|
ActionType.SEARCH_USER,
|
|
ActionType.TREND,
|
|
ActionType.REFRESH,
|
|
ActionType.DO_NOTHING,
|
|
ActionType.FOLLOW,
|
|
ActionType.MUTE,
|
|
]
|
|
|
|
|
|
# IPC-related constants.
|
|
IPC_COMMANDS_DIR = "ipc_commands"
|
|
IPC_RESPONSES_DIR = "ipc_responses"
|
|
ENV_STATUS_FILE = "env_status.json"
|
|
|
|
class CommandType:
|
|
"""Command type constants."""
|
|
INTERVIEW = "interview"
|
|
BATCH_INTERVIEW = "batch_interview"
|
|
CLOSE_ENV = "close_env"
|
|
|
|
|
|
class ParallelIPCHandler:
|
|
"""Dual-platform IPC command handler.
|
|
|
|
Manages both platform environments and processes Interview commands.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
simulation_dir: str,
|
|
twitter_env=None,
|
|
twitter_agent_graph=None,
|
|
reddit_env=None,
|
|
reddit_agent_graph=None
|
|
):
|
|
self.simulation_dir = simulation_dir
|
|
self.twitter_env = twitter_env
|
|
self.twitter_agent_graph = twitter_agent_graph
|
|
self.reddit_env = reddit_env
|
|
self.reddit_agent_graph = reddit_agent_graph
|
|
|
|
self.commands_dir = os.path.join(simulation_dir, IPC_COMMANDS_DIR)
|
|
self.responses_dir = os.path.join(simulation_dir, IPC_RESPONSES_DIR)
|
|
self.status_file = os.path.join(simulation_dir, ENV_STATUS_FILE)
|
|
|
|
os.makedirs(self.commands_dir, exist_ok=True)
|
|
os.makedirs(self.responses_dir, exist_ok=True)
|
|
|
|
def update_status(self, status: str):
|
|
"""Update the recorded environment status."""
|
|
with open(self.status_file, 'w', encoding='utf-8') as f:
|
|
json.dump({
|
|
"status": status,
|
|
"twitter_available": self.twitter_env is not None,
|
|
"reddit_available": self.reddit_env is not None,
|
|
"timestamp": datetime.now().isoformat()
|
|
}, f, ensure_ascii=False, indent=2)
|
|
|
|
def poll_command(self) -> Optional[Dict[str, Any]]:
|
|
"""Poll for the next pending command."""
|
|
if not os.path.exists(self.commands_dir):
|
|
return None
|
|
|
|
# Collect command files sorted by mtime so older commands run first.
|
|
command_files = []
|
|
for filename in os.listdir(self.commands_dir):
|
|
if filename.endswith('.json'):
|
|
filepath = os.path.join(self.commands_dir, filename)
|
|
command_files.append((filepath, os.path.getmtime(filepath)))
|
|
|
|
command_files.sort(key=lambda x: x[1])
|
|
|
|
for filepath, _ in command_files:
|
|
try:
|
|
with open(filepath, 'r', encoding='utf-8') as f:
|
|
return json.load(f)
|
|
except (json.JSONDecodeError, OSError):
|
|
continue
|
|
|
|
return None
|
|
|
|
def send_response(self, command_id: str, status: str, result: Dict = None, error: str = None):
|
|
"""Send a response for a previously dispatched command."""
|
|
response = {
|
|
"command_id": command_id,
|
|
"status": status,
|
|
"result": result,
|
|
"error": error,
|
|
"timestamp": datetime.now().isoformat()
|
|
}
|
|
|
|
response_file = os.path.join(self.responses_dir, f"{command_id}.json")
|
|
with open(response_file, 'w', encoding='utf-8') as f:
|
|
json.dump(response, f, ensure_ascii=False, indent=2)
|
|
|
|
# Remove the original command file once a response is recorded.
|
|
command_file = os.path.join(self.commands_dir, f"{command_id}.json")
|
|
try:
|
|
os.remove(command_file)
|
|
except OSError:
|
|
pass
|
|
|
|
def _get_env_and_graph(self, platform: str):
|
|
"""Return the environment and agent graph for the given platform.
|
|
|
|
Args:
|
|
platform: platform name ("twitter" or "reddit").
|
|
|
|
Returns:
|
|
Tuple ``(env, agent_graph, platform_name)`` or ``(None, None, None)``
|
|
when the platform is unavailable.
|
|
"""
|
|
if platform == "twitter" and self.twitter_env:
|
|
return self.twitter_env, self.twitter_agent_graph, "twitter"
|
|
elif platform == "reddit" and self.reddit_env:
|
|
return self.reddit_env, self.reddit_agent_graph, "reddit"
|
|
else:
|
|
return None, None, None
|
|
|
|
async def _interview_single_platform(self, agent_id: int, prompt: str, platform: str) -> Dict[str, Any]:
|
|
"""Run an Interview on a single platform.
|
|
|
|
Returns:
|
|
A dict with the interview result, or a dict containing an ``error`` key.
|
|
"""
|
|
env, agent_graph, actual_platform = self._get_env_and_graph(platform)
|
|
|
|
if not env or not agent_graph:
|
|
return {"platform": platform, "error": f"{platform}平台不可用"}
|
|
|
|
try:
|
|
agent = agent_graph.get_agent(agent_id)
|
|
interview_action = ManualAction(
|
|
action_type=ActionType.INTERVIEW,
|
|
action_args={"prompt": prompt}
|
|
)
|
|
actions = {agent: interview_action}
|
|
await env.step(actions)
|
|
|
|
result = self._get_interview_result(agent_id, actual_platform)
|
|
result["platform"] = actual_platform
|
|
return result
|
|
|
|
except Exception as e:
|
|
return {"platform": platform, "error": str(e)}
|
|
|
|
async def handle_interview(self, command_id: str, agent_id: int, prompt: str, platform: str = None) -> bool:
|
|
"""Handle a single-agent interview command.
|
|
|
|
Args:
|
|
command_id: command identifier.
|
|
agent_id: agent identifier.
|
|
prompt: interview prompt.
|
|
platform: optional platform selector.
|
|
- "twitter": interview on Twitter only.
|
|
- "reddit": interview on Reddit only.
|
|
- ``None``: interview on both platforms and return a merged result.
|
|
|
|
Returns:
|
|
``True`` on success, ``False`` on failure.
|
|
"""
|
|
# If a specific platform was requested, only interview on that platform.
|
|
if platform in ("twitter", "reddit"):
|
|
result = await self._interview_single_platform(agent_id, prompt, platform)
|
|
|
|
if "error" in result:
|
|
self.send_response(command_id, "failed", error=result["error"])
|
|
print(f" Interview失败: agent_id={agent_id}, platform={platform}, error={result['error']}")
|
|
return False
|
|
else:
|
|
self.send_response(command_id, "completed", result=result)
|
|
print(f" Interview完成: agent_id={agent_id}, platform={platform}")
|
|
return True
|
|
|
|
# No platform specified: interview on both platforms simultaneously.
|
|
if not self.twitter_env and not self.reddit_env:
|
|
self.send_response(command_id, "failed", error="没有可用的模拟环境")
|
|
return False
|
|
|
|
results = {
|
|
"agent_id": agent_id,
|
|
"prompt": prompt,
|
|
"platforms": {}
|
|
}
|
|
success_count = 0
|
|
|
|
# Run the two platform interviews in parallel.
|
|
tasks = []
|
|
platforms_to_interview = []
|
|
|
|
if self.twitter_env:
|
|
tasks.append(self._interview_single_platform(agent_id, prompt, "twitter"))
|
|
platforms_to_interview.append("twitter")
|
|
|
|
if self.reddit_env:
|
|
tasks.append(self._interview_single_platform(agent_id, prompt, "reddit"))
|
|
platforms_to_interview.append("reddit")
|
|
|
|
platform_results = await asyncio.gather(*tasks)
|
|
|
|
for platform_name, platform_result in zip(platforms_to_interview, platform_results):
|
|
results["platforms"][platform_name] = platform_result
|
|
if "error" not in platform_result:
|
|
success_count += 1
|
|
|
|
if success_count > 0:
|
|
self.send_response(command_id, "completed", result=results)
|
|
print(f" Interview完成: agent_id={agent_id}, 成功平台数={success_count}/{len(platforms_to_interview)}")
|
|
return True
|
|
else:
|
|
errors = [f"{p}: {r.get('error', '未知错误')}" for p, r in results["platforms"].items()]
|
|
self.send_response(command_id, "failed", error="; ".join(errors))
|
|
print(f" Interview失败: agent_id={agent_id}, 所有平台都失败")
|
|
return False
|
|
|
|
async def handle_batch_interview(self, command_id: str, interviews: List[Dict], platform: str = None) -> bool:
|
|
"""Handle a batch-interview command.
|
|
|
|
Args:
|
|
command_id: command identifier.
|
|
interviews: ``[{"agent_id": int, "prompt": str, "platform": str(optional)}, ...]``.
|
|
platform: default platform (can be overridden per interview entry).
|
|
- "twitter": interview on Twitter only.
|
|
- "reddit": interview on Reddit only.
|
|
- ``None``: interview every agent on both platforms.
|
|
"""
|
|
# Bucket interviews by target platform.
|
|
twitter_interviews = []
|
|
reddit_interviews = []
|
|
both_platforms_interviews = [] # entries that need both platforms
|
|
|
|
for interview in interviews:
|
|
item_platform = interview.get("platform", platform)
|
|
if item_platform == "twitter":
|
|
twitter_interviews.append(interview)
|
|
elif item_platform == "reddit":
|
|
reddit_interviews.append(interview)
|
|
else:
|
|
# No platform specified: interview on both.
|
|
both_platforms_interviews.append(interview)
|
|
|
|
# Fan the both-platform entries out into the per-platform buckets.
|
|
if both_platforms_interviews:
|
|
if self.twitter_env:
|
|
twitter_interviews.extend(both_platforms_interviews)
|
|
if self.reddit_env:
|
|
reddit_interviews.extend(both_platforms_interviews)
|
|
|
|
results = {}
|
|
|
|
# Run the Twitter-side interviews.
|
|
if twitter_interviews and self.twitter_env:
|
|
try:
|
|
twitter_actions = {}
|
|
for interview in twitter_interviews:
|
|
agent_id = interview.get("agent_id")
|
|
prompt = interview.get("prompt", "")
|
|
try:
|
|
agent = self.twitter_agent_graph.get_agent(agent_id)
|
|
twitter_actions[agent] = ManualAction(
|
|
action_type=ActionType.INTERVIEW,
|
|
action_args={"prompt": prompt}
|
|
)
|
|
except Exception as e:
|
|
print(f" 警告: 无法获取Twitter Agent {agent_id}: {e}")
|
|
|
|
if twitter_actions:
|
|
await self.twitter_env.step(twitter_actions)
|
|
|
|
for interview in twitter_interviews:
|
|
agent_id = interview.get("agent_id")
|
|
result = self._get_interview_result(agent_id, "twitter")
|
|
result["platform"] = "twitter"
|
|
results[f"twitter_{agent_id}"] = result
|
|
except Exception as e:
|
|
print(f" Twitter批量Interview失败: {e}")
|
|
|
|
# Run the Reddit-side interviews.
|
|
if reddit_interviews and self.reddit_env:
|
|
try:
|
|
reddit_actions = {}
|
|
for interview in reddit_interviews:
|
|
agent_id = interview.get("agent_id")
|
|
prompt = interview.get("prompt", "")
|
|
try:
|
|
agent = self.reddit_agent_graph.get_agent(agent_id)
|
|
reddit_actions[agent] = ManualAction(
|
|
action_type=ActionType.INTERVIEW,
|
|
action_args={"prompt": prompt}
|
|
)
|
|
except Exception as e:
|
|
print(f" 警告: 无法获取Reddit Agent {agent_id}: {e}")
|
|
|
|
if reddit_actions:
|
|
await self.reddit_env.step(reddit_actions)
|
|
|
|
for interview in reddit_interviews:
|
|
agent_id = interview.get("agent_id")
|
|
result = self._get_interview_result(agent_id, "reddit")
|
|
result["platform"] = "reddit"
|
|
results[f"reddit_{agent_id}"] = result
|
|
except Exception as e:
|
|
print(f" Reddit批量Interview失败: {e}")
|
|
|
|
if results:
|
|
self.send_response(command_id, "completed", result={
|
|
"interviews_count": len(results),
|
|
"results": results
|
|
})
|
|
print(f" 批量Interview完成: {len(results)} 个Agent")
|
|
return True
|
|
else:
|
|
self.send_response(command_id, "failed", error="没有成功的采访")
|
|
return False
|
|
|
|
def _get_interview_result(self, agent_id: int, platform: str) -> Dict[str, Any]:
|
|
"""Read the latest Interview result for an agent from the database."""
|
|
db_path = os.path.join(self.simulation_dir, f"{platform}_simulation.db")
|
|
|
|
result = {
|
|
"agent_id": agent_id,
|
|
"response": None,
|
|
"timestamp": None
|
|
}
|
|
|
|
if not os.path.exists(db_path):
|
|
return result
|
|
|
|
try:
|
|
conn = sqlite3.connect(db_path)
|
|
cursor = conn.cursor()
|
|
|
|
# Look up the most recent Interview row for this agent.
|
|
cursor.execute("""
|
|
SELECT user_id, info, created_at
|
|
FROM trace
|
|
WHERE action = ? AND user_id = ?
|
|
ORDER BY created_at DESC
|
|
LIMIT 1
|
|
""", (ActionType.INTERVIEW.value, agent_id))
|
|
|
|
row = cursor.fetchone()
|
|
if row:
|
|
user_id, info_json, created_at = row
|
|
try:
|
|
info = json.loads(info_json) if info_json else {}
|
|
result["response"] = info.get("response", info)
|
|
result["timestamp"] = created_at
|
|
except json.JSONDecodeError:
|
|
result["response"] = info_json
|
|
|
|
conn.close()
|
|
|
|
except Exception as e:
|
|
print(f" 读取Interview结果失败: {e}")
|
|
|
|
return result
|
|
|
|
async def process_commands(self) -> bool:
|
|
"""Process all pending commands.
|
|
|
|
Returns:
|
|
``True`` to keep running, ``False`` if the process should exit.
|
|
"""
|
|
command = self.poll_command()
|
|
if not command:
|
|
return True
|
|
|
|
command_id = command.get("command_id")
|
|
command_type = command.get("command_type")
|
|
args = command.get("args", {})
|
|
|
|
print(f"\n收到IPC命令: {command_type}, id={command_id}")
|
|
|
|
if command_type == CommandType.INTERVIEW:
|
|
await self.handle_interview(
|
|
command_id,
|
|
args.get("agent_id", 0),
|
|
args.get("prompt", ""),
|
|
args.get("platform")
|
|
)
|
|
return True
|
|
|
|
elif command_type == CommandType.BATCH_INTERVIEW:
|
|
await self.handle_batch_interview(
|
|
command_id,
|
|
args.get("interviews", []),
|
|
args.get("platform")
|
|
)
|
|
return True
|
|
|
|
elif command_type == CommandType.CLOSE_ENV:
|
|
print("收到关闭环境命令")
|
|
self.send_response(command_id, "completed", result={"message": "环境即将关闭"})
|
|
return False
|
|
|
|
else:
|
|
self.send_response(command_id, "failed", error=f"未知命令类型: {command_type}")
|
|
return True
|
|
|
|
|
|
def load_config(config_path: str) -> Dict[str, Any]:
|
|
"""Load a JSON config file from disk."""
|
|
with open(config_path, 'r', encoding='utf-8') as f:
|
|
return json.load(f)
|
|
|
|
|
|
# Non-core action types to filter out: they provide little analytical value.
|
|
FILTERED_ACTIONS = {'refresh', 'sign_up'}
|
|
|
|
# Action-type mapping (database name -> canonical name).
|
|
ACTION_TYPE_MAP = {
|
|
'create_post': 'CREATE_POST',
|
|
'like_post': 'LIKE_POST',
|
|
'dislike_post': 'DISLIKE_POST',
|
|
'repost': 'REPOST',
|
|
'quote_post': 'QUOTE_POST',
|
|
'follow': 'FOLLOW',
|
|
'mute': 'MUTE',
|
|
'create_comment': 'CREATE_COMMENT',
|
|
'like_comment': 'LIKE_COMMENT',
|
|
'dislike_comment': 'DISLIKE_COMMENT',
|
|
'search_posts': 'SEARCH_POSTS',
|
|
'search_user': 'SEARCH_USER',
|
|
'trend': 'TREND',
|
|
'do_nothing': 'DO_NOTHING',
|
|
'interview': 'INTERVIEW',
|
|
}
|
|
|
|
|
|
def get_agent_names_from_config(config: Dict[str, Any]) -> Dict[int, str]:
|
|
"""Build an ``agent_id -> entity_name`` map from the simulation config.
|
|
|
|
Using the entity name lets actions.jsonl display the real entity rather
|
|
than placeholder labels like ``Agent_0``.
|
|
|
|
Args:
|
|
config: contents of ``simulation_config.json``.
|
|
|
|
Returns:
|
|
Mapping from agent id to entity name.
|
|
"""
|
|
agent_names = {}
|
|
agent_configs = config.get("agent_configs", [])
|
|
|
|
for agent_config in agent_configs:
|
|
agent_id = agent_config.get("agent_id")
|
|
entity_name = agent_config.get("entity_name", f"Agent_{agent_id}")
|
|
if agent_id is not None:
|
|
agent_names[agent_id] = entity_name
|
|
|
|
return agent_names
|
|
|
|
|
|
def fetch_new_actions_from_db(
|
|
db_path: str,
|
|
last_rowid: int,
|
|
agent_names: Dict[int, str]
|
|
) -> Tuple[List[Dict[str, Any]], int]:
|
|
"""Fetch new action rows from the database and enrich them with context.
|
|
|
|
Args:
|
|
db_path: path to the database file.
|
|
last_rowid: highest rowid processed previously. We track ``rowid``
|
|
rather than ``created_at`` because the two platforms use different
|
|
``created_at`` formats.
|
|
agent_names: ``agent_id -> agent_name`` mapping.
|
|
|
|
Returns:
|
|
Tuple ``(actions_list, new_last_rowid)``.
|
|
- ``actions_list``: action records, each containing ``agent_id``,
|
|
``agent_name``, ``action_type``, and ``action_args`` (with context).
|
|
- ``new_last_rowid``: the new highest rowid seen.
|
|
"""
|
|
actions = []
|
|
new_last_rowid = last_rowid
|
|
|
|
if not os.path.exists(db_path):
|
|
return actions, new_last_rowid
|
|
|
|
try:
|
|
conn = sqlite3.connect(db_path)
|
|
cursor = conn.cursor()
|
|
|
|
# Use ``rowid`` to track processed rows. ``rowid`` is SQLite's built-in
|
|
# auto-increment column and avoids the cross-platform ``created_at``
|
|
# format mismatch (Twitter stores integers, Reddit stores datetime strings).
|
|
cursor.execute("""
|
|
SELECT rowid, user_id, action, info
|
|
FROM trace
|
|
WHERE rowid > ?
|
|
ORDER BY rowid ASC
|
|
""", (last_rowid,))
|
|
|
|
for rowid, user_id, action, info_json in cursor.fetchall():
|
|
new_last_rowid = rowid
|
|
|
|
if action in FILTERED_ACTIONS:
|
|
continue
|
|
|
|
try:
|
|
action_args = json.loads(info_json) if info_json else {}
|
|
except json.JSONDecodeError:
|
|
action_args = {}
|
|
|
|
# Slim ``action_args`` down to the key fields. Content is kept in full (no truncation).
|
|
simplified_args = {}
|
|
if 'content' in action_args:
|
|
simplified_args['content'] = action_args['content']
|
|
if 'post_id' in action_args:
|
|
simplified_args['post_id'] = action_args['post_id']
|
|
if 'comment_id' in action_args:
|
|
simplified_args['comment_id'] = action_args['comment_id']
|
|
if 'quoted_id' in action_args:
|
|
simplified_args['quoted_id'] = action_args['quoted_id']
|
|
if 'new_post_id' in action_args:
|
|
simplified_args['new_post_id'] = action_args['new_post_id']
|
|
if 'follow_id' in action_args:
|
|
simplified_args['follow_id'] = action_args['follow_id']
|
|
if 'query' in action_args:
|
|
simplified_args['query'] = action_args['query']
|
|
if 'like_id' in action_args:
|
|
simplified_args['like_id'] = action_args['like_id']
|
|
if 'dislike_id' in action_args:
|
|
simplified_args['dislike_id'] = action_args['dislike_id']
|
|
|
|
action_type = ACTION_TYPE_MAP.get(action, action.upper())
|
|
|
|
# Enrich with context such as post content and author name.
|
|
_enrich_action_context(cursor, action_type, simplified_args, agent_names)
|
|
|
|
actions.append({
|
|
'agent_id': user_id,
|
|
'agent_name': agent_names.get(user_id, f'Agent_{user_id}'),
|
|
'action_type': action_type,
|
|
'action_args': simplified_args,
|
|
})
|
|
|
|
conn.close()
|
|
except Exception as e:
|
|
print(f"读取数据库动作失败: {e}")
|
|
|
|
return actions, new_last_rowid
|
|
|
|
|
|
def _enrich_action_context(
|
|
cursor,
|
|
action_type: str,
|
|
action_args: Dict[str, Any],
|
|
agent_names: Dict[int, str]
|
|
) -> None:
|
|
"""Enrich an action's args with context such as post content and author name.
|
|
|
|
Args:
|
|
cursor: database cursor.
|
|
action_type: action type.
|
|
action_args: action args (mutated in place).
|
|
agent_names: ``agent_id -> agent_name`` mapping.
|
|
"""
|
|
try:
|
|
# Like/dislike post: include the post content and author name.
|
|
if action_type in ('LIKE_POST', 'DISLIKE_POST'):
|
|
post_id = action_args.get('post_id')
|
|
if post_id:
|
|
post_info = _get_post_info(cursor, post_id, agent_names)
|
|
if post_info:
|
|
action_args['post_content'] = post_info.get('content', '')
|
|
action_args['post_author_name'] = post_info.get('author_name', '')
|
|
|
|
# Repost: include the original post content and author name.
|
|
elif action_type == 'REPOST':
|
|
new_post_id = action_args.get('new_post_id')
|
|
if new_post_id:
|
|
# On a repost row, ``original_post_id`` points at the original post.
|
|
cursor.execute("""
|
|
SELECT original_post_id FROM post WHERE post_id = ?
|
|
""", (new_post_id,))
|
|
row = cursor.fetchone()
|
|
if row and row[0]:
|
|
original_post_id = row[0]
|
|
original_info = _get_post_info(cursor, original_post_id, agent_names)
|
|
if original_info:
|
|
action_args['original_content'] = original_info.get('content', '')
|
|
action_args['original_author_name'] = original_info.get('author_name', '')
|
|
|
|
# Quote post: include the original post content, author name, and quote comment.
|
|
elif action_type == 'QUOTE_POST':
|
|
quoted_id = action_args.get('quoted_id')
|
|
new_post_id = action_args.get('new_post_id')
|
|
|
|
if quoted_id:
|
|
original_info = _get_post_info(cursor, quoted_id, agent_names)
|
|
if original_info:
|
|
action_args['original_content'] = original_info.get('content', '')
|
|
action_args['original_author_name'] = original_info.get('author_name', '')
|
|
|
|
# Read the quote comment (``quote_content``).
|
|
if new_post_id:
|
|
cursor.execute("""
|
|
SELECT quote_content FROM post WHERE post_id = ?
|
|
""", (new_post_id,))
|
|
row = cursor.fetchone()
|
|
if row and row[0]:
|
|
action_args['quote_content'] = row[0]
|
|
|
|
# Follow: include the followee's display name.
|
|
elif action_type == 'FOLLOW':
|
|
follow_id = action_args.get('follow_id')
|
|
if follow_id:
|
|
# Look up ``followee_id`` from the ``follow`` table.
|
|
cursor.execute("""
|
|
SELECT followee_id FROM follow WHERE follow_id = ?
|
|
""", (follow_id,))
|
|
row = cursor.fetchone()
|
|
if row:
|
|
followee_id = row[0]
|
|
target_name = _get_user_name(cursor, followee_id, agent_names)
|
|
if target_name:
|
|
action_args['target_user_name'] = target_name
|
|
|
|
# Mute: include the muted user's display name.
|
|
elif action_type == 'MUTE':
|
|
# Read ``user_id`` or ``target_id`` from action_args.
|
|
target_id = action_args.get('user_id') or action_args.get('target_id')
|
|
if target_id:
|
|
target_name = _get_user_name(cursor, target_id, agent_names)
|
|
if target_name:
|
|
action_args['target_user_name'] = target_name
|
|
|
|
# Like/dislike comment: include the comment content and author name.
|
|
elif action_type in ('LIKE_COMMENT', 'DISLIKE_COMMENT'):
|
|
comment_id = action_args.get('comment_id')
|
|
if comment_id:
|
|
comment_info = _get_comment_info(cursor, comment_id, agent_names)
|
|
if comment_info:
|
|
action_args['comment_content'] = comment_info.get('content', '')
|
|
action_args['comment_author_name'] = comment_info.get('author_name', '')
|
|
|
|
# Create comment: include the parent post's content and author name.
|
|
elif action_type == 'CREATE_COMMENT':
|
|
post_id = action_args.get('post_id')
|
|
if post_id:
|
|
post_info = _get_post_info(cursor, post_id, agent_names)
|
|
if post_info:
|
|
action_args['post_content'] = post_info.get('content', '')
|
|
action_args['post_author_name'] = post_info.get('author_name', '')
|
|
|
|
except Exception as e:
|
|
# Failing to enrich context must not break the main flow.
|
|
print(f"补充动作上下文失败: {e}")
|
|
|
|
|
|
def _get_post_info(
|
|
cursor,
|
|
post_id: int,
|
|
agent_names: Dict[int, str]
|
|
) -> Optional[Dict[str, str]]:
|
|
"""Look up post info.
|
|
|
|
Args:
|
|
cursor: database cursor.
|
|
post_id: post identifier.
|
|
agent_names: ``agent_id -> agent_name`` mapping.
|
|
|
|
Returns:
|
|
Dict with ``content`` and ``author_name``, or ``None`` when not found.
|
|
"""
|
|
try:
|
|
cursor.execute("""
|
|
SELECT p.content, p.user_id, u.agent_id
|
|
FROM post p
|
|
LEFT JOIN user u ON p.user_id = u.user_id
|
|
WHERE p.post_id = ?
|
|
""", (post_id,))
|
|
row = cursor.fetchone()
|
|
if row:
|
|
content = row[0] or ''
|
|
user_id = row[1]
|
|
agent_id = row[2]
|
|
|
|
# Prefer the entity_name supplied via agent_names.
|
|
author_name = ''
|
|
if agent_id is not None and agent_id in agent_names:
|
|
author_name = agent_names[agent_id]
|
|
elif user_id:
|
|
# Fall back to the user table.
|
|
cursor.execute("SELECT name, user_name FROM user WHERE user_id = ?", (user_id,))
|
|
user_row = cursor.fetchone()
|
|
if user_row:
|
|
author_name = user_row[0] or user_row[1] or ''
|
|
|
|
return {'content': content, 'author_name': author_name}
|
|
except Exception:
|
|
pass
|
|
return None
|
|
|
|
|
|
def _get_user_name(
|
|
cursor,
|
|
user_id: int,
|
|
agent_names: Dict[int, str]
|
|
) -> Optional[str]:
|
|
"""Look up a user's display name.
|
|
|
|
Args:
|
|
cursor: database cursor.
|
|
user_id: user identifier.
|
|
agent_names: ``agent_id -> agent_name`` mapping.
|
|
|
|
Returns:
|
|
Display name, or ``None`` when the user cannot be found.
|
|
"""
|
|
try:
|
|
cursor.execute("""
|
|
SELECT agent_id, name, user_name FROM user WHERE user_id = ?
|
|
""", (user_id,))
|
|
row = cursor.fetchone()
|
|
if row:
|
|
agent_id = row[0]
|
|
name = row[1]
|
|
user_name = row[2]
|
|
|
|
# Prefer the entity_name supplied via agent_names.
|
|
if agent_id is not None and agent_id in agent_names:
|
|
return agent_names[agent_id]
|
|
return name or user_name or ''
|
|
except Exception:
|
|
pass
|
|
return None
|
|
|
|
|
|
def _get_comment_info(
|
|
cursor,
|
|
comment_id: int,
|
|
agent_names: Dict[int, str]
|
|
) -> Optional[Dict[str, str]]:
|
|
"""Look up comment info.
|
|
|
|
Args:
|
|
cursor: database cursor.
|
|
comment_id: comment identifier.
|
|
agent_names: ``agent_id -> agent_name`` mapping.
|
|
|
|
Returns:
|
|
Dict with ``content`` and ``author_name``, or ``None`` when not found.
|
|
"""
|
|
try:
|
|
cursor.execute("""
|
|
SELECT c.content, c.user_id, u.agent_id
|
|
FROM comment c
|
|
LEFT JOIN user u ON c.user_id = u.user_id
|
|
WHERE c.comment_id = ?
|
|
""", (comment_id,))
|
|
row = cursor.fetchone()
|
|
if row:
|
|
content = row[0] or ''
|
|
user_id = row[1]
|
|
agent_id = row[2]
|
|
|
|
# Prefer the entity_name supplied via agent_names.
|
|
author_name = ''
|
|
if agent_id is not None and agent_id in agent_names:
|
|
author_name = agent_names[agent_id]
|
|
elif user_id:
|
|
# Fall back to the user table.
|
|
cursor.execute("SELECT name, user_name FROM user WHERE user_id = ?", (user_id,))
|
|
user_row = cursor.fetchone()
|
|
if user_row:
|
|
author_name = user_row[0] or user_row[1] or ''
|
|
|
|
return {'content': content, 'author_name': author_name}
|
|
except Exception:
|
|
pass
|
|
return None
|
|
|
|
|
|
def create_model(config: Dict[str, Any], use_boost: bool = False):
|
|
"""Create the LLM model used by the simulation.
|
|
|
|
Two LLM configurations are supported, which lets parallel simulations run faster:
|
|
- default: ``LLM_API_KEY``, ``LLM_BASE_URL``, ``LLM_MODEL_NAME``.
|
|
- boost (optional): ``LLM_BOOST_API_KEY``, ``LLM_BOOST_BASE_URL``, ``LLM_BOOST_MODEL_NAME``.
|
|
|
|
When a boost LLM is configured, the two platforms can target different API
|
|
providers, increasing overall concurrency.
|
|
|
|
Args:
|
|
config: simulation config dict.
|
|
use_boost: whether to use the boost LLM config when available.
|
|
"""
|
|
# Inspect the boost configuration.
|
|
boost_api_key = os.environ.get("LLM_BOOST_API_KEY", "")
|
|
boost_base_url = os.environ.get("LLM_BOOST_BASE_URL", "")
|
|
boost_model = os.environ.get("LLM_BOOST_MODEL_NAME", "")
|
|
has_boost_config = bool(boost_api_key)
|
|
|
|
# Choose which LLM to use based on the request and what is configured.
|
|
if use_boost and has_boost_config:
|
|
# Use the boost configuration.
|
|
llm_api_key = boost_api_key
|
|
llm_base_url = boost_base_url
|
|
llm_model = boost_model or os.environ.get("LLM_MODEL_NAME", "")
|
|
config_label = "[加速LLM]"
|
|
else:
|
|
# Use the default configuration.
|
|
llm_api_key = os.environ.get("LLM_API_KEY", "")
|
|
llm_base_url = os.environ.get("LLM_BASE_URL", "")
|
|
llm_model = os.environ.get("LLM_MODEL_NAME", "")
|
|
config_label = "[通用LLM]"
|
|
|
|
# Fall back to the model name in the config when .env does not provide one.
|
|
if not llm_model:
|
|
llm_model = config.get("llm_model", "gpt-4o-mini")
|
|
|
|
# Populate the env vars camel-ai expects.
|
|
if llm_api_key:
|
|
os.environ["OPENAI_API_KEY"] = llm_api_key
|
|
|
|
if not os.environ.get("OPENAI_API_KEY"):
|
|
raise ValueError("缺少 API Key 配置,请在项目根目录 .env 文件中设置 LLM_API_KEY")
|
|
|
|
if llm_base_url:
|
|
os.environ["OPENAI_API_BASE_URL"] = llm_base_url
|
|
|
|
print(f"{config_label} model={llm_model}, base_url={llm_base_url[:40] if llm_base_url else '默认'}...")
|
|
|
|
return ModelFactory.create(
|
|
model_platform=ModelPlatformType.OPENAI,
|
|
model_type=llm_model,
|
|
)
|
|
|
|
|
|
def get_active_agents_for_round(
|
|
env,
|
|
config: Dict[str, Any],
|
|
current_hour: int,
|
|
round_num: int
|
|
) -> List:
|
|
"""Decide which agents are active in this round based on time and config."""
|
|
time_config = config.get("time_config", {})
|
|
agent_configs = config.get("agent_configs", [])
|
|
|
|
base_min = time_config.get("agents_per_hour_min", 5)
|
|
base_max = time_config.get("agents_per_hour_max", 20)
|
|
|
|
peak_hours = time_config.get("peak_hours", [9, 10, 11, 14, 15, 20, 21, 22])
|
|
off_peak_hours = time_config.get("off_peak_hours", [0, 1, 2, 3, 4, 5])
|
|
|
|
if current_hour in peak_hours:
|
|
multiplier = time_config.get("peak_activity_multiplier", 1.5)
|
|
elif current_hour in off_peak_hours:
|
|
multiplier = time_config.get("off_peak_activity_multiplier", 0.3)
|
|
else:
|
|
multiplier = 1.0
|
|
|
|
target_count = int(random.uniform(base_min, base_max) * multiplier)
|
|
|
|
candidates = []
|
|
for cfg in agent_configs:
|
|
agent_id = cfg.get("agent_id", 0)
|
|
active_hours = cfg.get("active_hours", list(range(8, 23)))
|
|
activity_level = cfg.get("activity_level", 0.5)
|
|
|
|
if current_hour not in active_hours:
|
|
continue
|
|
|
|
if random.random() < activity_level:
|
|
candidates.append(agent_id)
|
|
|
|
selected_ids = random.sample(
|
|
candidates,
|
|
min(target_count, len(candidates))
|
|
) if candidates else []
|
|
|
|
active_agents = []
|
|
for agent_id in selected_ids:
|
|
try:
|
|
agent = env.agent_graph.get_agent(agent_id)
|
|
active_agents.append((agent_id, agent))
|
|
except Exception:
|
|
pass
|
|
|
|
return active_agents
|
|
|
|
|
|
class PlatformSimulation:
|
|
"""Container for the result of a platform simulation."""
|
|
def __init__(self):
|
|
self.env = None
|
|
self.agent_graph = None
|
|
self.total_actions = 0
|
|
|
|
|
|
async def run_twitter_simulation(
|
|
config: Dict[str, Any],
|
|
simulation_dir: str,
|
|
action_logger: Optional[PlatformActionLogger] = None,
|
|
main_logger: Optional[SimulationLogManager] = None,
|
|
max_rounds: Optional[int] = None
|
|
) -> PlatformSimulation:
|
|
"""Run the Twitter simulation.
|
|
|
|
Args:
|
|
config: simulation config.
|
|
simulation_dir: simulation directory.
|
|
action_logger: action logger.
|
|
main_logger: main log manager.
|
|
max_rounds: optional cap on the number of rounds, used to truncate long runs.
|
|
|
|
Returns:
|
|
PlatformSimulation containing the env and agent_graph.
|
|
"""
|
|
result = PlatformSimulation()
|
|
|
|
def log_info(msg):
|
|
if main_logger:
|
|
main_logger.info(f"[Twitter] {msg}")
|
|
print(f"[Twitter] {msg}")
|
|
|
|
log_info("初始化...")
|
|
|
|
# Twitter uses the default LLM config.
|
|
model = create_model(config, use_boost=False)
|
|
|
|
# OASIS Twitter expects a CSV profile file.
|
|
profile_path = os.path.join(simulation_dir, "twitter_profiles.csv")
|
|
if not os.path.exists(profile_path):
|
|
log_info(f"错误: Profile文件不存在: {profile_path}")
|
|
return result
|
|
|
|
result.agent_graph = await generate_twitter_agent_graph(
|
|
profile_path=profile_path,
|
|
model=model,
|
|
available_actions=TWITTER_ACTIONS,
|
|
)
|
|
|
|
# Pull real agent names from the config (use entity_name rather than the default Agent_X).
|
|
agent_names = get_agent_names_from_config(config)
|
|
# If the config does not list a particular agent, fall back to OASIS's default name.
|
|
for agent_id, agent in result.agent_graph.get_agents():
|
|
if agent_id not in agent_names:
|
|
agent_names[agent_id] = getattr(agent, 'name', f'Agent_{agent_id}')
|
|
|
|
db_path = os.path.join(simulation_dir, "twitter_simulation.db")
|
|
if os.path.exists(db_path):
|
|
os.remove(db_path)
|
|
|
|
result.env = oasis.make(
|
|
agent_graph=result.agent_graph,
|
|
platform=oasis.DefaultPlatformType.TWITTER,
|
|
database_path=db_path,
|
|
semaphore=30, # cap concurrent LLM requests to avoid overloading the API
|
|
)
|
|
|
|
await result.env.reset()
|
|
log_info("环境已启动")
|
|
|
|
if action_logger:
|
|
action_logger.log_simulation_start(config)
|
|
|
|
total_actions = 0
|
|
last_rowid = 0 # last processed db row; using rowid avoids created_at format differences
|
|
|
|
# Run the initial events.
|
|
event_config = config.get("event_config", {})
|
|
initial_posts = event_config.get("initial_posts", [])
|
|
|
|
# Mark the start of round 0 (the initial-events phase).
|
|
if action_logger:
|
|
action_logger.log_round_start(0, 0) # round 0, simulated_hour 0
|
|
|
|
initial_action_count = 0
|
|
if initial_posts:
|
|
initial_actions = {}
|
|
for post in initial_posts:
|
|
agent_id = post.get("poster_agent_id", 0)
|
|
content = post.get("content", "")
|
|
try:
|
|
agent = result.env.agent_graph.get_agent(agent_id)
|
|
initial_actions[agent] = ManualAction(
|
|
action_type=ActionType.CREATE_POST,
|
|
action_args={"content": content}
|
|
)
|
|
|
|
if action_logger:
|
|
action_logger.log_action(
|
|
round_num=0,
|
|
agent_id=agent_id,
|
|
agent_name=agent_names.get(agent_id, f"Agent_{agent_id}"),
|
|
action_type="CREATE_POST",
|
|
action_args={"content": content}
|
|
)
|
|
total_actions += 1
|
|
initial_action_count += 1
|
|
except Exception:
|
|
pass
|
|
|
|
if initial_actions:
|
|
await result.env.step(initial_actions)
|
|
log_info(f"已发布 {len(initial_actions)} 条初始帖子")
|
|
|
|
# Mark the end of round 0.
|
|
if action_logger:
|
|
action_logger.log_round_end(0, initial_action_count)
|
|
|
|
# Main simulation loop.
|
|
time_config = config.get("time_config", {})
|
|
total_hours = time_config.get("total_simulation_hours", 72)
|
|
minutes_per_round = time_config.get("minutes_per_round", 30)
|
|
total_rounds = (total_hours * 60) // minutes_per_round
|
|
|
|
# Truncate when a max round count was supplied.
|
|
if max_rounds is not None and max_rounds > 0:
|
|
original_rounds = total_rounds
|
|
total_rounds = min(total_rounds, max_rounds)
|
|
if total_rounds < original_rounds:
|
|
log_info(f"轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})")
|
|
|
|
start_time = datetime.now()
|
|
|
|
for round_num in range(total_rounds):
|
|
# Bail out if a shutdown signal was received.
|
|
if _shutdown_event and _shutdown_event.is_set():
|
|
if main_logger:
|
|
main_logger.info(f"收到退出信号,在第 {round_num + 1} 轮停止模拟")
|
|
break
|
|
|
|
simulated_minutes = round_num * minutes_per_round
|
|
simulated_hour = (simulated_minutes // 60) % 24
|
|
simulated_day = simulated_minutes // (60 * 24) + 1
|
|
|
|
active_agents = get_active_agents_for_round(
|
|
result.env, config, simulated_hour, round_num
|
|
)
|
|
|
|
# Always log round-start, even when no agents are active.
|
|
if action_logger:
|
|
action_logger.log_round_start(round_num + 1, simulated_hour)
|
|
|
|
if not active_agents:
|
|
# Still emit round-end (with actions_count=0) so the log stays consistent.
|
|
if action_logger:
|
|
action_logger.log_round_end(round_num + 1, 0)
|
|
continue
|
|
|
|
actions = {agent: LLMAction() for _, agent in active_agents}
|
|
await result.env.step(actions)
|
|
|
|
# Pull the actually-executed actions from the database and log them.
|
|
actual_actions, last_rowid = fetch_new_actions_from_db(
|
|
db_path, last_rowid, agent_names
|
|
)
|
|
|
|
round_action_count = 0
|
|
for action_data in actual_actions:
|
|
if action_logger:
|
|
action_logger.log_action(
|
|
round_num=round_num + 1,
|
|
agent_id=action_data['agent_id'],
|
|
agent_name=action_data['agent_name'],
|
|
action_type=action_data['action_type'],
|
|
action_args=action_data['action_args']
|
|
)
|
|
total_actions += 1
|
|
round_action_count += 1
|
|
|
|
if action_logger:
|
|
action_logger.log_round_end(round_num + 1, round_action_count)
|
|
|
|
if (round_num + 1) % 20 == 0:
|
|
progress = (round_num + 1) / total_rounds * 100
|
|
log_info(f"Day {simulated_day}, {simulated_hour:02d}:00 - Round {round_num + 1}/{total_rounds} ({progress:.1f}%)")
|
|
|
|
# Note: do NOT close the env here; we keep it alive for Interview commands.
|
|
|
|
if action_logger:
|
|
action_logger.log_simulation_end(total_rounds, total_actions)
|
|
|
|
result.total_actions = total_actions
|
|
elapsed = (datetime.now() - start_time).total_seconds()
|
|
log_info(f"模拟循环完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}")
|
|
|
|
return result
|
|
|
|
|
|
async def run_reddit_simulation(
|
|
config: Dict[str, Any],
|
|
simulation_dir: str,
|
|
action_logger: Optional[PlatformActionLogger] = None,
|
|
main_logger: Optional[SimulationLogManager] = None,
|
|
max_rounds: Optional[int] = None
|
|
) -> PlatformSimulation:
|
|
"""Run the Reddit simulation.
|
|
|
|
Args:
|
|
config: simulation config.
|
|
simulation_dir: simulation directory.
|
|
action_logger: action logger.
|
|
main_logger: main log manager.
|
|
max_rounds: optional cap on the number of rounds, used to truncate long runs.
|
|
|
|
Returns:
|
|
PlatformSimulation containing the env and agent_graph.
|
|
"""
|
|
result = PlatformSimulation()
|
|
|
|
def log_info(msg):
|
|
if main_logger:
|
|
main_logger.info(f"[Reddit] {msg}")
|
|
print(f"[Reddit] {msg}")
|
|
|
|
log_info("初始化...")
|
|
|
|
# Reddit uses the boost LLM config when available, falling back to the default.
|
|
model = create_model(config, use_boost=True)
|
|
|
|
profile_path = os.path.join(simulation_dir, "reddit_profiles.json")
|
|
if not os.path.exists(profile_path):
|
|
log_info(f"错误: Profile文件不存在: {profile_path}")
|
|
return result
|
|
|
|
result.agent_graph = await generate_reddit_agent_graph(
|
|
profile_path=profile_path,
|
|
model=model,
|
|
available_actions=REDDIT_ACTIONS,
|
|
)
|
|
|
|
# Pull real agent names from the config (use entity_name rather than the default Agent_X).
|
|
agent_names = get_agent_names_from_config(config)
|
|
# If the config does not list a particular agent, fall back to OASIS's default name.
|
|
for agent_id, agent in result.agent_graph.get_agents():
|
|
if agent_id not in agent_names:
|
|
agent_names[agent_id] = getattr(agent, 'name', f'Agent_{agent_id}')
|
|
|
|
db_path = os.path.join(simulation_dir, "reddit_simulation.db")
|
|
if os.path.exists(db_path):
|
|
os.remove(db_path)
|
|
|
|
result.env = oasis.make(
|
|
agent_graph=result.agent_graph,
|
|
platform=oasis.DefaultPlatformType.REDDIT,
|
|
database_path=db_path,
|
|
semaphore=30, # cap concurrent LLM requests to avoid overloading the API
|
|
)
|
|
|
|
await result.env.reset()
|
|
log_info("环境已启动")
|
|
|
|
if action_logger:
|
|
action_logger.log_simulation_start(config)
|
|
|
|
total_actions = 0
|
|
last_rowid = 0 # last processed db row; using rowid avoids created_at format differences
|
|
|
|
# Run the initial events.
|
|
event_config = config.get("event_config", {})
|
|
initial_posts = event_config.get("initial_posts", [])
|
|
|
|
# Mark the start of round 0 (the initial-events phase).
|
|
if action_logger:
|
|
action_logger.log_round_start(0, 0) # round 0, simulated_hour 0
|
|
|
|
initial_action_count = 0
|
|
if initial_posts:
|
|
initial_actions = {}
|
|
for post in initial_posts:
|
|
agent_id = post.get("poster_agent_id", 0)
|
|
content = post.get("content", "")
|
|
try:
|
|
agent = result.env.agent_graph.get_agent(agent_id)
|
|
if agent in initial_actions:
|
|
if not isinstance(initial_actions[agent], list):
|
|
initial_actions[agent] = [initial_actions[agent]]
|
|
initial_actions[agent].append(ManualAction(
|
|
action_type=ActionType.CREATE_POST,
|
|
action_args={"content": content}
|
|
))
|
|
else:
|
|
initial_actions[agent] = ManualAction(
|
|
action_type=ActionType.CREATE_POST,
|
|
action_args={"content": content}
|
|
)
|
|
|
|
if action_logger:
|
|
action_logger.log_action(
|
|
round_num=0,
|
|
agent_id=agent_id,
|
|
agent_name=agent_names.get(agent_id, f"Agent_{agent_id}"),
|
|
action_type="CREATE_POST",
|
|
action_args={"content": content}
|
|
)
|
|
total_actions += 1
|
|
initial_action_count += 1
|
|
except Exception:
|
|
pass
|
|
|
|
if initial_actions:
|
|
await result.env.step(initial_actions)
|
|
log_info(f"已发布 {len(initial_actions)} 条初始帖子")
|
|
|
|
# Mark the end of round 0.
|
|
if action_logger:
|
|
action_logger.log_round_end(0, initial_action_count)
|
|
|
|
# Main simulation loop.
|
|
time_config = config.get("time_config", {})
|
|
total_hours = time_config.get("total_simulation_hours", 72)
|
|
minutes_per_round = time_config.get("minutes_per_round", 30)
|
|
total_rounds = (total_hours * 60) // minutes_per_round
|
|
|
|
# Truncate when a max round count was supplied.
|
|
if max_rounds is not None and max_rounds > 0:
|
|
original_rounds = total_rounds
|
|
total_rounds = min(total_rounds, max_rounds)
|
|
if total_rounds < original_rounds:
|
|
log_info(f"轮数已截断: {original_rounds} -> {total_rounds} (max_rounds={max_rounds})")
|
|
|
|
start_time = datetime.now()
|
|
|
|
for round_num in range(total_rounds):
|
|
# Bail out if a shutdown signal was received.
|
|
if _shutdown_event and _shutdown_event.is_set():
|
|
if main_logger:
|
|
main_logger.info(f"收到退出信号,在第 {round_num + 1} 轮停止模拟")
|
|
break
|
|
|
|
simulated_minutes = round_num * minutes_per_round
|
|
simulated_hour = (simulated_minutes // 60) % 24
|
|
simulated_day = simulated_minutes // (60 * 24) + 1
|
|
|
|
active_agents = get_active_agents_for_round(
|
|
result.env, config, simulated_hour, round_num
|
|
)
|
|
|
|
# Always log round-start, even when no agents are active.
|
|
if action_logger:
|
|
action_logger.log_round_start(round_num + 1, simulated_hour)
|
|
|
|
if not active_agents:
|
|
# Still emit round-end (with actions_count=0) so the log stays consistent.
|
|
if action_logger:
|
|
action_logger.log_round_end(round_num + 1, 0)
|
|
continue
|
|
|
|
actions = {agent: LLMAction() for _, agent in active_agents}
|
|
await result.env.step(actions)
|
|
|
|
# Pull the actually-executed actions from the database and log them.
|
|
actual_actions, last_rowid = fetch_new_actions_from_db(
|
|
db_path, last_rowid, agent_names
|
|
)
|
|
|
|
round_action_count = 0
|
|
for action_data in actual_actions:
|
|
if action_logger:
|
|
action_logger.log_action(
|
|
round_num=round_num + 1,
|
|
agent_id=action_data['agent_id'],
|
|
agent_name=action_data['agent_name'],
|
|
action_type=action_data['action_type'],
|
|
action_args=action_data['action_args']
|
|
)
|
|
total_actions += 1
|
|
round_action_count += 1
|
|
|
|
if action_logger:
|
|
action_logger.log_round_end(round_num + 1, round_action_count)
|
|
|
|
if (round_num + 1) % 20 == 0:
|
|
progress = (round_num + 1) / total_rounds * 100
|
|
log_info(f"Day {simulated_day}, {simulated_hour:02d}:00 - Round {round_num + 1}/{total_rounds} ({progress:.1f}%)")
|
|
|
|
# Note: do NOT close the env here; we keep it alive for Interview commands.
|
|
|
|
if action_logger:
|
|
action_logger.log_simulation_end(total_rounds, total_actions)
|
|
|
|
result.total_actions = total_actions
|
|
elapsed = (datetime.now() - start_time).total_seconds()
|
|
log_info(f"模拟循环完成! 耗时: {elapsed:.1f}秒, 总动作: {total_actions}")
|
|
|
|
return result
|
|
|
|
|
|
async def main():
|
|
parser = argparse.ArgumentParser(description='OASIS双平台并行模拟')
|
|
parser.add_argument(
|
|
'--config',
|
|
type=str,
|
|
required=True,
|
|
help='配置文件路径 (simulation_config.json)'
|
|
)
|
|
parser.add_argument(
|
|
'--twitter-only',
|
|
action='store_true',
|
|
help='只运行Twitter模拟'
|
|
)
|
|
parser.add_argument(
|
|
'--reddit-only',
|
|
action='store_true',
|
|
help='只运行Reddit模拟'
|
|
)
|
|
parser.add_argument(
|
|
'--max-rounds',
|
|
type=int,
|
|
default=None,
|
|
help='最大模拟轮数(可选,用于截断过长的模拟)'
|
|
)
|
|
parser.add_argument(
|
|
'--no-wait',
|
|
action='store_true',
|
|
default=False,
|
|
help='模拟完成后立即关闭环境,不进入等待命令模式'
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Create the shutdown event at the start of main() so the whole program
|
|
# can respond to exit signals.
|
|
global _shutdown_event
|
|
_shutdown_event = asyncio.Event()
|
|
|
|
if not os.path.exists(args.config):
|
|
print(f"错误: 配置文件不存在: {args.config}")
|
|
sys.exit(1)
|
|
|
|
config = load_config(args.config)
|
|
simulation_dir = os.path.dirname(args.config) or "."
|
|
wait_for_commands = not args.no_wait
|
|
|
|
# Initialize logging (disable OASIS logs, clean up stale files).
|
|
init_logging_for_simulation(simulation_dir)
|
|
|
|
# Create the log manager.
|
|
log_manager = SimulationLogManager(simulation_dir)
|
|
twitter_logger = log_manager.get_twitter_logger()
|
|
reddit_logger = log_manager.get_reddit_logger()
|
|
|
|
log_manager.info("=" * 60)
|
|
log_manager.info("OASIS 双平台并行模拟")
|
|
log_manager.info(f"配置文件: {args.config}")
|
|
log_manager.info(f"模拟ID: {config.get('simulation_id', 'unknown')}")
|
|
log_manager.info(f"等待命令模式: {'启用' if wait_for_commands else '禁用'}")
|
|
log_manager.info("=" * 60)
|
|
|
|
time_config = config.get("time_config", {})
|
|
total_hours = time_config.get('total_simulation_hours', 72)
|
|
minutes_per_round = time_config.get('minutes_per_round', 30)
|
|
config_total_rounds = (total_hours * 60) // minutes_per_round
|
|
|
|
log_manager.info(f"模拟参数:")
|
|
log_manager.info(f" - 总模拟时长: {total_hours}小时")
|
|
log_manager.info(f" - 每轮时间: {minutes_per_round}分钟")
|
|
log_manager.info(f" - 配置总轮数: {config_total_rounds}")
|
|
if args.max_rounds:
|
|
log_manager.info(f" - 最大轮数限制: {args.max_rounds}")
|
|
if args.max_rounds < config_total_rounds:
|
|
log_manager.info(f" - 实际执行轮数: {args.max_rounds} (已截断)")
|
|
log_manager.info(f" - Agent数量: {len(config.get('agent_configs', []))}")
|
|
|
|
log_manager.info("日志结构:")
|
|
log_manager.info(f" - 主日志: simulation.log")
|
|
log_manager.info(f" - Twitter动作: twitter/actions.jsonl")
|
|
log_manager.info(f" - Reddit动作: reddit/actions.jsonl")
|
|
log_manager.info("=" * 60)
|
|
|
|
start_time = datetime.now()
|
|
|
|
# Holds the result for each platform simulation.
|
|
twitter_result: Optional[PlatformSimulation] = None
|
|
reddit_result: Optional[PlatformSimulation] = None
|
|
|
|
if args.twitter_only:
|
|
twitter_result = await run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager, args.max_rounds)
|
|
elif args.reddit_only:
|
|
reddit_result = await run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager, args.max_rounds)
|
|
else:
|
|
# Run both platforms in parallel; each platform uses its own logger.
|
|
results = await asyncio.gather(
|
|
run_twitter_simulation(config, simulation_dir, twitter_logger, log_manager, args.max_rounds),
|
|
run_reddit_simulation(config, simulation_dir, reddit_logger, log_manager, args.max_rounds),
|
|
)
|
|
twitter_result, reddit_result = results
|
|
|
|
total_elapsed = (datetime.now() - start_time).total_seconds()
|
|
log_manager.info("=" * 60)
|
|
log_manager.info(f"模拟循环完成! 总耗时: {total_elapsed:.1f}秒")
|
|
|
|
# Enter wait-for-command mode if requested.
|
|
if wait_for_commands:
|
|
log_manager.info("")
|
|
log_manager.info("=" * 60)
|
|
log_manager.info("进入等待命令模式 - 环境保持运行")
|
|
log_manager.info("支持的命令: interview, batch_interview, close_env")
|
|
log_manager.info("=" * 60)
|
|
|
|
# Create the IPC handler.
|
|
ipc_handler = ParallelIPCHandler(
|
|
simulation_dir=simulation_dir,
|
|
twitter_env=twitter_result.env if twitter_result else None,
|
|
twitter_agent_graph=twitter_result.agent_graph if twitter_result else None,
|
|
reddit_env=reddit_result.env if reddit_result else None,
|
|
reddit_agent_graph=reddit_result.agent_graph if reddit_result else None
|
|
)
|
|
ipc_handler.update_status("alive")
|
|
|
|
# Command-wait loop (driven by the global ``_shutdown_event``).
|
|
try:
|
|
while not _shutdown_event.is_set():
|
|
should_continue = await ipc_handler.process_commands()
|
|
if not should_continue:
|
|
break
|
|
# Use ``wait_for`` instead of ``sleep`` so the loop reacts to shutdown_event.
|
|
try:
|
|
await asyncio.wait_for(_shutdown_event.wait(), timeout=0.5)
|
|
break # shutdown signal received
|
|
except asyncio.TimeoutError:
|
|
pass # timed out, continue looping
|
|
except KeyboardInterrupt:
|
|
print("\n收到中断信号")
|
|
except asyncio.CancelledError:
|
|
print("\n任务被取消")
|
|
except Exception as e:
|
|
print(f"\n命令处理出错: {e}")
|
|
|
|
log_manager.info("\n关闭环境...")
|
|
ipc_handler.update_status("stopped")
|
|
|
|
# Close the environments.
|
|
if twitter_result and twitter_result.env:
|
|
await twitter_result.env.close()
|
|
log_manager.info("[Twitter] 环境已关闭")
|
|
|
|
if reddit_result and reddit_result.env:
|
|
await reddit_result.env.close()
|
|
log_manager.info("[Reddit] 环境已关闭")
|
|
|
|
log_manager.info("=" * 60)
|
|
log_manager.info(f"全部完成!")
|
|
log_manager.info(f"日志文件:")
|
|
log_manager.info(f" - {os.path.join(simulation_dir, 'simulation.log')}")
|
|
log_manager.info(f" - {os.path.join(simulation_dir, 'twitter', 'actions.jsonl')}")
|
|
log_manager.info(f" - {os.path.join(simulation_dir, 'reddit', 'actions.jsonl')}")
|
|
log_manager.info("=" * 60)
|
|
|
|
|
|
def setup_signal_handlers(loop=None):
|
|
"""Install signal handlers that exit cleanly on SIGTERM/SIGINT.
|
|
|
|
Persistent-simulation flow: the process keeps running after the simulation
|
|
finishes so it can serve interview commands. On a termination signal we:
|
|
1. Tell the asyncio loop to stop waiting.
|
|
2. Give the program a chance to clean up (close databases, envs, ...).
|
|
3. Then exit.
|
|
"""
|
|
def signal_handler(signum, frame):
|
|
global _cleanup_done
|
|
sig_name = "SIGTERM" if signum == signal.SIGTERM else "SIGINT"
|
|
print(f"\n收到 {sig_name} 信号,正在退出...")
|
|
|
|
if not _cleanup_done:
|
|
_cleanup_done = True
|
|
# Notify the asyncio loop to exit so it can clean up resources.
|
|
if _shutdown_event:
|
|
_shutdown_event.set()
|
|
|
|
# Avoid sys.exit() on the first signal: let the asyncio loop exit cleanly.
|
|
# Only force-exit if a second signal comes in.
|
|
else:
|
|
print("强制退出...")
|
|
sys.exit(1)
|
|
|
|
signal.signal(signal.SIGTERM, signal_handler)
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
setup_signal_handlers()
|
|
try:
|
|
asyncio.run(main())
|
|
except KeyboardInterrupt:
|
|
print("\n程序被中断")
|
|
except SystemExit:
|
|
pass
|
|
finally:
|
|
# Clean up the multiprocessing resource tracker to avoid exit warnings.
|
|
try:
|
|
from multiprocessing import resource_tracker
|
|
resource_tracker._resource_tracker._stop()
|
|
except Exception:
|
|
pass
|
|
print("模拟进程已退出")
|