This is an automated email from the ASF dual-hosted git repository.

arm pushed a commit to branch github_tp_validation
in repository https://gitbox.apache.org/repos/asf/tooling-trusted-releases.git

commit 1c4f70a13c8f6261dd585a93ed7de7a85d324251
Author: Alastair McFarlane <[email protected]>
AuthorDate: Thu Mar 12 14:32:51 2026 +0000

    #676 Validate exp and nbp when loading pydantic model for Github token. 
Attestable class updated to store and load model instead of dict.
---
 atr/api/__init__.py               |  8 ++++----
 atr/attestable.py                 | 35 +++++++++++++++++++++++++++++++++--
 atr/db/interaction.py             | 17 ++++++++++-------
 atr/jwtoken.py                    |  4 ++--
 atr/models/github.py              | 24 ++++++++++++++++++++++--
 atr/ssh.py                        |  9 +++++----
 atr/storage/writers/ssh.py        | 13 ++++++++++---
 atr/tasks/checks/compare.py       | 26 +-------------------------
 tests/unit/test_checks_compare.py | 21 +++++++++++----------
 9 files changed, 98 insertions(+), 59 deletions(-)

diff --git a/atr/api/__init__.py b/atr/api/__init__.py
index 79ff4d95..62d03f0c 100644
--- a/atr/api/__init__.py
+++ b/atr/api/__init__.py
@@ -294,8 +294,8 @@ async def distribute_ssh_register(
     )
     async with 
