feat(task): refactor TaskManager to persist tasks in SQLAlchemy DB
Replace in-memory dict-based TaskManager with a SQLAlchemy-backed implementation using TaskModel. Tasks now survive process restarts. 6 new tests added and passing. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
479ae0b712
commit
1f43d35d59
|
|
@ -1,186 +1,136 @@
|
||||||
"""
|
"""Task state management — persistent via SQLAlchemy."""
|
||||||
Task state management
|
|
||||||
Used to track long-running tasks (e.g. graph building).
|
|
||||||
"""
|
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
import threading
|
import threading
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Dict, Any, Optional
|
from typing import Dict, Any, Optional, List
|
||||||
from dataclasses import dataclass, field
|
|
||||||
|
|
||||||
|
from ..db import get_session
|
||||||
|
from ..models.db_models import TaskModel
|
||||||
from ..utils.locale import t
|
from ..utils.locale import t
|
||||||
|
|
||||||
|
|
||||||
class TaskStatus(str, Enum):
|
class TaskStatus(str, Enum):
|
||||||
"""Task status enum"""
|
PENDING = "pending"
|
||||||
PENDING = "pending" # Waiting
|
PROCESSING = "processing"
|
||||||
PROCESSING = "processing" # In progress
|
COMPLETED = "completed"
|
||||||
COMPLETED = "completed" # Completed
|
FAILED = "failed"
|
||||||
FAILED = "failed" # Failed
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Task:
|
|
||||||
"""Task data class"""
|
|
||||||
task_id: str
|
|
||||||
task_type: str
|
|
||||||
status: TaskStatus
|
|
||||||
created_at: datetime
|
|
||||||
updated_at: datetime
|
|
||||||
progress: int = 0 # Total progress percentage 0-100
|
|
||||||
message: str = "" # Status message
|
|
||||||
result: Optional[Dict] = None # Task result
|
|
||||||
error: Optional[str] = None # Error info
|
|
||||||
metadata: Dict = field(default_factory=dict) # Extra metadata
|
|
||||||
progress_detail: Dict = field(default_factory=dict) # Detailed progress info
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
|
||||||
"""Convert to dictionary"""
|
|
||||||
return {
|
|
||||||
"task_id": self.task_id,
|
|
||||||
"task_type": self.task_type,
|
|
||||||
"status": self.status.value,
|
|
||||||
"created_at": self.created_at.isoformat(),
|
|
||||||
"updated_at": self.updated_at.isoformat(),
|
|
||||||
"progress": self.progress,
|
|
||||||
"message": self.message,
|
|
||||||
"progress_detail": self.progress_detail,
|
|
||||||
"result": self.result,
|
|
||||||
"error": self.error,
|
|
||||||
"metadata": self.metadata,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class TaskManager:
|
class TaskManager:
|
||||||
"""
|
"""Task manager — thread-safe, persistent via SQLAlchemy."""
|
||||||
Task manager
|
|
||||||
Thread-safe task state management
|
|
||||||
"""
|
|
||||||
|
|
||||||
_instance = None
|
_instance = None
|
||||||
_lock = threading.Lock()
|
_lock = threading.Lock()
|
||||||
|
|
||||||
def __new__(cls):
|
def __new__(cls):
|
||||||
"""Singleton pattern"""
|
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
with cls._lock:
|
with cls._lock:
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = super().__new__(cls)
|
cls._instance = super().__new__(cls)
|
||||||
cls._instance._tasks: Dict[str, Task] = {}
|
|
||||||
cls._instance._task_lock = threading.Lock()
|
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def create_task(self, task_type: str, metadata: Optional[Dict] = None) -> str:
|
def create_task(self, task_type: str, metadata: Optional[Dict] = None) -> str:
|
||||||
"""
|
|
||||||
Create a new task.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
task_type: task type
|
|
||||||
metadata: extra metadata
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
task ID
|
|
||||||
"""
|
|
||||||
task_id = str(uuid.uuid4())
|
task_id = str(uuid.uuid4())
|
||||||
now = datetime.now()
|
with get_session() as db:
|
||||||
|
task = TaskModel(
|
||||||
task = Task(
|
id=task_id,
|
||||||
task_id=task_id,
|
task_type=task_type,
|
||||||
task_type=task_type,
|
status="pending",
|
||||||
status=TaskStatus.PENDING,
|
progress=0,
|
||||||
created_at=now,
|
progress_detail=metadata or {},
|
||||||
updated_at=now,
|
)
|
||||||
metadata=metadata or {}
|
db.add(task)
|
||||||
)
|
db.commit()
|
||||||
|
|
||||||
with self._task_lock:
|
|
||||||
self._tasks[task_id] = task
|
|
||||||
|
|
||||||
return task_id
|
return task_id
|
||||||
|
|
||||||
def get_task(self, task_id: str) -> Optional[Task]:
|
def get_task(self, task_id: str) -> Optional[Dict[str, Any]]:
|
||||||
"""Get a task"""
|
with get_session() as db:
|
||||||
with self._task_lock:
|
task = db.get(TaskModel, task_id)
|
||||||
return self._tasks.get(task_id)
|
if task is None:
|
||||||
|
return None
|
||||||
|
return self._to_dict(task)
|
||||||
|
|
||||||
def update_task(
|
def update_task(
|
||||||
self,
|
self,
|
||||||
task_id: str,
|
task_id: str,
|
||||||
status: Optional[TaskStatus] = None,
|
status: Optional[str] = None,
|
||||||
progress: Optional[int] = None,
|
progress: Optional[int] = None,
|
||||||
message: Optional[str] = None,
|
message: Optional[str] = None,
|
||||||
result: Optional[Dict] = None,
|
result: Optional[Dict] = None,
|
||||||
error: Optional[str] = None,
|
error: Optional[str] = None,
|
||||||
progress_detail: Optional[Dict] = None
|
progress_detail: Optional[Dict] = None,
|
||||||
):
|
) -> None:
|
||||||
"""
|
with get_session() as db:
|
||||||
Update task status.
|
task = db.get(TaskModel, task_id)
|
||||||
|
if task is None:
|
||||||
|
return
|
||||||
|
if status is not None:
|
||||||
|
task.status = status
|
||||||
|
if progress is not None:
|
||||||
|
task.progress = progress
|
||||||
|
if message is not None:
|
||||||
|
task.message = message
|
||||||
|
if result is not None:
|
||||||
|
task.result = result
|
||||||
|
if error is not None:
|
||||||
|
task.error = error
|
||||||
|
if progress_detail is not None:
|
||||||
|
task.progress_detail = progress_detail
|
||||||
|
task.updated_at = datetime.now(timezone.utc)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
Args:
|
def complete_task(self, task_id: str, result: Dict) -> None:
|
||||||
task_id: task ID
|
|
||||||
status: new status
|
|
||||||
progress: progress
|
|
||||||
message: message
|
|
||||||
result: result
|
|
||||||
error: error info
|
|
||||||
progress_detail: detailed progress info
|
|
||||||
"""
|
|
||||||
with self._task_lock:
|
|
||||||
task = self._tasks.get(task_id)
|
|
||||||
if task:
|
|
||||||
task.updated_at = datetime.now()
|
|
||||||
if status is not None:
|
|
||||||
task.status = status
|
|
||||||
if progress is not None:
|
|
||||||
task.progress = progress
|
|
||||||
if message is not None:
|
|
||||||
task.message = message
|
|
||||||
if result is not None:
|
|
||||||
task.result = result
|
|
||||||
if error is not None:
|
|
||||||
task.error = error
|
|
||||||
if progress_detail is not None:
|
|
||||||
task.progress_detail = progress_detail
|
|
||||||
|
|
||||||
def complete_task(self, task_id: str, result: Dict):
|
|
||||||
"""Mark task as complete"""
|
|
||||||
self.update_task(
|
self.update_task(
|
||||||
task_id,
|
task_id,
|
||||||
status=TaskStatus.COMPLETED,
|
status=TaskStatus.COMPLETED,
|
||||||
progress=100,
|
progress=100,
|
||||||
message=t('progress.taskComplete'),
|
message=t("progress.taskComplete"),
|
||||||
result=result
|
result=result,
|
||||||
)
|
)
|
||||||
|
|
||||||
def fail_task(self, task_id: str, error: str):
|
def fail_task(self, task_id: str, error: str) -> None:
|
||||||
"""Mark task as failed"""
|
|
||||||
self.update_task(
|
self.update_task(
|
||||||
task_id,
|
task_id,
|
||||||
status=TaskStatus.FAILED,
|
status=TaskStatus.FAILED,
|
||||||
message=t('progress.taskFailed'),
|
message=t("progress.taskFailed"),
|
||||||
error=error
|
error=error,
|
||||||
)
|
)
|
||||||
|
|
||||||
def list_tasks(self, task_type: Optional[str] = None) -> list:
|
def list_tasks(self, task_type: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||||
"""List tasks"""
|
from sqlalchemy import select, desc
|
||||||
with self._task_lock:
|
with get_session() as db:
|
||||||
tasks = list(self._tasks.values())
|
stmt = select(TaskModel).order_by(desc(TaskModel.created_at))
|
||||||
if task_type:
|
if task_type:
|
||||||
tasks = [t for t in tasks if t.task_type == task_type]
|
stmt = stmt.where(TaskModel.task_type == task_type)
|
||||||
return [t.to_dict() for t in sorted(tasks, key=lambda x: x.created_at, reverse=True)]
|
tasks = db.execute(stmt).scalars().all()
|
||||||
|
return [self._to_dict(t) for t in tasks]
|
||||||
|
|
||||||
def cleanup_old_tasks(self, max_age_hours: int = 24):
|
def cleanup_old_tasks(self, max_age_hours: int = 24) -> None:
|
||||||
"""Clean up old tasks"""
|
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
cutoff = datetime.now() - timedelta(hours=max_age_hours)
|
from sqlalchemy import delete
|
||||||
|
cutoff = datetime.now(timezone.utc) - timedelta(hours=max_age_hours)
|
||||||
with self._task_lock:
|
with get_session() as db:
|
||||||
old_ids = [
|
db.execute(
|
||||||
tid for tid, task in self._tasks.items()
|
delete(TaskModel).where(
|
||||||
if task.created_at < cutoff and task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED]
|
TaskModel.created_at < cutoff,
|
||||||
]
|
TaskModel.status.in_(["completed", "failed"]),
|
||||||
for tid in old_ids:
|
)
|
||||||
del self._tasks[tid]
|
)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _to_dict(task: TaskModel) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"task_id": task.id,
|
||||||
|
"task_type": task.task_type,
|
||||||
|
"status": task.status,
|
||||||
|
"created_at": task.created_at.isoformat(),
|
||||||
|
"updated_at": task.updated_at.isoformat(),
|
||||||
|
"progress": task.progress,
|
||||||
|
"message": task.message or "",
|
||||||
|
"progress_detail": task.progress_detail or {},
|
||||||
|
"result": task.result,
|
||||||
|
"error": task.error,
|
||||||
|
"metadata": task.progress_detail or {},
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,86 @@
|
||||||
|
# backend/tests/test_task_manager_db.py
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from backend.app.db import Base
|
||||||
|
import backend.app.db as db_module
|
||||||
|
from backend.app.models.db_models import TaskModel
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def isolated_db():
|
||||||
|
"""BD SQLite en memòria per a cada test."""
|
||||||
|
db_module._engine = create_engine("sqlite:///:memory:", connect_args={"check_same_thread": False})
|
||||||
|
db_module._SessionLocal = sessionmaker(bind=db_module._engine, autocommit=False, autoflush=False)
|
||||||
|
Base.metadata.create_all(db_module._engine)
|
||||||
|
yield
|
||||||
|
Base.metadata.drop_all(db_module._engine)
|
||||||
|
db_module._engine = None
|
||||||
|
db_module._SessionLocal = None
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_and_get_task():
|
||||||
|
from backend.app.models.task import TaskManager
|
||||||
|
tm = TaskManager()
|
||||||
|
task_id = tm.create_task("graph_build", {"project_id": "proj-1"})
|
||||||
|
task = tm.get_task(task_id)
|
||||||
|
assert task is not None
|
||||||
|
assert task["task_type"] == "graph_build"
|
||||||
|
assert task["status"] == "pending"
|
||||||
|
assert task["progress"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_task_progress():
|
||||||
|
from backend.app.models.task import TaskManager
|
||||||
|
tm = TaskManager()
|
||||||
|
task_id = tm.create_task("ontology_generate")
|
||||||
|
tm.update_task(task_id, progress=50, message="Halfway")
|
||||||
|
task = tm.get_task(task_id)
|
||||||
|
assert task["progress"] == 50
|
||||||
|
assert task["message"] == "Halfway"
|
||||||
|
|
||||||
|
|
||||||
|
def test_complete_task():
|
||||||
|
from backend.app.models.task import TaskManager
|
||||||
|
tm = TaskManager()
|
||||||
|
task_id = tm.create_task("graph_build")
|
||||||
|
tm.complete_task(task_id, {"graph_id": "g-1"})
|
||||||
|
task = tm.get_task(task_id)
|
||||||
|
assert task["status"] == "completed"
|
||||||
|
assert task["progress"] == 100
|
||||||
|
assert task["result"]["graph_id"] == "g-1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_fail_task():
|
||||||
|
from backend.app.models.task import TaskManager
|
||||||
|
tm = TaskManager()
|
||||||
|
task_id = tm.create_task("simulation_prepare")
|
||||||
|
tm.fail_task(task_id, "LLM timeout")
|
||||||
|
task = tm.get_task(task_id)
|
||||||
|
assert task["status"] == "failed"
|
||||||
|
assert task["error"] == "LLM timeout"
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_survives_new_manager_instance():
|
||||||
|
"""La tasca ha d'estar a la BD, no a la memòria."""
|
||||||
|
from backend.app.models.task import TaskManager
|
||||||
|
tm1 = TaskManager()
|
||||||
|
task_id = tm1.create_task("graph_build")
|
||||||
|
# Crear una nova instància (simula reinici)
|
||||||
|
TaskManager._instance = None
|
||||||
|
tm2 = TaskManager()
|
||||||
|
task = tm2.get_task(task_id)
|
||||||
|
assert task is not None
|
||||||
|
assert task["task_id"] == task_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_tasks():
|
||||||
|
from backend.app.models.task import TaskManager
|
||||||
|
tm = TaskManager()
|
||||||
|
tm.create_task("graph_build")
|
||||||
|
tm.create_task("graph_build")
|
||||||
|
tm.create_task("ontology_generate")
|
||||||
|
all_tasks = tm.list_tasks()
|
||||||
|
assert len(all_tasks) == 3
|
||||||
|
graph_tasks = tm.list_tasks(task_type="graph_build")
|
||||||
|
assert len(graph_tasks) == 2
|
||||||
Loading…
Reference in New Issue