This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2063e141e4 Handle transient state errors in 
`RedshiftResumeClusterOperator` and `RedshiftPauseClusterOperator` (#27276)
2063e141e4 is described below

commit 2063e141e445d2567149154cdf90955a941e45b4
Author: Syed Hussaain <[email protected]>
AuthorDate: Thu Nov 17 00:16:25 2022 -0800

    Handle transient state errors in `RedshiftResumeClusterOperator` and 
`RedshiftPauseClusterOperator` (#27276)
    
    * Modify RedshiftPauseClusterOperator and RedshiftResumeClusterOperator to 
attempt to pause and resume multiple times to avoid edge cases of state changes
---
 .../providers/amazon/aws/hooks/redshift_cluster.py | 15 +++-
 .../amazon/aws/operators/redshift_cluster.py       | 49 ++++++++----
 .../amazon/aws/operators/test_redshift_cluster.py  | 86 ++++++++++++++++------
 .../providers/amazon/aws/example_redshift.py       | 35 ++++-----
 4 files changed, 128 insertions(+), 57 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/redshift_cluster.py 
b/airflow/providers/amazon/aws/hooks/redshift_cluster.py
index 43d7993af7..d85929d062 100644
--- a/airflow/providers/amazon/aws/hooks/redshift_cluster.py
+++ b/airflow/providers/amazon/aws/hooks/redshift_cluster.py
@@ -16,6 +16,7 @@
 # under the License.
 from __future__ import annotations
 
+import warnings
 from typing import Any, Sequence
 
 from botocore.exceptions import ClientError
@@ -157,16 +158,24 @@ class RedshiftHook(AwsBaseHook):
         )
         return response["Snapshot"] if response["Snapshot"] else None
 
-    def get_cluster_snapshot_status(self, snapshot_identifier: str, 
cluster_identifier: str):
+    def get_cluster_snapshot_status(self, snapshot_identifier: str, 
cluster_identifier: str | None = None):
         """
         Return Redshift cluster snapshot status. If cluster snapshot not found 
return ``None``
 
         :param snapshot_identifier: A unique identifier for the snapshot that 
you are requesting
-        :param cluster_identifier: The unique identifier of the cluster the 
snapshot was created from
+        :param cluster_identifier: (deprecated) The unique identifier of the 
cluster
+            the snapshot was created from
         """
+        if cluster_identifier:
+            warnings.warn(
+                "Parameter `cluster_identifier` is deprecated."
+                "This option will be removed in a future version.",
+                DeprecationWarning,
+                stacklevel=2,
+            )
+
         try:
             response = self.get_conn().describe_cluster_snapshots(
-                ClusterIdentifier=cluster_identifier,
                 SnapshotIdentifier=snapshot_identifier,
             )
             snapshot = response.get("Snapshots")[0]
diff --git a/airflow/providers/amazon/aws/operators/redshift_cluster.py 
b/airflow/providers/amazon/aws/operators/redshift_cluster.py
index 39515b7137..2da0fbf23a 100644
--- a/airflow/providers/amazon/aws/operators/redshift_cluster.py
+++ b/airflow/providers/amazon/aws/operators/redshift_cluster.py
@@ -367,7 +367,6 @@ class RedshiftDeleteClusterSnapshotOperator(BaseOperator):
     def get_status(self) -> str:
         return self.redshift_hook.get_cluster_snapshot_status(
             snapshot_identifier=self.snapshot_identifier,
-            cluster_identifier=self.cluster_identifier,
         )
 
 
@@ -397,15 +396,27 @@ class RedshiftResumeClusterOperator(BaseOperator):
         super().__init__(**kwargs)
         self.cluster_identifier = cluster_identifier
         self.aws_conn_id = aws_conn_id
+        # These parameters are added to address an issue with the boto3 API 
where the API
+        # prematurely reports the cluster as available to receive requests. 
This causes the cluster
+        # to reject initial attempts to resume the cluster despite reporting 
the correct state.
+        self._attempts = 10
+        self._attempt_interval = 15
 
     def execute(self, context: Context):
         redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
-        cluster_state = 
redshift_hook.cluster_status(cluster_identifier=self.cluster_identifier)
-        if cluster_state == "paused":
-            self.log.info("Starting Redshift cluster %s", 
self.cluster_identifier)
-            
redshift_hook.get_conn().resume_cluster(ClusterIdentifier=self.cluster_identifier)
-        else:
-            raise Exception(f"Unable to resume cluster - cluster state is 
{cluster_state}")
+
+        while self._attempts >= 1:
+            try:
+                
redshift_hook.get_conn().resume_cluster(ClusterIdentifier=self.cluster_identifier)
+                return
+            except 
redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
+                self._attempts = self._attempts - 1
+
+                if self._attempts > 0:
+                    self.log.error("Unable to resume cluster. %d attempts 
remaining.", self._attempts)
+                    time.sleep(self._attempt_interval)
+                else:
+                    raise error
 
 
 class RedshiftPauseClusterOperator(BaseOperator):
@@ -434,15 +445,27 @@ class RedshiftPauseClusterOperator(BaseOperator):
         super().__init__(**kwargs)
         self.cluster_identifier = cluster_identifier
         self.aws_conn_id = aws_conn_id
+        # These parameters are added to address an issue with the boto3 API 
where the API
+        # prematurely reports the cluster as available to receive requests. 
This causes the cluster
+        # to reject initial attempts to pause the cluster despite reporting 
the correct state.
+        self._attempts = 10
+        self._attempt_interval = 15
 
     def execute(self, context: Context):
         redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
-        cluster_state = 
redshift_hook.cluster_status(cluster_identifier=self.cluster_identifier)
-        if cluster_state == "available":
-            self.log.info("Pausing Redshift cluster %s", 
self.cluster_identifier)
-            
redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier)
-        else:
-            raise Exception(f"Unable to pause cluster - cluster state is 
{cluster_state}")
+
+        while self._attempts >= 1:
+            try:
+                
redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier)
+                return
+            except 
redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
+                self._attempts = self._attempts - 1
+
+                if self._attempts > 0:
+                    self.log.error("Unable to pause cluster. %d attempts 
remaining.", self._attempts)
+                    time.sleep(self._attempt_interval)
+                else:
+                    raise error
 
 
 class RedshiftDeleteClusterOperator(BaseOperator):
diff --git a/tests/providers/amazon/aws/operators/test_redshift_cluster.py 
b/tests/providers/amazon/aws/operators/test_redshift_cluster.py
index c10daa9a70..5a9322c9c9 100644
--- a/tests/providers/amazon/aws/operators/test_redshift_cluster.py
+++ b/tests/providers/amazon/aws/operators/test_redshift_cluster.py
@@ -18,6 +18,7 @@ from __future__ import annotations
 
 from unittest import mock
 
+import boto3
 import pytest
 
 from airflow.exceptions import AirflowException
@@ -172,7 +173,6 @@ class TestRedshiftDeleteClusterSnapshotOperator:
         )
 
         mock_get_cluster_snapshot_status.assert_called_once_with(
-            cluster_identifier="test_cluster",
             snapshot_identifier="test_snapshot",
         )
 
@@ -205,26 +205,47 @@ class TestResumeClusterOperator:
         assert redshift_operator.cluster_identifier == "test_cluster"
         assert redshift_operator.aws_conn_id == "aws_conn_test"
 
-    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status")
     
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn")
-    def test_resume_cluster_is_called_when_cluster_is_paused(self, 
mock_get_conn, mock_cluster_status):
-        mock_cluster_status.return_value = "paused"
+    def test_resume_cluster_is_called_when_cluster_is_paused(self, 
mock_get_conn):
         redshift_operator = RedshiftResumeClusterOperator(
             task_id="task_test", cluster_identifier="test_cluster", 
aws_conn_id="aws_conn_test"
         )
         redshift_operator.execute(None)
         
mock_get_conn.return_value.resume_cluster.assert_called_once_with(ClusterIdentifier="test_cluster")
 
-    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status")
-    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn")
-    def test_resume_cluster_not_called_when_cluster_is_not_paused(self, 
mock_get_conn, mock_cluster_status):
-        mock_cluster_status.return_value = "available"
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.conn")
+    @mock.patch("time.sleep", return_value=None)
+    def test_resume_cluster_multiple_attempts(self, mock_sleep, mock_conn):
+        exception = 
boto3.client("redshift").exceptions.InvalidClusterStateFault({}, "test")
+        returned_exception = type(exception)
+
+        mock_conn.exceptions.InvalidClusterStateFault = returned_exception
+        mock_conn.resume_cluster.side_effect = [exception, exception, True]
         redshift_operator = RedshiftResumeClusterOperator(
-            task_id="task_test", cluster_identifier="test_cluster", 
aws_conn_id="aws_conn_test"
+            task_id="task_test",
+            cluster_identifier="test_cluster",
+            aws_conn_id="aws_conn_test",
         )
-        with pytest.raises(Exception):
+        redshift_operator.execute(None)
+        assert mock_conn.resume_cluster.call_count == 3
+
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.conn")
+    @mock.patch("time.sleep", return_value=None)
+    def test_resume_cluster_multiple_attempts_fail(self, mock_sleep, 
mock_conn):
+        exception = 
boto3.client("redshift").exceptions.InvalidClusterStateFault({}, "test")
+        returned_exception = type(exception)
+
+        mock_conn.exceptions.InvalidClusterStateFault = returned_exception
+        mock_conn.resume_cluster.side_effect = exception
+
+        redshift_operator = RedshiftResumeClusterOperator(
+            task_id="task_test",
+            cluster_identifier="test_cluster",
+            aws_conn_id="aws_conn_test",
+        )
+        with pytest.raises(returned_exception):
             redshift_operator.execute(None)
-        mock_get_conn.return_value.resume_cluster.assert_not_called()
+        assert mock_conn.resume_cluster.call_count == 10
 
 
 class TestPauseClusterOperator:
@@ -236,26 +257,49 @@ class TestPauseClusterOperator:
         assert redshift_operator.cluster_identifier == "test_cluster"
         assert redshift_operator.aws_conn_id == "aws_conn_test"
 
-    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status")
     
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn")
-    def test_pause_cluster_is_called_when_cluster_is_available(self, 
mock_get_conn, mock_cluster_status):
-        mock_cluster_status.return_value = "available"
+    def test_pause_cluster_is_called_when_cluster_is_available(self, 
mock_get_conn):
         redshift_operator = RedshiftPauseClusterOperator(
             task_id="task_test", cluster_identifier="test_cluster", 
aws_conn_id="aws_conn_test"
         )
         redshift_operator.execute(None)
         
mock_get_conn.return_value.pause_cluster.assert_called_once_with(ClusterIdentifier="test_cluster")
 
-    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status")
-    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn")
-    def test_pause_cluster_not_called_when_cluster_is_not_available(self, 
mock_get_conn, mock_cluster_status):
-        mock_cluster_status.return_value = "paused"
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.conn")
+    @mock.patch("time.sleep", return_value=None)
+    def test_pause_cluster_multiple_attempts(self, mock_sleep, mock_conn):
+        exception = 
boto3.client("redshift").exceptions.InvalidClusterStateFault({}, "test")
+        returned_exception = type(exception)
+
+        mock_conn.exceptions.InvalidClusterStateFault = returned_exception
+        mock_conn.pause_cluster.side_effect = [exception, exception, True]
+
         redshift_operator = RedshiftPauseClusterOperator(
-            task_id="task_test", cluster_identifier="test_cluster", 
aws_conn_id="aws_conn_test"
+            task_id="task_test",
+            cluster_identifier="test_cluster",
+            aws_conn_id="aws_conn_test",
+        )
+
+        redshift_operator.execute(None)
+        assert mock_conn.pause_cluster.call_count == 3
+
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.conn")
+    @mock.patch("time.sleep", return_value=None)
+    def test_pause_cluster_multiple_attempts_fail(self, mock_sleep, mock_conn):
+        exception = 
boto3.client("redshift").exceptions.InvalidClusterStateFault({}, "test")
+        returned_exception = type(exception)
+
+        mock_conn.exceptions.InvalidClusterStateFault = returned_exception
+        mock_conn.pause_cluster.side_effect = exception
+
+        redshift_operator = RedshiftPauseClusterOperator(
+            task_id="task_test",
+            cluster_identifier="test_cluster",
+            aws_conn_id="aws_conn_test",
         )
-        with pytest.raises(Exception):
+        with pytest.raises(returned_exception):
             redshift_operator.execute(None)
-        mock_get_conn.return_value.pause_cluster.assert_not_called()
+        assert mock_conn.pause_cluster.call_count == 10
 
 
 class TestDeleteClusterOperator:
diff --git a/tests/system/providers/amazon/aws/example_redshift.py 
b/tests/system/providers/amazon/aws/example_redshift.py
index 0b8b573c89..5a355f15f4 100644
--- a/tests/system/providers/amazon/aws/example_redshift.py
+++ b/tests/system/providers/amazon/aws/example_redshift.py
@@ -26,7 +26,6 @@ from airflow import DAG, settings
 from airflow.decorators import task
 from airflow.models import Connection
 from airflow.models.baseoperator import chain
-from airflow.operators.python import get_current_context
 from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
 from airflow.providers.amazon.aws.operators.redshift_cluster import (
     RedshiftCreateClusterOperator,
@@ -90,8 +89,7 @@ def setup_security_group(sec_group_name: str, ip_permissions: 
list[dict]):
     client.authorize_security_group_ingress(
         GroupId=security_group["GroupId"], GroupName=sec_group_name, 
IpPermissions=ip_permissions
     )
-    ti = get_current_context()["ti"]
-    ti.xcom_push(key="security_group_id", value=security_group["GroupId"])
+    return security_group["GroupId"]
 
 
 @task(trigger_rule=TriggerRule.ALL_DONE)
@@ -99,17 +97,6 @@ def delete_security_group(sec_group_id: str, sec_group_name: 
str):
     boto3.client("ec2").delete_security_group(GroupId=sec_group_id, 
GroupName=sec_group_name)
 
 
-@task
-def await_cluster_snapshot(cluster_identifier):
-    waiter = boto3.client("redshift").get_waiter("snapshot_available")
-    waiter.wait(
-        ClusterIdentifier=cluster_identifier,
-        WaiterConfig={
-            "MaxAttempts": 100,
-        },
-    )
-
-
 with DAG(
     dag_id=DAG_ID,
     start_date=datetime(2021, 1, 1),
@@ -130,7 +117,7 @@ with DAG(
     create_cluster = RedshiftCreateClusterOperator(
         task_id="create_cluster",
         cluster_identifier=redshift_cluster_identifier,
-        vpc_security_group_ids=[set_up_sg["security_group_id"]],
+        vpc_security_group_ids=[set_up_sg],
         publicly_accessible=True,
         cluster_type="single-node",
         node_type="dc2.large",
@@ -145,7 +132,7 @@ with DAG(
         cluster_identifier=redshift_cluster_identifier,
         target_status="available",
         poke_interval=15,
-        timeout=60 * 30,
+        timeout=60 * 15,
     )
     # [END howto_sensor_redshift_cluster]
 
@@ -161,6 +148,14 @@ with DAG(
     )
     # [END howto_operator_redshift_create_cluster_snapshot]
 
+    wait_cluster_available_before_pause = RedshiftClusterSensor(
+        task_id="wait_cluster_available_before_pause",
+        cluster_identifier=redshift_cluster_identifier,
+        target_status="available",
+        poke_interval=15,
+        timeout=60 * 15,
+    )
+
     # [START howto_operator_redshift_pause_cluster]
     pause_cluster = RedshiftPauseClusterOperator(
         task_id="pause_cluster",
@@ -173,7 +168,7 @@ with DAG(
         cluster_identifier=redshift_cluster_identifier,
         target_status="paused",
         poke_interval=15,
-        timeout=60 * 30,
+        timeout=60 * 15,
     )
 
     # [START howto_operator_redshift_resume_cluster]
@@ -188,7 +183,7 @@ with DAG(
         cluster_identifier=redshift_cluster_identifier,
         target_status="available",
         poke_interval=15,
-        timeout=60 * 30,
+        timeout=60 * 15,
     )
 
     set_up_connection = create_connection(conn_id_name, 
cluster_id=redshift_cluster_identifier)
@@ -269,7 +264,7 @@ with DAG(
     # [END howto_operator_redshift_delete_cluster_snapshot]
 
     delete_sg = delete_security_group(
-        sec_group_id=set_up_sg["security_group_id"],
+        sec_group_id=set_up_sg,
         sec_group_name=sg_name,
     )
     chain(
@@ -280,7 +275,7 @@ with DAG(
         create_cluster,
         wait_cluster_available,
         create_cluster_snapshot,
-        await_cluster_snapshot(redshift_cluster_identifier),
+        wait_cluster_available_before_pause,
         pause_cluster,
         wait_cluster_paused,
         resume_cluster,

Reply via email to