This is an automated email from the ASF dual-hosted git repository. ash pushed a commit to branch execution-api-scope-infra in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 830e56dbedde5e56bf20d2cffb8335571f0a1718 Author: Ash Berlin-Taylor <[email protected]> AuthorDate: Fri Feb 27 15:10:58 2026 +0000 Restructure Execution API security to better use FastAPI's Security scopes Before this change, `JWTBearer` in deps.py does everything: crypto validation, sub-claim matching, and it runs twice per request on ti:self routes because FastAPI includes scopes in dependency cache keys for `HTTPBearer` subclasses, defeating dedup. In a PR that is already created (but not yet merged) we want per-endpoint token type policies (e.g. the /run endpoint will need to accept workload tokens while other routes stay execution-only). This changes is the "foundation" that enables that to work in a nice clear fashion `SecurityScopes` can't express this directly because FastAPI resolves outer router deps before inner ones -- a `token:workload` scope on an endpoint needs to *relax* the default restriction, but `SecurityScopes` only accumulate additively. The fix is a new security.py with a three-layer split: - `JWTBearer` (`_jwt_bearer`) now does only crypto validation and caches the result on the ASGI request scope. It never looks at scopes or token types. - `require_auth` is a plain function (not an `HTTPBearer` subclass) used via `Security(require_auth)` on routers. Because plain functions have `_uses_scopes=False` in FastAPI's dependency system, `_jwt_bearer` (its sub-dep) deduplicates correctly across multiple Security resolutions. It enforces `ti:self` via `SecurityScopes` and reads allowed token types from the matched route object. - `ExecutionAPIRoute` is a custom `APIRoute` subclass that precomputes `allowed_token_types` from `token:*` Security scopes at route registration time — after `include_router` has merged all parent and child dependencies. This sidesteps the resolution ordering problem entirely. To opt a route into workload tokens, it's now a one-liner: ```python @ti_id_router.patch( "/{task_instance_id}/run", dependencies=[Security(require_auth, scopes=["token:execution", "token:workload"])], ) ``` Nothing uses the workload-scoped tokens just yet -- this PR lays the foundation; a follow-up PR will add token:workload to /run. Also cleaned up the module boundaries: security.py owns all auth-related deps (CurrentTIToken, get_team_name_dep, require_auth); deps.py is just the svcs DepContainer. Renamed JWTBearerDep to CurrentTIToken to match the FastAPI current_user convention. I tried _lots_ of different approaches to get this merge/override behaviour, and the cleanest was a custom route class --- .../airflow/api_fastapi/execution_api/AGENTS.md | 4 + .../src/airflow/api_fastapi/execution_api/app.py | 21 +- .../src/airflow/api_fastapi/execution_api/deps.py | 90 +------- .../api_fastapi/execution_api/routes/__init__.py | 6 +- .../execution_api/routes/connections.py | 4 +- .../api_fastapi/execution_api/routes/health.py | 2 +- .../execution_api/routes/task_instances.py | 10 +- .../api_fastapi/execution_api/routes/variables.py | 4 +- .../api_fastapi/execution_api/routes/xcoms.py | 4 +- .../airflow/api_fastapi/execution_api/security.py | 243 +++++++++++++++++++++ .../unit/api_fastapi/execution_api/conftest.py | 60 +++-- .../unit/api_fastapi/execution_api/test_app.py | 17 ++ .../api_fastapi/execution_api/test_security.py | 136 ++++++++++++ .../versions/head/test_task_instances.py | 116 ++++++++-- .../execution_api/versions/head/test_variables.py | 4 +- .../execution_api/versions/head/test_xcoms.py | 4 +- 16 files changed, 555 insertions(+), 170 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/AGENTS.md b/airflow-core/src/airflow/api_fastapi/execution_api/AGENTS.md index 39e08334573..32500df182e 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/AGENTS.md +++ b/airflow-core/src/airflow/api_fastapi/execution_api/AGENTS.md @@ -63,3 +63,7 @@ Adding a new Execution API feature touches multiple packages. All of these must - Triggerer handler: `airflow-core/src/airflow/jobs/triggerer_job_runner.py` - Task SDK generated models: `task-sdk/src/airflow/sdk/api/datamodels/_generated.py` - Full versioning guide: [`contributing-docs/19_execution_api_versioning.rst`](../../../../contributing-docs/19_execution_api_versioning.rst) + +## Token Scope Infrastructure + +Token types (`"execution"`, `"workload"`), route-level enforcement via `ExecutionAPIRoute` + `require_auth`, and the `ti:self` path-parameter validation are documented in the module docstring of `security.py`. diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py b/airflow-core/src/airflow/api_fastapi/execution_api/app.py index ac0d8012a90..f948790730b 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py @@ -220,6 +220,15 @@ class CadwynWithOpenAPICustomization(Cadwyn): if prop.get("type") == "string" and (const := prop.pop("const", None)): prop["enum"] = [const] + # Remove internal x-airflow-* extension fields from OpenAPI spec + # These are used for runtime validation but shouldn't be exposed in the public API + for path_item in openapi_schema.get("paths", {}).values(): + for operation in path_item.values(): + if isinstance(operation, dict): + keys_to_remove = [key for key in operation.keys() if key.startswith("x-airflow-")] + for key in keys_to_remove: + del operation[key] + return openapi_schema @@ -304,23 +313,21 @@ class InProcessExecutionAPI: if not self._app: from airflow.api_fastapi.common.dagbag import create_dag_bag from airflow.api_fastapi.execution_api.app import create_task_execution_api_app - from airflow.api_fastapi.execution_api.deps import ( - JWTBearerDep, - JWTBearerTIPathDep, - ) + from airflow.api_fastapi.execution_api.datamodels.token import TIToken from airflow.api_fastapi.execution_api.routes.connections import has_connection_access from airflow.api_fastapi.execution_api.routes.variables import has_variable_access from airflow.api_fastapi.execution_api.routes.xcoms import has_xcom_access + from airflow.api_fastapi.execution_api.security import _jwt_bearer self._app = create_task_execution_api_app() # Set up dag_bag in app state for dependency injection self._app.state.dag_bag = create_dag_bag() - async def always_allow(): ... + async def always_allow(): + return TIToken(id="00000000-0000-0000-0000-000000000000", claims={"scope": "execution"}) - self._app.dependency_overrides[JWTBearerDep.dependency] = always_allow - self._app.dependency_overrides[JWTBearerTIPathDep.dependency] = always_allow + self._app.dependency_overrides[_jwt_bearer] = always_allow self._app.dependency_overrides[has_connection_access] = always_allow self._app.dependency_overrides[has_variable_access] = always_allow self._app.dependency_overrides[has_xcom_access] = always_allow diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/deps.py b/airflow-core/src/airflow/api_fastapi/execution_api/deps.py index 9fc8c30cb92..192309a8e40 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/deps.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/deps.py @@ -18,23 +18,8 @@ # Disable future annotations in this file to work around https://github.com/fastapi/fastapi/issues/13056 # ruff: noqa: I002 -from typing import Any - -import structlog import svcs -from fastapi import Depends, HTTPException, Request, status -from fastapi.security import HTTPBearer -from sqlalchemy import select - -from airflow.api_fastapi.auth.tokens import JWTValidator -from airflow.api_fastapi.common.db.common import AsyncSessionDep -from airflow.api_fastapi.execution_api.datamodels.token import TIToken -from airflow.configuration import conf -from airflow.models import DagModel, TaskInstance -from airflow.models.dagbundle import DagBundleModel -from airflow.models.team import Team - -log = structlog.get_logger(logger_name=__name__) +from fastapi import Depends, Request # See https://github.com/fastapi/fastapi/issues/13056 @@ -44,76 +29,3 @@ async def _container(request: Request): DepContainer: svcs.Container = Depends(_container) - - -class JWTBearer(HTTPBearer): - """ - A FastAPI security dependency that validates JWT tokens using for the Execution API. - - This will validate the tokens are signed and that the ``sub`` is a UUID, but nothing deeper than that. - - The dependency result will be an `TIToken` object containing the ``id`` UUID (from the ``sub``) and other - validated claims. - """ - - def __init__( - self, - path_param_name: str | None = None, - required_claims: dict[str, Any] | None = None, - ): - super().__init__(auto_error=False) - self.path_param_name = path_param_name - self.required_claims = required_claims or {} - - async def __call__( # type: ignore[override] - self, - request: Request, - services=DepContainer, - ) -> TIToken | None: - creds = await super().__call__(request) - if not creds: - raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing auth token") - - validator: JWTValidator = await services.aget(JWTValidator) - - try: - # Example: Validate "task_instance_id" component of the path matches the one in the token - if self.path_param_name: - id = request.path_params[self.path_param_name] - validators: dict[str, Any] = { - **self.required_claims, - "sub": {"essential": True, "value": id}, - } - else: - validators = self.required_claims - claims = await validator.avalidated_claims(creds.credentials, validators) - return TIToken(id=claims["sub"], claims=claims) - except Exception as err: - log.warning( - "Failed to validate JWT", - exc_info=True, - token=creds.credentials, - ) - raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Invalid auth token: {err}") - - -JWTBearerDep: TIToken = Depends(JWTBearer()) - -# This checks that the UUID in the url matches the one in the token for us. -JWTBearerTIPathDep = Depends(JWTBearer(path_param_name="task_instance_id")) - - -async def get_team_name_dep(session: AsyncSessionDep, token=JWTBearerDep) -> str | None: - """Return the team name associated to the task (if any).""" - if not conf.getboolean("core", "multi_team"): - return None - - stmt = ( - select(Team.name) - .select_from(TaskInstance) - .join(DagModel, DagModel.dag_id == TaskInstance.dag_id) - .join(DagBundleModel, DagBundleModel.name == DagModel.bundle_name) - .join(DagBundleModel.teams) - .where(TaskInstance.id == token.id) - ) - return await session.scalar(stmt) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py index 562b8588fbf..aeef4d092b1 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py @@ -17,9 +17,8 @@ from __future__ import annotations from cadwyn import VersionedAPIRouter -from fastapi import APIRouter +from fastapi import APIRouter, Security -from airflow.api_fastapi.execution_api.deps import JWTBearerDep from airflow.api_fastapi.execution_api.routes import ( asset_events, assets, @@ -32,12 +31,13 @@ from airflow.api_fastapi.execution_api.routes import ( variables, xcoms, ) +from airflow.api_fastapi.execution_api.security import require_auth execution_api_router = APIRouter() execution_api_router.include_router(health.router, prefix="/health", tags=["Health"]) # _Every_ single endpoint under here must be authenticated. Some do further checks on top of these -authenticated_router = VersionedAPIRouter(dependencies=[JWTBearerDep]) # type: ignore[list-item] +authenticated_router = VersionedAPIRouter(dependencies=[Security(require_auth)]) # type: ignore[list-item] authenticated_router.include_router(assets.router, prefix="/assets", tags=["Assets"]) authenticated_router.include_router(asset_events.router, prefix="/asset-events", tags=["Asset Events"]) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/connections.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/connections.py index 44cc3bfbd79..a7bb9959c6d 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/connections.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/connections.py @@ -23,14 +23,14 @@ from typing import Annotated from fastapi import APIRouter, Depends, HTTPException, Path, status from airflow.api_fastapi.execution_api.datamodels.connection import ConnectionResponse -from airflow.api_fastapi.execution_api.deps import JWTBearerDep, get_team_name_dep +from airflow.api_fastapi.execution_api.security import CurrentTIToken, get_team_name_dep from airflow.exceptions import AirflowNotFoundException from airflow.models.connection import Connection async def has_connection_access( connection_id: str = Path(), - token=JWTBearerDep, + token=CurrentTIToken, ) -> bool: """Check if the task has access to the connection.""" # TODO: Placeholder for actual implementation diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/health.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/health.py index d808f51e1db..fb6e40d66a5 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/health.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/health.py @@ -20,7 +20,7 @@ from __future__ import annotations from fastapi import APIRouter from fastapi.responses import JSONResponse -from airflow.api_fastapi.execution_api.deps import DepContainer +from airflow.api_fastapi.execution_api.security import DepContainer router = APIRouter() diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index f22d7c12585..f4c8ed082d5 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -28,7 +28,7 @@ from uuid import UUID import attrs import structlog from cadwyn import VersionedAPIRouter -from fastapi import Body, HTTPException, Query, status +from fastapi import Body, HTTPException, Query, Security, status from pydantic import JsonValue from sqlalchemy import func, or_, tuple_, update from sqlalchemy.engine import CursorResult @@ -59,7 +59,7 @@ from airflow.api_fastapi.execution_api.datamodels.taskinstance import ( TISuccessStatePayload, TITerminalStatePayload, ) -from airflow.api_fastapi.execution_api.deps import JWTBearerTIPathDep +from airflow.api_fastapi.execution_api.security import ExecutionAPIRoute, require_auth from airflow.exceptions import TaskNotFound from airflow.models.asset import AssetActive from airflow.models.dag import DagModel @@ -78,10 +78,10 @@ if TYPE_CHECKING: router = VersionedAPIRouter() ti_id_router = VersionedAPIRouter( + route_class=ExecutionAPIRoute, dependencies=[ - # This checks that the UUID in the url matches the one in the token for us. - JWTBearerTIPathDep - ] + Security(require_auth, scopes=["ti:self"]), + ], ) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/variables.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/variables.py index 5621b6cd081..1e2e2058932 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/variables.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/variables.py @@ -26,14 +26,14 @@ from airflow.api_fastapi.execution_api.datamodels.variable import ( VariablePostBody, VariableResponse, ) -from airflow.api_fastapi.execution_api.deps import JWTBearerDep, get_team_name_dep +from airflow.api_fastapi.execution_api.security import CurrentTIToken, get_team_name_dep from airflow.models.variable import Variable async def has_variable_access( request: Request, variable_key: str = Path(), - token=JWTBearerDep, + token=CurrentTIToken, ): """Check if the task has access to the variable.""" write = request.method not in {"GET", "HEAD", "OPTIONS"} diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py index ec77b64dc44..9b83c40db5e 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py @@ -32,7 +32,7 @@ from airflow.api_fastapi.execution_api.datamodels.xcom import ( XComSequenceIndexResponse, XComSequenceSliceResponse, ) -from airflow.api_fastapi.execution_api.deps import JWTBearerDep +from airflow.api_fastapi.execution_api.security import CurrentTIToken from airflow.models.taskmap import TaskMap from airflow.models.xcom import XComModel from airflow.utils.db import get_query_count @@ -44,7 +44,7 @@ async def has_xcom_access( task_id: str, xcom_key: Annotated[str, Path(alias="key", min_length=1)], request: Request, - token=JWTBearerDep, + token=CurrentTIToken, ) -> bool: """Check if the task has access to the XCom.""" # TODO: Placeholder for actual implementation diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/security.py b/airflow-core/src/airflow/api_fastapi/execution_api/security.py new file mode 100644 index 00000000000..6819d44a2e7 --- /dev/null +++ b/airflow-core/src/airflow/api_fastapi/execution_api/security.py @@ -0,0 +1,243 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Execution API security: JWT validation, token scopes, and route-level access control. + +Token types (``TokenType``): + +``"execution"`` + Default scope, accepted by all endpoints. Short-lived, automatically + refreshed by ``JWTReissueMiddleware``. + +``"workload"`` + Restricted scope, only accepted on routes that opt in via + ``Security(require_auth, scopes=["token:workload"])``. + +Tokens without a ``scope`` claim default to ``"execution"`` for backwards +compatibility (``claims.setdefault("scope", "execution")``). + +Enforcement flow: + 1. ``JWTBearer.__call__`` validates the JWT once per request (crypto + + signature verification), caching the result on the ASGI request scope. + Subsequent FastAPI dependency resolutions and Cadwyn replays return + the cache. + 2. ``require_auth`` is the Security dependency on routers. It receives + the token from ``JWTBearer`` and enforces: + - Token type against the route's ``allowed_token_types`` (precomputed + by ``ExecutionAPIRoute`` from ``token:*`` Security scopes). + - ``ti:self`` scope — checks that the JWT ``sub`` matches the + ``{task_instance_id}`` path parameter. + 3. ``ExecutionAPIRoute`` precomputes ``allowed_token_types`` from + ``token:*`` Security scopes at route registration time. Routes + without explicit ``token:*`` scopes default to execution-only. + +Why ``ExecutionAPIRoute`` is needed: + FastAPI resolves router-level ``Security()`` dependencies from outermost + to innermost. A ``token:workload`` scope on an inner endpoint would need + to *relax* the outer router's default execution-only restriction, but + ``SecurityScopes`` only accumulate additively — an outer dependency + cannot see scopes declared by inner ones. ``ExecutionAPIRoute`` solves + this by inspecting the **merged** dependency list at route registration + time (after ``include_router`` has combined all parent and child + dependencies) and precomputing the full ``allowed_token_types`` set. + ``require_auth`` then reads this precomputed set from the matched route + at request time, avoiding the ordering problem entirely. + + Any router whose routes need non-default token type policies must use + ``route_class=ExecutionAPIRoute``. Routers that only need the default + (execution-only) can use the standard route class — ``require_auth`` + falls back to ``{"execution"}`` when the attribute is absent. +""" + +# Disable future annotations in this file to work around https://github.com/fastapi/fastapi/issues/13056 +# ruff: noqa: I002 + +from typing import Any, Literal, get_args + +import structlog +from fastapi import Depends, HTTPException, Request, status +from fastapi.routing import APIRoute +from fastapi.security import HTTPBearer, SecurityScopes +from sqlalchemy import select + +from airflow.api_fastapi.auth.tokens import JWTValidator +from airflow.api_fastapi.common.db.common import AsyncSessionDep +from airflow.api_fastapi.execution_api.datamodels.token import TIToken +from airflow.api_fastapi.execution_api.deps import DepContainer + +log = structlog.get_logger(logger_name=__name__) + +TokenType = Literal["execution", "workload"] + +VALID_TOKEN_TYPES: frozenset[str] = frozenset(get_args(TokenType)) + +_REQUEST_SCOPE_TOKEN_KEY = "ti_token" + + +class JWTBearer(HTTPBearer): + """ + Validates JWT tokens for the Execution API. + + Performs cryptographic validation once per request and caches the result + on the ASGI request scope. Subsequent resolutions (FastAPI dependency + dedup or Cadwyn replays) return the cached token. + + This dependency handles ONLY crypto validation and token construction. + All route-specific authorization (token type, ti:self) is handled by + ``require_auth``. + """ + + def __init__(self, required_claims: dict[str, Any] | None = None): + super().__init__(auto_error=False) + self.required_claims = required_claims or {} + + async def __call__( # type: ignore[override] + self, + request: Request, + services=DepContainer, + ) -> TIToken | None: + # Return cached token (handles both FastAPI dependency dedup and Cadwyn replays). + if cached := request.scope.get(_REQUEST_SCOPE_TOKEN_KEY): + return cached + + # First resolution — full cryptographic validation. + creds = await super().__call__(request) + if not creds: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing auth token") + + validator: JWTValidator = await services.aget(JWTValidator) + + try: + claims = await validator.avalidated_claims(creds.credentials, dict(self.required_claims)) + except Exception as err: + log.warning("Failed to validate JWT", exc_info=True, token=creds.credentials) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=f"Invalid auth token: {err}") + + claims.setdefault("scope", "execution") + + token = TIToken(id=claims["sub"], claims=claims) + request.scope[_REQUEST_SCOPE_TOKEN_KEY] = token + return token + + +_jwt_bearer = JWTBearer() + + +async def require_auth( + security_scopes: SecurityScopes, + request: Request, + token: TIToken = Depends(_jwt_bearer), +) -> TIToken: + """ + Security dependency that enforces token type and ``ti:self`` scope. + + Used via ``Security(require_auth)`` on routers. ``SecurityScopes`` are + accumulated by FastAPI from all parent ``Security()`` declarations. + + Token type enforcement reads ``route.allowed_token_types`` (precomputed + by ``ExecutionAPIRoute``) or defaults to ``{"execution"}``. + """ + token_scope = token.claims.get("scope", "execution") + + if token_scope not in VALID_TOKEN_TYPES: + log.warning("Invalid token scope in claims", token_scope=token_scope, path=request.url.path) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Invalid token scope: {token_scope}", + ) + + route = request.scope.get("route") + allowed_token_types = getattr(route, "allowed_token_types", frozenset({"execution"})) + + if token_scope not in allowed_token_types: + log.warning( + "Token type not allowed for endpoint", + token_scope=token_scope, + allowed_types=sorted(allowed_token_types), + path=request.url.path, + ) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Token type '{token_scope}' not allowed for this endpoint. " + f"Allowed types: {', '.join(sorted(allowed_token_types))}", + ) + + if "ti:self" in security_scopes.scopes: + ti_self_id = str(request.path_params["task_instance_id"]) + if str(token.id) != ti_self_id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Token subject does not match task instance ID", + ) + + return token + + +CurrentTIToken: TIToken = Depends(require_auth) + + +class ExecutionAPIRoute(APIRoute): + """ + Custom route class that precomputes allowed token types from Security scopes. + + Scopes prefixed with ``token:`` (e.g., ``token:execution``, ``token:workload``) + are extracted at route registration time and stored as ``allowed_token_types``. + If no ``token:*`` scopes are declared, defaults to ``{"execution"}``. + + ``require_auth`` reads ``route.allowed_token_types`` at request time. + """ + + allowed_token_types: frozenset[str] + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + from fastapi.params import Security as SecurityParam + + all_scopes: set[str] = set() + for dep in self.dependencies: + if isinstance(dep, SecurityParam): + all_scopes.update(dep.scopes or []) + + token_scopes = {s.removeprefix("token:") for s in all_scopes if s.startswith("token:")} + + if token_scopes and not token_scopes <= VALID_TOKEN_TYPES: + invalid = token_scopes - VALID_TOKEN_TYPES + raise ValueError(f"Invalid token types in Security scopes: {invalid}") + + self.allowed_token_types = frozenset(token_scopes) if token_scopes else frozenset({"execution"}) + + +async def get_team_name_dep(session: AsyncSessionDep, token=CurrentTIToken) -> str | None: + """Return the team name associated to the task (if any).""" + from airflow.configuration import conf + from airflow.models import DagModel, TaskInstance + from airflow.models.dagbundle import DagBundleModel + from airflow.models.team import Team + + if not conf.getboolean("core", "multi_team"): + return None + + stmt = ( + select(Team.name) + .select_from(TaskInstance) + .join(DagModel, DagModel.dag_id == TaskInstance.dag_id) + .join(DagBundleModel, DagBundleModel.name == DagModel.bundle_name) + .join(DagBundleModel.teams) + .where(TaskInstance.id == token.id) + ) + return await session.scalar(stmt) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py b/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py index 9e26937b63c..71b40abab40 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py @@ -16,51 +16,41 @@ # under the License. from __future__ import annotations -from unittest.mock import AsyncMock - import pytest +from fastapi import FastAPI, Request from fastapi.testclient import TestClient from airflow.api_fastapi.app import cached_app -from airflow.api_fastapi.auth.tokens import JWTValidator -from airflow.api_fastapi.execution_api.app import lifespan +from airflow.api_fastapi.execution_api.datamodels.token import TIToken +from airflow.api_fastapi.execution_api.security import _jwt_bearer + + +def _get_execution_api_app(root_app: FastAPI) -> FastAPI: + """Find the mounted execution API sub-app.""" + for route in root_app.routes: + if hasattr(route, "path") and route.path == "/execution": + return route.app + raise RuntimeError("Execution API sub-app not found") @pytest.fixture -def client(request: pytest.FixtureRequest): - app = cached_app(apps="execution") +def exec_app(client): + """Return the execution API sub-app.""" + return _get_execution_api_app(client.app) - with TestClient(app, headers={"Authorization": "Bearer fake"}) as client: - auth = AsyncMock(spec=JWTValidator) - # Create a side_effect function that dynamically extracts the task instance ID from validators - def smart_validated_claims(cred, validators=None): - # Extract task instance ID from validators if present - # This handles the JWTBearerTIPathDep case where the validator contains the task ID from the path - if ( - validators - and "sub" in validators - and isinstance(validators["sub"], dict) - and "value" in validators["sub"] - ): - return { - "sub": validators["sub"]["value"], - "exp": 9999999999, # Far future expiration - "iat": 1000000000, # Past issuance time - "aud": "test-audience", - } [email protected] +def client(request: pytest.FixtureRequest): + app = cached_app(apps="execution") + exec_app = _get_execution_api_app(app) - # For other cases (like JWTBearerDep) where no specific validators are provided - # Return a default UUID with all required claims - return { - "sub": "00000000-0000-0000-0000-000000000000", - "exp": 9999999999, # Far future expiration - "iat": 1000000000, # Past issuance time - "aud": "test-audience", - } + async def mock_jwt_bearer(request: Request): + ti_id = request.path_params.get("task_instance_id", "00000000-0000-0000-0000-000000000000") + return TIToken(id=str(ti_id), claims={"sub": str(ti_id), "scope": "execution"}) - # Set the side_effect for avalidated_claims - auth.avalidated_claims.side_effect = smart_validated_claims - lifespan.registry.register_value(JWTValidator, auth) + exec_app.dependency_overrides[_jwt_bearer] = mock_jwt_bearer + with TestClient(app, headers={"Authorization": "Bearer fake"}) as client: yield client + + exec_app.dependency_overrides.pop(_jwt_bearer, None) diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py b/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py index 640d920137c..b0cb1d85c2e 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py @@ -43,6 +43,23 @@ def test_access_api_contract(client): assert response.headers["airflow-api-version"] == bundle.versions[0].value +def test_ti_self_routes_have_task_instance_id_param(client): + """Every route with ti:self scope must have a {task_instance_id} path parameter.""" + from fastapi.params import Security as SecurityParam + from fastapi.routing import APIRoute + + app = client.app + + for route in app.routes: + if not isinstance(route, APIRoute): + continue + for dep in route.dependencies: + if isinstance(dep, SecurityParam) and "ti:self" in (dep.scopes or []): + assert "task_instance_id" in route.dependant.path_param_names, ( + f"Route {route.path} has ti:self scope but no {{task_instance_id}} path parameter" + ) + + class TestCorrelationIdMiddleware: def test_correlation_id_echoed_in_response_headers(self, client): """Test that correlation-id from request is echoed back in response headers.""" diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/test_security.py b/airflow-core/tests/unit/api_fastapi/execution_api/test_security.py new file mode 100644 index 00000000000..080d11419fd --- /dev/null +++ b/airflow-core/tests/unit/api_fastapi/execution_api/test_security.py @@ -0,0 +1,136 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest +from fastapi import APIRouter, FastAPI, Request, Security +from fastapi.testclient import TestClient + +from airflow.api_fastapi.execution_api.datamodels.token import TIToken +from airflow.api_fastapi.execution_api.security import ExecutionAPIRoute, _jwt_bearer, require_auth + + +class TestExecutionAPIRoute: + """Unit tests for ExecutionAPIRoute precomputing allowed_token_types from Security scopes.""" + + def test_defaults_to_execution_only(self): + route = ExecutionAPIRoute( + path="/test", + endpoint=lambda: None, + dependencies=[Security(require_auth)], + ) + assert route.allowed_token_types == frozenset({"execution"}) + + def test_extracts_token_scopes(self): + route = ExecutionAPIRoute( + path="/test", + endpoint=lambda: None, + dependencies=[ + Security(require_auth), + Security(require_auth, scopes=["token:execution", "token:workload"]), + ], + ) + assert route.allowed_token_types == frozenset({"execution", "workload"}) + + def test_ignores_non_token_scopes(self): + route = ExecutionAPIRoute( + path="/test", + endpoint=lambda: None, + dependencies=[ + Security(require_auth, scopes=["ti:self", "token:execution"]), + ], + ) + assert route.allowed_token_types == frozenset({"execution"}) + + def test_rejects_invalid_token_types(self): + with pytest.raises(ValueError, match="Invalid token types"): + ExecutionAPIRoute( + path="/test", + endpoint=lambda: None, + dependencies=[ + Security(require_auth, scopes=["token:bogus"]), + ], + ) + + +class TestTokenTypeScopeEnforcement: + """End-to-end: ExecutionAPIRoute + require_auth enforce token types via Security scopes.""" + + @pytest.fixture + def token_type_app(self): + """ + Mirrors the real router structure: an authenticated_router with Security(require_auth), + a child ti_id_router with ExecutionAPIRoute and ti:self, and a specific endpoint on that + router opting in to workload tokens via endpoint-level Security scopes. + """ + app = FastAPI() + + authenticated_router = APIRouter(dependencies=[Security(require_auth)]) + ti_id_router = APIRouter( + route_class=ExecutionAPIRoute, + dependencies=[Security(require_auth, scopes=["ti:self"])], + ) + + @ti_id_router.get("/{task_instance_id}/state") + def default_endpoint(task_instance_id: str): + return {"ok": True} + + @ti_id_router.get( + "/{task_instance_id}/run", + dependencies=[Security(require_auth, scopes=["token:execution", "token:workload"])], + ) + def workload_endpoint(task_instance_id: str): + return {"ok": True} + + authenticated_router.include_router(ti_id_router, prefix="/task-instances") + app.include_router(authenticated_router) + + return app + + TI_ID = "00000000-0000-0000-0000-000000000001" + + def _override_jwt(self, app, scope: str): + ti_id = self.TI_ID + + async def mock_jwt(request: Request): + return TIToken(id=ti_id, claims={"scope": scope}) + + app.dependency_overrides[_jwt_bearer] = mock_jwt + + def test_workload_token_rejected_on_default_route(self, token_type_app): + self._override_jwt(token_type_app, "workload") + client = TestClient(token_type_app) + + resp = client.get(f"/task-instances/{self.TI_ID}/state", headers={"Authorization": "Bearer fake"}) + assert resp.status_code == 403 + assert "Token type 'workload' not allowed" in resp.json()["detail"] + + def test_workload_token_accepted_on_opted_in_route(self, token_type_app): + self._override_jwt(token_type_app, "workload") + client = TestClient(token_type_app) + + resp = client.get(f"/task-instances/{self.TI_ID}/run", headers={"Authorization": "Bearer fake"}) + assert resp.status_code == 200 + + def test_execution_token_accepted_on_both_routes(self, token_type_app): + self._override_jwt(token_type_app, "execution") + client = TestClient(token_type_app) + + state = client.get(f"/task-instances/{self.TI_ID}/state", headers={"Authorization": "Bearer fake"}) + run = client.get(f"/task-instances/{self.TI_ID}/run", headers={"Authorization": "Bearer fake"}) + assert state.status_code == 200 + assert run.status_code == 200 diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index ea1153f01cb..bd65b6f94ee 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -76,13 +76,16 @@ def _create_asset_aliases(session, num: int = 2) -> None: @pytest.fixture -def client_with_extra_route(): ... +def _use_real_jwt_bearer(exec_app): + """Remove the mock jwt_bearer override so the real JWTBearer.__call__ runs.""" + from airflow.api_fastapi.execution_api.security import _jwt_bearer + exec_app.dependency_overrides.pop(_jwt_bearer, None) -def test_id_matches_sub_claim(client, session, create_task_instance): - # Test that this is validated at the router level, so we don't have to test it in each component - # We validate it is set correctly, and test it once [email protected]("_use_real_jwt_bearer") +def test_id_matches_sub_claim(client, session, create_task_instance): + """Test that scope validation (ti:self) is enforced at the router level.""" ti = create_task_instance( task_id="test_ti_run_state_conflict_if_not_queued", state="queued", @@ -90,17 +93,10 @@ def test_id_matches_sub_claim(client, session, create_task_instance): session.commit() validator = mock.AsyncMock(spec=JWTValidator) - claims = {"sub": ti.id} - - def side_effect(cred, validators): - if not validators: - return claims - if str(validators["sub"]["value"]) != str(ti.id): - raise RuntimeError("Fake auth denied") - return claims - - validator.avalidated_claims.side_effect = side_effect - + validator.avalidated_claims.return_value = { + "sub": str(ti.id), + "scope": "execution", + } lifespan.registry.register_value(JWTValidator, validator) payload = { @@ -113,15 +109,10 @@ def test_id_matches_sub_claim(client, session, create_task_instance): resp = client.patch("/execution/task-instances/9c230b40-da03-451d-8bd7-be30471be383/run", json=payload) assert resp.status_code == 403 - assert validator.avalidated_claims.call_args_list[1] == mock.call( - mock.ANY, {"sub": {"essential": True, "value": "9c230b40-da03-451d-8bd7-be30471be383"}} - ) validator.avalidated_claims.reset_mock() resp = client.patch(f"/execution/task-instances/{ti.id}/run", json=payload) - assert resp.status_code == 200, resp.json() - validator.avalidated_claims.assert_awaited() @@ -2925,3 +2916,88 @@ class TestTIPatchRenderedMapIndex: ) assert response.status_code == 422 + + [email protected]("_use_real_jwt_bearer") +class TestTokenTypeValidation: + """Test token scope enforcement (workload vs execution).""" + + def test_workload_scope_rejected_on_default_endpoints(self, client, session, create_task_instance): + """workload scoped tokens should be rejected on endpoints without openapi_extra.""" + ti = create_task_instance(task_id="test_ti_run_heartbeat", state=State.RUNNING) + session.commit() + + validator = mock.AsyncMock(spec=JWTValidator) + validator.avalidated_claims.side_effect = lambda cred, validators: { + "sub": str(ti.id), + "scope": "workload", + "exp": 9999999999, + "iat": 1000000000, + } + lifespan.registry.register_value(JWTValidator, validator) + + payload = {"hostname": "test-host", "pid": 100} + resp = client.put(f"/execution/task-instances/{ti.id}/heartbeat", json=payload) + assert resp.status_code == 403 + assert "Token type 'workload' not allowed" in resp.json()["detail"] + + def test_execution_scope_accepted_on_all_endpoints(self, client, session, create_task_instance): + """execution scoped tokens should be able to call all endpoints.""" + ti = create_task_instance(task_id="test_ti_star", state=State.RUNNING) + session.commit() + + validator = mock.AsyncMock(spec=JWTValidator) + validator.avalidated_claims.side_effect = lambda cred, validators: { + "sub": str(ti.id), + "scope": "execution", + "exp": 9999999999, + "iat": 1000000000, + } + lifespan.registry.register_value(JWTValidator, validator) + + payload = {"state": "success", "end_date": "2024-10-31T13:00:00Z"} + resp = client.patch(f"/execution/task-instances/{ti.id}/state", json=payload) + assert resp.status_code in [200, 204] + + def test_invalid_scope_value_rejected(self, client, session, create_task_instance): + """Tokens with unrecognized scope values should be rejected.""" + ti = create_task_instance(task_id="test_invalid_scope", state=State.QUEUED) + session.commit() + + validator = mock.AsyncMock(spec=JWTValidator) + validator.avalidated_claims.side_effect = lambda cred, validators: { + "sub": str(ti.id), + "scope": "bogus:scope", + "exp": 9999999999, + "iat": 1000000000, + } + lifespan.registry.register_value(JWTValidator, validator) + + payload = { + "state": "running", + "hostname": "test-host", + "unixname": "test-user", + "pid": 100, + "start_date": "2024-10-31T12:00:00Z", + } + + resp = client.patch(f"/execution/task-instances/{ti.id}/run", json=payload) + assert resp.status_code == 403 + assert "Invalid token scope" in resp.json()["detail"] + + def test_no_scope_defaults_to_execution(self, client, session, create_task_instance): + """Tokens without scope claim should default to 'execution'.""" + ti = create_task_instance(task_id="test_no_scope", state=State.RUNNING) + session.commit() + + validator = mock.AsyncMock(spec=JWTValidator) + validator.avalidated_claims.side_effect = lambda cred, validators: { + "sub": str(ti.id), + "exp": 9999999999, + "iat": 1000000000, + } + lifespan.registry.register_value(JWTValidator, validator) + + payload = {"state": "success", "end_date": "2024-10-31T13:00:00Z"} + resp = client.patch(f"/execution/task-instances/{ti.id}/state", json=payload) + assert resp.status_code in [200, 204] diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py index 59b206441de..93cd8ca672e 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py @@ -41,15 +41,15 @@ def setup_method(): @pytest.fixture def access_denied(client): - from airflow.api_fastapi.execution_api.deps import JWTBearerDep from airflow.api_fastapi.execution_api.routes.variables import has_variable_access + from airflow.api_fastapi.execution_api.security import CurrentTIToken last_route = client.app.routes[-1] assert isinstance(last_route, Mount) assert isinstance(last_route.app, FastAPI) exec_app = last_route.app - async def _(request: Request, variable_key: str, token=JWTBearerDep): + async def _(request: Request, variable_key: str, token=CurrentTIToken): await has_variable_access(request, variable_key, token) raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py index f805971bf52..29269d41054 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py @@ -48,8 +48,8 @@ def reset_db(): @pytest.fixture def access_denied(client): - from airflow.api_fastapi.execution_api.deps import JWTBearerDep from airflow.api_fastapi.execution_api.routes.xcoms import has_xcom_access + from airflow.api_fastapi.execution_api.security import CurrentTIToken last_route = client.app.routes[-1] assert isinstance(last_route.app, FastAPI) @@ -61,7 +61,7 @@ def access_denied(client): run_id: str = Path(), task_id: str = Path(), xcom_key: str = Path(alias="key"), - token=JWTBearerDep, + token=CurrentTIToken, ): await has_xcom_access(dag_id, run_id, task_id, xcom_key, request, token) raise HTTPException(
