This is an automated email from the ASF dual-hosted git repository. arm pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tooling-trusted-releases.git
commit ca1db4aee2ad8f57dd45366a8218504a61c8a482 Author: Alastair McFarlane <[email protected]> AuthorDate: Thu Feb 19 17:52:39 2026 +0000 Change attestable hashes to dict and reuse to resolve TOCTOU of check result. Use attestable hashes for check reports. Add version to cache key. Add file hash to hash and signature check and github SHA to source_tree. --- atr/api/__init__.py | 32 ++++++++------ atr/attestable.py | 19 +++++---- atr/db/__init__.py | 3 ++ atr/get/checks.py | 17 +++++--- atr/get/result.py | 1 - atr/models/attestable.py | 2 +- atr/storage/readers/checks.py | 51 +++++++---------------- atr/storage/readers/releases.py | 13 +++++- atr/tasks/__init__.py | 86 +++++++++++++++++++++----------------- atr/tasks/checks/__init__.py | 87 +++++++++++++++++++++++++++++++-------- atr/tasks/checks/compare.py | 5 ++- atr/tasks/checks/hashing.py | 5 ++- atr/tasks/checks/license.py | 5 ++- atr/tasks/checks/paths.py | 13 ++++-- atr/tasks/checks/rat.py | 3 +- atr/tasks/checks/signature.py | 5 ++- atr/tasks/checks/targz.py | 5 ++- atr/tasks/checks/zipformat.py | 5 ++- tests/unit/recorders.py | 2 +- tests/unit/test_checks_compare.py | 6 +-- 20 files changed, 224 insertions(+), 141 deletions(-) diff --git a/atr/api/__init__.py b/atr/api/__init__.py index 1338f00e..4a0321ce 100644 --- a/atr/api/__init__.py +++ b/atr/api/__init__.py @@ -29,6 +29,7 @@ import sqlalchemy import sqlmodel import werkzeug.exceptions as exceptions +import atr.attestable as attestable import atr.blueprints.api as api import atr.config as config import atr.db as db @@ -75,24 +76,24 @@ async def checks_list(project: str, version: str) -> DictResponse: # TODO: Add phase in the response, and the revision too _simple_check(project, version) # TODO: Merge with checks_list_project_version_revision + async with db.session() as data: release_name = sql.release_name(project, version) release = await data.release(name=release_name).demand(exceptions.NotFound(f"Release {release_name} not found")) - check_results = await data.check_result(release_name=release_name).all() - - revision = None - for check_result in check_results: - if revision is None: - revision = check_result.revision_number - elif revision != check_result.revision_number: - raise exceptions.InternalServerError("Revision mismatch") - if revision is None: - raise exceptions.InternalServerError("No revision found") + if not release.latest_revision_number: + raise exceptions.InternalServerError("No latest revision found") + file_path_checks = await attestable.load_checks(project, version, release.latest_revision_number) + if file_path_checks: + check_results = await data.check_result( + inputs_hash_in=[h for inner in file_path_checks.values() for h in inner.values()] + ).all() + else: + check_results = [] return models.api.ChecksListResults( endpoint="/checks/list", checks=check_results, - checks_revision=revision, + checks_revision=release.latest_revision_number, current_phase=release.phase, ).model_dump(), 200 @@ -126,7 +127,14 @@ async def checks_list_revision(project: str, version: str, revision: str) -> Dic if revision_result is None: raise exceptions.NotFound(f"Revision '{revision}' does not exist for release '{project}-{version}'") - check_results = await data.check_result(release_name=release_name, revision_number=revision).all() + file_path_checks = await attestable.load_checks(project, version, revision) + if file_path_checks: + check_results = await data.check_result( + inputs_hash_in=[h for inner in file_path_checks.values() for h in inner.values()] + ).all() + else: + check_results = [] + return models.api.ChecksListResults( endpoint="/checks/list", checks=check_results, diff --git a/atr/attestable.py b/atr/attestable.py index 50260093..08404503 100644 --- a/atr/attestable.py +++ b/atr/attestable.py @@ -98,7 +98,7 @@ async def load_checks( project_name: str, version_name: str, revision_number: str, -) -> list[int] | None: +) -> dict[str, dict[str, str]] | None: file_path = attestable_checks_path(project_name, version_name, revision_number) if await aiofiles.os.path.isfile(file_path): try: @@ -107,7 +107,7 @@ async def load_checks( return models.AttestableChecksV1.model_validate(data).checks except (json.JSONDecodeError, pydantic.ValidationError) as e: log.warning(f"Could not parse {file_path}: {e}") - return [] + return {} def migrate_to_paths_files() -> int: @@ -182,18 +182,21 @@ async def write_checks_data( project_name: str, version_name: str, revision_number: str, - checks: list[int], + rel_path: str, + checks: dict[str, str], ) -> None: - log.info(f"Writing checks for {project_name}/{version_name}/{revision_number}: {checks}") + log.info(f"Writing checks for {project_name}/{version_name}/{revision_number}/{rel_path}: {checks}") def modify(content: str) -> str: try: current = AttestableChecksV1.model_validate_json(content).checks except pydantic.ValidationError: - current = [] - new_checks = set(current or []) - new_checks.update(checks) - result = models.AttestableChecksV1(checks=sorted(new_checks)) + current = {} + if rel_path not in current: + current[rel_path] = checks + else: + current[rel_path].update(checks) + result = models.AttestableChecksV1(checks=current) return result.model_dump_json(indent=2) await util.atomic_modify_file(attestable_checks_path(project_name, version_name, revision_number), modify) diff --git a/atr/db/__init__.py b/atr/db/__init__.py index fb282628..76df7d4e 100644 --- a/atr/db/__init__.py +++ b/atr/db/__init__.py @@ -166,6 +166,7 @@ class Session(sqlalchemy.ext.asyncio.AsyncSession): message: Opt[str] = NOT_SET, data: Opt[Any] = NOT_SET, inputs_hash: Opt[str] = NOT_SET, + inputs_hash_in: Opt[list[str]] = NOT_SET, _release: bool = False, ) -> Query[sql.CheckResult]: query = sqlmodel.select(sql.CheckResult) @@ -196,6 +197,8 @@ class Session(sqlalchemy.ext.asyncio.AsyncSession): query = query.where(sql.CheckResult.data == data) if is_defined(inputs_hash): query = query.where(sql.CheckResult.inputs_hash == inputs_hash) + if is_defined(inputs_hash_in): + query = query.where(via(sql.CheckResult.inputs_hash).in_(inputs_hash_in)) if _release: query = query.options(joined_load(sql.CheckResult.release)) diff --git a/atr/get/checks.py b/atr/get/checks.py index cc88c1c5..e6297e9d 100644 --- a/atr/get/checks.py +++ b/atr/get/checks.py @@ -23,6 +23,7 @@ import asfquart.base as base import htpy import quart +import atr.attestable as attestable import atr.blueprints.get as get import atr.db as db import atr.db.interaction as interaction @@ -232,11 +233,17 @@ async def _compute_stats( # noqa: C901 empty_stats = FileStats(0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0) return {p: empty_stats for p in paths}, empty_stats - async with db.session() as data: - check_results = await data.check_result( - release_name=release.name, - revision_number=release.latest_revision_number, - ).all() + file_path_checks = await attestable.load_checks( + release.project_name, release.version, release.latest_revision_number + ) + + if file_path_checks: + async with db.session() as data: + check_results = await data.check_result( + inputs_hash_in=[h for inner in file_path_checks.values() for h in inner.values()] + ).all() + else: + check_results = [] for cr in check_results: if not cr.primary_rel_path: diff --git a/atr/get/result.py b/atr/get/result.py index 06bbf328..5fccc29c 100644 --- a/atr/get/result.py +++ b/atr/get/result.py @@ -52,7 +52,6 @@ async def data( check_result = await data.check_result( id=check_id, - release_name=release.name, ).demand(base.ASFQuartException("Check result not found", errorcode=404)) payload = check_result.model_dump(mode="json", exclude={"release"}) diff --git a/atr/models/attestable.py b/atr/models/attestable.py index 2af1cc2f..4e000984 100644 --- a/atr/models/attestable.py +++ b/atr/models/attestable.py @@ -29,7 +29,7 @@ class HashEntry(schema.Strict): class AttestableChecksV1(schema.Strict): version: Literal[1] = 1 - checks: list[int] = schema.factory(list) + checks: dict[str, dict[str, str]] = schema.factory(dict) class AttestablePathsV1(schema.Strict): diff --git a/atr/storage/readers/checks.py b/atr/storage/readers/checks.py index 1a95df2c..26bc4ac5 100644 --- a/atr/storage/readers/checks.py +++ b/atr/storage/readers/checks.py @@ -18,51 +18,18 @@ # Removing this will cause circular imports from __future__ import annotations -import importlib from typing import TYPE_CHECKING import atr.attestable as attestable import atr.db as db -import atr.hashes as hashes import atr.models.sql as sql import atr.storage as storage import atr.storage.types as types -import atr.tasks.checks as checks import atr.util as util if TYPE_CHECKING: import pathlib - from collections.abc import Callable, Sequence - - -async def _filter_check_results_by_hash( - all_check_results: Sequence[sql.CheckResult], - rel_path: pathlib.Path, - input_hash_by_module: dict[str, str | None], - release: sql.Release, -) -> Sequence[sql.CheckResult]: - filtered_check_results = [] - if release.latest_revision_number is None: - raise ValueError("Release has no revision - Invalid state") - for cr in all_check_results: - if cr.checker not in input_hash_by_module: - module_path = cr.checker.rsplit(".", 1)[0] - try: - module = importlib.import_module(module_path) - policy_keys: list[str] = module.INPUT_POLICY_KEYS - extra_arg_names: list[str] = getattr(module, "INPUT_EXTRA_ARGS", []) - except (ImportError, AttributeError): - policy_keys = [] - extra_arg_names = [] - extra_args = await checks.resolve_extra_args(extra_arg_names, release) - cache_key = await checks.resolve_cache_key( - cr.checker, policy_keys, release, release.latest_revision_number, extra_args, file=rel_path.name - ) - input_hash_by_module[cr.checker] = hashes.compute_dict_hash(cache_key) if cache_key else None - - if cr.inputs_hash == input_hash_by_module[cr.checker]: - filtered_check_results.append(cr) - return filtered_check_results + from collections.abc import Callable class GeneralPublic: @@ -82,18 +49,28 @@ class GeneralPublic: if release.latest_revision_number is None: raise ValueError("Release has no revision - Invalid state") - check_ids = await attestable.load_checks(release.project_name, release.version, release.latest_revision_number) + file_path_checks = await attestable.load_checks( + release.project_name, release.version, release.latest_revision_number + ) all_check_results = ( [ a - for a in await self.__data.check_result(id_in=check_ids) + for a in await self.__data.check_result( + inputs_hash_in=[ + h + for key in ("", str(rel_path)) + if key in file_path_checks + for h in file_path_checks[key].values() + ], + primary_rel_path=str(rel_path), + ) .order_by( sql.validate_instrumented_attribute(sql.CheckResult.checker).asc(), sql.validate_instrumented_attribute(sql.CheckResult.created).desc(), ) .all() ] - if check_ids + if file_path_checks else [] ) diff --git a/atr/storage/readers/releases.py b/atr/storage/readers/releases.py index 90445082..e0da2459 100644 --- a/atr/storage/readers/releases.py +++ b/atr/storage/readers/releases.py @@ -123,8 +123,17 @@ class GeneralPublic: self, release: sql.Release, latest_revision_number: str, info: types.PathInfo ) -> None: match_ignore = await self.__read_as.checks.ignores_matcher(release.project_name) - check_ids = await attestable.load_checks(release.project_name, release.version, latest_revision_number) - attestable_checks = [a for a in await self.__data.check_result(id_in=check_ids).all()] if check_ids else [] + file_path_checks = await attestable.load_checks(release.project_name, release.version, latest_revision_number) + attestable_checks = ( + [ + a + for a in await self.__data.check_result( + inputs_hash_in=[h for inner in file_path_checks.values() for h in inner.values()] + ).all() + ] + if file_path_checks + else [] + ) cs = types.ChecksSubset( checks=attestable_checks, diff --git a/atr/tasks/__init__.py b/atr/tasks/__init__.py index 60b5df83..3acb79df 100644 --- a/atr/tasks/__init__.py +++ b/atr/tasks/__init__.py @@ -63,14 +63,15 @@ async def asc_checks( sql.TaskType.SIGNATURE_CHECK, release, revision, - data, signature_path, check_cache_key=await checks.resolve_cache_key( resolve(sql.TaskType.SIGNATURE_CHECK), + signature.CHECK_VERSION, signature.INPUT_POLICY_KEYS, release, revision, - await checks.resolve_extra_args(signature.INPUT_EXTRA_ARGS, release), + await checks.resolve_extra_args(signature.INPUT_EXTRA_ARGS, release, signature_path), + file=signature_path, ), extra_args={"committee_name": release.committee.name}, ) @@ -177,9 +178,9 @@ async def draft_checks( sql.TaskType.PATHS_CHECK, release, revision_number, - caller_data, check_cache_key=await checks.resolve_cache_key( resolve(sql.TaskType.PATHS_CHECK), + paths.CHECK_VERSION, paths.INPUT_POLICY_KEYS, release, revision_number, @@ -226,7 +227,6 @@ async def _draft_file_checks( sql.TaskType.SBOM_TOOL_SCORE, release, revision_number, - caller_data, path_str, extra_args={ "project_name": project_name, @@ -295,33 +295,33 @@ async def queued( task_type: sql.TaskType, release: sql.Release, revision_number: str, - data: db.Session | None = None, primary_rel_path: str | None = None, extra_args: dict[str, Any] | None = None, check_cache_key: dict[str, Any] | None = None, ) -> sql.Task | None: - hash_val = None + hash_val: str | None = None if check_cache_key is not None: + if "checker" not in check_cache_key: + raise ValueError("Cache key must contain a 'checker' key") hash_val = hashes.compute_dict_hash(check_cache_key) - if not data: - raise RuntimeError("DB Session is required for check_cache_key") - existing = await data.check_result(inputs_hash=hash_val, release_name=release.name).all() - if existing: - await attestable.write_checks_data( - release.project.name, release.version, revision_number, [c.id for c in existing] - ) - return None - return sql.Task( - status=sql.TaskStatus.QUEUED, - task_type=task_type, - task_args=extra_args or {}, - asf_uid=asf_uid, - project_name=release.project.name, - version_name=release.version, - revision_number=revision_number, - primary_rel_path=primary_rel_path, - inputs_hash=hash_val, + await attestable.write_checks_data( + release.project.name, + release.version, + revision_number, + primary_rel_path or "", + {check_cache_key["checker"]: hash_val}, ) + return sql.Task( + status=sql.TaskStatus.QUEUED, + task_type=task_type, + task_args=extra_args or {}, + asf_uid=asf_uid, + project_name=release.project.name, + version_name=release.version, + revision_number=revision_number, + primary_rel_path=primary_rel_path, + inputs_hash=hash_val, + ) async def _add_task(data: db.Session, task: sql.Task) -> None: @@ -405,14 +405,14 @@ async def sha_checks( sql.TaskType.HASHING_CHECK, release, revision, - data, hash_file, check_cache_key=await checks.resolve_cache_key( resolve(sql.TaskType.HASHING_CHECK), + hashing.CHECK_VERSION, hashing.INPUT_POLICY_KEYS, release, revision, - await checks.resolve_extra_args(hashing.INPUT_EXTRA_ARGS, release), + await checks.resolve_extra_args(hashing.INPUT_EXTRA_ARGS, release, hash_file), file=hash_file, ), ) @@ -429,6 +429,7 @@ async def tar_gz_checks( is_podling = (release.project.committee is not None) and release.project.committee.is_podling compare_ck = await checks.resolve_cache_key( resolve(sql.TaskType.COMPARE_SOURCE_TREES), + compare.CHECK_VERSION, compare.INPUT_POLICY_KEYS, release, revision, @@ -437,6 +438,7 @@ async def tar_gz_checks( ) license_h_ck = await checks.resolve_cache_key( resolve(sql.TaskType.LICENSE_HEADERS), + license.CHECK_VERSION, license.INPUT_POLICY_KEYS, release, revision, @@ -445,6 +447,7 @@ async def tar_gz_checks( ) license_f_ck = await checks.resolve_cache_key( resolve(sql.TaskType.LICENSE_FILES), + license.CHECK_VERSION, license.INPUT_POLICY_KEYS, release, revision, @@ -453,6 +456,7 @@ async def tar_gz_checks( ) rat_ck = await checks.resolve_cache_key( resolve(sql.TaskType.RAT_CHECK), + rat.CHECK_VERSION, rat.INPUT_POLICY_KEYS, release, revision, @@ -461,6 +465,7 @@ async def tar_gz_checks( ) targz_i_ck = await checks.resolve_cache_key( resolve(sql.TaskType.TARGZ_INTEGRITY), + targz.CHECK_VERSION, targz.INPUT_POLICY_KEYS, release, revision, @@ -469,6 +474,7 @@ async def tar_gz_checks( ) targz_s_ck = await checks.resolve_cache_key( resolve(sql.TaskType.TARGZ_STRUCTURE), + targz.CHECK_VERSION, targz.INPUT_POLICY_KEYS, release, revision, @@ -476,21 +482,20 @@ async def tar_gz_checks( file=path, ) tasks = [ - queued(asf_uid, sql.TaskType.COMPARE_SOURCE_TREES, release, revision, data, path, check_cache_key=compare_ck), + queued(asf_uid, sql.TaskType.COMPARE_SOURCE_TREES, release, revision, path, check_cache_key=compare_ck), queued( asf_uid, sql.TaskType.LICENSE_FILES, release, revision, - data, path, check_cache_key=license_f_ck, extra_args={"is_podling": is_podling}, ), - queued(asf_uid, sql.TaskType.LICENSE_HEADERS, release, revision, data, path, check_cache_key=license_h_ck), - queued(asf_uid, sql.TaskType.RAT_CHECK, release, revision, data, path, check_cache_key=rat_ck), - queued(asf_uid, sql.TaskType.TARGZ_INTEGRITY, release, revision, data, path, check_cache_key=targz_i_ck), - queued(asf_uid, sql.TaskType.TARGZ_STRUCTURE, release, revision, data, path, check_cache_key=targz_s_ck), + queued(asf_uid, sql.TaskType.LICENSE_HEADERS, release, revision, path, check_cache_key=license_h_ck), + queued(asf_uid, sql.TaskType.RAT_CHECK, release, revision, path, check_cache_key=rat_ck), + queued(asf_uid, sql.TaskType.TARGZ_INTEGRITY, release, revision, path, check_cache_key=targz_i_ck), + queued(asf_uid, sql.TaskType.TARGZ_STRUCTURE, release, revision, path, check_cache_key=targz_s_ck), ] return await asyncio.gather(*tasks) @@ -531,6 +536,7 @@ async def zip_checks( is_podling = (release.project.committee is not None) and release.project.committee.is_podling compare_ck = await checks.resolve_cache_key( resolve(sql.TaskType.COMPARE_SOURCE_TREES), + compare.CHECK_VERSION, compare.INPUT_POLICY_KEYS, release, revision, @@ -539,6 +545,7 @@ async def zip_checks( ) license_h_ck = await checks.resolve_cache_key( resolve(sql.TaskType.LICENSE_HEADERS), + license.CHECK_VERSION, license.INPUT_POLICY_KEYS, release, revision, @@ -547,6 +554,7 @@ async def zip_checks( ) license_f_ck = await checks.resolve_cache_key( resolve(sql.TaskType.LICENSE_FILES), + license.CHECK_VERSION, license.INPUT_POLICY_KEYS, release, revision, @@ -555,6 +563,7 @@ async def zip_checks( ) rat_ck = await checks.resolve_cache_key( resolve(sql.TaskType.RAT_CHECK), + rat.CHECK_VERSION, rat.INPUT_POLICY_KEYS, release, revision, @@ -563,6 +572,7 @@ async def zip_checks( ) zip_i_ck = await checks.resolve_cache_key( resolve(sql.TaskType.ZIPFORMAT_INTEGRITY), + zipformat.CHECK_VERSION, zipformat.INPUT_POLICY_KEYS, release, revision, @@ -571,6 +581,7 @@ async def zip_checks( ) zip_s_ck = await checks.resolve_cache_key( resolve(sql.TaskType.ZIPFORMAT_STRUCTURE), + zipformat.CHECK_VERSION, zipformat.INPUT_POLICY_KEYS, release, revision, @@ -579,21 +590,20 @@ async def zip_checks( ) tasks = [ - queued(asf_uid, sql.TaskType.COMPARE_SOURCE_TREES, release, revision, data, path, check_cache_key=compare_ck), + queued(asf_uid, sql.TaskType.COMPARE_SOURCE_TREES, release, revision, path, check_cache_key=compare_ck), queued( asf_uid, sql.TaskType.LICENSE_FILES, release, revision, - data, path, check_cache_key=license_f_ck, extra_args={"is_podling": is_podling}, ), - queued(asf_uid, sql.TaskType.LICENSE_HEADERS, release, revision, data, path, check_cache_key=license_h_ck), - queued(asf_uid, sql.TaskType.RAT_CHECK, release, revision, data, path, check_cache_key=rat_ck), - queued(asf_uid, sql.TaskType.ZIPFORMAT_INTEGRITY, release, revision, data, path, check_cache_key=zip_i_ck), - queued(asf_uid, sql.TaskType.ZIPFORMAT_STRUCTURE, release, revision, data, path, check_cache_key=zip_s_ck), + queued(asf_uid, sql.TaskType.LICENSE_HEADERS, release, revision, path, check_cache_key=license_h_ck), + queued(asf_uid, sql.TaskType.RAT_CHECK, release, revision, path, check_cache_key=rat_ck), + queued(asf_uid, sql.TaskType.ZIPFORMAT_INTEGRITY, release, revision, path, check_cache_key=zip_i_ck), + queued(asf_uid, sql.TaskType.ZIPFORMAT_STRUCTURE, release, revision, path, check_cache_key=zip_s_ck), ] return await asyncio.gather(*tasks) diff --git a/atr/tasks/checks/__init__.py b/atr/tasks/checks/__init__.py index db3616a9..5ece2282 100644 --- a/atr/tasks/checks/__init__.py +++ b/atr/tasks/checks/__init__.py @@ -20,10 +20,12 @@ from __future__ import annotations import dataclasses import datetime import functools +import json from typing import TYPE_CHECKING, Any, Final import aiofiles import aiofiles.os +import pydantic import sqlmodel if TYPE_CHECKING: @@ -39,6 +41,7 @@ import atr.file_paths as file_paths import atr.hashes as hashes import atr.log as log import atr.models.sql as sql +import atr.sbom.models.github as github_models import atr.util as util @@ -104,7 +107,15 @@ class Recorder: member_rel_path: str | None = None, afresh: bool = True, ) -> Recorder: - recorder = cls(checker, project_name, version_name, revision_number, primary_rel_path, member_rel_path, afresh) + recorder = cls( + checker, + project_name, + version_name, + revision_number, + primary_rel_path, + member_rel_path, + afresh, + ) if afresh is True: # Clear outer path whether it's specified or not await recorder.clear(primary_rel_path=primary_rel_path, member_rel_path=member_rel_path) @@ -198,7 +209,11 @@ class Recorder: return matches(str(abs_path)) async def cache_key_set( - self, policy_keys: list[str], input_args: list[str] | None = None, checker: str | None = None + self, + policy_keys: list[str], + version, + input_args: list[str] | None = None, + checker: str | None = None, ) -> bool: # TODO: Should this just be in the constructor? @@ -216,9 +231,15 @@ class Recorder: release = await data.release( name=self.release_name, _release_policy=True, _project_release_policy=True, _project=True ).demand(RuntimeError(f"Release {self.release_name} not found")) - args = await resolve_extra_args(input_args or [], release) + args = await resolve_extra_args(input_args or [], release, self.primary_rel_path) cache_key = await resolve_cache_key( - checker or self.checker, policy_keys, release, self.revision_number, args, file=self.primary_rel_path + checker or self.checker, + version, + policy_keys, + release, + self.revision_number, + args, + file=self.primary_rel_path, ) self.__input_hash = hashes.compute_dict_hash(cache_key) if cache_key else None return True @@ -257,7 +278,6 @@ class Recorder: primary_rel_path=primary_rel_path, member_rel_path=member_rel_path, ) - await attestable.write_checks_data(self.project_name, self.version_name, self.revision_number, [result.id]) return result async def exception( @@ -274,7 +294,6 @@ class Recorder: primary_rel_path=primary_rel_path, member_rel_path=member_rel_path, ) - await attestable.write_checks_data(self.project_name, self.version_name, self.revision_number, [result.id]) return result async def failure( @@ -291,7 +310,6 @@ class Recorder: primary_rel_path=primary_rel_path, member_rel_path=member_rel_path, ) - await attestable.write_checks_data(self.project_name, self.version_name, self.revision_number, [result.id]) return result async def success( @@ -308,7 +326,6 @@ class Recorder: primary_rel_path=primary_rel_path, member_rel_path=member_rel_path, ) - await attestable.write_checks_data(self.project_name, self.version_name, self.revision_number, [result.id]) return result async def use_check_cache(self) -> bool: @@ -337,7 +354,6 @@ class Recorder: primary_rel_path=primary_rel_path, member_rel_path=member_rel_path, ) - await attestable.write_checks_data(self.project_name, self.version_name, self.revision_number, [result.id]) return result @@ -347,6 +363,7 @@ def function_key(func: Callable[..., Any] | str) -> str: async def resolve_cache_key( checker: str | Callable[..., Any], + checker_version: str, policy_keys: list[str], release: sql.Release, revision: str, @@ -381,7 +398,7 @@ async def resolve_cache_key( return {**cache_key, **args} -async def resolve_extra_args(arg_names: list[str], release: sql.Release) -> dict[str, Any]: +async def resolve_extra_args(arg_names: list[str], release: sql.Release, rel_path: str | None = None) -> dict[str, Any]: result: dict[str, Any] = {} for name in arg_names: resolver = _EXTRA_ARG_RESOLVERS.get(name, None) @@ -389,7 +406,7 @@ async def resolve_extra_args(arg_names: list[str], release: sql.Release) -> dict if resolver is None: log.warning(f"Unknown extra arg resolver: {name}") return {} - result[name] = await resolver(release) + result[name] = await resolver(release, rel_path) return result @@ -407,7 +424,7 @@ def with_model(cls: type[schema.Strict]) -> Callable[[Callable[..., Any]], Calla return decorator -async def _resolve_all_files(release: sql.Release) -> list[str]: +async def _resolve_all_files(release: sql.Release, rel_path: str | None = None) -> list[str]: if not release.latest_revision_number: return [] if not ( @@ -422,21 +439,57 @@ async def _resolve_all_files(release: sql.Release) -> list[str]: return [] relative_paths = [p async for p in util.paths_recursive(base_path)] relative_paths_set = set(str(p) for p in relative_paths) - return list(relative_paths_set) + return list(sorted(relative_paths_set)) -async def _resolve_is_podling(release: sql.Release) -> bool: +async def _resolve_is_podling(release: sql.Release, rel_path: str | None = None) -> bool: return (release.committee is not None) and release.committee.is_podling -async def _resolve_committee_name(release: sql.Release) -> str: +async def _resolve_github_tp_sha(release: sql.Release, rel_path: str | None = None) -> str: + if not release.latest_revision_number: + return "" + payload_path = attestable.github_tp_payload_path( + release.project_name, release.version, release.latest_revision_number + ) + if not await aiofiles.os.path.isfile(payload_path): + return "" + try: + async with aiofiles.open(payload_path, encoding="utf-8") as f: + data = json.loads(await f.read()) + if not isinstance(data, dict): + log.warning(f"TP payload was not a JSON object in {payload_path}") + return "" + tp_data = github_models.TrustedPublisherPayload.model_validate(data) + return tp_data.sha + except (OSError, json.JSONDecodeError) as e: + log.warning(f"Failed to read TP payload from {payload_path}: {e}") + return "" + except pydantic.ValidationError as e: + log.warning(f"Failed to validate TP payload from {payload_path}: {e}") + return "" + + +async def _resolve_committee_name(release: sql.Release, rel_path: str | None = None) -> str: if release.committee is None: raise ValueError("Release has no committee") return release.committee.name -_EXTRA_ARG_RESOLVERS: Final[dict[str, Callable[[sql.Release], Any]]] = { +async def _resolve_unsuffixed_file_hash(release: sql.Release, rel_path: str | None = None) -> str: + if (not rel_path) or (not release.latest_revision_number): + return "" + abs_path = file_paths.revision_path_for_file( + release.project_name, release.version, release.latest_revision_number, rel_path + ) + plain_path = abs_path.with_suffix("") + return await hashes.compute_file_hash(plain_path) + + +_EXTRA_ARG_RESOLVERS: Final[dict[str, Callable[[sql.Release, str | None], Any]]] = { "all_files": _resolve_all_files, - "is_podling": _resolve_is_podling, "committee_name": _resolve_committee_name, + "github_tp_sha": _resolve_github_tp_sha, + "is_podling": _resolve_is_podling, + "unsuffixed_file_hash": _resolve_unsuffixed_file_hash, } diff --git a/atr/tasks/checks/compare.py b/atr/tasks/checks/compare.py index 162c9673..70f8daf1 100644 --- a/atr/tasks/checks/compare.py +++ b/atr/tasks/checks/compare.py @@ -53,7 +53,8 @@ _PERMITTED_ADDED_PATHS: Final[dict[str, list[str]]] = { } # Release policy fields which this check relies on - used for result caching INPUT_POLICY_KEYS: Final[list[str]] = [] -INPUT_EXTRA_ARGS: Final[list[str]] = [] +INPUT_EXTRA_ARGS: Final[list[str]] = ["github_tp_sha"] +CHECK_VERSION: Final[str] = "1" @dataclasses.dataclass @@ -93,7 +94,7 @@ async def source_trees(args: checks.FunctionArguments) -> results.Results | None ) return None - await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) + await recorder.cache_key_set(INPUT_POLICY_KEYS, CHECK_VERSION, INPUT_EXTRA_ARGS) payload = await _load_tp_payload(args.project_name, args.version_name, args.revision_number) checkout_dir: str | None = None diff --git a/atr/tasks/checks/hashing.py b/atr/tasks/checks/hashing.py index cb3c8f77..60ef085f 100644 --- a/atr/tasks/checks/hashing.py +++ b/atr/tasks/checks/hashing.py @@ -27,7 +27,8 @@ import atr.tasks.checks as checks # Release policy fields which this check relies on - used for result caching INPUT_POLICY_KEYS: Final[list[str]] = [] -INPUT_EXTRA_ARGS: Final[list[str]] = [] +INPUT_EXTRA_ARGS: Final[list[str]] = ["unsuffixed_file_hash"] +CHECK_VERSION: Final[str] = "1" async def check(args: checks.FunctionArguments) -> results.Results | None: @@ -41,7 +42,7 @@ async def check(args: checks.FunctionArguments) -> results.Results | None: await recorder.failure("Unsupported hash algorithm", {"algorithm": algorithm}) return None - await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) + await recorder.cache_key_set(INPUT_POLICY_KEYS, CHECK_VERSION, INPUT_EXTRA_ARGS) # Remove the hash file suffix to get the artifact path # This replaces the last suffix, which is what we want diff --git a/atr/tasks/checks/license.py b/atr/tasks/checks/license.py index 49b20a7a..fecb3fa9 100644 --- a/atr/tasks/checks/license.py +++ b/atr/tasks/checks/license.py @@ -82,6 +82,7 @@ INCLUDED_PATTERNS: Final[list[str]] = [ # Release policy fields which this check relies on - used for result caching INPUT_POLICY_KEYS: Final[list[str]] = [""] INPUT_EXTRA_ARGS: Final[list[str]] = ["is_podling"] +CHECK_VERSION: Final[str] = "1" # Types @@ -139,7 +140,7 @@ async def files(args: checks.FunctionArguments) -> results.Results | None: return None is_podling = args.extra_args.get("is_podling", False) - await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) + await recorder.cache_key_set(INPUT_POLICY_KEYS, CHECK_VERSION, INPUT_EXTRA_ARGS) log.info(f"Checking license files for {artifact_abs_path} (rel: {args.primary_rel_path})") @@ -172,7 +173,7 @@ async def headers(args: checks.FunctionArguments) -> results.Results | None: if project.policy_license_check_mode == sql.LicenseCheckMode.RAT: return None - await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) + await recorder.cache_key_set(INPUT_POLICY_KEYS, CHECK_VERSION, INPUT_EXTRA_ARGS) # if await recorder.check_cache(artifact_abs_path): # log.info(f"Using cached license headers result for {artifact_abs_path} (rel: {args.primary_rel_path})") diff --git a/atr/tasks/checks/paths.py b/atr/tasks/checks/paths.py index d37fc66b..c28900a8 100644 --- a/atr/tasks/checks/paths.py +++ b/atr/tasks/checks/paths.py @@ -40,6 +40,7 @@ _ALLOWED_TOP_LEVEL: Final = frozenset( # Release policy fields which this check relies on - used for result caching INPUT_POLICY_KEYS: Final[list[str]] = [] INPUT_EXTRA_ARGS: Final[list[str]] = ["is_podling", "all_files"] +CHECK_VERSION: Final[str] = "1" async def check(args: checks.FunctionArguments) -> results.Results | None: @@ -89,9 +90,15 @@ async def check(args: checks.FunctionArguments) -> results.Results | None: relative_paths = [p async for p in util.paths_recursive(base_path)] relative_paths_set = set(str(p) for p in relative_paths) - await recorder_errors.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS, checker=checks.function_key(check)) - await recorder_warnings.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS, checker=checks.function_key(check)) - await recorder_success.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS, checker=checks.function_key(check)) + await recorder_errors.cache_key_set( + INPUT_POLICY_KEYS, CHECK_VERSION, INPUT_EXTRA_ARGS, checker=checks.function_key(check) + ) + await recorder_warnings.cache_key_set( + INPUT_POLICY_KEYS, CHECK_VERSION, INPUT_EXTRA_ARGS, checker=checks.function_key(check) + ) + await recorder_success.cache_key_set( + INPUT_POLICY_KEYS, CHECK_VERSION, INPUT_EXTRA_ARGS, checker=checks.function_key(check) + ) for relative_path in relative_paths: # Delegate processing of each path to the helper function diff --git a/atr/tasks/checks/rat.py b/atr/tasks/checks/rat.py index 9c065ec7..486ad18c 100644 --- a/atr/tasks/checks/rat.py +++ b/atr/tasks/checks/rat.py @@ -68,6 +68,7 @@ _STD_EXCLUSIONS_EXTENDED: Final[list[str]] = [ # Release policy fields which this check relies on - used for result caching INPUT_POLICY_KEYS: Final[list[str]] = [] INPUT_EXTRA_ARGS: Final[list[str]] = [] +CHECK_VERSION: Final[str] = "1" class RatError(RuntimeError): @@ -88,7 +89,7 @@ async def check(args: checks.FunctionArguments) -> results.Results | None: log.info(f"Skipping RAT check for {artifact_abs_path} (mode is LIGHTWEIGHT)") return None - await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) + await recorder.cache_key_set(INPUT_POLICY_KEYS, CHECK_VERSION, INPUT_EXTRA_ARGS) log.info(f"Checking RAT licenses for {artifact_abs_path} (rel: {args.primary_rel_path})") diff --git a/atr/tasks/checks/signature.py b/atr/tasks/checks/signature.py index 830d9cfe..d6c88f2e 100644 --- a/atr/tasks/checks/signature.py +++ b/atr/tasks/checks/signature.py @@ -32,7 +32,8 @@ import atr.util as util # Release policy fields which this check relies on - used for result caching INPUT_POLICY_KEYS: Final[list[str]] = [] -INPUT_EXTRA_ARGS: Final[list[str]] = ["committee_name"] +INPUT_EXTRA_ARGS: Final[list[str]] = ["committee_name", "unsuffixed_file_hash"] +CHECK_VERSION: Final[str] = "1" async def check(args: checks.FunctionArguments) -> results.Results | None: @@ -54,7 +55,7 @@ async def check(args: checks.FunctionArguments) -> results.Results | None: await recorder.exception("Committee name is required", {"committee_name": committee_name}) return None - await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) + await recorder.cache_key_set(INPUT_POLICY_KEYS, CHECK_VERSION, INPUT_EXTRA_ARGS) log.info( f"Checking signature {primary_abs_path} for {artifact_abs_path}" diff --git a/atr/tasks/checks/targz.py b/atr/tasks/checks/targz.py index 7b369e5e..2415135d 100644 --- a/atr/tasks/checks/targz.py +++ b/atr/tasks/checks/targz.py @@ -28,6 +28,7 @@ import atr.util as util # Release policy fields which this check relies on - used for result caching INPUT_POLICY_KEYS: Final[list[str]] = [] INPUT_EXTRA_ARGS: Final[list[str]] = [] +CHECK_VERSION: Final[str] = "1" class RootDirectoryError(Exception): @@ -42,7 +43,7 @@ async def integrity(args: checks.FunctionArguments) -> results.Results | None: if not (artifact_abs_path := await recorder.abs_path()): return None - await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) + await recorder.cache_key_set(INPUT_POLICY_KEYS, CHECK_VERSION, INPUT_EXTRA_ARGS) log.info(f"Checking integrity for {artifact_abs_path} (rel: {args.primary_rel_path})") @@ -101,7 +102,7 @@ async def structure(args: checks.FunctionArguments) -> results.Results | None: if await recorder.primary_path_is_binary(): return None - await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) + await recorder.cache_key_set(INPUT_POLICY_KEYS, CHECK_VERSION, INPUT_EXTRA_ARGS) filename = artifact_abs_path.name basename_from_filename: Final[str] = ( diff --git a/atr/tasks/checks/zipformat.py b/atr/tasks/checks/zipformat.py index ed85f7b2..516f662f 100644 --- a/atr/tasks/checks/zipformat.py +++ b/atr/tasks/checks/zipformat.py @@ -29,6 +29,7 @@ import atr.util as util # Release policy fields which this check relies on - used for result caching INPUT_POLICY_KEYS: Final[list[str]] = [] INPUT_EXTRA_ARGS: Final[list[str]] = [] +CHECK_VERSION: Final[str] = "1" async def integrity(args: checks.FunctionArguments) -> results.Results | None: @@ -37,7 +38,7 @@ async def integrity(args: checks.FunctionArguments) -> results.Results | None: if not (artifact_abs_path := await recorder.abs_path()): return None - await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) + await recorder.cache_key_set(INPUT_POLICY_KEYS, CHECK_VERSION, INPUT_EXTRA_ARGS) log.info(f"Checking zip integrity for {artifact_abs_path} (rel: {args.primary_rel_path})") @@ -63,7 +64,7 @@ async def structure(args: checks.FunctionArguments) -> results.Results | None: if await recorder.primary_path_is_binary(): return None - await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) + await recorder.cache_key_set(INPUT_POLICY_KEYS, CHECK_VERSION, INPUT_EXTRA_ARGS) log.info(f"Checking zip structure for {artifact_abs_path} (rel: {args.primary_rel_path})") diff --git a/tests/unit/recorders.py b/tests/unit/recorders.py index bdb20dcf..5d5ce29f 100644 --- a/tests/unit/recorders.py +++ b/tests/unit/recorders.py @@ -42,7 +42,7 @@ class RecorderStub(checks.Recorder): return self._path if (rel_path is None) else self._path / rel_path async def cache_key_set( - self, policy_keys: list[str], input_args: list[str] | None = None, checker: str | None = None + self, policy_keys: list[str], version: str, input_args: list[str] | None = None, checker: str | None = None ) -> bool: return False diff --git a/tests/unit/test_checks_compare.py b/tests/unit/test_checks_compare.py index e90e3314..5ec9b751 100644 --- a/tests/unit/test_checks_compare.py +++ b/tests/unit/test_checks_compare.py @@ -243,7 +243,7 @@ class RecorderStub(atr.tasks.checks.Recorder): self._is_source = is_source async def cache_key_set( - self, policy_keys: list[str], input_args: list[str] | None = None, checker: str | None = None + self, policy_keys: list[str], version: str, input_args: list[str] | None = None, checker: str | None = None ) -> bool: return False @@ -272,8 +272,8 @@ class RecorderStub(atr.tasks.checks.Recorder): ) -> atr.models.sql.CheckResult: self.success_calls.append((message, data)) return atr.models.sql.CheckResult( - release_name=self.release_name, - revision_number=self.revision_number, + release_name=None, + revision_number=None, checker=self.checker, primary_rel_path=primary_rel_path or self.primary_rel_path, member_rel_path=member_rel_path, --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
