This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 00867ec8665 fix: Change DataprocAsyncHook parent class to
GoogleBaseAsyncHook (#52981)
00867ec8665 is described below
commit 00867ec8665d0205c7b2ef88e717eb23ac5a1f35
Author: Geonwoo Kim <[email protected]>
AuthorDate: Sat Jul 19 21:14:03 2025 +0900
fix: Change DataprocAsyncHook parent class to GoogleBaseAsyncHook (#52981)
* fix: Change DataprocAsyncHook parent class to GoogleBaseAsyncHook
* fix: hook test_dataproc
* fix: trigger test code
* fix: replace Caplog into mock_log
* fix: remove test_async_cluster_trigger_run_returns_error_event mock_auth
* fix: change pos of parameter and add fallback decorator
---
.../providers/google/cloud/hooks/dataproc.py | 138 ++++++-----
.../tests/unit/google/cloud/hooks/test_dataproc.py | 106 ++++++---
.../unit/google/cloud/triggers/test_dataproc.py | 252 +++++++++++----------
3 files changed, 278 insertions(+), 218 deletions(-)
diff --git
a/providers/google/src/airflow/providers/google/cloud/hooks/dataproc.py
b/providers/google/src/airflow/providers/google/cloud/hooks/dataproc.py
index 6cc57128bab..872673aae6b 100644
--- a/providers/google/src/airflow/providers/google/cloud/hooks/dataproc.py
+++ b/providers/google/src/airflow/providers/google/cloud/hooks/dataproc.py
@@ -47,7 +47,7 @@ from google.cloud.dataproc_v1 import (
from airflow.exceptions import AirflowException
from airflow.providers.google.common.consts import CLIENT_INFO
-from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
+from airflow.providers.google.common.hooks.base_google import
GoogleBaseAsyncHook, GoogleBaseHook
from airflow.version import version as airflow_version
if TYPE_CHECKING:
@@ -1269,7 +1269,7 @@ class DataprocHook(GoogleBaseHook):
return all([word in error_msg for word in key_words])
-class DataprocAsyncHook(GoogleBaseHook):
+class DataprocAsyncHook(GoogleBaseAsyncHook):
"""
Asynchronous interaction with Google Cloud Dataproc APIs.
@@ -1277,6 +1277,8 @@ class DataprocAsyncHook(GoogleBaseHook):
keyword arguments rather than positional.
"""
+ sync_hook_class = DataprocHook
+
def __init__(
self,
gcp_conn_id: str = "google_cloud_default",
@@ -1286,53 +1288,90 @@ class DataprocAsyncHook(GoogleBaseHook):
super().__init__(gcp_conn_id=gcp_conn_id,
impersonation_chain=impersonation_chain, **kwargs)
self._cached_client: JobControllerAsyncClient | None = None
- def get_cluster_client(self, region: str | None = None) ->
ClusterControllerAsyncClient:
+ async def get_cluster_client(self, region: str | None = None) ->
ClusterControllerAsyncClient:
"""Create a ClusterControllerAsyncClient."""
client_options = None
if region and region != "global":
client_options =
ClientOptions(api_endpoint=f"{region}-dataproc.googleapis.com:443")
+ sync_hook = await self.get_sync_hook()
return ClusterControllerAsyncClient(
- credentials=self.get_credentials(), client_info=CLIENT_INFO,
client_options=client_options
+ credentials=sync_hook.get_credentials(), client_info=CLIENT_INFO,
client_options=client_options
)
- def get_template_client(self, region: str | None = None) ->
WorkflowTemplateServiceAsyncClient:
+ async def get_template_client(self, region: str | None = None) ->
WorkflowTemplateServiceAsyncClient:
"""Create a WorkflowTemplateServiceAsyncClient."""
client_options = None
if region and region != "global":
client_options =
ClientOptions(api_endpoint=f"{region}-dataproc.googleapis.com:443")
+ sync_hook = await self.get_sync_hook()
return WorkflowTemplateServiceAsyncClient(
- credentials=self.get_credentials(), client_info=CLIENT_INFO,
client_options=client_options
+ credentials=sync_hook.get_credentials(), client_info=CLIENT_INFO,
client_options=client_options
)
- def get_job_client(self, region: str | None = None) ->
JobControllerAsyncClient:
+ async def get_job_client(self, region: str | None = None) ->
JobControllerAsyncClient:
"""Create a JobControllerAsyncClient."""
if self._cached_client is None:
client_options = None
if region and region != "global":
client_options =
ClientOptions(api_endpoint=f"{region}-dataproc.googleapis.com:443")
+ sync_hook = await self.get_sync_hook()
self._cached_client = JobControllerAsyncClient(
- credentials=self.get_credentials(),
+ credentials=sync_hook.get_credentials(),
client_info=CLIENT_INFO,
client_options=client_options,
)
return self._cached_client
- def get_batch_client(self, region: str | None = None) ->
BatchControllerAsyncClient:
+ async def get_batch_client(self, region: str | None = None) ->
BatchControllerAsyncClient:
"""Create a BatchControllerAsyncClient."""
client_options = None
if region and region != "global":
client_options =
ClientOptions(api_endpoint=f"{region}-dataproc.googleapis.com:443")
+ sync_hook = await self.get_sync_hook()
return BatchControllerAsyncClient(
- credentials=self.get_credentials(), client_info=CLIENT_INFO,
client_options=client_options
+ credentials=sync_hook.get_credentials(), client_info=CLIENT_INFO,
client_options=client_options
)
- def get_operations_client(self, region: str) -> OperationsClient:
+ async def get_operations_client(self, region: str) -> OperationsClient:
"""Create a OperationsClient."""
- return
self.get_template_client(region=region).transport.operations_client
+ template_client = await self.get_template_client(region=region)
+ return template_client.transport.operations_client
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ async def get_cluster(
+ self,
+ region: str,
+ cluster_name: str,
+ project_id: str,
+ retry: AsyncRetry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> Cluster:
+ """
+ Get a cluster.
+
+ :param region: Cloud Dataproc region in which to handle the request.
+ :param cluster_name: Name of the cluster to get.
+ :param project_id: Google Cloud project ID that the cluster belongs to.
+ :param retry: A retry object used to retry requests. If *None*,
requests
+ will not be retried.
+ :param timeout: The amount of time, in seconds, to wait for the request
+ to complete. If *retry* is specified, the timeout applies to each
+ individual attempt.
+ :param metadata: Additional metadata that is provided to the method.
+ """
+ client = await self.get_cluster_client(region=region)
+ result = await client.get_cluster(
+ request={"project_id": project_id, "region": region,
"cluster_name": cluster_name},
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
@GoogleBaseHook.fallback_to_default_project_id
async def create_cluster(
@@ -1390,7 +1429,7 @@ class DataprocAsyncHook(GoogleBaseHook):
cluster["config"] = cluster_config # type: ignore
cluster["labels"] = labels # type: ignore
- client = self.get_cluster_client(region=region)
+ client = await self.get_cluster_client(region=region)
result = await client.create_cluster(
request={
"project_id": project_id,
@@ -1435,7 +1474,7 @@ class DataprocAsyncHook(GoogleBaseHook):
individual attempt.
:param metadata: Additional metadata that is provided to the method.
"""
- client = self.get_cluster_client(region=region)
+ client = await self.get_cluster_client(region=region)
result = await client.delete_cluster(
request={
"project_id": project_id,
@@ -1483,7 +1522,7 @@ class DataprocAsyncHook(GoogleBaseHook):
individual attempt.
:param metadata: Additional metadata that is provided to the method.
"""
- client = self.get_cluster_client(region=region)
+ client = await self.get_cluster_client(region=region)
result = await client.diagnose_cluster(
request={
"project_id": project_id,
@@ -1500,38 +1539,6 @@ class DataprocAsyncHook(GoogleBaseHook):
)
return result
- @GoogleBaseHook.fallback_to_default_project_id
- async def get_cluster(
- self,
- region: str,
- cluster_name: str,
- project_id: str,
- retry: AsyncRetry | _MethodDefault = DEFAULT,
- timeout: float | None = None,
- metadata: Sequence[tuple[str, str]] = (),
- ) -> Cluster:
- """
- Get the resource representation for a cluster in a project.
-
- :param project_id: Google Cloud project ID that the cluster belongs to.
- :param region: Cloud Dataproc region to handle the request.
- :param cluster_name: The cluster name.
- :param retry: A retry object used to retry requests. If *None*,
requests
- will not be retried.
- :param timeout: The amount of time, in seconds, to wait for the request
- to complete. If *retry* is specified, the timeout applies to each
- individual attempt.
- :param metadata: Additional metadata that is provided to the method.
- """
- client = self.get_cluster_client(region=region)
- result = await client.get_cluster(
- request={"project_id": project_id, "region": region,
"cluster_name": cluster_name},
- retry=retry,
- timeout=timeout,
- metadata=metadata,
- )
- return result
-
@GoogleBaseHook.fallback_to_default_project_id
async def list_clusters(
self,
@@ -1561,7 +1568,7 @@ class DataprocAsyncHook(GoogleBaseHook):
individual attempt.
:param metadata: Additional metadata that is provided to the method.
"""
- client = self.get_cluster_client(region=region)
+ client = await self.get_cluster_client(region=region)
result = await client.list_clusters(
request={"project_id": project_id, "region": region, "filter":
filter_, "page_size": page_size},
retry=retry,
@@ -1638,7 +1645,7 @@ class DataprocAsyncHook(GoogleBaseHook):
"""
if region is None:
raise TypeError("missing 1 required keyword argument: 'region'")
- client = self.get_cluster_client(region=region)
+ client = await self.get_cluster_client(region=region)
operation = await client.update_cluster(
request={
"project_id": project_id,
@@ -1680,10 +1687,8 @@ class DataprocAsyncHook(GoogleBaseHook):
individual attempt.
:param metadata: Additional metadata that is provided to the method.
"""
- if region is None:
- raise TypeError("missing 1 required keyword argument: 'region'")
metadata = metadata or ()
- client = self.get_template_client(region)
+ client = await self.get_template_client(region)
parent = f"projects/{project_id}/regions/{region}"
return await client.create_workflow_template(
request={"parent": parent, "template": template}, retry=retry,
timeout=timeout, metadata=metadata
@@ -1725,10 +1730,8 @@ class DataprocAsyncHook(GoogleBaseHook):
individual attempt.
:param metadata: Additional metadata that is provided to the method.
"""
- if region is None:
- raise TypeError("missing 1 required keyword argument: 'region'")
metadata = metadata or ()
- client = self.get_template_client(region)
+ client = await self.get_template_client(region)
name =
f"projects/{project_id}/regions/{region}/workflowTemplates/{template_name}"
operation = await client.instantiate_workflow_template(
request={"name": name, "version": version, "request_id":
request_id, "parameters": parameters},
@@ -1767,10 +1770,8 @@ class DataprocAsyncHook(GoogleBaseHook):
individual attempt.
:param metadata: Additional metadata that is provided to the method.
"""
- if region is None:
- raise TypeError("missing 1 required keyword argument: 'region'")
metadata = metadata or ()
- client = self.get_template_client(region)
+ client = await self.get_template_client(region)
parent = f"projects/{project_id}/regions/{region}"
operation = await client.instantiate_inline_workflow_template(
request={"parent": parent, "template": template, "request_id":
request_id},
@@ -1781,7 +1782,8 @@ class DataprocAsyncHook(GoogleBaseHook):
return operation
async def get_operation(self, region, operation_name):
- return await
self.get_operations_client(region).get_operation(name=operation_name)
+ operations_client = await self.get_operations_client(region)
+ return await operations_client.get_operation(name=operation_name)
@GoogleBaseHook.fallback_to_default_project_id
async def get_job(
@@ -1806,9 +1808,7 @@ class DataprocAsyncHook(GoogleBaseHook):
individual attempt.
:param metadata: Additional metadata that is provided to the method.
"""
- if region is None:
- raise TypeError("missing 1 required keyword argument: 'region'")
- client = self.get_job_client(region=region)
+ client = await self.get_job_client(region=region)
job = await client.get_job(
request={"project_id": project_id, "region": region, "job_id":
job_id},
retry=retry,
@@ -1845,9 +1845,7 @@ class DataprocAsyncHook(GoogleBaseHook):
individual attempt.
:param metadata: Additional metadata that is provided to the method.
"""
- if region is None:
- raise TypeError("missing 1 required keyword argument: 'region'")
- client = self.get_job_client(region=region)
+ client = await self.get_job_client(region=region)
return await client.submit_job(
request={"project_id": project_id, "region": region, "job": job,
"request_id": request_id},
retry=retry,
@@ -1878,7 +1876,7 @@ class DataprocAsyncHook(GoogleBaseHook):
individual attempt.
:param metadata: Additional metadata that is provided to the method.
"""
- client = self.get_job_client(region=region)
+ client = await self.get_job_client(region=region)
job = await client.cancel_job(
request={"project_id": project_id, "region": region, "job_id":
job_id},
@@ -1920,7 +1918,7 @@ class DataprocAsyncHook(GoogleBaseHook):
individual attempt.
:param metadata: Additional metadata that is provided to the method.
"""
- client = self.get_batch_client(region)
+ client = await self.get_batch_client(region)
parent = f"projects/{project_id}/regions/{region}"
result = await client.create_batch(
@@ -1959,7 +1957,7 @@ class DataprocAsyncHook(GoogleBaseHook):
individual attempt.
:param metadata: Additional metadata that is provided to the method.
"""
- client = self.get_batch_client(region)
+ client = await self.get_batch_client(region)
name = f"projects/{project_id}/locations/{region}/batches/{batch_id}"
await client.delete_batch(
@@ -1994,7 +1992,7 @@ class DataprocAsyncHook(GoogleBaseHook):
individual attempt.
:param metadata: Additional metadata that is provided to the method.
"""
- client = self.get_batch_client(region)
+ client = await self.get_batch_client(region)
name = f"projects/{project_id}/locations/{region}/batches/{batch_id}"
result = await client.get_batch(
@@ -2039,7 +2037,7 @@ class DataprocAsyncHook(GoogleBaseHook):
:param filter: Result filters as specified in ListBatchesRequest
:param order_by: How to order results as specified in
ListBatchesRequest
"""
- client = self.get_batch_client(region)
+ client = await self.get_batch_client(region)
parent = f"projects/{project_id}/regions/{region}"
result = await client.list_batches(
diff --git a/providers/google/tests/unit/google/cloud/hooks/test_dataproc.py
b/providers/google/tests/unit/google/cloud/hooks/test_dataproc.py
index 1e0349d8bfd..786008640df 100644
--- a/providers/google/tests/unit/google/cloud/hooks/test_dataproc.py
+++ b/providers/google/tests/unit/google/cloud/hooks/test_dataproc.py
@@ -601,82 +601,124 @@ class TestDataprocAsyncHook:
with mock.patch(BASE_STRING.format("GoogleBaseHook.__init__"),
new=mock_init):
self.hook = DataprocAsyncHook(gcp_conn_id="test")
- @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_credentials"))
+ @pytest.mark.asyncio
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_sync_hook"))
@mock.patch(DATAPROC_STRING.format("ClusterControllerAsyncClient"))
- def test_get_cluster_client(self, mock_client, mock_get_credentials):
- self.hook.get_cluster_client(region=GCP_LOCATION)
+ async def test_get_cluster_client(self, mock_client, mock_get_sync_hook):
+ mock_sync_hook = mock.MagicMock()
+ mock_get_sync_hook.return_value = mock_sync_hook
+ mock_sync_hook.get_credentials.return_value = mock.MagicMock()
+
+ await self.hook.get_cluster_client(region=GCP_LOCATION)
mock_client.assert_called_once_with(
- credentials=mock_get_credentials.return_value,
+ credentials=mock_sync_hook.get_credentials.return_value,
client_info=CLIENT_INFO,
client_options=None,
)
- @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_credentials"))
+ @pytest.mark.asyncio
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_sync_hook"))
@mock.patch(DATAPROC_STRING.format("ClusterControllerAsyncClient"))
- def test_get_cluster_client_region(self, mock_client,
mock_get_credentials):
- self.hook.get_cluster_client(region="region1")
+ async def test_get_cluster_client_region(self, mock_client,
mock_get_sync_hook):
+ mock_sync_hook = mock.MagicMock()
+ mock_get_sync_hook.return_value = mock_sync_hook
+ mock_sync_hook.get_credentials.return_value = mock.MagicMock()
+
+ await self.hook.get_cluster_client(region="region1")
mock_client.assert_called_once_with(
- credentials=mock_get_credentials.return_value,
+ credentials=mock_sync_hook.get_credentials.return_value,
client_info=CLIENT_INFO,
client_options=ANY,
)
- @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_credentials"))
+ @pytest.mark.asyncio
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_sync_hook"))
@mock.patch(DATAPROC_STRING.format("WorkflowTemplateServiceAsyncClient"))
- def test_get_template_client_global(self, mock_client,
mock_get_credentials):
- _ = self.hook.get_template_client()
+ async def test_get_template_client_global(self, mock_client,
mock_get_sync_hook):
+ mock_sync_hook = mock.MagicMock()
+ mock_get_sync_hook.return_value = mock_sync_hook
+ mock_sync_hook.get_credentials.return_value = mock.MagicMock()
+
+ _ = await self.hook.get_template_client()
mock_client.assert_called_once_with(
- credentials=mock_get_credentials.return_value,
+ credentials=mock_sync_hook.get_credentials.return_value,
client_info=CLIENT_INFO,
client_options=None,
)
- @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_credentials"))
+ @pytest.mark.asyncio
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_sync_hook"))
@mock.patch(DATAPROC_STRING.format("WorkflowTemplateServiceAsyncClient"))
- def test_get_template_client_region(self, mock_client,
mock_get_credentials):
- _ = self.hook.get_template_client(region="region1")
+ async def test_get_template_client_region(self, mock_client,
mock_get_sync_hook):
+ mock_sync_hook = mock.MagicMock()
+ mock_get_sync_hook.return_value = mock_sync_hook
+ mock_sync_hook.get_credentials.return_value = mock.MagicMock()
+
+ _ = await self.hook.get_template_client(region="region1")
mock_client.assert_called_once_with(
- credentials=mock_get_credentials.return_value,
+ credentials=mock_sync_hook.get_credentials.return_value,
client_info=CLIENT_INFO,
client_options=ANY,
)
- @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_credentials"))
+ @pytest.mark.asyncio
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_sync_hook"))
@mock.patch(DATAPROC_STRING.format("JobControllerAsyncClient"))
- def test_get_job_client(self, mock_client, mock_get_credentials):
- self.hook.get_job_client(region=GCP_LOCATION)
+ async def test_get_job_client(self, mock_client, mock_get_sync_hook):
+ mock_sync_hook = mock.MagicMock()
+ mock_get_sync_hook.return_value = mock_sync_hook
+ mock_sync_hook.get_credentials.return_value = mock.MagicMock()
+
+ await self.hook.get_job_client(region=GCP_LOCATION)
mock_client.assert_called_once_with(
- credentials=mock_get_credentials.return_value,
+ credentials=mock_sync_hook.get_credentials.return_value,
client_info=CLIENT_INFO,
client_options=None,
)
- @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_credentials"))
+ @pytest.mark.asyncio
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_sync_hook"))
@mock.patch(DATAPROC_STRING.format("JobControllerAsyncClient"))
- def test_get_job_client_region(self, mock_client, mock_get_credentials):
- self.hook.get_job_client(region="region1")
+ async def test_get_job_client_region(self, mock_client,
mock_get_sync_hook):
+ mock_sync_hook = mock.MagicMock()
+ mock_get_sync_hook.return_value = mock_sync_hook
+ mock_sync_hook.get_credentials.return_value = mock.MagicMock()
+
+ await self.hook.get_job_client(region="region1")
mock_client.assert_called_once_with(
- credentials=mock_get_credentials.return_value,
+ credentials=mock_sync_hook.get_credentials.return_value,
client_info=CLIENT_INFO,
client_options=ANY,
)
- @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_credentials"))
+ @pytest.mark.asyncio
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_sync_hook"))
@mock.patch(DATAPROC_STRING.format("BatchControllerAsyncClient"))
- def test_get_batch_client(self, mock_client, mock_get_credentials):
- self.hook.get_batch_client(region=GCP_LOCATION)
+ async def test_get_batch_client(self, mock_client, mock_get_sync_hook):
+ mock_sync_hook = mock.MagicMock()
+ mock_get_sync_hook.return_value = mock_sync_hook
+ mock_sync_hook.get_credentials.return_value = mock.MagicMock()
+
+ await self.hook.get_batch_client(region=GCP_LOCATION)
mock_client.assert_called_once_with(
- credentials=mock_get_credentials.return_value,
+ credentials=mock_sync_hook.get_credentials.return_value,
client_info=CLIENT_INFO,
client_options=None,
)
- @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_credentials"))
+ @pytest.mark.asyncio
+ @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_sync_hook"))
@mock.patch(DATAPROC_STRING.format("BatchControllerAsyncClient"))
- def test_get_batch_client_region(self, mock_client, mock_get_credentials):
- self.hook.get_batch_client(region="region1")
+ async def test_get_batch_client_region(self, mock_client,
mock_get_sync_hook):
+ mock_sync_hook = mock.MagicMock()
+ mock_get_sync_hook.return_value = mock_sync_hook
+ mock_sync_hook.get_credentials.return_value = mock.MagicMock()
+
+ await self.hook.get_batch_client(region="region1")
mock_client.assert_called_once_with(
- credentials=mock_get_credentials.return_value,
client_info=CLIENT_INFO, client_options=ANY
+ credentials=mock_sync_hook.get_credentials.return_value,
+ client_info=CLIENT_INFO,
+ client_options=ANY,
)
@pytest.mark.asyncio
diff --git a/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py
b/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py
index 40048a4e5d7..03cf2986301 100644
--- a/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py
+++ b/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py
@@ -18,7 +18,6 @@ from __future__ import annotations
import asyncio
import contextlib
-import logging
from asyncio import CancelledError, Future, sleep
from unittest import mock
@@ -170,56 +169,22 @@ class TestDataprocClusterTrigger:
@pytest.mark.db_test
@pytest.mark.asyncio
-
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster")
- async def
test_async_cluster_triggers_on_success_should_execute_successfully(
- self, mock_hook, cluster_trigger, async_get_cluster
- ):
- mock_hook.return_value = async_get_cluster(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- cluster_name=TEST_CLUSTER_NAME,
- status=ClusterStatus(state=ClusterStatus.State.RUNNING),
- )
-
- generator = cluster_trigger.run()
- actual_event = await generator.asend(None)
-
- expected_event = TriggerEvent(
- {
- "cluster_name": TEST_CLUSTER_NAME,
- "cluster_state": ClusterStatus.State.RUNNING,
- "cluster": actual_event.payload["cluster"],
- }
- )
- assert expected_event == actual_event
-
- @pytest.mark.db_test
- @pytest.mark.asyncio
-
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster")
- @mock.patch(
-
"airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.delete_cluster",
- return_value=asyncio.Future(),
- )
- @mock.patch("google.auth.default")
+
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook")
+ @mock.patch.object(DataprocClusterTrigger, "log")
async def test_async_cluster_trigger_run_returns_error_event(
- self, mock_auth, mock_delete_cluster, mock_get_cluster,
cluster_trigger, async_get_cluster, caplog
+ self, mock_log, mock_get_async_hook, cluster_trigger
):
- mock_credentials = mock.MagicMock()
- mock_credentials.universe_domain = "googleapis.com"
-
- mock_auth.return_value = (mock_credentials, "project-id")
-
- mock_delete_cluster.return_value = asyncio.Future()
- mock_delete_cluster.return_value.set_result(None)
+ # Mock delete_cluster to return a Future
+ mock_delete_future = asyncio.Future()
+ mock_delete_future.set_result(None)
+ mock_get_async_hook.return_value.delete_cluster.return_value =
mock_delete_future
- mock_get_cluster.return_value = async_get_cluster(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- cluster_name=TEST_CLUSTER_NAME,
- status=ClusterStatus(state=ClusterStatus.State.ERROR),
- )
+ mock_cluster = mock.MagicMock()
+ mock_cluster.status = ClusterStatus(state=ClusterStatus.State.ERROR)
- caplog.set_level(logging.INFO)
+ future = asyncio.Future()
+ future.set_result(mock_cluster)
+ mock_get_async_hook.return_value.get_cluster.return_value = future
trigger_event = None
async for event in cluster_trigger.run():
@@ -230,31 +195,28 @@ class TestDataprocClusterTrigger:
@pytest.mark.db_test
@pytest.mark.asyncio
-
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster")
- async def test_cluster_run_loop_is_still_running(
- self, mock_hook, cluster_trigger, caplog, async_get_cluster
- ):
- mock_hook.return_value = async_get_cluster(
- project_id=TEST_PROJECT_ID,
- region=TEST_REGION,
- cluster_name=TEST_CLUSTER_NAME,
- status=ClusterStatus(state=ClusterStatus.State.CREATING),
- )
+
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook")
+ @mock.patch.object(DataprocClusterTrigger, "log")
+ async def test_cluster_run_loop_is_still_running(self, mock_log,
mock_get_async_hook, cluster_trigger):
+ mock_cluster = mock.MagicMock()
+ mock_cluster.status = ClusterStatus(state=ClusterStatus.State.CREATING)
- caplog.set_level(logging.INFO)
+ future = asyncio.Future()
+ future.set_result(mock_cluster)
+ mock_get_async_hook.return_value.get_cluster.return_value = future
task = asyncio.create_task(cluster_trigger.run().__anext__())
await asyncio.sleep(0.5)
assert not task.done()
- assert f"Current state is: {ClusterStatus.State.CREATING}."
- assert f"Sleeping for {TEST_POLL_INTERVAL} seconds."
+ mock_log.info.assert_called()
@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook")
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_sync_hook")
+ @mock.patch.object(DataprocClusterTrigger, "log")
async def test_cluster_trigger_cancellation_handling(
- self, mock_get_sync_hook, mock_get_async_hook, caplog
+ self, mock_log, mock_get_sync_hook, mock_get_async_hook
):
cluster =
Cluster(status=ClusterStatus(state=ClusterStatus.State.RUNNING))
mock_get_async_hook.return_value.get_cluster.return_value =
asyncio.Future()
@@ -288,8 +250,7 @@ class TestDataprocClusterTrigger:
cluster_name=cluster_trigger.cluster_name,
project_id=cluster_trigger.project_id,
)
- assert "Deleting cluster" in caplog.text
- assert "Deleted cluster" in caplog.text
+ mock_log.info.assert_called()
else:
mock_delete_cluster.assert_not_called()
except Exception as e:
@@ -297,19 +258,24 @@ class TestDataprocClusterTrigger:
@pytest.mark.db_test
@pytest.mark.asyncio
-
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster")
- async def test_fetch_cluster_status(self, mock_get_cluster,
cluster_trigger, async_get_cluster):
- mock_get_cluster.return_value = async_get_cluster(
- status=ClusterStatus(state=ClusterStatus.State.RUNNING)
- )
+
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook")
+ async def test_fetch_cluster_status(self, mock_get_async_hook,
cluster_trigger):
+ mock_cluster = mock.MagicMock()
+ mock_cluster.status = ClusterStatus(state=ClusterStatus.State.RUNNING)
+
+ future = asyncio.Future()
+ future.set_result(mock_cluster)
+ mock_get_async_hook.return_value.get_cluster.return_value = future
+
cluster = await cluster_trigger.fetch_cluster()
assert cluster.status.state == ClusterStatus.State.RUNNING, "The
cluster state should be RUNNING"
@pytest.mark.db_test
@pytest.mark.asyncio
-
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.delete_cluster")
- async def test_delete_when_error_occurred(self, mock_delete_cluster,
cluster_trigger):
+
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook")
+ @mock.patch.object(DataprocClusterTrigger, "log")
+ async def test_delete_when_error_occurred(self, mock_log,
mock_get_async_hook, cluster_trigger):
mock_cluster = mock.MagicMock(spec=Cluster)
type(mock_cluster).status = mock.PropertyMock(
return_value=mock.MagicMock(state=ClusterStatus.State.ERROR)
@@ -317,31 +283,32 @@ class TestDataprocClusterTrigger:
mock_delete_future = asyncio.Future()
mock_delete_future.set_result(None)
- mock_delete_cluster.return_value = mock_delete_future
+ mock_get_async_hook.return_value.delete_cluster.return_value =
mock_delete_future
cluster_trigger.delete_on_error = True
await cluster_trigger.delete_when_error_occurred(mock_cluster)
- mock_delete_cluster.assert_called_once_with(
+
mock_get_async_hook.return_value.delete_cluster.assert_called_once_with(
region=cluster_trigger.region,
cluster_name=cluster_trigger.cluster_name,
project_id=cluster_trigger.project_id,
)
- mock_delete_cluster.reset_mock()
+ mock_get_async_hook.return_value.delete_cluster.reset_mock()
cluster_trigger.delete_on_error = False
await cluster_trigger.delete_when_error_occurred(mock_cluster)
- mock_delete_cluster.assert_not_called()
+ mock_get_async_hook.return_value.delete_cluster.assert_not_called()
@pytest.mark.asyncio
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook")
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_sync_hook")
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.safe_to_cancel")
+ @mock.patch.object(DataprocClusterTrigger, "log")
async def test_cluster_trigger_run_cancelled_not_safe_to_cancel(
- self, mock_safe_to_cancel, mock_get_sync_hook, mock_get_async_hook,
cluster_trigger
+ self, mock_log, mock_safe_to_cancel, mock_get_sync_hook,
mock_get_async_hook, cluster_trigger
):
"""Test the trigger's cancellation behavior when it is not safe to
cancel."""
mock_safe_to_cancel.return_value = False
@@ -366,6 +333,31 @@ class TestDataprocClusterTrigger:
assert mock_delete_cluster.call_count == 0
mock_delete_cluster.assert_not_called()
+ @pytest.mark.db_test
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook")
+ async def
test_async_cluster_triggers_on_success_should_execute_successfully(
+ self, mock_get_async_hook, cluster_trigger
+ ):
+ mock_cluster = mock.MagicMock()
+ mock_cluster.status = ClusterStatus(state=ClusterStatus.State.RUNNING)
+
+ future = asyncio.Future()
+ future.set_result(mock_cluster)
+ mock_get_async_hook.return_value.get_cluster.return_value = future
+
+ generator = cluster_trigger.run()
+ actual_event = await generator.asend(None)
+
+ expected_event = TriggerEvent(
+ {
+ "cluster_name": TEST_CLUSTER_NAME,
+ "cluster_state": ClusterStatus.State.RUNNING,
+ "cluster": actual_event.payload["cluster"],
+ }
+ )
+ assert expected_event == actual_event
+
class TestDataprocBatchTrigger:
def
test_async_create_batch_trigger_serialization_should_execute_successfully(self,
batch_trigger):
@@ -387,17 +379,21 @@ class TestDataprocBatchTrigger:
@pytest.mark.db_test
@pytest.mark.asyncio
-
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_batch")
+
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocBatchTrigger.get_async_hook")
async def
test_async_create_batch_trigger_triggers_on_success_should_execute_successfully(
- self, mock_hook, batch_trigger, async_get_batch
+ self, mock_get_async_hook, batch_trigger
):
"""
Tests the DataprocBatchTrigger only fires once the batch execution
reaches a successful state.
"""
- mock_hook.return_value = async_get_batch(
- state=Batch.State.SUCCEEDED, batch_id=TEST_BATCH_ID,
state_message=TEST_BATCH_STATE_MESSAGE
- )
+ mock_batch = mock.MagicMock()
+ mock_batch.state = Batch.State.SUCCEEDED
+ mock_batch.state_message = TEST_BATCH_STATE_MESSAGE
+
+ future = asyncio.Future()
+ future.set_result(mock_batch)
+ mock_get_async_hook.return_value.get_batch.return_value = future
expected_event = TriggerEvent(
{
@@ -413,13 +409,17 @@ class TestDataprocBatchTrigger:
@pytest.mark.db_test
@pytest.mark.asyncio
-
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_batch")
+
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocBatchTrigger.get_async_hook")
async def test_async_create_batch_trigger_run_returns_failed_event(
- self, mock_hook, batch_trigger, async_get_batch
+ self, mock_get_async_hook, batch_trigger
):
- mock_hook.return_value = async_get_batch(
- state=Batch.State.FAILED, batch_id=TEST_BATCH_ID,
state_message=TEST_BATCH_STATE_MESSAGE
- )
+ mock_batch = mock.MagicMock()
+ mock_batch.state = Batch.State.FAILED
+ mock_batch.state_message = TEST_BATCH_STATE_MESSAGE
+
+ future = asyncio.Future()
+ future.set_result(mock_batch)
+ mock_get_async_hook.return_value.get_batch.return_value = future
expected_event = TriggerEvent(
{
@@ -435,11 +435,15 @@ class TestDataprocBatchTrigger:
@pytest.mark.db_test
@pytest.mark.asyncio
-
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_batch")
- async def test_create_batch_run_returns_cancelled_event(self, mock_hook,
batch_trigger, async_get_batch):
- mock_hook.return_value = async_get_batch(
- state=Batch.State.CANCELLED, batch_id=TEST_BATCH_ID,
state_message=TEST_BATCH_STATE_MESSAGE
- )
+
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocBatchTrigger.get_async_hook")
+ async def test_create_batch_run_returns_cancelled_event(self,
mock_get_async_hook, batch_trigger):
+ mock_batch = mock.MagicMock()
+ mock_batch.state = Batch.State.CANCELLED
+ mock_batch.state_message = TEST_BATCH_STATE_MESSAGE
+
+ future = asyncio.Future()
+ future.set_result(mock_batch)
+ mock_get_async_hook.return_value.get_batch.return_value = future
expected_event = TriggerEvent(
{
@@ -455,20 +459,21 @@ class TestDataprocBatchTrigger:
@pytest.mark.db_test
@pytest.mark.asyncio
-
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_batch")
- async def test_create_batch_run_loop_is_still_running(
- self, mock_hook, batch_trigger, caplog, async_get_batch
- ):
- mock_hook.return_value = async_get_batch(state=Batch.State.RUNNING)
+
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocBatchTrigger.get_async_hook")
+ @mock.patch.object(DataprocBatchTrigger, "log")
+ async def test_create_batch_run_loop_is_still_running(self, mock_log,
mock_get_async_hook, batch_trigger):
+ mock_batch = mock.MagicMock()
+ mock_batch.state = Batch.State.RUNNING
- caplog.set_level(logging.INFO)
+ future = asyncio.Future()
+ future.set_result(mock_batch)
+ mock_get_async_hook.return_value.get_batch.return_value = future
task = asyncio.create_task(batch_trigger.run().__anext__())
await asyncio.sleep(0.5)
assert not task.done()
- assert f"Current state is: {Batch.State.RUNNING}"
- assert f"Sleeping for {TEST_POLL_INTERVAL} seconds."
+ mock_log.info.assert_called()
class TestDataprocOperationTrigger:
@@ -486,13 +491,19 @@ class TestDataprocOperationTrigger:
}
@pytest.mark.asyncio
-
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocBaseTrigger.get_async_hook")
+
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocOperationTrigger.get_async_hook")
async def
test_async_operation_triggers_on_success_should_execute_successfully(
- self, mock_hook, operation_trigger, async_get_operation
+ self, mock_get_async_hook, operation_trigger
):
- mock_hook.return_value.get_operation.return_value =
async_get_operation(
- name=TEST_OPERATION_NAME, done=True, response={},
error=Status(message="")
- )
+ mock_operation = mock.MagicMock()
+ mock_operation.name = TEST_OPERATION_NAME
+ mock_operation.done = True
+ mock_operation.response = {}
+ mock_operation.error = Status(message="")
+
+ future = asyncio.Future()
+ future.set_result(mock_operation)
+ mock_get_async_hook.return_value.get_operation.return_value = future
expected_event = TriggerEvent(
{
@@ -506,17 +517,20 @@ class TestDataprocOperationTrigger:
assert expected_event == actual_event
@pytest.mark.asyncio
-
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocBaseTrigger.get_async_hook")
+
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocOperationTrigger.get_async_hook")
async def
test_async_diagnose_operation_triggers_on_success_should_execute_successfully(
- self, mock_hook, diagnose_operation_trigger, async_get_operation
+ self, mock_get_async_hook, diagnose_operation_trigger
):
gcs_uri = "gs://test-tarball-gcs-dir-bucket"
- mock_hook.return_value.get_operation.return_value =
async_get_operation(
- name=TEST_OPERATION_NAME,
- done=True,
- response=Any(value=gcs_uri.encode("utf-8")),
- error=Status(message=""),
- )
+ mock_operation = mock.MagicMock()
+ mock_operation.name = TEST_OPERATION_NAME
+ mock_operation.done = True
+ mock_operation.response = Any(value=gcs_uri.encode("utf-8"))
+ mock_operation.error = Status(message="")
+
+ future = asyncio.Future()
+ future.set_result(mock_operation)
+ mock_get_async_hook.return_value.get_operation.return_value = future
expected_event = TriggerEvent(
{
@@ -529,11 +543,17 @@ class TestDataprocOperationTrigger:
assert expected_event == actual_event
@pytest.mark.asyncio
-
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocBaseTrigger.get_async_hook")
- async def test_async_operation_triggers_on_error(self, mock_hook,
operation_trigger, async_get_operation):
- mock_hook.return_value.get_operation.return_value =
async_get_operation(
- name=TEST_OPERATION_NAME, done=True, response={},
error=Status(message="test_error")
- )
+
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocOperationTrigger.get_async_hook")
+ async def test_async_operation_triggers_on_error(self,
mock_get_async_hook, operation_trigger):
+ mock_operation = mock.MagicMock()
+ mock_operation.name = TEST_OPERATION_NAME
+ mock_operation.done = True
+ mock_operation.response = {}
+ mock_operation.error = Status(message="test_error")
+
+ future = asyncio.Future()
+ future.set_result(mock_operation)
+ mock_get_async_hook.return_value.get_operation.return_value = future
expected_event = TriggerEvent(
{