This is an automated email from the ASF dual-hosted git repository.

sbp pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tooling-atr-experiments.git


The following commit(s) were added to refs/heads/main by this push:
     new d17213c  Use a Pydantic class for archive integrity check arguments
d17213c is described below

commit d17213cffa83696c6a999ff6080873236ff8014b
Author: Sean B. Palmer <[email protected]>
AuthorDate: Mon Mar 10 21:11:28 2025 +0200

    Use a Pydantic class for archive integrity check arguments
---
 atr/routes/package.py |   3 +-
 atr/tasks/archive.py  |  19 +++++---
 atr/tasks/bulk.py     |  12 +++--
 atr/tasks/mailtest.py |  20 ++++----
 atr/tasks/vote.py     |  18 ++++---
 atr/worker.py         | 132 +++++++-------------------------------------------
 6 files changed, 61 insertions(+), 143 deletions(-)

diff --git a/atr/routes/package.py b/atr/routes/package.py
index f6c7ac5..74deaf8 100644
--- a/atr/routes/package.py
+++ b/atr/routes/package.py
@@ -37,6 +37,7 @@ from sqlmodel import select
 from werkzeug.datastructures import FileStorage, MultiDict
 from werkzeug.wrappers.response import Response
 
+import atr.tasks.archive as archive
 from asfquart.auth import Requirements, require
 from asfquart.base import ASFQuartException
 from asfquart.session import read as session_read
