This is an automated email from the ASF dual-hosted git repository. onikolas pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 72d09a677f Use a waiter in `AthenaHook` (#31942) 72d09a677f is described below commit 72d09a677fea22b51dbf20f3b12bae6b3c1e4792 Author: Raphaƫl Vandon <vand...@amazon.com> AuthorDate: Fri Jun 23 14:20:37 2023 -0700 Use a waiter in `AthenaHook` (#31942) * Use custom waiters for Emr Serverless operators Update unit tests --------- Co-authored-by: Syed Hussain <syeda...@amazon.com> Co-authored-by: Vincent <97131062+vincb...@users.noreply.github.com> --- airflow/providers/amazon/aws/hooks/athena.py | 80 ++++++++++------------ airflow/providers/amazon/aws/operators/athena.py | 5 +- airflow/providers/amazon/aws/sensors/athena.py | 4 +- airflow/providers/amazon/aws/waiters/athena.json | 30 ++++++++ tests/providers/amazon/aws/hooks/test_athena.py | 12 ++-- .../providers/amazon/aws/operators/test_athena.py | 63 ++--------------- 6 files changed, 81 insertions(+), 113 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/athena.py b/airflow/providers/amazon/aws/hooks/athena.py index f68eee9355..b0d1878507 100644 --- a/airflow/providers/amazon/aws/hooks/athena.py +++ b/airflow/providers/amazon/aws/hooks/athena.py @@ -24,12 +24,14 @@ This module contains AWS Athena hook. """ from __future__ import annotations -from time import sleep +import warnings from typing import Any from botocore.paginate import PageIterator +from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +from airflow.providers.amazon.aws.utils.waiter_with_logging import wait class AthenaHook(AwsBaseHook): @@ -38,8 +40,7 @@ class AthenaHook(AwsBaseHook): Provide thick wrapper around :external+boto3:py:class:`boto3.client("athena") <Athena.Client>`. - :param sleep_time: Time (in seconds) to wait between two consecutive calls - to check query status on Athena. + :param sleep_time: obsolete, please use the parameter of `poll_query_status` method instead :param log_query: Whether to log athena query and other execution params when it's executed. Defaults to *True*. @@ -65,9 +66,20 @@ class AthenaHook(AwsBaseHook): "CANCELLED", ) - def __init__(self, *args: Any, sleep_time: int = 30, log_query: bool = True, **kwargs: Any) -> None: + def __init__( + self, *args: Any, sleep_time: int | None = None, log_query: bool = True, **kwargs: Any + ) -> None: super().__init__(client_type="athena", *args, **kwargs) # type: ignore - self.sleep_time = sleep_time + if sleep_time is not None: + self.sleep_time = sleep_time + warnings.warn( + "The `sleep_time` parameter of the Athena hook is deprecated, " + "please pass this parameter to the poll_query_status method instead.", + AirflowProviderDeprecationWarning, + stacklevel=2, + ) + else: + self.sleep_time = 30 # previous default value self.log_query = log_query def run_query( @@ -229,51 +241,31 @@ class AthenaHook(AwsBaseHook): return paginator.paginate(**result_params) def poll_query_status( - self, - query_execution_id: str, - max_polling_attempts: int | None = None, + self, query_execution_id: str, max_polling_attempts: int | None = None, sleep_time: int | None = None ) -> str | None: """Poll the state of a submitted query until it reaches final state. :param query_execution_id: ID of submitted athena query - :param max_polling_attempts: Number of times to poll for query state - before function exits + :param max_polling_attempts: Number of times to poll for query state before function exits + :param sleep_time: Time (in seconds) to wait between two consecutive query status checks. :return: One of the final states """ - try_number = 1 - final_query_state = None # Query state when query reaches final state or max_polling_attempts reached - while True: - query_state = self.check_query_status(query_execution_id) - if query_state is None: - self.log.info( - "Query execution id: %s, trial %s: Invalid query state. Retrying again", - query_execution_id, - try_number, - ) - elif query_state in self.TERMINAL_STATES: - self.log.info( - "Query execution id: %s, trial %s: Query execution completed. Final state is %s", - query_execution_id, - try_number, - query_state, - ) - final_query_state = query_state - break - else: - self.log.info( - "Query execution id: %s, trial %s: Query is still in non-terminal state - %s", - query_execution_id, - try_number, - query_state, - ) - if ( - max_polling_attempts and try_number >= max_polling_attempts - ): # Break loop if max_polling_attempts reached - final_query_state = query_state - break - try_number += 1 - sleep(self.sleep_time) - return final_query_state + try: + wait( + waiter=self.get_waiter("query_complete"), + waiter_delay=sleep_time or self.sleep_time, + max_attempts=max_polling_attempts or 120, + args={"QueryExecutionId": query_execution_id}, + failure_message=f"Error while waiting for query {query_execution_id} to complete", + status_message=f"Query execution id: {query_execution_id}, " + f"Query is still in non-terminal state", + status_args=["QueryExecution.Status.State"], + ) + except AirflowException as error: + # this function does not raise errors to keep previous behavior. + self.log.warning(error) + finally: + return self.check_query_status(query_execution_id) def get_output_location(self, query_execution_id: str) -> str: """Get the output location of the query results in S3 URI format. diff --git a/airflow/providers/amazon/aws/operators/athena.py b/airflow/providers/amazon/aws/operators/athena.py index 1bd1a97be2..612e563ce6 100644 --- a/airflow/providers/amazon/aws/operators/athena.py +++ b/airflow/providers/amazon/aws/operators/athena.py @@ -88,7 +88,7 @@ class AthenaOperator(BaseOperator): @cached_property def hook(self) -> AthenaHook: """Create and return an AthenaHook.""" - return AthenaHook(self.aws_conn_id, sleep_time=self.sleep_time, log_query=self.log_query) + return AthenaHook(self.aws_conn_id, log_query=self.log_query) def execute(self, context: Context) -> str | None: """Run Presto Query on Athena.""" @@ -104,6 +104,7 @@ class AthenaOperator(BaseOperator): query_status = self.hook.poll_query_status( self.query_execution_id, max_polling_attempts=self.max_polling_attempts, + sleep_time=self.sleep_time, ) if query_status in AthenaHook.FAILURE_STATES: @@ -139,4 +140,4 @@ class AthenaOperator(BaseOperator): self.log.info( "Polling Athena for query with id %s to reach final state", self.query_execution_id ) - self.hook.poll_query_status(self.query_execution_id) + self.hook.poll_query_status(self.query_execution_id, sleep_time=self.sleep_time) diff --git a/airflow/providers/amazon/aws/sensors/athena.py b/airflow/providers/amazon/aws/sensors/athena.py index f67fb3ff9a..599341092e 100644 --- a/airflow/providers/amazon/aws/sensors/athena.py +++ b/airflow/providers/amazon/aws/sensors/athena.py @@ -76,7 +76,7 @@ class AthenaSensor(BaseSensorOperator): self.max_retries = max_retries def poke(self, context: Context) -> bool: - state = self.hook.poll_query_status(self.query_execution_id, self.max_retries) + state = self.hook.poll_query_status(self.query_execution_id, self.max_retries, self.sleep_time) if state in self.FAILURE_STATES: raise AirflowException("Athena sensor failed") @@ -88,4 +88,4 @@ class AthenaSensor(BaseSensorOperator): @cached_property def hook(self) -> AthenaHook: """Create and return an AthenaHook.""" - return AthenaHook(self.aws_conn_id, sleep_time=self.sleep_time) + return AthenaHook(self.aws_conn_id) diff --git a/airflow/providers/amazon/aws/waiters/athena.json b/airflow/providers/amazon/aws/waiters/athena.json new file mode 100644 index 0000000000..db68ce32f4 --- /dev/null +++ b/airflow/providers/amazon/aws/waiters/athena.json @@ -0,0 +1,30 @@ +{ + "version": 2, + "waiters": { + "query_complete": { + "operation": "GetQueryExecution", + "delay": 30, + "maxAttempts": 120, + "acceptors": [ + { + "expected": "SUCCEEDED", + "matcher": "path", + "state": "success", + "argument": "QueryExecution.Status.State" + }, + { + "expected": "FAILED", + "matcher": "path", + "state": "failure", + "argument": "QueryExecution.Status.State" + }, + { + "expected": "CANCELLED", + "matcher": "path", + "state": "failure", + "argument": "QueryExecution.Status.State" + } + ] + } + } +} diff --git a/tests/providers/amazon/aws/hooks/test_athena.py b/tests/providers/amazon/aws/hooks/test_athena.py index a65470acea..05ed6e9e30 100644 --- a/tests/providers/amazon/aws/hooks/test_athena.py +++ b/tests/providers/amazon/aws/hooks/test_athena.py @@ -49,11 +49,10 @@ MOCK_QUERY_EXECUTION_OUTPUT = { class TestAthenaHook: def setup_method(self): - self.athena = AthenaHook(sleep_time=0) + self.athena = AthenaHook() def test_init(self): assert self.athena.aws_conn_id == "aws_default" - assert self.athena.sleep_time == 0 @mock.patch.object(AthenaHook, "get_conn") def test_hook_run_query_without_token(self, mock_conn): @@ -104,7 +103,7 @@ class TestAthenaHook: @mock.patch.object(AthenaHook, "log") @mock.patch.object(AthenaHook, "get_conn") def test_hook_run_query_no_log_query(self, mock_conn, log): - athena_hook_no_log_query = AthenaHook(sleep_time=0, log_query=False) + athena_hook_no_log_query = AthenaHook(log_query=False) athena_hook_no_log_query.run_query( query=MOCK_DATA["query"], query_context=mock_query_context, @@ -176,7 +175,9 @@ class TestAthenaHook: @mock.patch.object(AthenaHook, "get_conn") def test_hook_poll_query_when_final(self, mock_conn): mock_conn.return_value.get_query_execution.return_value = MOCK_SUCCEEDED_QUERY_EXECUTION - result = self.athena.poll_query_status(query_execution_id=MOCK_DATA["query_execution_id"]) + result = self.athena.poll_query_status( + query_execution_id=MOCK_DATA["query_execution_id"], sleep_time=0 + ) mock_conn.return_value.get_query_execution.assert_called_once() assert result == "SUCCEEDED" @@ -184,8 +185,7 @@ class TestAthenaHook: def test_hook_poll_query_with_timeout(self, mock_conn): mock_conn.return_value.get_query_execution.return_value = MOCK_RUNNING_QUERY_EXECUTION result = self.athena.poll_query_status( - query_execution_id=MOCK_DATA["query_execution_id"], - max_polling_attempts=1, + query_execution_id=MOCK_DATA["query_execution_id"], max_polling_attempts=1, sleep_time=0 ) mock_conn.return_value.get_query_execution.assert_called_once() assert result == "RUNNING" diff --git a/tests/providers/amazon/aws/operators/test_athena.py b/tests/providers/amazon/aws/operators/test_athena.py index e7b945d2a4..cfc7869768 100644 --- a/tests/providers/amazon/aws/operators/test_athena.py +++ b/tests/providers/amazon/aws/operators/test_athena.py @@ -71,8 +71,6 @@ class TestAthenaOperator: assert self.athena.client_request_token == MOCK_DATA["client_request_token"] assert self.athena.sleep_time == 0 - assert self.athena.hook.sleep_time == 0 - @mock.patch.object(AthenaHook, "check_query_status", side_effect=("SUCCEEDED",)) @mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID) @mock.patch.object(AthenaHook, "get_conn") @@ -90,11 +88,7 @@ class TestAthenaOperator: @mock.patch.object( AthenaHook, "check_query_status", - side_effect=( - "RUNNING", - "RUNNING", - "SUCCEEDED", - ), + side_effect="SUCCEEDED", ) @mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID) @mock.patch.object(AthenaHook, "get_conn") @@ -107,39 +101,9 @@ class TestAthenaOperator: MOCK_DATA["client_request_token"], MOCK_DATA["workgroup"], ) - assert mock_check_query_status.call_count == 3 - - @mock.patch.object( - AthenaHook, - "check_query_status", - side_effect=( - None, - None, - ), - ) - @mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID) - @mock.patch.object(AthenaHook, "get_conn") - def test_hook_run_failed_query_with_none(self, mock_conn, mock_run_query, mock_check_query_status): - with pytest.raises(Exception): - self.athena.execute({}) - mock_run_query.assert_called_once_with( - MOCK_DATA["query"], - query_context, - result_configuration, - MOCK_DATA["client_request_token"], - MOCK_DATA["workgroup"], - ) - assert mock_check_query_status.call_count == 3 @mock.patch.object(AthenaHook, "get_state_change_reason") - @mock.patch.object( - AthenaHook, - "check_query_status", - side_effect=( - "RUNNING", - "FAILED", - ), - ) + @mock.patch.object(AthenaHook, "check_query_status", return_value="FAILED") @mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID) @mock.patch.object(AthenaHook, "get_conn") def test_hook_run_failure_query( @@ -154,18 +118,9 @@ class TestAthenaOperator: MOCK_DATA["client_request_token"], MOCK_DATA["workgroup"], ) - assert mock_check_query_status.call_count == 2 assert mock_get_state_change_reason.call_count == 1 - @mock.patch.object( - AthenaHook, - "check_query_status", - side_effect=( - "RUNNING", - "RUNNING", - "CANCELLED", - ), - ) + @mock.patch.object(AthenaHook, "check_query_status", return_value="CANCELLED") @mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID) @mock.patch.object(AthenaHook, "get_conn") def test_hook_run_cancelled_query(self, mock_conn, mock_run_query, mock_check_query_status): @@ -178,17 +133,8 @@ class TestAthenaOperator: MOCK_DATA["client_request_token"], MOCK_DATA["workgroup"], ) - assert mock_check_query_status.call_count == 3 - @mock.patch.object( - AthenaHook, - "check_query_status", - side_effect=( - "RUNNING", - "RUNNING", - "RUNNING", - ), - ) + @mock.patch.object(AthenaHook, "check_query_status", return_value="RUNNING") @mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID) @mock.patch.object(AthenaHook, "get_conn") def test_hook_run_failed_query_with_max_tries(self, mock_conn, mock_run_query, mock_check_query_status): @@ -201,7 +147,6 @@ class TestAthenaOperator: MOCK_DATA["client_request_token"], MOCK_DATA["workgroup"], ) - assert mock_check_query_status.call_count == 3 @mock.patch.object(AthenaHook, "check_query_status", side_effect=("SUCCEEDED",)) @mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID)