This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9563dc573b add deferrable mode to RedshiftDataOperator (#36586)
9563dc573b is described below

commit 9563dc573bc53b2c84640c88371b62cccdd811ff
Author: Wei Lee <weilee...@gmail.com>
AuthorDate: Fri Jan 19 02:35:43 2024 +0800

    add deferrable mode to RedshiftDataOperator (#36586)
    
    * feat(providers/amazon): add deferrable mode to RedshiftDataOperator
    
    * test(providers/amazon): add test case to RedshiftDataHook async methods
    
    * test(providers/amazon): add test case to RedshiftDataOperator when 
deferrable = True
    
    * refactor(providers/amazon): extract comment operator initialization as 
deferrable_operator fixture
    
    * refactor(providers/amaozn): rename region as region_name
    
    * feat(providers/amazon): add verify and botocore_config as suggested
    
    * refactor(providers/amazon): use async_conn from aws hook and add missing 
await
    
    * feat(providers/amazon): make RedshiftDataTrigger.hook a cached_property
    
    * refactor(providers/amaozn): unify how async and sync version of 
check_query_is_finished are implemented
    
    * style(providers/amazon): fix mypy failure
    
    * fix(providers/amazon): fix async_conn call
---
 .../providers/amazon/aws/hooks/redshift_data.py    |  83 ++++++++---
 .../amazon/aws/operators/redshift_data.py          |  55 +++++++-
 .../providers/amazon/aws/triggers/redshift_data.py | 113 +++++++++++++++
 airflow/providers/amazon/provider.yaml             |   1 +
 .../amazon/aws/hooks/test_redshift_data.py         |  61 +++++++-
 .../amazon/aws/operators/test_redshift_data.py     | 115 +++++++++++++++-
 .../amazon/aws/triggers/test_redshift_data.py      | 153 +++++++++++++++++++++
 7 files changed, 560 insertions(+), 21 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/redshift_data.py 
b/airflow/providers/amazon/aws/hooks/redshift_data.py
index f7df0fd744..538e5cee96 100644
--- a/airflow/providers/amazon/aws/hooks/redshift_data.py
+++ b/airflow/providers/amazon/aws/hooks/redshift_data.py
@@ -26,6 +26,21 @@ from airflow.providers.amazon.aws.utils import 
trim_none_values
 
 if TYPE_CHECKING:
     from mypy_boto3_redshift_data import RedshiftDataAPIServiceClient  # noqa
+    from mypy_boto3_redshift_data.type_defs import 
DescribeStatementResponseTypeDef
+
+FINISHED_STATE = "FINISHED"
+FAILED_STATE = "FAILED"
+ABORTED_STATE = "ABORTED"
+FAILURE_STATES = {FAILED_STATE, ABORTED_STATE}
+RUNNING_STATES = {"PICKED", "STARTED", "SUBMITTED"}
+
+
+class RedshiftDataQueryFailedError(ValueError):
+    """Raise an error that redshift data query failed."""
+
+
+class RedshiftDataQueryAbortedError(ValueError):
+    """Raise an error that redshift data query was aborted."""
 
 
 class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
@@ -108,27 +123,40 @@ class 
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
 
         return statement_id
 
-    def wait_for_results(self, statement_id, poll_interval):
+    def wait_for_results(self, statement_id: str, poll_interval: int) -> str:
         while True:
             self.log.info("Polling statement %s", statement_id)
-            resp = self.conn.describe_statement(
-                Id=statement_id,
-            )
-            status = resp["Status"]
-            if status == "FINISHED":
-                num_rows = resp.get("ResultRows")
-                if num_rows is not None:
-                    self.log.info("Processed %s rows", num_rows)
-                return status
-            elif status in ("FAILED", "ABORTED"):
-                raise ValueError(
-                    f"Statement {statement_id!r} terminated with status 
{status}. "
-                    f"Response details: {pformat(resp)}"
-                )
-            else:
-                self.log.info("Query %s", status)
+            is_finished = self.check_query_is_finished(statement_id)
+            if is_finished:
+                return FINISHED_STATE
+
             time.sleep(poll_interval)
 
+    def check_query_is_finished(self, statement_id: str) -> bool:
+        """Check whether query finished, raise exception is failed."""
+        resp = self.conn.describe_statement(Id=statement_id)
+        return self.parse_statement_resposne(resp)
+
+    def parse_statement_resposne(self, resp: DescribeStatementResponseTypeDef) 
-> bool:
+        """Parse the response of describe_statement."""
+        status = resp["Status"]
+        if status == FINISHED_STATE:
+            num_rows = resp.get("ResultRows")
+            if num_rows is not None:
+                self.log.info("Processed %s rows", num_rows)
+            return True
+        elif status in FAILURE_STATES:
+            exception_cls = (
+                RedshiftDataQueryFailedError if status == FAILED_STATE else 
RedshiftDataQueryAbortedError
+            )
+            raise exception_cls(
+                f"Statement {resp['Id']} terminated with status {status}. "
+                f"Response details: {pformat(resp)}"
+            )
+
+        self.log.info("Query status: %s", status)
+        return False
+
     def get_table_primary_key(
         self,
         table: str,
@@ -201,3 +229,24 @@ class 
RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]):
                 break
 
         return pk_columns or None
+
+    async def is_still_running(self, statement_id: str) -> bool:
+        """Async function to check whether the query is still running.
+
+        :param statement_id: the UUID of the statement
+        """
+        async with self.async_conn as client:
+            desc = await client.describe_statement(Id=statement_id)
+            return desc["Status"] in RUNNING_STATES
+
+    async def check_query_is_finished_async(self, statement_id: str) -> bool:
+        """Async function to check statement is finished.
+
+        It takes statement_id, makes async connection to redshift data to get 
the query status
+        by statement_id and returns the query status.
+
+        :param statement_id: the UUID of the statement
+        """
+        async with self.async_conn as client:
+            resp = await client.describe_statement(Id=statement_id)
+            return self.parse_statement_resposne(resp)
diff --git a/airflow/providers/amazon/aws/operators/redshift_data.py 
b/airflow/providers/amazon/aws/operators/redshift_data.py
index b454ad76ec..71ee82069e 100644
--- a/airflow/providers/amazon/aws/operators/redshift_data.py
+++ b/airflow/providers/amazon/aws/operators/redshift_data.py
@@ -17,10 +17,13 @@
 # under the License.
 from __future__ import annotations
 
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
 
+from airflow.configuration import conf
+from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
 from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
+from airflow.providers.amazon.aws.triggers.redshift_data import 
RedshiftDataTrigger
 from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
 
 if TYPE_CHECKING:
@@ -92,6 +95,7 @@ class RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
         poll_interval: int = 10,
         return_sql_result: bool = False,
         workgroup_name: str | None = None,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
         **kwargs,
     ) -> None:
         super().__init__(**kwargs)
