This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 83ca61a501 Fix `RdsStopDbOperator` operator in deferrable mode (#41059)
83ca61a501 is described below

commit 83ca61a501d755669fc83b1ad9038d0ca9d600ad
Author: Vincent <97131062+vincb...@users.noreply.github.com>
AuthorDate: Fri Jul 26 16:52:25 2024 -0400

    Fix `RdsStopDbOperator` operator in deferrable mode (#41059)
---
 airflow/providers/amazon/aws/hooks/rds.py        |   6 +-
 airflow/providers/amazon/aws/operators/rds.py    |  37 ++--
 airflow/providers/amazon/aws/waiters/rds.json    | 253 +++++++++++++++++++++++
 tests/providers/amazon/aws/hooks/test_rds.py     |  24 +--
 tests/providers/amazon/aws/operators/test_rds.py |  14 +-
 5 files changed, 280 insertions(+), 54 deletions(-)

diff --git a/airflow/providers/amazon/aws/hooks/rds.py 
b/airflow/providers/amazon/aws/hooks/rds.py
index 8219a37757..588d78c782 100644
--- a/airflow/providers/amazon/aws/hooks/rds.py
+++ b/airflow/providers/amazon/aws/hooks/rds.py
@@ -259,7 +259,7 @@ class RdsHook(AwsGenericHook["RDSClient"]):
             return self.get_db_instance_state(db_instance_id)
 
         target_state = target_state.lower()
-        if target_state in ("available", "deleted"):
+        if target_state in ("available", "deleted", "stopped"):
             waiter = self.conn.get_waiter(f"db_instance_{target_state}")  # 
type: ignore
             wait(
                 waiter=waiter,
@@ -272,7 +272,7 @@ class RdsHook(AwsGenericHook["RDSClient"]):
             )
         else:
             self._wait_for_state(poke, target_state, check_interval, 
max_attempts)
-            self.log.info("DB cluster snapshot '%s' reached the '%s' state", 
db_instance_id, target_state)
+            self.log.info("DB cluster '%s' reached the '%s' state", 
db_instance_id, target_state)
 
     def get_db_cluster_state(self, db_cluster_id: str) -> str:
         """
@@ -310,7 +310,7 @@ class RdsHook(AwsGenericHook["RDSClient"]):
             return self.get_db_cluster_state(db_cluster_id)
 
         target_state = target_state.lower()
-        if target_state in ("available", "deleted"):
+        if target_state in ("available", "deleted", "stopped"):
             waiter = self.conn.get_waiter(f"db_cluster_{target_state}")  # 
type: ignore
             waiter.wait(
                 DBClusterIdentifier=db_cluster_id,
diff --git a/airflow/providers/amazon/aws/operators/rds.py 
b/airflow/providers/amazon/aws/operators/rds.py
index a2f35b5081..f37c698d87 100644
--- a/airflow/providers/amazon/aws/operators/rds.py
+++ b/airflow/providers/amazon/aws/operators/rds.py
@@ -36,6 +36,7 @@ from airflow.providers.amazon.aws.utils import 
validate_execute_complete_event
 from airflow.providers.amazon.aws.utils.rds import RdsDbType
 from airflow.providers.amazon.aws.utils.tags import format_tags
 from airflow.providers.amazon.aws.utils.waiter_with_logging import wait
+from airflow.utils.helpers import prune_dict
 
 if TYPE_CHECKING:
     from mypy_boto3_rds.type_defs import TagTypeDef
@@ -782,7 +783,7 @@ class RdsStartDbOperator(RdsBaseOperator):
                     aws_conn_id=self.aws_conn_id,
                     region_name=self.region_name,
                     response=start_db_response,
-                    db_type=RdsDbType.INSTANCE,
+                    db_type=self.db_type,
                 ),
                 method_name="execute_complete",
             )
@@ -881,12 +882,25 @@ class RdsStopDbOperator(RdsBaseOperator):
                     aws_conn_id=self.aws_conn_id,
                     region_name=self.region_name,
                     response=stop_db_response,
-                    db_type=RdsDbType.INSTANCE,
+                    db_type=self.db_type,
                 ),
                 method_name="execute_complete",
             )
         elif self.wait_for_completion:
-            self._wait_until_db_stopped()
+            waiter = self.hook.get_waiter(f"db_{self.db_type.value}_stopped")
+            waiter_key = (
+                "DBInstanceIdentifier" if self.db_type == RdsDbType.INSTANCE 
else "DBClusterIdentifier"
+            )
+            kwargs = {waiter_key: self.db_identifier}
+            waiter.wait(
+                WaiterConfig=prune_dict(
+                    {
+                        "Delay": self.waiter_delay,
+                        "MaxAttempts": self.waiter_max_attempts,
+                    }
+                ),
+                **kwargs,
+            )
         return json.dumps(stop_db_response, default=str)
 
     def execute_complete(self, context: Context, event: dict[str, Any] | None 
= None) -> str:
@@ -915,23 +929,6 @@ class RdsStopDbOperator(RdsBaseOperator):
             response = 
self.hook.conn.stop_db_cluster(DBClusterIdentifier=self.db_identifier)
         return response
 
-    def _wait_until_db_stopped(self):
-        self.log.info("Waiting for DB %s to reach 'stopped' state", 
self.db_type.value)
-        if self.db_type == RdsDbType.INSTANCE:
-            self.hook.wait_for_db_instance_state(
-                self.db_identifier,
-                target_state="stopped",
-                check_interval=self.waiter_delay,
-                max_attempts=self.waiter_max_attempts,
-            )
-        else:
-            self.hook.wait_for_db_cluster_state(
-                self.db_identifier,
-                target_state="stopped",
-                check_interval=self.waiter_delay,
-                max_attempts=self.waiter_max_attempts,
-            )
-
 
 __all__ = [
     "RdsCreateDbSnapshotOperator",
diff --git a/airflow/providers/amazon/aws/waiters/rds.json 
b/airflow/providers/amazon/aws/waiters/rds.json
new file mode 100644
index 0000000000..78c56da53f
--- /dev/null
+++ b/airflow/providers/amazon/aws/waiters/rds.json
@@ -0,0 +1,253 @@
+{
+    "version": 2,
+    "waiters": {
+        "db_instance_stopped": {
+            "operation": "DescribeDBInstances",
+            "delay": 30,
+            "maxAttempts": 60,
+            "acceptors": [
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "stopped",
+                    "state": "success"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "available",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "backing-up",
+                    "state": "retry"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "creating",
+                    "state": "retry"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "delete-precheck",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "deleting",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "failed",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "inaccessible-encryption-credentials",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": 
"inaccessible-encryption-credentials-recoverable",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "incompatible-network",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "incompatible-option-group",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "incompatible-parameters",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "incompatible-restore",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "insufficient-capacity",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "maintenance",
+                    "state": "retry"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "modifying",
+                    "state": "retry"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "rebooting",
+                    "state": "retry"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "renaming",
+                    "state": "retry"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "restore-error",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "starting",
+                    "state": "retry"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "stopping",
+                    "state": "retry"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "storage-full",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "upgrading",
+                    "state": "retry"
+                }
+            ]
+        },
+        "db_cluster_stopped": {
+            "operation": "DescribeDBClusters",
+            "delay": 30,
+            "maxAttempts": 60,
+            "acceptors": [
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBClusters[].Status",
+                    "expected": "stopped",
+                    "state": "success"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "available",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "backing-up",
+                    "state": "retry"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "cloning-failed",
+                    "state": "retry"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "creating",
+                    "state": "retry"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "deleting",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "failing-over",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "inaccessible-encryption-credentials",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": 
"inaccessible-encryption-credentials-recoverable",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "maintenance",
+                    "state": "retry"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "migrating",
+                    "state": "retry"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "migration-failed",
+                    "state": "failure"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "modifying",
+                    "state": "retry"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "renaming",
+                    "state": "retry"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "starting",
+                    "state": "retry"
+                },
+                {
+                    "matcher": "pathAll",
+                    "argument": "DBInstances[].DBInstanceStatus",
+                    "expected": "stopping",
+                    "state": "retry"
+                }
+            ]
+        }
+    }
+}
diff --git a/tests/providers/amazon/aws/hooks/test_rds.py 
b/tests/providers/amazon/aws/hooks/test_rds.py
index b2668febfb..77159b4e14 100644
--- a/tests/providers/amazon/aws/hooks/test_rds.py
+++ b/tests/providers/amazon/aws/hooks/test_rds.py
@@ -150,7 +150,7 @@ class TestRdsHook:
 
     def test_wait_for_db_instance_state_boto_waiters(self, rds_hook: RdsHook, 
db_instance_id: str):
         """Checks that the DB instance waiter uses AWS boto waiters where 
possible"""
-        for state in ("available", "deleted"):
+        for state in ("available", "deleted", "stopped"):
             with patch.object(rds_hook.conn, "get_waiter") as mock:
                 rds_hook.wait_for_db_instance_state(db_instance_id, 
target_state=state, **self.waiter_args)
                 mock.assert_called_once_with(f"db_instance_{state}")
@@ -161,16 +161,6 @@ class TestRdsHook:
                     },
                 )
 
-    def test_wait_for_db_instance_state_custom_waiter(self, rds_hook: RdsHook, 
db_instance_id: str):
-        """Checks that the DB instance waiter uses custom wait logic when AWS 
boto waiters aren't available"""
-        with patch.object(rds_hook, "_wait_for_state") as mock:
-            rds_hook.wait_for_db_instance_state(db_instance_id, 
target_state="stopped", **self.waiter_args)
-            mock.assert_called_once()
-
-        with patch.object(rds_hook, "get_db_instance_state", 
return_value="stopped") as mock:
-            rds_hook.wait_for_db_instance_state(db_instance_id, 
target_state="stopped", **self.waiter_args)
-            mock.assert_called_once_with(db_instance_id)
-
     def test_get_db_cluster_state(self, rds_hook: RdsHook, db_cluster_id: str):
         response = 
rds_hook.conn.describe_db_clusters(DBClusterIdentifier=db_cluster_id)
         state_expected = response["DBClusters"][0]["Status"]
@@ -179,7 +169,7 @@ class TestRdsHook:
 
     def test_wait_for_db_cluster_state_boto_waiters(self, rds_hook: RdsHook, 
db_cluster_id: str):
         """Checks that the DB cluster waiter uses AWS boto waiters where 
possible"""
-        for state in ("available", "deleted"):
+        for state in ("available", "deleted", "stopped"):
             with patch.object(rds_hook.conn, "get_waiter") as mock:
                 rds_hook.wait_for_db_cluster_state(db_cluster_id, 
target_state=state, **self.waiter_args)
                 mock.assert_called_once_with(f"db_cluster_{state}")
@@ -191,16 +181,6 @@ class TestRdsHook:
                     },
                 )
 
