This is an automated email from the ASF dual-hosted git repository. eladkal pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 13e9a0d0d5 Fix credentials error for S3ToGCSOperator trigger (#37518) 13e9a0d0d5 is described below commit 13e9a0d0d562238afe03830a33dd52cf5cea5d12 Author: korolkevichm <52798297+korolkev...@users.noreply.github.com> AuthorDate: Mon Apr 1 15:56:39 2024 +0300 Fix credentials error for S3ToGCSOperator trigger (#37518) * Fix credentials error for S3ToGCSOperator trigger * fix: safe create StorageTransferServiceAsyncClient() * fix: test, style, some bugs --- .../google/cloud/hooks/cloud_storage_transfer_service.py | 13 +++++++++---- airflow/providers/google/cloud/transfers/s3_to_gcs.py | 1 + .../cloud/triggers/cloud_storage_transfer_service.py | 16 ++++++++++++++-- .../hooks/test_cloud_storage_transfer_service_async.py | 10 +++++++--- .../triggers/test_cloud_storage_transfer_service.py | 7 ++++++- 5 files changed, 37 insertions(+), 10 deletions(-) diff --git a/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py b/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py index e8f9cbc93d..966735e9c2 100644 --- a/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py +++ b/airflow/providers/google/cloud/hooks/cloud_storage_transfer_service.py @@ -45,6 +45,7 @@ from googleapiclient.discovery import Resource, build from googleapiclient.errors import HttpError from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning +from airflow.providers.google.common.consts import CLIENT_INFO from airflow.providers.google.common.hooks.base_google import GoogleBaseAsyncHook, GoogleBaseHook if TYPE_CHECKING: @@ -508,14 +509,18 @@ class CloudDataTransferServiceAsyncHook(GoogleBaseAsyncHook): self.project_id = project_id self._client: StorageTransferServiceAsyncClient | None = None - def get_conn(self) -> StorageTransferServiceAsyncClient: + async def get_conn(self) -> StorageTransferServiceAsyncClient: """ Return async connection to the Storage Transfer Service. :return: Google Storage Transfer asynchronous client. """ if not self._client: - self._client = StorageTransferServiceAsyncClient() + credentials = (await self.get_sync_hook()).get_credentials() + self._client = StorageTransferServiceAsyncClient( + credentials=credentials, + client_info=CLIENT_INFO, + ) return self._client async def get_jobs(self, job_names: list[str]) -> ListTransferJobsAsyncPager: @@ -525,7 +530,7 @@ class CloudDataTransferServiceAsyncHook(GoogleBaseAsyncHook): :param job_names: (Required) List of names of the jobs to be fetched. :return: Object that yields Transfer jobs. """ - client = self.get_conn() + client = await self.get_conn() jobs_list_request = ListTransferJobsRequest( filter=json.dumps({"project_id": self.project_id, "job_names": job_names}) ) @@ -540,7 +545,7 @@ class CloudDataTransferServiceAsyncHook(GoogleBaseAsyncHook): """ latest_operation_name = job.latest_operation_name if latest_operation_name: - client = self.get_conn() + client = await self.get_conn() response_operation = await client.transport.operations_client.get_operation(latest_operation_name) operation = TransferOperation.deserialize(response_operation.metadata.value) return operation diff --git a/airflow/providers/google/cloud/transfers/s3_to_gcs.py b/airflow/providers/google/cloud/transfers/s3_to_gcs.py index b2b306f74b..8c935603d2 100644 --- a/airflow/providers/google/cloud/transfers/s3_to_gcs.py +++ b/airflow/providers/google/cloud/transfers/s3_to_gcs.py @@ -269,6 +269,7 @@ class S3ToGCSOperator(S3ListOperator): self.defer( trigger=CloudStorageTransferServiceCreateJobsTrigger( project_id=gcs_hook.project_id, + gcp_conn_id=self.gcp_conn_id, job_names=job_names, poll_interval=self.poll_interval, ), diff --git a/airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py b/airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py index cd0440cee2..f5ab93242e 100644 --- a/airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py +++ b/airflow/providers/google/cloud/triggers/cloud_storage_transfer_service.py @@ -37,11 +37,19 @@ class CloudStorageTransferServiceCreateJobsTrigger(BaseTrigger): :param job_names: List of transfer jobs names. :param project_id: GCP project id. :param poll_interval: Interval in seconds between polls. + :param gcp_conn_id: The connection ID used to connect to Google Cloud. """ - def __init__(self, job_names: list[str], project_id: str | None = None, poll_interval: int = 10) -> None: + def __init__( + self, + job_names: list[str], + project_id: str | None = None, + poll_interval: int = 10, + gcp_conn_id: str = "google_cloud_default", + ) -> None: super().__init__() self.project_id = project_id + self.gcp_conn_id = gcp_conn_id self.job_names = job_names self.poll_interval = poll_interval @@ -53,6 +61,7 @@ class CloudStorageTransferServiceCreateJobsTrigger(BaseTrigger): "project_id": self.project_id, "job_names": self.job_names, "poll_interval": self.poll_interval, + "gcp_conn_id": self.gcp_conn_id, }, ) @@ -117,4 +126,7 @@ class CloudStorageTransferServiceCreateJobsTrigger(BaseTrigger): await asyncio.sleep(self.poll_interval) def get_async_hook(self) -> CloudDataTransferServiceAsyncHook: - return CloudDataTransferServiceAsyncHook(project_id=self.project_id) + return CloudDataTransferServiceAsyncHook( + project_id=self.project_id, + gcp_conn_id=self.gcp_conn_id, + ) diff --git a/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service_async.py b/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service_async.py index a27d233fd1..e05bacbbd2 100644 --- a/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service_async.py +++ b/tests/providers/google/cloud/hooks/test_cloud_storage_transfer_service_async.py @@ -42,16 +42,20 @@ def hook_async(): class TestCloudDataTransferServiceAsyncHook: + @pytest.mark.asyncio + @mock.patch(f"{TRANSFER_HOOK_PATH}.CloudDataTransferServiceAsyncHook.get_conn") @mock.patch(f"{TRANSFER_HOOK_PATH}.StorageTransferServiceAsyncClient") - def test_get_conn(self, mock_async_client): + async def test_get_conn(self, mock_async_client, mock_get_conn): expected_value = "Async Hook" mock_async_client.return_value = expected_value + mock_get_conn.return_value = expected_value hook = CloudDataTransferServiceAsyncHook(project_id=TEST_PROJECT_ID) - conn_0 = hook.get_conn() + + conn_0 = await hook.get_conn() assert conn_0 == expected_value - conn_1 = hook.get_conn() + conn_1 = await hook.get_conn() assert conn_1 == expected_value assert id(conn_0) == id(conn_1) diff --git a/tests/providers/google/cloud/triggers/test_cloud_storage_transfer_service.py b/tests/providers/google/cloud/triggers/test_cloud_storage_transfer_service.py index 2072138954..072c6a5d7d 100644 --- a/tests/providers/google/cloud/triggers/test_cloud_storage_transfer_service.py +++ b/tests/providers/google/cloud/triggers/test_cloud_storage_transfer_service.py @@ -32,6 +32,7 @@ from airflow.providers.google.cloud.triggers.cloud_storage_transfer_service impo from airflow.triggers.base import TriggerEvent PROJECT_ID = "test-project" +GCP_CONN_ID = "google-cloud-default-id" JOB_0 = "test-job-0" JOB_1 = "test-job-1" JOB_NAMES = [JOB_0, JOB_1] @@ -51,7 +52,10 @@ ASYNC_HOOK_CLASS_PATH = ( @pytest.fixture(scope="session") def trigger(): return CloudStorageTransferServiceCreateJobsTrigger( - project_id=PROJECT_ID, job_names=JOB_NAMES, poll_interval=POLL_INTERVAL + project_id=PROJECT_ID, + job_names=JOB_NAMES, + poll_interval=POLL_INTERVAL, + gcp_conn_id=GCP_CONN_ID, ) @@ -80,6 +84,7 @@ class TestCloudStorageTransferServiceCreateJobsTrigger: "project_id": PROJECT_ID, "job_names": JOB_NAMES, "poll_interval": POLL_INTERVAL, + "gcp_conn_id": GCP_CONN_ID, } def test_get_async_hook(self, trigger):