This is an automated email from the ASF dual-hosted git repository.

onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6c6a4a6a29d Add a backup implementation in AWS MwaaHook for calling 
the MWAA API (#47035)
6c6a4a6a29d is described below

commit 6c6a4a6a29d3c1c6d8ce3715f91bed12c679b284
Author: Ramit Kataria <[email protected]>
AuthorDate: Thu Mar 6 12:20:17 2025 -0800

    Add a backup implementation in AWS MwaaHook for calling the MWAA API 
(#47035)
    
    The existing implementation doesn't work when the user doesn't have
    `airflow:InvokeRestApi` permission in their IAM policy or when they make
    more than 10 transactions per second.
    
    This implementation mitigates those issues by using a session token
    approach. However, my existing implementation is still used by default
    because it is simpler.
    
    Some context here:
    
https://docs.aws.amazon.com/mwaa/latest/userguide/access-mwaa-apache-airflow-rest-api.html
---
 docs/spelling_wordlist.txt                         |   1 +
 .../src/airflow/providers/amazon/aws/hooks/mwaa.py |  94 +++++++--
 .../tests/unit/amazon/aws/hooks/test_mwaa.py       | 214 +++++++++++++++------
 3 files changed, 238 insertions(+), 71 deletions(-)

diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 87d1909eda6..ff16e3f01ca 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1863,6 +1863,7 @@ urls
 useHCatalog
 useLegacySQL
 useQueryCache
+userguide
 userId
 userpass
 usr
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py 
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py
index d7f01238e6a..0f47f0bafb6 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py
@@ -18,6 +18,7 @@
 
 from __future__ import annotations
 
+import requests
 from botocore.exceptions import ClientError
 
 from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -29,6 +30,12 @@ class MwaaHook(AwsBaseHook):
 
     Provide thin wrapper around :external+boto3:py:class:`boto3.client("mwaa") 
<MWAA.Client>`
 
+    If your IAM policy doesn't have `airflow:InvokeRestApi` permission, the 
hook will use a fallback method
+    that uses the AWS credential to generate a local web login token for the 
Airflow Web UI and then directly
+    make requests to the Airflow API. This fallback method can be set as the 
default (and only) method used by
+    setting `generate_local_token` to True.  Learn more here:
+    
https://docs.aws.amazon.com/mwaa/latest/userguide/access-mwaa-apache-airflow-rest-api.html#granting-access-MWAA-Enhanced-REST-API
+
     Additional arguments (such as ``aws_conn_id``) may be specified and
     are passed down to the underlying AwsBaseHook.
 
@@ -47,6 +54,7 @@ class MwaaHook(AwsBaseHook):
         method: str,
         body: dict | None = None,
         query_params: dict | None = None,
+        generate_local_token: bool = False,
     ) -> dict:
         """
         Invoke the REST API on the Airflow webserver with the specified inputs.
@@ -56,30 +64,86 @@ class MwaaHook(AwsBaseHook):
 
         :param env_name: name of the MWAA environment
         :param path: Apache Airflow REST API endpoint path to be called
-        :param method: HTTP method used for making Airflow REST API calls
+        :param method: HTTP method used for making Airflow REST API calls: 
'GET'|'PUT'|'POST'|'PATCH'|'DELETE'
         :param body: Request body for the Apache Airflow REST API call
         :param query_params: Query parameters to be included in the Apache 
Airflow REST API call
+        :param generate_local_token: If True, only the local web token method 
is used without trying boto's
+            `invoke_rest_api` first. If False, the local web token method is 
used as a fallback after trying
+            boto's `invoke_rest_api`
         """
-        body = body or {}
+        # Filter out keys with None values because Airflow REST API doesn't 
accept requests otherwise
+        body = {k: v for k, v in body.items() if v is not None} if body else {}
+        query_params = query_params or {}
         api_kwargs = {
             "Name": env_name,
             "Path": path,
             "Method": method,
-            # Filter out keys with None values because Airflow REST API 
doesn't accept requests otherwise
-            "Body": {k: v for k, v in body.items() if v is not None},
-            "QueryParameters": query_params if query_params else {},
+            "Body": body,
+            "QueryParameters": query_params,
         }
+
+        if generate_local_token:
+            return 
self._invoke_rest_api_using_local_session_token(**api_kwargs)
+
         try:
-            result = self.conn.invoke_rest_api(**api_kwargs)
+            response = self.conn.invoke_rest_api(**api_kwargs)
             # ResponseMetadata is removed because it contains data that is 
either very unlikely to be useful
             # in XComs and logs, or redundant given the data already included 
in the response
-            result.pop("ResponseMetadata", None)
-            return result
+            response.pop("ResponseMetadata", None)
+            return response
+
         except ClientError as e:
-            to_log = e.response
-            # ResponseMetadata and Error are removed because they contain data 
that is either very unlikely to
-            # be useful in XComs and logs, or redundant given the data already 
included in the response
-            to_log.pop("ResponseMetadata", None)
-            to_log.pop("Error", None)
-            self.log.error(to_log)
-            raise e
+            if (
+                e.response["Error"]["Code"] == "AccessDeniedException"
+                and "Airflow role" in e.response["Error"]["Message"]
+            ):
+                self.log.info(
+                    "Access Denied due to missing airflow:InvokeRestApi in IAM 
policy. Trying again by generating local token..."
+                )
+                return 
self._invoke_rest_api_using_local_session_token(**api_kwargs)
+            else:
+                to_log = e.response
+                # ResponseMetadata is removed because it contains data that is 
either very unlikely to be
+                # useful in XComs and logs, or redundant given the data 
already included in the response
+                to_log.pop("ResponseMetadata", None)
+                self.log.error(to_log)
+                raise
+
+    def _invoke_rest_api_using_local_session_token(
+        self,
+        **api_kwargs,
+    ) -> dict:
+        try:
+            session, hostname = self._get_session_conn(api_kwargs["Name"])
+
+            response = session.request(
+                method=api_kwargs["Method"],
+                url=f"https://{hostname}/api/v1{api_kwargs['Path']}",
+                params=api_kwargs["QueryParameters"],
+                json=api_kwargs["Body"],
+                timeout=10,
+            )
+            response.raise_for_status()
+
+        except requests.HTTPError as e:
+            self.log.error(e.response.json())
+            raise
+
+        return {
+            "RestApiStatusCode": response.status_code,
+            "RestApiResponse": response.json(),
+        }
+
+    # Based on: 
https://docs.aws.amazon.com/mwaa/latest/userguide/access-mwaa-apache-airflow-rest-api.html#create-web-server-session-token
+    def _get_session_conn(self, env_name: str) -> tuple:
+        create_token_response = self.conn.create_web_login_token(Name=env_name)
+        web_server_hostname = create_token_response["WebServerHostname"]
+        web_token = create_token_response["WebToken"]
+
+        login_url = f"https://{web_server_hostname}/aws_mwaa/login";
+        login_payload = {"token": web_token}
+        session = requests.Session()
+        login_response = session.post(login_url, data=login_payload, 
timeout=10)
+        login_response.raise_for_status()
+
+        return session, web_server_hostname
diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_mwaa.py 
b/providers/amazon/tests/unit/amazon/aws/hooks/test_mwaa.py
index 461e3258912..d8046db33a8 100644
--- a/providers/amazon/tests/unit/amazon/aws/hooks/test_mwaa.py
+++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_mwaa.py
@@ -19,6 +19,7 @@ from __future__ import annotations
 from unittest import mock
 
 import pytest
+import requests
 from botocore.exceptions import ClientError
 from moto import mock_aws
 
@@ -27,16 +28,161 @@ from airflow.providers.amazon.aws.hooks.mwaa import 
MwaaHook
 ENV_NAME = "test_env"
 PATH = "/dags/test_dag/dagRuns"
 METHOD = "POST"
+BODY: dict = {"conf": {}}
 QUERY_PARAMS = {"limit": 30}
+HOSTNAME = "example.com"
 
 
 class TestMwaaHook:
+    @pytest.fixture
+    def mock_conn(self):
+        with mock.patch.object(MwaaHook, "conn") as m:
+            yield m
+
     def setup_method(self):
         self.hook = MwaaHook()
 
-        # these example responses are included here instead of as a constant 
because the hook will mutate
-        # responses causing subsequent tests to fail
-        self.example_responses = {
+    def test_init(self):
+        assert self.hook.client_type == "mwaa"
+
+    @mock_aws
+    def test_get_conn(self):
+        assert self.hook.conn is not None
+
+    @pytest.mark.parametrize(
+        "body",
+        [
+            pytest.param(None, id="no_body"),
+            pytest.param(BODY, id="non_empty_body"),
+        ],
+    )
+    def test_invoke_rest_api_success(self, body, mock_conn, example_responses):
+        boto_invoke_mock = 
mock.MagicMock(return_value=example_responses["success"])
+        mock_conn.invoke_rest_api = boto_invoke_mock
+
+        retval = self.hook.invoke_rest_api(
+            env_name=ENV_NAME, path=PATH, method=METHOD, body=body, 
query_params=QUERY_PARAMS
+        )
+        kwargs_to_assert = {
+            "Name": ENV_NAME,
+            "Path": PATH,
+            "Method": METHOD,
+            "Body": body if body else {},
+            "QueryParameters": QUERY_PARAMS,
+        }
+        boto_invoke_mock.assert_called_once_with(**kwargs_to_assert)
+        mock_conn.create_web_login_token.assert_not_called()
+        assert retval == {k: v for k, v in 
example_responses["success"].items() if k != "ResponseMetadata"}
+
+    def test_invoke_rest_api_failure(self, mock_conn, example_responses):
+        error = ClientError(error_response=example_responses["failure"], 
operation_name="invoke_rest_api")
+        mock_conn.invoke_rest_api = mock.MagicMock(side_effect=error)
+        mock_error_log = mock.MagicMock()
+        self.hook.log.error = mock_error_log
+
+        with pytest.raises(ClientError) as caught_error:
+            self.hook.invoke_rest_api(env_name=ENV_NAME, path=PATH, 
method=METHOD)
+
+        assert caught_error.value == error
+        mock_conn.create_web_login_token.assert_not_called()
+        expected_log = {k: v for k, v in example_responses["failure"].items() 
if k != "ResponseMetadata"}
+        mock_error_log.assert_called_once_with(expected_log)
+
+    @pytest.mark.parametrize("generate_local_token", [pytest.param(True), 
pytest.param(False)])
+    @mock.patch("airflow.providers.amazon.aws.hooks.mwaa.requests.Session")
+    def test_invoke_rest_api_local_token_parameter(
+        self, mock_create_session, generate_local_token, mock_conn
+    ):
+        self.hook.invoke_rest_api(
+            env_name=ENV_NAME, path=PATH, method=METHOD, 
generate_local_token=generate_local_token
+        )
+        if generate_local_token:
+            mock_conn.invoke_rest_api.assert_not_called()
+            mock_conn.create_web_login_token.assert_called_once()
+            mock_create_session.assert_called_once()
+            mock_create_session.return_value.request.assert_called_once()
+        else:
+            mock_conn.invoke_rest_api.assert_called_once()
+
+    @mock.patch.object(MwaaHook, "_get_session_conn")
+    def test_invoke_rest_api_fallback_success_when_iam_fails(
+        self, mock_get_session_conn, mock_conn, example_responses
+    ):
+        boto_invoke_error = ClientError(
+            error_response=example_responses["missingIamRole"], 
operation_name="invoke_rest_api"
+        )
+        mock_conn.invoke_rest_api = 
mock.MagicMock(side_effect=boto_invoke_error)
+
+        kwargs_to_assert = {
+            "method": METHOD,
+            "url": f"https://{HOSTNAME}/api/v1{PATH}";,
+            "params": QUERY_PARAMS,
+            "json": BODY,
+            "timeout": 10,
+        }
+
+        mock_response = mock.MagicMock()
+        mock_response.status_code = 
example_responses["success"]["RestApiStatusCode"]
+        mock_response.json.return_value = 
example_responses["success"]["RestApiResponse"]
+        mock_session = mock.MagicMock()
+        mock_session.request.return_value = mock_response
+
+        mock_get_session_conn.return_value = (mock_session, HOSTNAME)
+
+        retval = self.hook.invoke_rest_api(
+            env_name=ENV_NAME, path=PATH, method=METHOD, body=BODY, 
query_params=QUERY_PARAMS
+        )
+
+        mock_session.request.assert_called_once_with(**kwargs_to_assert)
+        mock_response.raise_for_status.assert_called_once()
+        assert retval == {k: v for k, v in 
example_responses["success"].items() if k != "ResponseMetadata"}
+
+    @mock.patch.object(MwaaHook, "_get_session_conn")
+    def test_invoke_rest_api_using_local_session_token_failure(
+        self, mock_get_session_conn, example_responses
+    ):
+        mock_response = mock.MagicMock()
+        mock_response.json.return_value = 
example_responses["failure"]["RestApiResponse"]
+        error = requests.HTTPError(response=mock_response)
+        mock_response.raise_for_status.side_effect = error
+
+        mock_session = mock.MagicMock()
+        mock_session.request.return_value = mock_response
+
+        mock_get_session_conn.return_value = (mock_session, HOSTNAME)
+
+        mock_error_log = mock.MagicMock()
+        self.hook.log.error = mock_error_log
+
+        with pytest.raises(requests.HTTPError) as caught_error:
+            self.hook.invoke_rest_api(env_name=ENV_NAME, path=PATH, 
method=METHOD, generate_local_token=True)
+
+        assert caught_error.value == error
+        
mock_error_log.assert_called_once_with(example_responses["failure"]["RestApiResponse"])
+
+    @mock.patch("airflow.providers.amazon.aws.hooks.mwaa.requests.Session")
+    def test_get_session_conn(self, mock_create_session, mock_conn):
+        token = "token"
+        mock_conn.create_web_login_token.return_value = {"WebServerHostname": 
HOSTNAME, "WebToken": token}
+        login_url = f"https://{HOSTNAME}/aws_mwaa/login";
+        login_payload = {"token": token}
+
+        mock_session = mock.MagicMock()
+        mock_create_session.return_value = mock_session
+
+        retval = self.hook._get_session_conn(env_name=ENV_NAME)
+
+        mock_conn.create_web_login_token.assert_called_once_with(Name=ENV_NAME)
+        mock_create_session.assert_called_once_with()
+        mock_session.post.assert_called_once_with(login_url, 
data=login_payload, timeout=10)
+        mock_session.post.return_value.raise_for_status.assert_called_once()
+
+        assert retval == (mock_session, HOSTNAME)
+
+    @pytest.fixture
+    def example_responses(self):
+        """Fixture for test responses to avoid mutation between tests."""
+        return {
             "success": {
                 "ResponseMetadata": {
                     "RequestId": "some ID",
@@ -73,57 +219,13 @@ class TestMwaaHook:
                     "type": 
"https://airflow.apache.org/docs/apache-airflow/2.10.3/stable-rest-api-ref.html#section/Errors/NotFound";,
                 },
             },
+            "missingIamRole": {
+                "Error": {"Message": "No Airflow role granted in IAM.", 
"Code": "AccessDeniedException"},
+                "ResponseMetadata": {
+                    "RequestId": "some ID",
+                    "HTTPStatusCode": 403,
+                    "HTTPHeaders": {"header1": "value1"},
+                    "RetryAttempts": 0,
+                },
+            },
         }
-
-    def test_init(self):
-        assert self.hook.client_type == "mwaa"
-
-    @mock_aws
-    def test_get_conn(self):
-        assert self.hook.conn is not None
-
-    @pytest.mark.parametrize(
-        "body",
-        [
-            pytest.param(None, id="no_body"),
-            pytest.param({"conf": {}}, id="non_empty_body"),
-        ],
-    )
-    @mock.patch.object(MwaaHook, "conn")
-    def test_invoke_rest_api_success(self, mock_conn, body) -> None:
-        boto_invoke_mock = 
mock.MagicMock(return_value=self.example_responses["success"])
-        mock_conn.invoke_rest_api = boto_invoke_mock
-
-        retval = self.hook.invoke_rest_api(ENV_NAME, PATH, METHOD, body, 
QUERY_PARAMS)
-        kwargs_to_assert = {
-            "Name": ENV_NAME,
-            "Path": PATH,
-            "Method": METHOD,
-            "Body": body if body else {},
-            "QueryParameters": QUERY_PARAMS,
-        }
-        boto_invoke_mock.assert_called_once_with(**kwargs_to_assert)
-        assert retval == {
-            k: v for k, v in self.example_responses["success"].items() if k != 
"ResponseMetadata"
-        }
-
-    @mock.patch.object(MwaaHook, "conn")
-    def test_invoke_rest_api_failure(self, mock_conn) -> None:
-        error = ClientError(
-            error_response=self.example_responses["failure"], 
operation_name="invoke_rest_api"
-        )
-        boto_invoke_mock = mock.MagicMock(side_effect=error)
-        mock_conn.invoke_rest_api = boto_invoke_mock
-        mock_log = mock.MagicMock()
-        self.hook.log.error = mock_log
-
-        with pytest.raises(ClientError) as caught_error:
-            self.hook.invoke_rest_api(ENV_NAME, PATH, METHOD)
-
-        assert caught_error.value == error
-        expected_log = {
-            k: v
-            for k, v in self.example_responses["failure"].items()
-            if k != "ResponseMetadata" and k != "Error"
-        }
-        mock_log.assert_called_once_with(expected_log)

Reply via email to