This is an automated email from the ASF dual-hosted git repository.

pankajkoti pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e07a42e69d Check cluster state before defer Dataproc operators to 
trigger (#36892)
e07a42e69d is described below

commit e07a42e69d1ab472c4da991fca5782990607ebe0
Author: Wei Lee <weilee...@gmail.com>
AuthorDate: Mon Jan 22 14:32:00 2024 +0800

    Check cluster state before defer Dataproc operators to trigger (#36892)
    
    While operating a data proc cluster in deferrable mode, the condition might 
already be met (created, deleted, updated) before we defer the task into the 
trigger. This PR intends to check thecluster status before deferring the task 
to trigger.
    ---------
    
    Co-authored-by: Pankaj Koti <pankajkoti...@gmail.com>
---
 .../providers/google/cloud/operators/dataproc.py   |  63 ++++++---
 .../google/cloud/operators/test_dataproc.py        | 146 ++++++++++++++++++++-
 2 files changed, 185 insertions(+), 24 deletions(-)

diff --git a/airflow/providers/google/cloud/operators/dataproc.py 
b/airflow/providers/google/cloud/operators/dataproc.py
index 306e0dc03d..b14121139d 100644
--- a/airflow/providers/google/cloud/operators/dataproc.py
+++ b/airflow/providers/google/cloud/operators/dataproc.py
@@ -721,6 +721,7 @@ class 
DataprocCreateClusterOperator(GoogleCloudBaseOperator):
     def execute(self, context: Context) -> dict:
         self.log.info("Creating cluster: %s", self.cluster_name)
         hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, 
impersonation_chain=self.impersonation_chain)
+
         # Save data required to display extra link no matter what the cluster 
status will be
         project_id = self.project_id or hook.project_id
         if project_id:
@@ -731,6 +732,7 @@ class 
DataprocCreateClusterOperator(GoogleCloudBaseOperator):
                 project_id=project_id,
                 region=self.region,
             )
+
         try:
             # First try to create a new cluster
             operation = self._create_cluster(hook)
@@ -741,17 +743,24 @@ class 
DataprocCreateClusterOperator(GoogleCloudBaseOperator):
                 self.log.info("Cluster created.")
                 return Cluster.to_dict(cluster)
             else:
