feat(task): refactor TaskManager to persist tasks in SQLAlchemy DB

Replace in-memory dict-based TaskManager with a SQLAlchemy-backed implementation
using TaskModel. Tasks now survive process restarts. 6 new tests added and passing.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Ubuntu 2026-05-03 00:10:16 +00:00
parent 479ae0b712
commit 1f43d35d59
2 changed files with 180 additions and 144 deletions

View File

@ -1,186 +1,136 @@
""" """Task state management — persistent via SQLAlchemy."""
Task state management
Used to track long-running tasks (e.g. graph building).
"""
import uuid import uuid
import threading import threading
from datetime import datetime from datetime import datetime, timezone
from enum import Enum from enum import Enum
from typing import Dict, Any, Optional from typing import Dict, Any, Optional, List
from dataclasses import dataclass, field
from ..db import get_session
from ..models.db_models import TaskModel
from ..utils.locale import t from ..utils.locale import t
class TaskStatus(str, Enum): class TaskStatus(str, Enum):
"""Task status enum""" PENDING = "pending"
PENDING = "pending" # Waiting PROCESSING = "processing"
PROCESSING = "processing" # In progress COMPLETED = "completed"
COMPLETED = "completed" # Completed FAILED = "failed"
FAILED = "failed" # Failed
@dataclass
class Task:
"""Task data class"""
task_id: str
task_type: str
status: TaskStatus
created_at: datetime
updated_at: datetime
progress: int = 0 # Total progress percentage 0-100
message: str = "" # Status message
result: Optional[Dict] = None # Task result
error: Optional[str] = None # Error info
metadata: Dict = field(default_factory=dict) # Extra metadata
progress_detail: Dict = field(default_factory=dict) # Detailed progress info
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary"""
return {
"task_id": self.task_id,
"task_type": self.task_type,
"status": self.status.value,
"created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(),
"progress": self.progress,
"message": self.message,
"progress_detail": self.progress_detail,
"result": self.result,
"error": self.error,
"metadata": self.metadata,
}
class TaskManager: class TaskManager:
""" """Task manager — thread-safe, persistent via SQLAlchemy."""
Task manager
Thread-safe task state management
"""
_instance = None _instance = None
_lock = threading.Lock() _lock = threading.Lock()
def __new__(cls): def __new__(cls):
"""Singleton pattern"""
if cls._instance is None: if cls._instance is None:
with cls._lock: with cls._lock:
if cls._instance is None: if cls._instance is None:
cls._instance = super().__new__(cls) cls._instance = super().__new__(cls)
cls._instance._tasks: Dict[str, Task] = {}
cls._instance._task_lock = threading.Lock()
return cls._instance return cls._instance
def create_task(self, task_type: str, metadata: Optional[Dict] = None) -> str: def create_task(self, task_type: str, metadata: Optional[Dict] = None) -> str:
"""
Create a new task.
Args:
task_type: task type
metadata: extra metadata
Returns:
task ID
"""
task_id = str(uuid.uuid4()) task_id = str(uuid.uuid4())
now = datetime.now() with get_session() as db:
task = TaskModel(
task = Task( id=task_id,
task_id=task_id, task_type=task_type,
task_type=task_type, status="pending",
status=TaskStatus.PENDING, progress=0,
created_at=now, progress_detail=metadata or {},
updated_at=now, )
metadata=metadata or {} db.add(task)
) db.commit()
with self._task_lock:
self._tasks[task_id] = task
return task_id return task_id
def get_task(self, task_id: str) -> Optional[Task]: def get_task(self, task_id: str) -> Optional[Dict[str, Any]]:
"""Get a task""" with get_session() as db:
with self._task_lock: task = db.get(TaskModel, task_id)
return self._tasks.get(task_id) if task is None:
return None
return self._to_dict(task)
def update_task( def update_task(
self, self,
task_id: str, task_id: str,
status: Optional[TaskStatus] = None, status: Optional[str] = None,
progress: Optional[int] = None, progress: Optional[int] = None,
message: Optional[str] = None, message: Optional[str] = None,
result: Optional[Dict] = None, result: Optional[Dict] = None,
error: Optional[str] = None, error: Optional[str] = None,
progress_detail: Optional[Dict] = None progress_detail: Optional[Dict] = None,
): ) -> None:
""" with get_session() as db:
Update task status. task = db.get(TaskModel, task_id)
if task is None:
return
if status is not None:
task.status = status
if progress is not None:
task.progress = progress
if message is not None:
task.message = message
if result is not None:
task.result = result
if error is not None:
task.error = error
if progress_detail is not None:
task.progress_detail = progress_detail
task.updated_at = datetime.now(timezone.utc)
db.commit()
Args: def complete_task(self, task_id: str, result: Dict) -> None:
task_id: task ID
status: new status
progress: progress
message: message
result: result
error: error info
progress_detail: detailed progress info
"""
with self._task_lock:
task = self._tasks.get(task_id)
if task:
task.updated_at = datetime.now()
if status is not None:
task.status = status
if progress is not None:
task.progress = progress
if message is not None:
task.message = message
if result is not None:
task.result = result
if error is not None:
task.error = error
if progress_detail is not None:
task.progress_detail = progress_detail
def complete_task(self, task_id: str, result: Dict):
"""Mark task as complete"""
self.update_task( self.update_task(
task_id, task_id,
status=TaskStatus.COMPLETED, status=TaskStatus.COMPLETED,
progress=100, progress=100,
message=t('progress.taskComplete'), message=t("progress.taskComplete"),
result=result result=result,
) )
def fail_task(self, task_id: str, error: str): def fail_task(self, task_id: str, error: str) -> None:
"""Mark task as failed"""
self.update_task( self.update_task(
task_id, task_id,
status=TaskStatus.FAILED, status=TaskStatus.FAILED,
message=t('progress.taskFailed'), message=t("progress.taskFailed"),
error=error error=error,
) )
def list_tasks(self, task_type: Optional[str] = None) -> list: def list_tasks(self, task_type: Optional[str] = None) -> List[Dict[str, Any]]:
"""List tasks""" from sqlalchemy import select, desc
with self._task_lock: with get_session() as db:
tasks = list(self._tasks.values()) stmt = select(TaskModel).order_by(desc(TaskModel.created_at))
if task_type: if task_type:
tasks = [t for t in tasks if t.task_type == task_type] stmt = stmt.where(TaskModel.task_type == task_type)
return [t.to_dict() for t in sorted(tasks, key=lambda x: x.created_at, reverse=True)] tasks = db.execute(stmt).scalars().all()
return [self._to_dict(t) for t in tasks]
def cleanup_old_tasks(self, max_age_hours: int = 24): def cleanup_old_tasks(self, max_age_hours: int = 24) -> None:
"""Clean up old tasks"""
from datetime import timedelta from datetime import timedelta
cutoff = datetime.now() - timedelta(hours=max_age_hours) from sqlalchemy import delete
cutoff = datetime.now(timezone.utc) - timedelta(hours=max_age_hours)
with self._task_lock: with get_session() as db:
old_ids = [ db.execute(
tid for tid, task in self._tasks.items() delete(TaskModel).where(
if task.created_at < cutoff and task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED] TaskModel.created_at < cutoff,
] TaskModel.status.in_(["completed", "failed"]),
for tid in old_ids: )
del self._tasks[tid] )
db.commit()
@staticmethod
def _to_dict(task: TaskModel) -> Dict[str, Any]:
return {
"task_id": task.id,
"task_type": task.task_type,
"status": task.status,
"created_at": task.created_at.isoformat(),
"updated_at": task.updated_at.isoformat(),
"progress": task.progress,
"message": task.message or "",
"progress_detail": task.progress_detail or {},
"result": task.result,
"error": task.error,
"metadata": task.progress_detail or {},
}

View File

@ -0,0 +1,86 @@
# backend/tests/test_task_manager_db.py
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from backend.app.db import Base
import backend.app.db as db_module
from backend.app.models.db_models import TaskModel
@pytest.fixture(autouse=True)
def isolated_db():
"""BD SQLite en memòria per a cada test."""
db_module._engine = create_engine("sqlite:///:memory:", connect_args={"check_same_thread": False})
db_module._SessionLocal = sessionmaker(bind=db_module._engine, autocommit=False, autoflush=False)
Base.metadata.create_all(db_module._engine)
yield
Base.metadata.drop_all(db_module._engine)
db_module._engine = None
db_module._SessionLocal = None
def test_create_and_get_task():
from backend.app.models.task import TaskManager
tm = TaskManager()
task_id = tm.create_task("graph_build", {"project_id": "proj-1"})
task = tm.get_task(task_id)
assert task is not None
assert task["task_type"] == "graph_build"
assert task["status"] == "pending"
assert task["progress"] == 0
def test_update_task_progress():
from backend.app.models.task import TaskManager
tm = TaskManager()
task_id = tm.create_task("ontology_generate")
tm.update_task(task_id, progress=50, message="Halfway")
task = tm.get_task(task_id)
assert task["progress"] == 50
assert task["message"] == "Halfway"
def test_complete_task():
from backend.app.models.task import TaskManager
tm = TaskManager()
task_id = tm.create_task("graph_build")
tm.complete_task(task_id, {"graph_id": "g-1"})
task = tm.get_task(task_id)
assert task["status"] == "completed"
assert task["progress"] == 100
assert task["result"]["graph_id"] == "g-1"
def test_fail_task():
from backend.app.models.task import TaskManager
tm = TaskManager()
task_id = tm.create_task("simulation_prepare")
tm.fail_task(task_id, "LLM timeout")
task = tm.get_task(task_id)
assert task["status"] == "failed"
assert task["error"] == "LLM timeout"
def test_task_survives_new_manager_instance():
"""La tasca ha d'estar a la BD, no a la memòria."""
from backend.app.models.task import TaskManager
tm1 = TaskManager()
task_id = tm1.create_task("graph_build")
# Crear una nova instància (simula reinici)
TaskManager._instance = None
tm2 = TaskManager()
task = tm2.get_task(task_id)
assert task is not None
assert task["task_id"] == task_id
def test_list_tasks():
from backend.app.models.task import TaskManager
tm = TaskManager()
tm.create_task("graph_build")
tm.create_task("graph_build")
tm.create_task("ontology_generate")
all_tasks = tm.list_tasks()
assert len(all_tasks) == 3
graph_tasks = tm.list_tasks(task_type="graph_build")
assert len(graph_tasks) == 2