This is an automated email from the ASF dual-hosted git repository.

sbp pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tooling-trusted-release.git


The following commit(s) were added to refs/heads/main by this push:
     new e9bf1e8  Use a unified archive interface to eliminate separate zip 
license checks
e9bf1e8 is described below

commit e9bf1e8375dc6f174abafa6637c057565c227cc5
Author: Sean B. Palmer <[email protected]>
AuthorDate: Tue May 13 15:51:30 2025 +0100

    Use a unified archive interface to eliminate separate zip license checks
---
 atr/db/models.py                                |   2 -
 atr/tarzip.py                                   | 117 +++++++++++++
 atr/tasks/__init__.py                           |   8 +-
 atr/tasks/checks/license.py                     |  34 ++--
 atr/tasks/checks/zipformat.py                   | 211 ------------------------
 migrations/versions/0004_2025.05.13_657bf05b.py |  61 +++++++
 6 files changed, 197 insertions(+), 236 deletions(-)

diff --git a/atr/db/models.py b/atr/db/models.py
index 1096074..4b6506f 100644
--- a/atr/db/models.py
+++ b/atr/db/models.py
@@ -316,8 +316,6 @@ class TaskType(str, enum.Enum):
     TARGZ_STRUCTURE = "targz_structure"
     VOTE_INITIATE = "vote_initiate"
     ZIPFORMAT_INTEGRITY = "zipformat_integrity"
-    ZIPFORMAT_LICENSE_FILES = "zipformat_license_files"
-    ZIPFORMAT_LICENSE_HEADERS = "zipformat_license_headers"
     ZIPFORMAT_STRUCTURE = "zipformat_structure"
 
 
