kacpermuda commented on code in PR #36535: URL: https://github.com/apache/airflow/pull/36535#discussion_r1444446974
########## airflow/providers/snowflake/transfers/copy_into_snowflake.py: ########## @@ -137,5 +141,158 @@ def execute(self, context: Any) -> None: {self.validation_mode or ""} """ self.log.info("Executing COPY command...") - snowflake_hook.run(sql=sql, autocommit=self.autocommit) + self._result = self.hook.run( # type: ignore # mypy does not work well with return_dictionaries=True + sql=self._sql, + autocommit=self.autocommit, + handler=lambda x: x.fetchall(), + return_dictionaries=True, + ) self.log.info("COPY command completed") + + @staticmethod + def _extract_openlineage_unique_dataset_paths( + query_result: list[dict[str, Any]], + ) -> tuple[list[tuple[str, str]], list[str]]: + """Extracts and returns unique OpenLineage dataset paths and file paths that failed to be parsed. + + Each row in the results is expected to have a 'file' field, which is a URI. + The function parses these URIs and constructs a set of unique OpenLineage (namespace, name) tuples. + Additionally, it captures any URIs that cannot be parsed or processed + and returns them in a separate error list. + + For Azure, Snowflake has a unique way of representing URI: + azure://<account_name>.blob.core.windows.net/<container_name>/path/to/file.csv + that is transformed by this function to a Dataset with more universal naming convention: + Dataset(namespace="wasbs://container_name@account_name", name="path/to"), as described at + https://github.com/OpenLineage/OpenLineage/blob/main/spec/Naming.md#wasbs-azure-blob-storage + + :param query_result: A list of dictionaries, each containing a 'file' key with a URI value. + :return: Two lists - the first is a sorted list of tuples, each representing a unique dataset path, + and the second contains any URIs that cannot be parsed or processed correctly. + + >>> method = CopyFromExternalStageToSnowflakeOperator._extract_openlineage_unique_dataset_paths + + >>> results = [{"file": "azure://my_account.blob.core.windows.net/azure_container/dir3/file.csv"}] + >>> method(results) + ([('wasbs://azure_container@my_account', 'dir3')], []) + + >>> results = [{"file": "azure://my_account.blob.core.windows.net/azure_container"}] + >>> method(results) + ([('wasbs://azure_container@my_account', '/')], []) + + >>> results = [{"file": "s3://bucket"}, {"file": "gcs://bucket/"}, {"file": "s3://bucket/a.csv"}] + >>> method(results) + ([('gcs://bucket', '/'), ('s3://bucket', '/')], []) + + >>> results = [{"file": "s3://bucket/dir/file.csv"}, {"file": "gcs://bucket/dir/dir2/a.txt"}] + >>> method(results) + ([('gcs://bucket', 'dir/dir2'), ('s3://bucket', 'dir')], []) + + >>> results = [ + ... {"file": "s3://bucket/dir/file.csv"}, + ... {"file": "azure://my_account.something_new.windows.net/azure_container"}, + ... ] + >>> method(results) + ([('s3://bucket', 'dir')], ['azure://my_account.something_new.windows.net/azure_container']) + """ + import re + from pathlib import Path + from urllib.parse import urlparse + + azure_regex = r"azure:\/\/(\w+)?\.blob.core.windows.net\/(\w+)\/?(.*)?" + extraction_error_files = [] + unique_dataset_paths = set() + + for row in query_result: + uri = urlparse(row["file"]) + if uri.scheme == "azure": + match = re.fullmatch(azure_regex, row["file"]) + if not match: + extraction_error_files.append(row["file"]) + continue + account_name, container_name, name = match.groups() + namespace = f"wasbs://{container_name}@{account_name}" + else: + namespace = f"{uri.scheme}://{uri.netloc}" + name = uri.path.lstrip("/") + + name = Path(name).parent.as_posix() + if name in ("", "."): + name = "/" + + unique_dataset_paths.add((namespace, name)) + + return sorted(unique_dataset_paths), sorted(extraction_error_files) + + def get_openlineage_facets_on_complete(self, task_instance): + """Implement _on_complete because we rely on return value of a query.""" + import re + + from openlineage.client.facet import ( + ExternalQueryRunFacet, + ExtractionError, + ExtractionErrorRunFacet, + SqlJobFacet, + ) + from openlineage.client.run import Dataset + + from airflow.providers.openlineage.extractors import OperatorLineage + from airflow.providers.openlineage.sqlparser import SQLParser + + if not self._sql: + return OperatorLineage() + + query_results = self._result or [] + # If no files were uploaded we get [{"status": "0 files were uploaded..."}] + if len(query_results) == 1 and query_results[0].get("status"): + query_results = [] + unique_dataset_paths, extraction_error_files = self._extract_openlineage_unique_dataset_paths( + query_results + ) + input_datasets = [Dataset(namespace=namespace, name=name) for namespace, name in unique_dataset_paths] + + run_facets = {} + if extraction_error_files: + self.log.debug( + f"Unable to extract Dataset namespace and name " + f"for the following files: `{extraction_error_files}`." + ) + run_facets["extractionError"] = ExtractionErrorRunFacet( + totalTasks=len(query_results), + failedTasks=len(extraction_error_files), + errors=[ + ExtractionError( + errorMessage="Unable to extract Dataset namespace and name.", + stackTrace=None, + task=file_uri, + taskNumber=None, + ) + for file_uri in extraction_error_files + ], + ) + + connection = self.hook.get_connection(getattr(self.hook, str(self.hook.conn_name_attr))) + database_info = self.hook.get_openlineage_database_info(connection) + + dest_name = self.table + schema = self.hook.get_openlineage_default_schema() + database = database_info.database + if schema: + dest_name = f"{schema}.{dest_name}" + if database: + dest_name = f"{database}.{dest_name}" + + snowflake_namespace = SQLParser.create_namespace(database_info) + query = SQLParser.normalize_sql(self._sql) + query = re.sub(r"\n+", "\n", re.sub(r" +", " ", query)) Review Comment: This is a purely aesthetic procedure. Query, created in the operator looks good in the operator, when writing code, but after being parsed it may have some unnecessary spaces and newlines (depending on params provided). By removing those, I made sure that what goes into the OL event is as short and simple as possible. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@airflow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org