This is an automated email from the ASF dual-hosted git repository. arm pushed a commit to branch check_caching in repository https://gitbox.apache.org/repos/asf/tooling-trusted-releases.git
commit 9c5343134cbecbd4ec5546b283198475166ae14b Author: Alastair McFarlane <[email protected]> AuthorDate: Mon Feb 16 14:26:00 2026 +0000 Update check caching to use hash keys of inputs --- atr/attestable.py | 71 ++++++-- atr/db/__init__.py | 8 + atr/docs/checks.md | 2 +- atr/docs/tasks.md | 2 +- atr/get/report.py | 11 +- atr/merge.py | 7 +- atr/models/attestable.py | 8 +- atr/models/sql.py | 2 +- atr/storage/readers/checks.py | 57 ++++++- atr/storage/readers/releases.py | 36 ++-- atr/storage/types.py | 3 +- atr/storage/writers/revision.py | 7 +- atr/tasks/__init__.py | 336 +++++++++++++++++++++++++++++++------- atr/tasks/checks/__init__.py | 218 +++++++++++++++++-------- atr/tasks/checks/compare.py | 5 + atr/tasks/checks/hashing.py | 7 + atr/tasks/checks/license.py | 16 +- atr/tasks/checks/paths.py | 18 +- atr/tasks/checks/rat.py | 7 +- atr/tasks/checks/signature.py | 6 + atr/tasks/checks/targz.py | 8 + atr/tasks/checks/zipformat.py | 10 +- atr/util.py | 22 +++ tests/unit/recorders.py | 28 +++- tests/unit/test_cache.py | 10 +- tests/unit/test_checks_compare.py | 5 + 26 files changed, 694 insertions(+), 216 deletions(-) diff --git a/atr/attestable.py b/atr/attestable.py index d4d6d15a..50260093 100644 --- a/atr/attestable.py +++ b/atr/attestable.py @@ -18,22 +18,21 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING, Any, Final +from typing import TYPE_CHECKING, Any import aiofiles import aiofiles.os -import blake3 import pydantic +import atr.hashes as hashes import atr.log as log import atr.models.attestable as models import atr.util as util +from atr.models.attestable import AttestableChecksV1 if TYPE_CHECKING: import pathlib -_HASH_CHUNK_SIZE: Final[int] = 4 * 1024 * 1024 - def attestable_path(project_name: str, version_name: str, revision_number: str) -> pathlib.Path: return util.get_attestable_dir() / project_name / version_name / f"{revision_number}.json" @@ -43,18 +42,14 @@ def attestable_paths_path(project_name: str, version_name: str, revision_number: return util.get_attestable_dir() / project_name / version_name / f"{revision_number}.paths.json" -async def compute_file_hash(path: pathlib.Path) -> str: - hasher = blake3.blake3() - async with aiofiles.open(path, "rb") as f: - while chunk := await f.read(_HASH_CHUNK_SIZE): - hasher.update(chunk) - return f"blake3:{hasher.hexdigest()}" - - def github_tp_payload_path(project_name: str, version_name: str, revision_number: str) -> pathlib.Path: return util.get_attestable_dir() / project_name / version_name / f"{revision_number}.github-tp.json" +def attestable_checks_path(project_name: str, version_name: str, revision_number: str) -> pathlib.Path: + return util.get_attestable_dir() / project_name / version_name / f"{revision_number}.checks.json" + + async def github_tp_payload_write( project_name: str, version_name: str, revision_number: str, github_payload: dict[str, Any] ) -> None: @@ -99,6 +94,22 @@ async def load_paths( return None +async def load_checks( + project_name: str, + version_name: str, + revision_number: str, +) -> list[int] | None: + file_path = attestable_checks_path(project_name, version_name, revision_number) + if await aiofiles.os.path.isfile(file_path): + try: + async with aiofiles.open(file_path, encoding="utf-8") as f: + data = json.loads(await f.read()) + return models.AttestableChecksV1.model_validate(data).checks + except (json.JSONDecodeError, pydantic.ValidationError) as e: + log.warning(f"Could not parse {file_path}: {e}") + return [] + + def migrate_to_paths_files() -> int: attestable_dir = util.get_attestable_dir() if not attestable_dir.is_dir(): @@ -140,26 +151,52 @@ async def paths_to_hashes_and_sizes(directory: pathlib.Path) -> tuple[dict[str, if "\\" in path_key: # TODO: We should centralise this, and forbid some other characters too raise ValueError(f"Backslash in path is forbidden: {path_key}") - path_to_hash[path_key] = await compute_file_hash(full_path) + path_to_hash[path_key] = await hashes.compute_file_hash(full_path) path_to_size[path_key] = (await aiofiles.os.stat(full_path)).st_size return path_to_hash, path_to_size -async def write( +async def write_files_data( project_name: str, version_name: str, revision_number: str, + release_policy: dict[str, Any] | None, uploader_uid: str, previous: models.AttestableV1 | None, path_to_hash: dict[str, str], path_to_size: dict[str, int], ) -> None: - result = _generate(path_to_hash, path_to_size, revision_number, uploader_uid, previous) + result = _generate_files_data(path_to_hash, path_to_size, revision_number, release_policy, uploader_uid, previous) file_path = attestable_path(project_name, version_name, revision_number) await util.atomic_write_file(file_path, result.model_dump_json(indent=2)) paths_result = models.AttestablePathsV1(paths=result.paths) paths_file_path = attestable_paths_path(project_name, version_name, revision_number) await util.atomic_write_file(paths_file_path, paths_result.model_dump_json(indent=2)) + checks_file_path = attestable_checks_path(project_name, version_name, revision_number) + if not checks_file_path.exists(): + async with aiofiles.open(checks_file_path, "w", encoding="utf-8") as f: + await f.write(models.AttestableChecksV1().model_dump_json(indent=2)) + + +async def write_checks_data( + project_name: str, + version_name: str, + revision_number: str, + checks: list[int], +) -> None: + log.info(f"Writing checks for {project_name}/{version_name}/{revision_number}: {checks}") + + def modify(content: str) -> str: + try: + current = AttestableChecksV1.model_validate_json(content).checks + except pydantic.ValidationError: + current = [] + new_checks = set(current or []) + new_checks.update(checks) + result = models.AttestableChecksV1(checks=sorted(new_checks)) + return result.model_dump_json(indent=2) + + await util.atomic_modify_file(attestable_checks_path(project_name, version_name, revision_number), modify) def _compute_hashes_with_attribution( @@ -197,10 +234,11 @@ def _compute_hashes_with_attribution( return new_hashes -def _generate( +def _generate_files_data( path_to_hash: dict[str, str], path_to_size: dict[str, int], revision_number: str, + release_policy: dict[str, Any] | None, uploader_uid: str, previous: models.AttestableV1 | None, ) -> models.AttestableV1: @@ -215,4 +253,5 @@ def _generate( return models.AttestableV1( paths=dict(path_to_hash), hashes=dict(new_hashes), + policy=release_policy or {}, ) diff --git a/atr/db/__init__.py b/atr/db/__init__.py index eb454e0b..2c6d579e 100644 --- a/atr/db/__init__.py +++ b/atr/db/__init__.py @@ -155,6 +155,7 @@ class Session(sqlalchemy.ext.asyncio.AsyncSession): def check_result( self, id: Opt[int] = NOT_SET, + id_in: Opt[list[int]] = NOT_SET, release_name: Opt[str] = NOT_SET, revision_number: Opt[str] = NOT_SET, checker: Opt[str] = NOT_SET, @@ -164,12 +165,17 @@ class Session(sqlalchemy.ext.asyncio.AsyncSession): status: Opt[sql.CheckResultStatus] = NOT_SET, message: Opt[str] = NOT_SET, data: Opt[Any] = NOT_SET, + inputs_hash: Opt[str] = NOT_SET, _release: bool = False, ) -> Query[sql.CheckResult]: query = sqlmodel.select(sql.CheckResult) + via = sql.validate_instrumented_attribute + if is_defined(id): query = query.where(sql.CheckResult.id == id) + if is_defined(id_in): + query = query.where(via(sql.CheckResult.id).in_(id_in)) if is_defined(release_name): query = query.where(sql.CheckResult.release_name == release_name) if is_defined(revision_number): @@ -188,6 +194,8 @@ class Session(sqlalchemy.ext.asyncio.AsyncSession): query = query.where(sql.CheckResult.message == message) if is_defined(data): query = query.where(sql.CheckResult.data == data) + if is_defined(inputs_hash): + query = query.where(sql.CheckResult.inputs_hash == inputs_hash) if _release: query = query.options(joined_load(sql.CheckResult.release)) diff --git a/atr/docs/checks.md b/atr/docs/checks.md index 7a348179..30badd04 100644 --- a/atr/docs/checks.md +++ b/atr/docs/checks.md @@ -52,7 +52,7 @@ This check records separate checker keys for errors, warnings, and success. Use For each `.sha256` or `.sha512` file, ATR computes the hash of the referenced artifact and compares it with the expected value. It supports files that contain just the hash as well as files that include a filename and hash on the same line. If the suffix does not indicate `sha256` or `sha512`, the check fails. -The checker key is `atr.tasks.checks.hashing.check`. +The checker key is `atr.tasks.checks.file_hash.check`. ### Signature verification diff --git a/atr/docs/tasks.md b/atr/docs/tasks.md index 98c409a3..8761a25c 100644 --- a/atr/docs/tasks.md +++ b/atr/docs/tasks.md @@ -41,7 +41,7 @@ In `atr/tasks/checks` you will find several modules that perform these check tas In `atr/tasks/__init__.py` you will see imports for existing modules where you can add an import for new check task, for example: ```python -import atr.tasks.checks.hashing as hashing +import atr.tasks.checks.hashing as file_hash import atr.tasks.checks.license as license ``` diff --git a/atr/get/report.py b/atr/get/report.py index 6559b934..c987755a 100644 --- a/atr/get/report.py +++ b/atr/get/report.py @@ -40,10 +40,17 @@ async def selected_path(session: web.Committer, project_name: str, version_name: # If the draft is not found, we try to get the release candidate try: - release = await session.release(project_name, version_name, with_committee=True) + release = await session.release( + project_name, version_name, with_committee=True, with_release_policy=True, with_project_release_policy=True + ) except base.ASFQuartException: release = await session.release( - project_name, version_name, phase=sql.ReleasePhase.RELEASE_CANDIDATE, with_committee=True + project_name, + version_name, + phase=sql.ReleasePhase.RELEASE_CANDIDATE, + with_committee=True, + with_release_policy=True, + with_project_release_policy=True, ) if release.committee is None: diff --git a/atr/merge.py b/atr/merge.py index 75e18620..9e341b43 100644 --- a/atr/merge.py +++ b/atr/merge.py @@ -24,6 +24,7 @@ from typing import TYPE_CHECKING import aiofiles.os import atr.attestable as attestable +import atr.hashes as hashes import atr.util as util if TYPE_CHECKING: @@ -131,7 +132,7 @@ async def _add_from_prior( if (prior_hashes is not None) and (path in prior_hashes): n_hashes[path] = prior_hashes[path] else: - n_hashes[path] = await attestable.compute_file_hash(target) + n_hashes[path] = await hashes.compute_file_hash(target) stat_result = await aiofiles.os.stat(target) n_sizes[path] = stat_result.st_size return prior_hashes @@ -211,7 +212,7 @@ async def _merge_all_present( if (prior_hashes is not None) and (path in prior_hashes): p_hash = prior_hashes[path] else: - p_hash = await attestable.compute_file_hash(prior_dir / path) + p_hash = await hashes.compute_file_hash(prior_dir / path) if p_hash != b_hash: # Case 11 via hash: base and new have the same content but prior differs return await _replace_with_prior( @@ -250,7 +251,7 @@ async def _replace_with_prior( if (prior_hashes is not None) and (path in prior_hashes): n_hashes[path] = prior_hashes[path] else: - n_hashes[path] = await attestable.compute_file_hash(file_path) + n_hashes[path] = await hashes.compute_file_hash(file_path) stat_result = await aiofiles.os.stat(file_path) n_sizes[path] = stat_result.st_size return prior_hashes diff --git a/atr/models/attestable.py b/atr/models/attestable.py index 4bc574bd..2af1cc2f 100644 --- a/atr/models/attestable.py +++ b/atr/models/attestable.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from typing import Annotated, Literal +from typing import Annotated, Any, Literal import pydantic @@ -27,6 +27,11 @@ class HashEntry(schema.Strict): uploaders: list[Annotated[tuple[str, str], pydantic.BeforeValidator(tuple)]] +class AttestableChecksV1(schema.Strict): + version: Literal[1] = 1 + checks: list[int] = schema.factory(list) + + class AttestablePathsV1(schema.Strict): version: Literal[1] = 1 paths: dict[str, str] = schema.factory(dict) @@ -36,3 +41,4 @@ class AttestableV1(schema.Strict): version: Literal[1] = 1 paths: dict[str, str] = schema.factory(dict) hashes: dict[str, HashEntry] = schema.factory(dict) + policy: dict[str, Any] = schema.factory(dict) diff --git a/atr/models/sql.py b/atr/models/sql.py index a6e9aed5..c03a3953 100644 --- a/atr/models/sql.py +++ b/atr/models/sql.py @@ -946,7 +946,7 @@ class CheckResult(sqlmodel.SQLModel, table=True): data: Any = sqlmodel.Field( sa_column=sqlalchemy.Column(sqlalchemy.JSON), **example({"expected": "...", "found": "..."}) ) - input_hash: str | None = sqlmodel.Field(default=None, index=True, **example("blake3:7f83b1657ff1fc...")) + inputs_hash: str | None = sqlmodel.Field(default=None, index=True, **example("blake3:7f83b1657ff1fc...")) cached: bool = sqlmodel.Field(default=False, **example(False)) diff --git a/atr/storage/readers/checks.py b/atr/storage/readers/checks.py index b48daf35..1a95df2c 100644 --- a/atr/storage/readers/checks.py +++ b/atr/storage/readers/checks.py @@ -18,17 +18,51 @@ # Removing this will cause circular imports from __future__ import annotations +import importlib from typing import TYPE_CHECKING +import atr.attestable as attestable import atr.db as db +import atr.hashes as hashes import atr.models.sql as sql import atr.storage as storage import atr.storage.types as types +import atr.tasks.checks as checks import atr.util as util if TYPE_CHECKING: import pathlib - from collections.abc import Callable + from collections.abc import Callable, Sequence + + +async def _filter_check_results_by_hash( + all_check_results: Sequence[sql.CheckResult], + rel_path: pathlib.Path, + input_hash_by_module: dict[str, str | None], + release: sql.Release, +) -> Sequence[sql.CheckResult]: + filtered_check_results = [] + if release.latest_revision_number is None: + raise ValueError("Release has no revision - Invalid state") + for cr in all_check_results: + if cr.checker not in input_hash_by_module: + module_path = cr.checker.rsplit(".", 1)[0] + try: + module = importlib.import_module(module_path) + policy_keys: list[str] = module.INPUT_POLICY_KEYS + extra_arg_names: list[str] = getattr(module, "INPUT_EXTRA_ARGS", []) + except (ImportError, AttributeError): + policy_keys = [] + extra_arg_names = [] + extra_args = await checks.resolve_extra_args(extra_arg_names, release) + cache_key = await checks.resolve_cache_key( + cr.checker, policy_keys, release, release.latest_revision_number, extra_args, file=rel_path.name + ) + input_hash_by_module[cr.checker] = hashes.compute_dict_hash(cache_key) if cache_key else None + + if cr.inputs_hash == input_hash_by_module[cr.checker]: + filtered_check_results.append(cr) + return filtered_check_results class GeneralPublic: @@ -48,15 +82,20 @@ class GeneralPublic: if release.latest_revision_number is None: raise ValueError("Release has no revision - Invalid state") - query = self.__data.check_result( - release_name=release.name, - revision_number=release.latest_revision_number, - primary_rel_path=str(rel_path), - ).order_by( - sql.validate_instrumented_attribute(sql.CheckResult.checker).asc(), - sql.validate_instrumented_attribute(sql.CheckResult.created).desc(), + check_ids = await attestable.load_checks(release.project_name, release.version, release.latest_revision_number) + all_check_results = ( + [ + a + for a in await self.__data.check_result(id_in=check_ids) + .order_by( + sql.validate_instrumented_attribute(sql.CheckResult.checker).asc(), + sql.validate_instrumented_attribute(sql.CheckResult.created).desc(), + ) + .all() + ] + if check_ids + else [] ) - all_check_results = await query.all() # Filter out any results that are ignored unignored_checks = [] diff --git a/atr/storage/readers/releases.py b/atr/storage/readers/releases.py index e57a9786..90445082 100644 --- a/atr/storage/readers/releases.py +++ b/atr/storage/readers/releases.py @@ -21,6 +21,7 @@ from __future__ import annotations import dataclasses import pathlib +import atr.attestable as attestable import atr.classify as classify import atr.db as db import atr.models.sql as sql @@ -122,36 +123,29 @@ class GeneralPublic: self, release: sql.Release, latest_revision_number: str, info: types.PathInfo ) -> None: match_ignore = await self.__read_as.checks.ignores_matcher(release.project_name) + check_ids = await attestable.load_checks(release.project_name, release.version, latest_revision_number) + attestable_checks = [a for a in await self.__data.check_result(id_in=check_ids).all()] if check_ids else [] cs = types.ChecksSubset( - release=release, - latest_revision_number=latest_revision_number, + checks=attestable_checks, info=info, match_ignore=match_ignore, ) + # TODO: These get just the ones for the revision. + # It might be better to get all like we do in by_release_path, filter by hash, then filter by status await self.__successes(cs) await self.__warnings(cs) await self.__errors(cs) await self.__blocker(cs) async def __blocker(self, cs: types.ChecksSubset) -> None: - blocker = await self.__data.check_result( - release_name=cs.release.name, - revision_number=cs.latest_revision_number, - member_rel_path=None, - status=sql.CheckResultStatus.BLOCKER, - ).all() + blocker = [cr for cr in cs.checks if cr.status == sql.CheckResultStatus.BLOCKER] for result in blocker: if primary_rel_path := result.primary_rel_path: cs.info.errors.setdefault(pathlib.Path(primary_rel_path), []).append(result) async def __errors(self, cs: types.ChecksSubset) -> None: - errors = await self.__data.check_result( - release_name=cs.release.name, - revision_number=cs.latest_revision_number, - member_rel_path=None, - status=sql.CheckResultStatus.FAILURE, - ).all() + errors = [cr for cr in cs.checks if cr.status == sql.CheckResultStatus.FAILURE] for error in errors: if cs.match_ignore(error): cs.info.ignored_errors.append(error) @@ -160,24 +154,14 @@ class GeneralPublic: cs.info.errors.setdefault(pathlib.Path(primary_rel_path), []).append(error) async def __successes(self, cs: types.ChecksSubset) -> None: - successes = await self.__data.check_result( - release_name=cs.release.name, - revision_number=cs.latest_revision_number, - member_rel_path=None, - status=sql.CheckResultStatus.SUCCESS, - ).all() + successes = [cr for cr in cs.checks if cr.status == sql.CheckResultStatus.SUCCESS] for success in successes: # Successes cannot be ignored if primary_rel_path := success.primary_rel_path: cs.info.successes.setdefault(pathlib.Path(primary_rel_path), []).append(success) async def __warnings(self, cs: types.ChecksSubset) -> None: - warnings = await self.__data.check_result( - release_name=cs.release.name, - revision_number=cs.latest_revision_number, - member_rel_path=None, - status=sql.CheckResultStatus.WARNING, - ).all() + warnings = [cr for cr in cs.checks if cr.status == sql.CheckResultStatus.WARNING] for warning in warnings: if cs.match_ignore(warning): cs.info.ignored_warnings.append(warning) diff --git a/atr/storage/types.py b/atr/storage/types.py index 3cd74f6b..3ed5d423 100644 --- a/atr/storage/types.py +++ b/atr/storage/types.py @@ -75,8 +75,7 @@ class PathInfo(schema.Strict): @dataclasses.dataclass class ChecksSubset: - release: sql.Release - latest_revision_number: str + checks: list[sql.CheckResult] info: PathInfo match_ignore: Callable[[sql.CheckResult], bool] diff --git a/atr/storage/writers/revision.py b/atr/storage/writers/revision.py index 5371a0b5..b2cd2a38 100644 --- a/atr/storage/writers/revision.py +++ b/atr/storage/writers/revision.py @@ -118,7 +118,7 @@ class CommitteeParticipant(FoundationCommitter): # Get the release release_name = sql.release_name(project_name, version_name) async with db.session() as data: - release = await data.release(name=release_name).demand( + release = await data.release(name=release_name, _release_policy=True, _project_release_policy=True).demand( RuntimeError("Release does not exist for new revision creation") ) old_revision = await interaction.latest_revision(release) @@ -243,10 +243,13 @@ class CommitteeParticipant(FoundationCommitter): await aioshutil.rmtree(temp_dir) raise - await attestable.write( + policy = release.release_policy or release.project.release_policy + + await attestable.write_files_data( project_name, version_name, new_revision.number, + policy.model_dump() if policy else None, asf_uid, previous_attestable, path_to_hash, diff --git a/atr/tasks/__init__.py b/atr/tasks/__init__.py index 8030727d..00beb6d7 100644 --- a/atr/tasks/__init__.py +++ b/atr/tasks/__init__.py @@ -15,15 +15,21 @@ # specific language governing permissions and limitations # under the License. +import asyncio import datetime +import logging +import pathlib from collections.abc import Awaitable, Callable, Coroutine from typing import Any, Final import sqlmodel +import atr.attestable as attestable import atr.db as db +import atr.hashes as hashes import atr.models.results as results import atr.models.sql as sql +import atr.tasks.checks as checks import atr.tasks.checks.compare as compare import atr.tasks.checks.hashing as hashing import atr.tasks.checks.license as license @@ -43,19 +49,29 @@ import atr.tasks.vote as vote import atr.util as util -async def asc_checks(asf_uid: str, release: sql.Release, revision: str, signature_path: str) -> list[sql.Task]: +async def asc_checks( + asf_uid: str, release: sql.Release, revision: str, signature_path: str, data: db.Session +) -> list[sql.Task | None]: """Create signature check task for a .asc file.""" tasks = [] if release.committee: tasks.append( - queued( + await queued( asf_uid, sql.TaskType.SIGNATURE_CHECK, release, revision, + data, signature_path, - {"committee_name": release.committee.name}, + check_cache_key=await checks.resolve_cache_key( + resolve(sql.TaskType.SIGNATURE_CHECK), + signature.INPUT_POLICY_KEYS, + release, + revision, + await checks.resolve_extra_args(signature.INPUT_EXTRA_ARGS, release), + ), + extra_args={"committee_name": release.committee.name}, ) ) @@ -120,9 +136,12 @@ async def draft_checks( relative_paths = [path async for path in util.paths_recursive(revision_path)] async with db.ensure_session(caller_data) as data: - release = await data.release(name=sql.release_name(project_name, release_version), _committee=True).demand( - RuntimeError("Release not found") - ) + release = await data.release( + name=sql.release_name(project_name, release_version), + _committee=True, + _release_policy=True, + _project_release_policy=True, + ).demand(RuntimeError("Release not found")) other_releases = ( await data.release(project_name=project_name, phase=sql.ReleasePhase.RELEASE) .order_by(sql.Release.released) @@ -136,51 +155,91 @@ async def draft_checks( (v for v in release_versions if util.version_sort_key(v.version) < release_version_sortable), None ) for path in relative_paths: - path_str = str(path) - task_function: Callable[[str, sql.Release, str, str], Awaitable[list[sql.Task]]] | None = None - for suffix, func in TASK_FUNCTIONS.items(): - if path.name.endswith(suffix): - task_function = func - break - if task_function: - for task in await task_function(asf_uid, release, revision_number, path_str): - task.revision_number = revision_number - data.add(task) - # TODO: Should we check .json files for their content? - # Ideally we would not have to do that - if path.name.endswith(".cdx.json"): - data.add( - queued( - asf_uid, - sql.TaskType.SBOM_TOOL_SCORE, - release, - revision_number, - path_str, - extra_args={ - "project_name": project_name, - "version_name": release_version, - "revision_number": revision_number, - "previous_release_version": previous_version.version if previous_version else None, - "file_path": path_str, - "asf_uid": asf_uid, - }, - ) - ) + await _draft_file_checks( + asf_uid, + caller_data, + data, + path, + previous_version, + project_name, + release, + release_version, + revision_number, + ) is_podling = False if release.project.committee is not None: if release.project.committee.is_podling: is_podling = True - path_check_task = queued( - asf_uid, sql.TaskType.PATHS_CHECK, release, revision_number, extra_args={"is_podling": is_podling} + path_check_task = await queued( + asf_uid, + sql.TaskType.PATHS_CHECK, + release, + revision_number, + caller_data, + check_cache_key=await checks.resolve_cache_key( + resolve(sql.TaskType.PATHS_CHECK), + paths.INPUT_POLICY_KEYS, + release, + revision_number, + await checks.resolve_extra_args(paths.INPUT_EXTRA_ARGS, release), + ignore_path=True, + ), + extra_args={"is_podling": is_podling}, ) - data.add(path_check_task) - if caller_data is None: - await data.commit() + if path_check_task: + data.add(path_check_task) + if caller_data is None: + await data.commit() return len(relative_paths) +async def _draft_file_checks( + asf_uid: str, + caller_data: db.Session | None, + data: db.Session, + path: pathlib.Path, + previous_version: sql.Release | None, + project_name: str, + release: sql.Release, + release_version: str, + revision_number: str, +): + path_str = str(path) + task_function: Callable[[str, sql.Release, str, str, db.Session], Awaitable[list[sql.Task | None]]] | None = None + for suffix, func in TASK_FUNCTIONS.items(): + if path.name.endswith(suffix): + task_function = func + break + if task_function: + for task in await task_function(asf_uid, release, revision_number, path_str, data): + if task: + task.revision_number = revision_number + data.add(task) + # TODO: Should we check .json files for their content? + # Ideally we would not have to do that + if path.name.endswith(".cdx.json"): + data.add( + await queued( + asf_uid, + sql.TaskType.SBOM_TOOL_SCORE, + release, + revision_number, + caller_data, + path_str, + extra_args={ + "project_name": project_name, + "version_name": release_version, + "revision_number": revision_number, + "previous_release_version": previous_version.version if previous_version else None, + "file_path": path_str, + "asf_uid": asf_uid, + }, + ) + ) + + async def keys_import_file( asf_uid: str, project_name: str, version_name: str, revision_number: str, caller_data: db.Session | None = None ) -> None: @@ -230,14 +289,27 @@ async def metadata_update( return task -def queued( +async def queued( asf_uid: str, task_type: sql.TaskType, release: sql.Release, revision_number: str, + data: db.Session | None = None, primary_rel_path: str | None = None, extra_args: dict[str, Any] | None = None, -) -> sql.Task: + check_cache_key: dict[str, Any] | None = None, +) -> sql.Task | None: + if check_cache_key is not None: + logging.info("cache key", check_cache_key) + hash_val = hashes.compute_dict_hash(check_cache_key) + if not data: + raise RuntimeError("DB Session is required for check_cache_key") + existing = await data.check_result(inputs_hash=hash_val, release_name=release.name).all() + if existing: + await attestable.write_checks_data( + release.project.name, release.version, revision_number, [c.id for c in existing] + ) + return None return sql.Task( status=sql.TaskStatus.QUEUED, task_type=task_type, @@ -304,29 +376,107 @@ def resolve(task_type: sql.TaskType) -> Callable[..., Awaitable[results.Results # Otherwise we lose exhaustiveness checking -async def sha_checks(asf_uid: str, release: sql.Release, revision: str, hash_file: str) -> list[sql.Task]: +async def sha_checks( + asf_uid: str, release: sql.Release, revision: str, hash_file: str, data: db.Session +) -> list[sql.Task | None]: """Create hash check task for a .sha256 or .sha512 file.""" tasks = [] - tasks.append(queued(asf_uid, sql.TaskType.HASHING_CHECK, release, revision, hash_file)) + tasks.append( + queued( + asf_uid, + sql.TaskType.HASHING_CHECK, + release, + revision, + data, + hash_file, + check_cache_key=await checks.resolve_cache_key( + resolve(sql.TaskType.HASHING_CHECK), + hashing.INPUT_POLICY_KEYS, + release, + revision, + await checks.resolve_extra_args(hashing.INPUT_EXTRA_ARGS, release), + file=hash_file, + ), + ) + ) - return tasks + return await asyncio.gather(*tasks) -async def tar_gz_checks(asf_uid: str, release: sql.Release, revision: str, path: str) -> list[sql.Task]: +async def tar_gz_checks( + asf_uid: str, release: sql.Release, revision: str, path: str, data: db.Session +) -> list[sql.Task | None]: """Create check tasks for a .tar.gz or .tgz file.""" # This release has committee, as guaranteed in draft_checks is_podling = (release.project.committee is not None) and release.project.committee.is_podling + compare_ck = await checks.resolve_cache_key( + resolve(sql.TaskType.COMPARE_SOURCE_TREES), + compare.INPUT_POLICY_KEYS, + release, + revision, + await checks.resolve_extra_args(compare.INPUT_EXTRA_ARGS, release), + file=path, + ) + license_h_ck = await checks.resolve_cache_key( + resolve(sql.TaskType.LICENSE_HEADERS), + license.INPUT_POLICY_KEYS, + release, + revision, + await checks.resolve_extra_args(license.INPUT_EXTRA_ARGS, release), + file=path, + ) + license_f_ck = await checks.resolve_cache_key( + resolve(sql.TaskType.LICENSE_FILES), + license.INPUT_POLICY_KEYS, + release, + revision, + await checks.resolve_extra_args(license.INPUT_EXTRA_ARGS, release), + file=path, + ) + rat_ck = await checks.resolve_cache_key( + resolve(sql.TaskType.RAT_CHECK), + rat.INPUT_POLICY_KEYS, + release, + revision, + await checks.resolve_extra_args(rat.INPUT_EXTRA_ARGS, release), + file=path, + ) + targz_i_ck = await checks.resolve_cache_key( + resolve(sql.TaskType.TARGZ_INTEGRITY), + targz.INPUT_POLICY_KEYS, + release, + revision, + await checks.resolve_extra_args(targz.INPUT_EXTRA_ARGS, release), + file=path, + ) + targz_s_ck = await checks.resolve_cache_key( + resolve(sql.TaskType.TARGZ_STRUCTURE), + targz.INPUT_POLICY_KEYS, + release, + revision, + await checks.resolve_extra_args(targz.INPUT_EXTRA_ARGS, release), + file=path, + ) tasks = [ - queued(asf_uid, sql.TaskType.COMPARE_SOURCE_TREES, release, revision, path), - queued(asf_uid, sql.TaskType.LICENSE_FILES, release, revision, path, extra_args={"is_podling": is_podling}), - queued(asf_uid, sql.TaskType.LICENSE_HEADERS, release, revision, path), - queued(asf_uid, sql.TaskType.RAT_CHECK, release, revision, path), - queued(asf_uid, sql.TaskType.TARGZ_INTEGRITY, release, revision, path), - queued(asf_uid, sql.TaskType.TARGZ_STRUCTURE, release, revision, path), + queued(asf_uid, sql.TaskType.COMPARE_SOURCE_TREES, release, revision, data, path, check_cache_key=compare_ck), + queued( + asf_uid, + sql.TaskType.LICENSE_FILES, + release, + revision, + data, + path, + check_cache_key=license_f_ck, + extra_args={"is_podling": is_podling}, + ), + queued(asf_uid, sql.TaskType.LICENSE_HEADERS, release, revision, data, path, check_cache_key=license_h_ck), + queued(asf_uid, sql.TaskType.RAT_CHECK, release, revision, data, path, check_cache_key=rat_ck), + queued(asf_uid, sql.TaskType.TARGZ_INTEGRITY, release, revision, data, path, check_cache_key=targz_i_ck), + queued(asf_uid, sql.TaskType.TARGZ_STRUCTURE, release, revision, data, path, check_cache_key=targz_s_ck), ] - return tasks + return await asyncio.gather(*tasks) async def workflow_update( @@ -356,22 +506,82 @@ async def workflow_update( return task -async def zip_checks(asf_uid: str, release: sql.Release, revision: str, path: str) -> list[sql.Task]: +async def zip_checks( + asf_uid: str, release: sql.Release, revision: str, path: str, data: db.Session +) -> list[sql.Task | None]: """Create check tasks for a .zip file.""" # This release has committee, as guaranteed in draft_checks is_podling = (release.project.committee is not None) and release.project.committee.is_podling + compare_ck = await checks.resolve_cache_key( + resolve(sql.TaskType.COMPARE_SOURCE_TREES), + compare.INPUT_POLICY_KEYS, + release, + revision, + await checks.resolve_extra_args(compare.INPUT_EXTRA_ARGS, release), + file=path, + ) + license_h_ck = await checks.resolve_cache_key( + resolve(sql.TaskType.LICENSE_HEADERS), + license.INPUT_POLICY_KEYS, + release, + revision, + await checks.resolve_extra_args(license.INPUT_EXTRA_ARGS, release), + file=path, + ) + license_f_ck = await checks.resolve_cache_key( + resolve(sql.TaskType.LICENSE_FILES), + license.INPUT_POLICY_KEYS, + release, + revision, + await checks.resolve_extra_args(license.INPUT_EXTRA_ARGS, release), + file=path, + ) + rat_ck = await checks.resolve_cache_key( + resolve(sql.TaskType.RAT_CHECK), + rat.INPUT_POLICY_KEYS, + release, + revision, + await checks.resolve_extra_args(rat.INPUT_EXTRA_ARGS, release), + file=path, + ) + zip_i_ck = await checks.resolve_cache_key( + resolve(sql.TaskType.ZIPFORMAT_INTEGRITY), + zipformat.INPUT_POLICY_KEYS, + release, + revision, + await checks.resolve_extra_args(zipformat.INPUT_EXTRA_ARGS, release), + file=path, + ) + zip_s_ck = await checks.resolve_cache_key( + resolve(sql.TaskType.ZIPFORMAT_STRUCTURE), + zipformat.INPUT_POLICY_KEYS, + release, + revision, + await checks.resolve_extra_args(zipformat.INPUT_EXTRA_ARGS, release), + file=path, + ) + tasks = [ - queued(asf_uid, sql.TaskType.COMPARE_SOURCE_TREES, release, revision, path), - queued(asf_uid, sql.TaskType.LICENSE_FILES, release, revision, path, extra_args={"is_podling": is_podling}), - queued(asf_uid, sql.TaskType.LICENSE_HEADERS, release, revision, path), - queued(asf_uid, sql.TaskType.RAT_CHECK, release, revision, path), - queued(asf_uid, sql.TaskType.ZIPFORMAT_INTEGRITY, release, revision, path), - queued(asf_uid, sql.TaskType.ZIPFORMAT_STRUCTURE, release, revision, path), + queued(asf_uid, sql.TaskType.COMPARE_SOURCE_TREES, release, revision, data, path, check_cache_key=compare_ck), + queued( + asf_uid, + sql.TaskType.LICENSE_FILES, + release, + revision, + data, + path, + check_cache_key=license_f_ck, + extra_args={"is_podling": is_podling}, + ), + queued(asf_uid, sql.TaskType.LICENSE_HEADERS, release, revision, data, path, check_cache_key=license_h_ck), + queued(asf_uid, sql.TaskType.RAT_CHECK, release, revision, data, path, check_cache_key=rat_ck), + queued(asf_uid, sql.TaskType.ZIPFORMAT_INTEGRITY, release, revision, data, path, check_cache_key=zip_i_ck), + queued(asf_uid, sql.TaskType.ZIPFORMAT_STRUCTURE, release, revision, data, path, check_cache_key=zip_s_ck), ] - return tasks + return await asyncio.gather(*tasks) -TASK_FUNCTIONS: Final[dict[str, Callable[..., Coroutine[Any, Any, list[sql.Task]]]]] = { +TASK_FUNCTIONS: Final[dict[str, Callable[..., Coroutine[Any, Any, list[sql.Task | None]]]]] = { ".asc": asc_checks, ".sha256": sha_checks, ".sha512": sha_checks, diff --git a/atr/tasks/checks/__init__.py b/atr/tasks/checks/__init__.py index 1b78f68e..8b0d2830 100644 --- a/atr/tasks/checks/__init__.py +++ b/atr/tasks/checks/__init__.py @@ -20,26 +20,27 @@ from __future__ import annotations import dataclasses import datetime import functools -import pathlib from typing import TYPE_CHECKING, Any, Final import aiofiles import aiofiles.os -import blake3 import sqlmodel if TYPE_CHECKING: + import pathlib from collections.abc import Awaitable, Callable import atr.models.schema as schema +import atr.attestable as attestable import atr.config as config import atr.db as db +import atr.file_paths as file_paths +import atr.hashes as hashes +import atr.log as log import atr.models.sql as sql import atr.util as util -_HASH_CHUNK_SIZE: Final[int] = 4 * 1024 * 1024 - # Pydantic does not like Callable types, so we use a dataclass instead # It says: "you should define `Callable`, then call `FunctionArguments.model_rebuild()`" @@ -61,7 +62,7 @@ class Recorder: version_name: str primary_rel_path: str | None member_rel_path: str | None - revision: str + revision_number: str afresh: bool __cached: bool __input_hash: str | None @@ -77,7 +78,7 @@ class Recorder: member_rel_path: str | None = None, afresh: bool = True, ) -> None: - self.checker = function_key(checker) if callable(checker) else checker + self.checker = function_key(checker) self.release_name = sql.release_name(project_name, version_name) self.revision_number = revision_number self.primary_rel_path = primary_rel_path @@ -142,7 +143,7 @@ class Recorder: message=message, data=data, cached=False, - input_hash=self.__input_hash, + inputs_hash=self.input_hash, ) # It would be more efficient to keep a session open @@ -167,7 +168,7 @@ class Recorder: return self.abs_path_base() / rel_path_part def abs_path_base(self) -> pathlib.Path: - return pathlib.Path(util.get_unfinished_dir(), self.project_name, self.version_name, self.revision_number) + return file_paths.base_path_for_revision(self.project_name, self.version_name, self.revision_number) async def project(self) -> sql.Project: # TODO: Cache project @@ -196,13 +197,10 @@ class Recorder: abs_path = await self.abs_path() return matches(str(abs_path)) - @property - def cached(self) -> bool: - return self.__cached - - async def check_cache(self, path: pathlib.Path) -> bool: - if not await aiofiles.os.path.isfile(path): - return False + async def cache_key_set( + self, policy_keys: list[str], input_args: list[str] | None = None, checker: str | None = None + ) -> bool: + # TODO: Should this just be in the constructor? if config.get().DISABLE_CHECK_CACHE: return False @@ -214,48 +212,21 @@ class Recorder: if await aiofiles.os.path.exists(no_cache_file): return False - self.__input_hash = await _compute_file_hash(path) - async with db.session() as data: - via = sql.validate_instrumented_attribute - subquery = ( - sqlmodel.select( - sql.CheckResult.member_rel_path, - sqlmodel.func.max(via(sql.CheckResult.id)).label("max_id"), - ) - .where(sql.CheckResult.checker == self.checker) - .where(sql.CheckResult.input_hash == self.__input_hash) - .where(sql.CheckResult.primary_rel_path == self.primary_rel_path) - .group_by(sql.CheckResult.member_rel_path) - .subquery() + release = await data.release( + name=self.release_name, _release_policy=True, _project_release_policy=True, _project=True + ).demand(RuntimeError(f"Release {self.release_name} not found")) + args = await resolve_extra_args(input_args or [], release) + cache_key = await resolve_cache_key( + checker or self.checker, policy_keys, release, self.revision_number, args, file=self.primary_rel_path ) - stmt = sqlmodel.select(sql.CheckResult).join(subquery, via(sql.CheckResult.id) == subquery.c.max_id) - results = await data.execute(stmt) - cached_results = results.scalars().all() - - if not cached_results: - return False - - for cached in cached_results: - new_result = sql.CheckResult( - release_name=self.release_name, - revision_number=self.revision_number, - checker=self.checker, - primary_rel_path=self.primary_rel_path, - member_rel_path=cached.member_rel_path, - created=datetime.datetime.now(datetime.UTC), - status=cached.status, - message=cached.message, - data=cached.data, - cached=True, - input_hash=self.__input_hash, - ) - data.add(new_result) - await data.commit() - - self.__cached = True + self.__input_hash = hashes.compute_dict_hash(cache_key) if cache_key else None return True + @property + def cached(self) -> bool: + return self.__cached + async def clear(self, primary_rel_path: str | None = None, member_rel_path: str | None = None) -> None: async with db.session() as data: stmt = sqlmodel.delete(sql.CheckResult).where( @@ -273,48 +244,72 @@ class Recorder: return self.__input_hash async def blocker( - self, message: str, data: Any, primary_rel_path: str | None = None, member_rel_path: str | None = None + self, + message: str, + data: Any, + primary_rel_path: str | None = None, + member_rel_path: str | None = None, ) -> sql.CheckResult: - return await self._add( + result = await self._add( sql.CheckResultStatus.BLOCKER, message, data, primary_rel_path=primary_rel_path, member_rel_path=member_rel_path, ) + await attestable.write_checks_data(self.project_name, self.version_name, self.revision_number, [result.id]) + return result async def exception( - self, message: str, data: Any, primary_rel_path: str | None = None, member_rel_path: str | None = None + self, + message: str, + data: Any, + primary_rel_path: str | None = None, + member_rel_path: str | None = None, ) -> sql.CheckResult: - return await self._add( + result = await self._add( sql.CheckResultStatus.EXCEPTION, message, data, primary_rel_path=primary_rel_path, member_rel_path=member_rel_path, ) + await attestable.write_checks_data(self.project_name, self.version_name, self.revision_number, [result.id]) + return result async def failure( - self, message: str, data: Any, primary_rel_path: str | None = None, member_rel_path: str | None = None + self, + message: str, + data: Any, + primary_rel_path: str | None = None, + member_rel_path: str | None = None, ) -> sql.CheckResult: - return await self._add( + result = await self._add( sql.CheckResultStatus.FAILURE, message, data, primary_rel_path=primary_rel_path, member_rel_path=member_rel_path, ) + await attestable.write_checks_data(self.project_name, self.version_name, self.revision_number, [result.id]) + return result async def success( - self, message: str, data: Any, primary_rel_path: str | None = None, member_rel_path: str | None = None + self, + message: str, + data: Any, + primary_rel_path: str | None = None, + member_rel_path: str | None = None, ) -> sql.CheckResult: - return await self._add( + result = await self._add( sql.CheckResultStatus.SUCCESS, message, data, primary_rel_path=primary_rel_path, member_rel_path=member_rel_path, ) + await attestable.write_checks_data(self.project_name, self.version_name, self.revision_number, [result.id]) + return result async def use_check_cache(self) -> bool: if self.__use_check_cache is not None: @@ -329,19 +324,73 @@ class Recorder: return self.__use_check_cache async def warning( - self, message: str, data: Any, primary_rel_path: str | None = None, member_rel_path: str | None = None + self, + message: str, + data: Any, + primary_rel_path: str | None = None, + member_rel_path: str | None = None, ) -> sql.CheckResult: - return await self._add( + result = await self._add( sql.CheckResultStatus.WARNING, message, data, primary_rel_path=primary_rel_path, member_rel_path=member_rel_path, ) + await attestable.write_checks_data(self.project_name, self.version_name, self.revision_number, [result.id]) + return result -def function_key(func: Callable[..., Any]) -> str: - return func.__module__ + "." + func.__name__ +def function_key(func: Callable[..., Any] | str) -> str: + return func.__module__ + "." + func.__name__ if callable(func) else func + + +async def resolve_cache_key( + checker: str | Callable[..., Any], + policy_keys: list[str], + release: sql.Release, + revision: str, + args: dict[str, Any] | None = None, + file: str | None = None, + path: pathlib.Path | None = None, + ignore_path: bool = False, +) -> dict[str, Any] | None: + if not args: + args = {} + cache_key = {"checker": function_key(checker)} + file_hash = None + attestable_data = await attestable.load(release.project_name, release.version, revision) + if attestable_data: + policy = sql.ReleasePolicy.model_validate(attestable_data.policy) + if not ignore_path: + file_hash = attestable_data.paths.get(file) if file else None + else: + # TODO: Is this fallback valid / necessary? Or should we bail out if there's no attestable data? + policy = release.release_policy or release.project.release_policy + if not ignore_path: + if path is None: + path = file_paths.revision_path_for_file(release.project_name, release.version, revision, file or "") + file_hash = await hashes.compute_file_hash(path) + if file_hash: + cache_key["file_hash"] = file_hash + + if len(policy_keys) > 0 and policy is not None: + policy_dict = policy.model_dump(exclude_none=True) + return {**cache_key, **args, **{k: policy_dict[k] for k in policy_keys if k in policy_dict}} + else: + return {**cache_key, **args} + + +async def resolve_extra_args(arg_names: list[str], release: sql.Release) -> dict[str, Any]: + result: dict[str, Any] = {} + for name in arg_names: + resolver = _EXTRA_ARG_RESOLVERS.get(name, None) + # If we can't find a resolver, we'll carry on anyway since it'll just mean no cache potentially + if resolver is None: + log.warning(f"Unknown extra arg resolver: {name}") + return {} + result[name] = await resolver(release) + return result def with_model(cls: type[schema.Strict]) -> Callable[[Callable[..., Any]], Callable[..., Any]]: @@ -358,9 +407,36 @@ def with_model(cls: type[schema.Strict]) -> Callable[[Callable[..., Any]], Calla return decorator -async def _compute_file_hash(path: pathlib.Path) -> str: - hasher = blake3.blake3() - async with aiofiles.open(path, "rb") as f: - while chunk := await f.read(_HASH_CHUNK_SIZE): - hasher.update(chunk) - return f"blake3:{hasher.hexdigest()}" +async def _resolve_all_files(release: sql.Release) -> list[str]: + if not release.latest_revision_number: + return [] + if not ( + base_path := file_paths.base_path_for_revision( + release.project_name, release.version, release.latest_revision_number + ) + ): + return [] + + if not await aiofiles.os.path.isdir(base_path): + log.error(f"Base release directory does not exist or is not a directory: {base_path}") + return [] + relative_paths = [p async for p in util.paths_recursive(base_path)] + relative_paths_set = set(str(p) for p in relative_paths) + return list(relative_paths_set) + + +async def _resolve_is_podling(release: sql.Release) -> bool: + return (release.committee is not None) and release.committee.is_podling + + +async def _resolve_committee_name(release: sql.Release) -> str: + if release.committee is None: + raise ValueError("Release has no committee") + return release.committee.name + + +_EXTRA_ARG_RESOLVERS: Final[dict[str, Callable[[sql.Release], Any]]] = { + "all_files": _resolve_all_files, + "is_podling": _resolve_is_podling, + "committee_name": _resolve_committee_name, +} diff --git a/atr/tasks/checks/compare.py b/atr/tasks/checks/compare.py index 0df5ee48..162c9673 100644 --- a/atr/tasks/checks/compare.py +++ b/atr/tasks/checks/compare.py @@ -51,6 +51,9 @@ _DEFAULT_USER: Final[str] = "atr" _PERMITTED_ADDED_PATHS: Final[dict[str, list[str]]] = { "PKG-INFO": ["pyproject.toml"], } +# Release policy fields which this check relies on - used for result caching +INPUT_POLICY_KEYS: Final[list[str]] = [] +INPUT_EXTRA_ARGS: Final[list[str]] = [] @dataclasses.dataclass @@ -90,6 +93,8 @@ async def source_trees(args: checks.FunctionArguments) -> results.Results | None ) return None + await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) + payload = await _load_tp_payload(args.project_name, args.version_name, args.revision_number) checkout_dir: str | None = None archive_dir: str | None = None diff --git a/atr/tasks/checks/hashing.py b/atr/tasks/checks/hashing.py index e8ae78fe..cb3c8f77 100644 --- a/atr/tasks/checks/hashing.py +++ b/atr/tasks/checks/hashing.py @@ -17,6 +17,7 @@ import hashlib import secrets +from typing import Final import aiofiles @@ -24,6 +25,10 @@ import atr.log as log import atr.models.results as results import atr.tasks.checks as checks +# Release policy fields which this check relies on - used for result caching +INPUT_POLICY_KEYS: Final[list[str]] = [] +INPUT_EXTRA_ARGS: Final[list[str]] = [] + async def check(args: checks.FunctionArguments) -> results.Results | None: """Check the hash of a file.""" @@ -36,6 +41,8 @@ async def check(args: checks.FunctionArguments) -> results.Results | None: await recorder.failure("Unsupported hash algorithm", {"algorithm": algorithm}) return None + await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) + # Remove the hash file suffix to get the artifact path # This replaces the last suffix, which is what we want # >>> pathlib.Path("a/b/c.d.e.f.g").with_suffix(".x") diff --git a/atr/tasks/checks/license.py b/atr/tasks/checks/license.py index 066fef3a..49b20a7a 100644 --- a/atr/tasks/checks/license.py +++ b/atr/tasks/checks/license.py @@ -79,6 +79,10 @@ INCLUDED_PATTERNS: Final[list[str]] = [ r"\.(pl|pm|t)$", # Perl ] +# Release policy fields which this check relies on - used for result caching +INPUT_POLICY_KEYS: Final[list[str]] = [""] +INPUT_EXTRA_ARGS: Final[list[str]] = ["is_podling"] + # Types @@ -134,10 +138,12 @@ async def files(args: checks.FunctionArguments) -> results.Results | None: if project.policy_license_check_mode == sql.LicenseCheckMode.RAT: return None + is_podling = args.extra_args.get("is_podling", False) + await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) + log.info(f"Checking license files for {artifact_abs_path} (rel: {args.primary_rel_path})") try: - is_podling = args.extra_args.get("is_podling", False) for result in await asyncio.to_thread(_files_check_core_logic, str(artifact_abs_path), is_podling): match result: case ArtifactResult(): @@ -166,9 +172,11 @@ async def headers(args: checks.FunctionArguments) -> results.Results | None: if project.policy_license_check_mode == sql.LicenseCheckMode.RAT: return None - if await recorder.check_cache(artifact_abs_path): - log.info(f"Using cached license headers result for {artifact_abs_path} (rel: {args.primary_rel_path})") - return None + await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) + + # if await recorder.check_cache(artifact_abs_path): + # log.info(f"Using cached license headers result for {artifact_abs_path} (rel: {args.primary_rel_path})") + # return None log.info(f"Checking license headers for {artifact_abs_path} (rel: {args.primary_rel_path})") diff --git a/atr/tasks/checks/paths.py b/atr/tasks/checks/paths.py index 299e4b25..d37fc66b 100644 --- a/atr/tasks/checks/paths.py +++ b/atr/tasks/checks/paths.py @@ -37,6 +37,9 @@ _ALLOWED_TOP_LEVEL: Final = frozenset( "README", } ) +# Release policy fields which this check relies on - used for result caching +INPUT_POLICY_KEYS: Final[list[str]] = [] +INPUT_EXTRA_ARGS: Final[list[str]] = ["is_podling", "all_files"] async def check(args: checks.FunctionArguments) -> results.Results | None: @@ -85,6 +88,11 @@ async def check(args: checks.FunctionArguments) -> results.Results | None: is_podling = args.extra_args.get("is_podling", False) relative_paths = [p async for p in util.paths_recursive(base_path)] relative_paths_set = set(str(p) for p in relative_paths) + + await recorder_errors.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS, checker=checks.function_key(check)) + await recorder_warnings.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS, checker=checks.function_key(check)) + await recorder_success.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS, checker=checks.function_key(check)) + for relative_path in relative_paths: # Delegate processing of each path to the helper function await _check_path_process_single( @@ -289,12 +297,14 @@ async def _record( warnings: list[str], ) -> None: for error in errors: - await recorder_errors.failure(error, {}, primary_rel_path=relative_path_str) + await recorder_errors.failure(f"{relative_path_str}: {error}", {}, primary_rel_path=relative_path_str) for item in blockers: - await recorder_errors.blocker(item, {}, primary_rel_path=relative_path_str) + await recorder_errors.blocker(f"{relative_path_str}: {item}", {}, primary_rel_path=relative_path_str) for warning in warnings: - await recorder_warnings.warning(warning, {}, primary_rel_path=relative_path_str) + await recorder_warnings.warning(f"{relative_path_str}: {warning}", {}, primary_rel_path=relative_path_str) if not (errors or blockers or warnings): await recorder_success.success( - "Path structure and naming conventions conform to policy", {}, primary_rel_path=relative_path_str + f"{relative_path_str}: Path structure and naming conventions conform to policy", + {}, + primary_rel_path=relative_path_str, ) diff --git a/atr/tasks/checks/rat.py b/atr/tasks/checks/rat.py index 9e5757ce..9c065ec7 100644 --- a/atr/tasks/checks/rat.py +++ b/atr/tasks/checks/rat.py @@ -65,6 +65,9 @@ _STD_EXCLUSIONS_EXTENDED: Final[list[str]] = [ "GIT", "STANDARD_SCMS", ] +# Release policy fields which this check relies on - used for result caching +INPUT_POLICY_KEYS: Final[list[str]] = [] +INPUT_EXTRA_ARGS: Final[list[str]] = [] class RatError(RuntimeError): @@ -85,9 +88,7 @@ async def check(args: checks.FunctionArguments) -> results.Results | None: log.info(f"Skipping RAT check for {artifact_abs_path} (mode is LIGHTWEIGHT)") return None - if await recorder.check_cache(artifact_abs_path): - log.info(f"Using cached RAT result for {artifact_abs_path} (rel: {args.primary_rel_path})") - return None + await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) log.info(f"Checking RAT licenses for {artifact_abs_path} (rel: {args.primary_rel_path})") diff --git a/atr/tasks/checks/signature.py b/atr/tasks/checks/signature.py index 81ac1acf..830d9cfe 100644 --- a/atr/tasks/checks/signature.py +++ b/atr/tasks/checks/signature.py @@ -30,6 +30,10 @@ import atr.models.sql as sql import atr.tasks.checks as checks import atr.util as util +# Release policy fields which this check relies on - used for result caching +INPUT_POLICY_KEYS: Final[list[str]] = [] +INPUT_EXTRA_ARGS: Final[list[str]] = ["committee_name"] + async def check(args: checks.FunctionArguments) -> results.Results | None: """Check a signature file.""" @@ -50,6 +54,8 @@ async def check(args: checks.FunctionArguments) -> results.Results | None: await recorder.exception("Committee name is required", {"committee_name": committee_name}) return None + await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) + log.info( f"Checking signature {primary_abs_path} for {artifact_abs_path}" f" using {committee_name} keys (rel: {primary_rel_path})" diff --git a/atr/tasks/checks/targz.py b/atr/tasks/checks/targz.py index ffe22b3c..7b369e5e 100644 --- a/atr/tasks/checks/targz.py +++ b/atr/tasks/checks/targz.py @@ -25,6 +25,10 @@ import atr.tarzip as tarzip import atr.tasks.checks as checks import atr.util as util +# Release policy fields which this check relies on - used for result caching +INPUT_POLICY_KEYS: Final[list[str]] = [] +INPUT_EXTRA_ARGS: Final[list[str]] = [] + class RootDirectoryError(Exception): """Exception raised when a root directory is not found in an archive.""" @@ -38,6 +42,8 @@ async def integrity(args: checks.FunctionArguments) -> results.Results | None: if not (artifact_abs_path := await recorder.abs_path()): return None + await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) + log.info(f"Checking integrity for {artifact_abs_path} (rel: {args.primary_rel_path})") chunk_size = 4096 @@ -95,6 +101,8 @@ async def structure(args: checks.FunctionArguments) -> results.Results | None: if await recorder.primary_path_is_binary(): return None + await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) + filename = artifact_abs_path.name basename_from_filename: Final[str] = ( filename.removesuffix(".tar.gz") if filename.endswith(".tar.gz") else filename.removesuffix(".tgz") diff --git a/atr/tasks/checks/zipformat.py b/atr/tasks/checks/zipformat.py index d1888734..ed85f7b2 100644 --- a/atr/tasks/checks/zipformat.py +++ b/atr/tasks/checks/zipformat.py @@ -18,7 +18,7 @@ import asyncio import os import zipfile -from typing import Any +from typing import Any, Final import atr.log as log import atr.models.results as results @@ -26,6 +26,10 @@ import atr.tarzip as tarzip import atr.tasks.checks as checks import atr.util as util +# Release policy fields which this check relies on - used for result caching +INPUT_POLICY_KEYS: Final[list[str]] = [] +INPUT_EXTRA_ARGS: Final[list[str]] = [] + async def integrity(args: checks.FunctionArguments) -> results.Results | None: """Check that the zip archive is not corrupted and can be opened.""" @@ -33,6 +37,8 @@ async def integrity(args: checks.FunctionArguments) -> results.Results | None: if not (artifact_abs_path := await recorder.abs_path()): return None + await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) + log.info(f"Checking zip integrity for {artifact_abs_path} (rel: {args.primary_rel_path})") try: @@ -57,6 +63,8 @@ async def structure(args: checks.FunctionArguments) -> results.Results | None: if await recorder.primary_path_is_binary(): return None + await recorder.cache_key_set(INPUT_POLICY_KEYS, INPUT_EXTRA_ARGS) + log.info(f"Checking zip structure for {artifact_abs_path} (rel: {args.primary_rel_path})") try: diff --git a/atr/util.py b/atr/util.py index 7ddb9788..9cb3c820 100644 --- a/atr/util.py +++ b/atr/util.py @@ -21,6 +21,7 @@ import binascii import contextlib import dataclasses import datetime +import fcntl import hashlib import json import os @@ -207,6 +208,27 @@ async def atomic_write_file(file_path: pathlib.Path, content: str, encoding: str raise +async def atomic_modify_file( + file_path: pathlib.Path, + modify: Callable[[str], str], +) -> None: + # This function assumes that file_path already exists and its a regular file + lock_path = file_path.with_suffix(file_path.suffix + ".lock") + lock_fd = await asyncio.to_thread(os.open, str(lock_path), os.O_CREAT | os.O_RDWR) + try: + await asyncio.to_thread(fcntl.flock, lock_fd, fcntl.LOCK_EX) + try: + async with aiofiles.open(file_path, encoding="utf-8") as rf: + old_value = await rf.read() + new_value = modify(old_value) + if new_value != old_value: + await atomic_write_file(file_path, new_value) + finally: + fcntl.flock(lock_fd, fcntl.LOCK_UN) + finally: + await asyncio.to_thread(os.close, lock_fd) + + def chmod_directories(path: pathlib.Path, permissions: int = DIRECTORY_PERMISSIONS) -> None: # codeql[py/overly-permissive-file] os.chmod(path, permissions) diff --git a/tests/unit/recorders.py b/tests/unit/recorders.py index 33e5af03..bdb20dcf 100644 --- a/tests/unit/recorders.py +++ b/tests/unit/recorders.py @@ -18,6 +18,7 @@ import datetime import pathlib from collections.abc import Awaitable, Callable +from typing import Any import atr.models.sql as sql import atr.tasks.checks as checks @@ -40,6 +41,11 @@ class RecorderStub(checks.Recorder): async def abs_path(self, rel_path: str | None = None) -> pathlib.Path | None: return self._path if (rel_path is None) else self._path / rel_path + async def cache_key_set( + self, policy_keys: list[str], input_args: list[str] | None = None, checker: str | None = None + ) -> bool: + return False + async def primary_path_is_binary(self) -> bool: return False @@ -63,9 +69,29 @@ class RecorderStub(checks.Recorder): status=status, message=message, data=data, - input_hash=None, + inputs_hash=None, ) + async def exception( + self, message: str, data: Any, primary_rel_path: str | None = None, member_rel_path: str | None = None + ) -> sql.CheckResult: + return await self._add(sql.CheckResultStatus.EXCEPTION, message, data, primary_rel_path, member_rel_path) + + async def failure( + self, message: str, data: Any, primary_rel_path: str | None = None, member_rel_path: str | None = None + ) -> sql.CheckResult: + return await self._add(sql.CheckResultStatus.FAILURE, message, data, primary_rel_path, member_rel_path) + + async def success( + self, message: str, data: Any, primary_rel_path: str | None = None, member_rel_path: str | None = None + ) -> sql.CheckResult: + return await self._add(sql.CheckResultStatus.SUCCESS, message, data, primary_rel_path, member_rel_path) + + async def warning( + self, message: str, data: Any, primary_rel_path: str | None = None, member_rel_path: str | None = None + ) -> sql.CheckResult: + return await self._add(sql.CheckResultStatus.WARNING, message, data, primary_rel_path, member_rel_path) + def get_recorder(recorder: checks.Recorder) -> Callable[[], Awaitable[checks.Recorder]]: async def _recorder() -> checks.Recorder: diff --git a/tests/unit/test_cache.py b/tests/unit/test_cache.py index a1657783..12bf275b 100644 --- a/tests/unit/test_cache.py +++ b/tests/unit/test_cache.py @@ -115,17 +115,17 @@ async def test_admins_get_async_uses_extensions_when_available(mock_app: MockApp assert result == frozenset({"async_alice"}) -def test_admins_get_returns_empty_frozenset_when_not_set(mock_app: MockApp): - result = cache.admins_get() - assert result == frozenset() - - def test_admins_get_returns_frozenset_from_extensions(mock_app: MockApp): mock_app.extensions["admins"] = frozenset({"alice", "bob"}) result = cache.admins_get() assert result == frozenset({"alice", "bob"}) +def test_admins_get_returns_empty_frozenset_when_not_set(mock_app: MockApp): + result = cache.admins_get() + assert result == frozenset() + + @pytest.mark.asyncio async def test_admins_read_from_file_returns_none_for_invalid_json(state_dir: pathlib.Path): cache_path = state_dir / "cache" / "admins.json" diff --git a/tests/unit/test_checks_compare.py b/tests/unit/test_checks_compare.py index 5a90a029..e90e3314 100644 --- a/tests/unit/test_checks_compare.py +++ b/tests/unit/test_checks_compare.py @@ -242,6 +242,11 @@ class RecorderStub(atr.tasks.checks.Recorder): self.success_calls: list[tuple[str, object]] = [] self._is_source = is_source + async def cache_key_set( + self, policy_keys: list[str], input_args: list[str] | None = None, checker: str | None = None + ) -> bool: + return False + async def primary_path_is_source(self) -> bool: return self._is_source --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