-    def test_wait_for_db_cluster_state_custom_waiter(self, rds_hook: RdsHook, 
db_cluster_id: str):
-        """Checks that the DB cluster waiter uses custom wait logic when AWS 
boto waiters aren't available"""
-        with patch.object(rds_hook, "_wait_for_state") as mock_wait_for_state:
-            rds_hook.wait_for_db_cluster_state(db_cluster_id, 
target_state="stopped", **self.waiter_args)
-            mock_wait_for_state.assert_called_once()
-
-        with patch.object(rds_hook, "get_db_cluster_state", 
return_value="stopped") as mock:
-            rds_hook.wait_for_db_cluster_state(db_cluster_id, 
target_state="stopped", **self.waiter_args)
-            mock.assert_called_once_with(db_cluster_id)
-
     def test_get_db_snapshot_state(self, rds_hook: RdsHook, db_snapshot_id: 
str):
         response = 
rds_hook.conn.describe_db_snapshots(DBSnapshotIdentifier=db_snapshot_id)
         state_expected = response["DBSnapshots"][0]["Status"]
diff --git a/tests/providers/amazon/aws/operators/test_rds.py 
b/tests/providers/amazon/aws/operators/test_rds.py
index fd464019dd..651db53d42 100644
--- a/tests/providers/amazon/aws/operators/test_rds.py
+++ b/tests/providers/amazon/aws/operators/test_rds.py
@@ -813,8 +813,7 @@ class TestRdsStopDbOperator:
         del cls.hook
 
     @mock_aws
