This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3dc99d8a28 feat: Add openlineage support for 
CopyFromExternalStageToSnowflakeOperator (#36535)
3dc99d8a28 is described below

commit 3dc99d8a285aaadeb83797e691c9f6ec93ff9c93
Author: Kacper Muda <mudakac...@gmail.com>
AuthorDate: Mon Jan 8 13:02:46 2024 +0100

    feat: Add openlineage support for CopyFromExternalStageToSnowflakeOperator 
(#36535)
---
 .../snowflake/transfers/copy_into_snowflake.py     | 163 +++++++++++++++++++-
 .../transfers/test_copy_into_snowflake.py          | 168 ++++++++++++++++++++-
 2 files changed, 327 insertions(+), 4 deletions(-)

diff --git a/airflow/providers/snowflake/transfers/copy_into_snowflake.py 
b/airflow/providers/snowflake/transfers/copy_into_snowflake.py
index 10071add1a..342d5dc35a 100644
--- a/airflow/providers/snowflake/transfers/copy_into_snowflake.py
+++ b/airflow/providers/snowflake/transfers/copy_into_snowflake.py
@@ -108,8 +108,12 @@ class 
CopyFromExternalStageToSnowflakeOperator(BaseOperator):
         self.copy_options = copy_options
         self.validation_mode = validation_mode
 
+        self.hook: SnowflakeHook | None = None
+        self._sql: str | None = None
+        self._result: list[dict[str, Any]] = []
+
     def execute(self, context: Any) -> None:
-        snowflake_hook = SnowflakeHook(
+        self.hook = SnowflakeHook(
             snowflake_conn_id=self.snowflake_conn_id,
             warehouse=self.warehouse,
             database=self.database,
@@ -127,7 +131,7 @@ class 
CopyFromExternalStageToSnowflakeOperator(BaseOperator):
         if self.columns_array:
             into = f"{into}({', '.join(self.columns_array)})"
 
-        sql = f"""
+        self._sql = f"""
         COPY INTO {into}
              FROM  @{self.stage}/{self.prefix or ""}
         {"FILES=(" + ",".join(map(enclose_param, self.files)) + ")" if 
self.files else ""}
@@ -137,5 +141,158 @@ class 
CopyFromExternalStageToSnowflakeOperator(BaseOperator):
         {self.validation_mode or ""}
         """
         self.log.info("Executing COPY command...")
-        snowflake_hook.run(sql=sql, autocommit=self.autocommit)
+        self._result = self.hook.run(  # type: ignore # mypy does not work 
well with return_dictionaries=True
+            sql=self._sql,
+            autocommit=self.autocommit,
+            handler=lambda x: x.fetchall(),
+            return_dictionaries=True,
+        )
         self.log.info("COPY command completed")
+
+    @staticmethod
+    def _extract_openlineage_unique_dataset_paths(
+        query_result: list[dict[str, Any]],
+    ) -> tuple[list[tuple[str, str]], list[str]]:
+        """Extracts and returns unique OpenLineage dataset paths and file 
paths that failed to be parsed.
+
+        Each row in the results is expected to have a 'file' field, which is a 
URI.
+        The function parses these URIs and constructs a set of unique 
OpenLineage (namespace, name) tuples.
+        Additionally, it captures any URIs that cannot be parsed or processed
+        and returns them in a separate error list.
+
+        For Azure, Snowflake has a unique way of representing URI:
+            
azure://<account_name>.blob.core.windows.net/<container_name>/path/to/file.csv
+        that is transformed by this function to a Dataset with more universal 
naming convention:
+            Dataset(namespace="wasbs://container_name@account_name", 
name="path/to"), as described at
+        
https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md#wasbs-azure-blob-storage
+
+        :param query_result: A list of dictionaries, each containing a 'file' 
key with a URI value.
+        :return: Two lists - the first is a sorted list of tuples, each 
representing a unique dataset path,
+         and the second contains any URIs that cannot be parsed or processed 
correctly.
+
+        >>> method = 
CopyFromExternalStageToSnowflakeOperator._extract_openlineage_unique_dataset_paths
+
+        >>> results = [{"file": 
"azure://my_account.blob.core.windows.net/azure_container/dir3/file.csv"}]
+        >>> method(results)
+        ([('wasbs://azure_container@my_account', 'dir3')], [])
+
+        >>> results = [{"file": 
"azure://my_account.blob.core.windows.net/azure_container"}]
+        >>> method(results)
+        ([('wasbs://azure_container@my_account', '/')], [])
+
+        >>> results = [{"file": "s3://bucket"}, {"file": "gcs://bucket/"}, 
{"file": "s3://bucket/a.csv"}]
+        >>> method(results)
+        ([('gcs://bucket', '/'), ('s3://bucket', '/')], [])
+
+        >>> results = [{"file": "s3://bucket/dir/file.csv"}, {"file": 
"gcs://bucket/dir/dir2/a.txt"}]
+        >>> method(results)
+        ([('gcs://bucket', 'dir/dir2'), ('s3://bucket', 'dir')], [])
+
+        >>> results = [
+        ...     {"file": "s3://bucket/dir/file.csv"},
+        ...     {"file": 
"azure://my_account.something_new.windows.net/azure_container"},
+        ... ]
+        >>> method(results)
+        ([('s3://bucket', 'dir')], 
['azure://my_account.something_new.windows.net/azure_container'])
+        """
+        import re
+        from pathlib import Path
+        from urllib.parse import urlparse
+
+        azure_regex = r"azure:\/\/(\w+)?\.blob.core.windows.net\/(\w+)\/?(.*)?"
+        extraction_error_files = []
+        unique_dataset_paths = set()
+
+        for row in query_result:
+            uri = urlparse(row["file"])
+            if uri.scheme == "azure":
+                match = re.fullmatch(azure_regex, row["file"])
+                if not match:
+                    extraction_error_files.append(row["file"])
+                    continue
+                account_name, container_name, name = match.groups()
+                namespace = f"wasbs://{container_name}@{account_name}"
+            else:
+                namespace = f"{uri.scheme}://{uri.netloc}"
+                name = uri.path.lstrip("/")
+
+            name = Path(name).parent.as_posix()
+            if name in ("", "."):
+                name = "/"
+
+            unique_dataset_paths.add((namespace, name))
+
+        return sorted(unique_dataset_paths), sorted(extraction_error_files)
+
+    def get_openlineage_facets_on_complete(self, task_instance):
+        """Implement _on_complete because we rely on return value of a 
query."""
+        import re
+
+        from openlineage.client.facet import (
+            ExternalQueryRunFacet,
+            ExtractionError,
+            ExtractionErrorRunFacet,
+            SqlJobFacet,
+        )
+        from openlineage.client.run import Dataset
+
+        from airflow.providers.openlineage.extractors import OperatorLineage
+        from airflow.providers.openlineage.sqlparser import SQLParser
+
+        if not self._sql:
+            return OperatorLineage()
+
+        query_results = self._result or []
+        # If no files were uploaded we get [{"status": "0 files were 
uploaded..."}]
+        if len(query_results) == 1 and query_results[0].get("status"):
+            query_results = []
+        unique_dataset_paths, extraction_error_files = 
self._extract_openlineage_unique_dataset_paths(
+            query_results
+        )
+        input_datasets = [Dataset(namespace=namespace, name=name) for 
namespace, name in unique_dataset_paths]
+
+        run_facets = {}
+        if extraction_error_files:
+            self.log.debug(
+                f"Unable to extract Dataset namespace and name "
+                f"for the following files: `{extraction_error_files}`."
+            )
+            run_facets["extractionError"] = ExtractionErrorRunFacet(
+                totalTasks=len(query_results),
+                failedTasks=len(extraction_error_files),
+                errors=[
+                    ExtractionError(
+                        errorMessage="Unable to extract Dataset namespace and 
name.",
+                        stackTrace=None,
+                        task=file_uri,
+                        taskNumber=None,
+                    )
+                    for file_uri in extraction_error_files
+                ],
+            )
+
+        connection = self.hook.get_connection(getattr(self.hook, 
str(self.hook.conn_name_attr)))
+        database_info = self.hook.get_openlineage_database_info(connection)
+
+        dest_name = self.table
+        schema = self.hook.get_openlineage_default_schema()
+        database = database_info.database
+        if schema:
+            dest_name = f"{schema}.{dest_name}"
+            if database:
+                dest_name = f"{database}.{dest_name}"
+
+        snowflake_namespace = SQLParser.create_namespace(database_info)
+        query = SQLParser.normalize_sql(self._sql)
+        query = re.sub(r"\n+", "\n", re.sub(r" +", " ", query))
+
+        run_facets["externalQuery"] = ExternalQueryRunFacet(
+            externalQueryId=self.hook.query_ids[0], source=snowflake_namespace
+        )
+
+        return OperatorLineage(
+            inputs=input_datasets,
+            outputs=[Dataset(namespace=snowflake_namespace, name=dest_name)],
+            job_facets={"sql": SqlJobFacet(query=query)},
+            run_facets=run_facets,
+        )
diff --git a/tests/providers/snowflake/transfers/test_copy_into_snowflake.py 
b/tests/providers/snowflake/transfers/test_copy_into_snowflake.py
index 76268d077d..27e02dc41c 100644
--- a/tests/providers/snowflake/transfers/test_copy_into_snowflake.py
+++ b/tests/providers/snowflake/transfers/test_copy_into_snowflake.py
@@ -16,8 +16,20 @@
 # under the License.
 from __future__ import annotations
 
+from typing import Callable
 from unittest import mock
 
+from openlineage.client.facet import (
+    ExternalQueryRunFacet,
+    ExtractionError,
+    ExtractionErrorRunFacet,
+    SqlJobFacet,
+)
+from openlineage.client.run import Dataset
+from pytest import mark
+
+from airflow.providers.openlineage.extractors import OperatorLineage
+from airflow.providers.openlineage.sqlparser import DatabaseInfo
 from airflow.providers.snowflake.transfers.copy_into_snowflake import 
CopyFromExternalStageToSnowflakeOperator
 
 
@@ -62,4 +74,158 @@ class TestCopyFromExternalStageToSnowflake:
         validation_mode
         """
 
-        mock_hook.return_value.run.assert_called_once_with(sql=sql, 
autocommit=True)
+        mock_hook.return_value.run.assert_called_once_with(
+            sql=sql, autocommit=True, return_dictionaries=True, 
handler=mock.ANY
+        )
+
+        handler = 
mock_hook.return_value.run.mock_calls[0].kwargs.get("handler")
+        assert isinstance(handler, Callable)
+
+    
@mock.patch("airflow.providers.snowflake.transfers.copy_into_snowflake.SnowflakeHook")
+    def test_get_openlineage_facets_on_complete(self, mock_hook):
+        mock_hook().run.return_value = [
+            {"file": "s3://aws_bucket_name/dir1/file.csv"},
+            {"file": "s3://aws_bucket_name_2"},
+            {"file": "gcs://gcs_bucket_name/dir2/file.csv"},
+            {"file": "gcs://gcs_bucket_name_2"},
+            {"file": 
"azure://my_account.blob.core.windows.net/azure_container/dir3/file.csv"},
+            {"file": 
"azure://my_account.blob.core.windows.net/azure_container_2"},
+        ]
+        mock_hook().get_openlineage_database_info.return_value = DatabaseInfo(
+            scheme="snowflake_scheme", authority="authority", 
database="actual_database"
+        )
+        mock_hook().get_openlineage_default_schema.return_value = 
"actual_schema"
+        mock_hook().query_ids = ["query_id_123"]
+
+        expected_inputs = [
+            Dataset(namespace="gcs://gcs_bucket_name", name="dir2"),
+            Dataset(namespace="gcs://gcs_bucket_name_2", name="/"),
+            Dataset(namespace="s3://aws_bucket_name", name="dir1"),
+            Dataset(namespace="s3://aws_bucket_name_2", name="/"),
+            Dataset(namespace="wasbs://azure_container@my_account", 
name="dir3"),
+            Dataset(namespace="wasbs://azure_container_2@my_account", 
name="/"),
+        ]
+        expected_outputs = [
+            Dataset(namespace="snowflake_scheme://authority", 
name="actual_database.actual_schema.table")
+        ]
+        expected_sql = """COPY INTO schema.table\n FROM @stage/\n 
FILE_FORMAT=CSV"""
+
+        op = CopyFromExternalStageToSnowflakeOperator(
+            task_id="test",
+            table="table",
+            stage="stage",
+            database="",
+            schema="schema",
+            file_format="CSV",
+        )
+        op.execute(None)
+        result = op.get_openlineage_facets_on_complete(None)
+        assert result == OperatorLineage(
+            inputs=expected_inputs,
+            outputs=expected_outputs,
+            run_facets={
+                "externalQuery": ExternalQueryRunFacet(
+                    externalQueryId="query_id_123", 
source="snowflake_scheme://authority"
+                )
+            },
+            job_facets={"sql": SqlJobFacet(query=expected_sql)},
+        )
+
+    @mark.parametrize("rows", (None, []))
+    
@mock.patch("airflow.providers.snowflake.transfers.copy_into_snowflake.SnowflakeHook")
+    def test_get_openlineage_facets_on_complete_with_empty_inputs(self, 
mock_hook, rows):
+        mock_hook().run.return_value = rows
+        mock_hook().get_openlineage_database_info.return_value = DatabaseInfo(
+            scheme="snowflake_scheme", authority="authority", 
database="actual_database"
+        )
+        mock_hook().get_openlineage_default_schema.return_value = 
"actual_schema"
+        mock_hook().query_ids = ["query_id_123"]
+
+        expected_outputs = [
+            Dataset(namespace="snowflake_scheme://authority", 
name="actual_database.actual_schema.table")
+        ]
+        expected_sql = """COPY INTO schema.table\n FROM @stage/\n 
FILE_FORMAT=CSV"""
+
+        op = CopyFromExternalStageToSnowflakeOperator(
+            task_id="test",
+            table="table",
+            stage="stage",
+            database="",
+            schema="schema",
+            file_format="CSV",
+        )
+        op.execute(None)
+        result = op.get_openlineage_facets_on_complete(None)
+        assert result == OperatorLineage(
+            inputs=[],
+            outputs=expected_outputs,
+            run_facets={
+                "externalQuery": ExternalQueryRunFacet(
+                    externalQueryId="query_id_123", 
source="snowflake_scheme://authority"
+                )
+            },
+            job_facets={"sql": SqlJobFacet(query=expected_sql)},
+        )
+
+    
@mock.patch("airflow.providers.snowflake.transfers.copy_into_snowflake.SnowflakeHook")
+    def test_get_openlineage_facets_on_complete_unsupported_azure_uri(self, 
mock_hook):
+        mock_hook().run.return_value = [
+            {"file": "s3://aws_bucket_name/dir1/file.csv"},
+            {"file": "gs://gcp_bucket_name/dir2/file.csv"},
+            {"file": 
"azure://my_account.weird-url.net/azure_container/dir3/file.csv"},
+            {"file": "azure://my_account.another_weird-url.net/con/file.csv"},
+        ]
+        mock_hook().get_openlineage_database_info.return_value = DatabaseInfo(
+            scheme="snowflake_scheme", authority="authority", 
database="actual_database"
+        )
+        mock_hook().get_openlineage_default_schema.return_value = 
"actual_schema"
+        mock_hook().query_ids = ["query_id_123"]
+
+        expected_inputs = [
+            Dataset(namespace="gs://gcp_bucket_name", name="dir2"),
+            Dataset(namespace="s3://aws_bucket_name", name="dir1"),
+        ]
+        expected_outputs = [
+            Dataset(namespace="snowflake_scheme://authority", 
name="actual_database.actual_schema.table")
+        ]
+        expected_sql = """COPY INTO schema.table\n FROM @stage/\n 
FILE_FORMAT=CSV"""
+        expected_run_facets = {
+            "extractionError": ExtractionErrorRunFacet(
+                totalTasks=4,
+                failedTasks=2,
+                errors=[
+                    ExtractionError(
+                        errorMessage="Unable to extract Dataset namespace and 
name.",
+                        stackTrace=None,
+                        
task="azure://my_account.another_weird-url.net/con/file.csv",
+                        taskNumber=None,
+                    ),
+                    ExtractionError(
+                        errorMessage="Unable to extract Dataset namespace and 
name.",
+                        stackTrace=None,
+                        
task="azure://my_account.weird-url.net/azure_container/dir3/file.csv",
+                        taskNumber=None,
+                    ),
+                ],
+            ),
+            "externalQuery": ExternalQueryRunFacet(
+                externalQueryId="query_id_123", 
source="snowflake_scheme://authority"
+            ),
+        }
+
+        op = CopyFromExternalStageToSnowflakeOperator(
+            task_id="test",
+            table="table",
+            stage="stage",
+            database="",
+            schema="schema",
+            file_format="CSV",
+        )
+        op.execute(None)
+        result = op.get_openlineage_facets_on_complete(None)
+        assert result == OperatorLineage(
+            inputs=expected_inputs,
+            outputs=expected_outputs,
+            run_facets=expected_run_facets,
+            job_facets={"sql": SqlJobFacet(query=expected_sql)},
+        )

Reply via email to