@@ -114,11 +118,17 @@ class 
RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
             )
         self.return_sql_result = return_sql_result
         self.statement_id: str | None = None
+        self.deferrable = deferrable
 
     def execute(self, context: Context) -> GetStatementResultResponseTypeDef | 
str:
         """Execute a statement against Amazon Redshift."""
         self.log.info("Executing statement: %s", self.sql)
 
+        # Set wait_for_completion to False so that it waits for the status in 
the deferred task.
+        wait_for_completion = self.wait_for_completion
+        if self.deferrable and self.wait_for_completion:
+            self.wait_for_completion = False
+
         self.statement_id = self.hook.execute_query(
             database=self.database,
             sql=self.sql,
@@ -129,10 +139,27 @@ class 
RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
             secret_arn=self.secret_arn,
             statement_name=self.statement_name,
             with_event=self.with_event,
-            wait_for_completion=self.wait_for_completion,
+            wait_for_completion=wait_for_completion,
             poll_interval=self.poll_interval,
         )
 
+        if self.deferrable:
+            is_finished = self.hook.check_query_is_finished(self.statement_id)
+            if not is_finished:
+                self.defer(
+                    timeout=self.execution_timeout,
+                    trigger=RedshiftDataTrigger(
+                        statement_id=self.statement_id,
+                        task_id=self.task_id,
+                        poll_interval=self.poll_interval,
+                        aws_conn_id=self.aws_conn_id,
+                        region_name=self.region_name,
+                        verify=self.verify,
+                        botocore_config=self.botocore_config,
+                    ),
+                    method_name="execute_complete",
+                )
+
         if self.return_sql_result:
             result = self.hook.conn.get_statement_result(Id=self.statement_id)
             self.log.debug("Statement result: %s", result)
