feat(task): refactor TaskManager to persist tasks in SQLAlchemy DB
Replace in-memory dict-based TaskManager with a SQLAlchemy-backed implementation using TaskModel. Tasks now survive process restarts. 6 new tests added and passing. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
479ae0b712
commit
1f43d35d59
|
|
@ -1,186 +1,136 @@
|
|||
"""
|
||||
Task state management
|
||||
Used to track long-running tasks (e.g. graph building).
|
||||
"""
|
||||
|
||||
"""Task state management — persistent via SQLAlchemy."""
|
||||
import uuid
|
||||
import threading
|
||||
from datetime import datetime
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, Optional
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, Any, Optional, List
|
||||
|
||||
from ..db import get_session
|
||||
from ..models.db_models import TaskModel
|
||||
from ..utils.locale import t
|
||||
|
||||
|
||||
class TaskStatus(str, Enum):
|
||||
"""Task status enum"""
|
||||
PENDING = "pending" # Waiting
|
||||
PROCESSING = "processing" # In progress
|
||||
COMPLETED = "completed" # Completed
|
||||
FAILED = "failed" # Failed
|
||||
|
||||
|
||||
@dataclass
|
||||
class Task:
|
||||
"""Task data class"""
|
||||
task_id: str
|
||||
task_type: str
|
||||
status: TaskStatus
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
progress: int = 0 # Total progress percentage 0-100
|
||||
message: str = "" # Status message
|
||||
result: Optional[Dict] = None # Task result
|
||||
error: Optional[str] = None # Error info
|
||||
metadata: Dict = field(default_factory=dict) # Extra metadata
|
||||
progress_detail: Dict = field(default_factory=dict) # Detailed progress info
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary"""
|
||||
return {
|
||||
"task_id": self.task_id,
|
||||
"task_type": self.task_type,
|
||||
"status": self.status.value,
|
||||
"created_at": self.created_at.isoformat(),
|
||||
"updated_at": self.updated_at.isoformat(),
|
||||
"progress": self.progress,
|
||||
"message": self.message,
|
||||
"progress_detail": self.progress_detail,
|
||||
"result": self.result,
|
||||
"error": self.error,
|
||||
"metadata": self.metadata,
|
||||
}
|
||||
PENDING = "pending"
|
||||
PROCESSING = "processing"
|
||||
COMPLETED = "completed"
|
||||
FAILED = "failed"
|
||||
|
||||
|
||||
class TaskManager:
|
||||
"""
|
||||
Task manager
|
||||
Thread-safe task state management
|
||||
"""
|
||||
"""Task manager — thread-safe, persistent via SQLAlchemy."""
|
||||
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
def __new__(cls):
|
||||
"""Singleton pattern"""
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._tasks: Dict[str, Task] = {}
|
||||
cls._instance._task_lock = threading.Lock()
|
||||
return cls._instance
|
||||
|
||||
|
||||
def create_task(self, task_type: str, metadata: Optional[Dict] = None) -> str:
|
||||
"""
|
||||
Create a new task.
|
||||
|
||||
Args:
|
||||
task_type: task type
|
||||
metadata: extra metadata
|
||||
|
||||
Returns:
|
||||
task ID
|
||||
"""
|
||||
task_id = str(uuid.uuid4())
|
||||
now = datetime.now()
|
||||
|
||||
task = Task(
|
||||
task_id=task_id,
|
||||
task_type=task_type,
|
||||
status=TaskStatus.PENDING,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
metadata=metadata or {}
|
||||
)
|
||||
|
||||
with self._task_lock:
|
||||
self._tasks[task_id] = task
|
||||
|
||||
with get_session() as db:
|
||||
task = TaskModel(
|
||||
id=task_id,
|
||||
task_type=task_type,
|
||||
status="pending",
|
||||
progress=0,
|
||||
progress_detail=metadata or {},
|
||||
)
|
||||
db.add(task)
|
||||
db.commit()
|
||||
return task_id
|
||||
|
||||
def get_task(self, task_id: str) -> Optional[Task]:
|
||||
"""Get a task"""
|
||||
with self._task_lock:
|
||||
return self._tasks.get(task_id)
|
||||
|
||||
|
||||
def get_task(self, task_id: str) -> Optional[Dict[str, Any]]:
|
||||
with get_session() as db:
|
||||
task = db.get(TaskModel, task_id)
|
||||
if task is None:
|
||||
return None
|
||||
return self._to_dict(task)
|
||||
|
||||
def update_task(
|
||||
self,
|
||||
task_id: str,
|
||||
status: Optional[TaskStatus] = None,
|
||||
status: Optional[str] = None,
|
||||
progress: Optional[int] = None,
|
||||
message: Optional[str] = None,
|
||||
result: Optional[Dict] = None,
|
||||
error: Optional[str] = None,
|
||||
progress_detail: Optional[Dict] = None
|
||||
):
|
||||
"""
|
||||
Update task status.
|
||||
progress_detail: Optional[Dict] = None,
|
||||
) -> None:
|
||||
with get_session() as db:
|
||||
task = db.get(TaskModel, task_id)
|
||||
if task is None:
|
||||
return
|
||||
if status is not None:
|
||||
task.status = status
|
||||
if progress is not None:
|
||||
task.progress = progress
|
||||
if message is not None:
|
||||
task.message = message
|
||||
if result is not None:
|
||||
task.result = result
|
||||
if error is not None:
|
||||
task.error = error
|
||||
if progress_detail is not None:
|
||||
task.progress_detail = progress_detail
|
||||
task.updated_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
|
||||
Args:
|
||||
task_id: task ID
|
||||
status: new status
|
||||
progress: progress
|
||||
message: message
|
||||
result: result
|
||||
error: error info
|
||||
progress_detail: detailed progress info
|
||||
"""
|
||||
with self._task_lock:
|
||||
task = self._tasks.get(task_id)
|
||||
if task:
|
||||
task.updated_at = datetime.now()
|
||||
if status is not None:
|
||||
task.status = status
|
||||
if progress is not None:
|
||||
task.progress = progress
|
||||
if message is not None:
|
||||
task.message = message
|
||||
if result is not None:
|
||||
task.result = result
|
||||
if error is not None:
|
||||
task.error = error
|
||||
if progress_detail is not None:
|
||||
task.progress_detail = progress_detail
|
||||
|
||||
def complete_task(self, task_id: str, result: Dict):
|
||||
"""Mark task as complete"""
|
||||
def complete_task(self, task_id: str, result: Dict) -> None:
|
||||
self.update_task(
|
||||
task_id,
|
||||
status=TaskStatus.COMPLETED,
|
||||
progress=100,
|
||||
message=t('progress.taskComplete'),
|
||||
result=result
|
||||
message=t("progress.taskComplete"),
|
||||
result=result,
|
||||
)
|
||||
|
||||
def fail_task(self, task_id: str, error: str):
|
||||
"""Mark task as failed"""
|
||||
|
||||
def fail_task(self, task_id: str, error: str) -> None:
|
||||
self.update_task(
|
||||
task_id,
|
||||
status=TaskStatus.FAILED,
|
||||
message=t('progress.taskFailed'),
|
||||
error=error
|
||||
message=t("progress.taskFailed"),
|
||||
error=error,
|
||||
)
|
||||
|
||||
def list_tasks(self, task_type: Optional[str] = None) -> list:
|
||||
"""List tasks"""
|
||||
with self._task_lock:
|
||||
tasks = list(self._tasks.values())
|
||||
if task_type:
|
||||
tasks = [t for t in tasks if t.task_type == task_type]
|
||||
return [t.to_dict() for t in sorted(tasks, key=lambda x: x.created_at, reverse=True)]
|
||||
|
||||
def cleanup_old_tasks(self, max_age_hours: int = 24):
|
||||
"""Clean up old tasks"""
|
||||
from datetime import timedelta
|
||||
cutoff = datetime.now() - timedelta(hours=max_age_hours)
|
||||
|
||||
with self._task_lock:
|
||||
old_ids = [
|
||||
tid for tid, task in self._tasks.items()
|
||||
if task.created_at < cutoff and task.status in [TaskStatus.COMPLETED, TaskStatus.FAILED]
|
||||
]
|
||||
for tid in old_ids:
|
||||
del self._tasks[tid]
|
||||
|
||||
def list_tasks(self, task_type: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
from sqlalchemy import select, desc
|
||||
with get_session() as db:
|
||||
stmt = select(TaskModel).order_by(desc(TaskModel.created_at))
|
||||
if task_type:
|
||||
stmt = stmt.where(TaskModel.task_type == task_type)
|
||||
tasks = db.execute(stmt).scalars().all()
|
||||
return [self._to_dict(t) for t in tasks]
|
||||
|
||||
def cleanup_old_tasks(self, max_age_hours: int = 24) -> None:
|
||||
from datetime import timedelta
|
||||
from sqlalchemy import delete
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(hours=max_age_hours)
|
||||
with get_session() as db:
|
||||
db.execute(
|
||||
delete(TaskModel).where(
|
||||
TaskModel.created_at < cutoff,
|
||||
TaskModel.status.in_(["completed", "failed"]),
|
||||
)
|
||||
)
|
||||
db.commit()
|
||||
|
||||
@staticmethod
|
||||
def _to_dict(task: TaskModel) -> Dict[str, Any]:
|
||||
return {
|
||||
"task_id": task.id,
|
||||
"task_type": task.task_type,
|
||||
"status": task.status,
|
||||
"created_at": task.created_at.isoformat(),
|
||||
"updated_at": task.updated_at.isoformat(),
|
||||
"progress": task.progress,
|
||||
"message": task.message or "",
|
||||
"progress_detail": task.progress_detail or {},
|
||||
"result": task.result,
|
||||
"error": task.error,
|
||||
"metadata": task.progress_detail or {},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -0,0 +1,86 @@
|
|||
# backend/tests/test_task_manager_db.py
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from backend.app.db import Base
|
||||
import backend.app.db as db_module
|
||||
from backend.app.models.db_models import TaskModel
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def isolated_db():
|
||||
"""BD SQLite en memòria per a cada test."""
|
||||
db_module._engine = create_engine("sqlite:///:memory:", connect_args={"check_same_thread": False})
|
||||
db_module._SessionLocal = sessionmaker(bind=db_module._engine, autocommit=False, autoflush=False)
|
||||
Base.metadata.create_all(db_module._engine)
|
||||
yield
|
||||
Base.metadata.drop_all(db_module._engine)
|
||||
db_module._engine = None
|
||||
db_module._SessionLocal = None
|
||||
|
||||
|
||||
def test_create_and_get_task():
|
||||
from backend.app.models.task import TaskManager
|
||||
tm = TaskManager()
|
||||
task_id = tm.create_task("graph_build", {"project_id": "proj-1"})
|
||||
task = tm.get_task(task_id)
|
||||
assert task is not None
|
||||
assert task["task_type"] == "graph_build"
|
||||
assert task["status"] == "pending"
|
||||
assert task["progress"] == 0
|
||||
|
||||
|
||||
def test_update_task_progress():
|
||||
from backend.app.models.task import TaskManager
|
||||
tm = TaskManager()
|
||||
task_id = tm.create_task("ontology_generate")
|
||||
tm.update_task(task_id, progress=50, message="Halfway")
|
||||
task = tm.get_task(task_id)
|
||||
assert task["progress"] == 50
|
||||
assert task["message"] == "Halfway"
|
||||
|
||||
|
||||
def test_complete_task():
|
||||
from backend.app.models.task import TaskManager
|
||||
tm = TaskManager()
|
||||
task_id = tm.create_task("graph_build")
|
||||
tm.complete_task(task_id, {"graph_id": "g-1"})
|
||||
task = tm.get_task(task_id)
|
||||
assert task["status"] == "completed"
|
||||
assert task["progress"] == 100
|
||||
assert task["result"]["graph_id"] == "g-1"
|
||||
|
||||
|
||||
def test_fail_task():
|
||||
from backend.app.models.task import TaskManager
|
||||
tm = TaskManager()
|
||||
task_id = tm.create_task("simulation_prepare")
|
||||
tm.fail_task(task_id, "LLM timeout")
|
||||
task = tm.get_task(task_id)
|
||||
assert task["status"] == "failed"
|
||||
assert task["error"] == "LLM timeout"
|
||||
|
||||
|
||||
def test_task_survives_new_manager_instance():
|
||||
"""La tasca ha d'estar a la BD, no a la memòria."""
|
||||
from backend.app.models.task import TaskManager
|
||||
tm1 = TaskManager()
|
||||
task_id = tm1.create_task("graph_build")
|
||||
# Crear una nova instància (simula reinici)
|
||||
TaskManager._instance = None
|
||||
tm2 = TaskManager()
|
||||
task = tm2.get_task(task_id)
|
||||
assert task is not None
|
||||
assert task["task_id"] == task_id
|
||||
|
||||
|
||||
def test_list_tasks():
|
||||
from backend.app.models.task import TaskManager
|
||||
tm = TaskManager()
|
||||
tm.create_task("graph_build")
|
||||
tm.create_task("graph_build")
|
||||
tm.create_task("ontology_generate")
|
||||
all_tasks = tm.list_tasks()
|
||||
assert len(all_tasks) == 3
|
||||
graph_tasks = tm.list_tasks(task_type="graph_build")
|
||||
assert len(graph_tasks) == 2
|
||||
Loading…
Reference in New Issue