This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 7acd9cb5350 Add session-level query tags to Databricks SQL operators
(#66895)
7acd9cb5350 is described below
commit 7acd9cb53505e0fde250a487524bbfb3793036ba
Author: Nguyễn Ngọc Thành <[email protected]>
AuthorDate: Tue May 26 13:34:29 2026 +0700
Add session-level query tags to Databricks SQL operators (#66895)
---
.../providers/databricks/hooks/databricks_sql.py | 64 +++++++-
.../databricks/operators/databricks_sql.py | 59 +++++++-
.../databricks/sensors/databricks_partition.py | 2 +-
.../providers/databricks/sensors/databricks_sql.py | 2 +-
.../unit/databricks/hooks/test_databricks_sql.py | 145 +++++++++++++++++-
.../databricks/operators/test_databricks_sql.py | 165 ++++++++++++++++++++-
6 files changed, 425 insertions(+), 12 deletions(-)
diff --git
a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
index 021142395b2..ffc089607c3 100644
---
a/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
+++
b/providers/databricks/src/airflow/providers/databricks/hooks/databricks_sql.py
@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
+import logging
import threading
from collections import namedtuple
from collections.abc import Callable, Iterable, Mapping, Sequence
@@ -51,6 +52,8 @@ if TYPE_CHECKING:
T = TypeVar("T")
+log = logging.getLogger(__name__)
+
def create_timeout_thread(
cur, execution_timeout: timedelta | None
@@ -71,6 +74,35 @@ def create_timeout_thread(
return timer, timeout_event
+def _format_query_tag_value(value: str) -> str:
+ """
+ Escape special characters and truncate a single query tag value.
+
+ Databricks ``QUERY_TAGS`` uses ``key:value`` pairs delimited by commas, so
+ backslash, comma and colon inside *values* must be escaped. Values are
also
+ capped at 128 characters before escaping to keep the overall tag string
+ within reasonable bounds.
+ """
+ raw = str(value)
+ if len(raw) > 128:
+ log.warning(
+ "Query tag value truncated to 128 characters (original length %d):
%r", len(raw), raw[:128]
+ )
+ value = raw[:128]
+ return value.replace("\\", "\\\\").replace(",", "\\,").replace(":", "\\:")
+
+
+def _format_query_tags(tags: dict[str, str | None]) -> str:
+ """
+ Serialize a query-tags dict to the ``key:value,key:value`` string expected
by ``QUERY_TAGS``.
+
+ Entries whose value is ``None`` are omitted.
+ """
+ return ",".join(
+ f"{key}:{_format_query_tag_value(value)}" for key, value in
tags.items() if value is not None
+ )
+
+
class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
"""
Hook to interact with Databricks SQL.
@@ -88,6 +120,10 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
on every request
:param catalog: An optional initial catalog to use. Requires DBR version
9.0+
:param schema: An optional initial schema to use. Requires DBR version 9.0+
+ :param query_tags: An optional dict of query tags to attach to every SQL
statement executed by
+ this hook. Tags are injected via the ``QUERY_TAGS`` Databricks
session parameter so they
+ appear in ``system.query.history``. Any existing ``QUERY_TAGS``
already present in
+ *session_configuration* are preserved and the new tags are appended.
:param kwargs: Additional parameters internal to Databricks SQL Connector
parameters
"""
@@ -104,6 +140,7 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
catalog: str | None = None,
schema: str | None = None,
caller: str = "DatabricksSqlHook",
+ query_tags: dict[str, str | None] | None = None,
**kwargs,
) -> None:
super().__init__(databricks_conn_id, caller=caller)
@@ -118,6 +155,7 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
self.schema = schema
self.additional_params = kwargs
self.query_ids: list[str] = []
+ self.query_tags = query_tags
def _get_extra_config(self) -> dict[str, Any | None]:
extra_params = copy(self.databricks_conn.extra_dejson)
@@ -169,20 +207,32 @@ class DatabricksSqlHook(BaseDatabricksHook, DbApiHook):
if not self.session_config:
self.session_config =
self.databricks_conn.extra_dejson.get("session_configuration")
+ # session_configuration (including QUERY_TAGS) is applied only when
opening a new
+ # connection; changing query_tags after the first get_conn() call has
no effect.
if not self._sql_conn or prev_token != new_token:
if self._sql_conn: # close already existing connection
self._sql_conn.close()
+ session_config: dict[str, str] = dict(self.session_config) if
self.session_config else {}
+ if self.query_tags:
+ tags_str = _format_query_tags(self.query_tags)
+ existing = session_config.get("QUERY_TAGS", "")
+ session_config["QUERY_TAGS"] = f"{existing},{tags_str}" if
existing else tags_str
+
+ connect_kwargs = {
+ "schema": self.schema,
+ "catalog": self.catalog,
+ "session_configuration": session_config or None,
+ "http_headers": self.http_headers,
+ "_user_agent_entry": self.user_agent_value,
+ **self._get_extra_config(),
+ **self.additional_params,
+ }
+
self._sql_conn = sql.connect(
self.host,
self._http_path,
self._token,
- schema=self.schema,
- catalog=self.catalog,
- session_configuration=self.session_config,
- http_headers=self.http_headers,
- _user_agent_entry=self.user_agent_value,
- **self._get_extra_config(),
- **self.additional_params,
+ **connect_kwargs,
)
if self._sql_conn is None:
diff --git
a/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py
b/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py
index b50c434d04c..f72514f2488 100644
---
a/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py
+++
b/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py
@@ -46,6 +46,24 @@ _IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
_DISALLOWED_SQL_TOKENS = (";", "--", "/*", "*/")
+def _get_airflow_query_tags(context: Context) -> dict[str, str | None]:
+ """Return Airflow context metadata as a query-tags dict."""
+ task_instance = context.get("ti")
+ if task_instance is None:
+ return {}
+
+ def _as_str(value: Any) -> str | None:
+ return None if value is None else str(value)
+
+ return {
+ "airflow_dag_id": _as_str(task_instance.dag_id),
+ "airflow_task_id": _as_str(task_instance.task_id),
+ "airflow_run_id": _as_str(task_instance.run_id),
+ "airflow_try_number": _as_str(task_instance.try_number),
+ "airflow_map_index": _as_str(task_instance.map_index),
+ }
+
+
class DatabricksSqlOperator(SQLExecuteQueryOperator):
"""
Executes SQL code in a Databricks SQL endpoint or a Databricks cluster.
@@ -68,6 +86,8 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator):
:param session_configuration: An optional dictionary of Spark session
parameters. Defaults to None.
If not specified, it could be specified in the Databricks connection's
extra parameters.
:param client_parameters: Additional parameters internal to Databricks SQL
Connector parameters
+ :param query_tags: Optional dictionary of query tags to attach to
Databricks SQL queries.
+ :param include_airflow_query_tags: If True, add Airflow DAG/task/run
metadata as query tags.
:param http_headers: An optional list of (k, v) pairs that will be set as
HTTP headers on every request.
(templated)
:param catalog: An optional initial catalog to use. Requires DBR version
9.0+ (templated)
@@ -93,6 +113,7 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator):
"http_headers",
"databricks_conn_id",
"_gcs_impersonation_chain",
+ "query_tags",
}
| set(SQLExecuteQueryOperator.template_fields)
)
@@ -115,6 +136,8 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator):
output_format: str = "csv",
csv_params: dict[str, Any] | None = None,
client_parameters: dict[str, Any] | None = None,
+ query_tags: dict[str, str | None] | None = None,
+ include_airflow_query_tags: bool = True,
gcp_conn_id: str = "google_cloud_default",
gcs_impersonation_chain: str | Sequence[str] | None = None,
**kwargs,
@@ -132,6 +155,8 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator):
self.http_headers = http_headers
self.catalog = catalog
self.schema = schema
+ self.query_tags = query_tags or {}
+ self.include_airflow_query_tags = include_airflow_query_tags
self._gcp_conn_id = gcp_conn_id
self._gcs_impersonation_chain = gcs_impersonation_chain
@@ -303,6 +328,20 @@ class DatabricksSqlOperator(SQLExecuteQueryOperator):
return list(zip(descriptions, results))
+ def _get_query_tags(self, context: Context) -> dict[str, str | None] |
None:
+ query_tags: dict[str, str | None] = {}
+
+ if self.include_airflow_query_tags and context is not None:
+ query_tags.update(_get_airflow_query_tags(context))
+
+ query_tags.update(self.query_tags)
+
+ return query_tags or None
+
+ def execute(self, context: Context) -> Any:
+ self.get_db_hook().query_tags = self._get_query_tags(context)
+ return super().execute(context)
+
COPY_INTO_APPROVED_FORMATS = ["CSV", "JSON", "AVRO", "ORC", "PARQUET", "TEXT",
"BINARYFILE"]
@@ -335,6 +374,8 @@ class DatabricksCopyIntoOperator(BaseOperator):
:param catalog: An optional initial catalog to use. Requires DBR version
9.0+
:param schema: An optional initial schema to use. Requires DBR version 9.0+
:param client_parameters: Additional parameters internal to Databricks SQL
Connector parameters
+ :param query_tags: Optional dictionary of query tags to attach to
Databricks SQL queries.
+ :param include_airflow_query_tags: If True, add Airflow DAG/task/run
metadata as query tags.
:param files: optional list of files to import. Can't be specified
together with ``pattern``. (templated)
:param pattern: optional regex string to match file names to import.
Can't be specified together with ``files``.
@@ -355,6 +396,7 @@ class DatabricksCopyIntoOperator(BaseOperator):
"files",
"table_name",
"databricks_conn_id",
+ "query_tags",
)
def __init__(
@@ -381,9 +423,11 @@ class DatabricksCopyIntoOperator(BaseOperator):
force_copy: bool | None = None,
copy_options: dict[str, str] | None = None,
validate: bool | int | None = None,
+ query_tags: dict[str, str | None] | None = None,
+ include_airflow_query_tags: bool = True,
**kwargs,
) -> None:
- """Create a new ``DatabricksSqlOperator``."""
+ """Create a new ``DatabricksCopyIntoOperator``."""
super().__init__(**kwargs)
if files is not None and pattern is not None:
raise AirflowException("Only one of 'pattern' or 'files' should be
specified")
@@ -413,6 +457,8 @@ class DatabricksCopyIntoOperator(BaseOperator):
self._validate = validate
self._http_headers = http_headers
self._client_parameters = client_parameters or {}
+ self.query_tags = query_tags or {}
+ self.include_airflow_query_tags = include_airflow_query_tags
if force_copy is not None:
self._copy_options["force"] = "true" if force_copy else "false"
self._sql: str | None = None
@@ -514,10 +560,21 @@ FILEFORMAT = {self._file_format}
"""
return sql.strip()
+ def _get_query_tags(self, context: Context) -> dict[str, str | None] |
None:
+ query_tags: dict[str, str | None] = {}
+
+ if self.include_airflow_query_tags and context is not None:
+ query_tags.update(_get_airflow_query_tags(context))
+
+ query_tags.update(self.query_tags)
+
+ return query_tags or None
+
def execute(self, context: Context) -> Any:
self._sql = self._create_sql_query()
self.log.info("Executing: %s", self._sql)
hook = self._get_hook()
+ hook.query_tags = self._get_query_tags(context)
hook.run(self._sql)
def on_kill(self) -> None:
diff --git
a/providers/databricks/src/airflow/providers/databricks/sensors/databricks_partition.py
b/providers/databricks/src/airflow/providers/databricks/sensors/databricks_partition.py
index b4501ef1d43..2036dca97c3 100644
---
a/providers/databricks/src/airflow/providers/databricks/sensors/databricks_partition.py
+++
b/providers/databricks/src/airflow/providers/databricks/sensors/databricks_partition.py
@@ -130,7 +130,7 @@ class DatabricksPartitionSensor(BaseSensorOperator):
self.http_headers,
self.catalog,
self.schema,
- self.caller,
+ caller=self.caller,
**self.client_parameters,
**self.hook_params,
)
diff --git
a/providers/databricks/src/airflow/providers/databricks/sensors/databricks_sql.py
b/providers/databricks/src/airflow/providers/databricks/sensors/databricks_sql.py
index 68a44a20fec..6ab52df67b5 100644
---
a/providers/databricks/src/airflow/providers/databricks/sensors/databricks_sql.py
+++
b/providers/databricks/src/airflow/providers/databricks/sensors/databricks_sql.py
@@ -109,7 +109,7 @@ class DatabricksSqlSensor(BaseSensorOperator):
self.http_headers,
self.catalog,
self.schema,
- self.caller,
+ caller=self.caller,
**self.client_parameters,
**self.hook_params,
)
diff --git
a/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py
b/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py
index f3f053d443c..cd3c00e2839 100644
--- a/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py
+++ b/providers/databricks/tests/unit/databricks/hooks/test_databricks_sql.py
@@ -32,7 +32,12 @@ from databricks.sql.types import Row
from airflow.models import Connection
from airflow.providers.common.compat.sdk import AirflowException,
AirflowOptionalProviderFeatureException
from airflow.providers.common.sql.hooks.handlers import fetch_all_handler
-from airflow.providers.databricks.hooks.databricks_sql import
DatabricksSqlHook, create_timeout_thread
+from airflow.providers.databricks.hooks.databricks_sql import (
+ DatabricksSqlHook,
+ _format_query_tag_value,
+ _format_query_tags,
+ create_timeout_thread,
+)
TASK_ID = "databricks-sql-operator"
DEFAULT_CONN_ID = "databricks_default"
@@ -792,3 +797,141 @@ class TestGetSqlEndpointByName:
hook = DatabricksSqlHook(sql_endpoint_name="Test")
with pytest.raises(RuntimeError, match="Can't list Databricks SQL
warehouses"):
hook._get_sql_endpoint_by_name("Test")
+
+
[email protected]("airflow.providers.databricks.hooks.databricks_sql.sql.connect")
+def test_get_conn_passes_query_tags_via_session_configuration(mock_connect,
mock_get_requests):
+ """query_tags must be injected into session_configuration['QUERY_TAGS'],
not sql.connect(query_tags=)."""
+ hook = DatabricksSqlHook(
+ databricks_conn_id=DEFAULT_CONN_ID,
+ http_path=HTTP_PATH,
+ query_tags={"airflow_dag_id": "dag_1", "airflow_task_id": "task_1"},
+ )
+
+ hook.get_conn()
+
+ mock_connect.assert_called_once()
+ session_cfg = mock_connect.call_args.kwargs["session_configuration"]
+ assert session_cfg is not None
+ assert "QUERY_TAGS" in session_cfg
+ query_tags_str = session_cfg["QUERY_TAGS"]
+ assert "airflow_dag_id:dag_1" in query_tags_str
+ assert "airflow_task_id:task_1" in query_tags_str
+ assert "query_tags" not in mock_connect.call_args.kwargs
+
+
[email protected]("airflow.providers.databricks.hooks.databricks_sql.sql.connect")
+def
test_get_conn_merges_query_tags_with_existing_session_configuration(mock_connect,
mock_get_requests):
+ """Existing QUERY_TAGS in session_configuration must be preserved and new
tags appended."""
+ hook = DatabricksSqlHook(
+ databricks_conn_id=DEFAULT_CONN_ID,
+ http_path=HTTP_PATH,
+ session_configuration={"QUERY_TAGS": "existing_tag:existing_value"},
+ query_tags={"airflow_dag_id": "dag_1"},
+ )
+
+ hook.get_conn()
+
+ mock_connect.assert_called_once()
+ session_cfg = mock_connect.call_args.kwargs["session_configuration"]
+ query_tags_str = session_cfg["QUERY_TAGS"]
+ assert "existing_tag:existing_value" in query_tags_str
+ assert "airflow_dag_id:dag_1" in query_tags_str
+
+
[email protected]("airflow.providers.databricks.hooks.databricks_sql.sql.connect")
+def test_get_conn_no_query_tags(mock_connect, mock_get_requests):
+ """When no query_tags are provided, session_configuration should not gain
a QUERY_TAGS key."""
+ hook = DatabricksSqlHook(
+ databricks_conn_id=DEFAULT_CONN_ID,
+ http_path=HTTP_PATH,
+ )
+
+ hook.get_conn()
+
+ mock_connect.assert_called_once()
+ session_cfg = mock_connect.call_args.kwargs.get("session_configuration")
+ assert session_cfg is None or "QUERY_TAGS" not in session_cfg
+
+
+class TestFormatQueryTags:
+ def test_simple_values(self):
+ result = _format_query_tags({"dag_id": "my_dag", "task_id": "my_task"})
+ assert "dag_id:my_dag" in result
+ assert "task_id:my_task" in result
+
+ def test_none_values_omitted(self):
+ result = _format_query_tags({"dag_id": "my_dag", "map_index": None})
+ assert "dag_id:my_dag" in result
+ assert "map_index" not in result
+
+ def test_empty_dict_returns_empty_string(self):
+ assert _format_query_tags({}) == ""
+
+ def test_value_escaping_comma(self):
+ result = _format_query_tag_value("a,b")
+ assert result == "a\\,b"
+
+ def test_value_escaping_colon(self):
+ result = _format_query_tag_value("a:b")
+ assert result == "a\\:b"
+
+ def test_value_escaping_backslash(self):
+ result = _format_query_tag_value("a\\b")
+ assert result == "a\\\\b"
+
+ def test_value_truncated_at_128_chars(self):
+ long_value = "x" * 200
+ result = _format_query_tag_value(long_value)
+ assert len(result) == 128
+
+ def test_format_query_tags_roundtrip(self):
+ tags = {"airflow_dag_id": "dag:1", "airflow_run_id": "run,2"}
+ result = _format_query_tags(tags)
+ assert "airflow_dag_id:dag\\:1" in result
+ assert "airflow_run_id:run\\,2" in result
+
+
+class TestDatabricksSqlHookQueryTagsParamOrder:
+ """Ensure moving query_tags after caller preserves positional backward
compatibility."""
+
+ def test_query_tags_keyword_sets_field(self):
+ """query_tags kwarg must be stored on the instance."""
+ with patch(
+
"airflow.providers.databricks.hooks.databricks_sql.BaseDatabricksHook.__init__",
+ return_value=None,
+ ) as mock_base_init:
+ hook = DatabricksSqlHook.__new__(DatabricksSqlHook)
+ DatabricksSqlHook.__init__(
+ hook,
+ DEFAULT_CONN_ID,
+ query_tags={"key": "val"},
+ )
+ assert hook.query_tags == {"key": "val"}
+ # caller is forwarded to BaseDatabricksHook.__init__; verify the
default was passed
+ assert mock_base_init.call_args.kwargs.get("caller") ==
"DatabricksSqlHook"
+
+ def test_caller_positional_not_confused_with_query_tags(self):
+ """Passing caller as the 8th positional arg must not end up in
query_tags."""
+ with patch(
+
"airflow.providers.databricks.hooks.databricks_sql.BaseDatabricksHook.__init__",
+ return_value=None,
+ ) as mock_base_init:
+ hook = DatabricksSqlHook.__new__(DatabricksSqlHook)
+ # positional order: conn_id, http_path, sql_endpoint, session_cfg,
+ # http_headers, catalog, schema, caller
+ DatabricksSqlHook.__init__(
+ hook,
+ DEFAULT_CONN_ID, # databricks_conn_id
+ None, # http_path
+ None, # sql_endpoint_name
+ None, # session_configuration
+ None, # http_headers
+ None, # catalog
+ None, # schema
+ "CustomCaller", # caller (8th positional)
+ )
+ # caller is forwarded to BaseDatabricksHook.__init__; verify it
was not
+ # confused with query_tags (which comes after caller)
+ assert mock_base_init.call_args.kwargs.get("caller") ==
"CustomCaller"
+ assert hook.query_tags is None
diff --git
a/providers/databricks/tests/unit/databricks/operators/test_databricks_sql.py
b/providers/databricks/tests/unit/databricks/operators/test_databricks_sql.py
index e216c56bea2..bfd6d89437b 100644
---
a/providers/databricks/tests/unit/databricks/operators/test_databricks_sql.py
+++
b/providers/databricks/tests/unit/databricks/operators/test_databricks_sql.py
@@ -20,13 +20,17 @@ from __future__ import annotations
import json
import os
from collections import namedtuple
+from unittest import mock
from unittest.mock import patch
import pytest
from databricks.sql.types import Row
from airflow.providers.common.sql.hooks.handlers import fetch_all_handler
-from airflow.providers.databricks.operators.databricks_sql import
DatabricksSqlOperator
+from airflow.providers.databricks.operators.databricks_sql import (
+ DatabricksSqlOperator,
+ _get_airflow_query_tags,
+)
DATE = "2017-04-20"
TASK_ID = "databricks-sql-operator"
@@ -453,3 +457,162 @@ def test_parse_gcs_path():
bucket, object_name =
op._parse_gcs_path("gs://my-bucket/path/to/file.parquet")
assert bucket == "my-bucket"
assert object_name == "path/to/file.parquet"
+
+
+class TestDatabricksSqlOperatorQueryTags:
+ """Tests for query tags support in DatabricksSqlOperator."""
+
+ def
test_get_airflow_query_tags_returns_empty_dict_without_task_instance(self):
+ """_get_airflow_query_tags must return {} when context has no 'ti'
key."""
+ result = _get_airflow_query_tags({})
+ assert result == {}
+
+ def test_get_query_tags_with_none_context_returns_custom_tags_only(self):
+ """When context is None, only custom tags are returned (no Airflow
tags)."""
+ op = DatabricksSqlOperator(
+ task_id=TASK_ID,
+ sql="SELECT 1",
+ query_tags={"custom_tag": "custom_value"},
+ )
+ result = op._get_query_tags(None)
+ assert result == {"custom_tag": "custom_value"}
+
+ def
test_get_query_tags_with_none_context_and_no_custom_tags_returns_none(self):
+ """When context is None and no custom tags, None is returned."""
+ op = DatabricksSqlOperator(task_id=TASK_ID, sql="SELECT 1")
+ result = op._get_query_tags(None)
+ assert result is None
+
+ def test_get_query_tags_with_disabled_airflow_tags(self):
+ """When include_airflow_query_tags=False, only custom tags are
returned."""
+ op = DatabricksSqlOperator(
+ task_id=TASK_ID,
+ sql="SELECT 1",
+ query_tags={"custom_tag": "val"},
+ include_airflow_query_tags=False,
+ )
+ mock_context = {"ti": object()}
+ result = op._get_query_tags(mock_context)
+ assert result == {"custom_tag": "val"}
+
+ def test_get_query_tags_with_airflow_context(self):
+ """When context is provided and include_airflow_query_tags=True,
Airflow tags are included."""
+ op = DatabricksSqlOperator(
+ task_id=TASK_ID,
+ sql="SELECT 1",
+ query_tags={"custom_tag": "custom_value"},
+ )
+ mock_ti = mock.MagicMock(spec=["dag_id", "task_id", "run_id",
"try_number", "map_index"])
+ mock_ti.dag_id = "test_dag"
+ mock_ti.task_id = "test_task"
+ mock_ti.run_id = "test_run"
+ mock_ti.try_number = 1
+ mock_ti.map_index = -1
+ mock_context = {"ti": mock_ti}
+
+ result = op._get_query_tags(mock_context)
+
+ assert result is not None
+ assert result["airflow_dag_id"] == "test_dag"
+ assert result["airflow_task_id"] == "test_task"
+ assert result["airflow_run_id"] == "test_run"
+ assert result["airflow_try_number"] == "1"
+ assert result["airflow_map_index"] == "-1"
+ assert result["custom_tag"] == "custom_value"
+
+ def test_execute_sets_query_tags_on_hook(self):
+ """execute() sets query_tags on the hook before delegating to
SQLExecuteQueryOperator."""
+ with
patch("airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook")
as mock_cls:
+ mock_hook = mock_cls.return_value
+ mock_hook.run.return_value = []
+ mock_hook.descriptions = [[]]
+
+ op = DatabricksSqlOperator(
+ task_id=TASK_ID,
+ sql="SELECT 1",
+ query_tags={"env": "test"},
+ include_airflow_query_tags=False,
+ )
+
+ op.execute(None)
+
+ assert mock_hook.query_tags == {"env": "test"}
+
+ def test_custom_tags_override_airflow_tags_on_key_collision(self):
+ """Custom query_tags override Airflow tags when the same key is
used."""
+ op = DatabricksSqlOperator(
+ task_id=TASK_ID,
+ sql="SELECT 1",
+ query_tags={"airflow_dag_id": "overridden"},
+ )
+ mock_ti = mock.MagicMock(spec=["dag_id", "task_id", "run_id",
"try_number", "map_index"])
+ mock_ti.dag_id = "original_dag"
+ mock_ti.task_id = "task"
+ mock_ti.run_id = "run"
+ mock_ti.try_number = 1
+ mock_ti.map_index = -1
+ mock_context = {"ti": mock_ti}
+
+ result = op._get_query_tags(mock_context)
+
+ assert result is not None
+ assert result["airflow_dag_id"] == "overridden"
+
+
+class TestDatabricksCopyIntoOperatorQueryTags:
+ """Tests for query tags support in DatabricksCopyIntoOperator."""
+
+ def _make_op(self, **kwargs):
+ from airflow.providers.databricks.operators.databricks_sql import
DatabricksCopyIntoOperator
+
+ return DatabricksCopyIntoOperator(
+ task_id=TASK_ID,
+ table_name="test_table",
+ file_location="s3://bucket/path",
+ file_format="CSV",
+ **kwargs,
+ )
+
+ def test_get_query_tags_with_none_context_returns_custom_tags_only(self):
+ op = self._make_op(query_tags={"custom": "value"})
+ result = op._get_query_tags(None)
+ assert result == {"custom": "value"}
+
+ def
test_get_query_tags_with_none_context_and_no_custom_tags_returns_none(self):
+ op = self._make_op()
+ result = op._get_query_tags(None)
+ assert result is None
+
+ def test_get_query_tags_with_airflow_context(self):
+ op = self._make_op(query_tags={"env": "staging"})
+ mock_ti = mock.MagicMock(spec=["dag_id", "task_id", "run_id",
"try_number", "map_index"])
+ mock_ti.dag_id = "copy_dag"
+ mock_ti.task_id = "copy_task"
+ mock_ti.run_id = "run_1"
+ mock_ti.try_number = 2
+ mock_ti.map_index = 0
+ mock_context = {"ti": mock_ti}
+
+ result = op._get_query_tags(mock_context)
+
+ assert result is not None
+ assert result["airflow_dag_id"] == "copy_dag"
+ assert result["airflow_task_id"] == "copy_task"
+ assert result["env"] == "staging"
+
+ def test_execute_sets_query_tags_on_hook(self):
+ from airflow.providers.databricks.operators.databricks_sql import
DatabricksCopyIntoOperator
+
+ with
patch("airflow.providers.databricks.operators.databricks_sql.DatabricksSqlHook")
as mock_cls:
+ mock_hook = mock_cls.return_value
+ op = DatabricksCopyIntoOperator(
+ task_id=TASK_ID,
+ table_name="test_table",
+ file_location="s3://bucket/path",
+ file_format="CSV",
+ query_tags={"env": "prod"},
+ include_airflow_query_tags=False,
+ )
+ op.execute(None)
+
+ assert mock_hook.query_tags == {"env": "prod"}