This is an automated email from the ASF dual-hosted git repository. arm pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tooling-trusted-releases.git
commit a9f6090275d49dee785110b49f6b36a95054e676 Author: Alastair McFarlane <[email protected]> AuthorDate: Thu Jan 15 16:13:31 2026 +0000 Add scheduled column for tasks, allow asf_uid to be passed in task arguments --- atr/models/sql.py | 4 + atr/server.py | 39 +++----- atr/tasks/__init__.py | 14 +-- atr/tasks/gha.py | 14 +-- atr/tasks/metadata.py | 6 +- atr/worker.py | 118 +++++++++++++----------- migrations/versions/0040_2026.01.15_31d91cc5.py | 31 +++++++ 7 files changed, 130 insertions(+), 96 deletions(-) diff --git a/atr/models/sql.py b/atr/models/sql.py index 0a80912..5891a04 100644 --- a/atr/models/sql.py +++ b/atr/models/sql.py @@ -358,6 +358,10 @@ class Task(sqlmodel.SQLModel, table=True): default_factory=lambda: datetime.datetime.now(datetime.UTC), sa_column=sqlalchemy.Column(UTCDateTime, index=True), ) + scheduled: datetime.datetime = sqlmodel.Field( + default=None, + sa_column=sqlalchemy.Column(UTCDateTime, index=True), + ) started: datetime.datetime | None = sqlmodel.Field( default=None, sa_column=sqlalchemy.Column(UTCDateTime), diff --git a/atr/server.py b/atr/server.py index d64f7d1..9dde3d0 100644 --- a/atr/server.py +++ b/atr/server.py @@ -208,7 +208,8 @@ def _app_setup_lifecycle(app: base.QuartApp) -> None: await worker_manager.start() # Register recurring tasks (metadata updates, workflow status checks, etc.) - await _register_recurrent_tasks() + scheduler_task = asyncio.create_task(_register_recurrent_tasks()) + app.extensions["scheduler_task"] = scheduler_task await _initialise_test_environment() @@ -250,13 +251,13 @@ def _app_setup_lifecycle(app: base.QuartApp) -> None: await worker_manager.stop() # Stop the metadata scheduler - # metadata_scheduler = app.extensions.get("metadata_scheduler") - # if metadata_scheduler: - # metadata_scheduler.cancel() - # try: - # await metadata_scheduler - # except asyncio.CancelledError: - # ... + scheduler_task = app.extensions.get("scheduler_task") + if scheduler_task: + scheduler_task.cancel() + try: + await scheduler_task + except asyncio.CancelledError: + ... ssh_server = app.extensions.get("ssh_server") if ssh_server: @@ -514,31 +515,15 @@ async def _initialise_test_environment() -> None: await data.commit() -# -# async def _metadata_update_scheduler() -> None: -# """Periodically schedule remote metadata updates.""" -# # Wait one minute to allow the server to start -# await asyncio.sleep(60) -# -# while True: -# try: -# task = await tasks.metadata_update(asf_uid="system") -# log.info(f"Scheduled remote metadata update with ID {task.id}") -# except Exception as e: -# log.exception(f"Failed to schedule remote metadata update: {e!s}") -# -# # Schedule next update in 24 hours -# await asyncio.sleep(86400) - - async def _register_recurrent_tasks() -> None: """Schedule recurring tasks""" - # Wait one minute to allow the server to start - await asyncio.sleep(30) + # Start scheduled tasks 5 min after server start + await asyncio.sleep(300) try: await tasks.clear_scheduled() metadata = await tasks.metadata_update(asf_uid="system", schedule_next=True) log.info(f"Scheduled remote metadata update with ID {metadata.id}") + await asyncio.sleep(60) workflow = await tasks.workflow_update(asf_uid="system", schedule_next=True) log.info(f"Scheduled workflow status update with ID {workflow.id}") diff --git a/atr/tasks/__init__.py b/atr/tasks/__init__.py index 99f2328..b575a83 100644 --- a/atr/tasks/__init__.py +++ b/atr/tasks/__init__.py @@ -74,7 +74,7 @@ async def clear_scheduled(caller_data: db.Session | None = None): ] ), via(sql.Task.status) == sql.TaskStatus.QUEUED, - via(sql.Task.added) > now, + sqlmodel.or_(via(sql.Task.scheduled).is_(None), via(sql.Task.scheduled) > now), ) await data.execute(delete_stmt) @@ -181,9 +181,9 @@ async def metadata_update( schedule_next: bool = False, ) -> sql.Task: """Queue a metadata update task.""" - args = metadata.Update(asf_uid=asf_uid, next_schedule=0) + args = metadata.Update(asf_uid=asf_uid, next_schedule_seconds=0) if schedule_next: - args.next_schedule = 60 * 24 + args.next_schedule_seconds = 60 * 60 * 24 async with db.ensure_session(caller_data) as data: task = sql.Task( status=sql.TaskStatus.QUEUED, @@ -194,7 +194,7 @@ async def metadata_update( primary_rel_path=None, ) if schedule: - task.added = schedule + task.scheduled = schedule data.add(task) await data.commit() await data.flush() @@ -302,9 +302,9 @@ async def workflow_update( schedule_next: bool = False, ) -> sql.Task: """Queue a workflow status update task.""" - args = gha.WorkflowStatusCheck(next_schedule=0, run_id=0) + args = gha.WorkflowStatusCheck(next_schedule_seconds=0, run_id=0) if schedule_next: - args.next_schedule = 2 + args.next_schedule_seconds = 2 * 60 async with db.ensure_session(caller_data) as data: task = sql.Task( status=sql.TaskStatus.QUEUED, @@ -315,7 +315,7 @@ async def workflow_update( primary_rel_path=None, ) if schedule: - task.added = schedule + task.scheduled = schedule data.add(task) await data.commit() await data.flush() diff --git a/atr/tasks/gha.py b/atr/tasks/gha.py index 9ccaef8..1e8eda1 100644 --- a/atr/tasks/gha.py +++ b/atr/tasks/gha.py @@ -63,7 +63,7 @@ class DistributionWorkflow(schema.Strict): class WorkflowStatusCheck(schema.Strict): run_id: int | None = schema.description("Run ID") - next_schedule: int = pydantic.Field(default=0, description="The next scheduled time (in minutes)") + next_schedule_seconds: int = pydantic.Field(default=0, description="The next scheduled time") @checks.with_model(DistributionWorkflow) @@ -123,7 +123,7 @@ async def trigger_workflow(args: DistributionWorkflow, *, task_id: int | None = @checks.with_model(WorkflowStatusCheck) -async def status_check(args: WorkflowStatusCheck) -> DistributionWorkflowStatus: +async def status_check(args: WorkflowStatusCheck, asf_uid: str) -> DistributionWorkflowStatus: """Check remote workflow statuses.""" headers = {"Accept": "application/vnd.github+json", "Authorization": f"Bearer {config.get().GITHUB_TOKEN}"} @@ -182,7 +182,7 @@ async def status_check(args: WorkflowStatusCheck) -> DistributionWorkflowStatus: ) # Schedule next update - await _schedule_next(args) + await _schedule_next(args, asf_uid) return results.DistributionWorkflowStatus( kind="distribution_workflow_status", @@ -274,10 +274,10 @@ async def _request_and_retry( return None -async def _schedule_next(args: WorkflowStatusCheck): - if args.next_schedule: - next_schedule = datetime.datetime.now(datetime.UTC) + datetime.timedelta(minutes=args.next_schedule) - await tasks.workflow_update("system", schedule=next_schedule, schedule_next=True) +async def _schedule_next(args: WorkflowStatusCheck, asf_uid: str) -> None: + if args.next_schedule_seconds: + next_schedule = datetime.datetime.now(datetime.UTC) + datetime.timedelta(seconds=args.next_schedule_seconds) + await tasks.workflow_update(asf_uid, schedule=next_schedule, schedule_next=True) log.info( f"Scheduled next workflow status update for: {next_schedule.strftime('%Y-%m-%d %H:%M:%S')}", ) diff --git a/atr/tasks/metadata.py b/atr/tasks/metadata.py index ba2b6e5..81e070a 100644 --- a/atr/tasks/metadata.py +++ b/atr/tasks/metadata.py @@ -32,7 +32,7 @@ class Update(schema.Strict): """Arguments for the task to update metadata from remote data sources.""" asf_uid: str = schema.description("The ASF UID of the user triggering the update") - next_schedule: int = pydantic.Field(default=0, description="The next scheduled time (in minutes)") + next_schedule_seconds: int = pydantic.Field(default=0, description="The next scheduled time") class UpdateError(Exception): @@ -52,8 +52,8 @@ async def update(args: Update) -> results.Results | None: ) # Schedule next update - if args.next_schedule and args.next_schedule > 0: - next_schedule = datetime.datetime.now(datetime.UTC) + datetime.timedelta(minutes=args.next_schedule) + if args.next_schedule_seconds and args.next_schedule_seconds > 0: + next_schedule = datetime.datetime.now(datetime.UTC) + datetime.timedelta(seconds=args.next_schedule_seconds) await tasks.metadata_update(args.asf_uid, schedule=next_schedule, schedule_next=True) log.info( f"Scheduled next metadata update for: {next_schedule.strftime('%Y-%m-%d %H:%M:%S')}", diff --git a/atr/worker.py b/atr/worker.py index db3a9ac..aa19e28 100644 --- a/atr/worker.py +++ b/atr/worker.py @@ -28,6 +28,7 @@ import inspect import os import signal import traceback +from collections.abc import Awaitable, Callable from typing import Any, Final import sqlmodel @@ -103,12 +104,13 @@ def _setup_logging() -> None: # Task functions -async def _task_next_claim() -> tuple[int, str, list[str] | dict[str, Any]] | None: +async def _task_next_claim() -> tuple[int, str, list[str] | dict[str, Any], str] | None: """ Attempt to claim the oldest unclaimed task. Returns (task_id, task_type, task_args) if successful. Returns None if no tasks are available. """ + via = sql.validate_instrumented_attribute async with db.session() as data: async with data.begin(): # Get the ID of the oldest queued task @@ -117,10 +119,12 @@ async def _task_next_claim() -> tuple[int, str, list[str] | dict[str, Any]] | No .where( sqlmodel.and_( sql.Task.status == task.QUEUED, - sql.Task.added <= datetime.datetime.now(datetime.UTC) - datetime.timedelta(seconds=2), + sqlmodel.or_( + via(sql.Task.scheduled).is_(None), sql.Task.scheduled <= datetime.datetime.now(datetime.UTC) + ), ) ) - .order_by(sql.validate_instrumented_attribute(sql.Task.added).asc()) + .order_by(via(sql.Task.added).asc()) .limit(1) ) @@ -135,6 +139,7 @@ async def _task_next_claim() -> tuple[int, str, list[str] | dict[str, Any]] | No sql.validate_instrumented_attribute(sql.Task.id), sql.validate_instrumented_attribute(sql.Task.task_type), sql.validate_instrumented_attribute(sql.Task.task_args), + sql.validate_instrumented_attribute(sql.Task.asf_uid), ) ) @@ -142,14 +147,14 @@ async def _task_next_claim() -> tuple[int, str, list[str] | dict[str, Any]] | No claimed_task = result.first() if claimed_task: - task_id, task_type, task_args = claimed_task + task_id, task_type, task_args, asf_uid = claimed_task log.info(f"Claimed task {task_id} ({task_type}) with args {task_args}") - return task_id, task_type, task_args + return task_id, task_type, task_args, asf_uid return None -async def _task_process(task_id: int, task_type: str, task_args: list[str] | dict[str, Any]) -> None: +async def _task_process(task_id: int, task_type: str, task_args: list[str] | dict[str, Any], asf_uid: str) -> None: """Process a claimed task.""" log.info(f"Processing task {task_id} ({task_type}) with raw args {task_args}") try: @@ -167,52 +172,15 @@ async def _task_process(task_id: int, task_type: str, task_args: list[str] | dic # Check whether the handler is a check handler if (len(params) == 1) and (params[0].annotation == checks.FunctionArguments): - log.debug(f"Handler {handler.__name__} expects checks.FunctionArguments, fetching full task details") - async with db.session() as data: - task_obj = await data.task(id=task_id).demand( - ValueError(f"Task {task_id} disappeared during processing") - ) - - # Validate required fields from the Task object itself - if task_obj.project_name is None: - raise ValueError(f"Task {task_id} is missing required project_name") - if task_obj.version_name is None: - raise ValueError(f"Task {task_id} is missing required version_name") - if task_obj.revision_number is None: - raise ValueError(f"Task {task_id} is missing required revision_number") - - if not isinstance(task_args, dict): - raise TypeError( - f"Task {task_id} ({task_type}) has non-dict raw args" - f" {task_args} which should represent keyword_args" - ) - - async def recorder_factory() -> checks.Recorder: - return await checks.Recorder.create( - checker=handler, - project_name=task_obj.project_name or "", - version_name=task_obj.version_name or "", - revision_number=task_obj.revision_number or "", - primary_rel_path=task_obj.primary_rel_path, - ) - - function_arguments = checks.FunctionArguments( - recorder=recorder_factory, - asf_uid=task_obj.asf_uid, - project_name=task_obj.project_name or "", - version_name=task_obj.version_name or "", - revision_number=task_obj.revision_number, - primary_rel_path=task_obj.primary_rel_path, - extra_args=task_args, - ) - log.debug(f"Calling {handler.__name__} with structured arguments: {function_arguments}") - handler_result = await handler(function_arguments) + handler_result = await _execute_check_task(handler, task_args, task_id, task_type) else: # Otherwise, it's not a check handler - if sig.parameters.get("task_id") is None: - handler_result = await handler(task_args) - else: - handler_result = await handler(task_args, task_id=task_id) + additional_kwargs = {} + if sig.parameters.get("task_id") is not None: + additional_kwargs["task_id"] = task_id + if sig.parameters.get("asf_uid") is not None: + additional_kwargs["asf_uid"] = asf_uid + handler_result = await handler(task_args, **additional_kwargs) task_results = handler_result status = task.COMPLETED @@ -226,6 +194,52 @@ async def _task_process(task_id: int, task_type: str, task_args: list[str] | dic await _task_result_process(task_id, task_results, status, error) +async def _execute_check_task( + handler: Callable[..., Awaitable[results.Results | None]], + task_args: list[str] | dict[str, Any], + task_id: int, + task_type: str, +) -> results.Results | None: + log.debug(f"Handler {handler.__name__} expects checks.FunctionArguments, fetching full task details") + async with db.session() as data: + task_obj = await data.task(id=task_id).demand(ValueError(f"Task {task_id} disappeared during processing")) + + # Validate required fields from the Task object itself + if task_obj.project_name is None: + raise ValueError(f"Task {task_id} is missing required project_name") + if task_obj.version_name is None: + raise ValueError(f"Task {task_id} is missing required version_name") + if task_obj.revision_number is None: + raise ValueError(f"Task {task_id} is missing required revision_number") + + if not isinstance(task_args, dict): + raise TypeError( + f"Task {task_id} ({task_type}) has non-dict raw args {task_args} which should represent keyword_args" + ) + + async def recorder_factory() -> checks.Recorder: + return await checks.Recorder.create( + checker=handler, + project_name=task_obj.project_name or "", + version_name=task_obj.version_name or "", + revision_number=task_obj.revision_number or "", + primary_rel_path=task_obj.primary_rel_path, + ) + + function_arguments = checks.FunctionArguments( + recorder=recorder_factory, + asf_uid=task_obj.asf_uid, + project_name=task_obj.project_name or "", + version_name=task_obj.version_name or "", + revision_number=task_obj.revision_number, + primary_rel_path=task_obj.primary_rel_path, + extra_args=task_args, + ) + log.debug(f"Calling {handler.__name__} with structured arguments: {function_arguments}") + handler_result = await handler(function_arguments) + return handler_result + + async def _task_result_process( task_id: int, task_results: results.Results | None, status: sql.TaskStatus, error: str | None = None ) -> None: @@ -255,8 +269,8 @@ async def _worker_loop_run() -> None: try: task = await _task_next_claim() if task: - task_id, task_type, task_args = task - await _task_process(task_id, task_type, task_args) + task_id, task_type, task_args, asf_uid = task + await _task_process(task_id, task_type, task_args, asf_uid) processed += 1 # Only process max_to_process tasks and then exit # This prevents memory leaks from accumulating diff --git a/migrations/versions/0040_2026.01.15_31d91cc5.py b/migrations/versions/0040_2026.01.15_31d91cc5.py new file mode 100644 index 0000000..ceac8fa --- /dev/null +++ b/migrations/versions/0040_2026.01.15_31d91cc5.py @@ -0,0 +1,31 @@ +"""Add schedule column for tasks + +Revision ID: 0040_2026.01.15_31d91cc5 +Revises: 0039_2026.01.14_cd44f0ea +Create Date: 2026-01-15 15:34:00.515650+00:00 +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op + +import atr.models.sql + +# Revision identifiers, used by Alembic +revision: str = "0040_2026.01.15_31d91cc5" +down_revision: str | None = "0039_2026.01.14_cd44f0ea" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + with op.batch_alter_table("task", schema=None) as batch_op: + batch_op.add_column(sa.Column("scheduled", atr.models.sql.UTCDateTime(timezone=True), nullable=True)) + batch_op.create_index(batch_op.f("ix_task_scheduled"), ["scheduled"], unique=False) + + +def downgrade() -> None: + with op.batch_alter_table("task", schema=None) as batch_op: + batch_op.drop_index(batch_op.f("ix_task_scheduled")) + batch_op.drop_column("scheduled") --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
