This is an automated email from the ASF dual-hosted git repository.

pankajkoti pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 1e4663f34c add deferrable support to `DatabricksNotebookOperator` 
(#39295)
1e4663f34c is described below

commit 1e4663f34c2fb42b87cf75e4776650620eb2baa4
Author: Kalyan <kalyan.be...@live.com>
AuthorDate: Tue May 14 19:48:17 2024 +0530

    add deferrable support to `DatabricksNotebookOperator` (#39295)
    
    related: #39178
    
    This PR intends to make DatabricksNotebookOperator deferrable
---
 .../providers/databricks/hooks/databricks_base.py  |  1 +
 .../providers/databricks/operators/databricks.py   | 40 +++++++++++++++---
 .../providers/databricks/triggers/databricks.py    |  2 +
 .../databricks/operators/test_databricks.py        | 48 +++++++++++++++++++++-
 4 files changed, 83 insertions(+), 8 deletions(-)

diff --git a/airflow/providers/databricks/hooks/databricks_base.py 
b/airflow/providers/databricks/hooks/databricks_base.py
index 32316d49bb..2dee924f61 100644
--- a/airflow/providers/databricks/hooks/databricks_base.py
+++ b/airflow/providers/databricks/hooks/databricks_base.py
@@ -80,6 +80,7 @@ class BaseDatabricksHook(BaseHook):
     :param retry_delay: The number of seconds to wait between retries (it
         might be a floating point number).
     :param retry_args: An optional dictionary with arguments passed to 
``tenacity.Retrying`` class.
+    :param caller: The name of the operator that is calling the hook.
     """
 
     conn_name_attr: str = "databricks_conn_id"
diff --git a/airflow/providers/databricks/operators/databricks.py 
b/airflow/providers/databricks/operators/databricks.py
index 0d819e1b70..7ae802db10 100644
--- a/airflow/providers/databricks/operators/databricks.py
+++ b/airflow/providers/databricks/operators/databricks.py
@@ -167,7 +167,7 @@ def 
_handle_deferrable_databricks_operator_completion(event: dict, log: Logger)
 
     error_message = f"Job run failed with terminal state: {run_state} and with 
the errors {errors}"
 
-    if event["repair_run"]:
+    if event.get("repair_run"):
         log.warning(
             "%s but since repair run is set, repairing the run with all failed 
tasks",
             error_message,
@@ -923,9 +923,11 @@ class DatabricksNotebookOperator(BaseOperator):
     :param databricks_retry_args: An optional dictionary with arguments passed 
to ``tenacity.Retrying`` class.
     :param wait_for_termination: if we should wait for termination of the job 
run. ``True`` by default.
     :param databricks_conn_id: The name of the Airflow connection to use.
+    :param deferrable: Run operator in the deferrable mode.
     """
 
     template_fields = ("notebook_params",)
+    CALLER = "DatabricksNotebookOperator"
 
     def __init__(
         self,
@@ -942,6 +944,7 @@ class DatabricksNotebookOperator(BaseOperator):
         databricks_retry_args: dict[Any, Any] | None = None,
         wait_for_termination: bool = True,
         databricks_conn_id: str = "databricks_default",
+        deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
         **kwargs: Any,
     ):
         self.notebook_path = notebook_path
@@ -958,11 +961,12 @@ class DatabricksNotebookOperator(BaseOperator):
         self.wait_for_termination = wait_for_termination
         self.databricks_conn_id = databricks_conn_id
         self.databricks_run_id: int | None = None
+        self.deferrable = deferrable
         super().__init__(**kwargs)
 
     @cached_property
     def _hook(self) -> DatabricksHook:
-        return self._get_hook(caller="DatabricksNotebookOperator")
+        return self._get_hook(caller=self.CALLER)
 
     def _get_hook(self, caller: str) -> DatabricksHook:
         return DatabricksHook(
@@ -970,7 +974,7 @@ class DatabricksNotebookOperator(BaseOperator):
             retry_limit=self.databricks_retry_limit,
             retry_delay=self.databricks_retry_delay,
             retry_args=self.databricks_retry_args,
-            caller=caller,
+            caller=self.CALLER,
         )
 
     def _get_task_timeout_seconds(self) -> int:
@@ -1041,6 +1045,19 @@ class DatabricksNotebookOperator(BaseOperator):
         run = self._hook.get_run(self.databricks_run_id)
         run_state = RunState(**run["state"])
         self.log.info("Current state of the job: %s", 
run_state.life_cycle_state)
+        if self.deferrable and not run_state.is_terminal:
+            return self.defer(
+                trigger=DatabricksExecutionTrigger(
+                    run_id=self.databricks_run_id,
+                    databricks_conn_id=self.databricks_conn_id,
+                    polling_period_seconds=self.polling_period_seconds,
+                    retry_limit=self.databricks_retry_limit,
+                    retry_delay=self.databricks_retry_delay,
+                    retry_args=self.databricks_retry_args,
+                    caller=self.CALLER,
+                ),
+                method_name=DEFER_METHOD_NAME,
+            )
         while not run_state.is_terminal:
             time.sleep(self.polling_period_seconds)
             run = self._hook.get_run(self.databricks_run_id)
@@ -1056,9 +1073,7 @@ class DatabricksNotebookOperator(BaseOperator):
             )
         if not run_state.is_successful:
             raise AirflowException(
-                "Task failed. Final state %s. Reason: %s",
-                run_state.result_state,
-                run_state.state_message,
+                f"Task failed. Final state {run_state.result_state}. Reason: 
{run_state.state_message}"
             )
         self.log.info("Task succeeded. Final state %s.", 
run_state.result_state)
 
@@ -1066,3 +1081,16 @@ class DatabricksNotebookOperator(BaseOperator):
         self.launch_notebook_job()
         if self.wait_for_termination:
             self.monitor_databricks_job()
+
+    def execute_complete(self, context: dict | None, event: dict) -> None:
+        run_state = RunState.from_json(event["run_state"])
+        if run_state.life_cycle_state != "TERMINATED":
+            raise AirflowException(
+                f"Databricks job failed with state 
{run_state.life_cycle_state}. "
+                f"Message: {run_state.state_message}"
+            )
+        if not run_state.is_successful:
+            raise AirflowException(
+                f"Task failed. Final state {run_state.result_state}. Reason: 
{run_state.state_message}"
+            )
+        self.log.info("Task succeeded. Final state %s.", 
run_state.result_state)
diff --git a/airflow/providers/databricks/triggers/databricks.py 
b/airflow/providers/databricks/triggers/databricks.py
index d20202fdca..55845fc6f7 100644
--- a/airflow/providers/databricks/triggers/databricks.py
+++ b/airflow/providers/databricks/triggers/databricks.py
@@ -48,6 +48,7 @@ class DatabricksExecutionTrigger(BaseTrigger):
         retry_args: dict[Any, Any] | None = None,
         run_page_url: str | None = None,
         repair_run: bool = False,
+        caller: str = "DatabricksExecutionTrigger",
     ) -> None:
         super().__init__()
         self.run_id = run_id
