MicroFish/backend/app/services/simulation_manager.py

525 lines
20 KiB
Python

"""OASIS simulation manager.
Drives parallel Twitter + Reddit simulations using preset scripts plus
LLM-generated configuration parameters.
"""
import os
import json
import shutil
from typing import Dict, Any, List, Optional
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from ..config import Config
from ..utils.logger import get_logger
from .zep_entity_reader import ZepEntityReader, FilteredEntities
from .oasis_profile_generator import OasisProfileGenerator, OasisAgentProfile
from .simulation_config_generator import SimulationConfigGenerator, SimulationParameters
from ..utils.locale import t
logger = get_logger('mirofish.simulation')
class SimulationStatus(str, Enum):
"""Simulation lifecycle status."""
CREATED = "created"
PREPARING = "preparing"
READY = "ready"
RUNNING = "running"
PAUSED = "paused"
STOPPED = "stopped" # manually stopped
COMPLETED = "completed" # finished naturally
FAILED = "failed"
class PlatformType(str, Enum):
"""Simulated platform types."""
TWITTER = "twitter"
REDDIT = "reddit"
@dataclass
class SimulationState:
"""In-memory + persisted state for a single simulation."""
simulation_id: str
project_id: str
graph_id: str
# Per-platform enable flags.
enable_twitter: bool = True
enable_reddit: bool = True
# Lifecycle status.
status: SimulationStatus = SimulationStatus.CREATED
# Counters captured during the prepare phase.
entities_count: int = 0
profiles_count: int = 0
entity_types: List[str] = field(default_factory=list)
# Information about the auto-generated config.
config_generated: bool = False
config_reasoning: str = ""
# Runtime data.
current_round: int = 0
twitter_status: str = "not_started"
reddit_status: str = "not_started"
# Timestamps.
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
# Error message when status == FAILED.
error: Optional[str] = None
def to_dict(self) -> Dict[str, Any]:
"""Full state dict (used for persistence and internal callers)."""
return {
"simulation_id": self.simulation_id,
"project_id": self.project_id,
"graph_id": self.graph_id,
"enable_twitter": self.enable_twitter,
"enable_reddit": self.enable_reddit,
"status": self.status.value,
"entities_count": self.entities_count,
"profiles_count": self.profiles_count,
"entity_types": self.entity_types,
"config_generated": self.config_generated,
"config_reasoning": self.config_reasoning,
"current_round": self.current_round,
"twitter_status": self.twitter_status,
"reddit_status": self.reddit_status,
"created_at": self.created_at,
"updated_at": self.updated_at,
"error": self.error,
}
def to_simple_dict(self) -> Dict[str, Any]:
"""Simplified state dict (used for API responses)."""
return {
"simulation_id": self.simulation_id,
"project_id": self.project_id,
"graph_id": self.graph_id,
"status": self.status.value,
"entities_count": self.entities_count,
"profiles_count": self.profiles_count,
"entity_types": self.entity_types,
"config_generated": self.config_generated,
"error": self.error,
}
class SimulationManager:
"""Simulation manager.
Core responsibilities:
1. Read entities from the Zep graph and filter to the configured types.
2. Generate OASIS agent profiles per entity.
3. Use the LLM to generate simulation configuration parameters.
4. Materialize the files the preset scripts expect.
"""
# Root directory for persisted simulation data.
SIMULATION_DATA_DIR = os.path.join(
os.path.dirname(__file__),
'../../uploads/simulations'
)
def __init__(self):
# Ensure the simulation data directory exists.
os.makedirs(self.SIMULATION_DATA_DIR, exist_ok=True)
# In-memory cache of simulation state objects.
self._simulations: Dict[str, SimulationState] = {}
def _get_simulation_dir(self, simulation_id: str) -> str:
"""Return the on-disk directory for a simulation, creating if missing."""
sim_dir = os.path.join(self.SIMULATION_DATA_DIR, simulation_id)
os.makedirs(sim_dir, exist_ok=True)
return sim_dir
def _save_simulation_state(self, state: SimulationState):
"""Persist a simulation state to disk and update the cache."""
sim_dir = self._get_simulation_dir(state.simulation_id)
state_file = os.path.join(sim_dir, "state.json")
state.updated_at = datetime.now().isoformat()
with open(state_file, 'w', encoding='utf-8') as f:
json.dump(state.to_dict(), f, ensure_ascii=False, indent=2)
self._simulations[state.simulation_id] = state
def _load_simulation_state(self, simulation_id: str) -> Optional[SimulationState]:
"""Load a simulation state from disk (or cache) by id."""
if simulation_id in self._simulations:
return self._simulations[simulation_id]
sim_dir = self._get_simulation_dir(simulation_id)
state_file = os.path.join(sim_dir, "state.json")
if not os.path.exists(state_file):
return None
with open(state_file, 'r', encoding='utf-8') as f:
data = json.load(f)
state = SimulationState(
simulation_id=simulation_id,
project_id=data.get("project_id", ""),
graph_id=data.get("graph_id", ""),
enable_twitter=data.get("enable_twitter", True),
enable_reddit=data.get("enable_reddit", True),
status=SimulationStatus(data.get("status", "created")),
entities_count=data.get("entities_count", 0),
profiles_count=data.get("profiles_count", 0),
entity_types=data.get("entity_types", []),
config_generated=data.get("config_generated", False),
config_reasoning=data.get("config_reasoning", ""),
current_round=data.get("current_round", 0),
twitter_status=data.get("twitter_status", "not_started"),
reddit_status=data.get("reddit_status", "not_started"),
created_at=data.get("created_at", datetime.now().isoformat()),
updated_at=data.get("updated_at", datetime.now().isoformat()),
error=data.get("error"),
)
self._simulations[simulation_id] = state
return state
def create_simulation(
self,
project_id: str,
graph_id: str,
enable_twitter: bool = True,
enable_reddit: bool = True,
) -> SimulationState:
"""Create a new simulation in the ``CREATED`` state.
Args:
project_id: Owning project id.
graph_id: Source Zep graph id.
enable_twitter: When ``True``, the Twitter simulation runs.
enable_reddit: When ``True``, the Reddit simulation runs.
Returns:
The created ``SimulationState``.
"""
import uuid
simulation_id = f"sim_{uuid.uuid4().hex[:12]}"
state = SimulationState(
simulation_id=simulation_id,
project_id=project_id,
graph_id=graph_id,
enable_twitter=enable_twitter,
enable_reddit=enable_reddit,
status=SimulationStatus.CREATED,
)
self._save_simulation_state(state)
logger.info(t("log.simulation_manager.m001", simulation_id=simulation_id, project_id=project_id, graph_id=graph_id))
return state
def prepare_simulation(
self,
simulation_id: str,
simulation_requirement: str,
document_text: str,
defined_entity_types: Optional[List[str]] = None,
use_llm_for_profiles: bool = True,
progress_callback: Optional[callable] = None,
parallel_profile_count: int = 3
) -> SimulationState:
"""Prepare the simulation environment end-to-end.
Steps:
1. Read and filter entities from the graph.
2. Generate OASIS agent profiles (optional LLM enrichment, parallel-capable).
3. Use the LLM to produce simulation parameters (timing, activity, posting frequency).
4. Save the configuration and profile files.
5. Copy preset scripts into the simulation directory.
Args:
simulation_id: Simulation id.
simulation_requirement: Free-text description of the simulation goal.
document_text: Raw source document text passed to the LLM for context.
defined_entity_types: Optional list of allowed entity types.
use_llm_for_profiles: When ``True``, enrich profiles via the LLM.
progress_callback: Optional callback ``(stage, progress, message, **extras)``.
parallel_profile_count: Number of profile generations to run in parallel.
Returns:
The updated ``SimulationState``.
"""
state = self._load_simulation_state(simulation_id)
if not state:
raise ValueError(f"模拟不存在: {simulation_id}")
try:
state.status = SimulationStatus.PREPARING
self._save_simulation_state(state)
sim_dir = self._get_simulation_dir(simulation_id)
# ========== Stage 1: read and filter entities ==========
if progress_callback:
progress_callback("reading", 0, t('progress.connectingZepGraph'))
reader = ZepEntityReader()
if progress_callback:
progress_callback("reading", 30, t('progress.readingNodeData'))
filtered = reader.filter_defined_entities(
graph_id=state.graph_id,
defined_entity_types=defined_entity_types,
enrich_with_edges=True
)
state.entities_count = filtered.filtered_count
state.entity_types = list(filtered.entity_types)
if progress_callback:
progress_callback(
"reading", 100,
t('progress.readingComplete', count=filtered.filtered_count),
current=filtered.filtered_count,
total=filtered.filtered_count
)
if filtered.filtered_count == 0:
state.status = SimulationStatus.FAILED
state.error = "没有找到符合条件的实体,请检查图谱是否正确构建"
self._save_simulation_state(state)
return state
# ========== Stage 2: generate agent profiles ==========
total_entities = len(filtered.entities)
if progress_callback:
progress_callback(
"generating_profiles", 0,
t('progress.startGenerating'),
current=0,
total=total_entities
)
# Pass the graph_id so the generator can use Zep retrieval for richer context.
generator = OasisProfileGenerator(graph_id=state.graph_id)
def profile_progress(current, total, msg):
if progress_callback:
progress_callback(
"generating_profiles",
int(current / total * 100),
msg,
current=current,
total=total,
item_name=msg
)
# Configure the realtime save target (prefer Reddit JSON if Reddit is enabled).
realtime_output_path = None
realtime_platform = "reddit"
if state.enable_reddit:
realtime_output_path = os.path.join(sim_dir, "reddit_profiles.json")
realtime_platform = "reddit"
elif state.enable_twitter:
realtime_output_path = os.path.join(sim_dir, "twitter_profiles.csv")
realtime_platform = "twitter"
profiles = generator.generate_profiles_from_entities(
entities=filtered.entities,
use_llm=use_llm_for_profiles,
progress_callback=profile_progress,
graph_id=state.graph_id, # used for Zep retrieval enrichment
parallel_count=parallel_profile_count,
realtime_output_path=realtime_output_path,
output_platform=realtime_platform
)
state.profiles_count = len(profiles)
# Save profile files. Reddit also writes JSON during generation; this is
# a final consistency write. Twitter requires CSV per OASIS conventions.
if progress_callback:
progress_callback(
"generating_profiles", 95,
t('progress.savingProfiles'),
current=total_entities,
total=total_entities
)
if state.enable_reddit:
generator.save_profiles(
profiles=profiles,
file_path=os.path.join(sim_dir, "reddit_profiles.json"),
platform="reddit"
)
if state.enable_twitter:
# Twitter uses CSV format — required by OASIS.
generator.save_profiles(
profiles=profiles,
file_path=os.path.join(sim_dir, "twitter_profiles.csv"),
platform="twitter"
)
if progress_callback:
progress_callback(
"generating_profiles", 100,
t('progress.profilesComplete', count=len(profiles)),
current=len(profiles),
total=len(profiles)
)
# ========== Stage 3: LLM-driven simulation config ==========
if progress_callback:
progress_callback(
"generating_config", 0,
t('progress.analyzingRequirements'),
current=0,
total=3
)
config_generator = SimulationConfigGenerator()
if progress_callback:
progress_callback(
"generating_config", 30,
t('progress.callingLLMConfig'),
current=1,
total=3
)
sim_params = config_generator.generate_config(
simulation_id=simulation_id,
project_id=state.project_id,
graph_id=state.graph_id,
simulation_requirement=simulation_requirement,
document_text=document_text,
entities=filtered.entities,
enable_twitter=state.enable_twitter,
enable_reddit=state.enable_reddit
)
if progress_callback:
progress_callback(
"generating_config", 70,
t('progress.savingConfigFiles'),
current=2,
total=3
)
# Save the configuration file.
config_path = os.path.join(sim_dir, "simulation_config.json")
with open(config_path, 'w', encoding='utf-8') as f:
f.write(sim_params.to_json())
state.config_generated = True
state.config_reasoning = sim_params.generation_reasoning
if progress_callback:
progress_callback(
"generating_config", 100,
t('progress.configComplete'),
current=3,
total=3
)
# The runtime scripts now live under backend/scripts/; we no longer copy
# them per-simulation. simulation_runner invokes them in place.
state.status = SimulationStatus.READY
self._save_simulation_state(state)
logger.info(t("log.simulation_manager.m002", simulation_id=simulation_id, state=state.entities_count, state_2=state.profiles_count))
return state
except Exception as e:
logger.error(t("log.simulation_manager.m003", simulation_id=simulation_id, str=str(e)))
import traceback
logger.error(traceback.format_exc())
state.status = SimulationStatus.FAILED
state.error = str(e)
self._save_simulation_state(state)
raise
def get_simulation(self, simulation_id: str) -> Optional[SimulationState]:
"""Return the simulation's state, or ``None`` if unknown."""
return self._load_simulation_state(simulation_id)
def list_simulations(self, project_id: Optional[str] = None) -> List[SimulationState]:
"""List all simulations, optionally filtered by ``project_id``."""
simulations = []
if os.path.exists(self.SIMULATION_DATA_DIR):
for sim_id in os.listdir(self.SIMULATION_DATA_DIR):
# Skip dotfiles (e.g. .DS_Store) and non-directories.
sim_path = os.path.join(self.SIMULATION_DATA_DIR, sim_id)
if sim_id.startswith('.') or not os.path.isdir(sim_path):
continue
state = self._load_simulation_state(sim_id)
if state:
if project_id is None or state.project_id == project_id:
simulations.append(state)
return simulations
def get_profiles(self, simulation_id: str, platform: str = "reddit") -> List[Dict[str, Any]]:
"""Return the persisted agent profiles for a platform."""
state = self._load_simulation_state(simulation_id)
if not state:
raise ValueError(f"模拟不存在: {simulation_id}")
sim_dir = self._get_simulation_dir(simulation_id)
profile_path = os.path.join(sim_dir, f"{platform}_profiles.json")
if not os.path.exists(profile_path):
return []
with open(profile_path, 'r', encoding='utf-8') as f:
return json.load(f)
def get_simulation_config(self, simulation_id: str) -> Optional[Dict[str, Any]]:
"""Return the persisted simulation config dict, or ``None`` if absent."""
sim_dir = self._get_simulation_dir(simulation_id)
config_path = os.path.join(sim_dir, "simulation_config.json")
if not os.path.exists(config_path):
return None
with open(config_path, 'r', encoding='utf-8') as f:
return json.load(f)
def get_run_instructions(self, simulation_id: str) -> Dict[str, str]:
"""Return shell commands and instructions to launch the simulation manually."""
sim_dir = self._get_simulation_dir(simulation_id)
config_path = os.path.join(sim_dir, "simulation_config.json")
scripts_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../scripts'))
return {
"simulation_dir": sim_dir,
"scripts_dir": scripts_dir,
"config_file": config_path,
"commands": {
"twitter": f"python {scripts_dir}/run_twitter_simulation.py --config {config_path}",
"reddit": f"python {scripts_dir}/run_reddit_simulation.py --config {config_path}",
"parallel": f"python {scripts_dir}/run_parallel_simulation.py --config {config_path}",
},
"instructions": (
f"1. 激活conda环境: conda activate MiroFish\n"
f"2. 运行模拟 (脚本位于 {scripts_dir}):\n"
f" - 单独运行Twitter: python {scripts_dir}/run_twitter_simulation.py --config {config_path}\n"
f" - 单独运行Reddit: python {scripts_dir}/run_reddit_simulation.py --config {config_path}\n"
f" - 并行运行双平台: python {scripts_dir}/run_parallel_simulation.py --config {config_path}"
)
}