This is an automated email from the ASF dual-hosted git repository. pankajkoti pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 1e4663f34c add deferrable support to `DatabricksNotebookOperator` (#39295) 1e4663f34c is described below commit 1e4663f34c2fb42b87cf75e4776650620eb2baa4 Author: Kalyan <kalyan.be...@live.com> AuthorDate: Tue May 14 19:48:17 2024 +0530 add deferrable support to `DatabricksNotebookOperator` (#39295) related: #39178 This PR intends to make DatabricksNotebookOperator deferrable --- .../providers/databricks/hooks/databricks_base.py | 1 + .../providers/databricks/operators/databricks.py | 40 +++++++++++++++--- .../providers/databricks/triggers/databricks.py | 2 + .../databricks/operators/test_databricks.py | 48 +++++++++++++++++++++- 4 files changed, 83 insertions(+), 8 deletions(-) diff --git a/airflow/providers/databricks/hooks/databricks_base.py b/airflow/providers/databricks/hooks/databricks_base.py index 32316d49bb..2dee924f61 100644 --- a/airflow/providers/databricks/hooks/databricks_base.py +++ b/airflow/providers/databricks/hooks/databricks_base.py @@ -80,6 +80,7 @@ class BaseDatabricksHook(BaseHook): :param retry_delay: The number of seconds to wait between retries (it might be a floating point number). :param retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class. + :param caller: The name of the operator that is calling the hook. """ conn_name_attr: str = "databricks_conn_id" diff --git a/airflow/providers/databricks/operators/databricks.py b/airflow/providers/databricks/operators/databricks.py index 0d819e1b70..7ae802db10 100644 --- a/airflow/providers/databricks/operators/databricks.py +++ b/airflow/providers/databricks/operators/databricks.py @@ -167,7 +167,7 @@ def _handle_deferrable_databricks_operator_completion(event: dict, log: Logger) error_message = f"Job run failed with terminal state: {run_state} and with the errors {errors}" - if event["repair_run"]: + if event.get("repair_run"): log.warning( "%s but since repair run is set, repairing the run with all failed tasks", error_message, @@ -923,9 +923,11 @@ class DatabricksNotebookOperator(BaseOperator): :param databricks_retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class. :param wait_for_termination: if we should wait for termination of the job run. ``True`` by default. :param databricks_conn_id: The name of the Airflow connection to use. + :param deferrable: Run operator in the deferrable mode. """ template_fields = ("notebook_params",) + CALLER = "DatabricksNotebookOperator" def __init__( self, @@ -942,6 +944,7 @@ class DatabricksNotebookOperator(BaseOperator): databricks_retry_args: dict[Any, Any] | None = None, wait_for_termination: bool = True, databricks_conn_id: str = "databricks_default", + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), **kwargs: Any, ): self.notebook_path = notebook_path @@ -958,11 +961,12 @@ class DatabricksNotebookOperator(BaseOperator): self.wait_for_termination = wait_for_termination self.databricks_conn_id = databricks_conn_id self.databricks_run_id: int | None = None + self.deferrable = deferrable super().__init__(**kwargs) @cached_property def _hook(self) -> DatabricksHook: - return self._get_hook(caller="DatabricksNotebookOperator") + return self._get_hook(caller=self.CALLER) def _get_hook(self, caller: str) -> DatabricksHook: return DatabricksHook( @@ -970,7 +974,7 @@ class DatabricksNotebookOperator(BaseOperator): retry_limit=self.databricks_retry_limit, retry_delay=self.databricks_retry_delay, retry_args=self.databricks_retry_args, - caller=caller, + caller=self.CALLER, ) def _get_task_timeout_seconds(self) -> int: @@ -1041,6 +1045,19 @@ class DatabricksNotebookOperator(BaseOperator): run = self._hook.get_run(self.databricks_run_id) run_state = RunState(**run["state"]) self.log.info("Current state of the job: %s", run_state.life_cycle_state) + if self.deferrable and not run_state.is_terminal: + return self.defer( + trigger=DatabricksExecutionTrigger( + run_id=self.databricks_run_id, + databricks_conn_id=self.databricks_conn_id, + polling_period_seconds=self.polling_period_seconds, + retry_limit=self.databricks_retry_limit, + retry_delay=self.databricks_retry_delay, + retry_args=self.databricks_retry_args, + caller=self.CALLER, + ), + method_name=DEFER_METHOD_NAME, + ) while not run_state.is_terminal: time.sleep(self.polling_period_seconds) run = self._hook.get_run(self.databricks_run_id) @@ -1056,9 +1073,7 @@ class DatabricksNotebookOperator(BaseOperator): ) if not run_state.is_successful: raise AirflowException( - "Task failed. Final state %s. Reason: %s", - run_state.result_state, - run_state.state_message, + f"Task failed. Final state {run_state.result_state}. Reason: {run_state.state_message}" ) self.log.info("Task succeeded. Final state %s.", run_state.result_state) @@ -1066,3 +1081,16 @@ class DatabricksNotebookOperator(BaseOperator): self.launch_notebook_job() if self.wait_for_termination: self.monitor_databricks_job() + + def execute_complete(self, context: dict | None, event: dict) -> None: + run_state = RunState.from_json(event["run_state"]) + if run_state.life_cycle_state != "TERMINATED": + raise AirflowException( + f"Databricks job failed with state {run_state.life_cycle_state}. " + f"Message: {run_state.state_message}" + ) + if not run_state.is_successful: + raise AirflowException( + f"Task failed. Final state {run_state.result_state}. Reason: {run_state.state_message}" + ) + self.log.info("Task succeeded. Final state %s.", run_state.result_state) diff --git a/airflow/providers/databricks/triggers/databricks.py b/airflow/providers/databricks/triggers/databricks.py index d20202fdca..55845fc6f7 100644 --- a/airflow/providers/databricks/triggers/databricks.py +++ b/airflow/providers/databricks/triggers/databricks.py @@ -48,6 +48,7 @@ class DatabricksExecutionTrigger(BaseTrigger): retry_args: dict[Any, Any] | None = None, run_page_url: str | None = None, repair_run: bool = False, + caller: str = "DatabricksExecutionTrigger", ) -> None: super().__init__() self.run_id = run_id @@ -63,6 +64,7 @@ class DatabricksExecutionTrigger(BaseTrigger): retry_limit=self.retry_limit, retry_delay=self.retry_delay, retry_args=retry_args, + caller=caller, ) def serialize(self) -> tuple[str, dict[str, Any]]: diff --git a/tests/providers/databricks/operators/test_databricks.py b/tests/providers/databricks/operators/test_databricks.py index 64b9ba985c..d6e7eb3892 100644 --- a/tests/providers/databricks/operators/test_databricks.py +++ b/tests/providers/databricks/operators/test_databricks.py @@ -1865,6 +1865,50 @@ class TestDatabricksNotebookOperator: operator.launch_notebook_job.assert_called_once() operator.monitor_databricks_job.assert_not_called() + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_execute_with_deferrable(self, mock_databricks_hook): + mock_databricks_hook.return_value.get_run.return_value = {"state": {"life_cycle_state": "PENDING"}} + operator = DatabricksNotebookOperator( + task_id="test_task", + notebook_path="test_path", + source="test_source", + databricks_conn_id="test_conn_id", + wait_for_termination=True, + deferrable=True, + ) + operator.databricks_run_id = 12345 + + with pytest.raises(TaskDeferred) as exec_info: + operator.monitor_databricks_job() + assert isinstance( + exec_info.value.trigger, DatabricksExecutionTrigger + ), "Trigger is not a DatabricksExecutionTrigger" + assert exec_info.value.method_name == "execute_complete" + + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") + def test_execute_with_deferrable_early_termination(self, mock_databricks_hook): + mock_databricks_hook.return_value.get_run.return_value = { + "state": { + "life_cycle_state": "TERMINATED", + "result_state": "FAILED", + "state_message": "FAILURE", + } + } + operator = DatabricksNotebookOperator( + task_id="test_task", + notebook_path="test_path", + source="test_source", + databricks_conn_id="test_conn_id", + wait_for_termination=True, + deferrable=True, + ) + operator.databricks_run_id = 12345 + + with pytest.raises(AirflowException) as exec_info: + operator.monitor_databricks_job() + exception_message = "Task failed. Final state FAILED. Reason: FAILURE" + assert exception_message == str(exec_info.value) + @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_monitor_databricks_job_successful_raises_no_exception(self, mock_databricks_hook): mock_databricks_hook.return_value.get_run.return_value = { @@ -1896,10 +1940,10 @@ class TestDatabricksNotebookOperator: operator.databricks_run_id = 12345 - exception_message = "'Task failed. Final state %s. Reason: %s', 'FAILED', 'FAILURE'" with pytest.raises(AirflowException) as exc_info: operator.monitor_databricks_job() - assert exception_message in str(exc_info.value) + exception_message = "Task failed. Final state FAILED. Reason: FAILURE" + assert exception_message == str(exc_info.value) @mock.patch("airflow.providers.databricks.operators.databricks.DatabricksHook") def test_launch_notebook_job(self, mock_databricks_hook):