This is an automated email from the ASF dual-hosted git repository. arm pushed a commit to branch check_caching in repository https://gitbox.apache.org/repos/asf/tooling-trusted-releases.git
commit a449e4cbfe3b81c41f681c2e2cdc9650c9297437 Author: Alastair McFarlane <[email protected]> AuthorDate: Fri Feb 13 14:53:31 2026 +0000 Read and write checks to/from attestable data --- atr/attestable.py | 52 ++++++++++++++++++++++++++++++++++++++--- atr/db/__init__.py | 5 ++++ atr/models/attestable.py | 5 ++++ atr/storage/readers/checks.py | 31 +++++++++++++----------- atr/storage/readers/releases.py | 35 +++++++-------------------- atr/storage/types.py | 3 +-- atr/storage/writers/revision.py | 2 +- atr/tasks/__init__.py | 6 ++++- atr/tasks/checks/__init__.py | 27 ++++++++++++++------- atr/tasks/checks/compare.py | 2 +- atr/tasks/checks/hashing.py | 2 +- atr/tasks/checks/license.py | 4 ++-- atr/tasks/checks/paths.py | 20 ++++++++++++---- atr/tasks/checks/rat.py | 2 +- atr/tasks/checks/signature.py | 2 +- atr/tasks/checks/targz.py | 4 ++-- atr/tasks/checks/zipformat.py | 4 ++-- atr/util.py | 22 +++++++++++++++++ 18 files changed, 160 insertions(+), 68 deletions(-) diff --git a/atr/attestable.py b/atr/attestable.py index b2512279..50260093 100644 --- a/atr/attestable.py +++ b/atr/attestable.py @@ -28,6 +28,7 @@ import atr.hashes as hashes import atr.log as log import atr.models.attestable as models import atr.util as util +from atr.models.attestable import AttestableChecksV1 if TYPE_CHECKING: import pathlib @@ -45,6 +46,10 @@ def github_tp_payload_path(project_name: str, version_name: str, revision_number return util.get_attestable_dir() / project_name / version_name / f"{revision_number}.github-tp.json" +def attestable_checks_path(project_name: str, version_name: str, revision_number: str) -> pathlib.Path: + return util.get_attestable_dir() / project_name / version_name / f"{revision_number}.checks.json" + + async def github_tp_payload_write( project_name: str, version_name: str, revision_number: str, github_payload: dict[str, Any] ) -> None: @@ -89,6 +94,22 @@ async def load_paths( return None +async def load_checks( + project_name: str, + version_name: str, + revision_number: str, +) -> list[int] | None: + file_path = attestable_checks_path(project_name, version_name, revision_number) + if await aiofiles.os.path.isfile(file_path): + try: + async with aiofiles.open(file_path, encoding="utf-8") as f: + data = json.loads(await f.read()) + return models.AttestableChecksV1.model_validate(data).checks + except (json.JSONDecodeError, pydantic.ValidationError) as e: + log.warning(f"Could not parse {file_path}: {e}") + return [] + + def migrate_to_paths_files() -> int: attestable_dir = util.get_attestable_dir() if not attestable_dir.is_dir(): @@ -135,7 +156,7 @@ async def paths_to_hashes_and_sizes(directory: pathlib.Path) -> tuple[dict[str, return path_to_hash, path_to_size -async def write( +async def write_files_data( project_name: str, version_name: str, revision_number: str, @@ -145,12 +166,37 @@ async def write( path_to_hash: dict[str, str], path_to_size: dict[str, int], ) -> None: - result = _generate(path_to_hash, path_to_size, revision_number, release_policy, uploader_uid, previous) + result = _generate_files_data(path_to_hash, path_to_size, revision_number, release_policy, uploader_uid, previous) file_path = attestable_path(project_name, version_name, revision_number) await util.atomic_write_file(file_path, result.model_dump_json(indent=2)) paths_result = models.AttestablePathsV1(paths=result.paths) paths_file_path = attestable_paths_path(project_name, version_name, revision_number) await util.atomic_write_file(paths_file_path, paths_result.model_dump_json(indent=2)) + checks_file_path = attestable_checks_path(project_name, version_name, revision_number) + if not checks_file_path.exists(): + async with aiofiles.open(checks_file_path, "w", encoding="utf-8") as f: + await f.write(models.AttestableChecksV1().model_dump_json(indent=2)) + + +async def write_checks_data( + project_name: str, + version_name: str, + revision_number: str, + checks: list[int], +) -> None: + log.info(f"Writing checks for {project_name}/{version_name}/{revision_number}: {checks}") + + def modify(content: str) -> str: + try: + current = AttestableChecksV1.model_validate_json(content).checks + except pydantic.ValidationError: + current = [] + new_checks = set(current or []) + new_checks.update(checks) + result = models.AttestableChecksV1(checks=sorted(new_checks)) + return result.model_dump_json(indent=2) + + await util.atomic_modify_file(attestable_checks_path(project_name, version_name, revision_number), modify) def _compute_hashes_with_attribution( @@ -188,7 +234,7 @@ def _compute_hashes_with_attribution( return new_hashes -def _generate( +def _generate_files_data( path_to_hash: dict[str, str], path_to_size: dict[str, int], revision_number: str, diff --git a/atr/db/__init__.py b/atr/db/__init__.py index ed4a76a9..2c6d579e 100644 --- a/atr/db/__init__.py +++ b/atr/db/__init__.py @@ -155,6 +155,7 @@ class Session(sqlalchemy.ext.asyncio.AsyncSession): def check_result( self, id: Opt[int] = NOT_SET, + id_in: Opt[list[int]] = NOT_SET, release_name: Opt[str] = NOT_SET, revision_number: Opt[str] = NOT_SET, checker: Opt[str] = NOT_SET, @@ -169,8 +170,12 @@ class Session(sqlalchemy.ext.asyncio.AsyncSession): ) -> Query[sql.CheckResult]: query = sqlmodel.select(sql.CheckResult) + via = sql.validate_instrumented_attribute + if is_defined(id): query = query.where(sql.CheckResult.id == id) + if is_defined(id_in): + query = query.where(via(sql.CheckResult.id).in_(id_in)) if is_defined(release_name): query = query.where(sql.CheckResult.release_name == release_name) if is_defined(revision_number): diff --git a/atr/models/attestable.py b/atr/models/attestable.py index 45e3ac3d..2af1cc2f 100644 --- a/atr/models/attestable.py +++ b/atr/models/attestable.py @@ -27,6 +27,11 @@ class HashEntry(schema.Strict): uploaders: list[Annotated[tuple[str, str], pydantic.BeforeValidator(tuple)]] +class AttestableChecksV1(schema.Strict): + version: Literal[1] = 1 + checks: list[int] = schema.factory(list) + + class AttestablePathsV1(schema.Strict): version: Literal[1] = 1 paths: dict[str, str] = schema.factory(dict) diff --git a/atr/storage/readers/checks.py b/atr/storage/readers/checks.py index f9b74301..4edbb7ce 100644 --- a/atr/storage/readers/checks.py +++ b/atr/storage/readers/checks.py @@ -21,6 +21,7 @@ from __future__ import annotations import importlib from typing import TYPE_CHECKING +import atr.attestable as attestable import atr.db as db import atr.hashes as hashes import atr.models.sql as sql @@ -44,8 +45,8 @@ async def _filter_check_results_by_hash( if release.latest_revision_number is None: raise ValueError("Release has no revision - Invalid state") for cr in all_check_results: - module_path = cr.checker.rsplit(".", 1)[0] - if module_path not in input_hash_by_module: + if cr.checker not in input_hash_by_module: + module_path = cr.checker.rsplit(".", 1)[0] try: module = importlib.import_module(module_path) policy_keys: list[str] = module.INPUT_POLICY_KEYS @@ -57,9 +58,9 @@ async def _filter_check_results_by_hash( cache_key = await checks.resolve_cache_key( cr.checker, policy_keys, release, release.latest_revision_number, extra_args, file=rel_path.name ) - input_hash_by_module[module_path] = hashes.compute_dict_hash(cache_key) if cache_key else None + input_hash_by_module[cr.checker] = hashes.compute_dict_hash(cache_key) if cache_key else None - if cr.inputs_hash == input_hash_by_module[module_path]: + if cr.inputs_hash == input_hash_by_module[cr.checker]: filtered_check_results.append(cr) return filtered_check_results @@ -81,16 +82,20 @@ class GeneralPublic: if release.latest_revision_number is None: raise ValueError("Release has no revision - Invalid state") - # TODO: Is this potentially too much data? Within a revision I hope it's not too bad? - - query = self.__data.check_result( - release_name=release.name, - primary_rel_path=str(rel_path), - ).order_by( - sql.validate_instrumented_attribute(sql.CheckResult.checker).asc(), - sql.validate_instrumented_attribute(sql.CheckResult.created).desc(), + check_ids = await attestable.load_checks(release.project_name, release.version, release.latest_revision_number) + all_check_results = ( + [ + a + for a in await self.__data.check_result(id_in=check_ids) + .order_by( + sql.validate_instrumented_attribute(sql.CheckResult.checker).asc(), + sql.validate_instrumented_attribute(sql.CheckResult.created).desc(), + ) + .all() + ] + if check_ids + else [] ) - all_check_results = await query.all() # Filter to checks for the current file version / policy # Cache the computed input hash per checker module, since all results here share the same file and release diff --git a/atr/storage/readers/releases.py b/atr/storage/readers/releases.py index c4969aca..1a6f090a 100644 --- a/atr/storage/readers/releases.py +++ b/atr/storage/readers/releases.py @@ -21,6 +21,8 @@ from __future__ import annotations import dataclasses import pathlib +import atr.analysis as analysis +import atr.attestable as attestable import atr.classify as classify import atr.db as db import atr.models.sql as sql @@ -122,10 +124,11 @@ class GeneralPublic: self, release: sql.Release, latest_revision_number: str, info: types.PathInfo ) -> None: match_ignore = await self.__read_as.checks.ignores_matcher(release.project_name) + check_ids = await attestable.load_checks(release.project_name, release.version, latest_revision_number) + attestable_checks = [a for a in await self.__data.check_result(id_in=check_ids).all()] if check_ids else [] cs = types.ChecksSubset( - release=release, - latest_revision_number=latest_revision_number, + checks=attestable_checks, info=info, match_ignore=match_ignore, ) @@ -137,23 +140,13 @@ class GeneralPublic: await self.__blocker(cs) async def __blocker(self, cs: types.ChecksSubset) -> None: - blocker = await self.__data.check_result( - release_name=cs.release.name, - revision_number=cs.latest_revision_number, - member_rel_path=None, - status=sql.CheckResultStatus.BLOCKER, - ).all() + blocker = [cr for cr in cs.checks if cr.status == sql.CheckResultStatus.BLOCKER] for result in blocker: if primary_rel_path := result.primary_rel_path: cs.info.errors.setdefault(pathlib.Path(primary_rel_path), []).append(result) async def __errors(self, cs: types.ChecksSubset) -> None: - errors = await self.__data.check_result( - release_name=cs.release.name, - revision_number=cs.latest_revision_number, - member_rel_path=None, - status=sql.CheckResultStatus.FAILURE, - ).all() + errors = [cr for cr in cs.checks if cr.status == sql.CheckResultStatus.FAILURE] for error in errors: if cs.match_ignore(error): cs.info.ignored_errors.append(error) @@ -162,24 +155,14 @@ class GeneralPublic: cs.info.errors.setdefault(pathlib.Path(primary_rel_path), []).append(error) async def __successes(self, cs: types.ChecksSubset) -> None: - successes = await self.__data.check_result( - release_name=cs.release.name, - revision_number=cs.latest_revision_number, - member_rel_path=None, - status=sql.CheckResultStatus.SUCCESS, - ).all() + successes = [cr for cr in cs.checks if cr.status == sql.CheckResultStatus.SUCCESS] for success in successes: # Successes cannot be ignored if primary_rel_path := success.primary_rel_path: cs.info.successes.setdefault(pathlib.Path(primary_rel_path), []).append(success) async def __warnings(self, cs: types.ChecksSubset) -> None: - warnings = await self.__data.check_result( - release_name=cs.release.name, - revision_number=cs.latest_revision_number, - member_rel_path=None, - status=sql.CheckResultStatus.WARNING, - ).all() + warnings = [cr for cr in cs.checks if cr.status == sql.CheckResultStatus.WARNING] for warning in warnings: if cs.match_ignore(warning): cs.info.ignored_warnings.append(warning) diff --git a/atr/storage/types.py b/atr/storage/types.py index 3cd74f6b..3ed5d423 100644 --- a/atr/storage/types.py +++ b/atr/storage/types.py @@ -75,8 +75,7 @@ class PathInfo(schema.Strict): @dataclasses.dataclass class ChecksSubset: - release: sql.Release - latest_revision_number: str + checks: list[sql.CheckResult] info: PathInfo match_ignore: Callable[[sql.CheckResult], bool] diff --git a/atr/storage/writers/revision.py b/atr/storage/writers/revision.py index cb273606..b2cd2a38 100644 --- a/atr/storage/writers/revision.py +++ b/atr/storage/writers/revision.py @@ -245,7 +245,7 @@ class CommitteeParticipant(FoundationCommitter): policy = release.release_policy or release.project.release_policy - await attestable.write( + await attestable.write_files_data( project_name, version_name, new_revision.number, diff --git a/atr/tasks/__init__.py b/atr/tasks/__init__.py index 18780bd3..cb234a31 100644 --- a/atr/tasks/__init__.py +++ b/atr/tasks/__init__.py @@ -24,6 +24,7 @@ from typing import Any, Final import sqlmodel +import atr.attestable as attestable import atr.db as db import atr.hashes as hashes import atr.models.results as results @@ -294,8 +295,11 @@ async def queued( hash_val = hashes.compute_dict_hash(check_cache_key) if not data: raise RuntimeError("DB Session is required for check_cache_key") - existing = await data.check_result(inputs_hash=hash_val).all() + existing = await data.check_result(inputs_hash=hash_val, release_name=release.name).all() if existing: + await attestable.write_checks_data( + release.project.name, release.version, revision_number, [c.id for c in existing] + ) return None return sql.Task( status=sql.TaskStatus.QUEUED, diff --git a/atr/tasks/checks/__init__.py b/atr/tasks/checks/__init__.py index add95588..c5b795ee 100644 --- a/atr/tasks/checks/__init__.py +++ b/atr/tasks/checks/__init__.py @@ -198,7 +198,7 @@ class Recorder: abs_path = await self.abs_path() return matches(str(abs_path)) - async def cache_key_set(self, policy_keys: list[str], input_args: dict[str, Any] | None = None) -> bool: + async def cache_key_set(self, policy_keys: list[str], input_args: list[str] | None = None) -> bool: # TODO: Should this just be in the constructor? path = await self.abs_path() if (not path) or (not await aiofiles.os.path.isfile(path)): @@ -216,10 +216,11 @@ class Recorder: async with db.session() as data: release = await data.release( - name=self.release_name, _release_policy=True, _project_release_policy=True + name=self.release_name, _release_policy=True, _project_release_policy=True, _project=True ).demand(RuntimeError(f"Release {self.release_name} not found")) + args = resolve_extra_args(input_args or [], release) cache_key = await resolve_cache_key( - self.checker, policy_keys, release, self.revision_number, input_args, file=self.primary_rel_path + self.checker, policy_keys, release, self.revision_number, args, file=self.primary_rel_path ) self.__input_hash = hashes.compute_dict_hash(cache_key) if cache_key else None return True @@ -252,7 +253,7 @@ class Recorder: member_rel_path: str | None = None, inputs_hash: str | None = None, ) -> sql.CheckResult: - return await self._add( + result = await self._add( sql.CheckResultStatus.BLOCKER, message, data, @@ -260,6 +261,8 @@ class Recorder: member_rel_path=member_rel_path, inputs_hash=inputs_hash, ) + await attestable.write_checks_data(self.project_name, self.version_name, self.revision_number, [result.id]) + return result async def exception( self, @@ -269,7 +272,7 @@ class Recorder: member_rel_path: str | None = None, inputs_hash: str | None = None, ) -> sql.CheckResult: - return await self._add( + result = await self._add( sql.CheckResultStatus.EXCEPTION, message, data, @@ -277,6 +280,8 @@ class Recorder: member_rel_path=member_rel_path, inputs_hash=inputs_hash, ) + await attestable.write_checks_data(self.project_name, self.version_name, self.revision_number, [result.id]) + return result async def failure( self, @@ -286,7 +291,7 @@ class Recorder: member_rel_path: str | None = None, inputs_hash: str | None = None, ) -> sql.CheckResult: - return await self._add( + result = await self._add( sql.CheckResultStatus.FAILURE, message, data, @@ -294,6 +299,8 @@ class Recorder: member_rel_path=member_rel_path, inputs_hash=inputs_hash, ) + await attestable.write_checks_data(self.project_name, self.version_name, self.revision_number, [result.id]) + return result async def success( self, @@ -303,7 +310,7 @@ class Recorder: member_rel_path: str | None = None, inputs_hash: str | None = None, ) -> sql.CheckResult: - return await self._add( + result = await self._add( sql.CheckResultStatus.SUCCESS, message, data, @@ -311,6 +318,8 @@ class Recorder: member_rel_path=member_rel_path, inputs_hash=inputs_hash, ) + await attestable.write_checks_data(self.project_name, self.version_name, self.revision_number, [result.id]) + return result async def use_check_cache(self) -> bool: if self.__use_check_cache is not None: @@ -332,7 +341,7 @@ class Recorder: member_rel_path: str | None = None, inputs_hash: str | None = None, ) -> sql.CheckResult: - return await self._add( + result = await self._add( sql.CheckResultStatus.WARNING, message, data, @@ -340,6 +349,8 @@ class Recorder: member_rel_path=member_rel_path, inputs_hash=inputs_hash, ) + await attestable.write_checks_data(self.project_name, self.version_name, self.revision_number, [result.id]) + return result def function_key(func: Callable[..., Any] | str) -> str: diff --git a/atr/tasks/checks/compare.py b/atr/tasks/checks/compare.py index 54c98ecf..162c9673 100644 --- a/atr/tasks/checks/compare.py +++ b/atr/tasks/checks/compare.py @@ -93,7 +93,7 @@ async def source_trees(args: checks.FunctionArguments) -> results.Results | None ) return None - await recorder.cache_key_set(INPUT_POLICY_KEYS) + await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) payload = await _load_tp_payload(args.project_name, args.version_name, args.revision_number) checkout_dir: str | None = None diff --git a/atr/tasks/checks/hashing.py b/atr/tasks/checks/hashing.py index 19d4fcbf..cb3c8f77 100644 --- a/atr/tasks/checks/hashing.py +++ b/atr/tasks/checks/hashing.py @@ -41,7 +41,7 @@ async def check(args: checks.FunctionArguments) -> results.Results | None: await recorder.failure("Unsupported hash algorithm", {"algorithm": algorithm}) return None - await recorder.cache_key_set(INPUT_POLICY_KEYS) + await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) # Remove the hash file suffix to get the artifact path # This replaces the last suffix, which is what we want diff --git a/atr/tasks/checks/license.py b/atr/tasks/checks/license.py index 0e09cd23..49b20a7a 100644 --- a/atr/tasks/checks/license.py +++ b/atr/tasks/checks/license.py @@ -139,7 +139,7 @@ async def files(args: checks.FunctionArguments) -> results.Results | None: return None is_podling = args.extra_args.get("is_podling", False) - await recorder.cache_key_set(INPUT_POLICY_KEYS, {"is_podling": is_podling}) + await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) log.info(f"Checking license files for {artifact_abs_path} (rel: {args.primary_rel_path})") @@ -172,7 +172,7 @@ async def headers(args: checks.FunctionArguments) -> results.Results | None: if project.policy_license_check_mode == sql.LicenseCheckMode.RAT: return None - await recorder.cache_key_set(INPUT_POLICY_KEYS) + await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) # if await recorder.check_cache(artifact_abs_path): # log.info(f"Using cached license headers result for {artifact_abs_path} (rel: {args.primary_rel_path})") diff --git a/atr/tasks/checks/paths.py b/atr/tasks/checks/paths.py index 39d0cfe3..16bd10bc 100644 --- a/atr/tasks/checks/paths.py +++ b/atr/tasks/checks/paths.py @@ -18,7 +18,7 @@ import asyncio import pathlib import re -from typing import Final +from typing import Any, Final import aiofiles.os @@ -197,7 +197,7 @@ async def _check_path_process_single( # noqa: C901 relative_path_str = str(relative_path) file_hash = await hashes.compute_file_hash(full_path) - inputs_hash = hashes.compute_dict_hash({"file_hash": file_hash, "is_podling": is_podling}) + inputs_hash_key = {"file_hash": file_hash, "is_podling": is_podling} # For debugging and testing if (await user.is_admin_async(asf_uid)) and (full_path.name == "deliberately_slow_ATR_task_filename.txt"): @@ -283,7 +283,7 @@ async def _check_path_process_single( # noqa: C901 errors, blockers, warnings, - inputs_hash, + inputs_hash_key, ) @@ -295,15 +295,27 @@ async def _record( errors: list[str], blockers: list[str], warnings: list[str], - inputs_hash: str, + inputs_hash_key: dict[str, Any], ) -> None: for error in errors: + hash_key = inputs_hash_key.copy() + hash_key["checker"] = recorder_errors.checker + inputs_hash = hashes.compute_dict_hash(hash_key) await recorder_errors.failure(error, {}, primary_rel_path=relative_path_str, inputs_hash=inputs_hash) for item in blockers: + hash_key = inputs_hash_key.copy() + hash_key["checker"] = recorder_errors.checker + inputs_hash = hashes.compute_dict_hash(hash_key) await recorder_errors.blocker(item, {}, primary_rel_path=relative_path_str, inputs_hash=inputs_hash) for warning in warnings: + hash_key = inputs_hash_key.copy() + hash_key["checker"] = recorder_warnings.checker + inputs_hash = hashes.compute_dict_hash(hash_key) await recorder_warnings.warning(warning, {}, primary_rel_path=relative_path_str, inputs_hash=inputs_hash) if not (errors or blockers or warnings): + hash_key = inputs_hash_key.copy() + hash_key["checker"] = recorder_success.checker + inputs_hash = hashes.compute_dict_hash(hash_key) await recorder_success.success( "Path structure and naming conventions conform to policy", {}, diff --git a/atr/tasks/checks/rat.py b/atr/tasks/checks/rat.py index 0b374a71..9c065ec7 100644 --- a/atr/tasks/checks/rat.py +++ b/atr/tasks/checks/rat.py @@ -88,7 +88,7 @@ async def check(args: checks.FunctionArguments) -> results.Results | None: log.info(f"Skipping RAT check for {artifact_abs_path} (mode is LIGHTWEIGHT)") return None - await recorder.cache_key_set(INPUT_POLICY_KEYS) + await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) log.info(f"Checking RAT licenses for {artifact_abs_path} (rel: {args.primary_rel_path})") diff --git a/atr/tasks/checks/signature.py b/atr/tasks/checks/signature.py index 6405d137..830d9cfe 100644 --- a/atr/tasks/checks/signature.py +++ b/atr/tasks/checks/signature.py @@ -54,7 +54,7 @@ async def check(args: checks.FunctionArguments) -> results.Results | None: await recorder.exception("Committee name is required", {"committee_name": committee_name}) return None - await recorder.cache_key_set(INPUT_POLICY_KEYS, {"committee_name": committee_name}) + await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) log.info( f"Checking signature {primary_abs_path} for {artifact_abs_path}" diff --git a/atr/tasks/checks/targz.py b/atr/tasks/checks/targz.py index e4c83fb4..7b369e5e 100644 --- a/atr/tasks/checks/targz.py +++ b/atr/tasks/checks/targz.py @@ -42,7 +42,7 @@ async def integrity(args: checks.FunctionArguments) -> results.Results | None: if not (artifact_abs_path := await recorder.abs_path()): return None - await recorder.cache_key_set(INPUT_POLICY_KEYS) + await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) log.info(f"Checking integrity for {artifact_abs_path} (rel: {args.primary_rel_path})") @@ -101,7 +101,7 @@ async def structure(args: checks.FunctionArguments) -> results.Results | None: if await recorder.primary_path_is_binary(): return None - await recorder.cache_key_set(INPUT_POLICY_KEYS) + await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) filename = artifact_abs_path.name basename_from_filename: Final[str] = ( diff --git a/atr/tasks/checks/zipformat.py b/atr/tasks/checks/zipformat.py index 428a01a2..ed85f7b2 100644 --- a/atr/tasks/checks/zipformat.py +++ b/atr/tasks/checks/zipformat.py @@ -37,7 +37,7 @@ async def integrity(args: checks.FunctionArguments) -> results.Results | None: if not (artifact_abs_path := await recorder.abs_path()): return None - await recorder.cache_key_set(INPUT_POLICY_KEYS) + await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) log.info(f"Checking zip integrity for {artifact_abs_path} (rel: {args.primary_rel_path})") @@ -63,7 +63,7 @@ async def structure(args: checks.FunctionArguments) -> results.Results | None: if await recorder.primary_path_is_binary(): return None - await recorder.cache_key_set(INPUT_POLICY_KEYS) + await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) log.info(f"Checking zip structure for {artifact_abs_path} (rel: {args.primary_rel_path})") diff --git a/atr/util.py b/atr/util.py index 7ddb9788..9cb3c820 100644 --- a/atr/util.py +++ b/atr/util.py @@ -21,6 +21,7 @@ import binascii import contextlib import dataclasses import datetime +import fcntl import hashlib import json import os @@ -207,6 +208,27 @@ async def atomic_write_file(file_path: pathlib.Path, content: str, encoding: str raise +async def atomic_modify_file( + file_path: pathlib.Path, + modify: Callable[[str], str], +) -> None: + # This function assumes that file_path already exists and its a regular file + lock_path = file_path.with_suffix(file_path.suffix + ".lock") + lock_fd = await asyncio.to_thread(os.open, str(lock_path), os.O_CREAT | os.O_RDWR) + try: + await asyncio.to_thread(fcntl.flock, lock_fd, fcntl.LOCK_EX) + try: + async with aiofiles.open(file_path, encoding="utf-8") as rf: + old_value = await rf.read() + new_value = modify(old_value) + if new_value != old_value: + await atomic_write_file(file_path, new_value) + finally: + fcntl.flock(lock_fd, fcntl.LOCK_UN) + finally: + await asyncio.to_thread(os.close, lock_fd) + + def chmod_directories(path: pathlib.Path, permissions: int = DIRECTORY_PERMISSIONS) -> None: # codeql[py/overly-permissive-file] os.chmod(path, permissions) --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
