This is an automated email from the ASF dual-hosted git repository.

sbp pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tooling-trusted-release.git


The following commit(s) were added to refs/heads/main by this push:
     new e3cd3ed  Make the SBOM task generate a file, and use type safe symbols 
for tasks
e3cd3ed is described below

commit e3cd3ed5d421432f7e837fa87f9edae2770aec89
Author: Sean B. Palmer <[email protected]>
AuthorDate: Tue Apr 1 15:46:28 2025 +0100

    Make the SBOM task generate a file, and use type safe symbols for tasks
    
    - Instead of generating JSON which is stored in a task object, the
      SBOM task now writes that JSON to a file and is started from the
      candidate draft action pages.
    - We used a function to construct string values to identify tasks. We
      now use an enum instead, and a match statement in a function to map
      from those enum values to the actual functions. This can be checked
      exhaustively by the type checker.
    - Other miscellaneous small changes such as moving functions.
---
 atr/db/__init__.py                   |   4 +-
 atr/db/models.py                     |  16 +-
 atr/routes/candidate.py              |   3 +-
 atr/routes/draft.py                  |  82 +++++++--
 atr/routes/keys.py                   |  11 +-
 atr/ssh.py                           |   3 +-
 atr/tasks/__init__.py                |  55 ++++--
 atr/tasks/checks/__init__.py         |   8 +-
 atr/tasks/checks/rat.py              |   2 +-
 atr/tasks/rsync.py                   |  27 ++-
 atr/tasks/sbom.py                    | 324 ++++++++++++++++++++---------------
 atr/tasks/vote.py                    |  24 ---
 atr/templates/draft-review-path.html |   2 +-
 atr/templates/draft-tools.html       |   9 +
 atr/util.py                          |  16 ++
 atr/worker.py                        |  88 ++--------
 16 files changed, 391 insertions(+), 283 deletions(-)

