This is an automated email from the ASF dual-hosted git repository. eladkal pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 9563dc573b add deferrable mode to RedshiftDataOperator (#36586) 9563dc573b is described below commit 9563dc573bc53b2c84640c88371b62cccdd811ff Author: Wei Lee <weilee...@gmail.com> AuthorDate: Fri Jan 19 02:35:43 2024 +0800 add deferrable mode to RedshiftDataOperator (#36586) * feat(providers/amazon): add deferrable mode to RedshiftDataOperator * test(providers/amazon): add test case to RedshiftDataHook async methods * test(providers/amazon): add test case to RedshiftDataOperator when deferrable = True * refactor(providers/amazon): extract comment operator initialization as deferrable_operator fixture * refactor(providers/amaozn): rename region as region_name * feat(providers/amazon): add verify and botocore_config as suggested * refactor(providers/amazon): use async_conn from aws hook and add missing await * feat(providers/amazon): make RedshiftDataTrigger.hook a cached_property * refactor(providers/amaozn): unify how async and sync version of check_query_is_finished are implemented * style(providers/amazon): fix mypy failure * fix(providers/amazon): fix async_conn call --- .../providers/amazon/aws/hooks/redshift_data.py | 83 ++++++++--- .../amazon/aws/operators/redshift_data.py | 55 +++++++- .../providers/amazon/aws/triggers/redshift_data.py | 113 +++++++++++++++ airflow/providers/amazon/provider.yaml | 1 + .../amazon/aws/hooks/test_redshift_data.py | 61 +++++++- .../amazon/aws/operators/test_redshift_data.py | 115 +++++++++++++++- .../amazon/aws/triggers/test_redshift_data.py | 153 +++++++++++++++++++++ 7 files changed, 560 insertions(+), 21 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/redshift_data.py b/airflow/providers/amazon/aws/hooks/redshift_data.py index f7df0fd744..538e5cee96 100644 --- a/airflow/providers/amazon/aws/hooks/redshift_data.py +++ b/airflow/providers/amazon/aws/hooks/redshift_data.py @@ -26,6 +26,21 @@ from airflow.providers.amazon.aws.utils import trim_none_values if TYPE_CHECKING: from mypy_boto3_redshift_data import RedshiftDataAPIServiceClient # noqa + from mypy_boto3_redshift_data.type_defs import DescribeStatementResponseTypeDef + +FINISHED_STATE = "FINISHED" +FAILED_STATE = "FAILED" +ABORTED_STATE = "ABORTED" +FAILURE_STATES = {FAILED_STATE, ABORTED_STATE} +RUNNING_STATES = {"PICKED", "STARTED", "SUBMITTED"} + + +class RedshiftDataQueryFailedError(ValueError): + """Raise an error that redshift data query failed.""" + + +class RedshiftDataQueryAbortedError(ValueError): + """Raise an error that redshift data query was aborted.""" class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]): @@ -108,27 +123,40 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]): return statement_id - def wait_for_results(self, statement_id, poll_interval): + def wait_for_results(self, statement_id: str, poll_interval: int) -> str: while True: self.log.info("Polling statement %s", statement_id) - resp = self.conn.describe_statement( - Id=statement_id, - ) - status = resp["Status"] - if status == "FINISHED": - num_rows = resp.get("ResultRows") - if num_rows is not None: - self.log.info("Processed %s rows", num_rows) - return status - elif status in ("FAILED", "ABORTED"): - raise ValueError( - f"Statement {statement_id!r} terminated with status {status}. " - f"Response details: {pformat(resp)}" - ) - else: - self.log.info("Query %s", status) + is_finished = self.check_query_is_finished(statement_id) + if is_finished: + return FINISHED_STATE + time.sleep(poll_interval) + def check_query_is_finished(self, statement_id: str) -> bool: + """Check whether query finished, raise exception is failed.""" + resp = self.conn.describe_statement(Id=statement_id) + return self.parse_statement_resposne(resp) + + def parse_statement_resposne(self, resp: DescribeStatementResponseTypeDef) -> bool: + """Parse the response of describe_statement.""" + status = resp["Status"] + if status == FINISHED_STATE: + num_rows = resp.get("ResultRows") + if num_rows is not None: + self.log.info("Processed %s rows", num_rows) + return True + elif status in FAILURE_STATES: + exception_cls = ( + RedshiftDataQueryFailedError if status == FAILED_STATE else RedshiftDataQueryAbortedError + ) + raise exception_cls( + f"Statement {resp['Id']} terminated with status {status}. " + f"Response details: {pformat(resp)}" + ) + + self.log.info("Query status: %s", status) + return False + def get_table_primary_key( self, table: str, @@ -201,3 +229,24 @@ class RedshiftDataHook(AwsGenericHook["RedshiftDataAPIServiceClient"]): break return pk_columns or None + + async def is_still_running(self, statement_id: str) -> bool: + """Async function to check whether the query is still running. + + :param statement_id: the UUID of the statement + """ + async with self.async_conn as client: + desc = await client.describe_statement(Id=statement_id) + return desc["Status"] in RUNNING_STATES + + async def check_query_is_finished_async(self, statement_id: str) -> bool: + """Async function to check statement is finished. + + It takes statement_id, makes async connection to redshift data to get the query status + by statement_id and returns the query status. + + :param statement_id: the UUID of the statement + """ + async with self.async_conn as client: + resp = await client.describe_statement(Id=statement_id) + return self.parse_statement_resposne(resp) diff --git a/airflow/providers/amazon/aws/operators/redshift_data.py b/airflow/providers/amazon/aws/operators/redshift_data.py index b454ad76ec..71ee82069e 100644 --- a/airflow/providers/amazon/aws/operators/redshift_data.py +++ b/airflow/providers/amazon/aws/operators/redshift_data.py @@ -17,10 +17,13 @@ # under the License. from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any +from airflow.configuration import conf +from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator +from airflow.providers.amazon.aws.triggers.redshift_data import RedshiftDataTrigger from airflow.providers.amazon.aws.utils.mixins import aws_template_fields if TYPE_CHECKING: @@ -92,6 +95,7 @@ class RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]): poll_interval: int = 10, return_sql_result: bool = False, workgroup_name: str | None = None, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs, ) -> None: super().__init__(**kwargs) @@ -114,11 +118,17 @@ class RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]): ) self.return_sql_result = return_sql_result self.statement_id: str | None = None + self.deferrable = deferrable def execute(self, context: Context) -> GetStatementResultResponseTypeDef | str: """Execute a statement against Amazon Redshift.""" self.log.info("Executing statement: %s", self.sql) + # Set wait_for_completion to False so that it waits for the status in the deferred task. + wait_for_completion = self.wait_for_completion + if self.deferrable and self.wait_for_completion: + self.wait_for_completion = False + self.statement_id = self.hook.execute_query( database=self.database, sql=self.sql, @@ -129,10 +139,27 @@ class RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]): secret_arn=self.secret_arn, statement_name=self.statement_name, with_event=self.with_event, - wait_for_completion=self.wait_for_completion, + wait_for_completion=wait_for_completion, poll_interval=self.poll_interval, ) + if self.deferrable: + is_finished = self.hook.check_query_is_finished(self.statement_id) + if not is_finished: + self.defer( + timeout=self.execution_timeout, + trigger=RedshiftDataTrigger( + statement_id=self.statement_id, + task_id=self.task_id, + poll_interval=self.poll_interval, + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + botocore_config=self.botocore_config, + ), + method_name="execute_complete", + ) + if self.return_sql_result: result = self.hook.conn.get_statement_result(Id=self.statement_id) self.log.debug("Statement result: %s", result) @@ -140,6 +167,30 @@ class RedshiftDataOperator(AwsBaseOperator[RedshiftDataHook]): else: return self.statement_id + def execute_complete( + self, context: Context, event: dict[str, Any] | None = None + ) -> GetStatementResultResponseTypeDef | str: + if event is None: + err_msg = "Trigger error: event is None" + self.log.info(err_msg) + raise AirflowException(err_msg) + + if event["status"] == "error": + msg = f"context: {context}, error message: {event['message']}" + raise AirflowException(msg) + + statement_id = event["statement_id"] + if not statement_id: + raise AirflowException("statement_id should not be empty.") + + self.log.info("%s completed successfully.", self.task_id) + if self.return_sql_result: + result = self.hook.conn.get_statement_result(Id=statement_id) + self.log.debug("Statement result: %s", result) + return result + + return statement_id + def on_kill(self) -> None: """Cancel the submitted redshift query.""" if self.statement_id: diff --git a/airflow/providers/amazon/aws/triggers/redshift_data.py b/airflow/providers/amazon/aws/triggers/redshift_data.py new file mode 100644 index 0000000000..2d0ecbc594 --- /dev/null +++ b/airflow/providers/amazon/aws/triggers/redshift_data.py @@ -0,0 +1,113 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import asyncio +from functools import cached_property +from typing import Any, AsyncIterator + +from airflow.providers.amazon.aws.hooks.redshift_data import ( + ABORTED_STATE, + FAILED_STATE, + RedshiftDataHook, + RedshiftDataQueryAbortedError, + RedshiftDataQueryFailedError, +) +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class RedshiftDataTrigger(BaseTrigger): + """ + RedshiftDataTrigger is fired as deferred class with params to run the task in triggerer. + + :param statement_id: the UUID of the statement + :param task_id: task ID of the Dag + :param poll_interval: polling period in seconds to check for the status + :param aws_conn_id: AWS connection ID for redshift + :param region_name: aws region to use + """ + + def __init__( + self, + statement_id: str, + task_id: str, + poll_interval: int, + aws_conn_id: str | None = "aws_default", + region_name: str | None = None, + verify: bool | str | None = None, + botocore_config: dict | None = None, + ): + super().__init__() + self.statement_id = statement_id + self.task_id = task_id + self.poll_interval = poll_interval + + self.aws_conn_id = aws_conn_id + self.region_name = region_name + self.verify = verify + self.botocore_config = botocore_config + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serializes RedshiftDataTrigger arguments and classpath.""" + return ( + "airflow.providers.amazon.aws.triggers.redshift_data.RedshiftDataTrigger", + { + "statement_id": self.statement_id, + "task_id": self.task_id, + "aws_conn_id": self.aws_conn_id, + "poll_interval": self.poll_interval, + "region_name": self.region_name, + "verify": self.verify, + "botocore_config": self.botocore_config, + }, + ) + + @cached_property + def hook(self) -> RedshiftDataHook: + return RedshiftDataHook( + aws_conn_id=self.aws_conn_id, + region_name=self.region_name, + verify=self.verify, + config=self.botocore_config, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + try: + while await self.hook.is_still_running(self.statement_id): + await asyncio.sleep(self.poll_interval) + + is_finished = await self.hook.check_query_is_finished_async(self.statement_id) + if is_finished: + response = {"status": "success", "statement_id": self.statement_id} + else: + response = { + "status": "error", + "statement_id": self.statement_id, + "message": f"{self.task_id} failed", + } + yield TriggerEvent(response) + except (RedshiftDataQueryFailedError, RedshiftDataQueryAbortedError) as error: + response = { + "status": "error", + "statement_id": self.statement_id, + "message": str(error), + "type": FAILED_STATE if isinstance(error, RedshiftDataQueryFailedError) else ABORTED_STATE, + } + yield TriggerEvent(response) + except Exception as error: + yield TriggerEvent({"status": "error", "statement_id": self.statement_id, "message": str(error)}) diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 1b90089db2..bcbb5c18e3 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -621,6 +621,7 @@ triggers: - integration-name: Amazon Redshift python-modules: - airflow.providers.amazon.aws.triggers.redshift_cluster + - airflow.providers.amazon.aws.triggers.redshift_data - integration-name: Amazon SageMaker python-modules: - airflow.providers.amazon.aws.triggers.sagemaker diff --git a/tests/providers/amazon/aws/hooks/test_redshift_data.py b/tests/providers/amazon/aws/hooks/test_redshift_data.py index cc174a872c..126585b432 100644 --- a/tests/providers/amazon/aws/hooks/test_redshift_data.py +++ b/tests/providers/amazon/aws/hooks/test_redshift_data.py @@ -22,7 +22,11 @@ from unittest import mock import pytest -from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook +from airflow.providers.amazon.aws.hooks.redshift_data import ( + RedshiftDataHook, + RedshiftDataQueryAbortedError, + RedshiftDataQueryFailedError, +) SQL = "sql" DATABASE = "database" @@ -292,3 +296,58 @@ class TestRedshiftDataHook: wait_for_completion=True, ) assert "Processed " not in caplog.text + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "describe_statement_response, expected_result", + [ + ({"Status": "PICKED"}, True), + ({"Status": "STARTED"}, True), + ({"Status": "SUBMITTED"}, True), + ({"Status": "FINISHED"}, False), + ({"Status": "FAILED"}, False), + ({"Status": "ABORTED"}, False), + ], + ) + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.async_conn") + async def test_is_still_running(self, mock_conn, describe_statement_response, expected_result): + hook = RedshiftDataHook() + mock_conn.__aenter__.return_value.describe_statement.return_value = describe_statement_response + response = await hook.is_still_running("uuid") + assert response == expected_result + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.async_conn") + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.is_still_running") + async def test_check_query_is_finished_async(self, mock_is_still_running, mock_conn): + hook = RedshiftDataHook() + mock_is_still_running.return_value = False + mock_conn.describe_statement = mock.AsyncMock() + mock_conn.__aenter__.return_value.describe_statement.return_value = { + "Id": "uuid", + "Status": "FINISHED", + } + is_finished = await hook.check_query_is_finished_async(statement_id="uuid") + assert is_finished is True + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "describe_statement_response, expected_exception", + ( + ( + {"Id": "uuid", "Status": "FAILED", "QueryString": "select 1", "Error": "Test error"}, + RedshiftDataQueryFailedError, + ), + ({"Id": "uuid", "Status": "ABORTED"}, RedshiftDataQueryAbortedError), + ), + ) + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.async_conn") + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.is_still_running") + async def test_check_query_is_finished_async_exception( + self, mock_is_still_running, mock_conn, describe_statement_response, expected_exception + ): + hook = RedshiftDataHook() + mock_is_still_running.return_value = False + mock_conn.__aenter__.return_value.describe_statement.return_value = describe_statement_response + with pytest.raises(expected_exception): + await hook.check_query_is_finished_async(statement_id="uuid") diff --git a/tests/providers/amazon/aws/operators/test_redshift_data.py b/tests/providers/amazon/aws/operators/test_redshift_data.py index 4b921b7142..fa22c98218 100644 --- a/tests/providers/amazon/aws/operators/test_redshift_data.py +++ b/tests/providers/amazon/aws/operators/test_redshift_data.py @@ -21,8 +21,9 @@ from unittest import mock import pytest -from airflow.exceptions import AirflowProviderDeprecationWarning +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning, TaskDeferred from airflow.providers.amazon.aws.operators.redshift_data import RedshiftDataOperator +from airflow.providers.amazon.aws.triggers.redshift_data import RedshiftDataTrigger CONN_ID = "aws_conn_test" TASK_ID = "task_id" @@ -31,6 +32,32 @@ DATABASE = "database" STATEMENT_ID = "statement_id" +@pytest.fixture +def deferrable_operator(): + cluster_identifier = "cluster_identifier" + db_user = "db_user" + secret_arn = "secret_arn" + statement_name = "statement_name" + parameters = [{"name": "id", "value": "1"}] + poll_interval = 5 + + operator = RedshiftDataOperator( + aws_conn_id=CONN_ID, + task_id=TASK_ID, + sql=SQL, + database=DATABASE, + cluster_identifier=cluster_identifier, + db_user=db_user, + secret_arn=secret_arn, + statement_name=statement_name, + parameters=parameters, + wait_for_completion=False, + poll_interval=poll_interval, + deferrable=True, + ) + return operator + + class TestRedshiftDataOperator: def test_init(self): op = RedshiftDataOperator( @@ -202,3 +229,89 @@ class TestRedshiftDataOperator: mock_conn.get_statement_result.assert_called_once_with( Id=STATEMENT_ID, ) + + @mock.patch("airflow.providers.amazon.aws.operators.redshift_data.RedshiftDataOperator.defer") + @mock.patch( + "airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.check_query_is_finished", + return_value=True, + ) + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query") + def test_execute_finished_before_defer(self, mock_exec_query, check_query_is_finished, mock_defer): + cluster_identifier = "cluster_identifier" + workgroup_name = None + db_user = "db_user" + secret_arn = "secret_arn" + statement_name = "statement_name" + parameters = [{"name": "id", "value": "1"}] + poll_interval = 5 + + operator = RedshiftDataOperator( + aws_conn_id=CONN_ID, + task_id=TASK_ID, + sql=SQL, + database=DATABASE, + cluster_identifier=cluster_identifier, + db_user=db_user, + secret_arn=secret_arn, + statement_name=statement_name, + parameters=parameters, + wait_for_completion=False, + poll_interval=poll_interval, + deferrable=True, + ) + operator.execute(None) + + assert not mock_defer.called + mock_exec_query.assert_called_once_with( + sql=SQL, + database=DATABASE, + cluster_identifier=cluster_identifier, + workgroup_name=workgroup_name, + db_user=db_user, + secret_arn=secret_arn, + statement_name=statement_name, + parameters=parameters, + with_event=False, + wait_for_completion=False, + poll_interval=poll_interval, + ) + + # @mock.patch("airflow.providers.amazon.aws.operators.redshift_data.RedshiftDataOperator.defer") + @mock.patch( + "airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.check_query_is_finished", + return_value=False, + ) + @mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query") + def test_execute_defer(self, mock_exec_query, check_query_is_finished, deferrable_operator): + with pytest.raises(TaskDeferred) as exc: + deferrable_operator.execute(None) + + assert isinstance(exc.value.trigger, RedshiftDataTrigger) + + def test_execute_complete_failure(self, deferrable_operator): + """Tests that an AirflowException is raised in case of error event""" + with pytest.raises(AirflowException): + deferrable_operator.execute_complete( + context=None, event={"status": "error", "message": "test failure message"} + ) + + def test_execute_complete_exception(self, deferrable_operator): + """Tests that an AirflowException is raised in case of empty event""" + with pytest.raises(AirflowException) as exc: + deferrable_operator.execute_complete(context=None, event=None) + assert exc.value.args[0] == "Did not receive valid event from the trigerrer" + + def test_execute_complete(self, deferrable_operator): + """Asserts that logging occurs as expected""" + + deferrable_operator.statement_id = "uuid" + + with mock.patch.object(deferrable_operator.log, "info") as mock_log_info: + assert ( + deferrable_operator.execute_complete( + context=None, + event={"status": "success", "message": "Job completed", "statement_id": "uuid"}, + ) + == "uuid" + ) + mock_log_info.assert_called_with("%s completed successfully.", TASK_ID) diff --git a/tests/providers/amazon/aws/triggers/test_redshift_data.py b/tests/providers/amazon/aws/triggers/test_redshift_data.py new file mode 100644 index 0000000000..49c0862af2 --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_redshift_data.py @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.providers.amazon.aws.hooks.redshift_data import ( + ABORTED_STATE, + FAILED_STATE, + RedshiftDataQueryAbortedError, + RedshiftDataQueryFailedError, +) +from airflow.providers.amazon.aws.triggers.redshift_data import RedshiftDataTrigger +from airflow.triggers.base import TriggerEvent + +TEST_CONN_ID = "aws_default" +TEST_TASK_ID = "123" +POLL_INTERVAL = 4.0 + + +class TestRedshiftDataTrigger: + def test_redshift_data_trigger_serialization(self): + """ + Asserts that the RedshiftDataTrigger correctly serializes its arguments + and classpath. + """ + trigger = RedshiftDataTrigger( + statement_id=[], + task_id=TEST_TASK_ID, + aws_conn_id=TEST_CONN_ID, + poll_interval=POLL_INTERVAL, + ) + classpath, kwargs = trigger.serialize() + assert classpath == "airflow.providers.amazon.aws.triggers.redshift_data.RedshiftDataTrigger" + assert kwargs == { + "statement_id": [], + "task_id": TEST_TASK_ID, + "poll_interval": POLL_INTERVAL, + "aws_conn_id": TEST_CONN_ID, + "region_name": None, + "botocore_config": None, + "verify": None, + } + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "return_value, response", + [ + ( + True, + TriggerEvent({"status": "success", "statement_id": "uuid"}), + ), + ( + False, + TriggerEvent( + {"status": "error", "message": f"{TEST_TASK_ID} failed", "statement_id": "uuid"} + ), + ), + ], + ) + @mock.patch( + "airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.check_query_is_finished_async" + ) + @mock.patch( + "airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.is_still_running", + return_value=False, + ) + async def test_redshift_data_trigger_run( + self, mocked_is_still_running, mock_check_query_is_finised_async, return_value, response + ): + """ + Tests that RedshiftDataTrigger only fires once the query execution reaches a successful state. + """ + mock_check_query_is_finised_async.return_value = return_value + trigger = RedshiftDataTrigger( + statement_id="uuid", + task_id=TEST_TASK_ID, + poll_interval=POLL_INTERVAL, + aws_conn_id=TEST_CONN_ID, + ) + generator = trigger.run() + actual = await generator.asend(None) + assert response == actual + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "raised_exception, expected_response", + [ + ( + RedshiftDataQueryFailedError("Failed"), + { + "status": "error", + "statement_id": "uuid", + "message": "Failed", + "type": FAILED_STATE, + }, + ), + ( + RedshiftDataQueryAbortedError("Aborted"), + { + "status": "error", + "statement_id": "uuid", + "message": "Aborted", + "type": ABORTED_STATE, + }, + ), + ( + Exception(f"{TEST_TASK_ID} failed"), + {"status": "error", "statement_id": "uuid", "message": f"{TEST_TASK_ID} failed"}, + ), + ], + ) + @mock.patch( + "airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.check_query_is_finished_async" + ) + @mock.patch( + "airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.is_still_running", + return_value=False, + ) + async def test_redshift_data_trigger_exception( + self, mocked_is_still_running, mock_check_query_is_finised_async, raised_exception, expected_response + ): + """ + Test that RedshiftDataTrigger fires the correct event in case of an error. + """ + mock_check_query_is_finised_async.side_effect = raised_exception + + trigger = RedshiftDataTrigger( + statement_id="uuid", + task_id=TEST_TASK_ID, + poll_interval=POLL_INTERVAL, + aws_conn_id=TEST_CONN_ID, + ) + task = [i async for i in trigger.run()] + assert len(task) == 1 + assert TriggerEvent(expected_response) in task