This is an automated email from the ASF dual-hosted git repository.
sbp pushed a commit to branch arm
in repository https://gitbox.apache.org/repos/asf/tooling-trusted-releases.git
The following commit(s) were added to refs/heads/arm by this push:
new 5961d500 #676 Validate exp and nbf when loading pydantic model for
Github token. Attestable class updated to store and load model instead of dict.
5961d500 is described below
commit 5961d5002f01c20f9fbc25669efbf877faf2cd93
Author: Alastair McFarlane <[email protected]>
AuthorDate: Thu Mar 12 14:32:51 2026 +0000
#676 Validate exp and nbf when loading pydantic model for Github token.
Attestable class updated to store and load model instead of dict.
---
atr/api/__init__.py | 8 ++++----
atr/attestable.py | 35 +++++++++++++++++++++++++++++++++--
atr/db/interaction.py | 17 ++++++++++-------
atr/jwtoken.py | 4 ++--
atr/models/github.py | 24 ++++++++++++++++++++++--
atr/ssh.py | 9 +++++----
atr/storage/writers/ssh.py | 13 ++++++++++---
atr/tasks/checks/__init__.py | 22 +++-------------------
atr/tasks/checks/compare.py | 26 +-------------------------
tests/unit/test_checks_compare.py | 21 +++++++++++----------
10 files changed, 101 insertions(+), 78 deletions(-)
diff --git a/atr/api/__init__.py b/atr/api/__init__.py
index 79ff4d95..62d03f0c 100644
--- a/atr/api/__init__.py
+++ b/atr/api/__init__.py
@@ -294,8 +294,8 @@ async def distribute_ssh_register(
)
async with
storage.write_as_committee_member(util.unwrap(project.committee).key, asf_uid)
as wacm:
fingerprint, expires = await wacm.ssh.add_workflow_key(
- payload["actor"],
- payload["actor_id"],
+ payload.actor,
+ payload.actor_id,
release.safe_project_key,
data.ssh_key,
payload,
@@ -893,8 +893,8 @@ async def publisher_ssh_register(
)
async with
storage.write_as_committee_member(util.unwrap(project.committee).key, asf_uid)
as wacm:
fingerprint, expires = await wacm.ssh.add_workflow_key(
- payload["actor"],
- payload["actor_id"],
+ payload.actor,
+ payload.actor_id,
project.safe_key,
data.ssh_key,
payload,
diff --git a/atr/attestable.py b/atr/attestable.py
index fc908ad2..a58c34ae 100644
--- a/atr/attestable.py
+++ b/atr/attestable.py
@@ -29,6 +29,7 @@ import atr.classify as classify
import atr.hashes as hashes
import atr.log as log
import atr.models.attestable as models
+import atr.models.github as github
import atr.models.safe as safe
import atr.models.sql as sql
import atr.paths as paths
@@ -128,14 +129,44 @@ def github_tp_payload_path(
return paths.get_attestable_dir() / str(project_key) / str(version_key) /
f"{revision_number!s}.github-tp.json"
+async def github_tp_payload_read(
+ project_key: safe.ProjectKey, version_key: safe.VersionKey,
revision_number: safe.RevisionNumber
+) -> github.TrustedPublisherPayload | None:
+ payload_path = github_tp_payload_path(project_key, version_key,
revision_number)
+ if not await aiofiles.os.path.isfile(payload_path):
+ return None
+ try:
+ async with aiofiles.open(payload_path, encoding="utf-8") as f:
+ data = json.loads(await f.read())
+ if not isinstance(data, dict):
+ log.warning(f"TP payload was not a JSON object in {payload_path}")
+ return None
+ # Remove exp and nbf if they're stored - as of 2026-03-18 they're
validated and then removed before storage
+ # but we might have older data
+ if "exp" in data:
+ del data["exp"]
+ if "nbf" in data:
+ del data["nbf"]
+ return github.TrustedPublisherPayload.model_validate(data)
+ except (OSError, json.JSONDecodeError) as e:
+ log.warning(f"Failed to read TP payload from {payload_path}: {e}")
+ return None
+ except pydantic.ValidationError as e:
+ log.warning(f"Failed to validate TP payload from {payload_path}: {e}")
+ return None
+
+
async def github_tp_payload_write(
project_key: safe.ProjectKey,
version_key: safe.VersionKey,
revision_number: safe.RevisionNumber,
- github_payload: dict[str, Any],
+ github_payload: github.TrustedPublisherPayload,
) -> None:
payload_path = github_tp_payload_path(project_key, version_key,
revision_number)
- await util.atomic_write_file(payload_path, json.dumps(github_payload,
indent=2))
+ # Dump the workflow payload, excluding exp and nbf - which shouldn't have
made it this far. If they do,
+ # it's safe to remove them as they've been validated by the model already,
and we should never store
+ # stale dates
+ await util.atomic_write_file(payload_path,
json.dumps(github_payload.model_dump(exclude={"exp", "nbf"}), indent=2))
async def load(
diff --git a/atr/db/interaction.py b/atr/db/interaction.py
index 33ef59d9..f18a1a93 100644
--- a/atr/db/interaction.py
+++ b/atr/db/interaction.py
@@ -19,7 +19,7 @@ import contextlib
import datetime
import enum
from collections.abc import AsyncGenerator, Sequence
-from typing import Any, Final
+from typing import Final
import packaging.version as version
import sqlalchemy
@@ -31,6 +31,7 @@ import atr.db as db
import atr.jwtoken as jwtoken
import atr.ldap as ldap
import atr.log as log
+import atr.models.github as github
import atr.models.results as results
import atr.models.safe as safe
import atr.models.sql as sql
@@ -455,12 +456,14 @@ async def tasks_ongoing_revision(
return task_count, latest_revision
-async def trusted_jwt(publisher: str, jwt: str, phase: TrustedProjectPhase) ->
tuple[dict[str, Any], str, sql.Project]:
+async def trusted_jwt(
+ publisher: str, jwt: str, phase: TrustedProjectPhase
+) -> tuple[github.TrustedPublisherPayload, str, sql.Project]:
payload, asf_uid = await validate_trusted_jwt(publisher, jwt)
# JWT could be for an ASF user or the trusted role, but we need a user
here.
if asf_uid is None:
raise InteractionError("ASF user not found")
- project = await _trusted_project(payload["repository"],
payload["workflow_ref"], phase)
+ project = await _trusted_project(payload.repository, payload.workflow_ref,
phase)
return payload, asf_uid, project
@@ -471,7 +474,7 @@ async def trusted_jwt_for_dist(
phase: TrustedProjectPhase,
project_key: safe.ProjectKey,
version_key: safe.VersionKey,
-) -> tuple[dict[str, Any], str, sql.Project, sql.Release]:
+) -> tuple[github.TrustedPublisherPayload, str, sql.Project, sql.Release]:
payload, asf_uid_from_jwt = await validate_trusted_jwt(publisher, jwt)
if asf_uid_from_jwt is not None:
raise InteractionError("Must use Trusted Publishing when specifying
ASF UID")
@@ -553,12 +556,12 @@ async def user_projects(asf_uid: str, caller_data:
db.Session | None = None) ->
return [(p.key, p.display_name) for p in projects]
-async def validate_trusted_jwt(publisher: str, jwt: str) -> tuple[dict[str,
Any], str | None]:
+async def validate_trusted_jwt(publisher: str, jwt: str) ->
tuple[github.TrustedPublisherPayload, str | None]:
if publisher != "github":
raise InteractionError(f"Publisher {publisher} not supported")
payload = await jwtoken.verify_github_oidc(jwt)
- if int(payload["actor_id"]) != _GITHUB_TRUSTED_ROLE_NID:
- asf_uid = await ldap.github_to_apache(payload["actor_id"])
+ if payload.actor_id != _GITHUB_TRUSTED_ROLE_NID:
+ asf_uid = await ldap.github_to_apache(payload.actor_id)
else:
asf_uid = None
return payload, asf_uid
diff --git a/atr/jwtoken.py b/atr/jwtoken.py
index 7b9700df..38e00d32 100644
--- a/atr/jwtoken.py
+++ b/atr/jwtoken.py
@@ -154,7 +154,7 @@ async def verify(token: str) -> dict[str, Any]:
return claims
-async def verify_github_oidc(token: str) -> dict[str, Any]:
+async def verify_github_oidc(token: str) -> github.TrustedPublisherPayload:
header = jwt.get_unverified_header(token)
dangerous_headers = {"jku", "x5u", "jwk"}
if dangerous_headers.intersection(header.keys()):
@@ -215,7 +215,7 @@ async def verify_github_oidc(token: str) -> dict[str, Any]:
f"GitHub OIDC payload mismatch: {key} = {payload[key]} !=
{value}",
errorcode=401,
)
- return github.TrustedPublisherPayload.model_validate(payload).model_dump()
+ return github.TrustedPublisherPayload.model_validate(payload)
def write_new_signing_key() -> str:
diff --git a/atr/models/github.py b/atr/models/github.py
index d20ac73f..e5fe39fc 100644
--- a/atr/models/github.py
+++ b/atr/models/github.py
@@ -17,19 +17,23 @@
from __future__ import annotations
+import time
+
+import pydantic
+
from . import schema
class TrustedPublisherPayload(schema.Subset):
actor: str
- actor_id: str
+ actor_id: int
aud: str
base_ref: str
check_run_id: str
enterprise: str
enterprise_id: str
event_name: str
- exp: int
+ exp: int | None = None
head_ref: str
iat: int
iss: str
@@ -51,3 +55,19 @@ class TrustedPublisherPayload(schema.Subset):
workflow: str
workflow_ref: str
workflow_sha: str
+
+ @pydantic.field_validator("exp")
+ @classmethod
+ def _validate_exp(cls, value: int) -> int:
+ now = int(time.time())
+ if now > value:
+ raise ValueError("Token has expired")
+ return value
+
+ @pydantic.field_validator("nbf")
+ @classmethod
+ def _validate_nbf(cls, value: int | None) -> int | None:
+ now = int(time.time())
+ if value and now < value:
+ raise ValueError("Token not yet valid")
+ return value
diff --git a/atr/ssh.py b/atr/ssh.py
index 2b78de79..533fc86f 100644
--- a/atr/ssh.py
+++ b/atr/ssh.py
@@ -26,7 +26,7 @@ import pathlib
import stat
import string
import time
-from typing import Any, Final
+from typing import Final
import aiofiles
import aiofiles.os
@@ -40,6 +40,7 @@ import atr.attestable as attestable
import atr.config as config
import atr.db as db
import atr.log as log
+import atr.models.github as github
import atr.models.safe as safe
import atr.models.sql as sql
import atr.paths as paths
@@ -82,7 +83,7 @@ class SSHServer(asyncssh.SSHServer):
# Store connection for use in begin_auth
self._conn = conn
self._github_asf_uid: str | None = None
- self._github_payload: dict[str, Any] | None = None
+ self._github_payload: github.TrustedPublisherPayload | None = None
peer_addr = conn.get_extra_info("peername")[0]
log.info(f"SSH connection received from {peer_addr}")
@@ -163,7 +164,7 @@ class SSHServer(asyncssh.SSHServer):
log.failed_authentication("public_key_expired")
return False
- self._github_payload = workflow_key.github_payload
+ self._github_payload =
github.TrustedPublisherPayload.model_validate(workflow_key.github_payload)
return True
def _get_asf_uid(self, process: asyncssh.SSHServerProcess) -> str:
@@ -174,7 +175,7 @@ class SSHServer(asyncssh.SSHServer):
return self._github_asf_uid
return username
- def _get_github_payload(self, process: asyncssh.SSHServerProcess) ->
dict[str, Any] | None:
+ def _get_github_payload(self, process: asyncssh.SSHServerProcess) ->
github.TrustedPublisherPayload | None:
username = process.get_extra_info("username")
if username != "github":
return None
diff --git a/atr/storage/writers/ssh.py b/atr/storage/writers/ssh.py
index 5cfb5e28..fb776e4a 100644
--- a/atr/storage/writers/ssh.py
+++ b/atr/storage/writers/ssh.py
@@ -19,9 +19,9 @@
from __future__ import annotations
import time
-from typing import Any
import atr.db as db
+import atr.models.github as github
import atr.models.safe as safe
import atr.models.sql as sql
import atr.storage as storage
@@ -86,13 +86,20 @@ class CommitteeParticipant(FoundationCommitter):
self.__committee_key = committee_key
async def add_workflow_key(
- self, github_uid: str, github_nid: int, project_key: safe.ProjectKey,
key: str, github_payload: dict[str, Any]
+ self,
+ github_uid: str,
+ github_nid: int,
+ project_key: safe.ProjectKey,
+ key: str,
+ github_payload: github.TrustedPublisherPayload,
) -> tuple[str, int]:
now = int(time.time())
# Twenty minutes to upload all files
ttl = 20 * 60
expires = now + ttl
fingerprint = util.key_ssh_fingerprint(key)
+ # Exclude nbf and exp as we've already validated this key - now
protected by workflowkey "expires"
+ json_payload = github_payload.model_dump(exclude={"exp", "nbf"})
wsk = sql.WorkflowSSHKey(
fingerprint=fingerprint,
key=key,
@@ -100,7 +107,7 @@ class CommitteeParticipant(FoundationCommitter):
asf_uid=self.__asf_uid,
github_uid=github_uid,
github_nid=github_nid,
- github_payload=github_payload,
+ github_payload=json_payload,
expires=expires,
)
self.__data.add(wsk)
diff --git a/atr/tasks/checks/__init__.py b/atr/tasks/checks/__init__.py
index 6e26224c..24d3fce4 100644
--- a/atr/tasks/checks/__init__.py
+++ b/atr/tasks/checks/__init__.py
@@ -20,13 +20,11 @@ from __future__ import annotations
import dataclasses
import datetime
import functools
-import json
import pathlib
from typing import TYPE_CHECKING, Any, Final
import aiofiles
import aiofiles.os
-import pydantic
import sqlmodel
if TYPE_CHECKING:
@@ -39,7 +37,6 @@ import atr.classify as classify
import atr.db as db
import atr.hashes as hashes
import atr.log as log
-import atr.models.github as github_models
import atr.models.safe as safe
import atr.models.sql as sql
import atr.paths as file_paths
@@ -473,25 +470,12 @@ async def _resolve_committee_key(release: sql.Release,
rel_path: str | None = No
async def _resolve_github_tp_sha(release: sql.Release, rel_path: str | None =
None) -> str:
if not release.latest_revision_number:
return ""
- payload_path = attestable.github_tp_payload_path(
+ payload = await attestable.github_tp_payload_read(
release.safe_project_key, release.safe_version_key,
release.safe_latest_revision_number
)
- if not await aiofiles.os.path.isfile(payload_path):
- return ""
- try:
- async with aiofiles.open(payload_path, encoding="utf-8") as f:
- data = json.loads(await f.read())
- if not isinstance(data, dict):
- log.warning(f"TP payload was not a JSON object in {payload_path}")
- return ""
- tp_data = github_models.TrustedPublisherPayload.model_validate(data)
- return tp_data.sha
- except (OSError, json.JSONDecodeError) as e:
- log.warning(f"Failed to read TP payload from {payload_path}: {e}")
- return ""
- except pydantic.ValidationError as e:
- log.warning(f"Failed to validate TP payload from {payload_path}: {e}")
+ if not payload:
return ""
+ return payload.sha
async def _resolve_is_podling(release: sql.Release, rel_path: str | None =
None) -> bool:
diff --git a/atr/tasks/checks/compare.py b/atr/tasks/checks/compare.py
index 858af724..d0b777ac 100644
--- a/atr/tasks/checks/compare.py
+++ b/atr/tasks/checks/compare.py
@@ -18,7 +18,6 @@
import asyncio
import contextlib
import dataclasses
-import json
import os
import pathlib
import shutil
@@ -34,14 +33,12 @@ import dulwich.objects
import dulwich.objectspec
import dulwich.porcelain
import dulwich.refs
-import pydantic
import atr.attestable as attestable
import atr.config as config
import atr.log as log
import atr.models.github as github_models
import atr.models.results as results
-import atr.models.safe as safe
import atr.paths as paths
import atr.tasks.checks as checks
import atr.util as util
@@ -95,7 +92,7 @@ async def source_trees(args: checks.FunctionArguments) ->
results.Results | None
)
return None
- payload = await _load_tp_payload(args.project_key, args.version_key,
args.revision_number)
+ payload = await attestable.github_tp_payload_read(args.project_key,
args.version_key, args.revision_number)
checkout_dir: str | None = None
archive_dir: str | None = None
if payload is not None:
@@ -334,27 +331,6 @@ async def _find_archive_root(archive_path: pathlib.Path,
extract_dir: pathlib.Pa
return ArchiveRootResult(root=found_root, extra_entries=extra_entries)
-async def _load_tp_payload(
- project_key: safe.ProjectKey, version_key: safe.VersionKey,
revision_number: safe.RevisionNumber
-) -> github_models.TrustedPublisherPayload | None:
- payload_path = attestable.github_tp_payload_path(project_key, version_key,
revision_number)
- if not await aiofiles.os.path.isfile(payload_path):
- return None
- try:
- async with aiofiles.open(payload_path, encoding="utf-8") as f:
- data = json.loads(await f.read())
- if not isinstance(data, dict):
- log.warning(f"TP payload was not a JSON object in {payload_path}")
- return None
- return github_models.TrustedPublisherPayload.model_validate(data)
- except (OSError, json.JSONDecodeError) as e:
- log.warning(f"Failed to read TP payload from {payload_path}: {e}")
- return None
- except pydantic.ValidationError as e:
- log.warning(f"Failed to validate TP payload from {payload_path}: {e}")
- return None
-
-
def _payload_summary(payload: github_models.TrustedPublisherPayload | None) ->
dict[str, Any]:
if payload is None:
return {"present": False}
diff --git a/tests/unit/test_checks_compare.py
b/tests/unit/test_checks_compare.py
index 5cbc4c8b..338f331d 100644
--- a/tests/unit/test_checks_compare.py
+++ b/tests/unit/test_checks_compare.py
@@ -26,6 +26,7 @@ import dulwich.objects
import dulwich.refs
import pytest
+import atr.attestable
import atr.models.github
import atr.models.safe
import atr.models.sql
@@ -624,7 +625,7 @@ async def
test_source_trees_creates_temp_workspace_and_cleans_up(
compare = CompareRecorder(repo_only={"extra1.txt", "extra2.txt"})
tmp_root = tmp_path / "temporary-root"
- monkeypatch.setattr(atr.tasks.checks.compare, "_load_tp_payload",
PayloadLoader(payload))
+ monkeypatch.setattr(atr.attestable, "github_tp_payload_read",
PayloadLoader(payload))
monkeypatch.setattr(atr.tasks.checks.compare, "_checkout_github_source",
checkout)
monkeypatch.setattr(atr.tasks.checks, "resolve_archive_dir",
ArchiveDirResolver(cache_dir))
monkeypatch.setattr(atr.tasks.checks.compare, "_find_archive_root",
find_root)
@@ -653,7 +654,7 @@ async def
test_source_trees_payload_none_skips_temp_workspace(monkeypatch: pytes
recorder = RecorderStub(True)
args = _make_args(recorder)
- monkeypatch.setattr(atr.tasks.checks.compare, "_load_tp_payload",
PayloadLoader(None))
+ monkeypatch.setattr(atr.attestable, "github_tp_payload_read",
PayloadLoader(None))
monkeypatch.setattr(
atr.tasks.checks.compare,
"_checkout_github_source",
@@ -680,7 +681,7 @@ async def
test_source_trees_permits_pkg_info_when_pyproject_toml_exists(
compare = CompareRecorder(invalid={"PKG-INFO"})
tmp_root = tmp_path / "temporary-root"
- monkeypatch.setattr(atr.tasks.checks.compare, "_load_tp_payload",
PayloadLoader(payload))
+ monkeypatch.setattr(atr.attestable, "github_tp_payload_read",
PayloadLoader(payload))
monkeypatch.setattr(atr.tasks.checks.compare, "_checkout_github_source",
checkout)
monkeypatch.setattr(atr.tasks.checks, "resolve_archive_dir",
ArchiveDirResolver(cache_dir))
monkeypatch.setattr(atr.tasks.checks.compare, "_find_archive_root",
find_root)
@@ -707,7 +708,7 @@ async def
test_source_trees_records_failure_when_archive_has_invalid_files(
compare = CompareRecorder(invalid={"bad1.txt", "bad2.txt"},
repo_only={"ok.txt"})
tmp_root = tmp_path / "temporary-root"
- monkeypatch.setattr(atr.tasks.checks.compare, "_load_tp_payload",
PayloadLoader(payload))
+ monkeypatch.setattr(atr.attestable, "github_tp_payload_read",
PayloadLoader(payload))
monkeypatch.setattr(atr.tasks.checks.compare, "_checkout_github_source",
checkout)
monkeypatch.setattr(atr.tasks.checks, "resolve_archive_dir",
ArchiveDirResolver(cache_dir))
monkeypatch.setattr(atr.tasks.checks.compare, "_find_archive_root",
find_root)
@@ -738,7 +739,7 @@ async def
test_source_trees_records_failure_when_archive_root_not_found(
find_root = FindArchiveRootRecorder(root=None)
tmp_root = tmp_path / "temporary-root"
- monkeypatch.setattr(atr.tasks.checks.compare, "_load_tp_payload",
PayloadLoader(payload))
+ monkeypatch.setattr(atr.attestable, "github_tp_payload_read",
PayloadLoader(payload))
monkeypatch.setattr(atr.tasks.checks.compare, "_checkout_github_source",
checkout)
monkeypatch.setattr(atr.tasks.checks, "resolve_archive_dir",
ArchiveDirResolver(cache_dir))
monkeypatch.setattr(atr.tasks.checks.compare, "_find_archive_root",
find_root)
@@ -760,7 +761,7 @@ async def
test_source_trees_records_failure_when_cache_dir_unavailable(
args = _make_args(recorder)
payload = _make_payload()
- monkeypatch.setattr(atr.tasks.checks.compare, "_load_tp_payload",
PayloadLoader(payload))
+ monkeypatch.setattr(atr.attestable, "github_tp_payload_read",
PayloadLoader(payload))
monkeypatch.setattr(atr.tasks.checks, "resolve_archive_dir",
ArchiveDirResolver(None))
monkeypatch.setattr(
atr.tasks.checks.compare,
@@ -790,7 +791,7 @@ async def
test_source_trees_records_failure_when_extra_entries_in_archive(
find_root = FindArchiveRootRecorder(root="artifact",
extra_entries=["README.txt", "extra.txt"])
tmp_root = tmp_path / "temporary-root"
- monkeypatch.setattr(atr.tasks.checks.compare, "_load_tp_payload",
PayloadLoader(payload))
+ monkeypatch.setattr(atr.attestable, "github_tp_payload_read",
PayloadLoader(payload))
monkeypatch.setattr(atr.tasks.checks.compare, "_checkout_github_source",
checkout)
monkeypatch.setattr(atr.tasks.checks, "resolve_archive_dir",
ArchiveDirResolver(cache_dir))
monkeypatch.setattr(atr.tasks.checks.compare, "_find_archive_root",
find_root)
@@ -821,7 +822,7 @@ async def
test_source_trees_reports_repo_only_sample_limited_to_five(
compare = CompareRecorder(repo_only=repo_only_files)
tmp_root = tmp_path / "temporary-root"
- monkeypatch.setattr(atr.tasks.checks.compare, "_load_tp_payload",
PayloadLoader(payload))
+ monkeypatch.setattr(atr.attestable, "github_tp_payload_read",
PayloadLoader(payload))
monkeypatch.setattr(atr.tasks.checks.compare, "_checkout_github_source",
checkout)
monkeypatch.setattr(atr.tasks.checks, "resolve_archive_dir",
ArchiveDirResolver(cache_dir))
monkeypatch.setattr(atr.tasks.checks.compare, "_find_archive_root",
find_root)
@@ -843,7 +844,7 @@ async def
test_source_trees_skips_when_not_source(monkeypatch: pytest.MonkeyPatc
args = _make_args(recorder)
monkeypatch.setattr(
- atr.tasks.checks.compare, "_load_tp_payload",
RaiseAsync("_load_tp_payload should not be called")
+ atr.attestable, "github_tp_payload_read",
RaiseAsync("github_tp_payload_read should not be called")
)
await atr.tasks.checks.compare.source_trees(args)
@@ -875,7 +876,7 @@ def _make_payload(
"enterprise": "",
"enterprise_id": "",
"event_name": "push",
- "exp": 1,
+ "exp": 99999999999,
"head_ref": "",
"iat": 1,
"iss": "issuer",
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]