This is an automated email from the ASF dual-hosted git repository.

onikolas pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 9e159fc48d Add OpenLineage support to S3Operators - Copy, Delete and 
Create Object (#35796)
9e159fc48d is described below

commit 9e159fc48dd774aa09358801c17d6da217052f8a
Author: Kacper Muda <mudakac...@gmail.com>
AuthorDate: Wed Nov 22 19:51:55 2023 +0100

    Add OpenLineage support to S3Operators - Copy, Delete and Create Object 
(#35796)
---
 airflow/providers/amazon/aws/operators/s3.py    |  86 ++++++++++++++
 tests/providers/amazon/aws/operators/test_s3.py | 146 ++++++++++++++++++++++++
 2 files changed, 232 insertions(+)

diff --git a/airflow/providers/amazon/aws/operators/s3.py 
b/airflow/providers/amazon/aws/operators/s3.py
index 956feddb4a..068f73e622 100644
--- a/airflow/providers/amazon/aws/operators/s3.py
+++ b/airflow/providers/amazon/aws/operators/s3.py
@@ -321,6 +321,33 @@ class S3CopyObjectOperator(BaseOperator):
             self.acl_policy,
         )
 
+    def get_openlineage_facets_on_start(self):
+        from openlineage.client.run import Dataset
+
+        from airflow.providers.openlineage.extractors import OperatorLineage
+
+        dest_bucket_name, dest_bucket_key = S3Hook.get_s3_bucket_key(
+            self.dest_bucket_name, self.dest_bucket_key, "dest_bucket_name", 
"dest_bucket_key"
+        )
+
+        source_bucket_name, source_bucket_key = S3Hook.get_s3_bucket_key(
+            self.source_bucket_name, self.source_bucket_key, 
"source_bucket_name", "source_bucket_key"
+        )
+
+        input_dataset = Dataset(
+            namespace=f"s3://{source_bucket_name}",
+            name=source_bucket_key,
+        )
+        output_dataset = Dataset(
+            namespace=f"s3://{dest_bucket_name}",
+            name=dest_bucket_key,
+        )
+
+        return OperatorLineage(
+            inputs=[input_dataset],
+            outputs=[output_dataset],
+        )
+
 
 class S3CreateObjectOperator(BaseOperator):
     """
@@ -409,6 +436,22 @@ class S3CreateObjectOperator(BaseOperator):
         else:
             s3_hook.load_bytes(self.data, s3_key, s3_bucket, self.replace, 
self.encrypt, self.acl_policy)
 
+    def get_openlineage_facets_on_start(self):
+        from openlineage.client.run import Dataset
+
+        from airflow.providers.openlineage.extractors import OperatorLineage
+
+        bucket, key = S3Hook.get_s3_bucket_key(self.s3_bucket, self.s3_key, 
"dest_bucket", "dest_key")
+
+        output_dataset = Dataset(
+            namespace=f"s3://{bucket}",
+            name=key,
+        )
+
+        return OperatorLineage(
+            outputs=[output_dataset],
+        )
+
 
 class S3DeleteObjectsOperator(BaseOperator):
     """
@@ -462,6 +505,8 @@ class S3DeleteObjectsOperator(BaseOperator):
         self.aws_conn_id = aws_conn_id
         self.verify = verify
 
+        self._keys: str | list[str] = ""
+
         if not exactly_one(prefix is None, keys is None):
             raise AirflowException("Either keys or prefix should be set.")
 
@@ -476,6 +521,47 @@ class S3DeleteObjectsOperator(BaseOperator):
         keys = self.keys or s3_hook.list_keys(bucket_name=self.bucket, 
prefix=self.prefix)
         if keys:
             s3_hook.delete_objects(bucket=self.bucket, keys=keys)
+            self._keys = keys
+
+    def get_openlineage_facets_on_complete(self, task_instance):
+        """Implement _on_complete because object keys are resolved in 
execute()."""
+        from openlineage.client.facet import (
+            LifecycleStateChange,
+            LifecycleStateChangeDatasetFacet,
+            LifecycleStateChangeDatasetFacetPreviousIdentifier,
+        )
+        from openlineage.client.run import Dataset
+
+        from airflow.providers.openlineage.extractors import OperatorLineage
+
+        if not self._keys:
+            return OperatorLineage()
+
+        keys = self._keys
+        if isinstance(keys, str):
+            keys = [keys]
+
+        bucket_url = f"s3://{self.bucket}"
+        input_datasets = [
+            Dataset(
+                namespace=bucket_url,
+                name=key,
+                facets={
+                    "lifecycleStateChange": LifecycleStateChangeDatasetFacet(
+                        lifecycleStateChange=LifecycleStateChange.DROP.value,
+                        
previousIdentifier=LifecycleStateChangeDatasetFacetPreviousIdentifier(
+                            namespace=bucket_url,
+                            name=key,
+                        ),
+                    )
+                },
+            )
+            for key in keys
+        ]
+
+        return OperatorLineage(
+            inputs=input_datasets,
+        )
 
 
 class S3FileTransformOperator(BaseOperator):
diff --git a/tests/providers/amazon/aws/operators/test_s3.py 
b/tests/providers/amazon/aws/operators/test_s3.py
index 3bfd238a97..80a4b645d4 100644
--- a/tests/providers/amazon/aws/operators/test_s3.py
+++ b/tests/providers/amazon/aws/operators/test_s3.py
@@ -28,6 +28,12 @@ from unittest import mock
 import boto3
 import pytest
 from moto import mock_s3
+from openlineage.client.facet import (
+    LifecycleStateChange,
+    LifecycleStateChangeDatasetFacet,
+    LifecycleStateChangeDatasetFacetPreviousIdentifier,
+)
+from openlineage.client.run import Dataset
 
 from airflow.exceptions import AirflowException
 from airflow.providers.amazon.aws.hooks.s3 import S3Hook
@@ -44,6 +50,7 @@ from airflow.providers.amazon.aws.operators.s3 import (
     S3ListPrefixesOperator,
     S3PutBucketTaggingOperator,
 )
+from airflow.providers.openlineage.extractors import OperatorLineage
 
 BUCKET_NAME = os.environ.get("BUCKET_NAME", "test-airflow-bucket")
 S3_KEY = "test-airflow-key"
@@ -409,6 +416,55 @@ class TestS3CopyObjectOperator:
         # the object found should be consistent with dest_key specified earlier
         assert objects_in_dest_bucket["Contents"][0]["Key"] == self.dest_key
 
+    def test_get_openlineage_facets_on_start_combination_1(self):
+        expected_input = Dataset(
+            namespace=f"s3://{self.source_bucket}",
+            name=self.source_key,
+        )
+        expected_output = Dataset(
+            namespace=f"s3://{self.dest_bucket}",
+            name=self.dest_key,
+        )
+
+        op = S3CopyObjectOperator(
+            task_id="test",
+            source_bucket_name=self.source_bucket,
+            source_bucket_key=self.source_key,
+            dest_bucket_name=self.dest_bucket,
+            dest_bucket_key=self.dest_key,
+        )
+
+        lineage = op.get_openlineage_facets_on_start()
+        assert len(lineage.inputs) == 1
+        assert len(lineage.outputs) == 1
+        assert lineage.inputs[0] == expected_input
+        assert lineage.outputs[0] == expected_output
+
+    def test_get_openlineage_facets_on_start_combination_2(self):
+        expected_input = Dataset(
+            namespace=f"s3://{self.source_bucket}",
+            name=self.source_key,
+        )
+        expected_output = Dataset(
+            namespace=f"s3://{self.dest_bucket}",
+            name=self.dest_key,
+        )
+
+        source_key_s3_url = f"s3://{self.source_bucket}/{self.source_key}"
+        dest_key_s3_url = f"s3://{self.dest_bucket}/{self.dest_key}"
+
+        op = S3CopyObjectOperator(
+            task_id="test",
+            source_bucket_key=source_key_s3_url,
+            dest_bucket_key=dest_key_s3_url,
+        )
+
+        lineage = op.get_openlineage_facets_on_start()
+        assert len(lineage.inputs) == 1
+        assert len(lineage.outputs) == 1
+        assert lineage.inputs[0] == expected_input
+        assert lineage.outputs[0] == expected_output
+
 
 @mock_s3
 class TestS3DeleteObjectsOperator:
@@ -575,6 +631,82 @@ class TestS3DeleteObjectsOperator:
         # the object found should be consistent with dest_key specified earlier
         assert objects_in_dest_bucket["Contents"][0]["Key"] == key_of_test
 
+    @pytest.mark.parametrize("keys", ("path/data.txt", ["path/data.txt"]))
+    @mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook")
+    def test_get_openlineage_facets_on_complete_single_object(self, mock_hook, 
keys):
+        bucket = "testbucket"
+        expected_input = Dataset(
+            namespace=f"s3://{bucket}",
+            name="path/data.txt",
+            facets={
+                "lifecycleStateChange": LifecycleStateChangeDatasetFacet(
+                    lifecycleStateChange=LifecycleStateChange.DROP.value,
+                    
previousIdentifier=LifecycleStateChangeDatasetFacetPreviousIdentifier(
+                        namespace=f"s3://{bucket}",
+                        name="path/data.txt",
+                    ),
+                )
+            },
+        )
+
+        op = 
S3DeleteObjectsOperator(task_id="test_task_s3_delete_single_object", 
bucket=bucket, keys=keys)
+        op.execute(None)
+
+        lineage = op.get_openlineage_facets_on_complete(None)
+        assert len(lineage.inputs) == 1
+        assert lineage.inputs[0] == expected_input
+
+    @mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook")
+    def test_get_openlineage_facets_on_complete_multiple_objects(self, 
mock_hook):
+        bucket = "testbucket"
+        keys = ["path/data1.txt", "path/data2.txt"]
+        expected_inputs = [
+            Dataset(
+                namespace=f"s3://{bucket}",
+                name="path/data1.txt",
+                facets={
+                    "lifecycleStateChange": LifecycleStateChangeDatasetFacet(
+                        lifecycleStateChange=LifecycleStateChange.DROP.value,
+                        
previousIdentifier=LifecycleStateChangeDatasetFacetPreviousIdentifier(
+                            namespace=f"s3://{bucket}",
+                            name="path/data1.txt",
+                        ),
+                    )
+                },
+            ),
+            Dataset(
+                namespace=f"s3://{bucket}",
+                name="path/data2.txt",
+                facets={
+                    "lifecycleStateChange": LifecycleStateChangeDatasetFacet(
+                        lifecycleStateChange=LifecycleStateChange.DROP.value,
+                        
previousIdentifier=LifecycleStateChangeDatasetFacetPreviousIdentifier(
+                            namespace=f"s3://{bucket}",
+                            name="path/data2.txt",
+                        ),
+                    )
+                },
+            ),
+        ]
+
+        op = 
S3DeleteObjectsOperator(task_id="test_task_s3_delete_single_object", 
bucket=bucket, keys=keys)
+        op.execute(None)
+
+        lineage = op.get_openlineage_facets_on_complete(None)
+        assert len(lineage.inputs) == 2
+        assert lineage.inputs == expected_inputs
+
+    @pytest.mark.parametrize("keys", ("", []))
+    @mock.patch("airflow.providers.amazon.aws.operators.s3.S3Hook")
+    def test_get_openlineage_facets_on_complete_no_objects(self, mock_hook, 
keys):
+        op = S3DeleteObjectsOperator(
+            task_id="test_task_s3_delete_single_object", bucket="testbucket", 
keys=keys
+        )
+        op.execute(None)
+
+        lineage = op.get_openlineage_facets_on_complete(None)
+        assert lineage == OperatorLineage()
+
 
 class TestS3CreateObjectOperator:
     @mock.patch.object(S3Hook, "load_string")
@@ -614,3 +746,17 @@ class TestS3CreateObjectOperator:
         operator.execute(None)
 
         mock_load_string.assert_called_once_with(data, S3_KEY, BUCKET_NAME, 
False, False, None, None, None)
+
+    @pytest.mark.parametrize(("bucket", "key"), (("bucket", "file.txt"), 
(None, "s3://bucket/file.txt")))
+    def test_get_openlineage_facets_on_start(self, bucket, key):
+        expected_output = Dataset(
+            namespace="s3://bucket",
+            name="file.txt",
+        )
+
+        op = S3CreateObjectOperator(task_id="test", s3_bucket=bucket, 
s3_key=key, data="test")
+
+        lineage = op.get_openlineage_facets_on_start()
+        assert len(lineage.inputs) == 0
+        assert len(lineage.outputs) == 1
+        assert lineage.outputs[0] == expected_output

Reply via email to