This is an automated email from the ASF dual-hosted git repository. ferruzzi pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 915f9e4060 Add GCS Requester Pays bucket support to GCSToS3Operator (#32760) 915f9e4060 is described below commit 915f9e40601fbfa3ebcf2fe82ced14191b12ab18 Author: Hank Ehly <henry.e...@gmail.com> AuthorDate: Tue Aug 1 02:33:52 2023 +0900 Add GCS Requester Pays bucket support to GCSToS3Operator (#32760) * Add requester pays bucket support to GCSToS3Operator * Update docstrings * isort * Fix failing unit tests * Fix failing test --- .../providers/amazon/aws/transfers/gcs_to_s3.py | 32 ++++++++----- airflow/providers/google/cloud/hooks/gcs.py | 52 +++++++++++++++++----- airflow/providers/google/cloud/operators/gcs.py | 8 +++- .../amazon/aws/transfers/test_gcs_to_s3.py | 6 ++- tests/providers/google/cloud/hooks/test_gcs.py | 6 +-- tests/providers/google/cloud/operators/test_gcs.py | 5 ++- .../providers/amazon/aws/example_gcs_to_s3.py | 52 ++++++++++++++++++++-- 7 files changed, 126 insertions(+), 35 deletions(-) diff --git a/airflow/providers/amazon/aws/transfers/gcs_to_s3.py b/airflow/providers/amazon/aws/transfers/gcs_to_s3.py index d57de7e11e..2cdd0761a1 100644 --- a/airflow/providers/amazon/aws/transfers/gcs_to_s3.py +++ b/airflow/providers/amazon/aws/transfers/gcs_to_s3.py @@ -80,6 +80,8 @@ class GCSToS3Operator(BaseOperator): on the bucket is recreated within path passed in dest_s3_key. :param match_glob: (Optional) filters objects based on the glob pattern given by the string (e.g, ``'**/*/.json'``) + :param gcp_user_project: (Optional) The identifier of the Google Cloud project to bill for this request. + Required for Requester Pays buckets. """ template_fields: Sequence[str] = ( @@ -88,6 +90,7 @@ class GCSToS3Operator(BaseOperator): "delimiter", "dest_s3_key", "google_impersonation_chain", + "gcp_user_project", ) ui_color = "#f0eee4" @@ -107,6 +110,7 @@ class GCSToS3Operator(BaseOperator): s3_acl_policy: str | None = None, keep_directory_structure: bool = True, match_glob: str | None = None, + gcp_user_project: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -130,10 +134,11 @@ class GCSToS3Operator(BaseOperator): self.s3_acl_policy = s3_acl_policy self.keep_directory_structure = keep_directory_structure self.match_glob = match_glob + self.gcp_user_project = gcp_user_project def execute(self, context: Context) -> list[str]: # list all files in an Google Cloud Storage bucket - hook = GCSHook( + gcs_hook = GCSHook( gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.google_impersonation_chain, ) @@ -145,8 +150,12 @@ class GCSToS3Operator(BaseOperator): self.prefix, ) - files = hook.list( - bucket_name=self.bucket, prefix=self.prefix, delimiter=self.delimiter, match_glob=self.match_glob + gcs_files = gcs_hook.list( + bucket_name=self.bucket, + prefix=self.prefix, + delimiter=self.delimiter, + match_glob=self.match_glob, + user_project=self.gcp_user_project, ) s3_hook = S3Hook( @@ -173,24 +182,23 @@ class GCSToS3Operator(BaseOperator): existing_files = existing_files if existing_files is not None else [] # remove the prefix for the existing files to allow the match existing_files = [file.replace(prefix, "", 1) for file in existing_files] - files = list(set(files) - set(existing_files)) + gcs_files = list(set(gcs_files) - set(existing_files)) - if files: - - for file in files: - with hook.provide_file(object_name=file, bucket_name=self.bucket) as local_tmp_file: + if gcs_files: + for file in gcs_files: + with gcs_hook.provide_file( + object_name=file, bucket_name=self.bucket, user_project=self.gcp_user_project + ) as local_tmp_file: dest_key = os.path.join(self.dest_s3_key, file) self.log.info("Saving file to %s", dest_key) - s3_hook.load_file( filename=local_tmp_file.name, key=dest_key, replace=self.replace, acl_policy=self.s3_acl_policy, ) - - self.log.info("All done, uploaded %d files to S3", len(files)) + self.log.info("All done, uploaded %d files to S3", len(gcs_files)) else: self.log.info("In sync, no files needed to be uploaded to S3") - return files + return gcs_files diff --git a/airflow/providers/google/cloud/hooks/gcs.py b/airflow/providers/google/cloud/hooks/gcs.py index a42f81c355..a01bf72259 100644 --- a/airflow/providers/google/cloud/hooks/gcs.py +++ b/airflow/providers/google/cloud/hooks/gcs.py @@ -197,7 +197,6 @@ class GCSHook(GoogleBaseHook): destination_object = destination_object or source_object if source_bucket == destination_bucket and source_object == destination_object: - raise ValueError( f"Either source/destination bucket or source/destination object must be different, " f"not both the same: bucket={source_bucket}, object={source_object}" @@ -282,6 +281,7 @@ class GCSHook(GoogleBaseHook): chunk_size: int | None = None, timeout: int | None = DEFAULT_TIMEOUT, num_max_attempts: int | None = 1, + user_project: str | None = None, ) -> bytes: ... @@ -294,6 +294,7 @@ class GCSHook(GoogleBaseHook): chunk_size: int | None = None, timeout: int | None = DEFAULT_TIMEOUT, num_max_attempts: int | None = 1, + user_project: str | None = None, ) -> str: ... @@ -305,6 +306,7 @@ class GCSHook(GoogleBaseHook): chunk_size: int | None = None, timeout: int | None = DEFAULT_TIMEOUT, num_max_attempts: int | None = 1, + user_project: str | None = None, ) -> str | bytes: """ Downloads a file from Google Cloud Storage. @@ -320,6 +322,8 @@ class GCSHook(GoogleBaseHook): :param chunk_size: Blob chunk size. :param timeout: Request timeout in seconds. :param num_max_attempts: Number of attempts to download the file. + :param user_project: The identifier of the Google Cloud project to bill for the request. + Required for Requester Pays buckets. """ # TODO: future improvement check file size before downloading, # to check for local space availability @@ -330,7 +334,7 @@ class GCSHook(GoogleBaseHook): try: num_file_attempts += 1 client = self.get_conn() - bucket = client.bucket(bucket_name) + bucket = client.bucket(bucket_name, user_project=user_project) blob = bucket.blob(blob_name=object_name, chunk_size=chunk_size) if filename: @@ -395,6 +399,7 @@ class GCSHook(GoogleBaseHook): object_name: str | None = None, object_url: str | None = None, dir: str | None = None, + user_project: str | None = None, ) -> Generator[IO[bytes], None, None]: """ Downloads the file to a temporary directory and returns a file handle. @@ -406,13 +411,20 @@ class GCSHook(GoogleBaseHook): :param object_name: The object to fetch. :param object_url: File reference url. Must start with "gs: //" :param dir: The tmp sub directory to download the file to. (passed to NamedTemporaryFile) + :param user_project: The identifier of the Google Cloud project to bill for the request. + Required for Requester Pays buckets. :return: File handler """ if object_name is None: raise ValueError("Object name can not be empty") _, _, file_name = object_name.rpartition("/") with NamedTemporaryFile(suffix=file_name, dir=dir) as tmp_file: - self.download(bucket_name=bucket_name, object_name=object_name, filename=tmp_file.name) + self.download( + bucket_name=bucket_name, + object_name=object_name, + filename=tmp_file.name, + user_project=user_project, + ) tmp_file.flush() yield tmp_file @@ -423,6 +435,7 @@ class GCSHook(GoogleBaseHook): bucket_name: str = PROVIDE_BUCKET, object_name: str | None = None, object_url: str | None = None, + user_project: str | None = None, ) -> Generator[IO[bytes], None, None]: """ Creates temporary file, returns a file handle and uploads the files content on close. @@ -433,6 +446,8 @@ class GCSHook(GoogleBaseHook): :param bucket_name: The bucket to fetch from. :param object_name: The object to fetch. :param object_url: File reference url. Must start with "gs: //" + :param user_project: The identifier of the Google Cloud project to bill for the request. + Required for Requester Pays buckets. :return: File handler """ if object_name is None: @@ -442,7 +457,12 @@ class GCSHook(GoogleBaseHook): with NamedTemporaryFile(suffix=file_name) as tmp_file: yield tmp_file tmp_file.flush() - self.upload(bucket_name=bucket_name, object_name=object_name, filename=tmp_file.name) + self.upload( + bucket_name=bucket_name, + object_name=object_name, + filename=tmp_file.name, + user_project=user_project, + ) def upload( self, @@ -458,6 +478,7 @@ class GCSHook(GoogleBaseHook): num_max_attempts: int = 1, metadata: dict | None = None, cache_control: str | None = None, + user_project: str | None = None, ) -> None: """ Uploads a local file or file data as string or bytes to Google Cloud Storage. @@ -474,6 +495,8 @@ class GCSHook(GoogleBaseHook): :param num_max_attempts: Number of attempts to try to upload the file. :param metadata: The metadata to be uploaded with the file. :param cache_control: Cache-Control metadata field. + :param user_project: The identifier of the Google Cloud project to bill for the request. + Required for Requester Pays buckets. """ def _call_with_retry(f: Callable[[], None]) -> None: @@ -506,7 +529,7 @@ class GCSHook(GoogleBaseHook): continue client = self.get_conn() - bucket = client.bucket(bucket_name) + bucket = client.bucket(bucket_name, user_project=user_project) blob = bucket.blob(blob_name=object_name, chunk_size=chunk_size) if metadata: @@ -596,7 +619,6 @@ class GCSHook(GoogleBaseHook): """ blob_update_time = self.get_blob_update_time(bucket_name, object_name) if blob_update_time is not None: - if not ts.tzinfo: ts = ts.replace(tzinfo=timezone.utc) self.log.info("Verify object date: %s > %s", blob_update_time, ts) @@ -618,7 +640,6 @@ class GCSHook(GoogleBaseHook): """ blob_update_time = self.get_blob_update_time(bucket_name, object_name) if blob_update_time is not None: - if not min_ts.tzinfo: min_ts = min_ts.replace(tzinfo=timezone.utc) if not max_ts.tzinfo: @@ -639,7 +660,6 @@ class GCSHook(GoogleBaseHook): """ blob_update_time = self.get_blob_update_time(bucket_name, object_name) if blob_update_time is not None: - if not ts.tzinfo: ts = ts.replace(tzinfo=timezone.utc) self.log.info("Verify object date: %s < %s", blob_update_time, ts) @@ -681,16 +701,18 @@ class GCSHook(GoogleBaseHook): self.log.info("Blob %s deleted.", object_name) - def delete_bucket(self, bucket_name: str, force: bool = False) -> None: + def delete_bucket(self, bucket_name: str, force: bool = False, user_project: str | None = None) -> None: """ Delete a bucket object from the Google Cloud Storage. :param bucket_name: name of the bucket which will be deleted :param force: false not allow to delete non empty bucket, set force=True allows to delete non empty bucket + :param user_project: The identifier of the Google Cloud project to bill for the request. + Required for Requester Pays buckets. """ client = self.get_conn() - bucket = client.bucket(bucket_name) + bucket = client.bucket(bucket_name, user_project=user_project) self.log.info("Deleting %s bucket", bucket_name) try: @@ -707,6 +729,7 @@ class GCSHook(GoogleBaseHook): prefix: str | List[str] | None = None, delimiter: str | None = None, match_glob: str | None = None, + user_project: str | None = None, ): """ List all objects from the bucket with the given a single prefix or multiple prefixes. @@ -718,6 +741,8 @@ class GCSHook(GoogleBaseHook): :param delimiter: (Deprecated) filters objects based on the delimiter (for e.g '.csv') :param match_glob: (Optional) filters objects based on the glob pattern given by the string (e.g, ``'**/*/.json'``). + :param user_project: The identifier of the Google Cloud project to bill for the request. + Required for Requester Pays buckets. :return: a stream of object names matching the filtering criteria """ if delimiter and delimiter != "/": @@ -739,6 +764,7 @@ class GCSHook(GoogleBaseHook): prefix=prefix_item, delimiter=delimiter, match_glob=match_glob, + user_project=user_project, ) ) else: @@ -750,6 +776,7 @@ class GCSHook(GoogleBaseHook): prefix=prefix, delimiter=delimiter, match_glob=match_glob, + user_project=user_project, ) ) return objects @@ -762,6 +789,7 @@ class GCSHook(GoogleBaseHook): prefix: str | None = None, delimiter: str | None = None, match_glob: str | None = None, + user_project: str | None = None, ) -> List: """ List all objects from the bucket with the give string prefix in name. @@ -773,10 +801,12 @@ class GCSHook(GoogleBaseHook): :param delimiter: (Deprecated) filters objects based on the delimiter (for e.g '.csv') :param match_glob: (Optional) filters objects based on the glob pattern given by the string (e.g, ``'**/*/.json'``). + :param user_project: The identifier of the Google Cloud project to bill for the request. + Required for Requester Pays buckets. :return: a stream of object names matching the filtering criteria """ client = self.get_conn() - bucket = client.bucket(bucket_name) + bucket = client.bucket(bucket_name, user_project=user_project) ids = [] page_token = None diff --git a/airflow/providers/google/cloud/operators/gcs.py b/airflow/providers/google/cloud/operators/gcs.py index fd73af42fb..9b95032b42 100644 --- a/airflow/providers/google/cloud/operators/gcs.py +++ b/airflow/providers/google/cloud/operators/gcs.py @@ -301,7 +301,6 @@ class GCSDeleteObjectsOperator(GoogleCloudBaseOperator): impersonation_chain: str | Sequence[str] | None = None, **kwargs, ) -> None: - self.bucket_name = bucket_name self.objects = objects self.prefix = prefix @@ -875,12 +874,15 @@ class GCSDeleteBucketOperator(GoogleCloudBaseOperator): If set as a sequence, the identities from the list must grant Service Account Token Creator IAM role to the directly preceding identity, with first account from the list granting this role to the originating account (templated). + :param user_project: (Optional) The identifier of the project to bill for this request. + Required for Requester Pays buckets. """ template_fields: Sequence[str] = ( "bucket_name", "gcp_conn_id", "impersonation_chain", + "user_project", ) def __init__( @@ -890,6 +892,7 @@ class GCSDeleteBucketOperator(GoogleCloudBaseOperator): force: bool = True, gcp_conn_id: str = "google_cloud_default", impersonation_chain: str | Sequence[str] | None = None, + user_project: str | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -898,10 +901,11 @@ class GCSDeleteBucketOperator(GoogleCloudBaseOperator): self.force: bool = force self.gcp_conn_id = gcp_conn_id self.impersonation_chain = impersonation_chain + self.user_project = user_project def execute(self, context: Context) -> None: hook = GCSHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) - hook.delete_bucket(bucket_name=self.bucket_name, force=self.force) + hook.delete_bucket(bucket_name=self.bucket_name, force=self.force, user_project=self.user_project) class GCSSynchronizeBucketsOperator(GoogleCloudBaseOperator): diff --git a/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py b/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py index 5e64f167ba..9d5c497de1 100644 --- a/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py +++ b/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py @@ -69,7 +69,11 @@ class TestGCSToS3Operator: operator.execute(None) mock_hook.return_value.list.assert_called_once_with( - bucket_name=GCS_BUCKET, delimiter=None, match_glob=f"**/*{DELIMITER}", prefix=PREFIX + bucket_name=GCS_BUCKET, + delimiter=None, + match_glob=f"**/*{DELIMITER}", + prefix=PREFIX, + user_project=None, ) @mock.patch("airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook") diff --git a/tests/providers/google/cloud/hooks/test_gcs.py b/tests/providers/google/cloud/hooks/test_gcs.py index 61ce7a4162..4f1839fc42 100644 --- a/tests/providers/google/cloud/hooks/test_gcs.py +++ b/tests/providers/google/cloud/hooks/test_gcs.py @@ -503,7 +503,7 @@ class TestGCSHook: self.gcs_hook.delete_bucket(bucket_name=test_bucket) - mock_service.return_value.bucket.assert_called_once_with(test_bucket) + mock_service.return_value.bucket.assert_called_once_with(test_bucket, user_project=None) mock_service.return_value.bucket.return_value.delete.assert_called_once() @mock.patch(GCS_STRING.format("GCSHook.get_conn")) @@ -514,7 +514,7 @@ class TestGCSHook: test_bucket = "test bucket" with caplog.at_level(logging.INFO): self.gcs_hook.delete_bucket(bucket_name=test_bucket) - mock_service.return_value.bucket.assert_called_once_with(test_bucket) + mock_service.return_value.bucket.assert_called_once_with(test_bucket, user_project=None) mock_service.return_value.bucket.return_value.delete.assert_called_once() assert "Bucket test bucket not exist" in caplog.text @@ -784,7 +784,7 @@ class TestGCSHook: fhandle.write() mock_upload.assert_called_once_with( - bucket_name=test_bucket, object_name=test_object, filename=test_file + bucket_name=test_bucket, object_name=test_object, filename=test_file, user_project=None ) mock_temp_file.assert_has_calls( [ diff --git a/tests/providers/google/cloud/operators/test_gcs.py b/tests/providers/google/cloud/operators/test_gcs.py index 4ceaa5292b..815cad300d 100644 --- a/tests/providers/google/cloud/operators/test_gcs.py +++ b/tests/providers/google/cloud/operators/test_gcs.py @@ -196,7 +196,6 @@ class TestGCSFileTransformOperator: @mock.patch("airflow.providers.google.cloud.operators.gcs.subprocess") @mock.patch("airflow.providers.google.cloud.operators.gcs.GCSHook") def test_execute(self, mock_hook, mock_subprocess, mock_tempfile): - source_bucket = TEST_BUCKET source_object = "test.txt" destination_bucket = TEST_BUCKET + "-dest" @@ -416,7 +415,9 @@ class TestGCSDeleteBucketOperator: operator = GCSDeleteBucketOperator(task_id=TASK_ID, bucket_name=TEST_BUCKET) operator.execute(None) - mock_hook.return_value.delete_bucket.assert_called_once_with(bucket_name=TEST_BUCKET, force=True) + mock_hook.return_value.delete_bucket.assert_called_once_with( + bucket_name=TEST_BUCKET, force=True, user_project=None + ) class TestGoogleCloudStorageSync: diff --git a/tests/system/providers/amazon/aws/example_gcs_to_s3.py b/tests/system/providers/amazon/aws/example_gcs_to_s3.py index c0182f2d09..68db86f82a 100644 --- a/tests/system/providers/amazon/aws/example_gcs_to_s3.py +++ b/tests/system/providers/amazon/aws/example_gcs_to_s3.py @@ -19,13 +19,25 @@ from __future__ import annotations from datetime import datetime from airflow import DAG +from airflow.decorators import task from airflow.models.baseoperator import chain -from airflow.providers.amazon.aws.operators.s3 import S3CreateBucketOperator, S3DeleteBucketOperator +from airflow.providers.amazon.aws.operators.s3 import ( + S3CreateBucketOperator, + S3DeleteBucketOperator, +) from airflow.providers.amazon.aws.transfers.gcs_to_s3 import GCSToS3Operator +from airflow.providers.google.cloud.hooks.gcs import GCSHook +from airflow.providers.google.cloud.operators.gcs import ( + GCSCreateBucketOperator, + GCSDeleteBucketOperator, +) from airflow.utils.trigger_rule import TriggerRule from tests.system.providers.amazon.aws.utils import SystemTestContextBuilder -sys_test_context_task = SystemTestContextBuilder().build() +# Externally fetched variables: +GCP_PROJECT_ID = "GCP_PROJECT_ID" + +sys_test_context_task = SystemTestContextBuilder().add_variable(GCP_PROJECT_ID).build() DAG_ID = "example_gcs_to_s3" @@ -38,18 +50,40 @@ with DAG( ) as dag: test_context = sys_test_context_task() env_id = test_context["ENV_ID"] + gcp_user_project = test_context[GCP_PROJECT_ID] s3_bucket = f"{env_id}-gcs-to-s3-bucket" s3_key = f"{env_id}-gcs-to-s3-key" create_s3_bucket = S3CreateBucketOperator(task_id="create_s3_bucket", bucket_name=s3_bucket) + gcs_bucket = f"{env_id}-gcs-to-s3-bucket" + gcs_key = f"{env_id}-gcs-to-s3-key" + + create_gcs_bucket = GCSCreateBucketOperator( + task_id="create_gcs_bucket", + bucket_name=gcs_bucket, + resource={"billing": {"requesterPays": True}}, + project_id=gcp_user_project, + ) + + @task + def upload_gcs_file(bucket_name: str, object_name: str, user_project: str): + hook = GCSHook() + with hook.provide_file_and_upload( + bucket_name=bucket_name, + object_name=object_name, + user_project=user_project, + ) as temp_file: + temp_file.write(b"test") + # [START howto_transfer_gcs_to_s3] gcs_to_s3 = GCSToS3Operator( task_id="gcs_to_s3", - bucket=s3_bucket, - dest_s3_key=s3_key, + bucket=gcs_bucket, + dest_s3_key=f"s3://{s3_bucket}/{s3_key}", replace=True, + gcp_user_project=gcp_user_project, ) # [END howto_transfer_gcs_to_s3] @@ -60,14 +94,24 @@ with DAG( trigger_rule=TriggerRule.ALL_DONE, ) + delete_gcs_bucket = GCSDeleteBucketOperator( + task_id="delete_gcs_bucket", + bucket_name=gcs_bucket, + trigger_rule=TriggerRule.ALL_DONE, + user_project=gcp_user_project, + ) + chain( # TEST SETUP test_context, + create_gcs_bucket, + upload_gcs_file(gcs_bucket, gcs_key, gcp_user_project), create_s3_bucket, # TEST BODY gcs_to_s3, # TEST TEARDOWN delete_s3_bucket, + delete_gcs_bucket, ) from tests.system.utils.watcher import watcher