This is an automated email from the ASF dual-hosted git repository.
onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 6c6a4a6a29d Add a backup implementation in AWS MwaaHook for calling
the MWAA API (#47035)
6c6a4a6a29d is described below
commit 6c6a4a6a29d3c1c6d8ce3715f91bed12c679b284
Author: Ramit Kataria <[email protected]>
AuthorDate: Thu Mar 6 12:20:17 2025 -0800
Add a backup implementation in AWS MwaaHook for calling the MWAA API
(#47035)
The existing implementation doesn't work when the user doesn't have
`airflow:InvokeRestApi` permission in their IAM policy or when they make
more than 10 transactions per second.
This implementation mitigates those issues by using a session token
approach. However, my existing implementation is still used by default
because it is simpler.
Some context here:
https://docs.aws.amazon.com/mwaa/latest/userguide/access-mwaa-apache-airflow-rest-api.html
---
docs/spelling_wordlist.txt | 1 +
.../src/airflow/providers/amazon/aws/hooks/mwaa.py | 94 +++++++--
.../tests/unit/amazon/aws/hooks/test_mwaa.py | 214 +++++++++++++++------
3 files changed, 238 insertions(+), 71 deletions(-)
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 87d1909eda6..ff16e3f01ca 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1863,6 +1863,7 @@ urls
useHCatalog
useLegacySQL
useQueryCache
+userguide
userId
userpass
usr
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py
b/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py
index d7f01238e6a..0f47f0bafb6 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/mwaa.py
@@ -18,6 +18,7 @@
from __future__ import annotations
+import requests
from botocore.exceptions import ClientError
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
@@ -29,6 +30,12 @@ class MwaaHook(AwsBaseHook):
Provide thin wrapper around :external+boto3:py:class:`boto3.client("mwaa")
<MWAA.Client>`
+ If your IAM policy doesn't have `airflow:InvokeRestApi` permission, the
hook will use a fallback method
+ that uses the AWS credential to generate a local web login token for the
Airflow Web UI and then directly
+ make requests to the Airflow API. This fallback method can be set as the
default (and only) method used by
+ setting `generate_local_token` to True. Learn more here:
+
https://docs.aws.amazon.com/mwaa/latest/userguide/access-mwaa-apache-airflow-rest-api.html#granting-access-MWAA-Enhanced-REST-API
+
Additional arguments (such as ``aws_conn_id``) may be specified and
are passed down to the underlying AwsBaseHook.
@@ -47,6 +54,7 @@ class MwaaHook(AwsBaseHook):
method: str,
body: dict | None = None,
query_params: dict | None = None,
+ generate_local_token: bool = False,
) -> dict:
"""
Invoke the REST API on the Airflow webserver with the specified inputs.
@@ -56,30 +64,86 @@ class MwaaHook(AwsBaseHook):
:param env_name: name of the MWAA environment
:param path: Apache Airflow REST API endpoint path to be called
- :param method: HTTP method used for making Airflow REST API calls
+ :param method: HTTP method used for making Airflow REST API calls:
'GET'|'PUT'|'POST'|'PATCH'|'DELETE'
:param body: Request body for the Apache Airflow REST API call
:param query_params: Query parameters to be included in the Apache
Airflow REST API call
+ :param generate_local_token: If True, only the local web token method
is used without trying boto's
+ `invoke_rest_api` first. If False, the local web token method is
used as a fallback after trying
+ boto's `invoke_rest_api`
"""
- body = body or {}
+ # Filter out keys with None values because Airflow REST API doesn't
accept requests otherwise
+ body = {k: v for k, v in body.items() if v is not None} if body else {}
+ query_params = query_params or {}
api_kwargs = {
"Name": env_name,
"Path": path,
"Method": method,
- # Filter out keys with None values because Airflow REST API
doesn't accept requests otherwise
- "Body": {k: v for k, v in body.items() if v is not None},
- "QueryParameters": query_params if query_params else {},
+ "Body": body,
+ "QueryParameters": query_params,
}
+
+ if generate_local_token:
+ return
self._invoke_rest_api_using_local_session_token(**api_kwargs)
+
try:
- result = self.conn.invoke_rest_api(**api_kwargs)
+ response = self.conn.invoke_rest_api(**api_kwargs)
# ResponseMetadata is removed because it contains data that is
either very unlikely to be useful
# in XComs and logs, or redundant given the data already included
in the response
- result.pop("ResponseMetadata", None)
- return result
+ response.pop("ResponseMetadata", None)
+ return response
+
except ClientError as e:
- to_log = e.response
- # ResponseMetadata and Error are removed because they contain data
that is either very unlikely to
- # be useful in XComs and logs, or redundant given the data already
included in the response
- to_log.pop("ResponseMetadata", None)
- to_log.pop("Error", None)
- self.log.error(to_log)
- raise e
+ if (
+ e.response["Error"]["Code"] == "AccessDeniedException"
+ and "Airflow role" in e.response["Error"]["Message"]
+ ):
+ self.log.info(
+ "Access Denied due to missing airflow:InvokeRestApi in IAM
policy. Trying again by generating local token..."
+ )
+ return
self._invoke_rest_api_using_local_session_token(**api_kwargs)
+ else:
+ to_log = e.response
+ # ResponseMetadata is removed because it contains data that is
either very unlikely to be
+ # useful in XComs and logs, or redundant given the data
already included in the response
+ to_log.pop("ResponseMetadata", None)
+ self.log.error(to_log)
+ raise
+
+ def _invoke_rest_api_using_local_session_token(
+ self,
+ **api_kwargs,
+ ) -> dict:
+ try:
+ session, hostname = self._get_session_conn(api_kwargs["Name"])
+
+ response = session.request(
+ method=api_kwargs["Method"],
+ url=f"https://{hostname}/api/v1{api_kwargs['Path']}",
+ params=api_kwargs["QueryParameters"],
+ json=api_kwargs["Body"],
+ timeout=10,
+ )
+ response.raise_for_status()
+
+ except requests.HTTPError as e:
+ self.log.error(e.response.json())
+ raise
+
+ return {
+ "RestApiStatusCode": response.status_code,
+ "RestApiResponse": response.json(),
+ }
+
+ # Based on:
https://docs.aws.amazon.com/mwaa/latest/userguide/access-mwaa-apache-airflow-rest-api.html#create-web-server-session-token
+ def _get_session_conn(self, env_name: str) -> tuple:
+ create_token_response = self.conn.create_web_login_token(Name=env_name)
+ web_server_hostname = create_token_response["WebServerHostname"]
+ web_token = create_token_response["WebToken"]
+
+ login_url = f"https://{web_server_hostname}/aws_mwaa/login"
+ login_payload = {"token": web_token}
+ session = requests.Session()
+ login_response = session.post(login_url, data=login_payload,
timeout=10)
+ login_response.raise_for_status()
+
+ return session, web_server_hostname
diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_mwaa.py
b/providers/amazon/tests/unit/amazon/aws/hooks/test_mwaa.py
index 461e3258912..d8046db33a8 100644
--- a/providers/amazon/tests/unit/amazon/aws/hooks/test_mwaa.py
+++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_mwaa.py
@@ -19,6 +19,7 @@ from __future__ import annotations
from unittest import mock
import pytest
+import requests
from botocore.exceptions import ClientError
from moto import mock_aws
@@ -27,16 +28,161 @@ from airflow.providers.amazon.aws.hooks.mwaa import
MwaaHook
ENV_NAME = "test_env"
PATH = "/dags/test_dag/dagRuns"
METHOD = "POST"
+BODY: dict = {"conf": {}}
QUERY_PARAMS = {"limit": 30}
+HOSTNAME = "example.com"
class TestMwaaHook:
+ @pytest.fixture
+ def mock_conn(self):
+ with mock.patch.object(MwaaHook, "conn") as m:
+ yield m
+
def setup_method(self):
self.hook = MwaaHook()
- # these example responses are included here instead of as a constant
because the hook will mutate
- # responses causing subsequent tests to fail
- self.example_responses = {
+ def test_init(self):
+ assert self.hook.client_type == "mwaa"
+
+ @mock_aws
+ def test_get_conn(self):
+ assert self.hook.conn is not None
+
+ @pytest.mark.parametrize(
+ "body",
+ [
+ pytest.param(None, id="no_body"),
+ pytest.param(BODY, id="non_empty_body"),
+ ],
+ )
+ def test_invoke_rest_api_success(self, body, mock_conn, example_responses):
+ boto_invoke_mock =
mock.MagicMock(return_value=example_responses["success"])
+ mock_conn.invoke_rest_api = boto_invoke_mock
+
+ retval = self.hook.invoke_rest_api(
+ env_name=ENV_NAME, path=PATH, method=METHOD, body=body,
query_params=QUERY_PARAMS
+ )
+ kwargs_to_assert = {
+ "Name": ENV_NAME,
+ "Path": PATH,
+ "Method": METHOD,
+ "Body": body if body else {},
+ "QueryParameters": QUERY_PARAMS,
+ }
+ boto_invoke_mock.assert_called_once_with(**kwargs_to_assert)
+ mock_conn.create_web_login_token.assert_not_called()
+ assert retval == {k: v for k, v in
example_responses["success"].items() if k != "ResponseMetadata"}
+
+ def test_invoke_rest_api_failure(self, mock_conn, example_responses):
+ error = ClientError(error_response=example_responses["failure"],
operation_name="invoke_rest_api")
+ mock_conn.invoke_rest_api = mock.MagicMock(side_effect=error)
+ mock_error_log = mock.MagicMock()
+ self.hook.log.error = mock_error_log
+
+ with pytest.raises(ClientError) as caught_error:
+ self.hook.invoke_rest_api(env_name=ENV_NAME, path=PATH,
method=METHOD)
+
+ assert caught_error.value == error
+ mock_conn.create_web_login_token.assert_not_called()
+ expected_log = {k: v for k, v in example_responses["failure"].items()
if k != "ResponseMetadata"}
+ mock_error_log.assert_called_once_with(expected_log)
+
+ @pytest.mark.parametrize("generate_local_token", [pytest.param(True),
pytest.param(False)])
+ @mock.patch("airflow.providers.amazon.aws.hooks.mwaa.requests.Session")
+ def test_invoke_rest_api_local_token_parameter(
+ self, mock_create_session, generate_local_token, mock_conn
+ ):
+ self.hook.invoke_rest_api(
+ env_name=ENV_NAME, path=PATH, method=METHOD,
generate_local_token=generate_local_token
+ )
+ if generate_local_token:
+ mock_conn.invoke_rest_api.assert_not_called()
+ mock_conn.create_web_login_token.assert_called_once()
+ mock_create_session.assert_called_once()
+ mock_create_session.return_value.request.assert_called_once()
+ else:
+ mock_conn.invoke_rest_api.assert_called_once()
+
+ @mock.patch.object(MwaaHook, "_get_session_conn")
+ def test_invoke_rest_api_fallback_success_when_iam_fails(
+ self, mock_get_session_conn, mock_conn, example_responses
+ ):
+ boto_invoke_error = ClientError(
+ error_response=example_responses["missingIamRole"],
operation_name="invoke_rest_api"
+ )
+ mock_conn.invoke_rest_api =
mock.MagicMock(side_effect=boto_invoke_error)
+
+ kwargs_to_assert = {
+ "method": METHOD,
+ "url": f"https://{HOSTNAME}/api/v1{PATH}",
+ "params": QUERY_PARAMS,
+ "json": BODY,
+ "timeout": 10,
+ }
+
+ mock_response = mock.MagicMock()
+ mock_response.status_code =
example_responses["success"]["RestApiStatusCode"]
+ mock_response.json.return_value =
example_responses["success"]["RestApiResponse"]
+ mock_session = mock.MagicMock()
+ mock_session.request.return_value = mock_response
+
+ mock_get_session_conn.return_value = (mock_session, HOSTNAME)
+
+ retval = self.hook.invoke_rest_api(
+ env_name=ENV_NAME, path=PATH, method=METHOD, body=BODY,
query_params=QUERY_PARAMS
+ )
+
+ mock_session.request.assert_called_once_with(**kwargs_to_assert)
+ mock_response.raise_for_status.assert_called_once()
+ assert retval == {k: v for k, v in
example_responses["success"].items() if k != "ResponseMetadata"}
+
+ @mock.patch.object(MwaaHook, "_get_session_conn")
+ def test_invoke_rest_api_using_local_session_token_failure(
+ self, mock_get_session_conn, example_responses
+ ):
+ mock_response = mock.MagicMock()
+ mock_response.json.return_value =
example_responses["failure"]["RestApiResponse"]
+ error = requests.HTTPError(response=mock_response)
+ mock_response.raise_for_status.side_effect = error
+
+ mock_session = mock.MagicMock()
+ mock_session.request.return_value = mock_response
+
+ mock_get_session_conn.return_value = (mock_session, HOSTNAME)
+
+ mock_error_log = mock.MagicMock()
+ self.hook.log.error = mock_error_log
+
+ with pytest.raises(requests.HTTPError) as caught_error:
+ self.hook.invoke_rest_api(env_name=ENV_NAME, path=PATH,
method=METHOD, generate_local_token=True)
+
+ assert caught_error.value == error
+
mock_error_log.assert_called_once_with(example_responses["failure"]["RestApiResponse"])
+
+ @mock.patch("airflow.providers.amazon.aws.hooks.mwaa.requests.Session")
+ def test_get_session_conn(self, mock_create_session, mock_conn):
+ token = "token"
+ mock_conn.create_web_login_token.return_value = {"WebServerHostname":
HOSTNAME, "WebToken": token}
+ login_url = f"https://{HOSTNAME}/aws_mwaa/login"
+ login_payload = {"token": token}
+
+ mock_session = mock.MagicMock()
+ mock_create_session.return_value = mock_session
+
+ retval = self.hook._get_session_conn(env_name=ENV_NAME)
+
+ mock_conn.create_web_login_token.assert_called_once_with(Name=ENV_NAME)
+ mock_create_session.assert_called_once_with()
+ mock_session.post.assert_called_once_with(login_url,
data=login_payload, timeout=10)
+ mock_session.post.return_value.raise_for_status.assert_called_once()
+
+ assert retval == (mock_session, HOSTNAME)
+
+ @pytest.fixture
+ def example_responses(self):
+ """Fixture for test responses to avoid mutation between tests."""
+ return {
"success": {
"ResponseMetadata": {
"RequestId": "some ID",
@@ -73,57 +219,13 @@ class TestMwaaHook:
"type":
"https://airflow.apache.org/docs/apache-airflow/2.10.3/stable-rest-api-ref.html#section/Errors/NotFound",
},
},
+ "missingIamRole": {
+ "Error": {"Message": "No Airflow role granted in IAM.",
"Code": "AccessDeniedException"},
+ "ResponseMetadata": {
+ "RequestId": "some ID",
+ "HTTPStatusCode": 403,
+ "HTTPHeaders": {"header1": "value1"},
+ "RetryAttempts": 0,
+ },
+ },
}
-
- def test_init(self):
- assert self.hook.client_type == "mwaa"
-
- @mock_aws
- def test_get_conn(self):
- assert self.hook.conn is not None
-
- @pytest.mark.parametrize(
- "body",
- [
- pytest.param(None, id="no_body"),
- pytest.param({"conf": {}}, id="non_empty_body"),
- ],
- )
- @mock.patch.object(MwaaHook, "conn")
- def test_invoke_rest_api_success(self, mock_conn, body) -> None:
- boto_invoke_mock =
mock.MagicMock(return_value=self.example_responses["success"])
- mock_conn.invoke_rest_api = boto_invoke_mock
-
- retval = self.hook.invoke_rest_api(ENV_NAME, PATH, METHOD, body,
QUERY_PARAMS)
- kwargs_to_assert = {
- "Name": ENV_NAME,
- "Path": PATH,
- "Method": METHOD,
- "Body": body if body else {},
- "QueryParameters": QUERY_PARAMS,
- }
- boto_invoke_mock.assert_called_once_with(**kwargs_to_assert)
- assert retval == {
- k: v for k, v in self.example_responses["success"].items() if k !=
"ResponseMetadata"
- }
-
- @mock.patch.object(MwaaHook, "conn")
- def test_invoke_rest_api_failure(self, mock_conn) -> None:
- error = ClientError(
- error_response=self.example_responses["failure"],
operation_name="invoke_rest_api"
- )
- boto_invoke_mock = mock.MagicMock(side_effect=error)
- mock_conn.invoke_rest_api = boto_invoke_mock
- mock_log = mock.MagicMock()
- self.hook.log.error = mock_log
-
- with pytest.raises(ClientError) as caught_error:
- self.hook.invoke_rest_api(ENV_NAME, PATH, METHOD)
-
- assert caught_error.value == error
- expected_log = {
- k: v
- for k, v in self.example_responses["failure"].items()
- if k != "ResponseMetadata" and k != "Error"
- }
- mock_log.assert_called_once_with(expected_log)