o-nikolas commented on code in PR #68946:
URL: https://github.com/apache/airflow/pull/68946#discussion_r3469648406


##########
providers/amazon/src/airflow/providers/amazon/aws/operators/s3.py:
##########
@@ -385,6 +385,161 @@ def get_openlineage_facets_on_start(self):
         )
 
 
+class S3CopyPrefixOperator(AwsBaseOperator[S3Hook]):
+    """
+    Creates a copy of all objects under a prefix already stored in S3.
+
+    Note: the S3 connection used here needs to have access to both
+    source and destination bucket/prefix.
+
+    .. seealso::
+        For more information on how to use this operator, take a look at the 
guide:
+        :ref:`howto/operator:S3CopyPrefixOperator`
+
+    :param source_bucket_prefix: The prefix in the source bucket. (templated)
+        It can be either full s3:// style url or relative path from root level.
+        When it's specified as a full s3:// url, please omit 
source_bucket_name.
+    :param dest_bucket_prefix: The prefix in the destination to copy to. 
(templated)
+        The convention to specify `dest_bucket_prefix` is the same as 
`source_bucket_prefix`.
+    :param source_bucket_name: Name of the S3 bucket where the source objects 
are in. (templated)
+        It should be omitted when `source_bucket_prefix` is provided as a full 
s3:// url.
+    :param dest_bucket_name: Name of the S3 bucket to where the objects are 
copied. (templated)
+        It should be omitted when `dest_bucket_prefix` is provided as a full 
s3:// url.
+    :param kms_key_id: The ARN, id or alias of the AWS KMS key to use for 
encrypting the destination object.
+        Required if using KMS-based server-side encryption with a non-default 
key. (templated)
+    :param kms_encryption_type: Type of KMS encryption to use for the object.
+        Can be either "aws:kms" (standard KMS) or "aws:kms:dsse" 
(double-shielded KMS).
+    :param continue_on_failure: If False, stop and fail the task on the first 
copy error.
+        If True, try to copy every object in the prefix and then fail the task 
on any error.
+        Default is False.
+    :param acl_policy: String specifying the canned ACL policy for the file 
being
+        uploaded to the S3 bucket.
+    :param meta_data_directive: Whether to `COPY` the metadata from the source 
object or `REPLACE` it with
+        metadata that's provided in the request.
+    :param aws_conn_id: The Airflow connection used for AWS credentials.
+        If this is ``None`` or empty then the default boto3 behaviour is used. 
If
+        running Airflow in a distributed manner and aws_conn_id is None or
+        empty, then default boto3 configuration would be used (and must be
+        maintained on each worker node).
+    :param region_name: AWS region_name. If not specified then the default 
boto3 behaviour is used.
+    :param verify: Whether or not to verify SSL certificates. See:
+        
https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html
+    :param botocore_config: Configuration dictionary (key-values) for botocore 
client. See:
+        
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
+    """
+
+    template_fields: Sequence[str] = aws_template_fields(
+        "source_bucket_prefix",
+        "dest_bucket_prefix",
+        "source_bucket_name",
+        "dest_bucket_name",
+        "kms_key_id",
+    )
+    aws_hook_class = S3Hook
+
+    def __init__(
+        self,
+        *,
+        source_bucket_prefix: str,
+        dest_bucket_prefix: str,
+        source_bucket_name: str | None = None,
+        dest_bucket_name: str | None = None,
+        kms_key_id: str | None = None,
+        kms_encryption_type: str | None = None,
+        continue_on_failure: bool = False,
+        acl_policy: str | None = None,
+        meta_data_directive: str | None = None,
+        **kwargs,
+    ):
+        super().__init__(**kwargs)
+
+        self.source_bucket_prefix = source_bucket_prefix
+        self.dest_bucket_prefix = dest_bucket_prefix
+        self.source_bucket_name = source_bucket_name
+        self.dest_bucket_name = dest_bucket_name
+        self.kms_key_id = kms_key_id
+        self.kms_encryption_type = kms_encryption_type
+        self.continue_on_failure = continue_on_failure
+        self.acl_policy = acl_policy
+        self.meta_data_directive = meta_data_directive
+
+    def execute(self, context: Context):
+        source_bucket_name, source_bucket_prefix = self.hook.get_s3_bucket_key(
+            self.source_bucket_name, self.source_bucket_prefix, 
"source_bucket_name", "source_bucket_prefix"
+        )
+
+        dest_bucket_name, dest_bucket_prefix = self.hook.get_s3_bucket_key(
+            self.dest_bucket_name, self.dest_bucket_prefix, 
"dest_bucket_name", "dest_bucket_prefix"
+        )
+
+        s3_client = self.hook.get_conn()
+
+        paginator = s3_client.get_paginator("list_objects_v2")
+        pages = paginator.paginate(
+            Bucket=source_bucket_name,
+            Prefix=source_bucket_prefix,
+        )
+
+        copied_object_count = 0
+        failed_object_count = 0
+        for page in pages:
+            if "Contents" in page:
+                for obj in page["Contents"]:
+                    source_key = obj["Key"]
+                    dest_key = dest_bucket_prefix + 
source_key[len(source_bucket_prefix) :]
+
+                    try:
+                        self.hook.copy_object(
+                            source_bucket_key=source_key,
+                            dest_bucket_key=dest_key,
+                            source_bucket_name=source_bucket_name,
+                            dest_bucket_name=dest_bucket_name,
+                            kms_key_id=self.kms_key_id,
+                            kms_encryption_type=self.kms_encryption_type,
+                            acl_policy=self.acl_policy,
+                            meta_data_directive=self.meta_data_directive,
+                        )
+
+                        copied_object_count += 1
+                    except Exception as e:
+                        if self.continue_on_failure:
+                            self.log.error("Failed to copy %s: %s", 
source_key, e)
+                            failed_object_count += 1
+                        else:
+                            raise RuntimeError(f"Failed to copy {source_key}: 
{e}")

Review Comment:
   Maybe use `raise RuntimeError(...) from e` to not lose the original traceback



##########
providers/amazon/tests/unit/amazon/aws/operators/test_s3.py:
##########
@@ -658,6 +659,247 @@ def fake_copy_object(
         assert objects_in_dest_bucket["Contents"][0]["Key"] == self.dest_key
 
 
+class TestS3CopyPrefixOperator:
+    def setup_method(self):
+        self.source_bucket = "source-bucket"
+        self.source_prefix = "data/logs/"
+        self.dest_bucket = "dest-bucket"
+        self.dest_prefix = "backup/logs/"
+
+        self.source_s3_url = f"s3://{self.source_bucket}/{self.source_prefix}"
+        self.dest_s3_url = f"s3://{self.dest_bucket}/{self.dest_prefix}"
+
+    @staticmethod
+    def _create_s3_client():
+        return boto3.client("s3", region_name="us-east-1")
+
+    def _create_buckets(self, s3_client):
+        s3_client.create_bucket(Bucket=self.source_bucket)
+        s3_client.create_bucket(Bucket=self.dest_bucket)
+
+    def _upload_test_objects(self, s3_client, keys):
+        for key in keys:
+            s3_client.upload_fileobj(Bucket=self.source_bucket, Key=key, 
Fileobj=BytesIO(b"test-content"))
+
+    @mock_aws
+    def test_s3_copy_prefix_basic(self):
+        s3_client = self._create_s3_client()
+        self._create_buckets(s3_client)
+        self._upload_test_objects(
+            s3_client,
+            [
+                f"{self.source_prefix}file1.txt",
+                f"{self.source_prefix}file2.txt",
+                f"{self.source_prefix}subdir/file3.txt",
+            ],
+        )
+
+        dest_objects = s3_client.list_objects_v2(Bucket=self.dest_bucket, 
Prefix=self.dest_prefix)
+        assert "Contents" not in dest_objects
+
+        op = S3CopyPrefixOperator(
+            task_id="test_copy_prefix",
+            source_bucket_name=self.source_bucket,
+            source_bucket_prefix=self.source_prefix,
+            dest_bucket_name=self.dest_bucket,
+            dest_bucket_prefix=self.dest_prefix,
+        )
+
+        op.execute(None)
+
+        dest_objects = s3_client.list_objects_v2(Bucket=self.dest_bucket, 
Prefix=self.dest_prefix)
+        assert len(dest_objects["Contents"]) == 3
+
+        copied_keys = [obj["Key"] for obj in dest_objects["Contents"]]
+        assert "backup/logs/file1.txt" in copied_keys
+        assert "backup/logs/file2.txt" in copied_keys
+        assert "backup/logs/subdir/file3.txt" in copied_keys
+
+    @mock_aws
+    def test_s3_copy_prefix_selective_copying(self):
+        s3_client = self._create_s3_client()
+        self._create_buckets(s3_client)
+        self._upload_test_objects(
+            s3_client,
+            [
+                f"{self.source_prefix}file1.txt",
+                f"{self.source_prefix}subdir/file2.txt",
+                "data/metrics/file3.txt",
+                "other/logs/file4.txt",
+                "archive/data/file5.txt",
+            ],
+        )
+
+        op = S3CopyPrefixOperator(
+            task_id="test_copy_prefix_selective",
+            source_bucket_name=self.source_bucket,
+            source_bucket_prefix=self.source_prefix,
+            dest_bucket_name=self.dest_bucket,
+            dest_bucket_prefix=self.dest_prefix,
+        )
+
+        op.execute(None)
+
+        dest_objects = s3_client.list_objects_v2(Bucket=self.dest_bucket, 
Prefix=self.dest_prefix)
+        assert len(dest_objects["Contents"]) == 2
+
+        copied_keys = [obj["Key"] for obj in dest_objects["Contents"]]
+        assert "backup/logs/file1.txt" in copied_keys
+        assert "backup/logs/subdir/file2.txt" in copied_keys
+
+    @mock_aws
+    def test_s3_copy_prefix_s3_urls(self):
+        s3_client = self._create_s3_client()
+        self._create_buckets(s3_client)
+        self._upload_test_objects(
+            s3_client, [f"{self.source_prefix}file1.txt", 
f"{self.source_prefix}file2.txt"]
+        )
+
+        op = S3CopyPrefixOperator(
+            task_id="test_copy_prefix_urls",
+            source_bucket_prefix=self.source_s3_url,
+            dest_bucket_prefix=self.dest_s3_url,
+        )
+
+        op.execute(None)
+
+        dest_objects = s3_client.list_objects_v2(Bucket=self.dest_bucket, 
Prefix=self.dest_prefix)
+        assert len(dest_objects["Contents"]) == 2
+
+    def test_invalid_combination_bucket_with_s3_url(self):
+        op = S3CopyPrefixOperator(
+            task_id="test_invalid",
+            source_bucket_name=self.source_bucket,
+            
source_bucket_prefix=f"s3://{self.source_bucket}/{self.source_prefix}",
+            dest_bucket_name=self.dest_bucket,
+            dest_bucket_prefix=self.dest_prefix,
+        )
+
+        with pytest.raises(TypeError, match="should be a relative path"):
+            op.execute(None)
+
+    @mock_aws
+    def test_s3_copy_prefix_same_bucket(self):
+        s3_client = self._create_s3_client()
+        s3_client.create_bucket(Bucket=self.source_bucket)
+        self._upload_test_objects(s3_client, 
[f"{self.source_prefix}file1.txt"])
+
+        op = S3CopyPrefixOperator(
+            task_id="test_copy_prefix_same_bucket",
+            source_bucket_name=self.source_bucket,
+            source_bucket_prefix=self.source_prefix,
+            dest_bucket_name=self.source_bucket,
+            dest_bucket_prefix="archive/logs/",
+        )
+
+        op.execute(None)
+
+        all_objects = s3_client.list_objects_v2(Bucket=self.source_bucket)
+        keys = [obj["Key"] for obj in all_objects["Contents"]]
+        assert f"{self.source_prefix}file1.txt" in keys
+        assert "archive/logs/file1.txt" in keys
+
+    @mock_aws
+    def test_s3_copy_prefix_empty_result(self):
+        s3_client = self._create_s3_client()
+        self._create_buckets(s3_client)
+
+        op = S3CopyPrefixOperator(
+            task_id="test_copy_prefix_empty",
+            source_bucket_name=self.source_bucket,
+            source_bucket_prefix=self.source_prefix,
+            dest_bucket_name=self.dest_bucket,
+            dest_bucket_prefix=self.dest_prefix,
+        )
+
+        op.execute(None)
+
+        dest_objects = s3_client.list_objects_v2(Bucket=self.dest_bucket, 
Prefix=self.dest_prefix)
+        assert "Contents" not in dest_objects
+
+    @mock_aws
+    def test_continue_on_failure_false(self):
+        s3_client = self._create_s3_client()
+        self._create_buckets(s3_client)
+        self._upload_test_objects(s3_client, 
[f"{self.source_prefix}file1.txt"])
+
+        op = S3CopyPrefixOperator(
+            task_id="test_copy_prefix_fail_fast",
+            source_bucket_name=self.source_bucket,
+            source_bucket_prefix=self.source_prefix,
+            dest_bucket_name=self.dest_bucket,
+            dest_bucket_prefix=self.dest_prefix,
+            continue_on_failure=False,
+        )
+
+        with mock.patch.object(op.hook, "copy_object", 
side_effect=Exception("Copy failed")):
+            with pytest.raises(
+                RuntimeError, match=f"Failed to copy 
{self.source_prefix}file1.txt: Copy failed"
+            ):
+                op.execute(None)
+
+    @mock_aws
+    def test_continue_on_failure_true(self):
+        s3_client = self._create_s3_client()
+        self._create_buckets(s3_client)
+        self._upload_test_objects(
+            s3_client, [f"{self.source_prefix}file1.txt", 
f"{self.source_prefix}file2.txt"]
+        )
+
+        op = S3CopyPrefixOperator(
+            task_id="test_copy_prefix_continue",
+            source_bucket_name=self.source_bucket,
+            source_bucket_prefix=self.source_prefix,
+            dest_bucket_name=self.dest_bucket,
+            dest_bucket_prefix=self.dest_prefix,
+            continue_on_failure=True,
+        )
+
+        def mock_copy_object(*args, **kwargs):
+            if "file1.txt" in kwargs.get("source_bucket_key", ""):
+                raise Exception("Copy failed for file1")
+            return None
+
+        with mock.patch.object(op.hook, "copy_object", 
side_effect=mock_copy_object):
+            with pytest.raises(RuntimeError, match=r"Failed to copy 1 
object\(s\)"):
+                op.execute(None)

Review Comment:
   You assert that the Exception is thrown, but not that the copying continues. 
Worth verifying the file2 is copied



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to