mobuchowski commented on code in PR #36535:
URL: https://github.com/apache/airflow/pull/36535#discussion_r1444441436


##########
airflow/providers/snowflake/transfers/copy_into_snowflake.py:
##########
@@ -137,5 +141,158 @@ def execute(self, context: Any) -> None:
         {self.validation_mode or ""}
         """
         self.log.info("Executing COPY command...")
-        snowflake_hook.run(sql=sql, autocommit=self.autocommit)
+        self._result = self.hook.run(  # type: ignore # mypy does not work 
well with return_dictionaries=True
+            sql=self._sql,
+            autocommit=self.autocommit,
+            handler=lambda x: x.fetchall(),
+            return_dictionaries=True,
+        )
         self.log.info("COPY command completed")
+
+    @staticmethod
+    def _extract_openlineage_unique_dataset_paths(
+        query_result: list[dict[str, Any]],
+    ) -> tuple[list[tuple[str, str]], list[str]]:
+        """Extracts and returns unique OpenLineage dataset paths and file 
paths that failed to be parsed.
+
+        Each row in the results is expected to have a 'file' field, which is a 
URI.
+        The function parses these URIs and constructs a set of unique 
OpenLineage (namespace, name) tuples.
+        Additionally, it captures any URIs that cannot be parsed or processed
+        and returns them in a separate error list.
+
+        For Azure, Snowflake has a unique way of representing URI:
+            
azure://<account_name>.blob.core.windows.net/<container_name>/path/to/file.csv
+        that is transformed by this function to a Dataset with more universal 
naming convention:
+            Dataset(namespace="wasbs://container_name@account_name", 
name="path/to"), as described at
+        
https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md#wasbs-azure-blob-storage
+
+        :param query_result: A list of dictionaries, each containing a 'file' 
key with a URI value.
+        :return: Two lists - the first is a sorted list of tuples, each 
representing a unique dataset path,
+         and the second contains any URIs that cannot be parsed or processed 
correctly.
+
+        >>> method = 
CopyFromExternalStageToSnowflakeOperator._extract_openlineage_unique_dataset_paths
+
+        >>> results = [{"file": 
"azure://my_account.blob.core.windows.net/azure_container/dir3/file.csv"}]
+        >>> method(results)
+        ([('wasbs://azure_container@my_account', 'dir3')], [])
+
+        >>> results = [{"file": 
"azure://my_account.blob.core.windows.net/azure_container"}]
+        >>> method(results)
+        ([('wasbs://azure_container@my_account', '/')], [])
+
+        >>> results = [{"file": "s3://bucket"}, {"file": "gcs://bucket/"}, 
{"file": "s3://bucket/a.csv"}]
+        >>> method(results)
+        ([('gcs://bucket', '/'), ('s3://bucket', '/')], [])
+
+        >>> results = [{"file": "s3://bucket/dir/file.csv"}, {"file": 
"gcs://bucket/dir/dir2/a.txt"}]
+        >>> method(results)
+        ([('gcs://bucket', 'dir/dir2'), ('s3://bucket', 'dir')], [])
+
+        >>> results = [
+        ...     {"file": "s3://bucket/dir/file.csv"},
+        ...     {"file": 
"azure://my_account.something_new.windows.net/azure_container"},
+        ... ]
+        >>> method(results)
+        ([('s3://bucket', 'dir')], 
['azure://my_account.something_new.windows.net/azure_container'])
+        """
+        import re
+        from pathlib import Path
+        from urllib.parse import urlparse
+
+        azure_regex = r"azure:\/\/(\w+)?\.blob.core.windows.net\/(\w+)\/?(.*)?"
+        extraction_error_files = []
+        unique_dataset_paths = set()
+
+        for row in query_result:
+            uri = urlparse(row["file"])
+            if uri.scheme == "azure":
+                match = re.fullmatch(azure_regex, row["file"])
+                if not match:
+                    extraction_error_files.append(row["file"])
+                    continue
+                account_name, container_name, name = match.groups()
+                namespace = f"wasbs://{container_name}@{account_name}"
+            else:
+                namespace = f"{uri.scheme}://{uri.netloc}"
+                name = uri.path.lstrip("/")
+
+            name = Path(name).parent.as_posix()
+            if name in ("", "."):
+                name = "/"
+
+            unique_dataset_paths.add((namespace, name))
+
+        return sorted(unique_dataset_paths), sorted(extraction_error_files)
+
+    def get_openlineage_facets_on_complete(self, task_instance):
+        """Implement _on_complete because we rely on return value of a 
query."""
+        import re
+
+        from openlineage.client.facet import (
+            ExternalQueryRunFacet,
+            ExtractionError,
+            ExtractionErrorRunFacet,
+            SqlJobFacet,
+        )
+        from openlineage.client.run import Dataset
+
+        from airflow.providers.openlineage.extractors import OperatorLineage
+        from airflow.providers.openlineage.sqlparser import SQLParser
+
+        if not self._sql:
+            return OperatorLineage()
+
+        query_results = self._result or []
+        # If no files were uploaded we get [{"status": "0 files were 
uploaded..."}]
+        if len(query_results) == 1 and query_results[0].get("status"):
+            query_results = []
+        unique_dataset_paths, extraction_error_files = 
self._extract_openlineage_unique_dataset_paths(
+            query_results
+        )
+        input_datasets = [Dataset(namespace=namespace, name=name) for 
namespace, name in unique_dataset_paths]
+
+        run_facets = {}
+        if extraction_error_files:
+            self.log.debug(
+                f"Unable to extract Dataset namespace and name "
+                f"for the following files: `{extraction_error_files}`."
+            )
+            run_facets["extractionError"] = ExtractionErrorRunFacet(
+                totalTasks=len(query_results),
+                failedTasks=len(extraction_error_files),
+                errors=[
+                    ExtractionError(

Review Comment:
   👍 



##########
airflow/providers/snowflake/transfers/copy_into_snowflake.py:
##########
@@ -137,5 +141,158 @@ def execute(self, context: Any) -> None:
         {self.validation_mode or ""}
         """
         self.log.info("Executing COPY command...")
