This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new f0467c9fd6 Use a AwsBaseWaiterTrigger-based trigger in 
EmrAddStepsOperator deferred mode (#34216)
f0467c9fd6 is described below

commit f0467c9fd65e7146b44fc8f9fccb9ad750592371
Author: Pavel Yermalovich <[email protected]>
AuthorDate: Mon Sep 11 13:00:04 2023 +0200

    Use a AwsBaseWaiterTrigger-based trigger in EmrAddStepsOperator deferred 
mode (#34216)
---
 airflow/providers/amazon/aws/operators/emr.py      |  12 +-
 airflow/providers/amazon/aws/triggers/emr.py       |  91 ++++---------
 airflow/providers/amazon/aws/waiters/emr.json      |  31 +++++
 tests/providers/amazon/aws/hooks/test_emr.py       |   1 +
 tests/providers/amazon/aws/triggers/test_emr.py    |   8 ++
 .../amazon/aws/triggers/test_emr_trigger.py        | 144 ---------------------
 .../amazon/aws/waiters/test_custom_waiters.py      |  73 +++++++++++
 7 files changed, 147 insertions(+), 213 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/emr.py 
b/airflow/providers/amazon/aws/operators/emr.py
index 1bf2375a16..77e0167c21 100644
--- a/airflow/providers/amazon/aws/operators/emr.py
+++ b/airflow/providers/amazon/aws/operators/emr.py
@@ -100,8 +100,8 @@ class EmrAddStepsOperator(BaseOperator):
         aws_conn_id: str = "aws_default",
         steps: list[dict] | str | None = None,
         wait_for_completion: bool = False,
-        waiter_delay: int | None = 30,
-        waiter_max_attempts: int | None = 60,
+        waiter_delay: int = 30,
+        waiter_max_attempts: int = 60,
         execution_role_arn: str | None = None,
         deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
         **kwargs,
@@ -172,8 +172,8 @@ class EmrAddStepsOperator(BaseOperator):
                     job_flow_id=job_flow_id,
                     step_ids=step_ids,
                     aws_conn_id=self.aws_conn_id,
-                    max_attempts=self.waiter_max_attempts,
-                    poll_interval=self.waiter_delay,
+                    waiter_max_attempts=self.waiter_max_attempts,
+                    waiter_delay=self.waiter_delay,
                 ),
                 method_name="execute_complete",
             )
@@ -182,10 +182,10 @@ class EmrAddStepsOperator(BaseOperator):
 
     def execute_complete(self, context, event=None):
         if event["status"] != "success":
-            raise AirflowException(f"Error resuming cluster: {event}")
+            raise AirflowException(f"Error while running steps: {event}")
         else:
             self.log.info("Steps completed successfully")
-        return event["step_ids"]
+        return event["value"]
 
 
 class EmrStartNotebookExecutionOperator(BaseOperator):
diff --git a/airflow/providers/amazon/aws/triggers/emr.py 
b/airflow/providers/amazon/aws/triggers/emr.py
index a928255fe0..32f9049155 100644
--- a/airflow/providers/amazon/aws/triggers/emr.py
+++ b/airflow/providers/amazon/aws/triggers/emr.py
@@ -16,90 +16,55 @@
 # under the License.
 from __future__ import annotations
 
-import asyncio
 import warnings
-from typing import TYPE_CHECKING, Any
-
-from botocore.exceptions import WaiterError
+from typing import TYPE_CHECKING
 
 from airflow.exceptions import AirflowProviderDeprecationWarning
 from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, 
EmrServerlessHook
 from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
-from airflow.triggers.base import BaseTrigger, TriggerEvent
 
 if TYPE_CHECKING:
     from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
 
 
