This is an automated email from the ASF dual-hosted git repository.

arm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tooling-trusted-releases.git

commit ca1db4aee2ad8f57dd45366a8218504a61c8a482
Author: Alastair McFarlane <[email protected]>
AuthorDate: Thu Feb 19 17:52:39 2026 +0000

    Change attestable hashes to dict and reuse to resolve TOCTOU of check 
result. Use attestable hashes for check reports. Add version to cache key. Add 
file hash to hash and signature check and github SHA to source_tree.
---
 atr/api/__init__.py               | 32 ++++++++------
 atr/attestable.py                 | 19 +++++----
 atr/db/__init__.py                |  3 ++
 atr/get/checks.py                 | 17 +++++---
 atr/get/result.py                 |  1 -
 atr/models/attestable.py          |  2 +-
 atr/storage/readers/checks.py     | 51 +++++++----------------
 atr/storage/readers/releases.py   | 13 +++++-
 atr/tasks/__init__.py             | 86 +++++++++++++++++++++-----------------
 atr/tasks/checks/__init__.py      | 87 +++++++++++++++++++++++++++++++--------
 atr/tasks/checks/compare.py       |  5 ++-
 atr/tasks/checks/hashing.py       |  5 ++-
 atr/tasks/checks/license.py       |  5 ++-
 atr/tasks/checks/paths.py         | 13 ++++--
 atr/tasks/checks/rat.py           |  3 +-
 atr/tasks/checks/signature.py     |  5 ++-
 atr/tasks/checks/targz.py         |  5 ++-
 atr/tasks/checks/zipformat.py     |  5 ++-
 tests/unit/recorders.py           |  2 +-
 tests/unit/test_checks_compare.py |  6 +--
 20 files changed, 224 insertions(+), 141 deletions(-)

diff --git a/atr/api/__init__.py b/atr/api/__init__.py
index 1338f00e..4a0321ce 100644
--- a/atr/api/__init__.py
+++ b/atr/api/__init__.py
@@ -29,6 +29,7 @@ import sqlalchemy
 import sqlmodel
 import werkzeug.exceptions as exceptions
 
+import atr.attestable as attestable
 import atr.blueprints.api as api
 import atr.config as config
 import atr.db as db
@@ -75,24 +76,24 @@ async def checks_list(project: str, version: str) -> 
DictResponse:
     # TODO: Add phase in the response, and the revision too
     _simple_check(project, version)
     # TODO: Merge with checks_list_project_version_revision
+
     async with db.session() as data:
         release_name = sql.release_name(project, version)
         release = await 