@@ -140,6 +167,30 @@ class 
RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]):
         else:
             return self.statement_id
 
+    def execute_complete(
+        self, context: Context, event: dict[str, Any] | None = None
+    ) -> GetStatementResultResponseTypeDef | str:
+        if event is None:
+            err_msg = "Trigger error: event is None"
+            self.log.info(err_msg)
+            raise AirflowException(err_msg)
+
+        if event["status"] == "error":
+            msg = f"context: {context}, error message: {event['message']}"
+            raise AirflowException(msg)
+
+        statement_id = event["statement_id"]
+        if not statement_id:
+            raise AirflowException("statement_id should not be empty.")
+
+        self.log.info("%s completed successfully.", self.task_id)
+        if self.return_sql_result:
+            result = self.hook.conn.get_statement_result(Id=statement_id)
+            self.log.debug("Statement result: %s", result)
+            return result
+
+        return statement_id
+
     def on_kill(self) -> None:
         """Cancel the submitted redshift query."""
         if self.statement_id:
diff --git a/airflow/providers/amazon/aws/triggers/redshift_data.py 
b/airflow/providers/amazon/aws/triggers/redshift_data.py
new file mode 100644
index 0000000000..2d0ecbc594
--- /dev/null
+++ b/airflow/providers/amazon/aws/triggers/redshift_data.py
@@ -0,0 +1,113 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import asyncio
+from functools import cached_property
+from typing import Any, AsyncIterator
+
+from airflow.providers.amazon.aws.hooks.redshift_data import (
+    ABORTED_STATE,
+    FAILED_STATE,
+    RedshiftDataHook,
+    RedshiftDataQueryAbortedError,
+    RedshiftDataQueryFailedError,
+)
+from airflow.triggers.base import BaseTrigger, TriggerEvent
+
+
+class RedshiftDataTrigger(BaseTrigger):
+    """
+    RedshiftDataTrigger is fired as deferred class with params to run the task 
in triggerer.
+
+    :param statement_id: the UUID of the statement
+    :param task_id: task ID of the Dag
+    :param poll_interval:  polling period in seconds to check for the status
+    :param aws_conn_id: AWS connection ID for redshift
+    :param region_name: aws region to use
+    """
+
+    def __init__(
+        self,
+        statement_id: str,
+        task_id: str,
+        poll_interval: int,
+        aws_conn_id: str | None = "aws_default",
+        region_name: str | None = None,
+        verify: bool | str | None = None,
+        botocore_config: dict | None = None,
+    ):
+        super().__init__()
+        self.statement_id = statement_id
+        self.task_id = task_id
+        self.poll_interval = poll_interval
+
+        self.aws_conn_id = aws_conn_id
+        self.region_name = region_name
+        self.verify = verify
+        self.botocore_config = botocore_config
+
+    def serialize(self) -> tuple[str, dict[str, Any]]:
+        """Serializes RedshiftDataTrigger arguments and classpath."""
+        return (
+            
"airflow.providers.amazon.aws.triggers.redshift_data.RedshiftDataTrigger",
+            {
+                "statement_id": self.statement_id,
+                "task_id": self.task_id,
+                "aws_conn_id": self.aws_conn_id,
+                "poll_interval": self.poll_interval,
+                "region_name": self.region_name,
+                "verify": self.verify,
+                "botocore_config": self.botocore_config,
+            },
+        )
+
+    @cached_property
+    def hook(self) -> RedshiftDataHook:
+        return RedshiftDataHook(
+            aws_conn_id=self.aws_conn_id,
+            region_name=self.region_name,
+            verify=self.verify,
+            config=self.botocore_config,
+        )
+
+    async def run(self) -> AsyncIterator[TriggerEvent]:
+        try:
+            while await self.hook.is_still_running(self.statement_id):
+                await asyncio.sleep(self.poll_interval)
+
+            is_finished = await 
self.hook.check_query_is_finished_async(self.statement_id)
+            if is_finished:
+                response = {"status": "success", "statement_id": 
self.statement_id}
+            else:
+                response = {
+                    "status": "error",
+                    "statement_id": self.statement_id,
+                    "message": f"{self.task_id} failed",
+                }
+            yield TriggerEvent(response)
+        except (RedshiftDataQueryFailedError, RedshiftDataQueryAbortedError) 
as error:
+            response = {
+                "status": "error",
+                "statement_id": self.statement_id,
+                "message": str(error),
+                "type": FAILED_STATE if isinstance(error, 
RedshiftDataQueryFailedError) else ABORTED_STATE,
+            }
+            yield TriggerEvent(response)
+        except Exception as error:
+            yield TriggerEvent({"status": "error", "statement_id": 
self.statement_id, "message": str(error)})
diff --git a/airflow/providers/amazon/provider.yaml 
b/airflow/providers/amazon/provider.yaml
index 1b90089db2..bcbb5c18e3 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -621,6 +621,7 @@ triggers:
   - integration-name: Amazon Redshift
     python-modules:
       - airflow.providers.amazon.aws.triggers.redshift_cluster
