This is an automated email from the ASF dual-hosted git repository. onikolas pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new e0f21f43c6 Various fixes on ECS run task operator (#31838) e0f21f43c6 is described below commit e0f21f43c63b13fd48f55aa660746edc37df1458 Author: Raphaƫl Vandon <vand...@amazon.com> AuthorDate: Fri Jun 16 12:22:24 2023 -0700 Various fixes on ECS run task operator (#31838) * ECS Run Task op should not try to get logs or check the status if not waiting for completion --- airflow/providers/amazon/aws/hooks/ecs.py | 9 +++- airflow/providers/amazon/aws/operators/ecs.py | 28 ++++++++----- .../operators/ecs.rst | 2 +- tests/providers/amazon/aws/operators/test_ecs.py | 7 ++-- tests/system/providers/amazon/aws/example_ecs.py | 48 ++++++++++------------ .../providers/amazon/aws/example_ecs_fargate.py | 22 ++++++++++ 6 files changed, 75 insertions(+), 41 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/ecs.py b/airflow/providers/amazon/aws/hooks/ecs.py index 17119d5968..94baeb0a9a 100644 --- a/airflow/providers/amazon/aws/hooks/ecs.py +++ b/airflow/providers/amazon/aws/hooks/ecs.py @@ -188,7 +188,14 @@ class EcsTaskLogFetcher(Thread): except ClientError as error: if error.response["Error"]["Code"] != "ResourceNotFoundException": self.logger.warning("Error on retrieving Cloudwatch log events", error) - + else: + self.logger.info( + "Cannot find log stream yet, it can take a couple of seconds to show up. " + "If this error persists, check that the log group and stream are correct: " + "group: %s\tstream: %s", + self.log_group, + self.log_stream_name, + ) yield from () except ConnectionClosedError as error: self.logger.warning("ConnectionClosedError on retrieving Cloudwatch log events", error) diff --git a/airflow/providers/amazon/aws/operators/ecs.py b/airflow/providers/amazon/aws/operators/ecs.py index 149ae2c11c..3a62f389da 100644 --- a/airflow/providers/amazon/aws/operators/ecs.py +++ b/airflow/providers/amazon/aws/operators/ecs.py @@ -480,6 +480,17 @@ class EcsRunTaskOperator(EcsBaseOperator): self.waiter_delay = waiter_delay self.waiter_max_attempts = waiter_max_attempts + if self._aws_logs_enabled() and not self.wait_for_completion: + self.log.warning( + "Trying to get logs without waiting for the task to complete is undefined behavior." + ) + + @staticmethod + def _get_ecs_task_id(task_arn: str | None) -> str | None: + if task_arn is None: + return None + return task_arn.split("/")[-1] + @provide_session def execute(self, context, session=None): self.log.info( @@ -506,25 +517,24 @@ class EcsRunTaskOperator(EcsBaseOperator): @AwsBaseHook.retry(should_retry_eni) def _start_wait_check_task(self, context): - if not self.arn: self._start_task(context) + if not self.wait_for_completion: + return + if self._aws_logs_enabled(): self.log.info("Starting ECS Task Log Fetcher") self.task_log_fetcher = self._get_task_log_fetcher() self.task_log_fetcher.start() try: - if self.wait_for_completion: - self._wait_for_task_ended() + self._wait_for_task_ended() finally: self.task_log_fetcher.stop() - self.task_log_fetcher.join() else: - if self.wait_for_completion: - self._wait_for_task_ended() + self._wait_for_task_ended() self._check_success_task() @@ -566,8 +576,7 @@ class EcsRunTaskOperator(EcsBaseOperator): self.log.info("ECS Task started: %s", response) self.arn = response["tasks"][0]["taskArn"] - self.ecs_task_id = self.arn.split("/")[-1] - self.log.info("ECS task ID is: %s", self.ecs_task_id) + self.log.info("ECS task ID is: %s", self._get_ecs_task_id(self.arn)) if self.reattach: # Save the task ARN in XCom to be able to reattach it if needed @@ -590,7 +599,6 @@ class EcsRunTaskOperator(EcsBaseOperator): ) if previous_task_arn in running_tasks: self.arn = previous_task_arn - self.ecs_task_id = self.arn.split("/")[-1] self.log.info("Reattaching previously launched task: %s", self.arn) else: self.log.info("No active previously launched task found to reattach") @@ -620,7 +628,7 @@ class EcsRunTaskOperator(EcsBaseOperator): def _get_task_log_fetcher(self) -> EcsTaskLogFetcher: if not self.awslogs_group: raise ValueError("must specify awslogs_group to fetch task logs") - log_stream_name = f"{self.awslogs_stream_prefix}/{self.ecs_task_id}" + log_stream_name = f"{self.awslogs_stream_prefix}/{self._get_ecs_task_id(self.arn)}" return EcsTaskLogFetcher( aws_conn_id=self.aws_conn_id, diff --git a/docs/apache-airflow-providers-amazon/operators/ecs.rst b/docs/apache-airflow-providers-amazon/operators/ecs.rst index d513485a9a..e6b4385d36 100644 --- a/docs/apache-airflow-providers-amazon/operators/ecs.rst +++ b/docs/apache-airflow-providers-amazon/operators/ecs.rst @@ -250,7 +250,7 @@ both can be overridden with provided values. Raises an AirflowException with the failure reason if a failed state is provided and that state is reached before the target state. -.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_ecs.py +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_ecs_fargate.py :language: python :dedent: 4 :start-after: [START howto_sensor_ecs_task_state] diff --git a/tests/providers/amazon/aws/operators/test_ecs.py b/tests/providers/amazon/aws/operators/test_ecs.py index cadaa6e329..ca23931e90 100644 --- a/tests/providers/amazon/aws/operators/test_ecs.py +++ b/tests/providers/amazon/aws/operators/test_ecs.py @@ -304,7 +304,10 @@ class TestEcsRunTaskOperator(EcsBaseTestCase): wait_mock.assert_called_once_with() check_mock.assert_called_once_with() assert self.ecs.arn == f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}" - assert self.ecs.ecs_task_id == TASK_ID + + def test_task_id_parsing(self): + id = EcsRunTaskOperator._get_ecs_task_id(f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}") + assert id == TASK_ID @mock.patch.object(EcsBaseOperator, "client") def test_execute_with_failures(self, client_mock): @@ -571,7 +574,6 @@ class TestEcsRunTaskOperator(EcsBaseTestCase): check_mock.assert_called_once_with() xcom_del_mock.assert_called_once() assert self.ecs.arn == f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}" - assert self.ecs.ecs_task_id == TASK_ID @pytest.mark.parametrize( "launch_type, tags", @@ -620,7 +622,6 @@ class TestEcsRunTaskOperator(EcsBaseTestCase): check_mock.assert_called_once_with() xcom_del_mock.assert_called_once() assert self.ecs.arn == f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}" - assert self.ecs.ecs_task_id == TASK_ID @mock.patch.object(EcsBaseOperator, "client") @mock.patch("airflow.providers.amazon.aws.hooks.ecs.EcsTaskLogFetcher") diff --git a/tests/system/providers/amazon/aws/example_ecs.py b/tests/system/providers/amazon/aws/example_ecs.py index 194b070b51..be90f8c96f 100644 --- a/tests/system/providers/amazon/aws/example_ecs.py +++ b/tests/system/providers/amazon/aws/example_ecs.py @@ -23,7 +23,7 @@ import boto3 from airflow import DAG from airflow.decorators import task from airflow.models.baseoperator import chain -from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsTaskStates +from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates from airflow.providers.amazon.aws.operators.ecs import ( EcsCreateClusterOperator, EcsDeleteClusterOperator, @@ -34,7 +34,6 @@ from airflow.providers.amazon.aws.operators.ecs import ( from airflow.providers.amazon.aws.sensors.ecs import ( EcsClusterStateSensor, EcsTaskDefinitionStateSensor, - EcsTaskStateSensor, ) from airflow.utils.trigger_rule import TriggerRule from tests.system.providers.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder @@ -67,6 +66,12 @@ def get_region(): return boto3.session.Session().region_name +@task(trigger_rule=TriggerRule.ALL_DONE) +def clean_logs(group_name: str): + client = boto3.client("logs") + client.delete_log_group(logGroupName=group_name) + + with DAG( dag_id=DAG_ID, schedule="@once", @@ -85,6 +90,7 @@ with DAG( asg_name = f"{env_id}-asg" aws_region = get_region() + log_group_name = f"/ecs_test/{env_id}" # [START howto_operator_ecs_create_cluster] create_cluster = EcsCreateClusterOperator( @@ -114,7 +120,16 @@ with DAG( "workingDirectory": "/usr/bin", "entryPoint": ["sh", "-c"], "command": ["ls"], - } + "logConfiguration": { + "logDriver": "awslogs", + "options": { + "awslogs-group": log_group_name, + "awslogs-region": aws_region, + "awslogs-create-group": "true", + "awslogs-stream-prefix": "ecs", + }, + }, + }, ], register_task_kwargs={ "cpu": "256", @@ -140,38 +155,19 @@ with DAG( "containerOverrides": [ { "name": container_name, - "command": ["echo", "hello", "world"], + "command": ["echo hello world"], }, ], }, network_configuration={"awsvpcConfiguration": {"subnets": existing_cluster_subnets}}, # [START howto_awslogs_ecs] - awslogs_group="/ecs/hello-world", + awslogs_group=log_group_name, awslogs_region=aws_region, - awslogs_stream_prefix="ecs/hello-world-container", + awslogs_stream_prefix=f"ecs/{container_name}", # [END howto_awslogs_ecs] - # You must set `reattach=True` in order to get ecs_task_arn if you plan to use a Sensor. - reattach=True, ) # [END howto_operator_ecs_run_task] - # EcsRunTaskOperator waits by default, setting as False to test the Sensor below. - run_task.wait_for_completion = False - - # [START howto_sensor_ecs_task_state] - # By default, EcsTaskStateSensor waits until the task has started, but the - # demo task runs so fast that the sensor misses it. This sensor instead - # demonstrates how to wait until the ECS Task has completed by providing - # the target_state and failure_states parameters. - await_task_finish = EcsTaskStateSensor( - task_id="await_task_finish", - cluster=existing_cluster_name, - task=run_task.output["ecs_task_arn"], - target_state=EcsTaskStates.STOPPED, - failure_states={EcsTaskStates.NONE}, - ) - # [END howto_sensor_ecs_task_state] - # [START howto_operator_ecs_deregister_task_definition] deregister_task = EcsDeregisterTaskDefinitionOperator( task_id="deregister_task", @@ -209,10 +205,10 @@ with DAG( register_task, await_task_definition, run_task, - await_task_finish, deregister_task, delete_cluster, await_delete_cluster, + clean_logs(log_group_name), ) from tests.system.utils.watcher import watcher diff --git a/tests/system/providers/amazon/aws/example_ecs_fargate.py b/tests/system/providers/amazon/aws/example_ecs_fargate.py index 704bd91cdf..b23f85a956 100644 --- a/tests/system/providers/amazon/aws/example_ecs_fargate.py +++ b/tests/system/providers/amazon/aws/example_ecs_fargate.py @@ -23,7 +23,9 @@ import boto3 from airflow import DAG from airflow.decorators import task from airflow.models.baseoperator import chain +from airflow.providers.amazon.aws.hooks.ecs import EcsTaskStates from airflow.providers.amazon.aws.operators.ecs import EcsRunTaskOperator +from airflow.providers.amazon.aws.sensors.ecs import EcsTaskStateSensor from airflow.utils.trigger_rule import TriggerRule from tests.system.providers.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder @@ -120,9 +122,28 @@ with DAG( "assignPublicIp": "ENABLED", }, }, + # You must set `reattach=True` in order to get ecs_task_arn if you plan to use a Sensor. + reattach=True, ) # [END howto_operator_ecs] + # EcsRunTaskOperator waits by default, setting as False to test the Sensor below. + hello_world.wait_for_completion = False + + # [START howto_sensor_ecs_task_state] + # By default, EcsTaskStateSensor waits until the task has started, but the + # demo task runs so fast that the sensor misses it. This sensor instead + # demonstrates how to wait until the ECS Task has completed by providing + # the target_state and failure_states parameters. + await_task_finish = EcsTaskStateSensor( + task_id="await_task_finish", + cluster=cluster_name, + task=hello_world.output["ecs_task_arn"], + target_state=EcsTaskStates.STOPPED, + failure_states={EcsTaskStates.NONE}, + ) + # [END howto_sensor_ecs_task_state] + chain( # TEST SETUP test_context, @@ -130,6 +151,7 @@ with DAG( create_task_definition, # TEST BODY hello_world, + await_task_finish, # TEST TEARDOWN delete_task_definition(create_task_definition), delete_cluster(cluster_name),