diff --git a/atr/tarzip.py b/atr/tarzip.py
new file mode 100644
index 0000000..532c0ff
--- /dev/null
+++ b/atr/tarzip.py
@@ -0,0 +1,117 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tarfile
+import zipfile
+from collections.abc import Generator, Iterator
+from contextlib import contextmanager
+from typing import IO, Generic, TypeVar
+from typing import Protocol as TypingProtocol
+
+ArchiveT = TypeVar("ArchiveT", tarfile.TarFile, zipfile.ZipFile)
+MemberT = TypeVar("MemberT", tarfile.TarInfo, zipfile.ZipInfo)
+
+
+class AbstractArchiveMember(TypingProtocol, Generic[MemberT]):
+    name: str
+    _original_info: MemberT
+
+    def isfile(self) -> bool: ...
+    def isdir(self) -> bool: ...
+
+
+class TarMember(AbstractArchiveMember[tarfile.TarInfo]):
+    def __init__(self, original: tarfile.TarInfo):
+        self.name: str = original.name
+        self._original_info: tarfile.TarInfo = original
+
+    def isfile(self) -> bool:
+        return self._original_info.isfile()
+
+    def isdir(self) -> bool:
+        return self._original_info.isdir()
+
+
+class ZipMember(AbstractArchiveMember[zipfile.ZipInfo]):
+    def __init__(self, original: zipfile.ZipInfo):
+        self.name: str = original.filename
+        self._original_info: zipfile.ZipInfo = original
+
+    def isfile(self) -> bool:
+        return not self._original_info.is_dir()
+
+    def isdir(self) -> bool:
+        return self._original_info.is_dir()
+
+
+Member = TarMember | ZipMember
+
+
+class ArchiveContext(Generic[ArchiveT]):
+    _archive_obj: ArchiveT
+
+    def __init__(self, archive_obj: ArchiveT):
+        self._archive_obj = archive_obj
+
+    def __iter__(self) -> Iterator[TarMember | ZipMember]:
+        match self._archive_obj:
+            case tarfile.TarFile() as tf:
+                for member_orig in tf:
+                    yield TarMember(member_orig)
+            case zipfile.ZipFile() as zf:
+                for member_orig in zf.infolist():
+                    yield ZipMember(member_orig)
+
+    def extractfile(self, member_wrapper: Member) -> IO[bytes] | None:
+        try:
+            match self._archive_obj:
+                case tarfile.TarFile() as tf:
+                    if not isinstance(member_wrapper, TarMember):
+                        raise TypeError("Archive is TarFile, but 
member_wrapper is not TarMember")
+                    return tf.extractfile(member_wrapper._original_info)
+                case zipfile.ZipFile() as zf:
+                    if not isinstance(member_wrapper, ZipMember):
+                        raise TypeError("Archive is ZipFile, but 
member_wrapper is not ZipMember")
+                    return zf.open(member_wrapper._original_info)
+        except (KeyError, AttributeError, Exception):
+            return None
+
+
+Archive = ArchiveContext[tarfile.TarFile] | ArchiveContext[zipfile.ZipFile]
+
+
+@contextmanager
+def open_archive(archive_path: str) -> Generator[Archive]:
+    archive_file: tarfile.TarFile | zipfile.ZipFile | None = None
+    try:
+        try:
+            archive_file = tarfile.open(archive_path, "r:*")
+        except tarfile.ReadError:
+            try:
+                archive_file = zipfile.ZipFile(archive_path, "r")
+            except zipfile.BadZipFile:
+                raise ValueError(f"Unsupported or corrupted archive: 
{archive_path}")
+
+        match archive_file:
+            case tarfile.TarFile() as tf_concrete:
+                yield ArchiveContext[tarfile.TarFile](tf_concrete)
+            case zipfile.ZipFile() as zf_concrete:
+                yield ArchiveContext[zipfile.ZipFile](zf_concrete)
+
+    finally:
+        if archive_file:
+            archive_file.close()
diff --git a/atr/tasks/__init__.py b/atr/tasks/__init__.py
index 891b26e..ad037d0 100644
--- a/atr/tasks/__init__.py
+++ b/atr/tasks/__init__.py
@@ -162,10 +162,6 @@ def resolve(task_type: models.TaskType) -> Callable[..., 
Awaitable[str | None]]:
             return zipformat.integrity
         case models.TaskType.ZIPFORMAT_STRUCTURE:
             return zipformat.structure
-        case models.TaskType.ZIPFORMAT_LICENSE_FILES:
-            return zipformat.license_files
-        case models.TaskType.ZIPFORMAT_LICENSE_HEADERS:
-            return zipformat.license_headers
         # NOTE: Do NOT add "case _" here
         # Otherwise we lose exhaustiveness checking
 
@@ -196,8 +192,8 @@ async def zip_checks(release: models.Release, revision: 
str, path: str) -> list[
     """Create check tasks for a .zip file."""
     tasks = [
         queued(models.TaskType.ZIPFORMAT_INTEGRITY, release, revision, path),
-        queued(models.TaskType.ZIPFORMAT_LICENSE_FILES, release, revision, 
path),
-        queued(models.TaskType.ZIPFORMAT_LICENSE_HEADERS, release, revision, 
path),
+        queued(models.TaskType.LICENSE_FILES, release, revision, path),
+        queued(models.TaskType.LICENSE_HEADERS, release, revision, path),
         queued(models.TaskType.ZIPFORMAT_STRUCTURE, release, revision, path),
     ]
     return tasks
diff --git a/atr/tasks/checks/license.py b/atr/tasks/checks/license.py
index fba7f68..5385e2c 100644
--- a/atr/tasks/checks/license.py
+++ b/atr/tasks/checks/license.py
@@ -20,13 +20,13 @@ import hashlib
 import logging
 import os
 import re
-import tarfile
 from collections.abc import Iterator
 from typing import Any, Final
 
+import atr.db.models as models
+import atr.schema as schema
+import atr.tarzip as tarzip
 import atr.tasks.checks as checks
-from atr import schema
-from atr.db import models
 
 _LOGGER: Final = logging.getLogger(__name__)
 
@@ -210,8 +210,8 @@ def _files_check_core_logic(artifact_path: str) -> 
Iterator[Result]:
     #     # Continue checking files
 
     # Check for license files in the root directory
-    with tarfile.open(artifact_path, mode="r|gz") as tf:
-        for member in tf:
+    with tarzip.open_archive(artifact_path) as archive:
+        for member in archive:
             _LOGGER.warning(f"Checking member: {member.name}")
             if member.name and member.name.split("/")[-1].startswith("._"):
                 # Metadata convention
@@ -226,10 +226,10 @@ def _files_check_core_logic(artifact_path: str) -> 
Iterator[Result]:
                 files_found.append(filename)
                 if filename == "LICENSE":
                     # TODO: Check length, should be 11,358 bytes
-                    license_ok = _files_check_core_logic_license(tf, member)
+                    license_ok = _files_check_core_logic_license(archive, 
member)
                 elif filename == "NOTICE":
                     # TODO: Check length doesn't exceed some preset
-                    notice_ok, notice_issues = 
_files_check_core_logic_notice(tf, member)
+                    notice_ok, notice_issues = 
_files_check_core_logic_notice(archive, member)
 
     yield from _files_messages_build(files_found, license_ok, notice_ok, 
notice_issues)
 
@@ -259,9 +259,9 @@ def _files_check_core_logic(artifact_path: str) -> 
Iterator[Result]:
         )
 
 
-def _files_check_core_logic_license(tf: tarfile.TarFile, member: 
tarfile.TarInfo) -> bool:
+def _files_check_core_logic_license(archive: tarzip.Archive, member: 
tarzip.Member) -> bool:
     """Verify that the LICENSE file matches the Apache 2.0 license."""
-    f = tf.extractfile(member)
+    f = archive.extractfile(member)
     if not f:
         return False
 
@@ -276,9 +276,9 @@ def _files_check_core_logic_license(tf: tarfile.TarFile, 
member: tarfile.TarInfo
     return False
 
 
-def _files_check_core_logic_notice(tf: tarfile.TarFile, member: 
tarfile.TarInfo) -> tuple[bool, list[str]]:
+def _files_check_core_logic_notice(archive: tarzip.Archive, member: 
tarzip.Member) -> tuple[bool, list[str]]:
     """Verify that the NOTICE file follows the required format."""
-    f = tf.extractfile(member)
+    f = archive.extractfile(member)
     if not f:
         return False, ["Could not read NOTICE file"]
 
@@ -372,13 +372,13 @@ def _headers_check_core_logic(artifact_path: str) -> 
Iterator[Result]:
     #     )
 
     # Check files in the archive
-    with tarfile.open(artifact_path, mode="r|gz") as tf:
-        for member in tf:
+    with tarzip.open_archive(artifact_path) as archive:
+        for member in archive:
             if member.name and member.name.split("/")[-1].startswith("._"):
                 # Metadata convention
                 continue
 
-            match _headers_check_core_logic_process_file(tf, member):
+            match _headers_check_core_logic_process_file(archive, member):
                 case ArtifactResult() | MemberResult() as result:
                     artifact_data.files_checked += 1
                     match result.status:
@@ -405,8 +405,8 @@ def _headers_check_core_logic(artifact_path: str) -> 
Iterator[Result]:
 
 
 def _headers_check_core_logic_process_file(
-    tf: tarfile.TarFile,
-    member: tarfile.TarInfo,
+    archive: tarzip.Archive,
+    member: tarzip.Member,
 ) -> Result:
     """Process a single file in an archive for license header verification."""
     if not member.isfile():
@@ -424,7 +424,7 @@ def _headers_check_core_logic_process_file(
 
     # Extract and check the file
     try:
-        f = tf.extractfile(member)
+        f = archive.extractfile(member)
         if f is None:
             return MemberResult(
                 status=models.CheckResultStatus.EXCEPTION,
diff --git a/atr/tasks/checks/zipformat.py b/atr/tasks/checks/zipformat.py
index f4b0351..03b044a 100644
--- a/atr/tasks/checks/zipformat.py
+++ b/atr/tasks/checks/zipformat.py
@@ -22,7 +22,6 @@ import zipfile
 from typing import Any, Final
 
 import atr.tasks.checks as checks
-import atr.tasks.checks.license as license
 
 _LOGGER: Final = logging.getLogger(__name__)
 
@@ -47,68 +46,6 @@ async def integrity(args: checks.FunctionArguments) -> str | 
None:
     return None
 
 
-async def license_files(args: checks.FunctionArguments) -> str | None:
-    """Check that the LICENSE and NOTICE files exist and are valid within the 
zip."""
-    recorder = await args.recorder()
-    if not (artifact_abs_path := await recorder.abs_path()):
-        return None
-
-    _LOGGER.info(f"Checking zip license files for {artifact_abs_path} (rel: 
{args.primary_rel_path})")
-
-    try:
-        result_data = await 
asyncio.to_thread(_license_files_check_core_logic_zip, str(artifact_abs_path))
-
-        if result_data.get("error"):
-            await recorder.failure(result_data["error"], result_data)
-        elif result_data.get("license_valid") and 
result_data.get("notice_valid"):
-            await recorder.success("LICENSE and NOTICE files present and valid 
in zip", result_data)
-        else:
-            issues = []
-            if not result_data.get("license_found"):
-                issues.append("LICENSE missing")
-            elif not result_data.get("license_valid"):
-                issues.append("LICENSE invalid or empty")
-            if not result_data.get("notice_found"):
-                issues.append("NOTICE missing")
-            elif not result_data.get("notice_valid"):
-                issues.append("NOTICE invalid or empty")
-            issue_str = ", ".join(issues) if issues else "Issues found with 
LICENSE or NOTICE files"
-            await recorder.failure(issue_str, result_data)
-
-    except Exception as e:
-        await recorder.failure("Error checking zip license files", {"error": 
str(e)})
-
-    return None
-
-
-async def license_headers(args: checks.FunctionArguments) -> str | None:
-    """Check that all source files within the zip have valid license 
headers."""
-    recorder = await args.recorder()
-    if not (artifact_abs_path := await recorder.abs_path()):
-        return None
-
-    _LOGGER.info(f"Checking zip license headers for {artifact_abs_path} (rel: 
{args.primary_rel_path})")
-
-    try:
-        result_data = await 
asyncio.to_thread(_license_headers_check_core_logic_zip, str(artifact_abs_path))
-
-        if result_data.get("error_message"):
-            await recorder.failure(result_data["error_message"], result_data)
-        elif not result_data.get("valid"):
-            num_issues = len(result_data.get("files_without_headers", []))
-            failure_msg = f"{num_issues} file(s) missing or having invalid 
license headers"
-            await recorder.failure(failure_msg, result_data)
-        else:
-            await recorder.success(
-                f"License headers OK ({result_data.get('files_checked', 0)} 
files checked)", result_data
-            )
-
-    except Exception as e:
-        await recorder.failure("Error checking zip license headers", {"error": 
str(e)})
-
-    return None
-
-
 async def structure(args: checks.FunctionArguments) -> str | None:
     """Check that the zip archive has a single root directory matching the 
artifact name."""
     recorder = await args.recorder()
@@ -148,154 +85,6 @@ def _integrity_check_core_logic(artifact_path: str) -> 
dict[str, Any]:
         return {"error": f"Unexpected error: {e}"}
 
 
-def _license_files_check_core_logic_zip(artifact_path: str) -> dict[str, Any]:
-    """Verify LICENSE and NOTICE files within a zip archive."""
-    # TODO: Obviously we want to reuse the license files check logic from 
license.py
-    # But we'd need to have task dependencies to do that, ideally
-    try:
-        with zipfile.ZipFile(artifact_path, "r") as zf:
-            members = zf.namelist()
-            if not members:
-                return {"error": "Archive is empty"}
-
-            root_dir = _license_files_find_root_dir_zip(members)
-            # _LOGGER.info(f"Root dir of {artifact_path}: {root_dir}")
-            if not root_dir:
-                return {"error": "Could not determine root directory"}
-
-            expected_license_path = root_dir + "/LICENSE"
-            expected_notice_path = root_dir + "/NOTICE"
-
-            member_set = set(members)
-
-            license_found, license_valid = (
-                _license_files_check_file_zip(zf, artifact_path, 
expected_license_path)
-                if (expected_license_path in member_set)
-                else (False, False)
-            )
-            notice_found, notice_valid = (
-                _license_files_check_file_zip(zf, artifact_path, 
expected_notice_path)
-                if (expected_notice_path in member_set)
-                else (False, False)
-            )
-
-            return {
-                "root_dir": root_dir,
-                "license_found": license_found,
-                "license_valid": license_valid,
-                "notice_found": notice_found,
-                "notice_valid": notice_valid,
-            }
-
-    except zipfile.BadZipFile as e:
-        return {"error": f"Bad zip file: {e}"}
-    except FileNotFoundError:
-        return {"error": "File not found"}
-    except Exception as e:
-        return {"error": f"Unexpected error: {e}"}
-
-
-def _license_files_check_file_zip(zf: zipfile.ZipFile, artifact_path: str, 
expected_path: str) -> tuple[bool, bool]:
-    """Check for the presence and basic validity of a specific file in a 
zip."""
-    found = False
-    valid = False
-    try:
-        with zf.open(expected_path) as file_handle:
-            found = True
-            content = file_handle.read().strip()
-            if content:
-                # TODO: Add more specific NOTICE checks if needed
-                valid = True
-    except KeyError:
-        # File not found in zip
-        ...
-    except Exception as e:
-        filename = os.path.basename(expected_path)
-        _LOGGER.warning(f"Error reading {filename} in zip {artifact_path}: 
{e}")
-    return found, valid
-
-
-def _license_files_find_root_dir_zip(members: list[str]) -> str | None:
-    """Find the root directory in a list of zip members."""
-    for member in members:
-        if "/" in member:
-            return member.split("/", 1)[0]
-    return None
-
-
-def _license_headers_check_core_logic_zip(artifact_path: str) -> dict[str, 
Any]:
-    """Verify license headers for files within a zip archive."""
-    files_checked = 0
-    files_with_issues: list[str] = []
-    try:
-        with zipfile.ZipFile(artifact_path, "r") as zf:
-            members = zf.infolist()
-
-            for member_info in members:
-                if member_info.is_dir():
-                    continue
-
-                member_path = member_info.filename
-                _, extension = os.path.splitext(member_path)
-                extension = extension.lower().lstrip(".")
-
-                if not _license_headers_check_should_check_zip(member_path, 
extension):
-                    continue
-
-                files_checked += 1
-                is_valid, error_msg = 
_license_headers_check_single_file_zip(zf, member_info, extension)
-
-                if error_msg:
-                    # Already includes path and error type
-                    files_with_issues.append(error_msg)
-                elif not is_valid:
-                    # Just append path for header mismatch
-                    files_with_issues.append(member_path)
-
-            if files_with_issues:
-                return {
-                    "valid": False,
-                    "files_checked": files_checked,
-                    "files_without_headers": files_with_issues,
-                    "error_message": None,
-                }
-            else:
-                return {
-                    "valid": True,
-                    "files_checked": files_checked,
-                    "files_without_headers": [],
-                    "error_message": None,
-                }
-
-    except zipfile.BadZipFile as e:
-        return {"valid": False, "error_message": f"Bad zip file: {e}"}
-    except FileNotFoundError:
-        return {"valid": False, "error_message": "File not found"}
-    except Exception as e:
-        return {"valid": False, "error_message": f"Unexpected error: {e}"}
-
-
-def _license_headers_check_should_check_zip(member_path: str, extension: str) 
-> bool:
-    """Determine whether a file in a zip should be checked for license 
headers."""
-    for pattern in license.INCLUDED_PATTERNS:
-        if license.re.match(pattern, f".{extension}"):
-            return True
-    return False
-
-
-def _license_headers_check_single_file_zip(
-    zf: zipfile.ZipFile, member_info: zipfile.ZipInfo, extension: str
-) -> tuple[bool, str | None]:
-    """Check the license header of a single file within a zip. Returns 
(is_valid, error_message)."""
-    member_path = member_info.filename
-    try:
-        with zf.open(member_path) as file_in_zip:
-            content_bytes = file_in_zip.read(4096)
-            return license.headers_validate(content_bytes, member_path)
-    except Exception as read_error:
-        return False, f"{member_path} (Read Error: {read_error})"
-
-
 def _structure_check_core_logic(artifact_path: str) -> dict[str, Any]:
     """Verify the internal structure of the zip archive."""
     try:
diff --git a/migrations/versions/0004_2025.05.13_657bf05b.py 
b/migrations/versions/0004_2025.05.13_657bf05b.py
new file mode 100644
index 0000000..6f16dc9
--- /dev/null
+++ b/migrations/versions/0004_2025.05.13_657bf05b.py
@@ -0,0 +1,61 @@
+"""Remove some check functions from TaskType
+
+Revision ID: 0004_2025.05.13_657bf05b
+Revises: 0003_2025.05.09_ee553bee
+Create Date: 2025-05-13 14:41:31.781711+00:00
+"""
+
+from collections.abc import Sequence
+
+import sqlalchemy as sa
+from alembic import op
+
+# Revision identifiers, used by Alembic
+revision: str = "0004_2025.05.13_657bf05b"
+down_revision: str | None = "0003_2025.05.09_ee553bee"
+branch_labels: str | Sequence[str] | None = None
+depends_on: str | Sequence[str] | None = None
+
+_ENUM_MEMBERS_BEFORE_REMOVAL = (
+    "HASHING_CHECK",
+    "KEYS_IMPORT_FILE",
+    "LICENSE_FILES",
+    "LICENSE_HEADERS",
+    "MESSAGE_SEND",
+    "PATHS_CHECK",
+    "RAT_CHECK",
+    "SBOM_GENERATE_CYCLONEDX",
+    "SIGNATURE_CHECK",
+    "SVN_IMPORT_FILES",
+    "TARGZ_INTEGRITY",
+    "TARGZ_STRUCTURE",
+    "VOTE_INITIATE",
+    "ZIPFORMAT_INTEGRITY",
+    "ZIPFORMAT_LICENSE_FILES",
+    "ZIPFORMAT_LICENSE_HEADERS",
+    "ZIPFORMAT_STRUCTURE",
+)
+
+_ENUM_MEMBERS_AFTER_REMOVAL = tuple(
+    m for m in _ENUM_MEMBERS_BEFORE_REMOVAL if m not in 
{"ZIPFORMAT_LICENSE_FILES", "ZIPFORMAT_LICENSE_HEADERS"}
+)
+
+
+def upgrade() -> None:
+    with op.batch_alter_table("task", schema=None) as batch_op:
+        batch_op.alter_column(
+            "task_type",
+            existing_type=sa.Enum(*_ENUM_MEMBERS_BEFORE_REMOVAL, 
name="tasktype"),
+            type_=sa.Enum(*_ENUM_MEMBERS_AFTER_REMOVAL, name="tasktype"),
+            existing_nullable=False,
+        )
+
+
+def downgrade() -> None:
+    with op.batch_alter_table("task", schema=None) as batch_op:
+        batch_op.alter_column(
+            "task_type",
+            existing_type=sa.Enum(*_ENUM_MEMBERS_AFTER_REMOVAL, 
name="tasktype"),
+            type_=sa.Enum(*_ENUM_MEMBERS_BEFORE_REMOVAL, name="tasktype"),
+            existing_nullable=False,
+        )


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to