-        snowflake_hook.run(sql=sql, autocommit=self.autocommit)
+        self._result = self.hook.run(  # type: ignore # mypy does not work 
well with return_dictionaries=True
+            sql=self._sql,
+            autocommit=self.autocommit,
+            handler=lambda x: x.fetchall(),
+            return_dictionaries=True,
+        )
         self.log.info("COPY command completed")
+
+    @staticmethod
+    def _extract_openlineage_unique_dataset_paths(
+        query_result: list[dict[str, Any]],
+    ) -> tuple[list[tuple[str, str]], list[str]]:
+        """Extracts and returns unique OpenLineage dataset paths and file 
paths that failed to be parsed.
+
+        Each row in the results is expected to have a 'file' field, which is a 
URI.
+        The function parses these URIs and constructs a set of unique 
OpenLineage (namespace, name) tuples.
+        Additionally, it captures any URIs that cannot be parsed or processed
+        and returns them in a separate error list.
+
+        For Azure, Snowflake has a unique way of representing URI:
+            
azure://<account_name>.blob.core.windows.net/<container_name>/path/to/file.csv
+        that is transformed by this function to a Dataset with more universal 
naming convention:
+            Dataset(namespace="wasbs://container_name@account_name", 
name="path/to"), as described at
+        
https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md#wasbs-azure-blob-storage
+
+        :param query_result: A list of dictionaries, each containing a 'file' 
key with a URI value.
+        :return: Two lists - the first is a sorted list of tuples, each 
representing a unique dataset path,
+         and the second contains any URIs that cannot be parsed or processed 
correctly.
+
+        >>> method = 
CopyFromExternalStageToSnowflakeOperator._extract_openlineage_unique_dataset_paths
+
+        >>> results = [{"file": 
"azure://my_account.blob.core.windows.net/azure_container/dir3/file.csv"}]
+        >>> method(results)
+        ([('wasbs://azure_container@my_account', 'dir3')], [])
+
+        >>> results = [{"file": 
"azure://my_account.blob.core.windows.net/azure_container"}]
+        >>> method(results)
+        ([('wasbs://azure_container@my_account', '/')], [])
+
+        >>> results = [{"file": "s3://bucket"}, {"file": "gcs://bucket/"}, 
{"file": "s3://bucket/a.csv"}]
+        >>> method(results)
+        ([('gcs://bucket', '/'), ('s3://bucket', '/')], [])
+
+        >>> results = [{"file": "s3://bucket/dir/file.csv"}, {"file": 
"gcs://bucket/dir/dir2/a.txt"}]
+        >>> method(results)
+        ([('gcs://bucket', 'dir/dir2'), ('s3://bucket', 'dir')], [])
+
+        >>> results = [
+        ...     {"file": "s3://bucket/dir/file.csv"},
+        ...     {"file": 
"azure://my_account.something_new.windows.net/azure_container"},
+        ... ]
+        >>> method(results)
+        ([('s3://bucket', 'dir')], 
['azure://my_account.something_new.windows.net/azure_container'])
+        """
+        import re
+        from pathlib import Path
+        from urllib.parse import urlparse
+
+        azure_regex = r"azure:\/\/(\w+)?\.blob.core.windows.net\/(\w+)\/?(.*)?"
+        extraction_error_files = []
+        unique_dataset_paths = set()
+
+        for row in query_result:
+            uri = urlparse(row["file"])
+            if uri.scheme == "azure":
+                match = re.fullmatch(azure_regex, row["file"])
+                if not match:
+                    extraction_error_files.append(row["file"])
+                    continue
+                account_name, container_name, name = match.groups()
+                namespace = f"wasbs://{container_name}@{account_name}"
+            else:
+                namespace = f"{uri.scheme}://{uri.netloc}"
+                name = uri.path.lstrip("/")
+
+            name = Path(name).parent.as_posix()
+            if name in ("", "."):
+                name = "/"
+
+            unique_dataset_paths.add((namespace, name))
+
+        return sorted(unique_dataset_paths), sorted(extraction_error_files)
+
+    def get_openlineage_facets_on_complete(self, task_instance):
+        """Implement _on_complete because we rely on return value of a 
query."""
+        import re
+
+        from openlineage.client.facet import (
+            ExternalQueryRunFacet,
+            ExtractionError,
+            ExtractionErrorRunFacet,
+            SqlJobFacet,
+        )
+        from openlineage.client.run import Dataset
+
+        from airflow.providers.openlineage.extractors import OperatorLineage
+        from airflow.providers.openlineage.sqlparser import SQLParser
+
+        if not self._sql:
+            return OperatorLineage()
+
+        query_results = self._result or []
+        # If no files were uploaded we get [{"status": "0 files were 
uploaded..."}]
+        if len(query_results) == 1 and query_results[0].get("status"):
+            query_results = []
+        unique_dataset_paths, extraction_error_files = 
self._extract_openlineage_unique_dataset_paths(
+            query_results
+        )
+        input_datasets = [Dataset(namespace=namespace, name=name) for 
namespace, name in unique_dataset_paths]
+
+        run_facets = {}
+        if extraction_error_files:
+            self.log.debug(
+                f"Unable to extract Dataset namespace and name "
+                f"for the following files: `{extraction_error_files}`."
+            )
+            run_facets["extractionError"] = ExtractionErrorRunFacet(
+                totalTasks=len(query_results),
+                failedTasks=len(extraction_error_files),
+                errors=[
+                    ExtractionError(
+                        errorMessage="Unable to extract Dataset namespace and 
name.",
+                        stackTrace=None,
+                        task=file_uri,
+                        taskNumber=None,
+                    )
+                    for file_uri in extraction_error_files
+                ],
+            )
+
+        connection = self.hook.get_connection(getattr(self.hook, 
str(self.hook.conn_name_attr)))
+        database_info = self.hook.get_openlineage_database_info(connection)
+
+        dest_name = self.table
+        schema = self.hook.get_openlineage_default_schema()
+        database = database_info.database
+        if schema:
+            dest_name = f"{schema}.{dest_name}"
+            if database:
+                dest_name = f"{database}.{dest_name}"
+
+        snowflake_namespace = SQLParser.create_namespace(database_info)
+        query = SQLParser.normalize_sql(self._sql)
+        query = re.sub(r"\n+", "\n", re.sub(r" +", " ", query))

Review Comment:
   Why do we need to do this?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@airflow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to