This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 4d5e05b28b Add deferrable mode to RedshiftClusterSensor (#36550)
4d5e05b28b is described below

commit 4d5e05b28b99d8a06c20f00c93efa90002f8d401
Author: Wei Lee <weilee...@gmail.com>
AuthorDate: Tue Jan 9 05:21:00 2024 +0800

    Add deferrable mode to RedshiftClusterSensor (#36550)
---
 .../amazon/aws/sensors/redshift_cluster.py         |  44 +++++++-
 .../amazon/aws/triggers/redshift_cluster.py        |  56 +++++++++-
 tests/always/test_project_structure.py             |   1 -
 .../amazon/aws/sensors/test_redshift_cluster.py    |  58 ++++++++++
 .../amazon/aws/triggers/test_redshift_cluster.py   | 119 +++++++++++++++++++++
 5 files changed, 272 insertions(+), 6 deletions(-)

diff --git a/airflow/providers/amazon/aws/sensors/redshift_cluster.py 
b/airflow/providers/amazon/aws/sensors/redshift_cluster.py
index 5b649bf78a..cd63bb5e1f 100644
--- a/airflow/providers/amazon/aws/sensors/redshift_cluster.py
+++ b/airflow/providers/amazon/aws/sensors/redshift_cluster.py
@@ -16,13 +16,16 @@
 # under the License.
 from __future__ import annotations
 
+from datetime import timedelta
 from functools import cached_property
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, Any, Sequence
 
 from deprecated import deprecated
 
-from airflow.exceptions import AirflowProviderDeprecationWarning
+from airflow.configuration import conf
+from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning, AirflowSkipException
 from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
+from airflow.providers.amazon.aws.triggers.redshift_cluster import 
RedshiftClusterTrigger
 from airflow.sensors.base import BaseSensorOperator
 
 if TYPE_CHECKING:
@@ -39,6 +42,7 @@ class RedshiftClusterSensor(BaseSensorOperator):
 
     :param cluster_identifier: The identifier for the cluster being pinged.
     :param target_status: The cluster status desired.
+    :param deferrable: Run operator in the deferrable mode.
     """
 
     template_fields: Sequence[str] = ("cluster_identifier", "target_status")
@@ -49,14 +53,16 @@ class RedshiftClusterSensor(BaseSensorOperator):
         cluster_identifier: str,
         target_status: str = "available",
         aws_conn_id: str = "aws_default",
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
         **kwargs,
     ):
         super().__init__(**kwargs)
         self.cluster_identifier = cluster_identifier
         self.target_status = target_status
         self.aws_conn_id = aws_conn_id
+        self.deferrable = deferrable
 
-    def poke(self, context: Context):
+    def poke(self, context: Context) -> bool:
         current_status = self.hook.cluster_status(self.cluster_identifier)
         self.log.info(
             "Poked cluster %s for status '%s', found status '%s'",
@@ -66,6 +72,38 @@ class RedshiftClusterSensor(BaseSensorOperator):
         )
         return current_status == self.target_status
 
+    def execute(self, context: Context) -> None:
+        if not self.deferrable:
+            super().execute(context=context)
+        elif not self.poke(context):
+            self.defer(
+                timeout=timedelta(seconds=self.timeout),
+                trigger=RedshiftClusterTrigger(
+                    aws_conn_id=self.aws_conn_id,
+                    cluster_identifier=self.cluster_identifier,
+                    target_status=self.target_status,
+                    poke_interval=self.poke_interval,
+                ),
+                method_name="execute_complete",
+            )
+
+    def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> None:
+        if event is None:
+            err_msg = "Trigger error: event is None"
+            self.log.error(err_msg)
+            raise AirflowException(err_msg)
+
+        status = event["status"]
+        if status == "error":
+            # TODO: remove this if block when min_airflow_version is set to 
higher than 2.7.1
+            message = f"{event['status']}: {event['message']}"
+            if self.soft_fail:
+                raise AirflowSkipException(message)
+            raise AirflowException(message)
+        elif status == "success":
+            self.log.info("%s completed successfully.", self.task_id)
+            self.log.info("Cluster Identifier %s is in %s state", 
self.cluster_identifier, self.target_status)
+
     @deprecated(reason="use `hook` property instead.", 
category=AirflowProviderDeprecationWarning)
     def get_hook(self) -> RedshiftHook:
         """Create and return a RedshiftHook."""
diff --git a/airflow/providers/amazon/aws/triggers/redshift_cluster.py 
b/airflow/providers/amazon/aws/triggers/redshift_cluster.py
index 64a636c8db..456d9df303 100644
--- a/airflow/providers/amazon/aws/triggers/redshift_cluster.py
+++ b/airflow/providers/amazon/aws/triggers/redshift_cluster.py
@@ -16,12 +16,14 @@
 # under the License.
 from __future__ import annotations
 
+import asyncio
 import warnings
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, AsyncIterator
 
 from airflow.exceptions import AirflowProviderDeprecationWarning
-from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
+from airflow.providers.amazon.aws.hooks.redshift_cluster import 
RedshiftAsyncHook, RedshiftHook
 from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
+from airflow.triggers.base import BaseTrigger, TriggerEvent
 
 if TYPE_CHECKING:
     from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
@@ -262,3 +264,53 @@ class RedshiftDeleteClusterTrigger(AwsBaseWaiterTrigger):
 
     def hook(self) -> AwsGenericHook:
         return RedshiftHook(aws_conn_id=self.aws_conn_id)
+
+
+class RedshiftClusterTrigger(BaseTrigger):
+    """
+    RedshiftClusterTrigger is fired as deferred class with params to run the 
task in trigger worker.
+
+    :param aws_conn_id: Reference to AWS connection id for redshift
+    :param cluster_identifier: unique identifier of a cluster
+    :param target_status: Reference to the status which needs to be checked
+    :param poke_interval:  polling period in seconds to check for the status
+    """
+
+    def __init__(
+        self,
+        aws_conn_id: str,
+        cluster_identifier: str,
+        target_status: str,
+        poke_interval: float,
+    ):
+        super().__init__()
+        self.aws_conn_id = aws_conn_id
+        self.cluster_identifier = cluster_identifier
+        self.target_status = target_status
+        self.poke_interval = poke_interval
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        """Serializes RedshiftClusterTrigger arguments and classpath."""
+        return (
+            
"airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftClusterTrigger",
+            {
+                "aws_conn_id": self.aws_conn_id,
+                "cluster_identifier": self.cluster_identifier,
+                "target_status": self.target_status,
+                "poke_interval": self.poke_interval,
+            },
+        )
+
+    async def run(self) -> AsyncIterator[TriggerEvent]:
+        """Simple async function run until the cluster status match the target 
status."""
+        try:
+            hook = RedshiftAsyncHook(aws_conn_id=self.aws_conn_id)
+            while True:
+                res = await hook.cluster_status(self.cluster_identifier)
+                if (res["status"] == "success" and res["cluster_state"] == 
self.target_status) or res[
+                    "status"
+                ] == "error":
+                    yield TriggerEvent(res)
+                await asyncio.sleep(self.poke_interval)
+        except Exception as e:
+            yield TriggerEvent({"status": "error", "message": str(e)})
diff --git a/tests/always/test_project_structure.py 
b/tests/always/test_project_structure.py
index 3b1b8a1e97..c88387919e 100644
--- a/tests/always/test_project_structure.py
+++ b/tests/always/test_project_structure.py
@@ -75,7 +75,6 @@ class TestProjectStructure:
             "tests/providers/amazon/aws/triggers/test_glue_crawler.py",
             "tests/providers/amazon/aws/triggers/test_lambda_function.py",
             "tests/providers/amazon/aws/triggers/test_rds.py",
-            "tests/providers/amazon/aws/triggers/test_redshift_cluster.py",
             "tests/providers/amazon/aws/triggers/test_step_function.py",
             "tests/providers/amazon/aws/utils/test_rds.py",
             "tests/providers/amazon/aws/utils/test_sagemaker.py",
diff --git a/tests/providers/amazon/aws/sensors/test_redshift_cluster.py 
b/tests/providers/amazon/aws/sensors/test_redshift_cluster.py
index 07c6ad67ee..04d4ca55fd 100644
--- a/tests/providers/amazon/aws/sensors/test_redshift_cluster.py
+++ b/tests/providers/amazon/aws/sensors/test_redshift_cluster.py
@@ -16,10 +16,27 @@
 # under the License.
 from __future__ import annotations
 
+from unittest import mock
+
 import boto3
+import pytest
 from moto import mock_redshift
 
+from airflow.exceptions import AirflowException, TaskDeferred
 from airflow.providers.amazon.aws.sensors.redshift_cluster import 
RedshiftClusterSensor
+from airflow.providers.amazon.aws.triggers.redshift_cluster import 
RedshiftClusterTrigger
+
+MODULE = "airflow.providers.amazon.aws.sensors.redshift_cluster"
+
+
+@pytest.fixture
+def deferrable_op():
+    return RedshiftClusterSensor(
+        task_id="test_cluster_sensor",
+        cluster_identifier="test_cluster",
+        target_status="available",
+        deferrable=True,
+    )
 
 
 class TestRedshiftClusterSensor:
@@ -75,3 +92,44 @@ class TestRedshiftClusterSensor:
         )
 
         assert op.poke({})
+
+    @mock.patch(f"{MODULE}.RedshiftClusterSensor.defer")
+    @mock.patch(f"{MODULE}.RedshiftClusterSensor.poke", return_value=True)
+    def test_execute_finish_before_deferred(
+        self,
+        mock_poke,
+        mock_defer,
+        deferrable_op,
+    ):
+        """Assert task is not deferred when it receives a finish status before 
deferring"""
+
+        deferrable_op.execute({})
+        assert not mock_defer.called
+
+    @mock.patch(f"{MODULE}.RedshiftClusterSensor.poke", return_value=False)
+    def test_execute(self, mock_poke, deferrable_op):
+        """Test RedshiftClusterSensor that a task with wildcard=True
+        is deferred and an RedshiftClusterTrigger will be fired when executed 
method is called"""
+
+        with pytest.raises(TaskDeferred) as exc:
+            deferrable_op.execute(None)
+        assert isinstance(
+            exc.value.trigger, RedshiftClusterTrigger
+        ), "Trigger is not a RedshiftClusterTrigger"
+
+    def test_redshift_sensor_async_execute_failure(self, deferrable_op):
+        """Test RedshiftClusterSensor with an AirflowException is raised in 
case of error event"""
+
+        with pytest.raises(AirflowException):
+            deferrable_op.execute_complete(
+                context=None, event={"status": "error", "message": "test 
failure message"}
+            )
+
+    def test_redshift_sensor_async_execute_complete(self, deferrable_op):
+        """Asserts that logging occurs as expected"""
+
+        with mock.patch.object(deferrable_op.log, "info") as mock_log_info:
+            deferrable_op.execute_complete(
+                context=None, event={"status": "success", "cluster_state": 
"available"}
+            )
+        mock_log_info.assert_called_with("Cluster Identifier %s is in %s 
state", "test_cluster", "available")
diff --git a/tests/providers/amazon/aws/triggers/test_redshift_cluster.py 
b/tests/providers/amazon/aws/triggers/test_redshift_cluster.py
new file mode 100644
index 0000000000..5d5cc2c424
--- /dev/null
+++ b/tests/providers/amazon/aws/triggers/test_redshift_cluster.py
@@ -0,0 +1,119 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import asyncio
+from unittest import mock
+
+import pytest
+
+from airflow.providers.amazon.aws.triggers.redshift_cluster import (
+    RedshiftClusterTrigger,
+)
+from airflow.triggers.base import TriggerEvent
+
+POLLING_PERIOD_SECONDS = 1.0
+
+
+class TestRedshiftClusterTrigger:
+    def test_redshift_cluster_sensor_trigger_serialization(self):
+        """
+        Asserts that the RedshiftClusterTrigger correctly serializes its 
arguments
+        and classpath.
+        """
+        trigger = RedshiftClusterTrigger(
+            aws_conn_id="test_redshift_conn_id",
+            cluster_identifier="mock_cluster_identifier",
+            target_status="available",
+            poke_interval=POLLING_PERIOD_SECONDS,
+        )
+        classpath, kwargs = trigger.serialize()
+        assert classpath == 
"airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftClusterTrigger"
+        assert kwargs == {
+            "aws_conn_id": "test_redshift_conn_id",
+            "cluster_identifier": "mock_cluster_identifier",
+            "target_status": "available",
+            "poke_interval": POLLING_PERIOD_SECONDS,
+        }
+
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftAsyncHook.cluster_status")
+    async def test_redshift_cluster_sensor_trigger_success(self, 
mock_cluster_status):
+        """
+        Test RedshiftClusterTrigger with the success status
+        """
+        expected_result = {"status": "success", "cluster_state": "available"}
+
+        mock_cluster_status.return_value = expected_result
+        trigger = RedshiftClusterTrigger(
+            aws_conn_id="test_redshift_conn_id",
+            cluster_identifier="mock_cluster_identifier",
+            target_status="available",
+            poke_interval=POLLING_PERIOD_SECONDS,
+        )
+
+        generator = trigger.run()
+        actual = await generator.asend(None)
+        assert TriggerEvent(expected_result) == actual
+
+    @pytest.mark.asyncio
+    @pytest.mark.parametrize(
+        "expected_result",
+        [
+            ({"status": "success", "cluster_state": "Resuming"}),
+        ],
+    )
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftAsyncHook.cluster_status")
+    async def test_redshift_cluster_sensor_trigger_resuming_status(
+        self, mock_cluster_status, expected_result
+    ):
+        """Test RedshiftClusterTrigger with the success status"""
+        mock_cluster_status.return_value = expected_result
+        trigger = RedshiftClusterTrigger(
+            aws_conn_id="test_redshift_conn_id",
+            cluster_identifier="mock_cluster_identifier",
+            target_status="available",
+            poke_interval=POLLING_PERIOD_SECONDS,
+        )
+        task = asyncio.create_task(trigger.run().__anext__())
+        await asyncio.sleep(0.5)
+
+        # TriggerEvent was returned
+        assert task.done() is False
+
+        # Prevents error when task is destroyed while in "pending" state
+        asyncio.get_event_loop().stop()
+
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftAsyncHook.cluster_status")
+    async def test_redshift_cluster_sensor_trigger_exception(self, 
mock_cluster_status):
+        """Test RedshiftClusterTrigger with exception"""
+        mock_cluster_status.side_effect = Exception("Test exception")
+        trigger = RedshiftClusterTrigger(
+            aws_conn_id="test_redshift_conn_id",
+            cluster_identifier="mock_cluster_identifier",
+            target_status="available",
+            poke_interval=POLLING_PERIOD_SECONDS,
+        )
+
+        task = [i async for i in trigger.run()]
+        # since we use return as soon as we yield the trigger event
+        # at any given point there should be one trigger event returned to the 
task
+        # so we validate for length of task to be 1
+        assert len(task) == 1
+        assert TriggerEvent({"status": "error", "message": "Test exception"}) 
in task

Reply via email to