525 lines
20 KiB
Python
525 lines
20 KiB
Python
"""OASIS simulation manager.
|
|
|
|
Drives parallel Twitter + Reddit simulations using preset scripts plus
|
|
LLM-generated configuration parameters.
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import shutil
|
|
from typing import Dict, Any, List, Optional
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime
|
|
from enum import Enum
|
|
|
|
from ..config import Config
|
|
from ..utils.logger import get_logger
|
|
from .zep_entity_reader import ZepEntityReader, FilteredEntities
|
|
from .oasis_profile_generator import OasisProfileGenerator, OasisAgentProfile
|
|
from .simulation_config_generator import SimulationConfigGenerator, SimulationParameters
|
|
from ..utils.locale import t
|
|
|
|
logger = get_logger('mirofish.simulation')
|
|
|
|
|
|
class SimulationStatus(str, Enum):
|
|
"""Simulation lifecycle status."""
|
|
CREATED = "created"
|
|
PREPARING = "preparing"
|
|
READY = "ready"
|
|
RUNNING = "running"
|
|
PAUSED = "paused"
|
|
STOPPED = "stopped" # manually stopped
|
|
COMPLETED = "completed" # finished naturally
|
|
FAILED = "failed"
|
|
|
|
|
|
class PlatformType(str, Enum):
|
|
"""Simulated platform types."""
|
|
TWITTER = "twitter"
|
|
REDDIT = "reddit"
|
|
|
|
|
|
@dataclass
|
|
class SimulationState:
|
|
"""In-memory + persisted state for a single simulation."""
|
|
simulation_id: str
|
|
project_id: str
|
|
graph_id: str
|
|
|
|
# Per-platform enable flags.
|
|
enable_twitter: bool = True
|
|
enable_reddit: bool = True
|
|
|
|
# Lifecycle status.
|
|
status: SimulationStatus = SimulationStatus.CREATED
|
|
|
|
# Counters captured during the prepare phase.
|
|
entities_count: int = 0
|
|
profiles_count: int = 0
|
|
entity_types: List[str] = field(default_factory=list)
|
|
|
|
# Information about the auto-generated config.
|
|
config_generated: bool = False
|
|
config_reasoning: str = ""
|
|
|
|
# Runtime data.
|
|
current_round: int = 0
|
|
twitter_status: str = "not_started"
|
|
reddit_status: str = "not_started"
|
|
|
|
# Timestamps.
|
|
created_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
|
updated_at: str = field(default_factory=lambda: datetime.now().isoformat())
|
|
|
|
# Error message when status == FAILED.
|
|
error: Optional[str] = None
|
|
|
|
def to_dict(self) -> Dict[str, Any]:
|
|
"""Full state dict (used for persistence and internal callers)."""
|
|
return {
|
|
"simulation_id": self.simulation_id,
|
|
"project_id": self.project_id,
|
|
"graph_id": self.graph_id,
|
|
"enable_twitter": self.enable_twitter,
|
|
"enable_reddit": self.enable_reddit,
|
|
"status": self.status.value,
|
|
"entities_count": self.entities_count,
|
|
"profiles_count": self.profiles_count,
|
|
"entity_types": self.entity_types,
|
|
"config_generated": self.config_generated,
|
|
"config_reasoning": self.config_reasoning,
|
|
"current_round": self.current_round,
|
|
"twitter_status": self.twitter_status,
|
|
"reddit_status": self.reddit_status,
|
|
"created_at": self.created_at,
|
|
"updated_at": self.updated_at,
|
|
"error": self.error,
|
|
}
|
|
|
|
def to_simple_dict(self) -> Dict[str, Any]:
|
|
"""Simplified state dict (used for API responses)."""
|
|
return {
|
|
"simulation_id": self.simulation_id,
|
|
"project_id": self.project_id,
|
|
"graph_id": self.graph_id,
|
|
"status": self.status.value,
|
|
"entities_count": self.entities_count,
|
|
"profiles_count": self.profiles_count,
|
|
"entity_types": self.entity_types,
|
|
"config_generated": self.config_generated,
|
|
"error": self.error,
|
|
}
|
|
|
|
|
|
class SimulationManager:
|
|
"""Simulation manager.
|
|
|
|
Core responsibilities:
|
|
1. Read entities from the Zep graph and filter to the configured types.
|
|
2. Generate OASIS agent profiles per entity.
|
|
3. Use the LLM to generate simulation configuration parameters.
|
|
4. Materialize the files the preset scripts expect.
|
|
"""
|
|
|
|
# Root directory for persisted simulation data.
|
|
SIMULATION_DATA_DIR = os.path.join(
|
|
os.path.dirname(__file__),
|
|
'../../uploads/simulations'
|
|
)
|
|
|
|
def __init__(self):
|
|
# Ensure the simulation data directory exists.
|
|
os.makedirs(self.SIMULATION_DATA_DIR, exist_ok=True)
|
|
|
|
# In-memory cache of simulation state objects.
|
|
self._simulations: Dict[str, SimulationState] = {}
|
|
|
|
def _get_simulation_dir(self, simulation_id: str) -> str:
|
|
"""Return the on-disk directory for a simulation, creating if missing."""
|
|
sim_dir = os.path.join(self.SIMULATION_DATA_DIR, simulation_id)
|
|
os.makedirs(sim_dir, exist_ok=True)
|
|
return sim_dir
|
|
|
|
def _save_simulation_state(self, state: SimulationState):
|
|
"""Persist a simulation state to disk and update the cache."""
|
|
sim_dir = self._get_simulation_dir(state.simulation_id)
|
|
state_file = os.path.join(sim_dir, "state.json")
|
|
|
|
state.updated_at = datetime.now().isoformat()
|
|
|
|
with open(state_file, 'w', encoding='utf-8') as f:
|
|
json.dump(state.to_dict(), f, ensure_ascii=False, indent=2)
|
|
|
|
self._simulations[state.simulation_id] = state
|
|
|
|
def _load_simulation_state(self, simulation_id: str) -> Optional[SimulationState]:
|
|
"""Load a simulation state from disk (or cache) by id."""
|
|
if simulation_id in self._simulations:
|
|
return self._simulations[simulation_id]
|
|
|
|
sim_dir = self._get_simulation_dir(simulation_id)
|
|
state_file = os.path.join(sim_dir, "state.json")
|
|
|
|
if not os.path.exists(state_file):
|
|
return None
|
|
|
|
with open(state_file, 'r', encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
|
|
state = SimulationState(
|
|
simulation_id=simulation_id,
|
|
project_id=data.get("project_id", ""),
|
|
graph_id=data.get("graph_id", ""),
|
|
enable_twitter=data.get("enable_twitter", True),
|
|
enable_reddit=data.get("enable_reddit", True),
|
|
status=SimulationStatus(data.get("status", "created")),
|
|
entities_count=data.get("entities_count", 0),
|
|
profiles_count=data.get("profiles_count", 0),
|
|
entity_types=data.get("entity_types", []),
|
|
config_generated=data.get("config_generated", False),
|
|
config_reasoning=data.get("config_reasoning", ""),
|
|
current_round=data.get("current_round", 0),
|
|
twitter_status=data.get("twitter_status", "not_started"),
|
|
reddit_status=data.get("reddit_status", "not_started"),
|
|
created_at=data.get("created_at", datetime.now().isoformat()),
|
|
updated_at=data.get("updated_at", datetime.now().isoformat()),
|
|
error=data.get("error"),
|
|
)
|
|
|
|
self._simulations[simulation_id] = state
|
|
return state
|
|
|
|
def create_simulation(
|
|
self,
|
|
project_id: str,
|
|
graph_id: str,
|
|
enable_twitter: bool = True,
|
|
enable_reddit: bool = True,
|
|
) -> SimulationState:
|
|
"""Create a new simulation in the ``CREATED`` state.
|
|
|
|
Args:
|
|
project_id: Owning project id.
|
|
graph_id: Source Zep graph id.
|
|
enable_twitter: When ``True``, the Twitter simulation runs.
|
|
enable_reddit: When ``True``, the Reddit simulation runs.
|
|
|
|
Returns:
|
|
The created ``SimulationState``.
|
|
"""
|
|
import uuid
|
|
simulation_id = f"sim_{uuid.uuid4().hex[:12]}"
|
|
|
|
state = SimulationState(
|
|
simulation_id=simulation_id,
|
|
project_id=project_id,
|
|
graph_id=graph_id,
|
|
enable_twitter=enable_twitter,
|
|
enable_reddit=enable_reddit,
|
|
status=SimulationStatus.CREATED,
|
|
)
|
|
|
|
self._save_simulation_state(state)
|
|
logger.info(t("log.simulation_manager.m001", simulation_id=simulation_id, project_id=project_id, graph_id=graph_id))
|
|
|
|
return state
|
|
|
|
def prepare_simulation(
|
|
self,
|
|
simulation_id: str,
|
|
simulation_requirement: str,
|
|
document_text: str,
|
|
defined_entity_types: Optional[List[str]] = None,
|
|
use_llm_for_profiles: bool = True,
|
|
progress_callback: Optional[callable] = None,
|
|
parallel_profile_count: int = 3
|
|
) -> SimulationState:
|
|
"""Prepare the simulation environment end-to-end.
|
|
|
|
Steps:
|
|
1. Read and filter entities from the graph.
|
|
2. Generate OASIS agent profiles (optional LLM enrichment, parallel-capable).
|
|
3. Use the LLM to produce simulation parameters (timing, activity, posting frequency).
|
|
4. Save the configuration and profile files.
|
|
5. Copy preset scripts into the simulation directory.
|
|
|
|
Args:
|
|
simulation_id: Simulation id.
|
|
simulation_requirement: Free-text description of the simulation goal.
|
|
document_text: Raw source document text passed to the LLM for context.
|
|
defined_entity_types: Optional list of allowed entity types.
|
|
use_llm_for_profiles: When ``True``, enrich profiles via the LLM.
|
|
progress_callback: Optional callback ``(stage, progress, message, **extras)``.
|
|
parallel_profile_count: Number of profile generations to run in parallel.
|
|
|
|
Returns:
|
|
The updated ``SimulationState``.
|
|
"""
|
|
state = self._load_simulation_state(simulation_id)
|
|
if not state:
|
|
raise ValueError(f"模拟不存在: {simulation_id}")
|
|
|
|
try:
|
|
state.status = SimulationStatus.PREPARING
|
|
self._save_simulation_state(state)
|
|
|
|
sim_dir = self._get_simulation_dir(simulation_id)
|
|
|
|
# ========== Stage 1: read and filter entities ==========
|
|
if progress_callback:
|
|
progress_callback("reading", 0, t('progress.connectingZepGraph'))
|
|
|
|
reader = ZepEntityReader()
|
|
|
|
if progress_callback:
|
|
progress_callback("reading", 30, t('progress.readingNodeData'))
|
|
|
|
filtered = reader.filter_defined_entities(
|
|
graph_id=state.graph_id,
|
|
defined_entity_types=defined_entity_types,
|
|
enrich_with_edges=True
|
|
)
|
|
|
|
state.entities_count = filtered.filtered_count
|
|
state.entity_types = list(filtered.entity_types)
|
|
|
|
if progress_callback:
|
|
progress_callback(
|
|
"reading", 100,
|
|
t('progress.readingComplete', count=filtered.filtered_count),
|
|
current=filtered.filtered_count,
|
|
total=filtered.filtered_count
|
|
)
|
|
|
|
if filtered.filtered_count == 0:
|
|
state.status = SimulationStatus.FAILED
|
|
state.error = "没有找到符合条件的实体,请检查图谱是否正确构建"
|
|
self._save_simulation_state(state)
|
|
return state
|
|
|
|
# ========== Stage 2: generate agent profiles ==========
|
|
total_entities = len(filtered.entities)
|
|
|
|
if progress_callback:
|
|
progress_callback(
|
|
"generating_profiles", 0,
|
|
t('progress.startGenerating'),
|
|
current=0,
|
|
total=total_entities
|
|
)
|
|
|
|
# Pass the graph_id so the generator can use Zep retrieval for richer context.
|
|
generator = OasisProfileGenerator(graph_id=state.graph_id)
|
|
|
|
def profile_progress(current, total, msg):
|
|
if progress_callback:
|
|
progress_callback(
|
|
"generating_profiles",
|
|
int(current / total * 100),
|
|
msg,
|
|
current=current,
|
|
total=total,
|
|
item_name=msg
|
|
)
|
|
|
|
# Configure the realtime save target (prefer Reddit JSON if Reddit is enabled).
|
|
realtime_output_path = None
|
|
realtime_platform = "reddit"
|
|
if state.enable_reddit:
|
|
realtime_output_path = os.path.join(sim_dir, "reddit_profiles.json")
|
|
realtime_platform = "reddit"
|
|
elif state.enable_twitter:
|
|
realtime_output_path = os.path.join(sim_dir, "twitter_profiles.csv")
|
|
realtime_platform = "twitter"
|
|
|
|
profiles = generator.generate_profiles_from_entities(
|
|
entities=filtered.entities,
|
|
use_llm=use_llm_for_profiles,
|
|
progress_callback=profile_progress,
|
|
graph_id=state.graph_id, # used for Zep retrieval enrichment
|
|
parallel_count=parallel_profile_count,
|
|
realtime_output_path=realtime_output_path,
|
|
output_platform=realtime_platform
|
|
)
|
|
|
|
state.profiles_count = len(profiles)
|
|
|
|
# Save profile files. Reddit also writes JSON during generation; this is
|
|
# a final consistency write. Twitter requires CSV per OASIS conventions.
|
|
if progress_callback:
|
|
progress_callback(
|
|
"generating_profiles", 95,
|
|
t('progress.savingProfiles'),
|
|
current=total_entities,
|
|
total=total_entities
|
|
)
|
|
|
|
if state.enable_reddit:
|
|
generator.save_profiles(
|
|
profiles=profiles,
|
|
file_path=os.path.join(sim_dir, "reddit_profiles.json"),
|
|
platform="reddit"
|
|
)
|
|
|
|
if state.enable_twitter:
|
|
# Twitter uses CSV format — required by OASIS.
|
|
generator.save_profiles(
|
|
profiles=profiles,
|
|
file_path=os.path.join(sim_dir, "twitter_profiles.csv"),
|
|
platform="twitter"
|
|
)
|
|
|
|
if progress_callback:
|
|
progress_callback(
|
|
"generating_profiles", 100,
|
|
t('progress.profilesComplete', count=len(profiles)),
|
|
current=len(profiles),
|
|
total=len(profiles)
|
|
)
|
|
|
|
# ========== Stage 3: LLM-driven simulation config ==========
|
|
if progress_callback:
|
|
progress_callback(
|
|
"generating_config", 0,
|
|
t('progress.analyzingRequirements'),
|
|
current=0,
|
|
total=3
|
|
)
|
|
|
|
config_generator = SimulationConfigGenerator()
|
|
|
|
if progress_callback:
|
|
progress_callback(
|
|
"generating_config", 30,
|
|
t('progress.callingLLMConfig'),
|
|
current=1,
|
|
total=3
|
|
)
|
|
|
|
sim_params = config_generator.generate_config(
|
|
simulation_id=simulation_id,
|
|
project_id=state.project_id,
|
|
graph_id=state.graph_id,
|
|
simulation_requirement=simulation_requirement,
|
|
document_text=document_text,
|
|
entities=filtered.entities,
|
|
enable_twitter=state.enable_twitter,
|
|
enable_reddit=state.enable_reddit
|
|
)
|
|
|
|
if progress_callback:
|
|
progress_callback(
|
|
"generating_config", 70,
|
|
t('progress.savingConfigFiles'),
|
|
current=2,
|
|
total=3
|
|
)
|
|
|
|
# Save the configuration file.
|
|
config_path = os.path.join(sim_dir, "simulation_config.json")
|
|
with open(config_path, 'w', encoding='utf-8') as f:
|
|
f.write(sim_params.to_json())
|
|
|
|
state.config_generated = True
|
|
state.config_reasoning = sim_params.generation_reasoning
|
|
|
|
if progress_callback:
|
|
progress_callback(
|
|
"generating_config", 100,
|
|
t('progress.configComplete'),
|
|
current=3,
|
|
total=3
|
|
)
|
|
|
|
# The runtime scripts now live under backend/scripts/; we no longer copy
|
|
# them per-simulation. simulation_runner invokes them in place.
|
|
|
|
state.status = SimulationStatus.READY
|
|
self._save_simulation_state(state)
|
|
|
|
logger.info(t("log.simulation_manager.m002", simulation_id=simulation_id, state=state.entities_count, state_2=state.profiles_count))
|
|
|
|
return state
|
|
|
|
except Exception as e:
|
|
logger.error(t("log.simulation_manager.m003", simulation_id=simulation_id, str=str(e)))
|
|
import traceback
|
|
logger.error(traceback.format_exc())
|
|
state.status = SimulationStatus.FAILED
|
|
state.error = str(e)
|
|
self._save_simulation_state(state)
|
|
raise
|
|
|
|
def get_simulation(self, simulation_id: str) -> Optional[SimulationState]:
|
|
"""Return the simulation's state, or ``None`` if unknown."""
|
|
return self._load_simulation_state(simulation_id)
|
|
|
|
def list_simulations(self, project_id: Optional[str] = None) -> List[SimulationState]:
|
|
"""List all simulations, optionally filtered by ``project_id``."""
|
|
simulations = []
|
|
|
|
if os.path.exists(self.SIMULATION_DATA_DIR):
|
|
for sim_id in os.listdir(self.SIMULATION_DATA_DIR):
|
|
# Skip dotfiles (e.g. .DS_Store) and non-directories.
|
|
sim_path = os.path.join(self.SIMULATION_DATA_DIR, sim_id)
|
|
if sim_id.startswith('.') or not os.path.isdir(sim_path):
|
|
continue
|
|
|
|
state = self._load_simulation_state(sim_id)
|
|
if state:
|
|
if project_id is None or state.project_id == project_id:
|
|
simulations.append(state)
|
|
|
|
return simulations
|
|
|
|
def get_profiles(self, simulation_id: str, platform: str = "reddit") -> List[Dict[str, Any]]:
|
|
"""Return the persisted agent profiles for a platform."""
|
|
state = self._load_simulation_state(simulation_id)
|
|
if not state:
|
|
raise ValueError(f"模拟不存在: {simulation_id}")
|
|
|
|
sim_dir = self._get_simulation_dir(simulation_id)
|
|
profile_path = os.path.join(sim_dir, f"{platform}_profiles.json")
|
|
|
|
if not os.path.exists(profile_path):
|
|
return []
|
|
|
|
with open(profile_path, 'r', encoding='utf-8') as f:
|
|
return json.load(f)
|
|
|
|
def get_simulation_config(self, simulation_id: str) -> Optional[Dict[str, Any]]:
|
|
"""Return the persisted simulation config dict, or ``None`` if absent."""
|
|
sim_dir = self._get_simulation_dir(simulation_id)
|
|
config_path = os.path.join(sim_dir, "simulation_config.json")
|
|
|
|
if not os.path.exists(config_path):
|
|
return None
|
|
|
|
with open(config_path, 'r', encoding='utf-8') as f:
|
|
return json.load(f)
|
|
|
|
def get_run_instructions(self, simulation_id: str) -> Dict[str, str]:
|
|
"""Return shell commands and instructions to launch the simulation manually."""
|
|
sim_dir = self._get_simulation_dir(simulation_id)
|
|
config_path = os.path.join(sim_dir, "simulation_config.json")
|
|
scripts_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../scripts'))
|
|
|
|
return {
|
|
"simulation_dir": sim_dir,
|
|
"scripts_dir": scripts_dir,
|
|
"config_file": config_path,
|
|
"commands": {
|
|
"twitter": f"python {scripts_dir}/run_twitter_simulation.py --config {config_path}",
|
|
"reddit": f"python {scripts_dir}/run_reddit_simulation.py --config {config_path}",
|
|
"parallel": f"python {scripts_dir}/run_parallel_simulation.py --config {config_path}",
|
|
},
|
|
"instructions": (
|
|
f"1. 激活conda环境: conda activate MiroFish\n"
|
|
f"2. 运行模拟 (脚本位于 {scripts_dir}):\n"
|
|
f" - 单独运行Twitter: python {scripts_dir}/run_twitter_simulation.py --config {config_path}\n"
|
|
f" - 单独运行Reddit: python {scripts_dir}/run_reddit_simulation.py --config {config_path}\n"
|
|
f" - 并行运行双平台: python {scripts_dir}/run_parallel_simulation.py --config {config_path}"
|
|
)
|
|
}
|