This is an automated email from the ASF dual-hosted git repository. mobuchowski pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 3dc99d8a28 feat: Add openlineage support for CopyFromExternalStageToSnowflakeOperator (#36535) 3dc99d8a28 is described below commit 3dc99d8a285aaadeb83797e691c9f6ec93ff9c93 Author: Kacper Muda <mudakac...@gmail.com> AuthorDate: Mon Jan 8 13:02:46 2024 +0100 feat: Add openlineage support for CopyFromExternalStageToSnowflakeOperator (#36535) --- .../snowflake/transfers/copy_into_snowflake.py | 163 +++++++++++++++++++- .../transfers/test_copy_into_snowflake.py | 168 ++++++++++++++++++++- 2 files changed, 327 insertions(+), 4 deletions(-) diff --git a/airflow/providers/snowflake/transfers/copy_into_snowflake.py b/airflow/providers/snowflake/transfers/copy_into_snowflake.py index 10071add1a..342d5dc35a 100644 --- a/airflow/providers/snowflake/transfers/copy_into_snowflake.py +++ b/airflow/providers/snowflake/transfers/copy_into_snowflake.py @@ -108,8 +108,12 @@ class CopyFromExternalStageToSnowflakeOperator(BaseOperator): self.copy_options = copy_options self.validation_mode = validation_mode + self.hook: SnowflakeHook | None = None + self._sql: str | None = None + self._result: list[dict[str, Any]] = [] + def execute(self, context: Any) -> None: - snowflake_hook = SnowflakeHook( + self.hook = SnowflakeHook( snowflake_conn_id=self.snowflake_conn_id, warehouse=self.warehouse, database=self.database, @@ -127,7 +131,7 @@ class CopyFromExternalStageToSnowflakeOperator(BaseOperator): if self.columns_array: into = f"{into}({', '.join(self.columns_array)})" - sql = f""" + self._sql = f""" COPY INTO {into} FROM @{self.stage}/{self.prefix or ""} {"FILES=(" + ",".join(map(enclose_param, self.files)) + ")" if self.files else ""} @@ -137,5 +141,158 @@ class CopyFromExternalStageToSnowflakeOperator(BaseOperator): {self.validation_mode or ""} """ self.log.info("Executing COPY command...") - snowflake_hook.run(sql=sql, autocommit=self.autocommit) + self._result = self.hook.run( # type: ignore # mypy does not work well with return_dictionaries=True + sql=self._sql, + autocommit=self.autocommit, + handler=lambda x: x.fetchall(), + return_dictionaries=True, + ) self.log.info("COPY command completed") + + @staticmethod + def _extract_openlineage_unique_dataset_paths( + query_result: list[dict[str, Any]], + ) -> tuple[list[tuple[str, str]], list[str]]: + """Extracts and returns unique OpenLineage dataset paths and file paths that failed to be parsed. + + Each row in the results is expected to have a 'file' field, which is a URI. + The function parses these URIs and constructs a set of unique OpenLineage (namespace, name) tuples. + Additionally, it captures any URIs that cannot be parsed or processed + and returns them in a separate error list. + + For Azure, Snowflake has a unique way of representing URI: + azure://<account_name>.blob.core.windows.net/<container_name>/path/to/file.csv + that is transformed by this function to a Dataset with more universal naming convention: + Dataset(namespace="wasbs://container_name@account_name", name="path/to"), as described at + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md#wasbs-azure-blob-storage + + :param query_result: A list of dictionaries, each containing a 'file' key with a URI value. + :return: Two lists - the first is a sorted list of tuples, each representing a unique dataset path, + and the second contains any URIs that cannot be parsed or processed correctly. + + >>> method = CopyFromExternalStageToSnowflakeOperator._extract_openlineage_unique_dataset_paths + + >>> results = [{"file": "azure://my_account.blob.core.windows.net/azure_container/dir3/file.csv"}] + >>> method(results) + ([('wasbs://azure_container@my_account', 'dir3')], []) + + >>> results = [{"file": "azure://my_account.blob.core.windows.net/azure_container"}] + >>> method(results) + ([('wasbs://azure_container@my_account', '/')], []) + + >>> results = [{"file": "s3://bucket"}, {"file": "gcs://bucket/"}, {"file": "s3://bucket/a.csv"}] + >>> method(results) + ([('gcs://bucket', '/'), ('s3://bucket', '/')], []) + + >>> results = [{"file": "s3://bucket/dir/file.csv"}, {"file": "gcs://bucket/dir/dir2/a.txt"}] + >>> method(results) + ([('gcs://bucket', 'dir/dir2'), ('s3://bucket', 'dir')], []) + + >>> results = [ + ... {"file": "s3://bucket/dir/file.csv"}, + ... {"file": "azure://my_account.something_new.windows.net/azure_container"}, + ... ] + >>> method(results) + ([('s3://bucket', 'dir')], ['azure://my_account.something_new.windows.net/azure_container']) + """ + import re + from pathlib import Path + from urllib.parse import urlparse + + azure_regex = r"azure:\/\/(\w+)?\.blob.core.windows.net\/(\w+)\/?(.*)?" + extraction_error_files = [] + unique_dataset_paths = set() + + for row in query_result: + uri = urlparse(row["file"]) + if uri.scheme == "azure": + match = re.fullmatch(azure_regex, row["file"]) + if not match: + extraction_error_files.append(row["file"]) + continue + account_name, container_name, name = match.groups() + namespace = f"wasbs://{container_name}@{account_name}" + else: + namespace = f"{uri.scheme}://{uri.netloc}" + name = uri.path.lstrip("/") + + name = Path(name).parent.as_posix() + if name in ("", "."): + name = "/" + + unique_dataset_paths.add((namespace, name)) + + return sorted(unique_dataset_paths), sorted(extraction_error_files) + + def get_openlineage_facets_on_complete(self, task_instance): + """Implement _on_complete because we rely on return value of a query.""" + import re + + from openlineage.client.facet import ( + ExternalQueryRunFacet, + ExtractionError, + ExtractionErrorRunFacet, + SqlJobFacet, + ) + from openlineage.client.run import Dataset + + from airflow.providers.openlineage.extractors import OperatorLineage + from airflow.providers.openlineage.sqlparser import SQLParser + + if not self._sql: + return OperatorLineage() + + query_results = self._result or [] + # If no files were uploaded we get [{"status": "0 files were uploaded..."}] + if len(query_results) == 1 and query_results[0].get("status"): + query_results = [] + unique_dataset_paths, extraction_error_files = self._extract_openlineage_unique_dataset_paths( + query_results + ) + input_datasets = [Dataset(namespace=namespace, name=name) for namespace, name in unique_dataset_paths] + + run_facets = {} + if extraction_error_files: + self.log.debug( + f"Unable to extract Dataset namespace and name " + f"for the following files: `{extraction_error_files}`." + ) + run_facets["extractionError"] = ExtractionErrorRunFacet( + totalTasks=len(query_results), + failedTasks=len(extraction_error_files), + errors=[ + ExtractionError( + errorMessage="Unable to extract Dataset namespace and name.", + stackTrace=None, + task=file_uri, + taskNumber=None, + ) + for file_uri in extraction_error_files + ], + ) + + connection = self.hook.get_connection(getattr(self.hook, str(self.hook.conn_name_attr))) + database_info = self.hook.get_openlineage_database_info(connection) + + dest_name = self.table + schema = self.hook.get_openlineage_default_schema() + database = database_info.database + if schema: + dest_name = f"{schema}.{dest_name}" + if database: + dest_name = f"{database}.{dest_name}" + + snowflake_namespace = SQLParser.create_namespace(database_info) + query = SQLParser.normalize_sql(self._sql) + query = re.sub(r"\n+", "\n", re.sub(r" +", " ", query)) + + run_facets["externalQuery"] = ExternalQueryRunFacet( + externalQueryId=self.hook.query_ids[0], source=snowflake_namespace + ) + + return OperatorLineage( + inputs=input_datasets, + outputs=[Dataset(namespace=snowflake_namespace, name=dest_name)], + job_facets={"sql": SqlJobFacet(query=query)}, + run_facets=run_facets, + ) diff --git a/tests/providers/snowflake/transfers/test_copy_into_snowflake.py b/tests/providers/snowflake/transfers/test_copy_into_snowflake.py index 76268d077d..27e02dc41c 100644 --- a/tests/providers/snowflake/transfers/test_copy_into_snowflake.py +++ b/tests/providers/snowflake/transfers/test_copy_into_snowflake.py @@ -16,8 +16,20 @@ # under the License. from __future__ import annotations +from typing import Callable from unittest import mock +from openlineage.client.facet import ( + ExternalQueryRunFacet, + ExtractionError, + ExtractionErrorRunFacet, + SqlJobFacet, +) +from openlineage.client.run import Dataset +from pytest import mark + +from airflow.providers.openlineage.extractors import OperatorLineage +from airflow.providers.openlineage.sqlparser import DatabaseInfo from airflow.providers.snowflake.transfers.copy_into_snowflake import CopyFromExternalStageToSnowflakeOperator @@ -62,4 +74,158 @@ class TestCopyFromExternalStageToSnowflake: validation_mode """ - mock_hook.return_value.run.assert_called_once_with(sql=sql, autocommit=True) + mock_hook.return_value.run.assert_called_once_with( + sql=sql, autocommit=True, return_dictionaries=True, handler=mock.ANY + ) + + handler = mock_hook.return_value.run.mock_calls[0].kwargs.get("handler") + assert isinstance(handler, Callable) + + @mock.patch("airflow.providers.snowflake.transfers.copy_into_snowflake.SnowflakeHook") + def test_get_openlineage_facets_on_complete(self, mock_hook): + mock_hook().run.return_value = [ + {"file": "s3://aws_bucket_name/dir1/file.csv"}, + {"file": "s3://aws_bucket_name_2"}, + {"file": "gcs://gcs_bucket_name/dir2/file.csv"}, + {"file": "gcs://gcs_bucket_name_2"}, + {"file": "azure://my_account.blob.core.windows.net/azure_container/dir3/file.csv"}, + {"file": "azure://my_account.blob.core.windows.net/azure_container_2"}, + ] + mock_hook().get_openlineage_database_info.return_value = DatabaseInfo( + scheme="snowflake_scheme", authority="authority", database="actual_database" + ) + mock_hook().get_openlineage_default_schema.return_value = "actual_schema" + mock_hook().query_ids = ["query_id_123"] + + expected_inputs = [ + Dataset(namespace="gcs://gcs_bucket_name", name="dir2"), + Dataset(namespace="gcs://gcs_bucket_name_2", name="/"), + Dataset(namespace="s3://aws_bucket_name", name="dir1"), + Dataset(namespace="s3://aws_bucket_name_2", name="/"), + Dataset(namespace="wasbs://azure_container@my_account", name="dir3"), + Dataset(namespace="wasbs://azure_container_2@my_account", name="/"), + ] + expected_outputs = [ + Dataset(namespace="snowflake_scheme://authority", name="actual_database.actual_schema.table") + ] + expected_sql = """COPY INTO schema.table\n FROM @stage/\n FILE_FORMAT=CSV""" + + op = CopyFromExternalStageToSnowflakeOperator( + task_id="test", + table="table", + stage="stage", + database="", + schema="schema", + file_format="CSV", + ) + op.execute(None) + result = op.get_openlineage_facets_on_complete(None) + assert result == OperatorLineage( + inputs=expected_inputs, + outputs=expected_outputs, + run_facets={ + "externalQuery": ExternalQueryRunFacet( + externalQueryId="query_id_123", source="snowflake_scheme://authority" + ) + }, + job_facets={"sql": SqlJobFacet(query=expected_sql)}, + ) + + @mark.parametrize("rows", (None, [])) + @mock.patch("airflow.providers.snowflake.transfers.copy_into_snowflake.SnowflakeHook") + def test_get_openlineage_facets_on_complete_with_empty_inputs(self, mock_hook, rows): + mock_hook().run.return_value = rows + mock_hook().get_openlineage_database_info.return_value = DatabaseInfo( + scheme="snowflake_scheme", authority="authority", database="actual_database" + ) + mock_hook().get_openlineage_default_schema.return_value = "actual_schema" + mock_hook().query_ids = ["query_id_123"] + + expected_outputs = [ + Dataset(namespace="snowflake_scheme://authority", name="actual_database.actual_schema.table") + ] + expected_sql = """COPY INTO schema.table\n FROM @stage/\n FILE_FORMAT=CSV""" + + op = CopyFromExternalStageToSnowflakeOperator( + task_id="test", + table="table", + stage="stage", + database="", + schema="schema", + file_format="CSV", + ) + op.execute(None) + result = op.get_openlineage_facets_on_complete(None) + assert result == OperatorLineage( + inputs=[], + outputs=expected_outputs, + run_facets={ + "externalQuery": ExternalQueryRunFacet( + externalQueryId="query_id_123", source="snowflake_scheme://authority" + ) + }, + job_facets={"sql": SqlJobFacet(query=expected_sql)}, + ) + + @mock.patch("airflow.providers.snowflake.transfers.copy_into_snowflake.SnowflakeHook") + def test_get_openlineage_facets_on_complete_unsupported_azure_uri(self, mock_hook): + mock_hook().run.return_value = [ + {"file": "s3://aws_bucket_name/dir1/file.csv"}, + {"file": "gs://gcp_bucket_name/dir2/file.csv"}, + {"file": "azure://my_account.weird-url.net/azure_container/dir3/file.csv"}, + {"file": "azure://my_account.another_weird-url.net/con/file.csv"}, + ] + mock_hook().get_openlineage_database_info.return_value = DatabaseInfo( + scheme="snowflake_scheme", authority="authority", database="actual_database" + ) + mock_hook().get_openlineage_default_schema.return_value = "actual_schema" + mock_hook().query_ids = ["query_id_123"] + + expected_inputs = [ + Dataset(namespace="gs://gcp_bucket_name", name="dir2"), + Dataset(namespace="s3://aws_bucket_name", name="dir1"), + ] + expected_outputs = [ + Dataset(namespace="snowflake_scheme://authority", name="actual_database.actual_schema.table") + ] + expected_sql = """COPY INTO schema.table\n FROM @stage/\n FILE_FORMAT=CSV""" + expected_run_facets = { + "extractionError": ExtractionErrorRunFacet( + totalTasks=4, + failedTasks=2, + errors=[ + ExtractionError( + errorMessage="Unable to extract Dataset namespace and name.", + stackTrace=None, + task="azure://my_account.another_weird-url.net/con/file.csv", + taskNumber=None, + ), + ExtractionError( + errorMessage="Unable to extract Dataset namespace and name.", + stackTrace=None, + task="azure://my_account.weird-url.net/azure_container/dir3/file.csv", + taskNumber=None, + ), + ], + ), + "externalQuery": ExternalQueryRunFacet( + externalQueryId="query_id_123", source="snowflake_scheme://authority" + ), + } + + op = CopyFromExternalStageToSnowflakeOperator( + task_id="test", + table="table", + stage="stage", + database="", + schema="schema", + file_format="CSV", + ) + op.execute(None) + result = op.get_openlineage_facets_on_complete(None) + assert result == OperatorLineage( + inputs=expected_inputs, + outputs=expected_outputs, + run_facets=expected_run_facets, + job_facets={"sql": SqlJobFacet(query=expected_sql)}, + )