This is an automated email from the ASF dual-hosted git repository.
kamilbregula pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/master by this push:
new dd98b21 Add acl_policy parameter to GCSToS3Operator (#10804) (#10829)
dd98b21 is described below
commit dd98b21494ff6036242b63268140abe1294b3657
Author: amaterasu-coder <[email protected]>
AuthorDate: Tue Oct 6 13:09:01 2020 +0200
Add acl_policy parameter to GCSToS3Operator (#10804) (#10829)
---
.../providers/amazon/aws/transfers/gcs_to_s3.py | 9 +++++-
.../amazon/aws/transfers/test_gcs_to_s3.py | 32 ++++++++++++++++++++++
2 files changed, 40 insertions(+), 1 deletion(-)
diff --git a/airflow/providers/amazon/aws/transfers/gcs_to_s3.py
b/airflow/providers/amazon/aws/transfers/gcs_to_s3.py
index 9b00e33..695087e 100644
--- a/airflow/providers/amazon/aws/transfers/gcs_to_s3.py
+++ b/airflow/providers/amazon/aws/transfers/gcs_to_s3.py
@@ -82,6 +82,9 @@ class GCSToS3Operator(BaseOperator):
Service Account Token Creator IAM role to the directly preceding
identity, with first
account from the list granting this role to the originating account
(templated).
:type google_impersonation_chain: Union[str, Sequence[str]]
+ :param s3_acl_policy: Optional The string to specify the canned ACL policy
for the
+ object to be uploaded in S3
+ :type s3_acl_policy: str
"""
template_fields: Iterable[str] = (
@@ -109,6 +112,7 @@ class GCSToS3Operator(BaseOperator):
replace=False,
google_impersonation_chain: Optional[Union[str, Sequence[str]]] = None,
dest_s3_extra_args: Optional[Dict] = None,
+ s3_acl_policy: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)
@@ -133,6 +137,7 @@ class GCSToS3Operator(BaseOperator):
self.replace = replace
self.google_impersonation_chain = google_impersonation_chain
self.dest_s3_extra_args = dest_s3_extra_args or {}
+ self.s3_acl_policy = s3_acl_policy
def execute(self, context):
# list all files in an Google Cloud Storage bucket
@@ -177,7 +182,9 @@ class GCSToS3Operator(BaseOperator):
dest_key = self.dest_s3_key + file
self.log.info("Saving file to %s", dest_key)
- s3_hook.load_bytes(file_bytes, key=dest_key,
replace=self.replace)
+ s3_hook.load_bytes(
+ file_bytes, key=dest_key, replace=self.replace,
acl_policy=self.s3_acl_policy
+ )
self.log.info("All done, uploaded %d files to S3", len(files))
else:
diff --git a/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py
b/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py
index fd84874..2c18653 100644
--- a/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py
+++ b/tests/providers/amazon/aws/transfers/test_gcs_to_s3.py
@@ -34,6 +34,7 @@ DELIMITER = '.csv'
PREFIX = 'TEST'
S3_BUCKET = 's3://bucket/'
MOCK_FILES = ["TEST1.csv", "TEST2.csv", "TEST3.csv"]
+S3_ACL_POLICY = "private-read"
class TestGCSToS3Operator(unittest.TestCase):
@@ -240,3 +241,34 @@ class TestGCSToS3Operator(unittest.TestCase):
s3_mock_hook.assert_called_once_with(
aws_conn_id='aws_default', extra_args={'ContentLanguage':
'value'}, verify=None
)
+
+ # Test6: s3_acl_policy parameter is set
+ @mock_s3
+ @mock.patch('airflow.providers.google.cloud.operators.gcs.GCSHook')
+ @mock.patch('airflow.providers.amazon.aws.transfers.gcs_to_s3.GCSHook')
+ @mock.patch('airflow.providers.amazon.aws.hooks.s3.S3Hook.load_bytes')
+ def test_execute_with_s3_acl_policy(self, mock_load_bytes, mock_gcs_hook,
mock_gcs_hook2):
+ mock_gcs_hook.return_value.list.return_value = MOCK_FILES
+ mock_gcs_hook.return_value.download.return_value = b"testing"
+ mock_gcs_hook2.return_value.list.return_value = MOCK_FILES
+
+ operator = GCSToS3Operator(
+ task_id=TASK_ID,
+ bucket=GCS_BUCKET,
+ prefix=PREFIX,
+ delimiter=DELIMITER,
+ dest_aws_conn_id="aws_default",
+ dest_s3_key=S3_BUCKET,
+ replace=False,
+ s3_acl_policy=S3_ACL_POLICY,
+ )
+
+ # Create dest bucket without files
+ hook = S3Hook(aws_conn_id='airflow_gcs_test')
+ bucket = hook.get_bucket('bucket')
+ bucket.create()
+
+ operator.execute(None)
+
+ # Make sure the acl_policy parameter is passed to the upload method
+ self.assertEqual(mock_load_bytes.call_args.kwargs['acl_policy'],
S3_ACL_POLICY)