This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 415e076761 Deferrable mode for ECS operators (#31881)
415e076761 is described below

commit 415e0767616121854b6a29b3e44387f708cdf81e
Author: Raphaƫl Vandon <vand...@amazon.com>
AuthorDate: Fri Jun 23 10:13:13 2023 -0700

    Deferrable mode for ECS operators (#31881)
---
 airflow/providers/amazon/aws/operators/ecs.py      | 160 ++++++++++++++---
 airflow/providers/amazon/aws/triggers/ecs.py       | 198 +++++++++++++++++++++
 .../providers/amazon/aws/utils/task_log_fetcher.py |   5 +-
 airflow/providers/amazon/provider.yaml             |   3 +
 tests/providers/amazon/aws/operators/test_ecs.py   |  74 +++++++-
 tests/providers/amazon/aws/triggers/test_ecs.py    | 123 +++++++++++++
 .../amazon/aws/utils/test_task_log_fetcher.py      |   2 +-
 7 files changed, 532 insertions(+), 33 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/ecs.py 
b/airflow/providers/amazon/aws/operators/ecs.py
index bc8c4b70d7..2c2e93af35 100644
--- a/airflow/providers/amazon/aws/operators/ecs.py
+++ b/airflow/providers/amazon/aws/operators/ecs.py
@@ -35,6 +35,11 @@ from airflow.providers.amazon.aws.hooks.ecs import (
     EcsHook,
     should_retry_eni,
 )
+from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
+from airflow.providers.amazon.aws.triggers.ecs import (
+    ClusterWaiterTrigger,
+    TaskDoneTrigger,
+)
 from airflow.providers.amazon.aws.utils.task_log_fetcher import 
AwsTaskLogFetcher
 from airflow.utils.helpers import prune_dict
 from airflow.utils.session import provide_session
@@ -67,6 +72,15 @@ class EcsBaseOperator(BaseOperator):
         """Must overwrite in child classes."""
         raise NotImplementedError("Please implement execute() in subclass")
 
+    def _complete_exec_with_cluster_desc(self, context, event=None):
+        """To be used as trigger callback for operators that return the 
cluster description."""
+        if event["status"] != "success":
+            raise AirflowException(f"Error while waiting for operation on 
cluster to complete: {event}")
+        cluster_arn = event.get("arn")
+        # We cannot get the cluster definition from the waiter on success, so 
we have to query it here.
+        details = 
self.hook.conn.describe_clusters(clusters=[cluster_arn])["clusters"][0]
+        return details
+
 
 class EcsCreateClusterOperator(EcsBaseOperator):
     """
@@ -84,9 +98,17 @@ class EcsCreateClusterOperator(EcsBaseOperator):
         if not set then the default waiter value will be used.
     :param waiter_max_attempts: The maximum number of attempts to be made,
         if not set then the default waiter value will be used.
+    :param deferrable: If True, the operator will wait asynchronously for the 
job to complete.
+        This implies waiting for completion. This mode requires aiobotocore 
module to be installed.
+        (default: False)
     """
 
-    template_fields: Sequence[str] = ("cluster_name", "create_cluster_kwargs", 
"wait_for_completion")
+    template_fields: Sequence[str] = (
+        "cluster_name",
+        "create_cluster_kwargs",
+        "wait_for_completion",
+        "deferrable",
+    )
 
     def __init__(
         self,
@@ -94,8 +116,9 @@ class EcsCreateClusterOperator(EcsBaseOperator):
         cluster_name: str,
         create_cluster_kwargs: dict | None = None,
         wait_for_completion: bool = True,
-        waiter_delay: int | None = None,
-        waiter_max_attempts: int | None = None,
+        waiter_delay: int = 15,
+        waiter_max_attempts: int = 60,
+        deferrable: bool = False,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -104,6 +127,7 @@ class EcsCreateClusterOperator(EcsBaseOperator):
         self.wait_for_completion = wait_for_completion
         self.waiter_delay = waiter_delay
         self.waiter_max_attempts = waiter_max_attempts
+        self.deferrable = deferrable
 
     def execute(self, context: Context):
         self.log.info(
@@ -119,6 +143,21 @@ class EcsCreateClusterOperator(EcsBaseOperator):
             # In some circumstances the ECS Cluster is created immediately,
             # and there is no reason to wait for completion.
             self.log.info("Cluster %r in state: %r.", self.cluster_name, 
cluster_state)
+        elif self.deferrable:
+            self.defer(
+                trigger=ClusterWaiterTrigger(
+                    waiter_name="cluster_active",
+                    cluster_arn=cluster_details["clusterArn"],
+                    waiter_delay=self.waiter_delay,
+                    waiter_max_attempts=self.waiter_max_attempts,
+                    aws_conn_id=self.aws_conn_id,
+                    region=self.region,
+                ),
+                method_name="_complete_exec_with_cluster_desc",
+                # timeout is set to ensure that if a trigger dies, the timeout 
does not restart
+                # 60 seconds is added to allow the trigger to exit gracefully 
(i.e. yield TriggerEvent)
+                timeout=timedelta(seconds=self.waiter_max_attempts * 
self.waiter_delay + 60),
+            )
         elif self.wait_for_completion:
             waiter = self.hook.get_waiter("cluster_active")
             waiter.wait(
@@ -148,17 +187,21 @@ class EcsDeleteClusterOperator(EcsBaseOperator):
         if not set then the default waiter value will be used.
     :param waiter_max_attempts: The maximum number of attempts to be made,
         if not set then the default waiter value will be used.
+    :param deferrable: If True, the operator will wait asynchronously for the 
job to complete.
+        This implies waiting for completion. This mode requires aiobotocore 
module to be installed.
+        (default: False)
     """
 
-    template_fields: Sequence[str] = ("cluster_name", "wait_for_completion")
+    template_fields: Sequence[str] = ("cluster_name", "wait_for_completion", 
"deferrable")
 
     def __init__(
         self,
         *,
         cluster_name: str,
         wait_for_completion: bool = True,
-        waiter_delay: int | None = None,
-        waiter_max_attempts: int | None = None,
+        waiter_delay: int = 15,
+        waiter_max_attempts: int = 60,
+        deferrable: bool = False,
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -166,6 +209,7 @@ class EcsDeleteClusterOperator(EcsBaseOperator):
         self.wait_for_completion = wait_for_completion
         self.waiter_delay = waiter_delay
         self.waiter_max_attempts = waiter_max_attempts
+        self.deferrable = deferrable
 
     def execute(self, context: Context):
         self.log.info("Deleting cluster %r.", self.cluster_name)
@@ -174,9 +218,24 @@ class EcsDeleteClusterOperator(EcsBaseOperator):
         cluster_state = cluster_details.get("status")
 
         if cluster_state == EcsClusterStates.INACTIVE:
-            # In some circumstances the ECS Cluster is deleted immediately,
-            # so there is no reason to wait for completion.
+            # if the cluster doesn't have capacity providers that are 
associated with it,
+            # the deletion is instantaneous, and we don't need to wait for it.
             self.log.info("Cluster %r in state: %r.", self.cluster_name, 
cluster_state)
+        elif self.deferrable:
+            self.defer(
+                trigger=ClusterWaiterTrigger(
+                    waiter_name="cluster_inactive",
+                    cluster_arn=cluster_details["clusterArn"],
+                    waiter_delay=self.waiter_delay,
+                    waiter_max_attempts=self.waiter_max_attempts,
+                    aws_conn_id=self.aws_conn_id,
+                    region=self.region,
+                ),
+                method_name="_complete_exec_with_cluster_desc",
+                # timeout is set to ensure that if a trigger dies, the timeout 
does not restart
+                # 60 seconds is added to allow the trigger to exit gracefully 
(i.e. yield TriggerEvent)
+                timeout=timedelta(seconds=self.waiter_max_attempts * 
self.waiter_delay + 60),
+            )
         elif self.wait_for_completion:
             waiter = self.hook.get_waiter("cluster_inactive")
             waiter.wait(
@@ -347,6 +406,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
         finished.
     :param awslogs_fetch_interval: the interval that the ECS task log fetcher 
should wait
         in between each Cloudwatch logs fetches.
+        If deferrable is set to True, that parameter is ignored and 
waiter_delay is used instead.
     :param quota_retry: Config if and how to retry the launch of a new ECS 
task, to handle
         transient errors.
     :param reattach: If set to True, will check if the task previously 
launched by the task_instance
@@ -361,6 +421,9 @@ class EcsRunTaskOperator(EcsBaseOperator):
         if not set then the default waiter value will be used.
     :param waiter_max_attempts: The maximum number of attempts to be made,
         if not set then the default waiter value will be used.
+    :param deferrable: If True, the operator will wait asynchronously for the 
job to complete.
+        This implies waiting for completion. This mode requires aiobotocore 
module to be installed.
+        (default: False)
     """
 
     ui_color = "#f0ede4"
@@ -384,6 +447,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
         "reattach",
         "number_logs_exception",
         "wait_for_completion",
+        "deferrable",
     )
     template_fields_renderers = {
         "overrides": "json",
@@ -416,8 +480,9 @@ class EcsRunTaskOperator(EcsBaseOperator):
         reattach: bool = False,
         number_logs_exception: int = 10,
         wait_for_completion: bool = True,
-        waiter_delay: int | None = None,
-        waiter_max_attempts: int | None = None,
+        waiter_delay: int = 6,
+        waiter_max_attempts: int = 100,
+        deferrable: bool = False,
         **kwargs,
     ):
         super().__init__(**kwargs)
@@ -451,6 +516,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
         self.wait_for_completion = wait_for_completion
         self.waiter_delay = waiter_delay
         self.waiter_max_attempts = waiter_max_attempts
+        self.deferrable = deferrable
 
         if self._aws_logs_enabled() and not self.wait_for_completion:
             self.log.warning(
@@ -473,7 +539,35 @@ class EcsRunTaskOperator(EcsBaseOperator):
         if self.reattach:
             self._try_reattach_task(context)
 
-        self._start_wait_check_task(context)
+        self._start_wait_task(context)
+
+        self._after_execution(session)
+
+        if self.do_xcom_push and self.task_log_fetcher:
+            return self.task_log_fetcher.get_last_log_message()
+        else:
+            return None
+
+    def execute_complete(self, context, event=None):
+        if event["status"] != "success":
+            raise AirflowException(f"Error in task execution: {event}")
+        self.arn = event["task_arn"]  # restore arn to its updated value, 
needed for next steps
+        self._after_execution()
+        if self._aws_logs_enabled():
+            # same behavior as non-deferrable mode, return last line of logs 
of the task.
+            logs_client = AwsLogsHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region).conn
+            one_log = logs_client.get_log_events(
+                logGroupName=self.awslogs_group,
+                logStreamName=self._get_logs_stream_name(),
+                startFromHead=False,
+                limit=1,
+            )
+            if len(one_log["events"]) > 0:
+                return one_log["events"][0]["message"]
+
+    @provide_session
+    def _after_execution(self, session=None):
+        self._check_success_task()
 
         self.log.info("ECS Task has been successfully executed")
 
@@ -482,16 +576,29 @@ class EcsRunTaskOperator(EcsBaseOperator):
             # as we can't reattach it anymore
             self._xcom_del(session, 
self.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.task_id))
 
-        if self.do_xcom_push and self.task_log_fetcher:
-            return self.task_log_fetcher.get_last_log_message()
-
-        return None
-
     @AwsBaseHook.retry(should_retry_eni)
-    def _start_wait_check_task(self, context):
+    def _start_wait_task(self, context):
         if not self.arn:
             self._start_task(context)
 
+        if self.deferrable:
+            self.defer(
+                trigger=TaskDoneTrigger(
+                    cluster=self.cluster,
+                    task_arn=self.arn,
+                    waiter_delay=self.waiter_delay,
+                    waiter_max_attempts=self.waiter_max_attempts,
+                    aws_conn_id=self.aws_conn_id,
+                    region=self.region,
+                    log_group=self.awslogs_group,
+                    log_stream=self._get_logs_stream_name(),
+                ),
+                method_name="execute_complete",
+                # timeout is set to ensure that if a trigger dies, the timeout 
does not restart
+                # 60 seconds is added to allow the trigger to exit gracefully 
(i.e. yield TriggerEvent)
+                timeout=timedelta(seconds=self.waiter_max_attempts * 
self.waiter_delay + 60),
+            )
+
         if not self.wait_for_completion:
             return
 
@@ -508,8 +615,6 @@ class EcsRunTaskOperator(EcsBaseOperator):
         else:
             self._wait_for_task_ended()
 
-        self._check_success_task()
-
     def _xcom_del(self, session, task_id):
         session.query(XCom).filter(XCom.dag_id == self.dag_id, XCom.task_id == 
task_id).delete()
 
@@ -584,12 +689,10 @@ class EcsRunTaskOperator(EcsBaseOperator):
         waiter.wait(
             cluster=self.cluster,
             tasks=[self.arn],
-            WaiterConfig=prune_dict(
-                {
-                    "Delay": self.waiter_delay,
-                    "MaxAttempts": self.waiter_max_attempts,
-                }
-            ),
+            WaiterConfig={
+                "Delay": self.waiter_delay,
+                "MaxAttempts": self.waiter_max_attempts,
+            },
         )
 
         return
@@ -597,20 +700,23 @@ class EcsRunTaskOperator(EcsBaseOperator):
     def _aws_logs_enabled(self):
         return self.awslogs_group and self.awslogs_stream_prefix
 
+    def _get_logs_stream_name(self) -> str:
+        return 
f"{self.awslogs_stream_prefix}/{self._get_ecs_task_id(self.arn)}"
+
     def _get_task_log_fetcher(self) -> AwsTaskLogFetcher:
         if not self.awslogs_group:
             raise ValueError("must specify awslogs_group to fetch task logs")
-        log_stream_name = 
f"{self.awslogs_stream_prefix}/{self._get_ecs_task_id(self.arn)}"
 
         return AwsTaskLogFetcher(
             aws_conn_id=self.aws_conn_id,
             region_name=self.awslogs_region,
             log_group=self.awslogs_group,
-            log_stream_name=log_stream_name,
+            log_stream_name=self._get_logs_stream_name(),
             fetch_interval=self.awslogs_fetch_interval,
             logger=self.log,
         )
 
+    @AwsBaseHook.retry(should_retry_eni)
     def _check_success_task(self) -> None:
         if not self.client or not self.arn:
             return
diff --git a/airflow/providers/amazon/aws/triggers/ecs.py 
b/airflow/providers/amazon/aws/triggers/ecs.py
new file mode 100644
index 0000000000..8ba8350588
--- /dev/null
+++ b/airflow/providers/amazon/aws/triggers/ecs.py
@@ -0,0 +1,198 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import asyncio
+from typing import Any, AsyncIterator
+
+from botocore.exceptions import ClientError, WaiterError
+
+from airflow.providers.amazon.aws.hooks.ecs import EcsHook
+from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
+from airflow.providers.amazon.aws.utils.task_log_fetcher import 
AwsTaskLogFetcher
+from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class ClusterWaiterTrigger(BaseTrigger):
+    """
+    Polls the status of a cluster using a given waiter. Can be used to poll 
for an active or inactive cluster.
+
+    :param waiter_name: Name of the waiter to use, for instance 
'cluster_active' or 'cluster_inactive'
+    :param cluster_arn: ARN of the cluster to watch.
+    :param waiter_delay: The amount of time in seconds to wait between 
attempts.
+    :param waiter_max_attempts: The number of times to ping for status.
+        Will fail after that many unsuccessful attempts.
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+    :param region: The AWS region where the cluster is located.
+    """
+
+    def __init__(
+        self,
+        waiter_name: str,
+        cluster_arn: str,
+        waiter_delay: int | None,
+        waiter_max_attempts: int | None,
+        aws_conn_id: str | None,
+        region: str | None,
+    ):
+        self.cluster_arn = cluster_arn
+        self.waiter_name = waiter_name
+        self.waiter_delay = waiter_delay if waiter_delay is not None else 15  
# written like this to allow 0
+        self.attempts = waiter_max_attempts or 999999999
+        self.aws_conn_id = aws_conn_id
+        self.region = region
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        return (
+            self.__class__.__module__ + "." + self.__class__.__qualname__,
+            {
+                "waiter_name": self.waiter_name,
+                "cluster_arn": self.cluster_arn,
+                "waiter_delay": self.waiter_delay,
+                "waiter_max_attempts": self.attempts,
+                "aws_conn_id": self.aws_conn_id,
+                "region": self.region,
+            },
+        )
+
+    async def run(self) -> AsyncIterator[TriggerEvent]:
+        async with EcsHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region).async_conn as client:
+            waiter = client.get_waiter(self.waiter_name)
+            await async_wait(
+                waiter,
+                self.waiter_delay,
+                self.attempts,
+                {"clusters": [self.cluster_arn]},
+                "error when checking cluster status",
+                "Status of cluster",
+                ["clusters[].status"],
+            )
+            yield TriggerEvent({"status": "success", "arn": self.cluster_arn})
+
+
+class TaskDoneTrigger(BaseTrigger):
+    """
+    Waits for an ECS task to be done, while eventually polling logs.
+
+    :param cluster: short name or full ARN of the cluster where the task is 
running.
+    :param task_arn: ARN of the task to watch.
+    :param waiter_delay: The amount of time in seconds to wait between 
attempts.
+    :param waiter_max_attempts: The number of times to ping for status.
+        Will fail after that many unsuccessful attempts.
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+    :param region: The AWS region where the cluster is located.
+    """
+
+    def __init__(
+        self,
+        cluster: str,
+        task_arn: str,
+        waiter_delay: int,
+        waiter_max_attempts: int,
+        aws_conn_id: str | None,
+        region: str | None,
+        log_group: str | None = None,
+        log_stream: str | None = None,
+    ):
+        self.cluster = cluster
+        self.task_arn = task_arn
+
+        self.waiter_delay = waiter_delay
+        self.waiter_max_attempts = waiter_max_attempts
+        self.aws_conn_id = aws_conn_id
+        self.region = region
+
+        self.log_group = log_group
+        self.log_stream = log_stream
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        return (
+            self.__class__.__module__ + "." + self.__class__.__qualname__,
+            {
+                "cluster": self.cluster,
+                "task_arn": self.task_arn,
+                "waiter_delay": self.waiter_delay,
+                "waiter_max_attempts": self.waiter_max_attempts,
+                "aws_conn_id": self.aws_conn_id,
+                "region": self.region,
+                "log_group": self.log_group,
+                "log_stream": self.log_stream,
+            },
+        )
+
+    async def run(self) -> AsyncIterator[TriggerEvent]:
+        # fmt: off
+        async with EcsHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region).async_conn as ecs_client,\
+                AwsLogsHook(aws_conn_id=self.aws_conn_id, 
region_name=self.region).async_conn as logs_client:
+            # fmt: on
+            waiter = ecs_client.get_waiter("tasks_stopped")
+            logs_token = None
+            while self.waiter_max_attempts >= 1:
+                self.waiter_max_attempts = self.waiter_max_attempts - 1
+                try:
+                    await waiter.wait(
+                        cluster=self.cluster, tasks=[self.task_arn], 
WaiterConfig={"MaxAttempts": 1}
+                    )
+                    break  # we reach this point only if the waiter met a 
success criteria
+                except WaiterError as error:
+                    if "terminal failure" in str(error):
+                        raise
+                    self.log.info("Status of the task is %s", 
error.last_response["tasks"][0]["lastStatus"])
+                    await asyncio.sleep(int(self.waiter_delay))
+                finally:
+                    if self.log_group and self.log_stream:
+                        logs_token = await self._forward_logs(logs_client, 
logs_token)
+
+        yield TriggerEvent({"status": "success", "task_arn": self.task_arn})
+
+    async def _forward_logs(self, logs_client, next_token: str | None = None) 
-> str | None:
+        """
+        Reads logs from the cloudwatch stream and prints them to the task logs.
+        :return: the token to pass to the next iteration to resume where we 
started.
+        """
+        while True:
+            if next_token is not None:
+                token_arg: dict[str, str] = {"nextToken": next_token}
+            else:
+                token_arg = {}
+            try:
+                response = await logs_client.get_log_events(
+                    logGroupName=self.log_group,
+                    logStreamName=self.log_stream,
+                    startFromHead=True,
+                    **token_arg,
+                )
+            except ClientError as ce:
+                if ce.response["Error"]["Code"] == "ResourceNotFoundException":
+                    self.log.info(
+                        "Tried to get logs from stream %s in group %s but it 
didn't exist (yet). "
+                        "Will try again.",
+                        self.log_stream,
+                        self.log_group,
+                    )
+                    return None
+                raise
+
+            events = response["events"]
+            for log_event in events:
+                self.log.info(AwsTaskLogFetcher.event_to_str(log_event))
+
+            if len(events) == 0 or next_token == response["nextForwardToken"]:
+                return response["nextForwardToken"]
+            next_token = response["nextForwardToken"]
diff --git a/airflow/providers/amazon/aws/utils/task_log_fetcher.py 
b/airflow/providers/amazon/aws/utils/task_log_fetcher.py
index 97b43a67b2..22a5e5f2a1 100644
--- a/airflow/providers/amazon/aws/utils/task_log_fetcher.py
+++ b/airflow/providers/amazon/aws/utils/task_log_fetcher.py
@@ -62,7 +62,7 @@ class AwsTaskLogFetcher(Thread):
             time.sleep(self.fetch_interval.total_seconds())
             log_events = self._get_log_events(continuation_token)
             for log_event in log_events:
-                self.logger.info(self._event_to_str(log_event))
+                self.logger.info(self.event_to_str(log_event))
 
     def _get_log_events(self, skip_token: AwsLogsHook.ContinuationToken | None 
= None) -> Generator:
         if skip_token is None:
@@ -87,7 +87,8 @@ class AwsTaskLogFetcher(Thread):
             self.logger.warning("ConnectionClosedError on retrieving 
Cloudwatch log events", error)
             yield from ()
 
-    def _event_to_str(self, event: dict) -> str:
+    @staticmethod
+    def event_to_str(event: dict) -> str:
         event_dt = datetime.utcfromtimestamp(event["timestamp"] / 1000.0)
         formatted_event_dt = event_dt.strftime("%Y-%m-%d %H:%M:%S,%f")[:-3]
         message = event["message"]
diff --git a/airflow/providers/amazon/provider.yaml 
b/airflow/providers/amazon/provider.yaml
index 3680915dc3..223bc553ef 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -532,6 +532,9 @@ triggers:
   - integration-name: Amazon Elastic Kubernetes Service (EKS)
     python-modules:
       - airflow.providers.amazon.aws.triggers.eks
+  - integration-name: Amazon ECS
+    python-modules:
+      - airflow.providers.amazon.aws.triggers.ecs
 
 transfers:
   - source-integration-name: Amazon DynamoDB
diff --git a/tests/providers/amazon/aws/operators/test_ecs.py 
b/tests/providers/amazon/aws/operators/test_ecs.py
index 8c99a02ce8..b89ea59c56 100644
--- a/tests/providers/amazon/aws/operators/test_ecs.py
+++ b/tests/providers/amazon/aws/operators/test_ecs.py
@@ -20,13 +20,14 @@ from __future__ import annotations
 import sys
 from copy import deepcopy
 from unittest import mock
+from unittest.mock import MagicMock, PropertyMock
 
 import boto3
 import pytest
 
-from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
+from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning, TaskDeferred
 from airflow.providers.amazon.aws.exceptions import EcsOperatorError, 
EcsTaskFailToStart
-from airflow.providers.amazon.aws.hooks.ecs import EcsHook
+from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook
 from airflow.providers.amazon.aws.operators.ecs import (
     DEFAULT_CONN_ID,
     EcsBaseOperator,
@@ -36,6 +37,7 @@ from airflow.providers.amazon.aws.operators.ecs import (
     EcsRegisterTaskDefinitionOperator,
     EcsRunTaskOperator,
 )
+from airflow.providers.amazon.aws.triggers.ecs import TaskDoneTrigger
 from airflow.providers.amazon.aws.utils.task_log_fetcher import 
AwsTaskLogFetcher
 from airflow.utils.types import NOTSET
 
@@ -186,6 +188,7 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
             "reattach",
             "number_logs_exception",
             "wait_for_completion",
+            "deferrable",
         )
 
     @pytest.mark.parametrize(
@@ -343,7 +346,7 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
         self.ecs._wait_for_task_ended()
         client_mock.get_waiter.assert_called_once_with("tasks_stopped")
         client_mock.get_waiter.return_value.wait.assert_called_once_with(
-            cluster="c", tasks=["arn"], WaiterConfig={}
+            cluster="c", tasks=["arn"], WaiterConfig={"Delay": 6, 
"MaxAttempts": 100}
         )
         assert sys.maxsize == 
client_mock.get_waiter.return_value.config.max_attempts
 
@@ -654,6 +657,31 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
         self.ecs.do_xcom_push = False
         assert self.ecs.execute(None) is None
 
+    @mock.patch.object(EcsRunTaskOperator, "client")
+    def test_with_defer(self, client_mock):
+        self.ecs.deferrable = True
+
+        client_mock.run_task.return_value = RESPONSE_WITHOUT_FAILURES
+
+        with pytest.raises(TaskDeferred) as deferred:
+            self.ecs.execute(None)
+
+        assert isinstance(deferred.value.trigger, TaskDoneTrigger)
+        assert deferred.value.trigger.task_arn == 
f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}"
+
+    @mock.patch.object(EcsRunTaskOperator, "client", new_callable=PropertyMock)
+    @mock.patch.object(EcsRunTaskOperator, "_xcom_del")
+    def test_execute_complete(self, xcom_del_mock: MagicMock, client_mock):
+        event = {"status": "success", "task_arn": "my_arn"}
+        self.ecs.reattach = True
+
+        self.ecs.execute_complete(None, event)
+
+        # task gets described to assert its success
+        client_mock().describe_tasks.assert_called_once_with(cluster="c", 
tasks=["my_arn"])
+        # if reattach mode, xcom value is deleted on success
+        xcom_del_mock.assert_called_once()
+
 
 class TestEcsCreateClusterOperator(EcsBaseTestCase):
     @pytest.mark.parametrize("waiter_delay, waiter_max_attempts", 
WAITERS_TEST_CASES)
@@ -680,6 +708,26 @@ class TestEcsCreateClusterOperator(EcsBaseTestCase):
         mocked_waiters.wait.assert_called_once_with(clusters=mock.ANY, 
WaiterConfig=expected_waiter_config)
         assert result is not None
 
+    @mock.patch.object(EcsCreateClusterOperator, "client")
+    def test_execute_deferrable(self, mock_client: MagicMock):
+        op = EcsCreateClusterOperator(
+            task_id="task",
+            cluster_name=CLUSTER_NAME,
+            deferrable=True,
+            waiter_delay=12,
+            waiter_max_attempts=34,
+        )
+        mock_client.create_cluster.return_value = {
+            "cluster": {"status": EcsClusterStates.PROVISIONING, "clusterArn": 
"my arn"}
+        }
+
+        with pytest.raises(TaskDeferred) as defer:
+            op.execute(None)
+
+        assert defer.value.trigger.cluster_arn == "my arn"
+        assert defer.value.trigger.waiter_delay == 12
+        assert defer.value.trigger.attempts == 34
+
     def test_execute_immediate_create(self, patch_hook_waiters):
         """Test if cluster created during initial request."""
         op = EcsCreateClusterOperator(task_id="task", 
cluster_name=CLUSTER_NAME, wait_for_completion=True)
@@ -725,6 +773,26 @@ class TestEcsDeleteClusterOperator(EcsBaseTestCase):
         mocked_waiters.wait.assert_called_once_with(clusters=mock.ANY, 
WaiterConfig=expected_waiter_config)
         assert result is not None
 
+    @mock.patch.object(EcsDeleteClusterOperator, "client")
+    def test_execute_deferrable(self, mock_client: MagicMock):
+        op = EcsDeleteClusterOperator(
+            task_id="task",
+            cluster_name=CLUSTER_NAME,
+            deferrable=True,
+            waiter_delay=12,
+            waiter_max_attempts=34,
+        )
+        mock_client.delete_cluster.return_value = {
+            "cluster": {"status": EcsClusterStates.DEPROVISIONING, 
"clusterArn": "my arn"}
+        }
+
+        with pytest.raises(TaskDeferred) as defer:
+            op.execute(None)
+
+        assert defer.value.trigger.cluster_arn == "my arn"
+        assert defer.value.trigger.waiter_delay == 12
+        assert defer.value.trigger.attempts == 34
+
     def test_execute_immediate_delete(self, patch_hook_waiters):
         """Test if cluster deleted during initial request."""
         op = EcsDeleteClusterOperator(task_id="task", 
cluster_name=CLUSTER_NAME, wait_for_completion=True)
diff --git a/tests/providers/amazon/aws/triggers/test_ecs.py 
b/tests/providers/amazon/aws/triggers/test_ecs.py
new file mode 100644
index 0000000000..09b5decbe6
--- /dev/null
+++ b/tests/providers/amazon/aws/triggers/test_ecs.py
@@ -0,0 +1,123 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+from unittest.mock import AsyncMock
+
+import pytest
+from botocore.exceptions import WaiterError
+
+from airflow import AirflowException
+from airflow.providers.amazon.aws.hooks.ecs import EcsHook
+from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook
+from airflow.providers.amazon.aws.triggers.ecs import ClusterWaiterTrigger, 
TaskDoneTrigger
+from airflow.triggers.base import TriggerEvent
+
+
+class TestClusterWaiterTrigger:
+    @pytest.mark.asyncio
+    @mock.patch.object(EcsHook, "async_conn")
+    async def test_run_max_attempts(self, client_mock):
+        a_mock = mock.MagicMock()
+        client_mock.__aenter__.return_value = a_mock
+        wait_mock = AsyncMock()
+        wait_mock.side_effect = WaiterError("name", "reason", {"clusters": 
[{"status": "my_status"}]})
+        a_mock.get_waiter().wait = wait_mock
+
+        max_attempts = 5
+        trigger = ClusterWaiterTrigger("my_waiter", "cluster_arn", 0, 
max_attempts, None, None)
+
+        with pytest.raises(AirflowException):
+            generator = trigger.run()
+            await generator.asend(None)
+
+        assert wait_mock.call_count == max_attempts
+
+    @pytest.mark.asyncio
+    @mock.patch.object(EcsHook, "async_conn")
+    async def test_run_success(self, client_mock):
+        a_mock = mock.MagicMock()
+        client_mock.__aenter__.return_value = a_mock
+        wait_mock = AsyncMock()
+        a_mock.get_waiter().wait = wait_mock
+
+        trigger = ClusterWaiterTrigger("my_waiter", "cluster_arn", 0, 5, None, 
None)
+
+        generator = trigger.run()
+        response: TriggerEvent = await generator.asend(None)
+
+        assert response.payload["status"] == "success"
+        assert response.payload["arn"] == "cluster_arn"
+
+    @pytest.mark.asyncio
+    @mock.patch.object(EcsHook, "async_conn")
+    async def test_run_error(self, client_mock):
+        a_mock = mock.MagicMock()
+        client_mock.__aenter__.return_value = a_mock
+        wait_mock = AsyncMock()
+        wait_mock.side_effect = WaiterError("terminal failure", "reason", {})
+        a_mock.get_waiter().wait = wait_mock
+
+        trigger = ClusterWaiterTrigger("my_waiter", "cluster_arn", 0, 5, None, 
None)
+
+        with pytest.raises(AirflowException):
+            generator = trigger.run()
+            await generator.asend(None)
+
+
+class TestTaskDoneTrigger:
+    @pytest.mark.asyncio
+    @mock.patch.object(EcsHook, "async_conn")
+    # this mock is only necessary to avoid a "No module named 'aiobotocore'" 
error in the LatestBoto CI step
+    @mock.patch.object(AwsLogsHook, "async_conn")
+    async def test_run_until_error(self, _, client_mock):
+        a_mock = mock.MagicMock()
+        client_mock.__aenter__.return_value = a_mock
+        wait_mock = AsyncMock()
+        wait_mock.side_effect = [
+            WaiterError("name", "reason", {"tasks": [{"lastStatus": 
"my_status"}]}),
+            WaiterError("name", "reason", {"tasks": [{"lastStatus": 
"my_status"}]}),
+            WaiterError("terminal failure", "reason", {}),
+        ]
+        a_mock.get_waiter().wait = wait_mock
+
+        trigger = TaskDoneTrigger("cluster", "task_arn", 0, 10, None, None)
+
+        with pytest.raises(WaiterError):
+            generator = trigger.run()
+            await generator.asend(None)
+
+        assert wait_mock.call_count == 3
+
+    @pytest.mark.asyncio
+    @mock.patch.object(EcsHook, "async_conn")
+    # this mock is only necessary to avoid a "No module named 'aiobotocore'" 
error in the LatestBoto CI step
+    @mock.patch.object(AwsLogsHook, "async_conn")
+    async def test_run_success(self, _, client_mock):
+        a_mock = mock.MagicMock()
+        client_mock.__aenter__.return_value = a_mock
+        wait_mock = AsyncMock()
+        a_mock.get_waiter().wait = wait_mock
+
+        trigger = TaskDoneTrigger("cluster", "my_task_arn", 0, 10, None, None)
+
+        generator = trigger.run()
+        response: TriggerEvent = await generator.asend(None)
+
+        assert response.payload["status"] == "success"
+        assert response.payload["task_arn"] == "my_task_arn"
diff --git a/tests/providers/amazon/aws/utils/test_task_log_fetcher.py 
b/tests/providers/amazon/aws/utils/test_task_log_fetcher.py
index dbda751cfb..a5598ebf55 100644
--- a/tests/providers/amazon/aws/utils/test_task_log_fetcher.py
+++ b/tests/providers/amazon/aws/utils/test_task_log_fetcher.py
@@ -112,7 +112,7 @@ class TestAwsTaskLogFetcher:
             {"timestamp": 1617400367456, "message": "Second"},
             {"timestamp": 1617400467789, "message": "Third"},
         ]
-        assert [self.log_fetcher._event_to_str(event) for event in events] == (
+        assert [self.log_fetcher.event_to_str(event) for event in events] == (
             [
                 "[2021-04-02 21:51:07,123] First",
                 "[2021-04-02 21:52:47,456] Second",


Reply via email to