+      - airflow.providers.amazon.aws.triggers.redshift_data
   - integration-name: Amazon SageMaker
     python-modules:
       - airflow.providers.amazon.aws.triggers.sagemaker
diff --git a/tests/providers/amazon/aws/hooks/test_redshift_data.py 
b/tests/providers/amazon/aws/hooks/test_redshift_data.py
index cc174a872c..126585b432 100644
--- a/tests/providers/amazon/aws/hooks/test_redshift_data.py
+++ b/tests/providers/amazon/aws/hooks/test_redshift_data.py
@@ -22,7 +22,11 @@ from unittest import mock
 
 import pytest
 
-from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
+from airflow.providers.amazon.aws.hooks.redshift_data import (
+    RedshiftDataHook,
+    RedshiftDataQueryAbortedError,
+    RedshiftDataQueryFailedError,
+)
 
 SQL = "sql"
 DATABASE = "database"
@@ -292,3 +296,58 @@ class TestRedshiftDataHook:
                 wait_for_completion=True,
             )
             assert "Processed " not in caplog.text
+
+    @pytest.mark.asyncio
+    @pytest.mark.parametrize(
+        "describe_statement_response, expected_result",
+        [
+            ({"Status": "PICKED"}, True),
+            ({"Status": "STARTED"}, True),
+            ({"Status": "SUBMITTED"}, True),
+            ({"Status": "FINISHED"}, False),
+            ({"Status": "FAILED"}, False),
+            ({"Status": "ABORTED"}, False),
+        ],
+    )
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.async_conn")
+    async def test_is_still_running(self, mock_conn, 
describe_statement_response, expected_result):
+        hook = RedshiftDataHook()
+        mock_conn.__aenter__.return_value.describe_statement.return_value = 
describe_statement_response
+        response = await hook.is_still_running("uuid")
+        assert response == expected_result
+
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.async_conn")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.is_still_running")
+    async def test_check_query_is_finished_async(self, mock_is_still_running, 
mock_conn):
+        hook = RedshiftDataHook()
+        mock_is_still_running.return_value = False
+        mock_conn.describe_statement = mock.AsyncMock()
+        mock_conn.__aenter__.return_value.describe_statement.return_value = {
+            "Id": "uuid",
+            "Status": "FINISHED",
+        }
+        is_finished = await 
hook.check_query_is_finished_async(statement_id="uuid")
+        assert is_finished is True
+
+    @pytest.mark.asyncio
+    @pytest.mark.parametrize(
+        "describe_statement_response, expected_exception",
+        (
+            (
+                {"Id": "uuid", "Status": "FAILED", "QueryString": "select 1", 
"Error": "Test error"},
+                RedshiftDataQueryFailedError,
+            ),
+            ({"Id": "uuid", "Status": "ABORTED"}, 
RedshiftDataQueryAbortedError),
+        ),
+    )
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.async_conn")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.is_still_running")
+    async def test_check_query_is_finished_async_exception(
+        self, mock_is_still_running, mock_conn, describe_statement_response, 
expected_exception
+    ):
+        hook = RedshiftDataHook()
+        mock_is_still_running.return_value = False
+        mock_conn.__aenter__.return_value.describe_statement.return_value = 
describe_statement_response
+        with pytest.raises(expected_exception):
+            await hook.check_query_is_finished_async(statement_id="uuid")
diff --git a/tests/providers/amazon/aws/operators/test_redshift_data.py 
b/tests/providers/amazon/aws/operators/test_redshift_data.py
index 4b921b7142..fa22c98218 100644
--- a/tests/providers/amazon/aws/operators/test_redshift_data.py
+++ b/tests/providers/amazon/aws/operators/test_redshift_data.py
@@ -21,8 +21,9 @@ from unittest import mock
 
 import pytest
 
