This is an automated email from the ASF dual-hosted git repository. arm pushed a commit to branch github_tp_validation in repository https://gitbox.apache.org/repos/asf/tooling-trusted-releases.git
commit 1c4f70a13c8f6261dd585a93ed7de7a85d324251 Author: Alastair McFarlane <[email protected]> AuthorDate: Thu Mar 12 14:32:51 2026 +0000 #676 Validate exp and nbp when loading pydantic model for Github token. Attestable class updated to store and load model instead of dict. --- atr/api/__init__.py | 8 ++++---- atr/attestable.py | 35 +++++++++++++++++++++++++++++++++-- atr/db/interaction.py | 17 ++++++++++------- atr/jwtoken.py | 4 ++-- atr/models/github.py | 24 ++++++++++++++++++++++-- atr/ssh.py | 9 +++++---- atr/storage/writers/ssh.py | 13 ++++++++++--- atr/tasks/checks/compare.py | 26 +------------------------- tests/unit/test_checks_compare.py | 21 +++++++++++---------- 9 files changed, 98 insertions(+), 59 deletions(-) diff --git a/atr/api/__init__.py b/atr/api/__init__.py index 79ff4d95..62d03f0c 100644 --- a/atr/api/__init__.py +++ b/atr/api/__init__.py @@ -294,8 +294,8 @@ async def distribute_ssh_register( ) async with storage.write_as_committee_member(util.unwrap(project.committee).key, asf_uid) as wacm: fingerprint, expires = await wacm.ssh.add_workflow_key( - payload["actor"], - payload["actor_id"], + payload.actor, + payload.actor_id, release.safe_project_key, data.ssh_key, payload, @@ -893,8 +893,8 @@ async def publisher_ssh_register( ) async with storage.write_as_committee_member(util.unwrap(project.committee).key, asf_uid) as wacm: fingerprint, expires = await wacm.ssh.add_workflow_key( - payload["actor"], - payload["actor_id"], + payload.actor, + payload.actor_id, project.safe_key, data.ssh_key, payload, diff --git a/atr/attestable.py b/atr/attestable.py index fc908ad2..a58c34ae 100644 --- a/atr/attestable.py +++ b/atr/attestable.py @@ -29,6 +29,7 @@ import atr.classify as classify import atr.hashes as hashes import atr.log as log import atr.models.attestable as models +import atr.models.github as github import atr.models.safe as safe import atr.models.sql as sql import atr.paths as paths @@ -128,14 +129,44 @@ def github_tp_payload_path( return paths.get_attestable_dir() / str(project_key) / str(version_key) / f"{revision_number!s}.github-tp.json" +async def github_tp_payload_read( + project_key: safe.ProjectKey, version_key: safe.VersionKey, revision_number: safe.RevisionNumber +) -> github.TrustedPublisherPayload | None: + payload_path = github_tp_payload_path(project_key, version_key, revision_number) + if not await aiofiles.os.path.isfile(payload_path): + return None + try: + async with aiofiles.open(payload_path, encoding="utf-8") as f: + data = json.loads(await f.read()) + if not isinstance(data, dict): + log.warning(f"TP payload was not a JSON object in {payload_path}") + return None + # Remove exp and nbf if they're stored - as of 2026-03-18 they're validated and then removed before storage + # but we might have older data + if "exp" in data: + del data["exp"] + if "nbf" in data: + del data["nbf"] + return github.TrustedPublisherPayload.model_validate(data) + except (OSError, json.JSONDecodeError) as e: + log.warning(f"Failed to read TP payload from {payload_path}: {e}") + return None + except pydantic.ValidationError as e: + log.warning(f"Failed to validate TP payload from {payload_path}: {e}") + return None + + async def github_tp_payload_write( project_key: safe.ProjectKey, version_key: safe.VersionKey, revision_number: safe.RevisionNumber, - github_payload: dict[str, Any], + github_payload: github.TrustedPublisherPayload, ) -> None: payload_path = github_tp_payload_path(project_key, version_key, revision_number) - await util.atomic_write_file(payload_path, json.dumps(github_payload, indent=2)) + # Dump the workflow payload, excluding exp and nbf - which shouldn't have made it this far. If they do, + # it's safe to remove them as they've been validated by the model already, and we should never store + # stale dates + await util.atomic_write_file(payload_path, json.dumps(github_payload.model_dump(exclude={"exp", "nbf"}), indent=2)) async def load( diff --git a/atr/db/interaction.py b/atr/db/interaction.py index 33ef59d9..f18a1a93 100644 --- a/atr/db/interaction.py +++ b/atr/db/interaction.py @@ -19,7 +19,7 @@ import contextlib import datetime import enum from collections.abc import AsyncGenerator, Sequence -from typing import Any, Final +from typing import Final import packaging.version as version import sqlalchemy @@ -31,6 +31,7 @@ import atr.db as db import atr.jwtoken as jwtoken import atr.ldap as ldap import atr.log as log +import atr.models.github as github import atr.models.results as results import atr.models.safe as safe import atr.models.sql as sql @@ -455,12 +456,14 @@ async def tasks_ongoing_revision( return task_count, latest_revision -async def trusted_jwt(publisher: str, jwt: str, phase: TrustedProjectPhase) -> tuple[dict[str, Any], str, sql.Project]: +async def trusted_jwt( + publisher: str, jwt: str, phase: TrustedProjectPhase +) -> tuple[github.TrustedPublisherPayload, str, sql.Project]: payload, asf_uid = await validate_trusted_jwt(publisher, jwt) # JWT could be for an ASF user or the trusted role, but we need a user here. if asf_uid is None: raise InteractionError("ASF user not found") - project = await _trusted_project(payload["repository"], payload["workflow_ref"], phase) + project = await _trusted_project(payload.repository, payload.workflow_ref, phase) return payload, asf_uid, project @@ -471,7 +474,7 @@ async def trusted_jwt_for_dist( phase: TrustedProjectPhase, project_key: safe.ProjectKey, version_key: safe.VersionKey, -) -> tuple[dict[str, Any], str, sql.Project, sql.Release]: +) -> tuple[github.TrustedPublisherPayload, str, sql.Project, sql.Release]: payload, asf_uid_from_jwt = await validate_trusted_jwt(publisher, jwt) if asf_uid_from_jwt is not None: raise InteractionError("Must use Trusted Publishing when specifying ASF UID") @@ -553,12 +556,12 @@ async def user_projects(asf_uid: str, caller_data: db.Session | None = None) -> return [(p.key, p.display_name) for p in projects] -async def validate_trusted_jwt(publisher: str, jwt: str) -> tuple[dict[str, Any], str | None]: +async def validate_trusted_jwt(publisher: str, jwt: str) -> tuple[github.TrustedPublisherPayload, str | None]: if publisher != "github": raise InteractionError(f"Publisher {publisher} not supported") payload = await jwtoken.verify_github_oidc(jwt) - if int(payload["actor_id"]) != _GITHUB_TRUSTED_ROLE_NID: - asf_uid = await ldap.github_to_apache(payload["actor_id"]) + if payload.actor_id != _GITHUB_TRUSTED_ROLE_NID: + asf_uid = await ldap.github_to_apache(payload.actor_id) else: asf_uid = None return payload, asf_uid diff --git a/atr/jwtoken.py b/atr/jwtoken.py index 7b9700df..38e00d32 100644 --- a/atr/jwtoken.py +++ b/atr/jwtoken.py @@ -154,7 +154,7 @@ async def verify(token: str) -> dict[str, Any]: return claims -async def verify_github_oidc(token: str) -> dict[str, Any]: +async def verify_github_oidc(token: str) -> github.TrustedPublisherPayload: header = jwt.get_unverified_header(token) dangerous_headers = {"jku", "x5u", "jwk"} if dangerous_headers.intersection(header.keys()): @@ -215,7 +215,7 @@ async def verify_github_oidc(token: str) -> dict[str, Any]: f"GitHub OIDC payload mismatch: {key} = {payload[key]} != {value}", errorcode=401, ) - return github.TrustedPublisherPayload.model_validate(payload).model_dump() + return github.TrustedPublisherPayload.model_validate(payload) def write_new_signing_key() -> str: diff --git a/atr/models/github.py b/atr/models/github.py index d20ac73f..e5fe39fc 100644 --- a/atr/models/github.py +++ b/atr/models/github.py @@ -17,19 +17,23 @@ from __future__ import annotations +import time + +import pydantic + from . import schema class TrustedPublisherPayload(schema.Subset): actor: str - actor_id: str + actor_id: int aud: str base_ref: str check_run_id: str enterprise: str enterprise_id: str event_name: str - exp: int + exp: int | None = None head_ref: str iat: int iss: str @@ -51,3 +55,19 @@ class TrustedPublisherPayload(schema.Subset): workflow: str workflow_ref: str workflow_sha: str + + @pydantic.field_validator("exp") + @classmethod + def _validate_exp(cls, value: int) -> int: + now = int(time.time()) + if now > value: + raise ValueError("Token has expired") + return value + + @pydantic.field_validator("nbf") + @classmethod + def _validate_nbf(cls, value: int | None) -> int | None: + now = int(time.time()) + if value and now < value: + raise ValueError("Token not yet valid") + return value diff --git a/atr/ssh.py b/atr/ssh.py index 2b78de79..533fc86f 100644 --- a/atr/ssh.py +++ b/atr/ssh.py @@ -26,7 +26,7 @@ import pathlib import stat import string import time -from typing import Any, Final +from typing import Final import aiofiles import aiofiles.os @@ -40,6 +40,7 @@ import atr.attestable as attestable import atr.config as config import atr.db as db import atr.log as log +import atr.models.github as github import atr.models.safe as safe import atr.models.sql as sql import atr.paths as paths @@ -82,7 +83,7 @@ class SSHServer(asyncssh.SSHServer): # Store connection for use in begin_auth self._conn = conn self._github_asf_uid: str | None = None - self._github_payload: dict[str, Any] | None = None + self._github_payload: github.TrustedPublisherPayload | None = None peer_addr = conn.get_extra_info("peername")[0] log.info(f"SSH connection received from {peer_addr}") @@ -163,7 +164,7 @@ class SSHServer(asyncssh.SSHServer): log.failed_authentication("public_key_expired") return False - self._github_payload = workflow_key.github_payload + self._github_payload = github.TrustedPublisherPayload.model_validate(workflow_key.github_payload) return True def _get_asf_uid(self, process: asyncssh.SSHServerProcess) -> str: @@ -174,7 +175,7 @@ class SSHServer(asyncssh.SSHServer): return self._github_asf_uid return username - def _get_github_payload(self, process: asyncssh.SSHServerProcess) -> dict[str, Any] | None: + def _get_github_payload(self, process: asyncssh.SSHServerProcess) -> github.TrustedPublisherPayload | None: username = process.get_extra_info("username") if username != "github": return None diff --git a/atr/storage/writers/ssh.py b/atr/storage/writers/ssh.py index 5cfb5e28..fb776e4a 100644 --- a/atr/storage/writers/ssh.py +++ b/atr/storage/writers/ssh.py @@ -19,9 +19,9 @@ from __future__ import annotations import time -from typing import Any import atr.db as db +import atr.models.github as github import atr.models.safe as safe import atr.models.sql as sql import atr.storage as storage @@ -86,13 +86,20 @@ class CommitteeParticipant(FoundationCommitter): self.__committee_key = committee_key async def add_workflow_key( - self, github_uid: str, github_nid: int, project_key: safe.ProjectKey, key: str, github_payload: dict[str, Any] + self, + github_uid: str, + github_nid: int, + project_key: safe.ProjectKey, + key: str, + github_payload: github.TrustedPublisherPayload, ) -> tuple[str, int]: now = int(time.time()) # Twenty minutes to upload all files ttl = 20 * 60 expires = now + ttl fingerprint = util.key_ssh_fingerprint(key) + # Exclude nbf and exp as we've already validated this key - now protected by workflowkey "expires" + json_payload = github_payload.model_dump(exclude={"exp", "nbf"}) wsk = sql.WorkflowSSHKey( fingerprint=fingerprint, key=key, @@ -100,7 +107,7 @@ class CommitteeParticipant(FoundationCommitter): asf_uid=self.__asf_uid, github_uid=github_uid, github_nid=github_nid, - github_payload=github_payload, + github_payload=json_payload, expires=expires, ) self.__data.add(wsk) diff --git a/atr/tasks/checks/compare.py b/atr/tasks/checks/compare.py index 73264b64..ace6f52a 100644 --- a/atr/tasks/checks/compare.py +++ b/atr/tasks/checks/compare.py @@ -18,7 +18,6 @@ import asyncio import contextlib import dataclasses -import json import os import pathlib import shutil @@ -34,14 +33,12 @@ import dulwich.objects import dulwich.objectspec import dulwich.porcelain import dulwich.refs -import pydantic import atr.attestable as attestable import atr.config as config import atr.log as log import atr.models.github as github_models import atr.models.results as results -import atr.models.safe as safe import atr.paths as paths import atr.tasks.checks as checks import atr.util as util @@ -95,7 +92,7 @@ async def source_trees(args: checks.FunctionArguments) -> results.Results | None ) return None - payload = await _load_tp_payload(args.project_key, args.version_key, args.revision_number) + payload = await attestable.github_tp_payload_read(args.project_key, args.version_key, args.revision_number) checkout_dir: str | None = None archive_dir: str | None = None if payload is not None: @@ -334,27 +331,6 @@ async def _find_archive_root(archive_path: pathlib.Path, extract_dir: pathlib.Pa return ArchiveRootResult(root=found_root, extra_entries=extra_entries) -async def _load_tp_payload( - project_key: safe.ProjectKey, version_key: safe.VersionKey, revision_number: safe.RevisionNumber -) -> github_models.TrustedPublisherPayload | None: - payload_path = attestable.github_tp_payload_path(project_key, version_key, revision_number) - if not await aiofiles.os.path.isfile(payload_path): - return None - try: - async with aiofiles.open(payload_path, encoding="utf-8") as f: - data = json.loads(await f.read()) - if not isinstance(data, dict): - log.warning(f"TP payload was not a JSON object in {payload_path}") - return None - return github_models.TrustedPublisherPayload.model_validate(data) - except (OSError, json.JSONDecodeError) as e: - log.warning(f"Failed to read TP payload from {payload_path}: {e}") - return None - except pydantic.ValidationError as e: - log.warning(f"Failed to validate TP payload from {payload_path}: {e}") - return None - - def _payload_summary(payload: github_models.TrustedPublisherPayload | None) -> dict[str, Any]: if payload is None: return {"present": False} diff --git a/tests/unit/test_checks_compare.py b/tests/unit/test_checks_compare.py index 229f3928..5e970ceb 100644 --- a/tests/unit/test_checks_compare.py +++ b/tests/unit/test_checks_compare.py @@ -26,6 +26,7 @@ import dulwich.objects import dulwich.refs import pytest +import atr.attestable import atr.models.github import atr.models.sql import atr.tasks.checks @@ -620,7 +621,7 @@ async def test_source_trees_creates_temp_workspace_and_cleans_up( compare = CompareRecorder(repo_only={"extra1.txt", "extra2.txt"}) tmp_root = tmp_path / "temporary-root" - monkeypatch.setattr(atr.tasks.checks.compare, "_load_tp_payload", PayloadLoader(payload)) + monkeypatch.setattr(atr.attestable, "github_tp_payload_read", PayloadLoader(payload)) monkeypatch.setattr(atr.tasks.checks.compare, "_checkout_github_source", checkout) monkeypatch.setattr(atr.tasks.checks, "resolve_archive_dir", ArchiveDirResolver(cache_dir)) monkeypatch.setattr(atr.tasks.checks.compare, "_find_archive_root", find_root) @@ -649,7 +650,7 @@ async def test_source_trees_payload_none_skips_temp_workspace(monkeypatch: pytes recorder = RecorderStub(True) args = _make_args(recorder) - monkeypatch.setattr(atr.tasks.checks.compare, "_load_tp_payload", PayloadLoader(None)) + monkeypatch.setattr(atr.attestable, "github_tp_payload_read", PayloadLoader(None)) monkeypatch.setattr( atr.tasks.checks.compare, "_checkout_github_source", @@ -676,7 +677,7 @@ async def test_source_trees_permits_pkg_info_when_pyproject_toml_exists( compare = CompareRecorder(invalid={"PKG-INFO"}) tmp_root = tmp_path / "temporary-root" - monkeypatch.setattr(atr.tasks.checks.compare, "_load_tp_payload", PayloadLoader(payload)) + monkeypatch.setattr(atr.attestable, "github_tp_payload_read", PayloadLoader(payload)) monkeypatch.setattr(atr.tasks.checks.compare, "_checkout_github_source", checkout) monkeypatch.setattr(atr.tasks.checks, "resolve_archive_dir", ArchiveDirResolver(cache_dir)) monkeypatch.setattr(atr.tasks.checks.compare, "_find_archive_root", find_root) @@ -703,7 +704,7 @@ async def test_source_trees_records_failure_when_archive_has_invalid_files( compare = CompareRecorder(invalid={"bad1.txt", "bad2.txt"}, repo_only={"ok.txt"}) tmp_root = tmp_path / "temporary-root" - monkeypatch.setattr(atr.tasks.checks.compare, "_load_tp_payload", PayloadLoader(payload)) + monkeypatch.setattr(atr.attestable, "github_tp_payload_read", PayloadLoader(payload)) monkeypatch.setattr(atr.tasks.checks.compare, "_checkout_github_source", checkout) monkeypatch.setattr(atr.tasks.checks, "resolve_archive_dir", ArchiveDirResolver(cache_dir)) monkeypatch.setattr(atr.tasks.checks.compare, "_find_archive_root", find_root) @@ -734,7 +735,7 @@ async def test_source_trees_records_failure_when_archive_root_not_found( find_root = FindArchiveRootRecorder(root=None) tmp_root = tmp_path / "temporary-root" - monkeypatch.setattr(atr.tasks.checks.compare, "_load_tp_payload", PayloadLoader(payload)) + monkeypatch.setattr(atr.attestable, "github_tp_payload_read", PayloadLoader(payload)) monkeypatch.setattr(atr.tasks.checks.compare, "_checkout_github_source", checkout) monkeypatch.setattr(atr.tasks.checks, "resolve_archive_dir", ArchiveDirResolver(cache_dir)) monkeypatch.setattr(atr.tasks.checks.compare, "_find_archive_root", find_root) @@ -756,7 +757,7 @@ async def test_source_trees_records_failure_when_cache_dir_unavailable( args = _make_args(recorder) payload = _make_payload() - monkeypatch.setattr(atr.tasks.checks.compare, "_load_tp_payload", PayloadLoader(payload)) + monkeypatch.setattr(atr.attestable, "github_tp_payload_read", PayloadLoader(payload)) monkeypatch.setattr(atr.tasks.checks, "resolve_archive_dir", ArchiveDirResolver(None)) monkeypatch.setattr( atr.tasks.checks.compare, @@ -786,7 +787,7 @@ async def test_source_trees_records_failure_when_extra_entries_in_archive( find_root = FindArchiveRootRecorder(root="artifact", extra_entries=["README.txt", "extra.txt"]) tmp_root = tmp_path / "temporary-root" - monkeypatch.setattr(atr.tasks.checks.compare, "_load_tp_payload", PayloadLoader(payload)) + monkeypatch.setattr(atr.attestable, "github_tp_payload_read", PayloadLoader(payload)) monkeypatch.setattr(atr.tasks.checks.compare, "_checkout_github_source", checkout) monkeypatch.setattr(atr.tasks.checks, "resolve_archive_dir", ArchiveDirResolver(cache_dir)) monkeypatch.setattr(atr.tasks.checks.compare, "_find_archive_root", find_root) @@ -817,7 +818,7 @@ async def test_source_trees_reports_repo_only_sample_limited_to_five( compare = CompareRecorder(repo_only=repo_only_files) tmp_root = tmp_path / "temporary-root" - monkeypatch.setattr(atr.tasks.checks.compare, "_load_tp_payload", PayloadLoader(payload)) + monkeypatch.setattr(atr.attestable, "github_tp_payload_read", PayloadLoader(payload)) monkeypatch.setattr(atr.tasks.checks.compare, "_checkout_github_source", checkout) monkeypatch.setattr(atr.tasks.checks, "resolve_archive_dir", ArchiveDirResolver(cache_dir)) monkeypatch.setattr(atr.tasks.checks.compare, "_find_archive_root", find_root) @@ -839,7 +840,7 @@ async def test_source_trees_skips_when_not_source(monkeypatch: pytest.MonkeyPatc args = _make_args(recorder) monkeypatch.setattr( - atr.tasks.checks.compare, "_load_tp_payload", RaiseAsync("_load_tp_payload should not be called") + atr.attestable, "github_tp_payload_read", RaiseAsync("github_tp_payload_read should not be called") ) await atr.tasks.checks.compare.source_trees(args) @@ -871,7 +872,7 @@ def _make_payload( "enterprise": "", "enterprise_id": "", "event_name": "push", - "exp": 1, + "exp": 99999999999, "head_ref": "", "iat": 1, "iss": "issuer", --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
