This is an automated email from the ASF dual-hosted git repository.

ferruzzi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 41ebf28103 ECS Executor - Add backoff on failed task retry (#37109)
41ebf28103 is described below

commit 41ebf28103007d4894d86783dbcdc3afc16ec2f6
Author: D. Ferruzzi <ferru...@amazon.com>
AuthorDate: Mon Feb 5 10:45:50 2024 -0800

    ECS Executor - Add backoff on failed task retry (#37109)
    
    * ECS Executor - Add backoff on failed task retry
---
 .../amazon/aws/executors/ecs/ecs_executor.py       | 23 +++++++++++++++---
 .../providers/amazon/aws/executors/ecs/utils.py    |  2 ++
 .../executors/utils/exponential_backoff_retry.py   | 27 ++++++++++++++++++----
 .../amazon/aws/executors/ecs/test_ecs_executor.py  | 17 ++++++++++----
 .../utils/test_exponential_backoff_retry.py        | 20 +++++++++++++++-
 5 files changed, 77 insertions(+), 12 deletions(-)

diff --git a/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py 
b/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
index 4c805d9b53..2f0564ed9a 100644
--- a/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
+++ b/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
@@ -42,7 +42,10 @@ from airflow.providers.amazon.aws.executors.ecs.utils import 
(
     EcsQueuedTask,
     EcsTaskCollection,
 )
-from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry 
import exponential_backoff_retry
+from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry 
import (
+    calculate_next_attempt_delay,
+    exponential_backoff_retry,
+)
 from airflow.providers.amazon.aws.hooks.ecs import EcsHook
 from airflow.utils import timezone
 from airflow.utils.state import State
@@ -300,7 +303,14 @@ class AwsEcsExecutor(BaseExecutor):
             )
             self.active_workers.increment_failure_count(task_key)
             self.pending_tasks.appendleft(
-                EcsQueuedTask(task_key, task_cmd, queue, exec_info, 
failure_count + 1)
+                EcsQueuedTask(
+                    task_key,
+                    task_cmd,
+                    queue,
+                    exec_info,
+                    failure_count + 1,
+                    timezone.utcnow() + 
calculate_next_attempt_delay(failure_count),
+                )
             )
         else:
             self.log.error(
@@ -331,6 +341,8 @@ class AwsEcsExecutor(BaseExecutor):
             exec_config = ecs_task.executor_config
             attempt_number = ecs_task.attempt_number
             _failure_reasons = []
+            if timezone.utcnow() < ecs_task.next_attempt_time:
+                continue
             try:
                 run_task_response = self._run_task(task_key, cmd, queue, 
exec_config)
             except NoCredentialsError:
@@ -361,6 +373,9 @@ class AwsEcsExecutor(BaseExecutor):
                 # Make sure the number of attempts does not exceed 
MAX_RUN_TASK_ATTEMPTS
                 if int(attempt_number) <= 
int(self.__class__.MAX_RUN_TASK_ATTEMPTS):
                     ecs_task.attempt_number += 1
+                    ecs_task.next_attempt_time = timezone.utcnow() + 
calculate_next_attempt_delay(
+                        attempt_number
+                    )
                     self.pending_tasks.appendleft(ecs_task)
                 else:
                     self.log.error(
@@ -422,7 +437,9 @@ class AwsEcsExecutor(BaseExecutor):
         """Save the task to be executed in the next sync by inserting the 
commands into a queue."""
         if executor_config and ("name" in executor_config or "command" in 
executor_config):
             raise ValueError('Executor Config should never override "name" or 
"command"')
-        self.pending_tasks.append(EcsQueuedTask(key, command, queue, 
executor_config or {}, 1))
+        self.pending_tasks.append(
+            EcsQueuedTask(key, command, queue, executor_config or {}, 1, 
timezone.utcnow())
+        )
 
     def end(self, heartbeat_interval=10):
         """Waits for all currently running tasks to end, and doesn't launch 
any tasks."""
diff --git a/airflow/providers/amazon/aws/executors/ecs/utils.py 
b/airflow/providers/amazon/aws/executors/ecs/utils.py
index 7913bdf227..139ef35d71 100644
--- a/airflow/providers/amazon/aws/executors/ecs/utils.py
+++ b/airflow/providers/amazon/aws/executors/ecs/utils.py
@@ -23,6 +23,7 @@ Data classes and utility functions used by the ECS executor.
 
 from __future__ import annotations
 
+import datetime
 from collections import defaultdict
 from dataclasses import dataclass
 from typing import TYPE_CHECKING, Any, Callable, Dict, List
@@ -58,6 +59,7 @@ class EcsQueuedTask:
     queue: str
     executor_config: ExecutorConfigType
     attempt_number: int
+    next_attempt_time: datetime.datetime
 
 
 @dataclass
diff --git 
a/airflow/providers/amazon/aws/executors/utils/exponential_backoff_retry.py 
b/airflow/providers/amazon/aws/executors/utils/exponential_backoff_retry.py
index 8a69b6f3b6..fa53011f04 100644
--- a/airflow/providers/amazon/aws/executors/utils/exponential_backoff_retry.py
+++ b/airflow/providers/amazon/aws/executors/utils/exponential_backoff_retry.py
@@ -25,6 +25,21 @@ from airflow.utils import timezone
 log = logging.getLogger(__name__)
 
 
+def calculate_next_attempt_delay(
+    attempt_number: int,
+    max_delay: int = 60 * 2,
+    exponent_base: int = 4,
+) -> timedelta:
+    """
+    Calculate the exponential backoff (in seconds) until the next attempt.
+
+    :param attempt_number: Number of attempts since last success.
+    :param max_delay: Maximum delay in seconds between retries. Default 120.
+    :param exponent_base: Exponent base to calculate delay. Default 4.
+    """
+    return timedelta(seconds=min((exponent_base**attempt_number), max_delay))
+
+
 def exponential_backoff_retry(
     last_attempt_time: datetime,
     attempts_since_last_successful: int,
@@ -34,7 +49,7 @@ def exponential_backoff_retry(
     exponent_base: int = 4,
 ) -> None:
     """
-    Retries a callable function with exponential backoff between attempts if 
it raises an exception.
+    Retry a callable function with exponential backoff between attempts if it 
raises an exception.
 
     :param last_attempt_time: Timestamp of last attempt call.
     :param attempts_since_last_successful: Number of attempts since last 
success.
@@ -47,8 +62,10 @@ def exponential_backoff_retry(
         log.error("Max attempts reached. Exiting.")
         return
 
-    delay = min((exponent_base**attempts_since_last_successful), max_delay)
-    next_retry_time = last_attempt_time + timedelta(seconds=delay)
+    next_retry_time = last_attempt_time + calculate_next_attempt_delay(
+        attempt_number=attempts_since_last_successful, max_delay=max_delay, 
exponent_base=exponent_base
+    )
+
     current_time = timezone.utcnow()
 
     if current_time >= next_retry_time:
@@ -56,5 +73,7 @@ def exponential_backoff_retry(
             callable_function()
         except Exception:
             log.exception("Error calling %r", callable_function.__name__)
-            next_delay = min((exponent_base ** (attempts_since_last_successful 
+ 1)), max_delay)
+            next_delay = calculate_next_attempt_delay(
+                attempts_since_last_successful + 1, max_delay, exponent_base
+            )
             log.info("Waiting for %s seconds before retrying.", next_delay)
diff --git a/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py 
b/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py
index 78c3c1bc28..04e7774555 100644
--- a/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py
+++ b/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py
@@ -25,6 +25,7 @@ import time
 from functools import partial
 from typing import Callable
 from unittest import mock
+from unittest.mock import MagicMock
 
 import pytest
 import yaml
@@ -33,7 +34,7 @@ from inflection import camelize
 
 from airflow.exceptions import AirflowException
 from airflow.executors.base_executor import BaseExecutor
-from airflow.providers.amazon.aws.executors.ecs import ecs_executor_config
+from airflow.providers.amazon.aws.executors.ecs import ecs_executor, 
ecs_executor_config
 from airflow.providers.amazon.aws.executors.ecs.boto_schema import 
BotoTaskSchema
 from airflow.providers.amazon.aws.executors.ecs.ecs_executor import (
     CONFIG_GROUP_NAME,
@@ -50,6 +51,7 @@ from airflow.providers.amazon.aws.executors.ecs.utils import (
 from airflow.providers.amazon.aws.hooks.ecs import EcsHook
 from airflow.utils.helpers import convert_camel_to_snake
 from airflow.utils.state import State, TaskInstanceState
+from airflow.utils.timezone import utcnow
 
 pytestmark = pytest.mark.db_test
 
@@ -365,7 +367,8 @@ class TestAwsEcsExecutor:
         assert 1 == len(mock_executor.active_workers)
         assert ARN1 in 
mock_executor.active_workers.task_by_key(airflow_key).task_arn
 
-    def test_success_execute_api_exception(self, mock_executor):
+    @mock.patch.object(ecs_executor, "calculate_next_attempt_delay", 
return_value=dt.timedelta(seconds=0))
+    def test_success_execute_api_exception(self, mock_backoff, mock_executor):
         """Test what happens when ECS throws an exception, but ultimately runs 
the task."""
         run_task_exception = Exception("Test exception")
         run_task_success = {
@@ -381,9 +384,10 @@ class TestAwsEcsExecutor:
         }
         mock_executor.ecs.run_task.side_effect = [run_task_exception, 
run_task_exception, run_task_success]
         mock_executor.execute_async(mock_airflow_key, mock_cmd)
+        expected_retry_count = 2
 
         # Fail 2 times
-        for _ in range(2):
+        for _ in range(expected_retry_count):
             mock_executor.attempt_task_runs()
             # Task is not stored in active workers.
             assert len(mock_executor.active_workers) == 0
@@ -392,6 +396,9 @@ class TestAwsEcsExecutor:
         mock_executor.attempt_task_runs()
         assert len(mock_executor.pending_tasks) == 0
         assert ARN1 in mock_executor.active_workers.get_all_arns()
+        assert mock_backoff.call_count == expected_retry_count
+        for attempt_number in range(1, expected_retry_count):
+            mock_backoff.assert_has_calls([mock.call(attempt_number)])
 
     def test_failed_execute_api_exception(self, mock_executor):
         """Test what happens when ECS refuses to execute a task and throws an 
exception"""
@@ -479,7 +486,8 @@ class TestAwsEcsExecutor:
 
     @mock.patch.object(BaseExecutor, "fail")
     @mock.patch.object(BaseExecutor, "success")
-    def test_failed_sync_cumulative_fail(self, success_mock, fail_mock, 
mock_airflow_key, mock_executor):
+    @mock.patch.object(ecs_executor, "calculate_next_attempt_delay", 
return_value=dt.timedelta(seconds=0))
+    def test_failed_sync_cumulative_fail(self, _, success_mock, fail_mock, 
mock_airflow_key, mock_executor):
         """Test that failure_count/attempt_number is cumulative for pending 
tasks and active workers."""
         AwsEcsExecutor.MAX_RUN_TASK_ATTEMPTS = "5"
         mock_executor.ecs.run_task.return_value = {
@@ -488,6 +496,7 @@ class TestAwsEcsExecutor:
                 {"arn": ARN1, "reason": "Sample Failure", "detail": "UnitTest 
Failure - Please ignore"}
             ],
         }
+        mock_executor._calculate_next_attempt_time = 
MagicMock(return_value=utcnow())
         task_key = mock_airflow_key()
         mock_executor.execute_async(task_key, mock_cmd)
         for _ in range(2):
diff --git 
a/tests/providers/amazon/aws/executors/utils/test_exponential_backoff_retry.py 
b/tests/providers/amazon/aws/executors/utils/test_exponential_backoff_retry.py
index c7091139e1..c180184c30 100644
--- 
a/tests/providers/amazon/aws/executors/utils/test_exponential_backoff_retry.py
+++ 
b/tests/providers/amazon/aws/executors/utils/test_exponential_backoff_retry.py
@@ -21,7 +21,10 @@ from unittest import mock
 
 import pytest
 
-from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry 
import exponential_backoff_retry
+from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry 
import (
+    calculate_next_attempt_delay,
+    exponential_backoff_retry,
+)
 
 
 class TestExponentialBackoffRetry:
@@ -279,3 +282,18 @@ class TestExponentialBackoffRetry:
             exponent_base=3,
         )
         assert mock_callable_function.call_count == expected_calls
+
+    def test_calculate_next_attempt_delay(self):
+        exponent_base: int = 4
+        num_loops: int = 3
+        # Setting max_delay this way means there will be three loops will run 
to test:
+        # one will return a value under max_delay, one equal to max_delay, and 
one over.
+        max_delay: int = exponent_base**num_loops - 1
+
+        for attempt_number in range(1, num_loops):
+            returned_delay = calculate_next_attempt_delay(attempt_number, 
max_delay, exponent_base).seconds
+
+            if (expected_delay := exponent_base**attempt_number) <= max_delay:
+                assert returned_delay == expected_delay
+            else:
+                assert returned_delay == max_delay

Reply via email to