feat(task): refactor TaskManager to persist tasks in SQLAlchemy DB

Replace in-memory dict-based TaskManager with a SQLAlchemy-backed implementation
using TaskModel. Tasks now survive process restarts. 6 new tests added and passing.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Ubuntu 2026-05-03 00:10:16 +00:00
parent 479ae0b712
commit 1f43d35d59
2 changed files with 180 additions and 144 deletions

View File

@ -1,186 +1,136 @@
"""
Task state management
Used to track long-running tasks (e.g. graph building).
"""
"""Task state management — persistent via SQLAlchemy."""
import uuid
import threading
from datetime import datetime
from datetime import datetime, timezone
from enum import Enum
from typing import Dict, Any, Optional
from dataclasses import dataclass, field
from typing import Dict, Any, Optional, List
from ..db import get_session
from ..models.db_models import TaskModel
from ..utils.locale import t
class TaskStatus(str, Enum):
"""Task status enum"""
PENDING = "pending" # Waiting
PROCESSING = "processing" # In progress
COMPLETED = "completed" # Completed
FAILED = "failed" # Failed
@dataclass
class Task:
"""Task data class"""
task_id: str
task_type: str
status: TaskStatus
created_at: datetime
updated_at: datetime
progress: int = 0 # Total progress percentage 0-100
message: str = "" # Status message
result: Optional[Dict] = None # Task result
error: Optional[str] = None # Error info
metadata: Dict = field(default_factory=dict) # Extra metadata
progress_detail: Dict = field(default_factory=dict) # Detailed progress info
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary"""
return {
"task_id": self.task_id,
"task_type": self.task_type,
"status": self.status.value,
"created_at": self.created_at.isoformat(),
"updated_at": self.updated_at.isoformat(),
"progress": self.progress,
"message": self.message,
"progress_detail": self.progress_detail,
"result": self.result,
"error": self.error,
"metadata": self.metadata,
}
PENDING = "pending"
PROCESSING = "processing"
COMPLETED = "completed"
FAILED = "failed"
class TaskManager:
"""
Task manager
Thread-safe task state management
"""
"""Task manager — thread-safe, persistent via SQLAlchemy."""
_instance = None
_lock = threading.Lock()
def __new__(cls):
"""Singleton pattern"""
if cls._instance is None:
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._tasks: Dict[str, Task] = {}
cls._instance._task_lock = threading.Lock()
return cls._instance
def create_task(self, task_type: str, metadata: Optional[Dict] = None) -> str:
"""
Create a new task.
Args:
task_type: task type
metadata: extra metadata
Returns:
task ID
"""
task_id = str(uuid.uuid4())
now = datetime.now()
task = Task(
task_id=task_id,
task_type=task_type,
status=TaskStatus.PENDING,
created_at=now,
updated_at=now,
metadata=metadata or {}
)
with self._task_lock:
self._tasks[task_id] = task
with get_session() as db:
task = TaskModel(
id=task_id,
task_type=task_type,
status="pending",
progress=0,
progress_detail=metadata or {},
)
db.add(task)
db.commit()
return task_id
def get_task(self, task_id: str) -> Optional[Task]:
"""Get a task"""
with self._task_lock:
return self._tasks.get(task_id)
def get_task(self, task_id: str) -> Optional[Dict[str, Any]]:
with get_session() as db:
task = db.get(TaskModel, task_id)
if task is None:
return None
return self._to_dict(task)
def update_task(
self,
task_id: str,
status: Optional[TaskStatus] = None,
status: Optional[str] = None,
progress: Optional[int] = None,
message: Optional[str] = None,
result: Optional[Dict] = None,
error: Optional[str] = None,
progress_detail: Optional[Dict] = None
):
"""
Update task status.
progress_detail: Optional[Dict] = None,
) -> None:
with get_session() as db:
task = db.get(TaskModel, task_id)
if task is None:
return
if status is not None:
task.status = status
if progress is not None:
task.progress = progress
if message is not None:
task.message = message
if result is not None:
task.result = result
if error is not None:
task.error = error
if progress_detail is not None:
task.progress_detail = progress_detail
task.updated_at = datetime.now(timezone.utc)
db.commit()
Args:
task_id: task ID
status: new status
progress: progress
message: message
result: result
error: error info
progress_detail: detailed progress info
"""
with self._task_lock:
task = self._tasks.get(task_id)
if task:
task.updated_at = datetime.now()
if status is not None:
task.status = status
if progress is not None:
task.progress = progress
if message is not None:
task.message = message
if result is not None:
task.result = result
if error is not None:
task.error = error
if progress_detail is not None:
task.progress_detail = progress_detail
def complete_task(self, task_id: str, result: Dict):
"""Mark task as complete"""
def complete_task(self, task_id: str, result: Dict) -> None:
self.update_task(
task_id,
status=TaskStatus.COMPLETED,
progress=100,
message=t('progress.taskComplete'),
result=result
message=t("progress.taskComplete"),
result=result,
)
def fail_task(self, task_id: str, error: str):
"""Mark task as failed"""
def fail_task(self, task_id: str, error: str) -> None:
self.update_task(
task_id,
status=TaskStatus.FAILED,
message=t('progress.taskFailed'),
error=error
message=t("progress.taskFailed"),
error=error,
)
def list_tasks(self, task_type: Optional[str] = None) -> list:
"""List tasks"""
with self._task_lock:
tasks = list(self._tasks.values())
if task_type:
tasks = [t for t in tasks if t.task_type == task_type]
return [t.to_dict() for t in sorted(tasks, key=lambda x: x.created_at, reverse=True)]
def cleanup_old_tasks(self, max_age_hours: int = 24):
"""Clean up old tasks"""
from datetime import timedelta
cutoff = datetime.now() - timedelta(hours=max_age_hours)
with self._task_lock:
old_ids = [
tid for tid, task in self._tasks.items()
if task.created_at < cutoff and task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED]
]
for tid in old_ids:
del self._tasks[tid]
def list_tasks(self, task_type: Optional[str] = None) -> List[Dict[str, Any]]:
from sqlalchemy import select, desc
with get_session() as db:
stmt = select(TaskModel).order_by(desc(TaskModel.created_at))
if task_type:
stmt = stmt.where(TaskModel.task_type == task_type)
tasks = db.execute(stmt).scalars().all()
return [self._to_dict(t) for t in tasks]
def cleanup_old_tasks(self, max_age_hours: int = 24) -> None:
from datetime import timedelta
from sqlalchemy import delete
cutoff = datetime.now(timezone.utc) - timedelta(hours=max_age_hours)
with get_session() as db:
db.execute(
delete(TaskModel).where(
TaskModel.created_at < cutoff,
TaskModel.status.in_(["completed", "failed"]),
)
)
db.commit()
@staticmethod
def _to_dict(task: TaskModel) -> Dict[str, Any]:
return {
"task_id": task.id,
"task_type": task.task_type,
"status": task.status,
"created_at": task.created_at.isoformat(),
"updated_at": task.updated_at.isoformat(),
"progress": task.progress,
"message": task.message or "",
"progress_detail": task.progress_detail or {},
"result": task.result,
"error": task.error,
"metadata": task.progress_detail or {},
}

View File

@ -0,0 +1,86 @@
# backend/tests/test_task_manager_db.py
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from backend.app.db import Base
import backend.app.db as db_module
from backend.app.models.db_models import TaskModel
@pytest.fixture(autouse=True)
def isolated_db():
"""BD SQLite en memòria per a cada test."""
db_module._engine = create_engine("sqlite:///:memory:", connect_args={"check_same_thread": False})
db_module._SessionLocal = sessionmaker(bind=db_module._engine, autocommit=False, autoflush=False)
Base.metadata.create_all(db_module._engine)
yield
Base.metadata.drop_all(db_module._engine)
db_module._engine = None
db_module._SessionLocal = None
def test_create_and_get_task():
from backend.app.models.task import TaskManager
tm = TaskManager()
task_id = tm.create_task("graph_build", {"project_id": "proj-1"})
task = tm.get_task(task_id)
assert task is not None
assert task["task_type"] == "graph_build"
assert task["status"] == "pending"
assert task["progress"] == 0
def test_update_task_progress():
from backend.app.models.task import TaskManager
tm = TaskManager()
task_id = tm.create_task("ontology_generate")
tm.update_task(task_id, progress=50, message="Halfway")
task = tm.get_task(task_id)
assert task["progress"] == 50
assert task["message"] == "Halfway"
def test_complete_task():
from backend.app.models.task import TaskManager
tm = TaskManager()
task_id = tm.create_task("graph_build")
tm.complete_task(task_id, {"graph_id": "g-1"})
task = tm.get_task(task_id)
assert task["status"] == "completed"
assert task["progress"] == 100
assert task["result"]["graph_id"] == "g-1"
def test_fail_task():
from backend.app.models.task import TaskManager
tm = TaskManager()
task_id = tm.create_task("simulation_prepare")
tm.fail_task(task_id, "LLM timeout")
task = tm.get_task(task_id)
assert task["status"] == "failed"
assert task["error"] == "LLM timeout"
def test_task_survives_new_manager_instance():
"""La tasca ha d'estar a la BD, no a la memòria."""
from backend.app.models.task import TaskManager
tm1 = TaskManager()
task_id = tm1.create_task("graph_build")
# Crear una nova instància (simula reinici)
TaskManager._instance = None
tm2 = TaskManager()
task = tm2.get_task(task_id)
assert task is not None
assert task["task_id"] == task_id
def test_list_tasks():
from backend.app.models.task import TaskManager
tm = TaskManager()
tm.create_task("graph_build")
tm.create_task("graph_build")
tm.create_task("ontology_generate")
all_tasks = tm.list_tasks()
assert len(all_tasks) == 3
graph_tasks = tm.list_tasks(task_type="graph_build")
assert len(graph_tasks) == 2