kacpermuda commented on code in PR #45257:
URL: https://github.com/apache/airflow/pull/45257#discussion_r1905679785
##########
providers/src/airflow/providers/databricks/operators/databricks_sql.py:
##########
@@ -349,12 +360,233 @@ def _create_sql_query(self) -> str:
return sql.strip()
def execute(self, context: Context) -> Any:
- sql = self._create_sql_query()
- self.log.info("Executing: %s", sql)
+ """Execute the COPY INTO command and store the result for lineage
reporting."""
+ self._sql = self._create_sql_query()
+ self.log.info("Executing SQL: %s", self._sql)
+
hook = self._get_hook()
- hook.run(sql)
+ result = hook.run(self._sql, handler=lambda cur: cur.fetchall())
+ # Convert to list, handling the case where result might be None
+ self._result = list(result) if result is not None else []
def on_kill(self) -> None:
# NB: on_kill isn't required for this operator since query cancelling
gets
# handled in `DatabricksSqlHook.run()` method which is called in
`execute()`
...
+
+ def get_openlineage_facets_on_complete(self, task_instance):
+ """
+ Compute OpenLineage facets for the COPY INTO command.
+
+ Attempts to parse input files (from S3, GCS, Azure Blob, etc.) and
build an
+ input dataset list and an output dataset (the Delta table).
+ """
+ import re
+ from urllib.parse import urlparse
+
+ from airflow.providers.common.compat.openlineage.facet import (
+ Dataset,
+ Error,
+ ExternalQueryRunFacet,
+ ExtractionErrorRunFacet,
+ SQLJobFacet,
+ )
+ from airflow.providers.openlineage.extractors import OperatorLineage
+ from airflow.providers.openlineage.sqlparser import SQLParser
+
+ if not self._sql:
+ self.log.warning("No SQL query found, returning empty
OperatorLineage.")
+ return OperatorLineage()
+
+ input_datasets = []
+ extraction_errors = []
+ job_facets = {}
+ run_facets = {}
+
+ # Parse file_location to build the input dataset (if possible).
+ if self.file_location:
+ try:
+ parsed_uri = urlparse(self.file_location)
+ # Only process known schemes
+ if parsed_uri.scheme not in ("s3", "s3a", "s3n", "gs",
"azure", "abfss", "wasbs"):
+ raise ValueError(f"Unsupported scheme:
{parsed_uri.scheme}")
+
+ # Keep original scheme for s3/s3a/s3n
+ scheme = parsed_uri.scheme
+ namespace = f"{scheme}://{parsed_uri.netloc}"
+ path = parsed_uri.path.lstrip("/") or "/"
+ input_datasets.append(Dataset(namespace=namespace, name=path))
+ except Exception as e:
+ self.log.error("Failed to parse file_location: %s, error: %s",
self.file_location, str(e))
+ extraction_errors.append(
+ Error(errorMessage=str(e), stackTrace=None,
task=self.file_location, taskNumber=None)
+ )
+
+ # Build SQLJobFacet
+ try:
+ normalized_sql = SQLParser.normalize_sql(self._sql)
+ normalized_sql = re.sub(r"\n+", "\n", re.sub(r" +", " ",
normalized_sql))
Review Comment:
I think we usually only use `SQLParser.normalize_sql` for the SQLJobFacet.
What is the reason for this additional replacements? Could you add some
comments if it's necessary ?
##########
providers/src/airflow/providers/databricks/operators/databricks_sql.py:
##########
@@ -349,12 +360,233 @@ def _create_sql_query(self) -> str:
return sql.strip()
def execute(self, context: Context) -> Any:
- sql = self._create_sql_query()
- self.log.info("Executing: %s", sql)
+ """Execute the COPY INTO command and store the result for lineage
reporting."""
+ self._sql = self._create_sql_query()
+ self.log.info("Executing SQL: %s", self._sql)
+
hook = self._get_hook()
- hook.run(sql)
+ result = hook.run(self._sql, handler=lambda cur: cur.fetchall())
+ # Convert to list, handling the case where result might be None
+ self._result = list(result) if result is not None else []
Review Comment:
What is the result saved here? Later in the code It appears to be query_ids,
but are we sure that is what we are getting? What if somebody submits a query
that reads a million rows? I'm asking because it looks like a place with a lot
of potential to add a lot of processing even for users that do not use
OpenLineage integration.
##########
providers/src/airflow/providers/databricks/operators/databricks_sql.py:
##########
@@ -349,12 +360,233 @@ def _create_sql_query(self) -> str:
return sql.strip()
def execute(self, context: Context) -> Any:
- sql = self._create_sql_query()
- self.log.info("Executing: %s", sql)
+ """Execute the COPY INTO command and store the result for lineage
reporting."""
+ self._sql = self._create_sql_query()
+ self.log.info("Executing SQL: %s", self._sql)
+
hook = self._get_hook()
- hook.run(sql)
+ result = hook.run(self._sql, handler=lambda cur: cur.fetchall())
+ # Convert to list, handling the case where result might be None
+ self._result = list(result) if result is not None else []
def on_kill(self) -> None:
# NB: on_kill isn't required for this operator since query cancelling
gets
# handled in `DatabricksSqlHook.run()` method which is called in
`execute()`
...
+
+ def get_openlineage_facets_on_complete(self, task_instance):
Review Comment:
Overall, this is a really long method. Maybe we can somehow split it into
some smaller, logical parts if possible? If not, maybe somehow refactor it? I
think indentation make it harder to read, when there is a lot of logic inside a
single if. Maybe those code chunks should be separate methods?
##########
providers/src/airflow/providers/databricks/operators/databricks_sql.py:
##########
@@ -349,12 +360,233 @@ def _create_sql_query(self) -> str:
return sql.strip()
def execute(self, context: Context) -> Any:
- sql = self._create_sql_query()
- self.log.info("Executing: %s", sql)
+ """Execute the COPY INTO command and store the result for lineage
reporting."""
+ self._sql = self._create_sql_query()
+ self.log.info("Executing SQL: %s", self._sql)
+
hook = self._get_hook()
- hook.run(sql)
+ result = hook.run(self._sql, handler=lambda cur: cur.fetchall())
+ # Convert to list, handling the case where result might be None
+ self._result = list(result) if result is not None else []
def on_kill(self) -> None:
# NB: on_kill isn't required for this operator since query cancelling
gets
# handled in `DatabricksSqlHook.run()` method which is called in
`execute()`
...
+
+ def get_openlineage_facets_on_complete(self, task_instance):
+ """
+ Compute OpenLineage facets for the COPY INTO command.
+
+ Attempts to parse input files (from S3, GCS, Azure Blob, etc.) and
build an
+ input dataset list and an output dataset (the Delta table).
+ """
+ import re
+ from urllib.parse import urlparse
+
+ from airflow.providers.common.compat.openlineage.facet import (
+ Dataset,
+ Error,
+ ExternalQueryRunFacet,
+ ExtractionErrorRunFacet,
+ SQLJobFacet,
+ )
+ from airflow.providers.openlineage.extractors import OperatorLineage
+ from airflow.providers.openlineage.sqlparser import SQLParser
+
+ if not self._sql:
+ self.log.warning("No SQL query found, returning empty
OperatorLineage.")
+ return OperatorLineage()
+
+ input_datasets = []
+ extraction_errors = []
+ job_facets = {}
+ run_facets = {}
+
+ # Parse file_location to build the input dataset (if possible).
+ if self.file_location:
+ try:
+ parsed_uri = urlparse(self.file_location)
+ # Only process known schemes
+ if parsed_uri.scheme not in ("s3", "s3a", "s3n", "gs",
"azure", "abfss", "wasbs"):
+ raise ValueError(f"Unsupported scheme:
{parsed_uri.scheme}")
+
+ # Keep original scheme for s3/s3a/s3n
+ scheme = parsed_uri.scheme
+ namespace = f"{scheme}://{parsed_uri.netloc}"
+ path = parsed_uri.path.lstrip("/") or "/"
+ input_datasets.append(Dataset(namespace=namespace, name=path))
+ except Exception as e:
+ self.log.error("Failed to parse file_location: %s, error: %s",
self.file_location, str(e))
+ extraction_errors.append(
+ Error(errorMessage=str(e), stackTrace=None,
task=self.file_location, taskNumber=None)
+ )
+
+ # Build SQLJobFacet
+ try:
+ normalized_sql = SQLParser.normalize_sql(self._sql)
+ normalized_sql = re.sub(r"\n+", "\n", re.sub(r" +", " ",
normalized_sql))
+ job_facets["sql"] = SQLJobFacet(query=normalized_sql)
+ except Exception as e:
+ self.log.error("Failed creating SQL job facet: %s", str(e))
+ extraction_errors.append(
+ Error(errorMessage=str(e), stackTrace=None,
task="sql_facet_creation", taskNumber=None)
+ )
+
+ # Add extraction error facet if there are any errors
+ if extraction_errors:
+ run_facets["extractionError"] = ExtractionErrorRunFacet(
+ totalTasks=1,
+ failedTasks=len(extraction_errors),
+ errors=extraction_errors,
+ )
+ # Return only error facets for invalid URIs
+ return OperatorLineage(
+ inputs=[],
+ outputs=[],
+ job_facets=job_facets,
+ run_facets=run_facets,
+ )
+
+ # Only proceed with output dataset if input was valid
+ output_dataset = None
+ if self.table_name:
+ try:
+ table_parts = self.table_name.split(".")
+ if len(table_parts) == 3: # catalog.schema.table
+ catalog, schema, table = table_parts
+ elif len(table_parts) == 2: # schema.table
+ catalog = None
+ schema, table = table_parts
+ else:
+ catalog = None
+ schema = None
+ table = self.table_name
+
+ hook = self._get_hook()
+ conn = hook.get_connection(hook.databricks_conn_id)
+ output_namespace = f"databricks://{conn.host}"
+
+ # Combine schema/table with optional catalog for final dataset
name
+ fq_name = table
+ if schema:
+ fq_name = f"{schema}.{fq_name}"
+ if catalog:
+ fq_name = f"{catalog}.{fq_name}"
+
+ output_dataset = Dataset(namespace=output_namespace,
name=fq_name)
+ except Exception as e:
+ self.log.error("Failed to construct output dataset: %s",
str(e))
+ extraction_errors.append(
+ Error(
+ errorMessage=str(e),
+ stackTrace=None,
+ task="output_dataset_construction",
+ taskNumber=None,
+ )
+ )
Review Comment:
We are not using the extraction_errors later in the code, so there is no
point in appending here. Maybe the ExtractionErrorFacet should be created at
the very end?
##########
providers/src/airflow/providers/databricks/operators/databricks_sql.py:
##########
@@ -349,12 +360,233 @@ def _create_sql_query(self) -> str:
return sql.strip()
def execute(self, context: Context) -> Any:
- sql = self._create_sql_query()
- self.log.info("Executing: %s", sql)
+ """Execute the COPY INTO command and store the result for lineage
reporting."""
+ self._sql = self._create_sql_query()
+ self.log.info("Executing SQL: %s", self._sql)
+
hook = self._get_hook()
- hook.run(sql)
+ result = hook.run(self._sql, handler=lambda cur: cur.fetchall())
+ # Convert to list, handling the case where result might be None
+ self._result = list(result) if result is not None else []
def on_kill(self) -> None:
# NB: on_kill isn't required for this operator since query cancelling
gets
# handled in `DatabricksSqlHook.run()` method which is called in
`execute()`
...
+
+ def get_openlineage_facets_on_complete(self, task_instance):
+ """
+ Compute OpenLineage facets for the COPY INTO command.
+
+ Attempts to parse input files (from S3, GCS, Azure Blob, etc.) and
build an
+ input dataset list and an output dataset (the Delta table).
+ """
+ import re
+ from urllib.parse import urlparse
+
+ from airflow.providers.common.compat.openlineage.facet import (
+ Dataset,
+ Error,
+ ExternalQueryRunFacet,
+ ExtractionErrorRunFacet,
+ SQLJobFacet,
+ )
+ from airflow.providers.openlineage.extractors import OperatorLineage
+ from airflow.providers.openlineage.sqlparser import SQLParser
+
+ if not self._sql:
+ self.log.warning("No SQL query found, returning empty
OperatorLineage.")
+ return OperatorLineage()
+
+ input_datasets = []
+ extraction_errors = []
+ job_facets = {}
+ run_facets = {}
+
+ # Parse file_location to build the input dataset (if possible).
+ if self.file_location:
+ try:
+ parsed_uri = urlparse(self.file_location)
+ # Only process known schemes
+ if parsed_uri.scheme not in ("s3", "s3a", "s3n", "gs",
"azure", "abfss", "wasbs"):
+ raise ValueError(f"Unsupported scheme:
{parsed_uri.scheme}")
+
+ # Keep original scheme for s3/s3a/s3n
+ scheme = parsed_uri.scheme
+ namespace = f"{scheme}://{parsed_uri.netloc}"
+ path = parsed_uri.path.lstrip("/") or "/"
+ input_datasets.append(Dataset(namespace=namespace, name=path))
+ except Exception as e:
+ self.log.error("Failed to parse file_location: %s, error: %s",
self.file_location, str(e))
+ extraction_errors.append(
+ Error(errorMessage=str(e), stackTrace=None,
task=self.file_location, taskNumber=None)
+ )
+
+ # Build SQLJobFacet
+ try:
+ normalized_sql = SQLParser.normalize_sql(self._sql)
+ normalized_sql = re.sub(r"\n+", "\n", re.sub(r" +", " ",
normalized_sql))
+ job_facets["sql"] = SQLJobFacet(query=normalized_sql)
+ except Exception as e:
+ self.log.error("Failed creating SQL job facet: %s", str(e))
+ extraction_errors.append(
+ Error(errorMessage=str(e), stackTrace=None,
task="sql_facet_creation", taskNumber=None)
+ )
+
+ # Add extraction error facet if there are any errors
+ if extraction_errors:
+ run_facets["extractionError"] = ExtractionErrorRunFacet(
+ totalTasks=1,
+ failedTasks=len(extraction_errors),
+ errors=extraction_errors,
+ )
+ # Return only error facets for invalid URIs
+ return OperatorLineage(
+ inputs=[],
+ outputs=[],
+ job_facets=job_facets,
+ run_facets=run_facets,
+ )
+
+ # Only proceed with output dataset if input was valid
+ output_dataset = None
+ if self.table_name:
+ try:
+ table_parts = self.table_name.split(".")
+ if len(table_parts) == 3: # catalog.schema.table
+ catalog, schema, table = table_parts
+ elif len(table_parts) == 2: # schema.table
+ catalog = None
+ schema, table = table_parts
+ else:
+ catalog = None
+ schema = None
+ table = self.table_name
+
+ hook = self._get_hook()
+ conn = hook.get_connection(hook.databricks_conn_id)
+ output_namespace = f"databricks://{conn.host}"
+
+ # Combine schema/table with optional catalog for final dataset
name
+ fq_name = table
+ if schema:
+ fq_name = f"{schema}.{fq_name}"
+ if catalog:
+ fq_name = f"{catalog}.{fq_name}"
+
+ output_dataset = Dataset(namespace=output_namespace,
name=fq_name)
+ except Exception as e:
+ self.log.error("Failed to construct output dataset: %s",
str(e))
+ extraction_errors.append(
+ Error(
+ errorMessage=str(e),
+ stackTrace=None,
+ task="output_dataset_construction",
+ taskNumber=None,
+ )
+ )
+
+ # Add external query facet if we have run results
+ if hasattr(self, "_result") and self._result:
+ run_facets["externalQuery"] = ExternalQueryRunFacet(
+ externalQueryId=str(id(self._result)),
+ source=output_dataset.namespace if output_dataset else
"databricks",
+ )
+
+ return OperatorLineage(
+ inputs=input_datasets,
+ outputs=[output_dataset] if output_dataset else [],
+ job_facets=job_facets,
+ run_facets=run_facets,
+ )
+
+ @staticmethod
+ def _extract_openlineage_unique_dataset_paths(
Review Comment:
Where is this method used? I don't see it.
##########
providers/src/airflow/providers/databricks/operators/databricks_sql.py:
##########
@@ -349,12 +360,233 @@ def _create_sql_query(self) -> str:
return sql.strip()
def execute(self, context: Context) -> Any:
- sql = self._create_sql_query()
- self.log.info("Executing: %s", sql)
+ """Execute the COPY INTO command and store the result for lineage
reporting."""
+ self._sql = self._create_sql_query()
+ self.log.info("Executing SQL: %s", self._sql)
+
hook = self._get_hook()
- hook.run(sql)
+ result = hook.run(self._sql, handler=lambda cur: cur.fetchall())
+ # Convert to list, handling the case where result might be None
+ self._result = list(result) if result is not None else []
def on_kill(self) -> None:
# NB: on_kill isn't required for this operator since query cancelling
gets
# handled in `DatabricksSqlHook.run()` method which is called in
`execute()`
...
+
+ def get_openlineage_facets_on_complete(self, task_instance):
+ """
+ Compute OpenLineage facets for the COPY INTO command.
+
+ Attempts to parse input files (from S3, GCS, Azure Blob, etc.) and
build an
+ input dataset list and an output dataset (the Delta table).
+ """
+ import re
+ from urllib.parse import urlparse
+
+ from airflow.providers.common.compat.openlineage.facet import (
+ Dataset,
+ Error,
+ ExternalQueryRunFacet,
+ ExtractionErrorRunFacet,
+ SQLJobFacet,
+ )
+ from airflow.providers.openlineage.extractors import OperatorLineage
+ from airflow.providers.openlineage.sqlparser import SQLParser
+
+ if not self._sql:
+ self.log.warning("No SQL query found, returning empty
OperatorLineage.")
+ return OperatorLineage()
+
+ input_datasets = []
+ extraction_errors = []
+ job_facets = {}
+ run_facets = {}
+
+ # Parse file_location to build the input dataset (if possible).
+ if self.file_location:
+ try:
+ parsed_uri = urlparse(self.file_location)
+ # Only process known schemes
+ if parsed_uri.scheme not in ("s3", "s3a", "s3n", "gs",
"azure", "abfss", "wasbs"):
+ raise ValueError(f"Unsupported scheme:
{parsed_uri.scheme}")
+
+ # Keep original scheme for s3/s3a/s3n
+ scheme = parsed_uri.scheme
+ namespace = f"{scheme}://{parsed_uri.netloc}"
+ path = parsed_uri.path.lstrip("/") or "/"
+ input_datasets.append(Dataset(namespace=namespace, name=path))
+ except Exception as e:
+ self.log.error("Failed to parse file_location: %s, error: %s",
self.file_location, str(e))
+ extraction_errors.append(
+ Error(errorMessage=str(e), stackTrace=None,
task=self.file_location, taskNumber=None)
+ )
+
+ # Build SQLJobFacet
+ try:
+ normalized_sql = SQLParser.normalize_sql(self._sql)
+ normalized_sql = re.sub(r"\n+", "\n", re.sub(r" +", " ",
normalized_sql))
+ job_facets["sql"] = SQLJobFacet(query=normalized_sql)
+ except Exception as e:
+ self.log.error("Failed creating SQL job facet: %s", str(e))
+ extraction_errors.append(
+ Error(errorMessage=str(e), stackTrace=None,
task="sql_facet_creation", taskNumber=None)
+ )
+
+ # Add extraction error facet if there are any errors
+ if extraction_errors:
+ run_facets["extractionError"] = ExtractionErrorRunFacet(
+ totalTasks=1,
+ failedTasks=len(extraction_errors),
+ errors=extraction_errors,
+ )
+ # Return only error facets for invalid URIs
+ return OperatorLineage(
+ inputs=[],
+ outputs=[],
+ job_facets=job_facets,
+ run_facets=run_facets,
+ )
+
+ # Only proceed with output dataset if input was valid
+ output_dataset = None
+ if self.table_name:
+ try:
+ table_parts = self.table_name.split(".")
+ if len(table_parts) == 3: # catalog.schema.table
+ catalog, schema, table = table_parts
+ elif len(table_parts) == 2: # schema.table
+ catalog = None
+ schema, table = table_parts
+ else:
+ catalog = None
+ schema = None
+ table = self.table_name
+
+ hook = self._get_hook()
+ conn = hook.get_connection(hook.databricks_conn_id)
+ output_namespace = f"databricks://{conn.host}"
+
+ # Combine schema/table with optional catalog for final dataset
name
+ fq_name = table
+ if schema:
+ fq_name = f"{schema}.{fq_name}"
+ if catalog:
+ fq_name = f"{catalog}.{fq_name}"
Review Comment:
We are not replacing None values with anything here, so we can end up with
`None.None.table_name` ?
##########
providers/src/airflow/providers/databricks/operators/databricks_sql.py:
##########
@@ -349,12 +360,233 @@ def _create_sql_query(self) -> str:
return sql.strip()
def execute(self, context: Context) -> Any:
- sql = self._create_sql_query()
- self.log.info("Executing: %s", sql)
+ """Execute the COPY INTO command and store the result for lineage
reporting."""
+ self._sql = self._create_sql_query()
+ self.log.info("Executing SQL: %s", self._sql)
+
hook = self._get_hook()
- hook.run(sql)
+ result = hook.run(self._sql, handler=lambda cur: cur.fetchall())
+ # Convert to list, handling the case where result might be None
+ self._result = list(result) if result is not None else []
def on_kill(self) -> None:
# NB: on_kill isn't required for this operator since query cancelling
gets
# handled in `DatabricksSqlHook.run()` method which is called in
`execute()`
...
+
+ def get_openlineage_facets_on_complete(self, task_instance):
+ """
+ Compute OpenLineage facets for the COPY INTO command.
+
+ Attempts to parse input files (from S3, GCS, Azure Blob, etc.) and
build an
+ input dataset list and an output dataset (the Delta table).
+ """
+ import re
+ from urllib.parse import urlparse
+
+ from airflow.providers.common.compat.openlineage.facet import (
+ Dataset,
+ Error,
+ ExternalQueryRunFacet,
+ ExtractionErrorRunFacet,
+ SQLJobFacet,
+ )
+ from airflow.providers.openlineage.extractors import OperatorLineage
+ from airflow.providers.openlineage.sqlparser import SQLParser
+
+ if not self._sql:
+ self.log.warning("No SQL query found, returning empty
OperatorLineage.")
+ return OperatorLineage()
+
+ input_datasets = []
+ extraction_errors = []
+ job_facets = {}
+ run_facets = {}
+
+ # Parse file_location to build the input dataset (if possible).
+ if self.file_location:
+ try:
+ parsed_uri = urlparse(self.file_location)
+ # Only process known schemes
+ if parsed_uri.scheme not in ("s3", "s3a", "s3n", "gs",
"azure", "abfss", "wasbs"):
+ raise ValueError(f"Unsupported scheme:
{parsed_uri.scheme}")
+
+ # Keep original scheme for s3/s3a/s3n
+ scheme = parsed_uri.scheme
+ namespace = f"{scheme}://{parsed_uri.netloc}"
+ path = parsed_uri.path.lstrip("/") or "/"
+ input_datasets.append(Dataset(namespace=namespace, name=path))
+ except Exception as e:
+ self.log.error("Failed to parse file_location: %s, error: %s",
self.file_location, str(e))
+ extraction_errors.append(
+ Error(errorMessage=str(e), stackTrace=None,
task=self.file_location, taskNumber=None)
+ )
+
+ # Build SQLJobFacet
+ try:
+ normalized_sql = SQLParser.normalize_sql(self._sql)
+ normalized_sql = re.sub(r"\n+", "\n", re.sub(r" +", " ",
normalized_sql))
+ job_facets["sql"] = SQLJobFacet(query=normalized_sql)
+ except Exception as e:
+ self.log.error("Failed creating SQL job facet: %s", str(e))
+ extraction_errors.append(
+ Error(errorMessage=str(e), stackTrace=None,
task="sql_facet_creation", taskNumber=None)
+ )
+
+ # Add extraction error facet if there are any errors
+ if extraction_errors:
+ run_facets["extractionError"] = ExtractionErrorRunFacet(
+ totalTasks=1,
+ failedTasks=len(extraction_errors),
+ errors=extraction_errors,
+ )
+ # Return only error facets for invalid URIs
+ return OperatorLineage(
+ inputs=[],
+ outputs=[],
+ job_facets=job_facets,
+ run_facets=run_facets,
+ )
+
+ # Only proceed with output dataset if input was valid
+ output_dataset = None
+ if self.table_name:
+ try:
+ table_parts = self.table_name.split(".")
+ if len(table_parts) == 3: # catalog.schema.table
+ catalog, schema, table = table_parts
+ elif len(table_parts) == 2: # schema.table
+ catalog = None
+ schema, table = table_parts
+ else:
+ catalog = None
+ schema = None
+ table = self.table_name
+
+ hook = self._get_hook()
+ conn = hook.get_connection(hook.databricks_conn_id)
+ output_namespace = f"databricks://{conn.host}"
+
+ # Combine schema/table with optional catalog for final dataset
name
+ fq_name = table
+ if schema:
+ fq_name = f"{schema}.{fq_name}"
+ if catalog:
+ fq_name = f"{catalog}.{fq_name}"
+
+ output_dataset = Dataset(namespace=output_namespace,
name=fq_name)
+ except Exception as e:
+ self.log.error("Failed to construct output dataset: %s",
str(e))
+ extraction_errors.append(
+ Error(
+ errorMessage=str(e),
+ stackTrace=None,
+ task="output_dataset_construction",
+ taskNumber=None,
+ )
+ )
+
+ # Add external query facet if we have run results
+ if hasattr(self, "_result") and self._result:
Review Comment:
I think this `hasattr` is redundant, since we ourselves add it in init.
##########
providers/src/airflow/providers/databricks/operators/databricks_sql.py:
##########
@@ -349,12 +360,233 @@ def _create_sql_query(self) -> str:
return sql.strip()
def execute(self, context: Context) -> Any:
- sql = self._create_sql_query()
- self.log.info("Executing: %s", sql)
+ """Execute the COPY INTO command and store the result for lineage
reporting."""
+ self._sql = self._create_sql_query()
+ self.log.info("Executing SQL: %s", self._sql)
+
hook = self._get_hook()
- hook.run(sql)
+ result = hook.run(self._sql, handler=lambda cur: cur.fetchall())
+ # Convert to list, handling the case where result might be None
+ self._result = list(result) if result is not None else []
def on_kill(self) -> None:
# NB: on_kill isn't required for this operator since query cancelling
gets
# handled in `DatabricksSqlHook.run()` method which is called in
`execute()`
...
+
+ def get_openlineage_facets_on_complete(self, task_instance):
+ """
+ Compute OpenLineage facets for the COPY INTO command.
+
+ Attempts to parse input files (from S3, GCS, Azure Blob, etc.) and
build an
+ input dataset list and an output dataset (the Delta table).
+ """
+ import re
+ from urllib.parse import urlparse
+
+ from airflow.providers.common.compat.openlineage.facet import (
+ Dataset,
+ Error,
+ ExternalQueryRunFacet,
+ ExtractionErrorRunFacet,
+ SQLJobFacet,
+ )
+ from airflow.providers.openlineage.extractors import OperatorLineage
+ from airflow.providers.openlineage.sqlparser import SQLParser
+
+ if not self._sql:
+ self.log.warning("No SQL query found, returning empty
OperatorLineage.")
+ return OperatorLineage()
+
+ input_datasets = []
+ extraction_errors = []
+ job_facets = {}
+ run_facets = {}
+
+ # Parse file_location to build the input dataset (if possible).
+ if self.file_location:
+ try:
+ parsed_uri = urlparse(self.file_location)
+ # Only process known schemes
+ if parsed_uri.scheme not in ("s3", "s3a", "s3n", "gs",
"azure", "abfss", "wasbs"):
+ raise ValueError(f"Unsupported scheme:
{parsed_uri.scheme}")
+
+ # Keep original scheme for s3/s3a/s3n
+ scheme = parsed_uri.scheme
+ namespace = f"{scheme}://{parsed_uri.netloc}"
+ path = parsed_uri.path.lstrip("/") or "/"
+ input_datasets.append(Dataset(namespace=namespace, name=path))
+ except Exception as e:
+ self.log.error("Failed to parse file_location: %s, error: %s",
self.file_location, str(e))
+ extraction_errors.append(
+ Error(errorMessage=str(e), stackTrace=None,
task=self.file_location, taskNumber=None)
+ )
+
+ # Build SQLJobFacet
+ try:
+ normalized_sql = SQLParser.normalize_sql(self._sql)
+ normalized_sql = re.sub(r"\n+", "\n", re.sub(r" +", " ",
normalized_sql))
+ job_facets["sql"] = SQLJobFacet(query=normalized_sql)
+ except Exception as e:
+ self.log.error("Failed creating SQL job facet: %s", str(e))
+ extraction_errors.append(
+ Error(errorMessage=str(e), stackTrace=None,
task="sql_facet_creation", taskNumber=None)
+ )
+
+ # Add extraction error facet if there are any errors
+ if extraction_errors:
+ run_facets["extractionError"] = ExtractionErrorRunFacet(
+ totalTasks=1,
+ failedTasks=len(extraction_errors),
+ errors=extraction_errors,
+ )
+ # Return only error facets for invalid URIs
+ return OperatorLineage(
+ inputs=[],
+ outputs=[],
+ job_facets=job_facets,
+ run_facets=run_facets,
+ )
+
+ # Only proceed with output dataset if input was valid
+ output_dataset = None
+ if self.table_name:
+ try:
+ table_parts = self.table_name.split(".")
+ if len(table_parts) == 3: # catalog.schema.table
+ catalog, schema, table = table_parts
+ elif len(table_parts) == 2: # schema.table
+ catalog = None
+ schema, table = table_parts
+ else:
+ catalog = None
+ schema = None
+ table = self.table_name
+
+ hook = self._get_hook()
+ conn = hook.get_connection(hook.databricks_conn_id)
+ output_namespace = f"databricks://{conn.host}"
+
+ # Combine schema/table with optional catalog for final dataset
name
+ fq_name = table
+ if schema:
+ fq_name = f"{schema}.{fq_name}"
+ if catalog:
+ fq_name = f"{catalog}.{fq_name}"
+
+ output_dataset = Dataset(namespace=output_namespace,
name=fq_name)
+ except Exception as e:
+ self.log.error("Failed to construct output dataset: %s",
str(e))
+ extraction_errors.append(
+ Error(
+ errorMessage=str(e),
+ stackTrace=None,
+ task="output_dataset_construction",
+ taskNumber=None,
+ )
+ )
+
+ # Add external query facet if we have run results
+ if hasattr(self, "_result") and self._result:
+ run_facets["externalQuery"] = ExternalQueryRunFacet(
+ externalQueryId=str(id(self._result)),
Review Comment:
We are saving it as a list in execute and here we are converting it to
string. Why is that? Is it a single query_id or multiple ones?
##########
dev/breeze/tests/test_selective_checks.py:
##########
@@ -1762,13 +1762,13 @@ def test_expected_output_push(
"skip-providers-tests": "false",
"test-groups": "['core', 'providers']",
"docs-build": "true",
- "docs-list-as-string": "apache-airflow amazon common.compat
common.io common.sql dbt.cloud ftp google mysql openlineage postgres sftp
snowflake trino",
+ "docs-list-as-string": "apache-airflow amazon common.compat
common.io common.sql databricks dbt.cloud ftp google mysql openlineage postgres
sftp snowflake trino",
"skip-pre-commits":
"check-provider-yaml-valid,flynt,identity,lint-helm-chart,mypy-airflow,mypy-dev,mypy-docs,mypy-providers,mypy-task-sdk,"
"ts-compile-format-lint-ui,ts-compile-format-lint-www",
"run-kubernetes-tests": "false",
"upgrade-to-newer-dependencies": "false",
"core-test-types-list-as-string": "API Always CLI Core
Operators Other Serialization WWW",
- "providers-test-types-list-as-string": "Providers[amazon]
Providers[common.compat,common.io,common.sql,dbt.cloud,ftp,mysql,openlineage,postgres,sftp,snowflake,trino]
Providers[google]",
+ "providers-test-types-list-as-string": "Providers[amazon]
Providers[common.compat,common.io,common.sql,databricks,dbt.cloud,ftp,mysql,openlineage,postgres,sftp,snowflake,trino]
Providers[google]",
Review Comment:
I don't think any changes here are necessary, you are not adding any new
dependencies in the code. Try to submit the Pr without them and we'll see what
happens.
##########
generated/provider_dependencies.json:
##########
Review Comment:
I don't think any changes here are necessary, you are not adding any new
dependencies in the code. Try to submit the Pr without them and we'll see what
happens.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]