diff --git a/atr/db/__init__.py b/atr/db/__init__.py
index 010c0d7..9e81ee9 100644
--- a/atr/db/__init__.py
+++ b/atr/db/__init__.py
@@ -563,8 +563,8 @@ async def recent_tasks(data: Session, release_name: str, 
file_path: str, modifie
     recent_tasks: dict[str, models.Task] = {}
     for task in tasks:
         # If we haven't seen this task type before or if this task is newer
-        if (task.task_type not in recent_tasks) or (task.id > 
recent_tasks[task.task_type].id):
-            recent_tasks[task.task_type] = task
+        if (task.task_type.value not in recent_tasks) or (task.id > 
recent_tasks[task.task_type.value].id):
+            recent_tasks[task.task_type.value] = task
 
     return recent_tasks
 
diff --git a/atr/db/models.py b/atr/db/models.py
index 1e96477..2279502 100644
--- a/atr/db/models.py
+++ b/atr/db/models.py
@@ -284,12 +284,26 @@ class TaskStatus(str, enum.Enum):
     FAILED = "failed"
 
 
+class TaskType(str, enum.Enum):
+    ARCHIVE_INTEGRITY = "archive_integrity"
+    ARCHIVE_STRUCTURE = "archive_structure"
+    HASHING_CHECK = "hashing_check"
+    LICENSE_FILES = "license_files"
+    LICENSE_HEADERS = "license_headers"
+    # PATHS_CHECK = "paths_check"
+    RAT_CHECK = "rat_check"
+    RSYNC_ANALYSE = "rsync_analyse"
+    SIGNATURE_CHECK = "signature_check"
+    VOTE_INITIATE = "vote_initiate"
+    SBOM_GENERATE_CYCLONEDX = "sbom_generate_cyclonedx"
+
+
 class Task(sqlmodel.SQLModel, table=True):
     """A task in the task queue."""
 
     id: int = sqlmodel.Field(default=None, primary_key=True)
     status: TaskStatus = sqlmodel.Field(default=TaskStatus.QUEUED, index=True)
-    task_type: str
+    task_type: TaskType
     task_args: Any = 
sqlmodel.Field(sa_column=sqlalchemy.Column(sqlalchemy.JSON))
     added: datetime.datetime = sqlmodel.Field(
         default_factory=lambda: datetime.datetime.now(datetime.UTC),
diff --git a/atr/routes/candidate.py b/atr/routes/candidate.py
index 9eb4623..2a804d1 100644
--- a/atr/routes/candidate.py
+++ b/atr/routes/candidate.py
@@ -28,7 +28,6 @@ import wtforms
 import atr.db as db
 import atr.db.models as models
 import atr.routes as routes
-import atr.tasks.checks as checks
 import atr.tasks.vote as tasks_vote
 import atr.user as user
 import atr.util as util
@@ -158,7 +157,7 @@ async def vote_project(session: routes.CommitterSession, 
project_name: str, vers
         # Create a task for vote initiation
         task = models.Task(
             status=models.TaskStatus.QUEUED,
-            task_type=checks.function_key(tasks_vote.initiate),
+            task_type=models.TaskType.VOTE_INITIATE,
             task_args=tasks_vote.Initiate(
                 release_name=release_name,
                 email_to=email_to,
diff --git a/atr/routes/draft.py b/atr/routes/draft.py
index 79f09f7..19f015a 100644
--- a/atr/routes/draft.py
+++ b/atr/routes/draft.py
@@ -38,6 +38,7 @@ import atr.db as db
 import atr.db.models as models
 import atr.routes as routes
 import atr.tasks as tasks
+import atr.tasks.sbom as sbom
 import atr.util as util
 
 if TYPE_CHECKING:
@@ -83,6 +84,18 @@ class DeleteFileForm(util.QuartFormTyped):
     submit = wtforms.SubmitField("Delete file")
 
 
+class PromoteForm(util.QuartFormTyped):
+    """Form for promoting a candidate draft."""
+
+    candidate_draft_name = wtforms.StringField(
+        "Candidate draft name", 
validators=[wtforms.validators.InputRequired("Candidate draft name is 
required")]
+    )
+    confirm_promote = wtforms.BooleanField(
+        "Confirmation", validators=[wtforms.validators.DataRequired("You must 
confirm to proceed with promotion")]
+    )
+    submit = wtforms.SubmitField("Promote to candidate")
+
+
 async def _number_of_release_files(release: models.Release) -> int:
     """Return the number of files in the release."""
     path_project = release.project.name
@@ -417,18 +430,6 @@ async def directory(session: routes.CommitterSession) -> 
str:
     )
 
 
-class PromoteForm(util.QuartFormTyped):
-    """Form for promoting a candidate draft."""
-
-    candidate_draft_name = wtforms.StringField(
-        "Candidate draft name", 
validators=[wtforms.validators.InputRequired("Candidate draft name is 
required")]
-    )
-    confirm_promote = wtforms.BooleanField(
-        "Confirmation", validators=[wtforms.validators.DataRequired("You must 
confirm to proceed with promotion")]
-    )
-    submit = wtforms.SubmitField("Promote to candidate")
-
-
 @routes.committer("/draft/promote", methods=["GET", "POST"])
 async def promote(session: routes.CommitterSession) -> str | response.Response:
     """Allow the user to promote a candidate draft."""
@@ -634,6 +635,63 @@ async def review_path(session: routes.CommitterSession, 
project_name: str, versi
     )
 
 
[email protected]("/draft/sbomgen/<project_name>/<version_name>/<path:file_path>",
 methods=["POST"])
+async def sbomgen(
+    session: routes.CommitterSession, project_name: str, version_name: str, 
file_path: str
+) -> response.Response:
+    """Generate a CycloneDX SBOM file for a candidate draft file."""
+    # Check that the user has access to the project
+    if not any((p.name == project_name) for p in (await 
session.user_projects)):
+        raise base.ASFQuartException("You do not have access to this project", 
errorcode=403)
+
+    async with db.session() as data:
+        # Check that the release exists
+        release_name = f"{project_name}-{version_name}"
+        await data.release(name=release_name, _project=True).demand(
+            base.ASFQuartException("Release does not exist", errorcode=404)
+        )
+
+        # Construct paths
+        base_path = util.get_release_candidate_draft_dir() / project_name / 
version_name
+        full_path = base_path / file_path
+        # Standard CycloneDX extension
+        sbom_path_rel = file_path + ".cdx.json"
+        full_sbom_path = base_path / sbom_path_rel
+
+        # Check that the source file exists
+        if not await aiofiles.os.path.exists(full_path):
+            raise base.ASFQuartException("Source artifact file does not 
exist", errorcode=404)
+
+        # Check that the file is a .tar.gz archive
+        if not file_path.endswith(".tar.gz"):
+            raise base.ASFQuartException("SBOM generation is only supported 
for .tar.gz files", errorcode=400)
+
+        # Check that the SBOM file does not already exist
+        if await aiofiles.os.path.exists(full_sbom_path):
+            raise base.ASFQuartException("SBOM file already exists", 
errorcode=400)
+
+        # Create and queue the task
+        sbom_task = models.Task(
+            task_type=models.TaskType.SBOM_GENERATE_CYCLONEDX,
+            task_args=sbom.GenerateCycloneDX(
+                artifact_path=str(full_path),
+                output_path=str(full_sbom_path),
+            ).model_dump(),
+            added=datetime.datetime.now(datetime.UTC),
+            status=models.TaskStatus.QUEUED,
+            release_name=release_name,
+        )
+        data.add(sbom_task)
+        await data.commit()
+
+    return await session.redirect(
+        review,
+        success=f"SBOM generation task queued for 
{pathlib.Path(file_path).name}",
+        project_name=project_name,
+        version_name=version_name,
+    )
+
+
 
@routes.committer("/draft/tools/<project_name>/<version_name>/<path:file_path>")
 async def tools(session: routes.CommitterSession, project_name: str, 
version_name: str, file_path: str) -> str:
     """Show the tools for a specific file."""
diff --git a/atr/routes/keys.py b/atr/routes/keys.py
index 127b163..13c9bb6 100644
--- a/atr/routes/keys.py
+++ b/atr/routes/keys.py
@@ -26,8 +26,6 @@ import logging
 import logging.handlers
 import pprint
 import re
-import shutil
-import tempfile
 from collections.abc import AsyncGenerator, Sequence
 
 import asfquart as asfquart
@@ -52,13 +50,8 @@ class AddSSHKeyForm(util.QuartFormTyped):
 @contextlib.asynccontextmanager
 async def ephemeral_gpg_home() -> AsyncGenerator[str]:
     """Create a temporary directory for an isolated GPG home, and clean it up 
on exit."""
-    # TODO: This is only used in key_user_add
-    # We could even inline it there
-    temp_dir = await asyncio.to_thread(tempfile.mkdtemp, prefix="gpg-")
-    try:
-        yield temp_dir
-    finally:
-        await asyncio.to_thread(shutil.rmtree, temp_dir)
+    async with util.async_temporary_directory(prefix="gpg-") as temp_dir:
+        yield str(temp_dir)
 
 
 async def key_add_post(
diff --git a/atr/ssh.py b/atr/ssh.py
index b1a43ba..99d546f 100644
--- a/atr/ssh.py
+++ b/atr/ssh.py
@@ -32,7 +32,6 @@ import asyncssh
 import atr.config as config
 import atr.db as db
 import atr.db.models as models
-import atr.tasks.checks as checks
 import atr.tasks.rsync as rsync
 import atr.user as user
 import atr.util as util
@@ -315,7 +314,7 @@ async def _handle_client(process: 
asyncssh.SSHServerProcess) -> None:
             data.add(
                 models.Task(
                     status=models.TaskStatus.QUEUED,
-                    task_type=checks.function_key(rsync.analyse),
+                    task_type=models.TaskType.RSYNC_ANALYSE,
                     task_args=rsync.Analyse(
                         project_name=project_name,
                         release_version=release_version,
diff --git a/atr/tasks/__init__.py b/atr/tasks/__init__.py
index 45eb28d..e642cd7 100644
--- a/atr/tasks/__init__.py
+++ b/atr/tasks/__init__.py
@@ -15,15 +15,19 @@
 # specific language governing permissions and limitations
 # under the License.
 
+from collections.abc import Awaitable, Callable
+
 import aiofiles.os
 
 import atr.db.models as models
-import atr.tasks.checks as checks
 import atr.tasks.checks.archive as archive
 import atr.tasks.checks.hashing as hashing
 import atr.tasks.checks.license as license
 import atr.tasks.checks.rat as rat
 import atr.tasks.checks.signature as signature
+import atr.tasks.rsync as rsync
+import atr.tasks.sbom as sbom
+import atr.tasks.vote as vote
 import atr.util as util
 
 
@@ -43,7 +47,7 @@ async def asc_checks(release: models.Release, signature_path: 
str) -> list[model
         tasks.append(
             models.Task(
                 status=models.TaskStatus.QUEUED,
-                task_type=checks.function_key(signature.check),
+                task_type=models.TaskType.SIGNATURE_CHECK,
                 task_args=signature.Check(
                     release_name=release.name,
                     committee_name=release.committee.name,
@@ -59,6 +63,34 @@ async def asc_checks(release: models.Release, 
signature_path: str) -> list[model
     return tasks
 
 
+def resolve(task_type: models.TaskType) -> Callable[..., Awaitable[str | 
None]]:  # noqa: C901
+    match task_type:
+        case models.TaskType.ARCHIVE_INTEGRITY:
+            return archive.integrity
+        case models.TaskType.ARCHIVE_STRUCTURE:
+            return archive.structure
+        case models.TaskType.HASHING_CHECK:
+            return hashing.check
+        case models.TaskType.LICENSE_FILES:
+            return license.files
+        case models.TaskType.LICENSE_HEADERS:
+            return license.headers
+        # case models.TaskType.PATHS_CHECK:
+        #     return paths.check
+        case models.TaskType.RAT_CHECK:
+            return rat.check
+        case models.TaskType.RSYNC_ANALYSE:
+            return rsync.analyse
+        case models.TaskType.SIGNATURE_CHECK:
+            return signature.check
+        case models.TaskType.VOTE_INITIATE:
+            return vote.initiate
+        case models.TaskType.SBOM_GENERATE_CYCLONEDX:
+            return sbom.generate_cyclonedx
+        # NOTE: Do NOT add "case _" here
+        # Otherwise we lose exhaustiveness checking
+
+
 async def sha_checks(release: models.Release, hash_file: str) -> 
list[models.Task]:
     tasks = []
 
@@ -78,7 +110,7 @@ async def sha_checks(release: models.Release, hash_file: 
str) -> list[models.Tas
     tasks.append(
         models.Task(
             status=models.TaskStatus.QUEUED,
-            task_type=checks.function_key(hashing.check),
+            task_type=models.TaskType.HASHING_CHECK,
             task_args=hashing.Check(
                 release_name=release.name,
                 abs_path=original_file,
@@ -103,7 +135,7 @@ async def tar_gz_checks(release: models.Release, path: str) 
-> list[models.Task]
     tasks = [
         models.Task(
             status=models.TaskStatus.QUEUED,
-            task_type=checks.function_key(archive.integrity),
+            task_type=models.TaskType.ARCHIVE_INTEGRITY,
             task_args=archive.Integrity(release_name=release.name, 
abs_path=full_path).model_dump(),
             release_name=release.name,
             path=path,
@@ -111,7 +143,7 @@ async def tar_gz_checks(release: models.Release, path: str) 
-> list[models.Task]
         ),
         models.Task(
             status=models.TaskStatus.QUEUED,
-            task_type=checks.function_key(archive.structure),
+            task_type=models.TaskType.ARCHIVE_STRUCTURE,
             task_args=archive.Structure(release_name=release.name, 
abs_path=full_path).model_dump(),
             release_name=release.name,
             path=path,
@@ -119,7 +151,7 @@ async def tar_gz_checks(release: models.Release, path: str) 
-> list[models.Task]
         ),
         models.Task(
             status=models.TaskStatus.QUEUED,
-            task_type=checks.function_key(license.files),
+            task_type=models.TaskType.LICENSE_FILES,
             task_args=license.Files(release_name=release.name, 
abs_path=full_path).model_dump(),
             release_name=release.name,
             path=path,
@@ -127,7 +159,7 @@ async def tar_gz_checks(release: models.Release, path: str) 
-> list[models.Task]
         ),
         models.Task(
             status=models.TaskStatus.QUEUED,
-            task_type=checks.function_key(license.headers),
+            task_type=models.TaskType.LICENSE_HEADERS,
             task_args=license.Headers(release_name=release.name, 
abs_path=full_path).model_dump(),
             release_name=release.name,
             path=path,
@@ -135,7 +167,7 @@ async def tar_gz_checks(release: models.Release, path: str) 
-> list[models.Task]
         ),
         models.Task(
             status=models.TaskStatus.QUEUED,
-            task_type=checks.function_key(rat.check),
+            task_type=models.TaskType.RAT_CHECK,
             task_args=rat.Check(release_name=release.name, 
abs_path=full_path).model_dump(),
             release_name=release.name,
             path=path,
@@ -143,8 +175,11 @@ async def tar_gz_checks(release: models.Release, path: 
str) -> list[models.Task]
         ),
         # models.Task(
         #     status=models.TaskStatus.QUEUED,
-        #     task_type="generate_cyclonedx_sbom",
-        #     task_args=[full_path],
+        #     task_type=models.TaskType.SBOM_GENERATE_CYCLONEDX,
+        #     task_args=tasks.SbomGenerateCyclonedx(
+        #         artifact_path=str(full_path),
+        #         output_path=str(full_sbom_path),
+        #     ).model_dump(),
         #     release_name=release.name,
         #     path=path,
         #     modified=modified,
diff --git a/atr/tasks/checks/__init__.py b/atr/tasks/checks/__init__.py
index 48da530..dcc34b5 100644
--- a/atr/tasks/checks/__init__.py
+++ b/atr/tasks/checks/__init__.py
@@ -33,10 +33,6 @@ import atr.db as db
 import atr.db.models as models
 
 
-def function_key(func: Callable[..., Any]) -> str:
-    return func.__module__ + "." + func.__name__
-
-
 class Check:
     def __init__(
         self, checker: Callable[..., Any], release_name: str, path: str | None 
= None, afresh: bool = True
@@ -111,6 +107,10 @@ class Check:
         return await self._add(models.CheckResultStatus.WARNING, message, 
data, path=path)
 
 
+def function_key(func: Callable[..., Any]) -> str:
+    return func.__module__ + "." + func.__name__
+
+
 def rel_path(abs_path: str) -> str:
     """Return the relative path for a given absolute path."""
     conf = config.get()
diff --git a/atr/tasks/checks/rat.py b/atr/tasks/checks/rat.py
index 09d25f4..92284f6 100644
--- a/atr/tasks/checks/rat.py
+++ b/atr/tasks/checks/rat.py
@@ -177,7 +177,7 @@ def _check_core_logic(
 
             # Extract the archive to the temporary directory
             _LOGGER.info(f"Extracting {artifact_path} to {temp_dir}")
-            extracted_size = sbom.archive_extract_safe(
+            extracted_size = sbom._archive_extract_safe(
                 artifact_path, temp_dir, max_size=max_extract_size, 
chunk_size=chunk_size
             )
             _LOGGER.info(f"Extracted {extracted_size} bytes")
diff --git a/atr/tasks/rsync.py b/atr/tasks/rsync.py
index 37205e5..8aa9cca 100644
--- a/atr/tasks/rsync.py
+++ b/atr/tasks/rsync.py
@@ -25,6 +25,8 @@ import atr.db as db
 import atr.db.models as models
 import atr.tasks as tasks
 import atr.tasks.checks as checks
+
+# import atr.tasks.checks.paths as paths
 import atr.util as util
 
 if TYPE_CHECKING:
@@ -62,12 +64,12 @@ async def analyse(args: Analyse) -> str | None:
 async def _analyse_core(project_name: str, release_version: str) -> dict[str, 
Any]:
     """Core logic to analyse an rsync upload and queue checks."""
     base_path = util.get_release_candidate_draft_dir() / project_name / 
release_version
-    paths = await util.paths_recursive(base_path)
+    paths_recursive = await util.paths_recursive(base_path)
     release_name = f"{project_name}-{release_version}"
 
     async with db.session() as data:
         release = await data.release(name=release_name, 
_committee=True).demand(RuntimeError("Release not found"))
-        for path in paths:
+        for path in paths_recursive:
             # This works because path is relative
             full_path = base_path / path
 
@@ -90,5 +92,24 @@ async def _analyse_core(project_name: str, release_version: 
str) -> dict[str, An
                     for task in await task_function(release, str(path)):
                         if task.task_type not in cached_tasks:
                             data.add(task)
+
+            # # Add the generic path check task for every file
+            # if path_check_task_key not in cached_tasks:
+            #     path_check_task_args = paths.Check(
+            #         release_name=release_name,
+            #         base_release_dir=str(base_path),
+            #         path=str(path),
+            #     ).model_dump()
+
+        #     path_check_task = models.Task(
+        #         status=models.TaskStatus.QUEUED,
+        #         task_type=tasks.Type.PATHS_CHECK,
+        #         task_args=paths.Check(
+        #             release_name=release_name,
+        #             base_release_dir=str(base_path),
+        #             path=str(path),
+        #         ).model_dump(),
+        #     )
+
         await data.commit()
-    return {"paths": [str(path) for path in paths]}
+    return {"paths": [str(path) for path in paths_recursive]}
diff --git a/atr/tasks/sbom.py b/atr/tasks/sbom.py
index b21356e..14d63d5 100644
--- a/atr/tasks/sbom.py
+++ b/atr/tasks/sbom.py
@@ -15,191 +15,231 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import asyncio
+import json
 import logging
 import os
 import tarfile
 from typing import Any, Final
 
+import aiofiles
+import pydantic
+
 import atr.config as config
-import atr.db.models as models
+import atr.tasks.checks as checks
 import atr.tasks.checks.archive as archive
-import atr.tasks.task as task
+import atr.util as util
 
 _CONFIG: Final = config.get()
 _LOGGER: Final = logging.getLogger(__name__)
 
 
-def archive_extract_safe(
-    archive_path: str,
-    extract_dir: str,
-    max_size: int = _CONFIG.MAX_EXTRACT_SIZE,
-    chunk_size: int = _CONFIG.EXTRACT_CHUNK_SIZE,
-) -> int:
-    """Safely extract an archive with size limits."""
-    total_extracted = 0
+class SBOMGenerationError(Exception):
+    """Custom exception for SBOM generation failures."""
 
-    with tarfile.open(archive_path, mode="r|gz") as tf:
-        for member in tf:
-            # Skip anything that's not a file or directory
-            if not (member.isreg() or member.isdir()):
-                continue
-
-            # Check whether extraction would exceed the size limit
-            if member.isreg() and ((total_extracted + member.size) > max_size):
-                raise task.Error(
-                    f"Extraction would exceed maximum size limit of {max_size} 
bytes",
-                    {"max_size": max_size, "current_size": total_extracted, 
"file_size": member.size},
-                )
+    def __init__(self, message: str, details: dict[str, Any] | None = None) -> 
None:
+        super().__init__(message)
+        self.details = details or {}
 
-            # Extract directories directly
-            if member.isdir():
-                tf.extract(member, extract_dir)
-                continue
-
-            target_path = os.path.join(extract_dir, member.name)
-            os.makedirs(os.path.dirname(target_path), exist_ok=True)
-
-            source = tf.extractfile(member)
-            if source is None:
-                continue
-
-            # For files, extract in chunks to avoid saturating memory
-            with open(target_path, "wb") as target:
-                extracted_file_size = 0
-                while True:
-                    chunk = source.read(chunk_size)
-                    if not chunk:
-                        break
-                    target.write(chunk)
-                    extracted_file_size += len(chunk)
-
-                    # Check size limits during extraction
-                    if (total_extracted + extracted_file_size) > max_size:
-                        # Clean up the partial file
-                        target.close()
-                        os.unlink(target_path)
-                        raise task.Error(
-                            f"Extraction exceeded maximum size limit of 
{max_size} bytes",
-                            {"max_size": max_size, "current_size": 
total_extracted},
-                        )
-
-            total_extracted += extracted_file_size
 
-    return total_extracted
+class GenerateCycloneDX(pydantic.BaseModel):
+    """Arguments for the task to generate a CycloneDX SBOM."""
 
+    artifact_path: str = pydantic.Field(..., description="Absolute path to the 
artifact")
+    output_path: str = pydantic.Field(..., description="Absolute path where 
the generated SBOM JSON should be written")
 
-def generate_cyclonedx(args: list[str]) -> tuple[models.TaskStatus, str | 
None, tuple[Any, ...]]:
-    """Generate a CycloneDX SBOM for the given artifact."""
-    # First argument should be the artifact path
-    artifact_path = args[0]
 
-    task_results = task.results_as_tuple(_cyclonedx_generate(artifact_path))
-    _LOGGER.info(f"Generated CycloneDX SBOM for {artifact_path}")
+def _archive_extract_safe_process_file(
+    tf: tarfile.TarFile,
+    member: tarfile.TarInfo,
+    extract_dir: str,
+    total_extracted: int,
+    max_size: int,
+    chunk_size: int,
+) -> int:
+    """Process a single file member during safe archive extraction."""
+    target_path = os.path.join(extract_dir, member.name)
+    if not 
os.path.abspath(target_path).startswith(os.path.abspath(extract_dir)):
+        _LOGGER.warning(f"Skipping potentially unsafe path: {member.name}")
+        return 0
 
-    # Check whether the generation was successful
-    result = task_results[0]
-    if not result.get("valid", False):
-        return task.FAILED, result.get("message", "SBOM generation failed"), 
task_results
+    os.makedirs(os.path.dirname(target_path), exist_ok=True)
 
-    return task.COMPLETED, None, task_results
+    source = tf.extractfile(member)
+    if source is None:
+        # Should not happen if member.isreg() is true
+        _LOGGER.warning(f"Could not extract file object for member: 
{member.name}")
+        return 0
 
+    extracted_file_size = 0
+    try:
+        with open(target_path, "wb") as target:
+            while chunk := source.read(chunk_size):
+                target.write(chunk)
+                extracted_file_size += len(chunk)
+
+                # Check size limits during extraction
+                if (total_extracted + extracted_file_size) > max_size:
+                    # Clean up the partial file before raising
+                    target.close()
+                    os.unlink(target_path)
+                    raise SBOMGenerationError(
+                        f"Extraction exceeded maximum size limit of {max_size} 
bytes",
+                        {"max_size": max_size, "current_size": 
total_extracted},
+                    )
+    finally:
+        source.close()
+
+    return extracted_file_size
+
+
+def _archive_extract_safe(
+    archive_path: str,
+    extract_dir: str,
+    max_size: int,
+    chunk_size: int,
+) -> int:
+    """Safe archive extraction."""
+    total_extracted = 0
 
-def _cyclonedx_generate(artifact_path: str) -> dict[str, Any]:
-    """Generate a CycloneDX SBOM for the given artifact."""
-    _LOGGER.info(f"Generating CycloneDX SBOM for {artifact_path}")
     try:
-        return _cyclonedx_generate_core(artifact_path)
-    except Exception as e:
-        _LOGGER.error(f"Failed to generate CycloneDX SBOM: {e}")
-        return {
-            "valid": False,
-            "message": f"Failed to generate CycloneDX SBOM: {e!s}",
-        }
-
-
-def _cyclonedx_generate_core(artifact_path: str) -> dict[str, Any]:
-    """Generate a CycloneDX SBOM for the given artifact, raising potential 
exceptions."""
-    import json
-    import subprocess
-    import tempfile
-
-    # Create a temporary directory for extraction
-    with tempfile.TemporaryDirectory(prefix="cyclonedx_sbom_") as temp_dir:
+        with tarfile.open(archive_path, mode="r|gz") as tf:
+            for member in tf:
+                # Skip anything that's not a file or directory
+                if not (member.isreg() or member.isdir()):
+                    continue
+
+                # Check whether extraction would exceed the size limit
+                if member.isreg() and ((total_extracted + member.size) > 
max_size):
+                    raise SBOMGenerationError(
+                        f"Extraction would exceed maximum size limit of 
{max_size} bytes",
+                        {"max_size": max_size, "current_size": 
total_extracted, "file_size": member.size},
+                    )
+
+                # Extract directories directly
+                if member.isdir():
+                    # Ensure the path is safe before extracting
+                    target_path = os.path.join(extract_dir, member.name)
+                    if not 
os.path.abspath(target_path).startswith(os.path.abspath(extract_dir)):
+                        _LOGGER.warning(f"Skipping potentially unsafe path: 
{member.name}")
+                        continue
+                    tf.extract(member, extract_dir, numeric_owner=True)
+                    continue
+
+                if member.isreg():
+                    extracted_size = _archive_extract_safe_process_file(
+                        tf, member, extract_dir, total_extracted, max_size, 
chunk_size
+                    )
+                    total_extracted += extracted_size
+
+                # TODO: Add other types here
+
+    except tarfile.ReadError as e:
+        raise SBOMGenerationError(f"Failed to read archive: {e}", 
{"archive_path": archive_path}) from e
+
+    return total_extracted
+
+
[email protected]_model(GenerateCycloneDX)
+async def generate_cyclonedx(args: GenerateCycloneDX) -> str | None:
+    """Generate a CycloneDX SBOM for the given artifact and write it to the 
output path."""
+    try:
+        result_data = await _generate_cyclonedx_core(args.artifact_path, 
args.output_path)
+        _LOGGER.info(f"Successfully generated CycloneDX SBOM for 
{args.artifact_path}")
+        msg = result_data["message"]
+        if not isinstance(msg, str):
+            raise SBOMGenerationError(f"Invalid message type: {type(msg)}")
+        return msg
+    except SBOMGenerationError as e:
+        _LOGGER.error(f"SBOM generation failed for {args.artifact_path}: 
{e.details}")
+        raise
+
+
+async def _generate_cyclonedx_core(artifact_path: str, output_path: str) -> 
dict[str, Any]:
+    """Core logic to generate CycloneDX SBOM, raising SBOMGenerationError on 
failure."""
+    _LOGGER.info(f"Generating CycloneDX SBOM for {artifact_path} -> 
{output_path}")
+
+    async with util.async_temporary_directory(prefix="cyclonedx_sbom_") as 
temp_dir:
         _LOGGER.info(f"Created temporary directory: {temp_dir}")
 
         # Find and validate the root directory
         try:
-            root_dir = archive.root_directory(artifact_path)
-        except task.Error as e:
-            _LOGGER.error(f"Archive root directory issue: {e}")
-            return {
-                "valid": False,
-                "message": str(e),
-                "errors": [str(e)],
-            }
+            root_dir = await asyncio.to_thread(archive.root_directory, 
artifact_path)
+        except ValueError as e:
+            raise SBOMGenerationError(f"Archive root directory issue: {e}", 
{"artifact_path": artifact_path}) from e
+        except Exception as e:
+            raise SBOMGenerationError(
+                f"Failed to determine archive root directory: {e}", 
{"artifact_path": artifact_path}
+            ) from e
 
         extract_dir = os.path.join(temp_dir, root_dir)
 
         # Extract the archive to the temporary directory
+        # TODO: Ideally we'd have task dependencies or archive caching
         _LOGGER.info(f"Extracting {artifact_path} to {temp_dir}")
-        # TODO: We need task dependencies, because we don't want to do this 
more than once
-        extracted_size = archive_extract_safe(
-            artifact_path, temp_dir, max_size=_CONFIG.MAX_EXTRACT_SIZE, 
chunk_size=_CONFIG.EXTRACT_CHUNK_SIZE
+        extracted_size = await asyncio.to_thread(
+            _archive_extract_safe,
+            artifact_path,
+            str(temp_dir),
+            max_size=_CONFIG.MAX_EXTRACT_SIZE,
+            chunk_size=_CONFIG.EXTRACT_CHUNK_SIZE,
         )
-        _LOGGER.info(f"Extracted {extracted_size} bytes")
+        _LOGGER.info(f"Extracted {extracted_size} bytes into {extract_dir}")
+
+        # Run syft to generate the CycloneDX SBOM
+        syft_command = ["syft", extract_dir, "-o", "cyclonedx-json"]
+        _LOGGER.info(f"Running syft: {' '.join(syft_command)}")
 
-        # Run syft to generate CycloneDX SBOM
         try:
-            _LOGGER.info(f"Running syft on {extract_dir}")
-            process = subprocess.run(
-                ["syft", extract_dir, "-o", "cyclonedx-json"],
-                capture_output=True,
-                text=True,
-                check=True,
-                timeout=300,
+            process = await asyncio.create_subprocess_exec(
+                *syft_command,
+                stdout=asyncio.subprocess.PIPE,
+                stderr=asyncio.subprocess.PIPE,
             )
+            stdout, stderr = await asyncio.wait_for(process.communicate(), 
timeout=300)
+
+            stdout_str = stdout.decode("utf-8").strip() if stdout else ""
+            stderr_str = stderr.decode("utf-8").strip() if stderr else ""
+
+            if process.returncode != 0:
+                _LOGGER.error(f"syft command failed with code 
{process.returncode}")
+                _LOGGER.error(f"syft stderr: {stderr_str}")
+                _LOGGER.error(f"syft stdout: {stdout_str[:1000]}...")
+                raise SBOMGenerationError(
+                    f"syft command failed with code {process.returncode}",
+                    {"returncode": process.returncode, "stderr": stderr_str, 
"stdout": stdout_str[:1000]},
+                )
 
             # Parse the JSON output from syft
             try:
-                sbom_data = json.loads(process.stdout)
+                sbom_data = json.loads(stdout_str)
+                _LOGGER.info(f"Successfully parsed syft output for 
{artifact_path}")
+
+                # Write the SBOM data to the specified output path
+                try:
+                    async with aiofiles.open(output_path, "w", 
encoding="utf-8") as f:
+                        await f.write(json.dumps(sbom_data, indent=2))
+                    _LOGGER.info(f"Successfully wrote SBOM to {output_path}")
+                except Exception as write_err:
+                    _LOGGER.exception(f"Failed to write SBOM JSON to 
{output_path}: {write_err}")
+                    raise SBOMGenerationError(f"Failed to write SBOM to 
{output_path}: {write_err}") from write_err
+
                 return {
-                    "valid": True,
-                    "message": "Successfully generated CycloneDX SBOM",
+                    "message": "Successfully generated and saved CycloneDX 
SBOM",
                     "sbom": sbom_data,
                     "format": "CycloneDX",
                     "components": len(sbom_data.get("components", [])),
                 }
             except json.JSONDecodeError as e:
                 _LOGGER.error(f"Failed to parse syft output as JSON: {e}")
-                # Include first 1000 chars of output for debugging
-                return {
-                    "valid": False,
-                    "message": f"Failed to parse syft output: {e}",
-                    "errors": [str(e), process.stdout[:1000]],
-                }
-
-        except subprocess.CalledProcessError as e:
-            _LOGGER.error(f"syft command failed: {e}")
-            return {
-                "valid": False,
-                "message": f"syft command failed with code {e.returncode}",
-                "errors": [
-                    f"Process error code: {e.returncode}",
-                    f"STDOUT: {e.stdout}",
-                    f"STDERR: {e.stderr}",
-                ],
-            }
-        except subprocess.TimeoutExpired as e:
-            _LOGGER.error(f"syft command timed out: {e}")
-            return {
-                "valid": False,
-                "message": "syft command timed out after 5 minutes",
-                "errors": [str(e)],
-            }
-        except Exception as e:
-            _LOGGER.error(f"Unexpected error running syft: {e}")
-            return {
-                "valid": False,
-                "message": f"Unexpected error running syft: {e}",
-                "errors": [str(e)],
-            }
+                raise SBOMGenerationError(
+                    f"Failed to parse syft output: {e}",
+                    {"error": str(e), "syft_output": stdout_str[:1000]},
+                ) from e
+
+        except TimeoutError:
+            _LOGGER.error("syft command timed out after 5 minutes")
+            raise SBOMGenerationError("syft command timed out after 5 minutes")
+        except FileNotFoundError:
+            _LOGGER.error("syft command not found. Is it installed and in 
PATH?")
+            raise SBOMGenerationError("syft command not found")
diff --git a/atr/tasks/vote.py b/atr/tasks/vote.py
index d26004c..08a6965 100644
--- a/atr/tasks/vote.py
+++ b/atr/tasks/vote.py
@@ -31,23 +31,6 @@ import atr.tasks.checks as checks
 _LOGGER: Final = logging.getLogger(__name__)
 _LOGGER.setLevel(logging.DEBUG)
 
-# Create file handler for tasks-vote.log
-_HANDLER: Final = logging.FileHandler("tasks-vote.log")
-_HANDLER.setLevel(logging.DEBUG)
-
-# Create formatter with detailed information
-_HANDLER.setFormatter(
-    logging.Formatter(
-        "[%(asctime)s.%(msecs)03d] [%(process)d] [%(levelname)s] 
[%(name)s:%(funcName)s:%(lineno)d] %(message)s",
-        datefmt="%Y-%m-%d %H:%M:%S",
-    )
-)
-_LOGGER.addHandler(_HANDLER)
-# Ensure parent loggers don't duplicate messages
-_LOGGER.propagate = False
-
-_LOGGER.info("Vote module imported")
-
 
 class VoteInitiationError(Exception): ...
 
@@ -86,13 +69,6 @@ async def _initiate_core_logic(args: Initiate) -> dict[str, 
Any]:
     test_recipients = ["sbp"]
     _LOGGER.info("Starting initiate_core")
 
-    root_logger = logging.getLogger()
-    has_our_handler = any(
-        (isinstance(h, logging.FileHandler) and 
h.baseFilename.endswith("tasks-vote.log")) for h in root_logger.handlers
-    )
-    if not has_our_handler:
-        root_logger.addHandler(_HANDLER)
-
     async with db.session() as data:
         release = await data.release(name=args.release_name, _project=True, 
_committee=True).demand(
             VoteInitiationError(f"Release {args.release_name} not found")
diff --git a/atr/templates/draft-review-path.html 
b/atr/templates/draft-review-path.html
index 967fa37..1b05b55 100644
--- a/atr/templates/draft-review-path.html
+++ b/atr/templates/draft-review-path.html
@@ -225,5 +225,5 @@
 {% endblock javascripts %}
 
 {% macro function_name_from_key(key) -%}
-  {{- key.split(".")[-1] .replace("_", " ") | title -}}
+  {{- key.removeprefix("atr.tasks.checks.").replace("_", " ").replace(".", " 
") | title -}}
 {%- endmacro %}
diff --git a/atr/templates/draft-tools.html b/atr/templates/draft-tools.html
index 767bcf8..33d3f4b 100644
--- a/atr/templates/draft-tools.html
+++ b/atr/templates/draft-tools.html
@@ -46,6 +46,15 @@
     </form>
   </div>
 
+  {% if file_path.endswith(".tar.gz") %}
+    <h3>Generate SBOM</h3>
+    <p>Generate a CycloneDX Software Bill of Materials (SBOM) file for this 
artifact.</p>
+    <form method="post"
+          action="{{ as_url(routes.draft.sbomgen, project_name=project_name, 
version_name=version_name, file_path=file_path) }}">
+      <button type="submit" class="btn btn-outline-secondary">Generate 
CycloneDX SBOM (.cdx.json)</button>
+    </form>
+  {% endif %}
+
   <h3>Delete file</h3>
   <p>This tool deletes the file from the candidate draft.</p>
   <form method="post"
diff --git a/atr/util.py b/atr/util.py
index 260df0c..d2dcbd0 100644
--- a/atr/util.py
+++ b/atr/util.py
@@ -15,9 +15,13 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import asyncio
+import contextlib
 import dataclasses
 import hashlib
 import pathlib
+import shutil
+import tempfile
 from collections.abc import AsyncGenerator, Callable, Mapping, Sequence
 from typing import Annotated, Any, TypeVar
 
@@ -276,3 +280,15 @@ def _get_dict_to_list_validator(inner_adapter: 
pydantic.TypeAdapter[dict[Any, An
         return val
 
     return validator
+
+
[email protected]
+async def async_temporary_directory(
+    suffix: str | None = None, prefix: str | None = None, dir: str | 
pathlib.Path | None = None
+) -> AsyncGenerator[pathlib.Path]:
+    """Create an async temporary directory similar to 
tempfile.TemporaryDirectory."""
+    temp_dir_path: str = await asyncio.to_thread(tempfile.mkdtemp, 
suffix=suffix, prefix=prefix, dir=dir)
+    try:
+        yield pathlib.Path(temp_dir_path)
+    finally:
+        await asyncio.to_thread(shutil.rmtree, temp_dir_path, 
ignore_errors=True)
diff --git a/atr/worker.py b/atr/worker.py
index 65791a4..9312437 100644
--- a/atr/worker.py
+++ b/atr/worker.py
@@ -30,26 +30,14 @@ import os
 import resource
 import signal
 import traceback
-from typing import TYPE_CHECKING, Any, Final
+from typing import Any, Final
 
 import sqlmodel
 
 import atr.db as db
 import atr.db.models as models
-import atr.tasks.bulk as bulk
-import atr.tasks.checks as checks
-import atr.tasks.checks.archive as archive
-import atr.tasks.checks.hashing as hashing
-import atr.tasks.checks.license as license
-import atr.tasks.checks.rat as rat
-import atr.tasks.checks.signature as signature
-import atr.tasks.rsync as rsync
-import atr.tasks.sbom as sbom
+import atr.tasks as tasks
 import atr.tasks.task as task
-import atr.tasks.vote as vote
-
-if TYPE_CHECKING:
-    from collections.abc import Awaitable, Callable
 
 _LOGGER: Final = logging.getLogger(__name__)
 
@@ -202,64 +190,24 @@ async def _task_process(task_id: int, task_type: str, 
task_args: list[str] | dic
     """Process a claimed task."""
     _LOGGER.info(f"Processing task {task_id} ({task_type}) with args 
{task_args}")
     try:
-        # Map task types to their handler functions
-        modern_task_handlers: dict[str, Callable[..., Awaitable[str | None]]] 
= {
-            checks.function_key(archive.integrity): archive.integrity,
-            checks.function_key(archive.structure): archive.structure,
-            checks.function_key(hashing.check): hashing.check,
-            checks.function_key(license.files): license.files,
-            checks.function_key(license.headers): license.headers,
-            checks.function_key(rat.check): rat.check,
-            checks.function_key(signature.check): signature.check,
-            checks.function_key(rsync.analyse): rsync.analyse,
-            checks.function_key(vote.initiate): vote.initiate,
-        }
-        # TODO: We should use a decorator to register these automatically
-        dict_task_handlers = {
-            "package_bulk_download": bulk.download,
-        }
-        # TODO: These are synchronous
-        # We plan to convert these to async dict handlers
-        list_task_handlers = {
-            "generate_cyclonedx_sbom": sbom.generate_cyclonedx,
-        }
-
-        task_results: tuple[Any, ...]
-        if task_type in modern_task_handlers:
-            # NOTE: The other two branches below are deprecated
-            # This is transitional code, which we will tidy up significantly
-            handler = modern_task_handlers[task_type]
-            try:
-                handler_result = await handler(task_args)
-                task_results = tuple()
-                if handler_result is not None:
-                    task_results = (handler_result,)
-                status = task.COMPLETED
-                error = None
-            except Exception as e:
-                task_results = tuple()
-                status = task.FAILED
-                error = str(e)
-                _LOGGER.exception(f"Task {task_id} ({task_type}) failed: {e}")
-        elif isinstance(task_args, dict):
-            dict_handler = dict_task_handlers.get(task_type)
-            if not dict_handler:
-                msg = f"Unknown task type: {task_type}"
-                _LOGGER.error(msg)
-                raise Exception(msg)
-            status, error, task_results = await dict_handler(task_args)
-        else:
-            list_handler = list_task_handlers.get(task_type)
-            if not list_handler:
-                msg = f"Unknown task type: {task_type}"
-                _LOGGER.error(msg)
-                raise Exception(msg)
-            status, error, task_results = list_handler(task_args)
-        _LOGGER.info(f"Task {task_id} completed with status {status}, error 
{error}, results {task_results}")
-        await _task_result_process(task_id, task_results, status, error)
+        task_type_member = models.TaskType(task_type)
+    except ValueError as e:
+        _LOGGER.error(f"Invalid task type: {task_type}")
+        await _task_result_process(task_id, tuple(), task.FAILED, str(e))
+        return
 
+    task_results: tuple[Any, ...]
+    try:
+        handler = tasks.resolve(task_type_member)
+        handler_result = await handler(task_args)
+        task_results = (handler_result,)
+        status = task.COMPLETED
+        error = None
     except Exception as e:
-        await _task_error_handle(task_id, e)
+        task_results = tuple()
+        status = task.FAILED
+        error = str(e)
+    await _task_result_process(task_id, task_results, status, error)
 
 
 # Worker functions


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to