hankehly commented on code in PR #27076:
URL: https://github.com/apache/airflow/pull/27076#discussion_r996526630
##########
airflow/providers/amazon/aws/operators/rds.py:
##########
@@ -672,6 +682,132 @@ def execute(self, context: Context) -> str:
return json.dumps(delete_db_instance, default=str)
+class RdsStartDbOperator(RdsBaseOperator):
+ """
+ Starts an RDS DB instance / cluster
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:RdsStartDbOperator`
+
+ :param db_identifier: The AWS identifier of the DB to start
+ :param db_type: Type of the DB - either "instance" or "cluster" (default:
"instance")
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
(default: "aws_default")
+ :param wait_for_completion: If True, waits for DB to start. (default:
True)
+ """
+
+ template_fields = ("db_identifier", "db_type")
+
+ def __init__(
+ self,
+ *,
+ db_identifier: str,
+ db_type: str = "instance",
+ aws_conn_id: str = "aws_default",
+ wait_for_completion: bool = True,
+ **kwargs,
+ ):
+ super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+ self.db_identifier = db_identifier
+ self.db_type = db_type
+ self.wait_for_completion = wait_for_completion
+
+ def execute(self, context: Context) -> str:
+ self.db_type = RdsDbType(self.db_type)
+ start_db_response = self._start_db()
+ if self.wait_for_completion:
+ self._wait_until_db_available()
+ return json.dumps(start_db_response, default=str)
+
+ def _start_db(self):
+ self.log.info(f"Starting DB {self.db_type} '{self.db_identifier}'")
+ if self.db_type == RdsDbType.INSTANCE:
+ response =
self.hook.conn.start_db_instance(DBInstanceIdentifier=self.db_identifier)
+ else:
+ response =
self.hook.conn.start_db_cluster(DBClusterIdentifier=self.db_identifier)
+ return response
+
+ def _wait_until_db_available(self):
+ self.log.info(f"Waiting for DB {self.db_type} to reach 'available'
state")
+ if self.db_type == RdsDbType.INSTANCE:
+
self.hook.conn.get_waiter("db_instance_available").wait(DBInstanceIdentifier=self.db_identifier)
+ else:
+
self.hook.conn.get_waiter("db_cluster_available").wait(DBClusterIdentifier=self.db_identifier)
+
+
+class RdsStopDbOperator(RdsBaseOperator):
+ """
+ Stops an RDS DB instance / cluster
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:RdsStopDbOperator`
+
+ :param db_identifier: The AWS identifier of the DB to stop
+ :param db_type: Type of the DB - either "instance" or "cluster" (default:
"instance")
+ :param db_snapshot_identifier: The instance identifier of the DB Snapshot
to create before
+ stopping the DB instance. The default value (None) skips snapshot
creation. This
+ parameter is ignored when ``db_type`` is "cluster"
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
(default: "aws_default")
+ :param wait_for_completion: If True, waits for DB to stop. (default: True)
+ """
+
+ template_fields = ("db_identifier", "db_type")
+
+ def __init__(
+ self,
+ *,
+ db_identifier: str,
+ db_type: str = "instance",
+ db_snapshot_identifier: str = None,
+ aws_conn_id: str = "aws_default",
+ wait_for_completion: bool = True,
+ **kwargs,
+ ):
+ super().__init__(aws_conn_id=aws_conn_id, **kwargs)
+ self.db_identifier = db_identifier
+ self.db_type = db_type
+ self.db_snapshot_identifier = db_snapshot_identifier
+ self.wait_for_completion = wait_for_completion
+
+ def execute(self, context: Context) -> str:
+ self.db_type = RdsDbType(self.db_type)
+ stop_db_response = self._stop_db()
+ if self.wait_for_completion:
+ self._wait_until_db_stopped()
+ return json.dumps(stop_db_response, default=str)
+
+ def _stop_db(self):
+ self.log.info(f"Stopping DB {self.db_type} '{self.db_identifier}'")
+ if self.db_type == RdsDbType.INSTANCE:
+ conn_params = {"DBInstanceIdentifier": self.db_identifier}
+ # The db snapshot parameter is optional, but the AWS SDK raises an
exception
+ # if passed a null value. Only set snapshot id if value is present.
+ if self.db_snapshot_identifier:
+ conn_params["DBSnapshotIdentifier"] =
self.db_snapshot_identifier
+ response = self.hook.conn.stop_db_instance(**conn_params)
+ else:
+ if self.db_snapshot_identifier:
+ self.log.warning(
+ "'db_snapshot_identifier' does not apply to db clusters.
Remove it to silence this warning."
+ )
+ response =
self.hook.conn.stop_db_cluster(DBClusterIdentifier=self.db_identifier)
+ return response
+
+ def _wait_until_db_stopped(self):
Review Comment:
It's used around lines 802-808. (Could there be a misunderstanding?)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]