This is an automated email from the ASF dual-hosted git repository.

sbp pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tooling-trusted-release.git


The following commit(s) were added to refs/heads/main by this push:
     new 9d0d9f7  Run RAT checks on zip files too
9d0d9f7 is described below

commit 9d0d9f7385032a59c84729204639312ca1f4ffb8
Author: Sean B. Palmer <[email protected]>
AuthorDate: Mon Jun 23 19:48:19 2025 +0100

    Run RAT checks on zip files too
---
 atr/archives.py           | 162 ++++++++++++++++++++++++++++++++-------
 atr/tarzip.py             |  47 ++++++++++--
 atr/tasks/__init__.py     |   2 +-
 atr/tasks/checks/rat.py   | 190 +++++++++++++++++++++++++---------------------
 atr/tasks/checks/targz.py |   2 +-
 atr/tasks/sbom.py         |   2 +-
 6 files changed, 283 insertions(+), 122 deletions(-)

diff --git a/atr/archives.py b/atr/archives.py
index 5a0959d..0175fb7 100644
--- a/atr/archives.py
+++ b/atr/archives.py
@@ -19,8 +19,11 @@ import logging
 import os
 import os.path
 import tarfile
+import zipfile
 from typing import Final
 
+import atr.tarzip as tarzip
+
 _LOGGER: Final = logging.getLogger(__name__)
 
 
@@ -28,46 +31,53 @@ class ExtractionError(Exception):
     pass
 
 
