This is an automated email from the ASF dual-hosted git repository.

onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 72d09a677f Use a waiter in `AthenaHook` (#31942)
72d09a677f is described below

commit 72d09a677fea22b51dbf20f3b12bae6b3c1e4792
Author: Raphaƫl Vandon <vand...@amazon.com>
AuthorDate: Fri Jun 23 14:20:37 2023 -0700

    Use a waiter in `AthenaHook` (#31942)
    
    * Use custom waiters for Emr Serverless operators
    Update unit tests
    ---------
    
    Co-authored-by: Syed Hussain <syeda...@amazon.com>
    Co-authored-by: Vincent <97131062+vincb...@users.noreply.github.com>
---
 airflow/providers/amazon/aws/hooks/athena.py       | 80 ++++++++++------------
 airflow/providers/amazon/aws/operators/athena.py   |  5 +-
 airflow/providers/amazon/aws/sensors/athena.py     |  4 +-
 airflow/providers/amazon/aws/waiters/athena.json   | 30 ++++++++
 tests/providers/amazon/aws/hooks/test_athena.py    | 12 ++--
 .../providers/amazon/aws/operators/test_athena.py  | 63 ++---------------
 6 files changed, 81 insertions(+), 113 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/athena.py 
b/airflow/providers/amazon/aws/hooks/athena.py
index f68eee9355..b0d1878507 100644
--- a/airflow/providers/amazon/aws/hooks/athena.py
+++ b/airflow/providers/amazon/aws/hooks/athena.py
@@ -24,12 +24,14 @@ This module contains AWS Athena hook.
 """
 from __future__ import annotations
 
-from time import sleep
+import warnings
 from typing import Any
 
 from botocore.paginate import PageIterator
 
+from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
 from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
 
 
 class AthenaHook(AwsBaseHook):
@@ -38,8 +40,7 @@ class AthenaHook(AwsBaseHook):
     Provide thick wrapper around
     :external+boto3:py:class:`boto3.client("athena") <Athena.Client>`.
 
-    :param sleep_time: Time (in seconds) to wait between two consecutive calls
-        to check query status on Athena.
+    :param sleep_time: obsolete, please use the parameter of 
`poll_query_status` method instead
     :param log_query: Whether to log athena query and other execution params
         when it's executed. Defaults to *True*.
 
@@ -65,9 +66,20 @@ class AthenaHook(AwsBaseHook):
         "CANCELLED",
     )
 
-    def __init__(self, *args: Any, sleep_time: int = 30, log_query: bool = 
True, **kwargs: Any) -> None:
+    def __init__(
+        self, *args: Any, sleep_time: int | None = None, log_query: bool = 
True, **kwargs: Any
+    ) -> None:
         super().__init__(client_type="athena", *args, **kwargs)  # type: ignore
-        self.sleep_time = sleep_time
+        if sleep_time is not None:
+            self.sleep_time = sleep_time
+            warnings.warn(
+                "The `sleep_time` parameter of the Athena hook is deprecated, "
+                "please pass this parameter to the poll_query_status method 
instead.",
+                AirflowProviderDeprecationWarning,
+                stacklevel=2,
+            )
+        else:
+            self.sleep_time = 30  # previous default value
         self.log_query = log_query
 
     def run_query(
@@ -229,51 +241,31 @@ class AthenaHook(AwsBaseHook):
         return paginator.paginate(**result_params)
 
     def poll_query_status(
-        self,
-        query_execution_id: str,
-        max_polling_attempts: int | None = None,
+        self, query_execution_id: str, max_polling_attempts: int | None = 
None, sleep_time: int | None = None
     ) -> str | None:
         """Poll the state of a submitted query until it reaches final state.
 
         :param query_execution_id: ID of submitted athena query
-        :param max_polling_attempts: Number of times to poll for query state
-            before function exits
+        :param max_polling_attempts: Number of times to poll for query state 
before function exits
+        :param sleep_time: Time (in seconds) to wait between two consecutive 
query status checks.
         :return: One of the final states
         """
-        try_number = 1
-        final_query_state = None  # Query state when query reaches final state 
or max_polling_attempts reached
-        while True:
-            query_state = self.check_query_status(query_execution_id)
-            if query_state is None:
-                self.log.info(
-                    "Query execution id: %s, trial %s: Invalid query state. 
Retrying again",
-                    query_execution_id,
-                    try_number,
-                )
-            elif query_state in self.TERMINAL_STATES:
-                self.log.info(
-                    "Query execution id: %s, trial %s: Query execution 
completed. Final state is %s",
-                    query_execution_id,
-                    try_number,
-                    query_state,
-                )
-                final_query_state = query_state
-                break
-            else:
-                self.log.info(
-                    "Query execution id: %s, trial %s: Query is still in 
non-terminal state - %s",
-                    query_execution_id,
-                    try_number,
-                    query_state,
-                )
-            if (
-                max_polling_attempts and try_number >= max_polling_attempts
-            ):  # Break loop if max_polling_attempts reached
-                final_query_state = query_state
-                break
-            try_number += 1
-            sleep(self.sleep_time)
-        return final_query_state
+        try:
+            wait(
+                waiter=self.get_waiter("query_complete"),
+                waiter_delay=sleep_time or self.sleep_time,
+                max_attempts=max_polling_attempts or 120,
+                args={"QueryExecutionId": query_execution_id},
+                failure_message=f"Error while waiting for query 
{query_execution_id} to complete",
+                status_message=f"Query execution id: {query_execution_id}, "
+                f"Query is still in non-terminal state",
+                status_args=["QueryExecution.Status.State"],
+            )
+        except AirflowException as error:
+            # this function does not raise errors to keep previous behavior.
+            self.log.warning(error)
+        finally:
+            return self.check_query_status(query_execution_id)
 
     def get_output_location(self, query_execution_id: str) -> str:
         """Get the output location of the query results in S3 URI format.
diff --git a/airflow/providers/amazon/aws/operators/athena.py 
b/airflow/providers/amazon/aws/operators/athena.py
index 1bd1a97be2..612e563ce6 100644
--- a/airflow/providers/amazon/aws/operators/athena.py
+++ b/airflow/providers/amazon/aws/operators/athena.py
@@ -88,7 +88,7 @@ class AthenaOperator(BaseOperator):
     @cached_property
     def hook(self) -> AthenaHook:
         """Create and return an AthenaHook."""
-        return AthenaHook(self.aws_conn_id, sleep_time=self.sleep_time, 
log_query=self.log_query)
+        return AthenaHook(self.aws_conn_id, log_query=self.log_query)
 
     def execute(self, context: Context) -> str | None:
         """Run Presto Query on Athena."""
@@ -104,6 +104,7 @@ class AthenaOperator(BaseOperator):
         query_status = self.hook.poll_query_status(
             self.query_execution_id,
             max_polling_attempts=self.max_polling_attempts,
+            sleep_time=self.sleep_time,
         )
 
         if query_status in AthenaHook.FAILURE_STATES:
@@ -139,4 +140,4 @@ class AthenaOperator(BaseOperator):
                     self.log.info(
                         "Polling Athena for query with id %s to reach final 
state", self.query_execution_id
                     )
-                    self.hook.poll_query_status(self.query_execution_id)
+                    self.hook.poll_query_status(self.query_execution_id, 
sleep_time=self.sleep_time)
diff --git a/airflow/providers/amazon/aws/sensors/athena.py 
b/airflow/providers/amazon/aws/sensors/athena.py
index f67fb3ff9a..599341092e 100644
--- a/airflow/providers/amazon/aws/sensors/athena.py
+++ b/airflow/providers/amazon/aws/sensors/athena.py
@@ -76,7 +76,7 @@ class AthenaSensor(BaseSensorOperator):
         self.max_retries = max_retries
 
     def poke(self, context: Context) -> bool:
-        state = self.hook.poll_query_status(self.query_execution_id, 
self.max_retries)
+        state = self.hook.poll_query_status(self.query_execution_id, 
self.max_retries, self.sleep_time)
 
         if state in self.FAILURE_STATES:
             raise AirflowException("Athena sensor failed")
@@ -88,4 +88,4 @@ class AthenaSensor(BaseSensorOperator):
     @cached_property
     def hook(self) -> AthenaHook:
         """Create and return an AthenaHook."""
-        return AthenaHook(self.aws_conn_id, sleep_time=self.sleep_time)
+        return AthenaHook(self.aws_conn_id)
diff --git a/airflow/providers/amazon/aws/waiters/athena.json 
b/airflow/providers/amazon/aws/waiters/athena.json
new file mode 100644
index 0000000000..db68ce32f4
--- /dev/null
+++ b/airflow/providers/amazon/aws/waiters/athena.json
@@ -0,0 +1,30 @@
+{
+    "version": 2,
+    "waiters": {
+        "query_complete": {
+            "operation": "GetQueryExecution",
+            "delay": 30,
+            "maxAttempts": 120,
+            "acceptors": [
+                {
+                    "expected": "SUCCEEDED",
+                    "matcher": "path",
+                    "state": "success",
+                    "argument": "QueryExecution.Status.State"
+                },
+                {
+                    "expected": "FAILED",
+                    "matcher": "path",
+                    "state": "failure",
+                    "argument": "QueryExecution.Status.State"
+                },
+                {
+                    "expected": "CANCELLED",
+                    "matcher": "path",
+                    "state": "failure",
+                    "argument": "QueryExecution.Status.State"
+                }
+            ]
+        }
+    }
+}
diff --git a/tests/providers/amazon/aws/hooks/test_athena.py 
b/tests/providers/amazon/aws/hooks/test_athena.py
index a65470acea..05ed6e9e30 100644
--- a/tests/providers/amazon/aws/hooks/test_athena.py
+++ b/tests/providers/amazon/aws/hooks/test_athena.py
@@ -49,11 +49,10 @@ MOCK_QUERY_EXECUTION_OUTPUT = {
 
 class TestAthenaHook:
     def setup_method(self):
-        self.athena = AthenaHook(sleep_time=0)
+        self.athena = AthenaHook()
 
     def test_init(self):
         assert self.athena.aws_conn_id == "aws_default"
-        assert self.athena.sleep_time == 0
 
     @mock.patch.object(AthenaHook, "get_conn")
     def test_hook_run_query_without_token(self, mock_conn):
@@ -104,7 +103,7 @@ class TestAthenaHook:
     @mock.patch.object(AthenaHook, "log")
     @mock.patch.object(AthenaHook, "get_conn")
     def test_hook_run_query_no_log_query(self, mock_conn, log):
-        athena_hook_no_log_query = AthenaHook(sleep_time=0, log_query=False)
+        athena_hook_no_log_query = AthenaHook(log_query=False)
         athena_hook_no_log_query.run_query(
             query=MOCK_DATA["query"],
             query_context=mock_query_context,
@@ -176,7 +175,9 @@ class TestAthenaHook:
     @mock.patch.object(AthenaHook, "get_conn")
     def test_hook_poll_query_when_final(self, mock_conn):
         mock_conn.return_value.get_query_execution.return_value = 
MOCK_SUCCEEDED_QUERY_EXECUTION
-        result = 
self.athena.poll_query_status(query_execution_id=MOCK_DATA["query_execution_id"])
+        result = self.athena.poll_query_status(
+            query_execution_id=MOCK_DATA["query_execution_id"], sleep_time=0
+        )
         mock_conn.return_value.get_query_execution.assert_called_once()
         assert result == "SUCCEEDED"
 
@@ -184,8 +185,7 @@ class TestAthenaHook:
     def test_hook_poll_query_with_timeout(self, mock_conn):
         mock_conn.return_value.get_query_execution.return_value = 
MOCK_RUNNING_QUERY_EXECUTION
         result = self.athena.poll_query_status(
-            query_execution_id=MOCK_DATA["query_execution_id"],
-            max_polling_attempts=1,
+            query_execution_id=MOCK_DATA["query_execution_id"], 
max_polling_attempts=1, sleep_time=0
         )
         mock_conn.return_value.get_query_execution.assert_called_once()
         assert result == "RUNNING"
diff --git a/tests/providers/amazon/aws/operators/test_athena.py 
b/tests/providers/amazon/aws/operators/test_athena.py
index e7b945d2a4..cfc7869768 100644
--- a/tests/providers/amazon/aws/operators/test_athena.py
+++ b/tests/providers/amazon/aws/operators/test_athena.py
@@ -71,8 +71,6 @@ class TestAthenaOperator:
         assert self.athena.client_request_token == 
MOCK_DATA["client_request_token"]
         assert self.athena.sleep_time == 0
 
-        assert self.athena.hook.sleep_time == 0
-
     @mock.patch.object(AthenaHook, "check_query_status", 
side_effect=("SUCCEEDED",))
     @mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID)
     @mock.patch.object(AthenaHook, "get_conn")
@@ -90,11 +88,7 @@ class TestAthenaOperator:
     @mock.patch.object(
         AthenaHook,
         "check_query_status",
-        side_effect=(
-            "RUNNING",
-            "RUNNING",
-            "SUCCEEDED",
-        ),
+        side_effect="SUCCEEDED",
     )
     @mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID)
     @mock.patch.object(AthenaHook, "get_conn")
@@ -107,39 +101,9 @@ class TestAthenaOperator:
             MOCK_DATA["client_request_token"],
             MOCK_DATA["workgroup"],
         )
-        assert mock_check_query_status.call_count == 3
-
-    @mock.patch.object(
-        AthenaHook,
-        "check_query_status",
-        side_effect=(
-            None,
-            None,
-        ),
-    )
-    @mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID)
-    @mock.patch.object(AthenaHook, "get_conn")
-    def test_hook_run_failed_query_with_none(self, mock_conn, mock_run_query, 
mock_check_query_status):
-        with pytest.raises(Exception):
-            self.athena.execute({})
-        mock_run_query.assert_called_once_with(
-            MOCK_DATA["query"],
-            query_context,
-            result_configuration,
-            MOCK_DATA["client_request_token"],
-            MOCK_DATA["workgroup"],
-        )
-        assert mock_check_query_status.call_count == 3
 
     @mock.patch.object(AthenaHook, "get_state_change_reason")
