This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new dbab68b23e9 Fix: Ensure JWTValidator handles GUESS algorithm with JWKS 
(#63115)
dbab68b23e9 is described below

commit dbab68b23e92f4335dfd1b0bec730105956d2d8d
Author: Henry Chen <[email protected]>
AuthorDate: Wed Mar 11 21:07:36 2026 +0800

    Fix: Ensure JWTValidator handles GUESS algorithm with JWKS (#63115)
    
    * Fix: Ensure JWTValidator handles GUESS algorithm with JWKS
    
    - Updated `avalidated_claims` to read the signing algorithm (`alg`) from 
the token header when `jwt_algorithm` is set to "GUESS".
    - Passed the raw key (`key.key`) instead of the `PyJWK` object to prevent 
pyjwt from overriding the algorithm with `PyJWK.algorithm_name`.
    
    * Use algorithm_name instead of the token header
---
 .../src/airflow/api_fastapi/auth/tokens.py         | 24 +++++++++--------
 .../tests/unit/api_fastapi/auth/test_tokens.py     | 31 ++++++++++++++++++++++
 2 files changed, 44 insertions(+), 11 deletions(-)

diff --git a/airflow-core/src/airflow/api_fastapi/auth/tokens.py 
b/airflow-core/src/airflow/api_fastapi/auth/tokens.py
index 4732164be71..3375853a29a 100644
--- a/airflow-core/src/airflow/api_fastapi/auth/tokens.py
+++ b/airflow-core/src/airflow/api_fastapi/auth/tokens.py
@@ -93,7 +93,7 @@ def _guess_best_algorithm(key: AllowedPrivateKeys):
     from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
 
     if isinstance(key, RSAPrivateKey):
-        return "RS512"
+        return "RS256"
     if isinstance(key, Ed25519PrivateKey):
         return "EdDSA"
     raise ValueError(f"Unknown key object {type(key)}")
@@ -291,14 +291,8 @@ class JWTValidator:
             raise ValueError("Exactly one of private_key and secret_key must 
be specified")
 
         if self.algorithm == ["GUESS"]:
-            if self.jwks:
-                # TODO: We could probably populate this from the jwks 
document, but we don't have that at
-                # construction time.
-                raise ValueError(
-                    "Cannot guess the algorithm when using JWKS - please 
specify it in the config option "
-                    "[api_auth] jwt_algorithm"
-                )
-            self.algorithm = ["HS512"]
+            if not self.jwks:
+                self.algorithm = ["HS512"]
 
     def _get_kid_from_header(self, unvalidated: str) -> str:
         header = jwt.get_unverified_header(unvalidated)
@@ -326,13 +320,21 @@ class JWTValidator:
     ) -> dict[str, Any]:
         """Decode the JWT token, returning the validated claims or raising an 
exception."""
         key = await self._get_validation_key(unvalidated)
+        algorithms = self.algorithm
+        validation_key: str | jwt.PyJWK | Any = key
+        if algorithms == ["GUESS"] and isinstance(key, jwt.PyJWK):
+            if not key.algorithm_name:
+                raise jwt.InvalidTokenError("Missing algorithm in JWK")
+            algorithms = [key.algorithm_name]
+            validation_key = key.key
+
         claims = jwt.decode(
             unvalidated,
-            key,
+            validation_key,
             audience=self.audience,
             issuer=self.issuer,
             options={"require": list(self.required_claims)},
-            algorithms=self.algorithm,
+            algorithms=algorithms,
             leeway=self.leeway,
         )
 
diff --git a/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py 
b/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py
index e477c42af4f..6b848f723a0 100644
--- a/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py
+++ b/airflow-core/tests/unit/api_fastapi/auth/test_tokens.py
@@ -264,6 +264,37 @@ async def 
test_jwt_generate_validate_roundtrip_with_jwks(private_key, algorithm,
         assert await validator.avalidated_claims(token)
 
 
[email protected]("private_key", ["rsa_private_key", 
"ed25519_private_key"], indirect=True)
+async def 
test_jwt_validate_roundtrip_with_jwks_and_guess_algorithm(private_key, 
tmp_path: pathlib.Path):
+    jwk_content = json.dumps({"keys": [key_to_jwk_dict(private_key, 
"custom-kid")]})
+
+    jwks = tmp_path.joinpath("jwks.json")
+    await anyio.Path(jwks).write_text(jwk_content)
+
+    priv_key = tmp_path.joinpath("key.pem")
+    await anyio.Path(priv_key).write_bytes(key_to_pem(private_key))
+
+    with conf_vars(
+        {
+            ("api_auth", "trusted_jwks_url"): str(jwks),
+            ("api_auth", "jwt_kid"): "custom-kid",
+            ("api_auth", "jwt_issuer"): "http://my-issuer.localdomain";,
+            ("api_auth", "jwt_private_key_path"): str(priv_key),
+            ("api_auth", "jwt_algorithm"): "GUESS",
+            ("api_auth", "jwt_secret"): "",
+        }
+    ):
+        gen = JWTGenerator(audience="airflow1", valid_for=300)
+        token = gen.generate({"sub": "test"})
+
+        validator = JWTValidator(
+            audience="airflow1",
+            leeway=0,
+            **get_sig_validation_args(make_secret_key_if_needed=False),
+        )
+        assert await validator.avalidated_claims(token)
+
+
 class TestRevokeToken:
     pytestmark = [pytest.mark.db_test]
 

Reply via email to