This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 4a8a803d678 Refactor non-deferrable execution flow to ensure cluster 
state reconciliation runs after creation completes. (#61951)
4a8a803d678 is described below

commit 4a8a803d678b1b3126b51af974c3bc8e51044ce6
Author: SameerMesiah97 <[email protected]>
AuthorDate: Tue Mar 10 20:19:52 2026 +0000

    Refactor non-deferrable execution flow to ensure cluster state 
reconciliation runs after creation completes. (#61951)
    
    – Extract reconciliation logic into `_reconcile_cluster_state()`
    – Ensure DELETING state waits for deletion and re-creates the cluster
    – Ensure CREATING state is fully reconciled before returning
    – Handle STOPPED state via restart path
    – Raise explicit exception if cluster is not found after LRO completion
    – Return reconciled cluster to avoid stale state
    
    Update and extend unit tests to cover reconciliation scenarios in the 
non-deferrable path (CREATING, DELETING, STOPPED, ERROR, and timeout cases).
    
    Co-authored-by: Sameer Mesiah <[email protected]>
---
 .../providers/google/cloud/operators/dataproc.py   | 134 +++++++++------
 .../unit/google/cloud/operators/test_dataproc.py   | 180 ++++++++++++++++++++-
 2 files changed, 258 insertions(+), 56 deletions(-)

diff --git 
a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py 
b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
index 7b08af16a98..6dc6e0ab775 100644
--- a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
+++ b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
@@ -811,11 +811,47 @@ class 
DataprocCreateClusterOperator(GoogleCloudBaseOperator):
         self.log.info("Cluster created.")
         return Cluster.to_dict(cluster)
 
+    def _reconcile_cluster_state(self, hook: DataprocHook, cluster: Cluster) 
-> Cluster:
+
+        if cluster.status.state == cluster.status.State.CREATING:
+            self.log.info("Cluster %s is in CREATING state.", 
self.cluster_name)
+
+            cluster = self._wait_for_cluster_in_creating_state(hook)
+            self._handle_error_state(hook, cluster)
+        elif cluster.status.state == cluster.status.State.DELETING:
+            self.log.info("Cluster %s is in DELETING state.", 
self.cluster_name)
+
+            self._wait_for_cluster_in_deleting_state(hook)
+
+            self.log.info("Attempting to re-create cluster: %s", 
self.cluster_name)
+
+            operation = self._create_cluster(hook)
+            hook.wait_for_operation(
+                timeout=self.timeout,
+                result_retry=self.retry,
+                operation=operation,
+            )
+            cluster = self._get_cluster(hook)
+
+            self._handle_error_state(hook, cluster)
+        elif cluster.status.state == cluster.status.State.STOPPED:
+            self.log.info("Cluster %s is in STOPPED state.", self.cluster_name)
+
+            self.log.info("Attempting to re-start cluster: %s", 
self.cluster_name)
+
+            # _start_cluster waits for the operation to complete.
+            self._start_cluster(hook)
+
+            cluster = self._get_cluster(hook)
+
+        return cluster
+
     def execute(self, context: Context) -> dict:
-        self.log.info("Creating cluster: %s", self.cluster_name)
+
+        self.log.info("Attempting to create cluster: %s", self.cluster_name)
         hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, 
impersonation_chain=self.impersonation_chain)
 
-        # Save data required to display extra link no matter what the cluster 
status will be
+        # Save data required to display extra link regardless of the cluster 
status.
         project_id = self.project_id or hook.project_id
         if project_id:
             DataprocClusterLink.persist(
@@ -826,37 +862,40 @@ class 
DataprocCreateClusterOperator(GoogleCloudBaseOperator):
             )
 
         try:
-            # First try to create a new cluster
             operation = self._create_cluster(hook)
-            if not self.deferrable and type(operation) is not str:
-                cluster = hook.wait_for_operation(
-                    timeout=self.timeout, result_retry=self.retry, 
operation=operation
+
+            if not self.deferrable and not isinstance(operation, str):
+                hook.wait_for_operation(
+                    timeout=self.timeout,
+                    result_retry=self.retry,
+                    operation=operation,
                 )
-                self.log.info("Cluster created.")
-                return Cluster.to_dict(cluster)
-            cluster = hook.get_cluster(
-                project_id=self.project_id, region=self.region, 
cluster_name=self.cluster_name
-            )
-            if cluster.status.state == cluster.status.State.RUNNING:
-                self.log.info("Cluster created.")
-                return Cluster.to_dict(cluster)
-            self.defer(
-                trigger=DataprocClusterTrigger(
-                    cluster_name=self.cluster_name,
-                    project_id=self.project_id,
-                    region=self.region,
-                    gcp_conn_id=self.gcp_conn_id,
-                    impersonation_chain=self.impersonation_chain,
-                    polling_interval_seconds=self.polling_interval_seconds,
-                    delete_on_error=self.delete_on_error,
-                ),
-                method_name="execute_complete",
-            )
+
+            # Fetch current state.
+            cluster = self._get_cluster(hook)
+
+            if self.deferrable:
+                if cluster.status.state != cluster.status.State.RUNNING:
+                    self.defer(
+                        trigger=DataprocClusterTrigger(
+                            cluster_name=self.cluster_name,
+                            project_id=self.project_id,
+                            region=self.region,
+                            gcp_conn_id=self.gcp_conn_id,
+                            impersonation_chain=self.impersonation_chain,
+                            
polling_interval_seconds=self.polling_interval_seconds,
+                            delete_on_error=self.delete_on_error,
+                        ),
+                        method_name="execute_complete",
+                    )
+                    return
+
         except AlreadyExists:
             if not self.use_if_exists:
                 raise
-            self.log.info("Cluster already exists.")
+            self.log.info("Cluster %s already exists.", self.cluster_name)
             cluster = self._get_cluster(hook)
+
         except DataprocResourceIsNotReadyError as resource_not_ready_error:
             if self.num_retries_if_resource_is_not_ready:
                 attempt = self.num_retries_if_resource_is_not_ready
@@ -876,39 +915,28 @@ class 
DataprocCreateClusterOperator(GoogleCloudBaseOperator):
                 self._delete_cluster(hook)
                 self._wait_for_cluster_in_deleting_state(hook)
             raise resource_not_ready_error
-        except AirflowException as ae:
-            # There still could be a cluster created here in an ERROR state 
which
-            # should be deleted immediately rather than consuming another 
retry attempt
-            # (assuming delete_on_error is true (default))
-            # This reduces overall the number of task attempts from 3 to 2 to 
successful cluster creation
-            # assuming the underlying GCE issues have resolved within that 
window. Users can configure
-            # a higher number of retry attempts in powers of two with 30s-60s 
wait interval
+
+        except AirflowException as outer_airflow_exception:
+            # A cluster may have been created but entered ERROR state.
+            # If delete_on_error is enabled, delete it immediately so that
+            # the next retry attempt starts from a clean state.
             try:
                 cluster = self._get_cluster(hook)
                 self._handle_error_state(hook, cluster)
-            except AirflowException as ae_inner:
-                # We could get any number of failures here, including cluster 
not found and we
-                # can just ignore to ensure we surface the original cluster 
create failure
-                self.log.exception(ae_inner)
+            except AirflowException as inner_airflow_exception:
+                # Cleanup logic may raise secondary exceptions (e.g., cluster 
not found).
+                # Suppress those so that the original cluster creation failure 
is surfaced.
+                self.log.exception(inner_airflow_exception)
             finally:
-                raise ae
+                raise outer_airflow_exception
 
-        # Check if cluster is not in ERROR state
+        # Check if cluster is not in ERROR state.
         self._handle_error_state(hook, cluster)
-        if cluster.status.state == cluster.status.State.CREATING:
-            # Wait for cluster to be created
-            cluster = self._wait_for_cluster_in_creating_state(hook)
-            self._handle_error_state(hook, cluster)
-        elif cluster.status.state == cluster.status.State.DELETING:
-            # Wait for cluster to be deleted
-            self._wait_for_cluster_in_deleting_state(hook)
-            # Create new cluster
-            cluster = self._create_cluster(hook)
-            self._handle_error_state(hook, cluster)
-        elif cluster.status.state == cluster.status.State.STOPPED:
-            # if the cluster exists and already stopped, then start the cluster
-            self._start_cluster(hook)
 
+        # If cluster is not in RUNNING state, reconcile.
+        cluster = self._reconcile_cluster_state(hook, cluster)
+
+        self.log.info("Cluster %s is RUNNING.", self.cluster_name)
         return Cluster.to_dict(cluster)
 
     def execute_complete(self, context: Context, event: dict[str, Any]) -> Any:
diff --git 
a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py 
b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
index 3db5f497e4d..d284c875548 100644
--- a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
+++ b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
@@ -826,7 +826,7 @@ class 
TestDataprocCreateClusterOperator(DataprocClusterTestBase):
         # Test whether xcom push occurs before create cluster is called
         self.extra_links_manager_mock.assert_has_calls(expected_calls, 
any_order=False)
 
-        to_dict_mock.assert_called_once_with(mock_hook().wait_for_operation())
+        
to_dict_mock.assert_called_once_with(mock_hook.return_value.get_cluster.return_value)
         if AIRFLOW_V_3_0_PLUS:
             self.mock_ti.xcom_push.assert_called_once_with(
                 key="dataproc_cluster",
@@ -881,7 +881,7 @@ class 
TestDataprocCreateClusterOperator(DataprocClusterTestBase):
         # Test whether xcom push occurs before create cluster is called
         self.extra_links_manager_mock.assert_has_calls(expected_calls, 
any_order=False)
 
-        to_dict_mock.assert_called_once_with(mock_hook().wait_for_operation())
+        
to_dict_mock.assert_called_once_with(mock_hook.return_value.get_cluster.return_value)
         if AIRFLOW_V_3_0_PLUS:
             self.mock_ti.xcom_push.assert_called_once_with(
                 key="dataproc_cluster",
@@ -1015,7 +1015,7 @@ class 
TestDataprocCreateClusterOperator(DataprocClusterTestBase):
 
         mock_create_cluster.side_effect = [AlreadyExists("test"), 
cluster_running]
         mock_generator.return_value = [0]
-        mock_get_cluster.side_effect = [cluster_deleting, NotFound("test")]
+        mock_get_cluster.side_effect = [cluster_deleting, NotFound("test"), 
cluster_running]
 
         op = DataprocCreateClusterOperator(
             task_id=TASK_ID,
@@ -1035,6 +1035,180 @@ class 
TestDataprocCreateClusterOperator(DataprocClusterTestBase):
 
         to_dict_mock.assert_called_once_with(cluster_running)
 
+    @mock.patch(DATAPROC_PATH.format("Cluster.to_dict"))
+    
@mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator._wait_for_cluster_in_deleting_state"))
+    
@mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator._get_cluster"))
+    @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+    def test_execute_recreates_when_deleted_during_creation(
+        self,
+        mock_hook,
+        mock_get_cluster,
+        mock_wait_for_deleting,
+        to_dict_mock,
+    ):
+        mock_hook.return_value.wait_for_operation.return_value = None
+
+        # First invocation of get_cluster should return cluster in DELETING 
state.
+        cluster_deleting = mock.MagicMock()
+        cluster_deleting.status.state = cluster_deleting.status.State.DELETING
+
+        # Re-creation should return cluster in RUNNING state.
+        cluster_running = mock.MagicMock()
+        cluster_running.status.state = cluster_running.status.State.RUNNING
+
+        mock_get_cluster.side_effect = [
+            cluster_deleting,
+            cluster_running,
+        ]
+
+        op = DataprocCreateClusterOperator(
+            task_id=TASK_ID,
+            region=GCP_REGION,
+            project_id=GCP_PROJECT,
+            cluster_name=CLUSTER_NAME,
+            cluster_config=CONFIG,
+            deferrable=False,
+        )
+
+        op.execute(context=mock.MagicMock())
+
+        # Ensure re-creation path is traversed.
+        assert mock_wait_for_deleting.called
+        assert mock_hook.return_value.create_cluster.call_count == 2
+
+        to_dict_mock.assert_called_once_with(cluster_running)
+
+    
@mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator._wait_for_cluster_in_deleting_state"))
+    
@mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator._get_cluster"))
+    @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+    def test_execute_deleting_timeout_raises(
+        self,
+        mock_hook,
+        mock_get_cluster,
+        mock_wait_for_deleting,
+    ):
+        mock_hook.return_value.wait_for_operation.return_value = None
+
+        cluster_deleting = mock.MagicMock()
+        cluster_deleting.status.state = cluster_deleting.status.State.DELETING
+
+        mock_get_cluster.return_value = cluster_deleting
+        mock_wait_for_deleting.side_effect = AirflowException("Timeout")
+
+        op = DataprocCreateClusterOperator(
+            task_id=TASK_ID,
+            region=GCP_REGION,
+            project_id=GCP_PROJECT,
+            cluster_name=CLUSTER_NAME,
+            cluster_config=CONFIG,
+            deferrable=False,
+        )
+
+        with pytest.raises(AirflowException):
+            op.execute(context=mock.MagicMock())
+
+        # Ensure no re-creation is attempted.
+        assert mock_hook.return_value.create_cluster.call_count == 1
+
+    @mock.patch(DATAPROC_PATH.format("Cluster.to_dict"))
+    
@mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator._wait_for_cluster_in_creating_state"))
+    
@mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator._get_cluster"))
+    @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+    def test_execute_waits_when_still_creating(
+        self,
+        mock_hook,
+        mock_get_cluster,
+        mock_wait_for_creating,
+        to_dict_mock,
+    ):
+        mock_hook.return_value.wait_for_operation.return_value = None
+
+        cluster_creating = mock.MagicMock()
+        cluster_creating.status.state = cluster_creating.status.State.CREATING
+
+        cluster_running = mock.MagicMock()
+        cluster_running.status.state = cluster_running.status.State.RUNNING
+
+        mock_get_cluster.return_value = cluster_creating
+        mock_wait_for_creating.return_value = cluster_running
+
+        op = DataprocCreateClusterOperator(
+            task_id=TASK_ID,
+            region=GCP_REGION,
+            project_id=GCP_PROJECT,
+            cluster_name=CLUSTER_NAME,
+            cluster_config=CONFIG,
+            deferrable=False,
+        )
+
+        op.execute(context=mock.MagicMock())
+
+        mock_wait_for_creating.assert_called_once()
+        to_dict_mock.assert_called_once_with(cluster_running)
+
+    @mock.patch(DATAPROC_PATH.format("Cluster.to_dict"))
+    
@mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator._start_cluster"))
+    
@mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator._get_cluster"))
+    @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+    def test_execute_stopped_cluster_restarts(
+        self,
+        mock_hook,
+        mock_get_cluster,
+        mock_start_cluster,
+        to_dict_mock,
+    ):
+        mock_hook.return_value.wait_for_operation.return_value = None
+
+        cluster_stopped = mock.MagicMock()
+        cluster_stopped.status.state = cluster_stopped.status.State.STOPPED
+
+        mock_get_cluster.return_value = cluster_stopped
+
+        op = DataprocCreateClusterOperator(
+            task_id=TASK_ID,
+            region=GCP_REGION,
+            project_id=GCP_PROJECT,
+            cluster_name=CLUSTER_NAME,
+            cluster_config=CONFIG,
+            deferrable=False,
+        )
+
+        op.execute(context=mock.MagicMock())
+
+        mock_start_cluster.assert_called_once_with(mock_hook.return_value)
+        to_dict_mock.assert_called_once_with(cluster_stopped)
+
+    
@mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator._handle_error_state"))
+    
@mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator._get_cluster"))
+    @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+    def test_execute_error_state_after_wait_for_completion(
+        self,
+        mock_hook,
+        mock_get_cluster,
+        mock_handle_error,
+    ):
+        mock_hook.return_value.wait_for_operation.return_value = None
+
+        cluster_error = mock.MagicMock()
+        cluster_error.status.state = cluster_error.status.State.ERROR
+
+        mock_get_cluster.return_value = cluster_error
+        mock_handle_error.side_effect = AirflowException("Cluster error")
+
+        op = DataprocCreateClusterOperator(
+            task_id=TASK_ID,
+            region=GCP_REGION,
+            project_id=GCP_PROJECT,
+            cluster_name=CLUSTER_NAME,
+            cluster_config=CONFIG,
+            deferrable=False,
+        )
+
+        with pytest.raises(AirflowException):
+            op.execute(context=mock.MagicMock())
+
+        mock_handle_error.assert_called_once()
+
     @mock.patch(DATAPROC_PATH.format("DataprocHook"))
     @mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook"))
     def test_create_execute_call_defer_method(self, mock_trigger_hook, 
mock_hook):

Reply via email to