data.release(name=release_name).demand(exceptions.NotFound(f"Release 
{release_name} not found"))
-        check_results = await 
data.check_result(release_name=release_name).all()
-
-    revision = None
-    for check_result in check_results:
-        if revision is None:
-            revision = check_result.revision_number
-        elif revision != check_result.revision_number:
-            raise exceptions.InternalServerError("Revision mismatch")
-    if revision is None:
-        raise exceptions.InternalServerError("No revision found")
+        if not release.latest_revision_number:
+            raise exceptions.InternalServerError("No latest revision found")
+        file_path_checks = await attestable.load_checks(project, version, 
release.latest_revision_number)
+        if file_path_checks:
+            check_results = await data.check_result(
+                inputs_hash_in=[h for inner in file_path_checks.values() for h 
in inner.values()]
+            ).all()
+        else:
+            check_results = []
 
     return models.api.ChecksListResults(
         endpoint="/checks/list",
         checks=check_results,
-        checks_revision=revision,
+        checks_revision=release.latest_revision_number,
         current_phase=release.phase,
     ).model_dump(), 200
 
@@ -126,7 +127,14 @@ async def checks_list_revision(project: str, version: str, 
revision: str) -> Dic
         if revision_result is None:
             raise exceptions.NotFound(f"Revision '{revision}' does not exist 
for release '{project}-{version}'")
 
-        check_results = await data.check_result(release_name=release_name, 
revision_number=revision).all()
+        file_path_checks = await attestable.load_checks(project, version, 
revision)
+        if file_path_checks:
+            check_results = await data.check_result(
+                inputs_hash_in=[h for inner in file_path_checks.values() for h 
in inner.values()]
+            ).all()
+        else:
+            check_results = []
+
     return models.api.ChecksListResults(
         endpoint="/checks/list",
         checks=check_results,
diff --git a/atr/attestable.py b/atr/attestable.py
index 50260093..08404503 100644
--- a/atr/attestable.py
+++ b/atr/attestable.py
@@ -98,7 +98,7 @@ async def load_checks(
     project_name: str,
     version_name: str,
     revision_number: str,
-) -> list[int] | None:
+) -> dict[str, dict[str, str]] | None:
     file_path = attestable_checks_path(project_name, version_name, 
revision_number)
     if await aiofiles.os.path.isfile(file_path):
         try:
@@ -107,7 +107,7 @@ async def load_checks(
             return models.AttestableChecksV1.model_validate(data).checks
         except (json.JSONDecodeError, pydantic.ValidationError) as e:
             log.warning(f"Could not parse {file_path}: {e}")
-    return []
+    return {}
 
 
 def migrate_to_paths_files() -> int:
@@ -182,18 +182,21 @@ async def write_checks_data(
     project_name: str,
     version_name: str,
     revision_number: str,
-    checks: list[int],
+    rel_path: str,
+    checks: dict[str, str],
 ) -> None:
-    log.info(f"Writing checks for 
{project_name}/{version_name}/{revision_number}: {checks}")
+    log.info(f"Writing checks for 
{project_name}/{version_name}/{revision_number}/{rel_path}: {checks}")
 
     def modify(content: str) -> str:
         try:
             current = AttestableChecksV1.model_validate_json(content).checks
         except pydantic.ValidationError:
-            current = []
-        new_checks = set(current or [])
-        new_checks.update(checks)
-        result = models.AttestableChecksV1(checks=sorted(new_checks))
+            current = {}
+        if rel_path not in current:
+            current[rel_path] = checks
+        else:
+            current[rel_path].update(checks)
+        result = models.AttestableChecksV1(checks=current)
         return result.model_dump_json(indent=2)
 
     await util.atomic_modify_file(attestable_checks_path(project_name, 
version_name, revision_number), modify)
diff --git a/atr/db/__init__.py b/atr/db/__init__.py
index fb282628..76df7d4e 100644
--- a/atr/db/__init__.py
+++ b/atr/db/__init__.py
@@ -166,6 +166,7 @@ class Session(sqlalchemy.ext.asyncio.AsyncSession):
         message: Opt[str] = NOT_SET,
         data: Opt[Any] = NOT_SET,
         inputs_hash: Opt[str] = NOT_SET,
+        inputs_hash_in: Opt[list[str]] = NOT_SET,
         _release: bool = False,
     ) -> Query[sql.CheckResult]:
         query = sqlmodel.select(sql.CheckResult)
@@ -196,6 +197,8 @@ class Session(sqlalchemy.ext.asyncio.AsyncSession):
             query = query.where(sql.CheckResult.data == data)
         if is_defined(inputs_hash):
             query = query.where(sql.CheckResult.inputs_hash == inputs_hash)
+        if is_defined(inputs_hash_in):
+            query = 
query.where(via(sql.CheckResult.inputs_hash).in_(inputs_hash_in))
 
         if _release:
             query = query.options(joined_load(sql.CheckResult.release))
diff --git a/atr/get/checks.py b/atr/get/checks.py
index cc88c1c5..e6297e9d 100644
--- a/atr/get/checks.py
+++ b/atr/get/checks.py
@@ -23,6 +23,7 @@ import asfquart.base as base
 import htpy
 import quart
 
+import atr.attestable as attestable
 import atr.blueprints.get as get
 import atr.db as db
 import atr.db.interaction as interaction
@@ -232,11 +233,17 @@ async def _compute_stats(  # noqa: C901
         empty_stats = FileStats(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
         return {p: empty_stats for p in paths}, empty_stats
 
-    async with db.session() as data:
-        check_results = await data.check_result(
-            release_name=release.name,
-            revision_number=release.latest_revision_number,
-        ).all()
+    file_path_checks = await attestable.load_checks(
+        release.project_name, release.version, release.latest_revision_number
+    )
+
+    if file_path_checks:
+        async with db.session() as data:
+            check_results = await data.check_result(
+                inputs_hash_in=[h for inner in file_path_checks.values() for h 
in inner.values()]
+            ).all()
+    else:
+        check_results = []
 
     for cr in check_results:
         if not cr.primary_rel_path:
diff --git a/atr/get/result.py b/atr/get/result.py
index 06bbf328..5fccc29c 100644
--- a/atr/get/result.py
+++ b/atr/get/result.py
@@ -52,7 +52,6 @@ async def data(
 
         check_result = await data.check_result(
             id=check_id,
-            release_name=release.name,
         ).demand(base.ASFQuartException("Check result not found", 
errorcode=404))
 
     payload = check_result.model_dump(mode="json", exclude={"release"})
diff --git a/atr/models/attestable.py b/atr/models/attestable.py
index 2af1cc2f..4e000984 100644
--- a/atr/models/attestable.py
+++ b/atr/models/attestable.py
@@ -29,7 +29,7 @@ class HashEntry(schema.Strict):
 
 class AttestableChecksV1(schema.Strict):
     version: Literal[1] = 1
-    checks: list[int] = schema.factory(list)
+    checks: dict[str, dict[str, str]] = schema.factory(dict)
 
 
 class AttestablePathsV1(schema.Strict):
diff --git a/atr/storage/readers/checks.py b/atr/storage/readers/checks.py
index 1a95df2c..26bc4ac5 100644
--- a/atr/storage/readers/checks.py
+++ b/atr/storage/readers/checks.py
@@ -18,51 +18,18 @@
 # Removing this will cause circular imports
 from __future__ import annotations
 
-import importlib
 from typing import TYPE_CHECKING
 
 import atr.attestable as attestable
 import atr.db as db
-import atr.hashes as hashes
 import atr.models.sql as sql
 import atr.storage as storage
 import atr.storage.types as types
-import atr.tasks.checks as checks
 import atr.util as util
 
 if TYPE_CHECKING:
     import pathlib
-    from collections.abc import Callable, Sequence
-
-
-async def _filter_check_results_by_hash(
-    all_check_results: Sequence[sql.CheckResult],
-    rel_path: pathlib.Path,
-    input_hash_by_module: dict[str, str | None],
-    release: sql.Release,
-) -> Sequence[sql.CheckResult]:
-    filtered_check_results = []
-    if release.latest_revision_number is None:
-        raise ValueError("Release has no revision - Invalid state")
-    for cr in all_check_results:
-        if cr.checker not in input_hash_by_module:
-            module_path = cr.checker.rsplit(".", 1)[0]
-            try:
-                module = importlib.import_module(module_path)
-                policy_keys: list[str] = module.INPUT_POLICY_KEYS
-                extra_arg_names: list[str] = getattr(module, 
"INPUT_EXTRA_ARGS", [])
-            except (ImportError, AttributeError):
-                policy_keys = []
-                extra_arg_names = []
-            extra_args = await checks.resolve_extra_args(extra_arg_names, 
release)
-            cache_key = await checks.resolve_cache_key(
-                cr.checker, policy_keys, release, 
release.latest_revision_number, extra_args, file=rel_path.name
-            )
-            input_hash_by_module[cr.checker] = 
hashes.compute_dict_hash(cache_key) if cache_key else None
-
-        if cr.inputs_hash == input_hash_by_module[cr.checker]:
-            filtered_check_results.append(cr)
-    return filtered_check_results
+    from collections.abc import Callable
 
 
 class GeneralPublic:
@@ -82,18 +49,28 @@ class GeneralPublic:
         if release.latest_revision_number is None:
             raise ValueError("Release has no revision - Invalid state")
 
-        check_ids = await attestable.load_checks(release.project_name, 
release.version, release.latest_revision_number)
+        file_path_checks = await attestable.load_checks(
+            release.project_name, release.version, 
release.latest_revision_number
+        )
         all_check_results = (
             [
                 a
-                for a in await self.__data.check_result(id_in=check_ids)
+                for a in await self.__data.check_result(
+                    inputs_hash_in=[
+                        h
+                        for key in ("", str(rel_path))
+                        if key in file_path_checks
+                        for h in file_path_checks[key].values()
+                    ],
+                    primary_rel_path=str(rel_path),
+                )
                 .order_by(
                     
sql.validate_instrumented_attribute(sql.CheckResult.checker).asc(),
                     
sql.validate_instrumented_attribute(sql.CheckResult.created).desc(),
                 )
                 .all()
             ]
-            if check_ids
+            if file_path_checks
             else []
         )
 
diff --git a/atr/storage/readers/releases.py b/atr/storage/readers/releases.py
index 90445082..e0da2459 100644
--- a/atr/storage/readers/releases.py
+++ b/atr/storage/readers/releases.py
@@ -123,8 +123,17 @@ class GeneralPublic:
         self, release: sql.Release, latest_revision_number: str, info: 
types.PathInfo
     ) -> None:
         match_ignore = await 
self.__read_as.checks.ignores_matcher(release.project_name)
-        check_ids = await attestable.load_checks(release.project_name, 
release.version, latest_revision_number)
-        attestable_checks = [a for a in await 
self.__data.check_result(id_in=check_ids).all()] if check_ids else []
+        file_path_checks = await attestable.load_checks(release.project_name, 
release.version, latest_revision_number)
+        attestable_checks = (
+            [
+                a
+                for a in await self.__data.check_result(
+                    inputs_hash_in=[h for inner in file_path_checks.values() 
for h in inner.values()]
+                ).all()
+            ]
+            if file_path_checks
+            else []
+        )
 
         cs = types.ChecksSubset(
             checks=attestable_checks,
diff --git a/atr/tasks/__init__.py b/atr/tasks/__init__.py
index 60b5df83..3acb79df 100644
--- a/atr/tasks/__init__.py
+++ b/atr/tasks/__init__.py
@@ -63,14 +63,15 @@ async def asc_checks(
                 sql.TaskType.SIGNATURE_CHECK,
                 release,
                 revision,
-                data,
                 signature_path,
                 check_cache_key=await checks.resolve_cache_key(
                     resolve(sql.TaskType.SIGNATURE_CHECK),
+                    signature.CHECK_VERSION,
                     signature.INPUT_POLICY_KEYS,
                     release,
                     revision,
-                    await 
checks.resolve_extra_args(signature.INPUT_EXTRA_ARGS, release),
+                    await 
checks.resolve_extra_args(signature.INPUT_EXTRA_ARGS, release, signature_path),
+                    file=signature_path,
                 ),
                 extra_args={"committee_name": release.committee.name},
             )
@@ -177,9 +178,9 @@ async def draft_checks(
             sql.TaskType.PATHS_CHECK,
             release,
             revision_number,
-            caller_data,
             check_cache_key=await checks.resolve_cache_key(
                 resolve(sql.TaskType.PATHS_CHECK),
+                paths.CHECK_VERSION,
                 paths.INPUT_POLICY_KEYS,
                 release,
                 revision_number,
@@ -226,7 +227,6 @@ async def _draft_file_checks(
             sql.TaskType.SBOM_TOOL_SCORE,
             release,
             revision_number,
-            caller_data,
             path_str,
             extra_args={
                 "project_name": project_name,
@@ -295,33 +295,33 @@ async def queued(
     task_type: sql.TaskType,
     release: sql.Release,
     revision_number: str,
-    data: db.Session | None = None,
     primary_rel_path: str | None = None,
     extra_args: dict[str, Any] | None = None,
     check_cache_key: dict[str, Any] | None = None,
 ) -> sql.Task | None:
-    hash_val = None
+    hash_val: str | None = None
     if check_cache_key is not None:
+        if "checker" not in check_cache_key:
+            raise ValueError("Cache key must contain a 'checker' key")
         hash_val = hashes.compute_dict_hash(check_cache_key)
-        if not data:
-            raise RuntimeError("DB Session is required for check_cache_key")
-        existing = await data.check_result(inputs_hash=hash_val, 
release_name=release.name).all()
-        if existing:
-            await attestable.write_checks_data(
-                release.project.name, release.version, revision_number, [c.id 
for c in existing]
-            )
-            return None
-        return sql.Task(
-            status=sql.TaskStatus.QUEUED,
-            task_type=task_type,
-            task_args=extra_args or {},
-            asf_uid=asf_uid,
-            project_name=release.project.name,
-            version_name=release.version,
-            revision_number=revision_number,
-            primary_rel_path=primary_rel_path,
-            inputs_hash=hash_val,
+        await attestable.write_checks_data(
+            release.project.name,
+            release.version,
+            revision_number,
+            primary_rel_path or "",
+            {check_cache_key["checker"]: hash_val},
         )
+    return sql.Task(
+        status=sql.TaskStatus.QUEUED,
+        task_type=task_type,
+        task_args=extra_args or {},
+        asf_uid=asf_uid,
+        project_name=release.project.name,
+        version_name=release.version,
+        revision_number=revision_number,
+        primary_rel_path=primary_rel_path,
+        inputs_hash=hash_val,
+    )
 
 
 async def _add_task(data: db.Session, task: sql.Task) -> None:
@@ -405,14 +405,14 @@ async def sha_checks(
             sql.TaskType.HASHING_CHECK,
             release,
             revision,
-            data,
             hash_file,
             check_cache_key=await checks.resolve_cache_key(
                 resolve(sql.TaskType.HASHING_CHECK),
+                hashing.CHECK_VERSION,
                 hashing.INPUT_POLICY_KEYS,
                 release,
                 revision,
-                await checks.resolve_extra_args(hashing.INPUT_EXTRA_ARGS, 
release),
+                await checks.resolve_extra_args(hashing.INPUT_EXTRA_ARGS, 
release, hash_file),
                 file=hash_file,
             ),
         )
@@ -429,6 +429,7 @@ async def tar_gz_checks(
     is_podling = (release.project.committee is not None) and 
release.project.committee.is_podling
     compare_ck = await checks.resolve_cache_key(
         resolve(sql.TaskType.COMPARE_SOURCE_TREES),
+        compare.CHECK_VERSION,
         compare.INPUT_POLICY_KEYS,
         release,
         revision,
@@ -437,6 +438,7 @@ async def tar_gz_checks(
     )
     license_h_ck = await checks.resolve_cache_key(
         resolve(sql.TaskType.LICENSE_HEADERS),
+        license.CHECK_VERSION,
         license.INPUT_POLICY_KEYS,
         release,
         revision,
@@ -445,6 +447,7 @@ async def tar_gz_checks(
     )
     license_f_ck = await checks.resolve_cache_key(
         resolve(sql.TaskType.LICENSE_FILES),
+        license.CHECK_VERSION,
         license.INPUT_POLICY_KEYS,
         release,
         revision,
@@ -453,6 +456,7 @@ async def tar_gz_checks(
     )
     rat_ck = await checks.resolve_cache_key(
         resolve(sql.TaskType.RAT_CHECK),
+        rat.CHECK_VERSION,
         rat.INPUT_POLICY_KEYS,
         release,
         revision,
@@ -461,6 +465,7 @@ async def tar_gz_checks(
     )
     targz_i_ck = await checks.resolve_cache_key(
         resolve(sql.TaskType.TARGZ_INTEGRITY),
+        targz.CHECK_VERSION,
         targz.INPUT_POLICY_KEYS,
         release,
         revision,
@@ -469,6 +474,7 @@ async def tar_gz_checks(
     )
     targz_s_ck = await checks.resolve_cache_key(
         resolve(sql.TaskType.TARGZ_STRUCTURE),
+        targz.CHECK_VERSION,
         targz.INPUT_POLICY_KEYS,
         release,
         revision,
@@ -476,21 +482,20 @@ async def tar_gz_checks(
         file=path,
     )
     tasks = [
-        queued(asf_uid, sql.TaskType.COMPARE_SOURCE_TREES, release, revision, 
data, path, check_cache_key=compare_ck),
+        queued(asf_uid, sql.TaskType.COMPARE_SOURCE_TREES, release, revision, 
path, check_cache_key=compare_ck),
         queued(
             asf_uid,
             sql.TaskType.LICENSE_FILES,
             release,
             revision,
-            data,
             path,
             check_cache_key=license_f_ck,
             extra_args={"is_podling": is_podling},
         ),
-        queued(asf_uid, sql.TaskType.LICENSE_HEADERS, release, revision, data, 
path, check_cache_key=license_h_ck),
-        queued(asf_uid, sql.TaskType.RAT_CHECK, release, revision, data, path, 
check_cache_key=rat_ck),
-        queued(asf_uid, sql.TaskType.TARGZ_INTEGRITY, release, revision, data, 
path, check_cache_key=targz_i_ck),
-        queued(asf_uid, sql.TaskType.TARGZ_STRUCTURE, release, revision, data, 
path, check_cache_key=targz_s_ck),
+        queued(asf_uid, sql.TaskType.LICENSE_HEADERS, release, revision, path, 
check_cache_key=license_h_ck),
+        queued(asf_uid, sql.TaskType.RAT_CHECK, release, revision, path, 
check_cache_key=rat_ck),
+        queued(asf_uid, sql.TaskType.TARGZ_INTEGRITY, release, revision, path, 
check_cache_key=targz_i_ck),
+        queued(asf_uid, sql.TaskType.TARGZ_STRUCTURE, release, revision, path, 
check_cache_key=targz_s_ck),
     ]
 
     return await asyncio.gather(*tasks)
@@ -531,6 +536,7 @@ async def zip_checks(
     is_podling = (release.project.committee is not None) and 
release.project.committee.is_podling
     compare_ck = await checks.resolve_cache_key(
         resolve(sql.TaskType.COMPARE_SOURCE_TREES),
+        compare.CHECK_VERSION,
         compare.INPUT_POLICY_KEYS,
         release,
         revision,
@@ -539,6 +545,7 @@ async def zip_checks(
     )
     license_h_ck = await checks.resolve_cache_key(
         resolve(sql.TaskType.LICENSE_HEADERS),
+        license.CHECK_VERSION,
         license.INPUT_POLICY_KEYS,
         release,
         revision,
@@ -547,6 +554,7 @@ async def zip_checks(
     )
     license_f_ck = await checks.resolve_cache_key(
         resolve(sql.TaskType.LICENSE_FILES),
+        license.CHECK_VERSION,
         license.INPUT_POLICY_KEYS,
         release,
         revision,
@@ -555,6 +563,7 @@ async def zip_checks(
     )
     rat_ck = await checks.resolve_cache_key(
         resolve(sql.TaskType.RAT_CHECK),
+        rat.CHECK_VERSION,
         rat.INPUT_POLICY_KEYS,
         release,
         revision,
@@ -563,6 +572,7 @@ async def zip_checks(
     )
     zip_i_ck = await checks.resolve_cache_key(
         resolve(sql.TaskType.ZIPFORMAT_INTEGRITY),
+        zipformat.CHECK_VERSION,
         zipformat.INPUT_POLICY_KEYS,
         release,
         revision,
@@ -571,6 +581,7 @@ async def zip_checks(
     )
     zip_s_ck = await checks.resolve_cache_key(
         resolve(sql.TaskType.ZIPFORMAT_STRUCTURE),
+        zipformat.CHECK_VERSION,
         zipformat.INPUT_POLICY_KEYS,
         release,
         revision,
@@ -579,21 +590,20 @@ async def zip_checks(
     )
 
     tasks = [
-        queued(asf_uid, sql.TaskType.COMPARE_SOURCE_TREES, release, revision, 
data, path, check_cache_key=compare_ck),
+        queued(asf_uid, sql.TaskType.COMPARE_SOURCE_TREES, release, revision, 
path, check_cache_key=compare_ck),
         queued(
             asf_uid,
             sql.TaskType.LICENSE_FILES,
             release,
             revision,
-            data,
             path,
             check_cache_key=license_f_ck,
             extra_args={"is_podling": is_podling},
         ),
-        queued(asf_uid, sql.TaskType.LICENSE_HEADERS, release, revision, data, 
path, check_cache_key=license_h_ck),
-        queued(asf_uid, sql.TaskType.RAT_CHECK, release, revision, data, path, 
check_cache_key=rat_ck),
-        queued(asf_uid, sql.TaskType.ZIPFORMAT_INTEGRITY, release, revision, 
data, path, check_cache_key=zip_i_ck),
-        queued(asf_uid, sql.TaskType.ZIPFORMAT_STRUCTURE, release, revision, 
data, path, check_cache_key=zip_s_ck),
+        queued(asf_uid, sql.TaskType.LICENSE_HEADERS, release, revision, path, 
check_cache_key=license_h_ck),
+        queued(asf_uid, sql.TaskType.RAT_CHECK, release, revision, path, 
check_cache_key=rat_ck),
+        queued(asf_uid, sql.TaskType.ZIPFORMAT_INTEGRITY, release, revision, 
path, check_cache_key=zip_i_ck),
+        queued(asf_uid, sql.TaskType.ZIPFORMAT_STRUCTURE, release, revision, 
path, check_cache_key=zip_s_ck),
     ]
     return await asyncio.gather(*tasks)
 
diff --git a/atr/tasks/checks/__init__.py b/atr/tasks/checks/__init__.py
index db3616a9..5ece2282 100644
--- a/atr/tasks/checks/__init__.py
+++ b/atr/tasks/checks/__init__.py
@@ -20,10 +20,12 @@ from __future__ import annotations
 import dataclasses
 import datetime
 import functools
+import json
 from typing import TYPE_CHECKING, Any, Final
 
 import aiofiles
 import aiofiles.os
+import pydantic
 import sqlmodel
 
 if TYPE_CHECKING:
@@ -39,6 +41,7 @@ import atr.file_paths as file_paths
 import atr.hashes as hashes
 import atr.log as log
 import atr.models.sql as sql
+import atr.sbom.models.github as github_models
 import atr.util as util
 
 
@@ -104,7 +107,15 @@ class Recorder:
         member_rel_path: str | None = None,
         afresh: bool = True,
     ) -> Recorder:
-        recorder = cls(checker, project_name, version_name, revision_number, 
primary_rel_path, member_rel_path, afresh)
+        recorder = cls(
+            checker,
+            project_name,
+            version_name,
+            revision_number,
+            primary_rel_path,
+            member_rel_path,
+            afresh,
+        )
         if afresh is True:
             # Clear outer path whether it's specified or not
             await recorder.clear(primary_rel_path=primary_rel_path, 
member_rel_path=member_rel_path)
@@ -198,7 +209,11 @@ class Recorder:
         return matches(str(abs_path))
 
     async def cache_key_set(
-        self, policy_keys: list[str], input_args: list[str] | None = None, 
checker: str | None = None
+        self,
+        policy_keys: list[str],
+        version,
+        input_args: list[str] | None = None,
+        checker: str | None = None,
     ) -> bool:
         # TODO: Should this just be in the constructor?
 
@@ -216,9 +231,15 @@ class Recorder:
             release = await data.release(
                 name=self.release_name, _release_policy=True, 
_project_release_policy=True, _project=True
             ).demand(RuntimeError(f"Release {self.release_name} not found"))
-            args = await resolve_extra_args(input_args or [], release)
+            args = await resolve_extra_args(input_args or [], release, 
self.primary_rel_path)
             cache_key = await resolve_cache_key(
-                checker or self.checker, policy_keys, release, 
self.revision_number, args, file=self.primary_rel_path
+                checker or self.checker,
+                version,
+                policy_keys,
+                release,
+                self.revision_number,
+                args,
+                file=self.primary_rel_path,
             )
             self.__input_hash = hashes.compute_dict_hash(cache_key) if 
cache_key else None
         return True
@@ -257,7 +278,6 @@ class Recorder:
             primary_rel_path=primary_rel_path,
             member_rel_path=member_rel_path,
         )
-        await attestable.write_checks_data(self.project_name, 
self.version_name, self.revision_number, [result.id])
         return result
 
     async def exception(
@@ -274,7 +294,6 @@ class Recorder:
             primary_rel_path=primary_rel_path,
             member_rel_path=member_rel_path,
         )
-        await attestable.write_checks_data(self.project_name, 
self.version_name, self.revision_number, [result.id])
         return result
 
     async def failure(
@@ -291,7 +310,6 @@ class Recorder:
             primary_rel_path=primary_rel_path,
             member_rel_path=member_rel_path,
         )
-        await attestable.write_checks_data(self.project_name, 
self.version_name, self.revision_number, [result.id])
         return result
 
     async def success(
@@ -308,7 +326,6 @@ class Recorder:
             primary_rel_path=primary_rel_path,
             member_rel_path=member_rel_path,
         )
-        await attestable.write_checks_data(self.project_name, 
self.version_name, self.revision_number, [result.id])
         return result
 
     async def use_check_cache(self) -> bool:
@@ -337,7 +354,6 @@ class Recorder:
             primary_rel_path=primary_rel_path,
             member_rel_path=member_rel_path,
         )
-        await attestable.write_checks_data(self.project_name, 
self.version_name, self.revision_number, [result.id])
         return result
 
 
@@ -347,6 +363,7 @@ def function_key(func: Callable[..., Any] | str) -> str:
 
 async def resolve_cache_key(
     checker: str | Callable[..., Any],
+    checker_version: str,
     policy_keys: list[str],
     release: sql.Release,
     revision: str,
@@ -381,7 +398,7 @@ async def resolve_cache_key(
         return {**cache_key, **args}
 
 
-async def resolve_extra_args(arg_names: list[str], release: sql.Release) -> 
dict[str, Any]:
+async def resolve_extra_args(arg_names: list[str], release: sql.Release, 
rel_path: str | None = None) -> dict[str, Any]:
     result: dict[str, Any] = {}
     for name in arg_names:
         resolver = _EXTRA_ARG_RESOLVERS.get(name, None)
@@ -389,7 +406,7 @@ async def resolve_extra_args(arg_names: list[str], release: 
sql.Release) -> dict
         if resolver is None:
             log.warning(f"Unknown extra arg resolver: {name}")
             return {}
-        result[name] = await resolver(release)
+        result[name] = await resolver(release, rel_path)
     return result
 
 
@@ -407,7 +424,7 @@ def with_model(cls: type[schema.Strict]) -> 
Callable[[Callable[..., Any]], Calla
     return decorator
 
 
-async def _resolve_all_files(release: sql.Release) -> list[str]:
+async def _resolve_all_files(release: sql.Release, rel_path: str | None = 
None) -> list[str]:
     if not release.latest_revision_number:
         return []
     if not (
@@ -422,21 +439,57 @@ async def _resolve_all_files(release: sql.Release) -> 
list[str]:
         return []
     relative_paths = [p async for p in util.paths_recursive(base_path)]
     relative_paths_set = set(str(p) for p in relative_paths)
-    return list(relative_paths_set)
+    return list(sorted(relative_paths_set))
 
 
-async def _resolve_is_podling(release: sql.Release) -> bool:
+async def _resolve_is_podling(release: sql.Release, rel_path: str | None = 
None) -> bool:
     return (release.committee is not None) and release.committee.is_podling
 
 
-async def _resolve_committee_name(release: sql.Release) -> str:
+async def _resolve_github_tp_sha(release: sql.Release, rel_path: str | None = 
None) -> str:
+    if not release.latest_revision_number:
+        return ""
+    payload_path = attestable.github_tp_payload_path(
+        release.project_name, release.version, release.latest_revision_number
+    )
+    if not await aiofiles.os.path.isfile(payload_path):
+        return ""
+    try:
+        async with aiofiles.open(payload_path, encoding="utf-8") as f:
+            data = json.loads(await f.read())
+        if not isinstance(data, dict):
+            log.warning(f"TP payload was not a JSON object in {payload_path}")
+            return ""
+        tp_data = github_models.TrustedPublisherPayload.model_validate(data)
+        return tp_data.sha
+    except (OSError, json.JSONDecodeError) as e:
+        log.warning(f"Failed to read TP payload from {payload_path}: {e}")
+        return ""
+    except pydantic.ValidationError as e:
+        log.warning(f"Failed to validate TP payload from {payload_path}: {e}")
+        return ""
+
+
+async def _resolve_committee_name(release: sql.Release, rel_path: str | None = 
None) -> str:
     if release.committee is None:
         raise ValueError("Release has no committee")
     return release.committee.name
 
 
-_EXTRA_ARG_RESOLVERS: Final[dict[str, Callable[[sql.Release], Any]]] = {
+async def _resolve_unsuffixed_file_hash(release: sql.Release, rel_path: str | 
None = None) -> str:
+    if (not rel_path) or (not release.latest_revision_number):
+        return ""
+    abs_path = file_paths.revision_path_for_file(
+        release.project_name, release.version, release.latest_revision_number, 
rel_path
+    )
+    plain_path = abs_path.with_suffix("")
+    return await hashes.compute_file_hash(plain_path)
+
+
+_EXTRA_ARG_RESOLVERS: Final[dict[str, Callable[[sql.Release, str | None], 
Any]]] = {
     "all_files": _resolve_all_files,
-    "is_podling": _resolve_is_podling,
     "committee_name": _resolve_committee_name,
+    "github_tp_sha": _resolve_github_tp_sha,
+    "is_podling": _resolve_is_podling,
+    "unsuffixed_file_hash": _resolve_unsuffixed_file_hash,
 }
diff --git a/atr/tasks/checks/compare.py b/atr/tasks/checks/compare.py
index 162c9673..70f8daf1 100644
--- a/atr/tasks/checks/compare.py
+++ b/atr/tasks/checks/compare.py
@@ -53,7 +53,8 @@ _PERMITTED_ADDED_PATHS: Final[dict[str, list[str]]] = {
 }
 # Release policy fields which this check relies on - used for result caching
 INPUT_POLICY_KEYS: Final[list[str]] = []
-INPUT_EXTRA_ARGS: Final[list[str]] = []
+INPUT_EXTRA_ARGS: Final[list[str]] = ["github_tp_sha"]
+CHECK_VERSION: Final[str] = "1"
 
 
 @dataclasses.dataclass
@@ -93,7 +94,7 @@ async def source_trees(args: checks.FunctionArguments) -> 
results.Results | None
         )
         return None
 
-    await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS)
+    await recorder.cache_key_set(INPUT_POLICY_KEYS, CHECK_VERSION, 
INPUT_EXTRA_ARGS)
 
     payload = await _load_tp_payload(args.project_name, args.version_name, 
args.revision_number)
     checkout_dir: str | None = None
diff --git a/atr/tasks/checks/hashing.py b/atr/tasks/checks/hashing.py
index cb3c8f77..60ef085f 100644
--- a/atr/tasks/checks/hashing.py
+++ b/atr/tasks/checks/hashing.py
@@ -27,7 +27,8 @@ import atr.tasks.checks as checks
 
 # Release policy fields which this check relies on - used for result caching
 INPUT_POLICY_KEYS: Final[list[str]] = []
-INPUT_EXTRA_ARGS: Final[list[str]] = []
+INPUT_EXTRA_ARGS: Final[list[str]] = ["unsuffixed_file_hash"]
+CHECK_VERSION: Final[str] = "1"
 
 
 async def check(args: checks.FunctionArguments) -> results.Results | None:
@@ -41,7 +42,7 @@ async def check(args: checks.FunctionArguments) -> 
results.Results | None:
         await recorder.failure("Unsupported hash algorithm", {"algorithm": 
algorithm})
         return None
 
-    await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS)
+    await recorder.cache_key_set(INPUT_POLICY_KEYS, CHECK_VERSION, 
INPUT_EXTRA_ARGS)
 
     # Remove the hash file suffix to get the artifact path
     # This replaces the last suffix, which is what we want
diff --git a/atr/tasks/checks/license.py b/atr/tasks/checks/license.py
index 49b20a7a..fecb3fa9 100644
--- a/atr/tasks/checks/license.py
+++ b/atr/tasks/checks/license.py
@@ -82,6 +82,7 @@ INCLUDED_PATTERNS: Final[list[str]] = [
 # Release policy fields which this check relies on - used for result caching
 INPUT_POLICY_KEYS: Final[list[str]] = [""]
 INPUT_EXTRA_ARGS: Final[list[str]] = ["is_podling"]
+CHECK_VERSION: Final[str] = "1"
 
 # Types
 
@@ -139,7 +140,7 @@ async def files(args: checks.FunctionArguments) -> 
results.Results | None:
             return None
 
     is_podling = args.extra_args.get("is_podling", False)
-    await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS)
+    await recorder.cache_key_set(INPUT_POLICY_KEYS, CHECK_VERSION, 
INPUT_EXTRA_ARGS)
 
     log.info(f"Checking license files for {artifact_abs_path} (rel: 
{args.primary_rel_path})")
 
@@ -172,7 +173,7 @@ async def headers(args: checks.FunctionArguments) -> 
results.Results | None:
         if project.policy_license_check_mode == sql.LicenseCheckMode.RAT:
             return None
 
-    await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS)
+    await recorder.cache_key_set(INPUT_POLICY_KEYS, CHECK_VERSION, 
INPUT_EXTRA_ARGS)
 
     # if await recorder.check_cache(artifact_abs_path):
     #     log.info(f"Using cached license headers result for 
{artifact_abs_path} (rel: {args.primary_rel_path})")
diff --git a/atr/tasks/checks/paths.py b/atr/tasks/checks/paths.py
index d37fc66b..c28900a8 100644
--- a/atr/tasks/checks/paths.py
+++ b/atr/tasks/checks/paths.py
@@ -40,6 +40,7 @@ _ALLOWED_TOP_LEVEL: Final = frozenset(
 # Release policy fields which this check relies on - used for result caching
 INPUT_POLICY_KEYS: Final[list[str]] = []
 INPUT_EXTRA_ARGS: Final[list[str]] = ["is_podling", "all_files"]
+CHECK_VERSION: Final[str] = "1"
 
 
 async def check(args: checks.FunctionArguments) -> results.Results | None:
@@ -89,9 +90,15 @@ async def check(args: checks.FunctionArguments) -> 
results.Results | None:
     relative_paths = [p async for p in util.paths_recursive(base_path)]
     relative_paths_set = set(str(p) for p in relative_paths)
 
-    await recorder_errors.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS, 
checker=checks.function_key(check))
-    await recorder_warnings.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS, 
checker=checks.function_key(check))
-    await recorder_success.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS, 
checker=checks.function_key(check))
+    await recorder_errors.cache_key_set(
+        INPUT_POLICY_KEYS, CHECK_VERSION, INPUT_EXTRA_ARGS, 
checker=checks.function_key(check)
+    )
+    await recorder_warnings.cache_key_set(
+        INPUT_POLICY_KEYS, CHECK_VERSION, INPUT_EXTRA_ARGS, 
checker=checks.function_key(check)
+    )
+    await recorder_success.cache_key_set(
+        INPUT_POLICY_KEYS, CHECK_VERSION, INPUT_EXTRA_ARGS, 
checker=checks.function_key(check)
+    )
 
     for relative_path in relative_paths:
         # Delegate processing of each path to the helper function
diff --git a/atr/tasks/checks/rat.py b/atr/tasks/checks/rat.py
index 9c065ec7..486ad18c 100644
--- a/atr/tasks/checks/rat.py
+++ b/atr/tasks/checks/rat.py
@@ -68,6 +68,7 @@ _STD_EXCLUSIONS_EXTENDED: Final[list[str]] = [
 # Release policy fields which this check relies on - used for result caching
 INPUT_POLICY_KEYS: Final[list[str]] = []
 INPUT_EXTRA_ARGS: Final[list[str]] = []
+CHECK_VERSION: Final[str] = "1"
 
 
 class RatError(RuntimeError):
@@ -88,7 +89,7 @@ async def check(args: checks.FunctionArguments) -> 
results.Results | None:
         log.info(f"Skipping RAT check for {artifact_abs_path} (mode is 
LIGHTWEIGHT)")
         return None
 
-    await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS)
+    await recorder.cache_key_set(INPUT_POLICY_KEYS, CHECK_VERSION, 
INPUT_EXTRA_ARGS)
 
     log.info(f"Checking RAT licenses for {artifact_abs_path} (rel: 
{args.primary_rel_path})")
 
diff --git a/atr/tasks/checks/signature.py b/atr/tasks/checks/signature.py
index 830d9cfe..d6c88f2e 100644
--- a/atr/tasks/checks/signature.py
+++ b/atr/tasks/checks/signature.py
@@ -32,7 +32,8 @@ import atr.util as util
 
 # Release policy fields which this check relies on - used for result caching
 INPUT_POLICY_KEYS: Final[list[str]] = []
-INPUT_EXTRA_ARGS: Final[list[str]] = ["committee_name"]
+INPUT_EXTRA_ARGS: Final[list[str]] = ["committee_name", "unsuffixed_file_hash"]
+CHECK_VERSION: Final[str] = "1"
 
 
 async def check(args: checks.FunctionArguments) -> results.Results | None:
@@ -54,7 +55,7 @@ async def check(args: checks.FunctionArguments) -> 
results.Results | None:
         await recorder.exception("Committee name is required", 
{"committee_name": committee_name})
         return None
 
-    await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS)
+    await recorder.cache_key_set(INPUT_POLICY_KEYS, CHECK_VERSION, 
INPUT_EXTRA_ARGS)
 
     log.info(
         f"Checking signature {primary_abs_path} for {artifact_abs_path}"
diff --git a/atr/tasks/checks/targz.py b/atr/tasks/checks/targz.py
index 7b369e5e..2415135d 100644
--- a/atr/tasks/checks/targz.py
+++ b/atr/tasks/checks/targz.py
@@ -28,6 +28,7 @@ import atr.util as util
 # Release policy fields which this check relies on - used for result caching
 INPUT_POLICY_KEYS: Final[list[str]] = []
 INPUT_EXTRA_ARGS: Final[list[str]] = []
+CHECK_VERSION: Final[str] = "1"
 
 
 class RootDirectoryError(Exception):
@@ -42,7 +43,7 @@ async def integrity(args: checks.FunctionArguments) -> 
results.Results | None:
     if not (artifact_abs_path := await recorder.abs_path()):
         return None
 
-    await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS)
+    await recorder.cache_key_set(INPUT_POLICY_KEYS, CHECK_VERSION, 
INPUT_EXTRA_ARGS)
 
     log.info(f"Checking integrity for {artifact_abs_path} (rel: 
{args.primary_rel_path})")
 
@@ -101,7 +102,7 @@ async def structure(args: checks.FunctionArguments) -> 
results.Results | None:
     if await recorder.primary_path_is_binary():
         return None
 
-    await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS)
+    await recorder.cache_key_set(INPUT_POLICY_KEYS, CHECK_VERSION, 
INPUT_EXTRA_ARGS)
 
     filename = artifact_abs_path.name
     basename_from_filename: Final[str] = (
diff --git a/atr/tasks/checks/zipformat.py b/atr/tasks/checks/zipformat.py
index ed85f7b2..516f662f 100644
--- a/atr/tasks/checks/zipformat.py
+++ b/atr/tasks/checks/zipformat.py
@@ -29,6 +29,7 @@ import atr.util as util
 # Release policy fields which this check relies on - used for result caching
 INPUT_POLICY_KEYS: Final[list[str]] = []
 INPUT_EXTRA_ARGS: Final[list[str]] = []
+CHECK_VERSION: Final[str] = "1"
 
 
 async def integrity(args: checks.FunctionArguments) -> results.Results | None:
@@ -37,7 +38,7 @@ async def integrity(args: checks.FunctionArguments) -> 
results.Results | None:
     if not (artifact_abs_path := await recorder.abs_path()):
         return None
 
-    await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS)
+    await recorder.cache_key_set(INPUT_POLICY_KEYS, CHECK_VERSION, 
INPUT_EXTRA_ARGS)
 
     log.info(f"Checking zip integrity for {artifact_abs_path} (rel: 
{args.primary_rel_path})")
 
@@ -63,7 +64,7 @@ async def structure(args: checks.FunctionArguments) -> 
results.Results | None:
     if await recorder.primary_path_is_binary():
         return None
 
-    await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS)
+    await recorder.cache_key_set(INPUT_POLICY_KEYS, CHECK_VERSION, 
INPUT_EXTRA_ARGS)
 
     log.info(f"Checking zip structure for {artifact_abs_path} (rel: 
{args.primary_rel_path})")
 
diff --git a/tests/unit/recorders.py b/tests/unit/recorders.py
index bdb20dcf..5d5ce29f 100644
--- a/tests/unit/recorders.py
+++ b/tests/unit/recorders.py
@@ -42,7 +42,7 @@ class RecorderStub(checks.Recorder):
         return self._path if (rel_path is None) else self._path / rel_path
 
     async def cache_key_set(
-        self, policy_keys: list[str], input_args: list[str] | None = None, 
checker: str | None = None
+        self, policy_keys: list[str], version: str, input_args: list[str] | 
None = None, checker: str | None = None
     ) -> bool:
         return False
 
diff --git a/tests/unit/test_checks_compare.py 
b/tests/unit/test_checks_compare.py
index e90e3314..5ec9b751 100644
--- a/tests/unit/test_checks_compare.py
+++ b/tests/unit/test_checks_compare.py
@@ -243,7 +243,7 @@ class RecorderStub(atr.tasks.checks.Recorder):
         self._is_source = is_source
 
     async def cache_key_set(
-        self, policy_keys: list[str], input_args: list[str] | None = None, 
checker: str | None = None
+        self, policy_keys: list[str], version: str, input_args: list[str] | 
None = None, checker: str | None = None
     ) -> bool:
         return False
 
@@ -272,8 +272,8 @@ class RecorderStub(atr.tasks.checks.Recorder):
     ) -> atr.models.sql.CheckResult:
         self.success_calls.append((message, data))
         return atr.models.sql.CheckResult(
-            release_name=self.release_name,
-            revision_number=self.revision_number,
+            release_name=None,
+            revision_number=None,
             checker=self.checker,
             primary_rel_path=primary_rel_path or self.primary_rel_path,
             member_rel_path=member_rel_path,


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to