-from airflow.exceptions import AirflowProviderDeprecationWarning
+from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning, TaskDeferred
 from airflow.providers.amazon.aws.operators.redshift_data import 
RedshiftDataOperator
+from airflow.providers.amazon.aws.triggers.redshift_data import 
RedshiftDataTrigger
 
 CONN_ID = "aws_conn_test"
 TASK_ID = "task_id"
@@ -31,6 +32,32 @@ DATABASE = "database"
 STATEMENT_ID = "statement_id"
 
 
+@pytest.fixture
+def deferrable_operator():
+    cluster_identifier = "cluster_identifier"
+    db_user = "db_user"
+    secret_arn = "secret_arn"
+    statement_name = "statement_name"
+    parameters = [{"name": "id", "value": "1"}]
+    poll_interval = 5
+
+    operator = RedshiftDataOperator(
+        aws_conn_id=CONN_ID,
+        task_id=TASK_ID,
+        sql=SQL,
+        database=DATABASE,
+        cluster_identifier=cluster_identifier,
+        db_user=db_user,
+        secret_arn=secret_arn,
+        statement_name=statement_name,
+        parameters=parameters,
+        wait_for_completion=False,
+        poll_interval=poll_interval,
+        deferrable=True,
+    )
+    return operator
+
+
 class TestRedshiftDataOperator:
     def test_init(self):
         op = RedshiftDataOperator(
@@ -202,3 +229,89 @@ class TestRedshiftDataOperator:
         mock_conn.get_statement_result.assert_called_once_with(
             Id=STATEMENT_ID,
         )
+
+    
@mock.patch("airflow.providers.amazon.aws.operators.redshift_data.RedshiftDataOperator.defer")
+    @mock.patch(
+        
"airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.check_query_is_finished",
+        return_value=True,
+    )
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
+    def test_execute_finished_before_defer(self, mock_exec_query, 
check_query_is_finished, mock_defer):
+        cluster_identifier = "cluster_identifier"
+        workgroup_name = None
+        db_user = "db_user"
+        secret_arn = "secret_arn"
+        statement_name = "statement_name"
+        parameters = [{"name": "id", "value": "1"}]
+        poll_interval = 5
+
+        operator = RedshiftDataOperator(
+            aws_conn_id=CONN_ID,
+            task_id=TASK_ID,
+            sql=SQL,
+            database=DATABASE,
+            cluster_identifier=cluster_identifier,
+            db_user=db_user,
+            secret_arn=secret_arn,
+            statement_name=statement_name,
+            parameters=parameters,
+            wait_for_completion=False,
+            poll_interval=poll_interval,
+            deferrable=True,
+        )
+        operator.execute(None)
+
+        assert not mock_defer.called
+        mock_exec_query.assert_called_once_with(
+            sql=SQL,
+            database=DATABASE,
+            cluster_identifier=cluster_identifier,
+            workgroup_name=workgroup_name,
+            db_user=db_user,
+            secret_arn=secret_arn,
+            statement_name=statement_name,
+            parameters=parameters,
+            with_event=False,
+            wait_for_completion=False,
+            poll_interval=poll_interval,
+        )
+
+    # 
@mock.patch("airflow.providers.amazon.aws.operators.redshift_data.RedshiftDataOperator.defer")
+    @mock.patch(
+        
"airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.check_query_is_finished",
+        return_value=False,
+    )
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
+    def test_execute_defer(self, mock_exec_query, check_query_is_finished, 
deferrable_operator):
+        with pytest.raises(TaskDeferred) as exc:
+            deferrable_operator.execute(None)
+
+        assert isinstance(exc.value.trigger, RedshiftDataTrigger)
+
+    def test_execute_complete_failure(self, deferrable_operator):
+        """Tests that an AirflowException is raised in case of error event"""
+        with pytest.raises(AirflowException):
+            deferrable_operator.execute_complete(
+                context=None, event={"status": "error", "message": "test 
failure message"}
+            )
+
+    def test_execute_complete_exception(self, deferrable_operator):
+        """Tests that an AirflowException is raised in case of empty event"""
+        with pytest.raises(AirflowException) as exc:
+            deferrable_operator.execute_complete(context=None, event=None)
+            assert exc.value.args[0] == "Did not receive valid event from the 
trigerrer"
+
+    def test_execute_complete(self, deferrable_operator):
+        """Asserts that logging occurs as expected"""
+
+        deferrable_operator.statement_id = "uuid"
+
+        with mock.patch.object(deferrable_operator.log, "info") as 
mock_log_info:
+            assert (
+                deferrable_operator.execute_complete(
+                    context=None,
+                    event={"status": "success", "message": "Job completed", 
"statement_id": "uuid"},
+                )
+                == "uuid"
+            )
+        mock_log_info.assert_called_with("%s completed successfully.", TASK_ID)
diff --git a/tests/providers/amazon/aws/triggers/test_redshift_data.py 
b/tests/providers/amazon/aws/triggers/test_redshift_data.py
new file mode 100644
index 0000000000..49c0862af2
--- /dev/null
+++ b/tests/providers/amazon/aws/triggers/test_redshift_data.py
@@ -0,0 +1,153 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+from airflow.providers.amazon.aws.hooks.redshift_data import (
+    ABORTED_STATE,
+    FAILED_STATE,
+    RedshiftDataQueryAbortedError,
+    RedshiftDataQueryFailedError,
+)
+from airflow.providers.amazon.aws.triggers.redshift_data import 
RedshiftDataTrigger
+from airflow.triggers.base import TriggerEvent
+
+TEST_CONN_ID = "aws_default"
+TEST_TASK_ID = "123"
+POLL_INTERVAL = 4.0
+
+
+class TestRedshiftDataTrigger:
+    def test_redshift_data_trigger_serialization(self):
+        """
+        Asserts that the RedshiftDataTrigger correctly serializes its arguments
+        and classpath.
+        """
+        trigger = RedshiftDataTrigger(
+            statement_id=[],
+            task_id=TEST_TASK_ID,
+            aws_conn_id=TEST_CONN_ID,
+            poll_interval=POLL_INTERVAL,
+        )
+        classpath, kwargs = trigger.serialize()
+        assert classpath == 
"airflow.providers.amazon.aws.triggers.redshift_data.RedshiftDataTrigger"
+        assert kwargs == {
+            "statement_id": [],
+            "task_id": TEST_TASK_ID,
+            "poll_interval": POLL_INTERVAL,
+            "aws_conn_id": TEST_CONN_ID,
+            "region_name": None,
+            "botocore_config": None,
+            "verify": None,
+        }
+
+    @pytest.mark.asyncio
+    @pytest.mark.parametrize(
+        "return_value, response",
+        [
+            (
+                True,
+                TriggerEvent({"status": "success", "statement_id": "uuid"}),
+            ),
+            (
+                False,
+                TriggerEvent(
+                    {"status": "error", "message": f"{TEST_TASK_ID} failed", 
"statement_id": "uuid"}
+                ),
+            ),
+        ],
+    )
+    @mock.patch(
+        
"airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.check_query_is_finished_async"
+    )
+    @mock.patch(
+        
"airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.is_still_running",
+        return_value=False,
+    )
+    async def test_redshift_data_trigger_run(
+        self, mocked_is_still_running, mock_check_query_is_finised_async, 
return_value, response
+    ):
+        """
+        Tests that RedshiftDataTrigger only fires once the query execution 
reaches a successful state.
+        """
+        mock_check_query_is_finised_async.return_value = return_value
+        trigger = RedshiftDataTrigger(
+            statement_id="uuid",
+            task_id=TEST_TASK_ID,
+            poll_interval=POLL_INTERVAL,
+            aws_conn_id=TEST_CONN_ID,
+        )
+        generator = trigger.run()
+        actual = await generator.asend(None)
+        assert response == actual
+
+    @pytest.mark.asyncio
+    @pytest.mark.parametrize(
+        "raised_exception, expected_response",
+        [
+            (
+                RedshiftDataQueryFailedError("Failed"),
+                {
+                    "status": "error",
+                    "statement_id": "uuid",
+                    "message": "Failed",
+                    "type": FAILED_STATE,
+                },
+            ),
+            (
+                RedshiftDataQueryAbortedError("Aborted"),
+                {
+                    "status": "error",
+                    "statement_id": "uuid",
+                    "message": "Aborted",
+                    "type": ABORTED_STATE,
+                },
+            ),
+            (
+                Exception(f"{TEST_TASK_ID} failed"),
+                {"status": "error", "statement_id": "uuid", "message": 
f"{TEST_TASK_ID} failed"},
+            ),
+        ],
+    )
+    @mock.patch(
+        
"airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.check_query_is_finished_async"
+    )
+    @mock.patch(
+        
"airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.is_still_running",
+        return_value=False,
+    )
+    async def test_redshift_data_trigger_exception(
+        self, mocked_is_still_running, mock_check_query_is_finised_async, 
raised_exception, expected_response
+    ):
+        """
+        Test that RedshiftDataTrigger fires the correct event in case of an 
error.
+        """
+        mock_check_query_is_finised_async.side_effect = raised_exception
+
+        trigger = RedshiftDataTrigger(
+            statement_id="uuid",
+            task_id=TEST_TASK_ID,
+            poll_interval=POLL_INTERVAL,
+            aws_conn_id=TEST_CONN_ID,
+        )
+        task = [i async for i in trigger.run()]
+        assert len(task) == 1
+        assert TriggerEvent(expected_response) in task

Reply via email to