-class EmrAddStepsTrigger(BaseTrigger):
+class EmrAddStepsTrigger(AwsBaseWaiterTrigger):
     """
-    Asynchronously poll the boto3 API and wait for the steps to finish 
executing.
+    Poll for the status of EMR steps until they reach terminal state.
+
+    :param job_flow_id: job_flow_id which contains the steps to check the 
state of
+    :param step_ids: steps to check the state of
+    :param waiter_delay: polling period in seconds to check for the status
+    :param waiter_max_attempts: The maximum number of attempts to be made
+    :param aws_conn_id: Reference to AWS connection id
 
-    :param job_flow_id: The id of the job flow.
-    :param step_ids: The id of the steps being waited upon.
-    :param poll_interval: The amount of time in seconds to wait between 
attempts.
-    :param max_attempts: The maximum number of attempts to be made.
-    :param aws_conn_id: The Airflow connection used for AWS credentials.
     """
 
     def __init__(
         self,
         job_flow_id: str,
         step_ids: list[str],
-        aws_conn_id: str,
-        max_attempts: int | None,
-        poll_interval: int | None,
+        waiter_delay: int,
+        waiter_max_attempts: int,
+        aws_conn_id: str = "aws_default",
     ):
-        self.job_flow_id = job_flow_id
-        self.step_ids = step_ids
-        self.aws_conn_id = aws_conn_id
-        self.max_attempts = max_attempts
-        self.poll_interval = poll_interval
-
-    def serialize(self) -> tuple[str, dict[str, Any]]:
-        return (
-            "airflow.providers.amazon.aws.triggers.emr.EmrAddStepsTrigger",
-            {
-                "job_flow_id": str(self.job_flow_id),
-                "step_ids": self.step_ids,
-                "poll_interval": str(self.poll_interval),
-                "max_attempts": str(self.max_attempts),
-                "aws_conn_id": str(self.aws_conn_id),
-            },
+        super().__init__(
+            serialized_fields={"job_flow_id": job_flow_id, "step_ids": 
step_ids},
+            waiter_name="steps_wait_for_terminal",
+            waiter_args={"ClusterId": job_flow_id, "StepIds": step_ids},
+            failure_message=f"Error while waiting for steps {step_ids} to 
complete",
+            status_message=f"Step ids: {step_ids}, Steps are still in 
non-terminal state",
+            status_queries=[
+                "Steps[].Status.State",
+                "Steps[].Status.FailureDetails",
+            ],
+            return_value=step_ids,
+            waiter_delay=waiter_delay,
+            waiter_max_attempts=waiter_max_attempts,
+            aws_conn_id=aws_conn_id,
         )
 
-    async def run(self):
-        self.hook = EmrHook(aws_conn_id=self.aws_conn_id)
-        async with self.hook.async_conn as client:
-            for step_id in self.step_ids:
-                waiter = client.get_waiter("step_complete")
-                for attempt in range(1, 1 + self.max_attempts):
-                    try:
-                        await waiter.wait(
-                            ClusterId=self.job_flow_id,
-                            StepId=step_id,
-                            WaiterConfig={
-                                "Delay": int(self.poll_interval),
-                                "MaxAttempts": 1,
-                            },
-                        )
-                        break
-                    except WaiterError as error:
-                        if "terminal failure" in str(error):
-                            yield TriggerEvent(
-                                {"status": "failure", "message": f"Step 
{step_id} failed: {error}"}
-                            )
-                            break
-                        self.log.info(
-                            "Status of step is %s - %s",
-                            error.last_response["Step"]["Status"]["State"],
-                            
error.last_response["Step"]["Status"]["StateChangeReason"],
-                        )
-                        await asyncio.sleep(int(self.poll_interval))
-        if attempt >= int(self.max_attempts):
-            yield TriggerEvent({"status": "failure", "message": "Steps failed: 
max attempts reached"})
-        else:
-            yield TriggerEvent({"status": "success", "message": "Steps 
completed", "step_ids": self.step_ids})
+    def hook(self) -> AwsGenericHook:
+        return EmrHook(aws_conn_id=self.aws_conn_id)
 
 
 class EmrCreateJobFlowTrigger(AwsBaseWaiterTrigger):
diff --git a/airflow/providers/amazon/aws/waiters/emr.json 
b/airflow/providers/amazon/aws/waiters/emr.json
index 33a90c7751..91c902eed6 100644
--- a/airflow/providers/amazon/aws/waiters/emr.json
+++ b/airflow/providers/amazon/aws/waiters/emr.json
@@ -125,6 +125,37 @@
                     "state": "failure"
                 }
             ]
+        },
+        "steps_wait_for_terminal": {
+            "operation": "ListSteps",
+            "delay": 30,
+            "maxAttempts": 60,
+            "acceptors": [
+                {
+                    "matcher": "pathAll",
+                    "argument": "Steps[].Status.State",
+                    "expected": "COMPLETED",
+                    "state": "success"
+                },
+                {
+                    "matcher": "pathAny",
+                    "argument": "Steps[].Status.State",
+                    "expected": "CANCELLED",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "pathAny",
+                    "argument": "Steps[].Status.State",
+                    "expected": "FAILED",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "pathAny",
+                    "argument": "Steps[].Status.State",
+                    "expected": "INTERRUPTED",
+                    "state": "failure"
+                }
+            ]
         }
     }
 }
diff --git a/tests/providers/amazon/aws/hooks/test_emr.py 
b/tests/providers/amazon/aws/hooks/test_emr.py
index c68b25fbdf..b9864e84db 100644
--- a/tests/providers/amazon/aws/hooks/test_emr.py
+++ b/tests/providers/amazon/aws/hooks/test_emr.py
@@ -39,6 +39,7 @@ class TestEmrHook:
             "notebook_running",
             "notebook_stopped",
             "step_wait_for_terminal",
+            "steps_wait_for_terminal",
         ]
 
         assert sorted(hook.list_waiters()) == sorted([*official_waiters, 
*custom_waiters])
diff --git a/tests/providers/amazon/aws/triggers/test_emr.py 
b/tests/providers/amazon/aws/triggers/test_emr.py
index f83ad8ccd0..5a1369e89d 100644
--- a/tests/providers/amazon/aws/triggers/test_emr.py
+++ b/tests/providers/amazon/aws/triggers/test_emr.py
@@ -19,6 +19,7 @@ from __future__ import annotations
 import pytest
 
 from airflow.providers.amazon.aws.triggers.emr import (
+    EmrAddStepsTrigger,
     EmrContainerTrigger,
     EmrCreateJobFlowTrigger,
     EmrStepSensorTrigger,
@@ -40,6 +41,13 @@ class TestEmrTriggers:
     @pytest.mark.parametrize(
         "trigger",
         [
+            EmrAddStepsTrigger(
+                job_flow_id=TEST_JOB_FLOW_ID,
+                step_ids=["my_step1", "my_step2"],
+                aws_conn_id=TEST_AWS_CONN_ID,
+                waiter_delay=TEST_POLL_INTERVAL,
+                waiter_max_attempts=TEST_MAX_ATTEMPTS,
+            ),
             EmrCreateJobFlowTrigger(
                 job_flow_id=TEST_JOB_FLOW_ID,
                 aws_conn_id=TEST_AWS_CONN_ID,
diff --git a/tests/providers/amazon/aws/triggers/test_emr_trigger.py 
b/tests/providers/amazon/aws/triggers/test_emr_trigger.py
index fe28f86edb..187a948fb5 100644
--- a/tests/providers/amazon/aws/triggers/test_emr_trigger.py
+++ b/tests/providers/amazon/aws/triggers/test_emr_trigger.py
@@ -16,21 +16,14 @@
 # under the License.
 from __future__ import annotations
 
-from unittest import mock
-from unittest.mock import AsyncMock
-
 import pytest
-from botocore.exceptions import WaiterError
 
-from airflow.providers.amazon.aws.hooks.emr import EmrHook
 from airflow.providers.amazon.aws.triggers.emr import (
-    EmrAddStepsTrigger,
     EmrContainerTrigger,
     EmrCreateJobFlowTrigger,
     EmrStepSensorTrigger,
     EmrTerminateJobFlowTrigger,
 )
-from airflow.triggers.base import TriggerEvent
 
 TEST_JOB_FLOW_ID = "test_job_flow_id"
 TEST_STEP_IDS = ["step1", "step2"]
@@ -39,143 +32,6 @@ TEST_MAX_ATTEMPTS = 10
 TEST_POLL_INTERVAL = 10
 
 
-class TestEmrAddStepsTrigger:
-    def test_emr_add_steps_trigger_serialize(self):
-        emr_add_steps_trigger = EmrAddStepsTrigger(
-            job_flow_id=TEST_JOB_FLOW_ID,
-            step_ids=TEST_STEP_IDS,
-            aws_conn_id=TEST_AWS_CONN_ID,
-            max_attempts=TEST_MAX_ATTEMPTS,
-            poll_interval=TEST_POLL_INTERVAL,
-        )
-        class_path, args = emr_add_steps_trigger.serialize()
-        assert class_path == 
"airflow.providers.amazon.aws.triggers.emr.EmrAddStepsTrigger"
-        assert args["job_flow_id"] == TEST_JOB_FLOW_ID
-        assert args["step_ids"] == TEST_STEP_IDS
-        assert args["poll_interval"] == str(TEST_POLL_INTERVAL)
-        assert args["max_attempts"] == str(TEST_MAX_ATTEMPTS)
-        assert args["aws_conn_id"] == TEST_AWS_CONN_ID
-
-    @pytest.mark.asyncio
-    @mock.patch.object(EmrHook, "async_conn")
-    async def test_emr_add_steps_trigger_run(self, mock_async_conn):
-        a_mock = mock.MagicMock()
-        mock_async_conn.__aenter__.return_value = a_mock
-        a_mock.get_waiter().wait = AsyncMock()
-
-        emr_add_steps_trigger = EmrAddStepsTrigger(
-            job_flow_id=TEST_JOB_FLOW_ID,
-            step_ids=TEST_STEP_IDS,
-            aws_conn_id=TEST_AWS_CONN_ID,
-            max_attempts=TEST_MAX_ATTEMPTS,
-            poll_interval=TEST_POLL_INTERVAL,
-        )
-
-        generator = emr_add_steps_trigger.run()
-        response = await generator.asend(None)
-
-        assert response == TriggerEvent(
-            {"status": "success", "message": "Steps completed", "step_ids": 
TEST_STEP_IDS}
-        )
-
-    @pytest.mark.asyncio
-    @mock.patch("asyncio.sleep")
-    @mock.patch.object(EmrHook, "async_conn")
-    async def test_emr_add_steps_trigger_run_multiple_attempts(self, 
mock_async_conn, mock_sleep):
-        a_mock = mock.MagicMock()
-        mock_async_conn.__aenter__.return_value = a_mock
-        error = WaiterError(
-            name="test_name",
-            reason="test_reason",
-            last_response={"Step": {"Status": {"State": "Running", 
"StateChangeReason": "test_reason"}}},
-        )
-        a_mock.get_waiter().wait.side_effect = AsyncMock(side_effect=[error, 
error, True, error, error, True])
-        mock_sleep.return_value = True
-
-        emr_add_steps_trigger = EmrAddStepsTrigger(
-            job_flow_id=TEST_JOB_FLOW_ID,
-            step_ids=TEST_STEP_IDS,
-            aws_conn_id=TEST_AWS_CONN_ID,
-            max_attempts=TEST_MAX_ATTEMPTS,
-            poll_interval=TEST_POLL_INTERVAL,
-        )
-
-        generator = emr_add_steps_trigger.run()
-        response = await generator.asend(None)
-
-        assert a_mock.get_waiter().wait.call_count == 6
-        assert response == TriggerEvent(
-            {"status": "success", "message": "Steps completed", "step_ids": 
TEST_STEP_IDS}
-        )
-
-    @pytest.mark.asyncio
-    @mock.patch("asyncio.sleep")
-    @mock.patch.object(EmrHook, "async_conn")
-    async def test_emr_add_steps_trigger_run_attempts_exceeded(self, 
mock_async_conn, mock_sleep):
-        a_mock = mock.MagicMock()
-        mock_async_conn.__aenter__.return_value = a_mock
-        error = WaiterError(
-            name="test_name",
-            reason="test_reason",
-            last_response={"Step": {"Status": {"State": "Running", 
"StateChangeReason": "test_reason"}}},
-        )
-        a_mock.get_waiter().wait.side_effect = AsyncMock(side_effect=[error, 
error, True])
-        mock_sleep.return_value = True
-
-        emr_add_steps_trigger = EmrAddStepsTrigger(
-            job_flow_id=TEST_JOB_FLOW_ID,
-            step_ids=[TEST_STEP_IDS[0]],
-            aws_conn_id=TEST_AWS_CONN_ID,
-            max_attempts=2,
-            poll_interval=TEST_POLL_INTERVAL,
-        )
-
-        generator = emr_add_steps_trigger.run()
-        response = await generator.asend(None)
-
-        assert a_mock.get_waiter().wait.call_count == 2
-        assert response == TriggerEvent(
-            {"status": "failure", "message": "Steps failed: max attempts 
reached"}
-        )
-
-    @pytest.mark.asyncio
-    @mock.patch("asyncio.sleep")
-    @mock.patch.object(EmrHook, "async_conn")
-    async def test_emr_add_steps_trigger_run_attempts_failed(self, 
mock_async_conn, mock_sleep):
-        a_mock = mock.MagicMock()
-        mock_async_conn.__aenter__.return_value = a_mock
-        error_running = WaiterError(
-            name="test_name",
-            reason="test_reason",
-            last_response={"Step": {"Status": {"State": "Running", 
"StateChangeReason": "test_reason"}}},
-        )
-        error_failed = WaiterError(
-            name="test_name",
-            reason="Waiter encountered a terminal failure state:",
-            last_response={"Step": {"Status": {"State": "FAILED", 
"StateChangeReason": "test_reason"}}},
-        )
-        a_mock.get_waiter().wait.side_effect = AsyncMock(
-            side_effect=[error_running, error_running, error_failed]
-        )
-        mock_sleep.return_value = True
-
-        emr_add_steps_trigger = EmrAddStepsTrigger(
-            job_flow_id=TEST_JOB_FLOW_ID,
-            step_ids=[TEST_STEP_IDS[0]],
-            aws_conn_id=TEST_AWS_CONN_ID,
-            max_attempts=TEST_MAX_ATTEMPTS,
-            poll_interval=TEST_POLL_INTERVAL,
-        )
-
-        generator = emr_add_steps_trigger.run()
-        response = await generator.asend(None)
-
-        assert a_mock.get_waiter().wait.call_count == 3
-        assert response == TriggerEvent(
-            {"status": "failure", "message": f"Step {TEST_STEP_IDS[0]} failed: 
{error_failed}"}
-        )
-
-
 class TestEmrTriggers:
     @pytest.mark.parametrize(
         "trigger",
diff --git a/tests/providers/amazon/aws/waiters/test_custom_waiters.py 
b/tests/providers/amazon/aws/waiters/test_custom_waiters.py
index 229f6ce377..19f9296b6a 100644
--- a/tests/providers/amazon/aws/waiters/test_custom_waiters.py
+++ b/tests/providers/amazon/aws/waiters/test_custom_waiters.py
@@ -18,6 +18,7 @@
 from __future__ import annotations
 
 import json
+from typing import Sequence
 from unittest import mock
 
 import boto3
@@ -31,6 +32,7 @@ from airflow.providers.amazon.aws.hooks.batch_client import 
BatchClientHook
 from airflow.providers.amazon.aws.hooks.dynamodb import DynamoDBHook
 from airflow.providers.amazon.aws.hooks.ecs import EcsClusterStates, EcsHook, 
EcsTaskDefinitionStates
 from airflow.providers.amazon.aws.hooks.eks import EksHook
+from airflow.providers.amazon.aws.hooks.emr import EmrHook
 from airflow.providers.amazon.aws.waiters.base_waiter import BaseBotoWaiter
 
 
@@ -351,3 +353,74 @@ class TestCustomBatchServiceWaiters:
 
         with pytest.raises(WaiterError, match="Waiter encountered a terminal 
failure state"):
             waiter.wait(jobs=[self.JOB_ID], WaiterConfig={"Delay": 0.01, 
"MaxAttempts": 2})
+
+
+class TestCustomEmrServiceWaiters:
+    """Test waiters from ``amazon/aws/waiters/emr.json``."""
+
+    JOBFLOW_ID = "test_jobflow_id"
+    STEP_ID1 = "test_step_id_1"
+    STEP_ID2 = "test_step_id_2"
+
+    @pytest.fixture(autouse=True)
+    def setup_test_cases(self, monkeypatch):
+        self.client = boto3.client("emr", region_name="eu-west-3")
+        monkeypatch.setattr(EmrHook, "conn", self.client)
+
+    @pytest.fixture
+    def mock_list_steps(self):
+        """Mock ``EmrHook.Client.list_steps`` method."""
+        with mock.patch.object(self.client, "list_steps") as m:
+            yield m
+
+    def test_service_waiters(self):
+        hook_waiters = EmrHook(aws_conn_id=None).list_waiters()
+        assert "steps_wait_for_terminal" in hook_waiters
+
+    @staticmethod
+    def list_steps(step_records: Sequence[tuple[str, str]]):
+        """
+        Helper function to generate minimal ListSteps response.
+        https://docs.aws.amazon.com/emr/latest/APIReference/API_ListSteps.html
+        """
+        return {
+            "Steps": [
+                {
+                    "Id": step_record[0],
+                    "Status": {
+                        "State": step_record[1],
+                    },
+                }
+                for step_record in step_records
+            ],
+        }
+
+    def test_steps_succeeded(self, mock_list_steps):
+        """Test steps succeeded"""
+        mock_list_steps.side_effect = [
+            self.list_steps([(self.STEP_ID1, "PENDING"), (self.STEP_ID2, 
"RUNNING")]),
+            self.list_steps([(self.STEP_ID1, "RUNNING"), (self.STEP_ID2, 
"COMPLETED")]),
+            self.list_steps([(self.STEP_ID1, "COMPLETED"), (self.STEP_ID2, 
"COMPLETED")]),
+        ]
+        waiter = 
EmrHook(aws_conn_id=None).get_waiter("steps_wait_for_terminal")
+        waiter.wait(
+            ClusterId=self.JOBFLOW_ID,
+            StepIds=[self.STEP_ID1, self.STEP_ID2],
+            WaiterConfig={"Delay": 0.01, "MaxAttempts": 3},
+        )
+
+    def test_steps_failed(self, mock_list_steps):
+        """Test steps failed"""
+        mock_list_steps.side_effect = [
+            self.list_steps([(self.STEP_ID1, "PENDING"), (self.STEP_ID2, 
"RUNNING")]),
+            self.list_steps([(self.STEP_ID1, "RUNNING"), (self.STEP_ID2, 
"COMPLETED")]),
+            self.list_steps([(self.STEP_ID1, "FAILED"), (self.STEP_ID2, 
"COMPLETED")]),
+        ]
+        waiter = 
EmrHook(aws_conn_id=None).get_waiter("steps_wait_for_terminal")
+
+        with pytest.raises(WaiterError, match="Waiter encountered a terminal 
failure state"):
+            waiter.wait(
+                ClusterId=self.JOBFLOW_ID,
+                StepIds=[self.STEP_ID1, self.STEP_ID2],
+                WaiterConfig={"Delay": 0.01, "MaxAttempts": 3},
+            )

Reply via email to