This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1fd74fb47ca Config improvments to using JWKS with the JWT code (#48054)
1fd74fb47ca is described below
commit 1fd74fb47ca2904b5e2f5000f8e4d04da337319d
Author: Ash Berlin-Taylor <[email protected]>
AuthorDate: Fri Mar 21 17:45:00 2025 +0000
Config improvments to using JWKS with the JWT code (#48054)
Previously it was hard-coded to only use the thumbprint of the privkey as
the
`kid`, which is a sensible default but not the only way JWKS documents work,
and it has to match between the generator and validator.
Additionally the JWTValidator had no way of specifying the algorithm(s) to
use
when using a JWKS document and it would always fail. This now respects the
same config option that the JWTGenerator looked at.
---
.../src/airflow/api_fastapi/auth/tokens.py | 40 ++++++++++++++----
.../src/airflow/api_fastapi/execution_api/app.py | 1 -
.../src/airflow/config_templates/config.yml | 15 ++++++-
.../src/airflow/utils/log/file_task_handler.py | 2 +
.../tests/unit/api_fastapi/auth/test_tokens.py | 48 +++++++++++++++++-----
5 files changed, 86 insertions(+), 20 deletions(-)
diff --git a/airflow-core/src/airflow/api_fastapi/auth/tokens.py
b/airflow-core/src/airflow/api_fastapi/auth/tokens.py
index 10f71078171..7bb2f9d4a4a 100644
--- a/airflow-core/src/airflow/api_fastapi/auth/tokens.py
+++ b/airflow-core/src/airflow/api_fastapi/auth/tokens.py
@@ -22,7 +22,7 @@ import time
from base64 import urlsafe_b64encode
from collections.abc import Sequence
from datetime import datetime
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, Callable, Literal, overload
import attrs
import httpx
@@ -231,6 +231,16 @@ def _conf_factory(section, key, **kwargs):
return factory
+@overload
+def _conf_list_factory(section, key, first_only: Literal[True], **kwargs) ->
Callable[[], str]: ...
+
+
+@overload
+def _conf_list_factory(
+ section, key, first_only: Literal[False] = False, **kwargs
+) -> Callable[[], list[str]]: ...
+
+
def _conf_list_factory(section, key, first_only: bool = False, **kwargs):
def factory() -> list[str] | str:
from airflow.configuration import conf
@@ -239,7 +249,7 @@ def _conf_list_factory(section, key, first_only: bool =
False, **kwargs):
if first_only and val:
return val[0]
- return val
+ return val or []
return factory
@@ -262,12 +272,16 @@ class JWTValidator:
jwks: JWKS | None = None
secret_key: str | None = attrs.field(repr=False, default=None,
converter=lambda v: None if v == "" else v)
issuer: str | list[str] | None = attrs.field(
- factory=_conf_list_factory("api_auth", "jwt_issuer", fallback=None)
+ factory=_conf_list_factory("api_auth", "jwt_issuer", fallback=None),
+ # Ensure we have None, instead of an empty list, else pyjwt will fail
to validate it
+ converter=lambda v: None if v == [] else v,
)
# By default, we just validate these
required_claims: frozenset[str] = frozenset({"exp", "iat", "nbf"})
audience: str | Sequence[str]
- algorithm: list[str] = attrs.field(default=["GUESS"], converter=_to_list)
+ algorithm: list[str] = attrs.field(
+ factory=_conf_list_factory("api_auth", "jwt_algorithm",
fallback="GUESS"), converter=_to_list
+ )
leeway: float = attrs.field(factory=_conf_factory("api_auth",
"jwt_leeway"), converter=int)
@@ -277,8 +291,12 @@ class JWTValidator:
if self.algorithm == ["GUESS"]:
if self.jwks:
- # TODO: We could probably populate this from the jwks document?
- raise ValueError("Cannot guess the algorithm when using JWKS")
+ # TODO: We could probably populate this from the jwks
document, but we don't have that at
+ # construction time.
+ raise ValueError(
+ "Cannot guess the algorithm when using JWKS - please
specify it in the config option "
+ "[api_auth] jwt_algorithm"
+ )
else:
self.algorithm = ["HS512"]
@@ -380,19 +398,25 @@ class JWTGenerator:
issuer: str | list[str] | None = attrs.field(
factory=_conf_list_factory("api_auth", "jwt_issuer", first_only=True,
fallback=None)
)
- algorithm: str = attrs.field(factory=_conf_factory("api_auth",
"jwt_algorithm", fallback="GUESS"))
+ algorithm: str = attrs.field(
+ factory=_conf_list_factory("api_auth", "jwt_algorithm",
first_only=True, fallback="GUESS")
+ )
@kid.default
def _generate_kid(self):
if not self._private_key:
return "not-used"
+ if kid := _conf_factory("api_auth", "jwt_kid", fallback=None)():
+ return kid
+
+ # Generate it from the thumbprint of the private key
info = key_to_jwk_dict(self._private_key)
return info["kid"]
def __attrs_post_init__(self):
if not (self._private_key is None) ^ (self._secret_key is None):
- raise ValueError("Exactly one of privaate_key and secret_key must
be specified")
+ raise ValueError("Exactly one of private_key and secret_key must
be specified")
if self.algorithm == "GUESS":
if self._private_key:
diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/app.py
b/airflow-core/src/airflow/api_fastapi/execution_api/app.py
index 3546a57eeb4..12e86129faa 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/app.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/app.py
@@ -58,7 +58,6 @@ def _jwt_validator():
validator = JWTValidator(
required_claims=required_claims,
issuer=issuer,
- leeway=conf.getint("api_auth", "jwt_leeway"),
audience=conf.get_mandatory_list_value("execution_api",
"jwt_audience"),
**get_sig_validation_args(make_secret_key_if_needed=False),
)
diff --git a/airflow-core/src/airflow/config_templates/config.yml
b/airflow-core/src/airflow/config_templates/config.yml
index 9c1991e813d..eb99d48a06b 100644
--- a/airflow-core/src/airflow/config_templates/config.yml
+++ b/airflow-core/src/airflow/config_templates/config.yml
@@ -1510,10 +1510,23 @@ api_auth:
This value must be appropriate for the given private key type.
- Default is "HS512" if ``jwt_secret`` is set, or "EdDSA" otherwise
+ If this is not specified Airflow makes some guesses as what algorithm
is best based on the key type.
+
+ ("HS512" if ``jwt_secret`` is set, otherwise a key-type specific guess)
example: '"EdDSA" or "HS512"'
type: string
default: ~
+ jwt_kid:
+ version_added: 3.0.0
+ description: |
+ The Key ID to place in header when generating JWTs. Not used in the
validation path.
+
+ If this is not specified the RFC7638 thumbprint of the private key
will be used.
+
+ Ignored when ``jwt_secret`` is used.
+ type: string
+ example: "my-key-id"
+ default: ~
trusted_jwks_url:
version_added: 3.0.0
description: |
diff --git a/airflow-core/src/airflow/utils/log/file_task_handler.py
b/airflow-core/src/airflow/utils/log/file_task_handler.py
index d4b637d3d49..2581c7f1e74 100644
--- a/airflow-core/src/airflow/utils/log/file_task_handler.py
+++ b/airflow-core/src/airflow/utils/log/file_task_handler.py
@@ -101,6 +101,8 @@ def _fetch_logs_from_service(url, log_relative_path):
timeout = conf.getint("webserver", "log_fetch_timeout_sec", fallback=None)
generator = JWTGenerator(
secret_key=get_signing_key("webserver", "secret_key"),
+ private_key=None,
+ issuer=None,
valid_for=conf.getint("webserver", "log_request_clock_grace",
fallback=30),
audience="task-instance-logs",
)
diff --git a/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py
b/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py
index 02d12996d14..82a3f680cb1 100644
--- a/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py
+++ b/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py
@@ -33,11 +33,14 @@ from airflow.api_fastapi.auth.tokens import (
JWTGenerator,
JWTValidator,
generate_private_key,
+ get_sig_validation_args,
key_to_jwk_dict,
key_to_pem,
)
from airflow.utils import timezone
+from tests_common.test_utils.config import conf_vars
+
if TYPE_CHECKING:
from cryptography.hazmat.primitives.asymmetric.ed25519 import
Ed25519PrivateKey
from kgb import SpyAgency
@@ -52,16 +55,6 @@ def private_key(request):
return request.getfixturevalue(request.param or "ed25519_private_key")
[email protected]
-def mock_kid():
- return "test-kid"
-
-
[email protected]
-def mock_subject():
- return "test-subject"
-
-
class TestJWKS:
@pytest.mark.parametrize("private_key", ["rsa_private_key",
"ed25519_private_key"], indirect=True)
async def test_fetch_jwks_success(self, private_key):
@@ -200,6 +193,41 @@ async def test_jwt_wrong_subject(jwt_generator,
jwt_validator):
)
[email protected](
+ ["private_key", "algorithm"],
+ [("rsa_private_key", "RS256"), ("ed25519_private_key", "EdDSA")],
+ indirect=["private_key"],
+)
+async def test_jwt_generate_validate_roundtrip_with_jwks(private_key,
algorithm, tmp_path: pathlib.Path):
+ jwk_content = json.dumps({"keys": [key_to_jwk_dict(private_key,
"custom-kid")]})
+
+ jwks = tmp_path.joinpath("jwks.json")
+ jwks.write_text(jwk_content)
+
+ priv_key = tmp_path.joinpath("key.pem")
+ priv_key.write_bytes(key_to_pem(private_key))
+
+ with conf_vars(
+ {
+ ("api_auth", "trusted_jwks_url"): str(jwks),
+ ("api_auth", "jwt_kid"): "custom-kid",
+ ("api_auth", "jwt_issuer"): "http://my-issuer.localdomain",
+ ("api_auth", "jwt_private_key_path"): str(priv_key),
+ ("api_auth", "jwt_algorithm"): algorithm,
+ ("api_auth", "jwt_secret"): "",
+ }
+ ):
+ gen = JWTGenerator(audience="airflow1", valid_for=300)
+ token = gen.generate({"sub": "test"})
+
+ validator = JWTValidator(
+ audience="airflow1",
+ leeway=0,
+ **get_sig_validation_args(make_secret_key_if_needed=False),
+ )
+ assert await validator.avalidated_claims(token)
+
+
@pytest.fixture(scope="session")
def rsa_private_key():
return generate_private_key()