This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 415e076761 Deferrable mode for ECS operators (#31881) 415e076761 is described below commit 415e0767616121854b6a29b3e44387f708cdf81e Author: Raphaƫl Vandon <vand...@amazon.com> AuthorDate: Fri Jun 23 10:13:13 2023 -0700 Deferrable mode for ECS operators (#31881) --- airflow/providers/amazon/aws/operators/ecs.py | 160 ++++++++++++++--- airflow/providers/amazon/aws/triggers/ecs.py | 198 +++++++++++++++++++++ .../providers/amazon/aws/utils/task_log_fetcher.py | 5 +- airflow/providers/amazon/provider.yaml | 3 + tests/providers/amazon/aws/operators/test_ecs.py | 74 +++++++- tests/providers/amazon/aws/triggers/test_ecs.py | 123 +++++++++++++ .../amazon/aws/utils/test_task_log_fetcher.py | 2 +- 7 files changed, 532 insertions(+), 33 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py index bc8c4b70d7..2c2e93af35 100644 --- a/airflow/providers/amazon/aws/operators/ecs.py +++ b/airflow/providers/amazon/aws/operators/ecs.py @@ -35,6 +35,11 @@ from airflow.providers.amazon.aws.hooks.ecs import ( EcsHook, should_retry_eni, ) +from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook +from airflow.providers.amazon.aws.triggers.ecs import ( + ClusterWaiterTrigger, + TaskDoneTrigger, +) from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher from airflow.utils.helpers import prune_dict from airflow.utils.session import provide_session @@ -67,6 +72,15 @@ class EcsBaseOperator(BaseOperator): """Must overwrite in child classes.""" raise NotImplementedError("Please implement execute() in subclass") + def _complete_exec_with_cluster_desc(self, context, event=None): + """To be used as trigger callback for operators that return the cluster description.""" + if event["status"] != "success": + raise AirflowException(f"Error while waiting for operation on cluster to complete: {event}") + cluster_arn = event.get("arn") + # We cannot get the cluster definition from the waiter on success, so we have to query it here. + details = self.hook.conn.describe_clusters(clusters=[cluster_arn])["clusters"][0] + return details + class EcsCreateClusterOperator(EcsBaseOperator): """ @@ -84,9 +98,17 @@ class EcsCreateClusterOperator(EcsBaseOperator): if not set then the default waiter value will be used. :param waiter_max_attempts: The maximum number of attempts to be made, if not set then the default waiter value will be used. + :param deferrable: If True, the operator will wait asynchronously for the job to complete. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) """ - template_fields: Sequence[str] = ("cluster_name", "create_cluster_kwargs", "wait_for_completion") + template_fields: Sequence[str] = ( + "cluster_name", + "create_cluster_kwargs", + "wait_for_completion", + "deferrable", + ) def __init__( self, @@ -94,8 +116,9 @@ class EcsCreateClusterOperator(EcsBaseOperator): cluster_name: str, create_cluster_kwargs: dict | None = None, wait_for_completion: bool = True, - waiter_delay: int | None = None, - waiter_max_attempts: int | None = None, + waiter_delay: int = 15, + waiter_max_attempts: int = 60, + deferrable: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) @@ -104,6 +127,7 @@ class EcsCreateClusterOperator(EcsBaseOperator): self.wait_for_completion = wait_for_completion self.waiter_delay = waiter_delay self.waiter_max_attempts = waiter_max_attempts + self.deferrable = deferrable def execute(self, context: Context): self.log.info( @@ -119,6 +143,21 @@ class EcsCreateClusterOperator(EcsBaseOperator): # In some circumstances the ECS Cluster is created immediately, # and there is no reason to wait for completion. self.log.info("Cluster %r in state: %r.", self.cluster_name, cluster_state) + elif self.deferrable: + self.defer( + trigger=ClusterWaiterTrigger( + waiter_name="cluster_active", + cluster_arn=cluster_details["clusterArn"], + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + aws_conn_id=self.aws_conn_id, + region=self.region, + ), + method_name="_complete_exec_with_cluster_desc", + # timeout is set to ensure that if a trigger dies, the timeout does not restart + # 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent) + timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60), + ) elif self.wait_for_completion: waiter = self.hook.get_waiter("cluster_active") waiter.wait( @@ -148,17 +187,21 @@ class EcsDeleteClusterOperator(EcsBaseOperator): if not set then the default waiter value will be used. :param waiter_max_attempts: The maximum number of attempts to be made, if not set then the default waiter value will be used. + :param deferrable: If True, the operator will wait asynchronously for the job to complete. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) """ - template_fields: Sequence[str] = ("cluster_name", "wait_for_completion") + template_fields: Sequence[str] = ("cluster_name", "wait_for_completion", "deferrable") def __init__( self, *, cluster_name: str, wait_for_completion: bool = True, - waiter_delay: int | None = None, - waiter_max_attempts: int | None = None, + waiter_delay: int = 15, + waiter_max_attempts: int = 60, + deferrable: bool = False, **kwargs, ) -> None: super().__init__(**kwargs) @@ -166,6 +209,7 @@ class EcsDeleteClusterOperator(EcsBaseOperator): self.wait_for_completion = wait_for_completion self.waiter_delay = waiter_delay self.waiter_max_attempts = waiter_max_attempts + self.deferrable = deferrable def execute(self, context: Context): self.log.info("Deleting cluster %r.", self.cluster_name) @@ -174,9 +218,24 @@ class EcsDeleteClusterOperator(EcsBaseOperator): cluster_state = cluster_details.get("status") if cluster_state == EcsClusterStates.INACTIVE: - # In some circumstances the ECS Cluster is deleted immediately, - # so there is no reason to wait for completion. + # if the cluster doesn't have capacity providers that are associated with it, + # the deletion is instantaneous, and we don't need to wait for it. self.log.info("Cluster %r in state: %r.", self.cluster_name, cluster_state) + elif self.deferrable: + self.defer( + trigger=ClusterWaiterTrigger( + waiter_name="cluster_inactive", + cluster_arn=cluster_details["clusterArn"], + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + aws_conn_id=self.aws_conn_id, + region=self.region, + ), + method_name="_complete_exec_with_cluster_desc", + # timeout is set to ensure that if a trigger dies, the timeout does not restart + # 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent) + timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60), + ) elif self.wait_for_completion: waiter = self.hook.get_waiter("cluster_inactive") waiter.wait( @@ -347,6 +406,7 @@ class EcsRunTaskOperator(EcsBaseOperator): finished. :param awslogs_fetch_interval: the interval that the ECS task log fetcher should wait in between each Cloudwatch logs fetches. + If deferrable is set to True, that parameter is ignored and waiter_delay is used instead. :param quota_retry: Config if and how to retry the launch of a new ECS task, to handle transient errors. :param reattach: If set to True, will check if the task previously launched by the task_instance @@ -361,6 +421,9 @@ class EcsRunTaskOperator(EcsBaseOperator): if not set then the default waiter value will be used. :param waiter_max_attempts: The maximum number of attempts to be made, if not set then the default waiter value will be used. + :param deferrable: If True, the operator will wait asynchronously for the job to complete. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) """ ui_color = "#f0ede4" @@ -384,6 +447,7 @@ class EcsRunTaskOperator(EcsBaseOperator): "reattach", "number_logs_exception", "wait_for_completion", + "deferrable", ) template_fields_renderers = { "overrides": "json", @@ -416,8 +480,9 @@ class EcsRunTaskOperator(EcsBaseOperator): reattach: bool = False, number_logs_exception: int = 10, wait_for_completion: bool = True, - waiter_delay: int | None = None, - waiter_max_attempts: int | None = None, + waiter_delay: int = 6, + waiter_max_attempts: int = 100, + deferrable: bool = False, **kwargs, ): super().__init__(**kwargs) @@ -451,6 +516,7 @@ class EcsRunTaskOperator(EcsBaseOperator): self.wait_for_completion = wait_for_completion self.waiter_delay = waiter_delay self.waiter_max_attempts = waiter_max_attempts + self.deferrable = deferrable if self._aws_logs_enabled() and not self.wait_for_completion: self.log.warning( @@ -473,7 +539,35 @@ class EcsRunTaskOperator(EcsBaseOperator): if self.reattach: self._try_reattach_task(context) - self._start_wait_check_task(context) + self._start_wait_task(context) + + self._after_execution(session) + + if self.do_xcom_push and self.task_log_fetcher: + return self.task_log_fetcher.get_last_log_message() + else: + return None + + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error in task execution: {event}") + self.arn = event["task_arn"] # restore arn to its updated value, needed for next steps + self._after_execution() + if self._aws_logs_enabled(): + # same behavior as non-deferrable mode, return last line of logs of the task. + logs_client = AwsLogsHook(aws_conn_id=self.aws_conn_id, region_name=self.region).conn + one_log = logs_client.get_log_events( + logGroupName=self.awslogs_group, + logStreamName=self._get_logs_stream_name(), + startFromHead=False, + limit=1, + ) + if len(one_log["events"]) > 0: + return one_log["events"][0]["message"] + + @provide_session + def _after_execution(self, session=None): + self._check_success_task() self.log.info("ECS Task has been successfully executed") @@ -482,16 +576,29 @@ class EcsRunTaskOperator(EcsBaseOperator): # as we can't reattach it anymore self._xcom_del(session, self.REATTACH_XCOM_TASK_ID_TEMPLATE.format(task_id=self.task_id)) - if self.do_xcom_push and self.task_log_fetcher: - return self.task_log_fetcher.get_last_log_message() - - return None - @AwsBaseHook.retry(should_retry_eni) - def _start_wait_check_task(self, context): + def _start_wait_task(self, context): if not self.arn: self._start_task(context) + if self.deferrable: + self.defer( + trigger=TaskDoneTrigger( + cluster=self.cluster, + task_arn=self.arn, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + aws_conn_id=self.aws_conn_id, + region=self.region, + log_group=self.awslogs_group, + log_stream=self._get_logs_stream_name(), + ), + method_name="execute_complete", + # timeout is set to ensure that if a trigger dies, the timeout does not restart + # 60 seconds is added to allow the trigger to exit gracefully (i.e. yield TriggerEvent) + timeout=timedelta(seconds=self.waiter_max_attempts * self.waiter_delay + 60), + ) + if not self.wait_for_completion: return @@ -508,8 +615,6 @@ class EcsRunTaskOperator(EcsBaseOperator): else: self._wait_for_task_ended() - self._check_success_task() - def _xcom_del(self, session, task_id): session.query(XCom).filter(XCom.dag_id == self.dag_id, XCom.task_id == task_id).delete() @@ -584,12 +689,10 @@ class EcsRunTaskOperator(EcsBaseOperator): waiter.wait( cluster=self.cluster, tasks=[self.arn], - WaiterConfig=prune_dict( - { - "Delay": self.waiter_delay, - "MaxAttempts": self.waiter_max_attempts, - } - ), + WaiterConfig={ + "Delay": self.waiter_delay, + "MaxAttempts": self.waiter_max_attempts, + }, ) return @@ -597,20 +700,23 @@ class EcsRunTaskOperator(EcsBaseOperator): def _aws_logs_enabled(self): return self.awslogs_group and self.awslogs_stream_prefix + def _get_logs_stream_name(self) -> str: + return f"{self.awslogs_stream_prefix}/{self._get_ecs_task_id(self.arn)}" + def _get_task_log_fetcher(self) -> AwsTaskLogFetcher: if not self.awslogs_group: raise ValueError("must specify awslogs_group to fetch task logs") - log_stream_name = f"{self.awslogs_stream_prefix}/{self._get_ecs_task_id(self.arn)}" return AwsTaskLogFetcher( aws_conn_id=self.aws_conn_id, region_name=self.awslogs_region, log_group=self.awslogs_group, - log_stream_name=log_stream_name, + log_stream_name=self._get_logs_stream_name(), fetch_interval=self.awslogs_fetch_interval, logger=self.log, ) + @AwsBaseHook.retry(should_retry_eni) def _check_success_task(self) -> None: if not self.client or not self.arn: return diff --git a/airflow/providers/amazon/aws/triggers/ecs.py b/airflow/providers/amazon/aws/triggers/ecs.py new file mode 100644 index 0000000000..8ba8350588 --- /dev/null +++ b/airflow/providers/amazon/aws/triggers/ecs.py @@ -0,0 +1,198 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import asyncio +from typing import Any, AsyncIterator + +from botocore.exceptions import ClientError, WaiterError + +from airflow.providers.amazon.aws.hooks.ecs import EcsHook +from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook +from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher +from airflow.providers.amazon.aws.utils.waiter_with_logging import async_wait +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class ClusterWaiterTrigger(BaseTrigger): + """ + Polls the status of a cluster using a given waiter. Can be used to poll for an active or inactive cluster. + + :param waiter_name: Name of the waiter to use, for instance 'cluster_active' or 'cluster_inactive' + :param cluster_arn: ARN of the cluster to watch. + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The number of times to ping for status. + Will fail after that many unsuccessful attempts. + :param aws_conn_id: The Airflow connection used for AWS credentials. + :param region: The AWS region where the cluster is located. + """ + + def __init__( + self, + waiter_name: str, + cluster_arn: str, + waiter_delay: int | None, + waiter_max_attempts: int | None, + aws_conn_id: str | None, + region: str | None, + ): + self.cluster_arn = cluster_arn + self.waiter_name = waiter_name + self.waiter_delay = waiter_delay if waiter_delay is not None else 15 # written like this to allow 0 + self.attempts = waiter_max_attempts or 999999999 + self.aws_conn_id = aws_conn_id + self.region = region + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + self.__class__.__module__ + "." + self.__class__.__qualname__, + { + "waiter_name": self.waiter_name, + "cluster_arn": self.cluster_arn, + "waiter_delay": self.waiter_delay, + "waiter_max_attempts": self.attempts, + "aws_conn_id": self.aws_conn_id, + "region": self.region, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + async with EcsHook(aws_conn_id=self.aws_conn_id, region_name=self.region).async_conn as client: + waiter = client.get_waiter(self.waiter_name) + await async_wait( + waiter, + self.waiter_delay, + self.attempts, + {"clusters": [self.cluster_arn]}, + "error when checking cluster status", + "Status of cluster", + ["clusters[].status"], + ) + yield TriggerEvent({"status": "success", "arn": self.cluster_arn}) + + +class TaskDoneTrigger(BaseTrigger): + """ + Waits for an ECS task to be done, while eventually polling logs. + + :param cluster: short name or full ARN of the cluster where the task is running. + :param task_arn: ARN of the task to watch. + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The number of times to ping for status. + Will fail after that many unsuccessful attempts. + :param aws_conn_id: The Airflow connection used for AWS credentials. + :param region: The AWS region where the cluster is located. + """ + + def __init__( + self, + cluster: str, + task_arn: str, + waiter_delay: int, + waiter_max_attempts: int, + aws_conn_id: str | None, + region: str | None, + log_group: str | None = None, + log_stream: str | None = None, + ): + self.cluster = cluster + self.task_arn = task_arn + + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + self.aws_conn_id = aws_conn_id + self.region = region + + self.log_group = log_group + self.log_stream = log_stream + + def serialize(self) -> tuple[str, dict[str, Any]]: + return ( + self.__class__.__module__ + "." + self.__class__.__qualname__, + { + "cluster": self.cluster, + "task_arn": self.task_arn, + "waiter_delay": self.waiter_delay, + "waiter_max_attempts": self.waiter_max_attempts, + "aws_conn_id": self.aws_conn_id, + "region": self.region, + "log_group": self.log_group, + "log_stream": self.log_stream, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + # fmt: off + async with EcsHook(aws_conn_id=self.aws_conn_id, region_name=self.region).async_conn as ecs_client,\ + AwsLogsHook(aws_conn_id=self.aws_conn_id, region_name=self.region).async_conn as logs_client: + # fmt: on + waiter = ecs_client.get_waiter("tasks_stopped") + logs_token = None + while self.waiter_max_attempts >= 1: + self.waiter_max_attempts = self.waiter_max_attempts - 1 + try: + await waiter.wait( + cluster=self.cluster, tasks=[self.task_arn], WaiterConfig={"MaxAttempts": 1} + ) + break # we reach this point only if the waiter met a success criteria + except WaiterError as error: + if "terminal failure" in str(error): + raise + self.log.info("Status of the task is %s", error.last_response["tasks"][0]["lastStatus"]) + await asyncio.sleep(int(self.waiter_delay)) + finally: + if self.log_group and self.log_stream: + logs_token = await self._forward_logs(logs_client, logs_token) + + yield TriggerEvent({"status": "success", "task_arn": self.task_arn}) + + async def _forward_logs(self, logs_client, next_token: str | None = None) -> str | None: + """ + Reads logs from the cloudwatch stream and prints them to the task logs. + :return: the token to pass to the next iteration to resume where we started. + """ + while True: + if next_token is not None: + token_arg: dict[str, str] = {"nextToken": next_token} + else: + token_arg = {} + try: + response = await logs_client.get_log_events( + logGroupName=self.log_group, + logStreamName=self.log_stream, + startFromHead=True, + **token_arg, + ) + except ClientError as ce: + if ce.response["Error"]["Code"] == "ResourceNotFoundException": + self.log.info( + "Tried to get logs from stream %s in group %s but it didn't exist (yet). " + "Will try again.", + self.log_stream, + self.log_group, + ) + return None + raise + + events = response["events"] + for log_event in events: + self.log.info(AwsTaskLogFetcher.event_to_str(log_event)) + + if len(events) == 0 or next_token == response["nextForwardToken"]: + return response["nextForwardToken"] + next_token = response["nextForwardToken"] diff --git a/airflow/providers/amazon/aws/utils/task_log_fetcher.py b/airflow/providers/amazon/aws/utils/task_log_fetcher.py index 97b43a67b2..22a5e5f2a1 100644 --- a/airflow/providers/amazon/aws/utils/task_log_fetcher.py +++ b/airflow/providers/amazon/aws/utils/task_log_fetcher.py @@ -62,7 +62,7 @@ class AwsTaskLogFetcher(Thread): time.sleep(self.fetch_interval.total_seconds()) log_events = self._get_log_events(continuation_token) for log_event in log_events: - self.logger.info(self._event_to_str(log_event)) + self.logger.info(self.event_to_str(log_event)) def _get_log_events(self, skip_token: AwsLogsHook.ContinuationToken | None = None) -> Generator: if skip_token is None: @@ -87,7 +87,8 @@ class AwsTaskLogFetcher(Thread): self.logger.warning("ConnectionClosedError on retrieving Cloudwatch log events", error) yield from () - def _event_to_str(self, event: dict) -> str: + @staticmethod + def event_to_str(event: dict) -> str: event_dt = datetime.utcfromtimestamp(event["timestamp"] / 1000.0) formatted_event_dt = event_dt.strftime("%Y-%m-%d %H:%M:%S,%f")[:-3] message = event["message"] diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 3680915dc3..223bc553ef 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -532,6 +532,9 @@ triggers: - integration-name: Amazon Elastic Kubernetes Service (EKS) python-modules: - airflow.providers.amazon.aws.triggers.eks + - integration-name: Amazon ECS + python-modules: + - airflow.providers.amazon.aws.triggers.ecs transfers: - source-integration-name: Amazon DynamoDB diff --git a/tests/providers/amazon/aws/operators/test_ecs.py b/tests/providers/amazon/aws/operators/test_ecs.py index 8c99a02ce8..b89ea59c56 100644 --- a/tests/providers/amazon/aws/operators/test_ecs.py +++ b/tests/providers/amazon/aws/operators/test_ecs.py @@ -20,13 +20,14 @@ from __future__ import annotations import sys from copy import deepcopy from unittest import mock +from unittest.mock import MagicMock, PropertyMock import boto3 import pytest -from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, TaskDeferred from airflow.providers.amazon.aws.exceptions import EcsOperatorError, EcsTaskFailToStart -from airflow.providers.amazon.aws.hooks.ecs import EcsHook +from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook from airflow.providers.amazon.aws.operators.ecs import ( DEFAULT_CONN_ID, EcsBaseOperator, @@ -36,6 +37,7 @@ from airflow.providers.amazon.aws.operators.ecs import ( EcsRegisterTaskDefinitionOperator, EcsRunTaskOperator, ) +from airflow.providers.amazon.aws.triggers.ecs import TaskDoneTrigger from airflow.providers.amazon.aws.utils.task_log_fetcher import AwsTaskLogFetcher from airflow.utils.types import NOTSET @@ -186,6 +188,7 @@ class TestEcsRunTaskOperator(EcsBaseTestCase): "reattach", "number_logs_exception", "wait_for_completion", + "deferrable", ) @pytest.mark.parametrize( @@ -343,7 +346,7 @@ class TestEcsRunTaskOperator(EcsBaseTestCase): self.ecs._wait_for_task_ended() client_mock.get_waiter.assert_called_once_with("tasks_stopped") client_mock.get_waiter.return_value.wait.assert_called_once_with( - cluster="c", tasks=["arn"], WaiterConfig={} + cluster="c", tasks=["arn"], WaiterConfig={"Delay": 6, "MaxAttempts": 100} ) assert sys.maxsize == client_mock.get_waiter.return_value.config.max_attempts @@ -654,6 +657,31 @@ class TestEcsRunTaskOperator(EcsBaseTestCase): self.ecs.do_xcom_push = False assert self.ecs.execute(None) is None + @mock.patch.object(EcsRunTaskOperator, "client") + def test_with_defer(self, client_mock): + self.ecs.deferrable = True + + client_mock.run_task.return_value = RESPONSE_WITHOUT_FAILURES + + with pytest.raises(TaskDeferred) as deferred: + self.ecs.execute(None) + + assert isinstance(deferred.value.trigger, TaskDoneTrigger) + assert deferred.value.trigger.task_arn == f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}" + + @mock.patch.object(EcsRunTaskOperator, "client", new_callable=PropertyMock) + @mock.patch.object(EcsRunTaskOperator, "_xcom_del") + def test_execute_complete(self, xcom_del_mock: MagicMock, client_mock): + event = {"status": "success", "task_arn": "my_arn"} + self.ecs.reattach = True + + self.ecs.execute_complete(None, event) + + # task gets described to assert its success + client_mock().describe_tasks.assert_called_once_with(cluster="c", tasks=["my_arn"]) + # if reattach mode, xcom value is deleted on success + xcom_del_mock.assert_called_once() + class TestEcsCreateClusterOperator(EcsBaseTestCase): @pytest.mark.parametrize("waiter_delay, waiter_max_attempts", WAITERS_TEST_CASES) @@ -680,6 +708,26 @@ class TestEcsCreateClusterOperator(EcsBaseTestCase): mocked_waiters.wait.assert_called_once_with(clusters=mock.ANY, WaiterConfig=expected_waiter_config) assert result is not None + @mock.patch.object(EcsCreateClusterOperator, "client") + def test_execute_deferrable(self, mock_client: MagicMock): + op = EcsCreateClusterOperator( + task_id="task", + cluster_name=CLUSTER_NAME, + deferrable=True, + waiter_delay=12, + waiter_max_attempts=34, + ) + mock_client.create_cluster.return_value = { + "cluster": {"status": EcsClusterStates.PROVISIONING, "clusterArn": "my arn"} + } + + with pytest.raises(TaskDeferred) as defer: + op.execute(None) + + assert defer.value.trigger.cluster_arn == "my arn" + assert defer.value.trigger.waiter_delay == 12 + assert defer.value.trigger.attempts == 34 + def test_execute_immediate_create(self, patch_hook_waiters): """Test if cluster created during initial request.""" op = EcsCreateClusterOperator(task_id="task", cluster_name=CLUSTER_NAME, wait_for_completion=True) @@ -725,6 +773,26 @@ class TestEcsDeleteClusterOperator(EcsBaseTestCase): mocked_waiters.wait.assert_called_once_with(clusters=mock.ANY, WaiterConfig=expected_waiter_config) assert result is not None + @mock.patch.object(EcsDeleteClusterOperator, "client") + def test_execute_deferrable(self, mock_client: MagicMock): + op = EcsDeleteClusterOperator( + task_id="task", + cluster_name=CLUSTER_NAME, + deferrable=True, + waiter_delay=12, + waiter_max_attempts=34, + ) + mock_client.delete_cluster.return_value = { + "cluster": {"status": EcsClusterStates.DEPROVISIONING, "clusterArn": "my arn"} + } + + with pytest.raises(TaskDeferred) as defer: + op.execute(None) + + assert defer.value.trigger.cluster_arn == "my arn" + assert defer.value.trigger.waiter_delay == 12 + assert defer.value.trigger.attempts == 34 + def test_execute_immediate_delete(self, patch_hook_waiters): """Test if cluster deleted during initial request.""" op = EcsDeleteClusterOperator(task_id="task", cluster_name=CLUSTER_NAME, wait_for_completion=True) diff --git a/tests/providers/amazon/aws/triggers/test_ecs.py b/tests/providers/amazon/aws/triggers/test_ecs.py new file mode 100644 index 0000000000..09b5decbe6 --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_ecs.py @@ -0,0 +1,123 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock +from unittest.mock import AsyncMock + +import pytest +from botocore.exceptions import WaiterError + +from airflow import AirflowException +from airflow.providers.amazon.aws.hooks.ecs import EcsHook +from airflow.providers.amazon.aws.hooks.logs import AwsLogsHook +from airflow.providers.amazon.aws.triggers.ecs import ClusterWaiterTrigger, TaskDoneTrigger +from airflow.triggers.base import TriggerEvent + + +class TestClusterWaiterTrigger: + @pytest.mark.asyncio + @mock.patch.object(EcsHook, "async_conn") + async def test_run_max_attempts(self, client_mock): + a_mock = mock.MagicMock() + client_mock.__aenter__.return_value = a_mock + wait_mock = AsyncMock() + wait_mock.side_effect = WaiterError("name", "reason", {"clusters": [{"status": "my_status"}]}) + a_mock.get_waiter().wait = wait_mock + + max_attempts = 5 + trigger = ClusterWaiterTrigger("my_waiter", "cluster_arn", 0, max_attempts, None, None) + + with pytest.raises(AirflowException): + generator = trigger.run() + await generator.asend(None) + + assert wait_mock.call_count == max_attempts + + @pytest.mark.asyncio + @mock.patch.object(EcsHook, "async_conn") + async def test_run_success(self, client_mock): + a_mock = mock.MagicMock() + client_mock.__aenter__.return_value = a_mock + wait_mock = AsyncMock() + a_mock.get_waiter().wait = wait_mock + + trigger = ClusterWaiterTrigger("my_waiter", "cluster_arn", 0, 5, None, None) + + generator = trigger.run() + response: TriggerEvent = await generator.asend(None) + + assert response.payload["status"] == "success" + assert response.payload["arn"] == "cluster_arn" + + @pytest.mark.asyncio + @mock.patch.object(EcsHook, "async_conn") + async def test_run_error(self, client_mock): + a_mock = mock.MagicMock() + client_mock.__aenter__.return_value = a_mock + wait_mock = AsyncMock() + wait_mock.side_effect = WaiterError("terminal failure", "reason", {}) + a_mock.get_waiter().wait = wait_mock + + trigger = ClusterWaiterTrigger("my_waiter", "cluster_arn", 0, 5, None, None) + + with pytest.raises(AirflowException): + generator = trigger.run() + await generator.asend(None) + + +class TestTaskDoneTrigger: + @pytest.mark.asyncio + @mock.patch.object(EcsHook, "async_conn") + # this mock is only necessary to avoid a "No module named 'aiobotocore'" error in the LatestBoto CI step + @mock.patch.object(AwsLogsHook, "async_conn") + async def test_run_until_error(self, _, client_mock): + a_mock = mock.MagicMock() + client_mock.__aenter__.return_value = a_mock + wait_mock = AsyncMock() + wait_mock.side_effect = [ + WaiterError("name", "reason", {"tasks": [{"lastStatus": "my_status"}]}), + WaiterError("name", "reason", {"tasks": [{"lastStatus": "my_status"}]}), + WaiterError("terminal failure", "reason", {}), + ] + a_mock.get_waiter().wait = wait_mock + + trigger = TaskDoneTrigger("cluster", "task_arn", 0, 10, None, None) + + with pytest.raises(WaiterError): + generator = trigger.run() + await generator.asend(None) + + assert wait_mock.call_count == 3 + + @pytest.mark.asyncio + @mock.patch.object(EcsHook, "async_conn") + # this mock is only necessary to avoid a "No module named 'aiobotocore'" error in the LatestBoto CI step + @mock.patch.object(AwsLogsHook, "async_conn") + async def test_run_success(self, _, client_mock): + a_mock = mock.MagicMock() + client_mock.__aenter__.return_value = a_mock + wait_mock = AsyncMock() + a_mock.get_waiter().wait = wait_mock + + trigger = TaskDoneTrigger("cluster", "my_task_arn", 0, 10, None, None) + + generator = trigger.run() + response: TriggerEvent = await generator.asend(None) + + assert response.payload["status"] == "success" + assert response.payload["task_arn"] == "my_task_arn" diff --git a/tests/providers/amazon/aws/utils/test_task_log_fetcher.py b/tests/providers/amazon/aws/utils/test_task_log_fetcher.py index dbda751cfb..a5598ebf55 100644 --- a/tests/providers/amazon/aws/utils/test_task_log_fetcher.py +++ b/tests/providers/amazon/aws/utils/test_task_log_fetcher.py @@ -112,7 +112,7 @@ class TestAwsTaskLogFetcher: {"timestamp": 1617400367456, "message": "Second"}, {"timestamp": 1617400467789, "message": "Third"}, ] - assert [self.log_fetcher._event_to_str(event) for event in events] == ( + assert [self.log_fetcher.event_to_str(event) for event in events] == ( [ "[2021-04-02 21:51:07,123] First", "[2021-04-02 21:52:47,456] Second",