@@ -578,7 +579,7 @@ async def task_verification_create(db_session: 
AsyncSession, package: Package) -
         Task(
             status=TaskStatus.QUEUED,
             task_type="verify_archive_integrity",
-            task_args=["releases/" + package.artifact_sha3],
+            task_args=archive.CheckIntegrity(path="releases/" + 
package.artifact_sha3).model_dump(),
             package_sha3=package.artifact_sha3,
         ),
         Task(
diff --git a/atr/tasks/archive.py b/atr/tasks/archive.py
index 57a0017..3f8ee59 100644
--- a/atr/tasks/archive.py
+++ b/atr/tasks/archive.py
@@ -20,21 +20,28 @@ import os.path
 import tarfile
 from typing import Any, Final
 
+from pydantic import BaseModel, Field
+
 import atr.tasks.task as task
 
 _LOGGER = logging.getLogger(__name__)
 
 
-def check_integrity(args: list[str]) -> tuple[task.Status, str | None, 
tuple[Any, ...]]:
+class CheckIntegrity(BaseModel):
+    """Parameters for archive integrity checking."""
+
+    path: str = Field(..., description="Path to the .tar.gz file to check")
+    chunk_size: int = Field(default=4096, description="Size of chunks to read 
when checking the file")
+
+
+def check_integrity(args: dict[str, Any]) -> tuple[task.Status, str | None, 
tuple[Any, ...]]:
     """Check the integrity of a .tar.gz file."""
     # TODO: We should standardise the "ERROR" mechanism here in the data
     # Then we can have a single task wrapper for all tasks
     # TODO: We should use task.TaskError as standard, and maybe typeguard each 
function
-    # First argument should be the path, second is optional chunk_size
-    path = args[0]
-    chunk_size = int(args[1]) if len(args) > 1 else 4096
-    task_results = task.results_as_tuple(_check_integrity_core(path, 
chunk_size))
-    _LOGGER.info(f"Verified {args} and computed size {task_results[0]}")
+    data = CheckIntegrity(**args)
+    task_results = task.results_as_tuple(_check_integrity_core(data.path, 
data.chunk_size))
+    _LOGGER.info(f"Verified {data.path} and computed size {task_results[0]}")
     return task.COMPLETED, None, task_results
 
 
diff --git a/atr/tasks/bulk.py b/atr/tasks/bulk.py
index f6e65ee..2654920 100644
--- a/atr/tasks/bulk.py
+++ b/atr/tasks/bulk.py
@@ -29,6 +29,8 @@ import aiohttp
 from sqlalchemy import text
 from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, 
create_async_engine
 
+import atr.tasks.task as task
+
 # Configure detailed logging
 logger = logging.getLogger(__name__)
 logger.setLevel(logging.DEBUG)
@@ -280,7 +282,7 @@ def database_progress_percentage_calculate(progress: 
tuple[int, int] | None) ->
     return percentage
 
 
-def download(args: list[str]) -> tuple[str, str | None, tuple[Any, ...]]:
+def download(args: list[str]) -> tuple[task.Status, str | None, tuple[Any, 
...]]:
     """Download bulk package from URL."""
     # Returns (status, error, result)
     # This is the main task entry point, called by worker.py
@@ -294,10 +296,10 @@ def download(args: list[str]) -> tuple[str, str | None, 
tuple[Any, ...]]:
     except Exception as e:
         logger.exception(f"Error in download function: {e}")
         # Return a tuple with a dictionary that matches what the template 
expects
-        return "FAILED", str(e), ({"message": f"Error: {e}", "progress": 0},)
+        return task.FAILED, str(e), ({"message": f"Error: {e}", "progress": 
0},)
 
 
-def download_core(args_list: list[str]) -> tuple[str, str | None, tuple[Any, 
...]]:
+def download_core(args_list: list[str]) -> tuple[task.Status, str | None, 
tuple[Any, ...]]:
     """Download bulk package from URL."""
     logger.info("Starting download_core")
     try:
@@ -329,7 +331,7 @@ def download_core(args_list: list[str]) -> tuple[str, str | 
None, tuple[Any, ...
         # Return a result dictionary
         # This matches what we have in templates/release-bulk.html
         return (
-            "COMPLETED",
+            task.COMPLETED,
             None,
             (
                 {
@@ -345,7 +347,7 @@ def download_core(args_list: list[str]) -> tuple[str, str | 
None, tuple[Any, ...
     except Exception as e:
         logger.exception(f"Error in download_core: {e}")
         return (
-            "FAILED",
+            task.FAILED,
             str(e),
             (
                 {
diff --git a/atr/tasks/mailtest.py b/atr/tasks/mailtest.py
index 72c7f27..81d9e25 100644
--- a/atr/tasks/mailtest.py
+++ b/atr/tasks/mailtest.py
@@ -20,6 +20,8 @@ import os
 from dataclasses import dataclass
 from typing import Any
 
+import atr.tasks.task as task
+
 # Configure detailed logging
 logger = logging.getLogger(__name__)
 logger.setLevel(logging.DEBUG)
@@ -81,7 +83,7 @@ class Args:
         return args_obj
 
 
-def send(args: list[str]) -> tuple[str, str | None, tuple[Any, ...]]:
+def send(args: list[str]) -> tuple[task.Status, str | None, tuple[Any, ...]]:
     """Send a test email."""
     logger.info(f"Sending with args: {args}")
     try:
@@ -91,10 +93,10 @@ def send(args: list[str]) -> tuple[str, str | None, 
tuple[Any, ...]]:
         return status, error, result
     except Exception as e:
         logger.exception(f"Error in send function: {e}")
-        return "FAILED", str(e), tuple()
+        return task.FAILED, str(e), tuple()
 
 
-def send_core(args_list: list[str]) -> tuple[str, str | None, tuple[Any, ...]]:
+def send_core(args_list: list[str]) -> tuple[task.Status, str | None, 
tuple[Any, ...]]:
     """Send a test email."""
     import asyncio
 
@@ -137,17 +139,17 @@ def send_core(args_list: list[str]) -> tuple[str, str | 
None, tuple[Any, ...]]:
             if not tooling_pmc:
                 error_msg = "Tooling PMC not found in database"
                 logger.error(error_msg)
-                return "FAILED", error_msg, tuple()
+                return task.FAILED, error_msg, tuple()
 
             if domain != "apache.org":
                 error_msg = f"Email domain must be apache.org, got {domain}"
                 logger.error(error_msg)
-                return "FAILED", error_msg, tuple()
+                return task.FAILED, error_msg, tuple()
 
             if local_part not in tooling_pmc.pmc_members:
                 error_msg = f"Email recipient {local_part} is not a member of 
the tooling PMC"
                 logger.error(error_msg)
-                return "FAILED", error_msg, tuple()
+                return task.FAILED, error_msg, tuple()
 
             logger.info(f"Recipient {email_recipient} is a tooling PMC member, 
allowed")
 
@@ -163,7 +165,7 @@ def send_core(args_list: list[str]) -> tuple[str, str | 
None, tuple[Any, ...]]:
         except Exception as e:
             error_msg = f"Failed to load DKIM key: {e}"
             logger.error(error_msg)
-            return "FAILED", error_msg, tuple()
+            return task.FAILED, error_msg, tuple()
 
         event = atr.mail.ArtifactEvent(
             artifact_name=args.artifact_name,
@@ -173,8 +175,8 @@ def send_core(args_list: list[str]) -> tuple[str, str | 
None, tuple[Any, ...]]:
         atr.mail.send(event)
         logger.info(f"Email sent successfully to {args.email_recipient}")
 
-        return "COMPLETED", None, tuple()
+        return task.COMPLETED, None, tuple()
 
     except Exception as e:
         logger.exception(f"Error in send_core: {e}")
-        return "FAILED", str(e), tuple()
+        return task.FAILED, str(e), tuple()
diff --git a/atr/tasks/vote.py b/atr/tasks/vote.py
index 1227bf5..cfca283 100644
--- a/atr/tasks/vote.py
+++ b/atr/tasks/vote.py
@@ -22,6 +22,8 @@ from dataclasses import dataclass
 from datetime import UTC
 from typing import Any
 
+import atr.tasks.task as task
+
 # Configure detailed logging
 logger = logging.getLogger(__name__)
 logger.setLevel(logging.DEBUG)
@@ -98,7 +100,7 @@ class Args:
         return args_obj
 
 
-def initiate(args: list[str]) -> tuple[str, str | None, tuple[Any, ...]]:
+def initiate(args: list[str]) -> tuple[task.Status, str | None, tuple[Any, 
...]]:
     """Initiate a vote for a release."""
     logger.info(f"Initiating vote with args: {args}")
     try:
@@ -108,10 +110,10 @@ def initiate(args: list[str]) -> tuple[str, str | None, 
tuple[Any, ...]]:
         return status, error, result
     except Exception as e:
         logger.exception(f"Error in initiate function: {e}")
-        return "FAILED", str(e), tuple()
+        return task.FAILED, str(e), tuple()
 
 
-def initiate_core(args_list: list[str]) -> tuple[str, str | None, tuple[Any, 
...]]:
+def initiate_core(args_list: list[str]) -> tuple[task.Status, str | None, 
tuple[Any, ...]]:
     """Get arguments, create an email, and then send it to the recipient."""
     import atr.mail
     from atr.db.service import get_release_by_key_sync
@@ -141,7 +143,7 @@ def initiate_core(args_list: list[str]) -> tuple[str, str | 
None, tuple[Any, ...
         if not release:
             error_msg = f"Release with key {args.release_key} not found"
             logger.error(error_msg)
-            return "FAILED", error_msg, tuple()
+            return task.FAILED, error_msg, tuple()
 
         # GPG key ID, just for testing the UI
         gpg_key_id = args.gpg_key_id
@@ -166,13 +168,13 @@ def initiate_core(args_list: list[str]) -> tuple[str, str 
| None, tuple[Any, ...
         except Exception as e:
             error_msg = f"Failed to load DKIM key: {e}"
             logger.error(error_msg)
-            return "FAILED", error_msg, tuple()
+            return task.FAILED, error_msg, tuple()
 
         # Get PMC and product details
         if release.pmc is None:
             error_msg = "Release has no associated PMC"
             logger.error(error_msg)
-            return "FAILED", error_msg, tuple()
+            return task.FAILED, error_msg, tuple()
 
         pmc_name = release.pmc.project_name
         pmc_display = release.pmc.display_name
@@ -235,7 +237,7 @@ Thanks,
         # TODO: Update release status to indicate a vote is in progress
         # This would involve updating the database with the vote details 
somehow
         return (
-            "COMPLETED",
+            task.COMPLETED,
             None,
             (
                 {
@@ -250,4 +252,4 @@ Thanks,
 
     except Exception as e:
         logger.exception(f"Error in initiate_core: {e}")
-        return "FAILED", str(e), tuple()
+        return task.FAILED, str(e), tuple()
diff --git a/atr/worker.py b/atr/worker.py
index 3a1c1c7..a816f72 100644
--- a/atr/worker.py
+++ b/atr/worker.py
@@ -191,106 +191,6 @@ def task_result_process(
                 )
 
 
-def task_bulk_download_debug(args: list[str] | dict) -> tuple[str, str | None, 
tuple[Any, ...]]:
-    # This was a debug function; pay no attention to this
-    # TODO: Remove once we're sure everything is working
-    _LOGGER.info(f"Bulk download debug task received args: {args}")
-
-    try:
-        # Extract parameters from args (support both list and dict inputs)
-        if isinstance(args, list):
-            # If it's a list, the release_key is the first element
-            # release_key = args[0] if args else "unknown"
-            url = args[1] if len(args) > 1 else "unknown"
-            file_types = args[2] if len(args) > 2 else []
-            require_signatures = args[3] if len(args) > 3 else False
-        elif isinstance(args, dict):
-            # release_key = args.get("release_key", "unknown")
-            url = args.get("url", "unknown")
-            file_types = args.get("file_types", [])
-            require_signatures = args.get("require_signatures", False)
-        # else:
-        #     _LOGGER.warning(f"Unexpected args type: {type(args)}")
-        #     release_key = "unknown"
-        #     url = "unknown"
-        #     file_types = []
-        #     require_signatures = False
-
-        # Progress messages to display over time
-        progress_messages = [
-            f"Connecting to {url}...",
-            f"Connected to {url}. Scanning for {', '.join(file_types) if 
file_types else 'all'} files...",
-            "Found 15 files matching criteria. Downloading...",
-            "Downloaded 7/15 files (47%)...",
-            "Downloaded 15/15 files (100%). Processing...",
-        ]
-
-        # Get task_id from the current process
-        current_pid = os.getpid()
-        task_id = None
-
-        # Get the task ID for the current process
-        with db.create_sync_db_session() as session:
-            result = session.execute(
-                sqlalchemy.text("SELECT id FROM task WHERE pid = :pid AND 
status = 'ACTIVE'"), {"pid": current_pid}
-            )
-            task_row = result.first()
-            if task_row:
-                task_id = task_row[0]
-
-        if not task_id:
-            _LOGGER.warning(f"Could not find active task for PID 
{current_pid}")
-
-        # Process each progress message with a delay
-        for i, message in enumerate(progress_messages):
-            progress_pct = (i + 1) * 20
-
-            update = {
-                "message": message,
-                "progress": progress_pct,
-                "url": url,
-                "timestamp": datetime.datetime.now(datetime.UTC).isoformat(),
-            }
-
-            # Log the progress
-            _LOGGER.info(f"Progress update {i + 1}/{len(progress_messages)}: 
{message} ({progress_pct}%)")
-
-            # Update the database with the current progress if we have a 
task_id
-            if task_id:
-                with db.create_sync_db_session() as session:
-                    # Update the task with the current progress message
-                    with session.begin():
-                        session.execute(
-                            sqlalchemy.text("""
-                                UPDATE task
-                                SET result = :result
-                                WHERE id = :task_id AND status = 'ACTIVE'
-                            """),
-                            {"task_id": task_id, "result": json.dumps(update)},
-                        )
-
-            # Sleep before the next update, except for the last one
-            if i < len(progress_messages) - 1:
-                time.sleep(2.75)
-
-        final_result = {
-            "message": f"Successfully processed {url}",
-            "progress": 100,
-            "files_processed": 15,
-            "files_downloaded": 15,
-            "url": url,
-            "file_types": file_types,
-            "require_signatures": require_signatures,
-            "completed_at": datetime.datetime.now(datetime.UTC).isoformat(),
-        }
-
-        return "COMPLETED", None, (final_result,)
-
-    except Exception as e:
-        _LOGGER.exception(f"Error in bulk download debug task: {e}")
-        return "FAILED", str(e), ({"error": str(e), "message": f"Error: 
{e!s}", "progress": 0},)
-
-
 def task_process(task_id: int, task_type: str, task_args: str) -> None:
     """Process a claimed task."""
     _LOGGER.info(f"Processing task {task_id} ({task_type}) with args 
{task_args}")
@@ -299,8 +199,10 @@ def task_process(task_id: int, task_type: str, task_args: 
str) -> None:
 
         # Map task types to their handler functions
         # TODO: We should use a decorator to register these automatically
-        task_handlers = {
+        dict_task_handlers = {
             "verify_archive_integrity": archive.check_integrity,
+        }
+        list_task_handlers = {
             "verify_archive_structure": archive.check_structure,
             "verify_license_files": license.check_files,
             "verify_signature": signature.check,
@@ -312,20 +214,22 @@ def task_process(task_id: int, task_type: str, task_args: 
str) -> None:
             "vote_initiate": vote.initiate,
         }
 
-        handler = task_handlers.get(task_type)
-        if not handler:
-            msg = f"Unknown task type: {task_type}, {task_handlers.keys()}"
-            _LOGGER.error(msg)
-            raise Exception(msg)
-
-        raw_status, error, task_results = handler(args)
-        if isinstance(raw_status, task.Status):
-            status = raw_status.value.upper()
-        elif isinstance(raw_status, str):
-            status = raw_status.upper()
+        if isinstance(args, dict):
+            dict_handler = dict_task_handlers.get(task_type)
+            if not dict_handler:
+                msg = f"Unknown task type: {task_type}"
+                _LOGGER.error(msg)
+                raise Exception(msg)
+            status, error, task_results = dict_handler(args)
         else:
-            raise Exception(f"Unknown task status type: {type(raw_status)}")
-        task_result_process(task_id, task_results, status=status, error=error)
+            list_handler = list_task_handlers.get(task_type)
+            if not list_handler:
+                msg = f"Unknown task type: {task_type}"
+                _LOGGER.error(msg)
+                raise Exception(msg)
+            status, error, task_results = list_handler(args)
+
+        task_result_process(task_id, task_results, 
status=status.value.upper(), error=error)
 
     except Exception as e:
         task_error_handle(task_id, e)


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to