This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 06e3197ccb Fix bug for ECS Executor where tasks were being skipped if 
one task failed. (#37979)
06e3197ccb is described below

commit 06e3197ccb206590b38e475b880092cf5283176c
Author: Syed Hussain <103602455+syeda...@users.noreply.github.com>
AuthorDate: Sat Mar 9 05:42:25 2024 -0800

    Fix bug for ECS Executor where tasks were being skipped if one task failed. 
(#37979)
    
    Add unit tests to catch case where tasks fail
---
 .../amazon/aws/executors/ecs/ecs_executor.py       |  10 +-
 .../amazon/aws/executors/ecs/test_ecs_executor.py  | 168 +++++++++++++++++++++
 2 files changed, 173 insertions(+), 5 deletions(-)

diff --git a/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py 
b/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
index 6a72459144..6167ddcbf4 100644
--- a/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
+++ b/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
@@ -306,7 +306,7 @@ class AwsEcsExecutor(BaseExecutor):
                 task_arn,
             )
             self.active_workers.increment_failure_count(task_key)
-            self.pending_tasks.appendleft(
+            self.pending_tasks.append(
                 EcsQueuedTask(
                     task_key,
                     task_cmd,
@@ -350,12 +350,12 @@ class AwsEcsExecutor(BaseExecutor):
             try:
                 run_task_response = self._run_task(task_key, cmd, queue, 
exec_config)
             except NoCredentialsError:
-                self.pending_tasks.appendleft(ecs_task)
+                self.pending_tasks.append(ecs_task)
                 raise
             except ClientError as e:
                 error_code = e.response["Error"]["Code"]
                 if error_code in INVALID_CREDENTIALS_EXCEPTIONS:
-                    self.pending_tasks.appendleft(ecs_task)
+                    self.pending_tasks.append(ecs_task)
                     raise
                 _failure_reasons.append(str(e))
             except Exception as e:
@@ -375,12 +375,12 @@ class AwsEcsExecutor(BaseExecutor):
                 for reason in _failure_reasons:
                     failure_reasons[reason] += 1
                 # Make sure the number of attempts does not exceed 
MAX_RUN_TASK_ATTEMPTS
-                if int(attempt_number) <= 
int(self.__class__.MAX_RUN_TASK_ATTEMPTS):
+                if int(attempt_number) < 
int(self.__class__.MAX_RUN_TASK_ATTEMPTS):
                     ecs_task.attempt_number += 1
                     ecs_task.next_attempt_time = timezone.utcnow() + 
calculate_next_attempt_delay(
                         attempt_number
                     )
-                    self.pending_tasks.appendleft(ecs_task)
+                    self.pending_tasks.append(ecs_task)
                 else:
                     self.log.error(
                         "ECS task %s has failed a maximum of %s times. Marking 
as failed",
diff --git a/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py 
b/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py
index dcfe1babc5..35ce4f0a9b 100644
--- a/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py
+++ b/tests/providers/amazon/aws/executors/ecs/test_ecs_executor.py
@@ -60,6 +60,29 @@ pytestmark = pytest.mark.db_test
 ARN1 = "arn1"
 ARN2 = "arn2"
 ARN3 = "arn3"
+RUN_TASK_KWARGS = {
+    "cluster": "some-cluster",
+    "launchType": "FARGATE",
+    "taskDefinition": "some-task-def",
+    "platformVersion": "LATEST",
+    "count": 1,
+    "overrides": {
+        "containerOverrides": [
+            {
+                "name": "container-name",
+                "command": [""],
+                "environment": [{"name": "AIRFLOW_IS_EXECUTOR_CONTAINER", 
"value": "true"}],
+            }
+        ]
+    },
+    "networkConfiguration": {
+        "awsvpcConfiguration": {
+            "subnets": ["sub1", "sub2"],
+            "securityGroups": ["sg1", "sg2"],
+            "assignPublicIp": "DISABLED",
+        }
+    },
+}
 
 
 def mock_task(arn=ARN1, state=State.RUNNING):
@@ -430,6 +453,151 @@ class TestAwsEcsExecutor:
             # Task is not stored in active workers.
             assert len(mock_executor.active_workers) == 0
 
+    @mock.patch.object(ecs_executor, "calculate_next_attempt_delay", 
return_value=dt.timedelta(seconds=0))
+    def test_attempt_task_runs_attempts_when_tasks_fail(self, _, 
mock_executor, caplog):
+        """
+        Test case when all tasks fail to run.
+
+        The executor should attempt each task exactly once per sync() 
iteration.
+        It should preserve the order of tasks, and attempt each task up to
+        `MAX_RUN_TASK_ATTEMPTS` times before dropping the task.
+        """
+        airflow_keys = [mock.Mock(spec=tuple), mock.Mock(spec=tuple)]
+        airflow_cmd1 = mock.Mock(spec=list)
+        airflow_cmd2 = mock.Mock(spec=list)
+        caplog.set_level("ERROR")
+        commands = [airflow_cmd1, airflow_cmd2]
+
+        failures = [Exception("Failure 1"), Exception("Failure 2")]
+
+        mock_executor.execute_async(airflow_keys[0], commands[0])
+        mock_executor.execute_async(airflow_keys[1], commands[1])
+
+        assert len(mock_executor.pending_tasks) == 2
+        assert len(mock_executor.active_workers.get_all_arns()) == 0
+
+        mock_executor.ecs.run_task.side_effect = failures
+        mock_executor.attempt_task_runs()
+
+        for i in range(2):
+            RUN_TASK_KWARGS["overrides"]["containerOverrides"][0]["command"] = 
commands[i]
+            assert mock_executor.ecs.run_task.call_args_list[i].kwargs == 
RUN_TASK_KWARGS
+        assert "Pending ECS tasks failed to launch for the following reasons: 
" in caplog.messages[0]
+        assert len(mock_executor.pending_tasks) == 2
+        assert len(mock_executor.active_workers.get_all_arns()) == 0
+
+        caplog.clear()
+        mock_executor.ecs.run_task.call_args_list.clear()
+
+        mock_executor.ecs.run_task.side_effect = failures
+        mock_executor.attempt_task_runs()
+
+        for i in range(2):
+            RUN_TASK_KWARGS["overrides"]["containerOverrides"][0]["command"] = 
commands[i]
+            assert mock_executor.ecs.run_task.call_args_list[i].kwargs == 
RUN_TASK_KWARGS
+        assert "Pending ECS tasks failed to launch for the following reasons: 
" in caplog.messages[0]
+        assert len(mock_executor.pending_tasks) == 2
+        assert len(mock_executor.active_workers.get_all_arns()) == 0
+
+        caplog.clear()
+        mock_executor.ecs.run_task.call_args_list.clear()
+
+        mock_executor.ecs.run_task.side_effect = failures
+        mock_executor.attempt_task_runs()
+
+        assert len(mock_executor.active_workers.get_all_arns()) == 0
+        assert len(mock_executor.pending_tasks) == 0
+
+        assert len(caplog.messages) == 3
+        for i in range(2):
+            assert (
+                f"ECS task {airflow_keys[i]} has failed a maximum of 3 times. 
Marking as failed"
+                == caplog.messages[i]
+            )
+
+    @mock.patch.object(ecs_executor, "calculate_next_attempt_delay", 
return_value=dt.timedelta(seconds=0))
+    def test_attempt_task_runs_attempts_when_some_tasks_fal(self, _, 
mock_executor, caplog):
+        """
+        Test case when one task fail to run, and a new task gets queued.
+
+        The executor should attempt each task exactly once per sync() 
iteration.
+        It should preserve the order of tasks, and attempt each task up to
+        `MAX_RUN_TASK_ATTEMPTS` times before dropping the task. If a task 
succeeds, the task
+        should be removed from pending_jobs and into active_workers.
+        """
+        airflow_keys = [mock.Mock(spec=tuple), mock.Mock(spec=tuple)]
+        airflow_cmd1 = mock.Mock(spec=list)
+        airflow_cmd2 = mock.Mock(spec=list)
+        caplog.set_level("ERROR")
+        airflow_commands = [airflow_cmd1, airflow_cmd2]
+        task = {
+            "taskArn": ARN1,
+            "lastStatus": "",
+            "desiredStatus": "",
+            "containers": [{"name": "some-ecs-container"}],
+        }
+        success_response = {"tasks": [task], "failures": []}
+
+        responses = [Exception("Failure 1"), success_response]
+
+        mock_executor.execute_async(airflow_keys[0], airflow_commands[0])
+        mock_executor.execute_async(airflow_keys[1], airflow_commands[1])
+
+        assert len(mock_executor.pending_tasks) == 2
+
+        mock_executor.ecs.run_task.side_effect = responses
+        mock_executor.attempt_task_runs()
+
+        for i in range(2):
+            RUN_TASK_KWARGS["overrides"]["containerOverrides"][0]["command"] = 
airflow_commands[i]
+            assert mock_executor.ecs.run_task.call_args_list[i].kwargs == 
RUN_TASK_KWARGS
+
+        assert len(mock_executor.pending_tasks) == 1
+        assert len(mock_executor.active_workers.get_all_arns()) == 1
+
+        caplog.clear()
+        mock_executor.ecs.run_task.call_args_list.clear()
+
+        # queue new task
+        airflow_keys[1] = mock.Mock(spec=tuple)
+        airflow_commands[1] = mock.Mock(spec=list)
+        mock_executor.execute_async(airflow_keys[1], airflow_commands[1])
+
+        assert len(mock_executor.pending_tasks) == 2
+        # assert that the order of pending tasks is preserved i.e. the first 
task is 1st etc.
+        assert mock_executor.pending_tasks[0].key == airflow_keys[0]
+        assert mock_executor.pending_tasks[0].command == airflow_commands[0]
+
+        task["taskArn"] = ARN2
+        success_response = {"tasks": [task], "failures": []}
+        responses = [Exception("Failure 1"), success_response]
+        mock_executor.ecs.run_task.side_effect = responses
+        mock_executor.attempt_task_runs()
+
+        for i in range(2):
+            RUN_TASK_KWARGS["overrides"]["containerOverrides"][0]["command"] = 
airflow_commands[i]
+            assert mock_executor.ecs.run_task.call_args_list[i].kwargs == 
RUN_TASK_KWARGS
+
+        assert len(mock_executor.pending_tasks) == 1
+        assert len(mock_executor.active_workers.get_all_arns()) == 2
+
+        caplog.clear()
+        mock_executor.ecs.run_task.call_args_list.clear()
+
+        responses = [Exception("Failure 1")]
+        mock_executor.ecs.run_task.side_effect = responses
+        mock_executor.attempt_task_runs()
+
+        RUN_TASK_KWARGS["overrides"]["containerOverrides"][0]["command"] = 
airflow_commands[0]
+        assert mock_executor.ecs.run_task.call_args_list[0].kwargs == 
RUN_TASK_KWARGS
+
+        assert len(caplog.messages) == 2
+
+        assert (
+            f"ECS task {airflow_keys[0]} has failed a maximum of 3 times. 
Marking as failed"
+            == caplog.messages[0]
+        )
+
     @mock.patch.object(BaseExecutor, "fail")
     @mock.patch.object(BaseExecutor, "success")
     def test_sync(self, success_mock, fail_mock, mock_executor):

Reply via email to