This is an automated email from the ASF dual-hosted git repository. turbaszek pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push: new 309788e Refactor DataprocOperators to support google-cloud-dataproc 2.0 (#13256) 309788e is described below commit 309788e5e2023c598095a4ee00df417d94b6a5df Author: Tomek Urbaszek <turbas...@gmail.com> AuthorDate: Mon Jan 18 17:49:19 2021 +0100 Refactor DataprocOperators to support google-cloud-dataproc 2.0 (#13256) --- airflow/providers/google/ADDITIONAL_INFO.md | 2 + airflow/providers/google/cloud/hooks/dataproc.py | 104 ++++++++--------- .../providers/google/cloud/operators/dataproc.py | 30 +++-- airflow/providers/google/cloud/sensors/dataproc.py | 12 +- setup.py | 2 +- .../providers/google/cloud/hooks/test_dataproc.py | 129 ++++++++++++--------- .../google/cloud/operators/test_dataproc.py | 14 ++- .../google/cloud/sensors/test_dataproc.py | 8 +- 8 files changed, 157 insertions(+), 144 deletions(-) diff --git a/airflow/providers/google/ADDITIONAL_INFO.md b/airflow/providers/google/ADDITIONAL_INFO.md index c696e1b..16a6683 100644 --- a/airflow/providers/google/ADDITIONAL_INFO.md +++ b/airflow/providers/google/ADDITIONAL_INFO.md @@ -32,11 +32,13 @@ Details are covered in the UPDATING.md files for each library, but there are som | [``google-cloud-automl``](https://pypi.org/project/google-cloud-automl/) | ``>=0.4.0,<2.0.0`` | ``>=2.1.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-bigquery-automl/blob/master/UPGRADING.md) | | [``google-cloud-bigquery-datatransfer``](https://pypi.org/project/google-cloud-bigquery-datatransfer/) | ``>=0.4.0,<2.0.0`` | ``>=3.0.0,<4.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-bigquery-datatransfer/blob/master/UPGRADING.md) | | [``google-cloud-datacatalog``](https://pypi.org/project/google-cloud-datacatalog/) | ``>=0.5.0,<0.8`` | ``>=3.0.0,<4.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-datacatalog/blob/master/UPGRADING.md) | +| [``google-cloud-dataproc``](https://pypi.org/project/google-cloud-dataproc/) | ``>=1.0.1,<2.0.0`` | ``>=2.2.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-dataproc/blob/master/UPGRADING.md) | | [``google-cloud-kms``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.2.1,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-kms/blob/master/UPGRADING.md) | | [``google-cloud-os-login``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-oslogin/blob/master/UPGRADING.md) | | [``google-cloud-pubsub``](https://pypi.org/project/google-cloud-pubsub/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-pubsub/blob/master/UPGRADING.md) | | [``google-cloud-tasks``](https://pypi.org/project/google-cloud-tasks/) | ``>=1.2.1,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-tasks/blob/master/UPGRADING.md) | + ### The field names use the snake_case convention If your DAG uses an object from the above mentioned libraries passed by XCom, it is necessary to update the naming convention of the fields that are read. Previously, the fields used the CamelSnake convention, now the snake_case convention is used. diff --git a/airflow/providers/google/cloud/hooks/dataproc.py b/airflow/providers/google/cloud/hooks/dataproc.py index 12d5941..35d4786 100644 --- a/airflow/providers/google/cloud/hooks/dataproc.py +++ b/airflow/providers/google/cloud/hooks/dataproc.py @@ -26,18 +26,16 @@ from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union from google.api_core.exceptions import ServerError from google.api_core.retry import Retry from google.cloud.dataproc_v1beta2 import ( # pylint: disable=no-name-in-module - ClusterControllerClient, - JobControllerClient, - WorkflowTemplateServiceClient, -) -from google.cloud.dataproc_v1beta2.types import ( # pylint: disable=no-name-in-module Cluster, - Duration, - FieldMask, + ClusterControllerClient, Job, + JobControllerClient, JobStatus, WorkflowTemplate, + WorkflowTemplateServiceClient, ) +from google.protobuf.duration_pb2 import Duration +from google.protobuf.field_mask_pb2 import FieldMask from airflow.exceptions import AirflowException from airflow.providers.google.common.hooks.base_google import GoogleBaseHook @@ -291,10 +289,12 @@ class DataprocHook(GoogleBaseHook): client = self.get_cluster_client(location=region) result = client.create_cluster( - project_id=project_id, - region=region, - cluster=cluster, - request_id=request_id, + request={ + 'project_id': project_id, + 'region': region, + 'cluster': cluster, + 'request_id': request_id, + }, retry=retry, timeout=timeout, metadata=metadata, @@ -340,11 +340,13 @@ class DataprocHook(GoogleBaseHook): """ client = self.get_cluster_client(location=region) result = client.delete_cluster( - project_id=project_id, - region=region, - cluster_name=cluster_name, - cluster_uuid=cluster_uuid, - request_id=request_id, + request={ + 'project_id': project_id, + 'region': region, + 'cluster_name': cluster_name, + 'cluster_uuid': cluster_uuid, + 'request_id': request_id, + }, retry=retry, timeout=timeout, metadata=metadata, @@ -382,9 +384,7 @@ class DataprocHook(GoogleBaseHook): """ client = self.get_cluster_client(location=region) operation = client.diagnose_cluster( - project_id=project_id, - region=region, - cluster_name=cluster_name, + request={'project_id': project_id, 'region': region, 'cluster_name': cluster_name}, retry=retry, timeout=timeout, metadata=metadata, @@ -423,9 +423,7 @@ class DataprocHook(GoogleBaseHook): """ client = self.get_cluster_client(location=region) result = client.get_cluster( - project_id=project_id, - region=region, - cluster_name=cluster_name, + request={'project_id': project_id, 'region': region, 'cluster_name': cluster_name}, retry=retry, timeout=timeout, metadata=metadata, @@ -467,10 +465,7 @@ class DataprocHook(GoogleBaseHook): """ client = self.get_cluster_client(location=region) result = client.list_clusters( - project_id=project_id, - region=region, - filter_=filter_, - page_size=page_size, + request={'project_id': project_id, 'region': region, 'filter': filter_, 'page_size': page_size}, retry=retry, timeout=timeout, metadata=metadata, @@ -551,13 +546,15 @@ class DataprocHook(GoogleBaseHook): """ client = self.get_cluster_client(location=location) operation = client.update_cluster( - project_id=project_id, - region=location, - cluster_name=cluster_name, - cluster=cluster, - update_mask=update_mask, - graceful_decommission_timeout=graceful_decommission_timeout, - request_id=request_id, + request={ + 'project_id': project_id, + 'region': location, + 'cluster_name': cluster_name, + 'cluster': cluster, + 'update_mask': update_mask, + 'graceful_decommission_timeout': graceful_decommission_timeout, + 'request_id': request_id, + }, retry=retry, timeout=timeout, metadata=metadata, @@ -593,10 +590,11 @@ class DataprocHook(GoogleBaseHook): :param metadata: Additional metadata that is provided to the method. :type metadata: Sequence[Tuple[str, str]] """ + metadata = metadata or () client = self.get_template_client(location) - parent = client.region_path(project_id, location) + parent = f'projects/{project_id}/regions/{location}' return client.create_workflow_template( - parent=parent, template=template, retry=retry, timeout=timeout, metadata=metadata + request={'parent': parent, 'template': template}, retry=retry, timeout=timeout, metadata=metadata ) @GoogleBaseHook.fallback_to_default_project_id @@ -643,13 +641,11 @@ class DataprocHook(GoogleBaseHook): :param metadata: Additional metadata that is provided to the method. :type metadata: Sequence[Tuple[str, str]] """ + metadata = metadata or () client = self.get_template_client(location) - name = client.workflow_template_path(project_id, location, template_name) + name = f'projects/{project_id}/regions/{location}/workflowTemplates/{template_name}' operation = client.instantiate_workflow_template( - name=name, - version=version, - parameters=parameters, - request_id=request_id, + request={'name': name, 'version': version, 'request_id': request_id, 'parameters': parameters}, retry=retry, timeout=timeout, metadata=metadata, @@ -690,12 +686,11 @@ class DataprocHook(GoogleBaseHook): :param metadata: Additional metadata that is provided to the method. :type metadata: Sequence[Tuple[str, str]] """ + metadata = metadata or () client = self.get_template_client(location) - parent = client.region_path(project_id, location) + parent = f'projects/{project_id}/regions/{location}' operation = client.instantiate_inline_workflow_template( - parent=parent, - template=template, - request_id=request_id, + request={'parent': parent, 'template': template, 'request_id': request_id}, retry=retry, timeout=timeout, metadata=metadata, @@ -722,19 +717,19 @@ class DataprocHook(GoogleBaseHook): """ state = None start = time.monotonic() - while state not in (JobStatus.ERROR, JobStatus.DONE, JobStatus.CANCELLED): + while state not in (JobStatus.State.ERROR, JobStatus.State.DONE, JobStatus.State.CANCELLED): if timeout and start + timeout < time.monotonic(): raise AirflowException(f"Timeout: dataproc job {job_id} is not ready after {timeout}s") time.sleep(wait_time) try: - job = self.get_job(location=location, job_id=job_id, project_id=project_id) + job = self.get_job(project_id=project_id, location=location, job_id=job_id) state = job.status.state except ServerError as err: self.log.info("Retrying. Dataproc API returned server error when waiting for job: %s", err) - if state == JobStatus.ERROR: + if state == JobStatus.State.ERROR: raise AirflowException(f'Job failed:\n{job}') - if state == JobStatus.CANCELLED: + if state == JobStatus.State.CANCELLED: raise AirflowException(f'Job was cancelled:\n{job}') @GoogleBaseHook.fallback_to_default_project_id @@ -767,9 +762,7 @@ class DataprocHook(GoogleBaseHook): """ client = self.get_job_client(location=location) job = client.get_job( - project_id=project_id, - region=location, - job_id=job_id, + request={'project_id': project_id, 'region': location, 'job_id': job_id}, retry=retry, timeout=timeout, metadata=metadata, @@ -812,10 +805,7 @@ class DataprocHook(GoogleBaseHook): """ client = self.get_job_client(location=location) return client.submit_job( - project_id=project_id, - region=location, - job=job, - request_id=request_id, + request={'project_id': project_id, 'region': location, 'job': job, 'request_id': request_id}, retry=retry, timeout=timeout, metadata=metadata, @@ -884,9 +874,7 @@ class DataprocHook(GoogleBaseHook): client = self.get_job_client(location=location) job = client.cancel_job( - project_id=project_id, - region=location, - job_id=job_id, + request={'project_id': project_id, 'region': location, 'job_id': job_id}, retry=retry, timeout=timeout, metadata=metadata, diff --git a/airflow/providers/google/cloud/operators/dataproc.py b/airflow/providers/google/cloud/operators/dataproc.py index 839a624..a7d1379 100644 --- a/airflow/providers/google/cloud/operators/dataproc.py +++ b/airflow/providers/google/cloud/operators/dataproc.py @@ -17,7 +17,6 @@ # under the License. # """This module contains Google Dataproc operators.""" -# pylint: disable=C0302 import inspect import ntpath @@ -31,12 +30,9 @@ from typing import Dict, List, Optional, Sequence, Set, Tuple, Union from google.api_core.exceptions import AlreadyExists, NotFound from google.api_core.retry import Retry, exponential_sleep_generator -from google.cloud.dataproc_v1beta2.types import ( # pylint: disable=no-name-in-module - Cluster, - Duration, - FieldMask, -) -from google.protobuf.json_format import MessageToDict +from google.cloud.dataproc_v1beta2 import Cluster # pylint: disable=no-name-in-module +from google.protobuf.duration_pb2 import Duration +from google.protobuf.field_mask_pb2 import FieldMask from airflow.exceptions import AirflowException from airflow.models import BaseOperator @@ -562,7 +558,7 @@ class DataprocCreateClusterOperator(BaseOperator): ) def _handle_error_state(self, hook: DataprocHook, cluster: Cluster) -> None: - if cluster.status.state != cluster.status.ERROR: + if cluster.status.state != cluster.status.State.ERROR: return self.log.info("Cluster is in ERROR state") gcs_uri = hook.diagnose_cluster( @@ -590,7 +586,7 @@ class DataprocCreateClusterOperator(BaseOperator): time_left = self.timeout cluster = self._get_cluster(hook) for time_to_sleep in exponential_sleep_generator(initial=10, maximum=120): - if cluster.status.state != cluster.status.CREATING: + if cluster.status.state != cluster.status.State.CREATING: break if time_left < 0: raise AirflowException(f"Cluster {self.cluster_name} is still CREATING state, aborting") @@ -613,18 +609,18 @@ class DataprocCreateClusterOperator(BaseOperator): # Check if cluster is not in ERROR state self._handle_error_state(hook, cluster) - if cluster.status.state == cluster.status.CREATING: + if cluster.status.state == cluster.status.State.CREATING: # Wait for cluster to be be created cluster = self._wait_for_cluster_in_creating_state(hook) self._handle_error_state(hook, cluster) - elif cluster.status.state == cluster.status.DELETING: + elif cluster.status.state == cluster.status.State.DELETING: # Wait for cluster to be deleted self._wait_for_cluster_in_deleting_state(hook) # Create new cluster cluster = self._create_cluster(hook) self._handle_error_state(hook, cluster) - return MessageToDict(cluster) + return Cluster.to_dict(cluster) class DataprocScaleClusterOperator(BaseOperator): @@ -1855,7 +1851,7 @@ class DataprocSubmitJobOperator(BaseOperator): :type wait_timeout: int """ - template_fields = ('project_id', 'location', 'job', 'impersonation_chain') + template_fields = ('project_id', 'location', 'job', 'impersonation_chain', 'request_id') template_fields_renderers = {"job": "json"} @apply_defaults @@ -1941,14 +1937,14 @@ class DataprocUpdateClusterOperator(BaseOperator): example, to change the number of workers in a cluster to 5, the ``update_mask`` parameter would be specified as ``config.worker_config.num_instances``, and the ``PATCH`` request body would specify the new value. If a dict is provided, it must be of the same form as the protobuf message - :class:`~google.cloud.dataproc_v1beta2.types.FieldMask` - :type update_mask: Union[Dict, google.cloud.dataproc_v1beta2.types.FieldMask] + :class:`~google.protobuf.field_mask_pb2.FieldMask` + :type update_mask: Union[Dict, google.protobuf.field_mask_pb2.FieldMask] :param graceful_decommission_timeout: Optional. Timeout for graceful YARN decommissioning. Graceful decommissioning allows removing nodes from the cluster without interrupting jobs in progress. Timeout specifies how long to wait for jobs in progress to finish before forcefully removing nodes (and potentially interrupting jobs). Default timeout is 0 (for forceful decommission), and the maximum allowed timeout is 1 day. - :type graceful_decommission_timeout: Union[Dict, google.cloud.dataproc_v1beta2.types.Duration] + :type graceful_decommission_timeout: Union[Dict, google.protobuf.duration_pb2.Duration] :param request_id: Optional. A unique id used to identify the request. If the server receives two ``UpdateClusterRequest`` requests with the same id, then the second request will be ignored and the first ``google.longrunning.Operation`` created and stored in the backend is returned. @@ -1974,7 +1970,7 @@ class DataprocUpdateClusterOperator(BaseOperator): :type impersonation_chain: Union[str, Sequence[str]] """ - template_fields = ('impersonation_chain',) + template_fields = ('impersonation_chain', 'cluster_name') @apply_defaults def __init__( # pylint: disable=too-many-arguments diff --git a/airflow/providers/google/cloud/sensors/dataproc.py b/airflow/providers/google/cloud/sensors/dataproc.py index 1777a22..93656df 100644 --- a/airflow/providers/google/cloud/sensors/dataproc.py +++ b/airflow/providers/google/cloud/sensors/dataproc.py @@ -65,14 +65,18 @@ class DataprocJobSensor(BaseSensorOperator): job = hook.get_job(job_id=self.dataproc_job_id, location=self.location, project_id=self.project_id) state = job.status.state - if state == JobStatus.ERROR: + if state == JobStatus.State.ERROR: raise AirflowException(f'Job failed:\n{job}') - elif state in {JobStatus.CANCELLED, JobStatus.CANCEL_PENDING, JobStatus.CANCEL_STARTED}: + elif state in { + JobStatus.State.CANCELLED, + JobStatus.State.CANCEL_PENDING, + JobStatus.State.CANCEL_STARTED, + }: raise AirflowException(f'Job was cancelled:\n{job}') - elif JobStatus.DONE == state: + elif JobStatus.State.DONE == state: self.log.debug("Job %s completed successfully.", self.dataproc_job_id) return True - elif JobStatus.ATTEMPT_FAILURE == state: + elif JobStatus.State.ATTEMPT_FAILURE == state: self.log.debug("Job %s attempt has failed.", self.dataproc_job_id) self.log.info("Waiting for job %s to complete.", self.dataproc_job_id) diff --git a/setup.py b/setup.py index eba4a7a..da29cd1 100644 --- a/setup.py +++ b/setup.py @@ -286,7 +286,7 @@ google = [ 'google-cloud-bigtable>=1.0.0,<2.0.0', 'google-cloud-container>=0.1.1,<2.0.0', 'google-cloud-datacatalog>=3.0.0,<4.0.0', - 'google-cloud-dataproc>=1.0.1,<2.0.0', + 'google-cloud-dataproc>=2.2.0,<3.0.0', 'google-cloud-dlp>=0.11.0,<2.0.0', 'google-cloud-kms>=2.0.0,<3.0.0', 'google-cloud-language>=1.1.1,<2.0.0', diff --git a/tests/providers/google/cloud/hooks/test_dataproc.py b/tests/providers/google/cloud/hooks/test_dataproc.py index d09c91e..6842acc 100644 --- a/tests/providers/google/cloud/hooks/test_dataproc.py +++ b/tests/providers/google/cloud/hooks/test_dataproc.py @@ -20,7 +20,7 @@ import unittest from unittest import mock import pytest -from google.cloud.dataproc_v1beta2.types import JobStatus # pylint: disable=no-name-in-module +from google.cloud.dataproc_v1beta2 import JobStatus # pylint: disable=no-name-in-module from airflow.exceptions import AirflowException from airflow.providers.google.cloud.hooks.dataproc import DataprocHook, DataProcJobBuilder @@ -43,8 +43,6 @@ CLUSTER = { "project_id": GCP_PROJECT, } -PARENT = "parent" -NAME = "name" BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}" DATAPROC_STRING = "airflow.providers.google.cloud.hooks.dataproc.{}" @@ -113,11 +111,13 @@ class TestDataprocHook(unittest.TestCase): ) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.create_cluster.assert_called_once_with( - project_id=GCP_PROJECT, - region=GCP_LOCATION, - cluster=CLUSTER, + request=dict( + project_id=GCP_PROJECT, + region=GCP_LOCATION, + cluster=CLUSTER, + request_id=None, + ), metadata=None, - request_id=None, retry=None, timeout=None, ) @@ -127,12 +127,14 @@ class TestDataprocHook(unittest.TestCase): self.hook.delete_cluster(project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.delete_cluster.assert_called_once_with( - project_id=GCP_PROJECT, - region=GCP_LOCATION, - cluster_name=CLUSTER_NAME, - cluster_uuid=None, + request=dict( + project_id=GCP_PROJECT, + region=GCP_LOCATION, + cluster_name=CLUSTER_NAME, + cluster_uuid=None, + request_id=None, + ), metadata=None, - request_id=None, retry=None, timeout=None, ) @@ -142,9 +144,11 @@ class TestDataprocHook(unittest.TestCase): self.hook.diagnose_cluster(project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.diagnose_cluster.assert_called_once_with( - project_id=GCP_PROJECT, - region=GCP_LOCATION, - cluster_name=CLUSTER_NAME, + request=dict( + project_id=GCP_PROJECT, + region=GCP_LOCATION, + cluster_name=CLUSTER_NAME, + ), metadata=None, retry=None, timeout=None, @@ -156,9 +160,11 @@ class TestDataprocHook(unittest.TestCase): self.hook.get_cluster(project_id=GCP_PROJECT, region=GCP_LOCATION, cluster_name=CLUSTER_NAME) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.get_cluster.assert_called_once_with( - project_id=GCP_PROJECT, - region=GCP_LOCATION, - cluster_name=CLUSTER_NAME, + request=dict( + project_id=GCP_PROJECT, + region=GCP_LOCATION, + cluster_name=CLUSTER_NAME, + ), metadata=None, retry=None, timeout=None, @@ -171,10 +177,12 @@ class TestDataprocHook(unittest.TestCase): self.hook.list_clusters(project_id=GCP_PROJECT, region=GCP_LOCATION, filter_=filter_) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.list_clusters.assert_called_once_with( - project_id=GCP_PROJECT, - region=GCP_LOCATION, - filter_=filter_, - page_size=None, + request=dict( + project_id=GCP_PROJECT, + region=GCP_LOCATION, + filter=filter_, + page_size=None, + ), metadata=None, retry=None, timeout=None, @@ -192,14 +200,16 @@ class TestDataprocHook(unittest.TestCase): ) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.update_cluster.assert_called_once_with( - project_id=GCP_PROJECT, - region=GCP_LOCATION, - cluster=CLUSTER, - cluster_name=CLUSTER_NAME, - update_mask=update_mask, - graceful_decommission_timeout=None, + request=dict( + project_id=GCP_PROJECT, + region=GCP_LOCATION, + cluster=CLUSTER, + cluster_name=CLUSTER_NAME, + update_mask=update_mask, + graceful_decommission_timeout=None, + request_id=None, + ), metadata=None, - request_id=None, retry=None, timeout=None, ) @@ -207,44 +217,45 @@ class TestDataprocHook(unittest.TestCase): @mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client")) def test_create_workflow_template(self, mock_client): template = {"test": "test"} - mock_client.return_value.region_path.return_value = PARENT + parent = f'projects/{GCP_PROJECT}/regions/{GCP_LOCATION}' self.hook.create_workflow_template(location=GCP_LOCATION, template=template, project_id=GCP_PROJECT) - mock_client.return_value.region_path.assert_called_once_with(GCP_PROJECT, GCP_LOCATION) mock_client.return_value.create_workflow_template.assert_called_once_with( - parent=PARENT, template=template, retry=None, timeout=None, metadata=None + request=dict(parent=parent, template=template), retry=None, timeout=None, metadata=() ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client")) def test_instantiate_workflow_template(self, mock_client): template_name = "template_name" - mock_client.return_value.workflow_template_path.return_value = NAME + name = f'projects/{GCP_PROJECT}/regions/{GCP_LOCATION}/workflowTemplates/{template_name}' self.hook.instantiate_workflow_template( location=GCP_LOCATION, template_name=template_name, project_id=GCP_PROJECT ) - mock_client.return_value.workflow_template_path.assert_called_once_with( - GCP_PROJECT, GCP_LOCATION, template_name - ) mock_client.return_value.instantiate_workflow_template.assert_called_once_with( - name=NAME, version=None, parameters=None, request_id=None, retry=None, timeout=None, metadata=None + request=dict(name=name, version=None, parameters=None, request_id=None), + retry=None, + timeout=None, + metadata=(), ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_template_client")) def test_instantiate_inline_workflow_template(self, mock_client): template = {"test": "test"} - mock_client.return_value.region_path.return_value = PARENT + parent = f'projects/{GCP_PROJECT}/regions/{GCP_LOCATION}' self.hook.instantiate_inline_workflow_template( location=GCP_LOCATION, template=template, project_id=GCP_PROJECT ) - mock_client.return_value.region_path.assert_called_once_with(GCP_PROJECT, GCP_LOCATION) mock_client.return_value.instantiate_inline_workflow_template.assert_called_once_with( - parent=PARENT, template=template, request_id=None, retry=None, timeout=None, metadata=None + request=dict(parent=parent, template=template, request_id=None), + retry=None, + timeout=None, + metadata=(), ) @mock.patch(DATAPROC_STRING.format("DataprocHook.get_job")) def test_wait_for_job(self, mock_get_job): mock_get_job.side_effect = [ - mock.MagicMock(status=mock.MagicMock(state=JobStatus.RUNNING)), - mock.MagicMock(status=mock.MagicMock(state=JobStatus.ERROR)), + mock.MagicMock(status=mock.MagicMock(state=JobStatus.State.RUNNING)), + mock.MagicMock(status=mock.MagicMock(state=JobStatus.State.ERROR)), ] with pytest.raises(AirflowException): self.hook.wait_for_job(job_id=JOB_ID, location=GCP_LOCATION, project_id=GCP_PROJECT, wait_time=0) @@ -259,9 +270,11 @@ class TestDataprocHook(unittest.TestCase): self.hook.get_job(location=GCP_LOCATION, job_id=JOB_ID, project_id=GCP_PROJECT) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.get_job.assert_called_once_with( - region=GCP_LOCATION, - job_id=JOB_ID, - project_id=GCP_PROJECT, + request=dict( + region=GCP_LOCATION, + job_id=JOB_ID, + project_id=GCP_PROJECT, + ), retry=None, timeout=None, metadata=None, @@ -272,10 +285,12 @@ class TestDataprocHook(unittest.TestCase): self.hook.submit_job(location=GCP_LOCATION, job=JOB, project_id=GCP_PROJECT) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.submit_job.assert_called_once_with( - region=GCP_LOCATION, - job=JOB, - project_id=GCP_PROJECT, - request_id=None, + request=dict( + region=GCP_LOCATION, + job=JOB, + project_id=GCP_PROJECT, + request_id=None, + ), retry=None, timeout=None, metadata=None, @@ -297,9 +312,11 @@ class TestDataprocHook(unittest.TestCase): self.hook.cancel_job(location=GCP_LOCATION, job_id=JOB_ID, project_id=GCP_PROJECT) mock_client.assert_called_once_with(location=GCP_LOCATION) mock_client.return_value.cancel_job.assert_called_once_with( - region=GCP_LOCATION, - job_id=JOB_ID, - project_id=GCP_PROJECT, + request=dict( + region=GCP_LOCATION, + job_id=JOB_ID, + project_id=GCP_PROJECT, + ), retry=None, timeout=None, metadata=None, @@ -311,9 +328,11 @@ class TestDataprocHook(unittest.TestCase): self.hook.cancel_job(job_id=JOB_ID, project_id=GCP_PROJECT) mock_client.assert_called_once_with(location='global') mock_client.return_value.cancel_job.assert_called_once_with( - region='global', - job_id=JOB_ID, - project_id=GCP_PROJECT, + request=dict( + region='global', + job_id=JOB_ID, + project_id=GCP_PROJECT, + ), retry=None, timeout=None, metadata=None, diff --git a/tests/providers/google/cloud/operators/test_dataproc.py b/tests/providers/google/cloud/operators/test_dataproc.py index 8c06ef7..e1c712e 100644 --- a/tests/providers/google/cloud/operators/test_dataproc.py +++ b/tests/providers/google/cloud/operators/test_dataproc.py @@ -217,8 +217,9 @@ class TestDataprocClusterCreateOperator(unittest.TestCase): assert_warning("Default region value", warnings) assert op_default_region.region == 'global' + @mock.patch(DATAPROC_PATH.format("Cluster.to_dict")) @mock.patch(DATAPROC_PATH.format("DataprocHook")) - def test_execute(self, mock_hook): + def test_execute(self, mock_hook, to_dict_mock): op = DataprocCreateClusterOperator( task_id=TASK_ID, region=GCP_LOCATION, @@ -246,9 +247,11 @@ class TestDataprocClusterCreateOperator(unittest.TestCase): timeout=TIMEOUT, metadata=METADATA, ) + to_dict_mock.assert_called_once_with(mock_hook().create_cluster().result()) + @mock.patch(DATAPROC_PATH.format("Cluster.to_dict")) @mock.patch(DATAPROC_PATH.format("DataprocHook")) - def test_execute_if_cluster_exists(self, mock_hook): + def test_execute_if_cluster_exists(self, mock_hook, to_dict_mock): mock_hook.return_value.create_cluster.side_effect = [AlreadyExists("test")] mock_hook.return_value.get_cluster.return_value.status.state = 0 op = DataprocCreateClusterOperator( @@ -286,6 +289,7 @@ class TestDataprocClusterCreateOperator(unittest.TestCase): timeout=TIMEOUT, metadata=METADATA, ) + to_dict_mock.assert_called_once_with(mock_hook.return_value.get_cluster.return_value) @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_execute_if_cluster_exists_do_not_use(self, mock_hook): @@ -313,7 +317,7 @@ class TestDataprocClusterCreateOperator(unittest.TestCase): mock_hook.return_value.create_cluster.side_effect = [AlreadyExists("test")] cluster_status = mock_hook.return_value.get_cluster.return_value.status cluster_status.state = 0 - cluster_status.ERROR = 0 + cluster_status.State.ERROR = 0 op = DataprocCreateClusterOperator( task_id=TASK_ID, @@ -348,11 +352,11 @@ class TestDataprocClusterCreateOperator(unittest.TestCase): ): cluster = mock.MagicMock() cluster.status.state = 0 - cluster.status.DELETING = 0 + cluster.status.State.DELETING = 0 # pylint: disable=no-member cluster2 = mock.MagicMock() cluster2.status.state = 0 - cluster2.status.ERROR = 0 + cluster2.status.State.ERROR = 0 # pylint: disable=no-member mock_create_cluster.side_effect = [AlreadyExists("test"), cluster2] mock_generator.return_value = [0] diff --git a/tests/providers/google/cloud/sensors/test_dataproc.py b/tests/providers/google/cloud/sensors/test_dataproc.py index 1ce8eea..6f2991a 100644 --- a/tests/providers/google/cloud/sensors/test_dataproc.py +++ b/tests/providers/google/cloud/sensors/test_dataproc.py @@ -45,7 +45,7 @@ class TestDataprocJobSensor(unittest.TestCase): @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_done(self, mock_hook): - job = self.create_job(JobStatus.DONE) + job = self.create_job(JobStatus.State.DONE) job_id = "job_id" mock_hook.return_value.get_job.return_value = job @@ -66,7 +66,7 @@ class TestDataprocJobSensor(unittest.TestCase): @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_error(self, mock_hook): - job = self.create_job(JobStatus.ERROR) + job = self.create_job(JobStatus.State.ERROR) job_id = "job_id" mock_hook.return_value.get_job.return_value = job @@ -88,7 +88,7 @@ class TestDataprocJobSensor(unittest.TestCase): @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_wait(self, mock_hook): - job = self.create_job(JobStatus.RUNNING) + job = self.create_job(JobStatus.State.RUNNING) job_id = "job_id" mock_hook.return_value.get_job.return_value = job @@ -109,7 +109,7 @@ class TestDataprocJobSensor(unittest.TestCase): @mock.patch(DATAPROC_PATH.format("DataprocHook")) def test_cancelled(self, mock_hook): - job = self.create_job(JobStatus.CANCELLED) + job = self.create_job(JobStatus.State.CANCELLED) job_id = "job_id" mock_hook.return_value.get_job.return_value = job