This is an automated email from the ASF dual-hosted git repository. kamilbregula pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push: new ef8617e Support google-cloud-tasks>=2.0.0 (#13347) ef8617e is described below commit ef8617ec9d6e4b7c433a29bd388f5102a7a17c11 Author: Kamil BreguĊa <mik-...@users.noreply.github.com> AuthorDate: Thu Jan 14 12:18:49 2021 +0100 Support google-cloud-tasks>=2.0.0 (#13347) --- airflow/providers/google/ADDITIONAL_INFO.md | 4 +- airflow/providers/google/cloud/hooks/tasks.py | 118 +++++++++++---------- airflow/providers/google/cloud/operators/tasks.py | 39 ++++--- setup.py | 2 +- tests/providers/google/cloud/hooks/test_tasks.py | 86 +++++++-------- .../providers/google/cloud/operators/test_tasks.py | 65 ++++++++++-- 6 files changed, 176 insertions(+), 138 deletions(-) diff --git a/airflow/providers/google/ADDITIONAL_INFO.md b/airflow/providers/google/ADDITIONAL_INFO.md index 800703b..c696e1b 100644 --- a/airflow/providers/google/ADDITIONAL_INFO.md +++ b/airflow/providers/google/ADDITIONAL_INFO.md @@ -32,10 +32,10 @@ Details are covered in the UPDATING.md files for each library, but there are som | [``google-cloud-automl``](https://pypi.org/project/google-cloud-automl/) | ``>=0.4.0,<2.0.0`` | ``>=2.1.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-bigquery-automl/blob/master/UPGRADING.md) | | [``google-cloud-bigquery-datatransfer``](https://pypi.org/project/google-cloud-bigquery-datatransfer/) | ``>=0.4.0,<2.0.0`` | ``>=3.0.0,<4.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-bigquery-datatransfer/blob/master/UPGRADING.md) | | [``google-cloud-datacatalog``](https://pypi.org/project/google-cloud-datacatalog/) | ``>=0.5.0,<0.8`` | ``>=3.0.0,<4.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-datacatalog/blob/master/UPGRADING.md) | +| [``google-cloud-kms``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.2.1,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-kms/blob/master/UPGRADING.md) | | [``google-cloud-os-login``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-oslogin/blob/master/UPGRADING.md) | | [``google-cloud-pubsub``](https://pypi.org/project/google-cloud-pubsub/) | ``>=1.0.0,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-pubsub/blob/master/UPGRADING.md) | -| [``google-cloud-kms``](https://pypi.org/project/google-cloud-os-login/) | ``>=1.2.1,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-kms/blob/master/UPGRADING.md) | - +| [``google-cloud-tasks``](https://pypi.org/project/google-cloud-tasks/) | ``>=1.2.1,<2.0.0`` | ``>=2.0.0,<3.0.0`` | [`UPGRADING.md`](https://github.com/googleapis/python-tasks/blob/master/UPGRADING.md) | ### The field names use the snake_case convention diff --git a/airflow/providers/google/cloud/hooks/tasks.py b/airflow/providers/google/cloud/hooks/tasks.py index 1c3223d..633f227 100644 --- a/airflow/providers/google/cloud/hooks/tasks.py +++ b/airflow/providers/google/cloud/hooks/tasks.py @@ -21,11 +21,13 @@ This module contains a CloudTasksHook which allows you to connect to Google Cloud Tasks service, performing actions to queues or tasks. """ + from typing import Dict, List, Optional, Sequence, Tuple, Union from google.api_core.retry import Retry -from google.cloud.tasks_v2 import CloudTasksClient, enums -from google.cloud.tasks_v2.types import FieldMask, Queue, Task +from google.cloud.tasks_v2 import CloudTasksClient +from google.cloud.tasks_v2.types import Queue, Task +from google.protobuf.field_mask_pb2 import FieldMask from airflow.exceptions import AirflowException from airflow.providers.google.common.hooks.base_google import GoogleBaseHook @@ -120,20 +122,19 @@ class CloudTasksHook(GoogleBaseHook): client = self.get_conn() if queue_name: - full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name) + full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}" if isinstance(task_queue, Queue): task_queue.name = full_queue_name elif isinstance(task_queue, dict): task_queue['name'] = full_queue_name else: raise AirflowException('Unable to set queue_name.') - full_location_path = CloudTasksClient.location_path(project_id, location) + full_location_path = f"projects/{project_id}/locations/{location}" return client.create_queue( - parent=full_location_path, - queue=task_queue, + request={'parent': full_location_path, 'queue': task_queue}, retry=retry, timeout=timeout, - metadata=metadata, + metadata=metadata or (), ) @GoogleBaseHook.fallback_to_default_project_id @@ -167,7 +168,7 @@ class CloudTasksHook(GoogleBaseHook): :param update_mask: A mast used to specify which fields of the queue are being updated. If empty, then all fields will be updated. If a dict is provided, it must be of the same form as the protobuf message. - :type update_mask: dict or google.cloud.tasks_v2.types.FieldMask + :type update_mask: dict or google.protobuf.field_mask_pb2.FieldMask :param retry: (Optional) A retry object used to retry requests. If None is specified, requests will not be retried. :type retry: google.api_core.retry.Retry @@ -182,7 +183,7 @@ class CloudTasksHook(GoogleBaseHook): client = self.get_conn() if queue_name and location: - full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name) + full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}" if isinstance(task_queue, Queue): task_queue.name = full_queue_name elif isinstance(task_queue, dict): @@ -190,11 +191,10 @@ class CloudTasksHook(GoogleBaseHook): else: raise AirflowException('Unable to set queue_name.') return client.update_queue( - queue=task_queue, - update_mask=update_mask, + request={'queue': task_queue, 'update_mask': update_mask}, retry=retry, timeout=timeout, - metadata=metadata, + metadata=metadata or (), ) @GoogleBaseHook.fallback_to_default_project_id @@ -230,8 +230,10 @@ class CloudTasksHook(GoogleBaseHook): """ client = self.get_conn() - full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name) - return client.get_queue(name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata) + full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}" + return client.get_queue( + request={'name': full_queue_name}, retry=retry, timeout=timeout, metadata=metadata or () + ) @GoogleBaseHook.fallback_to_default_project_id def list_queues( @@ -270,14 +272,12 @@ class CloudTasksHook(GoogleBaseHook): """ client = self.get_conn() - full_location_path = CloudTasksClient.location_path(project_id, location) + full_location_path = f"projects/{project_id}/locations/{location}" queues = client.list_queues( - parent=full_location_path, - filter_=results_filter, - page_size=page_size, + request={'parent': full_location_path, 'filter': results_filter, 'page_size': page_size}, retry=retry, timeout=timeout, - metadata=metadata, + metadata=metadata or (), ) return list(queues) @@ -313,8 +313,10 @@ class CloudTasksHook(GoogleBaseHook): """ client = self.get_conn() - full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name) - client.delete_queue(name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata) + full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}" + client.delete_queue( + request={'name': full_queue_name}, retry=retry, timeout=timeout, metadata=metadata or () + ) @GoogleBaseHook.fallback_to_default_project_id def purge_queue( @@ -349,8 +351,10 @@ class CloudTasksHook(GoogleBaseHook): """ client = self.get_conn() - full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name) - return client.purge_queue(name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata) + full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}" + return client.purge_queue( + request={'name': full_queue_name}, retry=retry, timeout=timeout, metadata=metadata or () + ) @GoogleBaseHook.fallback_to_default_project_id def pause_queue( @@ -385,8 +389,10 @@ class CloudTasksHook(GoogleBaseHook): """ client = self.get_conn() - full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name) - return client.pause_queue(name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata) + full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}" + return client.pause_queue( + request={'name': full_queue_name}, retry=retry, timeout=timeout, metadata=metadata or () + ) @GoogleBaseHook.fallback_to_default_project_id def resume_queue( @@ -421,8 +427,10 @@ class CloudTasksHook(GoogleBaseHook): """ client = self.get_conn() - full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name) - return client.resume_queue(name=full_queue_name, retry=retry, timeout=timeout, metadata=metadata) + full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}" + return client.resume_queue( + request={'name': full_queue_name}, retry=retry, timeout=timeout, metadata=metadata or () + ) @GoogleBaseHook.fallback_to_default_project_id def create_task( @@ -432,7 +440,7 @@ class CloudTasksHook(GoogleBaseHook): task: Union[Dict, Task], project_id: str, task_name: Optional[str] = None, - response_view: Optional[enums.Task.View] = None, + response_view: Optional = None, retry: Optional[Retry] = None, timeout: Optional[float] = None, metadata: Optional[Sequence[Tuple[str, str]]] = None, @@ -455,7 +463,7 @@ class CloudTasksHook(GoogleBaseHook): :type task_name: str :param response_view: (Optional) This field specifies which subset of the Task will be returned. - :type response_view: google.cloud.tasks_v2.enums.Task.View + :type response_view: google.cloud.tasks_v2.Task.View :param retry: (Optional) A retry object used to retry requests. If None is specified, requests will not be retried. :type retry: google.api_core.retry.Retry @@ -470,21 +478,21 @@ class CloudTasksHook(GoogleBaseHook): client = self.get_conn() if task_name: - full_task_name = CloudTasksClient.task_path(project_id, location, queue_name, task_name) + full_task_name = ( + f"projects/{project_id}/locations/{location}/queues/{queue_name}/tasks/{task_name}" + ) if isinstance(task, Task): task.name = full_task_name elif isinstance(task, dict): task['name'] = full_task_name else: raise AirflowException('Unable to set task_name.') - full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name) + full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}" return client.create_task( - parent=full_queue_name, - task=task, - response_view=response_view, + request={'parent': full_queue_name, 'task': task, 'response_view': response_view}, retry=retry, timeout=timeout, - metadata=metadata, + metadata=metadata or (), ) @GoogleBaseHook.fallback_to_default_project_id @@ -494,7 +502,7 @@ class CloudTasksHook(GoogleBaseHook): queue_name: str, task_name: str, project_id: str, - response_view: Optional[enums.Task.View] = None, + response_view: Optional = None, retry: Optional[Retry] = None, timeout: Optional[float] = None, metadata: Optional[Sequence[Tuple[str, str]]] = None, @@ -513,7 +521,7 @@ class CloudTasksHook(GoogleBaseHook): :type project_id: str :param response_view: (Optional) This field specifies which subset of the Task will be returned. - :type response_view: google.cloud.tasks_v2.enums.Task.View + :type response_view: google.cloud.tasks_v2.Task.View :param retry: (Optional) A retry object used to retry requests. If None is specified, requests will not be retried. :type retry: google.api_core.retry.Retry @@ -527,13 +535,12 @@ class CloudTasksHook(GoogleBaseHook): """ client = self.get_conn() - full_task_name = CloudTasksClient.task_path(project_id, location, queue_name, task_name) + full_task_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}/tasks/{task_name}" return client.get_task( - name=full_task_name, - response_view=response_view, + request={'name': full_task_name, 'response_view': response_view}, retry=retry, timeout=timeout, - metadata=metadata, + metadata=metadata or (), ) @GoogleBaseHook.fallback_to_default_project_id @@ -542,7 +549,7 @@ class CloudTasksHook(GoogleBaseHook): location: str, queue_name: str, project_id: str, - response_view: Optional[enums.Task.View] = None, + response_view: Optional = None, page_size: Optional[int] = None, retry: Optional[Retry] = None, timeout: Optional[float] = None, @@ -560,7 +567,7 @@ class CloudTasksHook(GoogleBaseHook): :type project_id: str :param response_view: (Optional) This field specifies which subset of the Task will be returned. - :type response_view: google.cloud.tasks_v2.enums.Task.View + :type response_view: google.cloud.tasks_v2.Task.View :param page_size: (Optional) The maximum number of resources contained in the underlying API response. :type page_size: int @@ -576,14 +583,12 @@ class CloudTasksHook(GoogleBaseHook): :rtype: list[google.cloud.tasks_v2.types.Task] """ client = self.get_conn() - full_queue_name = CloudTasksClient.queue_path(project_id, location, queue_name) + full_queue_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}" tasks = client.list_tasks( - parent=full_queue_name, - response_view=response_view, - page_size=page_size, + request={'parent': full_queue_name, 'response_view': response_view, 'page_size': page_size}, retry=retry, timeout=timeout, - metadata=metadata, + metadata=metadata or (), ) return list(tasks) @@ -622,8 +627,10 @@ class CloudTasksHook(GoogleBaseHook): """ client = self.get_conn() - full_task_name = CloudTasksClient.task_path(project_id, location, queue_name, task_name) - client.delete_task(name=full_task_name, retry=retry, timeout=timeout, metadata=metadata) + full_task_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}/tasks/{task_name}" + client.delete_task( + request={'name': full_task_name}, retry=retry, timeout=timeout, metadata=metadata or () + ) @GoogleBaseHook.fallback_to_default_project_id def run_task( @@ -632,7 +639,7 @@ class CloudTasksHook(GoogleBaseHook): queue_name: str, task_name: str, project_id: str, - response_view: Optional[enums.Task.View] = None, + response_view: Optional = None, retry: Optional[Retry] = None, timeout: Optional[float] = None, metadata: Optional[Sequence[Tuple[str, str]]] = None, @@ -651,7 +658,7 @@ class CloudTasksHook(GoogleBaseHook): :type project_id: str :param response_view: (Optional) This field specifies which subset of the Task will be returned. - :type response_view: google.cloud.tasks_v2.enums.Task.View + :type response_view: google.cloud.tasks_v2.Task.View :param retry: (Optional) A retry object used to retry requests. If None is specified, requests will not be retried. :type retry: google.api_core.retry.Retry @@ -665,11 +672,10 @@ class CloudTasksHook(GoogleBaseHook): """ client = self.get_conn() - full_task_name = CloudTasksClient.task_path(project_id, location, queue_name, task_name) + full_task_name = f"projects/{project_id}/locations/{location}/queues/{queue_name}/tasks/{task_name}" return client.run_task( - name=full_task_name, - response_view=response_view, + request={'name': full_task_name, 'response_view': response_view}, retry=retry, timeout=timeout, - metadata=metadata, + metadata=metadata or (), ) diff --git a/airflow/providers/google/cloud/operators/tasks.py b/airflow/providers/google/cloud/operators/tasks.py index 7867d29..2834b32 100644 --- a/airflow/providers/google/cloud/operators/tasks.py +++ b/airflow/providers/google/cloud/operators/tasks.py @@ -25,9 +25,8 @@ from typing import Dict, Optional, Sequence, Tuple, Union from google.api_core.exceptions import AlreadyExists from google.api_core.retry import Retry -from google.cloud.tasks_v2 import enums -from google.cloud.tasks_v2.types import FieldMask, Queue, Task -from google.protobuf.json_format import MessageToDict +from google.cloud.tasks_v2.types import Queue, Task +from google.protobuf.field_mask_pb2 import FieldMask from airflow.models import BaseOperator from airflow.providers.google.cloud.hooks.tasks import CloudTasksHook @@ -136,7 +135,7 @@ class CloudTasksQueueCreateOperator(BaseOperator): metadata=self.metadata, ) - return MessageToDict(queue) + return Queue.to_dict(queue) class CloudTasksQueueUpdateOperator(BaseOperator): @@ -159,7 +158,7 @@ class CloudTasksQueueUpdateOperator(BaseOperator): :param update_mask: A mast used to specify which fields of the queue are being updated. If empty, then all fields will be updated. If a dict is provided, it must be of the same form as the protobuf message. - :type update_mask: dict or google.cloud.tasks_v2.types.FieldMask + :type update_mask: dict or google.protobuf.field_mask_pb2.FieldMask :param retry: (Optional) A retry object used to retry requests. If None is specified, requests will not be retried. :type retry: google.api_core.retry.Retry @@ -237,7 +236,7 @@ class CloudTasksQueueUpdateOperator(BaseOperator): timeout=self.timeout, metadata=self.metadata, ) - return MessageToDict(queue) + return Queue.to_dict(queue) class CloudTasksQueueGetOperator(BaseOperator): @@ -320,7 +319,7 @@ class CloudTasksQueueGetOperator(BaseOperator): timeout=self.timeout, metadata=self.metadata, ) - return MessageToDict(queue) + return Queue.to_dict(queue) class CloudTasksQueuesListOperator(BaseOperator): @@ -408,7 +407,7 @@ class CloudTasksQueuesListOperator(BaseOperator): timeout=self.timeout, metadata=self.metadata, ) - return [MessageToDict(q) for q in queues] + return [Queue.to_dict(q) for q in queues] class CloudTasksQueueDeleteOperator(BaseOperator): @@ -571,7 +570,7 @@ class CloudTasksQueuePurgeOperator(BaseOperator): timeout=self.timeout, metadata=self.metadata, ) - return MessageToDict(queue) + return Queue.to_dict(queue) class CloudTasksQueuePauseOperator(BaseOperator): @@ -654,7 +653,7 @@ class CloudTasksQueuePauseOperator(BaseOperator): timeout=self.timeout, metadata=self.metadata, ) - return MessageToDict(queue) + return Queue.to_dict(queue) class CloudTasksQueueResumeOperator(BaseOperator): @@ -737,7 +736,7 @@ class CloudTasksQueueResumeOperator(BaseOperator): timeout=self.timeout, metadata=self.metadata, ) - return MessageToDict(queue) + return Queue.to_dict(queue) class CloudTasksTaskCreateOperator(BaseOperator): @@ -803,7 +802,7 @@ class CloudTasksTaskCreateOperator(BaseOperator): task: Union[Dict, Task], project_id: Optional[str] = None, task_name: Optional[str] = None, - response_view: Optional[enums.Task.View] = None, + response_view: Optional = None, retry: Optional[Retry] = None, timeout: Optional[float] = None, metadata: Optional[MetaData] = None, @@ -840,7 +839,7 @@ class CloudTasksTaskCreateOperator(BaseOperator): timeout=self.timeout, metadata=self.metadata, ) - return MessageToDict(task) + return Task.to_dict(task) class CloudTasksTaskGetOperator(BaseOperator): @@ -900,7 +899,7 @@ class CloudTasksTaskGetOperator(BaseOperator): queue_name: str, task_name: str, project_id: Optional[str] = None, - response_view: Optional[enums.Task.View] = None, + response_view: Optional = None, retry: Optional[Retry] = None, timeout: Optional[float] = None, metadata: Optional[MetaData] = None, @@ -935,7 +934,7 @@ class CloudTasksTaskGetOperator(BaseOperator): timeout=self.timeout, metadata=self.metadata, ) - return MessageToDict(task) + return Task.to_dict(task) class CloudTasksTasksListOperator(BaseOperator): @@ -994,7 +993,7 @@ class CloudTasksTasksListOperator(BaseOperator): location: str, queue_name: str, project_id: Optional[str] = None, - response_view: Optional[enums.Task.View] = None, + response_view: Optional = None, page_size: Optional[int] = None, retry: Optional[Retry] = None, timeout: Optional[float] = None, @@ -1030,7 +1029,7 @@ class CloudTasksTasksListOperator(BaseOperator): timeout=self.timeout, metadata=self.metadata, ) - return [MessageToDict(t) for t in tasks] + return [Task.to_dict(t) for t in tasks] class CloudTasksTaskDeleteOperator(BaseOperator): @@ -1134,7 +1133,7 @@ class CloudTasksTaskRunOperator(BaseOperator): :type project_id: str :param response_view: (Optional) This field specifies which subset of the Task will be returned. - :type response_view: google.cloud.tasks_v2.enums.Task.View + :type response_view: google.cloud.tasks_v2.Task.View :param retry: (Optional) A retry object used to retry requests. If None is specified, requests will not be retried. :type retry: google.api_core.retry.Retry @@ -1176,7 +1175,7 @@ class CloudTasksTaskRunOperator(BaseOperator): queue_name: str, task_name: str, project_id: Optional[str] = None, - response_view: Optional[enums.Task.View] = None, + response_view: Optional = None, retry: Optional[Retry] = None, timeout: Optional[float] = None, metadata: Optional[MetaData] = None, @@ -1211,4 +1210,4 @@ class CloudTasksTaskRunOperator(BaseOperator): timeout=self.timeout, metadata=self.metadata, ) - return MessageToDict(task) + return Task.to_dict(task) diff --git a/setup.py b/setup.py index 6a5c07e..16a4e2a 100644 --- a/setup.py +++ b/setup.py @@ -300,7 +300,7 @@ google = [ 'google-cloud-spanner>=1.10.0,<2.0.0', 'google-cloud-speech>=0.36.3,<2.0.0', 'google-cloud-storage>=1.30,<2.0.0', - 'google-cloud-tasks>=1.2.1,<2.0.0', + 'google-cloud-tasks>=2.0.0,<3.0.0', 'google-cloud-texttospeech>=0.4.0,<2.0.0', 'google-cloud-translate>=1.5.0,<2.0.0', 'google-cloud-videointelligence>=1.7.0,<2.0.0', diff --git a/tests/providers/google/cloud/hooks/test_tasks.py b/tests/providers/google/cloud/hooks/test_tasks.py index 8be6686..6504595 100644 --- a/tests/providers/google/cloud/hooks/test_tasks.py +++ b/tests/providers/google/cloud/hooks/test_tasks.py @@ -72,11 +72,10 @@ class TestCloudTasksHook(unittest.TestCase): self.assertIs(result, API_RESPONSE) get_conn.return_value.create_queue.assert_called_once_with( - parent=FULL_LOCATION_PATH, - queue=Queue(name=FULL_QUEUE_PATH), + request=dict(parent=FULL_LOCATION_PATH, queue=Queue(name=FULL_QUEUE_PATH)), retry=None, timeout=None, - metadata=None, + metadata=(), ) @mock.patch( @@ -94,11 +93,10 @@ class TestCloudTasksHook(unittest.TestCase): self.assertIs(result, API_RESPONSE) get_conn.return_value.update_queue.assert_called_once_with( - queue=Queue(name=FULL_QUEUE_PATH, state=3), - update_mask=None, + request=dict(queue=Queue(name=FULL_QUEUE_PATH, state=3), update_mask=None), retry=None, timeout=None, - metadata=None, + metadata=(), ) @mock.patch( @@ -111,30 +109,28 @@ class TestCloudTasksHook(unittest.TestCase): self.assertIs(result, API_RESPONSE) get_conn.return_value.get_queue.assert_called_once_with( - name=FULL_QUEUE_PATH, retry=None, timeout=None, metadata=None + request=dict(name=FULL_QUEUE_PATH), retry=None, timeout=None, metadata=() ) @mock.patch( "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.get_conn", - **{"return_value.list_queues.return_value": API_RESPONSE}, # type: ignore + **{"return_value.list_queues.return_value": [Queue(name=FULL_QUEUE_PATH)]}, # type: ignore ) def test_list_queues(self, get_conn): result = self.hook.list_queues(location=LOCATION, project_id=PROJECT_ID) - self.assertEqual(result, list(API_RESPONSE)) + self.assertEqual(result, [Queue(name=FULL_QUEUE_PATH)]) get_conn.return_value.list_queues.assert_called_once_with( - parent=FULL_LOCATION_PATH, - filter_=None, - page_size=None, + request=dict(parent=FULL_LOCATION_PATH, filter=None, page_size=None), retry=None, timeout=None, - metadata=None, + metadata=(), ) @mock.patch( "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.get_conn", - **{"return_value.delete_queue.return_value": API_RESPONSE}, # type: ignore + **{"return_value.delete_queue.return_value": None}, # type: ignore ) def test_delete_queue(self, get_conn): result = self.hook.delete_queue(location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID) @@ -142,51 +138,51 @@ class TestCloudTasksHook(unittest.TestCase): self.assertEqual(result, None) get_conn.return_value.delete_queue.assert_called_once_with( - name=FULL_QUEUE_PATH, retry=None, timeout=None, metadata=None + request=dict(name=FULL_QUEUE_PATH), retry=None, timeout=None, metadata=() ) @mock.patch( "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.get_conn", - **{"return_value.purge_queue.return_value": API_RESPONSE}, # type: ignore + **{"return_value.purge_queue.return_value": Queue(name=FULL_QUEUE_PATH)}, # type: ignore ) def test_purge_queue(self, get_conn): result = self.hook.purge_queue(location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID) - self.assertEqual(result, API_RESPONSE) + self.assertEqual(result, Queue(name=FULL_QUEUE_PATH)) get_conn.return_value.purge_queue.assert_called_once_with( - name=FULL_QUEUE_PATH, retry=None, timeout=None, metadata=None + request=dict(name=FULL_QUEUE_PATH), retry=None, timeout=None, metadata=() ) @mock.patch( "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.get_conn", - **{"return_value.pause_queue.return_value": API_RESPONSE}, # type: ignore + **{"return_value.pause_queue.return_value": Queue(name=FULL_QUEUE_PATH)}, # type: ignore ) def test_pause_queue(self, get_conn): result = self.hook.pause_queue(location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID) - self.assertEqual(result, API_RESPONSE) + self.assertEqual(result, Queue(name=FULL_QUEUE_PATH)) get_conn.return_value.pause_queue.assert_called_once_with( - name=FULL_QUEUE_PATH, retry=None, timeout=None, metadata=None + request=dict(name=FULL_QUEUE_PATH), retry=None, timeout=None, metadata=() ) @mock.patch( "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.get_conn", - **{"return_value.resume_queue.return_value": API_RESPONSE}, # type: ignore + **{"return_value.resume_queue.return_value": Queue(name=FULL_QUEUE_PATH)}, # type: ignore ) def test_resume_queue(self, get_conn): result = self.hook.resume_queue(location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID) - self.assertEqual(result, API_RESPONSE) + self.assertEqual(result, Queue(name=FULL_QUEUE_PATH)) get_conn.return_value.resume_queue.assert_called_once_with( - name=FULL_QUEUE_PATH, retry=None, timeout=None, metadata=None + request=dict(name=FULL_QUEUE_PATH), retry=None, timeout=None, metadata=() ) @mock.patch( "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.get_conn", - **{"return_value.create_task.return_value": API_RESPONSE}, # type: ignore + **{"return_value.create_task.return_value": Task(name=FULL_TASK_PATH)}, # type: ignore ) def test_create_task(self, get_conn): result = self.hook.create_task( @@ -197,20 +193,18 @@ class TestCloudTasksHook(unittest.TestCase): task_name=TASK_NAME, ) - self.assertEqual(result, API_RESPONSE) + self.assertEqual(result, Task(name=FULL_TASK_PATH)) get_conn.return_value.create_task.assert_called_once_with( - parent=FULL_QUEUE_PATH, - task=Task(name=FULL_TASK_PATH), - response_view=None, + request=dict(parent=FULL_QUEUE_PATH, task=Task(name=FULL_TASK_PATH), response_view=None), retry=None, timeout=None, - metadata=None, + metadata=(), ) @mock.patch( "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.get_conn", - **{"return_value.get_task.return_value": API_RESPONSE}, # type: ignore + **{"return_value.get_task.return_value": Task(name=FULL_TASK_PATH)}, # type: ignore ) def test_get_task(self, get_conn): result = self.hook.get_task( @@ -220,37 +214,34 @@ class TestCloudTasksHook(unittest.TestCase): project_id=PROJECT_ID, ) - self.assertEqual(result, API_RESPONSE) + self.assertEqual(result, Task(name=FULL_TASK_PATH)) get_conn.return_value.get_task.assert_called_once_with( - name=FULL_TASK_PATH, - response_view=None, + request=dict(name=FULL_TASK_PATH, response_view=None), retry=None, timeout=None, - metadata=None, + metadata=(), ) @mock.patch( "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.get_conn", - **{"return_value.list_tasks.return_value": API_RESPONSE}, # type: ignore + **{"return_value.list_tasks.return_value": [Task(name=FULL_TASK_PATH)]}, # type: ignore ) def test_list_tasks(self, get_conn): result = self.hook.list_tasks(location=LOCATION, queue_name=QUEUE_ID, project_id=PROJECT_ID) - self.assertEqual(result, list(API_RESPONSE)) + self.assertEqual(result, [Task(name=FULL_TASK_PATH)]) get_conn.return_value.list_tasks.assert_called_once_with( - parent=FULL_QUEUE_PATH, - response_view=None, - page_size=None, + request=dict(parent=FULL_QUEUE_PATH, response_view=None, page_size=None), retry=None, timeout=None, - metadata=None, + metadata=(), ) @mock.patch( "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.get_conn", - **{"return_value.delete_task.return_value": API_RESPONSE}, # type: ignore + **{"return_value.delete_task.return_value": None}, # type: ignore ) def test_delete_task(self, get_conn): result = self.hook.delete_task( @@ -263,12 +254,12 @@ class TestCloudTasksHook(unittest.TestCase): self.assertEqual(result, None) get_conn.return_value.delete_task.assert_called_once_with( - name=FULL_TASK_PATH, retry=None, timeout=None, metadata=None + request=dict(name=FULL_TASK_PATH), retry=None, timeout=None, metadata=() ) @mock.patch( "airflow.providers.google.cloud.hooks.tasks.CloudTasksHook.get_conn", - **{"return_value.run_task.return_value": API_RESPONSE}, # type: ignore + **{"return_value.run_task.return_value": Task(name=FULL_TASK_PATH)}, # type: ignore ) def test_run_task(self, get_conn): result = self.hook.run_task( @@ -278,12 +269,11 @@ class TestCloudTasksHook(unittest.TestCase): project_id=PROJECT_ID, ) - self.assertEqual(result, API_RESPONSE) + self.assertEqual(result, Task(name=FULL_TASK_PATH)) get_conn.return_value.run_task.assert_called_once_with( - name=FULL_TASK_PATH, - response_view=None, + request=dict(name=FULL_TASK_PATH, response_view=None), retry=None, timeout=None, - metadata=None, + metadata=(), ) diff --git a/tests/providers/google/cloud/operators/test_tasks.py b/tests/providers/google/cloud/operators/test_tasks.py index b7e886d..ed76911 100644 --- a/tests/providers/google/cloud/operators/test_tasks.py +++ b/tests/providers/google/cloud/operators/test_tasks.py @@ -57,7 +57,7 @@ class TestCloudTasksQueueCreate(unittest.TestCase): result = operator.execute(context=None) - self.assertEqual({'name': FULL_QUEUE_PATH}, result) + self.assertEqual({'name': FULL_QUEUE_PATH, 'state': 0}, result) mock_hook.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, @@ -81,7 +81,7 @@ class TestCloudTasksQueueUpdate(unittest.TestCase): result = operator.execute(context=None) - self.assertEqual({'name': FULL_QUEUE_PATH}, result) + self.assertEqual({'name': FULL_QUEUE_PATH, 'state': 0}, result) mock_hook.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, @@ -106,7 +106,7 @@ class TestCloudTasksQueueGet(unittest.TestCase): result = operator.execute(context=None) - self.assertEqual({'name': FULL_QUEUE_PATH}, result) + self.assertEqual({'name': FULL_QUEUE_PATH, 'state': 0}, result) mock_hook.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, @@ -129,7 +129,7 @@ class TestCloudTasksQueuesList(unittest.TestCase): result = operator.execute(context=None) - self.assertEqual([{'name': FULL_QUEUE_PATH}], result) + self.assertEqual([{'name': FULL_QUEUE_PATH, 'state': 0}], result) mock_hook.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, @@ -176,7 +176,7 @@ class TestCloudTasksQueuePurge(unittest.TestCase): result = operator.execute(context=None) - self.assertEqual({'name': FULL_QUEUE_PATH}, result) + self.assertEqual({'name': FULL_QUEUE_PATH, 'state': 0}, result) mock_hook.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, @@ -199,7 +199,7 @@ class TestCloudTasksQueuePause(unittest.TestCase): result = operator.execute(context=None) - self.assertEqual({'name': FULL_QUEUE_PATH}, result) + self.assertEqual({'name': FULL_QUEUE_PATH, 'state': 0}, result) mock_hook.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, @@ -222,7 +222,7 @@ class TestCloudTasksQueueResume(unittest.TestCase): result = operator.execute(context=None) - self.assertEqual({'name': FULL_QUEUE_PATH}, result) + self.assertEqual({'name': FULL_QUEUE_PATH, 'state': 0}, result) mock_hook.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, @@ -247,7 +247,16 @@ class TestCloudTasksTaskCreate(unittest.TestCase): result = operator.execute(context=None) - self.assertEqual({'appEngineHttpRequest': {}}, result) + self.assertEqual( + { + 'app_engine_http_request': {'body': '', 'headers': {}, 'http_method': 0, 'relative_uri': ''}, + 'dispatch_count': 0, + 'name': '', + 'response_count': 0, + 'view': 0, + }, + result, + ) mock_hook.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, @@ -275,7 +284,16 @@ class TestCloudTasksTaskGet(unittest.TestCase): result = operator.execute(context=None) - self.assertEqual({'appEngineHttpRequest': {}}, result) + self.assertEqual( + { + 'app_engine_http_request': {'body': '', 'headers': {}, 'http_method': 0, 'relative_uri': ''}, + 'dispatch_count': 0, + 'name': '', + 'response_count': 0, + 'view': 0, + }, + result, + ) mock_hook.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, @@ -300,7 +318,23 @@ class TestCloudTasksTasksList(unittest.TestCase): result = operator.execute(context=None) - self.assertEqual([{'appEngineHttpRequest': {}}], result) + self.assertEqual( + [ + { + 'app_engine_http_request': { + 'body': '', + 'headers': {}, + 'http_method': 0, + 'relative_uri': '', + }, + 'dispatch_count': 0, + 'name': '', + 'response_count': 0, + 'view': 0, + } + ], + result, + ) mock_hook.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, impersonation_chain=None, @@ -353,7 +387,16 @@ class TestCloudTasksTaskRun(unittest.TestCase): result = operator.execute(context=None) - self.assertEqual({'appEngineHttpRequest': {}}, result) + self.assertEqual( + { + 'app_engine_http_request': {'body': '', 'headers': {}, 'http_method': 0, 'relative_uri': ''}, + 'dispatch_count': 0, + 'name': '', + 'response_count': 0, + 'view': 0, + }, + result, + ) mock_hook.assert_called_once_with( gcp_conn_id=GCP_CONN_ID, impersonation_chain=None,