This is an automated email from the ASF dual-hosted git repository.

tn pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tooling-trusted-release.git


The following commit(s) were added to refs/heads/main by this push:
     new b41da93  clean up worker manager, ensure that datetimes are stored / 
retrieved in UTC
b41da93 is described below

commit b41da934825b4b5eed28269778bba397d14536cd
Author: Thomas Neidhart <[email protected]>
AuthorDate: Mon Mar 31 17:48:38 2025 +0200

    clean up worker manager, ensure that datetimes are stored / retrieved in UTC
---
 atr/db/models.py |  48 +++++++++++++++++++++++--
 atr/manager.py   | 108 ++++++++++++++++++++++---------------------------------
 2 files changed, 87 insertions(+), 69 deletions(-)

diff --git a/atr/db/models.py b/atr/db/models.py
index e27b4a6..e8530a1 100644
--- a/atr/db/models.py
+++ b/atr/db/models.py
@@ -32,6 +32,39 @@ import sqlmodel
 import atr.db as db
 
 
+class UTCDateTime(sqlalchemy.types.TypeDecorator):
+    """
+    A custom column type to store datetime in sqlite.
+
+    As sqlite does not have timezone support, we ensure that all datetimes 
stored
+    within sqlite are converted to UTC. When retrieved, the datetimes are 
constructred
+    as offset-aware datetime with UTC as their timezone.
+    """
+
+    impl = sqlalchemy.types.TIMESTAMP(timezone=True)
+
+    cache_ok = True
+
+    def process_bind_param(self, value, dialect):  # type: ignore
+        if value:
+            if not isinstance(value, datetime.datetime):
+                raise ValueError(f"unexpected value type {type(value)}")
+
+            if value.tzinfo is None:
+                raise ValueError("encountered offset-naive datetime")
+
+            # store the datetime in UTC in sqlite as it does not support 
timezones
+            return value.astimezone(datetime.UTC)
+        else:
+            return value
+
+    def process_result_value(self, value, dialect):  # type: ignore
+        if isinstance(value, datetime.datetime):
+            return value.replace(tzinfo=datetime.UTC)
+        else:
+            return value
+
+
 class UserRole(str, enum.Enum):
     COMMITTEE_MEMBER = "committee_member"
     RELEASE_MANAGER = "release_manager"
@@ -258,10 +291,19 @@ class Task(sqlmodel.SQLModel, table=True):
     status: TaskStatus = sqlmodel.Field(default=TaskStatus.QUEUED, index=True)
     task_type: str
     task_args: Any = 
sqlmodel.Field(sa_column=sqlalchemy.Column(sqlalchemy.JSON))
-    added: datetime.datetime = sqlmodel.Field(default_factory=lambda: 
datetime.datetime.now(datetime.UTC), index=True)
-    started: datetime.datetime | None = None
+    added: datetime.datetime = sqlmodel.Field(
+        default_factory=lambda: datetime.datetime.now(datetime.UTC),
+        sa_column=sqlalchemy.Column(UTCDateTime, index=True),
+    )
+    started: datetime.datetime | None = sqlmodel.Field(
+        default=None,
+        sa_column=sqlalchemy.Column(UTCDateTime),
+    )
     pid: int | None = None
-    completed: datetime.datetime | None = None
+    completed: datetime.datetime | None = sqlmodel.Field(
+        default=None,
+        sa_column=sqlalchemy.Column(UTCDateTime),
+    )
     result: Any | None = sqlmodel.Field(default=None, 
sa_column=sqlalchemy.Column(sqlalchemy.JSON))
     error: str | None = None
     release_name: str | None = sqlmodel.Field(default=None, 
foreign_key="release.name")
diff --git a/atr/manager.py b/atr/manager.py
index 086a094..f0dcafe 100644
--- a/atr/manager.py
+++ b/atr/manager.py
@@ -17,6 +17,8 @@
 
 """Worker process manager."""
 
+from __future__ import annotations
+
 import asyncio
 import datetime
 import io
@@ -24,11 +26,12 @@ import logging
 import os
 import signal
 import sys
-from typing import Final, Optional
+from typing import Final
 
-import sqlalchemy
+import sqlmodel
 
 import atr.db as db
