This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 2063e141e4 Handle transient state errors in
`RedshiftResumeClusterOperator` and `RedshiftPauseClusterOperator` (#27276)
2063e141e4 is described below
commit 2063e141e445d2567149154cdf90955a941e45b4
Author: Syed Hussaain <[email protected]>
AuthorDate: Thu Nov 17 00:16:25 2022 -0800
Handle transient state errors in `RedshiftResumeClusterOperator` and
`RedshiftPauseClusterOperator` (#27276)
* Modify RedshiftPauseClusterOperator and RedshiftResumeClusterOperator to
attempt to pause and resume multiple times to avoid edge cases of state changes
---
.../providers/amazon/aws/hooks/redshift_cluster.py | 15 +++-
.../amazon/aws/operators/redshift_cluster.py | 49 ++++++++----
.../amazon/aws/operators/test_redshift_cluster.py | 86 ++++++++++++++++------
.../providers/amazon/aws/example_redshift.py | 35 ++++-----
4 files changed, 128 insertions(+), 57 deletions(-)
diff --git a/airflow/providers/amazon/aws/hooks/redshift_cluster.py
b/airflow/providers/amazon/aws/hooks/redshift_cluster.py
index 43d7993af7..d85929d062 100644
--- a/airflow/providers/amazon/aws/hooks/redshift_cluster.py
+++ b/airflow/providers/amazon/aws/hooks/redshift_cluster.py
@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
+import warnings
from typing import Any, Sequence
from botocore.exceptions import ClientError
@@ -157,16 +158,24 @@ class RedshiftHook(AwsBaseHook):
)
return response["Snapshot"] if response["Snapshot"] else None
- def get_cluster_snapshot_status(self, snapshot_identifier: str,
cluster_identifier: str):
+ def get_cluster_snapshot_status(self, snapshot_identifier: str,
cluster_identifier: str | None = None):
"""
Return Redshift cluster snapshot status. If cluster snapshot not found
return ``None``
:param snapshot_identifier: A unique identifier for the snapshot that
you are requesting
- :param cluster_identifier: The unique identifier of the cluster the
snapshot was created from
+ :param cluster_identifier: (deprecated) The unique identifier of the
cluster
+ the snapshot was created from
"""
+ if cluster_identifier:
+ warnings.warn(
+ "Parameter `cluster_identifier` is deprecated."
+ "This option will be removed in a future version.",
+ DeprecationWarning,
+ stacklevel=2,
+ )
+
try:
response = self.get_conn().describe_cluster_snapshots(
- ClusterIdentifier=cluster_identifier,
SnapshotIdentifier=snapshot_identifier,
)
snapshot = response.get("Snapshots")[0]
diff --git a/airflow/providers/amazon/aws/operators/redshift_cluster.py
b/airflow/providers/amazon/aws/operators/redshift_cluster.py
index 39515b7137..2da0fbf23a 100644
--- a/airflow/providers/amazon/aws/operators/redshift_cluster.py
+++ b/airflow/providers/amazon/aws/operators/redshift_cluster.py
@@ -367,7 +367,6 @@ class RedshiftDeleteClusterSnapshotOperator(BaseOperator):
def get_status(self) -> str:
return self.redshift_hook.get_cluster_snapshot_status(
snapshot_identifier=self.snapshot_identifier,
- cluster_identifier=self.cluster_identifier,
)
@@ -397,15 +396,27 @@ class RedshiftResumeClusterOperator(BaseOperator):
super().__init__(**kwargs)
self.cluster_identifier = cluster_identifier
self.aws_conn_id = aws_conn_id
+ # These parameters are added to address an issue with the boto3 API
where the API
+ # prematurely reports the cluster as available to receive requests.
This causes the cluster
+ # to reject initial attempts to resume the cluster despite reporting
the correct state.
+ self._attempts = 10
+ self._attempt_interval = 15
def execute(self, context: Context):
redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
- cluster_state =
redshift_hook.cluster_status(cluster_identifier=self.cluster_identifier)
- if cluster_state == "paused":
- self.log.info("Starting Redshift cluster %s",
self.cluster_identifier)
-
redshift_hook.get_conn().resume_cluster(ClusterIdentifier=self.cluster_identifier)
- else:
- raise Exception(f"Unable to resume cluster - cluster state is
{cluster_state}")
+
+ while self._attempts >= 1:
+ try:
+
redshift_hook.get_conn().resume_cluster(ClusterIdentifier=self.cluster_identifier)
+ return
+ except
redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
+ self._attempts = self._attempts - 1
+
+ if self._attempts > 0:
+ self.log.error("Unable to resume cluster. %d attempts
remaining.", self._attempts)
+ time.sleep(self._attempt_interval)
+ else:
+ raise error
class RedshiftPauseClusterOperator(BaseOperator):
@@ -434,15 +445,27 @@ class RedshiftPauseClusterOperator(BaseOperator):
super().__init__(**kwargs)
self.cluster_identifier = cluster_identifier
self.aws_conn_id = aws_conn_id
+ # These parameters are added to address an issue with the boto3 API
where the API
+ # prematurely reports the cluster as available to receive requests.
This causes the cluster
+ # to reject initial attempts to pause the cluster despite reporting
the correct state.
+ self._attempts = 10
+ self._attempt_interval = 15
def execute(self, context: Context):
redshift_hook = RedshiftHook(aws_conn_id=self.aws_conn_id)
- cluster_state =
redshift_hook.cluster_status(cluster_identifier=self.cluster_identifier)
- if cluster_state == "available":
- self.log.info("Pausing Redshift cluster %s",
self.cluster_identifier)
-
redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier)
- else:
- raise Exception(f"Unable to pause cluster - cluster state is
{cluster_state}")
+
+ while self._attempts >= 1:
+ try:
+
redshift_hook.get_conn().pause_cluster(ClusterIdentifier=self.cluster_identifier)
+ return
+ except
redshift_hook.get_conn().exceptions.InvalidClusterStateFault as error:
+ self._attempts = self._attempts - 1
+
+ if self._attempts > 0:
+ self.log.error("Unable to pause cluster. %d attempts
remaining.", self._attempts)
+ time.sleep(self._attempt_interval)
+ else:
+ raise error
class RedshiftDeleteClusterOperator(BaseOperator):
diff --git a/tests/providers/amazon/aws/operators/test_redshift_cluster.py
b/tests/providers/amazon/aws/operators/test_redshift_cluster.py
index c10daa9a70..5a9322c9c9 100644
--- a/tests/providers/amazon/aws/operators/test_redshift_cluster.py
+++ b/tests/providers/amazon/aws/operators/test_redshift_cluster.py
@@ -18,6 +18,7 @@ from __future__ import annotations
from unittest import mock
+import boto3
import pytest
from airflow.exceptions import AirflowException
@@ -172,7 +173,6 @@ class TestRedshiftDeleteClusterSnapshotOperator:
)
mock_get_cluster_snapshot_status.assert_called_once_with(
- cluster_identifier="test_cluster",
snapshot_identifier="test_snapshot",
)
@@ -205,26 +205,47 @@ class TestResumeClusterOperator:
assert redshift_operator.cluster_identifier == "test_cluster"
assert redshift_operator.aws_conn_id == "aws_conn_test"
-
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn")
- def test_resume_cluster_is_called_when_cluster_is_paused(self,
mock_get_conn, mock_cluster_status):
- mock_cluster_status.return_value = "paused"
+ def test_resume_cluster_is_called_when_cluster_is_paused(self,
mock_get_conn):
redshift_operator = RedshiftResumeClusterOperator(
task_id="task_test", cluster_identifier="test_cluster",
aws_conn_id="aws_conn_test"
)
redshift_operator.execute(None)
mock_get_conn.return_value.resume_cluster.assert_called_once_with(ClusterIdentifier="test_cluster")
-
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status")
-
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn")
- def test_resume_cluster_not_called_when_cluster_is_not_paused(self,
mock_get_conn, mock_cluster_status):
- mock_cluster_status.return_value = "available"
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.conn")
+ @mock.patch("time.sleep", return_value=None)
+ def test_resume_cluster_multiple_attempts(self, mock_sleep, mock_conn):
+ exception =
boto3.client("redshift").exceptions.InvalidClusterStateFault({}, "test")
+ returned_exception = type(exception)
+
+ mock_conn.exceptions.InvalidClusterStateFault = returned_exception
+ mock_conn.resume_cluster.side_effect = [exception, exception, True]
redshift_operator = RedshiftResumeClusterOperator(
- task_id="task_test", cluster_identifier="test_cluster",
aws_conn_id="aws_conn_test"
+ task_id="task_test",
+ cluster_identifier="test_cluster",
+ aws_conn_id="aws_conn_test",
)
- with pytest.raises(Exception):
+ redshift_operator.execute(None)
+ assert mock_conn.resume_cluster.call_count == 3
+
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.conn")
+ @mock.patch("time.sleep", return_value=None)
+ def test_resume_cluster_multiple_attempts_fail(self, mock_sleep,
mock_conn):
+ exception =
boto3.client("redshift").exceptions.InvalidClusterStateFault({}, "test")
+ returned_exception = type(exception)
+
+ mock_conn.exceptions.InvalidClusterStateFault = returned_exception
+ mock_conn.resume_cluster.side_effect = exception
+
+ redshift_operator = RedshiftResumeClusterOperator(
+ task_id="task_test",
+ cluster_identifier="test_cluster",
+ aws_conn_id="aws_conn_test",
+ )
+ with pytest.raises(returned_exception):
redshift_operator.execute(None)
- mock_get_conn.return_value.resume_cluster.assert_not_called()
+ assert mock_conn.resume_cluster.call_count == 10
class TestPauseClusterOperator:
@@ -236,26 +257,49 @@ class TestPauseClusterOperator:
assert redshift_operator.cluster_identifier == "test_cluster"
assert redshift_operator.aws_conn_id == "aws_conn_test"
-
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status")
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn")
- def test_pause_cluster_is_called_when_cluster_is_available(self,
mock_get_conn, mock_cluster_status):
- mock_cluster_status.return_value = "available"
+ def test_pause_cluster_is_called_when_cluster_is_available(self,
mock_get_conn):
redshift_operator = RedshiftPauseClusterOperator(
task_id="task_test", cluster_identifier="test_cluster",
aws_conn_id="aws_conn_test"
)
redshift_operator.execute(None)
mock_get_conn.return_value.pause_cluster.assert_called_once_with(ClusterIdentifier="test_cluster")
-
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.cluster_status")
-
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.get_conn")
- def test_pause_cluster_not_called_when_cluster_is_not_available(self,
mock_get_conn, mock_cluster_status):
- mock_cluster_status.return_value = "paused"
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.conn")
+ @mock.patch("time.sleep", return_value=None)
+ def test_pause_cluster_multiple_attempts(self, mock_sleep, mock_conn):
+ exception =
boto3.client("redshift").exceptions.InvalidClusterStateFault({}, "test")
+ returned_exception = type(exception)
+
+ mock_conn.exceptions.InvalidClusterStateFault = returned_exception
+ mock_conn.pause_cluster.side_effect = [exception, exception, True]
+
redshift_operator = RedshiftPauseClusterOperator(
- task_id="task_test", cluster_identifier="test_cluster",
aws_conn_id="aws_conn_test"
+ task_id="task_test",
+ cluster_identifier="test_cluster",
+ aws_conn_id="aws_conn_test",
+ )
+
+ redshift_operator.execute(None)
+ assert mock_conn.pause_cluster.call_count == 3
+
+
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_cluster.RedshiftHook.conn")
+ @mock.patch("time.sleep", return_value=None)
+ def test_pause_cluster_multiple_attempts_fail(self, mock_sleep, mock_conn):
+ exception =
boto3.client("redshift").exceptions.InvalidClusterStateFault({}, "test")
+ returned_exception = type(exception)
+
+ mock_conn.exceptions.InvalidClusterStateFault = returned_exception
+ mock_conn.pause_cluster.side_effect = exception
+
+ redshift_operator = RedshiftPauseClusterOperator(
+ task_id="task_test",
+ cluster_identifier="test_cluster",
+ aws_conn_id="aws_conn_test",
)
- with pytest.raises(Exception):
+ with pytest.raises(returned_exception):
redshift_operator.execute(None)
- mock_get_conn.return_value.pause_cluster.assert_not_called()
+ assert mock_conn.pause_cluster.call_count == 10
class TestDeleteClusterOperator:
diff --git a/tests/system/providers/amazon/aws/example_redshift.py
b/tests/system/providers/amazon/aws/example_redshift.py
index 0b8b573c89..5a355f15f4 100644
--- a/tests/system/providers/amazon/aws/example_redshift.py
+++ b/tests/system/providers/amazon/aws/example_redshift.py
@@ -26,7 +26,6 @@ from airflow import DAG, settings
from airflow.decorators import task
from airflow.models import Connection
from airflow.models.baseoperator import chain
-from airflow.operators.python import get_current_context
from airflow.providers.amazon.aws.hooks.redshift_cluster import RedshiftHook
from airflow.providers.amazon.aws.operators.redshift_cluster import (
RedshiftCreateClusterOperator,
@@ -90,8 +89,7 @@ def setup_security_group(sec_group_name: str, ip_permissions:
list[dict]):
client.authorize_security_group_ingress(
GroupId=security_group["GroupId"], GroupName=sec_group_name,
IpPermissions=ip_permissions
)
- ti = get_current_context()["ti"]
- ti.xcom_push(key="security_group_id", value=security_group["GroupId"])
+ return security_group["GroupId"]
@task(trigger_rule=TriggerRule.ALL_DONE)
@@ -99,17 +97,6 @@ def delete_security_group(sec_group_id: str, sec_group_name:
str):
boto3.client("ec2").delete_security_group(GroupId=sec_group_id,
GroupName=sec_group_name)
-@task
-def await_cluster_snapshot(cluster_identifier):
- waiter = boto3.client("redshift").get_waiter("snapshot_available")
- waiter.wait(
- ClusterIdentifier=cluster_identifier,
- WaiterConfig={
- "MaxAttempts": 100,
- },
- )
-
-
with DAG(
dag_id=DAG_ID,
start_date=datetime(2021, 1, 1),
@@ -130,7 +117,7 @@ with DAG(
create_cluster = RedshiftCreateClusterOperator(
task_id="create_cluster",
cluster_identifier=redshift_cluster_identifier,
- vpc_security_group_ids=[set_up_sg["security_group_id"]],
+ vpc_security_group_ids=[set_up_sg],
publicly_accessible=True,
cluster_type="single-node",
node_type="dc2.large",
@@ -145,7 +132,7 @@ with DAG(
cluster_identifier=redshift_cluster_identifier,
target_status="available",
poke_interval=15,
- timeout=60 * 30,
+ timeout=60 * 15,
)
# [END howto_sensor_redshift_cluster]
@@ -161,6 +148,14 @@ with DAG(
)
# [END howto_operator_redshift_create_cluster_snapshot]
+ wait_cluster_available_before_pause = RedshiftClusterSensor(
+ task_id="wait_cluster_available_before_pause",
+ cluster_identifier=redshift_cluster_identifier,
+ target_status="available",
+ poke_interval=15,
+ timeout=60 * 15,
+ )
+
# [START howto_operator_redshift_pause_cluster]
pause_cluster = RedshiftPauseClusterOperator(
task_id="pause_cluster",
@@ -173,7 +168,7 @@ with DAG(
cluster_identifier=redshift_cluster_identifier,
target_status="paused",
poke_interval=15,
- timeout=60 * 30,
+ timeout=60 * 15,
)
# [START howto_operator_redshift_resume_cluster]
@@ -188,7 +183,7 @@ with DAG(
cluster_identifier=redshift_cluster_identifier,
target_status="available",
poke_interval=15,
- timeout=60 * 30,
+ timeout=60 * 15,
)
set_up_connection = create_connection(conn_id_name,
cluster_id=redshift_cluster_identifier)
@@ -269,7 +264,7 @@ with DAG(
# [END howto_operator_redshift_delete_cluster_snapshot]
delete_sg = delete_security_group(
- sec_group_id=set_up_sg["security_group_id"],
+ sec_group_id=set_up_sg,
sec_group_name=sg_name,
)
chain(
@@ -280,7 +275,7 @@ with DAG(
create_cluster,
wait_cluster_available,
create_cluster_snapshot,
- await_cluster_snapshot(redshift_cluster_identifier),
+ wait_cluster_available_before_pause,
pause_cluster,
wait_cluster_paused,
resume_cluster,