This is an automated email from the ASF dual-hosted git repository.

kaxil pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new db04a4ef8c2 Fix older and custom secrets backends breaking on Airflow 
3.2 (#68302)
db04a4ef8c2 is described below

commit db04a4ef8c2042c0e770e23749b66e2525c39569
Author: Kaxil Naik <[email protected]>
AuthorDate: Tue Jun 9 20:59:41 2026 +0100

    Fix older and custom secrets backends breaking on Airflow 3.2 (#68302)
    
    Airflow 3.2 (AIP-67 multi-team) added a `team_name` keyword to the
    secrets backend API. BaseSecretsBackend.get_connection() forwarded it
    to get_conn_value() unconditionally, and the core get_*_from_secrets()
    call sites passed it straight to get_connection()/get_variable(). Any
    backend whose override predates the keyword (the old (self, conn_id) /
    (self, key) signature) raised TypeError on every lookup, which the
    backend-iteration loop swallowed as "not found", so connections and
    variables silently failed to resolve after upgrading to 3.2.
    
    Route those calls through a helper that introspects the override and
    forwards team_name only when it is accepted (a team_name parameter or
    **kwargs), omitting it for pre-3.2 signatures so older bundled providers
    and custom backends keep working -- in both single-team and multi-team
    deployments -- without being forced to add the parameter. A TypeError
    from inside an accepting backend is left to propagate, not retried
    without team_name, so a team-scoped lookup is never silently resolved
    against the global scope. Applied to the get_conn_value forward in the
    base class and to the Connection/Variable get_*_from_secrets call sites,
    which also covers backends overriding get_connection/get_variable
    directly (e.g. Vault).
---
 airflow-core/src/airflow/models/connection.py      |  5 +-
 airflow-core/src/airflow/models/variable.py        |  5 +-
 airflow-core/tests/unit/always/test_secrets.py     | 50 ++++++++++++
 .../src/airflow_shared/secrets_backend/base.py     | 40 ++++++++-
 .../tests/secrets_backend/test_base.py             | 95 ++++++++++++++++++++++
 5 files changed, 192 insertions(+), 3 deletions(-)

diff --git a/airflow-core/src/airflow/models/connection.py 
b/airflow-core/src/airflow/models/connection.py
index ee9fc222e40..1b4b0f8f867 100644
--- a/airflow-core/src/airflow/models/connection.py
+++ b/airflow-core/src/airflow/models/connection.py
@@ -31,6 +31,7 @@ from sqlalchemy import ForeignKey, Integer, String, Text, 
select
 from sqlalchemy.orm import Mapped, mapped_column, reconstructor
 
 from airflow._shared.module_loading import import_string
+from airflow._shared.secrets_backend.base import call_secrets_backend_method
 from airflow._shared.secrets_masker import mask_secret
 from airflow.exceptions import AirflowException, AirflowNotFoundException
 from airflow.models.base import ID_LEN, Base
@@ -517,7 +518,9 @@ class Connection(Base, FernetFieldsMixin, LoggingMixin):
         # iterate over backends if not in cache (or expired)
         for secrets_backend in ensure_secrets_loaded():
             try:
-                conn = secrets_backend.get_connection(conn_id=conn_id, 
team_name=team_name)
+                conn = call_secrets_backend_method(
+                    secrets_backend.get_connection, team_name=team_name, 
conn_id=conn_id
+                )
                 if conn:
                     SecretCache.save_connection_uri(conn_id, conn.get_uri(), 
team_name=team_name)
                     return conn
diff --git a/airflow-core/src/airflow/models/variable.py 
b/airflow-core/src/airflow/models/variable.py
index f7a2aba5dda..b06e73cd5f5 100644
--- a/airflow-core/src/airflow/models/variable.py
+++ b/airflow-core/src/airflow/models/variable.py
@@ -28,6 +28,7 @@ from sqlalchemy import Boolean, ForeignKey, Integer, String, 
Text, delete, or_,
 from sqlalchemy.dialects.mysql import MEDIUMTEXT
 from sqlalchemy.orm import Mapped, declared_attr, mapped_column, 
reconstructor, synonym
 
+from airflow._shared.secrets_backend.base import call_secrets_backend_method
 from airflow._shared.secrets_masker import mask_secret
 from airflow.configuration import conf, ensure_secrets_loaded
 from airflow.models.base import ID_LEN, Base
@@ -479,7 +480,9 @@ class Variable(Base, LoggingMixin):
         # iterate over backends if not in cache (or expired)
         for secrets_backend in ensure_secrets_loaded():
             try:
-                var_val = secrets_backend.get_variable(key=key, 
team_name=team_name)
+                var_val = call_secrets_backend_method(
+                    secrets_backend.get_variable, team_name=team_name, key=key
+                )
                 if var_val is not None:
                     break
             except AirflowSecretsBackendAccessDenied:
diff --git a/airflow-core/tests/unit/always/test_secrets.py 
b/airflow-core/tests/unit/always/test_secrets.py
index d09f04df7df..0c2749ca9d1 100644
--- a/airflow-core/tests/unit/always/test_secrets.py
+++ b/airflow-core/tests/unit/always/test_secrets.py
@@ -25,6 +25,7 @@ from airflow.configuration import ensure_secrets_loaded, 
initialize_secrets_back
 from airflow.models import Connection, Variable
 from airflow.sdk import SecretCache
 from airflow.sdk.exceptions import AirflowNotFoundException
+from airflow.secrets import BaseSecretsBackend
 
 from tests_common.test_utils.config import conf_vars
 from tests_common.test_utils.db import clear_db_variables
@@ -227,6 +228,55 @@ class TestVariableFromSecrets:
         assert Variable.get_variable_from_secrets(key="_team___myvar") is None
 
 
+class _LegacyGetConnectionBackend(BaseSecretsBackend):
+    """Backend overriding ``get_connection`` with the pre-3.2 ``(self, 
conn_id)`` signature (e.g. Vault)."""
+
+    def __init__(self, conns: dict[str, Connection]):
+        self._conns = conns
+
+    def get_connection(self, conn_id):
+        return self._conns.get(conn_id)
+
+
+class _LegacyGetVariableBackend(BaseSecretsBackend):
+    """Backend overriding ``get_variable`` with the pre-3.2 ``(self, key)`` 
signature."""
+
+    def __init__(self, variables: dict[str, str]):
+        self._vars = variables
+
+    def get_variable(self, key):
+        return self._vars.get(key)
+
+
+@skip_if_force_lowest_dependencies_marker
+class TestLegacyBackendSignatureCompat:
+    """Backends whose overrides predate the ``team_name`` keyword must keep 
working (issue #1333)."""
+
+    def setup_method(self) -> None:
+        SecretCache.reset()
+
+    @pytest.mark.parametrize("team_name", [None, "team_a"])
+    @conf_vars({("core", "multi_team"): "True"})
+    @mock.patch.dict("sys.modules", {"airflow.sdk.execution_time.task_runner": 
None})
+    def test_get_connection_with_legacy_get_connection_override(self, 
team_name):
+        backend = _LegacyGetConnectionBackend(
+            {"legacy_conn": Connection(conn_id="legacy_conn", 
conn_type="mysql", host="h")}
+        )
+        with mock.patch("airflow.configuration.ensure_secrets_loaded", 
return_value=[backend]):
+            conn = Connection.get_connection_from_secrets("legacy_conn", 
team_name=team_name)
+
+        assert conn.conn_id == "legacy_conn"
+        assert conn.conn_type == "mysql"
+
+    @pytest.mark.parametrize("team_name", [None, "team_a"])
+    def test_get_variable_with_legacy_get_variable_override(self, team_name):
+        backend = _LegacyGetVariableBackend({"legacy_var": "secret_value"})
+        with mock.patch("airflow.models.variable.ensure_secrets_loaded", 
return_value=[backend]):
+            value = Variable.get_variable_from_secrets("legacy_var", 
team_name=team_name)
+
+        assert value == "secret_value"
+
+
 @skip_if_force_lowest_dependencies_marker
 class TestSecretBackendKwargEnvVars:
     """Test per-key env var overrides for secrets backend kwargs."""
diff --git a/shared/secrets_backend/src/airflow_shared/secrets_backend/base.py 
b/shared/secrets_backend/src/airflow_shared/secrets_backend/base.py
index a2715b93ca9..e61197d2baa 100644
--- a/shared/secrets_backend/src/airflow_shared/secrets_backend/base.py
+++ b/shared/secrets_backend/src/airflow_shared/secrets_backend/base.py
@@ -16,7 +16,45 @@
 # under the License.
 from __future__ import annotations
 
+import inspect
 from abc import ABC
+from collections.abc import Callable
+
+
+def _accepts_team_name(method: Callable) -> bool:
+    """
+    Return whether a secrets-backend method accepts the ``team_name`` keyword.
+
+    Backends written before Airflow 3.2 override ``get_conn_value`` / 
``get_variable`` /
+    ``get_connection`` with the legacy ``(self, conn_id)`` / ``(self, key)`` 
signature.
+    AIP-67 (multi-team) added a ``team_name`` keyword; forwarding it to those 
raises
+    ``TypeError``. A method accepts it if it declares a ``team_name`` 
parameter or a
+    ``**kwargs`` catch-all.
+    """
+    try:
+        parameters = inspect.signature(method).parameters
+    except (TypeError, ValueError):
+        # Un-introspectable callable (e.g. C-implemented): assume the 3.2+ 
signature.
+        return True
+    return "team_name" in parameters or any(
+        p.kind is inspect.Parameter.VAR_KEYWORD for p in parameters.values()
+    )
+
+
+def call_secrets_backend_method(method: Callable, *, team_name: str | None, 
**kwargs):
+    """
+    Call a secrets-backend lookup ``method``, forwarding ``team_name`` only 
when supported.
+
+    Forward ``team_name`` to backends that accept it (3.2+ overrides) and omit 
it for
+    pre-3.2 overrides, so older bundled providers and custom backends keep 
working --
+    in both single-team and multi-team deployments -- without being forced to 
add the
+    parameter. A ``TypeError`` raised inside an accepting backend is left to 
propagate
+    rather than retried without ``team_name``, which could mask the error and 
resolve a
+    team-scoped lookup against the global scope.
+    """
+    if _accepts_team_name(method):
+        return method(team_name=team_name, **kwargs)
+    return method(**kwargs)
 
 
 class BaseSecretsBackend(ABC):
@@ -113,7 +151,7 @@ class BaseSecretsBackend(ABC):
         :param team_name: Team name associated to the task trying to access 
the connection (if any)
         :return: Connection object or None
         """
-        value = self.get_conn_value(conn_id=conn_id, team_name=team_name)
+        value = call_secrets_backend_method(self.get_conn_value, 
team_name=team_name, conn_id=conn_id)
         if value:
             return self.deserialize_connection(conn_id=conn_id, value=value)
         return None
diff --git a/shared/secrets_backend/tests/secrets_backend/test_base.py 
b/shared/secrets_backend/tests/secrets_backend/test_base.py
index cd2b7a0934a..e14e6204243 100644
--- a/shared/secrets_backend/tests/secrets_backend/test_base.py
+++ b/shared/secrets_backend/tests/secrets_backend/test_base.py
@@ -133,3 +133,98 @@ class TestBaseSecretsBackend:
         assert isinstance(conn, MockConnection)
         assert conn.conn_id == "test_conn"
         assert conn.uri == sample_conn_uri
+
+
+class _LegacyConnValueBackend(BaseSecretsBackend):
+    """Backend overriding ``get_conn_value`` with the pre-3.2 ``(self, 
conn_id)`` signature."""
+
+    def __init__(self, conn_values: dict[str, str]):
+        self.conn_values = conn_values
+        self._set_connection_class(MockConnection)
+
+    def get_conn_value(self, conn_id: str) -> str | None:
+        return self.conn_values.get(conn_id)
+
+
+class _TeamAwareConnValueBackend(BaseSecretsBackend):
+    """Backend whose ``get_conn_value`` accepts ``team_name`` (3.2+ 
signature)."""
+
+    def __init__(self, conn_values: dict[str, str]):
+        self.conn_values = conn_values
+        self.received_team_name: str | None = None
+        self._set_connection_class(MockConnection)
+
+    def get_conn_value(self, conn_id: str, team_name: str | None = None) -> 
str | None:
+        self.received_team_name = team_name
+        return self.conn_values.get(conn_id)
+
+
+class _KwargsConnValueBackend(BaseSecretsBackend):
+    """Backend whose ``get_conn_value`` swallows extra kwargs via 
``**kwargs``."""
+
+    def __init__(self, conn_values: dict[str, str]):
+        self.conn_values = conn_values
+        self.received_kwargs: dict = {}
+        self._set_connection_class(MockConnection)
+
+    def get_conn_value(self, conn_id: str, **kwargs) -> str | None:
+        self.received_kwargs = kwargs
+        return self.conn_values.get(conn_id)
+
+
+class _TeamAwareRaisingBackend(BaseSecretsBackend):
+    """``get_conn_value`` declares ``team_name`` but its body raises 
``TypeError``."""
+
+    def __init__(self):
+        self.call_count = 0
+        self._set_connection_class(MockConnection)
+
+    def get_conn_value(self, conn_id: str, team_name: str | None = None) -> 
str | None:
+        self.call_count += 1
+        raise TypeError("boom from inside the backend body")
+
+
+class TestTeamNameBackwardCompat:
+    """``get_connection`` must not forward ``team_name`` to overrides that 
predate it (issue #1333)."""
+
+    @pytest.mark.parametrize("team_name", [None, "team_a"])
+    def test_legacy_get_conn_value_signature_does_not_break(self, 
sample_conn_uri, team_name):
+        backend = _LegacyConnValueBackend(conn_values={"test_conn": 
sample_conn_uri})
+
+        conn = backend.get_connection(conn_id="test_conn", team_name=team_name)
+
+        assert isinstance(conn, MockConnection)
+        assert conn.conn_id == "test_conn"
+
+    def test_team_name_forwarded_when_override_accepts_it(self, 
sample_conn_uri):
+        backend = _TeamAwareConnValueBackend(conn_values={"test_conn": 
sample_conn_uri})
+
+        conn = backend.get_connection(conn_id="test_conn", team_name="team_a")
+
+        assert isinstance(conn, MockConnection)
+        assert backend.received_team_name == "team_a"
+
+    def test_team_name_forwarded_to_kwargs_override(self, sample_conn_uri):
+        backend = _KwargsConnValueBackend(conn_values={"test_conn": 
sample_conn_uri})
+
+        conn = backend.get_connection(conn_id="test_conn", team_name="team_a")
+
+        assert isinstance(conn, MockConnection)
+        assert backend.received_kwargs == {"team_name": "team_a"}
+
+    def test_team_aware_backend_typeerror_not_masked(self):
+        # A TypeError from inside a team_name-aware backend must propagate, 
not be
+        # retried without team_name (which would hide the error and could 
resolve the
+        # lookup against the global scope instead of the requested team).
+        backend = _TeamAwareRaisingBackend()
+
+        with pytest.raises(TypeError, match="boom from inside the backend 
body"):
+            backend.get_connection(conn_id="test_conn", team_name="team_a")
+
+        assert backend.call_count == 1
+
+    @pytest.mark.parametrize("team_name", [None, "team_a"])
+    def test_legacy_backend_missing_conn_returns_none(self, team_name):
+        backend = _LegacyConnValueBackend(conn_values={})
+
+        assert backend.get_connection(conn_id="missing", team_name=team_name) 
is None

Reply via email to