This is an automated email from the ASF dual-hosted git repository.

rnhttr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 44250748ebd Adding `DatabricksSQLStatementSensor` Sensor with 
Deferrability (#49516)
44250748ebd is described below

commit 44250748ebdb1c6c940058e3416fb87de26c85e3
Author: Jake Roach <116606359+jroachgol...@users.noreply.github.com>
AuthorDate: Wed May 14 21:40:11 2025 -0400

    Adding `DatabricksSQLStatementSensor` Sensor with Deferrability (#49516)
    
    * Adding mixin, Sensor, and updated Operator to create a draft PR
    
    * Prepping to update branch
    
    * Adding tests for DatabricksSQLStatementSensor
    
    * Added unit tests and examples
    
    * Added documentation, pointed ignore tests on mixin
    
    * Implementing changes requested in PR
    
    * Tweaking example DAG
    
    * Adding validation, updating docs
    
    * Updated example DAG
    
    * Adding test for Databricks mixins
    
    * Updated test_project_structure.py
---
 .../databricks/docs/operators/sql_statements.rst   |  46 +++++
 providers/databricks/provider.yaml                 |   1 +
 .../providers/databricks/get_provider_info.py      |   1 +
 .../providers/databricks/operators/databricks.py   |  87 +--------
 .../providers/databricks/sensors/databricks.py     | 162 ++++++++++++++++
 .../airflow/providers/databricks/utils/mixins.py   | 193 +++++++++++++++++++
 .../databricks/example_databricks_sensors.py       |  14 ++
 .../unit/databricks/sensors/test_databricks.py     | 208 +++++++++++++++++++++
 .../tests/unit/databricks/utils/test_mixins.py     | 127 +++++++++++++
 9 files changed, 756 insertions(+), 83 deletions(-)

diff --git a/providers/databricks/docs/operators/sql_statements.rst 
b/providers/databricks/docs/operators/sql_statements.rst
index 73b7948a144..9d314a238da 100644
--- a/providers/databricks/docs/operators/sql_statements.rst
+++ b/providers/databricks/docs/operators/sql_statements.rst
@@ -55,3 +55,49 @@ An example usage of the ``DatabricksSQLStatementsOperator`` 
is as follows:
     :language: python
     :start-after: [START howto_operator_sql_statements]
     :end-before: [END howto_operator_sql_statements]
+
+
+.. _howto/sensor:DatabricksSQLStatementsSensor:
+
+DatabricksSQLStatementsSensor
+===============================
+
+Use the 
:class:`~airflow.providers.databricks.sensor.databricks.DatabricksSQLStatementsSensor`
 to either submit a
+Databricks SQL Statement to Databricks using the
+`Databricks SQL Statement Execution API 
<https://docs.databricks.com/api/workspace/statementexecution>`_, or pass
+a Statement ID to the Sensor and await for the query to terminate execution.
+
+
+Using the Sensor
+------------------
+
+The ``DatabricksSQLStatementsSensor`` does one of two things. The Sensor can 
submit SQL statements to Databricks using
+the `/api/2.0/sql/statements/ 
<https://docs.databricks.com/api/workspace/statementexecution/executestatement>`_
+endpoint. However, the Sensor can also take the Statement ID of an 
already-submitted SQL Statement and handle the
+response to that execution.
+
+It supports configurable execution parameters such as warehouse selection, 
catalog, schema, and parameterized queries.
+The operator can either synchronously poll for query completion or run in a 
deferrable mode for improved efficiency.
+
+The only required parameters for using the Sensor are:
+
+* One of ``statement`` or ``statement_id`` - The SQL statement to execute. The 
statement can optionally be
+  parameterized, see parameters.
+* ``warehouse_id`` - Warehouse upon which to execute a statement.
+
+All other parameters are optional and described in the documentation for 
``DatabricksSQLStatementsSensor`` including
+but not limited to:
+
+* ``catalog``
+* ``schema``
+* ``parameters``
+
+Examples
+--------
+
+An example usage of the ``DatabricksSQLStatementsSensor`` is as follows:
+
+.. exampleinclude:: 
/../../databricks/tests/system/databricks/example_databricks_sensors.py
+    :language: python
+    :start-after: [START howto_sensor_databricks_sql_statement]
+    :end-before: [END howto_sensor_databricks_sql_statement]
diff --git a/providers/databricks/provider.yaml 
b/providers/databricks/provider.yaml
index 8bb7812f25f..18b51930ed7 100644
--- a/providers/databricks/provider.yaml
+++ b/providers/databricks/provider.yaml
@@ -143,6 +143,7 @@ triggers:
 sensors:
   - integration-name: Databricks
     python-modules:
+      - airflow.providers.databricks.sensors.databricks
       - airflow.providers.databricks.sensors.databricks_sql
       - airflow.providers.databricks.sensors.databricks_partition
 
diff --git 
a/providers/databricks/src/airflow/providers/databricks/get_provider_info.py 
b/providers/databricks/src/airflow/providers/databricks/get_provider_info.py
index aa7a31757ca..d0a5d349598 100644
--- a/providers/databricks/src/airflow/providers/databricks/get_provider_info.py
+++ b/providers/databricks/src/airflow/providers/databricks/get_provider_info.py
@@ -107,6 +107,7 @@ def get_provider_info():
             {
                 "integration-name": "Databricks",
                 "python-modules": [
+                    "airflow.providers.databricks.sensors.databricks",
                     "airflow.providers.databricks.sensors.databricks_sql",
                     
"airflow.providers.databricks.sensors.databricks_partition",
                 ],
diff --git 
a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py 
b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py
index abafd217009..66fdd7556b5 100644
--- 
a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py
+++ 
b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py
@@ -34,7 +34,6 @@ from airflow.providers.databricks.hooks.databricks import (
     DatabricksHook,
     RunLifeCycleState,
     RunState,
-    SQLStatementState,
 )
 from airflow.providers.databricks.operators.databricks_workflow import (
     DatabricksWorkflowTaskGroup,
@@ -46,9 +45,9 @@ from airflow.providers.databricks.plugins.databricks_workflow 
import (
 )
 from airflow.providers.databricks.triggers.databricks import (
     DatabricksExecutionTrigger,
-    DatabricksSQLStatementExecutionTrigger,
 )
 from airflow.providers.databricks.utils.databricks import 
normalise_json_content, validate_trigger_event
+from airflow.providers.databricks.utils.mixins import 
DatabricksSQLStatementsMixin
 from airflow.providers.databricks.version_compat import AIRFLOW_V_3_0_PLUS
 
 if TYPE_CHECKING:
@@ -978,7 +977,7 @@ class DatabricksRunNowOperator(BaseOperator):
             self.log.error("Error: Task: %s with invalid run_id was requested 
to be cancelled.", self.task_id)
 
 
-class DatabricksSQLStatementsOperator(BaseOperator):
+class DatabricksSQLStatementsOperator(DatabricksSQLStatementsMixin, 
BaseOperator):
     """
     Submits a Databricks SQL Statement to Databricks using the 
api/2.0/sql/statements/ API endpoint.
 
@@ -1073,59 +1072,6 @@ class DatabricksSQLStatementsOperator(BaseOperator):
             caller=caller,
         )
 
-    def _handle_operator_execution(self) -> None:
-        end_time = time.time() + self.timeout
-        while end_time > time.time():
-            statement_state = 
self._hook.get_sql_statement_state(self.statement_id)
-            if statement_state.is_terminal:
-                if statement_state.is_successful:
-                    self.log.info("%s completed successfully.", self.task_id)
-                    return
-                error_message = (
-                    f"{self.task_id} failed with terminal state: 
{statement_state.state} "
-                    f"and with the error code {statement_state.error_code} "
-                    f"and error message {statement_state.error_message}"
-                )
-                raise AirflowException(error_message)
-
-            self.log.info("%s in run state: %s", self.task_id, 
statement_state.state)
-            self.log.info("Sleeping for %s seconds.", 
self.polling_period_seconds)
-            time.sleep(self.polling_period_seconds)
-
-        self._hook.cancel_sql_statement(self.statement_id)
-        raise AirflowException(
-            f"{self.task_id} timed out after {self.timeout} seconds with 
state: {statement_state.state}",
-        )
-
-    def _handle_deferrable_operator_execution(self) -> None:
-        statement_state = self._hook.get_sql_statement_state(self.statement_id)
-        end_time = time.time() + self.timeout
-        if not statement_state.is_terminal:
-            if not self.statement_id:
-                raise AirflowException("Failed to retrieve statement_id after 
submitting SQL statement.")
-            self.defer(
-                trigger=DatabricksSQLStatementExecutionTrigger(
-                    statement_id=self.statement_id,
-                    databricks_conn_id=self.databricks_conn_id,
-                    end_time=end_time,
-                    polling_period_seconds=self.polling_period_seconds,
-                    retry_limit=self.databricks_retry_limit,
-                    retry_delay=self.databricks_retry_delay,
-                    retry_args=self.databricks_retry_args,
-                ),
-                method_name=DEFER_METHOD_NAME,
-            )
-        else:
-            if statement_state.is_successful:
-                self.log.info("%s completed successfully.", self.task_id)
-            else:
-                error_message = (
-                    f"{self.task_id} failed with terminal state: 
{statement_state.state} "
-                    f"and with the error code {statement_state.error_code} "
-                    f"and error message {statement_state.error_message}"
-                )
-                raise AirflowException(error_message)
-
     def execute(self, context: Context):
         json = {
             "statement": self.statement,
@@ -1146,34 +1092,9 @@ class DatabricksSQLStatementsOperator(BaseOperator):
         if not self.wait_for_termination:
             return
         if self.deferrable:
-            self._handle_deferrable_operator_execution()
-        else:
-            self._handle_operator_execution()
-
-    def on_kill(self):
-        if self.statement_id:
-            self._hook.cancel_sql_statement(self.statement_id)
-            self.log.info(
-                "Task: %s with statement ID: %s was requested to be 
cancelled.",
-                self.task_id,
-                self.statement_id,
-            )
+            
self._handle_deferrable_execution(defer_method_name=DEFER_METHOD_NAME)  # type: 
ignore[misc]
         else:
-            self.log.error(
-                "Error: Task: %s with invalid statement_id was requested to be 
cancelled.", self.task_id
-            )
-
-    def execute_complete(self, context: dict | None, event: dict):
-        statement_state = SQLStatementState.from_json(event["state"])
-        error = event["error"]
-        statement_id = event["statement_id"]
-
-        if statement_state.is_successful:
-            self.log.info("SQL Statement with ID %s completed successfully.", 
statement_id)
-            return
-
-        error_message = f"SQL Statement execution failed with terminal state: 
{statement_state} and with the error {error}"
-        raise AirflowException(error_message)
+            self._handle_execution()  # type: ignore[misc]
 
 
 class DatabricksTaskBaseOperator(BaseOperator, ABC):
diff --git 
a/providers/databricks/src/airflow/providers/databricks/sensors/databricks.py 
b/providers/databricks/src/airflow/providers/databricks/sensors/databricks.py
new file mode 100644
index 00000000000..f417291f2c9
--- /dev/null
+++ 
b/providers/databricks/src/airflow/providers/databricks/sensors/databricks.py
@@ -0,0 +1,162 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from __future__ import annotations
+
+from collections.abc import Sequence
+from functools import cached_property
+from typing import TYPE_CHECKING, Any
+
+from airflow.configuration import conf
+from airflow.exceptions import AirflowException
+from airflow.providers.databricks.hooks.databricks import DatabricksHook, 
SQLStatementState
+from airflow.providers.databricks.operators.databricks import DEFER_METHOD_NAME
+from airflow.providers.databricks.utils.mixins import 
DatabricksSQLStatementsMixin
+from airflow.providers.databricks.version_compat import AIRFLOW_V_3_0_PLUS
+
+if AIRFLOW_V_3_0_PLUS:
+    from airflow.sdk import BaseSensorOperator
+else:
+    from airflow.sensors.base import BaseSensorOperator
+
+if TYPE_CHECKING:
+    from airflow.utils.context import Context
+
+XCOM_STATEMENT_ID_KEY = "statement_id"
+
+
+class DatabricksSQLStatementsSensor(DatabricksSQLStatementsMixin, 
BaseSensorOperator):
+    """DatabricksSQLStatementsSensor."""
+
+    template_fields: Sequence[str] = (
+        "databricks_conn_id",
+        "statement",
+        "statement_id",
+    )
+    template_ext: Sequence[str] = (".json-tpl",)
+    ui_color = "#1CB1C2"
+    ui_fgcolor = "#fff"
+
+    def __init__(
+        self,
+        warehouse_id: str,
+        *,
+        statement: str | None = None,
+        statement_id: str | None = None,
+        catalog: str | None = None,
+        schema: str | None = None,
+        parameters: list[dict[str, Any]] | None = None,
+        databricks_conn_id: str = "databricks_default",
+        polling_period_seconds: int = 30,
+        databricks_retry_limit: int = 3,
+        databricks_retry_delay: int = 1,
+        databricks_retry_args: dict[Any, Any] | None = None,
+        do_xcom_push: bool = True,
+        wait_for_termination: bool = True,
+        timeout: float = 3600,
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
+        **kwargs,
+    ):
+        # Handle the scenario where either both statement and statement_id are 
set/not set
+        if statement and statement_id:
+            raise AirflowException("Cannot provide both statement and 
statement_id.")
+        if not statement and not statement_id:
+            raise AirflowException("One of either statement or statement_id 
must be provided.")
+
+        if not warehouse_id:
+            raise AirflowException("warehouse_id must be provided.")
+
+        super().__init__(**kwargs)
+
+        self.statement = statement
+        self.statement_id = statement_id
+        self.warehouse_id = warehouse_id
+        self.catalog = catalog
+        self.schema = schema
+        self.parameters = parameters
+        self.databricks_conn_id = databricks_conn_id
+        self.polling_period_seconds = polling_period_seconds
+        self.databricks_retry_limit = databricks_retry_limit
+        self.databricks_retry_delay = databricks_retry_delay
+        self.databricks_retry_args = databricks_retry_args
+        self.wait_for_termination = wait_for_termination
+        self.deferrable = deferrable
+        self.timeout = timeout
+        self.do_xcom_push = do_xcom_push
+
+    @cached_property
+    def _hook(self):
+        return self._get_hook(caller="DatabricksSQLStatementsSensor")
+
+    def _get_hook(self, caller: str) -> DatabricksHook:
+        return DatabricksHook(
+            self.databricks_conn_id,
+            retry_limit=self.databricks_retry_limit,
+            retry_delay=self.databricks_retry_delay,
+            retry_args=self.databricks_retry_args,
+            caller=caller,
+        )
+
+    def execute(self, context: Context):
+        if not self.statement_id:
+            # Otherwise, we'll go ahead and "submit" the statement
+            json = {
+                "statement": self.statement,
+                "warehouse_id": self.warehouse_id,
+                "catalog": self.catalog,
+                "schema": self.schema,
+                "parameters": self.parameters,
+                "wait_timeout": "0s",
+            }
+
+            self.statement_id = self._hook.post_sql_statement(json)
+            self.log.info("SQL Statement submitted with statement_id: %s", 
self.statement_id)
+
+        if self.do_xcom_push and context is not None:
+            context["ti"].xcom_push(key=XCOM_STATEMENT_ID_KEY, 
value=self.statement_id)
+
+        # If we're not waiting for the query to complete execution, then we'll 
go ahead and return. However, a
+        # recommendation to use the DatabricksSQLStatementOperator is made in 
this case
+        if not self.wait_for_termination:
+            self.log.info(
+                "If setting wait_for_termination = False, consider using the 
DatabricksSQLStatementsOperator instead."
+            )
+            return
+
+        if self.deferrable:
+            
self._handle_deferrable_execution(defer_method_name=DEFER_METHOD_NAME)  # type: 
ignore[misc]
+
+    def poke(self, context: Context):
+        """
+        Handle non-deferrable Sensor execution.
+
+        :param context: (Context)
+        :return: (bool)
+        """
+        # This is going to very closely mirror the execute_complete
+        statement_state: SQLStatementState = 
self._hook.get_sql_statement_state(self.statement_id)
+
+        if statement_state.is_running:
+            self.log.info("SQL Statement with ID %s is running", 
self.statement_id)
+            return False
+        if statement_state.is_successful:
+            self.log.info("SQL Statement with ID %s completed successfully.", 
self.statement_id)
+            return True
+        raise AirflowException(
+            f"SQL Statement with ID {statement_state} failed with error: 
{statement_state.error_message}"
+        )
diff --git 
a/providers/databricks/src/airflow/providers/databricks/utils/mixins.py 
b/providers/databricks/src/airflow/providers/databricks/utils/mixins.py
new file mode 100644
index 00000000000..97d8c14e5b6
--- /dev/null
+++ b/providers/databricks/src/airflow/providers/databricks/utils/mixins.py
@@ -0,0 +1,193 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+from __future__ import annotations
+
+import time
+from logging import Logger
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Protocol,
+)
+
+from airflow.exceptions import AirflowException
+from airflow.providers.databricks.hooks.databricks import DatabricksHook, 
SQLStatementState
+from airflow.providers.databricks.triggers.databricks import 
DatabricksSQLStatementExecutionTrigger
+
+if TYPE_CHECKING:
+    from airflow.utils.context import Context
+
+
+class GetHookHasFields(Protocol):
+    """Protocol for get_hook method."""
+
+    databricks_conn_id: str
+    databricks_retry_args: dict | None
+    databricks_retry_delay: int
+    databricks_retry_limit: int
+
+
+class HandleExecutionHasFields(Protocol):
+    """Protocol for _handle_execution method."""
+
+    _hook: DatabricksHook
+    log: Logger
+    polling_period_seconds: int
+    task_id: str
+    timeout: int
+    statement_id: str
+
+
+class HandleDeferrableExecutionHasFields(Protocol):
+    """Protocol for _handle_deferrable_execution method."""
+
+    _hook: DatabricksHook
+    databricks_conn_id: str
+    databricks_retry_args: dict[Any, Any] | None
+    databricks_retry_delay: int
+    databricks_retry_limit: int
+    defer: Any
+    log: Logger
+    polling_period_seconds: int
+    statement_id: str
+    task_id: str
+    timeout: int
+
+
+class ExecuteCompleteHasFields(Protocol):
+    """Protocol for execute_complete method."""
+
+    statement_id: str
+    _hook: DatabricksHook
+    log: Logger
+
+
+class OnKillHasFields(Protocol):
+    """Protocol for on_kill method."""
+
+    _hook: DatabricksHook
+    log: Logger
+    statement_id: str
+    task_id: str
+
+
+class DatabricksSQLStatementsMixin:
+    """
+    Mixin class to be used by both the DatabricksSqlStatementsOperator, and 
the DatabricksSqlStatementSensor.
+
+        - _handle_operator_execution (renamed to _handle_execution)
+        - _handle_deferrable_operator_execution (renamed to 
_handle_deferrable_execution)
+        - execute_complete
+        - on_kill
+    """
+
+    def _handle_execution(self: HandleExecutionHasFields) -> None:
+        """Execute a SQL statement in non-deferrable mode."""
+        # Determine the time at which the Task will timeout. The 
statement_state is defined here in the event
+        # the while-loop is never entered
+        end_time = time.time() + self.timeout
+
+        while end_time > time.time():
+            statement_state: SQLStatementState = 
self._hook.get_sql_statement_state(self.statement_id)
+
+            if statement_state.is_terminal:
+                if statement_state.is_successful:
+                    self.log.info("%s completed successfully.", self.task_id)
+                    return
+
+                error_message = (
+                    f"{self.task_id} failed with terminal state: 
{statement_state.state} "
+                    f"and with the error code {statement_state.error_code} "
+                    f"and error message {statement_state.error_message}"
+                )
+                raise AirflowException(error_message)
+
+            self.log.info("%s in run state: %s", self.task_id, 
statement_state.state)
+            self.log.info("Sleeping for %s seconds.", 
self.polling_period_seconds)
+            time.sleep(self.polling_period_seconds)
+
+        # Once the timeout is exceeded, the query is cancelled. This is an 
important steps; if a query takes
+        # to log, it needs to be killed. Otherwise, it may be the case that 
there are "zombie" queries running
+        # that are no longer being orchestrated
+        self._hook.cancel_sql_statement(self.statement_id)
+        raise AirflowException(
+            f"{self.task_id} timed out after {self.timeout} seconds with 
state: {statement_state}",
+        )
+
+    def _handle_deferrable_execution(
+        self: HandleDeferrableExecutionHasFields, defer_method_name: str = 
"execute_complete"
+    ) -> None:
+        """Execute a SQL statement in deferrable mode."""
+        statement_state: SQLStatementState = 
self._hook.get_sql_statement_state(self.statement_id)
+        end_time: float = time.time() + self.timeout
+
+        if not statement_state.is_terminal:
+            # If the query is still running and there is no statement_id, this 
is somewhat of a "zombie"
+            # query, and should throw an exception
+            if not self.statement_id:
+                raise AirflowException("Failed to retrieve statement_id after 
submitting SQL statement.")
+
+            self.defer(
+                trigger=DatabricksSQLStatementExecutionTrigger(
+                    statement_id=self.statement_id,
+                    databricks_conn_id=self.databricks_conn_id,
+                    end_time=end_time,
+                    polling_period_seconds=self.polling_period_seconds,
+                    retry_limit=self.databricks_retry_limit,
+                    retry_delay=self.databricks_retry_delay,
+                    retry_args=self.databricks_retry_args,
+                ),
+                method_name=defer_method_name,
+            )
+
+        else:
+            if statement_state.is_successful:
+                self.log.info("%s completed successfully.", self.task_id)
+            else:
+                error_message = (
+                    f"{self.task_id} failed with terminal state: 
{statement_state.state} "
+                    f"and with the error code {statement_state.error_code} "
+                    f"and error message {statement_state.error_message}"
+                )
+                raise AirflowException(error_message)
+
+    def execute_complete(self: ExecuteCompleteHasFields, context: Context, 
event: dict):
+        statement_state = SQLStatementState.from_json(event["state"])
+        error = event["error"]
+        statement_id = event["statement_id"]
+
+        if statement_state.is_successful:
+            self.log.info("SQL Statement with ID %s completed successfully.", 
statement_id)
+            return
+
+        error_message = f"SQL Statement execution failed with terminal state: 
{statement_state} and with the error {error}"
+        raise AirflowException(error_message)
+
+    def on_kill(self: OnKillHasFields) -> None:
+        if self.statement_id:
+            self._hook.cancel_sql_statement(self.statement_id)
+            self.log.info(
+                "Task: %s with statement ID: %s was requested to be 
cancelled.",
+                self.task_id,
+                self.statement_id,
+            )
+        else:
+            self.log.error(
+                "Error: Task: %s with invalid statement_id was requested to be 
cancelled.", self.task_id
+            )
diff --git 
a/providers/databricks/tests/system/databricks/example_databricks_sensors.py 
b/providers/databricks/tests/system/databricks/example_databricks_sensors.py
index cf676183e6c..177ea8ce293 100644
--- a/providers/databricks/tests/system/databricks/example_databricks_sensors.py
+++ b/providers/databricks/tests/system/databricks/example_databricks_sensors.py
@@ -22,6 +22,7 @@ import textwrap
 from datetime import datetime
 
 from airflow import DAG
+from airflow.providers.databricks.sensors.databricks import 
DatabricksSQLStatementsSensor
 from airflow.providers.databricks.sensors.databricks_partition import 
DatabricksPartitionSensor
 from airflow.providers.databricks.sensors.databricks_sql import 
DatabricksSqlSensor
 
@@ -30,6 +31,7 @@ ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
 # [DAG name to be shown on Airflow UI]
 DAG_ID = "example_databricks_sensor"
 
+
 with DAG(
     dag_id=DAG_ID,
     schedule="@daily",
@@ -67,6 +69,18 @@ with DAG(
     )
     # [END howto_sensor_databricks_sql]
 
+    # [START howto_sensor_databricks_sql_statement]
+    # Example of using the DatabricksSQLStatementSensor to wait for a query
+    # to successfully run.
+    sql_statement_sensor = DatabricksSQLStatementsSensor(
+        task_id="sql_statement_sensor_task",
+        databricks_conn_id=connection_id,
+        warehouse_id="warehouse_id",
+        statement="select * from default.my_airflow_table",
+        # deferrable=True, # For using the operator in deferrable mode
+    )
+    # [END howto_sensor_databricks_sql_statement]
+
     # [START howto_sensor_databricks_partition]
     # Example of using the Databricks Partition Sensor to check the presence
     # of the specified partition(s) in a table.
diff --git 
a/providers/databricks/tests/unit/databricks/sensors/test_databricks.py 
b/providers/databricks/tests/unit/databricks/sensors/test_databricks.py
new file mode 100644
index 00000000000..dc2521a21bc
--- /dev/null
+++ b/providers/databricks/tests/unit/databricks/sensors/test_databricks.py
@@ -0,0 +1,208 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+
+import pytest
+
+from airflow.exceptions import AirflowException, TaskDeferred
+from airflow.providers.databricks.hooks.databricks import SQLStatementState
+from airflow.providers.databricks.sensors.databricks import 
DatabricksSQLStatementsSensor
+from airflow.providers.databricks.triggers.databricks import 
DatabricksSQLStatementExecutionTrigger
+
+DEFAULT_CONN_ID = "databricks_default"
+STATEMENT = "select * from test.test;"
+STATEMENT_ID = "statement_id"
+TASK_ID = "task_id"
+WAREHOUSE_ID = "warehouse_id"
+
+
+class TestDatabricksSQLStatementsSensor:
+    """
+    Validate and test the functionality of the DatabricksSQLStatementsSensor. 
This Sensor borrows heavily
+    from the DatabricksSQLStatementOperator, meaning that much of the testing 
logic is also reused.
+    """
+
+    def test_init_statement(self):
+        """Test initialization for traditional use-case (statement)."""
+        op = DatabricksSQLStatementsSensor(task_id=TASK_ID, 
statement=STATEMENT, warehouse_id=WAREHOUSE_ID)
+
+        assert op.statement == STATEMENT
+        assert op.warehouse_id == WAREHOUSE_ID
+
+    def test_init_statement_id(self):
+        """Test initialization when a statement_id is passed, rather than a 
statement."""
+        op = DatabricksSQLStatementsSensor(
+            task_id=TASK_ID, statement_id=STATEMENT_ID, 
warehouse_id=WAREHOUSE_ID
+        )
+
+        assert op.statement_id == STATEMENT_ID
+        assert op.warehouse_id == WAREHOUSE_ID
+
+    
@mock.patch("airflow.providers.databricks.sensors.databricks.DatabricksHook")
+    def test_exec_success(self, db_mock_class):
+        """
+        Test the execute function for non-deferrable execution. This same 
exact behavior is expected when the
+        statement itself fails, so no test_exec_failure_statement is 
implemented.
+        """
+        expected_json = {
+            "statement": STATEMENT,
+            "warehouse_id": WAREHOUSE_ID,
+            "catalog": None,
+            "schema": None,
+            "parameters": None,
+            "wait_timeout": "0s",
+        }
+
+        op = DatabricksSQLStatementsSensor(task_id=TASK_ID, 
statement=STATEMENT, warehouse_id=WAREHOUSE_ID)
+        db_mock = db_mock_class.return_value
+        db_mock.post_sql_statement.return_value = STATEMENT_ID
+
+        op.execute(None)  # No context is being passed in
+
+        db_mock_class.assert_called_once_with(
+            DEFAULT_CONN_ID,
+            retry_limit=op.databricks_retry_limit,
+            retry_delay=op.databricks_retry_delay,
+            retry_args=None,
+            caller="DatabricksSQLStatementsSensor",
+        )
+
+        # Since a statement is being passed in rather than a statement_id, 
we're asserting that the
+        # post_sql_statement method is called once
+        db_mock.post_sql_statement.assert_called_once_with(expected_json)
+        assert op.statement_id == STATEMENT_ID
+
+    
@mock.patch("airflow.providers.databricks.sensors.databricks.DatabricksHook")
+    def test_on_kill(self, db_mock_class):
+        """
+        Test the on_kill method. This is actually part of the 
DatabricksSQLStatementMixin, so the
+        test logic will match that with the same name for 
DatabricksSQLStatementOperator.
+        """
+        # Behavior here will remain the same whether a statement or 
statement_id is passed
+        op = DatabricksSQLStatementsSensor(task_id=TASK_ID, 
statement=STATEMENT, warehouse_id=WAREHOUSE_ID)
+        db_mock = db_mock_class.return_value
+        op.statement_id = STATEMENT_ID
+
+        # When on_kill is executed, it should call the cancel_sql_statement 
method
+        op.on_kill()
+        db_mock.cancel_sql_statement.assert_called_once_with(STATEMENT_ID)
+
+    def test_wait_for_termination_is_default(self):
+        """Validate that the default value for wait_for_termination is True."""
+        op = DatabricksSQLStatementsSensor(
+            task_id=TASK_ID, statement="select * from test.test;", 
warehouse_id=WAREHOUSE_ID
+        )
+
+        assert op.wait_for_termination
+
+    @pytest.mark.parametrize(
+        argnames=("statement_state", "expected_poke_result"),
+        argvalues=[
+            ("RUNNING", False),
+            ("SUCCEEDED", True),
+        ],
+    )
+    
@mock.patch("airflow.providers.databricks.sensors.databricks.DatabricksHook")
+    def test_poke(self, db_mock_class, statement_state, expected_poke_result):
+        op = DatabricksSQLStatementsSensor(
+            task_id=TASK_ID,
+            statement=STATEMENT,
+            warehouse_id=WAREHOUSE_ID,
+        )
+        db_mock = db_mock_class.return_value
+        db_mock.get_sql_statement_state.return_value = 
SQLStatementState(statement_state)
+
+        poke_result = op.poke(None)
+
+        assert poke_result == expected_poke_result
+
+    
@mock.patch("airflow.providers.databricks.sensors.databricks.DatabricksHook")
+    def test_poke_failure(self, db_mock_class):
+        op = DatabricksSQLStatementsSensor(
+            task_id=TASK_ID,
+            statement=STATEMENT,
+            warehouse_id=WAREHOUSE_ID,
+        )
+        db_mock = db_mock_class.return_value
+        db_mock.get_sql_statement_state.return_value = 
SQLStatementState("FAILED")
+
+        with pytest.raises(AirflowException):
+            op.poke(None)
+
+    
@mock.patch("airflow.providers.databricks.sensors.databricks.DatabricksHook")
+    def test_execute_task_deferred(self, db_mock_class):
+        """
+        Test that the statement is successfully deferred. This behavior will 
remain the same whether a
+        statement or a statement_id is passed.
+        """
+        op = DatabricksSQLStatementsSensor(
+            task_id=TASK_ID,
+            statement=STATEMENT,
+            warehouse_id=WAREHOUSE_ID,
+            deferrable=True,
+        )
+        db_mock = db_mock_class.return_value
+        db_mock.get_sql_statement_state.return_value = 
SQLStatementState("RUNNING")
+
+        with pytest.raises(TaskDeferred) as exc:
+            op.execute(None)
+
+        assert isinstance(exc.value.trigger, 
DatabricksSQLStatementExecutionTrigger)
+        assert exc.value.method_name == "execute_complete"
+
+    def test_execute_complete_success(self):
+        """
+        Test the execute_complete function in case the Trigger has returned a 
successful completion event.
+        This method is part of the DatabricksSQLStatementsMixin. Note that 
this is only being tested when
+        in deferrable mode.
+        """
+        event = {
+            "statement_id": STATEMENT_ID,
+            "state": SQLStatementState("SUCCEEDED").to_json(),
+            "error": {},
+        }
+
+        op = DatabricksSQLStatementsSensor(
+            task_id=TASK_ID,
+            statement=STATEMENT,
+            warehouse_id=WAREHOUSE_ID,
+            deferrable=True,
+        )
+        assert op.execute_complete(context=None, event=event) is None
+
+    
@mock.patch("airflow.providers.databricks.sensors.databricks.DatabricksHook")
+    def test_execute_complete_failure(self, db_mock_class):
+        """Test execute_complete function in case the Trigger has returned a 
failure completion event."""
+        event = {
+            "statement_id": STATEMENT_ID,
+            "state": SQLStatementState("FAILED").to_json(),
+            "error": SQLStatementState(
+                state="FAILED", error_code="500", error_message="Something 
Went Wrong"
+            ).to_json(),
+        }
+        op = DatabricksSQLStatementsSensor(
+            task_id=TASK_ID,
+            statement=STATEMENT,
+            warehouse_id=WAREHOUSE_ID,
+            deferrable=True,
+        )
+
+        with pytest.raises(AirflowException, match="^SQL Statement execution 
failed with terminal state: .*"):
+            op.execute_complete(context=None, event=event)
diff --git a/providers/databricks/tests/unit/databricks/utils/test_mixins.py 
b/providers/databricks/tests/unit/databricks/utils/test_mixins.py
new file mode 100644
index 00000000000..95db64edfd2
--- /dev/null
+++ b/providers/databricks/tests/unit/databricks/utils/test_mixins.py
@@ -0,0 +1,127 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock
+
+import pytest
+
+from airflow.exceptions import AirflowException
+from airflow.providers.databricks.utils.mixins import 
DatabricksSQLStatementsMixin
+
+
+class DatabricksSQLStatements(DatabricksSQLStatementsMixin):
+    def __init__(self):
+        self.databricks_conn_id = "databricks_conn_id"
+        self.databricks_retry_limit = 3
+        self.databricks_retry_delay = 60
+        self.databricks_retry_args = None
+        self.polling_period_seconds = 10
+        self.statement_id = "statement_id"
+        self.task_id = "task_id"
+        self.timeout = 60
+
+        # Utilities
+        self._hook = MagicMock()
+        self.defer = MagicMock()
+        self.log = MagicMock()
+
+
+@pytest.fixture
+def databricks_sql_statements():
+    return DatabricksSQLStatements()
+
+
+@pytest.fixture
+def terminal_success_state():
+    terminal_success_state = MagicMock()
+    terminal_success_state.is_terminal = True
+    terminal_success_state.is_successful = True
+    return terminal_success_state
+
+
+@pytest.fixture
+def terminal_failure_state():
+    terminal_fail_state = MagicMock()
+    terminal_fail_state.is_terminal = True
+    terminal_fail_state.is_successful = False
+    terminal_fail_state.state = "FAILED"
+    terminal_fail_state.error_code = "123"
+    terminal_fail_state.error_message = "Query failed"
+    return terminal_fail_state
+
+
+class TestDatabricksSQLStatementsMixin:
+    """
+    We'll provide tests for each of the following methods:
+
+    - _handle_execution
+    - _handle_deferrable_execution
+    - execute_complete
+    - on_kill
+    """
+
+    def test_handle_execution_success(self, databricks_sql_statements, 
terminal_success_state):
+        # Test an immediate success of the SQL statement
+        databricks_sql_statements._hook.get_sql_statement_state.return_value = 
terminal_success_state
+        databricks_sql_statements._handle_execution()
+
+        
databricks_sql_statements._hook.cancel_sql_statement.assert_not_called()
+
+    def test_handle_execution_failure(self, databricks_sql_statements, 
terminal_failure_state):
+        # Test an immediate failure of the SQL statement
+        databricks_sql_statements._hook.get_sql_statement_state.return_value = 
terminal_failure_state
+
+        with pytest.raises(AirflowException):
+            databricks_sql_statements._handle_execution()
+
+        
databricks_sql_statements._hook.cancel_sql_statement.assert_not_called()
+
+    def test_handle_deferrable_execution_running(self, 
databricks_sql_statements):
+        terminal_running_state = MagicMock()
+        terminal_running_state.is_terminal = False
+
+        # Test an immediate success of the SQL statement
+        databricks_sql_statements._hook.get_sql_statement_state.return_value = 
terminal_running_state
+        databricks_sql_statements._handle_deferrable_execution()
+
+        databricks_sql_statements.defer.assert_called_once()
+
+    def test_handle_deferrable_execution_success(self, 
databricks_sql_statements, terminal_success_state):
+        # Test an immediate success of the SQL statement
+        databricks_sql_statements._hook.get_sql_statement_state.return_value = 
terminal_success_state
+        databricks_sql_statements._handle_deferrable_execution()
+
+        databricks_sql_statements.defer.assert_not_called()
+
+    def test_handle_deferrable_execution_failure(self, 
databricks_sql_statements, terminal_failure_state):
+        # Test an immediate failure of the SQL statement
+        databricks_sql_statements._hook.get_sql_statement_state.return_value = 
terminal_failure_state
+
+        with pytest.raises(AirflowException):
+            databricks_sql_statements._handle_deferrable_execution()
+
+    def test_execute_complete(self):
+        # Both the TestDatabricksSQLStatementsOperator and 
TestDatabricksSQLStatementsSensor tests implement
+        # a test_execute_complete_failure and test_execute_complete_success 
method, so we'll pass here
+        pass
+
+    def test_on_kill(self):
+        # This test is implemented in both the 
TestDatabricksSQLStatementsOperator and
+        # TestDatabricksSQLStatementsSensor tests, so it will not be 
implemented here
+        pass


Reply via email to