This is an automated email from the ASF dual-hosted git repository. ferruzzi pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 41ebf28103 ECS Executor - Add backoff on failed task retry (#37109) 41ebf28103 is described below commit 41ebf28103007d4894d86783dbcdc3afc16ec2f6 Author: D. Ferruzzi <ferru...@amazon.com> AuthorDate: Mon Feb 5 10:45:50 2024 -0800 ECS Executor - Add backoff on failed task retry (#37109) * ECS Executor - Add backoff on failed task retry --- .../amazon/aws/executors/ecs/ecs_executor.py | 23 +++++++++++++++--- .../providers/amazon/aws/executors/ecs/utils.py | 2 ++ .../executors/utils/exponential_backoff_retry.py | 27 ++++++++++++++++++---- .../amazon/aws/executors/ecs/test_ecs_executor.py | 17 ++++++++++---- .../utils/test_exponential_backoff_retry.py | 20 +++++++++++++++- 5 files changed, 77 insertions(+), 12 deletions(-) diff --git a/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py b/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py index 4c805d9b53..2f0564ed9a 100644 --- a/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +++ b/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py @@ -42,7 +42,10 @@ from airflow.providers.amazon.aws.executors.ecs.utils import ( EcsQueuedTask, EcsTaskCollection, ) -from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry import exponential_backoff_retry +from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry import ( + calculate_next_attempt_delay, + exponential_backoff_retry, +) from airflow.providers.amazon.aws.hooks.ecs import EcsHook from airflow.utils import timezone from airflow.utils.state import State @@ -300,7 +303,14 @@ class AwsEcsExecutor(BaseExecutor): ) self.active_workers.increment_failure_count(task_key) self.pending_tasks.appendleft( - EcsQueuedTask(task_key, task_cmd, queue, exec_info, failure_count + 1) + EcsQueuedTask( + task_key, + task_cmd, + queue, + exec_info, + failure_count + 1, + timezone.utcnow() + calculate_next_attempt_delay(failure_count), + ) ) else: self.log.error( @@ -331,6 +341,8 @@ class AwsEcsExecutor(BaseExecutor): exec_config = ecs_task.executor_config attempt_number = ecs_task.attempt_number _failure_reasons = [] + if timezone.utcnow() < ecs_task.next_attempt_time: + continue try: run_task_response = self._run_task(task_key, cmd, queue, exec_config) except NoCredentialsError: @@ -361,6 +373,9 @@ class AwsEcsExecutor(BaseExecutor): # Make sure the number of attempts does not exceed MAX_RUN_TASK_ATTEMPTS if int(attempt_number) <= int(self.__class__.MAX_RUN_TASK_ATTEMPTS): ecs_task.attempt_number += 1 + ecs_task.next_attempt_time = timezone.utcnow() + calculate_next_attempt_delay( + attempt_number + ) self.pending_tasks.appendleft(ecs_task) else: self.log.error( @@ -422,7 +437,9 @@ class AwsEcsExecutor(BaseExecutor): """Save the task to be executed in the next sync by inserting the commands into a queue.""" if executor_config and ("name" in executor_config or "command" in executor_config): raise ValueError('Executor Config should never override "name" or "command"') - self.pending_tasks.append(EcsQueuedTask(key, command, queue, executor_config or {}, 1)) + self.pending_tasks.append( + EcsQueuedTask(key, command, queue, executor_config or {}, 1, timezone.utcnow()) + ) def end(self, heartbeat_interval=10): """Waits for all currently running tasks to end, and doesn't launch any tasks.""" diff --git a/airflow/providers/amazon/aws/executors/ecs/utils.py b/airflow/providers/amazon/aws/executors/ecs/utils.py index 7913bdf227..139ef35d71 100644 --- a/airflow/providers/amazon/aws/executors/ecs/utils.py +++ b/airflow/providers/amazon/aws/executors/ecs/utils.py @@ -23,6 +23,7 @@ Data classes and utility functions used by the ECS executor. from __future__ import annotations +import datetime from collections import defaultdict from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Callable, Dict, List @@ -58,6 +59,7 @@ class EcsQueuedTask: queue: str executor_config: ExecutorConfigType attempt_number: int + next_attempt_time: datetime.datetime @dataclass diff --git a/airflow/providers/amazon/aws/executors/utils/exponential_backoff_retry.py b/airflow/providers/amazon/aws/executors/utils/exponential_backoff_retry.py index 8a69b6f3b6..fa53011f04 100644 --- a/airflow/providers/amazon/aws/executors/utils/exponential_backoff_retry.py +++ b/airflow/providers/amazon/aws/executors/utils/exponential_backoff_retry.py @@ -25,6 +25,21 @@ from airflow.utils import timezone log = logging.getLogger(__name__) +def calculate_next_attempt_delay( + attempt_number: int, + max_delay: int = 60 * 2, + exponent_base: int = 4, +) -> timedelta: + """ + Calculate the exponential backoff (in seconds) until the next attempt. + + :param attempt_number: Number of attempts since last success. + :param max_delay: Maximum delay in seconds between retries. Default 120. + :param exponent_base: Exponent base to calculate delay. Default 4. + """ + return timedelta(seconds=min((exponent_base**attempt_number), max_delay)) + + def exponential_backoff_retry( last_attempt_time: datetime, attempts_since_last_successful: int, @@ -34,7 +49,7 @@ def exponential_backoff_retry( exponent_base: int = 4, ) -> None: """ - Retries a callable function with exponential backoff between attempts if it raises an exception. + Retry a callable function with exponential backoff between attempts if it raises an exception. :param last_attempt_time: Timestamp of last attempt call. :param attempts_since_last_successful: Number of attempts since last success. @@ -47,8 +62,10 @@ def exponential_backoff_retry( log.error("Max attempts reached. Exiting.") return - delay = min((exponent_base**attempts_since_last_successful), max_delay) - next_retry_time = last_attempt_time + timedelta(seconds=delay) + next_retry_time = last_attempt_time + calculate_next_attempt_delay( + attempt_number=attempts_since_last_successful, max_delay=max_delay, exponent_base=exponent_base + ) + current_time = timezone.utcnow() if current_time >= next_retry_time: @@ -56,5 +73,7 @@ def exponential_backoff_retry( callable_function() except Exception: log.exception("Error calling %r", callable_function.__name__) - next_delay = min((exponent_base ** (attempts_since_last_successful + 1)), max_delay) + next_delay = calculate_next_attempt_delay( + attempts_since_last_successful + 1, max_delay, exponent_base + ) log.info("Waiting for %s seconds before retrying.", next_delay) diff --git a/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py b/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py index 78c3c1bc28..04e7774555 100644 --- a/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py +++ b/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py @@ -25,6 +25,7 @@ import time from functools import partial from typing import Callable from unittest import mock +from unittest.mock import MagicMock import pytest import yaml @@ -33,7 +34,7 @@ from inflection import camelize from airflow.exceptions import AirflowException from airflow.executors.base_executor import BaseExecutor -from airflow.providers.amazon.aws.executors.ecs import ecs_executor_config +from airflow.providers.amazon.aws.executors.ecs import ecs_executor, ecs_executor_config from airflow.providers.amazon.aws.executors.ecs.boto_schema import BotoTaskSchema from airflow.providers.amazon.aws.executors.ecs.ecs_executor import ( CONFIG_GROUP_NAME, @@ -50,6 +51,7 @@ from airflow.providers.amazon.aws.executors.ecs.utils import ( from airflow.providers.amazon.aws.hooks.ecs import EcsHook from airflow.utils.helpers import convert_camel_to_snake from airflow.utils.state import State, TaskInstanceState +from airflow.utils.timezone import utcnow pytestmark = pytest.mark.db_test @@ -365,7 +367,8 @@ class TestAwsEcsExecutor: assert 1 == len(mock_executor.active_workers) assert ARN1 in mock_executor.active_workers.task_by_key(airflow_key).task_arn - def test_success_execute_api_exception(self, mock_executor): + @mock.patch.object(ecs_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0)) + def test_success_execute_api_exception(self, mock_backoff, mock_executor): """Test what happens when ECS throws an exception, but ultimately runs the task.""" run_task_exception = Exception("Test exception") run_task_success = { @@ -381,9 +384,10 @@ class TestAwsEcsExecutor: } mock_executor.ecs.run_task.side_effect = [run_task_exception, run_task_exception, run_task_success] mock_executor.execute_async(mock_airflow_key, mock_cmd) + expected_retry_count = 2 # Fail 2 times - for _ in range(2): + for _ in range(expected_retry_count): mock_executor.attempt_task_runs() # Task is not stored in active workers. assert len(mock_executor.active_workers) == 0 @@ -392,6 +396,9 @@ class TestAwsEcsExecutor: mock_executor.attempt_task_runs() assert len(mock_executor.pending_tasks) == 0 assert ARN1 in mock_executor.active_workers.get_all_arns() + assert mock_backoff.call_count == expected_retry_count + for attempt_number in range(1, expected_retry_count): + mock_backoff.assert_has_calls([mock.call(attempt_number)]) def test_failed_execute_api_exception(self, mock_executor): """Test what happens when ECS refuses to execute a task and throws an exception""" @@ -479,7 +486,8 @@ class TestAwsEcsExecutor: @mock.patch.object(BaseExecutor, "fail") @mock.patch.object(BaseExecutor, "success") - def test_failed_sync_cumulative_fail(self, success_mock, fail_mock, mock_airflow_key, mock_executor): + @mock.patch.object(ecs_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0)) + def test_failed_sync_cumulative_fail(self, _, success_mock, fail_mock, mock_airflow_key, mock_executor): """Test that failure_count/attempt_number is cumulative for pending tasks and active workers.""" AwsEcsExecutor.MAX_RUN_TASK_ATTEMPTS = "5" mock_executor.ecs.run_task.return_value = { @@ -488,6 +496,7 @@ class TestAwsEcsExecutor: {"arn": ARN1, "reason": "Sample Failure", "detail": "UnitTest Failure - Please ignore"} ], } + mock_executor._calculate_next_attempt_time = MagicMock(return_value=utcnow()) task_key = mock_airflow_key() mock_executor.execute_async(task_key, mock_cmd) for _ in range(2): diff --git a/tests/providers/amazon/aws/executors/utils/test_exponential_backoff_retry.py b/tests/providers/amazon/aws/executors/utils/test_exponential_backoff_retry.py index c7091139e1..c180184c30 100644 --- a/tests/providers/amazon/aws/executors/utils/test_exponential_backoff_retry.py +++ b/tests/providers/amazon/aws/executors/utils/test_exponential_backoff_retry.py @@ -21,7 +21,10 @@ from unittest import mock import pytest -from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry import exponential_backoff_retry +from airflow.providers.amazon.aws.executors.utils.exponential_backoff_retry import ( + calculate_next_attempt_delay, + exponential_backoff_retry, +) class TestExponentialBackoffRetry: @@ -279,3 +282,18 @@ class TestExponentialBackoffRetry: exponent_base=3, ) assert mock_callable_function.call_count == expected_calls + + def test_calculate_next_attempt_delay(self): + exponent_base: int = 4 + num_loops: int = 3 + # Setting max_delay this way means there will be three loops will run to test: + # one will return a value under max_delay, one equal to max_delay, and one over. + max_delay: int = exponent_base**num_loops - 1 + + for attempt_number in range(1, num_loops): + returned_delay = calculate_next_attempt_delay(attempt_number, max_delay, exponent_base).seconds + + if (expected_delay := exponent_base**attempt_number) <= max_delay: + assert returned_delay == expected_delay + else: + assert returned_delay == max_delay