This is an automated email from the ASF dual-hosted git repository.

onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d49fa999a9 bugfix: break down run+wait method in ECS operator (#32104)
d49fa999a9 is described below

commit d49fa999a94a2269dd6661fe5eebbb4c768c7848
Author: Raphaƫl Vandon <vand...@amazon.com>
AuthorDate: Fri Jun 23 14:31:07 2023 -0700

    bugfix: break down run+wait method in ECS operator (#32104)
    
    This method is just causing trouble by handling several things, it's hiding 
the logic.
    A bug fixed in #31838 was reintroduced in #31881 because the check that was 
skipped on `wait_for_completion` was not skipped anymore.
    
    The bug is that checking the status will always fail if not waiting for 
completion, because obviously the task is not ready just after creation.
---
 airflow/providers/amazon/aws/operators/ecs.py | 77 +++++++++++++--------------
 1 file changed, 38 insertions(+), 39 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/ecs.py 
b/airflow/providers/amazon/aws/operators/ecs.py
index 2c2e93af35..91533cfa62 100644
--- a/airflow/providers/amazon/aws/operators/ecs.py
+++ b/airflow/providers/amazon/aws/operators/ecs.py
@@ -539,46 +539,8 @@ class EcsRunTaskOperator(EcsBaseOperator):
         if self.reattach:
             self._try_reattach_task(context)
 
-        self._start_wait_task(context)
-
-        self._after_execution(session)
-
-        if self.do_xcom_push and self.task_log_fetcher:
-            return self.task_log_fetcher.get_last_log_message()
-        else:
-            return None
-
-    def execute_complete(self, context, event=None):
-        if event["status"] != "success":
-            raise AirflowException(f"Error in task execution: {event}")
-        self.arn = event["task_arn"]  # restore arn to its updated value, 
needed for next steps
-        self._after_execution()
-        if self._aws_logs_enabled():
-            # same behavior as non-deferrable mode, return last line of logs 
of the task.
-            logs_client = AwsLogsHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region).conn
-            one_log = logs_client.get_log_events(
-                logGroupName=self.awslogs_group,
-                logStreamName=self._get_logs_stream_name(),
-                startFromHead=False,
-                limit=1,
-            )
-            if len(one_log["events"]) > 0:
-                return one_log["events"][0]["message"]
-
-    @provide_session
-    def _after_execution(self, session=None):
-        self._check_success_task()
-
-        self.log.info("ECS Task has been successfully executed")
-
-        if self.reattach:
-            # Clear the XCom value storing the ECS task ARN if the task has 
completed
-            # as we can't reattach it anymore
-            self._xcom_del(session, 
self.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.task_id))
-
-    @AwsBaseHook.retry(should_retry_eni)
-    def _start_wait_task(self, context):
         if not self.arn:
+            # start the task except if we reattached to an existing one just 
before.
             self._start_task(context)
 
         if self.deferrable:
@@ -598,6 +560,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
                 # 60 seconds is added to allow the trigger to exit gracefully 
(i.e. yield TriggerEvent)
                 timeout=timedelta(seconds=self.waiter_max_attempts * 
self.waiter_delay + 60),
             )
+            # self.defer raises a special exception, so execution stops here 
in this case.
 
         if not self.wait_for_completion:
             return
@@ -615,9 +578,45 @@ class EcsRunTaskOperator(EcsBaseOperator):
         else:
             self._wait_for_task_ended()
 
+        self._after_execution(session)
+
+        if self.do_xcom_push and self.task_log_fetcher:
+            return self.task_log_fetcher.get_last_log_message()
+        else:
+            return None
+
+    def execute_complete(self, context, event=None):
+        if event["status"] != "success":
+            raise AirflowException(f"Error in task execution: {event}")
+        self.arn = event["task_arn"]  # restore arn to its updated value, 
needed for next steps
+        self._after_execution()
+        if self._aws_logs_enabled():
+            # same behavior as non-deferrable mode, return last line of logs 
of the task.
+            logs_client = AwsLogsHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region).conn
+            one_log = logs_client.get_log_events(
+                logGroupName=self.awslogs_group,
+                logStreamName=self._get_logs_stream_name(),
+                startFromHead=False,
+                limit=1,
+            )
+            if len(one_log["events"]) > 0:
+                return one_log["events"][0]["message"]
+
+    @provide_session
+    def _after_execution(self, session=None):
+        self._check_success_task()
+
+        self.log.info("ECS Task has been successfully executed")
+
+        if self.reattach:
+            # Clear the XCom value storing the ECS task ARN if the task has 
completed
+            # as we can't reattach it anymore
+            self._xcom_del(session, 
self.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.task_id))
+
     def _xcom_del(self, session, task_id):
         session.query(XCom).filter(XCom.dag_id == self.dag_id, XCom.task_id == 
task_id).delete()
 
+    @AwsBaseHook.retry(should_retry_eni)
     def _start_task(self, context):
         run_opts = {
             "cluster": self.cluster,

Reply via email to