This is an automated email from the ASF dual-hosted git repository.
vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 97b3949f0ac Implement `batch_is_authorized_` methods in AWS auth
manager (#55307)
97b3949f0ac is described below
commit 97b3949f0ac582685976063f4292438e705d815d
Author: Vincent <[email protected]>
AuthorDate: Mon Sep 8 14:12:46 2025 -0400
Implement `batch_is_authorized_` methods in AWS auth manager (#55307)
---
.../amazon/aws/auth_manager/aws_auth_manager.py | 95 ++++++++++++++++---
.../aws/auth_manager/test_aws_auth_manager.py | 102 +++++++++++++++++++++
2 files changed, 185 insertions(+), 12 deletions(-)
diff --git
a/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
b/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
index 387d968ec15..6f3e5851c2f 100644
---
a/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
+++
b/providers/amazon/src/airflow/providers/amazon/aws/auth_manager/aws_auth_manager.py
@@ -44,7 +44,10 @@ from airflow.providers.amazon.version_compat import
AIRFLOW_V_3_0_PLUS
if TYPE_CHECKING:
from airflow.api_fastapi.auth.managers.base_auth_manager import
ResourceMethod
from airflow.api_fastapi.auth.managers.models.batch_apis import (
+ IsAuthorizedConnectionRequest,
IsAuthorizedDagRequest,
+ IsAuthorizedPoolRequest,
+ IsAuthorizedVariableRequest,
)
from airflow.api_fastapi.auth.managers.models.resource_details import (
AccessView,
@@ -244,6 +247,27 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
return [menu_item for menu_item in menu_items if
_has_access_to_menu_item(requests[menu_item.value])]
+ def batch_is_authorized_connection(
+ self,
+ requests: Sequence[IsAuthorizedConnectionRequest],
+ *,
+ user: AwsAuthManagerUser,
+ ) -> bool:
+ facade_requests: Sequence[IsAuthorizedRequest] = [
+ cast(
+ "IsAuthorizedRequest",
+ {
+ "method": request["method"],
+ "entity_type": AvpEntities.CONNECTION,
+ "entity_id": cast("ConnectionDetails",
request["details"]).conn_id
+ if request.get("details")
+ else None,
+ },
+ )
+ for request in requests
+ ]
+ return self.avp_facade.batch_is_authorized(requests=facade_requests,
user=user)
+
def batch_is_authorized_dag(
self,
requests: Sequence[IsAuthorizedDagRequest],
@@ -251,18 +275,65 @@ class AwsAuthManager(BaseAuthManager[AwsAuthManagerUser]):
user: AwsAuthManagerUser,
) -> bool:
facade_requests: Sequence[IsAuthorizedRequest] = [
- {
- "method": request["method"],
- "entity_type": AvpEntities.DAG,
- "entity_id": cast("DagDetails", request["details"]).id if
request.get("details") else None,
- "context": {
- "dag_entity": {
- "string": cast("DagAccessEntity",
request["access_entity"]).value,
- },
- }
- if request.get("access_entity")
- else None,
- }
+ cast(
+ "IsAuthorizedRequest",
+ {
+ "method": request["method"],
+ "entity_type": AvpEntities.DAG,
+ "entity_id": cast("DagDetails", request["details"]).id
+ if request.get("details")
+ else None,
+ "context": {
+ "dag_entity": {
+ "string": cast("DagAccessEntity",
request["access_entity"]).value,
+ },
+ }
+ if request.get("access_entity")
+ else None,
+ },
+ )
+ for request in requests
+ ]
+ return self.avp_facade.batch_is_authorized(requests=facade_requests,
user=user)
+
+ def batch_is_authorized_pool(
+ self,
+ requests: Sequence[IsAuthorizedPoolRequest],
+ *,
+ user: AwsAuthManagerUser,
+ ) -> bool:
+ facade_requests: Sequence[IsAuthorizedRequest] = [
+ cast(
+ "IsAuthorizedRequest",
+ {
+ "method": request["method"],
+ "entity_type": AvpEntities.POOL,
+ "entity_id": cast("PoolDetails", request["details"]).name
+ if request.get("details")
+ else None,
+ },
+ )
+ for request in requests
+ ]
+ return self.avp_facade.batch_is_authorized(requests=facade_requests,
user=user)
+
+ def batch_is_authorized_variable(
+ self,
+ requests: Sequence[IsAuthorizedVariableRequest],
+ *,
+ user: AwsAuthManagerUser,
+ ) -> bool:
+ facade_requests: Sequence[IsAuthorizedRequest] = [
+ cast(
+ "IsAuthorizedRequest",
+ {
+ "method": request["method"],
+ "entity_type": AvpEntities.VARIABLE,
+ "entity_id": cast("VariableDetails",
request["details"]).key
+ if request.get("details")
+ else None,
+ },
+ )
for request in requests
]
return self.avp_facade.batch_is_authorized(requests=facade_requests,
user=user)
diff --git
a/providers/amazon/tests/unit/amazon/aws/auth_manager/test_aws_auth_manager.py
b/providers/amazon/tests/unit/amazon/aws/auth_manager/test_aws_auth_manager.py
index eab473ea9a5..70d5a31986e 100644
---
a/providers/amazon/tests/unit/amazon/aws/auth_manager/test_aws_auth_manager.py
+++
b/providers/amazon/tests/unit/amazon/aws/auth_manager/test_aws_auth_manager.py
@@ -439,6 +439,40 @@ class TestAwsAuthManager:
)
assert result == [MenuItem.VARIABLES, MenuItem.DAGS]
+ @patch.object(AwsAuthManager, "avp_facade")
+ def test_batch_is_authorized_connection(
+ self,
+ mock_avp_facade,
+ auth_manager,
+ ):
+ batch_is_authorized = Mock(return_value=True)
+ mock_avp_facade.batch_is_authorized = batch_is_authorized
+
+ result = auth_manager.batch_is_authorized_connection(
+ requests=[
+ {"method": "GET"},
+ {"method": "PUT", "details":
ConnectionDetails(conn_id="test")},
+ ],
+ user=mock,
+ )
+
+ batch_is_authorized.assert_called_once_with(
+ requests=[
+ {
+ "method": "GET",
+ "entity_type": AvpEntities.CONNECTION,
+ "entity_id": None,
+ },
+ {
+ "method": "PUT",
+ "entity_type": AvpEntities.CONNECTION,
+ "entity_id": "test",
+ },
+ ],
+ user=ANY,
+ )
+ assert result
+
@patch.object(AwsAuthManager, "avp_facade")
def test_batch_is_authorized_dag(
self,
@@ -510,6 +544,74 @@ class TestAwsAuthManager:
)
assert result
+ @patch.object(AwsAuthManager, "avp_facade")
+ def test_batch_is_authorized_pool(
+ self,
+ mock_avp_facade,
+ auth_manager,
+ ):
+ batch_is_authorized = Mock(return_value=True)
+ mock_avp_facade.batch_is_authorized = batch_is_authorized
+
+ result = auth_manager.batch_is_authorized_pool(
+ requests=[
+ {"method": "GET"},
+ {"method": "PUT", "details": PoolDetails(name="test")},
+ ],
+ user=mock,
+ )
+
+ batch_is_authorized.assert_called_once_with(
+ requests=[
+ {
+ "method": "GET",
+ "entity_type": AvpEntities.POOL,
+ "entity_id": None,
+ },
+ {
+ "method": "PUT",
+ "entity_type": AvpEntities.POOL,
+ "entity_id": "test",
+ },
+ ],
+ user=ANY,
+ )
+ assert result
+
+ @patch.object(AwsAuthManager, "avp_facade")
+ def test_batch_is_authorized_variable(
+ self,
+ mock_avp_facade,
+ auth_manager,
+ ):
+ batch_is_authorized = Mock(return_value=True)
+ mock_avp_facade.batch_is_authorized = batch_is_authorized
+
+ result = auth_manager.batch_is_authorized_variable(
+ requests=[
+ {"method": "GET"},
+ {"method": "PUT", "details": VariableDetails(key="test")},
+ ],
+ user=mock,
+ )
+
+ batch_is_authorized.assert_called_once_with(
+ requests=[
+ {
+ "method": "GET",
+ "entity_type": AvpEntities.VARIABLE,
+ "entity_id": None,
+ },
+ {
+ "method": "PUT",
+ "entity_type": AvpEntities.VARIABLE,
+ "entity_id": "test",
+ },
+ ],
+ user=ANY,
+ )
+ assert result
+
@pytest.mark.parametrize(
"method, user, expected_result",
[