-    @mock.patch.object(
-        AthenaHook,
-        "check_query_status",
-        side_effect=(
-            "RUNNING",
-            "FAILED",
-        ),
-    )
+    @mock.patch.object(AthenaHook, "check_query_status", return_value="FAILED")
     @mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID)
     @mock.patch.object(AthenaHook, "get_conn")
     def test_hook_run_failure_query(
@@ -154,18 +118,9 @@ class TestAthenaOperator:
             MOCK_DATA["client_request_token"],
             MOCK_DATA["workgroup"],
         )
-        assert mock_check_query_status.call_count == 2
         assert mock_get_state_change_reason.call_count == 1
 
-    @mock.patch.object(
-        AthenaHook,
-        "check_query_status",
-        side_effect=(
-            "RUNNING",
-            "RUNNING",
-            "CANCELLED",
-        ),
-    )
+    @mock.patch.object(AthenaHook, "check_query_status", 
return_value="CANCELLED")
     @mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID)
     @mock.patch.object(AthenaHook, "get_conn")
     def test_hook_run_cancelled_query(self, mock_conn, mock_run_query, 
mock_check_query_status):
@@ -178,17 +133,8 @@ class TestAthenaOperator:
             MOCK_DATA["client_request_token"],
             MOCK_DATA["workgroup"],
         )
-        assert mock_check_query_status.call_count == 3
 
-    @mock.patch.object(
-        AthenaHook,
-        "check_query_status",
-        side_effect=(
-            "RUNNING",
-            "RUNNING",
-            "RUNNING",
-        ),
-    )
+    @mock.patch.object(AthenaHook, "check_query_status", 
return_value="RUNNING")
     @mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID)
     @mock.patch.object(AthenaHook, "get_conn")
     def test_hook_run_failed_query_with_max_tries(self, mock_conn, 
mock_run_query, mock_check_query_status):
@@ -201,7 +147,6 @@ class TestAthenaOperator:
             MOCK_DATA["client_request_token"],
             MOCK_DATA["workgroup"],
         )
-        assert mock_check_query_status.call_count == 3
 
     @mock.patch.object(AthenaHook, "check_query_status", 
side_effect=("SUCCEEDED",))
     @mock.patch.object(AthenaHook, "run_query", return_value=ATHENA_QUERY_ID)

Reply via email to