rahul-madaan commented on code in PR #45257:
URL: https://github.com/apache/airflow/pull/45257#discussion_r1929281714


##########
providers/src/airflow/providers/databricks/operators/databricks_sql.py:
##########
@@ -349,12 +360,233 @@ def _create_sql_query(self) -> str:
         return sql.strip()
 
     def execute(self, context: Context) -> Any:
-        sql = self._create_sql_query()
-        self.log.info("Executing: %s", sql)
+        """Execute the COPY INTO command and store the result for lineage 
reporting."""
+        self._sql = self._create_sql_query()
+        self.log.info("Executing SQL: %s", self._sql)
+
         hook = self._get_hook()
-        hook.run(sql)
+        result = hook.run(self._sql, handler=lambda cur: cur.fetchall())
+        # Convert to list, handling the case where result might be None
+        self._result = list(result) if result is not None else []
 
     def on_kill(self) -> None:
         # NB: on_kill isn't required for this operator since query cancelling 
gets
         # handled in `DatabricksSqlHook.run()` method which is called in 
`execute()`
         ...
+
+    def get_openlineage_facets_on_complete(self, task_instance):
+        """
+        Compute OpenLineage facets for the COPY INTO command.
+
+        Attempts to parse input files (from S3, GCS, Azure Blob, etc.) and 
build an
+        input dataset list and an output dataset (the Delta table).
+        """
+        import re
+        from urllib.parse import urlparse
+
+        from airflow.providers.common.compat.openlineage.facet import (
+            Dataset,
+            Error,
+            ExternalQueryRunFacet,
+            ExtractionErrorRunFacet,
+            SQLJobFacet,
+        )
+        from airflow.providers.openlineage.extractors import OperatorLineage
+        from airflow.providers.openlineage.sqlparser import SQLParser
+
+        if not self._sql:
+            self.log.warning("No SQL query found, returning empty 
OperatorLineage.")
+            return OperatorLineage()
+
+        input_datasets = []
+        extraction_errors = []
+        job_facets = {}
+        run_facets = {}
+
+        # Parse file_location to build the input dataset (if possible).
+        if self.file_location:
+            try:
+                parsed_uri = urlparse(self.file_location)
+                # Only process known schemes
+                if parsed_uri.scheme not in ("s3", "s3a", "s3n", "gs", 
"azure", "abfss", "wasbs"):
+                    raise ValueError(f"Unsupported scheme: 
{parsed_uri.scheme}")
+
+                # Keep original scheme for s3/s3a/s3n
+                scheme = parsed_uri.scheme
+                namespace = f"{scheme}://{parsed_uri.netloc}"
+                path = parsed_uri.path.lstrip("/") or "/"
+                input_datasets.append(Dataset(namespace=namespace, name=path))
+            except Exception as e:
+                self.log.error("Failed to parse file_location: %s, error: %s", 
self.file_location, str(e))
+                extraction_errors.append(
+                    Error(errorMessage=str(e), stackTrace=None, 
task=self.file_location, taskNumber=None)
+                )
+
+        # Build SQLJobFacet
+        try:
+            normalized_sql = SQLParser.normalize_sql(self._sql)
+            normalized_sql = re.sub(r"\n+", "\n", re.sub(r" +", " ", 
normalized_sql))
+            job_facets["sql"] = SQLJobFacet(query=normalized_sql)
+        except Exception as e:
+            self.log.error("Failed creating SQL job facet: %s", str(e))
+            extraction_errors.append(
+                Error(errorMessage=str(e), stackTrace=None, 
task="sql_facet_creation", taskNumber=None)
+            )
+
+        # Add extraction error facet if there are any errors
+        if extraction_errors:
+            run_facets["extractionError"] = ExtractionErrorRunFacet(
+                totalTasks=1,
+                failedTasks=len(extraction_errors),
+                errors=extraction_errors,
+            )
+            # Return only error facets for invalid URIs
+            return OperatorLineage(
+                inputs=[],
+                outputs=[],
+                job_facets=job_facets,
+                run_facets=run_facets,
+            )
+
+        # Only proceed with output dataset if input was valid
+        output_dataset = None
+        if self.table_name:
+            try:
+                table_parts = self.table_name.split(".")
+                if len(table_parts) == 3:  # catalog.schema.table
+                    catalog, schema, table = table_parts
+                elif len(table_parts) == 2:  # schema.table
+                    catalog = None
+                    schema, table = table_parts
+                else:
+                    catalog = None
+                    schema = None
+                    table = self.table_name
+
+                hook = self._get_hook()
+                conn = hook.get_connection(hook.databricks_conn_id)
+                output_namespace = f"databricks://{conn.host}"
+
+                # Combine schema/table with optional catalog for final dataset 
name
+                fq_name = table
+                if schema:
+                    fq_name = f"{schema}.{fq_name}"
+                if catalog:
+                    fq_name = f"{catalog}.{fq_name}"
+
+                output_dataset = Dataset(namespace=output_namespace, 
name=fq_name)
+            except Exception as e:
+                self.log.error("Failed to construct output dataset: %s", 
str(e))
+                extraction_errors.append(
+                    Error(
+                        errorMessage=str(e),
+                        stackTrace=None,
+                        task="output_dataset_construction",
+                        taskNumber=None,
+                    )
+                )
+
+        # Add external query facet if we have run results
+        if hasattr(self, "_result") and self._result:
+            run_facets["externalQuery"] = ExternalQueryRunFacet(
+                externalQueryId=str(id(self._result)),
+                source=output_dataset.namespace if output_dataset else 
"databricks",
+            )
+
+        return OperatorLineage(
+            inputs=input_datasets,
+            outputs=[output_dataset] if output_dataset else [],
+            job_facets=job_facets,
+            run_facets=run_facets,
+        )
+
+    @staticmethod
+    def _extract_openlineage_unique_dataset_paths(

Review Comment:
   Junk method, forgot to remove it. apologies 😅 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to