-    @patch.object(RdsHook, "wait_for_db_instance_state")
-    def test_stop_db_instance(self, mock_await_status):
+    def test_stop_db_instance(self):
         _create_db_instance(self.hook)
         stop_db_instance = RdsStopDbOperator(task_id="test_stop_db_instance", 
db_identifier=DB_INSTANCE_NAME)
         _patch_hook_get_connection(stop_db_instance.hook)
@@ -822,11 +821,10 @@ class TestRdsStopDbOperator:
         result = 
self.hook.conn.describe_db_instances(DBInstanceIdentifier=DB_INSTANCE_NAME)
         status = result["DBInstances"][0]["DBInstanceStatus"]
         assert status == "stopped"
-        mock_await_status.assert_called()
 
     @mock_aws
-    @patch.object(RdsHook, "wait_for_db_instance_state")
-    def test_stop_db_instance_no_wait(self, mock_await_status):
+    @patch.object(RdsHook, "get_waiter")
+    def test_stop_db_instance_no_wait(self, mock_get_waiter):
         _create_db_instance(self.hook)
         stop_db_instance = RdsStopDbOperator(
             task_id="test_stop_db_instance_no_wait", 
db_identifier=DB_INSTANCE_NAME, wait_for_completion=False
@@ -836,7 +834,7 @@ class TestRdsStopDbOperator:
         result = 
self.hook.conn.describe_db_instances(DBInstanceIdentifier=DB_INSTANCE_NAME)
         status = result["DBInstances"][0]["DBInstanceStatus"]
         assert status == "stopped"
-        mock_await_status.assert_not_called()
+        mock_get_waiter.assert_not_called()
 
     @mock.patch.object(RdsHook, "conn")
     def test_deferred(self, conn_mock):
@@ -872,8 +870,7 @@ class TestRdsStopDbOperator:
         assert len(instance_snapshots) == 1
 
     @mock_aws
-    @patch.object(RdsHook, "wait_for_db_cluster_state")
-    def test_stop_db_cluster(self, mock_await_status):
+    def test_stop_db_cluster(self):
         _create_db_cluster(self.hook)
         stop_db_cluster = RdsStopDbOperator(
             task_id="test_stop_db_cluster", db_identifier=DB_CLUSTER_NAME, 
db_type="cluster"
@@ -884,7 +881,6 @@ class TestRdsStopDbOperator:
         describe_result = 
self.hook.conn.describe_db_clusters(DBClusterIdentifier=DB_CLUSTER_NAME)
         status = describe_result["DBClusters"][0]["Status"]
         assert status == "stopped"
-        mock_await_status.assert_called()
 
     @mock_aws
     def test_stop_db_cluster_create_snapshot_logs_warning_message(self, 
caplog):

Reply via email to