This is an automated email from the ASF dual-hosted git repository. eladkal pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 06e3197ccb Fix bug for ECS Executor where tasks were being skipped if one task failed. (#37979) 06e3197ccb is described below commit 06e3197ccb206590b38e475b880092cf5283176c Author: Syed Hussain <103602455+syeda...@users.noreply.github.com> AuthorDate: Sat Mar 9 05:42:25 2024 -0800 Fix bug for ECS Executor where tasks were being skipped if one task failed. (#37979) Add unit tests to catch case where tasks fail --- .../amazon/aws/executors/ecs/ecs_executor.py | 10 +- .../amazon/aws/executors/ecs/test_ecs_executor.py | 168 +++++++++++++++++++++ 2 files changed, 173 insertions(+), 5 deletions(-) diff --git a/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py b/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py index 6a72459144..6167ddcbf4 100644 --- a/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +++ b/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py @@ -306,7 +306,7 @@ class AwsEcsExecutor(BaseExecutor): task_arn, ) self.active_workers.increment_failure_count(task_key) - self.pending_tasks.appendleft( + self.pending_tasks.append( EcsQueuedTask( task_key, task_cmd, @@ -350,12 +350,12 @@ class AwsEcsExecutor(BaseExecutor): try: run_task_response = self._run_task(task_key, cmd, queue, exec_config) except NoCredentialsError: - self.pending_tasks.appendleft(ecs_task) + self.pending_tasks.append(ecs_task) raise except ClientError as e: error_code = e.response["Error"]["Code"] if error_code in INVALID_CREDENTIALS_EXCEPTIONS: - self.pending_tasks.appendleft(ecs_task) + self.pending_tasks.append(ecs_task) raise _failure_reasons.append(str(e)) except Exception as e: @@ -375,12 +375,12 @@ class AwsEcsExecutor(BaseExecutor): for reason in _failure_reasons: failure_reasons[reason] += 1 # Make sure the number of attempts does not exceed MAX_RUN_TASK_ATTEMPTS - if int(attempt_number) <= int(self.__class__.MAX_RUN_TASK_ATTEMPTS): + if int(attempt_number) < int(self.__class__.MAX_RUN_TASK_ATTEMPTS): ecs_task.attempt_number += 1 ecs_task.next_attempt_time = timezone.utcnow() + calculate_next_attempt_delay( attempt_number ) - self.pending_tasks.appendleft(ecs_task) + self.pending_tasks.append(ecs_task) else: self.log.error( "ECS task %s has failed a maximum of %s times. Marking as failed", diff --git a/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py b/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py index dcfe1babc5..35ce4f0a9b 100644 --- a/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py +++ b/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py @@ -60,6 +60,29 @@ pytestmark = pytest.mark.db_test ARN1 = "arn1" ARN2 = "arn2" ARN3 = "arn3" +RUN_TASK_KWARGS = { + "cluster": "some-cluster", + "launchType": "FARGATE", + "taskDefinition": "some-task-def", + "platformVersion": "LATEST", + "count": 1, + "overrides": { + "containerOverrides": [ + { + "name": "container-name", + "command": [""], + "environment": [{"name": "AIRFLOW_IS_EXECUTOR_CONTAINER", "value": "true"}], + } + ] + }, + "networkConfiguration": { + "awsvpcConfiguration": { + "subnets": ["sub1", "sub2"], + "securityGroups": ["sg1", "sg2"], + "assignPublicIp": "DISABLED", + } + }, +} def mock_task(arn=ARN1, state=State.RUNNING): @@ -430,6 +453,151 @@ class TestAwsEcsExecutor: # Task is not stored in active workers. assert len(mock_executor.active_workers) == 0 + @mock.patch.object(ecs_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0)) + def test_attempt_task_runs_attempts_when_tasks_fail(self, _, mock_executor, caplog): + """ + Test case when all tasks fail to run. + + The executor should attempt each task exactly once per sync() iteration. + It should preserve the order of tasks, and attempt each task up to + `MAX_RUN_TASK_ATTEMPTS` times before dropping the task. + """ + airflow_keys = [mock.Mock(spec=tuple), mock.Mock(spec=tuple)] + airflow_cmd1 = mock.Mock(spec=list) + airflow_cmd2 = mock.Mock(spec=list) + caplog.set_level("ERROR") + commands = [airflow_cmd1, airflow_cmd2] + + failures = [Exception("Failure 1"), Exception("Failure 2")] + + mock_executor.execute_async(airflow_keys[0], commands[0]) + mock_executor.execute_async(airflow_keys[1], commands[1]) + + assert len(mock_executor.pending_tasks) == 2 + assert len(mock_executor.active_workers.get_all_arns()) == 0 + + mock_executor.ecs.run_task.side_effect = failures + mock_executor.attempt_task_runs() + + for i in range(2): + RUN_TASK_KWARGS["overrides"]["containerOverrides"][0]["command"] = commands[i] + assert mock_executor.ecs.run_task.call_args_list[i].kwargs == RUN_TASK_KWARGS + assert "Pending ECS tasks failed to launch for the following reasons: " in caplog.messages[0] + assert len(mock_executor.pending_tasks) == 2 + assert len(mock_executor.active_workers.get_all_arns()) == 0 + + caplog.clear() + mock_executor.ecs.run_task.call_args_list.clear() + + mock_executor.ecs.run_task.side_effect = failures + mock_executor.attempt_task_runs() + + for i in range(2): + RUN_TASK_KWARGS["overrides"]["containerOverrides"][0]["command"] = commands[i] + assert mock_executor.ecs.run_task.call_args_list[i].kwargs == RUN_TASK_KWARGS + assert "Pending ECS tasks failed to launch for the following reasons: " in caplog.messages[0] + assert len(mock_executor.pending_tasks) == 2 + assert len(mock_executor.active_workers.get_all_arns()) == 0 + + caplog.clear() + mock_executor.ecs.run_task.call_args_list.clear() + + mock_executor.ecs.run_task.side_effect = failures + mock_executor.attempt_task_runs() + + assert len(mock_executor.active_workers.get_all_arns()) == 0 + assert len(mock_executor.pending_tasks) == 0 + + assert len(caplog.messages) == 3 + for i in range(2): + assert ( + f"ECS task {airflow_keys[i]} has failed a maximum of 3 times. Marking as failed" + == caplog.messages[i] + ) + + @mock.patch.object(ecs_executor, "calculate_next_attempt_delay", return_value=dt.timedelta(seconds=0)) + def test_attempt_task_runs_attempts_when_some_tasks_fal(self, _, mock_executor, caplog): + """ + Test case when one task fail to run, and a new task gets queued. + + The executor should attempt each task exactly once per sync() iteration. + It should preserve the order of tasks, and attempt each task up to + `MAX_RUN_TASK_ATTEMPTS` times before dropping the task. If a task succeeds, the task + should be removed from pending_jobs and into active_workers. + """ + airflow_keys = [mock.Mock(spec=tuple), mock.Mock(spec=tuple)] + airflow_cmd1 = mock.Mock(spec=list) + airflow_cmd2 = mock.Mock(spec=list) + caplog.set_level("ERROR") + airflow_commands = [airflow_cmd1, airflow_cmd2] + task = { + "taskArn": ARN1, + "lastStatus": "", + "desiredStatus": "", + "containers": [{"name": "some-ecs-container"}], + } + success_response = {"tasks": [task], "failures": []} + + responses = [Exception("Failure 1"), success_response] + + mock_executor.execute_async(airflow_keys[0], airflow_commands[0]) + mock_executor.execute_async(airflow_keys[1], airflow_commands[1]) + + assert len(mock_executor.pending_tasks) == 2 + + mock_executor.ecs.run_task.side_effect = responses + mock_executor.attempt_task_runs() + + for i in range(2): + RUN_TASK_KWARGS["overrides"]["containerOverrides"][0]["command"] = airflow_commands[i] + assert mock_executor.ecs.run_task.call_args_list[i].kwargs == RUN_TASK_KWARGS + + assert len(mock_executor.pending_tasks) == 1 + assert len(mock_executor.active_workers.get_all_arns()) == 1 + + caplog.clear() + mock_executor.ecs.run_task.call_args_list.clear() + + # queue new task + airflow_keys[1] = mock.Mock(spec=tuple) + airflow_commands[1] = mock.Mock(spec=list) + mock_executor.execute_async(airflow_keys[1], airflow_commands[1]) + + assert len(mock_executor.pending_tasks) == 2 + # assert that the order of pending tasks is preserved i.e. the first task is 1st etc. + assert mock_executor.pending_tasks[0].key == airflow_keys[0] + assert mock_executor.pending_tasks[0].command == airflow_commands[0] + + task["taskArn"] = ARN2 + success_response = {"tasks": [task], "failures": []} + responses = [Exception("Failure 1"), success_response] + mock_executor.ecs.run_task.side_effect = responses + mock_executor.attempt_task_runs() + + for i in range(2): + RUN_TASK_KWARGS["overrides"]["containerOverrides"][0]["command"] = airflow_commands[i] + assert mock_executor.ecs.run_task.call_args_list[i].kwargs == RUN_TASK_KWARGS + + assert len(mock_executor.pending_tasks) == 1 + assert len(mock_executor.active_workers.get_all_arns()) == 2 + + caplog.clear() + mock_executor.ecs.run_task.call_args_list.clear() + + responses = [Exception("Failure 1")] + mock_executor.ecs.run_task.side_effect = responses + mock_executor.attempt_task_runs() + + RUN_TASK_KWARGS["overrides"]["containerOverrides"][0]["command"] = airflow_commands[0] + assert mock_executor.ecs.run_task.call_args_list[0].kwargs == RUN_TASK_KWARGS + + assert len(caplog.messages) == 2 + + assert ( + f"ECS task {airflow_keys[0]} has failed a maximum of 3 times. Marking as failed" + == caplog.messages[0] + ) + @mock.patch.object(BaseExecutor, "fail") @mock.patch.object(BaseExecutor, "success") def test_sync(self, success_mock, fail_mock, mock_executor):