This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 4a8a803d678 Refactor non-deferrable execution flow to ensure cluster
state reconciliation runs after creation completes. (#61951)
4a8a803d678 is described below
commit 4a8a803d678b1b3126b51af974c3bc8e51044ce6
Author: SameerMesiah97 <[email protected]>
AuthorDate: Tue Mar 10 20:19:52 2026 +0000
Refactor non-deferrable execution flow to ensure cluster state
reconciliation runs after creation completes. (#61951)
– Extract reconciliation logic into `_reconcile_cluster_state()`
– Ensure DELETING state waits for deletion and re-creates the cluster
– Ensure CREATING state is fully reconciled before returning
– Handle STOPPED state via restart path
– Raise explicit exception if cluster is not found after LRO completion
– Return reconciled cluster to avoid stale state
Update and extend unit tests to cover reconciliation scenarios in the
non-deferrable path (CREATING, DELETING, STOPPED, ERROR, and timeout cases).
Co-authored-by: Sameer Mesiah <[email protected]>
---
.../providers/google/cloud/operators/dataproc.py | 134 +++++++++------
.../unit/google/cloud/operators/test_dataproc.py | 180 ++++++++++++++++++++-
2 files changed, 258 insertions(+), 56 deletions(-)
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
index 7b08af16a98..6dc6e0ab775 100644
--- a/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
+++ b/providers/google/src/airflow/providers/google/cloud/operators/dataproc.py
@@ -811,11 +811,47 @@ class
DataprocCreateClusterOperator(GoogleCloudBaseOperator):
self.log.info("Cluster created.")
return Cluster.to_dict(cluster)
+ def _reconcile_cluster_state(self, hook: DataprocHook, cluster: Cluster)
-> Cluster:
+
+ if cluster.status.state == cluster.status.State.CREATING:
+ self.log.info("Cluster %s is in CREATING state.",
self.cluster_name)
+
+ cluster = self._wait_for_cluster_in_creating_state(hook)
+ self._handle_error_state(hook, cluster)
+ elif cluster.status.state == cluster.status.State.DELETING:
+ self.log.info("Cluster %s is in DELETING state.",
self.cluster_name)
+
+ self._wait_for_cluster_in_deleting_state(hook)
+
+ self.log.info("Attempting to re-create cluster: %s",
self.cluster_name)
+
+ operation = self._create_cluster(hook)
+ hook.wait_for_operation(
+ timeout=self.timeout,
+ result_retry=self.retry,
+ operation=operation,
+ )
+ cluster = self._get_cluster(hook)
+
+ self._handle_error_state(hook, cluster)
+ elif cluster.status.state == cluster.status.State.STOPPED:
+ self.log.info("Cluster %s is in STOPPED state.", self.cluster_name)
+
+ self.log.info("Attempting to re-start cluster: %s",
self.cluster_name)
+
+ # _start_cluster waits for the operation to complete.
+ self._start_cluster(hook)
+
+ cluster = self._get_cluster(hook)
+
+ return cluster
+
def execute(self, context: Context) -> dict:
- self.log.info("Creating cluster: %s", self.cluster_name)
+
+ self.log.info("Attempting to create cluster: %s", self.cluster_name)
hook = DataprocHook(gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain)
- # Save data required to display extra link no matter what the cluster
status will be
+ # Save data required to display extra link regardless of the cluster
status.
project_id = self.project_id or hook.project_id
if project_id:
DataprocClusterLink.persist(
@@ -826,37 +862,40 @@ class
DataprocCreateClusterOperator(GoogleCloudBaseOperator):
)
try:
- # First try to create a new cluster
operation = self._create_cluster(hook)
- if not self.deferrable and type(operation) is not str:
- cluster = hook.wait_for_operation(
- timeout=self.timeout, result_retry=self.retry,
operation=operation
+
+ if not self.deferrable and not isinstance(operation, str):
+ hook.wait_for_operation(
+ timeout=self.timeout,
+ result_retry=self.retry,
+ operation=operation,
)
- self.log.info("Cluster created.")
- return Cluster.to_dict(cluster)
- cluster = hook.get_cluster(
- project_id=self.project_id, region=self.region,
cluster_name=self.cluster_name
- )
- if cluster.status.state == cluster.status.State.RUNNING:
- self.log.info("Cluster created.")
- return Cluster.to_dict(cluster)
- self.defer(
- trigger=DataprocClusterTrigger(
- cluster_name=self.cluster_name,
- project_id=self.project_id,
- region=self.region,
- gcp_conn_id=self.gcp_conn_id,
- impersonation_chain=self.impersonation_chain,
- polling_interval_seconds=self.polling_interval_seconds,
- delete_on_error=self.delete_on_error,
- ),
- method_name="execute_complete",
- )
+
+ # Fetch current state.
+ cluster = self._get_cluster(hook)
+
+ if self.deferrable:
+ if cluster.status.state != cluster.status.State.RUNNING:
+ self.defer(
+ trigger=DataprocClusterTrigger(
+ cluster_name=self.cluster_name,
+ project_id=self.project_id,
+ region=self.region,
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+
polling_interval_seconds=self.polling_interval_seconds,
+ delete_on_error=self.delete_on_error,
+ ),
+ method_name="execute_complete",
+ )
+ return
+
except AlreadyExists:
if not self.use_if_exists:
raise
- self.log.info("Cluster already exists.")
+ self.log.info("Cluster %s already exists.", self.cluster_name)
cluster = self._get_cluster(hook)
+
except DataprocResourceIsNotReadyError as resource_not_ready_error:
if self.num_retries_if_resource_is_not_ready:
attempt = self.num_retries_if_resource_is_not_ready
@@ -876,39 +915,28 @@ class
DataprocCreateClusterOperator(GoogleCloudBaseOperator):
self._delete_cluster(hook)
self._wait_for_cluster_in_deleting_state(hook)
raise resource_not_ready_error
- except AirflowException as ae:
- # There still could be a cluster created here in an ERROR state
which
- # should be deleted immediately rather than consuming another
retry attempt
- # (assuming delete_on_error is true (default))
- # This reduces overall the number of task attempts from 3 to 2 to
successful cluster creation
- # assuming the underlying GCE issues have resolved within that
window. Users can configure
- # a higher number of retry attempts in powers of two with 30s-60s
wait interval
+
+ except AirflowException as outer_airflow_exception:
+ # A cluster may have been created but entered ERROR state.
+ # If delete_on_error is enabled, delete it immediately so that
+ # the next retry attempt starts from a clean state.
try:
cluster = self._get_cluster(hook)
self._handle_error_state(hook, cluster)
- except AirflowException as ae_inner:
- # We could get any number of failures here, including cluster
not found and we
- # can just ignore to ensure we surface the original cluster
create failure
- self.log.exception(ae_inner)
+ except AirflowException as inner_airflow_exception:
+ # Cleanup logic may raise secondary exceptions (e.g., cluster
not found).
+ # Suppress those so that the original cluster creation failure
is surfaced.
+ self.log.exception(inner_airflow_exception)
finally:
- raise ae
+ raise outer_airflow_exception
- # Check if cluster is not in ERROR state
+ # Check if cluster is not in ERROR state.
self._handle_error_state(hook, cluster)
- if cluster.status.state == cluster.status.State.CREATING:
- # Wait for cluster to be created
- cluster = self._wait_for_cluster_in_creating_state(hook)
- self._handle_error_state(hook, cluster)
- elif cluster.status.state == cluster.status.State.DELETING:
- # Wait for cluster to be deleted
- self._wait_for_cluster_in_deleting_state(hook)
- # Create new cluster
- cluster = self._create_cluster(hook)
- self._handle_error_state(hook, cluster)
- elif cluster.status.state == cluster.status.State.STOPPED:
- # if the cluster exists and already stopped, then start the cluster
- self._start_cluster(hook)
+ # If cluster is not in RUNNING state, reconcile.
+ cluster = self._reconcile_cluster_state(hook, cluster)
+
+ self.log.info("Cluster %s is RUNNING.", self.cluster_name)
return Cluster.to_dict(cluster)
def execute_complete(self, context: Context, event: dict[str, Any]) -> Any:
diff --git
a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
index 3db5f497e4d..d284c875548 100644
--- a/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
+++ b/providers/google/tests/unit/google/cloud/operators/test_dataproc.py
@@ -826,7 +826,7 @@ class
TestDataprocCreateClusterOperator(DataprocClusterTestBase):
# Test whether xcom push occurs before create cluster is called
self.extra_links_manager_mock.assert_has_calls(expected_calls,
any_order=False)
- to_dict_mock.assert_called_once_with(mock_hook().wait_for_operation())
+
to_dict_mock.assert_called_once_with(mock_hook.return_value.get_cluster.return_value)
if AIRFLOW_V_3_0_PLUS:
self.mock_ti.xcom_push.assert_called_once_with(
key="dataproc_cluster",
@@ -881,7 +881,7 @@ class
TestDataprocCreateClusterOperator(DataprocClusterTestBase):
# Test whether xcom push occurs before create cluster is called
self.extra_links_manager_mock.assert_has_calls(expected_calls,
any_order=False)
- to_dict_mock.assert_called_once_with(mock_hook().wait_for_operation())
+
to_dict_mock.assert_called_once_with(mock_hook.return_value.get_cluster.return_value)
if AIRFLOW_V_3_0_PLUS:
self.mock_ti.xcom_push.assert_called_once_with(
key="dataproc_cluster",
@@ -1015,7 +1015,7 @@ class
TestDataprocCreateClusterOperator(DataprocClusterTestBase):
mock_create_cluster.side_effect = [AlreadyExists("test"),
cluster_running]
mock_generator.return_value = [0]
- mock_get_cluster.side_effect = [cluster_deleting, NotFound("test")]
+ mock_get_cluster.side_effect = [cluster_deleting, NotFound("test"),
cluster_running]
op = DataprocCreateClusterOperator(
task_id=TASK_ID,
@@ -1035,6 +1035,180 @@ class
TestDataprocCreateClusterOperator(DataprocClusterTestBase):
to_dict_mock.assert_called_once_with(cluster_running)
+ @mock.patch(DATAPROC_PATH.format("Cluster.to_dict"))
+
@mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator._wait_for_cluster_in_deleting_state"))
+
@mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator._get_cluster"))
+ @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+ def test_execute_recreates_when_deleted_during_creation(
+ self,
+ mock_hook,
+ mock_get_cluster,
+ mock_wait_for_deleting,
+ to_dict_mock,
+ ):
+ mock_hook.return_value.wait_for_operation.return_value = None
+
+ # First invocation of get_cluster should return cluster in DELETING
state.
+ cluster_deleting = mock.MagicMock()
+ cluster_deleting.status.state = cluster_deleting.status.State.DELETING
+
+ # Re-creation should return cluster in RUNNING state.
+ cluster_running = mock.MagicMock()
+ cluster_running.status.state = cluster_running.status.State.RUNNING
+
+ mock_get_cluster.side_effect = [
+ cluster_deleting,
+ cluster_running,
+ ]
+
+ op = DataprocCreateClusterOperator(
+ task_id=TASK_ID,
+ region=GCP_REGION,
+ project_id=GCP_PROJECT,
+ cluster_name=CLUSTER_NAME,
+ cluster_config=CONFIG,
+ deferrable=False,
+ )
+
+ op.execute(context=mock.MagicMock())
+
+ # Ensure re-creation path is traversed.
+ assert mock_wait_for_deleting.called
+ assert mock_hook.return_value.create_cluster.call_count == 2
+
+ to_dict_mock.assert_called_once_with(cluster_running)
+
+
@mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator._wait_for_cluster_in_deleting_state"))
+
@mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator._get_cluster"))
+ @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+ def test_execute_deleting_timeout_raises(
+ self,
+ mock_hook,
+ mock_get_cluster,
+ mock_wait_for_deleting,
+ ):
+ mock_hook.return_value.wait_for_operation.return_value = None
+
+ cluster_deleting = mock.MagicMock()
+ cluster_deleting.status.state = cluster_deleting.status.State.DELETING
+
+ mock_get_cluster.return_value = cluster_deleting
+ mock_wait_for_deleting.side_effect = AirflowException("Timeout")
+
+ op = DataprocCreateClusterOperator(
+ task_id=TASK_ID,
+ region=GCP_REGION,
+ project_id=GCP_PROJECT,
+ cluster_name=CLUSTER_NAME,
+ cluster_config=CONFIG,
+ deferrable=False,
+ )
+
+ with pytest.raises(AirflowException):
+ op.execute(context=mock.MagicMock())
+
+ # Ensure no re-creation is attempted.
+ assert mock_hook.return_value.create_cluster.call_count == 1
+
+ @mock.patch(DATAPROC_PATH.format("Cluster.to_dict"))
+
@mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator._wait_for_cluster_in_creating_state"))
+
@mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator._get_cluster"))
+ @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+ def test_execute_waits_when_still_creating(
+ self,
+ mock_hook,
+ mock_get_cluster,
+ mock_wait_for_creating,
+ to_dict_mock,
+ ):
+ mock_hook.return_value.wait_for_operation.return_value = None
+
+ cluster_creating = mock.MagicMock()
+ cluster_creating.status.state = cluster_creating.status.State.CREATING
+
+ cluster_running = mock.MagicMock()
+ cluster_running.status.state = cluster_running.status.State.RUNNING
+
+ mock_get_cluster.return_value = cluster_creating
+ mock_wait_for_creating.return_value = cluster_running
+
+ op = DataprocCreateClusterOperator(
+ task_id=TASK_ID,
+ region=GCP_REGION,
+ project_id=GCP_PROJECT,
+ cluster_name=CLUSTER_NAME,
+ cluster_config=CONFIG,
+ deferrable=False,
+ )
+
+ op.execute(context=mock.MagicMock())
+
+ mock_wait_for_creating.assert_called_once()
+ to_dict_mock.assert_called_once_with(cluster_running)
+
+ @mock.patch(DATAPROC_PATH.format("Cluster.to_dict"))
+
@mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator._start_cluster"))
+
@mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator._get_cluster"))
+ @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+ def test_execute_stopped_cluster_restarts(
+ self,
+ mock_hook,
+ mock_get_cluster,
+ mock_start_cluster,
+ to_dict_mock,
+ ):
+ mock_hook.return_value.wait_for_operation.return_value = None
+
+ cluster_stopped = mock.MagicMock()
+ cluster_stopped.status.state = cluster_stopped.status.State.STOPPED
+
+ mock_get_cluster.return_value = cluster_stopped
+
+ op = DataprocCreateClusterOperator(
+ task_id=TASK_ID,
+ region=GCP_REGION,
+ project_id=GCP_PROJECT,
+ cluster_name=CLUSTER_NAME,
+ cluster_config=CONFIG,
+ deferrable=False,
+ )
+
+ op.execute(context=mock.MagicMock())
+
+ mock_start_cluster.assert_called_once_with(mock_hook.return_value)
+ to_dict_mock.assert_called_once_with(cluster_stopped)
+
+
@mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator._handle_error_state"))
+
@mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator._get_cluster"))
+ @mock.patch(DATAPROC_PATH.format("DataprocHook"))
+ def test_execute_error_state_after_wait_for_completion(
+ self,
+ mock_hook,
+ mock_get_cluster,
+ mock_handle_error,
+ ):
+ mock_hook.return_value.wait_for_operation.return_value = None
+
+ cluster_error = mock.MagicMock()
+ cluster_error.status.state = cluster_error.status.State.ERROR
+
+ mock_get_cluster.return_value = cluster_error
+ mock_handle_error.side_effect = AirflowException("Cluster error")
+
+ op = DataprocCreateClusterOperator(
+ task_id=TASK_ID,
+ region=GCP_REGION,
+ project_id=GCP_PROJECT,
+ cluster_name=CLUSTER_NAME,
+ cluster_config=CONFIG,
+ deferrable=False,
+ )
+
+ with pytest.raises(AirflowException):
+ op.execute(context=mock.MagicMock())
+
+ mock_handle_error.assert_called_once()
+
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
@mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook"))
def test_create_execute_call_defer_method(self, mock_trigger_hook,
mock_hook):