docs(i18n): translate chinese docstrings/comments in backend/scripts
This commit is contained in:
parent
8189c08166
commit
5815ed28d2
File diff suppressed because it is too large
Load Diff
|
|
@ -1,16 +1,16 @@
|
|||
"""
|
||||
OASIS Reddit模拟预设脚本
|
||||
此脚本读取配置文件中的参数来执行模拟,实现全程自动化
|
||||
"""OASIS Reddit simulation preset script.
|
||||
|
||||
功能特性:
|
||||
- 完成模拟后不立即关闭环境,进入等待命令模式
|
||||
- 支持通过IPC接收Interview命令
|
||||
- 支持单个Agent采访和批量采访
|
||||
- 支持远程关闭环境命令
|
||||
This script reads parameters from a config file and runs the simulation end-to-end automatically.
|
||||
|
||||
使用方式:
|
||||
Features:
|
||||
- After the simulation finishes, the environment stays alive and enters a command-wait mode.
|
||||
- Accepts Interview commands over IPC.
|
||||
- Supports single-agent and batch interviews.
|
||||
- Supports a remote close-environment command.
|
||||
|
||||
Usage:
|
||||
python run_reddit_simulation.py --config /path/to/simulation_config.json
|
||||
python run_reddit_simulation.py --config /path/to/simulation_config.json --no-wait # 完成后立即关闭
|
||||
python run_reddit_simulation.py --config /path/to/simulation_config.json --no-wait # close immediately when done
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
|
@ -25,18 +25,18 @@ import sqlite3
|
|||
from datetime import datetime
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
# 全局变量:用于信号处理
|
||||
# Globals used by the signal handler.
|
||||
_shutdown_event = None
|
||||
_cleanup_done = False
|
||||
|
||||
# 添加项目路径
|
||||
# Add project paths to sys.path so sibling modules import correctly.
|
||||
_scripts_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
_backend_dir = os.path.abspath(os.path.join(_scripts_dir, '..'))
|
||||
_project_root = os.path.abspath(os.path.join(_backend_dir, '..'))
|
||||
sys.path.insert(0, _scripts_dir)
|
||||
sys.path.insert(0, _backend_dir)
|
||||
|
||||
# 加载项目根目录的 .env 文件(包含 LLM_API_KEY 等配置)
|
||||
# Load the .env file from the project root (contains LLM_API_KEY and related settings).
|
||||
from dotenv import load_dotenv
|
||||
_env_file = os.path.join(_project_root, '.env')
|
||||
if os.path.exists(_env_file):
|
||||
|
|
@ -51,7 +51,7 @@ import re
|
|||
|
||||
|
||||
class UnicodeFormatter(logging.Formatter):
|
||||
"""自定义格式化器,将 Unicode 转义序列转换为可读字符"""
|
||||
"""Custom log formatter that converts Unicode escape sequences into readable characters."""
|
||||
|
||||
UNICODE_ESCAPE_PATTERN = re.compile(r'\\u([0-9a-fA-F]{4})')
|
||||
|
||||
|
|
@ -68,24 +68,23 @@ class UnicodeFormatter(logging.Formatter):
|
|||
|
||||
|
||||
class MaxTokensWarningFilter(logging.Filter):
|
||||
"""过滤掉 camel-ai 关于 max_tokens 的警告(我们故意不设置 max_tokens,让模型自行决定)"""
|
||||
"""Suppress camel-ai's max_tokens warning (we intentionally leave max_tokens unset and let the model decide)."""
|
||||
|
||||
def filter(self, record):
|
||||
# 过滤掉包含 max_tokens 警告的日志
|
||||
if "max_tokens" in record.getMessage() and "Invalid or missing" in record.getMessage():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# 在模块加载时立即添加过滤器,确保在 camel 代码执行前生效
|
||||
# Install the filter at module import time so it takes effect before any camel code runs.
|
||||
logging.getLogger().addFilter(MaxTokensWarningFilter())
|
||||
|
||||
|
||||
def setup_oasis_logging(log_dir: str):
|
||||
"""配置 OASIS 的日志,使用固定名称的日志文件"""
|
||||
"""Configure OASIS logging with fixed log file names."""
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
# 清理旧的日志文件
|
||||
# Remove stale log files from previous runs so the new run starts clean.
|
||||
for f in os.listdir(log_dir):
|
||||
old_log = os.path.join(log_dir, f)
|
||||
if os.path.isfile(old_log) and f.endswith('.log'):
|
||||
|
|
@ -131,20 +130,20 @@ except ImportError as e:
|
|||
sys.exit(1)
|
||||
|
||||
|
||||
# IPC相关常量
|
||||
# IPC-related constants.
|
||||
IPC_COMMANDS_DIR = "ipc_commands"
|
||||
IPC_RESPONSES_DIR = "ipc_responses"
|
||||
ENV_STATUS_FILE = "env_status.json"
|
||||
|
||||
class CommandType:
|
||||
"""命令类型常量"""
|
||||
"""Command type constants."""
|
||||
INTERVIEW = "interview"
|
||||
BATCH_INTERVIEW = "batch_interview"
|
||||
CLOSE_ENV = "close_env"
|
||||
|
||||
|
||||
class IPCHandler:
|
||||
"""IPC命令处理器"""
|
||||
"""IPC command handler."""
|
||||
|
||||
def __init__(self, simulation_dir: str, env, agent_graph):
|
||||
self.simulation_dir = simulation_dir
|
||||
|
|
@ -155,12 +154,11 @@ class IPCHandler:
|
|||
self.status_file = os.path.join(simulation_dir, ENV_STATUS_FILE)
|
||||
self._running = True
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(self.commands_dir, exist_ok=True)
|
||||
os.makedirs(self.responses_dir, exist_ok=True)
|
||||
|
||||
def update_status(self, status: str):
|
||||
"""更新环境状态"""
|
||||
"""Update the environment status file."""
|
||||
with open(self.status_file, 'w', encoding='utf-8') as f:
|
||||
json.dump({
|
||||
"status": status,
|
||||
|
|
@ -168,11 +166,11 @@ class IPCHandler:
|
|||
}, f, ensure_ascii=False, indent=2)
|
||||
|
||||
def poll_command(self) -> Optional[Dict[str, Any]]:
|
||||
"""轮询获取待处理命令"""
|
||||
"""Poll for pending IPC commands."""
|
||||
if not os.path.exists(self.commands_dir):
|
||||
return None
|
||||
|
||||
# 获取命令文件(按时间排序)
|
||||
# Collect command files sorted by modification time so older commands are handled first.
|
||||
command_files = []
|
||||
for filename in os.listdir(self.commands_dir):
|
||||
if filename.endswith('.json'):
|
||||
|
|
@ -191,7 +189,7 @@ class IPCHandler:
|
|||
return None
|
||||
|
||||
def send_response(self, command_id: str, status: str, result: Dict = None, error: str = None):
|
||||
"""发送响应"""
|
||||
"""Send an IPC response for a command."""
|
||||
response = {
|
||||
"command_id": command_id,
|
||||
"status": status,
|
||||
|
|
@ -204,7 +202,7 @@ class IPCHandler:
|
|||
with open(response_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(response, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 删除命令文件
|
||||
# Remove the command file once a response has been written so it isn't re-processed.
|
||||
command_file = os.path.join(self.commands_dir, f"{command_id}.json")
|
||||
try:
|
||||
os.remove(command_file)
|
||||
|
|
@ -212,27 +210,23 @@ class IPCHandler:
|
|||
pass
|
||||
|
||||
async def handle_interview(self, command_id: str, agent_id: int, prompt: str) -> bool:
|
||||
"""
|
||||
处理单个Agent采访命令
|
||||
"""Handle a single-agent interview command.
|
||||
|
||||
Returns:
|
||||
True 表示成功,False 表示失败
|
||||
True on success, False on failure.
|
||||
"""
|
||||
try:
|
||||
# 获取Agent
|
||||
agent = self.agent_graph.get_agent(agent_id)
|
||||
|
||||
# 创建Interview动作
|
||||
interview_action = ManualAction(
|
||||
action_type=ActionType.INTERVIEW,
|
||||
action_args={"prompt": prompt}
|
||||
)
|
||||
|
||||
# 执行Interview
|
||||
actions = {agent: interview_action}
|
||||
await self.env.step(actions)
|
||||
|
||||
# 从数据库获取结果
|
||||
# Read the interview answer back from the simulation database.
|
||||
result = self._get_interview_result(agent_id)
|
||||
|
||||
self.send_response(command_id, "completed", result=result)
|
||||
|
|
@ -246,16 +240,14 @@ class IPCHandler:
|
|||
return False
|
||||
|
||||
async def handle_batch_interview(self, command_id: str, interviews: List[Dict]) -> bool:
|
||||
"""
|
||||
处理批量采访命令
|
||||
"""Handle a batch interview command.
|
||||
|
||||
Args:
|
||||
interviews: [{"agent_id": int, "prompt": str}, ...]
|
||||
"""
|
||||
try:
|
||||
# 构建动作字典
|
||||
actions = {}
|
||||
agent_prompts = {} # 记录每个agent的prompt
|
||||
agent_prompts = {} # Track which prompt was sent to each agent so results can be paired back.
|
||||
|
||||
for interview in interviews:
|
||||
agent_id = interview.get("agent_id")
|
||||
|
|
@ -275,10 +267,8 @@ class IPCHandler:
|
|||
self.send_response(command_id, "failed", error="没有有效的Agent")
|
||||
return False
|
||||
|
||||
# 执行批量Interview
|
||||
await self.env.step(actions)
|
||||
|
||||
# 获取所有结果
|
||||
results = {}
|
||||
for agent_id in agent_prompts.keys():
|
||||
result = self._get_interview_result(agent_id)
|
||||
|
|
@ -298,7 +288,7 @@ class IPCHandler:
|
|||
return False
|
||||
|
||||
def _get_interview_result(self, agent_id: int) -> Dict[str, Any]:
|
||||
"""从数据库获取最新的Interview结果"""
|
||||
"""Fetch the most recent interview result for an agent from the database."""
|
||||
db_path = os.path.join(self.simulation_dir, "reddit_simulation.db")
|
||||
|
||||
result = {
|
||||
|
|
@ -314,7 +304,7 @@ class IPCHandler:
|
|||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 查询最新的Interview记录
|
||||
# Query the most recent interview row for this agent.
|
||||
cursor.execute("""
|
||||
SELECT user_id, info, created_at
|
||||
FROM trace
|
||||
|
|
@ -341,11 +331,10 @@ class IPCHandler:
|
|||
return result
|
||||
|
||||
async def process_commands(self) -> bool:
|
||||
"""
|
||||
处理所有待处理命令
|
||||
"""Process all pending IPC commands.
|
||||
|
||||
Returns:
|
||||
True 表示继续运行,False 表示应该退出
|
||||
True to keep running, False if the loop should exit.
|
||||
"""
|
||||
command = self.poll_command()
|
||||
if not command:
|
||||
|
|
@ -383,9 +372,9 @@ class IPCHandler:
|
|||
|
||||
|
||||
class RedditSimulationRunner:
|
||||
"""Reddit模拟运行器"""
|
||||
"""Reddit simulation runner."""
|
||||
|
||||
# Reddit可用动作(不包含INTERVIEW,INTERVIEW只能通过ManualAction手动触发)
|
||||
# Available Reddit actions (INTERVIEW is excluded because it can only be triggered via ManualAction).
|
||||
AVAILABLE_ACTIONS = [
|
||||
ActionType.LIKE_POST,
|
||||
ActionType.DISLIKE_POST,
|
||||
|
|
@ -403,12 +392,11 @@ class RedditSimulationRunner:
|
|||
]
|
||||
|
||||
def __init__(self, config_path: str, wait_for_commands: bool = True):
|
||||
"""
|
||||
初始化模拟运行器
|
||||
"""Initialize the simulation runner.
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径 (simulation_config.json)
|
||||
wait_for_commands: 模拟完成后是否等待命令(默认True)
|
||||
config_path: Path to the configuration file (simulation_config.json).
|
||||
wait_for_commands: Whether to wait for commands after the simulation finishes (default True).
|
||||
"""
|
||||
self.config_path = config_path
|
||||
self.config = self._load_config()
|
||||
|
|
@ -419,37 +407,36 @@ class RedditSimulationRunner:
|
|||
self.ipc_handler = None
|
||||
|
||||
def _load_config(self) -> Dict[str, Any]:
|
||||
"""加载配置文件"""
|
||||
"""Load the configuration file."""
|
||||
with open(self.config_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
|
||||
def _get_profile_path(self) -> str:
|
||||
"""获取Profile文件路径"""
|
||||
"""Return the path to the agent profiles file."""
|
||||
return os.path.join(self.simulation_dir, "reddit_profiles.json")
|
||||
|
||||
def _get_db_path(self) -> str:
|
||||
"""获取数据库路径"""
|
||||
"""Return the path to the simulation database."""
|
||||
return os.path.join(self.simulation_dir, "reddit_simulation.db")
|
||||
|
||||
def _create_model(self):
|
||||
"""
|
||||
创建LLM模型
|
||||
"""Create the LLM model.
|
||||
|
||||
统一使用项目根目录 .env 文件中的配置(优先级最高):
|
||||
- LLM_API_KEY: API密钥
|
||||
- LLM_BASE_URL: API基础URL
|
||||
- LLM_MODEL_NAME: 模型名称
|
||||
Configuration is sourced from the project-root ``.env`` file (highest priority):
|
||||
- LLM_API_KEY: API key.
|
||||
- LLM_BASE_URL: API base URL.
|
||||
- LLM_MODEL_NAME: Model name.
|
||||
"""
|
||||
# 优先从 .env 读取配置
|
||||
# Prefer values from .env over the per-simulation config.
|
||||
llm_api_key = os.environ.get("LLM_API_KEY", "")
|
||||
llm_base_url = os.environ.get("LLM_BASE_URL", "")
|
||||
llm_model = os.environ.get("LLM_MODEL_NAME", "")
|
||||
|
||||
# 如果 .env 中没有,则使用 config 作为备用
|
||||
# Fall back to the simulation config file if .env did not specify a model.
|
||||
if not llm_model:
|
||||
llm_model = self.config.get("llm_model", "gpt-4o-mini")
|
||||
|
||||
# 设置 camel-ai 所需的环境变量
|
||||
# Export the env vars camel-ai expects.
|
||||
if llm_api_key:
|
||||
os.environ["OPENAI_API_KEY"] = llm_api_key
|
||||
|
||||
|
|
@ -472,9 +459,7 @@ class RedditSimulationRunner:
|
|||
current_hour: int,
|
||||
round_num: int
|
||||
) -> List:
|
||||
"""
|
||||
根据时间和配置决定本轮激活哪些Agent
|
||||
"""
|
||||
"""Decide which agents are active for the current round, based on time of day and config."""
|
||||
time_config = self.config.get("time_config", {})
|
||||
agent_configs = self.config.get("agent_configs", [])
|
||||
|
||||
|
|
@ -521,10 +506,10 @@ class RedditSimulationRunner:
|
|||
return active_agents
|
||||
|
||||
async def run(self, max_rounds: int = None):
|
||||
"""运行Reddit模拟
|
||||
"""Run the Reddit simulation.
|
||||
|
||||
Args:
|
||||
max_rounds: 最大模拟轮数(可选,用于截断过长的模拟)
|
||||
max_rounds: Optional cap on the number of simulation rounds (used to truncate overly long runs).
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("OASIS Reddit模拟")
|
||||
|
|
@ -538,7 +523,7 @@ class RedditSimulationRunner:
|
|||
minutes_per_round = time_config.get("minutes_per_round", 30)
|
||||
total_rounds = (total_hours * 60) // minutes_per_round
|
||||
|
||||
# 如果指定了最大轮数,则截断
|
||||
# Truncate if a max_rounds cap was supplied.
|
||||
if max_rounds is not None and max_rounds > 0:
|
||||
original_rounds = total_rounds
|
||||
total_rounds = min(total_rounds, max_rounds)
|
||||
|
|
@ -578,17 +563,16 @@ class RedditSimulationRunner:
|
|||
agent_graph=self.agent_graph,
|
||||
platform=oasis.DefaultPlatformType.REDDIT,
|
||||
database_path=db_path,
|
||||
semaphore=30, # 限制最大并发 LLM 请求数,防止 API 过载
|
||||
semaphore=30, # Cap concurrent LLM requests to avoid overloading the API.
|
||||
)
|
||||
|
||||
await self.env.reset()
|
||||
print("环境初始化完成\n")
|
||||
|
||||
# 初始化IPC处理器
|
||||
self.ipc_handler = IPCHandler(self.simulation_dir, self.env, self.agent_graph)
|
||||
self.ipc_handler.update_status("running")
|
||||
|
||||
# 执行初始事件
|
||||
# Apply the configured initial events (seed posts) before starting the main loop.
|
||||
event_config = self.config.get("event_config", {})
|
||||
initial_posts = event_config.get("initial_posts", [])
|
||||
|
||||
|
|
@ -619,7 +603,7 @@ class RedditSimulationRunner:
|
|||
await self.env.step(initial_actions)
|
||||
print(f" 已发布 {len(initial_actions)} 条初始帖子")
|
||||
|
||||
# 主模拟循环
|
||||
# Main simulation loop.
|
||||
print("\n开始模拟循环...")
|
||||
start_time = datetime.now()
|
||||
|
||||
|
|
@ -655,7 +639,7 @@ class RedditSimulationRunner:
|
|||
print(f" - 总耗时: {total_elapsed:.1f}秒")
|
||||
print(f" - 数据库: {db_path}")
|
||||
|
||||
# 是否进入等待命令模式
|
||||
# Optionally enter command-wait mode.
|
||||
if self.wait_for_commands:
|
||||
print("\n" + "=" * 60)
|
||||
print("进入等待命令模式 - 环境保持运行")
|
||||
|
|
@ -664,7 +648,7 @@ class RedditSimulationRunner:
|
|||
|
||||
self.ipc_handler.update_status("alive")
|
||||
|
||||
# 等待命令循环(使用全局 _shutdown_event)
|
||||
# Command-wait loop driven by the global _shutdown_event.
|
||||
try:
|
||||
while not _shutdown_event.is_set():
|
||||
should_continue = await self.ipc_handler.process_commands()
|
||||
|
|
@ -672,7 +656,7 @@ class RedditSimulationRunner:
|
|||
break
|
||||
try:
|
||||
await asyncio.wait_for(_shutdown_event.wait(), timeout=0.5)
|
||||
break # 收到退出信号
|
||||
break # Shutdown signal received.
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
except KeyboardInterrupt:
|
||||
|
|
@ -684,7 +668,6 @@ class RedditSimulationRunner:
|
|||
|
||||
print("\n关闭环境...")
|
||||
|
||||
# 关闭环境
|
||||
self.ipc_handler.update_status("stopped")
|
||||
await self.env.close()
|
||||
|
||||
|
|
@ -715,7 +698,7 @@ async def main():
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 在 main 函数开始时创建 shutdown 事件
|
||||
# Create the shutdown event lazily here so it is bound to the running asyncio loop.
|
||||
global _shutdown_event
|
||||
_shutdown_event = asyncio.Event()
|
||||
|
||||
|
|
@ -723,7 +706,7 @@ async def main():
|
|||
print(f"错误: 配置文件不存在: {args.config}")
|
||||
sys.exit(1)
|
||||
|
||||
# 初始化日志配置(使用固定文件名,清理旧日志)
|
||||
# Initialize log config with fixed filenames; old logs are cleared inside setup_oasis_logging.
|
||||
simulation_dir = os.path.dirname(args.config) or "."
|
||||
setup_oasis_logging(os.path.join(simulation_dir, "log"))
|
||||
|
||||
|
|
@ -735,9 +718,9 @@ async def main():
|
|||
|
||||
|
||||
def setup_signal_handlers():
|
||||
"""
|
||||
设置信号处理器,确保收到 SIGTERM/SIGINT 时能够正确退出
|
||||
让程序有机会正常清理资源(关闭数据库、环境等)
|
||||
"""Install signal handlers so SIGTERM/SIGINT trigger a graceful exit.
|
||||
|
||||
This gives the program a chance to clean up resources (close the database, the OASIS environment, etc.).
|
||||
"""
|
||||
def signal_handler(signum, frame):
|
||||
global _cleanup_done
|
||||
|
|
@ -748,7 +731,7 @@ def setup_signal_handlers():
|
|||
if _shutdown_event:
|
||||
_shutdown_event.set()
|
||||
else:
|
||||
# 重复收到信号才强制退出
|
||||
# Force exit only on a repeat signal so the user can still hard-kill if cleanup hangs.
|
||||
print("强制退出...")
|
||||
sys.exit(1)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,16 +1,18 @@
|
|||
"""
|
||||
OASIS Twitter模拟预设脚本
|
||||
此脚本读取配置文件中的参数来执行模拟,实现全程自动化
|
||||
OASIS Twitter simulation preset script.
|
||||
|
||||
功能特性:
|
||||
- 完成模拟后不立即关闭环境,进入等待命令模式
|
||||
- 支持通过IPC接收Interview命令
|
||||
- 支持单个Agent采访和批量采访
|
||||
- 支持远程关闭环境命令
|
||||
This script reads parameters from a config file to run a fully automated simulation.
|
||||
|
||||
使用方式:
|
||||
Features:
|
||||
- Does not close the environment immediately when the simulation finishes; enters
|
||||
command-wait mode instead.
|
||||
- Receives Interview commands over IPC.
|
||||
- Supports both single-agent and batch interviews.
|
||||
- Supports a remote close-environment command.
|
||||
|
||||
Usage:
|
||||
python run_twitter_simulation.py --config /path/to/simulation_config.json
|
||||
python run_twitter_simulation.py --config /path/to/simulation_config.json --no-wait # 完成后立即关闭
|
||||
python run_twitter_simulation.py --config /path/to/simulation_config.json --no-wait # close immediately when done
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
|
@ -25,18 +27,17 @@ import sqlite3
|
|||
from datetime import datetime
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
||||
# 全局变量:用于信号处理
|
||||
# Globals used by the signal handler.
|
||||
_shutdown_event = None
|
||||
_cleanup_done = False
|
||||
|
||||
# 添加项目路径
|
||||
_scripts_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
_backend_dir = os.path.abspath(os.path.join(_scripts_dir, '..'))
|
||||
_project_root = os.path.abspath(os.path.join(_backend_dir, '..'))
|
||||
sys.path.insert(0, _scripts_dir)
|
||||
sys.path.insert(0, _backend_dir)
|
||||
|
||||
# 加载项目根目录的 .env 文件(包含 LLM_API_KEY 等配置)
|
||||
# Load the project-root .env (it carries LLM_API_KEY and friends).
|
||||
from dotenv import load_dotenv
|
||||
_env_file = os.path.join(_project_root, '.env')
|
||||
if os.path.exists(_env_file):
|
||||
|
|
@ -51,7 +52,7 @@ import re
|
|||
|
||||
|
||||
class UnicodeFormatter(logging.Formatter):
|
||||
"""自定义格式化器,将 Unicode 转义序列转换为可读字符"""
|
||||
"""Custom formatter that turns Unicode escape sequences into readable characters."""
|
||||
|
||||
UNICODE_ESCAPE_PATTERN = re.compile(r'\\u([0-9a-fA-F]{4})')
|
||||
|
||||
|
|
@ -68,24 +69,23 @@ class UnicodeFormatter(logging.Formatter):
|
|||
|
||||
|
||||
class MaxTokensWarningFilter(logging.Filter):
|
||||
"""过滤掉 camel-ai 关于 max_tokens 的警告(我们故意不设置 max_tokens,让模型自行决定)"""
|
||||
"""Suppress camel-ai's max_tokens warning — we intentionally leave it unset and let the model decide."""
|
||||
|
||||
def filter(self, record):
|
||||
# 过滤掉包含 max_tokens 警告的日志
|
||||
if "max_tokens" in record.getMessage() and "Invalid or missing" in record.getMessage():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
# 在模块加载时立即添加过滤器,确保在 camel 代码执行前生效
|
||||
# Install the filter at import time so it is active before any camel code runs.
|
||||
logging.getLogger().addFilter(MaxTokensWarningFilter())
|
||||
|
||||
|
||||
def setup_oasis_logging(log_dir: str):
|
||||
"""配置 OASIS 的日志,使用固定名称的日志文件"""
|
||||
"""Configure OASIS logging with fixed log filenames."""
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
|
||||
# 清理旧的日志文件
|
||||
# Wipe stale log files from previous runs.
|
||||
for f in os.listdir(log_dir):
|
||||
old_log = os.path.join(log_dir, f)
|
||||
if os.path.isfile(old_log) and f.endswith('.log'):
|
||||
|
|
@ -131,20 +131,20 @@ except ImportError as e:
|
|||
sys.exit(1)
|
||||
|
||||
|
||||
# IPC相关常量
|
||||
# IPC-related constants.
|
||||
IPC_COMMANDS_DIR = "ipc_commands"
|
||||
IPC_RESPONSES_DIR = "ipc_responses"
|
||||
ENV_STATUS_FILE = "env_status.json"
|
||||
|
||||
class CommandType:
|
||||
"""命令类型常量"""
|
||||
"""Command type constants."""
|
||||
INTERVIEW = "interview"
|
||||
BATCH_INTERVIEW = "batch_interview"
|
||||
CLOSE_ENV = "close_env"
|
||||
|
||||
|
||||
class IPCHandler:
|
||||
"""IPC命令处理器"""
|
||||
"""Handles IPC commands directed at the running simulation."""
|
||||
|
||||
def __init__(self, simulation_dir: str, env, agent_graph):
|
||||
self.simulation_dir = simulation_dir
|
||||
|
|
@ -155,12 +155,11 @@ class IPCHandler:
|
|||
self.status_file = os.path.join(simulation_dir, ENV_STATUS_FILE)
|
||||
self._running = True
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(self.commands_dir, exist_ok=True)
|
||||
os.makedirs(self.responses_dir, exist_ok=True)
|
||||
|
||||
def update_status(self, status: str):
|
||||
"""更新环境状态"""
|
||||
"""Write the current environment status to the status file."""
|
||||
with open(self.status_file, 'w', encoding='utf-8') as f:
|
||||
json.dump({
|
||||
"status": status,
|
||||
|
|
@ -168,11 +167,11 @@ class IPCHandler:
|
|||
}, f, ensure_ascii=False, indent=2)
|
||||
|
||||
def poll_command(self) -> Optional[Dict[str, Any]]:
|
||||
"""轮询获取待处理命令"""
|
||||
"""Poll for the next pending command."""
|
||||
if not os.path.exists(self.commands_dir):
|
||||
return None
|
||||
|
||||
# 获取命令文件(按时间排序)
|
||||
# Collect command files ordered by mtime.
|
||||
command_files = []
|
||||
for filename in os.listdir(self.commands_dir):
|
||||
if filename.endswith('.json'):
|
||||
|
|
@ -191,7 +190,7 @@ class IPCHandler:
|
|||
return None
|
||||
|
||||
def send_response(self, command_id: str, status: str, result: Dict = None, error: str = None):
|
||||
"""发送响应"""
|
||||
"""Send a response for a processed command."""
|
||||
response = {
|
||||
"command_id": command_id,
|
||||
"status": status,
|
||||
|
|
@ -204,7 +203,7 @@ class IPCHandler:
|
|||
with open(response_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(response, f, ensure_ascii=False, indent=2)
|
||||
|
||||
# 删除命令文件
|
||||
# Remove the command file once a response has been written.
|
||||
command_file = os.path.join(self.commands_dir, f"{command_id}.json")
|
||||
try:
|
||||
os.remove(command_file)
|
||||
|
|
@ -212,27 +211,23 @@ class IPCHandler:
|
|||
pass
|
||||
|
||||
async def handle_interview(self, command_id: str, agent_id: int, prompt: str) -> bool:
|
||||
"""
|
||||
处理单个Agent采访命令
|
||||
"""Handle a single-agent interview command.
|
||||
|
||||
Returns:
|
||||
True 表示成功,False 表示失败
|
||||
True on success, False on failure.
|
||||
"""
|
||||
try:
|
||||
# 获取Agent
|
||||
agent = self.agent_graph.get_agent(agent_id)
|
||||
|
||||
# 创建Interview动作
|
||||
interview_action = ManualAction(
|
||||
action_type=ActionType.INTERVIEW,
|
||||
action_args={"prompt": prompt}
|
||||
)
|
||||
|
||||
# 执行Interview
|
||||
actions = {agent: interview_action}
|
||||
await self.env.step(actions)
|
||||
|
||||
# 从数据库获取结果
|
||||
# Pull the resulting transcript from the simulation database.
|
||||
result = self._get_interview_result(agent_id)
|
||||
|
||||
self.send_response(command_id, "completed", result=result)
|
||||
|
|
@ -246,16 +241,14 @@ class IPCHandler:
|
|||
return False
|
||||
|
||||
async def handle_batch_interview(self, command_id: str, interviews: List[Dict]) -> bool:
|
||||
"""
|
||||
处理批量采访命令
|
||||
"""Handle a batch interview command.
|
||||
|
||||
Args:
|
||||
interviews: [{"agent_id": int, "prompt": str}, ...]
|
||||
"""
|
||||
try:
|
||||
# 构建动作字典
|
||||
actions = {}
|
||||
agent_prompts = {} # 记录每个agent的prompt
|
||||
agent_prompts = {} # Track the prompt issued to each agent for later result lookup.
|
||||
|
||||
for interview in interviews:
|
||||
agent_id = interview.get("agent_id")
|
||||
|
|
@ -275,10 +268,9 @@ class IPCHandler:
|
|||
self.send_response(command_id, "failed", error="没有有效的Agent")
|
||||
return False
|
||||
|
||||
# 执行批量Interview
|
||||
await self.env.step(actions)
|
||||
|
||||
# 获取所有结果
|
||||
# Collect the per-agent interview results.
|
||||
results = {}
|
||||
for agent_id in agent_prompts.keys():
|
||||
result = self._get_interview_result(agent_id)
|
||||
|
|
@ -298,7 +290,7 @@ class IPCHandler:
|
|||
return False
|
||||
|
||||
def _get_interview_result(self, agent_id: int) -> Dict[str, Any]:
|
||||
"""从数据库获取最新的Interview结果"""
|
||||
"""Fetch the most recent interview result for an agent from the database."""
|
||||
db_path = os.path.join(self.simulation_dir, "twitter_simulation.db")
|
||||
|
||||
result = {
|
||||
|
|
@ -314,7 +306,7 @@ class IPCHandler:
|
|||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 查询最新的Interview记录
|
||||
# Pull the most recent INTERVIEW trace row for this agent.
|
||||
cursor.execute("""
|
||||
SELECT user_id, info, created_at
|
||||
FROM trace
|
||||
|
|
@ -341,11 +333,10 @@ class IPCHandler:
|
|||
return result
|
||||
|
||||
async def process_commands(self) -> bool:
|
||||
"""
|
||||
处理所有待处理命令
|
||||
"""Process pending commands.
|
||||
|
||||
Returns:
|
||||
True 表示继续运行,False 表示应该退出
|
||||
True if the run loop should continue, False if it should exit.
|
||||
"""
|
||||
command = self.poll_command()
|
||||
if not command:
|
||||
|
|
@ -383,9 +374,9 @@ class IPCHandler:
|
|||
|
||||
|
||||
class TwitterSimulationRunner:
|
||||
"""Twitter模拟运行器"""
|
||||
"""Drives a single Twitter simulation run."""
|
||||
|
||||
# Twitter可用动作(不包含INTERVIEW,INTERVIEW只能通过ManualAction手动触发)
|
||||
# Available Twitter actions. INTERVIEW is intentionally excluded — it can only be triggered via ManualAction.
|
||||
AVAILABLE_ACTIONS = [
|
||||
ActionType.CREATE_POST,
|
||||
ActionType.LIKE_POST,
|
||||
|
|
@ -396,12 +387,11 @@ class TwitterSimulationRunner:
|
|||
]
|
||||
|
||||
def __init__(self, config_path: str, wait_for_commands: bool = True):
|
||||
"""
|
||||
初始化模拟运行器
|
||||
"""Initialize the simulation runner.
|
||||
|
||||
Args:
|
||||
config_path: 配置文件路径 (simulation_config.json)
|
||||
wait_for_commands: 模拟完成后是否等待命令(默认True)
|
||||
config_path: Path to the config file (simulation_config.json).
|
||||
wait_for_commands: Whether to wait for IPC commands after the simulation completes (default True).
|
||||
"""
|
||||
self.config_path = config_path
|
||||
self.config = self._load_config()
|
||||
|
|
@ -412,37 +402,36 @@ class TwitterSimulationRunner:
|
|||
self.ipc_handler = None
|
||||
|
||||
def _load_config(self) -> Dict[str, Any]:
|
||||
"""加载配置文件"""
|
||||
"""Load the simulation config file."""
|
||||
with open(self.config_path, 'r', encoding='utf-8') as f:
|
||||
return json.load(f)
|
||||
|
||||
def _get_profile_path(self) -> str:
|
||||
"""获取Profile文件路径(OASIS Twitter使用CSV格式)"""
|
||||
"""Return the agent profile path (OASIS Twitter expects CSV)."""
|
||||
return os.path.join(self.simulation_dir, "twitter_profiles.csv")
|
||||
|
||||
def _get_db_path(self) -> str:
|
||||
"""获取数据库路径"""
|
||||
"""Return the simulation SQLite database path."""
|
||||
return os.path.join(self.simulation_dir, "twitter_simulation.db")
|
||||
|
||||
def _create_model(self):
|
||||
"""
|
||||
创建LLM模型
|
||||
"""Create the LLM model.
|
||||
|
||||
统一使用项目根目录 .env 文件中的配置(优先级最高):
|
||||
- LLM_API_KEY: API密钥
|
||||
- LLM_BASE_URL: API基础URL
|
||||
- LLM_MODEL_NAME: 模型名称
|
||||
Uses the project-root .env file (highest precedence):
|
||||
- LLM_API_KEY: API key
|
||||
- LLM_BASE_URL: API base URL
|
||||
- LLM_MODEL_NAME: model name
|
||||
"""
|
||||
# 优先从 .env 读取配置
|
||||
# Prefer values from .env.
|
||||
llm_api_key = os.environ.get("LLM_API_KEY", "")
|
||||
llm_base_url = os.environ.get("LLM_BASE_URL", "")
|
||||
llm_model = os.environ.get("LLM_MODEL_NAME", "")
|
||||
|
||||
# 如果 .env 中没有,则使用 config 作为备用
|
||||
# Fall back to the simulation config if .env did not provide a model name.
|
||||
if not llm_model:
|
||||
llm_model = self.config.get("llm_model", "gpt-4o-mini")
|
||||
|
||||
# 设置 camel-ai 所需的环境变量
|
||||
# camel-ai reads OPENAI_API_KEY from the environment.
|
||||
if llm_api_key:
|
||||
os.environ["OPENAI_API_KEY"] = llm_api_key
|
||||
|
||||
|
|
@ -465,25 +454,24 @@ class TwitterSimulationRunner:
|
|||
current_hour: int,
|
||||
round_num: int
|
||||
) -> List:
|
||||
"""
|
||||
根据时间和配置决定本轮激活哪些Agent
|
||||
"""Decide which agents activate this round, based on time and config.
|
||||
|
||||
Args:
|
||||
env: OASIS环境
|
||||
current_hour: 当前模拟小时(0-23)
|
||||
round_num: 当前轮数
|
||||
env: The OASIS environment.
|
||||
current_hour: Current simulated hour (0-23).
|
||||
round_num: Current round number.
|
||||
|
||||
Returns:
|
||||
激活的Agent列表
|
||||
The list of agents activated this round.
|
||||
"""
|
||||
time_config = self.config.get("time_config", {})
|
||||
agent_configs = self.config.get("agent_configs", [])
|
||||
|
||||
# 基础激活数量
|
||||
# Base activation count per round.
|
||||
base_min = time_config.get("agents_per_hour_min", 5)
|
||||
base_max = time_config.get("agents_per_hour_max", 20)
|
||||
|
||||
# 根据时段调整
|
||||
# Adjust by time-of-day (peak vs. off-peak hours).
|
||||
peak_hours = time_config.get("peak_hours", [9, 10, 11, 14, 15, 20, 21, 22])
|
||||
off_peak_hours = time_config.get("off_peak_hours", [0, 1, 2, 3, 4, 5])
|
||||
|
||||
|
|
@ -496,28 +484,26 @@ class TwitterSimulationRunner:
|
|||
|
||||
target_count = int(random.uniform(base_min, base_max) * multiplier)
|
||||
|
||||
# 根据每个Agent的配置计算激活概率
|
||||
# Compute activation probability for each configured agent.
|
||||
candidates = []
|
||||
for cfg in agent_configs:
|
||||
agent_id = cfg.get("agent_id", 0)
|
||||
active_hours = cfg.get("active_hours", list(range(8, 23)))
|
||||
activity_level = cfg.get("activity_level", 0.5)
|
||||
|
||||
# 检查是否在活跃时间
|
||||
if current_hour not in active_hours:
|
||||
continue
|
||||
|
||||
# 根据活跃度计算概率
|
||||
if random.random() < activity_level:
|
||||
candidates.append(agent_id)
|
||||
|
||||
# 随机选择
|
||||
# Pick a random subset of the eligible candidates.
|
||||
selected_ids = random.sample(
|
||||
candidates,
|
||||
min(target_count, len(candidates))
|
||||
) if candidates else []
|
||||
|
||||
# 转换为Agent对象
|
||||
# Resolve IDs to Agent objects.
|
||||
active_agents = []
|
||||
for agent_id in selected_ids:
|
||||
try:
|
||||
|
|
@ -529,10 +515,10 @@ class TwitterSimulationRunner:
|
|||
return active_agents
|
||||
|
||||
async def run(self, max_rounds: int = None):
|
||||
"""运行Twitter模拟
|
||||
"""Run the Twitter simulation.
|
||||
|
||||
Args:
|
||||
max_rounds: 最大模拟轮数(可选,用于截断过长的模拟)
|
||||
max_rounds: Optional cap on the number of rounds, used to truncate overly long simulations.
|
||||
"""
|
||||
print("=" * 60)
|
||||
print("OASIS Twitter模拟")
|
||||
|
|
@ -541,15 +527,13 @@ class TwitterSimulationRunner:
|
|||
print(f"等待命令模式: {'启用' if self.wait_for_commands else '禁用'}")
|
||||
print("=" * 60)
|
||||
|
||||
# 加载时间配置
|
||||
time_config = self.config.get("time_config", {})
|
||||
total_hours = time_config.get("total_simulation_hours", 72)
|
||||
minutes_per_round = time_config.get("minutes_per_round", 30)
|
||||
|
||||
# 计算总轮数
|
||||
total_rounds = (total_hours * 60) // minutes_per_round
|
||||
|
||||
# 如果指定了最大轮数,则截断
|
||||
# Truncate to max_rounds when one was supplied.
|
||||
if max_rounds is not None and max_rounds > 0:
|
||||
original_rounds = total_rounds
|
||||
total_rounds = min(total_rounds, max_rounds)
|
||||
|
|
@ -564,11 +548,10 @@ class TwitterSimulationRunner:
|
|||
print(f" - 最大轮数限制: {max_rounds}")
|
||||
print(f" - Agent数量: {len(self.config.get('agent_configs', []))}")
|
||||
|
||||
# 创建模型
|
||||
print("\n初始化LLM模型...")
|
||||
model = self._create_model()
|
||||
|
||||
# 加载Agent图
|
||||
# Load the agent graph from the profile CSV.
|
||||
print("加载Agent Profile...")
|
||||
profile_path = self._get_profile_path()
|
||||
if not os.path.exists(profile_path):
|
||||
|
|
@ -581,29 +564,27 @@ class TwitterSimulationRunner:
|
|||
available_actions=self.AVAILABLE_ACTIONS,
|
||||
)
|
||||
|
||||
# 数据库路径
|
||||
# Reset the simulation database for a clean run.
|
||||
db_path = self._get_db_path()
|
||||
if os.path.exists(db_path):
|
||||
os.remove(db_path)
|
||||
print(f"已删除旧数据库: {db_path}")
|
||||
|
||||
# 创建环境
|
||||
print("创建OASIS环境...")
|
||||
self.env = oasis.make(
|
||||
agent_graph=self.agent_graph,
|
||||
platform=oasis.DefaultPlatformType.TWITTER,
|
||||
database_path=db_path,
|
||||
semaphore=30, # 限制最大并发 LLM 请求数,防止 API 过载
|
||||
semaphore=30, # Cap concurrent LLM requests to avoid API overload.
|
||||
)
|
||||
|
||||
await self.env.reset()
|
||||
print("环境初始化完成\n")
|
||||
|
||||
# 初始化IPC处理器
|
||||
self.ipc_handler = IPCHandler(self.simulation_dir, self.env, self.agent_graph)
|
||||
self.ipc_handler.update_status("running")
|
||||
|
||||
# 执行初始事件
|
||||
# Run the initial seeded events (kickoff posts).
|
||||
event_config = self.config.get("event_config", {})
|
||||
initial_posts = event_config.get("initial_posts", [])
|
||||
|
||||
|
|
@ -626,17 +607,16 @@ class TwitterSimulationRunner:
|
|||
await self.env.step(initial_actions)
|
||||
print(f" 已发布 {len(initial_actions)} 条初始帖子")
|
||||
|
||||
# 主模拟循环
|
||||
# Main simulation loop.
|
||||
print("\n开始模拟循环...")
|
||||
start_time = datetime.now()
|
||||
|
||||
for round_num in range(total_rounds):
|
||||
# 计算当前模拟时间
|
||||
# Map round number to simulated wall-clock time.
|
||||
simulated_minutes = round_num * minutes_per_round
|
||||
simulated_hour = (simulated_minutes // 60) % 24
|
||||
simulated_day = simulated_minutes // (60 * 24) + 1
|
||||
|
||||
# 获取本轮激活的Agent
|
||||
active_agents = self._get_active_agents_for_round(
|
||||
self.env, simulated_hour, round_num
|
||||
)
|
||||
|
|
@ -644,16 +624,14 @@ class TwitterSimulationRunner:
|
|||
if not active_agents:
|
||||
continue
|
||||
|
||||
# 构建动作
|
||||
actions = {
|
||||
agent: LLMAction()
|
||||
for _, agent in active_agents
|
||||
}
|
||||
|
||||
# 执行动作
|
||||
await self.env.step(actions)
|
||||
|
||||
# 打印进度
|
||||
# Periodic progress log.
|
||||
if (round_num + 1) % 10 == 0 or round_num == 0:
|
||||
elapsed = (datetime.now() - start_time).total_seconds()
|
||||
progress = (round_num + 1) / total_rounds * 100
|
||||
|
|
@ -667,7 +645,7 @@ class TwitterSimulationRunner:
|
|||
print(f" - 总耗时: {total_elapsed:.1f}秒")
|
||||
print(f" - 数据库: {db_path}")
|
||||
|
||||
# 是否进入等待命令模式
|
||||
# Optionally enter command-wait mode.
|
||||
if self.wait_for_commands:
|
||||
print("\n" + "=" * 60)
|
||||
print("进入等待命令模式 - 环境保持运行")
|
||||
|
|
@ -676,7 +654,7 @@ class TwitterSimulationRunner:
|
|||
|
||||
self.ipc_handler.update_status("alive")
|
||||
|
||||
# 等待命令循环(使用全局 _shutdown_event)
|
||||
# Command-wait loop, driven by the global _shutdown_event.
|
||||
try:
|
||||
while not _shutdown_event.is_set():
|
||||
should_continue = await self.ipc_handler.process_commands()
|
||||
|
|
@ -684,7 +662,7 @@ class TwitterSimulationRunner:
|
|||
break
|
||||
try:
|
||||
await asyncio.wait_for(_shutdown_event.wait(), timeout=0.5)
|
||||
break # 收到退出信号
|
||||
break # Shutdown signal received.
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
except KeyboardInterrupt:
|
||||
|
|
@ -696,7 +674,6 @@ class TwitterSimulationRunner:
|
|||
|
||||
print("\n关闭环境...")
|
||||
|
||||
# 关闭环境
|
||||
self.ipc_handler.update_status("stopped")
|
||||
await self.env.close()
|
||||
|
||||
|
|
@ -727,7 +704,7 @@ async def main():
|
|||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 在 main 函数开始时创建 shutdown 事件
|
||||
# Create the shutdown event inside the running event loop.
|
||||
global _shutdown_event
|
||||
_shutdown_event = asyncio.Event()
|
||||
|
||||
|
|
@ -735,7 +712,7 @@ async def main():
|
|||
print(f"错误: 配置文件不存在: {args.config}")
|
||||
sys.exit(1)
|
||||
|
||||
# 初始化日志配置(使用固定文件名,清理旧日志)
|
||||
# Initialize logging with fixed filenames; old logs are wiped.
|
||||
simulation_dir = os.path.dirname(args.config) or "."
|
||||
setup_oasis_logging(os.path.join(simulation_dir, "log"))
|
||||
|
||||
|
|
@ -747,9 +724,11 @@ async def main():
|
|||
|
||||
|
||||
def setup_signal_handlers():
|
||||
"""
|
||||
设置信号处理器,确保收到 SIGTERM/SIGINT 时能够正确退出
|
||||
让程序有机会正常清理资源(关闭数据库、环境等)
|
||||
"""Install signal handlers so SIGTERM/SIGINT trigger an orderly shutdown.
|
||||
|
||||
The handler gives the program a chance to clean up resources properly
|
||||
(closing the database, the OASIS environment, etc.) on the first signal,
|
||||
and only force-exits on a repeated signal.
|
||||
"""
|
||||
def signal_handler(signum, frame):
|
||||
global _cleanup_done
|
||||
|
|
@ -760,7 +739,7 @@ def setup_signal_handlers():
|
|||
if _shutdown_event:
|
||||
_shutdown_event.set()
|
||||
else:
|
||||
# 重复收到信号才强制退出
|
||||
# Force exit only on a repeat signal.
|
||||
print("强制退出...")
|
||||
sys.exit(1)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue