This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e9a895cee53 Pass team name to `is_authorized_connection`,
`is_authorized_variable` and `is_authorized_pool` in Airflow API (#55193)
e9a895cee53 is described below
commit e9a895cee533dd91556f84baba06f60bbd75ece6
Author: Vincent <[email protected]>
AuthorDate: Thu Sep 4 08:54:31 2025 -0400
Pass team name to `is_authorized_connection`, `is_authorized_variable` and
`is_authorized_pool` in Airflow API (#55193)
---
.../auth/managers/models/resource_details.py | 3 +
.../src/airflow/api_fastapi/core_api/security.py | 12 ++-
airflow-core/src/airflow/models/connection.py | 16 +++-
airflow-core/src/airflow/models/pool.py | 7 ++
airflow-core/src/airflow/models/variable.py | 14 +++-
.../unit/api_fastapi/core_api/test_security.py | 92 +++++++++++++++++++++-
airflow-core/tests/unit/models/test_connection.py | 18 +++++
airflow-core/tests/unit/models/test_dag.py | 14 ----
airflow-core/tests/unit/models/test_pool.py | 13 +++
airflow-core/tests/unit/models/test_variable.py | 12 +++
devel-common/src/tests_common/pytest_plugin.py | 18 +++++
11 files changed, 197 insertions(+), 22 deletions(-)
diff --git
a/airflow-core/src/airflow/api_fastapi/auth/managers/models/resource_details.py
b/airflow-core/src/airflow/api_fastapi/auth/managers/models/resource_details.py
index aed68f27aa8..afce15e6b0a 100644
---
a/airflow-core/src/airflow/api_fastapi/auth/managers/models/resource_details.py
+++
b/airflow-core/src/airflow/api_fastapi/auth/managers/models/resource_details.py
@@ -35,6 +35,7 @@ class ConnectionDetails:
"""Represents the details of a connection."""
conn_id: str | None = None
+ team_name: str | None = None
@dataclass
@@ -71,6 +72,7 @@ class PoolDetails:
"""Represents the details of a pool."""
name: str | None = None
+ team_name: str | None = None
@dataclass
@@ -78,6 +80,7 @@ class VariableDetails:
"""Represents the details of a variable."""
key: str | None = None
+ team_name: str | None = None
class AccessView(Enum):
diff --git a/airflow-core/src/airflow/api_fastapi/core_api/security.py
b/airflow-core/src/airflow/api_fastapi/core_api/security.py
index 900ad3fbb6a..6053900ce33 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/security.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/security.py
@@ -42,6 +42,7 @@ from
airflow.api_fastapi.auth.managers.models.resource_details import (
)
from airflow.api_fastapi.core_api.base import OrmClause
from airflow.configuration import conf
+from airflow.models import Connection, Pool, Variable
from airflow.models.dag import DagModel, DagRun, DagTag
from airflow.models.dagwarning import DagWarning
from airflow.models.taskinstance import TaskInstance as TI
@@ -223,10 +224,11 @@ def requires_access_pool(method: ResourceMethod) ->
Callable[[Request, BaseUser]
user: GetUserDep,
) -> None:
pool_name = request.path_params.get("pool_name")
+ team_name = Pool.get_team_name(pool_name) if pool_name else None
_requires_access(
is_authorized_callback=lambda:
get_auth_manager().is_authorized_pool(
- method=method, details=PoolDetails(name=pool_name), user=user
+ method=method, details=PoolDetails(name=pool_name,
team_name=team_name), user=user
)
)
@@ -239,10 +241,13 @@ def requires_access_connection(method: ResourceMethod) ->
Callable[[Request, Bas
user: GetUserDep,
) -> None:
connection_id = request.path_params.get("connection_id")
+ team_name = Connection.get_team_name(connection_id) if connection_id
else None
_requires_access(
is_authorized_callback=lambda:
get_auth_manager().is_authorized_connection(
- method=method,
details=ConnectionDetails(conn_id=connection_id), user=user
+ method=method,
+ details=ConnectionDetails(conn_id=connection_id,
team_name=team_name),
+ user=user,
)
)
@@ -273,10 +278,11 @@ def requires_access_variable(method: ResourceMethod) ->
Callable[[Request, BaseU
user: GetUserDep,
) -> None:
variable_key: str | None = request.path_params.get("variable_key")
+ team_name = Variable.get_team_name(variable_key) if variable_key else
None
_requires_access(
is_authorized_callback=lambda:
get_auth_manager().is_authorized_variable(
- method=method, details=VariableDetails(key=variable_key),
user=user
+ method=method, details=VariableDetails(key=variable_key,
team_name=team_name), user=user
),
)
diff --git a/airflow-core/src/airflow/models/connection.py
b/airflow-core/src/airflow/models/connection.py
index 33d292a645a..1567cab1dd2 100644
--- a/airflow-core/src/airflow/models/connection.py
+++ b/airflow-core/src/airflow/models/connection.py
@@ -27,7 +27,7 @@ from json import JSONDecodeError
from typing import Any
from urllib.parse import parse_qsl, quote, unquote, urlencode, urlsplit
-from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Text
+from sqlalchemy import Boolean, Column, ForeignKey, Integer, String, Text,
select
from sqlalchemy.orm import declared_attr, reconstructor, synonym
from sqlalchemy_utils import UUIDType
@@ -36,10 +36,12 @@ from airflow.configuration import ensure_secrets_loaded
from airflow.exceptions import AirflowException, AirflowNotFoundException
from airflow.models.base import ID_LEN, Base
from airflow.models.crypto import get_fernet
+from airflow.models.team import Team
from airflow.sdk import SecretCache
from airflow.utils.helpers import prune_dict
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.module_loading import import_string
+from airflow.utils.session import NEW_SESSION, provide_session
log = logging.getLogger(__name__)
# sanitize the `conn_id` pattern by allowing alphanumeric characters plus
@@ -150,6 +152,7 @@ class Connection(Base, LoggingMixin):
port: int | None = None,
extra: str | dict | None = None,
uri: str | None = None,
+ team_id: str | None = None,
):
super().__init__()
self.conn_id = sanitize_conn_id(conn_id)
@@ -178,6 +181,7 @@ class Connection(Base, LoggingMixin):
if self.password:
mask_secret(self.password)
mask_secret(quote(self.password))
+ self.team_id = team_id
@staticmethod
def _validate_extra(extra, conn_id) -> None:
@@ -584,3 +588,13 @@ class Connection(Base, LoggingMixin):
conn_repr = self.to_dict(prune_empty=True, validate=False)
conn_repr.pop("conn_id", None)
return json.dumps(conn_repr)
+
+ @staticmethod
+ @provide_session
+ def get_team_name(connection_id: str, session=NEW_SESSION) -> str | None:
+ stmt = (
+ select(Team.name)
+ .join(Connection, Team.id == Connection.team_id)
+ .where(Connection.conn_id == connection_id)
+ )
+ return session.scalar(stmt)
diff --git a/airflow-core/src/airflow/models/pool.py
b/airflow-core/src/airflow/models/pool.py
index 2a8b9b57c3f..7a205036530 100644
--- a/airflow-core/src/airflow/models/pool.py
+++ b/airflow-core/src/airflow/models/pool.py
@@ -24,6 +24,7 @@ from sqlalchemy_utils import UUIDType
from airflow.exceptions import AirflowException, PoolNotFound
from airflow.models.base import Base
+from airflow.models.team import Team
from airflow.ti_deps.dependencies_states import EXECUTION_STATES
from airflow.utils.db import exists_query
from airflow.utils.session import NEW_SESSION, provide_session
@@ -352,3 +353,9 @@ class Pool(Base):
if self.slots == -1:
return float("inf")
return self.slots - self.occupied_slots(session)
+
+ @staticmethod
+ @provide_session
+ def get_team_name(pool_name: str, session=NEW_SESSION) -> str | None:
+ stmt = select(Team.name).join(Pool, Team.id ==
Pool.team_id).where(Pool.pool == pool_name)
+ return session.scalar(stmt)
diff --git a/airflow-core/src/airflow/models/variable.py
b/airflow-core/src/airflow/models/variable.py
index 6e9a7537821..62bb268e89d 100644
--- a/airflow-core/src/airflow/models/variable.py
+++ b/airflow-core/src/airflow/models/variable.py
@@ -33,10 +33,11 @@ from airflow._shared.secrets_masker import mask_secret
from airflow.configuration import ensure_secrets_loaded
from airflow.models.base import ID_LEN, Base
from airflow.models.crypto import get_fernet
+from airflow.models.team import Team
from airflow.sdk import SecretCache
from airflow.secrets.metastore import MetastoreBackend
from airflow.utils.log.logging_mixin import LoggingMixin
-from airflow.utils.session import create_session
+from airflow.utils.session import NEW_SESSION, create_session, provide_session
if TYPE_CHECKING:
from sqlalchemy.orm import Session
@@ -57,11 +58,12 @@ class Variable(Base, LoggingMixin):
is_encrypted = Column(Boolean, unique=False, default=False)
team_id = Column(UUIDType(binary=False), ForeignKey("team.id"),
nullable=True)
- def __init__(self, key=None, val=None, description=None):
+ def __init__(self, key=None, val=None, description=None, team_id=None):
super().__init__()
self.key = key
self.val = val
self.description = description
+ self.team_id = team_id
@reconstructor
def on_db_load(self):
@@ -452,3 +454,11 @@ class Variable(Base, LoggingMixin):
SecretCache.save_variable(key, var_val) # we save None as well
return var_val
+
+ @staticmethod
+ @provide_session
+ def get_team_name(variable_key: str, session=NEW_SESSION) -> str | None:
+ stmt = (
+ select(Team.name).join(Variable, Team.id ==
Variable.team_id).where(Variable.key == variable_key)
+ )
+ return session.scalar(stmt)
diff --git a/airflow-core/tests/unit/api_fastapi/core_api/test_security.py
b/airflow-core/tests/unit/api_fastapi/core_api/test_security.py
index f407fbc7f8f..20cf233c1c5 100644
--- a/airflow-core/tests/unit/api_fastapi/core_api/test_security.py
+++ b/airflow-core/tests/unit/api_fastapi/core_api/test_security.py
@@ -23,9 +23,22 @@ from fastapi import HTTPException
from jwt import ExpiredSignatureError, InvalidTokenError
from airflow.api_fastapi.app import create_app
-from airflow.api_fastapi.auth.managers.models.resource_details import
DagAccessEntity
+from airflow.api_fastapi.auth.managers.models.resource_details import (
+ ConnectionDetails,
+ DagAccessEntity,
+ PoolDetails,
+ VariableDetails,
+)
from airflow.api_fastapi.auth.managers.simple.user import SimpleAuthManagerUser
-from airflow.api_fastapi.core_api.security import is_safe_url,
requires_access_dag, resolve_user_from_token
+from airflow.api_fastapi.core_api.security import (
+ is_safe_url,
+ requires_access_connection,
+ requires_access_dag,
+ requires_access_pool,
+ requires_access_variable,
+ resolve_user_from_token,
+)
+from airflow.models import Connection, Pool, Variable
from tests_common.test_utils.config import conf_vars
@@ -141,3 +154,78 @@ class TestFastApiSecurity:
request = Mock()
request.base_url = "https://requesting_server_base_url.com/prefix2"
assert is_safe_url(url, request=request) == expected_is_safe
+
+ @pytest.mark.db_test
+ @pytest.mark.parametrize(
+ "team_name",
+ [None, "team1"],
+ )
+ @patch.object(Connection, "get_team_name")
+ @patch("airflow.api_fastapi.core_api.security.get_auth_manager")
+ async def test_requires_access_connection(self, mock_get_auth_manager,
mock_get_team_name, team_name):
+ auth_manager = Mock()
+ auth_manager.is_authorized_connection.return_value = True
+ mock_get_auth_manager.return_value = auth_manager
+ fastapi_request = Mock()
+ fastapi_request.path_params = {"connection_id": "conn_id"}
+ mock_get_team_name.return_value = team_name
+ user = Mock()
+
+ requires_access_connection("GET")(fastapi_request, user)
+
+ auth_manager.is_authorized_connection.assert_called_once_with(
+ method="GET",
+ details=ConnectionDetails(conn_id="conn_id", team_name=team_name),
+ user=user,
+ )
+ mock_get_team_name.assert_called_once_with("conn_id")
+
+ @pytest.mark.db_test
+ @pytest.mark.parametrize(
+ "team_name",
+ [None, "team1"],
+ )
+ @patch.object(Variable, "get_team_name")
+ @patch("airflow.api_fastapi.core_api.security.get_auth_manager")
+ async def test_requires_access_variable(self, mock_get_auth_manager,
mock_get_team_name, team_name):
+ auth_manager = Mock()
+ auth_manager.is_authorized_variable.return_value = True
+ mock_get_auth_manager.return_value = auth_manager
+ fastapi_request = Mock()
+ fastapi_request.path_params = {"variable_key": "var_key"}
+ mock_get_team_name.return_value = team_name
+ user = Mock()
+
+ requires_access_variable("GET")(fastapi_request, user)
+
+ auth_manager.is_authorized_variable.assert_called_once_with(
+ method="GET",
+ details=VariableDetails(key="var_key", team_name=team_name),
+ user=user,
+ )
+ mock_get_team_name.assert_called_once_with("var_key")
+
+ @pytest.mark.db_test
+ @pytest.mark.parametrize(
+ "team_name",
+ [None, "team1"],
+ )
+ @patch.object(Pool, "get_team_name")
+ @patch("airflow.api_fastapi.core_api.security.get_auth_manager")
+ async def test_requires_access_pool(self, mock_get_auth_manager,
mock_get_team_name, team_name):
+ auth_manager = Mock()
+ auth_manager.is_authorized_pool.return_value = True
+ mock_get_auth_manager.return_value = auth_manager
+ fastapi_request = Mock()
+ fastapi_request.path_params = {"pool_name": "pool"}
+ mock_get_team_name.return_value = team_name
+ user = Mock()
+
+ requires_access_pool("GET")(fastapi_request, user)
+
+ auth_manager.is_authorized_pool.assert_called_once_with(
+ method="GET",
+ details=PoolDetails(name="pool", team_name=team_name),
+ user=user,
+ )
+ mock_get_team_name.assert_called_once_with("pool")
diff --git a/airflow-core/tests/unit/models/test_connection.py
b/airflow-core/tests/unit/models/test_connection.py
index 62c9377e2b6..6fdea1b2d05 100644
--- a/airflow-core/tests/unit/models/test_connection.py
+++ b/airflow-core/tests/unit/models/test_connection.py
@@ -19,6 +19,7 @@ from __future__ import annotations
import re
import sys
+from typing import TYPE_CHECKING
from unittest import mock
import pytest
@@ -28,6 +29,12 @@ from airflow.models import Connection
from airflow.sdk.exceptions import AirflowRuntimeError, ErrorType
from airflow.sdk.execution_time.comms import ErrorResponse
+from tests_common.test_utils.db import clear_db_connections
+
+if TYPE_CHECKING:
+ from airflow.models.team import Team
+ from airflow.settings import Session
+
class TestConnection:
@pytest.mark.parametrize(
@@ -355,3 +362,14 @@ class TestConnection:
# Verify the backends were called
mock_env_backend.assert_called_once_with(conn_id="test_conn")
mock_db_backend.assert_called_once_with(conn_id="test_conn")
+
+ @pytest.mark.db_test
+ def test_get_team_name(self, testing_team: Team, session: Session):
+ clear_db_connections()
+
+ connection = Connection(conn_id="test_conn", conn_type="test_type",
team_id=testing_team.id)
+ session.add(connection)
+ session.flush()
+
+ assert Connection.get_team_name("test_conn", session=session) ==
"testing"
+ clear_db_connections()
diff --git a/airflow-core/tests/unit/models/test_dag.py
b/airflow-core/tests/unit/models/test_dag.py
index cf3814fa9fd..796ed807df8 100644
--- a/airflow-core/tests/unit/models/test_dag.py
+++ b/airflow-core/tests/unit/models/test_dag.py
@@ -21,7 +21,6 @@ import datetime
import logging
import os
import pickle
-import uuid
from datetime import timedelta
from pathlib import Path
from typing import TYPE_CHECKING
@@ -58,7 +57,6 @@ from airflow.models.dagbundle import DagBundleModel
from airflow.models.dagrun import DagRun
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskinstance import TaskInstance as TI
-from airflow.models.team import Team
from airflow.providers.standard.operators.bash import BashOperator
from airflow.providers.standard.operators.empty import EmptyOperator
from airflow.providers.standard.operators.python import PythonOperator
@@ -143,18 +141,6 @@ def test_dags_bundle(configure_testing_dag_bundle):
yield
[email protected]
-def testing_team():
- from airflow.utils.session import create_session
-
- with create_session() as session:
- team = session.query(Team).filter_by(name="testing").one_or_none()
- if not team:
- team = Team(id=uuid.uuid4(), name="testing")
- session.add(team)
- yield team
-
-
def _create_dagrun(
dag: DAG,
*,
diff --git a/airflow-core/tests/unit/models/test_pool.py
b/airflow-core/tests/unit/models/test_pool.py
index 81d4ce66dc7..5f53474ae24 100644
--- a/airflow-core/tests/unit/models/test_pool.py
+++ b/airflow-core/tests/unit/models/test_pool.py
@@ -17,6 +17,8 @@
# under the License.
from __future__ import annotations
+from typing import TYPE_CHECKING
+
import pytest
from airflow import settings
@@ -36,6 +38,10 @@ from tests_common.test_utils.db import (
set_default_pool_slots,
)
+if TYPE_CHECKING:
+ from airflow.models.team import Team
+ from airflow.settings import Session
+
pytestmark = pytest.mark.db_test
DEFAULT_DATE = timezone.datetime(2016, 1, 1)
@@ -319,3 +325,10 @@ class TestPool:
default_pool = Pool.get_default_pool()
assert not Pool.is_default_pool(id=pool.id)
assert Pool.is_default_pool(str(default_pool.id))
+
+ def test_get_team_name(self, testing_team: Team, session: Session):
+ pool = Pool(pool="test", include_deferred=False,
team_id=testing_team.id)
+ session.add(pool)
+ session.flush()
+
+ assert Pool.get_team_name("test", session=session) == "testing"
diff --git a/airflow-core/tests/unit/models/test_variable.py
b/airflow-core/tests/unit/models/test_variable.py
index d7a035bbb8d..91fb045e2c5 100644
--- a/airflow-core/tests/unit/models/test_variable.py
+++ b/airflow-core/tests/unit/models/test_variable.py
@@ -19,6 +19,7 @@ from __future__ import annotations
import logging
import os
+from typing import TYPE_CHECKING
from unittest import mock
import pytest
@@ -31,6 +32,10 @@ from airflow.secrets.metastore import MetastoreBackend
from tests_common.test_utils import db
from tests_common.test_utils.config import conf_vars
+if TYPE_CHECKING:
+ from airflow.models.team import Team
+ from airflow.settings import Session
+
pytestmark = pytest.mark.db_test
@@ -311,6 +316,13 @@ class TestVariable:
assert c != b
+ def test_get_team_name(self, testing_team: Team, session: Session):
+ var = Variable(key="key", val="value", team_id=testing_team.id)
+ session.add(var)
+ session.flush()
+
+ assert Variable.get_team_name("key", session=session) == "testing"
+
@pytest.mark.parametrize(
"variable_value, deserialize_json, expected_masked_values",
diff --git a/devel-common/src/tests_common/pytest_plugin.py
b/devel-common/src/tests_common/pytest_plugin.py
index ce588352d4e..9f05f0ec8b0 100644
--- a/devel-common/src/tests_common/pytest_plugin.py
+++ b/devel-common/src/tests_common/pytest_plugin.py
@@ -25,6 +25,7 @@ import platform
import re
import subprocess
import sys
+import uuid
import warnings
from collections.abc import Callable, Generator
from contextlib import ExitStack, suppress
@@ -2672,6 +2673,23 @@ def testing_dag_bundle():
session.add(testing)
[email protected]
+def testing_team():
+ from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS
+
+ if AIRFLOW_V_3_0_PLUS:
+ from airflow.models.team import Team
+ from airflow.utils.session import create_session
+
+ with create_session() as session:
+ team = session.query(Team).filter_by(name="testing").one_or_none()
+ if not team:
+ team = Team(id=uuid.uuid4(), name="testing")
+ session.add(team)
+ session.flush()
+ yield team
+
+
@pytest.fixture
def create_connection_without_db(monkeypatch):
"""