storage.write_as_committee_member(util.unwrap(project.committee).key, asf_uid) 
as wacm:
         fingerprint, expires = await wacm.ssh.add_workflow_key(
-            payload["actor"],
-            payload["actor_id"],
+            payload.actor,
+            payload.actor_id,
             release.safe_project_key,
             data.ssh_key,
             payload,
@@ -893,8 +893,8 @@ async def publisher_ssh_register(
     )
     async with 
storage.write_as_committee_member(util.unwrap(project.committee).key, asf_uid) 
as wacm:
         fingerprint, expires = await wacm.ssh.add_workflow_key(
-            payload["actor"],
-            payload["actor_id"],
+            payload.actor,
+            payload.actor_id,
             project.safe_key,
             data.ssh_key,
             payload,
diff --git a/atr/attestable.py b/atr/attestable.py
index fc908ad2..a58c34ae 100644
--- a/atr/attestable.py
+++ b/atr/attestable.py
@@ -29,6 +29,7 @@ import atr.classify as classify
 import atr.hashes as hashes
 import atr.log as log
 import atr.models.attestable as models
+import atr.models.github as github
 import atr.models.safe as safe
 import atr.models.sql as sql
 import atr.paths as paths
@@ -128,14 +129,44 @@ def github_tp_payload_path(
     return paths.get_attestable_dir() / str(project_key) / str(version_key) / 
f"{revision_number!s}.github-tp.json"
 
 
+async def github_tp_payload_read(
+    project_key: safe.ProjectKey, version_key: safe.VersionKey, 
revision_number: safe.RevisionNumber
+) -> github.TrustedPublisherPayload | None:
+    payload_path = github_tp_payload_path(project_key, version_key, 
revision_number)
+    if not await aiofiles.os.path.isfile(payload_path):
+        return None
+    try:
+        async with aiofiles.open(payload_path, encoding="utf-8") as f:
+            data = json.loads(await f.read())
+        if not isinstance(data, dict):
+            log.warning(f"TP payload was not a JSON object in {payload_path}")
+            return None
+        # Remove exp and nbf if they're stored - as of 2026-03-18 they're 
validated and then removed before storage
+        # but we might have older data
+        if "exp" in data:
+            del data["exp"]
+        if "nbf" in data:
+            del data["nbf"]
+        return github.TrustedPublisherPayload.model_validate(data)
+    except (OSError, json.JSONDecodeError) as e:
+        log.warning(f"Failed to read TP payload from {payload_path}: {e}")
+        return None
+    except pydantic.ValidationError as e:
+        log.warning(f"Failed to validate TP payload from {payload_path}: {e}")
+        return None
+
+
 async def github_tp_payload_write(
     project_key: safe.ProjectKey,
     version_key: safe.VersionKey,
     revision_number: safe.RevisionNumber,
-    github_payload: dict[str, Any],
+    github_payload: github.TrustedPublisherPayload,
 ) -> None:
     payload_path = github_tp_payload_path(project_key, version_key, 
revision_number)
-    await util.atomic_write_file(payload_path, json.dumps(github_payload, 
indent=2))
+    # Dump the workflow payload, excluding exp and nbf - which shouldn't have 
made it this far. If they do,
+    # it's safe to remove them as they've been validated by the model already, 
and we should never store
+    # stale dates
+    await util.atomic_write_file(payload_path, 
json.dumps(github_payload.model_dump(exclude={"exp", "nbf"}), indent=2))
 
 
 async def load(
diff --git a/atr/db/interaction.py b/atr/db/interaction.py
index 33ef59d9..f18a1a93 100644
--- a/atr/db/interaction.py
+++ b/atr/db/interaction.py
@@ -19,7 +19,7 @@ import contextlib
 import datetime
 import enum
 from collections.abc import AsyncGenerator, Sequence
-from typing import Any, Final
+from typing import Final
 
 import packaging.version as version
 import sqlalchemy
@@ -31,6 +31,7 @@ import atr.db as db
 import atr.jwtoken as jwtoken
 import atr.ldap as ldap
 import atr.log as log
+import atr.models.github as github
 import atr.models.results as results
 import atr.models.safe as safe
 import atr.models.sql as sql
@@ -455,12 +456,14 @@ async def tasks_ongoing_revision(
         return task_count, latest_revision
 
 
-async def trusted_jwt(publisher: str, jwt: str, phase: TrustedProjectPhase) -> 
tuple[dict[str, Any], str, sql.Project]:
+async def trusted_jwt(
+    publisher: str, jwt: str, phase: TrustedProjectPhase
+) -> tuple[github.TrustedPublisherPayload, str, sql.Project]:
     payload, asf_uid = await validate_trusted_jwt(publisher, jwt)
     # JWT could be for an ASF user or the trusted role, but we need a user 
here.
     if asf_uid is None:
         raise InteractionError("ASF user not found")
-    project = await _trusted_project(payload["repository"], 
payload["workflow_ref"], phase)
+    project = await _trusted_project(payload.repository, payload.workflow_ref, 
phase)
     return payload, asf_uid, project
 
 
@@ -471,7 +474,7 @@ async def trusted_jwt_for_dist(
     phase: TrustedProjectPhase,
     project_key: safe.ProjectKey,
     version_key: safe.VersionKey,
-) -> tuple[dict[str, Any], str, sql.Project, sql.Release]:
+) -> tuple[github.TrustedPublisherPayload, str, sql.Project, sql.Release]:
     payload, asf_uid_from_jwt = await validate_trusted_jwt(publisher, jwt)
     if asf_uid_from_jwt is not None:
         raise InteractionError("Must use Trusted Publishing when specifying 
ASF UID")
@@ -553,12 +556,12 @@ async def user_projects(asf_uid: str, caller_data: 
db.Session | None = None) ->
     return [(p.key, p.display_name) for p in projects]
 
 
-async def validate_trusted_jwt(publisher: str, jwt: str) -> tuple[dict[str, 
Any], str | None]:
+async def validate_trusted_jwt(publisher: str, jwt: str) -> 
tuple[github.TrustedPublisherPayload, str | None]:
     if publisher != "github":
         raise InteractionError(f"Publisher {publisher} not supported")
     payload = await jwtoken.verify_github_oidc(jwt)
-    if int(payload["actor_id"]) != _GITHUB_TRUSTED_ROLE_NID:
-        asf_uid = await ldap.github_to_apache(payload["actor_id"])
+    if payload.actor_id != _GITHUB_TRUSTED_ROLE_NID:
+        asf_uid = await ldap.github_to_apache(payload.actor_id)
     else:
         asf_uid = None
     return payload, asf_uid
diff --git a/atr/jwtoken.py b/atr/jwtoken.py
index 7b9700df..38e00d32 100644
--- a/atr/jwtoken.py
+++ b/atr/jwtoken.py
@@ -154,7 +154,7 @@ async def verify(token: str) -> dict[str, Any]:
     return claims
 
 
-async def verify_github_oidc(token: str) -> dict[str, Any]:
+async def verify_github_oidc(token: str) -> github.TrustedPublisherPayload:
     header = jwt.get_unverified_header(token)
     dangerous_headers = {"jku", "x5u", "jwk"}
     if dangerous_headers.intersection(header.keys()):
@@ -215,7 +215,7 @@ async def verify_github_oidc(token: str) -> dict[str, Any]:
                 f"GitHub OIDC payload mismatch: {key} = {payload[key]} != 
{value}",
                 errorcode=401,
             )
-    return github.TrustedPublisherPayload.model_validate(payload).model_dump()
+    return github.TrustedPublisherPayload.model_validate(payload)
 
 
 def write_new_signing_key() -> str:
diff --git a/atr/models/github.py b/atr/models/github.py
index d20ac73f..e5fe39fc 100644
--- a/atr/models/github.py
+++ b/atr/models/github.py
@@ -17,19 +17,23 @@
 
 from __future__ import annotations
 
+import time
+
+import pydantic
+
 from . import schema
 
 
 class TrustedPublisherPayload(schema.Subset):
     actor: str
-    actor_id: str
+    actor_id: int
     aud: str
     base_ref: str
     check_run_id: str
     enterprise: str
     enterprise_id: str
     event_name: str
-    exp: int
+    exp: int | None = None
     head_ref: str
     iat: int
     iss: str
@@ -51,3 +55,19 @@ class TrustedPublisherPayload(schema.Subset):
     workflow: str
     workflow_ref: str
     workflow_sha: str
+
+    @pydantic.field_validator("exp")
+    @classmethod
+    def _validate_exp(cls, value: int) -> int:
+        now = int(time.time())
+        if now > value:
+            raise ValueError("Token has expired")
+        return value
+
+    @pydantic.field_validator("nbf")
+    @classmethod
+    def _validate_nbf(cls, value: int | None) -> int | None:
+        now = int(time.time())
+        if value and now < value:
+            raise ValueError("Token not yet valid")
+        return value
diff --git a/atr/ssh.py b/atr/ssh.py
index 2b78de79..533fc86f 100644
--- a/atr/ssh.py
+++ b/atr/ssh.py
@@ -26,7 +26,7 @@ import pathlib
 import stat
 import string
 import time
-from typing import Any, Final
+from typing import Final
 
 import aiofiles
 import aiofiles.os
@@ -40,6 +40,7 @@ import atr.attestable as attestable
 import atr.config as config
 import atr.db as db
 import atr.log as log
+import atr.models.github as github
 import atr.models.safe as safe
 import atr.models.sql as sql
 import atr.paths as paths
@@ -82,7 +83,7 @@ class SSHServer(asyncssh.SSHServer):
         # Store connection for use in begin_auth
         self._conn = conn
         self._github_asf_uid: str | None = None
-        self._github_payload: dict[str, Any] | None = None
+        self._github_payload: github.TrustedPublisherPayload | None = None
         peer_addr = conn.get_extra_info("peername")[0]
         log.info(f"SSH connection received from {peer_addr}")
 
@@ -163,7 +164,7 @@ class SSHServer(asyncssh.SSHServer):
                 log.failed_authentication("public_key_expired")
                 return False
 
-            self._github_payload = workflow_key.github_payload
+            self._github_payload = 
github.TrustedPublisherPayload.model_validate(workflow_key.github_payload)
             return True
 
     def _get_asf_uid(self, process: asyncssh.SSHServerProcess) -> str:
@@ -174,7 +175,7 @@ class SSHServer(asyncssh.SSHServer):
             return self._github_asf_uid
         return username
 
-    def _get_github_payload(self, process: asyncssh.SSHServerProcess) -> 
dict[str, Any] | None:
+    def _get_github_payload(self, process: asyncssh.SSHServerProcess) -> 
github.TrustedPublisherPayload | None:
         username = process.get_extra_info("username")
         if username != "github":
             return None
diff --git a/atr/storage/writers/ssh.py b/atr/storage/writers/ssh.py
index 5cfb5e28..fb776e4a 100644
--- a/atr/storage/writers/ssh.py
+++ b/atr/storage/writers/ssh.py
@@ -19,9 +19,9 @@
 from __future__ import annotations
 
 import time
-from typing import Any
 
 import atr.db as db
+import atr.models.github as github
 import atr.models.safe as safe
 import atr.models.sql as sql
 import atr.storage as storage
@@ -86,13 +86,20 @@ class CommitteeParticipant(FoundationCommitter):
         self.__committee_key = committee_key
 
     async def add_workflow_key(
-        self, github_uid: str, github_nid: int, project_key: safe.ProjectKey, 
key: str, github_payload: dict[str, Any]
+        self,
+        github_uid: str,
+        github_nid: int,
+        project_key: safe.ProjectKey,
+        key: str,
+        github_payload: github.TrustedPublisherPayload,
     ) -> tuple[str, int]:
         now = int(time.time())
         # Twenty minutes to upload all files
         ttl = 20 * 60
         expires = now + ttl
         fingerprint = util.key_ssh_fingerprint(key)
+        # Exclude nbf and exp as we've already validated this key - now 
protected by workflowkey "expires"
+        json_payload = github_payload.model_dump(exclude={"exp", "nbf"})
         wsk = sql.WorkflowSSHKey(
             fingerprint=fingerprint,
             key=key,
@@ -100,7 +107,7 @@ class CommitteeParticipant(FoundationCommitter):
             asf_uid=self.__asf_uid,
             github_uid=github_uid,
             github_nid=github_nid,
-            github_payload=github_payload,
+            github_payload=json_payload,
             expires=expires,
         )
         self.__data.add(wsk)
diff --git a/atr/tasks/checks/compare.py b/atr/tasks/checks/compare.py
index 73264b64..ace6f52a 100644
--- a/atr/tasks/checks/compare.py
+++ b/atr/tasks/checks/compare.py
@@ -18,7 +18,6 @@
 import asyncio
 import contextlib
 import dataclasses
-import json
 import os
 import pathlib
 import shutil
@@ -34,14 +33,12 @@ import dulwich.objects
 import dulwich.objectspec
 import dulwich.porcelain
 import dulwich.refs
-import pydantic
 
 import atr.attestable as attestable
 import atr.config as config
 import atr.log as log
 import atr.models.github as github_models
 import atr.models.results as results
-import atr.models.safe as safe
 import atr.paths as paths
 import atr.tasks.checks as checks
 import atr.util as util
@@ -95,7 +92,7 @@ async def source_trees(args: checks.FunctionArguments) -> 
results.Results | None
         )
         return None
 
-    payload = await _load_tp_payload(args.project_key, args.version_key, 
args.revision_number)
+    payload = await attestable.github_tp_payload_read(args.project_key, 
args.version_key, args.revision_number)
     checkout_dir: str | None = None
     archive_dir: str | None = None
     if payload is not None:
@@ -334,27 +331,6 @@ async def _find_archive_root(archive_path: pathlib.Path, 
extract_dir: pathlib.Pa
     return ArchiveRootResult(root=found_root, extra_entries=extra_entries)
 
 
-async def _load_tp_payload(
-    project_key: safe.ProjectKey, version_key: safe.VersionKey, 
revision_number: safe.RevisionNumber
-) -> github_models.TrustedPublisherPayload | None:
-    payload_path = attestable.github_tp_payload_path(project_key, version_key, 
revision_number)
-    if not await aiofiles.os.path.isfile(payload_path):
-        return None
-    try:
-        async with aiofiles.open(payload_path, encoding="utf-8") as f:
-            data = json.loads(await f.read())
-        if not isinstance(data, dict):
-            log.warning(f"TP payload was not a JSON object in {payload_path}")
-            return None
-        return github_models.TrustedPublisherPayload.model_validate(data)
-    except (OSError, json.JSONDecodeError) as e:
-        log.warning(f"Failed to read TP payload from {payload_path}: {e}")
-        return None
-    except pydantic.ValidationError as e:
-        log.warning(f"Failed to validate TP payload from {payload_path}: {e}")
-        return None
-
-
 def _payload_summary(payload: github_models.TrustedPublisherPayload | None) -> 
dict[str, Any]:
     if payload is None:
         return {"present": False}
diff --git a/tests/unit/test_checks_compare.py 
b/tests/unit/test_checks_compare.py
index 229f3928..5e970ceb 100644
--- a/tests/unit/test_checks_compare.py
+++ b/tests/unit/test_checks_compare.py
@@ -26,6 +26,7 @@ import dulwich.objects
 import dulwich.refs
 import pytest
 
+import atr.attestable
 import atr.models.github
 import atr.models.sql
 import atr.tasks.checks
@@ -620,7 +621,7 @@ async def 
test_source_trees_creates_temp_workspace_and_cleans_up(
     compare = CompareRecorder(repo_only={"extra1.txt", "extra2.txt"})
     tmp_root = tmp_path / "temporary-root"
 
-    monkeypatch.setattr(atr.tasks.checks.compare, "_load_tp_payload", 
PayloadLoader(payload))
+    monkeypatch.setattr(atr.attestable, "github_tp_payload_read", 
PayloadLoader(payload))
     monkeypatch.setattr(atr.tasks.checks.compare, "_checkout_github_source", 
checkout)
     monkeypatch.setattr(atr.tasks.checks, "resolve_archive_dir", 
ArchiveDirResolver(cache_dir))
     monkeypatch.setattr(atr.tasks.checks.compare, "_find_archive_root", 
find_root)
@@ -649,7 +650,7 @@ async def 
test_source_trees_payload_none_skips_temp_workspace(monkeypatch: pytes
     recorder = RecorderStub(True)
     args = _make_args(recorder)
 
-    monkeypatch.setattr(atr.tasks.checks.compare, "_load_tp_payload", 
PayloadLoader(None))
+    monkeypatch.setattr(atr.attestable, "github_tp_payload_read", 
PayloadLoader(None))
     monkeypatch.setattr(
         atr.tasks.checks.compare,
         "_checkout_github_source",
@@ -676,7 +677,7 @@ async def 
test_source_trees_permits_pkg_info_when_pyproject_toml_exists(
     compare = CompareRecorder(invalid={"PKG-INFO"})
     tmp_root = tmp_path / "temporary-root"
 
-    monkeypatch.setattr(atr.tasks.checks.compare, "_load_tp_payload", 
PayloadLoader(payload))
+    monkeypatch.setattr(atr.attestable, "github_tp_payload_read", 
PayloadLoader(payload))
     monkeypatch.setattr(atr.tasks.checks.compare, "_checkout_github_source", 
checkout)
     monkeypatch.setattr(atr.tasks.checks, "resolve_archive_dir", 
ArchiveDirResolver(cache_dir))
     monkeypatch.setattr(atr.tasks.checks.compare, "_find_archive_root", 
find_root)
@@ -703,7 +704,7 @@ async def 
test_source_trees_records_failure_when_archive_has_invalid_files(
     compare = CompareRecorder(invalid={"bad1.txt", "bad2.txt"}, 
repo_only={"ok.txt"})
     tmp_root = tmp_path / "temporary-root"
 
-    monkeypatch.setattr(atr.tasks.checks.compare, "_load_tp_payload", 
PayloadLoader(payload))
+    monkeypatch.setattr(atr.attestable, "github_tp_payload_read", 
PayloadLoader(payload))
     monkeypatch.setattr(atr.tasks.checks.compare, "_checkout_github_source", 
checkout)
     monkeypatch.setattr(atr.tasks.checks, "resolve_archive_dir", 
ArchiveDirResolver(cache_dir))
     monkeypatch.setattr(atr.tasks.checks.compare, "_find_archive_root", 
find_root)
@@ -734,7 +735,7 @@ async def 
test_source_trees_records_failure_when_archive_root_not_found(
     find_root = FindArchiveRootRecorder(root=None)
     tmp_root = tmp_path / "temporary-root"
 
-    monkeypatch.setattr(atr.tasks.checks.compare, "_load_tp_payload", 
PayloadLoader(payload))
+    monkeypatch.setattr(atr.attestable, "github_tp_payload_read", 
PayloadLoader(payload))
     monkeypatch.setattr(atr.tasks.checks.compare, "_checkout_github_source", 
checkout)
     monkeypatch.setattr(atr.tasks.checks, "resolve_archive_dir", 
ArchiveDirResolver(cache_dir))
     monkeypatch.setattr(atr.tasks.checks.compare, "_find_archive_root", 
find_root)
@@ -756,7 +757,7 @@ async def 
test_source_trees_records_failure_when_cache_dir_unavailable(
     args = _make_args(recorder)
     payload = _make_payload()
 
-    monkeypatch.setattr(atr.tasks.checks.compare, "_load_tp_payload", 
PayloadLoader(payload))
+    monkeypatch.setattr(atr.attestable, "github_tp_payload_read", 
PayloadLoader(payload))
     monkeypatch.setattr(atr.tasks.checks, "resolve_archive_dir", 
ArchiveDirResolver(None))
     monkeypatch.setattr(
         atr.tasks.checks.compare,
@@ -786,7 +787,7 @@ async def 
test_source_trees_records_failure_when_extra_entries_in_archive(
     find_root = FindArchiveRootRecorder(root="artifact", 
extra_entries=["README.txt", "extra.txt"])
     tmp_root = tmp_path / "temporary-root"
 
-    monkeypatch.setattr(atr.tasks.checks.compare, "_load_tp_payload", 
PayloadLoader(payload))
+    monkeypatch.setattr(atr.attestable, "github_tp_payload_read", 
PayloadLoader(payload))
     monkeypatch.setattr(atr.tasks.checks.compare, "_checkout_github_source", 
checkout)
     monkeypatch.setattr(atr.tasks.checks, "resolve_archive_dir", 
ArchiveDirResolver(cache_dir))
     monkeypatch.setattr(atr.tasks.checks.compare, "_find_archive_root", 
find_root)
@@ -817,7 +818,7 @@ async def 
test_source_trees_reports_repo_only_sample_limited_to_five(
     compare = CompareRecorder(repo_only=repo_only_files)
     tmp_root = tmp_path / "temporary-root"
 
-    monkeypatch.setattr(atr.tasks.checks.compare, "_load_tp_payload", 
PayloadLoader(payload))
+    monkeypatch.setattr(atr.attestable, "github_tp_payload_read", 
PayloadLoader(payload))
     monkeypatch.setattr(atr.tasks.checks.compare, "_checkout_github_source", 
checkout)
     monkeypatch.setattr(atr.tasks.checks, "resolve_archive_dir", 
ArchiveDirResolver(cache_dir))
     monkeypatch.setattr(atr.tasks.checks.compare, "_find_archive_root", 
find_root)
@@ -839,7 +840,7 @@ async def 
test_source_trees_skips_when_not_source(monkeypatch: pytest.MonkeyPatc
     args = _make_args(recorder)
 
     monkeypatch.setattr(
-        atr.tasks.checks.compare, "_load_tp_payload", 
RaiseAsync("_load_tp_payload should not be called")
+        atr.attestable, "github_tp_payload_read", 
RaiseAsync("github_tp_payload_read should not be called")
     )
 
     await atr.tasks.checks.compare.source_trees(args)
@@ -871,7 +872,7 @@ def _make_payload(
         "enterprise": "",
         "enterprise_id": "",
         "event_name": "push",
-        "exp": 1,
+        "exp": 99999999999,
         "head_ref": "",
         "iat": 1,
         "iss": "issuer",


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to