This is an automated email from the ASF dual-hosted git repository. vincbeck pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 4d5e05b28b Add deferrable mode to RedshiftClusterSensor (#36550) 4d5e05b28b is described below commit 4d5e05b28b99d8a06c20f00c93efa90002f8d401 Author: Wei Lee <weilee...@gmail.com> AuthorDate: Tue Jan 9 05:21:00 2024 +0800 Add deferrable mode to RedshiftClusterSensor (#36550) --- .../amazon/aws/sensors/redshift_cluster.py | 44 +++++++- .../amazon/aws/triggers/redshift_cluster.py | 56 +++++++++- tests/always/test_project_structure.py | 1 - .../amazon/aws/sensors/test_redshift_cluster.py | 58 ++++++++++ .../amazon/aws/triggers/test_redshift_cluster.py | 119 +++++++++++++++++++++ 5 files changed, 272 insertions(+), 6 deletions(-) diff --git a/airflow/providers/amazon/aws/sensors/redshift_cluster.py b/airflow/providers/amazon/aws/sensors/redshift_cluster.py index 5b649bf78a..cd63bb5e1f 100644 --- a/airflow/providers/amazon/aws/sensors/redshift_cluster.py +++ b/airflow/providers/amazon/aws/sensors/redshift_cluster.py @@ -16,13 +16,16 @@ # under the License. from __future__ import annotations +from datetime import timedelta from functools import cached_property -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, Any, Sequence from deprecated import deprecated -from airflow.exceptions import AirflowProviderDeprecationWarning +from airflow.configuration import conf +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, AirflowSkipException from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook +from airflow.providers.amazon.aws.triggers.redshift_cluster import RedshiftClusterTrigger from airflow.sensors.base import BaseSensorOperator if TYPE_CHECKING: @@ -39,6 +42,7 @@ class RedshiftClusterSensor(BaseSensorOperator): :param cluster_identifier: The identifier for the cluster being pinged. :param target_status: The cluster status desired. + :param deferrable: Run operator in the deferrable mode. """ template_fields: Sequence[str] = ("cluster_identifier", "target_status") @@ -49,14 +53,16 @@ class RedshiftClusterSensor(BaseSensorOperator): cluster_identifier: str, target_status: str = "available", aws_conn_id: str = "aws_default", + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ): super().__init__(**kwargs) self.cluster_identifier = cluster_identifier self.target_status = target_status self.aws_conn_id = aws_conn_id + self.deferrable = deferrable - def poke(self, context: Context): + def poke(self, context: Context) -> bool: current_status = self.hook.cluster_status(self.cluster_identifier) self.log.info( "Poked cluster %s for status '%s', found status '%s'", @@ -66,6 +72,38 @@ class RedshiftClusterSensor(BaseSensorOperator): ) return current_status == self.target_status + def execute(self, context: Context) -> None: + if not self.deferrable: + super().execute(context=context) + elif not self.poke(context): + self.defer( + timeout=timedelta(seconds=self.timeout), + trigger=RedshiftClusterTrigger( + aws_conn_id=self.aws_conn_id, + cluster_identifier=self.cluster_identifier, + target_status=self.target_status, + poke_interval=self.poke_interval, + ), + method_name="execute_complete", + ) + + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> None: + if event is None: + err_msg = "Trigger error: event is None" + self.log.error(err_msg) + raise AirflowException(err_msg) + + status = event["status"] + if status == "error": + # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 + message = f"{event['status']}: {event['message']}" + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) + elif status == "success": + self.log.info("%s completed successfully.", self.task_id) + self.log.info("Cluster Identifier %s is in %s state", self.cluster_identifier, self.target_status) + @deprecated(reason="use `hook` property instead.", category=AirflowProviderDeprecationWarning) def get_hook(self) -> RedshiftHook: """Create and return a RedshiftHook.""" diff --git a/airflow/providers/amazon/aws/triggers/redshift_cluster.py b/airflow/providers/amazon/aws/triggers/redshift_cluster.py index 64a636c8db..456d9df303 100644 --- a/airflow/providers/amazon/aws/triggers/redshift_cluster.py +++ b/airflow/providers/amazon/aws/triggers/redshift_cluster.py @@ -16,12 +16,14 @@ # under the License. from __future__ import annotations +import asyncio import warnings -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, AsyncIterator from airflow.exceptions import AirflowProviderDeprecationWarning -from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook +from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftAsyncHook, RedshiftHook from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger +from airflow.triggers.base import BaseTrigger, TriggerEvent if TYPE_CHECKING: from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook @@ -262,3 +264,53 @@ class RedshiftDeleteClusterTrigger(AwsBaseWaiterTrigger): def hook(self) -> AwsGenericHook: return RedshiftHook(aws_conn_id=self.aws_conn_id) + + +class RedshiftClusterTrigger(BaseTrigger): + """ + RedshiftClusterTrigger is fired as deferred class with params to run the task in trigger worker. + + :param aws_conn_id: Reference to AWS connection id for redshift + :param cluster_identifier: unique identifier of a cluster + :param target_status: Reference to the status which needs to be checked + :param poke_interval: polling period in seconds to check for the status + """ + + def __init__( + self, + aws_conn_id: str, + cluster_identifier: str, + target_status: str, + poke_interval: float, + ): + super().__init__() + self.aws_conn_id = aws_conn_id + self.cluster_identifier = cluster_identifier + self.target_status = target_status + self.poke_interval = poke_interval + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serializes RedshiftClusterTrigger arguments and classpath.""" + return ( + "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftClusterTrigger", + { + "aws_conn_id": self.aws_conn_id, + "cluster_identifier": self.cluster_identifier, + "target_status": self.target_status, + "poke_interval": self.poke_interval, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + """Simple async function run until the cluster status match the target status.""" + try: + hook = RedshiftAsyncHook(aws_conn_id=self.aws_conn_id) + while True: + res = await hook.cluster_status(self.cluster_identifier) + if (res["status"] == "success" and res["cluster_state"] == self.target_status) or res[ + "status" + ] == "error": + yield TriggerEvent(res) + await asyncio.sleep(self.poke_interval) + except Exception as e: + yield TriggerEvent({"status": "error", "message": str(e)}) diff --git a/tests/always/test_project_structure.py b/tests/always/test_project_structure.py index 3b1b8a1e97..c88387919e 100644 --- a/tests/always/test_project_structure.py +++ b/tests/always/test_project_structure.py @@ -75,7 +75,6 @@ class TestProjectStructure: "tests/providers/amazon/aws/triggers/test_glue_crawler.py", "tests/providers/amazon/aws/triggers/test_lambda_function.py", "tests/providers/amazon/aws/triggers/test_rds.py", - "tests/providers/amazon/aws/triggers/test_redshift_cluster.py", "tests/providers/amazon/aws/triggers/test_step_function.py", "tests/providers/amazon/aws/utils/test_rds.py", "tests/providers/amazon/aws/utils/test_sagemaker.py", diff --git a/tests/providers/amazon/aws/sensors/test_redshift_cluster.py b/tests/providers/amazon/aws/sensors/test_redshift_cluster.py index 07c6ad67ee..04d4ca55fd 100644 --- a/tests/providers/amazon/aws/sensors/test_redshift_cluster.py +++ b/tests/providers/amazon/aws/sensors/test_redshift_cluster.py @@ -16,10 +16,27 @@ # under the License. from __future__ import annotations +from unittest import mock + import boto3 +import pytest from moto import mock_redshift +from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.sensors.redshift_cluster import RedshiftClusterSensor +from airflow.providers.amazon.aws.triggers.redshift_cluster import RedshiftClusterTrigger + +MODULE = "airflow.providers.amazon.aws.sensors.redshift_cluster" + + +@pytest.fixture +def deferrable_op(): + return RedshiftClusterSensor( + task_id="test_cluster_sensor", + cluster_identifier="test_cluster", + target_status="available", + deferrable=True, + ) class TestRedshiftClusterSensor: @@ -75,3 +92,44 @@ class TestRedshiftClusterSensor: ) assert op.poke({}) + + @mock.patch(f"{MODULE}.RedshiftClusterSensor.defer") + @mock.patch(f"{MODULE}.RedshiftClusterSensor.poke", return_value=True) + def test_execute_finish_before_deferred( + self, + mock_poke, + mock_defer, + deferrable_op, + ): + """Assert task is not deferred when it receives a finish status before deferring""" + + deferrable_op.execute({}) + assert not mock_defer.called + + @mock.patch(f"{MODULE}.RedshiftClusterSensor.poke", return_value=False) + def test_execute(self, mock_poke, deferrable_op): + """Test RedshiftClusterSensor that a task with wildcard=True + is deferred and an RedshiftClusterTrigger will be fired when executed method is called""" + + with pytest.raises(TaskDeferred) as exc: + deferrable_op.execute(None) + assert isinstance( + exc.value.trigger, RedshiftClusterTrigger + ), "Trigger is not a RedshiftClusterTrigger" + + def test_redshift_sensor_async_execute_failure(self, deferrable_op): + """Test RedshiftClusterSensor with an AirflowException is raised in case of error event""" + + with pytest.raises(AirflowException): + deferrable_op.execute_complete( + context=None, event={"status": "error", "message": "test failure message"} + ) + + def test_redshift_sensor_async_execute_complete(self, deferrable_op): + """Asserts that logging occurs as expected""" + + with mock.patch.object(deferrable_op.log, "info") as mock_log_info: + deferrable_op.execute_complete( + context=None, event={"status": "success", "cluster_state": "available"} + ) + mock_log_info.assert_called_with("Cluster Identifier %s is in %s state", "test_cluster", "available") diff --git a/tests/providers/amazon/aws/triggers/test_redshift_cluster.py b/tests/providers/amazon/aws/triggers/test_redshift_cluster.py new file mode 100644 index 0000000000..5d5cc2c424 --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_redshift_cluster.py @@ -0,0 +1,119 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import asyncio +from unittest import mock + +import pytest + +from airflow.providers.amazon.aws.triggers.redshift_cluster import ( + RedshiftClusterTrigger, +) +from airflow.triggers.base import TriggerEvent + +POLLING_PERIOD_SECONDS = 1.0 + + +class TestRedshiftClusterTrigger: + def test_redshift_cluster_sensor_trigger_serialization(self): + """ + Asserts that the RedshiftClusterTrigger correctly serializes its arguments + and classpath. + """ + trigger = RedshiftClusterTrigger( + aws_conn_id="test_redshift_conn_id", + cluster_identifier="mock_cluster_identifier", + target_status="available", + poke_interval=POLLING_PERIOD_SECONDS, + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.amazon.aws.triggers.redshift_cluster.RedshiftClusterTrigger" + assert kwargs == { + "aws_conn_id": "test_redshift_conn_id", + "cluster_identifier": "mock_cluster_identifier", + "target_status": "available", + "poke_interval": POLLING_PERIOD_SECONDS, + } + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftAsyncHook.cluster_status") + async def test_redshift_cluster_sensor_trigger_success(self, mock_cluster_status): + """ + Test RedshiftClusterTrigger with the success status + """ + expected_result = {"status": "success", "cluster_state": "available"} + + mock_cluster_status.return_value = expected_result + trigger = RedshiftClusterTrigger( + aws_conn_id="test_redshift_conn_id", + cluster_identifier="mock_cluster_identifier", + target_status="available", + poke_interval=POLLING_PERIOD_SECONDS, + ) + + generator = trigger.run() + actual = await generator.asend(None) + assert TriggerEvent(expected_result) == actual + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "expected_result", + [ + ({"status": "success", "cluster_state": "Resuming"}), + ], + ) + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftAsyncHook.cluster_status") + async def test_redshift_cluster_sensor_trigger_resuming_status( + self, mock_cluster_status, expected_result + ): + """Test RedshiftClusterTrigger with the success status""" + mock_cluster_status.return_value = expected_result + trigger = RedshiftClusterTrigger( + aws_conn_id="test_redshift_conn_id", + cluster_identifier="mock_cluster_identifier", + target_status="available", + poke_interval=POLLING_PERIOD_SECONDS, + ) + task = asyncio.create_task(trigger.run().__anext__()) + await asyncio.sleep(0.5) + + # TriggerEvent was returned + assert task.done() is False + + # Prevents error when task is destroyed while in "pending" state + asyncio.get_event_loop().stop() + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftAsyncHook.cluster_status") + async def test_redshift_cluster_sensor_trigger_exception(self, mock_cluster_status): + """Test RedshiftClusterTrigger with exception""" + mock_cluster_status.side_effect = Exception("Test exception") + trigger = RedshiftClusterTrigger( + aws_conn_id="test_redshift_conn_id", + cluster_identifier="mock_cluster_identifier", + target_status="available", + poke_interval=POLLING_PERIOD_SECONDS, + ) + + task = [i async for i in trigger.run()] + # since we use return as soon as we yield the trigger event + # at any given point there should be one trigger event returned to the task + # so we validate for length of task to be 1 + assert len(task) == 1 + assert TriggerEvent({"status": "error", "message": "Test exception"}) in task