This is an automated email from the ASF dual-hosted git repository. arm pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tooling-trusted-releases.git
commit d3a256375b92a1e62131727532cfdf484636cca3 Author: Alastair McFarlane <[email protected]> AuthorDate: Fri Feb 13 14:53:31 2026 +0000 Read and write checks to/from attestable data --- atr/attestable.py | 52 ++++++++++++++++++++-- atr/db/__init__.py | 5 +++ atr/models/attestable.py | 5 +++ atr/storage/readers/checks.py | 44 +++++++++---------- atr/storage/readers/releases.py | 34 ++++----------- atr/storage/types.py | 3 +- atr/storage/writers/revision.py | 2 +- atr/tasks/__init__.py | 49 +++++++++++++-------- atr/tasks/checks/__init__.py | 91 ++++++++++++++++++++++++--------------- atr/tasks/checks/compare.py | 2 +- atr/tasks/checks/hashing.py | 2 +- atr/tasks/checks/license.py | 4 +- atr/tasks/checks/paths.py | 22 +++++----- atr/tasks/checks/rat.py | 2 +- atr/tasks/checks/signature.py | 2 +- atr/tasks/checks/targz.py | 4 +- atr/tasks/checks/zipformat.py | 4 +- atr/util.py | 22 ++++++++++ tests/unit/recorders.py | 26 +++++++++++ tests/unit/test_cache.py | 10 ++--- tests/unit/test_checks_compare.py | 5 +++ 21 files changed, 254 insertions(+), 136 deletions(-) diff --git a/atr/attestable.py b/atr/attestable.py index b2512279..50260093 100644 --- a/atr/attestable.py +++ b/atr/attestable.py @@ -28,6 +28,7 @@ import atr.hashes as hashes import atr.log as log import atr.models.attestable as models import atr.util as util +from atr.models.attestable import AttestableChecksV1 if TYPE_CHECKING: import pathlib @@ -45,6 +46,10 @@ def github_tp_payload_path(project_name: str, version_name: str, revision_number return util.get_attestable_dir() / project_name / version_name / f"{revision_number}.github-tp.json" +def attestable_checks_path(project_name: str, version_name: str, revision_number: str) -> pathlib.Path: + return util.get_attestable_dir() / project_name / version_name / f"{revision_number}.checks.json" + + async def github_tp_payload_write( project_name: str, version_name: str, revision_number: str, github_payload: dict[str, Any] ) -> None: @@ -89,6 +94,22 @@ async def load_paths( return None +async def load_checks( + project_name: str, + version_name: str, + revision_number: str, +) -> list[int] | None: + file_path = attestable_checks_path(project_name, version_name, revision_number) + if await aiofiles.os.path.isfile(file_path): + try: + async with aiofiles.open(file_path, encoding="utf-8") as f: + data = json.loads(await f.read()) + return models.AttestableChecksV1.model_validate(data).checks + except (json.JSONDecodeError, pydantic.ValidationError) as e: + log.warning(f"Could not parse {file_path}: {e}") + return [] + + def migrate_to_paths_files() -> int: attestable_dir = util.get_attestable_dir() if not attestable_dir.is_dir(): @@ -135,7 +156,7 @@ async def paths_to_hashes_and_sizes(directory: pathlib.Path) -> tuple[dict[str, return path_to_hash, path_to_size -async def write( +async def write_files_data( project_name: str, version_name: str, revision_number: str, @@ -145,12 +166,37 @@ async def write( path_to_hash: dict[str, str], path_to_size: dict[str, int], ) -> None: - result = _generate(path_to_hash, path_to_size, revision_number, release_policy, uploader_uid, previous) + result = _generate_files_data(path_to_hash, path_to_size, revision_number, release_policy, uploader_uid, previous) file_path = attestable_path(project_name, version_name, revision_number) await util.atomic_write_file(file_path, result.model_dump_json(indent=2)) paths_result = models.AttestablePathsV1(paths=result.paths) paths_file_path = attestable_paths_path(project_name, version_name, revision_number) await util.atomic_write_file(paths_file_path, paths_result.model_dump_json(indent=2)) + checks_file_path = attestable_checks_path(project_name, version_name, revision_number) + if not checks_file_path.exists(): + async with aiofiles.open(checks_file_path, "w", encoding="utf-8") as f: + await f.write(models.AttestableChecksV1().model_dump_json(indent=2)) + + +async def write_checks_data( + project_name: str, + version_name: str, + revision_number: str, + checks: list[int], +) -> None: + log.info(f"Writing checks for {project_name}/{version_name}/{revision_number}: {checks}") + + def modify(content: str) -> str: + try: + current = AttestableChecksV1.model_validate_json(content).checks + except pydantic.ValidationError: + current = [] + new_checks = set(current or []) + new_checks.update(checks) + result = models.AttestableChecksV1(checks=sorted(new_checks)) + return result.model_dump_json(indent=2) + + await util.atomic_modify_file(attestable_checks_path(project_name, version_name, revision_number), modify) def _compute_hashes_with_attribution( @@ -188,7 +234,7 @@ def _compute_hashes_with_attribution( return new_hashes -def _generate( +def _generate_files_data( path_to_hash: dict[str, str], path_to_size: dict[str, int], revision_number: str, diff --git a/atr/db/__init__.py b/atr/db/__init__.py index ed4a76a9..2c6d579e 100644 --- a/atr/db/__init__.py +++ b/atr/db/__init__.py @@ -155,6 +155,7 @@ class Session(sqlalchemy.ext.asyncio.AsyncSession): def check_result( self, id: Opt[int] = NOT_SET, + id_in: Opt[list[int]] = NOT_SET, release_name: Opt[str] = NOT_SET, revision_number: Opt[str] = NOT_SET, checker: Opt[str] = NOT_SET, @@ -169,8 +170,12 @@ class Session(sqlalchemy.ext.asyncio.AsyncSession): ) -> Query[sql.CheckResult]: query = sqlmodel.select(sql.CheckResult) + via = sql.validate_instrumented_attribute + if is_defined(id): query = query.where(sql.CheckResult.id == id) + if is_defined(id_in): + query = query.where(via(sql.CheckResult.id).in_(id_in)) if is_defined(release_name): query = query.where(sql.CheckResult.release_name == release_name) if is_defined(revision_number): diff --git a/atr/models/attestable.py b/atr/models/attestable.py index 45e3ac3d..2af1cc2f 100644 --- a/atr/models/attestable.py +++ b/atr/models/attestable.py @@ -27,6 +27,11 @@ class HashEntry(schema.Strict): uploaders: list[Annotated[tuple[str, str], pydantic.BeforeValidator(tuple)]] +class AttestableChecksV1(schema.Strict): + version: Literal[1] = 1 + checks: list[int] = schema.factory(list) + + class AttestablePathsV1(schema.Strict): version: Literal[1] = 1 paths: dict[str, str] = schema.factory(dict) diff --git a/atr/storage/readers/checks.py b/atr/storage/readers/checks.py index f9b74301..1a95df2c 100644 --- a/atr/storage/readers/checks.py +++ b/atr/storage/readers/checks.py @@ -21,6 +21,7 @@ from __future__ import annotations import importlib from typing import TYPE_CHECKING +import atr.attestable as attestable import atr.db as db import atr.hashes as hashes import atr.models.sql as sql @@ -44,8 +45,8 @@ async def _filter_check_results_by_hash( if release.latest_revision_number is None: raise ValueError("Release has no revision - Invalid state") for cr in all_check_results: - module_path = cr.checker.rsplit(".", 1)[0] - if module_path not in input_hash_by_module: + if cr.checker not in input_hash_by_module: + module_path = cr.checker.rsplit(".", 1)[0] try: module = importlib.import_module(module_path) policy_keys: list[str] = module.INPUT_POLICY_KEYS @@ -53,13 +54,13 @@ async def _filter_check_results_by_hash( except (ImportError, AttributeError): policy_keys = [] extra_arg_names = [] - extra_args = checks.resolve_extra_args(extra_arg_names, release) + extra_args = await checks.resolve_extra_args(extra_arg_names, release) cache_key = await checks.resolve_cache_key( cr.checker, policy_keys, release, release.latest_revision_number, extra_args, file=rel_path.name ) - input_hash_by_module[module_path] = hashes.compute_dict_hash(cache_key) if cache_key else None + input_hash_by_module[cr.checker] = hashes.compute_dict_hash(cache_key) if cache_key else None - if cr.inputs_hash == input_hash_by_module[module_path]: + if cr.inputs_hash == input_hash_by_module[cr.checker]: filtered_check_results.append(cr) return filtered_check_results @@ -81,31 +82,26 @@ class GeneralPublic: if release.latest_revision_number is None: raise ValueError("Release has no revision - Invalid state") - # TODO: Is this potentially too much data? Within a revision I hope it's not too bad? - - query = self.__data.check_result( - release_name=release.name, - primary_rel_path=str(rel_path), - ).order_by( - sql.validate_instrumented_attribute(sql.CheckResult.checker).asc(), - sql.validate_instrumented_attribute(sql.CheckResult.created).desc(), - ) - all_check_results = await query.all() - - # Filter to checks for the current file version / policy - # Cache the computed input hash per checker module, since all results here share the same file and release - input_hash_by_module: dict[str, str | None] = {} - # TODO: This has a bug - create an archive, it'll scan with a hash and show missing checksum. - # Then generate a checksum. It'll re-scan the file with the same hash, but now has one. Two checks shown. - filtered_check_results = await _filter_check_results_by_hash( - all_check_results, rel_path, input_hash_by_module, release + check_ids = await attestable.load_checks(release.project_name, release.version, release.latest_revision_number) + all_check_results = ( + [ + a + for a in await self.__data.check_result(id_in=check_ids) + .order_by( + sql.validate_instrumented_attribute(sql.CheckResult.checker).asc(), + sql.validate_instrumented_attribute(sql.CheckResult.created).desc(), + ) + .all() + ] + if check_ids + else [] ) # Filter out any results that are ignored unignored_checks = [] ignored_checks = [] match_ignore = await self.ignores_matcher(release.project_name) - for cr in filtered_check_results: + for cr in all_check_results: if not match_ignore(cr): unignored_checks.append(cr) else: diff --git a/atr/storage/readers/releases.py b/atr/storage/readers/releases.py index c4969aca..90445082 100644 --- a/atr/storage/readers/releases.py +++ b/atr/storage/readers/releases.py @@ -21,6 +21,7 @@ from __future__ import annotations import dataclasses import pathlib +import atr.attestable as attestable import atr.classify as classify import atr.db as db import atr.models.sql as sql @@ -122,10 +123,11 @@ class GeneralPublic: self, release: sql.Release, latest_revision_number: str, info: types.PathInfo ) -> None: match_ignore = await self.__read_as.checks.ignores_matcher(release.project_name) + check_ids = await attestable.load_checks(release.project_name, release.version, latest_revision_number) + attestable_checks = [a for a in await self.__data.check_result(id_in=check_ids).all()] if check_ids else [] cs = types.ChecksSubset( - release=release, - latest_revision_number=latest_revision_number, + checks=attestable_checks, info=info, match_ignore=match_ignore, ) @@ -137,23 +139,13 @@ class GeneralPublic: await self.__blocker(cs) async def __blocker(self, cs: types.ChecksSubset) -> None: - blocker = await self.__data.check_result( - release_name=cs.release.name, - revision_number=cs.latest_revision_number, - member_rel_path=None, - status=sql.CheckResultStatus.BLOCKER, - ).all() + blocker = [cr for cr in cs.checks if cr.status == sql.CheckResultStatus.BLOCKER] for result in blocker: if primary_rel_path := result.primary_rel_path: cs.info.errors.setdefault(pathlib.Path(primary_rel_path), []).append(result) async def __errors(self, cs: types.ChecksSubset) -> None: - errors = await self.__data.check_result( - release_name=cs.release.name, - revision_number=cs.latest_revision_number, - member_rel_path=None, - status=sql.CheckResultStatus.FAILURE, - ).all() + errors = [cr for cr in cs.checks if cr.status == sql.CheckResultStatus.FAILURE] for error in errors: if cs.match_ignore(error): cs.info.ignored_errors.append(error) @@ -162,24 +154,14 @@ class GeneralPublic: cs.info.errors.setdefault(pathlib.Path(primary_rel_path), []).append(error) async def __successes(self, cs: types.ChecksSubset) -> None: - successes = await self.__data.check_result( - release_name=cs.release.name, - revision_number=cs.latest_revision_number, - member_rel_path=None, - status=sql.CheckResultStatus.SUCCESS, - ).all() + successes = [cr for cr in cs.checks if cr.status == sql.CheckResultStatus.SUCCESS] for success in successes: # Successes cannot be ignored if primary_rel_path := success.primary_rel_path: cs.info.successes.setdefault(pathlib.Path(primary_rel_path), []).append(success) async def __warnings(self, cs: types.ChecksSubset) -> None: - warnings = await self.__data.check_result( - release_name=cs.release.name, - revision_number=cs.latest_revision_number, - member_rel_path=None, - status=sql.CheckResultStatus.WARNING, - ).all() + warnings = [cr for cr in cs.checks if cr.status == sql.CheckResultStatus.WARNING] for warning in warnings: if cs.match_ignore(warning): cs.info.ignored_warnings.append(warning) diff --git a/atr/storage/types.py b/atr/storage/types.py index 3cd74f6b..3ed5d423 100644 --- a/atr/storage/types.py +++ b/atr/storage/types.py @@ -75,8 +75,7 @@ class PathInfo(schema.Strict): @dataclasses.dataclass class ChecksSubset: - release: sql.Release - latest_revision_number: str + checks: list[sql.CheckResult] info: PathInfo match_ignore: Callable[[sql.CheckResult], bool] diff --git a/atr/storage/writers/revision.py b/atr/storage/writers/revision.py index cb273606..b2cd2a38 100644 --- a/atr/storage/writers/revision.py +++ b/atr/storage/writers/revision.py @@ -245,7 +245,7 @@ class CommitteeParticipant(FoundationCommitter): policy = release.release_policy or release.project.release_policy - await attestable.write( + await attestable.write_files_data( project_name, version_name, new_revision.number, diff --git a/atr/tasks/__init__.py b/atr/tasks/__init__.py index 18780bd3..00beb6d7 100644 --- a/atr/tasks/__init__.py +++ b/atr/tasks/__init__.py @@ -24,6 +24,7 @@ from typing import Any, Final import sqlmodel +import atr.attestable as attestable import atr.db as db import atr.hashes as hashes import atr.models.results as results @@ -68,7 +69,7 @@ async def asc_checks( signature.INPUT_POLICY_KEYS, release, revision, - checks.resolve_extra_args(signature.INPUT_EXTRA_ARGS, release), + await checks.resolve_extra_args(signature.INPUT_EXTRA_ARGS, release), ), extra_args={"committee_name": release.committee.name}, ) @@ -176,11 +177,20 @@ async def draft_checks( release, revision_number, caller_data, + check_cache_key=await checks.resolve_cache_key( + resolve(sql.TaskType.PATHS_CHECK), + paths.INPUT_POLICY_KEYS, + release, + revision_number, + await checks.resolve_extra_args(paths.INPUT_EXTRA_ARGS, release), + ignore_path=True, + ), extra_args={"is_podling": is_podling}, ) - data.add(path_check_task) - if caller_data is None: - await data.commit() + if path_check_task: + data.add(path_check_task) + if caller_data is None: + await data.commit() return len(relative_paths) @@ -294,8 +304,11 @@ async def queued( hash_val = hashes.compute_dict_hash(check_cache_key) if not data: raise RuntimeError("DB Session is required for check_cache_key") - existing = await data.check_result(inputs_hash=hash_val).all() + existing = await data.check_result(inputs_hash=hash_val, release_name=release.name).all() if existing: + await attestable.write_checks_data( + release.project.name, release.version, revision_number, [c.id for c in existing] + ) return None return sql.Task( status=sql.TaskStatus.QUEUED, @@ -382,7 +395,7 @@ async def sha_checks( hashing.INPUT_POLICY_KEYS, release, revision, - checks.resolve_extra_args(hashing.INPUT_EXTRA_ARGS, release), + await checks.resolve_extra_args(hashing.INPUT_EXTRA_ARGS, release), file=hash_file, ), ) @@ -402,7 +415,7 @@ async def tar_gz_checks( compare.INPUT_POLICY_KEYS, release, revision, - checks.resolve_extra_args(compare.INPUT_EXTRA_ARGS, release), + await checks.resolve_extra_args(compare.INPUT_EXTRA_ARGS, release), file=path, ) license_h_ck = await checks.resolve_cache_key( @@ -410,7 +423,7 @@ async def tar_gz_checks( license.INPUT_POLICY_KEYS, release, revision, - checks.resolve_extra_args(license.INPUT_EXTRA_ARGS, release), + await checks.resolve_extra_args(license.INPUT_EXTRA_ARGS, release), file=path, ) license_f_ck = await checks.resolve_cache_key( @@ -418,7 +431,7 @@ async def tar_gz_checks( license.INPUT_POLICY_KEYS, release, revision, - checks.resolve_extra_args(license.INPUT_EXTRA_ARGS, release), + await checks.resolve_extra_args(license.INPUT_EXTRA_ARGS, release), file=path, ) rat_ck = await checks.resolve_cache_key( @@ -426,7 +439,7 @@ async def tar_gz_checks( rat.INPUT_POLICY_KEYS, release, revision, - checks.resolve_extra_args(rat.INPUT_EXTRA_ARGS, release), + await checks.resolve_extra_args(rat.INPUT_EXTRA_ARGS, release), file=path, ) targz_i_ck = await checks.resolve_cache_key( @@ -434,7 +447,7 @@ async def tar_gz_checks( targz.INPUT_POLICY_KEYS, release, revision, - checks.resolve_extra_args(targz.INPUT_EXTRA_ARGS, release), + await checks.resolve_extra_args(targz.INPUT_EXTRA_ARGS, release), file=path, ) targz_s_ck = await checks.resolve_cache_key( @@ -442,7 +455,7 @@ async def tar_gz_checks( targz.INPUT_POLICY_KEYS, release, revision, - checks.resolve_extra_args(targz.INPUT_EXTRA_ARGS, release), + await checks.resolve_extra_args(targz.INPUT_EXTRA_ARGS, release), file=path, ) tasks = [ @@ -504,7 +517,7 @@ async def zip_checks( compare.INPUT_POLICY_KEYS, release, revision, - checks.resolve_extra_args(compare.INPUT_EXTRA_ARGS, release), + await checks.resolve_extra_args(compare.INPUT_EXTRA_ARGS, release), file=path, ) license_h_ck = await checks.resolve_cache_key( @@ -512,7 +525,7 @@ async def zip_checks( license.INPUT_POLICY_KEYS, release, revision, - checks.resolve_extra_args(license.INPUT_EXTRA_ARGS, release), + await checks.resolve_extra_args(license.INPUT_EXTRA_ARGS, release), file=path, ) license_f_ck = await checks.resolve_cache_key( @@ -520,7 +533,7 @@ async def zip_checks( license.INPUT_POLICY_KEYS, release, revision, - checks.resolve_extra_args(license.INPUT_EXTRA_ARGS, release), + await checks.resolve_extra_args(license.INPUT_EXTRA_ARGS, release), file=path, ) rat_ck = await checks.resolve_cache_key( @@ -528,7 +541,7 @@ async def zip_checks( rat.INPUT_POLICY_KEYS, release, revision, - checks.resolve_extra_args(rat.INPUT_EXTRA_ARGS, release), + await checks.resolve_extra_args(rat.INPUT_EXTRA_ARGS, release), file=path, ) zip_i_ck = await checks.resolve_cache_key( @@ -536,7 +549,7 @@ async def zip_checks( zipformat.INPUT_POLICY_KEYS, release, revision, - checks.resolve_extra_args(zipformat.INPUT_EXTRA_ARGS, release), + await checks.resolve_extra_args(zipformat.INPUT_EXTRA_ARGS, release), file=path, ) zip_s_ck = await checks.resolve_cache_key( @@ -544,7 +557,7 @@ async def zip_checks( zipformat.INPUT_POLICY_KEYS, release, revision, - checks.resolve_extra_args(zipformat.INPUT_EXTRA_ARGS, release), + await checks.resolve_extra_args(zipformat.INPUT_EXTRA_ARGS, release), file=path, ) diff --git a/atr/tasks/checks/__init__.py b/atr/tasks/checks/__init__.py index add95588..8b0d2830 100644 --- a/atr/tasks/checks/__init__.py +++ b/atr/tasks/checks/__init__.py @@ -118,7 +118,6 @@ class Recorder: data: Any, primary_rel_path: str | None = None, member_rel_path: str | None = None, - inputs_hash: str | None = None, ) -> sql.CheckResult: if self.constructed is False: raise RuntimeError("Cannot add check result to a recorder that has not been constructed") @@ -144,7 +143,7 @@ class Recorder: message=message, data=data, cached=False, - inputs_hash=inputs_hash or self.__input_hash, + inputs_hash=self.input_hash, ) # It would be more efficient to keep a session open @@ -198,11 +197,10 @@ class Recorder: abs_path = await self.abs_path() return matches(str(abs_path)) - async def cache_key_set(self, policy_keys: list[str], input_args: dict[str, Any] | None = None) -> bool: + async def cache_key_set( + self, policy_keys: list[str], input_args: list[str] | None = None, checker: str | None = None + ) -> bool: # TODO: Should this just be in the constructor? - path = await self.abs_path() - if (not path) or (not await aiofiles.os.path.isfile(path)): - return False if config.get().DISABLE_CHECK_CACHE: return False @@ -216,10 +214,11 @@ class Recorder: async with db.session() as data: release = await data.release( - name=self.release_name, _release_policy=True, _project_release_policy=True + name=self.release_name, _release_policy=True, _project_release_policy=True, _project=True ).demand(RuntimeError(f"Release {self.release_name} not found")) + args = await resolve_extra_args(input_args or [], release) cache_key = await resolve_cache_key( - self.checker, policy_keys, release, self.revision_number, input_args, file=self.primary_rel_path + checker or self.checker, policy_keys, release, self.revision_number, args, file=self.primary_rel_path ) self.__input_hash = hashes.compute_dict_hash(cache_key) if cache_key else None return True @@ -250,16 +249,16 @@ class Recorder: data: Any, primary_rel_path: str | None = None, member_rel_path: str | None = None, - inputs_hash: str | None = None, ) -> sql.CheckResult: - return await self._add( + result = await self._add( sql.CheckResultStatus.BLOCKER, message, data, primary_rel_path=primary_rel_path, member_rel_path=member_rel_path, - inputs_hash=inputs_hash, ) + await attestable.write_checks_data(self.project_name, self.version_name, self.revision_number, [result.id]) + return result async def exception( self, @@ -267,16 +266,16 @@ class Recorder: data: Any, primary_rel_path: str | None = None, member_rel_path: str | None = None, - inputs_hash: str | None = None, ) -> sql.CheckResult: - return await self._add( + result = await self._add( sql.CheckResultStatus.EXCEPTION, message, data, primary_rel_path=primary_rel_path, member_rel_path=member_rel_path, - inputs_hash=inputs_hash, ) + await attestable.write_checks_data(self.project_name, self.version_name, self.revision_number, [result.id]) + return result async def failure( self, @@ -284,16 +283,16 @@ class Recorder: data: Any, primary_rel_path: str | None = None, member_rel_path: str | None = None, - inputs_hash: str | None = None, ) -> sql.CheckResult: - return await self._add( + result = await self._add( sql.CheckResultStatus.FAILURE, message, data, primary_rel_path=primary_rel_path, member_rel_path=member_rel_path, - inputs_hash=inputs_hash, ) + await attestable.write_checks_data(self.project_name, self.version_name, self.revision_number, [result.id]) + return result async def success( self, @@ -301,16 +300,16 @@ class Recorder: data: Any, primary_rel_path: str | None = None, member_rel_path: str | None = None, - inputs_hash: str | None = None, ) -> sql.CheckResult: - return await self._add( + result = await self._add( sql.CheckResultStatus.SUCCESS, message, data, primary_rel_path=primary_rel_path, member_rel_path=member_rel_path, - inputs_hash=inputs_hash, ) + await attestable.write_checks_data(self.project_name, self.version_name, self.revision_number, [result.id]) + return result async def use_check_cache(self) -> bool: if self.__use_check_cache is not None: @@ -330,16 +329,16 @@ class Recorder: data: Any, primary_rel_path: str | None = None, member_rel_path: str | None = None, - inputs_hash: str | None = None, ) -> sql.CheckResult: - return await self._add( + result = await self._add( sql.CheckResultStatus.WARNING, message, data, primary_rel_path=primary_rel_path, member_rel_path=member_rel_path, - inputs_hash=inputs_hash, ) + await attestable.write_checks_data(self.project_name, self.version_name, self.revision_number, [result.id]) + return result def function_key(func: Callable[..., Any] | str) -> str: @@ -354,23 +353,26 @@ async def resolve_cache_key( args: dict[str, Any] | None = None, file: str | None = None, path: pathlib.Path | None = None, + ignore_path: bool = False, ) -> dict[str, Any] | None: - if file is None and path is None: - raise ValueError("Must specify either file or path") if not args: args = {} + cache_key = {"checker": function_key(checker)} + file_hash = None attestable_data = await attestable.load(release.project_name, release.version, revision) if attestable_data: policy = sql.ReleasePolicy.model_validate(attestable_data.policy) - file_hash = attestable_data.paths[file or ""] + if not ignore_path: + file_hash = attestable_data.paths.get(file) if file else None else: # TODO: Is this fallback valid / necessary? Or should we bail out if there's no attestable data? policy = release.release_policy or release.project.release_policy - if path is None: - # We know file isn't None here but type checker doesn't - path = file_paths.revision_path_for_file(release.project_name, release.version, revision, file or "") - file_hash = await hashes.compute_file_hash(path) - cache_key = {"file_hash": file_hash, "checker": function_key(checker)} + if not ignore_path: + if path is None: + path = file_paths.revision_path_for_file(release.project_name, release.version, revision, file or "") + file_hash = await hashes.compute_file_hash(path) + if file_hash: + cache_key["file_hash"] = file_hash if len(policy_keys) > 0 and policy is not None: policy_dict = policy.model_dump(exclude_none=True) @@ -379,7 +381,7 @@ async def resolve_cache_key( return {**cache_key, **args} -def resolve_extra_args(arg_names: list[str], release: sql.Release) -> dict[str, Any]: +async def resolve_extra_args(arg_names: list[str], release: sql.Release) -> dict[str, Any]: result: dict[str, Any] = {} for name in arg_names: resolver = _EXTRA_ARG_RESOLVERS.get(name, None) @@ -387,7 +389,7 @@ def resolve_extra_args(arg_names: list[str], release: sql.Release) -> dict[str, if resolver is None: log.warning(f"Unknown extra arg resolver: {name}") return {} - result[name] = resolver(release) + result[name] = await resolver(release) return result @@ -405,17 +407,36 @@ def with_model(cls: type[schema.Strict]) -> Callable[[Callable[..., Any]], Calla return decorator -def _resolve_is_podling(release: sql.Release) -> bool: +async def _resolve_all_files(release: sql.Release) -> list[str]: + if not release.latest_revision_number: + return [] + if not ( + base_path := file_paths.base_path_for_revision( + release.project_name, release.version, release.latest_revision_number + ) + ): + return [] + + if not await aiofiles.os.path.isdir(base_path): + log.error(f"Base release directory does not exist or is not a directory: {base_path}") + return [] + relative_paths = [p async for p in util.paths_recursive(base_path)] + relative_paths_set = set(str(p) for p in relative_paths) + return list(relative_paths_set) + + +async def _resolve_is_podling(release: sql.Release) -> bool: return (release.committee is not None) and release.committee.is_podling -def _resolve_committee_name(release: sql.Release) -> str: +async def _resolve_committee_name(release: sql.Release) -> str: if release.committee is None: raise ValueError("Release has no committee") return release.committee.name _EXTRA_ARG_RESOLVERS: Final[dict[str, Callable[[sql.Release], Any]]] = { + "all_files": _resolve_all_files, "is_podling": _resolve_is_podling, "committee_name": _resolve_committee_name, } diff --git a/atr/tasks/checks/compare.py b/atr/tasks/checks/compare.py index 54c98ecf..162c9673 100644 --- a/atr/tasks/checks/compare.py +++ b/atr/tasks/checks/compare.py @@ -93,7 +93,7 @@ async def source_trees(args: checks.FunctionArguments) -> results.Results | None ) return None - await recorder.cache_key_set(INPUT_POLICY_KEYS) + await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) payload = await _load_tp_payload(args.project_name, args.version_name, args.revision_number) checkout_dir: str | None = None diff --git a/atr/tasks/checks/hashing.py b/atr/tasks/checks/hashing.py index 19d4fcbf..cb3c8f77 100644 --- a/atr/tasks/checks/hashing.py +++ b/atr/tasks/checks/hashing.py @@ -41,7 +41,7 @@ async def check(args: checks.FunctionArguments) -> results.Results | None: await recorder.failure("Unsupported hash algorithm", {"algorithm": algorithm}) return None - await recorder.cache_key_set(INPUT_POLICY_KEYS) + await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) # Remove the hash file suffix to get the artifact path # This replaces the last suffix, which is what we want diff --git a/atr/tasks/checks/license.py b/atr/tasks/checks/license.py index 0e09cd23..49b20a7a 100644 --- a/atr/tasks/checks/license.py +++ b/atr/tasks/checks/license.py @@ -139,7 +139,7 @@ async def files(args: checks.FunctionArguments) -> results.Results | None: return None is_podling = args.extra_args.get("is_podling", False) - await recorder.cache_key_set(INPUT_POLICY_KEYS, {"is_podling": is_podling}) + await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) log.info(f"Checking license files for {artifact_abs_path} (rel: {args.primary_rel_path})") @@ -172,7 +172,7 @@ async def headers(args: checks.FunctionArguments) -> results.Results | None: if project.policy_license_check_mode == sql.LicenseCheckMode.RAT: return None - await recorder.cache_key_set(INPUT_POLICY_KEYS) + await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) # if await recorder.check_cache(artifact_abs_path): # log.info(f"Using cached license headers result for {artifact_abs_path} (rel: {args.primary_rel_path})") diff --git a/atr/tasks/checks/paths.py b/atr/tasks/checks/paths.py index 39d0cfe3..d37fc66b 100644 --- a/atr/tasks/checks/paths.py +++ b/atr/tasks/checks/paths.py @@ -23,7 +23,6 @@ from typing import Final import aiofiles.os import atr.analysis as analysis -import atr.hashes as hashes import atr.log as log import atr.models.results as results import atr.tasks.checks as checks @@ -40,7 +39,7 @@ _ALLOWED_TOP_LEVEL: Final = frozenset( ) # Release policy fields which this check relies on - used for result caching INPUT_POLICY_KEYS: Final[list[str]] = [] -INPUT_EXTRA_ARGS: Final[list[str]] = ["is_podling"] +INPUT_EXTRA_ARGS: Final[list[str]] = ["is_podling", "all_files"] async def check(args: checks.FunctionArguments) -> results.Results | None: @@ -89,6 +88,11 @@ async def check(args: checks.FunctionArguments) -> results.Results | None: is_podling = args.extra_args.get("is_podling", False) relative_paths = [p async for p in util.paths_recursive(base_path)] relative_paths_set = set(str(p) for p in relative_paths) + + await recorder_errors.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS, checker=checks.function_key(check)) + await recorder_warnings.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS, checker=checks.function_key(check)) + await recorder_success.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS, checker=checks.function_key(check)) + for relative_path in relative_paths: # Delegate processing of each path to the helper function await _check_path_process_single( @@ -196,9 +200,6 @@ async def _check_path_process_single( # noqa: C901 full_path = base_path / relative_path relative_path_str = str(relative_path) - file_hash = await hashes.compute_file_hash(full_path) - inputs_hash = hashes.compute_dict_hash({"file_hash": file_hash, "is_podling": is_podling}) - # For debugging and testing if (await user.is_admin_async(asf_uid)) and (full_path.name == "deliberately_slow_ATR_task_filename.txt"): await asyncio.sleep(20) @@ -283,7 +284,6 @@ async def _check_path_process_single( # noqa: C901 errors, blockers, warnings, - inputs_hash, ) @@ -295,18 +295,16 @@ async def _record( errors: list[str], blockers: list[str], warnings: list[str], - inputs_hash: str, ) -> None: for error in errors: - await recorder_errors.failure(error, {}, primary_rel_path=relative_path_str, inputs_hash=inputs_hash) + await recorder_errors.failure(f"{relative_path_str}: {error}", {}, primary_rel_path=relative_path_str) for item in blockers: - await recorder_errors.blocker(item, {}, primary_rel_path=relative_path_str, inputs_hash=inputs_hash) + await recorder_errors.blocker(f"{relative_path_str}: {item}", {}, primary_rel_path=relative_path_str) for warning in warnings: - await recorder_warnings.warning(warning, {}, primary_rel_path=relative_path_str, inputs_hash=inputs_hash) + await recorder_warnings.warning(f"{relative_path_str}: {warning}", {}, primary_rel_path=relative_path_str) if not (errors or blockers or warnings): await recorder_success.success( - "Path structure and naming conventions conform to policy", + f"{relative_path_str}: Path structure and naming conventions conform to policy", {}, primary_rel_path=relative_path_str, - inputs_hash=inputs_hash, ) diff --git a/atr/tasks/checks/rat.py b/atr/tasks/checks/rat.py index 0b374a71..9c065ec7 100644 --- a/atr/tasks/checks/rat.py +++ b/atr/tasks/checks/rat.py @@ -88,7 +88,7 @@ async def check(args: checks.FunctionArguments) -> results.Results | None: log.info(f"Skipping RAT check for {artifact_abs_path} (mode is LIGHTWEIGHT)") return None - await recorder.cache_key_set(INPUT_POLICY_KEYS) + await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) log.info(f"Checking RAT licenses for {artifact_abs_path} (rel: {args.primary_rel_path})") diff --git a/atr/tasks/checks/signature.py b/atr/tasks/checks/signature.py index 6405d137..830d9cfe 100644 --- a/atr/tasks/checks/signature.py +++ b/atr/tasks/checks/signature.py @@ -54,7 +54,7 @@ async def check(args: checks.FunctionArguments) -> results.Results | None: await recorder.exception("Committee name is required", {"committee_name": committee_name}) return None - await recorder.cache_key_set(INPUT_POLICY_KEYS, {"committee_name": committee_name}) + await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) log.info( f"Checking signature {primary_abs_path} for {artifact_abs_path}" diff --git a/atr/tasks/checks/targz.py b/atr/tasks/checks/targz.py index e4c83fb4..7b369e5e 100644 --- a/atr/tasks/checks/targz.py +++ b/atr/tasks/checks/targz.py @@ -42,7 +42,7 @@ async def integrity(args: checks.FunctionArguments) -> results.Results | None: if not (artifact_abs_path := await recorder.abs_path()): return None - await recorder.cache_key_set(INPUT_POLICY_KEYS) + await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) log.info(f"Checking integrity for {artifact_abs_path} (rel: {args.primary_rel_path})") @@ -101,7 +101,7 @@ async def structure(args: checks.FunctionArguments) -> results.Results | None: if await recorder.primary_path_is_binary(): return None - await recorder.cache_key_set(INPUT_POLICY_KEYS) + await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) filename = artifact_abs_path.name basename_from_filename: Final[str] = ( diff --git a/atr/tasks/checks/zipformat.py b/atr/tasks/checks/zipformat.py index 428a01a2..ed85f7b2 100644 --- a/atr/tasks/checks/zipformat.py +++ b/atr/tasks/checks/zipformat.py @@ -37,7 +37,7 @@ async def integrity(args: checks.FunctionArguments) -> results.Results | None: if not (artifact_abs_path := await recorder.abs_path()): return None - await recorder.cache_key_set(INPUT_POLICY_KEYS) + await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) log.info(f"Checking zip integrity for {artifact_abs_path} (rel: {args.primary_rel_path})") @@ -63,7 +63,7 @@ async def structure(args: checks.FunctionArguments) -> results.Results | None: if await recorder.primary_path_is_binary(): return None - await recorder.cache_key_set(INPUT_POLICY_KEYS) + await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) log.info(f"Checking zip structure for {artifact_abs_path} (rel: {args.primary_rel_path})") diff --git a/atr/util.py b/atr/util.py index 7ddb9788..9cb3c820 100644 --- a/atr/util.py +++ b/atr/util.py @@ -21,6 +21,7 @@ import binascii import contextlib import dataclasses import datetime +import fcntl import hashlib import json import os @@ -207,6 +208,27 @@ async def atomic_write_file(file_path: pathlib.Path, content: str, encoding: str raise +async def atomic_modify_file( + file_path: pathlib.Path, + modify: Callable[[str], str], +) -> None: + # This function assumes that file_path already exists and its a regular file + lock_path = file_path.with_suffix(file_path.suffix + ".lock") + lock_fd = await asyncio.to_thread(os.open, str(lock_path), os.O_CREAT | os.O_RDWR) + try: + await asyncio.to_thread(fcntl.flock, lock_fd, fcntl.LOCK_EX) + try: + async with aiofiles.open(file_path, encoding="utf-8") as rf: + old_value = await rf.read() + new_value = modify(old_value) + if new_value != old_value: + await atomic_write_file(file_path, new_value) + finally: + fcntl.flock(lock_fd, fcntl.LOCK_UN) + finally: + await asyncio.to_thread(os.close, lock_fd) + + def chmod_directories(path: pathlib.Path, permissions: int = DIRECTORY_PERMISSIONS) -> None: # codeql[py/overly-permissive-file] os.chmod(path, permissions) diff --git a/tests/unit/recorders.py b/tests/unit/recorders.py index 47c772eb..bdb20dcf 100644 --- a/tests/unit/recorders.py +++ b/tests/unit/recorders.py @@ -18,6 +18,7 @@ import datetime import pathlib from collections.abc import Awaitable, Callable +from typing import Any import atr.models.sql as sql import atr.tasks.checks as checks @@ -40,6 +41,11 @@ class RecorderStub(checks.Recorder): async def abs_path(self, rel_path: str | None = None) -> pathlib.Path | None: return self._path if (rel_path is None) else self._path / rel_path + async def cache_key_set( + self, policy_keys: list[str], input_args: list[str] | None = None, checker: str | None = None + ) -> bool: + return False + async def primary_path_is_binary(self) -> bool: return False @@ -66,6 +72,26 @@ class RecorderStub(checks.Recorder): inputs_hash=None, ) + async def exception( + self, message: str, data: Any, primary_rel_path: str | None = None, member_rel_path: str | None = None + ) -> sql.CheckResult: + return await self._add(sql.CheckResultStatus.EXCEPTION, message, data, primary_rel_path, member_rel_path) + + async def failure( + self, message: str, data: Any, primary_rel_path: str | None = None, member_rel_path: str | None = None + ) -> sql.CheckResult: + return await self._add(sql.CheckResultStatus.FAILURE, message, data, primary_rel_path, member_rel_path) + + async def success( + self, message: str, data: Any, primary_rel_path: str | None = None, member_rel_path: str | None = None + ) -> sql.CheckResult: + return await self._add(sql.CheckResultStatus.SUCCESS, message, data, primary_rel_path, member_rel_path) + + async def warning( + self, message: str, data: Any, primary_rel_path: str | None = None, member_rel_path: str | None = None + ) -> sql.CheckResult: + return await self._add(sql.CheckResultStatus.WARNING, message, data, primary_rel_path, member_rel_path) + def get_recorder(recorder: checks.Recorder) -> Callable[[], Awaitable[checks.Recorder]]: async def _recorder() -> checks.Recorder: diff --git a/tests/unit/test_cache.py b/tests/unit/test_cache.py index a1657783..12bf275b 100644 --- a/tests/unit/test_cache.py +++ b/tests/unit/test_cache.py @@ -115,17 +115,17 @@ async def test_admins_get_async_uses_extensions_when_available(mock_app: MockApp assert result == frozenset({"async_alice"}) -def test_admins_get_returns_empty_frozenset_when_not_set(mock_app: MockApp): - result = cache.admins_get() - assert result == frozenset() - - def test_admins_get_returns_frozenset_from_extensions(mock_app: MockApp): mock_app.extensions["admins"] = frozenset({"alice", "bob"}) result = cache.admins_get() assert result == frozenset({"alice", "bob"}) +def test_admins_get_returns_empty_frozenset_when_not_set(mock_app: MockApp): + result = cache.admins_get() + assert result == frozenset() + + @pytest.mark.asyncio async def test_admins_read_from_file_returns_none_for_invalid_json(state_dir: pathlib.Path): cache_path = state_dir / "cache" / "admins.json" diff --git a/tests/unit/test_checks_compare.py b/tests/unit/test_checks_compare.py index 5a90a029..e90e3314 100644 --- a/tests/unit/test_checks_compare.py +++ b/tests/unit/test_checks_compare.py @@ -242,6 +242,11 @@ class RecorderStub(atr.tasks.checks.Recorder): self.success_calls: list[tuple[str, object]] = [] self._is_source = is_source + async def cache_key_set( + self, policy_keys: list[str], input_args: list[str] | None = None, checker: str | None = None + ) -> bool: + return False + async def primary_path_is_source(self) -> bool: return self._is_source --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
