feat(simulation): add DELETE agent, POST generate-config, PATCH config endpoints
- DELETE /<sim_id>/agent/<user_id>: removes agent from reddit_profiles.json (atomic write, guards against running/completed status) - POST /<sim_id>/generate-config: transitions profiles_ready→configuring→ready, runs LLM config generation in background thread, returns task_id - PATCH /<sim_id>/config: merges time/platform config fields into simulation_config.json (atomic write) - Corresponding SimulationManager methods: delete_agent_profile(), patch_simulation_config() - 7 tests all passing (3 original + 4 new) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
c0356e706a
commit
83cf890c83
|
|
@ -2728,3 +2728,148 @@ def patch_agent(simulation_id: str, user_id: int):
|
|||
except Exception as e:
|
||||
logger.error(f"patch_agent failed: {e}")
|
||||
return jsonify({"success": False, "error": str(e), "traceback": traceback.format_exc()}), 500
|
||||
|
||||
|
||||
@simulation_bp.route('/<simulation_id>/agent/<int:user_id>', methods=['DELETE'])
|
||||
def delete_agent(simulation_id: str, user_id: int):
|
||||
"""Remove an agent from the simulation (Fase A only)."""
|
||||
try:
|
||||
manager = SimulationManager()
|
||||
try:
|
||||
manager.delete_agent_profile(simulation_id, user_id)
|
||||
except ValueError as e:
|
||||
return jsonify({"success": False, "error": str(e)}), 404
|
||||
except PermissionError as e:
|
||||
return jsonify({"success": False, "error": str(e)}), 403
|
||||
except LookupError as e:
|
||||
return jsonify({"success": False, "error": str(e)}), 404
|
||||
|
||||
return jsonify({"success": True, "data": {"deleted_user_id": user_id}})
|
||||
except Exception as e:
|
||||
logger.error(f"delete_agent failed: {e}")
|
||||
return jsonify({"success": False, "error": str(e), "traceback": traceback.format_exc()}), 500
|
||||
|
||||
|
||||
@simulation_bp.route('/<simulation_id>/generate-config', methods=['POST'])
|
||||
def generate_config_endpoint(simulation_id: str):
|
||||
"""
|
||||
Transition from Fase A to Fase B.
|
||||
Requires status=profiles_ready. Changes to configuring, starts async config generation.
|
||||
Returns task_id for polling.
|
||||
"""
|
||||
import threading
|
||||
from ..models.task import TaskManager, TaskStatus
|
||||
from ..services.simulation_config_generator import SimulationConfigGenerator
|
||||
|
||||
try:
|
||||
manager = SimulationManager()
|
||||
state = manager.get_simulation(simulation_id)
|
||||
|
||||
if not state:
|
||||
return jsonify({"success": False, "error": t('api.simulationNotFound', id=simulation_id)}), 404
|
||||
|
||||
if state.status != SimulationStatus.PROFILES_READY:
|
||||
return jsonify({
|
||||
"success": False,
|
||||
"error": t('api.requireProfilesReady', status=state.status.value)
|
||||
}), 400
|
||||
|
||||
project = ProjectManager.get_project(state.project_id)
|
||||
if not project:
|
||||
return jsonify({"success": False, "error": t('api.projectNotFound', id=state.project_id)}), 404
|
||||
|
||||
simulation_requirement = project.get("simulation_requirement") or ""
|
||||
document_text = ProjectManager.get_extracted_text(state.project_id, get_storage()) or ""
|
||||
|
||||
task_manager = TaskManager()
|
||||
task_id = task_manager.create_task(
|
||||
task_type="generate_config",
|
||||
metadata={"simulation_id": simulation_id}
|
||||
)
|
||||
|
||||
state.status = SimulationStatus.CONFIGURING
|
||||
manager._save_simulation_state(state)
|
||||
|
||||
current_locale = get_locale()
|
||||
|
||||
def run_generate_config():
|
||||
set_locale(current_locale)
|
||||
try:
|
||||
task_manager.update_task(task_id, status=TaskStatus.PROCESSING, progress=0,
|
||||
message=t('progress.generatingSimConfig'))
|
||||
|
||||
sim_dir = manager._get_simulation_dir(simulation_id)
|
||||
profiles_file = os.path.join(sim_dir, "reddit_profiles.json")
|
||||
with open(profiles_file, 'r', encoding='utf-8') as f:
|
||||
profiles = json.load(f)
|
||||
|
||||
from ..services.zep_entity_reader import ZepEntityReader
|
||||
entity_nodes = []
|
||||
reader = ZepEntityReader()
|
||||
for p in profiles:
|
||||
uuid_ = p.get("source_entity_uuid")
|
||||
if uuid_:
|
||||
try:
|
||||
entity = reader.get_entity_with_context(state.graph_id, uuid_)
|
||||
if entity:
|
||||
entity_nodes.append(entity)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
gen = SimulationConfigGenerator(graph_id=state.graph_id)
|
||||
params = gen.generate_simulation_parameters(
|
||||
simulation_requirement=simulation_requirement,
|
||||
document_text=document_text,
|
||||
entities=entity_nodes,
|
||||
)
|
||||
|
||||
config_data = params.to_dict() if hasattr(params, 'to_dict') else {}
|
||||
config_file = os.path.join(sim_dir, "simulation_config.json")
|
||||
with open(config_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(config_data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
state2 = manager.get_simulation(simulation_id)
|
||||
if state2:
|
||||
state2.status = SimulationStatus.READY
|
||||
state2.config_generated = True
|
||||
manager._save_simulation_state(state2)
|
||||
|
||||
task_manager.complete_task(task_id, result={"status": "prepared"})
|
||||
except Exception as e:
|
||||
logger.error(f"generate_config background failed: {e}")
|
||||
task_manager.fail_task(task_id, str(e))
|
||||
state2 = manager.get_simulation(simulation_id)
|
||||
if state2:
|
||||
state2.status = SimulationStatus.PROFILES_READY
|
||||
manager._save_simulation_state(state2)
|
||||
|
||||
threading.Thread(target=run_generate_config, daemon=True).start()
|
||||
|
||||
return jsonify({"success": True, "data": {"simulation_id": simulation_id, "task_id": task_id}})
|
||||
except Exception as e:
|
||||
logger.error(f"generate_config endpoint error: {e}")
|
||||
return jsonify({"success": False, "error": str(e), "traceback": traceback.format_exc()}), 500
|
||||
|
||||
|
||||
@simulation_bp.route('/<simulation_id>/config', methods=['PATCH'])
|
||||
def patch_simulation_config_endpoint(simulation_id: str):
|
||||
"""Update simulation global config parameters (Fase B)."""
|
||||
try:
|
||||
fields = request.get_json() or {}
|
||||
if not fields:
|
||||
return jsonify({"success": False, "error": t('api.requireFields')}), 400
|
||||
|
||||
manager = SimulationManager()
|
||||
try:
|
||||
updated = manager.patch_simulation_config(simulation_id, fields)
|
||||
except ValueError as e:
|
||||
return jsonify({"success": False, "error": str(e)}), 404
|
||||
except PermissionError as e:
|
||||
return jsonify({"success": False, "error": str(e)}), 403
|
||||
except FileNotFoundError as e:
|
||||
return jsonify({"success": False, "error": str(e)}), 404
|
||||
|
||||
return jsonify({"success": True, "data": updated})
|
||||
except Exception as e:
|
||||
logger.error(f"patch_simulation_config failed: {e}")
|
||||
return jsonify({"success": False, "error": str(e), "traceback": traceback.format_exc()}), 500
|
||||
|
|
|
|||
|
|
@ -593,3 +593,94 @@ class SimulationManager:
|
|||
raise
|
||||
|
||||
return target
|
||||
|
||||
def delete_agent_profile(self, simulation_id: str, user_id: int) -> None:
|
||||
"""
|
||||
Remove an agent from reddit_profiles.json.
|
||||
Raises ValueError if simulation not found.
|
||||
Raises PermissionError if status is running or completed.
|
||||
Raises LookupError if agent not found.
|
||||
Atomic write.
|
||||
"""
|
||||
state = self.get_simulation(simulation_id)
|
||||
if not state:
|
||||
raise ValueError(f"Simulation {simulation_id} not found")
|
||||
|
||||
immutable = {SimulationStatus.RUNNING, SimulationStatus.COMPLETED}
|
||||
if state.status in immutable:
|
||||
raise PermissionError(f"Cannot delete agent while simulation is {state.status.value}")
|
||||
|
||||
sim_dir = self._get_simulation_dir(simulation_id)
|
||||
profiles_file = os.path.join(sim_dir, "reddit_profiles.json")
|
||||
backup_file = profiles_file + ".bak"
|
||||
|
||||
with open(profiles_file, 'r', encoding='utf-8') as f:
|
||||
profiles = json.load(f)
|
||||
|
||||
original_len = len(profiles)
|
||||
profiles = [p for p in profiles if p.get("user_id") != user_id]
|
||||
if len(profiles) == original_len:
|
||||
raise LookupError(f"Agent user_id={user_id} not found")
|
||||
|
||||
shutil.copy2(profiles_file, backup_file)
|
||||
try:
|
||||
with open(profiles_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(profiles, f, ensure_ascii=False, indent=2)
|
||||
os.remove(backup_file)
|
||||
except Exception:
|
||||
shutil.copy2(backup_file, profiles_file)
|
||||
os.remove(backup_file)
|
||||
raise
|
||||
|
||||
def patch_simulation_config(self, simulation_id: str, fields: dict) -> dict:
|
||||
"""
|
||||
Update global simulation config parameters (Fase B).
|
||||
Supported top-level: total_simulation_hours, minutes_per_round, agents_per_hour_min,
|
||||
agents_per_hour_max, following_probability, recsys_type, twitter_config (dict merged),
|
||||
reddit_config (dict merged).
|
||||
Atomic write.
|
||||
"""
|
||||
state = self.get_simulation(simulation_id)
|
||||
if not state:
|
||||
raise ValueError(f"Simulation {simulation_id} not found")
|
||||
|
||||
immutable = {SimulationStatus.RUNNING, SimulationStatus.COMPLETED}
|
||||
if state.status in immutable:
|
||||
raise PermissionError(f"Cannot edit config while simulation is {state.status.value}")
|
||||
|
||||
sim_dir = self._get_simulation_dir(simulation_id)
|
||||
config_file = os.path.join(sim_dir, "simulation_config.json")
|
||||
backup_file = config_file + ".bak"
|
||||
|
||||
if not os.path.exists(config_file):
|
||||
raise FileNotFoundError("simulation_config.json not found")
|
||||
|
||||
with open(config_file, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
|
||||
time_fields = {"total_simulation_hours", "minutes_per_round",
|
||||
"agents_per_hour_min", "agents_per_hour_max"}
|
||||
time_config = config.setdefault("time_config", {})
|
||||
for k in time_fields:
|
||||
if k in fields:
|
||||
time_config[k] = fields[k]
|
||||
|
||||
for k in ("following_probability", "recsys_type"):
|
||||
if k in fields:
|
||||
config[k] = fields[k]
|
||||
|
||||
for nested in ("twitter_config", "reddit_config"):
|
||||
if nested in fields and isinstance(fields[nested], dict):
|
||||
config.setdefault(nested, {}).update(fields[nested])
|
||||
|
||||
shutil.copy2(config_file, backup_file)
|
||||
try:
|
||||
with open(config_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(config, f, ensure_ascii=False, indent=2)
|
||||
os.remove(backup_file)
|
||||
except Exception:
|
||||
shutil.copy2(backup_file, config_file)
|
||||
os.remove(backup_file)
|
||||
raise
|
||||
|
||||
return config
|
||||
|
|
|
|||
|
|
@ -84,3 +84,87 @@ def test_patch_agent_not_found(client, sim_with_profiles):
|
|||
sim_id = sim_with_profiles
|
||||
resp = client.patch(f"/api/simulation/{sim_id}/agent/99", json={"bio": "x"})
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_delete_agent_removes_from_profiles(client, sim_with_profiles, tmp_path):
|
||||
sim_id = sim_with_profiles
|
||||
resp = client.delete(f"/api/simulation/{sim_id}/agent/1")
|
||||
assert resp.status_code == 200
|
||||
data = resp.get_json()
|
||||
assert data["success"] is True
|
||||
# Verify file on disk
|
||||
from pathlib import Path
|
||||
import json as _json
|
||||
import os as _os
|
||||
sim_dir = Path(_os.environ.get("OASIS_SIMULATION_DATA_DIR", str(tmp_path))) / sim_id
|
||||
# Use the monkeypatched path from SimulationManager
|
||||
from backend.app.services.simulation_manager import SimulationManager
|
||||
sim_manager_dir = Path(SimulationManager.SIMULATION_DATA_DIR) / sim_id
|
||||
profiles = _json.loads((sim_manager_dir / "reddit_profiles.json").read_text())
|
||||
assert all(p["user_id"] != 1 for p in profiles)
|
||||
assert len(profiles) == 1
|
||||
|
||||
|
||||
def test_delete_agent_not_found(client, sim_with_profiles):
|
||||
sim_id = sim_with_profiles
|
||||
resp = client.delete(f"/api/simulation/{sim_id}/agent/99")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_generate_config_returns_task_id(client, sim_with_profiles, monkeypatch):
|
||||
sim_id = sim_with_profiles
|
||||
# Mock SimulationConfigGenerator to avoid real LLM calls
|
||||
# We just need the endpoint to accept the request and return task_id
|
||||
# The background thread will fail quickly but that's OK for this test
|
||||
import backend.app.services.simulation_config_generator as scg_module
|
||||
from backend.app.models.project import ProjectManager
|
||||
|
||||
class FakeParams:
|
||||
def to_dict(self):
|
||||
return {"time_config": {"total_simulation_hours": 24}}
|
||||
|
||||
monkeypatch.setattr(scg_module, "SimulationConfigGenerator",
|
||||
lambda **kwargs: type('G', (), {
|
||||
'generate_simulation_parameters': lambda self, **kw: FakeParams()
|
||||
})())
|
||||
|
||||
monkeypatch.setattr(ProjectManager, "get_project",
|
||||
staticmethod(lambda pid: {"project_id": pid, "simulation_requirement": "test"}))
|
||||
|
||||
resp = client.post(f"/api/simulation/{sim_id}/generate-config", json={})
|
||||
assert resp.status_code == 200
|
||||
data = resp.get_json()
|
||||
assert data["success"] is True
|
||||
assert "task_id" in data["data"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sim_prepared(tmp_path, monkeypatch):
|
||||
"""Creates a simulation with status=ready and a simulation_config.json"""
|
||||
from backend.app.services.simulation_manager import SimulationManager
|
||||
monkeypatch.setattr(SimulationManager, 'SIMULATION_DATA_DIR', str(tmp_path))
|
||||
sim_id = "sim_prepared001"
|
||||
sim_dir = tmp_path / sim_id
|
||||
sim_dir.mkdir()
|
||||
state = {
|
||||
"simulation_id": sim_id, "project_id": "p", "graph_id": "g",
|
||||
"status": "ready", "entities_count": 1, "profiles_count": 1,
|
||||
"entity_types": [], "config_generated": True, "config_reasoning": "",
|
||||
"current_round": 0, "twitter_status": "not_started", "reddit_status": "not_started",
|
||||
"created_at": "2026-01-01T00:00:00", "updated_at": "2026-01-01T00:00:00",
|
||||
"error": None, "parent_simulation_id": None, "graph_id_simulation": None,
|
||||
"enable_twitter": True, "enable_reddit": True,
|
||||
}
|
||||
(sim_dir / "state.json").write_text(json.dumps(state))
|
||||
config = {"time_config": {"total_simulation_hours": 24, "minutes_per_round": 60}, "agent_configs": []}
|
||||
(sim_dir / "simulation_config.json").write_text(json.dumps(config))
|
||||
return sim_id
|
||||
|
||||
|
||||
def test_patch_config_updates_total_hours(client, sim_prepared):
|
||||
sim_id = sim_prepared
|
||||
resp = client.patch(f"/api/simulation/{sim_id}/config", json={"total_simulation_hours": 48})
|
||||
assert resp.status_code == 200
|
||||
data = resp.get_json()
|
||||
assert data["success"] is True
|
||||
assert data["data"]["time_config"]["total_simulation_hours"] == 48
|
||||
|
|
|
|||
Loading…
Reference in New Issue