This is an automated email from the ASF dual-hosted git repository.
sbp pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tooling-atr-experiments.git
The following commit(s) were added to refs/heads/main by this push:
new d17213c Use a Pydantic class for archive integrity check arguments
d17213c is described below
commit d17213cffa83696c6a999ff6080873236ff8014b
Author: Sean B. Palmer <[email protected]>
AuthorDate: Mon Mar 10 21:11:28 2025 +0200
Use a Pydantic class for archive integrity check arguments
---
atr/routes/package.py | 3 +-
atr/tasks/archive.py | 19 +++++---
atr/tasks/bulk.py | 12 +++--
atr/tasks/mailtest.py | 20 ++++----
atr/tasks/vote.py | 18 ++++---
atr/worker.py | 132 +++++++-------------------------------------------
6 files changed, 61 insertions(+), 143 deletions(-)
diff --git a/atr/routes/package.py b/atr/routes/package.py
index f6c7ac5..74deaf8 100644
--- a/atr/routes/package.py
+++ b/atr/routes/package.py
@@ -37,6 +37,7 @@ from sqlmodel import select
from werkzeug.datastructures import FileStorage, MultiDict
from werkzeug.wrappers.response import Response
+import atr.tasks.archive as archive
from asfquart.auth import Requirements, require
from asfquart.base import ASFQuartException
from asfquart.session import read as session_read
@@ -578,7 +579,7 @@ async def task_verification_create(db_session:
AsyncSession, package: Package) -
Task(
status=TaskStatus.QUEUED,
task_type="verify_archive_integrity",
- task_args=["releases/" + package.artifact_sha3],
+ task_args=archive.CheckIntegrity(path="releases/" +
package.artifact_sha3).model_dump(),
package_sha3=package.artifact_sha3,
),
Task(
diff --git a/atr/tasks/archive.py b/atr/tasks/archive.py
index 57a0017..3f8ee59 100644
--- a/atr/tasks/archive.py
+++ b/atr/tasks/archive.py
@@ -20,21 +20,28 @@ import os.path
import tarfile
from typing import Any, Final
+from pydantic import BaseModel, Field
+
import atr.tasks.task as task
_LOGGER = logging.getLogger(__name__)
-def check_integrity(args: list[str]) -> tuple[task.Status, str | None,
tuple[Any, ...]]:
+class CheckIntegrity(BaseModel):
+ """Parameters for archive integrity checking."""
+
+ path: str = Field(..., description="Path to the .tar.gz file to check")
+ chunk_size: int = Field(default=4096, description="Size of chunks to read
when checking the file")
+
+
+def check_integrity(args: dict[str, Any]) -> tuple[task.Status, str | None,
tuple[Any, ...]]:
"""Check the integrity of a .tar.gz file."""
# TODO: We should standardise the "ERROR" mechanism here in the data
# Then we can have a single task wrapper for all tasks
# TODO: We should use task.TaskError as standard, and maybe typeguard each
function
- # First argument should be the path, second is optional chunk_size
- path = args[0]
- chunk_size = int(args[1]) if len(args) > 1 else 4096
- task_results = task.results_as_tuple(_check_integrity_core(path,
chunk_size))
- _LOGGER.info(f"Verified {args} and computed size {task_results[0]}")
+ data = CheckIntegrity(**args)
+ task_results = task.results_as_tuple(_check_integrity_core(data.path,
data.chunk_size))
+ _LOGGER.info(f"Verified {data.path} and computed size {task_results[0]}")
return task.COMPLETED, None, task_results
diff --git a/atr/tasks/bulk.py b/atr/tasks/bulk.py
index f6e65ee..2654920 100644
--- a/atr/tasks/bulk.py
+++ b/atr/tasks/bulk.py
@@ -29,6 +29,8 @@ import aiohttp
from sqlalchemy import text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker,
create_async_engine
+import atr.tasks.task as task
+
# Configure detailed logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@@ -280,7 +282,7 @@ def database_progress_percentage_calculate(progress:
tuple[int, int] | None) ->
return percentage
-def download(args: list[str]) -> tuple[str, str | None, tuple[Any, ...]]:
+def download(args: list[str]) -> tuple[task.Status, str | None, tuple[Any,
...]]:
"""Download bulk package from URL."""
# Returns (status, error, result)
# This is the main task entry point, called by worker.py
@@ -294,10 +296,10 @@ def download(args: list[str]) -> tuple[str, str | None,
tuple[Any, ...]]:
except Exception as e:
logger.exception(f"Error in download function: {e}")
# Return a tuple with a dictionary that matches what the template
expects
- return "FAILED", str(e), ({"message": f"Error: {e}", "progress": 0},)
+ return task.FAILED, str(e), ({"message": f"Error: {e}", "progress":
0},)
-def download_core(args_list: list[str]) -> tuple[str, str | None, tuple[Any,
...]]:
+def download_core(args_list: list[str]) -> tuple[task.Status, str | None,
tuple[Any, ...]]:
"""Download bulk package from URL."""
logger.info("Starting download_core")
try:
@@ -329,7 +331,7 @@ def download_core(args_list: list[str]) -> tuple[str, str |
None, tuple[Any, ...
# Return a result dictionary
# This matches what we have in templates/release-bulk.html
return (
- "COMPLETED",
+ task.COMPLETED,
None,
(
{
@@ -345,7 +347,7 @@ def download_core(args_list: list[str]) -> tuple[str, str |
None, tuple[Any, ...
except Exception as e:
logger.exception(f"Error in download_core: {e}")
return (
- "FAILED",
+ task.FAILED,
str(e),
(
{
diff --git a/atr/tasks/mailtest.py b/atr/tasks/mailtest.py
index 72c7f27..81d9e25 100644
--- a/atr/tasks/mailtest.py
+++ b/atr/tasks/mailtest.py
@@ -20,6 +20,8 @@ import os
from dataclasses import dataclass
from typing import Any
+import atr.tasks.task as task
+
# Configure detailed logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@@ -81,7 +83,7 @@ class Args:
return args_obj
-def send(args: list[str]) -> tuple[str, str | None, tuple[Any, ...]]:
+def send(args: list[str]) -> tuple[task.Status, str | None, tuple[Any, ...]]:
"""Send a test email."""
logger.info(f"Sending with args: {args}")
try:
@@ -91,10 +93,10 @@ def send(args: list[str]) -> tuple[str, str | None,
tuple[Any, ...]]:
return status, error, result
except Exception as e:
logger.exception(f"Error in send function: {e}")
- return "FAILED", str(e), tuple()
+ return task.FAILED, str(e), tuple()
-def send_core(args_list: list[str]) -> tuple[str, str | None, tuple[Any, ...]]:
+def send_core(args_list: list[str]) -> tuple[task.Status, str | None,
tuple[Any, ...]]:
"""Send a test email."""
import asyncio
@@ -137,17 +139,17 @@ def send_core(args_list: list[str]) -> tuple[str, str |
None, tuple[Any, ...]]:
if not tooling_pmc:
error_msg = "Tooling PMC not found in database"
logger.error(error_msg)
- return "FAILED", error_msg, tuple()
+ return task.FAILED, error_msg, tuple()
if domain != "apache.org":
error_msg = f"Email domain must be apache.org, got {domain}"
logger.error(error_msg)
- return "FAILED", error_msg, tuple()
+ return task.FAILED, error_msg, tuple()
if local_part not in tooling_pmc.pmc_members:
error_msg = f"Email recipient {local_part} is not a member of
the tooling PMC"
logger.error(error_msg)
- return "FAILED", error_msg, tuple()
+ return task.FAILED, error_msg, tuple()
logger.info(f"Recipient {email_recipient} is a tooling PMC member,
allowed")
@@ -163,7 +165,7 @@ def send_core(args_list: list[str]) -> tuple[str, str |
None, tuple[Any, ...]]:
except Exception as e:
error_msg = f"Failed to load DKIM key: {e}"
logger.error(error_msg)
- return "FAILED", error_msg, tuple()
+ return task.FAILED, error_msg, tuple()
event = atr.mail.ArtifactEvent(
artifact_name=args.artifact_name,
@@ -173,8 +175,8 @@ def send_core(args_list: list[str]) -> tuple[str, str |
None, tuple[Any, ...]]:
atr.mail.send(event)
logger.info(f"Email sent successfully to {args.email_recipient}")
- return "COMPLETED", None, tuple()
+ return task.COMPLETED, None, tuple()
except Exception as e:
logger.exception(f"Error in send_core: {e}")
- return "FAILED", str(e), tuple()
+ return task.FAILED, str(e), tuple()
diff --git a/atr/tasks/vote.py b/atr/tasks/vote.py
index 1227bf5..cfca283 100644
--- a/atr/tasks/vote.py
+++ b/atr/tasks/vote.py
@@ -22,6 +22,8 @@ from dataclasses import dataclass
from datetime import UTC
from typing import Any
+import atr.tasks.task as task
+
# Configure detailed logging
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
@@ -98,7 +100,7 @@ class Args:
return args_obj
-def initiate(args: list[str]) -> tuple[str, str | None, tuple[Any, ...]]:
+def initiate(args: list[str]) -> tuple[task.Status, str | None, tuple[Any,
...]]:
"""Initiate a vote for a release."""
logger.info(f"Initiating vote with args: {args}")
try:
@@ -108,10 +110,10 @@ def initiate(args: list[str]) -> tuple[str, str | None,
tuple[Any, ...]]:
return status, error, result
except Exception as e:
logger.exception(f"Error in initiate function: {e}")
- return "FAILED", str(e), tuple()
+ return task.FAILED, str(e), tuple()
-def initiate_core(args_list: list[str]) -> tuple[str, str | None, tuple[Any,
...]]:
+def initiate_core(args_list: list[str]) -> tuple[task.Status, str | None,
tuple[Any, ...]]:
"""Get arguments, create an email, and then send it to the recipient."""
import atr.mail
from atr.db.service import get_release_by_key_sync
@@ -141,7 +143,7 @@ def initiate_core(args_list: list[str]) -> tuple[str, str |
None, tuple[Any, ...
if not release:
error_msg = f"Release with key {args.release_key} not found"
logger.error(error_msg)
- return "FAILED", error_msg, tuple()
+ return task.FAILED, error_msg, tuple()
# GPG key ID, just for testing the UI
gpg_key_id = args.gpg_key_id
@@ -166,13 +168,13 @@ def initiate_core(args_list: list[str]) -> tuple[str, str
| None, tuple[Any, ...
except Exception as e:
error_msg = f"Failed to load DKIM key: {e}"
logger.error(error_msg)
- return "FAILED", error_msg, tuple()
+ return task.FAILED, error_msg, tuple()
# Get PMC and product details
if release.pmc is None:
error_msg = "Release has no associated PMC"
logger.error(error_msg)
- return "FAILED", error_msg, tuple()
+ return task.FAILED, error_msg, tuple()
pmc_name = release.pmc.project_name
pmc_display = release.pmc.display_name
@@ -235,7 +237,7 @@ Thanks,
# TODO: Update release status to indicate a vote is in progress
# This would involve updating the database with the vote details
somehow
return (
- "COMPLETED",
+ task.COMPLETED,
None,
(
{
@@ -250,4 +252,4 @@ Thanks,
except Exception as e:
logger.exception(f"Error in initiate_core: {e}")
- return "FAILED", str(e), tuple()
+ return task.FAILED, str(e), tuple()
diff --git a/atr/worker.py b/atr/worker.py
index 3a1c1c7..a816f72 100644
--- a/atr/worker.py
+++ b/atr/worker.py
@@ -191,106 +191,6 @@ def task_result_process(
)
-def task_bulk_download_debug(args: list[str] | dict) -> tuple[str, str | None,
tuple[Any, ...]]:
- # This was a debug function; pay no attention to this
- # TODO: Remove once we're sure everything is working
- _LOGGER.info(f"Bulk download debug task received args: {args}")
-
- try:
- # Extract parameters from args (support both list and dict inputs)
- if isinstance(args, list):
- # If it's a list, the release_key is the first element
- # release_key = args[0] if args else "unknown"
- url = args[1] if len(args) > 1 else "unknown"
- file_types = args[2] if len(args) > 2 else []
- require_signatures = args[3] if len(args) > 3 else False
- elif isinstance(args, dict):
- # release_key = args.get("release_key", "unknown")
- url = args.get("url", "unknown")
- file_types = args.get("file_types", [])
- require_signatures = args.get("require_signatures", False)
- # else:
- # _LOGGER.warning(f"Unexpected args type: {type(args)}")
- # release_key = "unknown"
- # url = "unknown"
- # file_types = []
- # require_signatures = False
-
- # Progress messages to display over time
- progress_messages = [
- f"Connecting to {url}...",
- f"Connected to {url}. Scanning for {', '.join(file_types) if
file_types else 'all'} files...",
- "Found 15 files matching criteria. Downloading...",
- "Downloaded 7/15 files (47%)...",
- "Downloaded 15/15 files (100%). Processing...",
- ]
-
- # Get task_id from the current process
- current_pid = os.getpid()
- task_id = None
-
- # Get the task ID for the current process
- with db.create_sync_db_session() as session:
- result = session.execute(
- sqlalchemy.text("SELECT id FROM task WHERE pid = :pid AND
status = 'ACTIVE'"), {"pid": current_pid}
- )
- task_row = result.first()
- if task_row:
- task_id = task_row[0]
-
- if not task_id:
- _LOGGER.warning(f"Could not find active task for PID
{current_pid}")
-
- # Process each progress message with a delay
- for i, message in enumerate(progress_messages):
- progress_pct = (i + 1) * 20
-
- update = {
- "message": message,
- "progress": progress_pct,
- "url": url,
- "timestamp": datetime.datetime.now(datetime.UTC).isoformat(),
- }
-
- # Log the progress
- _LOGGER.info(f"Progress update {i + 1}/{len(progress_messages)}:
{message} ({progress_pct}%)")
-
- # Update the database with the current progress if we have a
task_id
- if task_id:
- with db.create_sync_db_session() as session:
- # Update the task with the current progress message
- with session.begin():
- session.execute(
- sqlalchemy.text("""
- UPDATE task
- SET result = :result
- WHERE id = :task_id AND status = 'ACTIVE'
- """),
- {"task_id": task_id, "result": json.dumps(update)},
- )
-
- # Sleep before the next update, except for the last one
- if i < len(progress_messages) - 1:
- time.sleep(2.75)
-
- final_result = {
- "message": f"Successfully processed {url}",
- "progress": 100,
- "files_processed": 15,
- "files_downloaded": 15,
- "url": url,
- "file_types": file_types,
- "require_signatures": require_signatures,
- "completed_at": datetime.datetime.now(datetime.UTC).isoformat(),
- }
-
- return "COMPLETED", None, (final_result,)
-
- except Exception as e:
- _LOGGER.exception(f"Error in bulk download debug task: {e}")
- return "FAILED", str(e), ({"error": str(e), "message": f"Error:
{e!s}", "progress": 0},)
-
-
def task_process(task_id: int, task_type: str, task_args: str) -> None:
"""Process a claimed task."""
_LOGGER.info(f"Processing task {task_id} ({task_type}) with args
{task_args}")
@@ -299,8 +199,10 @@ def task_process(task_id: int, task_type: str, task_args:
str) -> None:
# Map task types to their handler functions
# TODO: We should use a decorator to register these automatically
- task_handlers = {
+ dict_task_handlers = {
"verify_archive_integrity": archive.check_integrity,
+ }
+ list_task_handlers = {
"verify_archive_structure": archive.check_structure,
"verify_license_files": license.check_files,
"verify_signature": signature.check,
@@ -312,20 +214,22 @@ def task_process(task_id: int, task_type: str, task_args:
str) -> None:
"vote_initiate": vote.initiate,
}
- handler = task_handlers.get(task_type)
- if not handler:
- msg = f"Unknown task type: {task_type}, {task_handlers.keys()}"
- _LOGGER.error(msg)
- raise Exception(msg)
-
- raw_status, error, task_results = handler(args)
- if isinstance(raw_status, task.Status):
- status = raw_status.value.upper()
- elif isinstance(raw_status, str):
- status = raw_status.upper()
+ if isinstance(args, dict):
+ dict_handler = dict_task_handlers.get(task_type)
+ if not dict_handler:
+ msg = f"Unknown task type: {task_type}"
+ _LOGGER.error(msg)
+ raise Exception(msg)
+ status, error, task_results = dict_handler(args)
else:
- raise Exception(f"Unknown task status type: {type(raw_status)}")
- task_result_process(task_id, task_results, status=status, error=error)
+ list_handler = list_task_handlers.get(task_type)
+ if not list_handler:
+ msg = f"Unknown task type: {task_type}"
+ _LOGGER.error(msg)
+ raise Exception(msg)
+ status, error, task_results = list_handler(args)
+
+ task_result_process(task_id, task_results,
status=status.value.upper(), error=error)
except Exception as e:
task_error_handle(task_id, e)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]