This is an automated email from the ASF dual-hosted git repository. rnhttr pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 44250748ebd Adding `DatabricksSQLStatementSensor` Sensor with Deferrability (#49516) 44250748ebd is described below commit 44250748ebdb1c6c940058e3416fb87de26c85e3 Author: Jake Roach <116606359+jroachgol...@users.noreply.github.com> AuthorDate: Wed May 14 21:40:11 2025 -0400 Adding `DatabricksSQLStatementSensor` Sensor with Deferrability (#49516) * Adding mixin, Sensor, and updated Operator to create a draft PR * Prepping to update branch * Adding tests for DatabricksSQLStatementSensor * Added unit tests and examples * Added documentation, pointed ignore tests on mixin * Implementing changes requested in PR * Tweaking example DAG * Adding validation, updating docs * Updated example DAG * Adding test for Databricks mixins * Updated test_project_structure.py --- .../databricks/docs/operators/sql_statements.rst | 46 +++++ providers/databricks/provider.yaml | 1 + .../providers/databricks/get_provider_info.py | 1 + .../providers/databricks/operators/databricks.py | 87 +-------- .../providers/databricks/sensors/databricks.py | 162 ++++++++++++++++ .../airflow/providers/databricks/utils/mixins.py | 193 +++++++++++++++++++ .../databricks/example_databricks_sensors.py | 14 ++ .../unit/databricks/sensors/test_databricks.py | 208 +++++++++++++++++++++ .../tests/unit/databricks/utils/test_mixins.py | 127 +++++++++++++ 9 files changed, 756 insertions(+), 83 deletions(-) diff --git a/providers/databricks/docs/operators/sql_statements.rst b/providers/databricks/docs/operators/sql_statements.rst index 73b7948a144..9d314a238da 100644 --- a/providers/databricks/docs/operators/sql_statements.rst +++ b/providers/databricks/docs/operators/sql_statements.rst @@ -55,3 +55,49 @@ An example usage of the ``DatabricksSQLStatementsOperator`` is as follows: :language: python :start-after: [START howto_operator_sql_statements] :end-before: [END howto_operator_sql_statements] + + +.. _howto/sensor:DatabricksSQLStatementsSensor: + +DatabricksSQLStatementsSensor +=============================== + +Use the :class:`~airflow.providers.databricks.sensor.databricks.DatabricksSQLStatementsSensor` to either submit a +Databricks SQL Statement to Databricks using the +`Databricks SQL Statement Execution API <https://docs.databricks.com/api/workspace/statementexecution>`_, or pass +a Statement ID to the Sensor and await for the query to terminate execution. + + +Using the Sensor +------------------ + +The ``DatabricksSQLStatementsSensor`` does one of two things. The Sensor can submit SQL statements to Databricks using +the `/api/2.0/sql/statements/ <https://docs.databricks.com/api/workspace/statementexecution/executestatement>`_ +endpoint. However, the Sensor can also take the Statement ID of an already-submitted SQL Statement and handle the +response to that execution. + +It supports configurable execution parameters such as warehouse selection, catalog, schema, and parameterized queries. +The operator can either synchronously poll for query completion or run in a deferrable mode for improved efficiency. + +The only required parameters for using the Sensor are: + +* One of ``statement`` or ``statement_id`` - The SQL statement to execute. The statement can optionally be + parameterized, see parameters. +* ``warehouse_id`` - Warehouse upon which to execute a statement. + +All other parameters are optional and described in the documentation for ``DatabricksSQLStatementsSensor`` including +but not limited to: + +* ``catalog`` +* ``schema`` +* ``parameters`` + +Examples +-------- + +An example usage of the ``DatabricksSQLStatementsSensor`` is as follows: + +.. exampleinclude:: /../../databricks/tests/system/databricks/example_databricks_sensors.py + :language: python + :start-after: [START howto_sensor_databricks_sql_statement] + :end-before: [END howto_sensor_databricks_sql_statement] diff --git a/providers/databricks/provider.yaml b/providers/databricks/provider.yaml index 8bb7812f25f..18b51930ed7 100644 --- a/providers/databricks/provider.yaml +++ b/providers/databricks/provider.yaml @@ -143,6 +143,7 @@ triggers: sensors: - integration-name: Databricks python-modules: + - airflow.providers.databricks.sensors.databricks - airflow.providers.databricks.sensors.databricks_sql - airflow.providers.databricks.sensors.databricks_partition diff --git a/providers/databricks/src/airflow/providers/databricks/get_provider_info.py b/providers/databricks/src/airflow/providers/databricks/get_provider_info.py index aa7a31757ca..d0a5d349598 100644 --- a/providers/databricks/src/airflow/providers/databricks/get_provider_info.py +++ b/providers/databricks/src/airflow/providers/databricks/get_provider_info.py @@ -107,6 +107,7 @@ def get_provider_info(): { "integration-name": "Databricks", "python-modules": [ + "airflow.providers.databricks.sensors.databricks", "airflow.providers.databricks.sensors.databricks_sql", "airflow.providers.databricks.sensors.databricks_partition", ], diff --git a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py index abafd217009..66fdd7556b5 100644 --- a/providers/databricks/src/airflow/providers/databricks/operators/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/operators/databricks.py @@ -34,7 +34,6 @@ from airflow.providers.databricks.hooks.databricks import ( DatabricksHook, RunLifeCycleState, RunState, - SQLStatementState, ) from airflow.providers.databricks.operators.databricks_workflow import ( DatabricksWorkflowTaskGroup, @@ -46,9 +45,9 @@ from airflow.providers.databricks.plugins.databricks_workflow import ( ) from airflow.providers.databricks.triggers.databricks import ( DatabricksExecutionTrigger, - DatabricksSQLStatementExecutionTrigger, ) from airflow.providers.databricks.utils.databricks import normalise_json_content, validate_trigger_event +from airflow.providers.databricks.utils.mixins import DatabricksSQLStatementsMixin from airflow.providers.databricks.version_compat import AIRFLOW_V_3_0_PLUS if TYPE_CHECKING: @@ -978,7 +977,7 @@ class DatabricksRunNowOperator(BaseOperator): self.log.error("Error: Task: %s with invalid run_id was requested to be cancelled.", self.task_id) -class DatabricksSQLStatementsOperator(BaseOperator): +class DatabricksSQLStatementsOperator(DatabricksSQLStatementsMixin, BaseOperator): """ Submits a Databricks SQL Statement to Databricks using the api/2.0/sql/statements/ API endpoint. @@ -1073,59 +1072,6 @@ class DatabricksSQLStatementsOperator(BaseOperator): caller=caller, ) - def _handle_operator_execution(self) -> None: - end_time = time.time() + self.timeout - while end_time > time.time(): - statement_state = self._hook.get_sql_statement_state(self.statement_id) - if statement_state.is_terminal: - if statement_state.is_successful: - self.log.info("%s completed successfully.", self.task_id) - return - error_message = ( - f"{self.task_id} failed with terminal state: {statement_state.state} " - f"and with the error code {statement_state.error_code} " - f"and error message {statement_state.error_message}" - ) - raise AirflowException(error_message) - - self.log.info("%s in run state: %s", self.task_id, statement_state.state) - self.log.info("Sleeping for %s seconds.", self.polling_period_seconds) - time.sleep(self.polling_period_seconds) - - self._hook.cancel_sql_statement(self.statement_id) - raise AirflowException( - f"{self.task_id} timed out after {self.timeout} seconds with state: {statement_state.state}", - ) - - def _handle_deferrable_operator_execution(self) -> None: - statement_state = self._hook.get_sql_statement_state(self.statement_id) - end_time = time.time() + self.timeout - if not statement_state.is_terminal: - if not self.statement_id: - raise AirflowException("Failed to retrieve statement_id after submitting SQL statement.") - self.defer( - trigger=DatabricksSQLStatementExecutionTrigger( - statement_id=self.statement_id, - databricks_conn_id=self.databricks_conn_id, - end_time=end_time, - polling_period_seconds=self.polling_period_seconds, - retry_limit=self.databricks_retry_limit, - retry_delay=self.databricks_retry_delay, - retry_args=self.databricks_retry_args, - ), - method_name=DEFER_METHOD_NAME, - ) - else: - if statement_state.is_successful: - self.log.info("%s completed successfully.", self.task_id) - else: - error_message = ( - f"{self.task_id} failed with terminal state: {statement_state.state} " - f"and with the error code {statement_state.error_code} " - f"and error message {statement_state.error_message}" - ) - raise AirflowException(error_message) - def execute(self, context: Context): json = { "statement": self.statement, @@ -1146,34 +1092,9 @@ class DatabricksSQLStatementsOperator(BaseOperator): if not self.wait_for_termination: return if self.deferrable: - self._handle_deferrable_operator_execution() - else: - self._handle_operator_execution() - - def on_kill(self): - if self.statement_id: - self._hook.cancel_sql_statement(self.statement_id) - self.log.info( - "Task: %s with statement ID: %s was requested to be cancelled.", - self.task_id, - self.statement_id, - ) + self._handle_deferrable_execution(defer_method_name=DEFER_METHOD_NAME) # type: ignore[misc] else: - self.log.error( - "Error: Task: %s with invalid statement_id was requested to be cancelled.", self.task_id - ) - - def execute_complete(self, context: dict | None, event: dict): - statement_state = SQLStatementState.from_json(event["state"]) - error = event["error"] - statement_id = event["statement_id"] - - if statement_state.is_successful: - self.log.info("SQL Statement with ID %s completed successfully.", statement_id) - return - - error_message = f"SQL Statement execution failed with terminal state: {statement_state} and with the error {error}" - raise AirflowException(error_message) + self._handle_execution() # type: ignore[misc] class DatabricksTaskBaseOperator(BaseOperator, ABC): diff --git a/providers/databricks/src/airflow/providers/databricks/sensors/databricks.py b/providers/databricks/src/airflow/providers/databricks/sensors/databricks.py new file mode 100644 index 00000000000..f417291f2c9 --- /dev/null +++ b/providers/databricks/src/airflow/providers/databricks/sensors/databricks.py @@ -0,0 +1,162 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from __future__ import annotations + +from collections.abc import Sequence +from functools import cached_property +from typing import TYPE_CHECKING, Any + +from airflow.configuration import conf +from airflow.exceptions import AirflowException +from airflow.providers.databricks.hooks.databricks import DatabricksHook, SQLStatementState +from airflow.providers.databricks.operators.databricks import DEFER_METHOD_NAME +from airflow.providers.databricks.utils.mixins import DatabricksSQLStatementsMixin +from airflow.providers.databricks.version_compat import AIRFLOW_V_3_0_PLUS + +if AIRFLOW_V_3_0_PLUS: + from airflow.sdk import BaseSensorOperator +else: + from airflow.sensors.base import BaseSensorOperator + +if TYPE_CHECKING: + from airflow.utils.context import Context + +XCOM_STATEMENT_ID_KEY = "statement_id" + + +class DatabricksSQLStatementsSensor(DatabricksSQLStatementsMixin, BaseSensorOperator): + """DatabricksSQLStatementsSensor.""" + + template_fields: Sequence[str] = ( + "databricks_conn_id", + "statement", + "statement_id", + ) + template_ext: Sequence[str] = (".json-tpl",) + ui_color = "#1CB1C2" + ui_fgcolor = "#fff" + + def __init__( + self, + warehouse_id: str, + *, + statement: str | None = None, + statement_id: str | None = None, + catalog: str | None = None, + schema: str | None = None, + parameters: list[dict[str, Any]] | None = None, + databricks_conn_id: str = "databricks_default", + polling_period_seconds: int = 30, + databricks_retry_limit: int = 3, + databricks_retry_delay: int = 1, + databricks_retry_args: dict[Any, Any] | None = None, + do_xcom_push: bool = True, + wait_for_termination: bool = True, + timeout: float = 3600, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ): + # Handle the scenario where either both statement and statement_id are set/not set + if statement and statement_id: + raise AirflowException("Cannot provide both statement and statement_id.") + if not statement and not statement_id: + raise AirflowException("One of either statement or statement_id must be provided.") + + if not warehouse_id: + raise AirflowException("warehouse_id must be provided.") + + super().__init__(**kwargs) + + self.statement = statement + self.statement_id = statement_id + self.warehouse_id = warehouse_id + self.catalog = catalog + self.schema = schema + self.parameters = parameters + self.databricks_conn_id = databricks_conn_id + self.polling_period_seconds = polling_period_seconds + self.databricks_retry_limit = databricks_retry_limit + self.databricks_retry_delay = databricks_retry_delay + self.databricks_retry_args = databricks_retry_args + self.wait_for_termination = wait_for_termination + self.deferrable = deferrable + self.timeout = timeout + self.do_xcom_push = do_xcom_push + + @cached_property + def _hook(self): + return self._get_hook(caller="DatabricksSQLStatementsSensor") + + def _get_hook(self, caller: str) -> DatabricksHook: + return DatabricksHook( + self.databricks_conn_id, + retry_limit=self.databricks_retry_limit, + retry_delay=self.databricks_retry_delay, + retry_args=self.databricks_retry_args, + caller=caller, + ) + + def execute(self, context: Context): + if not self.statement_id: + # Otherwise, we'll go ahead and "submit" the statement + json = { + "statement": self.statement, + "warehouse_id": self.warehouse_id, + "catalog": self.catalog, + "schema": self.schema, + "parameters": self.parameters, + "wait_timeout": "0s", + } + + self.statement_id = self._hook.post_sql_statement(json) + self.log.info("SQL Statement submitted with statement_id: %s", self.statement_id) + + if self.do_xcom_push and context is not None: + context["ti"].xcom_push(key=XCOM_STATEMENT_ID_KEY, value=self.statement_id) + + # If we're not waiting for the query to complete execution, then we'll go ahead and return. However, a + # recommendation to use the DatabricksSQLStatementOperator is made in this case + if not self.wait_for_termination: + self.log.info( + "If setting wait_for_termination = False, consider using the DatabricksSQLStatementsOperator instead." + ) + return + + if self.deferrable: + self._handle_deferrable_execution(defer_method_name=DEFER_METHOD_NAME) # type: ignore[misc] + + def poke(self, context: Context): + """ + Handle non-deferrable Sensor execution. + + :param context: (Context) + :return: (bool) + """ + # This is going to very closely mirror the execute_complete + statement_state: SQLStatementState = self._hook.get_sql_statement_state(self.statement_id) + + if statement_state.is_running: + self.log.info("SQL Statement with ID %s is running", self.statement_id) + return False + if statement_state.is_successful: + self.log.info("SQL Statement with ID %s completed successfully.", self.statement_id) + return True + raise AirflowException( + f"SQL Statement with ID {statement_state} failed with error: {statement_state.error_message}" + ) diff --git a/providers/databricks/src/airflow/providers/databricks/utils/mixins.py b/providers/databricks/src/airflow/providers/databricks/utils/mixins.py new file mode 100644 index 00000000000..97d8c14e5b6 --- /dev/null +++ b/providers/databricks/src/airflow/providers/databricks/utils/mixins.py @@ -0,0 +1,193 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +from __future__ import annotations + +import time +from logging import Logger +from typing import ( + TYPE_CHECKING, + Any, + Protocol, +) + +from airflow.exceptions import AirflowException +from airflow.providers.databricks.hooks.databricks import DatabricksHook, SQLStatementState +from airflow.providers.databricks.triggers.databricks import DatabricksSQLStatementExecutionTrigger + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class GetHookHasFields(Protocol): + """Protocol for get_hook method.""" + + databricks_conn_id: str + databricks_retry_args: dict | None + databricks_retry_delay: int + databricks_retry_limit: int + + +class HandleExecutionHasFields(Protocol): + """Protocol for _handle_execution method.""" + + _hook: DatabricksHook + log: Logger + polling_period_seconds: int + task_id: str + timeout: int + statement_id: str + + +class HandleDeferrableExecutionHasFields(Protocol): + """Protocol for _handle_deferrable_execution method.""" + + _hook: DatabricksHook + databricks_conn_id: str + databricks_retry_args: dict[Any, Any] | None + databricks_retry_delay: int + databricks_retry_limit: int + defer: Any + log: Logger + polling_period_seconds: int + statement_id: str + task_id: str + timeout: int + + +class ExecuteCompleteHasFields(Protocol): + """Protocol for execute_complete method.""" + + statement_id: str + _hook: DatabricksHook + log: Logger + + +class OnKillHasFields(Protocol): + """Protocol for on_kill method.""" + + _hook: DatabricksHook + log: Logger + statement_id: str + task_id: str + + +class DatabricksSQLStatementsMixin: + """ + Mixin class to be used by both the DatabricksSqlStatementsOperator, and the DatabricksSqlStatementSensor. + + - _handle_operator_execution (renamed to _handle_execution) + - _handle_deferrable_operator_execution (renamed to _handle_deferrable_execution) + - execute_complete + - on_kill + """ + + def _handle_execution(self: HandleExecutionHasFields) -> None: + """Execute a SQL statement in non-deferrable mode.""" + # Determine the time at which the Task will timeout. The statement_state is defined here in the event + # the while-loop is never entered + end_time = time.time() + self.timeout + + while end_time > time.time(): + statement_state: SQLStatementState = self._hook.get_sql_statement_state(self.statement_id) + + if statement_state.is_terminal: + if statement_state.is_successful: + self.log.info("%s completed successfully.", self.task_id) + return + + error_message = ( + f"{self.task_id} failed with terminal state: {statement_state.state} " + f"and with the error code {statement_state.error_code} " + f"and error message {statement_state.error_message}" + ) + raise AirflowException(error_message) + + self.log.info("%s in run state: %s", self.task_id, statement_state.state) + self.log.info("Sleeping for %s seconds.", self.polling_period_seconds) + time.sleep(self.polling_period_seconds) + + # Once the timeout is exceeded, the query is cancelled. This is an important steps; if a query takes + # to log, it needs to be killed. Otherwise, it may be the case that there are "zombie" queries running + # that are no longer being orchestrated + self._hook.cancel_sql_statement(self.statement_id) + raise AirflowException( + f"{self.task_id} timed out after {self.timeout} seconds with state: {statement_state}", + ) + + def _handle_deferrable_execution( + self: HandleDeferrableExecutionHasFields, defer_method_name: str = "execute_complete" + ) -> None: + """Execute a SQL statement in deferrable mode.""" + statement_state: SQLStatementState = self._hook.get_sql_statement_state(self.statement_id) + end_time: float = time.time() + self.timeout + + if not statement_state.is_terminal: + # If the query is still running and there is no statement_id, this is somewhat of a "zombie" + # query, and should throw an exception + if not self.statement_id: + raise AirflowException("Failed to retrieve statement_id after submitting SQL statement.") + + self.defer( + trigger=DatabricksSQLStatementExecutionTrigger( + statement_id=self.statement_id, + databricks_conn_id=self.databricks_conn_id, + end_time=end_time, + polling_period_seconds=self.polling_period_seconds, + retry_limit=self.databricks_retry_limit, + retry_delay=self.databricks_retry_delay, + retry_args=self.databricks_retry_args, + ), + method_name=defer_method_name, + ) + + else: + if statement_state.is_successful: + self.log.info("%s completed successfully.", self.task_id) + else: + error_message = ( + f"{self.task_id} failed with terminal state: {statement_state.state} " + f"and with the error code {statement_state.error_code} " + f"and error message {statement_state.error_message}" + ) + raise AirflowException(error_message) + + def execute_complete(self: ExecuteCompleteHasFields, context: Context, event: dict): + statement_state = SQLStatementState.from_json(event["state"]) + error = event["error"] + statement_id = event["statement_id"] + + if statement_state.is_successful: + self.log.info("SQL Statement with ID %s completed successfully.", statement_id) + return + + error_message = f"SQL Statement execution failed with terminal state: {statement_state} and with the error {error}" + raise AirflowException(error_message) + + def on_kill(self: OnKillHasFields) -> None: + if self.statement_id: + self._hook.cancel_sql_statement(self.statement_id) + self.log.info( + "Task: %s with statement ID: %s was requested to be cancelled.", + self.task_id, + self.statement_id, + ) + else: + self.log.error( + "Error: Task: %s with invalid statement_id was requested to be cancelled.", self.task_id + ) diff --git a/providers/databricks/tests/system/databricks/example_databricks_sensors.py b/providers/databricks/tests/system/databricks/example_databricks_sensors.py index cf676183e6c..177ea8ce293 100644 --- a/providers/databricks/tests/system/databricks/example_databricks_sensors.py +++ b/providers/databricks/tests/system/databricks/example_databricks_sensors.py @@ -22,6 +22,7 @@ import textwrap from datetime import datetime from airflow import DAG +from airflow.providers.databricks.sensors.databricks import DatabricksSQLStatementsSensor from airflow.providers.databricks.sensors.databricks_partition import DatabricksPartitionSensor from airflow.providers.databricks.sensors.databricks_sql import DatabricksSqlSensor @@ -30,6 +31,7 @@ ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") # [DAG name to be shown on Airflow UI] DAG_ID = "example_databricks_sensor" + with DAG( dag_id=DAG_ID, schedule="@daily", @@ -67,6 +69,18 @@ with DAG( ) # [END howto_sensor_databricks_sql] + # [START howto_sensor_databricks_sql_statement] + # Example of using the DatabricksSQLStatementSensor to wait for a query + # to successfully run. + sql_statement_sensor = DatabricksSQLStatementsSensor( + task_id="sql_statement_sensor_task", + databricks_conn_id=connection_id, + warehouse_id="warehouse_id", + statement="select * from default.my_airflow_table", + # deferrable=True, # For using the operator in deferrable mode + ) + # [END howto_sensor_databricks_sql_statement] + # [START howto_sensor_databricks_partition] # Example of using the Databricks Partition Sensor to check the presence # of the specified partition(s) in a table. diff --git a/providers/databricks/tests/unit/databricks/sensors/test_databricks.py b/providers/databricks/tests/unit/databricks/sensors/test_databricks.py new file mode 100644 index 00000000000..dc2521a21bc --- /dev/null +++ b/providers/databricks/tests/unit/databricks/sensors/test_databricks.py @@ -0,0 +1,208 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.exceptions import AirflowException, TaskDeferred +from airflow.providers.databricks.hooks.databricks import SQLStatementState +from airflow.providers.databricks.sensors.databricks import DatabricksSQLStatementsSensor +from airflow.providers.databricks.triggers.databricks import DatabricksSQLStatementExecutionTrigger + +DEFAULT_CONN_ID = "databricks_default" +STATEMENT = "select * from test.test;" +STATEMENT_ID = "statement_id" +TASK_ID = "task_id" +WAREHOUSE_ID = "warehouse_id" + + +class TestDatabricksSQLStatementsSensor: + """ + Validate and test the functionality of the DatabricksSQLStatementsSensor. This Sensor borrows heavily + from the DatabricksSQLStatementOperator, meaning that much of the testing logic is also reused. + """ + + def test_init_statement(self): + """Test initialization for traditional use-case (statement).""" + op = DatabricksSQLStatementsSensor(task_id=TASK_ID, statement=STATEMENT, warehouse_id=WAREHOUSE_ID) + + assert op.statement == STATEMENT + assert op.warehouse_id == WAREHOUSE_ID + + def test_init_statement_id(self): + """Test initialization when a statement_id is passed, rather than a statement.""" + op = DatabricksSQLStatementsSensor( + task_id=TASK_ID, statement_id=STATEMENT_ID, warehouse_id=WAREHOUSE_ID + ) + + assert op.statement_id == STATEMENT_ID + assert op.warehouse_id == WAREHOUSE_ID + + @mock.patch("airflow.providers.databricks.sensors.databricks.DatabricksHook") + def test_exec_success(self, db_mock_class): + """ + Test the execute function for non-deferrable execution. This same exact behavior is expected when the + statement itself fails, so no test_exec_failure_statement is implemented. + """ + expected_json = { + "statement": STATEMENT, + "warehouse_id": WAREHOUSE_ID, + "catalog": None, + "schema": None, + "parameters": None, + "wait_timeout": "0s", + } + + op = DatabricksSQLStatementsSensor(task_id=TASK_ID, statement=STATEMENT, warehouse_id=WAREHOUSE_ID) + db_mock = db_mock_class.return_value + db_mock.post_sql_statement.return_value = STATEMENT_ID + + op.execute(None) # No context is being passed in + + db_mock_class.assert_called_once_with( + DEFAULT_CONN_ID, + retry_limit=op.databricks_retry_limit, + retry_delay=op.databricks_retry_delay, + retry_args=None, + caller="DatabricksSQLStatementsSensor", + ) + + # Since a statement is being passed in rather than a statement_id, we're asserting that the + # post_sql_statement method is called once + db_mock.post_sql_statement.assert_called_once_with(expected_json) + assert op.statement_id == STATEMENT_ID + + @mock.patch("airflow.providers.databricks.sensors.databricks.DatabricksHook") + def test_on_kill(self, db_mock_class): + """ + Test the on_kill method. This is actually part of the DatabricksSQLStatementMixin, so the + test logic will match that with the same name for DatabricksSQLStatementOperator. + """ + # Behavior here will remain the same whether a statement or statement_id is passed + op = DatabricksSQLStatementsSensor(task_id=TASK_ID, statement=STATEMENT, warehouse_id=WAREHOUSE_ID) + db_mock = db_mock_class.return_value + op.statement_id = STATEMENT_ID + + # When on_kill is executed, it should call the cancel_sql_statement method + op.on_kill() + db_mock.cancel_sql_statement.assert_called_once_with(STATEMENT_ID) + + def test_wait_for_termination_is_default(self): + """Validate that the default value for wait_for_termination is True.""" + op = DatabricksSQLStatementsSensor( + task_id=TASK_ID, statement="select * from test.test;", warehouse_id=WAREHOUSE_ID + ) + + assert op.wait_for_termination + + @pytest.mark.parametrize( + argnames=("statement_state", "expected_poke_result"), + argvalues=[ + ("RUNNING", False), + ("SUCCEEDED", True), + ], + ) + @mock.patch("airflow.providers.databricks.sensors.databricks.DatabricksHook") + def test_poke(self, db_mock_class, statement_state, expected_poke_result): + op = DatabricksSQLStatementsSensor( + task_id=TASK_ID, + statement=STATEMENT, + warehouse_id=WAREHOUSE_ID, + ) + db_mock = db_mock_class.return_value + db_mock.get_sql_statement_state.return_value = SQLStatementState(statement_state) + + poke_result = op.poke(None) + + assert poke_result == expected_poke_result + + @mock.patch("airflow.providers.databricks.sensors.databricks.DatabricksHook") + def test_poke_failure(self, db_mock_class): + op = DatabricksSQLStatementsSensor( + task_id=TASK_ID, + statement=STATEMENT, + warehouse_id=WAREHOUSE_ID, + ) + db_mock = db_mock_class.return_value + db_mock.get_sql_statement_state.return_value = SQLStatementState("FAILED") + + with pytest.raises(AirflowException): + op.poke(None) + + @mock.patch("airflow.providers.databricks.sensors.databricks.DatabricksHook") + def test_execute_task_deferred(self, db_mock_class): + """ + Test that the statement is successfully deferred. This behavior will remain the same whether a + statement or a statement_id is passed. + """ + op = DatabricksSQLStatementsSensor( + task_id=TASK_ID, + statement=STATEMENT, + warehouse_id=WAREHOUSE_ID, + deferrable=True, + ) + db_mock = db_mock_class.return_value + db_mock.get_sql_statement_state.return_value = SQLStatementState("RUNNING") + + with pytest.raises(TaskDeferred) as exc: + op.execute(None) + + assert isinstance(exc.value.trigger, DatabricksSQLStatementExecutionTrigger) + assert exc.value.method_name == "execute_complete" + + def test_execute_complete_success(self): + """ + Test the execute_complete function in case the Trigger has returned a successful completion event. + This method is part of the DatabricksSQLStatementsMixin. Note that this is only being tested when + in deferrable mode. + """ + event = { + "statement_id": STATEMENT_ID, + "state": SQLStatementState("SUCCEEDED").to_json(), + "error": {}, + } + + op = DatabricksSQLStatementsSensor( + task_id=TASK_ID, + statement=STATEMENT, + warehouse_id=WAREHOUSE_ID, + deferrable=True, + ) + assert op.execute_complete(context=None, event=event) is None + + @mock.patch("airflow.providers.databricks.sensors.databricks.DatabricksHook") + def test_execute_complete_failure(self, db_mock_class): + """Test execute_complete function in case the Trigger has returned a failure completion event.""" + event = { + "statement_id": STATEMENT_ID, + "state": SQLStatementState("FAILED").to_json(), + "error": SQLStatementState( + state="FAILED", error_code="500", error_message="Something Went Wrong" + ).to_json(), + } + op = DatabricksSQLStatementsSensor( + task_id=TASK_ID, + statement=STATEMENT, + warehouse_id=WAREHOUSE_ID, + deferrable=True, + ) + + with pytest.raises(AirflowException, match="^SQL Statement execution failed with terminal state: .*"): + op.execute_complete(context=None, event=event) diff --git a/providers/databricks/tests/unit/databricks/utils/test_mixins.py b/providers/databricks/tests/unit/databricks/utils/test_mixins.py new file mode 100644 index 00000000000..95db64edfd2 --- /dev/null +++ b/providers/databricks/tests/unit/databricks/utils/test_mixins.py @@ -0,0 +1,127 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import MagicMock + +import pytest + +from airflow.exceptions import AirflowException +from airflow.providers.databricks.utils.mixins import DatabricksSQLStatementsMixin + + +class DatabricksSQLStatements(DatabricksSQLStatementsMixin): + def __init__(self): + self.databricks_conn_id = "databricks_conn_id" + self.databricks_retry_limit = 3 + self.databricks_retry_delay = 60 + self.databricks_retry_args = None + self.polling_period_seconds = 10 + self.statement_id = "statement_id" + self.task_id = "task_id" + self.timeout = 60 + + # Utilities + self._hook = MagicMock() + self.defer = MagicMock() + self.log = MagicMock() + + +@pytest.fixture +def databricks_sql_statements(): + return DatabricksSQLStatements() + + +@pytest.fixture +def terminal_success_state(): + terminal_success_state = MagicMock() + terminal_success_state.is_terminal = True + terminal_success_state.is_successful = True + return terminal_success_state + + +@pytest.fixture +def terminal_failure_state(): + terminal_fail_state = MagicMock() + terminal_fail_state.is_terminal = True + terminal_fail_state.is_successful = False + terminal_fail_state.state = "FAILED" + terminal_fail_state.error_code = "123" + terminal_fail_state.error_message = "Query failed" + return terminal_fail_state + + +class TestDatabricksSQLStatementsMixin: + """ + We'll provide tests for each of the following methods: + + - _handle_execution + - _handle_deferrable_execution + - execute_complete + - on_kill + """ + + def test_handle_execution_success(self, databricks_sql_statements, terminal_success_state): + # Test an immediate success of the SQL statement + databricks_sql_statements._hook.get_sql_statement_state.return_value = terminal_success_state + databricks_sql_statements._handle_execution() + + databricks_sql_statements._hook.cancel_sql_statement.assert_not_called() + + def test_handle_execution_failure(self, databricks_sql_statements, terminal_failure_state): + # Test an immediate failure of the SQL statement + databricks_sql_statements._hook.get_sql_statement_state.return_value = terminal_failure_state + + with pytest.raises(AirflowException): + databricks_sql_statements._handle_execution() + + databricks_sql_statements._hook.cancel_sql_statement.assert_not_called() + + def test_handle_deferrable_execution_running(self, databricks_sql_statements): + terminal_running_state = MagicMock() + terminal_running_state.is_terminal = False + + # Test an immediate success of the SQL statement + databricks_sql_statements._hook.get_sql_statement_state.return_value = terminal_running_state + databricks_sql_statements._handle_deferrable_execution() + + databricks_sql_statements.defer.assert_called_once() + + def test_handle_deferrable_execution_success(self, databricks_sql_statements, terminal_success_state): + # Test an immediate success of the SQL statement + databricks_sql_statements._hook.get_sql_statement_state.return_value = terminal_success_state + databricks_sql_statements._handle_deferrable_execution() + + databricks_sql_statements.defer.assert_not_called() + + def test_handle_deferrable_execution_failure(self, databricks_sql_statements, terminal_failure_state): + # Test an immediate failure of the SQL statement + databricks_sql_statements._hook.get_sql_statement_state.return_value = terminal_failure_state + + with pytest.raises(AirflowException): + databricks_sql_statements._handle_deferrable_execution() + + def test_execute_complete(self): + # Both the TestDatabricksSQLStatementsOperator and TestDatabricksSQLStatementsSensor tests implement + # a test_execute_complete_failure and test_execute_complete_success method, so we'll pass here + pass + + def test_on_kill(self): + # This test is implemented in both the TestDatabricksSQLStatementsOperator and + # TestDatabricksSQLStatementsSensor tests, so it will not be implemented here + pass