This is an automated email from the ASF dual-hosted git repository. arm pushed a commit to branch jwtoken_multiple_sources in repository https://gitbox.apache.org/repos/asf/tooling-trusted-releases.git
commit 03f2f5ec9c1552fadc721bf36f5d565c36675040 Author: Alastair McFarlane <[email protected]> AuthorDate: Mon Jan 26 11:23:46 2026 +0000 #504 - enable jwtoken.require to take arguments, check tokens from multiple locations and process claims. Update asf_uid handling in API --- atr/api/__init__.py | 115 +++++++++++++++++++++++++++++--------------------- atr/db/interaction.py | 37 +++------------- atr/jwtoken.py | 96 ++++++++++++++++++++++++++++++++++------- 3 files changed, 154 insertions(+), 94 deletions(-) diff --git a/atr/api/__init__.py b/atr/api/__init__.py index 79260cf..452e8db 100644 --- a/atr/api/__init__.py +++ b/atr/api/__init__.py @@ -253,24 +253,24 @@ async def committees_list() -> DictResponse: @api.route("/distribute/ssh/register", methods=["POST"]) [email protected](token_types=["github"]) @quart_schema.validate_request(models.api.DistributeSshRegisterArgs) async def distribute_ssh_register(data: models.api.DistributeSshRegisterArgs) -> DictResponse: """ Register an SSH key sent with a corroborating Trusted Publisher JWT, validating the requested release is in the correct phase. """ - payload, asf_uid, project, release = await interaction.trusted_jwt_for_dist( - data.publisher, - data.jwt, - data.asf_uid, + asf_uid = _jwt_asf_uid(github=True, atr=False) + project, release = await interaction.check_release_phase( interaction.TrustedProjectPhase(data.phase), data.project_name, data.version, ) + claims = getattr(quart.g, "jwt_claims", {}) async with storage.write_as_committee_member(util.unwrap(project.committee).name, asf_uid) as wacm: fingerprint, expires = await wacm.ssh.add_workflow_key( - payload["actor"], - payload["actor_id"], + claims["actor"], + claims["actor_id"], release.project_name, data.ssh_key, ) @@ -284,12 +284,14 @@ async def distribute_ssh_register(data: models.api.DistributeSshRegisterArgs) -> @api.route("/distribute/task/status", methods=["POST"]) [email protected](token_types=["github"]) @quart_schema.validate_request(models.api.DistributeStatusUpdateArgs) async def update_distribution_task_status(data: models.api.DistributeStatusUpdateArgs) -> DictResponse: """ Update the status of a distribution task """ - _payload, _asf_uid = await interaction.validate_trusted_jwt(data.publisher, data.jwt) + + # asf_uid = _jwt_asf_uid() async with db.session() as db_data: status = await db_data.workflow_status( workflow_id=data.workflow, @@ -306,7 +308,7 @@ async def update_distribution_task_status(data: models.api.DistributeStatusUpdat @api.route("/distribution/record", methods=["POST"]) [email protected] [email protected](token_types=["atr"]) @quart_schema.security_scheme([{"BearerAuth": []}]) @quart_schema.validate_request(models.api.DistributionRecordArgs) @quart_schema.validate_response(models.api.DistributionRecordResults, 200) @@ -345,15 +347,14 @@ async def distribution_record(data: models.api.DistributionRecordArgs) -> DictRe @api.route("/distribute/record_from_workflow", methods=["POST"]) [email protected](token_types=["github"]) @quart_schema.validate_request(models.api.DistributionRecordFromWorkflowArgs) async def distribution_record_from_workflow(data: models.api.DistributionRecordFromWorkflowArgs) -> DictResponse: """ Record a distribution. """ - _payload, asf_uid, _project, release = await interaction.trusted_jwt_for_dist( - data.publisher, - data.jwt, - data.asf_uid, + asf_uid = _jwt_asf_uid(github=True, atr=False) + _project, release = await interaction.check_release_phase( interaction.TrustedProjectPhase(data.phase), data.project, data.version, @@ -383,7 +384,7 @@ async def distribution_record_from_workflow(data: models.api.DistributionRecordF @api.route("/ignore/add", methods=["POST"]) [email protected] [email protected](token_types=["atr"]) @quart_schema.security_scheme([{"BearerAuth": []}]) @quart_schema.validate_request(models.api.IgnoreAddArgs) @quart_schema.validate_response(models.api.IgnoreAddResults, 200) @@ -412,7 +413,7 @@ async def ignore_add(data: models.api.IgnoreAddArgs) -> DictResponse: @api.route("/ignore/delete", methods=["POST"]) [email protected] [email protected](token_types=["atr"]) @quart_schema.security_scheme([{"BearerAuth": []}]) @quart_schema.validate_request(models.api.IgnoreDeleteArgs) @quart_schema.validate_response(models.api.IgnoreDeleteResults, 200) @@ -477,7 +478,7 @@ async def jwt_create(data: models.api.JwtCreateArgs) -> DictResponse: @api.route("/key/add", methods=["POST"]) @rate_limiter.rate_limit(10, datetime.timedelta(hours=1)) [email protected] [email protected](token_types=["atr"]) @quart_schema.security_scheme([{"BearerAuth": []}]) @quart_schema.validate_request(models.api.KeyAddArgs) @quart_schema.validate_response(models.api.KeyAddResults, 200) @@ -510,7 +511,7 @@ async def key_add(data: models.api.KeyAddArgs) -> DictResponse: @api.route("/key/delete", methods=["POST"]) [email protected] [email protected](token_types=["atr"]) @quart_schema.security_scheme([{"BearerAuth": []}]) @quart_schema.validate_request(models.api.KeyDeleteArgs) @quart_schema.validate_response(models.api.KeyDeleteResults, 200) @@ -564,7 +565,7 @@ async def key_get(fingerprint: str) -> DictResponse: @api.route("/keys/upload", methods=["POST"]) @rate_limiter.rate_limit(10, datetime.timedelta(hours=1)) [email protected] [email protected](token_types=["atr"]) @quart_schema.security_scheme([{"BearerAuth": []}]) @quart_schema.validate_request(models.api.KeysUploadArgs) @quart_schema.validate_response(models.api.KeysUploadResults, 200) @@ -719,22 +720,25 @@ async def projects_list() -> DictResponse: @api.route("/publisher/distribution/record", methods=["POST"]) [email protected](token_types=["github"]) @quart_schema.validate_request(models.api.PublisherDistributionRecordArgs) async def publisher_distribution_record(data: models.api.PublisherDistributionRecordArgs) -> DictResponse: """ Record a distribution with a corroborating Trusted Publisher JWT. """ + asf_uid = _jwt_asf_uid(github=True, atr=False) + claims = getattr(quart.g, "jwt_claims", {}) try: - _payload, asf_uid, project = await interaction.trusted_jwt( - data.publisher, - data.jwt, + project = await interaction.trusted_project( + claims["repository"], + claims["workflow_ref"], interaction.TrustedProjectPhase.FINISH, ) except interaction.ReleasePolicyNotFoundError: # TODO: We could perform a more advanced query with multiple in_ statements - _payload, asf_uid, project = await interaction.trusted_jwt( - data.publisher, - data.jwt, + project = await interaction.trusted_project( + claims["repository"], + claims["workflow_ref"], interaction.TrustedProjectPhase.COMPOSE, ) async with db.session() as db_data: @@ -767,14 +771,17 @@ async def publisher_distribution_record(data: models.api.PublisherDistributionRe @api.route("/publisher/release/announce", methods=["POST"]) [email protected](token_types=["github"]) @quart_schema.validate_request(models.api.PublisherReleaseAnnounceArgs) async def publisher_release_announce(data: models.api.PublisherReleaseAnnounceArgs) -> DictResponse: """ Announce a release with a corroborating Trusted Publisher JWT. """ - _payload, asf_uid, project = await interaction.trusted_jwt( - data.publisher, - data.jwt, + asf_uid = _jwt_asf_uid(github=True, atr=False) + claims = getattr(quart.g, "jwt_claims", {}) + project = await interaction.trusted_project( + claims["repository"], + claims["workflow_ref"], interaction.TrustedProjectPhase.FINISH, ) try: @@ -801,19 +808,24 @@ async def publisher_release_announce(data: models.api.PublisherReleaseAnnounceAr @api.route("/publisher/ssh/register", methods=["POST"]) [email protected](token_types=["github"]) @rate_limiter.rate_limit(10, datetime.timedelta(hours=1)) @quart_schema.validate_request(models.api.PublisherSshRegisterArgs) async def publisher_ssh_register(data: models.api.PublisherSshRegisterArgs) -> DictResponse: """ Register an SSH key sent with a corroborating Trusted Publisher JWT. """ - payload, asf_uid, project = await interaction.trusted_jwt( - data.publisher, data.jwt, interaction.TrustedProjectPhase.COMPOSE + asf_uid = _jwt_asf_uid(github=True, atr=False) + claims = getattr(quart.g, "jwt_claims", {}) + project = await interaction.trusted_project( + claims["repository"], + claims["workflow_ref"], + interaction.TrustedProjectPhase.COMPOSE, ) async with storage.write_as_committee_member(util.unwrap(project.committee).name, asf_uid) as wacm: fingerprint, expires = await wacm.ssh.add_workflow_key( - payload["actor"], - payload["actor_id"], + claims["actor"], + claims["actor_id"], project.name, data.ssh_key, ) @@ -827,15 +839,18 @@ async def publisher_ssh_register(data: models.api.PublisherSshRegisterArgs) -> D @api.route("/publisher/vote/resolve", methods=["POST"]) [email protected](token_types=["github"]) @quart_schema.validate_request(models.api.PublisherVoteResolveArgs) async def publisher_vote_resolve(data: models.api.PublisherVoteResolveArgs) -> DictResponse: """ Resolve a vote with a corroborating Trusted Publisher JWT. """ # TODO: Need to be able to resolve and make the release immutable - _payload, asf_uid, project = await interaction.trusted_jwt( - data.publisher, - data.jwt, + asf_uid = _jwt_asf_uid(github=True, atr=False) + claims = getattr(quart.g, "jwt_claims", {}) + project = await interaction.trusted_project( + claims["repository"], + claims["workflow_ref"], interaction.TrustedProjectPhase.VOTE, ) async with storage.write_as_project_committee_member(project.name, asf_uid) as wacm: @@ -856,7 +871,7 @@ async def publisher_vote_resolve(data: models.api.PublisherVoteResolveArgs) -> D @api.route("/release/announce", methods=["POST"]) [email protected] [email protected](token_types=["atr"]) @quart_schema.security_scheme([{"BearerAuth": []}]) @quart_schema.validate_request(models.api.ReleaseAnnounceArgs) @quart_schema.validate_response(models.api.ReleaseAnnounceResults, 201) @@ -894,7 +909,7 @@ async def release_announce(data: models.api.ReleaseAnnounceArgs) -> DictResponse @api.route("/release/create", methods=["POST"]) [email protected] [email protected](token_types=["atr"]) @quart_schema.security_scheme([{"BearerAuth": []}]) @quart_schema.validate_request(models.api.ReleaseCreateArgs) @quart_schema.validate_response(models.api.ReleaseCreateResults, 201) @@ -921,7 +936,7 @@ async def release_create(data: models.api.ReleaseCreateArgs) -> DictResponse: # TODO: Duplicates the below @api.route("/release/delete", methods=["POST"]) [email protected] [email protected](token_types=["atr"]) @quart_schema.security_scheme([{"BearerAuth": []}]) @quart_schema.validate_request(models.api.ReleaseDeleteArgs) @quart_schema.validate_response(models.api.ReleaseDeleteResults, 200) @@ -1004,7 +1019,7 @@ async def release_revisions(project: str, version: str) -> DictResponse: @api.route("/release/upload", methods=["POST"]) [email protected] [email protected](token_types=["atr"]) @quart_schema.security_scheme([{"BearerAuth": []}]) @quart_schema.validate_request(models.api.ReleaseUploadArgs) @quart_schema.validate_response(models.api.ReleaseUploadResults, 201) @@ -1072,7 +1087,7 @@ async def releases_list(query_args: models.api.ReleasesListQuery) -> DictRespons @api.route("/signature/provenance", methods=["POST"]) [email protected] [email protected](token_types=["atr"]) @quart_schema.security_scheme([{"BearerAuth": []}]) @quart_schema.validate_request(models.api.SignatureProvenanceArgs) @quart_schema.validate_response(models.api.SignatureProvenanceResults, 200) @@ -1136,7 +1151,7 @@ async def signature_provenance(data: models.api.SignatureProvenanceArgs) -> Dict @api.route("/ssh-key/add", methods=["POST"]) @rate_limiter.rate_limit(10, datetime.timedelta(hours=1)) [email protected] [email protected](token_types=["atr"]) @quart_schema.security_scheme([{"BearerAuth": []}]) @quart_schema.validate_request(models.api.SshKeyAddArgs) @quart_schema.validate_response(models.api.SshKeyAddResults, 201) @@ -1158,7 +1173,7 @@ async def ssh_key_add(data: models.api.SshKeyAddArgs) -> DictResponse: @api.route("/ssh-key/delete", methods=["POST"]) @rate_limiter.rate_limit(10, datetime.timedelta(hours=1)) [email protected] [email protected](token_types=["atr"]) @quart_schema.security_scheme([{"BearerAuth": []}]) @quart_schema.validate_request(models.api.SshKeyDeleteArgs) @quart_schema.validate_response(models.api.SshKeyDeleteResults, 201) @@ -1237,7 +1252,7 @@ async def tasks_list(query_args: models.api.TasksListQuery) -> DictResponse: @api.route("/user/info") @rate_limiter.rate_limit(10, datetime.timedelta(hours=1)) [email protected] [email protected](token_types=["atr"]) @quart_schema.security_scheme([{"BearerAuth": []}]) @quart_schema.validate_response(models.api.UserInfoResults, 200) async def user_info() -> DictResponse: @@ -1295,7 +1310,7 @@ async def users_list() -> DictResponse: # TODO: Add endpoints to allow users to vote @api.route("/vote/resolve", methods=["POST"]) [email protected] [email protected](token_types=["atr"]) @quart_schema.security_scheme([{"BearerAuth": []}]) @quart_schema.validate_request(models.api.VoteResolveArgs) @quart_schema.validate_response(models.api.VoteResolveResults, 200) @@ -1330,7 +1345,7 @@ async def vote_resolve(data: models.api.VoteResolveArgs) -> DictResponse: @api.route("/vote/start", methods=["POST"]) [email protected] [email protected](token_types=["atr"]) @quart_schema.security_scheme([{"BearerAuth": []}]) @quart_schema.validate_request(models.api.VoteStartArgs) @quart_schema.validate_response(models.api.VoteStartResults, 201) @@ -1373,7 +1388,7 @@ async def vote_start(data: models.api.VoteStartArgs) -> DictResponse: @api.route("/vote/tabulate", methods=["POST"]) [email protected] [email protected](token_types=["atr"]) @quart_schema.security_scheme([{"BearerAuth": []}]) @quart_schema.validate_request(models.api.VoteTabulateArgs) @quart_schema.validate_response(models.api.VoteTabulateResults, 200) @@ -1409,9 +1424,15 @@ async def vote_tabulate(data: models.api.VoteTabulateArgs) -> DictResponse: ).model_dump(), 200 -def _jwt_asf_uid() -> str: - claims = getattr(quart.g, "jwt_claims", {}) - asf_uid = claims.get("sub") +def _jwt_asf_uid(atr: bool = True, github: bool = False) -> str: + asf_uid = None + if atr: + claims = getattr(quart.g, "jwt_claims", {}) + asf_uid = claims.get("sub") + elif github: + asf_uid = getattr(quart.g, "github_asf_uid", "") + if asf_uid is None: + raise exceptions.Unauthorized("No ASF UID found in request") if not isinstance(asf_uid, str): raise base.ASFQuartException(f"Invalid token subject: {asf_uid!r}, type: {type(asf_uid)}", errorcode=401) return asf_uid diff --git a/atr/db/interaction.py b/atr/db/interaction.py index 35e2f30..488ce22 100644 --- a/atr/db/interaction.py +++ b/atr/db/interaction.py @@ -19,7 +19,6 @@ import contextlib import datetime import enum from collections.abc import AsyncGenerator, Sequence -from typing import Any, Final import packaging.version as version import sqlalchemy @@ -27,8 +26,6 @@ import sqlalchemy.orm as orm import sqlmodel import atr.db as db -import atr.jwtoken as jwtoken -import atr.ldap as ldap import atr.log as log import atr.models.results as results import atr.models.sql as sql @@ -36,8 +33,6 @@ import atr.user as user import atr.util as util import atr.web as web -_GITHUB_TRUSTED_ROLE_NID: Final[int] = 254436773 - class ApacheUserMissingError(RuntimeError): def __init__(self, message: str, fingerprint: str | None, primary_uid: str | None) -> None: @@ -388,20 +383,9 @@ async def tasks_ongoing_revision( return task_count, latest_revision -async def trusted_jwt(publisher: str, jwt: str, phase: TrustedProjectPhase) -> tuple[dict[str, Any], str, sql.Project]: - payload, asf_uid = await validate_trusted_jwt(publisher, jwt) - # JWT could be for an ASF user or the trusted role, but we need a user here. - if asf_uid is None: - raise InteractionError("ASF user not found") - project = await _trusted_project(payload["repository"], payload["workflow_ref"], phase) - return payload, asf_uid, project - - -async def trusted_jwt_for_dist( - publisher: str, jwt: str, asf_uid: str, phase: TrustedProjectPhase, project_name: str, version_name: str -) -> tuple[dict[str, Any], str, sql.Project, sql.Release]: - payload, _asf_uid = await validate_trusted_jwt(publisher, jwt) - # payload, asf_uid, project = await trusted_jwt(publisher, jwt, phase) +async def check_release_phase( + phase: TrustedProjectPhase, project_name: str, version_name: str +) -> tuple[sql.Project, sql.Release]: async with db.session() as db_data: project = await db_data.project(name=project_name, _committee=True).demand( InteractionError(f"Project {project_name} does not exist") @@ -416,7 +400,7 @@ async def trusted_jwt_for_dist( if (phase == TrustedProjectPhase.FINISH) and (release.phase != sql.ReleasePhase.RELEASE_PREVIEW): raise InteractionError(f"Release {version_name} is not in finish phase") - return payload, asf_uid, project, release + return project, release async def unfinished_releases(asfuid: str) -> list[tuple[str, str, list[sql.Release]]]: @@ -479,17 +463,6 @@ async def user_projects(asf_uid: str, caller_data: db.Session | None = None) -> return [(p.name, p.display_name) for p in projects] -async def validate_trusted_jwt(publisher: str, jwt: str) -> tuple[dict[str, Any], str | None]: - if publisher != "github": - raise InteractionError(f"Publisher {publisher} not supported") - payload = await jwtoken.verify_github_oidc(jwt) - if int(payload["actor_id"]) != _GITHUB_TRUSTED_ROLE_NID: - asf_uid = await ldap.github_to_apache(payload["actor_id"]) - else: - asf_uid = None - return payload, asf_uid - - async def wait_for_task( task: sql.Task, caller_data: db.Session | None = None, @@ -514,7 +487,7 @@ async def wait_for_task( return False -async def _trusted_project(repository: str, workflow_ref: str, phase: TrustedProjectPhase) -> sql.Project: +async def trusted_project(repository: str, workflow_ref: str, phase: TrustedProjectPhase) -> sql.Project: # Debugging log.info(f"GitHub OIDC JWT payload: {repository} {workflow_ref}") repository_name, workflow_path = _trusted_project_checks(repository, workflow_ref) diff --git a/atr/jwtoken.py b/atr/jwtoken.py index 8f3b7ee..dc67ff8 100644 --- a/atr/jwtoken.py +++ b/atr/jwtoken.py @@ -28,6 +28,7 @@ import jwt import quart import atr.config as config +import atr.ldap as ldap import atr.log as log _ALGORITHM: Final[str] = "HS256" @@ -41,6 +42,8 @@ _GITHUB_OIDC_EXPECTED: Final[dict[str, str]] = { "runner_environment": "github-hosted", } _GITHUB_OIDC_ISSUER: Final[str] = "https://token.actions.githubusercontent.com" +_GITHUB_TOKEN_FIELD: Final[str] = "jwt" +_GITHUB_TRUSTED_ROLE_NID: Final[int] = 254436773 _JWT_SECRET_KEY: Final[str] = config.get().JWT_SECRET_KEY if TYPE_CHECKING: @@ -60,23 +63,33 @@ def issue(uid: str, *, ttl: int = 90 * 60) -> str: return jwt.encode(payload, _JWT_SECRET_KEY, algorithm=_ALGORITHM) -def require[**P, R](func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Awaitable[R]]: - @functools.wraps(func) - async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: - token = _extract_bearer_token(quart.request) - try: - claims = verify(token) - except jwt.ExpiredSignatureError as exc: - raise base.ASFQuartException("Token has expired", errorcode=401) from exc - except jwt.InvalidTokenError as exc: - raise base.ASFQuartException("Invalid Bearer JWT format", errorcode=401) from exc - except jwt.PyJWTError as exc: - raise base.ASFQuartException(f"Invalid Bearer JWT: {exc}", errorcode=401) from exc +def require[**P, R](*, token_types=None) -> Callable[[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Awaitable[R]]]: + if token_types is None: + token_types = ["atr"] - quart.g.jwt_claims = claims - return await func(*args, **kwargs) + def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Awaitable[R]]: + @functools.wraps(func) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + errors: list[str] = [] + claims = None + gh_asf_uid = None - return wrapper + for token_type in token_types: + claims, gh_asf_uid = await _try_verify_token(token_type, quart.request, errors) + if claims is not None: + break + + if claims is None: + error_msg = "; ".join(errors) if errors else "Authentication required" + raise base.ASFQuartException(error_msg, errorcode=401) + + quart.g.jwt_claims = claims + quart.g.github_asf_uid = gh_asf_uid + return await func(*args, **kwargs) + + return wrapper + + return decorator def unverified_header_and_payload(jwt_value: str) -> dict[str, Any]: @@ -163,3 +176,56 @@ def _extract_bearer_token(request: quart.Request) -> str: "Authentication required. Please provide a valid Bearer token in the Authorization header", errorcode=401 ) return token + + +async def _extract_token_from_body(request: quart.Request, field: str) -> tuple[str, dict[str, Any]]: + try: + body = await request.get_json() + except Exception as exc: + raise base.ASFQuartException("Invalid JSON in request body", errorcode=400) from exc + + if not body: + raise base.ASFQuartException("Request body is required", errorcode=400) + + token = body.get(field) + if not token: + raise base.ASFQuartException(f"Missing '{field}' field in request body", errorcode=400) + + if not isinstance(token, str): + raise base.ASFQuartException(f"'{field}' must be a string", errorcode=400) + + return token, body + + +async def _try_verify_token( + token_type: str, request: quart.Request, errors: list[str] +) -> tuple[dict[str, Any] | None, str | None] | tuple[None, None]: + try: + if token_type == "atr": + token = _extract_bearer_token(request) + return verify(token), None + if token_type == "github": + token, body = await _extract_token_from_body(request, _GITHUB_TOKEN_FIELD) + github_payload, asf_uid = await _validate_trusted_jwt(body.get("publisher", ""), token) + return github_payload, asf_uid + raise RuntimeError(f"Invalid token type: {token_type}") + except jwt.ExpiredSignatureError: + errors.append(f"{token_type}: Token has expired") + except jwt.InvalidTokenError: + errors.append(f"{token_type}: Invalid Bearer JWT format") + except jwt.PyJWTError as exc: + errors.append(f"{token_type}: Invalid Bearer JWT: {exc!s}") + except Exception as exc: + errors.append(f"{token_type}: {exc!s}") + return None, None + + +async def _validate_trusted_jwt(publisher: str, token: str) -> tuple[dict[str, Any], str | None]: + if publisher != "github": + raise jwt.InvalidTokenError(f"Publisher {publisher} not supported") + payload = await verify_github_oidc(token) + if int(payload["actor_id"]) != _GITHUB_TRUSTED_ROLE_NID: + asf_uid = await ldap.github_to_apache(payload["actor_id"]) + else: + asf_uid = None + return payload, asf_uid --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
