This is an automated email from the ASF dual-hosted git repository. arm pushed a commit to branch check_caching in repository https://gitbox.apache.org/repos/asf/tooling-trusted-releases.git
commit 133ab8372425fe94b080cde4e37d7ec44c79b3d1 Author: Alastair McFarlane <[email protected]> AuthorDate: Thu Feb 12 17:31:16 2026 +0000 Include checker name in cache key and tidy up some code. --- atr/attestable.py | 4 +- atr/docs/tasks.md | 2 +- atr/{hashing.py => hashes.py} | 0 atr/merge.py | 7 +- atr/storage/readers/checks.py | 6 +- atr/tasks/__init__.py | 288 +++++++++++--------------- atr/tasks/checks/__init__.py | 17 +- atr/tasks/checks/{file_hash.py => hashing.py} | 0 atr/tasks/checks/paths.py | 6 +- 9 files changed, 139 insertions(+), 191 deletions(-) diff --git a/atr/attestable.py b/atr/attestable.py index e44be1fe..b2512279 100644 --- a/atr/attestable.py +++ b/atr/attestable.py @@ -24,7 +24,7 @@ import aiofiles import aiofiles.os import pydantic -import atr.hashing as hashing +import atr.hashes as hashes import atr.log as log import atr.models.attestable as models import atr.util as util @@ -130,7 +130,7 @@ async def paths_to_hashes_and_sizes(directory: pathlib.Path) -> tuple[dict[str, if "\\" in path_key: # TODO: We should centralise this, and forbid some other characters too raise ValueError(f"Backslash in path is forbidden: {path_key}") - path_to_hash[path_key] = await hashing.compute_file_hash(full_path) + path_to_hash[path_key] = await hashes.compute_file_hash(full_path) path_to_size[path_key] = (await aiofiles.os.stat(full_path)).st_size return path_to_hash, path_to_size diff --git a/atr/docs/tasks.md b/atr/docs/tasks.md index 43c8ae2b..8761a25c 100644 --- a/atr/docs/tasks.md +++ b/atr/docs/tasks.md @@ -41,7 +41,7 @@ In `atr/tasks/checks` you will find several modules that perform these check tas In `atr/tasks/__init__.py` you will see imports for existing modules where you can add an import for new check task, for example: ```python -import atr.tasks.checks.file_hash as file_hash +import atr.tasks.checks.hashing as file_hash import atr.tasks.checks.license as license ``` diff --git a/atr/hashing.py b/atr/hashes.py similarity index 100% rename from atr/hashing.py rename to atr/hashes.py diff --git a/atr/merge.py b/atr/merge.py index 75e18620..9e341b43 100644 --- a/atr/merge.py +++ b/atr/merge.py @@ -24,6 +24,7 @@ from typing import TYPE_CHECKING import aiofiles.os import atr.attestable as attestable +import atr.hashes as hashes import atr.util as util if TYPE_CHECKING: @@ -131,7 +132,7 @@ async def _add_from_prior( if (prior_hashes is not None) and (path in prior_hashes): n_hashes[path] = prior_hashes[path] else: - n_hashes[path] = await attestable.compute_file_hash(target) + n_hashes[path] = await hashes.compute_file_hash(target) stat_result = await aiofiles.os.stat(target) n_sizes[path] = stat_result.st_size return prior_hashes @@ -211,7 +212,7 @@ async def _merge_all_present( if (prior_hashes is not None) and (path in prior_hashes): p_hash = prior_hashes[path] else: - p_hash = await attestable.compute_file_hash(prior_dir / path) + p_hash = await hashes.compute_file_hash(prior_dir / path) if p_hash != b_hash: # Case 11 via hash: base and new have the same content but prior differs return await _replace_with_prior( @@ -250,7 +251,7 @@ async def _replace_with_prior( if (prior_hashes is not None) and (path in prior_hashes): n_hashes[path] = prior_hashes[path] else: - n_hashes[path] = await attestable.compute_file_hash(file_path) + n_hashes[path] = await hashes.compute_file_hash(file_path) stat_result = await aiofiles.os.stat(file_path) n_sizes[path] = stat_result.st_size return prior_hashes diff --git a/atr/storage/readers/checks.py b/atr/storage/readers/checks.py index 8c5c3e2e..f9b74301 100644 --- a/atr/storage/readers/checks.py +++ b/atr/storage/readers/checks.py @@ -22,7 +22,7 @@ import importlib from typing import TYPE_CHECKING import atr.db as db -import atr.hashing as hashing +import atr.hashes as hashes import atr.models.sql as sql import atr.storage as storage import atr.storage.types as types @@ -55,9 +55,9 @@ async def _filter_check_results_by_hash( extra_arg_names = [] extra_args = checks.resolve_extra_args(extra_arg_names, release) cache_key = await checks.resolve_cache_key( - policy_keys, release, release.latest_revision_number, extra_args, file=rel_path.name + cr.checker, policy_keys, release, release.latest_revision_number, extra_args, file=rel_path.name ) - input_hash_by_module[module_path] = hashing.compute_dict_hash(cache_key) if cache_key else None + input_hash_by_module[module_path] = hashes.compute_dict_hash(cache_key) if cache_key else None if cr.inputs_hash == input_hash_by_module[module_path]: filtered_check_results.append(cr) diff --git a/atr/tasks/__init__.py b/atr/tasks/__init__.py index 93797bcc..18780bd3 100644 --- a/atr/tasks/__init__.py +++ b/atr/tasks/__init__.py @@ -25,12 +25,12 @@ from typing import Any, Final import sqlmodel import atr.db as db -import atr.hashing as hashing +import atr.hashes as hashes import atr.models.results as results import atr.models.sql as sql import atr.tasks.checks as checks import atr.tasks.checks.compare as compare -import atr.tasks.checks.file_hash as file_hash +import atr.tasks.checks.hashing as hashing import atr.tasks.checks.license as license import atr.tasks.checks.paths as paths import atr.tasks.checks.rat as rat @@ -64,6 +64,7 @@ async def asc_checks( data, signature_path, check_cache_key=await checks.resolve_cache_key( + resolve(sql.TaskType.SIGNATURE_CHECK), signature.INPUT_POLICY_KEYS, release, revision, @@ -290,7 +291,7 @@ async def queued( ) -> sql.Task | None: if check_cache_key is not None: logging.info("cache key", check_cache_key) - hash_val = hashing.compute_dict_hash(check_cache_key) + hash_val = hashes.compute_dict_hash(check_cache_key) if not data: raise RuntimeError("DB Session is required for check_cache_key") existing = await data.check_result(inputs_hash=hash_val).all() @@ -317,7 +318,7 @@ def resolve(task_type: sql.TaskType) -> Callable[..., Awaitable[results.Results case sql.TaskType.DISTRIBUTION_WORKFLOW: return gha.trigger_workflow case sql.TaskType.HASHING_CHECK: - return file_hash.check + return hashing.check case sql.TaskType.KEYS_IMPORT_FILE: return keys.import_file case sql.TaskType.LICENSE_FILES: @@ -377,10 +378,11 @@ async def sha_checks( data, hash_file, check_cache_key=await checks.resolve_cache_key( - file_hash.INPUT_POLICY_KEYS, + resolve(sql.TaskType.HASHING_CHECK), + hashing.INPUT_POLICY_KEYS, release, revision, - checks.resolve_extra_args(file_hash.INPUT_EXTRA_ARGS, release), + checks.resolve_extra_args(hashing.INPUT_EXTRA_ARGS, release), file=hash_file, ), ) @@ -395,23 +397,56 @@ async def tar_gz_checks( """Create check tasks for a .tar.gz or .tgz file.""" # This release has committee, as guaranteed in draft_checks is_podling = (release.project.committee is not None) and release.project.committee.is_podling - + compare_ck = await checks.resolve_cache_key( + resolve(sql.TaskType.COMPARE_SOURCE_TREES), + compare.INPUT_POLICY_KEYS, + release, + revision, + checks.resolve_extra_args(compare.INPUT_EXTRA_ARGS, release), + file=path, + ) + license_h_ck = await checks.resolve_cache_key( + resolve(sql.TaskType.LICENSE_HEADERS), + license.INPUT_POLICY_KEYS, + release, + revision, + checks.resolve_extra_args(license.INPUT_EXTRA_ARGS, release), + file=path, + ) + license_f_ck = await checks.resolve_cache_key( + resolve(sql.TaskType.LICENSE_FILES), + license.INPUT_POLICY_KEYS, + release, + revision, + checks.resolve_extra_args(license.INPUT_EXTRA_ARGS, release), + file=path, + ) + rat_ck = await checks.resolve_cache_key( + resolve(sql.TaskType.RAT_CHECK), + rat.INPUT_POLICY_KEYS, + release, + revision, + checks.resolve_extra_args(rat.INPUT_EXTRA_ARGS, release), + file=path, + ) + targz_i_ck = await checks.resolve_cache_key( + resolve(sql.TaskType.TARGZ_INTEGRITY), + targz.INPUT_POLICY_KEYS, + release, + revision, + checks.resolve_extra_args(targz.INPUT_EXTRA_ARGS, release), + file=path, + ) + targz_s_ck = await checks.resolve_cache_key( + resolve(sql.TaskType.TARGZ_STRUCTURE), + targz.INPUT_POLICY_KEYS, + release, + revision, + checks.resolve_extra_args(targz.INPUT_EXTRA_ARGS, release), + file=path, + ) tasks = [ - queued( - asf_uid, - sql.TaskType.COMPARE_SOURCE_TREES, - release, - revision, - data, - path, - check_cache_key=await checks.resolve_cache_key( - compare.INPUT_POLICY_KEYS, - release, - revision, - checks.resolve_extra_args(compare.INPUT_EXTRA_ARGS, release), - file=path, - ), - ), + queued(asf_uid, sql.TaskType.COMPARE_SOURCE_TREES, release, revision, data, path, check_cache_key=compare_ck), queued( asf_uid, sql.TaskType.LICENSE_FILES, @@ -419,75 +454,13 @@ async def tar_gz_checks( revision, data, path, - check_cache_key=await checks.resolve_cache_key( - license.INPUT_POLICY_KEYS, - release, - revision, - checks.resolve_extra_args(license.INPUT_EXTRA_ARGS, release), - file=path, - ), + check_cache_key=license_f_ck, extra_args={"is_podling": is_podling}, ), - queued( - asf_uid, - sql.TaskType.LICENSE_HEADERS, - release, - revision, - data, - path, - check_cache_key=await checks.resolve_cache_key( - license.INPUT_POLICY_KEYS, - release, - revision, - checks.resolve_extra_args(license.INPUT_EXTRA_ARGS, release), - file=path, - ), - ), - queued( - asf_uid, - sql.TaskType.RAT_CHECK, - release, - revision, - data, - path, - check_cache_key=await checks.resolve_cache_key( - rat.INPUT_POLICY_KEYS, - release, - revision, - checks.resolve_extra_args(rat.INPUT_EXTRA_ARGS, release), - file=path, - ), - ), - queued( - asf_uid, - sql.TaskType.TARGZ_INTEGRITY, - release, - revision, - data, - path, - check_cache_key=await checks.resolve_cache_key( - targz.INPUT_POLICY_KEYS, - release, - revision, - checks.resolve_extra_args(targz.INPUT_EXTRA_ARGS, release), - file=path, - ), - ), - queued( - asf_uid, - sql.TaskType.TARGZ_STRUCTURE, - release, - revision, - data, - path, - check_cache_key=await checks.resolve_cache_key( - targz.INPUT_POLICY_KEYS, - release, - revision, - checks.resolve_extra_args(targz.INPUT_EXTRA_ARGS, release), - file=path, - ), - ), + queued(asf_uid, sql.TaskType.LICENSE_HEADERS, release, revision, data, path, check_cache_key=license_h_ck), + queued(asf_uid, sql.TaskType.RAT_CHECK, release, revision, data, path, check_cache_key=rat_ck), + queued(asf_uid, sql.TaskType.TARGZ_INTEGRITY, release, revision, data, path, check_cache_key=targz_i_ck), + queued(asf_uid, sql.TaskType.TARGZ_STRUCTURE, release, revision, data, path, check_cache_key=targz_s_ck), ] return await asyncio.gather(*tasks) @@ -526,22 +499,57 @@ async def zip_checks( """Create check tasks for a .zip file.""" # This release has committee, as guaranteed in draft_checks is_podling = (release.project.committee is not None) and release.project.committee.is_podling + compare_ck = await checks.resolve_cache_key( + resolve(sql.TaskType.COMPARE_SOURCE_TREES), + compare.INPUT_POLICY_KEYS, + release, + revision, + checks.resolve_extra_args(compare.INPUT_EXTRA_ARGS, release), + file=path, + ) + license_h_ck = await checks.resolve_cache_key( + resolve(sql.TaskType.LICENSE_HEADERS), + license.INPUT_POLICY_KEYS, + release, + revision, + checks.resolve_extra_args(license.INPUT_EXTRA_ARGS, release), + file=path, + ) + license_f_ck = await checks.resolve_cache_key( + resolve(sql.TaskType.LICENSE_FILES), + license.INPUT_POLICY_KEYS, + release, + revision, + checks.resolve_extra_args(license.INPUT_EXTRA_ARGS, release), + file=path, + ) + rat_ck = await checks.resolve_cache_key( + resolve(sql.TaskType.RAT_CHECK), + rat.INPUT_POLICY_KEYS, + release, + revision, + checks.resolve_extra_args(rat.INPUT_EXTRA_ARGS, release), + file=path, + ) + zip_i_ck = await checks.resolve_cache_key( + resolve(sql.TaskType.ZIPFORMAT_INTEGRITY), + zipformat.INPUT_POLICY_KEYS, + release, + revision, + checks.resolve_extra_args(zipformat.INPUT_EXTRA_ARGS, release), + file=path, + ) + zip_s_ck = await checks.resolve_cache_key( + resolve(sql.TaskType.ZIPFORMAT_STRUCTURE), + zipformat.INPUT_POLICY_KEYS, + release, + revision, + checks.resolve_extra_args(zipformat.INPUT_EXTRA_ARGS, release), + file=path, + ) + tasks = [ - queued( - asf_uid, - sql.TaskType.COMPARE_SOURCE_TREES, - release, - revision, - data, - path, - check_cache_key=await checks.resolve_cache_key( - compare.INPUT_POLICY_KEYS, - release, - revision, - checks.resolve_extra_args(compare.INPUT_EXTRA_ARGS, release), - file=path, - ), - ), + queued(asf_uid, sql.TaskType.COMPARE_SOURCE_TREES, release, revision, data, path, check_cache_key=compare_ck), queued( asf_uid, sql.TaskType.LICENSE_FILES, @@ -549,75 +557,13 @@ async def zip_checks( revision, data, path, - check_cache_key=await checks.resolve_cache_key( - license.INPUT_POLICY_KEYS, - release, - revision, - checks.resolve_extra_args(license.INPUT_EXTRA_ARGS, release), - file=path, - ), + check_cache_key=license_f_ck, extra_args={"is_podling": is_podling}, ), - queued( - asf_uid, - sql.TaskType.LICENSE_HEADERS, - release, - revision, - data, - path, - check_cache_key=await checks.resolve_cache_key( - license.INPUT_POLICY_KEYS, - release, - revision, - checks.resolve_extra_args(license.INPUT_EXTRA_ARGS, release), - file=path, - ), - ), - queued( - asf_uid, - sql.TaskType.RAT_CHECK, - release, - revision, - data, - path, - check_cache_key=await checks.resolve_cache_key( - rat.INPUT_POLICY_KEYS, - release, - revision, - checks.resolve_extra_args(rat.INPUT_EXTRA_ARGS, release), - file=path, - ), - ), - queued( - asf_uid, - sql.TaskType.TARGZ_INTEGRITY, - release, - revision, - data, - path, - check_cache_key=await checks.resolve_cache_key( - zipformat.INPUT_POLICY_KEYS, - release, - revision, - checks.resolve_extra_args(zipformat.INPUT_EXTRA_ARGS, release), - file=path, - ), - ), - queued( - asf_uid, - sql.TaskType.TARGZ_STRUCTURE, - release, - revision, - data, - path, - check_cache_key=await checks.resolve_cache_key( - zipformat.INPUT_POLICY_KEYS, - release, - revision, - checks.resolve_extra_args(zipformat.INPUT_EXTRA_ARGS, release), - file=path, - ), - ), + queued(asf_uid, sql.TaskType.LICENSE_HEADERS, release, revision, data, path, check_cache_key=license_h_ck), + queued(asf_uid, sql.TaskType.RAT_CHECK, release, revision, data, path, check_cache_key=rat_ck), + queued(asf_uid, sql.TaskType.ZIPFORMAT_INTEGRITY, release, revision, data, path, check_cache_key=zip_i_ck), + queued(asf_uid, sql.TaskType.ZIPFORMAT_STRUCTURE, release, revision, data, path, check_cache_key=zip_s_ck), ] return await asyncio.gather(*tasks) diff --git a/atr/tasks/checks/__init__.py b/atr/tasks/checks/__init__.py index a6494cae..add95588 100644 --- a/atr/tasks/checks/__init__.py +++ b/atr/tasks/checks/__init__.py @@ -36,7 +36,7 @@ import atr.attestable as attestable import atr.config as config import atr.db as db import atr.file_paths as file_paths -import atr.hashing as hashing +import atr.hashes as hashes import atr.log as log import atr.models.sql as sql import atr.util as util @@ -78,7 +78,7 @@ class Recorder: member_rel_path: str | None = None, afresh: bool = True, ) -> None: - self.checker = function_key(checker) if callable(checker) else checker + self.checker = function_key(checker) self.release_name = sql.release_name(project_name, version_name) self.revision_number = revision_number self.primary_rel_path = primary_rel_path @@ -219,9 +219,9 @@ class Recorder: name=self.release_name, _release_policy=True, _project_release_policy=True ).demand(RuntimeError(f"Release {self.release_name} not found")) cache_key = await resolve_cache_key( - policy_keys, release, self.revision_number, input_args, file=self.primary_rel_path + self.checker, policy_keys, release, self.revision_number, input_args, file=self.primary_rel_path ) - self.__input_hash = hashing.compute_dict_hash(cache_key) if cache_key else None + self.__input_hash = hashes.compute_dict_hash(cache_key) if cache_key else None return True @property @@ -342,11 +342,12 @@ class Recorder: ) -def function_key(func: Callable[..., Any]) -> str: - return func.__module__ + "." + func.__name__ +def function_key(func: Callable[..., Any] | str) -> str: + return func.__module__ + "." + func.__name__ if callable(func) else func async def resolve_cache_key( + checker: str | Callable[..., Any], policy_keys: list[str], release: sql.Release, revision: str, @@ -368,8 +369,8 @@ async def resolve_cache_key( if path is None: # We know file isn't None here but type checker doesn't path = file_paths.revision_path_for_file(release.project_name, release.version, revision, file or "") - file_hash = await hashing.compute_file_hash(path) - cache_key = {"file_hash": file_hash} + file_hash = await hashes.compute_file_hash(path) + cache_key = {"file_hash": file_hash, "checker": function_key(checker)} if len(policy_keys) > 0 and policy is not None: policy_dict = policy.model_dump(exclude_none=True) diff --git a/atr/tasks/checks/file_hash.py b/atr/tasks/checks/hashing.py similarity index 100% rename from atr/tasks/checks/file_hash.py rename to atr/tasks/checks/hashing.py diff --git a/atr/tasks/checks/paths.py b/atr/tasks/checks/paths.py index bcd36215..39d0cfe3 100644 --- a/atr/tasks/checks/paths.py +++ b/atr/tasks/checks/paths.py @@ -23,7 +23,7 @@ from typing import Final import aiofiles.os import atr.analysis as analysis -import atr.hashing as hashing +import atr.hashes as hashes import atr.log as log import atr.models.results as results import atr.tasks.checks as checks @@ -196,8 +196,8 @@ async def _check_path_process_single( # noqa: C901 full_path = base_path / relative_path relative_path_str = str(relative_path) - file_hash = await hashing.compute_file_hash(full_path) - inputs_hash = hashing.compute_dict_hash({"file_hash": file_hash, "is_podling": is_podling}) + file_hash = await hashes.compute_file_hash(full_path) + inputs_hash = hashes.compute_dict_hash({"file_hash": file_hash, "is_podling": is_podling}) # For debugging and testing if (await user.is_admin_async(asf_uid)) and (full_path.name == "deliberately_slow_ATR_task_filename.txt"): --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
