diff --git a/backend/app/models/task.py b/backend/app/models/task.py index 7a6c4f53..45c260fe 100644 --- a/backend/app/models/task.py +++ b/backend/app/models/task.py @@ -1,186 +1,136 @@ -""" -Task state management -Used to track long-running tasks (e.g. graph building). -""" - +"""Task state management — persistent via SQLAlchemy.""" import uuid import threading -from datetime import datetime +from datetime import datetime, timezone from enum import Enum -from typing import Dict, Any, Optional -from dataclasses import dataclass, field +from typing import Dict, Any, Optional, List +from ..db import get_session +from ..models.db_models import TaskModel from ..utils.locale import t class TaskStatus(str, Enum): - """Task status enum""" - PENDING = "pending" # Waiting - PROCESSING = "processing" # In progress - COMPLETED = "completed" # Completed - FAILED = "failed" # Failed - - -@dataclass -class Task: - """Task data class""" - task_id: str - task_type: str - status: TaskStatus - created_at: datetime - updated_at: datetime - progress: int = 0 # Total progress percentage 0-100 - message: str = "" # Status message - result: Optional[Dict] = None # Task result - error: Optional[str] = None # Error info - metadata: Dict = field(default_factory=dict) # Extra metadata - progress_detail: Dict = field(default_factory=dict) # Detailed progress info - - def to_dict(self) -> Dict[str, Any]: - """Convert to dictionary""" - return { - "task_id": self.task_id, - "task_type": self.task_type, - "status": self.status.value, - "created_at": self.created_at.isoformat(), - "updated_at": self.updated_at.isoformat(), - "progress": self.progress, - "message": self.message, - "progress_detail": self.progress_detail, - "result": self.result, - "error": self.error, - "metadata": self.metadata, - } + PENDING = "pending" + PROCESSING = "processing" + COMPLETED = "completed" + FAILED = "failed" class TaskManager: - """ - Task manager - Thread-safe task state management - """ + """Task manager — thread-safe, persistent via SQLAlchemy.""" _instance = None _lock = threading.Lock() def __new__(cls): - """Singleton pattern""" if cls._instance is None: with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) - cls._instance._tasks: Dict[str, Task] = {} - cls._instance._task_lock = threading.Lock() return cls._instance - + def create_task(self, task_type: str, metadata: Optional[Dict] = None) -> str: - """ - Create a new task. - - Args: - task_type: task type - metadata: extra metadata - - Returns: - task ID - """ task_id = str(uuid.uuid4()) - now = datetime.now() - - task = Task( - task_id=task_id, - task_type=task_type, - status=TaskStatus.PENDING, - created_at=now, - updated_at=now, - metadata=metadata or {} - ) - - with self._task_lock: - self._tasks[task_id] = task - + with get_session() as db: + task = TaskModel( + id=task_id, + task_type=task_type, + status="pending", + progress=0, + progress_detail=metadata or {}, + ) + db.add(task) + db.commit() return task_id - - def get_task(self, task_id: str) -> Optional[Task]: - """Get a task""" - with self._task_lock: - return self._tasks.get(task_id) - + + def get_task(self, task_id: str) -> Optional[Dict[str, Any]]: + with get_session() as db: + task = db.get(TaskModel, task_id) + if task is None: + return None + return self._to_dict(task) + def update_task( self, task_id: str, - status: Optional[TaskStatus] = None, + status: Optional[str] = None, progress: Optional[int] = None, message: Optional[str] = None, result: Optional[Dict] = None, error: Optional[str] = None, - progress_detail: Optional[Dict] = None - ): - """ - Update task status. + progress_detail: Optional[Dict] = None, + ) -> None: + with get_session() as db: + task = db.get(TaskModel, task_id) + if task is None: + return + if status is not None: + task.status = status + if progress is not None: + task.progress = progress + if message is not None: + task.message = message + if result is not None: + task.result = result + if error is not None: + task.error = error + if progress_detail is not None: + task.progress_detail = progress_detail + task.updated_at = datetime.now(timezone.utc) + db.commit() - Args: - task_id: task ID - status: new status - progress: progress - message: message - result: result - error: error info - progress_detail: detailed progress info - """ - with self._task_lock: - task = self._tasks.get(task_id) - if task: - task.updated_at = datetime.now() - if status is not None: - task.status = status - if progress is not None: - task.progress = progress - if message is not None: - task.message = message - if result is not None: - task.result = result - if error is not None: - task.error = error - if progress_detail is not None: - task.progress_detail = progress_detail - - def complete_task(self, task_id: str, result: Dict): - """Mark task as complete""" + def complete_task(self, task_id: str, result: Dict) -> None: self.update_task( task_id, status=TaskStatus.COMPLETED, progress=100, - message=t('progress.taskComplete'), - result=result + message=t("progress.taskComplete"), + result=result, ) - - def fail_task(self, task_id: str, error: str): - """Mark task as failed""" + + def fail_task(self, task_id: str, error: str) -> None: self.update_task( task_id, status=TaskStatus.FAILED, - message=t('progress.taskFailed'), - error=error + message=t("progress.taskFailed"), + error=error, ) - - def list_tasks(self, task_type: Optional[str] = None) -> list: - """List tasks""" - with self._task_lock: - tasks = list(self._tasks.values()) - if task_type: - tasks = [t for t in tasks if t.task_type == task_type] - return [t.to_dict() for t in sorted(tasks, key=lambda x: x.created_at, reverse=True)] - - def cleanup_old_tasks(self, max_age_hours: int = 24): - """Clean up old tasks""" - from datetime import timedelta - cutoff = datetime.now() - timedelta(hours=max_age_hours) - - with self._task_lock: - old_ids = [ - tid for tid, task in self._tasks.items() - if task.created_at < cutoff and task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED] - ] - for tid in old_ids: - del self._tasks[tid] + def list_tasks(self, task_type: Optional[str] = None) -> List[Dict[str, Any]]: + from sqlalchemy import select, desc + with get_session() as db: + stmt = select(TaskModel).order_by(desc(TaskModel.created_at)) + if task_type: + stmt = stmt.where(TaskModel.task_type == task_type) + tasks = db.execute(stmt).scalars().all() + return [self._to_dict(t) for t in tasks] + + def cleanup_old_tasks(self, max_age_hours: int = 24) -> None: + from datetime import timedelta + from sqlalchemy import delete + cutoff = datetime.now(timezone.utc) - timedelta(hours=max_age_hours) + with get_session() as db: + db.execute( + delete(TaskModel).where( + TaskModel.created_at < cutoff, + TaskModel.status.in_(["completed", "failed"]), + ) + ) + db.commit() + + @staticmethod + def _to_dict(task: TaskModel) -> Dict[str, Any]: + return { + "task_id": task.id, + "task_type": task.task_type, + "status": task.status, + "created_at": task.created_at.isoformat(), + "updated_at": task.updated_at.isoformat(), + "progress": task.progress, + "message": task.message or "", + "progress_detail": task.progress_detail or {}, + "result": task.result, + "error": task.error, + "metadata": task.progress_detail or {}, + } diff --git a/backend/tests/test_task_manager_db.py b/backend/tests/test_task_manager_db.py new file mode 100644 index 00000000..3d1b9039 --- /dev/null +++ b/backend/tests/test_task_manager_db.py @@ -0,0 +1,86 @@ +# backend/tests/test_task_manager_db.py +import pytest +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker +from backend.app.db import Base +import backend.app.db as db_module +from backend.app.models.db_models import TaskModel + + +@pytest.fixture(autouse=True) +def isolated_db(): + """BD SQLite en memòria per a cada test.""" + db_module._engine = create_engine("sqlite:///:memory:", connect_args={"check_same_thread": False}) + db_module._SessionLocal = sessionmaker(bind=db_module._engine, autocommit=False, autoflush=False) + Base.metadata.create_all(db_module._engine) + yield + Base.metadata.drop_all(db_module._engine) + db_module._engine = None + db_module._SessionLocal = None + + +def test_create_and_get_task(): + from backend.app.models.task import TaskManager + tm = TaskManager() + task_id = tm.create_task("graph_build", {"project_id": "proj-1"}) + task = tm.get_task(task_id) + assert task is not None + assert task["task_type"] == "graph_build" + assert task["status"] == "pending" + assert task["progress"] == 0 + + +def test_update_task_progress(): + from backend.app.models.task import TaskManager + tm = TaskManager() + task_id = tm.create_task("ontology_generate") + tm.update_task(task_id, progress=50, message="Halfway") + task = tm.get_task(task_id) + assert task["progress"] == 50 + assert task["message"] == "Halfway" + + +def test_complete_task(): + from backend.app.models.task import TaskManager + tm = TaskManager() + task_id = tm.create_task("graph_build") + tm.complete_task(task_id, {"graph_id": "g-1"}) + task = tm.get_task(task_id) + assert task["status"] == "completed" + assert task["progress"] == 100 + assert task["result"]["graph_id"] == "g-1" + + +def test_fail_task(): + from backend.app.models.task import TaskManager + tm = TaskManager() + task_id = tm.create_task("simulation_prepare") + tm.fail_task(task_id, "LLM timeout") + task = tm.get_task(task_id) + assert task["status"] == "failed" + assert task["error"] == "LLM timeout" + + +def test_task_survives_new_manager_instance(): + """La tasca ha d'estar a la BD, no a la memòria.""" + from backend.app.models.task import TaskManager + tm1 = TaskManager() + task_id = tm1.create_task("graph_build") + # Crear una nova instància (simula reinici) + TaskManager._instance = None + tm2 = TaskManager() + task = tm2.get_task(task_id) + assert task is not None + assert task["task_id"] == task_id + + +def test_list_tasks(): + from backend.app.models.task import TaskManager + tm = TaskManager() + tm.create_task("graph_build") + tm.create_task("graph_build") + tm.create_task("ontology_generate") + all_tasks = tm.list_tasks() + assert len(all_tasks) == 3 + graph_tasks = tm.list_tasks(task_type="graph_build") + assert len(graph_tasks) == 2