This is an automated email from the ASF dual-hosted git repository. arm pushed a commit to branch check_caching in repository https://gitbox.apache.org/repos/asf/tooling-trusted-releases.git
commit d7c4a17504eae32b3300647a4b3a43585970c259 Author: Alastair McFarlane <[email protected]> AuthorDate: Wed Feb 11 16:13:22 2026 +0000 Start to move caching out of check tasks --- atr/attestable.py | 16 +- atr/db/__init__.py | 3 + atr/docs/checks.md | 2 +- atr/docs/tasks.md | 2 +- atr/file_paths.py | 28 ++++ atr/get/report.py | 11 +- atr/hashing.py | 42 +++++ atr/models/sql.py | 2 +- atr/storage/readers/checks.py | 53 ++++++- atr/storage/readers/releases.py | 2 + atr/tasks/__init__.py | 202 +++++++++++++++++------- atr/tasks/checks/__init__.py | 171 ++++++++++++-------- atr/tasks/checks/compare.py | 5 + atr/tasks/checks/{hashing.py => file_hash.py} | 7 + atr/tasks/checks/license.py | 16 +- atr/tasks/checks/paths.py | 20 ++- atr/tasks/checks/rat.py | 7 +- atr/tasks/checks/signature.py | 6 + atr/tasks/checks/targz.py | 8 + atr/tasks/checks/zipformat.py | 10 +- migrations/versions/0049_2026.02.11_5b874ed2.py | 37 +++++ tests/unit/recorders.py | 2 +- 22 files changed, 494 insertions(+), 158 deletions(-) diff --git a/atr/attestable.py b/atr/attestable.py index d4d6d15a..cac950c8 100644 --- a/atr/attestable.py +++ b/atr/attestable.py @@ -18,13 +18,13 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING, Any, Final +from typing import TYPE_CHECKING, Any import aiofiles import aiofiles.os -import blake3 import pydantic +import atr.hashing as hashing import atr.log as log import atr.models.attestable as models import atr.util as util @@ -32,8 +32,6 @@ import atr.util as util if TYPE_CHECKING: import pathlib -_HASH_CHUNK_SIZE: Final[int] = 4 * 1024 * 1024 - def attestable_path(project_name: str, version_name: str, revision_number: str) -> pathlib.Path: return util.get_attestable_dir() / project_name / version_name / f"{revision_number}.json" @@ -43,14 +41,6 @@ def attestable_paths_path(project_name: str, version_name: str, revision_number: return util.get_attestable_dir() / project_name / version_name / f"{revision_number}.paths.json" -async def compute_file_hash(path: pathlib.Path) -> str: - hasher = blake3.blake3() - async with aiofiles.open(path, "rb") as f: - while chunk := await f.read(_HASH_CHUNK_SIZE): - hasher.update(chunk) - return f"blake3:{hasher.hexdigest()}" - - def github_tp_payload_path(project_name: str, version_name: str, revision_number: str) -> pathlib.Path: return util.get_attestable_dir() / project_name / version_name / f"{revision_number}.github-tp.json" @@ -140,7 +130,7 @@ async def paths_to_hashes_and_sizes(directory: pathlib.Path) -> tuple[dict[str, if "\\" in path_key: # TODO: We should centralise this, and forbid some other characters too raise ValueError(f"Backslash in path is forbidden: {path_key}") - path_to_hash[path_key] = await compute_file_hash(full_path) + path_to_hash[path_key] = await hashing.compute_file_hash(full_path) path_to_size[path_key] = (await aiofiles.os.stat(full_path)).st_size return path_to_hash, path_to_size diff --git a/atr/db/__init__.py b/atr/db/__init__.py index eb454e0b..ed4a76a9 100644 --- a/atr/db/__init__.py +++ b/atr/db/__init__.py @@ -164,6 +164,7 @@ class Session(sqlalchemy.ext.asyncio.AsyncSession): status: Opt[sql.CheckResultStatus] = NOT_SET, message: Opt[str] = NOT_SET, data: Opt[Any] = NOT_SET, + inputs_hash: Opt[str] = NOT_SET, _release: bool = False, ) -> Query[sql.CheckResult]: query = sqlmodel.select(sql.CheckResult) @@ -188,6 +189,8 @@ class Session(sqlalchemy.ext.asyncio.AsyncSession): query = query.where(sql.CheckResult.message == message) if is_defined(data): query = query.where(sql.CheckResult.data == data) + if is_defined(inputs_hash): + query = query.where(sql.CheckResult.inputs_hash == inputs_hash) if _release: query = query.options(joined_load(sql.CheckResult.release)) diff --git a/atr/docs/checks.md b/atr/docs/checks.md index 7a348179..30badd04 100644 --- a/atr/docs/checks.md +++ b/atr/docs/checks.md @@ -52,7 +52,7 @@ This check records separate checker keys for errors, warnings, and success. Use For each `.sha256` or `.sha512` file, ATR computes the hash of the referenced artifact and compares it with the expected value. It supports files that contain just the hash as well as files that include a filename and hash on the same line. If the suffix does not indicate `sha256` or `sha512`, the check fails. -The checker key is `atr.tasks.checks.hashing.check`. +The checker key is `atr.tasks.checks.file_hash.check`. ### Signature verification diff --git a/atr/docs/tasks.md b/atr/docs/tasks.md index 98c409a3..43c8ae2b 100644 --- a/atr/docs/tasks.md +++ b/atr/docs/tasks.md @@ -41,7 +41,7 @@ In `atr/tasks/checks` you will find several modules that perform these check tas In `atr/tasks/__init__.py` you will see imports for existing modules where you can add an import for new check task, for example: ```python -import atr.tasks.checks.hashing as hashing +import atr.tasks.checks.file_hash as file_hash import atr.tasks.checks.license as license ``` diff --git a/atr/file_paths.py b/atr/file_paths.py new file mode 100644 index 00000000..d29d6b96 --- /dev/null +++ b/atr/file_paths.py @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pathlib + +import atr.util as util + + +def base_path_for_revision(project_name: str, version_name: str, revision: str) -> pathlib.Path: + return pathlib.Path(util.get_unfinished_dir(), project_name, version_name, revision) + + +def revision_path_for_file(project_name: str, version_name: str, revision: str, file_name: str) -> pathlib.Path: + return base_path_for_revision(project_name, version_name, revision) / file_name diff --git a/atr/get/report.py b/atr/get/report.py index 6559b934..c987755a 100644 --- a/atr/get/report.py +++ b/atr/get/report.py @@ -40,10 +40,17 @@ async def selected_path(session: web.Committer, project_name: str, version_name: # If the draft is not found, we try to get the release candidate try: - release = await session.release(project_name, version_name, with_committee=True) + release = await session.release( + project_name, version_name, with_committee=True, with_release_policy=True, with_project_release_policy=True + ) except base.ASFQuartException: release = await session.release( - project_name, version_name, phase=sql.ReleasePhase.RELEASE_CANDIDATE, with_committee=True + project_name, + version_name, + phase=sql.ReleasePhase.RELEASE_CANDIDATE, + with_committee=True, + with_release_policy=True, + with_project_release_policy=True, ) if release.committee is None: diff --git a/atr/hashing.py b/atr/hashing.py new file mode 100644 index 00000000..2970e086 --- /dev/null +++ b/atr/hashing.py @@ -0,0 +1,42 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pathlib +from typing import Any, Final + +import aiofiles +import aiofiles.os +import blake3 + +_HASH_CHUNK_SIZE: Final[int] = 4 * 1024 * 1024 + + +async def compute_file_hash(path: str | pathlib.Path) -> str: + path = pathlib.Path(path) + hasher = blake3.blake3() + async with aiofiles.open(path, "rb") as f: + while chunk := await f.read(_HASH_CHUNK_SIZE): + hasher.update(chunk) + return f"blake3:{hasher.hexdigest()}" + + +def compute_dict_hash(to_hash: dict[Any, Any]) -> str: + hasher = blake3.blake3() + for k in sorted(to_hash.keys()): + hasher.update(str(k).encode("utf-8")) + hasher.update(str(to_hash[k]).encode("utf-8")) + return f"blake3:{hasher.hexdigest()}" diff --git a/atr/models/sql.py b/atr/models/sql.py index a6e9aed5..c03a3953 100644 --- a/atr/models/sql.py +++ b/atr/models/sql.py @@ -946,7 +946,7 @@ class CheckResult(sqlmodel.SQLModel, table=True): data: Any = sqlmodel.Field( sa_column=sqlalchemy.Column(sqlalchemy.JSON), **example({"expected": "...", "found": "..."}) ) - input_hash: str | None = sqlmodel.Field(default=None, index=True, **example("blake3:7f83b1657ff1fc...")) + inputs_hash: str | None = sqlmodel.Field(default=None, index=True, **example("blake3:7f83b1657ff1fc...")) cached: bool = sqlmodel.Field(default=False, **example(False)) diff --git a/atr/storage/readers/checks.py b/atr/storage/readers/checks.py index b48daf35..5797ba04 100644 --- a/atr/storage/readers/checks.py +++ b/atr/storage/readers/checks.py @@ -18,17 +18,51 @@ # Removing this will cause circular imports from __future__ import annotations +import importlib from typing import TYPE_CHECKING import atr.db as db +import atr.file_paths as file_paths +import atr.hashing as hashing import atr.models.sql as sql import atr.storage as storage import atr.storage.types as types +import atr.tasks.checks as checks import atr.util as util if TYPE_CHECKING: import pathlib - from collections.abc import Callable + from collections.abc import Callable, Sequence + + +async def _filter_check_results_by_hash( + all_check_results: Sequence[sql.CheckResult], + file_path: pathlib.Path, + input_hash_by_module: dict[str, str | None], + release: sql.Release, +) -> Sequence[sql.CheckResult]: + filtered_check_results = [] + if release.latest_revision_number is None: + raise ValueError("Release has no revision - Invalid state") + for cr in all_check_results: + module_path = cr.checker.rsplit(".", 1)[0] + if module_path not in input_hash_by_module: + try: + module = importlib.import_module(module_path) + policy_keys: list[str] = module.INPUT_POLICY_KEYS + extra_arg_names: list[str] = getattr(module, "INPUT_EXTRA_ARGS", []) + except (ImportError, AttributeError): + policy_keys = [] + extra_arg_names = [] + extra_args = checks.resolve_extra_args(extra_arg_names, release) + cache_key = await checks.resolve_cache_key( + policy_keys, release, release.latest_revision_number, extra_args, path=file_path + ) + input_hash_by_module[module_path] = hashing.compute_dict_hash(cache_key) if cache_key else None + + if cr.inputs_hash == input_hash_by_module[module_path]: + filtered_check_results.append(cr) + return filtered_check_results class GeneralPublic: @@ -48,9 +82,10 @@ class GeneralPublic: if release.latest_revision_number is None: raise ValueError("Release has no revision - Invalid state") + # TODO: Is this potentially too much data? Within a revision I hope it's not too bad? + query = self.__data.check_result( release_name=release.name, - revision_number=release.latest_revision_number, primary_rel_path=str(rel_path), ).order_by( sql.validate_instrumented_attribute(sql.CheckResult.checker).asc(), @@ -58,11 +93,23 @@ class GeneralPublic: ) all_check_results = await query.all() + file_path = file_paths.revision_path_for_file( + release.project_name, release.version, release.latest_revision_number, rel_path.name + ) + # Filter to checks for the current file version / policy + # Cache the computed input hash per checker module, since all results here share the same file and release + input_hash_by_module: dict[str, str | None] = {} + # TODO: This has a bug - create an archive, it'll scan with a hash and show missing checksum. + # Then generate a checksum. It'll re-scan the file with the same hash, but now has one. Two checks shown. + filtered_check_results = await _filter_check_results_by_hash( + all_check_results, file_path, input_hash_by_module, release + ) + # Filter out any results that are ignored unignored_checks = [] ignored_checks = [] match_ignore = await self.ignores_matcher(release.project_name) - for cr in all_check_results: + for cr in filtered_check_results: if not match_ignore(cr): unignored_checks.append(cr) else: diff --git a/atr/storage/readers/releases.py b/atr/storage/readers/releases.py index e57a9786..c4969aca 100644 --- a/atr/storage/readers/releases.py +++ b/atr/storage/readers/releases.py @@ -129,6 +129,8 @@ class GeneralPublic: info=info, match_ignore=match_ignore, ) + # TODO: These get just the ones for the revision. + # It might be better to get all like we do in by_release_path, filter by hash, then filter by status await self.__successes(cs) await self.__warnings(cs) await self.__errors(cs) diff --git a/atr/tasks/__init__.py b/atr/tasks/__init__.py index 8030727d..9eca7a8e 100644 --- a/atr/tasks/__init__.py +++ b/atr/tasks/__init__.py @@ -15,17 +15,22 @@ # specific language governing permissions and limitations # under the License. +import asyncio import datetime +import logging +import pathlib from collections.abc import Awaitable, Callable, Coroutine from typing import Any, Final import sqlmodel import atr.db as db +import atr.hashing as hashing import atr.models.results as results import atr.models.sql as sql +import atr.tasks.checks as checks import atr.tasks.checks.compare as compare -import atr.tasks.checks.hashing as hashing +import atr.tasks.checks.file_hash as file_hash import atr.tasks.checks.license as license import atr.tasks.checks.paths as paths import atr.tasks.checks.rat as rat @@ -43,17 +48,20 @@ import atr.tasks.vote as vote import atr.util as util -async def asc_checks(asf_uid: str, release: sql.Release, revision: str, signature_path: str) -> list[sql.Task]: +async def asc_checks( + asf_uid: str, release: sql.Release, revision: str, signature_path: str, data: db.Session +) -> list[sql.Task | None]: """Create signature check task for a .asc file.""" tasks = [] if release.committee: tasks.append( - queued( + await queued( asf_uid, sql.TaskType.SIGNATURE_CHECK, release, revision, + data, signature_path, {"committee_name": release.committee.name}, ) @@ -120,9 +128,12 @@ async def draft_checks( relative_paths = [path async for path in util.paths_recursive(revision_path)] async with db.ensure_session(caller_data) as data: - release = await data.release(name=sql.release_name(project_name, release_version), _committee=True).demand( - RuntimeError("Release not found") - ) + release = await data.release( + name=sql.release_name(project_name, release_version), + _committee=True, + _release_policy=True, + _project_release_policy=True, + ).demand(RuntimeError("Release not found")) other_releases = ( await data.release(project_name=project_name, phase=sql.ReleasePhase.RELEASE) .order_by(sql.Release.released) @@ -136,43 +147,29 @@ async def draft_checks( (v for v in release_versions if util.version_sort_key(v.version) < release_version_sortable), None ) for path in relative_paths: - path_str = str(path) - task_function: Callable[[str, sql.Release, str, str], Awaitable[list[sql.Task]]] | None = None - for suffix, func in TASK_FUNCTIONS.items(): - if path.name.endswith(suffix): - task_function = func - break - if task_function: - for task in await task_function(asf_uid, release, revision_number, path_str): - task.revision_number = revision_number - data.add(task) - # TODO: Should we check .json files for their content? - # Ideally we would not have to do that - if path.name.endswith(".cdx.json"): - data.add( - queued( - asf_uid, - sql.TaskType.SBOM_TOOL_SCORE, - release, - revision_number, - path_str, - extra_args={ - "project_name": project_name, - "version_name": release_version, - "revision_number": revision_number, - "previous_release_version": previous_version.version if previous_version else None, - "file_path": path_str, - "asf_uid": asf_uid, - }, - ) - ) + await _draft_file_checks( + asf_uid, + caller_data, + data, + path, + previous_version, + project_name, + release, + release_version, + revision_number, + ) is_podling = False if release.project.committee is not None: if release.project.committee.is_podling: is_podling = True - path_check_task = queued( - asf_uid, sql.TaskType.PATHS_CHECK, release, revision_number, extra_args={"is_podling": is_podling} + path_check_task = await queued( + asf_uid, + sql.TaskType.PATHS_CHECK, + release, + revision_number, + caller_data, + extra_args={"is_podling": is_podling}, ) data.add(path_check_task) if caller_data is None: @@ -181,6 +178,51 @@ async def draft_checks( return len(relative_paths) +async def _draft_file_checks( + asf_uid: str, + caller_data: db.Session | None, + data: db.Session, + path: pathlib.Path, + previous_version: sql.Release | None, + project_name: str, + release: sql.Release, + release_version: str, + revision_number: str, +): + path_str = str(path) + task_function: Callable[[str, sql.Release, str, str, db.Session], Awaitable[list[sql.Task | None]]] | None = None + for suffix, func in TASK_FUNCTIONS.items(): + if path.name.endswith(suffix): + task_function = func + break + if task_function: + for task in await task_function(asf_uid, release, revision_number, path_str, data): + if task: + task.revision_number = revision_number + data.add(task) + # TODO: Should we check .json files for their content? + # Ideally we would not have to do that + if path.name.endswith(".cdx.json"): + data.add( + await queued( + asf_uid, + sql.TaskType.SBOM_TOOL_SCORE, + release, + revision_number, + caller_data, + path_str, + extra_args={ + "project_name": project_name, + "version_name": release_version, + "revision_number": revision_number, + "previous_release_version": previous_version.version if previous_version else None, + "file_path": path_str, + "asf_uid": asf_uid, + }, + ) + ) + + async def keys_import_file( asf_uid: str, project_name: str, version_name: str, revision_number: str, caller_data: db.Session | None = None ) -> None: @@ -230,14 +272,24 @@ async def metadata_update( return task -def queued( +async def queued( asf_uid: str, task_type: sql.TaskType, release: sql.Release, revision_number: str, + data: db.Session | None = None, primary_rel_path: str | None = None, extra_args: dict[str, Any] | None = None, -) -> sql.Task: + check_cache_key: dict[str, Any] | None = None, +) -> sql.Task | None: + if check_cache_key is not None: + logging.info("cache key", check_cache_key) + hash_val = hashing.compute_dict_hash(check_cache_key) + if not data: + raise RuntimeError("DB Session is required for check_cache_key") + existing = await data.check_result(inputs_hash=hash_val).all() + if existing: + return None return sql.Task( status=sql.TaskStatus.QUEUED, task_type=task_type, @@ -259,7 +311,7 @@ def resolve(task_type: sql.TaskType) -> Callable[..., Awaitable[results.Results case sql.TaskType.DISTRIBUTION_WORKFLOW: return gha.trigger_workflow case sql.TaskType.HASHING_CHECK: - return hashing.check + return file_hash.check case sql.TaskType.KEYS_IMPORT_FILE: return keys.import_file case sql.TaskType.LICENSE_FILES: @@ -304,29 +356,53 @@ def resolve(task_type: sql.TaskType) -> Callable[..., Awaitable[results.Results # Otherwise we lose exhaustiveness checking -async def sha_checks(asf_uid: str, release: sql.Release, revision: str, hash_file: str) -> list[sql.Task]: +async def sha_checks( + asf_uid: str, release: sql.Release, revision: str, hash_file: str, data: db.Session +) -> list[sql.Task | None]: """Create hash check task for a .sha256 or .sha512 file.""" tasks = [] - tasks.append(queued(asf_uid, sql.TaskType.HASHING_CHECK, release, revision, hash_file)) + tasks.append(queued(asf_uid, sql.TaskType.HASHING_CHECK, release, revision, data, hash_file)) - return tasks + return await asyncio.gather(*tasks) -async def tar_gz_checks(asf_uid: str, release: sql.Release, revision: str, path: str) -> list[sql.Task]: +async def tar_gz_checks( + asf_uid: str, release: sql.Release, revision: str, path: str, data: db.Session +) -> list[sql.Task | None]: """Create check tasks for a .tar.gz or .tgz file.""" # This release has committee, as guaranteed in draft_checks is_podling = (release.project.committee is not None) and release.project.committee.is_podling + tasks = [ - queued(asf_uid, sql.TaskType.COMPARE_SOURCE_TREES, release, revision, path), - queued(asf_uid, sql.TaskType.LICENSE_FILES, release, revision, path, extra_args={"is_podling": is_podling}), - queued(asf_uid, sql.TaskType.LICENSE_HEADERS, release, revision, path), - queued(asf_uid, sql.TaskType.RAT_CHECK, release, revision, path), - queued(asf_uid, sql.TaskType.TARGZ_INTEGRITY, release, revision, path), - queued(asf_uid, sql.TaskType.TARGZ_STRUCTURE, release, revision, path), + queued(asf_uid, sql.TaskType.COMPARE_SOURCE_TREES, release, revision, data, path), + queued( + asf_uid, + sql.TaskType.LICENSE_FILES, + release, + revision, + data, + path, + check_cache_key=await checks.resolve_cache_key( + license.INPUT_POLICY_KEYS, release, revision, {**{"is_podling": is_podling}}, file=path + ), + extra_args={"is_podling": is_podling}, + ), + queued( + asf_uid, + sql.TaskType.LICENSE_HEADERS, + release, + revision, + data, + path, + check_cache_key=await checks.resolve_cache_key(license.INPUT_POLICY_KEYS, release, revision, file=path), + ), + queued(asf_uid, sql.TaskType.RAT_CHECK, release, revision, data, path), + queued(asf_uid, sql.TaskType.TARGZ_INTEGRITY, release, revision, data, path), + queued(asf_uid, sql.TaskType.TARGZ_STRUCTURE, release, revision, data, path), ] - return tasks + return await asyncio.gather(*tasks) async def workflow_update( @@ -356,22 +432,26 @@ async def workflow_update( return task -async def zip_checks(asf_uid: str, release: sql.Release, revision: str, path: str) -> list[sql.Task]: +async def zip_checks( + asf_uid: str, release: sql.Release, revision: str, path: str, data: db.Session +) -> list[sql.Task | None]: """Create check tasks for a .zip file.""" # This release has committee, as guaranteed in draft_checks is_podling = (release.project.committee is not None) and release.project.committee.is_podling tasks = [ - queued(asf_uid, sql.TaskType.COMPARE_SOURCE_TREES, release, revision, path), - queued(asf_uid, sql.TaskType.LICENSE_FILES, release, revision, path, extra_args={"is_podling": is_podling}), - queued(asf_uid, sql.TaskType.LICENSE_HEADERS, release, revision, path), - queued(asf_uid, sql.TaskType.RAT_CHECK, release, revision, path), - queued(asf_uid, sql.TaskType.ZIPFORMAT_INTEGRITY, release, revision, path), - queued(asf_uid, sql.TaskType.ZIPFORMAT_STRUCTURE, release, revision, path), + queued(asf_uid, sql.TaskType.COMPARE_SOURCE_TREES, release, revision, data, path), + queued( + asf_uid, sql.TaskType.LICENSE_FILES, release, revision, data, path, extra_args={"is_podling": is_podling} + ), + queued(asf_uid, sql.TaskType.LICENSE_HEADERS, release, revision, data, path), + queued(asf_uid, sql.TaskType.RAT_CHECK, release, revision, data, path), + queued(asf_uid, sql.TaskType.ZIPFORMAT_INTEGRITY, release, revision, data, path), + queued(asf_uid, sql.TaskType.ZIPFORMAT_STRUCTURE, release, revision, data, path), ] - return tasks + return await asyncio.gather(*tasks) -TASK_FUNCTIONS: Final[dict[str, Callable[..., Coroutine[Any, Any, list[sql.Task]]]]] = { +TASK_FUNCTIONS: Final[dict[str, Callable[..., Coroutine[Any, Any, list[sql.Task | None]]]]] = { ".asc": asc_checks, ".sha256": sha_checks, ".sha512": sha_checks, diff --git a/atr/tasks/checks/__init__.py b/atr/tasks/checks/__init__.py index 1b78f68e..1b3c1760 100644 --- a/atr/tasks/checks/__init__.py +++ b/atr/tasks/checks/__init__.py @@ -20,26 +20,26 @@ from __future__ import annotations import dataclasses import datetime import functools -import pathlib from typing import TYPE_CHECKING, Any, Final import aiofiles import aiofiles.os -import blake3 import sqlmodel if TYPE_CHECKING: + import pathlib from collections.abc import Awaitable, Callable import atr.models.schema as schema import atr.config as config import atr.db as db +import atr.file_paths as file_paths +import atr.hashing as hashing +import atr.log as log import atr.models.sql as sql import atr.util as util -_HASH_CHUNK_SIZE: Final[int] = 4 * 1024 * 1024 - # Pydantic does not like Callable types, so we use a dataclass instead # It says: "you should define `Callable`, then call `FunctionArguments.model_rebuild()`" @@ -61,7 +61,7 @@ class Recorder: version_name: str primary_rel_path: str | None member_rel_path: str | None - revision: str + revision_number: str afresh: bool __cached: bool __input_hash: str | None @@ -117,6 +117,7 @@ class Recorder: data: Any, primary_rel_path: str | None = None, member_rel_path: str | None = None, + inputs_hash: str | None = None, ) -> sql.CheckResult: if self.constructed is False: raise RuntimeError("Cannot add check result to a recorder that has not been constructed") @@ -142,7 +143,7 @@ class Recorder: message=message, data=data, cached=False, - input_hash=self.__input_hash, + inputs_hash=inputs_hash or self.__input_hash, ) # It would be more efficient to keep a session open @@ -167,7 +168,7 @@ class Recorder: return self.abs_path_base() / rel_path_part def abs_path_base(self) -> pathlib.Path: - return pathlib.Path(util.get_unfinished_dir(), self.project_name, self.version_name, self.revision_number) + return file_paths.base_path_for_revision(self.project_name, self.version_name, self.revision_number) async def project(self) -> sql.Project: # TODO: Cache project @@ -196,12 +197,10 @@ class Recorder: abs_path = await self.abs_path() return matches(str(abs_path)) - @property - def cached(self) -> bool: - return self.__cached - - async def check_cache(self, path: pathlib.Path) -> bool: - if not await aiofiles.os.path.isfile(path): + async def cache_key_set(self, policy_keys: list[str], input_args: dict[str, Any] | None = None) -> bool: + # TODO: Should this just be in the constructor? + path = await self.abs_path() + if (not path) or (not await aiofiles.os.path.isfile(path)): return False if config.get().DISABLE_CHECK_CACHE: @@ -214,48 +213,18 @@ class Recorder: if await aiofiles.os.path.exists(no_cache_file): return False - self.__input_hash = await _compute_file_hash(path) - async with db.session() as data: - via = sql.validate_instrumented_attribute - subquery = ( - sqlmodel.select( - sql.CheckResult.member_rel_path, - sqlmodel.func.max(via(sql.CheckResult.id)).label("max_id"), - ) - .where(sql.CheckResult.checker == self.checker) - .where(sql.CheckResult.input_hash == self.__input_hash) - .where(sql.CheckResult.primary_rel_path == self.primary_rel_path) - .group_by(sql.CheckResult.member_rel_path) - .subquery() - ) - stmt = sqlmodel.select(sql.CheckResult).join(subquery, via(sql.CheckResult.id) == subquery.c.max_id) - results = await data.execute(stmt) - cached_results = results.scalars().all() - - if not cached_results: - return False - - for cached in cached_results: - new_result = sql.CheckResult( - release_name=self.release_name, - revision_number=self.revision_number, - checker=self.checker, - primary_rel_path=self.primary_rel_path, - member_rel_path=cached.member_rel_path, - created=datetime.datetime.now(datetime.UTC), - status=cached.status, - message=cached.message, - data=cached.data, - cached=True, - input_hash=self.__input_hash, - ) - data.add(new_result) - await data.commit() - - self.__cached = True + release = await data.release( + name=self.release_name, _release_policy=True, _project_release_policy=True + ).demand(RuntimeError(f"Release {self.release_name} not found")) + cache_key = await resolve_cache_key(policy_keys, release, self.revision_number, input_args, path=path) + self.__input_hash = hashing.compute_dict_hash(cache_key) if cache_key else None return True + @property + def cached(self) -> bool: + return self.__cached + async def clear(self, primary_rel_path: str | None = None, member_rel_path: str | None = None) -> None: async with db.session() as data: stmt = sqlmodel.delete(sql.CheckResult).where( @@ -273,7 +242,12 @@ class Recorder: return self.__input_hash async def blocker( - self, message: str, data: Any, primary_rel_path: str | None = None, member_rel_path: str | None = None + self, + message: str, + data: Any, + primary_rel_path: str | None = None, + member_rel_path: str | None = None, + inputs_hash: str | None = None, ) -> sql.CheckResult: return await self._add( sql.CheckResultStatus.BLOCKER, @@ -281,10 +255,16 @@ class Recorder: data, primary_rel_path=primary_rel_path, member_rel_path=member_rel_path, + inputs_hash=inputs_hash, ) async def exception( - self, message: str, data: Any, primary_rel_path: str | None = None, member_rel_path: str | None = None + self, + message: str, + data: Any, + primary_rel_path: str | None = None, + member_rel_path: str | None = None, + inputs_hash: str | None = None, ) -> sql.CheckResult: return await self._add( sql.CheckResultStatus.EXCEPTION, @@ -292,10 +272,16 @@ class Recorder: data, primary_rel_path=primary_rel_path, member_rel_path=member_rel_path, + inputs_hash=inputs_hash, ) async def failure( - self, message: str, data: Any, primary_rel_path: str | None = None, member_rel_path: str | None = None + self, + message: str, + data: Any, + primary_rel_path: str | None = None, + member_rel_path: str | None = None, + inputs_hash: str | None = None, ) -> sql.CheckResult: return await self._add( sql.CheckResultStatus.FAILURE, @@ -303,10 +289,16 @@ class Recorder: data, primary_rel_path=primary_rel_path, member_rel_path=member_rel_path, + inputs_hash=inputs_hash, ) async def success( - self, message: str, data: Any, primary_rel_path: str | None = None, member_rel_path: str | None = None + self, + message: str, + data: Any, + primary_rel_path: str | None = None, + member_rel_path: str | None = None, + inputs_hash: str | None = None, ) -> sql.CheckResult: return await self._add( sql.CheckResultStatus.SUCCESS, @@ -314,6 +306,7 @@ class Recorder: data, primary_rel_path=primary_rel_path, member_rel_path=member_rel_path, + inputs_hash=inputs_hash, ) async def use_check_cache(self) -> bool: @@ -329,7 +322,12 @@ class Recorder: return self.__use_check_cache async def warning( - self, message: str, data: Any, primary_rel_path: str | None = None, member_rel_path: str | None = None + self, + message: str, + data: Any, + primary_rel_path: str | None = None, + member_rel_path: str | None = None, + inputs_hash: str | None = None, ) -> sql.CheckResult: return await self._add( sql.CheckResultStatus.WARNING, @@ -337,6 +335,7 @@ class Recorder: data, primary_rel_path=primary_rel_path, member_rel_path=member_rel_path, + inputs_hash=inputs_hash, ) @@ -344,6 +343,44 @@ def function_key(func: Callable[..., Any]) -> str: return func.__module__ + "." + func.__name__ +async def resolve_cache_key( + policy_keys: list[str], + release: sql.Release, + revision: str, + args: dict[str, Any] | None = None, + file: str | None = None, + path: pathlib.Path | None = None, +) -> dict[str, Any] | None: + if file is None and path is None: + raise ValueError("Must specify either file or path") + if not args: + args = {} + if path is None: + # We know file isn't None here but type checker doesn't + path = file_paths.revision_path_for_file(release.project_name, release.version, revision, file or "") + file_hash = await hashing.compute_file_hash(path) + cache_key = {"file_hash": file_hash} + + policy = release.release_policy or release.project.release_policy + if len(policy_keys) > 0 and policy is not None: + policy_dict = policy.model_dump(exclude_none=True) + return {**cache_key, **args, **{k: policy_dict[k] for k in policy_keys if k in policy_dict}} + else: + return {**cache_key, **args} + + +def resolve_extra_args(arg_names: list[str], release: sql.Release) -> dict[str, Any]: + result: dict[str, Any] = {} + for name in arg_names: + resolver = _EXTRA_ARG_RESOLVERS.get(name, None) + # If we can't find a resolver, we'll carry on anyway since it'll just mean no cache potentially + if resolver is None: + log.warning(f"Unknown extra arg resolver: {name}") + return {} + result[name] = resolver(release) + return result + + def with_model(cls: type[schema.Strict]) -> Callable[[Callable[..., Any]], Callable[..., Any]]: """Decorator to specify the parameters for a check.""" @@ -358,9 +395,17 @@ def with_model(cls: type[schema.Strict]) -> Callable[[Callable[..., Any]], Calla return decorator -async def _compute_file_hash(path: pathlib.Path) -> str: - hasher = blake3.blake3() - async with aiofiles.open(path, "rb") as f: - while chunk := await f.read(_HASH_CHUNK_SIZE): - hasher.update(chunk) - return f"blake3:{hasher.hexdigest()}" +def _resolve_is_podling(release: sql.Release) -> bool: + return (release.committee is not None) and release.committee.is_podling + + +def _resolve_committee_name(release: sql.Release) -> str: + if release.committee is None: + raise ValueError("Release has no committee") + return release.committee.name + + +_EXTRA_ARG_RESOLVERS: Final[dict[str, Callable[[sql.Release], Any]]] = { + "is_podling": _resolve_is_podling, + "committee_name": _resolve_committee_name, +} diff --git a/atr/tasks/checks/compare.py b/atr/tasks/checks/compare.py index 0df5ee48..54c98ecf 100644 --- a/atr/tasks/checks/compare.py +++ b/atr/tasks/checks/compare.py @@ -51,6 +51,9 @@ _DEFAULT_USER: Final[str] = "atr" _PERMITTED_ADDED_PATHS: Final[dict[str, list[str]]] = { "PKG-INFO": ["pyproject.toml"], } +# Release policy fields which this check relies on - used for result caching +INPUT_POLICY_KEYS: Final[list[str]] = [] +INPUT_EXTRA_ARGS: Final[list[str]] = [] @dataclasses.dataclass @@ -90,6 +93,8 @@ async def source_trees(args: checks.FunctionArguments) -> results.Results | None ) return None + await recorder.cache_key_set(INPUT_POLICY_KEYS) + payload = await _load_tp_payload(args.project_name, args.version_name, args.revision_number) checkout_dir: str | None = None archive_dir: str | None = None diff --git a/atr/tasks/checks/hashing.py b/atr/tasks/checks/file_hash.py similarity index 93% rename from atr/tasks/checks/hashing.py rename to atr/tasks/checks/file_hash.py index e8ae78fe..19d4fcbf 100644 --- a/atr/tasks/checks/hashing.py +++ b/atr/tasks/checks/file_hash.py @@ -17,6 +17,7 @@ import hashlib import secrets +from typing import Final import aiofiles @@ -24,6 +25,10 @@ import atr.log as log import atr.models.results as results import atr.tasks.checks as checks +# Release policy fields which this check relies on - used for result caching +INPUT_POLICY_KEYS: Final[list[str]] = [] +INPUT_EXTRA_ARGS: Final[list[str]] = [] + async def check(args: checks.FunctionArguments) -> results.Results | None: """Check the hash of a file.""" @@ -36,6 +41,8 @@ async def check(args: checks.FunctionArguments) -> results.Results | None: await recorder.failure("Unsupported hash algorithm", {"algorithm": algorithm}) return None + await recorder.cache_key_set(INPUT_POLICY_KEYS) + # Remove the hash file suffix to get the artifact path # This replaces the last suffix, which is what we want # >>> pathlib.Path("a/b/c.d.e.f.g").with_suffix(".x") diff --git a/atr/tasks/checks/license.py b/atr/tasks/checks/license.py index 066fef3a..0e09cd23 100644 --- a/atr/tasks/checks/license.py +++ b/atr/tasks/checks/license.py @@ -79,6 +79,10 @@ INCLUDED_PATTERNS: Final[list[str]] = [ r"\.(pl|pm|t)$", # Perl ] +# Release policy fields which this check relies on - used for result caching +INPUT_POLICY_KEYS: Final[list[str]] = [""] +INPUT_EXTRA_ARGS: Final[list[str]] = ["is_podling"] + # Types @@ -134,10 +138,12 @@ async def files(args: checks.FunctionArguments) -> results.Results | None: if project.policy_license_check_mode == sql.LicenseCheckMode.RAT: return None + is_podling = args.extra_args.get("is_podling", False) + await recorder.cache_key_set(INPUT_POLICY_KEYS, {"is_podling": is_podling}) + log.info(f"Checking license files for {artifact_abs_path} (rel: {args.primary_rel_path})") try: - is_podling = args.extra_args.get("is_podling", False) for result in await asyncio.to_thread(_files_check_core_logic, str(artifact_abs_path), is_podling): match result: case ArtifactResult(): @@ -166,9 +172,11 @@ async def headers(args: checks.FunctionArguments) -> results.Results | None: if project.policy_license_check_mode == sql.LicenseCheckMode.RAT: return None - if await recorder.check_cache(artifact_abs_path): - log.info(f"Using cached license headers result for {artifact_abs_path} (rel: {args.primary_rel_path})") - return None + await recorder.cache_key_set(INPUT_POLICY_KEYS) + + # if await recorder.check_cache(artifact_abs_path): + # log.info(f"Using cached license headers result for {artifact_abs_path} (rel: {args.primary_rel_path})") + # return None log.info(f"Checking license headers for {artifact_abs_path} (rel: {args.primary_rel_path})") diff --git a/atr/tasks/checks/paths.py b/atr/tasks/checks/paths.py index 299e4b25..bcd36215 100644 --- a/atr/tasks/checks/paths.py +++ b/atr/tasks/checks/paths.py @@ -23,6 +23,7 @@ from typing import Final import aiofiles.os import atr.analysis as analysis +import atr.hashing as hashing import atr.log as log import atr.models.results as results import atr.tasks.checks as checks @@ -37,6 +38,9 @@ _ALLOWED_TOP_LEVEL: Final = frozenset( "README", } ) +# Release policy fields which this check relies on - used for result caching +INPUT_POLICY_KEYS: Final[list[str]] = [] +INPUT_EXTRA_ARGS: Final[list[str]] = ["is_podling"] async def check(args: checks.FunctionArguments) -> results.Results | None: @@ -192,6 +196,9 @@ async def _check_path_process_single( # noqa: C901 full_path = base_path / relative_path relative_path_str = str(relative_path) + file_hash = await hashing.compute_file_hash(full_path) + inputs_hash = hashing.compute_dict_hash({"file_hash": file_hash, "is_podling": is_podling}) + # For debugging and testing if (await user.is_admin_async(asf_uid)) and (full_path.name == "deliberately_slow_ATR_task_filename.txt"): await asyncio.sleep(20) @@ -276,6 +283,7 @@ async def _check_path_process_single( # noqa: C901 errors, blockers, warnings, + inputs_hash, ) @@ -287,14 +295,18 @@ async def _record( errors: list[str], blockers: list[str], warnings: list[str], + inputs_hash: str, ) -> None: for error in errors: - await recorder_errors.failure(error, {}, primary_rel_path=relative_path_str) + await recorder_errors.failure(error, {}, primary_rel_path=relative_path_str, inputs_hash=inputs_hash) for item in blockers: - await recorder_errors.blocker(item, {}, primary_rel_path=relative_path_str) + await recorder_errors.blocker(item, {}, primary_rel_path=relative_path_str, inputs_hash=inputs_hash) for warning in warnings: - await recorder_warnings.warning(warning, {}, primary_rel_path=relative_path_str) + await recorder_warnings.warning(warning, {}, primary_rel_path=relative_path_str, inputs_hash=inputs_hash) if not (errors or blockers or warnings): await recorder_success.success( - "Path structure and naming conventions conform to policy", {}, primary_rel_path=relative_path_str + "Path structure and naming conventions conform to policy", + {}, + primary_rel_path=relative_path_str, + inputs_hash=inputs_hash, ) diff --git a/atr/tasks/checks/rat.py b/atr/tasks/checks/rat.py index 9e5757ce..0b374a71 100644 --- a/atr/tasks/checks/rat.py +++ b/atr/tasks/checks/rat.py @@ -65,6 +65,9 @@ _STD_EXCLUSIONS_EXTENDED: Final[list[str]] = [ "GIT", "STANDARD_SCMS", ] +# Release policy fields which this check relies on - used for result caching +INPUT_POLICY_KEYS: Final[list[str]] = [] +INPUT_EXTRA_ARGS: Final[list[str]] = [] class RatError(RuntimeError): @@ -85,9 +88,7 @@ async def check(args: checks.FunctionArguments) -> results.Results | None: log.info(f"Skipping RAT check for {artifact_abs_path} (mode is LIGHTWEIGHT)") return None - if await recorder.check_cache(artifact_abs_path): - log.info(f"Using cached RAT result for {artifact_abs_path} (rel: {args.primary_rel_path})") - return None + await recorder.cache_key_set(INPUT_POLICY_KEYS) log.info(f"Checking RAT licenses for {artifact_abs_path} (rel: {args.primary_rel_path})") diff --git a/atr/tasks/checks/signature.py b/atr/tasks/checks/signature.py index 81ac1acf..6405d137 100644 --- a/atr/tasks/checks/signature.py +++ b/atr/tasks/checks/signature.py @@ -30,6 +30,10 @@ import atr.models.sql as sql import atr.tasks.checks as checks import atr.util as util +# Release policy fields which this check relies on - used for result caching +INPUT_POLICY_KEYS: Final[list[str]] = [] +INPUT_EXTRA_ARGS: Final[list[str]] = ["committee_name"] + async def check(args: checks.FunctionArguments) -> results.Results | None: """Check a signature file.""" @@ -50,6 +54,8 @@ async def check(args: checks.FunctionArguments) -> results.Results | None: await recorder.exception("Committee name is required", {"committee_name": committee_name}) return None + await recorder.cache_key_set(INPUT_POLICY_KEYS, {"committee_name": committee_name}) + log.info( f"Checking signature {primary_abs_path} for {artifact_abs_path}" f" using {committee_name} keys (rel: {primary_rel_path})" diff --git a/atr/tasks/checks/targz.py b/atr/tasks/checks/targz.py index ffe22b3c..e4c83fb4 100644 --- a/atr/tasks/checks/targz.py +++ b/atr/tasks/checks/targz.py @@ -25,6 +25,10 @@ import atr.tarzip as tarzip import atr.tasks.checks as checks import atr.util as util +# Release policy fields which this check relies on - used for result caching +INPUT_POLICY_KEYS: Final[list[str]] = [] +INPUT_EXTRA_ARGS: Final[list[str]] = [] + class RootDirectoryError(Exception): """Exception raised when a root directory is not found in an archive.""" @@ -38,6 +42,8 @@ async def integrity(args: checks.FunctionArguments) -> results.Results | None: if not (artifact_abs_path := await recorder.abs_path()): return None + await recorder.cache_key_set(INPUT_POLICY_KEYS) + log.info(f"Checking integrity for {artifact_abs_path} (rel: {args.primary_rel_path})") chunk_size = 4096 @@ -95,6 +101,8 @@ async def structure(args: checks.FunctionArguments) -> results.Results | None: if await recorder.primary_path_is_binary(): return None + await recorder.cache_key_set(INPUT_POLICY_KEYS) + filename = artifact_abs_path.name basename_from_filename: Final[str] = ( filename.removesuffix(".tar.gz") if filename.endswith(".tar.gz") else filename.removesuffix(".tgz") diff --git a/atr/tasks/checks/zipformat.py b/atr/tasks/checks/zipformat.py index d1888734..428a01a2 100644 --- a/atr/tasks/checks/zipformat.py +++ b/atr/tasks/checks/zipformat.py @@ -18,7 +18,7 @@ import asyncio import os import zipfile -from typing import Any +from typing import Any, Final import atr.log as log import atr.models.results as results @@ -26,6 +26,10 @@ import atr.tarzip as tarzip import atr.tasks.checks as checks import atr.util as util +# Release policy fields which this check relies on - used for result caching +INPUT_POLICY_KEYS: Final[list[str]] = [] +INPUT_EXTRA_ARGS: Final[list[str]] = [] + async def integrity(args: checks.FunctionArguments) -> results.Results | None: """Check that the zip archive is not corrupted and can be opened.""" @@ -33,6 +37,8 @@ async def integrity(args: checks.FunctionArguments) -> results.Results | None: if not (artifact_abs_path := await recorder.abs_path()): return None + await recorder.cache_key_set(INPUT_POLICY_KEYS) + log.info(f"Checking zip integrity for {artifact_abs_path} (rel: {args.primary_rel_path})") try: @@ -57,6 +63,8 @@ async def structure(args: checks.FunctionArguments) -> results.Results | None: if await recorder.primary_path_is_binary(): return None + await recorder.cache_key_set(INPUT_POLICY_KEYS) + log.info(f"Checking zip structure for {artifact_abs_path} (rel: {args.primary_rel_path})") try: diff --git a/migrations/versions/0049_2026.02.11_5b874ed2.py b/migrations/versions/0049_2026.02.11_5b874ed2.py new file mode 100644 index 00000000..e4730cd6 --- /dev/null +++ b/migrations/versions/0049_2026.02.11_5b874ed2.py @@ -0,0 +1,37 @@ +"""Rename input_hash inputs_hash + +Revision ID: 0049_2026.02.11_5b874ed2 +Revises: 0048_2026.02.06_blocking_to_blocker +Create Date: 2026-02-11 13:42:59.712570+00:00 +""" + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op + +# Revision identifiers, used by Alembic +revision: str = "0049_2026.02.11_5b874ed2" +down_revision: str | None = "0048_2026.02.06_blocking_to_blocker" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + with op.batch_alter_table("checkresult", schema=None) as batch_op: + batch_op.add_column(sa.Column("inputs_hash", sa.String(), nullable=True)) + batch_op.drop_index(batch_op.f("ix_checkresult_input_hash")) + batch_op.create_index(batch_op.f("ix_checkresult_inputs_hash"), ["inputs_hash"], unique=False) + with op.batch_alter_table("checkresult", schema=None) as batch_op: + batch_op.execute("UPDATE checkresult SET inputs_hash = input_hash") + batch_op.drop_column("input_hash") + + +def downgrade() -> None: + with op.batch_alter_table("checkresult", schema=None) as batch_op: + batch_op.add_column(sa.Column("input_hash", sa.VARCHAR(), nullable=True)) + batch_op.drop_index(batch_op.f("ix_checkresult_inputs_hash")) + batch_op.create_index(batch_op.f("ix_checkresult_input_hash"), ["input_hash"], unique=False) + with op.batch_alter_table("checkresult", schema=None) as batch_op: + batch_op.execute("UPDATE checkresult SET input_hash = inputs_hash") + batch_op.drop_column("inputs_hash") diff --git a/tests/unit/recorders.py b/tests/unit/recorders.py index 33e5af03..47c772eb 100644 --- a/tests/unit/recorders.py +++ b/tests/unit/recorders.py @@ -63,7 +63,7 @@ class RecorderStub(checks.Recorder): status=status, message=message, data=data, - input_hash=None, + inputs_hash=None, ) --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
