This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new df4cb30b116 Restructure Execution API security to better use FastAPI's
Security scopes (#62582)
df4cb30b116 is described below
commit df4cb30b116c8628afc465876e08d58f2bcb897b
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Tue Mar 10 13:33:18 2026 +0000
Restructure Execution API security to better use FastAPI's Security scopes
(#62582)
Before this change, `JWTBearer` in deps.py does everything: crypto
validation,
sub-claim matching, and it runs twice per request on ti:self routes because
FastAPI includes scopes in dependency cache keys for `HTTPBearer`
subclasses,
defeating dedup.
In a PR that is already created (but not yet merged) we want per-endpoint
token type policies (e.g. the /run endpoint will need to accept workload
tokens while other routes stay execution-only). This changes is the
"foundation" that enables that to work in a nice clear fashion
`SecurityScopes` can't express this directly because FastAPI resolves outer
router deps before inner ones -- a `token:workload` scope on an endpoint
needs
to *relax* the default restriction, but `SecurityScopes` only accumulate
additively.
The fix is a new security.py with a three-layer split:
- `JWTBearer` (`_jwt_bearer`) now does only crypto validation and caches
the
result on the ASGI request scope. It never looks at scopes or token
types.
- `require_auth` is a plain function (not an `HTTPBearer` subclass) used
via
`Security(require_auth)` on routers. Because plain functions have
`_uses_scopes=False` in FastAPI's dependency system, `_jwt_bearer` (its
sub-dep) deduplicates correctly across multiple Security resolutions. It
enforces `ti:self` via `SecurityScopes` and reads allowed token types
from
the matched route object.
- `ExecutionAPIRoute` is a custom `APIRoute` subclass that precomputes
`allowed_token_types` from `token:*` Security scopes at route
registration
time — after `include_router` has merged all parent and child
dependencies. This sidesteps the resolution ordering problem entirely.
To opt a route into workload tokens, it's now a one-liner:
```python
@ti_id_router.patch(
"/{task_instance_id}/run",
dependencies=[Security(require_auth, scopes=["token:execution",
"token:workload"])],
)
```
Nothing uses the workload-scoped tokens just yet -- this PR lays the
foundation; a follow-up PR will add token:workload to /run.
Also cleaned up the module boundaries: security.py owns all auth-related
deps
(CurrentTIToken, get_team_name_dep, require_auth); deps.py is just the svcs
DepContainer. Renamed JWTBearerDep to CurrentTIToken to match the FastAPI
current_user convention.
I tried _lots_ of different approaches to get this merge/override behaviour,
and the cleanest was a custom route class
---
.codespellignorelines | 1 +
.../airflow/api_fastapi/execution_api/AGENTS.md | 4 +
.../src/airflow/api_fastapi/execution_api/app.py | 26 ++-
.../src/airflow/api_fastapi/execution_api/deps.py | 90 +-------
.../api_fastapi/execution_api/routes/__init__.py | 6 +-
.../execution_api/routes/connections.py | 4 +-
.../execution_api/routes/task_instances.py | 10 +-
.../api_fastapi/execution_api/routes/variables.py | 4 +-
.../api_fastapi/execution_api/routes/xcoms.py | 4 +-
.../airflow/api_fastapi/execution_api/security.py | 243 +++++++++++++++++++++
.../unit/api_fastapi/execution_api/conftest.py | 60 +++--
.../unit/api_fastapi/execution_api/test_app.py | 17 ++
.../api_fastapi/execution_api/test_security.py | 138 ++++++++++++
.../versions/head/test_task_instances.py | 116 ++++++++--
.../execution_api/versions/head/test_variables.py | 4 +-
.../execution_api/versions/head/test_xcoms.py | 4 +-
16 files changed, 563 insertions(+), 168 deletions(-)
diff --git a/.codespellignorelines b/.codespellignorelines
index 1234698be30..5e8e3650862 100644
--- a/.codespellignorelines
+++ b/.codespellignorelines
@@ -4,3 +4,4 @@
The platform supports **C**reate, **R**ead, **U**pdate, and **D**elete
operations on most resources.
<pre><code>Code block\ndoes not\nrespect\nnewlines\n</code></pre>
"trough",
+ assert "task_instance_id" in route.dependant.path_param_names,
(
diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/AGENTS.md
b/airflow-core/src/airflow/api_fastapi/execution_api/AGENTS.md
index 39e08334573..32500df182e 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/AGENTS.md
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/AGENTS.md
@@ -63,3 +63,7 @@ Adding a new Execution API feature touches multiple packages.
All of these must
- Triggerer handler: `airflow-core/src/airflow/jobs/triggerer_job_runner.py`
- Task SDK generated models:
`task-sdk/src/airflow/sdk/api/datamodels/_generated.py`
- Full versioning guide:
[`contributing-docs/19_execution_api_versioning.rst`](../../../../contributing-docs/19_execution_api_versioning.rst)
+
+## Token Scope Infrastructure
+
+Token types (`"execution"`, `"workload"`), route-level enforcement via
`ExecutionAPIRoute` + `require_auth`, and the `ti:self` path-parameter
validation are documented in the module docstring of `security.py`.
diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py
b/airflow-core/src/airflow/api_fastapi/execution_api/app.py
index ac0d8012a90..c7a9593c3c8 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py
@@ -220,6 +220,15 @@ class CadwynWithOpenAPICustomization(Cadwyn):
if prop.get("type") == "string" and (const :=
prop.pop("const", None)):
prop["enum"] = [const]
+ # Remove internal x-airflow-* extension fields from OpenAPI spec
+ # These are used for runtime validation but shouldn't be exposed in
the public API
+ for path_item in openapi_schema.get("paths", {}).values():
+ for operation in path_item.values():
+ if isinstance(operation, dict):
+ keys_to_remove = [key for key in operation.keys() if
key.startswith("x-airflow-")]
+ for key in keys_to_remove:
+ del operation[key]
+
return openapi_schema
@@ -304,23 +313,26 @@ class InProcessExecutionAPI:
if not self._app:
from airflow.api_fastapi.common.dagbag import create_dag_bag
from airflow.api_fastapi.execution_api.app import
create_task_execution_api_app
- from airflow.api_fastapi.execution_api.deps import (
- JWTBearerDep,
- JWTBearerTIPathDep,
- )
+ from airflow.api_fastapi.execution_api.datamodels.token import
TIToken
from airflow.api_fastapi.execution_api.routes.connections import
has_connection_access
from airflow.api_fastapi.execution_api.routes.variables import
has_variable_access
from airflow.api_fastapi.execution_api.routes.xcoms import
has_xcom_access
+ from airflow.api_fastapi.execution_api.security import _jwt_bearer
self._app = create_task_execution_api_app()
# Set up dag_bag in app state for dependency injection
self._app.state.dag_bag = create_dag_bag()
- async def always_allow(): ...
+ async def always_allow(request: Request):
+ from uuid import UUID
+
+ ti_id = UUID(
+ request.path_params.get("task_instance_id",
"00000000-0000-0000-0000-000000000000")
+ )
+ return TIToken(id=ti_id, claims={"scope": "execution"})
- self._app.dependency_overrides[JWTBearerDep.dependency] =
always_allow
- self._app.dependency_overrides[JWTBearerTIPathDep.dependency] =
always_allow
+ self._app.dependency_overrides[_jwt_bearer] = always_allow
self._app.dependency_overrides[has_connection_access] =
always_allow
self._app.dependency_overrides[has_variable_access] = always_allow
self._app.dependency_overrides[has_xcom_access] = always_allow
diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/deps.py
b/airflow-core/src/airflow/api_fastapi/execution_api/deps.py
index 9fc8c30cb92..192309a8e40 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/deps.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/deps.py
@@ -18,23 +18,8 @@
# Disable future annotations in this file to work around
https://github.com/fastapi/fastapi/issues/13056
# ruff: noqa: I002
-from typing import Any
-
-import structlog
import svcs
-from fastapi import Depends, HTTPException, Request, status
-from fastapi.security import HTTPBearer
-from sqlalchemy import select
-
-from airflow.api_fastapi.auth.tokens import JWTValidator
-from airflow.api_fastapi.common.db.common import AsyncSessionDep
-from airflow.api_fastapi.execution_api.datamodels.token import TIToken
-from airflow.configuration import conf
-from airflow.models import DagModel, TaskInstance
-from airflow.models.dagbundle import DagBundleModel
-from airflow.models.team import Team
-
-log = structlog.get_logger(logger_name=__name__)
+from fastapi import Depends, Request
# See https://github.com/fastapi/fastapi/issues/13056
@@ -44,76 +29,3 @@ async def _container(request: Request):
DepContainer: svcs.Container = Depends(_container)
-
-
-class JWTBearer(HTTPBearer):
- """
- A FastAPI security dependency that validates JWT tokens using for the
Execution API.
-
- This will validate the tokens are signed and that the ``sub`` is a UUID,
but nothing deeper than that.
-
- The dependency result will be an `TIToken` object containing the ``id``
UUID (from the ``sub``) and other
- validated claims.
- """
-
- def __init__(
- self,
- path_param_name: str | None = None,
- required_claims: dict[str, Any] | None = None,
- ):
- super().__init__(auto_error=False)
- self.path_param_name = path_param_name
- self.required_claims = required_claims or {}
-
- async def __call__( # type: ignore[override]
- self,
- request: Request,
- services=DepContainer,
- ) -> TIToken | None:
- creds = await super().__call__(request)
- if not creds:
- raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing auth token")
-
- validator: JWTValidator = await services.aget(JWTValidator)
-
- try:
- # Example: Validate "task_instance_id" component of the path
matches the one in the token
- if self.path_param_name:
- id = request.path_params[self.path_param_name]
- validators: dict[str, Any] = {
- **self.required_claims,
- "sub": {"essential": True, "value": id},
- }
- else:
- validators = self.required_claims
- claims = await validator.avalidated_claims(creds.credentials,
validators)
- return TIToken(id=claims["sub"], claims=claims)
- except Exception as err:
- log.warning(
- "Failed to validate JWT",
- exc_info=True,
- token=creds.credentials,
- )
- raise HTTPException(status_code=status.HTTP_403_FORBIDDEN,
detail=f"Invalid auth token: {err}")
-
-
-JWTBearerDep: TIToken = Depends(JWTBearer())
-
-# This checks that the UUID in the url matches the one in the token for us.
-JWTBearerTIPathDep = Depends(JWTBearer(path_param_name="task_instance_id"))
-
-
-async def get_team_name_dep(session: AsyncSessionDep, token=JWTBearerDep) ->
str | None:
- """Return the team name associated to the task (if any)."""
- if not conf.getboolean("core", "multi_team"):
- return None
-
- stmt = (
- select(Team.name)
- .select_from(TaskInstance)
- .join(DagModel, DagModel.dag_id == TaskInstance.dag_id)
- .join(DagBundleModel, DagBundleModel.name == DagModel.bundle_name)
- .join(DagBundleModel.teams)
- .where(TaskInstance.id == token.id)
- )
- return await session.scalar(stmt)
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py
index 562b8588fbf..aeef4d092b1 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/__init__.py
@@ -17,9 +17,8 @@
from __future__ import annotations
from cadwyn import VersionedAPIRouter
-from fastapi import APIRouter
+from fastapi import APIRouter, Security
-from airflow.api_fastapi.execution_api.deps import JWTBearerDep
from airflow.api_fastapi.execution_api.routes import (
asset_events,
assets,
@@ -32,12 +31,13 @@ from airflow.api_fastapi.execution_api.routes import (
variables,
xcoms,
)
+from airflow.api_fastapi.execution_api.security import require_auth
execution_api_router = APIRouter()
execution_api_router.include_router(health.router, prefix="/health",
tags=["Health"])
# _Every_ single endpoint under here must be authenticated. Some do further
checks on top of these
-authenticated_router = VersionedAPIRouter(dependencies=[JWTBearerDep]) #
type: ignore[list-item]
+authenticated_router =
VersionedAPIRouter(dependencies=[Security(require_auth)]) # type:
ignore[list-item]
authenticated_router.include_router(assets.router, prefix="/assets",
tags=["Assets"])
authenticated_router.include_router(asset_events.router,
prefix="/asset-events", tags=["Asset Events"])
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/connections.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/connections.py
index 44cc3bfbd79..a7bb9959c6d 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/connections.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/connections.py
@@ -23,14 +23,14 @@ from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, Path, status
from airflow.api_fastapi.execution_api.datamodels.connection import
ConnectionResponse
-from airflow.api_fastapi.execution_api.deps import JWTBearerDep,
get_team_name_dep
+from airflow.api_fastapi.execution_api.security import CurrentTIToken,
get_team_name_dep
from airflow.exceptions import AirflowNotFoundException
from airflow.models.connection import Connection
async def has_connection_access(
connection_id: str = Path(),
- token=JWTBearerDep,
+ token=CurrentTIToken,
) -> bool:
"""Check if the task has access to the connection."""
# TODO: Placeholder for actual implementation
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index 53e3bbb2a9f..9273cc8b3d4 100644
---
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -28,7 +28,7 @@ from uuid import UUID
import attrs
import structlog
from cadwyn import VersionedAPIRouter
-from fastapi import Body, HTTPException, Query, status
+from fastapi import Body, HTTPException, Query, Security, status
from pydantic import JsonValue
from sqlalchemy import func, or_, tuple_, update
from sqlalchemy.engine import CursorResult
@@ -59,7 +59,7 @@ from
airflow.api_fastapi.execution_api.datamodels.taskinstance import (
TISuccessStatePayload,
TITerminalStatePayload,
)
-from airflow.api_fastapi.execution_api.deps import JWTBearerTIPathDep
+from airflow.api_fastapi.execution_api.security import ExecutionAPIRoute,
require_auth
from airflow.exceptions import TaskNotFound
from airflow.models.asset import AssetActive
from airflow.models.dag import DagModel
@@ -78,10 +78,10 @@ if TYPE_CHECKING:
router = VersionedAPIRouter()
ti_id_router = VersionedAPIRouter(
+ route_class=ExecutionAPIRoute,
dependencies=[
- # This checks that the UUID in the url matches the one in the token
for us.
- JWTBearerTIPathDep
- ]
+ Security(require_auth, scopes=["ti:self"]),
+ ],
)
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/variables.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/variables.py
index 5621b6cd081..1e2e2058932 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/variables.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/variables.py
@@ -26,14 +26,14 @@ from airflow.api_fastapi.execution_api.datamodels.variable
import (
VariablePostBody,
VariableResponse,
)
-from airflow.api_fastapi.execution_api.deps import JWTBearerDep,
get_team_name_dep
+from airflow.api_fastapi.execution_api.security import CurrentTIToken,
get_team_name_dep
from airflow.models.variable import Variable
async def has_variable_access(
request: Request,
variable_key: str = Path(),
- token=JWTBearerDep,
+ token=CurrentTIToken,
):
"""Check if the task has access to the variable."""
write = request.method not in {"GET", "HEAD", "OPTIONS"}
diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py
index ec77b64dc44..9b83c40db5e 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/xcoms.py
@@ -32,7 +32,7 @@ from airflow.api_fastapi.execution_api.datamodels.xcom import
(
XComSequenceIndexResponse,
XComSequenceSliceResponse,
)
-from airflow.api_fastapi.execution_api.deps import JWTBearerDep
+from airflow.api_fastapi.execution_api.security import CurrentTIToken
from airflow.models.taskmap import TaskMap
from airflow.models.xcom import XComModel
from airflow.utils.db import get_query_count
@@ -44,7 +44,7 @@ async def has_xcom_access(
task_id: str,
xcom_key: Annotated[str, Path(alias="key", min_length=1)],
request: Request,
- token=JWTBearerDep,
+ token=CurrentTIToken,
) -> bool:
"""Check if the task has access to the XCom."""
# TODO: Placeholder for actual implementation
diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/security.py
b/airflow-core/src/airflow/api_fastapi/execution_api/security.py
new file mode 100644
index 00000000000..215997d28d9
--- /dev/null
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/security.py
@@ -0,0 +1,243 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Execution API security: JWT validation, token scopes, and route-level access
control.
+
+Token types (``TokenType``):
+
+``"execution"``
+ Default scope, accepted by all endpoints. Short-lived, automatically
+ refreshed by ``JWTReissueMiddleware``.
+
+``"workload"``
+ Restricted scope, only accepted on routes that opt in via
+ ``Security(require_auth, scopes=["token:workload"])``.
+
+Tokens without a ``scope`` claim default to ``"execution"`` for backwards
+compatibility (``claims.setdefault("scope", "execution")``).
+
+Enforcement flow:
+ 1. ``JWTBearer.__call__`` validates the JWT once per request (crypto +
+ signature verification), caching the result on the ASGI request scope.
+ Subsequent FastAPI dependency resolutions and Cadwyn replays return
+ the cache.
+ 2. ``require_auth`` is the Security dependency on routers. It receives
+ the token from ``JWTBearer`` and enforces:
+ - Token type against the route's ``allowed_token_types`` (precomputed
+ by ``ExecutionAPIRoute`` from ``token:*`` Security scopes).
+ - ``ti:self`` scope — checks that the JWT ``sub`` matches the
+ ``{task_instance_id}`` path parameter.
+ 3. ``ExecutionAPIRoute`` precomputes ``allowed_token_types`` from
+ ``token:*`` Security scopes at route registration time. Routes
+ without explicit ``token:*`` scopes default to execution-only.
+
+Why ``ExecutionAPIRoute`` is needed:
+ FastAPI resolves router-level ``Security()`` dependencies from outermost
+ to innermost. A ``token:workload`` scope on an inner endpoint would need
+ to *relax* the outer router's default execution-only restriction, but
+ ``SecurityScopes`` only accumulate additively — an outer dependency
+ cannot see scopes declared by inner ones. ``ExecutionAPIRoute`` solves
+ this by inspecting the **merged** dependency list at route registration
+ time (after ``include_router`` has combined all parent and child
+ dependencies) and precomputing the full ``allowed_token_types`` set.
+ ``require_auth`` then reads this precomputed set from the matched route
+ at request time, avoiding the ordering problem entirely.
+
+ Any router whose routes need non-default token type policies must use
+ ``route_class=ExecutionAPIRoute``. Routers that only need the default
+ (execution-only) can use the standard route class — ``require_auth``
+ falls back to ``{"execution"}`` when the attribute is absent.
+"""
+
+# Disable future annotations in this file to work around
https://github.com/fastapi/fastapi/issues/13056
+# ruff: noqa: I002
+
+from typing import Any, Literal, get_args
+
+import structlog
+from fastapi import Depends, HTTPException, Request, status
+from fastapi.params import Security as SecurityParam
+from fastapi.routing import APIRoute
+from fastapi.security import HTTPBearer, SecurityScopes
+from sqlalchemy import select
+
+from airflow.api_fastapi.auth.tokens import JWTValidator
+from airflow.api_fastapi.common.db.common import AsyncSessionDep
+from airflow.api_fastapi.execution_api.datamodels.token import TIToken
+from airflow.api_fastapi.execution_api.deps import DepContainer
+
+log = structlog.get_logger(logger_name=__name__)
+
+TokenType = Literal["execution", "workload"]
+
+VALID_TOKEN_TYPES: frozenset[str] = frozenset(get_args(TokenType))
+
+_REQUEST_SCOPE_TOKEN_KEY = "ti_token"
+
+
+class JWTBearer(HTTPBearer):
+ """
+ Validates JWT tokens for the Execution API.
+
+ Performs cryptographic validation once per request and caches the result
+ on the ASGI request scope. Subsequent resolutions (FastAPI dependency
+ dedup or Cadwyn replays) return the cached token.
+
+ This dependency handles ONLY crypto validation and token construction.
+ All route-specific authorization (token type, ti:self) is handled by
+ ``require_auth``.
+ """
+
+ def __init__(self, required_claims: dict[str, Any] | None = None):
+ super().__init__(auto_error=False)
+ self.required_claims = required_claims or {}
+
+ async def __call__( # type: ignore[override]
+ self,
+ request: Request,
+ services=DepContainer,
+ ) -> TIToken | None:
+ # Return cached token (handles both FastAPI dependency dedup and
Cadwyn replays).
+ if cached := request.scope.get(_REQUEST_SCOPE_TOKEN_KEY):
+ return cached
+
+ # First resolution — full cryptographic validation.
+ creds = await super().__call__(request)
+ if not creds:
+ raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED,
detail="Missing auth token")
+
+ validator: JWTValidator = await services.aget(JWTValidator)
+
+ try:
+ claims = await validator.avalidated_claims(creds.credentials,
dict(self.required_claims))
+ except Exception as err:
+ log.warning("Failed to validate JWT", exc_info=True,
token=creds.credentials)
+ raise HTTPException(status_code=status.HTTP_403_FORBIDDEN,
detail=f"Invalid auth token: {err}")
+
+ claims.setdefault("scope", "execution")
+
+ token = TIToken(id=claims["sub"], claims=claims)
+ request.scope[_REQUEST_SCOPE_TOKEN_KEY] = token
+ return token
+
+
+_jwt_bearer = JWTBearer()
+
+
+async def require_auth(
+ security_scopes: SecurityScopes,
+ request: Request,
+ token: TIToken = Depends(_jwt_bearer),
+) -> TIToken:
+ """
+ Security dependency that enforces token type and ``ti:self`` scope.
+
+ Used via ``Security(require_auth)`` on routers. ``SecurityScopes`` are
+ accumulated by FastAPI from all parent ``Security()`` declarations.
+
+ Token type enforcement reads ``route.allowed_token_types`` (precomputed
+ by ``ExecutionAPIRoute``) or defaults to ``{"execution"}``.
+ """
+ token_scope = token.claims.get("scope", "execution")
+
+ if token_scope not in VALID_TOKEN_TYPES:
+ log.warning("Invalid token scope in claims", token_scope=token_scope,
path=request.url.path)
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail=f"Invalid token scope: {token_scope}",
+ )
+
+ route = request.scope.get("route")
+ allowed_token_types = getattr(route, "allowed_token_types",
frozenset({"execution"}))
+
+ if token_scope not in allowed_token_types:
+ log.warning(
+ "Token type not allowed for endpoint",
+ token_scope=token_scope,
+ allowed_types=sorted(allowed_token_types),
+ path=request.url.path,
+ )
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail=f"Token type '{token_scope}' not allowed for this endpoint.
"
+ f"Allowed types: {', '.join(sorted(allowed_token_types))}",
+ )
+
+ if "ti:self" in security_scopes.scopes:
+ ti_self_id = str(request.path_params["task_instance_id"])
+ if str(token.id) != ti_self_id:
+ raise HTTPException(
+ status_code=status.HTTP_403_FORBIDDEN,
+ detail="Token subject does not match task instance ID",
+ )
+
+ return token
+
+
+CurrentTIToken: TIToken = Depends(require_auth)
+
+
+class ExecutionAPIRoute(APIRoute):
+ """
+ Custom route class that precomputes allowed token types from Security
scopes.
+
+ Scopes prefixed with ``token:`` (e.g., ``token:execution``,
``token:workload``)
+ are extracted at route registration time and stored as
``allowed_token_types``.
+ If no ``token:*`` scopes are declared, defaults to ``{"execution"}``.
+
+ ``require_auth`` reads ``route.allowed_token_types`` at request time.
+ """
+
+ allowed_token_types: frozenset[str]
+
+ def __init__(self, *args: Any, **kwargs: Any) -> None:
+ super().__init__(*args, **kwargs)
+
+ all_scopes: set[str] = set()
+ for dep in self.dependencies:
+ if isinstance(dep, SecurityParam):
+ all_scopes.update(dep.scopes or [])
+
+ token_scopes = {s.removeprefix("token:") for s in all_scopes if
s.startswith("token:")}
+
+ if token_scopes and not token_scopes <= VALID_TOKEN_TYPES:
+ invalid = token_scopes - VALID_TOKEN_TYPES
+ raise ValueError(f"Invalid token types in Security scopes:
{invalid}")
+
+ self.allowed_token_types = frozenset(token_scopes) if token_scopes
else frozenset({"execution"})
+
+
+async def get_team_name_dep(session: AsyncSessionDep, token=CurrentTIToken) ->
str | None:
+ """Return the team name associated to the task (if any)."""
+ from airflow.configuration import conf
+ from airflow.models import DagModel, TaskInstance
+ from airflow.models.dagbundle import DagBundleModel
+ from airflow.models.team import Team
+
+ if not conf.getboolean("core", "multi_team"):
+ return None
+
+ stmt = (
+ select(Team.name)
+ .select_from(TaskInstance)
+ .join(DagModel, DagModel.dag_id == TaskInstance.dag_id)
+ .join(DagBundleModel, DagBundleModel.name == DagModel.bundle_name)
+ .join(DagBundleModel.teams)
+ .where(TaskInstance.id == token.id)
+ )
+ return await session.scalar(stmt)
diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py
b/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py
index 9e26937b63c..78bd0548df9 100644
--- a/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py
+++ b/airflow-core/tests/unit/api_fastapi/execution_api/conftest.py
@@ -16,51 +16,43 @@
# under the License.
from __future__ import annotations
-from unittest.mock import AsyncMock
-
import pytest
+from fastapi import FastAPI, Request
from fastapi.testclient import TestClient
from airflow.api_fastapi.app import cached_app
-from airflow.api_fastapi.auth.tokens import JWTValidator
-from airflow.api_fastapi.execution_api.app import lifespan
+from airflow.api_fastapi.execution_api.datamodels.token import TIToken
+from airflow.api_fastapi.execution_api.security import _jwt_bearer
+
+
+def _get_execution_api_app(root_app: FastAPI) -> FastAPI:
+ """Find the mounted execution API sub-app."""
+ for route in root_app.routes:
+ if hasattr(route, "path") and route.path == "/execution":
+ return route.app
+ raise RuntimeError("Execution API sub-app not found")
+
+
[email protected]
+def exec_app(client):
+ """Return the execution API sub-app."""
+ return _get_execution_api_app(client.app)
@pytest.fixture
def client(request: pytest.FixtureRequest):
app = cached_app(apps="execution")
+ exec_app = _get_execution_api_app(app)
- with TestClient(app, headers={"Authorization": "Bearer fake"}) as client:
- auth = AsyncMock(spec=JWTValidator)
-
- # Create a side_effect function that dynamically extracts the task
instance ID from validators
- def smart_validated_claims(cred, validators=None):
- # Extract task instance ID from validators if present
- # This handles the JWTBearerTIPathDep case where the validator
contains the task ID from the path
- if (
- validators
- and "sub" in validators
- and isinstance(validators["sub"], dict)
- and "value" in validators["sub"]
- ):
- return {
- "sub": validators["sub"]["value"],
- "exp": 9999999999, # Far future expiration
- "iat": 1000000000, # Past issuance time
- "aud": "test-audience",
- }
+ async def mock_jwt_bearer(request: Request):
+ from uuid import UUID
- # For other cases (like JWTBearerDep) where no specific validators
are provided
- # Return a default UUID with all required claims
- return {
- "sub": "00000000-0000-0000-0000-000000000000",
- "exp": 9999999999, # Far future expiration
- "iat": 1000000000, # Past issuance time
- "aud": "test-audience",
- }
+ ti_id = UUID(request.path_params.get("task_instance_id",
"00000000-0000-0000-0000-000000000000"))
+ return TIToken(id=ti_id, claims={"sub": str(ti_id), "scope":
"execution"})
- # Set the side_effect for avalidated_claims
- auth.avalidated_claims.side_effect = smart_validated_claims
- lifespan.registry.register_value(JWTValidator, auth)
+ exec_app.dependency_overrides[_jwt_bearer] = mock_jwt_bearer
+ with TestClient(app, headers={"Authorization": "Bearer fake"}) as client:
yield client
+
+ exec_app.dependency_overrides.pop(_jwt_bearer, None)
diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py
b/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py
index 640d920137c..b0cb1d85c2e 100644
--- a/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py
+++ b/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py
@@ -43,6 +43,23 @@ def test_access_api_contract(client):
assert response.headers["airflow-api-version"] == bundle.versions[0].value
+def test_ti_self_routes_have_task_instance_id_param(client):
+ """Every route with ti:self scope must have a {task_instance_id} path
parameter."""
+ from fastapi.params import Security as SecurityParam
+ from fastapi.routing import APIRoute
+
+ app = client.app
+
+ for route in app.routes:
+ if not isinstance(route, APIRoute):
+ continue
+ for dep in route.dependencies:
+ if isinstance(dep, SecurityParam) and "ti:self" in (dep.scopes or
[]):
+ assert "task_instance_id" in route.dependant.path_param_names,
(
+ f"Route {route.path} has ti:self scope but no
{{task_instance_id}} path parameter"
+ )
+
+
class TestCorrelationIdMiddleware:
def test_correlation_id_echoed_in_response_headers(self, client):
"""Test that correlation-id from request is echoed back in response
headers."""
diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/test_security.py
b/airflow-core/tests/unit/api_fastapi/execution_api/test_security.py
new file mode 100644
index 00000000000..8fff2c9f732
--- /dev/null
+++ b/airflow-core/tests/unit/api_fastapi/execution_api/test_security.py
@@ -0,0 +1,138 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from uuid import UUID
+
+import pytest
+from fastapi import APIRouter, FastAPI, Request, Security
+from fastapi.testclient import TestClient
+
+from airflow.api_fastapi.execution_api.datamodels.token import TIToken
+from airflow.api_fastapi.execution_api.security import ExecutionAPIRoute,
_jwt_bearer, require_auth
+
+
+class TestExecutionAPIRoute:
+ """Unit tests for ExecutionAPIRoute precomputing allowed_token_types from
Security scopes."""
+
+ def test_defaults_to_execution_only(self):
+ route = ExecutionAPIRoute(
+ path="/test",
+ endpoint=lambda: None,
+ dependencies=[Security(require_auth)],
+ )
+ assert route.allowed_token_types == frozenset({"execution"})
+
+ def test_extracts_token_scopes(self):
+ route = ExecutionAPIRoute(
+ path="/test",
+ endpoint=lambda: None,
+ dependencies=[
+ Security(require_auth),
+ Security(require_auth, scopes=["token:execution",
"token:workload"]),
+ ],
+ )
+ assert route.allowed_token_types == frozenset({"execution",
"workload"})
+
+ def test_ignores_non_token_scopes(self):
+ route = ExecutionAPIRoute(
+ path="/test",
+ endpoint=lambda: None,
+ dependencies=[
+ Security(require_auth, scopes=["ti:self", "token:execution"]),
+ ],
+ )
+ assert route.allowed_token_types == frozenset({"execution"})
+
+ def test_rejects_invalid_token_types(self):
+ with pytest.raises(ValueError, match="Invalid token types"):
+ ExecutionAPIRoute(
+ path="/test",
+ endpoint=lambda: None,
+ dependencies=[
+ Security(require_auth, scopes=["token:bogus"]),
+ ],
+ )
+
+
+class TestTokenTypeScopeEnforcement:
+ """End-to-end: ExecutionAPIRoute + require_auth enforce token types via
Security scopes."""
+
+ @pytest.fixture
+ def token_type_app(self):
+ """
+ Mirrors the real router structure: an authenticated_router with
Security(require_auth),
+ a child ti_id_router with ExecutionAPIRoute and ti:self, and a
specific endpoint on that
+ router opting in to workload tokens via endpoint-level Security scopes.
+ """
+ app = FastAPI()
+
+ authenticated_router = APIRouter(dependencies=[Security(require_auth)])
+ ti_id_router = APIRouter(
+ route_class=ExecutionAPIRoute,
+ dependencies=[Security(require_auth, scopes=["ti:self"])],
+ )
+
+ @ti_id_router.get("/{task_instance_id}/state")
+ def default_endpoint(task_instance_id: str):
+ return {"ok": True}
+
+ @ti_id_router.get(
+ "/{task_instance_id}/run",
+ dependencies=[Security(require_auth, scopes=["token:execution",
"token:workload"])],
+ )
+ def workload_endpoint(task_instance_id: str):
+ return {"ok": True}
+
+ authenticated_router.include_router(ti_id_router,
prefix="/task-instances")
+ app.include_router(authenticated_router)
+
+ return app
+
+ TI_ID = "00000000-0000-0000-0000-000000000001"
+
+ def _override_jwt(self, app, scope: str):
+ ti_id = self.TI_ID
+
+ async def mock_jwt(request: Request):
+ return TIToken(id=UUID(ti_id), claims={"scope": scope})
+
+ app.dependency_overrides[_jwt_bearer] = mock_jwt
+
+ def test_workload_token_rejected_on_default_route(self, token_type_app):
+ self._override_jwt(token_type_app, "workload")
+ client = TestClient(token_type_app)
+
+ resp = client.get(f"/task-instances/{self.TI_ID}/state",
headers={"Authorization": "Bearer fake"})
+ assert resp.status_code == 403
+ assert "Token type 'workload' not allowed" in resp.json()["detail"]
+
+ def test_workload_token_accepted_on_opted_in_route(self, token_type_app):
+ self._override_jwt(token_type_app, "workload")
+ client = TestClient(token_type_app)
+
+ resp = client.get(f"/task-instances/{self.TI_ID}/run",
headers={"Authorization": "Bearer fake"})
+ assert resp.status_code == 200
+
+ def test_execution_token_accepted_on_both_routes(self, token_type_app):
+ self._override_jwt(token_type_app, "execution")
+ client = TestClient(token_type_app)
+
+ state = client.get(f"/task-instances/{self.TI_ID}/state",
headers={"Authorization": "Bearer fake"})
+ run = client.get(f"/task-instances/{self.TI_ID}/run",
headers={"Authorization": "Bearer fake"})
+ assert state.status_code == 200
+ assert run.status_code == 200
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
index ea1153f01cb..d9ec3916187 100644
---
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
+++
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
@@ -76,13 +76,16 @@ def _create_asset_aliases(session, num: int = 2) -> None:
@pytest.fixture
-def client_with_extra_route(): ...
+def _use_real_jwt_bearer(exec_app):
+ """Remove the mock jwt_bearer override so the real JWTBearer.__call__
runs."""
+ from airflow.api_fastapi.execution_api.security import _jwt_bearer
+ exec_app.dependency_overrides.pop(_jwt_bearer, None)
-def test_id_matches_sub_claim(client, session, create_task_instance):
- # Test that this is validated at the router level, so we don't have to
test it in each component
- # We validate it is set correctly, and test it once
[email protected]("_use_real_jwt_bearer")
+def test_id_matches_sub_claim(client, session, create_task_instance):
+ """Test that scope validation (ti:self) is enforced at the router level."""
ti = create_task_instance(
task_id="test_ti_run_state_conflict_if_not_queued",
state="queued",
@@ -90,17 +93,10 @@ def test_id_matches_sub_claim(client, session,
create_task_instance):
session.commit()
validator = mock.AsyncMock(spec=JWTValidator)
- claims = {"sub": ti.id}
-
- def side_effect(cred, validators):
- if not validators:
- return claims
- if str(validators["sub"]["value"]) != str(ti.id):
- raise RuntimeError("Fake auth denied")
- return claims
-
- validator.avalidated_claims.side_effect = side_effect
-
+ validator.avalidated_claims.return_value = {
+ "sub": str(ti.id),
+ "scope": "execution",
+ }
lifespan.registry.register_value(JWTValidator, validator)
payload = {
@@ -113,15 +109,10 @@ def test_id_matches_sub_claim(client, session,
create_task_instance):
resp =
client.patch("/execution/task-instances/9c230b40-da03-451d-8bd7-be30471be383/run",
json=payload)
assert resp.status_code == 403
- assert validator.avalidated_claims.call_args_list[1] == mock.call(
- mock.ANY, {"sub": {"essential": True, "value":
"9c230b40-da03-451d-8bd7-be30471be383"}}
- )
validator.avalidated_claims.reset_mock()
resp = client.patch(f"/execution/task-instances/{ti.id}/run", json=payload)
-
assert resp.status_code == 200, resp.json()
-
validator.avalidated_claims.assert_awaited()
@@ -2925,3 +2916,88 @@ class TestTIPatchRenderedMapIndex:
)
assert response.status_code == 422
+
+
[email protected]("_use_real_jwt_bearer")
+class TestTokenTypeValidation:
+ """Test token scope enforcement (workload vs execution)."""
+
+ def test_workload_scope_rejected_on_default_endpoints(self, client,
session, create_task_instance):
+ """workload scoped tokens should be rejected on endpoints without
token:workload Security scope."""
+ ti = create_task_instance(task_id="test_ti_run_heartbeat",
state=State.RUNNING)
+ session.commit()
+
+ validator = mock.AsyncMock(spec=JWTValidator)
+ validator.avalidated_claims.side_effect = lambda cred, validators: {
+ "sub": str(ti.id),
+ "scope": "workload",
+ "exp": 9999999999,
+ "iat": 1000000000,
+ }
+ lifespan.registry.register_value(JWTValidator, validator)
+
+ payload = {"hostname": "test-host", "pid": 100}
+ resp = client.put(f"/execution/task-instances/{ti.id}/heartbeat",
json=payload)
+ assert resp.status_code == 403
+ assert "Token type 'workload' not allowed" in resp.json()["detail"]
+
+ def test_execution_scope_accepted_on_all_endpoints(self, client, session,
create_task_instance):
+ """execution scoped tokens should be able to call all endpoints."""
+ ti = create_task_instance(task_id="test_ti_star", state=State.RUNNING)
+ session.commit()
+
+ validator = mock.AsyncMock(spec=JWTValidator)
+ validator.avalidated_claims.side_effect = lambda cred, validators: {
+ "sub": str(ti.id),
+ "scope": "execution",
+ "exp": 9999999999,
+ "iat": 1000000000,
+ }
+ lifespan.registry.register_value(JWTValidator, validator)
+
+ payload = {"state": "success", "end_date": "2024-10-31T13:00:00Z"}
+ resp = client.patch(f"/execution/task-instances/{ti.id}/state",
json=payload)
+ assert resp.status_code in [200, 204]
+
+ def test_invalid_scope_value_rejected(self, client, session,
create_task_instance):
+ """Tokens with unrecognized scope values should be rejected."""
+ ti = create_task_instance(task_id="test_invalid_scope",
state=State.QUEUED)
+ session.commit()
+
+ validator = mock.AsyncMock(spec=JWTValidator)
+ validator.avalidated_claims.side_effect = lambda cred, validators: {
+ "sub": str(ti.id),
+ "scope": "bogus:scope",
+ "exp": 9999999999,
+ "iat": 1000000000,
+ }
+ lifespan.registry.register_value(JWTValidator, validator)
+
+ payload = {
+ "state": "running",
+ "hostname": "test-host",
+ "unixname": "test-user",
+ "pid": 100,
+ "start_date": "2024-10-31T12:00:00Z",
+ }
+
+ resp = client.patch(f"/execution/task-instances/{ti.id}/run",
json=payload)
+ assert resp.status_code == 403
+ assert "Invalid token scope" in resp.json()["detail"]
+
+ def test_no_scope_defaults_to_execution(self, client, session,
create_task_instance):
+ """Tokens without scope claim should default to 'execution'."""
+ ti = create_task_instance(task_id="test_no_scope", state=State.RUNNING)
+ session.commit()
+
+ validator = mock.AsyncMock(spec=JWTValidator)
+ validator.avalidated_claims.side_effect = lambda cred, validators: {
+ "sub": str(ti.id),
+ "exp": 9999999999,
+ "iat": 1000000000,
+ }
+ lifespan.registry.register_value(JWTValidator, validator)
+
+ payload = {"state": "success", "end_date": "2024-10-31T13:00:00Z"}
+ resp = client.patch(f"/execution/task-instances/{ti.id}/state",
json=payload)
+ assert resp.status_code in [200, 204]
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py
index 59b206441de..93cd8ca672e 100644
---
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py
+++
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_variables.py
@@ -41,15 +41,15 @@ def setup_method():
@pytest.fixture
def access_denied(client):
- from airflow.api_fastapi.execution_api.deps import JWTBearerDep
from airflow.api_fastapi.execution_api.routes.variables import
has_variable_access
+ from airflow.api_fastapi.execution_api.security import CurrentTIToken
last_route = client.app.routes[-1]
assert isinstance(last_route, Mount)
assert isinstance(last_route.app, FastAPI)
exec_app = last_route.app
- async def _(request: Request, variable_key: str, token=JWTBearerDep):
+ async def _(request: Request, variable_key: str, token=CurrentTIToken):
await has_variable_access(request, variable_key, token)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py
index 554c2ad2c84..2135cb970a4 100644
---
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py
+++
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_xcoms.py
@@ -48,8 +48,8 @@ def reset_db():
@pytest.fixture
def access_denied(client):
- from airflow.api_fastapi.execution_api.deps import JWTBearerDep
from airflow.api_fastapi.execution_api.routes.xcoms import has_xcom_access
+ from airflow.api_fastapi.execution_api.security import CurrentTIToken
last_route = client.app.routes[-1]
assert isinstance(last_route.app, FastAPI)
@@ -61,7 +61,7 @@ def access_denied(client):
run_id: str = Path(),
task_id: str = Path(),
xcom_key: str = Path(alias="key"),
- token=JWTBearerDep,
+ token=CurrentTIToken,
):
await has_xcom_access(dag_id, run_id, task_id, xcom_key, request,
token)
raise HTTPException(