+import atr.db.models as models
 
 # Configure logging
 logging.basicConfig(
@@ -44,7 +47,7 @@ global_worker_debug = False
 
 # Global worker manager instance
 # Can't use "StringClass" | None, must use Optional["StringClass"] for forward 
references
-global_worker_manager: Optional["WorkerManager"] = None
+global_worker_manager: WorkerManager | None = None
 
 
 class WorkerManager:
@@ -195,7 +198,7 @@ class WorkerManager:
             except asyncio.CancelledError:
                 break
             except Exception as e:
-                _LOGGER.error(f"Error in worker monitor: {e}")
+                _LOGGER.error(f"Error in worker monitor: {e}", exc_info=e)
                 # TODO: How long should we wait before trying again?
                 await asyncio.sleep(1.0)
 
@@ -220,7 +223,7 @@ class WorkerManager:
         for pid in exited_workers:
             self.workers.pop(pid, None)
 
-        # # Check for active tasks
+        # Check for active tasks
         # try:
         #     async with get_session() as session:
         #         result = await session.execute(
@@ -238,12 +241,11 @@ class WorkerManager:
         # Spawn new workers if needed
         await self.maintain_worker_pool()
 
-        # Reset any tasks that were being processed by exited workers
-        if exited_workers:
-            await self.reset_broken_tasks(exited_workers)
+        # Reset any tasks that were being processed by now inactive workers
+        await self.reset_broken_tasks()
 
     async def terminate_long_running_task(
-        self, session: sqlalchemy.ext.asyncio.AsyncSession, worker: 
"WorkerProcess", task_id: int, pid: int
+        self, task: models.Task, worker: WorkerProcess, task_id: int, pid: int
     ) -> None:
         """
         Terminate a task that has been running for too long.
@@ -251,20 +253,10 @@ class WorkerManager:
         """
         try:
             # Mark the task as failed
-            # TODO: Replace with ORM
-            await session.execute(
-                sqlalchemy.text("""
-                    UPDATE task
-                    SET status = 'FAILED', completed = :now, error = :error
-                    WHERE id = :task_id
-                    AND status = 'ACTIVE'
-                """),
-                {
-                    "now": datetime.datetime.now(datetime.UTC),
-                    "task_id": task_id,
-                    "error": f"Task terminated after exceeding time limit of 
{self.max_task_seconds} seconds",
-                },
-            )
+            task.status = models.TaskStatus.FAILED
+            task.completed = datetime.datetime.now(datetime.UTC)
+            task.error = f"Task terminated after exceeding time limit of 
{self.max_task_seconds} seconds"
+
             if worker.pid:
                 os.kill(worker.pid, signal.SIGTERM)
                 _LOGGER.info(f"Worker {pid} terminated after processing task 
{task_id} for > {self.max_task_seconds}s")
@@ -273,39 +265,21 @@ class WorkerManager:
         except Exception as e:
             _LOGGER.error(f"Error stopping long-running worker {pid}: {e}")
 
-    async def check_task_duration(self, pid: int, worker: "WorkerProcess") -> 
bool:
+    async def check_task_duration(self, pid: int, worker: WorkerProcess) -> 
bool:
         """
         Check if a worker has been processing its task for too long.
         Returns True if the worker has been terminated.
         """
         try:
-            async with db.create_async_db_session() as session:
-                async with session.begin():
-                    # TODO: Replace with ORM
-                    result = await session.execute(
-                        sqlalchemy.text("""
-                            SELECT id, started FROM task
-                            WHERE status = 'ACTIVE'
-                            AND pid = :pid
-                        """),
-                        {"pid": pid},
-                    )
-                    task = result.first()
-                    if not task or not task[1]:
+            async with db.session() as data:
+                async with data.begin():
+                    task = await data.task(pid=pid, 
status=models.TaskStatus.ACTIVE).get()
+                    if not task or not task.started:
                         return False
 
-                    task_id, started = task
-                    # Convert started to datetime if it's a string
-                    if isinstance(started, str):
-                        try:
-                            started = 
datetime.datetime.fromisoformat(started.replace("Z", "+00:00"))
-                        except ValueError:
-                            _LOGGER.error(f"Could not parse started time 
'{started}' for task {task_id}")
-                            return False
-
-                    task_duration = (datetime.datetime.now(datetime.UTC) - 
started).total_seconds()
+                    task_duration = (datetime.datetime.now(datetime.UTC) - 
task.started).total_seconds()
                     if task_duration > self.max_task_seconds:
-                        await self.terminate_long_running_task(session, 
worker, task_id, pid)
+                        await self.terminate_long_running_task(task, worker, 
task.id, pid)
                         return True
 
                     return False
@@ -323,26 +297,28 @@ class WorkerManager:
                 await self.spawn_worker()
             _LOGGER.info(f"Worker pool restored to {len(self.workers)} 
workers")
 
-    async def reset_broken_tasks(self, exited_pids: list[int]) -> None:
+    async def reset_broken_tasks(self) -> None:
         """Reset any tasks that were being processed by exited workers."""
         try:
-            async with db.create_async_db_session() as session:
-                async with session.begin():
-                    # Generate named parameters for each PID
-                    placeholders = ",".join(f":pid_{i}" for i in 
range(len(exited_pids)))
-                    params = {f"pid_{i}": pid for i, pid in 
enumerate(exited_pids)}
-
-                    # Execute update with proper parameter binding
-                    # TODO: Replace with ORM
-                    await session.execute(
-                        sqlalchemy.text(f"""
-                            UPDATE task
-                            SET status = 'QUEUED', started = NULL, pid = NULL
-                            WHERE status = 'ACTIVE'
-                            AND pid IN ({placeholders})
-                        """),
-                        params,
+            async with db.session() as data:
+                async with data.begin():
+                    active_worker_pids = list(self.workers)
+
+                    update_stmt = (
+                        sqlmodel.update(models.Task)
+                        .where(
+                            sqlmodel.and_(
+                                
db.validate_instrumented_attribute(models.Task.id).notin_(active_worker_pids),
+                                models.Task.status == models.TaskStatus.ACTIVE,
+                            )
+                        )
+                        .values(status=models.TaskStatus.QUEUED, started=None, 
pid=None)
                     )
+
+                    result = await data.execute(update_stmt)
+                    if result.rowcount > 0:
+                        _LOGGER.info(f"Reset {result.rowcount} tasks to state 
'QUEUED' as their worker died")
+
         except Exception as e:
             _LOGGER.error(f"Error resetting broken tasks: {e}")
 
@@ -377,7 +353,7 @@ class WorkerProcess:
             # Process no longer exists
             return False
         except PermissionError:
-            # Process exists but we don't have permission to signal it
+            # Process exists, but we don't have permission to signal it
             # This shouldn't happen in our case since we own the process
             _LOGGER.warning(f"Permission error checking process {self.pid}")
             return False


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to