This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 8199e76a718 Fix EMR Serverless task failure on transient AWS 
throttling errors (#67222)
8199e76a718 is described below

commit 8199e76a718c657907e887952401d2533d74f967
Author: Subham <[email protected]>
AuthorDate: Fri May 22 12:08:26 2026 +0530

    Fix EMR Serverless task failure on transient AWS throttling errors (#67222)
    
    * Fix EMR Serverless task failure on transient AWS throttling errors
    
    * fix: Don't count throttling retries against waiter_max_attempts
---
 .../amazon/aws/utils/waiter_with_logging.py        | 66 +++++++++++++++++++---
 .../amazon/aws/utils/test_waiter_with_logging.py   | 53 +++++++++++++++++
 2 files changed, 112 insertions(+), 7 deletions(-)

diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/utils/waiter_with_logging.py
 
b/providers/amazon/src/airflow/providers/amazon/aws/utils/waiter_with_logging.py
index 615523a7047..60e20bbf2cd 100644
--- 
a/providers/amazon/src/airflow/providers/amazon/aws/utils/waiter_with_logging.py
+++ 
b/providers/amazon/src/airflow/providers/amazon/aws/utils/waiter_with_logging.py
@@ -30,6 +30,28 @@ from airflow.providers.common.compat.sdk import 
AirflowException
 if TYPE_CHECKING:
     from botocore.waiter import Waiter
 
+# Standard throttling and transient error codes to retry on
+# https://docs.aws.amazon.com/general/latest/gr/api-retries.html
+# and https://github.com/boto/botocore/blob/develop/botocore/retryhandler.py
+RETRIABLE_ERROR_CODES = {
+    "ThrottlingException",
+    "Throttling",
+    "RequestLimitExceeded",
+    "ProvisionedThroughputExceededException",
+    "LimitExceededException",
+    "RequestThrottled",
+    "RequestThrottledException",
+    "TooManyRequestsException",
+    "ServerException",
+    "InternalServerError",
+    "InternalFailure",
+    "ServiceUnavailable",
+    "BadGateway",
+    "GatewayTimeout",
+    "RequestTimeout",
+    "RequestTimeoutException",
+}
+
 
 def wait(
     waiter: Waiter,
@@ -65,9 +87,12 @@ def wait(
         status_args = ["Clusters[0].state", "Clusters[0].details"]
     """
     log = logging.getLogger(__name__)
-    for attempt in range(waiter_max_attempts):
-        if attempt:
+    first_attempt = True
+    attempt = 0
+    while attempt < waiter_max_attempts:
+        if not first_attempt:
             time.sleep(waiter_delay)
+        first_attempt = False
         try:
             waiter.wait(**args, WaiterConfig={"MaxAttempts": 1})
 
@@ -87,11 +112,23 @@ def wait(
                 and isinstance(last_response.get("Error"), dict)
                 and "Code" in last_response.get("Error")
             ):
-                raise AirflowException(f"{failure_message}: {error}")
+                error_code = last_response["Error"]["Code"]
+                if error_code not in RETRIABLE_ERROR_CODES:
+                    raise AirflowException(f"{failure_message}: {error}")
+
+                log.info(
+                    "Waiter encountered retriable error: %s. Retrying (attempt 
%d/%d)...",
+                    error_code,
+                    attempt + 1,
+                    waiter_max_attempts,
+                )
+                # Don't increment attempt counter for retriable errors; 
continue looping
+                continue
 
             log.info("%s: %s", status_message, 
_LazyStatusFormatter(status_args, last_response))
         else:
             break
+        attempt += 1
     else:
         raise AirflowException("Waiter error: max attempts reached")
 
@@ -104,7 +141,7 @@ async def async_wait(
     failure_message: str,
     status_message: str,
     status_args: list[str],
-):
+) -> None:
     """
     Use an async boto waiter to poll an AWS service for the specified state.
 
@@ -130,9 +167,12 @@ async def async_wait(
         status_args = ["Clusters[0].state", "Clusters[0].details"]
     """
     log = logging.getLogger(__name__)
-    for attempt in range(waiter_max_attempts):
-        if attempt:
+    first_attempt = True
+    attempt = 0
+    while attempt < waiter_max_attempts:
+        if not first_attempt:
             await asyncio.sleep(waiter_delay)
+        first_attempt = False
         try:
             await waiter.wait(**args, WaiterConfig={"MaxAttempts": 1})
 
@@ -153,11 +193,23 @@ async def async_wait(
                 and isinstance(last_response.get("Error"), dict)
                 and "Code" in last_response.get("Error")
             ):
-                raise 
AirflowException(f"{failure_message}\n{last_response}\n{error}")
+                error_code = last_response["Error"]["Code"]
+                if error_code not in RETRIABLE_ERROR_CODES:
+                    raise 
AirflowException(f"{failure_message}\n{last_response}\n{error}")
+
+                log.info(
+                    "Waiter encountered retriable error: %s. Retrying (attempt 
%d/%d)...",
+                    error_code,
+                    attempt + 1,
+                    waiter_max_attempts,
+                )
+                # Don't increment attempt counter for retriable errors; 
continue looping
+                continue
 
             log.info("%s: %s", status_message, 
_LazyStatusFormatter(status_args, last_response))
         else:
             break
+        attempt += 1
     else:
         raise AirflowException("Waiter error: max attempts reached")
 
diff --git 
a/providers/amazon/tests/unit/amazon/aws/utils/test_waiter_with_logging.py 
b/providers/amazon/tests/unit/amazon/aws/utils/test_waiter_with_logging.py
index 716af7fbdd1..3414403578e 100644
--- a/providers/amazon/tests/unit/amazon/aws/utils/test_waiter_with_logging.py
+++ b/providers/amazon/tests/unit/amazon/aws/utils/test_waiter_with_logging.py
@@ -367,3 +367,56 @@ class TestWaiter:
         finally:
             logger.setLevel(level)
         status_format_mock.assert_not_called()
+
+    @mock.patch("time.sleep")
+    def test_wait_with_retriable_throttling_error(self, mock_sleep):
+        mock_sleep.return_value = True
+        mock_waiter = mock.MagicMock()
+        throttling_error = WaiterError(
+            name="test_waiter",
+            reason="An error occurred (ThrottlingException) when calling the 
GetJobRun operation: Rate exceeded",
+            last_response={
+                "Error": {
+                    "Message": "Rate exceeded",
+                    "Code": "ThrottlingException",
+                }
+            },
+        )
+        mock_waiter.wait.side_effect = [throttling_error, throttling_error, 
True]
+        wait(
+            waiter=mock_waiter,
+            waiter_delay=123,
+            waiter_max_attempts=10,
+            args={"test_arg": "test_value"},
+            failure_message="test failure message",
+            status_message="test status message",
+            status_args=["Status.State"],
+        )
+        assert mock_waiter.wait.call_count == 3
+        mock_sleep.assert_called_with(123)
+
+    @pytest.mark.asyncio
+    async def test_async_wait_with_retriable_throttling_error(self):
+        mock_waiter = mock.MagicMock()
+        throttling_error = WaiterError(
+            name="test_waiter",
+            reason="An error occurred (ThrottlingException) when calling the 
GetJobRun operation: Rate exceeded",
+            last_response={
+                "Error": {
+                    "Message": "Rate exceeded",
+                    "Code": "ThrottlingException",
+                }
+            },
+        )
+        mock_waiter.wait = AsyncMock()
+        mock_waiter.wait.side_effect = [throttling_error, throttling_error, 
True]
+        await async_wait(
+            waiter=mock_waiter,
+            waiter_delay=0,
+            waiter_max_attempts=10,
+            args={"test_arg": "test_value"},
+            failure_message="test failure message",
+            status_message="test status message",
+            status_args=["Status.State"],
+        )
+        assert mock_waiter.wait.call_count == 3

Reply via email to