This is an automated email from the ASF dual-hosted git repository.
tn pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tooling-trusted-release.git
The following commit(s) were added to refs/heads/main by this push:
new b41da93 clean up worker manager, ensure that datetimes are stored /
retrieved in UTC
b41da93 is described below
commit b41da934825b4b5eed28269778bba397d14536cd
Author: Thomas Neidhart <[email protected]>
AuthorDate: Mon Mar 31 17:48:38 2025 +0200
clean up worker manager, ensure that datetimes are stored / retrieved in UTC
---
atr/db/models.py | 48 +++++++++++++++++++++++--
atr/manager.py | 108 ++++++++++++++++++++++---------------------------------
2 files changed, 87 insertions(+), 69 deletions(-)
diff --git a/atr/db/models.py b/atr/db/models.py
index e27b4a6..e8530a1 100644
--- a/atr/db/models.py
+++ b/atr/db/models.py
@@ -32,6 +32,39 @@ import sqlmodel
import atr.db as db
+class UTCDateTime(sqlalchemy.types.TypeDecorator):
+ """
+ A custom column type to store datetime in sqlite.
+
+ As sqlite does not have timezone support, we ensure that all datetimes
stored
+ within sqlite are converted to UTC. When retrieved, the datetimes are
constructred
+ as offset-aware datetime with UTC as their timezone.
+ """
+
+ impl = sqlalchemy.types.TIMESTAMP(timezone=True)
+
+ cache_ok = True
+
+ def process_bind_param(self, value, dialect): # type: ignore
+ if value:
+ if not isinstance(value, datetime.datetime):
+ raise ValueError(f"unexpected value type {type(value)}")
+
+ if value.tzinfo is None:
+ raise ValueError("encountered offset-naive datetime")
+
+ # store the datetime in UTC in sqlite as it does not support
timezones
+ return value.astimezone(datetime.UTC)
+ else:
+ return value
+
+ def process_result_value(self, value, dialect): # type: ignore
+ if isinstance(value, datetime.datetime):
+ return value.replace(tzinfo=datetime.UTC)
+ else:
+ return value
+
+
class UserRole(str, enum.Enum):
COMMITTEE_MEMBER = "committee_member"
RELEASE_MANAGER = "release_manager"
@@ -258,10 +291,19 @@ class Task(sqlmodel.SQLModel, table=True):
status: TaskStatus = sqlmodel.Field(default=TaskStatus.QUEUED, index=True)
task_type: str
task_args: Any =
sqlmodel.Field(sa_column=sqlalchemy.Column(sqlalchemy.JSON))
- added: datetime.datetime = sqlmodel.Field(default_factory=lambda:
datetime.datetime.now(datetime.UTC), index=True)
- started: datetime.datetime | None = None
+ added: datetime.datetime = sqlmodel.Field(
+ default_factory=lambda: datetime.datetime.now(datetime.UTC),
+ sa_column=sqlalchemy.Column(UTCDateTime, index=True),
+ )
+ started: datetime.datetime | None = sqlmodel.Field(
+ default=None,
+ sa_column=sqlalchemy.Column(UTCDateTime),
+ )
pid: int | None = None
- completed: datetime.datetime | None = None
+ completed: datetime.datetime | None = sqlmodel.Field(
+ default=None,
+ sa_column=sqlalchemy.Column(UTCDateTime),
+ )
result: Any | None = sqlmodel.Field(default=None,
sa_column=sqlalchemy.Column(sqlalchemy.JSON))
error: str | None = None
release_name: str | None = sqlmodel.Field(default=None,
foreign_key="release.name")
diff --git a/atr/manager.py b/atr/manager.py
index 086a094..f0dcafe 100644
--- a/atr/manager.py
+++ b/atr/manager.py
@@ -17,6 +17,8 @@
"""Worker process manager."""
+from __future__ import annotations
+
import asyncio
import datetime
import io
@@ -24,11 +26,12 @@ import logging
import os
import signal
import sys
-from typing import Final, Optional
+from typing import Final
-import sqlalchemy
+import sqlmodel
import atr.db as db
+import atr.db.models as models
# Configure logging
logging.basicConfig(
@@ -44,7 +47,7 @@ global_worker_debug = False
# Global worker manager instance
# Can't use "StringClass" | None, must use Optional["StringClass"] for forward
references
-global_worker_manager: Optional["WorkerManager"] = None
+global_worker_manager: WorkerManager | None = None
class WorkerManager:
@@ -195,7 +198,7 @@ class WorkerManager:
except asyncio.CancelledError:
break
except Exception as e:
- _LOGGER.error(f"Error in worker monitor: {e}")
+ _LOGGER.error(f"Error in worker monitor: {e}", exc_info=e)
# TODO: How long should we wait before trying again?
await asyncio.sleep(1.0)
@@ -220,7 +223,7 @@ class WorkerManager:
for pid in exited_workers:
self.workers.pop(pid, None)
- # # Check for active tasks
+ # Check for active tasks
# try:
# async with get_session() as session:
# result = await session.execute(
@@ -238,12 +241,11 @@ class WorkerManager:
# Spawn new workers if needed
await self.maintain_worker_pool()
- # Reset any tasks that were being processed by exited workers
- if exited_workers:
- await self.reset_broken_tasks(exited_workers)
+ # Reset any tasks that were being processed by now inactive workers
+ await self.reset_broken_tasks()
async def terminate_long_running_task(
- self, session: sqlalchemy.ext.asyncio.AsyncSession, worker:
"WorkerProcess", task_id: int, pid: int
+ self, task: models.Task, worker: WorkerProcess, task_id: int, pid: int
) -> None:
"""
Terminate a task that has been running for too long.
@@ -251,20 +253,10 @@ class WorkerManager:
"""
try:
# Mark the task as failed
- # TODO: Replace with ORM
- await session.execute(
- sqlalchemy.text("""
- UPDATE task
- SET status = 'FAILED', completed = :now, error = :error
- WHERE id = :task_id
- AND status = 'ACTIVE'
- """),
- {
- "now": datetime.datetime.now(datetime.UTC),
- "task_id": task_id,
- "error": f"Task terminated after exceeding time limit of
{self.max_task_seconds} seconds",
- },
- )
+ task.status = models.TaskStatus.FAILED
+ task.completed = datetime.datetime.now(datetime.UTC)
+ task.error = f"Task terminated after exceeding time limit of
{self.max_task_seconds} seconds"
+
if worker.pid:
os.kill(worker.pid, signal.SIGTERM)
_LOGGER.info(f"Worker {pid} terminated after processing task
{task_id} for > {self.max_task_seconds}s")
@@ -273,39 +265,21 @@ class WorkerManager:
except Exception as e:
_LOGGER.error(f"Error stopping long-running worker {pid}: {e}")
- async def check_task_duration(self, pid: int, worker: "WorkerProcess") ->
bool:
+ async def check_task_duration(self, pid: int, worker: WorkerProcess) ->
bool:
"""
Check if a worker has been processing its task for too long.
Returns True if the worker has been terminated.
"""
try:
- async with db.create_async_db_session() as session:
- async with session.begin():
- # TODO: Replace with ORM
- result = await session.execute(
- sqlalchemy.text("""
- SELECT id, started FROM task
- WHERE status = 'ACTIVE'
- AND pid = :pid
- """),
- {"pid": pid},
- )
- task = result.first()
- if not task or not task[1]:
+ async with db.session() as data:
+ async with data.begin():
+ task = await data.task(pid=pid,
status=models.TaskStatus.ACTIVE).get()
+ if not task or not task.started:
return False
- task_id, started = task
- # Convert started to datetime if it's a string
- if isinstance(started, str):
- try:
- started =
datetime.datetime.fromisoformat(started.replace("Z", "+00:00"))
- except ValueError:
- _LOGGER.error(f"Could not parse started time
'{started}' for task {task_id}")
- return False
-
- task_duration = (datetime.datetime.now(datetime.UTC) -
started).total_seconds()
+ task_duration = (datetime.datetime.now(datetime.UTC) -
task.started).total_seconds()
if task_duration > self.max_task_seconds:
- await self.terminate_long_running_task(session,
worker, task_id, pid)
+ await self.terminate_long_running_task(task, worker,
task.id, pid)
return True
return False
@@ -323,26 +297,28 @@ class WorkerManager:
await self.spawn_worker()
_LOGGER.info(f"Worker pool restored to {len(self.workers)}
workers")
- async def reset_broken_tasks(self, exited_pids: list[int]) -> None:
+ async def reset_broken_tasks(self) -> None:
"""Reset any tasks that were being processed by exited workers."""
try:
- async with db.create_async_db_session() as session:
- async with session.begin():
- # Generate named parameters for each PID
- placeholders = ",".join(f":pid_{i}" for i in
range(len(exited_pids)))
- params = {f"pid_{i}": pid for i, pid in
enumerate(exited_pids)}
-
- # Execute update with proper parameter binding
- # TODO: Replace with ORM
- await session.execute(
- sqlalchemy.text(f"""
- UPDATE task
- SET status = 'QUEUED', started = NULL, pid = NULL
- WHERE status = 'ACTIVE'
- AND pid IN ({placeholders})
- """),
- params,
+ async with db.session() as data:
+ async with data.begin():
+ active_worker_pids = list(self.workers)
+
+ update_stmt = (
+ sqlmodel.update(models.Task)
+ .where(
+ sqlmodel.and_(
+
db.validate_instrumented_attribute(models.Task.id).notin_(active_worker_pids),
+ models.Task.status == models.TaskStatus.ACTIVE,
+ )
+ )
+ .values(status=models.TaskStatus.QUEUED, started=None,
pid=None)
)
+
+ result = await data.execute(update_stmt)
+ if result.rowcount > 0:
+ _LOGGER.info(f"Reset {result.rowcount} tasks to state
'QUEUED' as their worker died")
+
except Exception as e:
_LOGGER.error(f"Error resetting broken tasks: {e}")
@@ -377,7 +353,7 @@ class WorkerProcess:
# Process no longer exists
return False
except PermissionError:
- # Process exists but we don't have permission to signal it
+ # Process exists, but we don't have permission to signal it
# This shouldn't happen in our case since we own the process
_LOGGER.warning(f"Permission error checking process {self.pid}")
return False
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]