This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 13e9a0d0d5 Fix credentials error for S3ToGCSOperator trigger (#37518)
13e9a0d0d5 is described below

commit 13e9a0d0d562238afe03830a33dd52cf5cea5d12
Author: korolkevichm <52798297+korolkev...@users.noreply.github.com>
AuthorDate: Mon Apr 1 15:56:39 2024 +0300

    Fix credentials error for S3ToGCSOperator trigger (#37518)
    
    * Fix credentials error for S3ToGCSOperator trigger
    
    * fix: safe create StorageTransferServiceAsyncClient()
    
    * fix: test, style, some bugs
---
 .../google/cloud/hooks/cloud_storage_transfer_service.py | 13 +++++++++----
 airflow/providers/google/cloud/transfers/s3_to_gcs.py    |  1 +
 .../cloud/triggers/cloud_storage_transfer_service.py     | 16 ++++++++++++++--
 .../hooks/test_cloud_storage_transfer_service_async.py   | 10 +++++++---
 .../triggers/test_cloud_storage_transfer_service.py      |  7 ++++++-
 5 files changed, 37 insertions(+), 10 deletions(-)

diff --git 
a/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py 
b/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py
index e8f9cbc93d..966735e9c2 100644
--- a/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py
+++ b/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py
@@ -45,6 +45,7 @@ from googleapiclient.discovery import Resource, build
 from googleapiclient.errors import HttpError
 
 from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
+from airflow.providers.google.common.consts import CLIENT_INFO
 from airflow.providers.google.common.hooks.base_google import 
GoogleBaseAsyncHook, GoogleBaseHook
 
 if TYPE_CHECKING:
@@ -508,14 +509,18 @@ class 
CloudDataTransferServiceAsyncHook(GoogleBaseAsyncHook):
         self.project_id = project_id
         self._client: StorageTransferServiceAsyncClient | None = None
 
-    def get_conn(self) -> StorageTransferServiceAsyncClient:
+    async def get_conn(self) -> StorageTransferServiceAsyncClient:
         """
         Return async connection to the Storage Transfer Service.
 
         :return: Google Storage Transfer asynchronous client.
         """
         if not self._client:
-            self._client = StorageTransferServiceAsyncClient()
+            credentials = (await self.get_sync_hook()).get_credentials()
+            self._client = StorageTransferServiceAsyncClient(
+                credentials=credentials,
+                client_info=CLIENT_INFO,
+            )
         return self._client
 
     async def get_jobs(self, job_names: list[str]) -> 
ListTransferJobsAsyncPager:
@@ -525,7 +530,7 @@ class 
CloudDataTransferServiceAsyncHook(GoogleBaseAsyncHook):
         :param job_names: (Required) List of names of the jobs to be fetched.
         :return: Object that yields Transfer jobs.
         """
-        client = self.get_conn()
+        client = await self.get_conn()
         jobs_list_request = ListTransferJobsRequest(
             filter=json.dumps({"project_id": self.project_id, "job_names": 
job_names})
         )
@@ -540,7 +545,7 @@ class 
CloudDataTransferServiceAsyncHook(GoogleBaseAsyncHook):
         """
         latest_operation_name = job.latest_operation_name
         if latest_operation_name:
-            client = self.get_conn()
+            client = await self.get_conn()
             response_operation = await 
client.transport.operations_client.get_operation(latest_operation_name)
             operation = 
TransferOperation.deserialize(response_operation.metadata.value)
             return operation
diff --git a/airflow/providers/google/cloud/transfers/s3_to_gcs.py 
b/airflow/providers/google/cloud/transfers/s3_to_gcs.py
index b2b306f74b..8c935603d2 100644
--- a/airflow/providers/google/cloud/transfers/s3_to_gcs.py
+++ b/airflow/providers/google/cloud/transfers/s3_to_gcs.py
@@ -269,6 +269,7 @@ class S3ToGCSOperator(S3ListOperator):
         self.defer(
             trigger=CloudStorageTransferServiceCreateJobsTrigger(
                 project_id=gcs_hook.project_id,
+                gcp_conn_id=self.gcp_conn_id,
                 job_names=job_names,
                 poll_interval=self.poll_interval,
             ),
diff --git 
a/airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py 
b/airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py
index cd0440cee2..f5ab93242e 100644
--- a/airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py
+++ b/airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py
@@ -37,11 +37,19 @@ class 
CloudStorageTransferServiceCreateJobsTrigger(BaseTrigger):
     :param job_names: List of transfer jobs names.
     :param project_id: GCP project id.
     :param poll_interval: Interval in seconds between polls.
+    :param gcp_conn_id: The connection ID used to connect to Google Cloud.
     """
 
-    def __init__(self, job_names: list[str], project_id: str | None = None, 
poll_interval: int = 10) -> None:
+    def __init__(
+        self,
+        job_names: list[str],
+        project_id: str | None = None,
+        poll_interval: int = 10,
+        gcp_conn_id: str = "google_cloud_default",
+    ) -> None:
         super().__init__()
         self.project_id = project_id
+        self.gcp_conn_id = gcp_conn_id
         self.job_names = job_names
         self.poll_interval = poll_interval
 
@@ -53,6 +61,7 @@ class 
CloudStorageTransferServiceCreateJobsTrigger(BaseTrigger):
                 "project_id": self.project_id,
                 "job_names": self.job_names,
                 "poll_interval": self.poll_interval,
+                "gcp_conn_id": self.gcp_conn_id,
             },
         )
 
@@ -117,4 +126,7 @@ class 
CloudStorageTransferServiceCreateJobsTrigger(BaseTrigger):
             await asyncio.sleep(self.poll_interval)
 
     def get_async_hook(self) -> CloudDataTransferServiceAsyncHook:
-        return CloudDataTransferServiceAsyncHook(project_id=self.project_id)
+        return CloudDataTransferServiceAsyncHook(
+            project_id=self.project_id,
+            gcp_conn_id=self.gcp_conn_id,
+        )
diff --git 
a/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service_async.py
 
b/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service_async.py
index a27d233fd1..e05bacbbd2 100644
--- 
a/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service_async.py
+++ 
b/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service_async.py
@@ -42,16 +42,20 @@ def hook_async():
 
 
 class TestCloudDataTransferServiceAsyncHook:
+    @pytest.mark.asyncio
+    
@mock.patch(f"{TRANSFER_HOOK_PATH}.CloudDataTransferServiceAsyncHook.get_conn")
     @mock.patch(f"{TRANSFER_HOOK_PATH}.StorageTransferServiceAsyncClient")
-    def test_get_conn(self, mock_async_client):
+    async def test_get_conn(self, mock_async_client, mock_get_conn):
         expected_value = "Async Hook"
         mock_async_client.return_value = expected_value
+        mock_get_conn.return_value = expected_value
 
         hook = CloudDataTransferServiceAsyncHook(project_id=TEST_PROJECT_ID)
-        conn_0 = hook.get_conn()
+
+        conn_0 = await hook.get_conn()
         assert conn_0 == expected_value
 
-        conn_1 = hook.get_conn()
+        conn_1 = await hook.get_conn()
         assert conn_1 == expected_value
         assert id(conn_0) == id(conn_1)
 
diff --git 
a/tests/providers/google/cloud/triggers/test_cloud_storage_transfer_service.py 
b/tests/providers/google/cloud/triggers/test_cloud_storage_transfer_service.py
index 2072138954..072c6a5d7d 100644
--- 
a/tests/providers/google/cloud/triggers/test_cloud_storage_transfer_service.py
+++ 
b/tests/providers/google/cloud/triggers/test_cloud_storage_transfer_service.py
@@ -32,6 +32,7 @@ from 
airflow.providers.google.cloud.triggers.cloud_storage_transfer_service impo
 from airflow.triggers.base import TriggerEvent
 
 PROJECT_ID = "test-project"
+GCP_CONN_ID = "google-cloud-default-id"
 JOB_0 = "test-job-0"
 JOB_1 = "test-job-1"
 JOB_NAMES = [JOB_0, JOB_1]
@@ -51,7 +52,10 @@ ASYNC_HOOK_CLASS_PATH = (
 @pytest.fixture(scope="session")
 def trigger():
     return CloudStorageTransferServiceCreateJobsTrigger(
-        project_id=PROJECT_ID, job_names=JOB_NAMES, poll_interval=POLL_INTERVAL
+        project_id=PROJECT_ID,
+        job_names=JOB_NAMES,
+        poll_interval=POLL_INTERVAL,
+        gcp_conn_id=GCP_CONN_ID,
     )
 
 
@@ -80,6 +84,7 @@ class TestCloudStorageTransferServiceCreateJobsTrigger:
             "project_id": PROJECT_ID,
             "job_names": JOB_NAMES,
             "poll_interval": POLL_INTERVAL,
+            "gcp_conn_id": GCP_CONN_ID,
         }
 
     def test_get_async_hook(self, trigger):

Reply via email to