This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 00867ec8665 fix: Change DataprocAsyncHook parent class to 
GoogleBaseAsyncHook (#52981)
00867ec8665 is described below

commit 00867ec8665d0205c7b2ef88e717eb23ac5a1f35
Author: Geonwoo Kim <[email protected]>
AuthorDate: Sat Jul 19 21:14:03 2025 +0900

    fix: Change DataprocAsyncHook parent class to GoogleBaseAsyncHook (#52981)
    
    * fix: Change DataprocAsyncHook parent class to GoogleBaseAsyncHook
    
    * fix: hook test_dataproc
    
    * fix: trigger test code
    
    * fix: replace Caplog into mock_log
    
    * fix: remove test_async_cluster_trigger_run_returns_error_event mock_auth
    
    * fix: change pos of parameter and add fallback decorator
---
 .../providers/google/cloud/hooks/dataproc.py       | 138 ++++++-----
 .../tests/unit/google/cloud/hooks/test_dataproc.py | 106 ++++++---
 .../unit/google/cloud/triggers/test_dataproc.py    | 252 +++++++++++----------
 3 files changed, 278 insertions(+), 218 deletions(-)

diff --git 
a/providers/google/src/airflow/providers/google/cloud/hooks/dataproc.py 
b/providers/google/src/airflow/providers/google/cloud/hooks/dataproc.py
index 6cc57128bab..872673aae6b 100644
--- a/providers/google/src/airflow/providers/google/cloud/hooks/dataproc.py
+++ b/providers/google/src/airflow/providers/google/cloud/hooks/dataproc.py
@@ -47,7 +47,7 @@ from google.cloud.dataproc_v1 import (
 
 from airflow.exceptions import AirflowException
 from airflow.providers.google.common.consts import CLIENT_INFO
-from airflow.providers.google.common.hooks.base_google import GoogleBaseHook
+from airflow.providers.google.common.hooks.base_google import 
GoogleBaseAsyncHook, GoogleBaseHook
 from airflow.version import version as airflow_version
 
 if TYPE_CHECKING:
@@ -1269,7 +1269,7 @@ class DataprocHook(GoogleBaseHook):
         return all([word in error_msg for word in key_words])
 
 
-class DataprocAsyncHook(GoogleBaseHook):
+class DataprocAsyncHook(GoogleBaseAsyncHook):
     """
     Asynchronous interaction with Google Cloud Dataproc APIs.
 
@@ -1277,6 +1277,8 @@ class DataprocAsyncHook(GoogleBaseHook):
     keyword arguments rather than positional.
     """
 
+    sync_hook_class = DataprocHook
+
     def __init__(
         self,
         gcp_conn_id: str = "google_cloud_default",
@@ -1286,53 +1288,90 @@ class DataprocAsyncHook(GoogleBaseHook):
         super().__init__(gcp_conn_id=gcp_conn_id, 
impersonation_chain=impersonation_chain, **kwargs)
         self._cached_client: JobControllerAsyncClient | None = None
 
-    def get_cluster_client(self, region: str | None = None) -> 
ClusterControllerAsyncClient:
+    async def get_cluster_client(self, region: str | None = None) -> 
ClusterControllerAsyncClient:
         """Create a ClusterControllerAsyncClient."""
         client_options = None
         if region and region != "global":
             client_options = 
ClientOptions(api_endpoint=f"{region}-dataproc.googleapis.com:443")
 
+        sync_hook = await self.get_sync_hook()
         return ClusterControllerAsyncClient(
-            credentials=self.get_credentials(), client_info=CLIENT_INFO, 
client_options=client_options
+            credentials=sync_hook.get_credentials(), client_info=CLIENT_INFO, 
client_options=client_options
         )
 
-    def get_template_client(self, region: str | None = None) -> 
WorkflowTemplateServiceAsyncClient:
+    async def get_template_client(self, region: str | None = None) -> 
WorkflowTemplateServiceAsyncClient:
         """Create a WorkflowTemplateServiceAsyncClient."""
         client_options = None
         if region and region != "global":
             client_options = 
ClientOptions(api_endpoint=f"{region}-dataproc.googleapis.com:443")
 
+        sync_hook = await self.get_sync_hook()
         return WorkflowTemplateServiceAsyncClient(
-            credentials=self.get_credentials(), client_info=CLIENT_INFO, 
client_options=client_options
+            credentials=sync_hook.get_credentials(), client_info=CLIENT_INFO, 
client_options=client_options
         )
 
-    def get_job_client(self, region: str | None = None) -> 
JobControllerAsyncClient:
+    async def get_job_client(self, region: str | None = None) -> 
JobControllerAsyncClient:
         """Create a JobControllerAsyncClient."""
         if self._cached_client is None:
             client_options = None
             if region and region != "global":
                 client_options = 
ClientOptions(api_endpoint=f"{region}-dataproc.googleapis.com:443")
 
+            sync_hook = await self.get_sync_hook()
             self._cached_client = JobControllerAsyncClient(
-                credentials=self.get_credentials(),
+                credentials=sync_hook.get_credentials(),
                 client_info=CLIENT_INFO,
                 client_options=client_options,
             )
         return self._cached_client
 
-    def get_batch_client(self, region: str | None = None) -> 
BatchControllerAsyncClient:
+    async def get_batch_client(self, region: str | None = None) -> 
BatchControllerAsyncClient:
         """Create a BatchControllerAsyncClient."""
         client_options = None
         if region and region != "global":
             client_options = 
ClientOptions(api_endpoint=f"{region}-dataproc.googleapis.com:443")
 
+        sync_hook = await self.get_sync_hook()
         return BatchControllerAsyncClient(
-            credentials=self.get_credentials(), client_info=CLIENT_INFO, 
client_options=client_options
+            credentials=sync_hook.get_credentials(), client_info=CLIENT_INFO, 
client_options=client_options
         )
 
-    def get_operations_client(self, region: str) -> OperationsClient:
+    async def get_operations_client(self, region: str) -> OperationsClient:
         """Create a OperationsClient."""
-        return 
self.get_template_client(region=region).transport.operations_client
+        template_client = await self.get_template_client(region=region)
+        return template_client.transport.operations_client
+
+    @GoogleBaseHook.fallback_to_default_project_id
+    async def get_cluster(
+        self,
+        region: str,
+        cluster_name: str,
+        project_id: str,
+        retry: AsyncRetry | _MethodDefault = DEFAULT,
+        timeout: float | None = None,
+        metadata: Sequence[tuple[str, str]] = (),
+    ) -> Cluster:
+        """
+        Get a cluster.
+
+        :param region: Cloud Dataproc region in which to handle the request.
+        :param cluster_name: Name of the cluster to get.
+        :param project_id: Google Cloud project ID that the cluster belongs to.
+        :param retry: A retry object used to retry requests. If *None*, 
requests
+            will not be retried.
+        :param timeout: The amount of time, in seconds, to wait for the request
+            to complete. If *retry* is specified, the timeout applies to each
+            individual attempt.
+        :param metadata: Additional metadata that is provided to the method.
+        """
+        client = await self.get_cluster_client(region=region)
+        result = await client.get_cluster(
+            request={"project_id": project_id, "region": region, 
"cluster_name": cluster_name},
+            retry=retry,
+            timeout=timeout,
+            metadata=metadata,
+        )
+        return result
 
     @GoogleBaseHook.fallback_to_default_project_id
     async def create_cluster(
@@ -1390,7 +1429,7 @@ class DataprocAsyncHook(GoogleBaseHook):
             cluster["config"] = cluster_config  # type: ignore
             cluster["labels"] = labels  # type: ignore
 
-        client = self.get_cluster_client(region=region)
+        client = await self.get_cluster_client(region=region)
         result = await client.create_cluster(
             request={
                 "project_id": project_id,
@@ -1435,7 +1474,7 @@ class DataprocAsyncHook(GoogleBaseHook):
             individual attempt.
         :param metadata: Additional metadata that is provided to the method.
         """
-        client = self.get_cluster_client(region=region)
+        client = await self.get_cluster_client(region=region)
         result = await client.delete_cluster(
             request={
                 "project_id": project_id,
@@ -1483,7 +1522,7 @@ class DataprocAsyncHook(GoogleBaseHook):
             individual attempt.
         :param metadata: Additional metadata that is provided to the method.
         """
-        client = self.get_cluster_client(region=region)
+        client = await self.get_cluster_client(region=region)
         result = await client.diagnose_cluster(
             request={
                 "project_id": project_id,
@@ -1500,38 +1539,6 @@ class DataprocAsyncHook(GoogleBaseHook):
         )
         return result
 
-    @GoogleBaseHook.fallback_to_default_project_id
-    async def get_cluster(
-        self,
-        region: str,
-        cluster_name: str,
-        project_id: str,
-        retry: AsyncRetry | _MethodDefault = DEFAULT,
-        timeout: float | None = None,
-        metadata: Sequence[tuple[str, str]] = (),
-    ) -> Cluster:
-        """
-        Get the resource representation for a cluster in a project.
-
-        :param project_id: Google Cloud project ID that the cluster belongs to.
-        :param region: Cloud Dataproc region to handle the request.
-        :param cluster_name: The cluster name.
-        :param retry: A retry object used to retry requests. If *None*, 
requests
-            will not be retried.
-        :param timeout: The amount of time, in seconds, to wait for the request
-            to complete. If *retry* is specified, the timeout applies to each
-            individual attempt.
-        :param metadata: Additional metadata that is provided to the method.
-        """
-        client = self.get_cluster_client(region=region)
-        result = await client.get_cluster(
-            request={"project_id": project_id, "region": region, 
"cluster_name": cluster_name},
-            retry=retry,
-            timeout=timeout,
-            metadata=metadata,
-        )
-        return result
-
     @GoogleBaseHook.fallback_to_default_project_id
     async def list_clusters(
         self,
@@ -1561,7 +1568,7 @@ class DataprocAsyncHook(GoogleBaseHook):
             individual attempt.
         :param metadata: Additional metadata that is provided to the method.
         """
-        client = self.get_cluster_client(region=region)
+        client = await self.get_cluster_client(region=region)
         result = await client.list_clusters(
             request={"project_id": project_id, "region": region, "filter": 
filter_, "page_size": page_size},
             retry=retry,
@@ -1638,7 +1645,7 @@ class DataprocAsyncHook(GoogleBaseHook):
         """
         if region is None:
             raise TypeError("missing 1 required keyword argument: 'region'")
-        client = self.get_cluster_client(region=region)
+        client = await self.get_cluster_client(region=region)
         operation = await client.update_cluster(
             request={
                 "project_id": project_id,
@@ -1680,10 +1687,8 @@ class DataprocAsyncHook(GoogleBaseHook):
             individual attempt.
         :param metadata: Additional metadata that is provided to the method.
         """
-        if region is None:
-            raise TypeError("missing 1 required keyword argument: 'region'")
         metadata = metadata or ()
-        client = self.get_template_client(region)
+        client = await self.get_template_client(region)
         parent = f"projects/{project_id}/regions/{region}"
         return await client.create_workflow_template(
             request={"parent": parent, "template": template}, retry=retry, 
timeout=timeout, metadata=metadata
@@ -1725,10 +1730,8 @@ class DataprocAsyncHook(GoogleBaseHook):
             individual attempt.
         :param metadata: Additional metadata that is provided to the method.
         """
-        if region is None:
-            raise TypeError("missing 1 required keyword argument: 'region'")
         metadata = metadata or ()
-        client = self.get_template_client(region)
+        client = await self.get_template_client(region)
         name = 
f"projects/{project_id}/regions/{region}/workflowTemplates/{template_name}"
         operation = await client.instantiate_workflow_template(
             request={"name": name, "version": version, "request_id": 
request_id, "parameters": parameters},
@@ -1767,10 +1770,8 @@ class DataprocAsyncHook(GoogleBaseHook):
             individual attempt.
         :param metadata: Additional metadata that is provided to the method.
         """
-        if region is None:
-            raise TypeError("missing 1 required keyword argument: 'region'")
         metadata = metadata or ()
-        client = self.get_template_client(region)
+        client = await self.get_template_client(region)
         parent = f"projects/{project_id}/regions/{region}"
         operation = await client.instantiate_inline_workflow_template(
             request={"parent": parent, "template": template, "request_id": 
request_id},
@@ -1781,7 +1782,8 @@ class DataprocAsyncHook(GoogleBaseHook):
         return operation
 
     async def get_operation(self, region, operation_name):
-        return await 
self.get_operations_client(region).get_operation(name=operation_name)
+        operations_client = await self.get_operations_client(region)
+        return await operations_client.get_operation(name=operation_name)
 
     @GoogleBaseHook.fallback_to_default_project_id
     async def get_job(
@@ -1806,9 +1808,7 @@ class DataprocAsyncHook(GoogleBaseHook):
             individual attempt.
         :param metadata: Additional metadata that is provided to the method.
         """
-        if region is None:
-            raise TypeError("missing 1 required keyword argument: 'region'")
-        client = self.get_job_client(region=region)
+        client = await self.get_job_client(region=region)
         job = await client.get_job(
             request={"project_id": project_id, "region": region, "job_id": 
job_id},
             retry=retry,
@@ -1845,9 +1845,7 @@ class DataprocAsyncHook(GoogleBaseHook):
             individual attempt.
         :param metadata: Additional metadata that is provided to the method.
         """
-        if region is None:
-            raise TypeError("missing 1 required keyword argument: 'region'")
-        client = self.get_job_client(region=region)
+        client = await self.get_job_client(region=region)
         return await client.submit_job(
             request={"project_id": project_id, "region": region, "job": job, 
"request_id": request_id},
             retry=retry,
@@ -1878,7 +1876,7 @@ class DataprocAsyncHook(GoogleBaseHook):
             individual attempt.
         :param metadata: Additional metadata that is provided to the method.
         """
-        client = self.get_job_client(region=region)
+        client = await self.get_job_client(region=region)
 
         job = await client.cancel_job(
             request={"project_id": project_id, "region": region, "job_id": 
job_id},
@@ -1920,7 +1918,7 @@ class DataprocAsyncHook(GoogleBaseHook):
             individual attempt.
         :param metadata: Additional metadata that is provided to the method.
         """
-        client = self.get_batch_client(region)
+        client = await self.get_batch_client(region)
         parent = f"projects/{project_id}/regions/{region}"
 
         result = await client.create_batch(
@@ -1959,7 +1957,7 @@ class DataprocAsyncHook(GoogleBaseHook):
             individual attempt.
         :param metadata: Additional metadata that is provided to the method.
         """
-        client = self.get_batch_client(region)
+        client = await self.get_batch_client(region)
         name = f"projects/{project_id}/locations/{region}/batches/{batch_id}"
 
         await client.delete_batch(
@@ -1994,7 +1992,7 @@ class DataprocAsyncHook(GoogleBaseHook):
             individual attempt.
         :param metadata: Additional metadata that is provided to the method.
         """
-        client = self.get_batch_client(region)
+        client = await self.get_batch_client(region)
         name = f"projects/{project_id}/locations/{region}/batches/{batch_id}"
 
         result = await client.get_batch(
@@ -2039,7 +2037,7 @@ class DataprocAsyncHook(GoogleBaseHook):
         :param filter: Result filters as specified in ListBatchesRequest
         :param order_by: How to order results as specified in 
ListBatchesRequest
         """
-        client = self.get_batch_client(region)
+        client = await self.get_batch_client(region)
         parent = f"projects/{project_id}/regions/{region}"
 
         result = await client.list_batches(
diff --git a/providers/google/tests/unit/google/cloud/hooks/test_dataproc.py 
b/providers/google/tests/unit/google/cloud/hooks/test_dataproc.py
index 1e0349d8bfd..786008640df 100644
--- a/providers/google/tests/unit/google/cloud/hooks/test_dataproc.py
+++ b/providers/google/tests/unit/google/cloud/hooks/test_dataproc.py
@@ -601,82 +601,124 @@ class TestDataprocAsyncHook:
         with mock.patch(BASE_STRING.format("GoogleBaseHook.__init__"), 
new=mock_init):
             self.hook = DataprocAsyncHook(gcp_conn_id="test")
 
-    @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_credentials"))
+    @pytest.mark.asyncio
+    @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_sync_hook"))
     @mock.patch(DATAPROC_STRING.format("ClusterControllerAsyncClient"))
-    def test_get_cluster_client(self, mock_client, mock_get_credentials):
-        self.hook.get_cluster_client(region=GCP_LOCATION)
+    async def test_get_cluster_client(self, mock_client, mock_get_sync_hook):
+        mock_sync_hook = mock.MagicMock()
+        mock_get_sync_hook.return_value = mock_sync_hook
+        mock_sync_hook.get_credentials.return_value = mock.MagicMock()
+
+        await self.hook.get_cluster_client(region=GCP_LOCATION)
         mock_client.assert_called_once_with(
-            credentials=mock_get_credentials.return_value,
+            credentials=mock_sync_hook.get_credentials.return_value,
             client_info=CLIENT_INFO,
             client_options=None,
         )
 
-    @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_credentials"))
+    @pytest.mark.asyncio
+    @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_sync_hook"))
     @mock.patch(DATAPROC_STRING.format("ClusterControllerAsyncClient"))
-    def test_get_cluster_client_region(self, mock_client, 
mock_get_credentials):
-        self.hook.get_cluster_client(region="region1")
+    async def test_get_cluster_client_region(self, mock_client, 
mock_get_sync_hook):
+        mock_sync_hook = mock.MagicMock()
+        mock_get_sync_hook.return_value = mock_sync_hook
+        mock_sync_hook.get_credentials.return_value = mock.MagicMock()
+
+        await self.hook.get_cluster_client(region="region1")
         mock_client.assert_called_once_with(
-            credentials=mock_get_credentials.return_value,
+            credentials=mock_sync_hook.get_credentials.return_value,
             client_info=CLIENT_INFO,
             client_options=ANY,
         )
 
-    @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_credentials"))
+    @pytest.mark.asyncio
+    @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_sync_hook"))
     @mock.patch(DATAPROC_STRING.format("WorkflowTemplateServiceAsyncClient"))
-    def test_get_template_client_global(self, mock_client, 
mock_get_credentials):
-        _ = self.hook.get_template_client()
+    async def test_get_template_client_global(self, mock_client, 
mock_get_sync_hook):
+        mock_sync_hook = mock.MagicMock()
+        mock_get_sync_hook.return_value = mock_sync_hook
+        mock_sync_hook.get_credentials.return_value = mock.MagicMock()
+
+        _ = await self.hook.get_template_client()
         mock_client.assert_called_once_with(
-            credentials=mock_get_credentials.return_value,
+            credentials=mock_sync_hook.get_credentials.return_value,
             client_info=CLIENT_INFO,
             client_options=None,
         )
 
-    @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_credentials"))
+    @pytest.mark.asyncio
+    @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_sync_hook"))
     @mock.patch(DATAPROC_STRING.format("WorkflowTemplateServiceAsyncClient"))
-    def test_get_template_client_region(self, mock_client, 
mock_get_credentials):
-        _ = self.hook.get_template_client(region="region1")
+    async def test_get_template_client_region(self, mock_client, 
mock_get_sync_hook):
+        mock_sync_hook = mock.MagicMock()
+        mock_get_sync_hook.return_value = mock_sync_hook
+        mock_sync_hook.get_credentials.return_value = mock.MagicMock()
+
+        _ = await self.hook.get_template_client(region="region1")
         mock_client.assert_called_once_with(
-            credentials=mock_get_credentials.return_value,
+            credentials=mock_sync_hook.get_credentials.return_value,
             client_info=CLIENT_INFO,
             client_options=ANY,
         )
 
-    @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_credentials"))
+    @pytest.mark.asyncio
+    @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_sync_hook"))
     @mock.patch(DATAPROC_STRING.format("JobControllerAsyncClient"))
-    def test_get_job_client(self, mock_client, mock_get_credentials):
-        self.hook.get_job_client(region=GCP_LOCATION)
+    async def test_get_job_client(self, mock_client, mock_get_sync_hook):
+        mock_sync_hook = mock.MagicMock()
+        mock_get_sync_hook.return_value = mock_sync_hook
+        mock_sync_hook.get_credentials.return_value = mock.MagicMock()
+
+        await self.hook.get_job_client(region=GCP_LOCATION)
         mock_client.assert_called_once_with(
-            credentials=mock_get_credentials.return_value,
+            credentials=mock_sync_hook.get_credentials.return_value,
             client_info=CLIENT_INFO,
             client_options=None,
         )
 
-    @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_credentials"))
+    @pytest.mark.asyncio
+    @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_sync_hook"))
     @mock.patch(DATAPROC_STRING.format("JobControllerAsyncClient"))
-    def test_get_job_client_region(self, mock_client, mock_get_credentials):
-        self.hook.get_job_client(region="region1")
+    async def test_get_job_client_region(self, mock_client, 
mock_get_sync_hook):
+        mock_sync_hook = mock.MagicMock()
+        mock_get_sync_hook.return_value = mock_sync_hook
+        mock_sync_hook.get_credentials.return_value = mock.MagicMock()
+
+        await self.hook.get_job_client(region="region1")
         mock_client.assert_called_once_with(
-            credentials=mock_get_credentials.return_value,
+            credentials=mock_sync_hook.get_credentials.return_value,
             client_info=CLIENT_INFO,
             client_options=ANY,
         )
 
-    @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_credentials"))
+    @pytest.mark.asyncio
+    @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_sync_hook"))
     @mock.patch(DATAPROC_STRING.format("BatchControllerAsyncClient"))
-    def test_get_batch_client(self, mock_client, mock_get_credentials):
-        self.hook.get_batch_client(region=GCP_LOCATION)
+    async def test_get_batch_client(self, mock_client, mock_get_sync_hook):
+        mock_sync_hook = mock.MagicMock()
+        mock_get_sync_hook.return_value = mock_sync_hook
+        mock_sync_hook.get_credentials.return_value = mock.MagicMock()
+
+        await self.hook.get_batch_client(region=GCP_LOCATION)
         mock_client.assert_called_once_with(
-            credentials=mock_get_credentials.return_value,
+            credentials=mock_sync_hook.get_credentials.return_value,
             client_info=CLIENT_INFO,
             client_options=None,
         )
 
-    @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_credentials"))
+    @pytest.mark.asyncio
+    @mock.patch(DATAPROC_STRING.format("DataprocAsyncHook.get_sync_hook"))
     @mock.patch(DATAPROC_STRING.format("BatchControllerAsyncClient"))
-    def test_get_batch_client_region(self, mock_client, mock_get_credentials):
-        self.hook.get_batch_client(region="region1")
+    async def test_get_batch_client_region(self, mock_client, 
mock_get_sync_hook):
+        mock_sync_hook = mock.MagicMock()
+        mock_get_sync_hook.return_value = mock_sync_hook
+        mock_sync_hook.get_credentials.return_value = mock.MagicMock()
+
+        await self.hook.get_batch_client(region="region1")
         mock_client.assert_called_once_with(
-            credentials=mock_get_credentials.return_value, 
client_info=CLIENT_INFO, client_options=ANY
+            credentials=mock_sync_hook.get_credentials.return_value,
+            client_info=CLIENT_INFO,
+            client_options=ANY,
         )
 
     @pytest.mark.asyncio
diff --git a/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py 
b/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py
index 40048a4e5d7..03cf2986301 100644
--- a/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py
+++ b/providers/google/tests/unit/google/cloud/triggers/test_dataproc.py
@@ -18,7 +18,6 @@ from __future__ import annotations
 
 import asyncio
 import contextlib
-import logging
 from asyncio import CancelledError, Future, sleep
 from unittest import mock
 
@@ -170,56 +169,22 @@ class TestDataprocClusterTrigger:
 
     @pytest.mark.db_test
     @pytest.mark.asyncio
-    
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster")
-    async def 
test_async_cluster_triggers_on_success_should_execute_successfully(
-        self, mock_hook, cluster_trigger, async_get_cluster
-    ):
-        mock_hook.return_value = async_get_cluster(
-            project_id=TEST_PROJECT_ID,
-            region=TEST_REGION,
-            cluster_name=TEST_CLUSTER_NAME,
-            status=ClusterStatus(state=ClusterStatus.State.RUNNING),
-        )
-
-        generator = cluster_trigger.run()
-        actual_event = await generator.asend(None)
-
-        expected_event = TriggerEvent(
-            {
-                "cluster_name": TEST_CLUSTER_NAME,
-                "cluster_state": ClusterStatus.State.RUNNING,
-                "cluster": actual_event.payload["cluster"],
-            }
-        )
-        assert expected_event == actual_event
-
-    @pytest.mark.db_test
-    @pytest.mark.asyncio
-    
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster")
-    @mock.patch(
-        
"airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.delete_cluster",
-        return_value=asyncio.Future(),
-    )
-    @mock.patch("google.auth.default")
+    
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook")
+    @mock.patch.object(DataprocClusterTrigger, "log")
     async def test_async_cluster_trigger_run_returns_error_event(
-        self, mock_auth, mock_delete_cluster, mock_get_cluster, 
cluster_trigger, async_get_cluster, caplog
+        self, mock_log, mock_get_async_hook, cluster_trigger
     ):
-        mock_credentials = mock.MagicMock()
-        mock_credentials.universe_domain = "googleapis.com"
-
-        mock_auth.return_value = (mock_credentials, "project-id")
-
-        mock_delete_cluster.return_value = asyncio.Future()
-        mock_delete_cluster.return_value.set_result(None)
+        # Mock delete_cluster to return a Future
+        mock_delete_future = asyncio.Future()
+        mock_delete_future.set_result(None)
+        mock_get_async_hook.return_value.delete_cluster.return_value = 
mock_delete_future
 
-        mock_get_cluster.return_value = async_get_cluster(
-            project_id=TEST_PROJECT_ID,
-            region=TEST_REGION,
-            cluster_name=TEST_CLUSTER_NAME,
-            status=ClusterStatus(state=ClusterStatus.State.ERROR),
-        )
+        mock_cluster = mock.MagicMock()
+        mock_cluster.status = ClusterStatus(state=ClusterStatus.State.ERROR)
 
-        caplog.set_level(logging.INFO)
+        future = asyncio.Future()
+        future.set_result(mock_cluster)
+        mock_get_async_hook.return_value.get_cluster.return_value = future
 
         trigger_event = None
         async for event in cluster_trigger.run():
@@ -230,31 +195,28 @@ class TestDataprocClusterTrigger:
 
     @pytest.mark.db_test
     @pytest.mark.asyncio
-    
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster")
-    async def test_cluster_run_loop_is_still_running(
-        self, mock_hook, cluster_trigger, caplog, async_get_cluster
-    ):
-        mock_hook.return_value = async_get_cluster(
-            project_id=TEST_PROJECT_ID,
-            region=TEST_REGION,
-            cluster_name=TEST_CLUSTER_NAME,
-            status=ClusterStatus(state=ClusterStatus.State.CREATING),
-        )
+    
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook")
+    @mock.patch.object(DataprocClusterTrigger, "log")
+    async def test_cluster_run_loop_is_still_running(self, mock_log, 
mock_get_async_hook, cluster_trigger):
+        mock_cluster = mock.MagicMock()
+        mock_cluster.status = ClusterStatus(state=ClusterStatus.State.CREATING)
 
-        caplog.set_level(logging.INFO)
+        future = asyncio.Future()
+        future.set_result(mock_cluster)
+        mock_get_async_hook.return_value.get_cluster.return_value = future
 
         task = asyncio.create_task(cluster_trigger.run().__anext__())
         await asyncio.sleep(0.5)
 
         assert not task.done()
-        assert f"Current state is: {ClusterStatus.State.CREATING}."
-        assert f"Sleeping for {TEST_POLL_INTERVAL} seconds."
+        mock_log.info.assert_called()
 
     @pytest.mark.asyncio
     
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook")
     
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_sync_hook")
+    @mock.patch.object(DataprocClusterTrigger, "log")
     async def test_cluster_trigger_cancellation_handling(
-        self, mock_get_sync_hook, mock_get_async_hook, caplog
+        self, mock_log, mock_get_sync_hook, mock_get_async_hook
     ):
         cluster = 
Cluster(status=ClusterStatus(state=ClusterStatus.State.RUNNING))
         mock_get_async_hook.return_value.get_cluster.return_value = 
asyncio.Future()
@@ -288,8 +250,7 @@ class TestDataprocClusterTrigger:
                     cluster_name=cluster_trigger.cluster_name,
                     project_id=cluster_trigger.project_id,
                 )
-                assert "Deleting cluster" in caplog.text
-                assert "Deleted cluster" in caplog.text
+                mock_log.info.assert_called()
             else:
                 mock_delete_cluster.assert_not_called()
         except Exception as e:
@@ -297,19 +258,24 @@ class TestDataprocClusterTrigger:
 
     @pytest.mark.db_test
     @pytest.mark.asyncio
-    
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_cluster")
-    async def test_fetch_cluster_status(self, mock_get_cluster, 
cluster_trigger, async_get_cluster):
-        mock_get_cluster.return_value = async_get_cluster(
-            status=ClusterStatus(state=ClusterStatus.State.RUNNING)
-        )
+    
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook")
+    async def test_fetch_cluster_status(self, mock_get_async_hook, 
cluster_trigger):
+        mock_cluster = mock.MagicMock()
+        mock_cluster.status = ClusterStatus(state=ClusterStatus.State.RUNNING)
+
+        future = asyncio.Future()
+        future.set_result(mock_cluster)
+        mock_get_async_hook.return_value.get_cluster.return_value = future
+
         cluster = await cluster_trigger.fetch_cluster()
 
         assert cluster.status.state == ClusterStatus.State.RUNNING, "The 
cluster state should be RUNNING"
 
     @pytest.mark.db_test
     @pytest.mark.asyncio
-    
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.delete_cluster")
-    async def test_delete_when_error_occurred(self, mock_delete_cluster, 
cluster_trigger):
+    
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook")
+    @mock.patch.object(DataprocClusterTrigger, "log")
+    async def test_delete_when_error_occurred(self, mock_log, 
mock_get_async_hook, cluster_trigger):
         mock_cluster = mock.MagicMock(spec=Cluster)
         type(mock_cluster).status = mock.PropertyMock(
             return_value=mock.MagicMock(state=ClusterStatus.State.ERROR)
@@ -317,31 +283,32 @@ class TestDataprocClusterTrigger:
 
         mock_delete_future = asyncio.Future()
         mock_delete_future.set_result(None)
-        mock_delete_cluster.return_value = mock_delete_future
+        mock_get_async_hook.return_value.delete_cluster.return_value = 
mock_delete_future
 
         cluster_trigger.delete_on_error = True
 
         await cluster_trigger.delete_when_error_occurred(mock_cluster)
 
-        mock_delete_cluster.assert_called_once_with(
+        
mock_get_async_hook.return_value.delete_cluster.assert_called_once_with(
             region=cluster_trigger.region,
             cluster_name=cluster_trigger.cluster_name,
             project_id=cluster_trigger.project_id,
         )
 
-        mock_delete_cluster.reset_mock()
+        mock_get_async_hook.return_value.delete_cluster.reset_mock()
         cluster_trigger.delete_on_error = False
 
         await cluster_trigger.delete_when_error_occurred(mock_cluster)
 
-        mock_delete_cluster.assert_not_called()
+        mock_get_async_hook.return_value.delete_cluster.assert_not_called()
 
     @pytest.mark.asyncio
     
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook")
     
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_sync_hook")
     
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.safe_to_cancel")
+    @mock.patch.object(DataprocClusterTrigger, "log")
     async def test_cluster_trigger_run_cancelled_not_safe_to_cancel(
-        self, mock_safe_to_cancel, mock_get_sync_hook, mock_get_async_hook, 
cluster_trigger
+        self, mock_log, mock_safe_to_cancel, mock_get_sync_hook, 
mock_get_async_hook, cluster_trigger
     ):
         """Test the trigger's cancellation behavior when it is not safe to 
cancel."""
         mock_safe_to_cancel.return_value = False
@@ -366,6 +333,31 @@ class TestDataprocClusterTrigger:
         assert mock_delete_cluster.call_count == 0
         mock_delete_cluster.assert_not_called()
 
+    @pytest.mark.db_test
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocClusterTrigger.get_async_hook")
+    async def 
test_async_cluster_triggers_on_success_should_execute_successfully(
+        self, mock_get_async_hook, cluster_trigger
+    ):
+        mock_cluster = mock.MagicMock()
+        mock_cluster.status = ClusterStatus(state=ClusterStatus.State.RUNNING)
+
+        future = asyncio.Future()
+        future.set_result(mock_cluster)
+        mock_get_async_hook.return_value.get_cluster.return_value = future
+
+        generator = cluster_trigger.run()
+        actual_event = await generator.asend(None)
+
+        expected_event = TriggerEvent(
+            {
+                "cluster_name": TEST_CLUSTER_NAME,
+                "cluster_state": ClusterStatus.State.RUNNING,
+                "cluster": actual_event.payload["cluster"],
+            }
+        )
+        assert expected_event == actual_event
+
 
 class TestDataprocBatchTrigger:
     def 
test_async_create_batch_trigger_serialization_should_execute_successfully(self, 
batch_trigger):
@@ -387,17 +379,21 @@ class TestDataprocBatchTrigger:
 
     @pytest.mark.db_test
     @pytest.mark.asyncio
-    
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_batch")
+    
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocBatchTrigger.get_async_hook")
     async def 
test_async_create_batch_trigger_triggers_on_success_should_execute_successfully(
-        self, mock_hook, batch_trigger, async_get_batch
+        self, mock_get_async_hook, batch_trigger
     ):
         """
         Tests the DataprocBatchTrigger only fires once the batch execution 
reaches a successful state.
         """
 
-        mock_hook.return_value = async_get_batch(
-            state=Batch.State.SUCCEEDED, batch_id=TEST_BATCH_ID, 
state_message=TEST_BATCH_STATE_MESSAGE
-        )
+        mock_batch = mock.MagicMock()
+        mock_batch.state = Batch.State.SUCCEEDED
+        mock_batch.state_message = TEST_BATCH_STATE_MESSAGE
+
+        future = asyncio.Future()
+        future.set_result(mock_batch)
+        mock_get_async_hook.return_value.get_batch.return_value = future
 
         expected_event = TriggerEvent(
             {
@@ -413,13 +409,17 @@ class TestDataprocBatchTrigger:
 
     @pytest.mark.db_test
     @pytest.mark.asyncio
-    
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_batch")
+    
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocBatchTrigger.get_async_hook")
     async def test_async_create_batch_trigger_run_returns_failed_event(
-        self, mock_hook, batch_trigger, async_get_batch
+        self, mock_get_async_hook, batch_trigger
     ):
-        mock_hook.return_value = async_get_batch(
-            state=Batch.State.FAILED, batch_id=TEST_BATCH_ID, 
state_message=TEST_BATCH_STATE_MESSAGE
-        )
+        mock_batch = mock.MagicMock()
+        mock_batch.state = Batch.State.FAILED
+        mock_batch.state_message = TEST_BATCH_STATE_MESSAGE
+
+        future = asyncio.Future()
+        future.set_result(mock_batch)
+        mock_get_async_hook.return_value.get_batch.return_value = future
 
         expected_event = TriggerEvent(
             {
@@ -435,11 +435,15 @@ class TestDataprocBatchTrigger:
 
     @pytest.mark.db_test
     @pytest.mark.asyncio
-    
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_batch")
-    async def test_create_batch_run_returns_cancelled_event(self, mock_hook, 
batch_trigger, async_get_batch):
-        mock_hook.return_value = async_get_batch(
-            state=Batch.State.CANCELLED, batch_id=TEST_BATCH_ID, 
state_message=TEST_BATCH_STATE_MESSAGE
-        )
+    
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocBatchTrigger.get_async_hook")
+    async def test_create_batch_run_returns_cancelled_event(self, 
mock_get_async_hook, batch_trigger):
+        mock_batch = mock.MagicMock()
+        mock_batch.state = Batch.State.CANCELLED
+        mock_batch.state_message = TEST_BATCH_STATE_MESSAGE
+
+        future = asyncio.Future()
+        future.set_result(mock_batch)
+        mock_get_async_hook.return_value.get_batch.return_value = future
 
         expected_event = TriggerEvent(
             {
@@ -455,20 +459,21 @@ class TestDataprocBatchTrigger:
 
     @pytest.mark.db_test
     @pytest.mark.asyncio
-    
@mock.patch("airflow.providers.google.cloud.hooks.dataproc.DataprocAsyncHook.get_batch")
-    async def test_create_batch_run_loop_is_still_running(
-        self, mock_hook, batch_trigger, caplog, async_get_batch
-    ):
-        mock_hook.return_value = async_get_batch(state=Batch.State.RUNNING)
+    
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocBatchTrigger.get_async_hook")
+    @mock.patch.object(DataprocBatchTrigger, "log")
+    async def test_create_batch_run_loop_is_still_running(self, mock_log, 
mock_get_async_hook, batch_trigger):
+        mock_batch = mock.MagicMock()
+        mock_batch.state = Batch.State.RUNNING
 
-        caplog.set_level(logging.INFO)
+        future = asyncio.Future()
+        future.set_result(mock_batch)
+        mock_get_async_hook.return_value.get_batch.return_value = future
 
         task = asyncio.create_task(batch_trigger.run().__anext__())
         await asyncio.sleep(0.5)
 
         assert not task.done()
-        assert f"Current state is: {Batch.State.RUNNING}"
-        assert f"Sleeping for {TEST_POLL_INTERVAL} seconds."
+        mock_log.info.assert_called()
 
 
 class TestDataprocOperationTrigger:
@@ -486,13 +491,19 @@ class TestDataprocOperationTrigger:
         }
 
     @pytest.mark.asyncio
-    
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocBaseTrigger.get_async_hook")
+    
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocOperationTrigger.get_async_hook")
     async def 
test_async_operation_triggers_on_success_should_execute_successfully(
-        self, mock_hook, operation_trigger, async_get_operation
+        self, mock_get_async_hook, operation_trigger
     ):
-        mock_hook.return_value.get_operation.return_value = 
async_get_operation(
-            name=TEST_OPERATION_NAME, done=True, response={}, 
error=Status(message="")
-        )
+        mock_operation = mock.MagicMock()
+        mock_operation.name = TEST_OPERATION_NAME
+        mock_operation.done = True
+        mock_operation.response = {}
+        mock_operation.error = Status(message="")
+
+        future = asyncio.Future()
+        future.set_result(mock_operation)
+        mock_get_async_hook.return_value.get_operation.return_value = future
 
         expected_event = TriggerEvent(
             {
@@ -506,17 +517,20 @@ class TestDataprocOperationTrigger:
         assert expected_event == actual_event
 
     @pytest.mark.asyncio
-    
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocBaseTrigger.get_async_hook")
+    
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocOperationTrigger.get_async_hook")
     async def 
test_async_diagnose_operation_triggers_on_success_should_execute_successfully(
-        self, mock_hook, diagnose_operation_trigger, async_get_operation
+        self, mock_get_async_hook, diagnose_operation_trigger
     ):
         gcs_uri = "gs://test-tarball-gcs-dir-bucket"
-        mock_hook.return_value.get_operation.return_value = 
async_get_operation(
-            name=TEST_OPERATION_NAME,
-            done=True,
-            response=Any(value=gcs_uri.encode("utf-8")),
-            error=Status(message=""),
-        )
+        mock_operation = mock.MagicMock()
+        mock_operation.name = TEST_OPERATION_NAME
+        mock_operation.done = True
+        mock_operation.response = Any(value=gcs_uri.encode("utf-8"))
+        mock_operation.error = Status(message="")
+
+        future = asyncio.Future()
+        future.set_result(mock_operation)
+        mock_get_async_hook.return_value.get_operation.return_value = future
 
         expected_event = TriggerEvent(
             {
@@ -529,11 +543,17 @@ class TestDataprocOperationTrigger:
         assert expected_event == actual_event
 
     @pytest.mark.asyncio
-    
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocBaseTrigger.get_async_hook")
-    async def test_async_operation_triggers_on_error(self, mock_hook, 
operation_trigger, async_get_operation):
-        mock_hook.return_value.get_operation.return_value = 
async_get_operation(
-            name=TEST_OPERATION_NAME, done=True, response={}, 
error=Status(message="test_error")
-        )
+    
@mock.patch("airflow.providers.google.cloud.triggers.dataproc.DataprocOperationTrigger.get_async_hook")
+    async def test_async_operation_triggers_on_error(self, 
mock_get_async_hook, operation_trigger):
+        mock_operation = mock.MagicMock()
+        mock_operation.name = TEST_OPERATION_NAME
+        mock_operation.done = True
+        mock_operation.response = {}
+        mock_operation.error = Status(message="test_error")
+
+        future = asyncio.Future()
+        future.set_result(mock_operation)
+        mock_get_async_hook.return_value.get_operation.return_value = future
 
         expected_event = TriggerEvent(
             {


Reply via email to