This is an automated email from the ASF dual-hosted git repository. pankajkoti pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new e07a42e69d Check cluster state before defer Dataproc operators to trigger (#36892) e07a42e69d is described below commit e07a42e69d1ab472c4da991fca5782990607ebe0 Author: Wei Lee <weilee...@gmail.com> AuthorDate: Mon Jan 22 14:32:00 2024 +0800 Check cluster state before defer Dataproc operators to trigger (#36892) While operating a data proc cluster in deferrable mode, the condition might already be met (created, deleted, updated) before we defer the task into the trigger. This PR intends to check thecluster status before deferring the task to trigger. --------- Co-authored-by: Pankaj Koti <pankajkoti...@gmail.com> --- .../providers/google/cloud/operators/dataproc.py | 63 ++++++--- .../google/cloud/operators/test_dataproc.py | 146 ++++++++++++++++++++- 2 files changed, 185 insertions(+), 24 deletions(-) diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index 306e0dc03d..b14121139d 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -721,6 +721,7 @@ class DataprocCreateClusterOperator(GoogleCloudBaseOperator): def execute(self, context: Context) -> dict: self.log.info("Creating cluster: %s", self.cluster_name) hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) + # Save data required to display extra link no matter what the cluster status will be project_id = self.project_id or hook.project_id if project_id: @@ -731,6 +732,7 @@ class DataprocCreateClusterOperator(GoogleCloudBaseOperator): project_id=project_id, region=self.region, ) + try: # First try to create a new cluster operation = self._create_cluster(hook) @@ -741,17 +743,24 @@ class DataprocCreateClusterOperator(GoogleCloudBaseOperator): self.log.info("Cluster created.") return Cluster.to_dict(cluster) else: - self.defer( - trigger=DataprocClusterTrigger( - cluster_name=self.cluster_name, - project_id=self.project_id, - region=self.region, - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - polling_interval_seconds=self.polling_interval_seconds, - ), - method_name="execute_complete", + cluster = hook.get_cluster( + project_id=self.project_id, region=self.region, cluster_name=self.cluster_name ) + if cluster.status.state == cluster.status.State.RUNNING: + self.log.info("Cluster created.") + return Cluster.to_dict(cluster) + else: + self.defer( + trigger=DataprocClusterTrigger( + cluster_name=self.cluster_name, + project_id=self.project_id, + region=self.region, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + polling_interval_seconds=self.polling_interval_seconds, + ), + method_name="execute_complete", + ) except AlreadyExists: if not self.use_if_exists: raise @@ -1016,6 +1025,16 @@ class DataprocDeleteClusterOperator(GoogleCloudBaseOperator): hook.wait_for_operation(timeout=self.timeout, result_retry=self.retry, operation=operation) self.log.info("Cluster deleted.") else: + try: + hook.get_cluster( + project_id=self.project_id, region=self.region, cluster_name=self.cluster_name + ) + except NotFound: + self.log.info("Cluster deleted.") + return + except Exception as e: + raise AirflowException(str(e)) + end_time: float = time.time() + self.timeout self.defer( trigger=DataprocDeleteClusterTrigger( @@ -2480,17 +2499,21 @@ class DataprocUpdateClusterOperator(GoogleCloudBaseOperator): if not self.deferrable: hook.wait_for_operation(timeout=self.timeout, result_retry=self.retry, operation=operation) else: - self.defer( - trigger=DataprocClusterTrigger( - cluster_name=self.cluster_name, - project_id=self.project_id, - region=self.region, - gcp_conn_id=self.gcp_conn_id, - impersonation_chain=self.impersonation_chain, - polling_interval_seconds=self.polling_interval_seconds, - ), - method_name="execute_complete", + cluster = hook.get_cluster( + project_id=self.project_id, region=self.region, cluster_name=self.cluster_name ) + if cluster.status.state != cluster.status.State.RUNNING: + self.defer( + trigger=DataprocClusterTrigger( + cluster_name=self.cluster_name, + project_id=self.project_id, + region=self.region, + gcp_conn_id=self.gcp_conn_id, + impersonation_chain=self.impersonation_chain, + polling_interval_seconds=self.polling_interval_seconds, + ), + method_name="execute_complete", + ) self.log.info("Updated %s cluster.", self.cluster_name) def execute_complete(self, context: Context, event: dict[str, Any]) -> Any: diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py index 59a9c1008c..00f45ca8b3 100644 --- a/tests/providers/google/cloud/operators/test_dataproc.py +++ b/tests/providers/google/cloud/operators/test_dataproc.py @@ -23,7 +23,8 @@ from unittest.mock import MagicMock, Mock, call import pytest from google.api_core.exceptions import AlreadyExists, NotFound from google.api_core.retry import Retry -from google.cloud.dataproc_v1 import Batch, JobStatus +from google.cloud import dataproc +from google.cloud.dataproc_v1 import Batch, Cluster, JobStatus from airflow.exceptions import ( AirflowException, @@ -579,7 +580,7 @@ class TestsClusterGenerator: assert CONFIG_WITH_FLEX_MIG == cluster -class TestDataprocClusterCreateOperator(DataprocClusterTestBase): +class TestDataprocCreateClusterOperator(DataprocClusterTestBase): def test_deprecation_warning(self): with pytest.warns(AirflowProviderDeprecationWarning) as warnings: op = DataprocCreateClusterOperator( @@ -883,6 +884,54 @@ class TestDataprocClusterCreateOperator(DataprocClusterTestBase): assert isinstance(exc.value.trigger, DataprocClusterTrigger) assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME + @mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator.defer")) + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + @mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook")) + def test_create_execute_call_finished_before_defer(self, mock_trigger_hook, mock_hook, mock_defer): + cluster = Cluster( + cluster_name="test_cluster", + status=dataproc.ClusterStatus(state=dataproc.ClusterStatus.State.RUNNING), + ) + mock_hook.return_value.create_cluster.return_value = cluster + mock_hook.return_value.get_cluster.return_value = cluster + operator = DataprocCreateClusterOperator( + task_id=TASK_ID, + region=GCP_REGION, + project_id=GCP_PROJECT, + cluster_config=CONFIG, + labels=LABELS, + cluster_name=CLUSTER_NAME, + delete_on_error=True, + metadata=METADATA, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + retry=RETRY, + timeout=TIMEOUT, + deferrable=True, + ) + + operator.execute(mock.MagicMock()) + assert not mock_defer.called + + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + mock_hook.return_value.create_cluster.assert_called_once_with( + region=GCP_REGION, + project_id=GCP_PROJECT, + cluster_config=CONFIG, + request_id=None, + labels=LABELS, + cluster_name=CLUSTER_NAME, + virtual_cluster_config=None, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + mock_hook.return_value.wait_for_operation.assert_not_called() + @pytest.mark.db_test @pytest.mark.need_serialized_dag @@ -1100,6 +1149,47 @@ class TestDataprocClusterDeleteOperator: assert isinstance(exc.value.trigger, DataprocDeleteClusterTrigger) assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME + @mock.patch(DATAPROC_PATH.format("DataprocDeleteClusterOperator.defer")) + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + @mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook")) + def test_create_execute_call_finished_before_defer(self, mock_trigger_hook, mock_hook, mock_defer): + mock_hook.return_value.create_cluster.return_value = None + mock_hook.return_value.get_cluster.side_effect = NotFound("test") + operator = DataprocDeleteClusterOperator( + task_id=TASK_ID, + region=GCP_REGION, + project_id=GCP_PROJECT, + cluster_name=CLUSTER_NAME, + request_id=REQUEST_ID, + gcp_conn_id=GCP_CONN_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + impersonation_chain=IMPERSONATION_CHAIN, + deferrable=True, + ) + + operator.execute(mock.MagicMock()) + + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + + mock_hook.return_value.delete_cluster.assert_called_once_with( + project_id=GCP_PROJECT, + region=GCP_REGION, + cluster_name=CLUSTER_NAME, + cluster_uuid=None, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + mock_hook.return_value.wait_for_operation.assert_not_called() + assert not mock_defer.called + class TestDataprocSubmitJobOperator(DataprocJobTestBase): @mock.patch(DATAPROC_PATH.format("DataprocHook")) @@ -1240,8 +1330,8 @@ class TestDataprocSubmitJobOperator(DataprocJobTestBase): assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME @mock.patch(DATAPROC_PATH.format("DataprocHook")) - @mock.patch("airflow.providers.google.cloud.operators.dataproc.DataprocSubmitJobOperator.defer") - @mock.patch("airflow.providers.google.cloud.operators.dataproc.DataprocHook.submit_job") + @mock.patch(DATAPROC_PATH.format("DataprocSubmitJobOperator.defer")) + @mock.patch(DATAPROC_PATH.format("DataprocHook.submit_job")) def test_dataproc_operator_execute_async_done_before_defer(self, mock_submit_job, mock_defer, mock_hook): mock_submit_job.return_value.reference.job_id = TEST_JOB_ID job_status = mock_hook.return_value.get_job.return_value.status @@ -1498,6 +1588,54 @@ class TestDataprocUpdateClusterOperator(DataprocClusterTestBase): assert isinstance(exc.value.trigger, DataprocClusterTrigger) assert exc.value.method_name == GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME + @mock.patch(DATAPROC_PATH.format("DataprocCreateClusterOperator.defer")) + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + @mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook")) + def test_create_execute_call_finished_before_defer(self, mock_trigger_hook, mock_hook, mock_defer): + cluster = Cluster( + cluster_name="test_cluster", + status=dataproc.ClusterStatus(state=dataproc.ClusterStatus.State.RUNNING), + ) + mock_hook.return_value.update_cluster.return_value = cluster + mock_hook.return_value.get_cluster.return_value = cluster + operator = DataprocUpdateClusterOperator( + task_id=TASK_ID, + region=GCP_REGION, + cluster_name=CLUSTER_NAME, + cluster=CLUSTER, + update_mask=UPDATE_MASK, + request_id=REQUEST_ID, + graceful_decommission_timeout={"graceful_decommission_timeout": "600s"}, + project_id=GCP_PROJECT, + gcp_conn_id=GCP_CONN_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + impersonation_chain=IMPERSONATION_CHAIN, + deferrable=True, + ) + + operator.execute(mock.MagicMock()) + + mock_hook.assert_called_once_with( + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + ) + mock_hook.return_value.update_cluster.assert_called_once_with( + project_id=GCP_PROJECT, + region=GCP_REGION, + cluster_name=CLUSTER_NAME, + cluster=CLUSTER, + update_mask=UPDATE_MASK, + request_id=REQUEST_ID, + graceful_decommission_timeout={"graceful_decommission_timeout": "600s"}, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + mock_hook.return_value.wait_for_operation.assert_not_called() + assert not mock_defer.called + @pytest.mark.db_test @pytest.mark.need_serialized_dag