-def targz_extract(
+def extract(
     archive_path: str,
     extract_dir: str,
     max_size: int,
     chunk_size: int,
 ) -> int:
-    """Safe archive extraction."""
     total_extracted = 0
 
     try:
-        with tarfile.open(archive_path, mode="r|gz") as tf:
-            for member in tf:
-                keep_going, total_extracted = archive_extract_member(
-                    tf, member, extract_dir, total_extracted, max_size, 
chunk_size
-                )
-                if not keep_going:
-                    break
-
-    except tarfile.ReadError as e:
+        with tarzip.open_archive(archive_path) as archive:
+            match archive.specific():
+                case tarfile.TarFile() as tf:
+                    for member in tf:
+                        keep_going, total_extracted = archive_extract_member(
+                            tf, member, extract_dir, total_extracted, 
max_size, chunk_size
+                        )
+                        if not keep_going:
+                            break
+
+                case zipfile.ZipFile():
+                    for member in archive:
+                        if not isinstance(member, tarzip.ZipMember):
+                            continue
+                        keep_going, total_extracted = 
_zip_archive_extract_member(
+                            archive, member, extract_dir, total_extracted, 
max_size, chunk_size
+                        )
+                        if not keep_going:
+                            break
+
+                case _:
+                    raise ExtractionError("Unsupported archive type", 
{"archive_path": archive_path})
+
+    except (tarfile.TarError, zipfile.BadZipFile, ValueError) as e:
         raise ExtractionError(f"Failed to read archive: {e}", {"archive_path": 
archive_path}) from e
 
     return total_extracted
 
 
-def targz_total_size(tgz_path: str, chunk_size: int = 4096) -> int:
-    """Verify a .tar.gz file and compute its uncompressed size."""
-    total_size = 0
+def total_size(tgz_path: str, chunk_size: int = 4096) -> int:
+    with tarzip.open_archive(tgz_path) as archive:
+        match archive.specific():
+            case tarfile.TarFile() as tf:
+                total_size = _size_tar(tf, chunk_size)
+
+            case zipfile.ZipFile():
+                total_size = _size_zip(archive, chunk_size)
 
-    with tarfile.open(tgz_path, mode="r|gz") as tf:
-        for member in tf:
-            # Do not skip metadata here
-            total_size += member.size
-            # Verify file by extraction
-            if member.isfile():
-                f = tf.extractfile(member)
-                if f is not None:
-                    while True:
-                        data = f.read(chunk_size)
-                        if not data:
-                            break
     return total_size
 
 
@@ -216,3 +226,103 @@ def _safe_path(base_dir: str, *paths: str) -> str | None:
     if target.startswith(os.path.abspath(base_dir)):
         return target
     return None
+
+
+def _size_tar(tf: tarfile.TarFile, chunk_size: int) -> int:
+    total_size = 0
+    for member in tf:
+        total_size += member.size
+        if member.isfile():
+            fileobj = tf.extractfile(member)
+            if fileobj is not None:
+                while fileobj.read(chunk_size):
+                    pass
+    return total_size
+
+
+def _size_zip(archive: tarzip.Archive, chunk_size: int) -> int:
+    total_size = 0
+    for member in archive:
+        if not isinstance(member, tarzip.ZipMember):
+            continue
+        total_size += member.size
+        if member.isfile():
+            fileobj = archive.extractfile(member)
+            if fileobj is not None:
+                while fileobj.read(chunk_size):
+                    pass
+    return total_size
+
+
+def _zip_archive_extract_member(
+    archive: tarzip.Archive,
+    member: tarzip.ZipMember,
+    extract_dir: str,
+    total_extracted: int,
+    max_size: int,
+    chunk_size: int,
+) -> tuple[bool, int]:
+    if member.name.split("/")[-1].startswith("._"):
+        return False, 0
+
+    if member.isfile() and (total_extracted + member.size) > max_size:
+        raise ExtractionError(
+            f"Extraction would exceed maximum size limit of {max_size} bytes",
+            {"max_size": max_size, "current_size": total_extracted, 
"file_size": member.size},
+        )
+
+    if member.isdir():
+        target_path = os.path.join(extract_dir, member.name)
+        if not 
os.path.abspath(target_path).startswith(os.path.abspath(extract_dir)):
+            _LOGGER.warning("Skipping potentially unsafe path: %s", 
member.name)
+            return False, 0
+        os.makedirs(target_path, exist_ok=True)
+        return True, total_extracted
+
+    if member.isfile():
+        extracted_size = _zip_extract_safe_process_file(
+            archive, member, extract_dir, total_extracted, max_size, chunk_size
+        )
+        return True, total_extracted + extracted_size
+
+    return False, total_extracted
+
+
+def _zip_extract_safe_process_file(
+    archive: tarzip.Archive,
+    member: tarzip.ZipMember,
+    extract_dir: str,
+    total_extracted: int,
+    max_size: int,
+    chunk_size: int,
+) -> int:
+    target_path = os.path.join(extract_dir, member.name)
+    if not 
os.path.abspath(target_path).startswith(os.path.abspath(extract_dir)):
+        _LOGGER.warning(f"Skipping potentially unsafe path: {member.name}")
+        return 0
+
+    os.makedirs(os.path.dirname(target_path), exist_ok=True)
+
+    source = archive.extractfile(member)
+    if source is None:
+        _LOGGER.warning(f"Could not extract {member.name} from archive")
+        return 0
+
+    extracted_file_size = 0
+    try:
+        with open(target_path, "wb") as target:
+            while chunk := source.read(chunk_size):
+                target.write(chunk)
+                extracted_file_size += len(chunk)
+
+                if (total_extracted + extracted_file_size) > max_size:
+                    target.close()
+                    os.unlink(target_path)
+                    raise ExtractionError(
+                        f"Extraction exceeded maximum size limit of {max_size} 
bytes",
+                        {"max_size": max_size, "current_size": 
total_extracted},
+                    )
+    finally:
+        source.close()
+
+    return extracted_file_size
diff --git a/atr/tarzip.py b/atr/tarzip.py
index ebeab18..414f37a 100644
--- a/atr/tarzip.py
+++ b/atr/tarzip.py
@@ -31,16 +31,24 @@ MemberT = TypeVar("MemberT", tarfile.TarInfo, 
zipfile.ZipInfo, covariant=True)
 
 class AbstractArchiveMember[MemberT: (tarfile.TarInfo, 
zipfile.ZipInfo)](TypingProtocol):  # type: ignore[misc]
     name: str
+    size: int
+    linkname: str | None
+
     _original_info: MemberT
 
     def isfile(self) -> bool: ...
     def isdir(self) -> bool: ...
+    def issym(self) -> bool: ...
+    def islnk(self) -> bool: ...
+    def isdev(self) -> bool: ...
 
 
 class TarMember(AbstractArchiveMember[tarfile.TarInfo]):
     def __init__(self, original: tarfile.TarInfo):
-        self.name: str = original.name
-        self._original_info: tarfile.TarInfo = original
+        self.name = original.name
+        self._original_info = original
+        self.size = original.size
+        self.linkname = original.linkname if hasattr(original, "linkname") 
else None
 
     def isfile(self) -> bool:
         return self._original_info.isfile()
@@ -48,11 +56,24 @@ class TarMember(AbstractArchiveMember[tarfile.TarInfo]):
     def isdir(self) -> bool:
         return self._original_info.isdir()
 
+    def issym(self) -> bool:
+        return self._original_info.issym()
+
+    def islnk(self) -> bool:
+        return self._original_info.islnk()
+
+    def isdev(self) -> bool:
+        return self._original_info.isdev()
+
 
 class ZipMember(AbstractArchiveMember[zipfile.ZipInfo]):
     def __init__(self, original: zipfile.ZipInfo):
-        self.name: str = original.filename
-        self._original_info: zipfile.ZipInfo = original
+        self.name = original.filename
+        self._original_info = original
+
+        self.size = original.file_size
+        # Link targets are not encoded in ZIP files
+        self.linkname: str | None = None
 
     def isfile(self) -> bool:
         return not self._original_info.is_dir()
@@ -60,6 +81,15 @@ class ZipMember(AbstractArchiveMember[zipfile.ZipInfo]):
     def isdir(self) -> bool:
         return self._original_info.is_dir()
 
+    def issym(self) -> bool:
+        return False
+
+    def islnk(self) -> bool:
+        return False
+
+    def isdev(self) -> bool:
+        return False
+
 
 Member = TarMember | ZipMember
 
@@ -74,6 +104,8 @@ class ArchiveContext[ArchiveT: (tarfile.TarFile, 
zipfile.ZipFile)]:
         match self._archive_obj:
             case tarfile.TarFile() as tf:
                 for member_orig in tf:
+                    if member_orig.isdev():
+                        continue
                     yield TarMember(member_orig)
             case zipfile.ZipFile() as zf:
                 for member_orig in zf.infolist():
@@ -93,8 +125,13 @@ class ArchiveContext[ArchiveT: (tarfile.TarFile, 
zipfile.ZipFile)]:
         except (KeyError, AttributeError, Exception):
             return None
 
+    def specific(self) -> tarfile.TarFile | zipfile.ZipFile:
+        return self._archive_obj
+
 
-Archive = ArchiveContext[tarfile.TarFile] | ArchiveContext[zipfile.ZipFile]
+TarArchive = ArchiveContext[tarfile.TarFile]
+ZipArchive = ArchiveContext[zipfile.ZipFile]
+Archive = TarArchive | ZipArchive
 
 
 @contextmanager
diff --git a/atr/tasks/__init__.py b/atr/tasks/__init__.py
index 15a8101..4c80ebc 100644
--- a/atr/tasks/__init__.py
+++ b/atr/tasks/__init__.py
@@ -196,7 +196,7 @@ async def zip_checks(release: models.Release, revision: 
str, path: str) -> list[
     tasks = [
         queued(models.TaskType.LICENSE_FILES, release, revision, path),
         queued(models.TaskType.LICENSE_HEADERS, release, revision, path),
-        # queued(models.TaskType.RAT_CHECK, release, revision, path),
+        queued(models.TaskType.RAT_CHECK, release, revision, path),
         queued(models.TaskType.ZIPFORMAT_INTEGRITY, release, revision, path),
         queued(models.TaskType.ZIPFORMAT_STRUCTURE, release, revision, path),
     ]
diff --git a/atr/tasks/checks/rat.py b/atr/tasks/checks/rat.py
index 4449a9b..716a9e1 100644
--- a/atr/tasks/checks/rat.py
+++ b/atr/tasks/checks/rat.py
@@ -26,7 +26,6 @@ from typing import Any, Final
 import atr.archives as archives
 import atr.config as config
 import atr.tasks.checks as checks
-import atr.tasks.checks.targz as targz
 
 _CONFIG: Final = config.get()
 _JAVA_MEMORY_ARGS: Final[list[str]] = []
@@ -89,51 +88,11 @@ def _check_core_logic(
 ) -> dict[str, Any]:
     """Verify license headers using Apache RAT."""
     _LOGGER.info(f"Verifying licenses with Apache RAT for {artifact_path}")
-
-    # Log the PATH environment variable
     _LOGGER.info(f"PATH environment variable: {os.environ.get('PATH', 'PATH 
not found')}")
 
-    # Check that Java is installed
-    # TODO: Run this only once, when the server starts
-    try:
-        java_version = subprocess.check_output(
-            ["java", *_JAVA_MEMORY_ARGS, "-version"], 
stderr=subprocess.STDOUT, text=True
-        )
-        _LOGGER.info(f"Java version: {java_version.splitlines()[0]}")
-    except (subprocess.SubprocessError, FileNotFoundError) as e:
-        _LOGGER.error(f"Java is not properly installed or not in PATH: {e}")
-
-        # Try to get some output even if the command failed
-        try:
-            # Use run instead of check_output to avoid exceptions
-            java_result = subprocess.run(
-                ["java", *_JAVA_MEMORY_ARGS, "-version"],
-                stderr=subprocess.STDOUT,
-                stdout=subprocess.PIPE,
-                text=True,
-                check=False,
-            )
-            _LOGGER.info(f"Java command return code: {java_result.returncode}")
-            _LOGGER.info(f"Java command output: {java_result.stdout or 
java_result.stderr}")
-
-            # Try to find where Java might be located
-            which_java = subprocess.run(["which", "java"], 
capture_output=True, text=True, check=False)
-            which_java_result = which_java.stdout.strip() if 
(which_java.returncode == 0) else "not found"
-            _LOGGER.info(f"Result for which java: {which_java_result}")
-        except Exception as inner_e:
-            _LOGGER.error(f"Additional error while trying to debug java: 
{inner_e}")
-
-        return {
-            "valid": False,
-            "message": "Java is not properly installed or not in PATH",
-            "total_files": 0,
-            "approved_licenses": 0,
-            "unapproved_licenses": 0,
-            "unknown_licenses": 0,
-            "unapproved_files": [],
-            "unknown_license_files": [],
-            "errors": [f"Java error: {e}"],
-        }
+    java_check = _check_java_installed()
+    if java_check is not None:
+        return java_check
 
     # Verify RAT JAR exists and is accessible
     rat_jar_path, jar_error = _check_core_logic_jar_exists(rat_jar_path)
@@ -146,33 +105,42 @@ def _check_core_logic(
         with tempfile.TemporaryDirectory(prefix="rat_verify_") as temp_dir:
             _LOGGER.info(f"Created temporary directory: {temp_dir}")
 
-            # Find and validate the root directory
-            try:
-                root_dir = targz.root_directory(artifact_path)
-            except targz.RootDirectoryError as e:
-                error_msg = str(e)
-                _LOGGER.error(f"Archive root directory issue: {error_msg}")
+            # # Find and validate the root directory
+            # try:
+            #     root_dir = targz.root_directory(artifact_path)
+            # except targz.RootDirectoryError as e:
+            #     error_msg = str(e)
+            #     _LOGGER.error(f"Archive root directory issue: {error_msg}")
+            #     return {
+            #         "valid": False,
+            #         "message": "No root directory found",
+            #         "total_files": 0,
+            #         "approved_licenses": 0,
+            #         "unapproved_licenses": 0,
+            #         "unknown_licenses": 0,
+            #         "unapproved_files": [],
+            #         "unknown_license_files": [],
+            #         "warning": error_msg or "No root directory found",
+            #         "errors": [],
+            #     }
+
+            # extract_dir = os.path.join(temp_dir, root_dir)
+
+            # Extract the archive to the temporary directory
+            _LOGGER.info(f"Extracting {artifact_path} to {temp_dir}")
+            extracted_size = archives.extract(artifact_path, temp_dir, 
max_size=max_extract_size, chunk_size=chunk_size)
+            _LOGGER.info(f"Extracted {extracted_size} bytes")
+
+            # Find the root directory
+            if (extract_dir := _extracted_dir(temp_dir)) is None:
+                _LOGGER.error("No root directory found in archive")
                 return {
                     "valid": False,
-                    "message": "No root directory found",
-                    "total_files": 0,
-                    "approved_licenses": 0,
-                    "unapproved_licenses": 0,
-                    "unknown_licenses": 0,
-                    "unapproved_files": [],
-                    "unknown_license_files": [],
-                    "warning": error_msg or "No root directory found",
+                    "message": "No root directory found in archive",
                     "errors": [],
                 }
 
-            extract_dir = os.path.join(temp_dir, root_dir)
-
-            # Extract the archive to the temporary directory
-            _LOGGER.info(f"Extracting {artifact_path} to {temp_dir}")
-            extracted_size = archives.targz_extract(
-                artifact_path, temp_dir, max_size=max_extract_size, 
chunk_size=chunk_size
-            )
-            _LOGGER.info(f"Extracted {extracted_size} bytes")
+            _LOGGER.info(f"Using root directory: {extract_dir}")
 
             # Execute RAT and get results or error
             error_result, xml_output_path = 
_check_core_logic_execute_rat(rat_jar_path, extract_dir, temp_dir)
@@ -180,28 +148,14 @@ def _check_core_logic(
                 return error_result
 
             # Parse the XML output
-            try:
-                _LOGGER.info(f"Parsing RAT XML output: {xml_output_path}")
-                # Make sure xml_output_path is not None before parsing
-                if xml_output_path is None:
-                    raise ValueError("XML output path is None")
-
-                results = _check_core_logic_parse_output(xml_output_path, 
extract_dir)
-                _LOGGER.info(f"Successfully parsed RAT output with 
{results.get('total_files', 0)} files")
-                return results
-            except Exception as e:
-                _LOGGER.error(f"Error parsing RAT output: {e}")
-                return {
-                    "valid": False,
-                    "message": f"Failed to parse Apache RAT output: {e!s}",
-                    "total_files": 0,
-                    "approved_licenses": 0,
-                    "unapproved_licenses": 0,
-                    "unknown_licenses": 0,
-                    "unapproved_files": [],
-                    "unknown_license_files": [],
-                    "errors": [f"Parse error: {e}"],
-                }
+            _LOGGER.info(f"Parsing RAT XML output: {xml_output_path}")
+            # Make sure xml_output_path is not None before parsing
+            if xml_output_path is None:
+                raise ValueError("XML output path is None")
+
+            results = _check_core_logic_parse_output(xml_output_path, 
extract_dir)
+            _LOGGER.info(f"Successfully parsed RAT output with 
{results.get('total_files', 0)} files")
+            return results
 
     except Exception as e:
         import traceback
@@ -475,3 +429,63 @@ with unapproved licenses, and {unknown_licenses} with 
unknown licenses"""
             "unknown_licenses": 0,
             "errors": [f"XML parsing error: {e!s}"],
         }
+
+
+def _check_java_installed() -> dict[str, Any] | None:
+    # Check that Java is installed
+    # TODO: Run this only once, when the server starts
+    try:
+        java_version = subprocess.check_output(
+            ["java", *_JAVA_MEMORY_ARGS, "-version"], 
stderr=subprocess.STDOUT, text=True
+        )
+        _LOGGER.info(f"Java version: {java_version.splitlines()[0]}")
+    except (subprocess.SubprocessError, FileNotFoundError) as e:
+        _LOGGER.error(f"Java is not properly installed or not in PATH: {e}")
+
+        # Try to get some output even if the command failed
+        try:
+            # Use run instead of check_output to avoid exceptions
+            java_result = subprocess.run(
+                ["java", *_JAVA_MEMORY_ARGS, "-version"],
+                stderr=subprocess.STDOUT,
+                stdout=subprocess.PIPE,
+                text=True,
+                check=False,
+            )
+            _LOGGER.info(f"Java command return code: {java_result.returncode}")
+            _LOGGER.info(f"Java command output: {java_result.stdout or 
java_result.stderr}")
+
+            # Try to find where Java might be located
+            which_java = subprocess.run(["which", "java"], 
capture_output=True, text=True, check=False)
+            which_java_result = which_java.stdout.strip() if 
(which_java.returncode == 0) else "not found"
+            _LOGGER.info(f"Result for which java: {which_java_result}")
+        except Exception as inner_e:
+            _LOGGER.error(f"Additional error while trying to debug java: 
{inner_e}")
+
+        return {
+            "valid": False,
+            "message": "Java is not properly installed or not in PATH",
+            "total_files": 0,
+            "approved_licenses": 0,
+            "unapproved_licenses": 0,
+            "unknown_licenses": 0,
+            "unapproved_files": [],
+            "unknown_license_files": [],
+            "errors": [f"Java error: {e}"],
+        }
+
+
+def _extracted_dir(temp_dir: str) -> str | None:
+    # Loop through all the dirs in temp_dir
+    extract_dir = None
+    for dir_name in os.listdir(temp_dir):
+        if dir_name.startswith("."):
+            continue
+        dir_path = os.path.join(temp_dir, dir_name)
+        if not os.path.isdir(dir_path):
+            raise ValueError(f"Unknown file type found in temporary directory: 
{dir_path}")
+        if extract_dir is None:
+            extract_dir = dir_path
+        else:
+            raise ValueError(f"Multiple root directories found: {extract_dir}, 
{dir_path}")
+    return extract_dir
diff --git a/atr/tasks/checks/targz.py b/atr/tasks/checks/targz.py
index 656edfe..b0454a8 100644
--- a/atr/tasks/checks/targz.py
+++ b/atr/tasks/checks/targz.py
@@ -42,7 +42,7 @@ async def integrity(args: checks.FunctionArguments) -> str | 
None:
 
     chunk_size = 4096
     try:
-        size = await asyncio.to_thread(archives.targz_total_size, 
str(artifact_abs_path), chunk_size)
+        size = await asyncio.to_thread(archives.total_size, 
str(artifact_abs_path), chunk_size)
         await recorder.success("Able to read all entries of the archive using 
tarfile", {"size": size})
     except Exception as e:
         await recorder.failure("Unable to read all entries of the archive 
using tarfile", {"error": str(e)})
diff --git a/atr/tasks/sbom.py b/atr/tasks/sbom.py
index d136ad1..7992ac1 100644
--- a/atr/tasks/sbom.py
+++ b/atr/tasks/sbom.py
@@ -87,7 +87,7 @@ async def _generate_cyclonedx_core(artifact_path: str, 
output_path: str) -> dict
         # TODO: Ideally we'd have task dependencies or archive caching
         _LOGGER.info(f"Extracting {artifact_path} to {temp_dir}")
         extracted_size = await asyncio.to_thread(
-            archives.targz_extract,
+            archives.extract,
             artifact_path,
             str(temp_dir),
             max_size=_CONFIG.MAX_EXTRACT_SIZE,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to