This is an automated email from the ASF dual-hosted git repository.

vincbeck pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a2b8696bdb Openlineage s3 to redshift operator integration (#41575)
a2b8696bdb is described below

commit a2b8696bdb6e99b27e40e9d8bce04dd2ee1789a7
Author: Artur Skarżyński <33717106+artu...@users.noreply.github.com>
AuthorDate: Mon Aug 19 16:11:54 2024 +0200

    Openlineage s3 to redshift operator integration (#41575)
    
    
    
    ---------
    
    Co-authored-by: Vincent <97131062+vincb...@users.noreply.github.com>
---
 .../amazon/aws/transfers/s3_to_redshift.py         | 101 ++++++-
 airflow/providers/amazon/aws/utils/openlineage.py  | 136 +++++++++
 .../amazon/aws/transfers/test_s3_to_redshift.py    | 308 +++++++++++++++++++++
 .../providers/amazon/aws/utils/test_openlineage.py | 168 +++++++++++
 4 files changed, 703 insertions(+), 10 deletions(-)

diff --git a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py 
b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
index 161276b33c..653885b541 100644
--- a/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
+++ b/airflow/providers/amazon/aws/transfers/s3_to_redshift.py
@@ -121,6 +121,10 @@ class S3ToRedshiftOperator(BaseOperator):
                 if arg in self.redshift_data_api_kwargs:
                     raise AirflowException(f"Cannot include param '{arg}' in 
Redshift Data API kwargs")
 
+    @property
+    def use_redshift_data(self):
+        return bool(self.redshift_data_api_kwargs)
+
     def _build_copy_query(
         self, copy_destination: str, credentials_block: str, region_info: str, 
copy_options: str
     ) -> str:
@@ -138,11 +142,11 @@ class S3ToRedshiftOperator(BaseOperator):
         if self.method not in AVAILABLE_METHODS:
             raise AirflowException(f"Method not found! Available methods: 
{AVAILABLE_METHODS}")
 
-        redshift_hook: RedshiftDataHook | RedshiftSQLHook
-        if self.redshift_data_api_kwargs:
-            redshift_hook = RedshiftDataHook(aws_conn_id=self.redshift_conn_id)
+        if self.use_redshift_data:
+            redshift_data_hook = 
RedshiftDataHook(aws_conn_id=self.redshift_conn_id)
         else:
-            redshift_hook = 
RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
+            redshift_sql_hook = 
RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
+
         conn = S3Hook.get_connection(conn_id=self.aws_conn_id) if 
self.aws_conn_id else None
         region_info = ""
         if conn and conn.extra_dejson.get("region", False):
@@ -167,12 +171,12 @@ class S3ToRedshiftOperator(BaseOperator):
         if self.method == "REPLACE":
             sql = ["BEGIN;", f"DELETE FROM {destination};", copy_statement, 
"COMMIT"]
         elif self.method == "UPSERT":
-            if isinstance(redshift_hook, RedshiftDataHook):
-                keys = self.upsert_keys or redshift_hook.get_table_primary_key(
+            if self.use_redshift_data:
+                keys = self.upsert_keys or 
redshift_data_hook.get_table_primary_key(
                     table=self.table, schema=self.schema, 
**self.redshift_data_api_kwargs
                 )
             else:
-                keys = self.upsert_keys or 
redshift_hook.get_table_primary_key(self.table, self.schema)
+                keys = self.upsert_keys or 
redshift_sql_hook.get_table_primary_key(self.table, self.schema)
             if not keys:
                 raise AirflowException(
                     f"No primary key on {self.schema}.{self.table}. Please 
provide keys on 'upsert_keys'"
@@ -192,8 +196,85 @@ class S3ToRedshiftOperator(BaseOperator):
             sql = copy_statement
 
         self.log.info("Executing COPY command...")
-        if isinstance(redshift_hook, RedshiftDataHook):
-            redshift_hook.execute_query(sql=sql, 
**self.redshift_data_api_kwargs)
+        if self.use_redshift_data:
+            redshift_data_hook.execute_query(sql=sql, 
**self.redshift_data_api_kwargs)
         else:
-            redshift_hook.run(sql, autocommit=self.autocommit)
+            redshift_sql_hook.run(sql, autocommit=self.autocommit)
         self.log.info("COPY command complete...")
+
+    def get_openlineage_facets_on_complete(self, task_instance):
+        """Implement on_complete as we will query destination table."""
+        from pathlib import Path
+
+        from airflow.providers.amazon.aws.utils.openlineage import (
+            get_facets_from_redshift_table,
+            get_identity_column_lineage_facet,
+        )
+        from airflow.providers.common.compat.openlineage.facet import (
+            Dataset,
+            Identifier,
+            LifecycleStateChange,
+            LifecycleStateChangeDatasetFacet,
+            SymlinksDatasetFacet,
+        )
+        from airflow.providers.openlineage.extractors import OperatorLineage
+
+        if self.use_redshift_data:
+            redshift_data_hook = 
RedshiftDataHook(aws_conn_id=self.redshift_conn_id)
+            database = self.redshift_data_api_kwargs.get("database")
+            identifier = self.redshift_data_api_kwargs.get(
+                "cluster_identifier", 
self.redshift_data_api_kwargs.get("workgroup_name")
+            )
+            port = self.redshift_data_api_kwargs.get("port", "5439")
+            authority = f"{identifier}.{redshift_data_hook.region_name}:{port}"
+            output_dataset_facets = get_facets_from_redshift_table(
+                redshift_data_hook, self.table, self.redshift_data_api_kwargs, 
self.schema
+            )
+        else:
+            redshift_sql_hook = 
RedshiftSQLHook(redshift_conn_id=self.redshift_conn_id)
+            database = redshift_sql_hook.conn.schema
+            authority = 
redshift_sql_hook.get_openlineage_database_info(redshift_sql_hook.conn).authority
+            output_dataset_facets = get_facets_from_redshift_table(
+                redshift_sql_hook, self.table, self.redshift_data_api_kwargs, 
self.schema
+            )
+
+        input_dataset_facets = {}
+        if not self.column_list:
+            # If column_list is not specified, then we know that input file 
matches columns of output table.
+            input_dataset_facets["schema"] = output_dataset_facets["schema"]
+
+        dataset_name = self.s3_key
+        if "*" in dataset_name:
+            # If wildcard ("*") is used in s3 path, we want the name of 
dataset to be directory name,
+            # but we create a symlink to the full object path with wildcard.
+            input_dataset_facets["symlink"] = SymlinksDatasetFacet(
+                identifiers=[Identifier(namespace=f"s3://{self.s3_bucket}", 
name=dataset_name, type="file")]
+            )
+            dataset_name = Path(dataset_name).parent.as_posix()
+            if dataset_name == ".":
+                # blob path does not have leading slash, but we need root 
dataset name to be "/"
+                dataset_name = "/"
+
+        input_dataset = Dataset(
+            namespace=f"s3://{self.s3_bucket}",
+            name=dataset_name,
+            facets=input_dataset_facets,
+        )
+
+        output_dataset_facets["columnLineage"] = 
get_identity_column_lineage_facet(
+            field_names=[field.name for field in 
output_dataset_facets["schema"].fields],
+            input_datasets=[input_dataset],
+        )
+
+        if self.method == "REPLACE":
+            output_dataset_facets["lifecycleStateChange"] = 
LifecycleStateChangeDatasetFacet(
+                lifecycleStateChange=LifecycleStateChange.OVERWRITE
+            )
+
+        output_dataset = Dataset(
+            namespace=f"redshift://{authority}",
+            name=f"{database}.{self.schema}.{self.table}",
+            facets=output_dataset_facets,
+        )
+
+        return OperatorLineage(inputs=[input_dataset], 
outputs=[output_dataset])
diff --git a/airflow/providers/amazon/aws/utils/openlineage.py 
b/airflow/providers/amazon/aws/utils/openlineage.py
new file mode 100644
index 0000000000..db472a3e46
--- /dev/null
+++ b/airflow/providers/amazon/aws/utils/openlineage.py
@@ -0,0 +1,136 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any
+
+from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
+from airflow.providers.common.compat.openlineage.facet import (
+    ColumnLineageDatasetFacet,
+    DocumentationDatasetFacet,
+    Fields,
+    InputField,
+    SchemaDatasetFacet,
+    SchemaDatasetFacetFields,
+)
+
+if TYPE_CHECKING:
+    from airflow.providers.amazon.aws.hooks.redshift_data import 
RedshiftDataHook
+
+
+def get_facets_from_redshift_table(
+    redshift_hook: RedshiftDataHook | RedshiftSQLHook,
+    table: str,
+    redshift_data_api_kwargs: dict,
+    schema: str = "public",
+) -> dict[Any, Any]:
+    """
+    Query redshift for table metadata.
+
+    SchemaDatasetFacet and DocumentationDatasetFacet (if table has 
description) will be created.
+    """
+    sql = f"""
+    SELECT
+        cols.column_name,
+        cols.data_type,
+        col_des.description as column_description,
+        tbl_des.description as table_description
+    FROM
+        information_schema.columns cols
+    LEFT JOIN
+        pg_catalog.pg_description col_des
+    ON
+        cols.ordinal_position = col_des.objsubid
+        AND col_des.objoid = (SELECT oid FROM pg_class WHERE relnamespace =
+        (SELECT oid FROM pg_namespace WHERE nspname = cols.table_schema) AND 
relname = cols.table_name)
+    LEFT JOIN
+        pg_catalog.pg_class tbl
+    ON
+        tbl.relname = cols.table_name
+        AND tbl.relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = 
cols.table_schema)
+    LEFT JOIN
+        pg_catalog.pg_description tbl_des
+    ON
+        tbl.oid = tbl_des.objoid
+        AND tbl_des.objsubid = 0
+    WHERE
+        cols.table_name = '{table}'
+        AND cols.table_schema = '{schema}';
+    """
+    if isinstance(redshift_hook, RedshiftSQLHook):
+        records = redshift_hook.get_records(sql)
+        if records:
+            table_description = records[0][-1]  # Assuming the table 
description is the same for all rows
+        else:
+            table_description = None
+        documentation = 
DocumentationDatasetFacet(description=table_description or "")
+        table_schema = SchemaDatasetFacet(
+            fields=[
+                SchemaDatasetFacetFields(name=field[0], type=field[1], 
description=field[2])
+                for field in records
+            ]
+        )
+    else:
+        statement_id = redshift_hook.execute_query(sql=sql, poll_interval=1, 
**redshift_data_api_kwargs)
+        response = redshift_hook.conn.get_statement_result(Id=statement_id)
+
+        table_schema = SchemaDatasetFacet(
+            fields=[
+                SchemaDatasetFacetFields(
+                    name=field[0]["stringValue"],
+                    type=field[1]["stringValue"],
+                    description=field[2].get("stringValue"),
+                )
+                for field in response["Records"]
+            ]
+        )
+        #  Table description will be the same for all fields, so we retrieve 
it from first field.
+        documentation = DocumentationDatasetFacet(
+            description=response["Records"][0][3].get("stringValue") or ""
+        )
+
+    return {"schema": table_schema, "documentation": documentation}
+
+
+def get_identity_column_lineage_facet(
+    field_names,
+    input_datasets,
+) -> ColumnLineageDatasetFacet:
+    """
+    Get column lineage facet.
+
+    Simple lineage will be created, where each source column corresponds to 
single destination column
+    in each input dataset and there are no transformations made.
+    """
+    if field_names and not input_datasets:
+        raise ValueError("When providing `field_names` You must provide at 
least one `input_dataset`.")
+
+    column_lineage_facet = ColumnLineageDatasetFacet(
+        fields={
+            field: Fields(
+                inputFields=[
+                    InputField(namespace=dataset.namespace, name=dataset.name, 
field=field)
+                    for dataset in input_datasets
+                ],
+                transformationType="IDENTITY",
+                transformationDescription="identical",
+            )
+            for field in field_names
+        }
+    )
+    return column_lineage_facet
diff --git a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py 
b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
index cb5ef7fdb7..f554ce8699 100644
--- a/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
+++ b/tests/providers/amazon/aws/transfers/test_s3_to_redshift.py
@@ -26,6 +26,7 @@ from boto3.session import Session
 from airflow.exceptions import AirflowException
 from airflow.models.connection import Connection
 from airflow.providers.amazon.aws.transfers.s3_to_redshift import 
S3ToRedshiftOperator
+from airflow.providers.common.compat.openlineage.facet import 
LifecycleStateChange
 from tests.test_utils.asserts import assert_equal_ignore_multiple_spaces
 
 
@@ -496,3 +497,310 @@ class TestS3ToRedshiftTransfer:
         assert access_key in actual_copy_query
         assert secret_key in actual_copy_query
         assert_equal_ignore_multiple_spaces(actual_copy_query, 
expected_copy_query)
+
+    @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
+    
@mock.patch("airflow.models.connection.Connection.get_connection_from_secrets")
+    @mock.patch("boto3.session.Session")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run")
+    def test_get_openlineage_facets_on_complete_default(
+        self, mock_run, mock_session, mock_connection, mock_hook
+    ):
+        access_key = "aws_access_key_id"
+        secret_key = "aws_secret_access_key"
+        mock_session.return_value = Session(access_key, secret_key)
+        mock_session.return_value.access_key = access_key
+        mock_session.return_value.secret_key = secret_key
+        mock_session.return_value.token = None
+
+        mock_connection.return_value = mock.MagicMock(
+            schema="database", port=5439, 
host="cluster.id.region.redshift.amazonaws.com", extra_dejson={}
+        )
+
+        schema = "schema"
+        table = "table"
+        s3_bucket = "bucket"
+        s3_key = "key"
+        copy_options = ""
+
+        op = S3ToRedshiftOperator(
+            schema=schema,
+            table=table,
+            s3_bucket=s3_bucket,
+            s3_key=s3_key,
+            copy_options=copy_options,
+            redshift_conn_id="redshift_conn_id",
+            aws_conn_id="aws_conn_id",
+            task_id="task_id",
+            dag=None,
+        )
+        op.execute(None)
+
+        lineage = op.get_openlineage_facets_on_complete(None)
+        # Hook called two times - on operator execution, and on querying data 
in redshift to fetch schema
+        assert mock_run.call_count == 2
+
+        assert len(lineage.inputs) == 1
+        assert len(lineage.outputs) == 1
+        assert lineage.inputs[0].name == s3_key
+        assert lineage.outputs[0].name == f"database.{schema}.{table}"
+        assert lineage.outputs[0].namespace == "redshift://cluster.region:5439"
+
+        assert lineage.outputs[0].facets.get("schema") is not None
+        assert lineage.outputs[0].facets.get("columnLineage") is not None
+
+        assert lineage.inputs[0].facets.get("schema") is not None
+        # As method was not overwrite, there should be no lifecycleStateChange 
facet
+        assert "lifecycleStateChange" not in lineage.outputs[0].facets
+
+    @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
+    
@mock.patch("airflow.models.connection.Connection.get_connection_from_secrets")
+    @mock.patch("boto3.session.Session")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run")
+    def test_get_openlineage_facets_on_complete_replace(
+        self, mock_run, mock_session, mock_connection, mock_hook
+    ):
+        access_key = "aws_access_key_id"
+        secret_key = "aws_secret_access_key"
+        mock_session.return_value = Session(access_key, secret_key)
+        mock_session.return_value.access_key = access_key
+        mock_session.return_value.secret_key = secret_key
+        mock_session.return_value.token = None
+
+        mock_connection.return_value = mock.MagicMock(
+            schema="database", port=5439, 
host="cluster.id.region.redshift.amazonaws.com", extra_dejson={}
+        )
+
+        schema = "schema"
+        table = "table"
+        s3_bucket = "bucket"
+        s3_key = "key"
+        copy_options = ""
+
+        op = S3ToRedshiftOperator(
+            schema=schema,
+            table=table,
+            s3_bucket=s3_bucket,
+            s3_key=s3_key,
+            copy_options=copy_options,
+            method="REPLACE",
+            redshift_conn_id="redshift_conn_id",
+            aws_conn_id="aws_conn_id",
+            task_id="task_id",
+            dag=None,
+        )
+        op.execute(None)
+
+        lineage = op.get_openlineage_facets_on_complete(None)
+
+        assert (
+            
lineage.outputs[0].facets["lifecycleStateChange"].lifecycleStateChange
+            == LifecycleStateChange.OVERWRITE
+        )
+
+    @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
+    
@mock.patch("airflow.models.connection.Connection.get_connection_from_secrets")
+    @mock.patch("boto3.session.Session")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run")
+    def test_get_openlineage_facets_on_complete_column_list(
+        self, mock_run, mock_session, mock_connection, mock_hook
+    ):
+        access_key = "aws_access_key_id"
+        secret_key = "aws_secret_access_key"
+        mock_session.return_value = Session(access_key, secret_key)
+        mock_session.return_value.access_key = access_key
+        mock_session.return_value.secret_key = secret_key
+        mock_session.return_value.token = None
+
+        mock_connection.return_value = mock.MagicMock(
+            schema="database", port=5439, 
host="cluster.id.region.redshift.amazonaws.com", extra_dejson={}
+        )
+
+        schema = "schema"
+        table = "table"
+        s3_bucket = "bucket"
+        s3_key = "key"
+        copy_options = ""
+
+        op = S3ToRedshiftOperator(
+            schema=schema,
+            table=table,
+            s3_bucket=s3_bucket,
+            s3_key=s3_key,
+            copy_options=copy_options,
+            column_list=["column1", "column2"],
+            redshift_conn_id="redshift_conn_id",
+            aws_conn_id="aws_conn_id",
+            task_id="task_id",
+            dag=None,
+        )
+        op.execute(None)
+
+        lineage = op.get_openlineage_facets_on_complete(None)
+
+        assert lineage.outputs[0].facets.get("schema") is not None
+        assert lineage.inputs[0].facets.get("schema") is None
+
+    @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
+    
@mock.patch("airflow.models.connection.Connection.get_connection_from_secrets")
+    @mock.patch("boto3.session.Session")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
+    @mock.patch(
+        
"airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.region_name",
+        new_callable=mock.PropertyMock,
+    )
+    def test_get_openlineage_facets_on_complete_using_redshift_data_api(
+        self, mock_rs_region, mock_rs, mock_session, mock_connection, mock_hook
+    ):
+        """
+        Using the Redshift Data API instead of the SQL-based connection
+        """
+        access_key = "aws_access_key_id"
+        secret_key = "aws_secret_access_key"
+        mock_session.return_value = Session(access_key, secret_key)
+        mock_session.return_value.access_key = access_key
+        mock_session.return_value.secret_key = secret_key
+        mock_session.return_value.token = None
+
+        mock_hook.return_value = Connection()
+        mock_rs.execute_statement.return_value = {"Id": "STATEMENT_ID"}
+        mock_rs.describe_statement.return_value = {"Status": "FINISHED"}
+
+        mock_rs_region.return_value = "region"
+
+        schema = "schema"
+        table = "table"
+        s3_bucket = "bucket"
+        s3_key = "key"
+        copy_options = ""
+
+        # RS Data API params
+        database = "database"
+        cluster_identifier = "cluster"
+        db_user = "db_user"
+        secret_arn = "secret_arn"
+        statement_name = "statement_name"
+
+        op = S3ToRedshiftOperator(
+            schema=schema,
+            table=table,
+            s3_bucket=s3_bucket,
+            s3_key=s3_key,
+            copy_options=copy_options,
+            redshift_conn_id="redshift_conn_id",
+            aws_conn_id="aws_conn_id",
+            task_id="task_id",
+            dag=None,
+            redshift_data_api_kwargs=dict(
+                database=database,
+                cluster_identifier=cluster_identifier,
+                db_user=db_user,
+                secret_arn=secret_arn,
+                statement_name=statement_name,
+            ),
+        )
+        op.execute(None)
+
+        lineage = op.get_openlineage_facets_on_complete(None)
+
+        assert len(lineage.inputs) == 1
+        assert len(lineage.outputs) == 1
+        assert lineage.inputs[0].name == s3_key
+        assert lineage.outputs[0].name == f"database.{schema}.{table}"
+        assert lineage.outputs[0].namespace == "redshift://cluster.region:5439"
+
+        assert lineage.outputs[0].facets.get("schema") is not None
+        assert lineage.outputs[0].facets.get("columnLineage") is not None
+
+        assert lineage.inputs[0].facets.get("schema") is not None
+        # As method was not overwrite, there should be no lifecycleStateChange 
facet
+        assert "lifecycleStateChange" not in lineage.outputs[0].facets
+
+    @mock.patch("airflow.providers.amazon.aws.hooks.s3.S3Hook.get_connection")
+    
@mock.patch("airflow.models.connection.Connection.get_connection_from_secrets")
+    @mock.patch("boto3.session.Session")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.run")
+    
@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
+    @mock.patch(
+        
"airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.region_name",
+        new_callable=mock.PropertyMock,
+    )
+    def test_get_openlineage_facets_on_complete_data_and_sql_hooks_aligned(
+        self, mock_rs_region, mock_rs, mock_run, mock_session, 
mock_connection, mock_hook
+    ):
+        """
+        Ensuring both supported hooks - RedshiftDataHook and RedshiftSQLHook 
return same lineage.
+        """
+        access_key = "aws_access_key_id"
+        secret_key = "aws_secret_access_key"
+        mock_session.return_value = Session(access_key, secret_key)
+        mock_session.return_value.access_key = access_key
+        mock_session.return_value.secret_key = secret_key
+        mock_session.return_value.token = None
+
+        mock_connection.return_value = mock.MagicMock(
+            schema="database", port=5439, 
host="cluster.id.region.redshift.amazonaws.com", extra_dejson={}
+        )
+        mock_hook.return_value = Connection()
+        mock_rs.execute_statement.return_value = {"Id": "STATEMENT_ID"}
+        mock_rs.describe_statement.return_value = {"Status": "FINISHED"}
+
+        mock_rs_region.return_value = "region"
+
+        schema = "schema"
+        table = "table"
+        s3_bucket = "bucket"
+        s3_key = "key"
+        copy_options = ""
+
+        # RS Data API params
+        database = "database"
+        cluster_identifier = "cluster"
+        db_user = "db_user"
+        secret_arn = "secret_arn"
+        statement_name = "statement_name"
+
+        op_rs_data = S3ToRedshiftOperator(
+            schema=schema,
+            table=table,
+            s3_bucket=s3_bucket,
+            s3_key=s3_key,
+            copy_options=copy_options,
+            redshift_conn_id="redshift_conn_id",
+            aws_conn_id="aws_conn_id",
+            task_id="task_id",
+            dag=None,
+            redshift_data_api_kwargs=dict(
+                database=database,
+                cluster_identifier=cluster_identifier,
+                db_user=db_user,
+                secret_arn=secret_arn,
+                statement_name=statement_name,
+            ),
+        )
+        op_rs_data.execute(None)
+        rs_data_lineage = op_rs_data.get_openlineage_facets_on_complete(None)
+
+        op_rs_sql = S3ToRedshiftOperator(
+            schema=schema,
+            table=table,
+            s3_bucket=s3_bucket,
+            s3_key=s3_key,
+            copy_options=copy_options,
+            redshift_conn_id="redshift_conn_id",
+            aws_conn_id="aws_conn_id",
+            task_id="task_id",
+            dag=None,
+        )
+        op_rs_sql.execute(None)
+        rs_sql_lineage = op_rs_sql.get_openlineage_facets_on_complete(None)
+
+        assert rs_sql_lineage.inputs == rs_data_lineage.inputs
+        assert len(rs_sql_lineage.outputs) == 1
+        assert len(rs_data_lineage.outputs) == 1
+        assert rs_sql_lineage.outputs[0].facets["schema"] == 
rs_data_lineage.outputs[0].facets["schema"]
+        assert (
+            rs_sql_lineage.outputs[0].facets["columnLineage"]
+            == rs_data_lineage.outputs[0].facets["columnLineage"]
+        )
+        assert rs_sql_lineage.outputs[0].name == 
rs_data_lineage.outputs[0].name
+        assert rs_sql_lineage.outputs[0].namespace == 
rs_data_lineage.outputs[0].namespace
diff --git a/tests/providers/amazon/aws/utils/test_openlineage.py 
b/tests/providers/amazon/aws/utils/test_openlineage.py
new file mode 100644
index 0000000000..b3e820b581
--- /dev/null
+++ b/tests/providers/amazon/aws/utils/test_openlineage.py
@@ -0,0 +1,168 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+from airflow.providers.amazon.aws.hooks.redshift_data import RedshiftDataHook
+from airflow.providers.amazon.aws.hooks.redshift_sql import RedshiftSQLHook
+from airflow.providers.amazon.aws.utils.openlineage import (
+    get_facets_from_redshift_table,
+    get_identity_column_lineage_facet,
+)
+from airflow.providers.common.compat.openlineage.facet import (
+    ColumnLineageDatasetFacet,
+    Dataset,
+    Fields,
+    InputField,
+)
+
+
+@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.get_records")
+def test_get_facets_from_redshift_table_sql_hook(mock_get_records):
+    mock_get_records.return_value = [
+        ("column1", "varchar", "Column 1 description", "Table description"),
+        ("column2", "int", "Column 2 description", "Table description"),
+    ]
+
+    mock_hook = RedshiftSQLHook()
+
+    result = get_facets_from_redshift_table(
+        redshift_hook=mock_hook, table="my_table", redshift_data_api_kwargs={}
+    )
+
+    assert result["documentation"].description == "Table description"
+    assert len(result["schema"].fields) == 2
+    assert result["schema"].fields[0].name == "column1"
+    assert result["schema"].fields[0].type == "varchar"
+    assert result["schema"].fields[0].description == "Column 1 description"
+
+
+@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.execute_query")
+@mock.patch("airflow.providers.amazon.aws.hooks.redshift_data.RedshiftDataHook.conn")
+def test_get_facets_from_redshift_table_data_hook(mock_connection, 
mock_execute_query):
+    mock_execute_query.return_value = "statement_id"
+    mock_connection.get_statement_result.return_value = {
+        "Records": [
+            [
+                {"stringValue": "column1"},
+                {"stringValue": "varchar"},
+                {"stringValue": "Column 1 description"},
+                {"stringValue": "Table description"},
+            ],
+            [
+                {"stringValue": "column2"},
+                {"stringValue": "int"},
+                {"stringValue": "Column 2 description"},
+                {"stringValue": "Table description"},
+            ],
+        ]
+    }
+
+    mock_hook = RedshiftDataHook()
+
+    result = get_facets_from_redshift_table(
+        redshift_hook=mock_hook, table="my_table", redshift_data_api_kwargs={}
+    )
+
+    assert result["documentation"].description == "Table description"
+    assert len(result["schema"].fields) == 2
+    assert result["schema"].fields[0].name == "column1"
+    assert result["schema"].fields[0].type == "varchar"
+    assert result["schema"].fields[0].description == "Column 1 description"
+
+
+@mock.patch("airflow.providers.amazon.aws.hooks.redshift_sql.RedshiftSQLHook.get_records")
+def test_get_facets_no_records_sql_hook(mock_get_records):
+    mock_get_records.return_value = []
+
+    mock_hook = RedshiftSQLHook()
+
+    result = get_facets_from_redshift_table(
+        redshift_hook=mock_hook, table="my_table", redshift_data_api_kwargs={}
+    )
+
+    assert result["documentation"].description == ""
+    assert len(result["schema"].fields) == 0
+
+
+def test_get_identity_column_lineage_facet_multiple_input_datasets():
+    field_names = ["field1", "field2"]
+    input_datasets = [
+        Dataset(namespace="s3://first_bucket", name="dir1"),
+        Dataset(namespace="s3://second_bucket", name="dir2"),
+    ]
+    expected_facet = ColumnLineageDatasetFacet(
+        fields={
+            "field1": Fields(
+                inputFields=[
+                    InputField(
+                        namespace="s3://first_bucket",
+                        name="dir1",
+                        field="field1",
+                    ),
+                    InputField(
+                        namespace="s3://second_bucket",
+                        name="dir2",
+                        field="field1",
+                    ),
+                ],
+                transformationType="IDENTITY",
+                transformationDescription="identical",
+            ),
+            "field2": Fields(
+                inputFields=[
+                    InputField(
+                        namespace="s3://first_bucket",
+                        name="dir1",
+                        field="field2",
+                    ),
+                    InputField(
+                        namespace="s3://second_bucket",
+                        name="dir2",
+                        field="field2",
+                    ),
+                ],
+                transformationType="IDENTITY",
+                transformationDescription="identical",
+            ),
+        }
+    )
+    result = get_identity_column_lineage_facet(field_names=field_names, 
input_datasets=input_datasets)
+    assert result == expected_facet
+
+
+def test_get_identity_column_lineage_facet_no_field_names():
+    field_names = []
+    input_datasets = [
+        Dataset(namespace="s3://first_bucket", name="dir1"),
+        Dataset(namespace="s3://second_bucket", name="dir2"),
+    ]
+    expected_facet = ColumnLineageDatasetFacet(fields={})
+    result = get_identity_column_lineage_facet(field_names=field_names, 
input_datasets=input_datasets)
+    assert result == expected_facet
+
+
+def test_get_identity_column_lineage_facet_no_input_datasets():
+    field_names = ["field1", "field2"]
+    input_datasets = []
+
+    with pytest.raises(ValueError):
+        get_identity_column_lineage_facet(field_names=field_names, 
input_datasets=input_datasets)

Reply via email to