-                self.defer(
-                    trigger=DataprocClusterTrigger(
-                        cluster_name=self.cluster_name,
-                        project_id=self.project_id,
-                        region=self.region,
-                        gcp_conn_id=self.gcp_conn_id,
-                        impersonation_chain=self.impersonation_chain,
-                        polling_interval_seconds=self.polling_interval_seconds,
-                    ),
-                    method_name="execute_complete",
+                cluster = hook.get_cluster(
+                    project_id=self.project_id, region=self.region, 
cluster_name=self.cluster_name
                 )
+                if cluster.status.state == cluster.status.State.RUNNING:
+                    self.log.info("Cluster created.")
+                    return Cluster.to_dict(cluster)
+                else:
+                    self.defer(
+                        trigger=DataprocClusterTrigger(
+                            cluster_name=self.cluster_name,
+                            project_id=self.project_id,
+                            region=self.region,
+                            gcp_conn_id=self.gcp_conn_id,
+                            impersonation_chain=self.impersonation_chain,
+                            
polling_interval_seconds=self.polling_interval_seconds,
+                        ),
+                        method_name="execute_complete",
+                    )
         except AlreadyExists:
             if not self.use_if_exists:
                 raise
@@ -1016,6 +1025,16 @@ class 
DataprocDeleteClusterOperator(GoogleCloudBaseOperator):
             hook.wait_for_operation(timeout=self.timeout, 
result_retry=self.retry, operation=operation)
             self.log.info("Cluster deleted.")
         else:
+            try:
+                hook.get_cluster(
+                    project_id=self.project_id, region=self.region, 
cluster_name=self.cluster_name
+                )
+            except NotFound:
+                self.log.info("Cluster deleted.")
+                return
+            except Exception as e:
+                raise AirflowException(str(e))
+
             end_time: float = time.time() + self.timeout
             self.defer(
                 trigger=DataprocDeleteClusterTrigger(
@@ -2480,17 +2499,21 @@ class 
DataprocUpdateClusterOperator(GoogleCloudBaseOperator):
         if not self.deferrable:
             hook.wait_for_operation(timeout=self.timeout, 
result_retry=self.retry, operation=operation)
         else:
-            self.defer(
-                trigger=DataprocClusterTrigger(
-                    cluster_name=self.cluster_name,
-                    project_id=self.project_id,
-                    region=self.region,
-                    gcp_conn_id=self.gcp_conn_id,
-                    impersonation_chain=self.impersonation_chain,
-                    polling_interval_seconds=self.polling_interval_seconds,
-                ),
-                method_name="execute_complete",
+            cluster = hook.get_cluster(
+                project_id=self.project_id, region=self.region, 
cluster_name=self.cluster_name
             )
+            if cluster.status.state != cluster.status.State.RUNNING:
+                self.defer(
+                    trigger=DataprocClusterTrigger(
+                        cluster_name=self.cluster_name,
+                        project_id=self.project_id,
+                        region=self.region,
+                        gcp_conn_id=self.gcp_conn_id,
+                        impersonation_chain=self.impersonation_chain,
+                        polling_interval_seconds=self.polling_interval_seconds,
+                    ),
+                    method_name="execute_complete",
+                )
         self.log.info("Updated %s cluster.", self.cluster_name)
 
     def execute_complete(self, context: Context, event: dict[str, Any]) -> Any:
diff --git a/tests/providers/google/cloud/operators/test_dataproc.py 
b/tests/providers/google/cloud/operators/test_dataproc.py
index 59a9c1008c..00f45ca8b3 100644
--- a/tests/providers/google/cloud/operators/test_dataproc.py
+++ b/tests/providers/google/cloud/operators/test_dataproc.py
@@ -23,7 +23,8 @@ from unittest.mock import MagicMock, Mock, call
 import pytest
 from google.api_core.exceptions import AlreadyExists, NotFound
 from google.api_core.retry import Retry
-from google.cloud.dataproc_v1 import Batch, JobStatus
+from google.cloud import dataproc
+from google.cloud.dataproc_v1 import Batch, Cluster, JobStatus
 
 from airflow.exceptions import (
     AirflowException,
@@ -579,7 +580,7 @@ class TestsClusterGenerator:
         assert CONFIG_WITH_FLEX_MIG == cluster
 
 
-class TestDataprocClusterCreateOperator(DataprocClusterTestBase):
+class TestDataprocCreateClusterOperator(DataprocClusterTestBase):
     def test_deprecation_warning(self):
         with pytest.warns(AirflowProviderDeprecationWarning) as warnings:
             op = DataprocCreateClusterOperator(
@@ -883,6 +884,54 @@ class 
TestDataprocClusterCreateOperator(DataprocClusterTestBase):
         assert isinstance(exc.value.trigger, DataprocClusterTrigger)
         assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
 
+    @mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator.defer"))
+    @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+    @mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook"))
+    def test_create_execute_call_finished_before_defer(self, 
mock_trigger_hook, mock_hook, mock_defer):
+        cluster = Cluster(
+            cluster_name="test_cluster",
+            
status=dataproc.ClusterStatus(state=dataproc.ClusterStatus.State.RUNNING),
+        )
+        mock_hook.return_value.create_cluster.return_value = cluster
+        mock_hook.return_value.get_cluster.return_value = cluster
+        operator = DataprocCreateClusterOperator(
+            task_id=TASK_ID,
+            region=GCP_REGION,
+            project_id=GCP_PROJECT,
+            cluster_config=CONFIG,
+            labels=LABELS,
+            cluster_name=CLUSTER_NAME,
+            delete_on_error=True,
+            metadata=METADATA,
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            deferrable=True,
+        )
+
+        operator.execute(mock.MagicMock())
+        assert not mock_defer.called
+
+        mock_hook.assert_called_once_with(
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+
+        mock_hook.return_value.create_cluster.assert_called_once_with(
+            region=GCP_REGION,
+            project_id=GCP_PROJECT,
+            cluster_config=CONFIG,
+            request_id=None,
+            labels=LABELS,
+            cluster_name=CLUSTER_NAME,
+            virtual_cluster_config=None,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+        mock_hook.return_value.wait_for_operation.assert_not_called()
+
 
 @pytest.mark.db_test
 @pytest.mark.need_serialized_dag
@@ -1100,6 +1149,47 @@ class TestDataprocClusterDeleteOperator:
         assert isinstance(exc.value.trigger, DataprocDeleteClusterTrigger)
         assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
 
+    @mock.patch(DATAPROC_PATH.format("DataprocDeleteClusterOperator.defer"))
+    @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+    @mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook"))
+    def test_create_execute_call_finished_before_defer(self, 
mock_trigger_hook, mock_hook, mock_defer):
+        mock_hook.return_value.create_cluster.return_value = None
+        mock_hook.return_value.get_cluster.side_effect = NotFound("test")
+        operator = DataprocDeleteClusterOperator(
+            task_id=TASK_ID,
+            region=GCP_REGION,
+            project_id=GCP_PROJECT,
+            cluster_name=CLUSTER_NAME,
+            request_id=REQUEST_ID,
+            gcp_conn_id=GCP_CONN_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+            impersonation_chain=IMPERSONATION_CHAIN,
+            deferrable=True,
+        )
+
+        operator.execute(mock.MagicMock())
+
+        mock_hook.assert_called_once_with(
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+
+        mock_hook.return_value.delete_cluster.assert_called_once_with(
+            project_id=GCP_PROJECT,
+            region=GCP_REGION,
+            cluster_name=CLUSTER_NAME,
+            cluster_uuid=None,
+            request_id=REQUEST_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+
+        mock_hook.return_value.wait_for_operation.assert_not_called()
+        assert not mock_defer.called
+
 
 class TestDataprocSubmitJobOperator(DataprocJobTestBase):
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
@@ -1240,8 +1330,8 @@ class TestDataprocSubmitJobOperator(DataprocJobTestBase):
         assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
 
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
-    
@mock.patch("airflow.providers.google.cloud.operators.dataproc.DataprocSubmitJobOperator.defer")
-    
@mock.patch("airflow.providers.google.cloud.operators.dataproc.DataprocHook.submit_job")
+    @mock.patch(DATAPROC_PATH.format("DataprocSubmitJobOperator.defer"))
+    @mock.patch(DATAPROC_PATH.format("DataprocHook.submit_job"))
     def test_dataproc_operator_execute_async_done_before_defer(self, 
mock_submit_job, mock_defer, mock_hook):
         mock_submit_job.return_value.reference.job_id = TEST_JOB_ID
         job_status = mock_hook.return_value.get_job.return_value.status
@@ -1498,6 +1588,54 @@ class 
TestDataprocUpdateClusterOperator(DataprocClusterTestBase):
         assert isinstance(exc.value.trigger, DataprocClusterTrigger)
         assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
 
+    @mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator.defer"))
+    @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+    @mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook"))
+    def test_create_execute_call_finished_before_defer(self, 
mock_trigger_hook, mock_hook, mock_defer):
+        cluster = Cluster(
+            cluster_name="test_cluster",
+            
status=dataproc.ClusterStatus(state=dataproc.ClusterStatus.State.RUNNING),
+        )
+        mock_hook.return_value.update_cluster.return_value = cluster
+        mock_hook.return_value.get_cluster.return_value = cluster
+        operator = DataprocUpdateClusterOperator(
+            task_id=TASK_ID,
+            region=GCP_REGION,
+            cluster_name=CLUSTER_NAME,
+            cluster=CLUSTER,
+            update_mask=UPDATE_MASK,
+            request_id=REQUEST_ID,
+            graceful_decommission_timeout={"graceful_decommission_timeout": 
"600s"},
+            project_id=GCP_PROJECT,
+            gcp_conn_id=GCP_CONN_ID,
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+            impersonation_chain=IMPERSONATION_CHAIN,
+            deferrable=True,
+        )
+
+        operator.execute(mock.MagicMock())
+
+        mock_hook.assert_called_once_with(
+            gcp_conn_id=GCP_CONN_ID,
+            impersonation_chain=IMPERSONATION_CHAIN,
+        )
+        mock_hook.return_value.update_cluster.assert_called_once_with(
+            project_id=GCP_PROJECT,
+            region=GCP_REGION,
+            cluster_name=CLUSTER_NAME,
+            cluster=CLUSTER,
+            update_mask=UPDATE_MASK,
+            request_id=REQUEST_ID,
+            graceful_decommission_timeout={"graceful_decommission_timeout": 
"600s"},
+            retry=RETRY,
+            timeout=TIMEOUT,
+            metadata=METADATA,
+        )
+        mock_hook.return_value.wait_for_operation.assert_not_called()
+        assert not mock_defer.called
+
 
 @pytest.mark.db_test
 @pytest.mark.need_serialized_dag

Reply via email to