@@ -63,6 +64,7 @@ class DatabricksExecutionTrigger(BaseTrigger):
             retry_limit=self.retry_limit,
             retry_delay=self.retry_delay,
             retry_args=retry_args,
+            caller=caller,
         )
 
     def serialize(self) -> tuple[str, dict[str, Any]]:
diff --git a/tests/providers/databricks/operators/test_databricks.py 
b/tests/providers/databricks/operators/test_databricks.py
index 64b9ba985c..d6e7eb3892 100644
--- a/tests/providers/databricks/operators/test_databricks.py
+++ b/tests/providers/databricks/operators/test_databricks.py
@@ -1865,6 +1865,50 @@ class TestDatabricksNotebookOperator:
         operator.launch_notebook_job.assert_called_once()
         operator.monitor_databricks_job.assert_not_called()
 
+    
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
+    def test_execute_with_deferrable(self, mock_databricks_hook):
+        mock_databricks_hook.return_value.get_run.return_value = {"state": 
{"life_cycle_state": "PENDING"}}
+        operator = DatabricksNotebookOperator(
+            task_id="test_task",
+            notebook_path="test_path",
+            source="test_source",
+            databricks_conn_id="test_conn_id",
+            wait_for_termination=True,
+            deferrable=True,
+        )
+        operator.databricks_run_id = 12345
+
+        with pytest.raises(TaskDeferred) as exec_info:
+            operator.monitor_databricks_job()
+        assert isinstance(
+            exec_info.value.trigger, DatabricksExecutionTrigger
+        ), "Trigger is not a DatabricksExecutionTrigger"
+        assert exec_info.value.method_name == "execute_complete"
+
+    
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
+    def test_execute_with_deferrable_early_termination(self, 
mock_databricks_hook):
+        mock_databricks_hook.return_value.get_run.return_value = {
+            "state": {
+                "life_cycle_state": "TERMINATED",
+                "result_state": "FAILED",
+                "state_message": "FAILURE",
+            }
+        }
+        operator = DatabricksNotebookOperator(
+            task_id="test_task",
+            notebook_path="test_path",
+            source="test_source",
+            databricks_conn_id="test_conn_id",
+            wait_for_termination=True,
+            deferrable=True,
+        )
+        operator.databricks_run_id = 12345
+
+        with pytest.raises(AirflowException) as exec_info:
+            operator.monitor_databricks_job()
+        exception_message = "Task failed. Final state FAILED. Reason: FAILURE"
+        assert exception_message == str(exec_info.value)
+
     
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
     def test_monitor_databricks_job_successful_raises_no_exception(self, 
mock_databricks_hook):
         mock_databricks_hook.return_value.get_run.return_value = {
@@ -1896,10 +1940,10 @@ class TestDatabricksNotebookOperator:
 
         operator.databricks_run_id = 12345
 
-        exception_message = "'Task failed. Final state %s. Reason: %s', 
'FAILED', 'FAILURE'"
         with pytest.raises(AirflowException) as exc_info:
             operator.monitor_databricks_job()
-        assert exception_message in str(exc_info.value)
+        exception_message = "Task failed. Final state FAILED. Reason: FAILURE"
+        assert exception_message == str(exc_info.value)
 
     
@mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook")
     def test_launch_notebook_job(self, mock_databricks_hook):

Reply via email to