This is an automated email from the ASF dual-hosted git repository.

onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e0f21f43c6 Various fixes on ECS run task operator (#31838)
e0f21f43c6 is described below

commit e0f21f43c63b13fd48f55aa660746edc37df1458
Author: Raphaƫl Vandon <vand...@amazon.com>
AuthorDate: Fri Jun 16 12:22:24 2023 -0700

    Various fixes on ECS run task operator (#31838)
    
    * ECS Run Task op should not try to get logs or check the status if not 
waiting for completion
---
 airflow/providers/amazon/aws/hooks/ecs.py          |  9 +++-
 airflow/providers/amazon/aws/operators/ecs.py      | 28 ++++++++-----
 .../operators/ecs.rst                              |  2 +-
 tests/providers/amazon/aws/operators/test_ecs.py   |  7 ++--
 tests/system/providers/amazon/aws/example_ecs.py   | 48 ++++++++++------------
 .../providers/amazon/aws/example_ecs_fargate.py    | 22 ++++++++++
 6 files changed, 75 insertions(+), 41 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/ecs.py 
b/airflow/providers/amazon/aws/hooks/ecs.py
index 17119d5968..94baeb0a9a 100644
--- a/airflow/providers/amazon/aws/hooks/ecs.py
+++ b/airflow/providers/amazon/aws/hooks/ecs.py
@@ -188,7 +188,14 @@ class EcsTaskLogFetcher(Thread):
         except ClientError as error:
             if error.response["Error"]["Code"] != "ResourceNotFoundException":
                 self.logger.warning("Error on retrieving Cloudwatch log 
events", error)
-
+            else:
+                self.logger.info(
+                    "Cannot find log stream yet, it can take a couple of 
seconds to show up. "
+                    "If this error persists, check that the log group and 
stream are correct: "
+                    "group: %s\tstream: %s",
+                    self.log_group,
+                    self.log_stream_name,
+                )
             yield from ()
         except ConnectionClosedError as error:
             self.logger.warning("ConnectionClosedError on retrieving 
Cloudwatch log events", error)
diff --git a/airflow/providers/amazon/aws/operators/ecs.py 
b/airflow/providers/amazon/aws/operators/ecs.py
index 149ae2c11c..3a62f389da 100644
--- a/airflow/providers/amazon/aws/operators/ecs.py
+++ b/airflow/providers/amazon/aws/operators/ecs.py
@@ -480,6 +480,17 @@ class EcsRunTaskOperator(EcsBaseOperator):
         self.waiter_delay = waiter_delay
         self.waiter_max_attempts = waiter_max_attempts
 
+        if self._aws_logs_enabled() and not self.wait_for_completion:
+            self.log.warning(
+                "Trying to get logs without waiting for the task to complete 
is undefined behavior."
+            )
+
+    @staticmethod
+    def _get_ecs_task_id(task_arn: str | None) -> str | None:
+        if task_arn is None:
+            return None
+        return task_arn.split("/")[-1]
+
     @provide_session
     def execute(self, context, session=None):
         self.log.info(
@@ -506,25 +517,24 @@ class EcsRunTaskOperator(EcsBaseOperator):
 
     @AwsBaseHook.retry(should_retry_eni)
     def _start_wait_check_task(self, context):
-
         if not self.arn:
             self._start_task(context)
 
+        if not self.wait_for_completion:
+            return
+
         if self._aws_logs_enabled():
             self.log.info("Starting ECS Task Log Fetcher")
             self.task_log_fetcher = self._get_task_log_fetcher()
             self.task_log_fetcher.start()
 
             try:
-                if self.wait_for_completion:
-                    self._wait_for_task_ended()
+                self._wait_for_task_ended()
             finally:
                 self.task_log_fetcher.stop()
-
             self.task_log_fetcher.join()
         else:
-            if self.wait_for_completion:
-                self._wait_for_task_ended()
+            self._wait_for_task_ended()
 
         self._check_success_task()
 
@@ -566,8 +576,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
         self.log.info("ECS Task started: %s", response)
 
         self.arn = response["tasks"][0]["taskArn"]
-        self.ecs_task_id = self.arn.split("/")[-1]
-        self.log.info("ECS task ID is: %s", self.ecs_task_id)
+        self.log.info("ECS task ID is: %s", self._get_ecs_task_id(self.arn))
 
         if self.reattach:
             # Save the task ARN in XCom to be able to reattach it if needed
@@ -590,7 +599,6 @@ class EcsRunTaskOperator(EcsBaseOperator):
         )
         if previous_task_arn in running_tasks:
             self.arn = previous_task_arn
-            self.ecs_task_id = self.arn.split("/")[-1]
             self.log.info("Reattaching previously launched task: %s", self.arn)
         else:
             self.log.info("No active previously launched task found to 
reattach")
@@ -620,7 +628,7 @@ class EcsRunTaskOperator(EcsBaseOperator):
     def _get_task_log_fetcher(self) -> EcsTaskLogFetcher:
         if not self.awslogs_group:
             raise ValueError("must specify awslogs_group to fetch task logs")
-        log_stream_name = f"{self.awslogs_stream_prefix}/{self.ecs_task_id}"
+        log_stream_name = 
f"{self.awslogs_stream_prefix}/{self._get_ecs_task_id(self.arn)}"
 
         return EcsTaskLogFetcher(
             aws_conn_id=self.aws_conn_id,
diff --git a/docs/apache-airflow-providers-amazon/operators/ecs.rst 
b/docs/apache-airflow-providers-amazon/operators/ecs.rst
index d513485a9a..e6b4385d36 100644
--- a/docs/apache-airflow-providers-amazon/operators/ecs.rst
+++ b/docs/apache-airflow-providers-amazon/operators/ecs.rst
@@ -250,7 +250,7 @@ both can be overridden with provided values.  Raises an 
AirflowException with
 the failure reason if a failed state is provided and that state is reached
 before the target state.
 
-.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_ecs.py
+.. exampleinclude:: 
/../../tests/system/providers/amazon/aws/example_ecs_fargate.py
     :language: python
     :dedent: 4
     :start-after: [START howto_sensor_ecs_task_state]
diff --git a/tests/providers/amazon/aws/operators/test_ecs.py 
b/tests/providers/amazon/aws/operators/test_ecs.py
index cadaa6e329..ca23931e90 100644
--- a/tests/providers/amazon/aws/operators/test_ecs.py
+++ b/tests/providers/amazon/aws/operators/test_ecs.py
@@ -304,7 +304,10 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
         wait_mock.assert_called_once_with()
         check_mock.assert_called_once_with()
         assert self.ecs.arn == 
f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}"
-        assert self.ecs.ecs_task_id == TASK_ID
+
+    def test_task_id_parsing(self):
+        id = 
EcsRunTaskOperator._get_ecs_task_id(f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}")
+        assert id == TASK_ID
 
     @mock.patch.object(EcsBaseOperator, "client")
     def test_execute_with_failures(self, client_mock):
@@ -571,7 +574,6 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
         check_mock.assert_called_once_with()
         xcom_del_mock.assert_called_once()
         assert self.ecs.arn == 
f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}"
-        assert self.ecs.ecs_task_id == TASK_ID
 
     @pytest.mark.parametrize(
         "launch_type, tags",
@@ -620,7 +622,6 @@ class TestEcsRunTaskOperator(EcsBaseTestCase):
         check_mock.assert_called_once_with()
         xcom_del_mock.assert_called_once()
         assert self.ecs.arn == 
f"arn:aws:ecs:us-east-1:012345678910:task/{TASK_ID}"
-        assert self.ecs.ecs_task_id == TASK_ID
 
     @mock.patch.object(EcsBaseOperator, "client")
     @mock.patch("airflow.providers.amazon.aws.hooks.ecs.EcsTaskLogFetcher")
diff --git a/tests/system/providers/amazon/aws/example_ecs.py 
b/tests/system/providers/amazon/aws/example_ecs.py
index 194b070b51..be90f8c96f 100644
--- a/tests/system/providers/amazon/aws/example_ecs.py
+++ b/tests/system/providers/amazon/aws/example_ecs.py
@@ -23,7 +23,7 @@ import boto3
 from airflow import DAG
 from airflow.decorators import task
 from airflow.models.baseoperator import chain
-from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, 
EcsTaskStates
+from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates
 from airflow.providers.amazon.aws.operators.ecs import (
     EcsCreateClusterOperator,
     EcsDeleteClusterOperator,
@@ -34,7 +34,6 @@ from airflow.providers.amazon.aws.operators.ecs import (
 from airflow.providers.amazon.aws.sensors.ecs import (
     EcsClusterStateSensor,
     EcsTaskDefinitionStateSensor,
-    EcsTaskStateSensor,
 )
 from airflow.utils.trigger_rule import TriggerRule
 from tests.system.providers.amazon.aws.utils import ENV_ID_KEY, 
SystemTestContextBuilder
@@ -67,6 +66,12 @@ def get_region():
     return boto3.session.Session().region_name
 
 
+@task(trigger_rule=TriggerRule.ALL_DONE)
+def clean_logs(group_name: str):
+    client = boto3.client("logs")
+    client.delete_log_group(logGroupName=group_name)
+
+
 with DAG(
     dag_id=DAG_ID,
     schedule="@once",
@@ -85,6 +90,7 @@ with DAG(
     asg_name = f"{env_id}-asg"
 
     aws_region = get_region()
+    log_group_name = f"/ecs_test/{env_id}"
 
     # [START howto_operator_ecs_create_cluster]
     create_cluster = EcsCreateClusterOperator(
@@ -114,7 +120,16 @@ with DAG(
                 "workingDirectory": "/usr/bin",
                 "entryPoint": ["sh", "-c"],
                 "command": ["ls"],
-            }
+                "logConfiguration": {
+                    "logDriver": "awslogs",
+                    "options": {
+                        "awslogs-group": log_group_name,
+                        "awslogs-region": aws_region,
+                        "awslogs-create-group": "true",
+                        "awslogs-stream-prefix": "ecs",
+                    },
+                },
+            },
         ],
         register_task_kwargs={
             "cpu": "256",
@@ -140,38 +155,19 @@ with DAG(
             "containerOverrides": [
                 {
                     "name": container_name,
-                    "command": ["echo", "hello", "world"],
+                    "command": ["echo hello world"],
                 },
             ],
         },
         network_configuration={"awsvpcConfiguration": {"subnets": 
existing_cluster_subnets}},
         # [START howto_awslogs_ecs]
-        awslogs_group="/ecs/hello-world",
+        awslogs_group=log_group_name,
         awslogs_region=aws_region,
-        awslogs_stream_prefix="ecs/hello-world-container",
+        awslogs_stream_prefix=f"ecs/{container_name}",
         # [END howto_awslogs_ecs]
-        # You must set `reattach=True` in order to get ecs_task_arn if you 
plan to use a Sensor.
-        reattach=True,
     )
     # [END howto_operator_ecs_run_task]
 
-    # EcsRunTaskOperator waits by default, setting as False to test the Sensor 
below.
-    run_task.wait_for_completion = False
-
-    # [START howto_sensor_ecs_task_state]
-    # By default, EcsTaskStateSensor waits until the task has started, but the
-    # demo task runs so fast that the sensor misses it.  This sensor instead
-    # demonstrates how to wait until the ECS Task has completed by providing
-    # the target_state and failure_states parameters.
-    await_task_finish = EcsTaskStateSensor(
-        task_id="await_task_finish",
-        cluster=existing_cluster_name,
-        task=run_task.output["ecs_task_arn"],
-        target_state=EcsTaskStates.STOPPED,
-        failure_states={EcsTaskStates.NONE},
-    )
-    # [END howto_sensor_ecs_task_state]
-
     # [START howto_operator_ecs_deregister_task_definition]
     deregister_task = EcsDeregisterTaskDefinitionOperator(
         task_id="deregister_task",
@@ -209,10 +205,10 @@ with DAG(
         register_task,
         await_task_definition,
         run_task,
-        await_task_finish,
         deregister_task,
         delete_cluster,
         await_delete_cluster,
+        clean_logs(log_group_name),
     )
 
     from tests.system.utils.watcher import watcher
diff --git a/tests/system/providers/amazon/aws/example_ecs_fargate.py 
b/tests/system/providers/amazon/aws/example_ecs_fargate.py
index 704bd91cdf..b23f85a956 100644
--- a/tests/system/providers/amazon/aws/example_ecs_fargate.py
+++ b/tests/system/providers/amazon/aws/example_ecs_fargate.py
@@ -23,7 +23,9 @@ import boto3
 from airflow import DAG
 from airflow.decorators import task
 from airflow.models.baseoperator import chain
+from airflow.providers.amazon.aws.hooks.ecs import EcsTaskStates
 from airflow.providers.amazon.aws.operators.ecs import EcsRunTaskOperator
+from airflow.providers.amazon.aws.sensors.ecs import EcsTaskStateSensor
 from airflow.utils.trigger_rule import TriggerRule
 from tests.system.providers.amazon.aws.utils import ENV_ID_KEY, 
SystemTestContextBuilder
 
@@ -120,9 +122,28 @@ with DAG(
                 "assignPublicIp": "ENABLED",
             },
         },
+        # You must set `reattach=True` in order to get ecs_task_arn if you 
plan to use a Sensor.
+        reattach=True,
     )
     # [END howto_operator_ecs]
 
+    # EcsRunTaskOperator waits by default, setting as False to test the Sensor 
below.
+    hello_world.wait_for_completion = False
+
+    # [START howto_sensor_ecs_task_state]
+    # By default, EcsTaskStateSensor waits until the task has started, but the
+    # demo task runs so fast that the sensor misses it.  This sensor instead
+    # demonstrates how to wait until the ECS Task has completed by providing
+    # the target_state and failure_states parameters.
+    await_task_finish = EcsTaskStateSensor(
+        task_id="await_task_finish",
+        cluster=cluster_name,
+        task=hello_world.output["ecs_task_arn"],
+        target_state=EcsTaskStates.STOPPED,
+        failure_states={EcsTaskStates.NONE},
+    )
+    # [END howto_sensor_ecs_task_state]
+
     chain(
         # TEST SETUP
         test_context,
@@ -130,6 +151,7 @@ with DAG(
         create_task_definition,
         # TEST BODY
         hello_world,
+        await_task_finish,
         # TEST TEARDOWN
         delete_task_definition(create_task_definition),
         delete_cluster(cluster_name),

Reply via email to