This is an automated email from the ASF dual-hosted git repository.

pankaj pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 574102fd29 [FEAT] adds repair run functionality for databricks (#36601)
574102fd29 is described below

commit 574102fd291930ed45262a40fb7033a122152541
Author: gaurav7261 <142777151+gaurav7...@users.noreply.github.com>
AuthorDate: Thu Jan 11 22:24:47 2024 +0530

    [FEAT] adds repair run functionality for databricks (#36601)
    
    * [FEAT] adds repair run functionality for databricks
    
    * [FIX] addded latest repair run and test cases
    
    * [FIX] comma typo
    
    * [FIX] check for DatabricksRunNowOperator instance before doing repair run
    
    * [FIX] fixed static checks
    
    * [FIX] fixed static checks
    
    * Update airflow/providers/databricks/hooks/databricks.py
    
    Co-authored-by: Andrey Anshin <andrey.ans...@taragol.is>
    
    * [FIX]  type annotations
    
    * [FIX] change from log.warn to log.warning
    
    * Update airflow/providers/databricks/operators/databricks.py
    
    Co-authored-by: Andrey Anshin <andrey.ans...@taragol.is>
    
    * [FIX] CI Static check
    
    ---------
    
    Co-authored-by: GauravM 
<gau...@ip-192-168-0-100.ap-south-1.compute.internal>
    Co-authored-by: GauravM 
<gau...@ip-192-168-0-101.ap-south-1.compute.internal>
    Co-authored-by: GauravM <gau...@ip-10-20-1-171.ap-south-1.compute.internal>
    Co-authored-by: Andrey Anshin <andrey.ans...@taragol.is>
---
 airflow/providers/databricks/hooks/databricks.py   | 15 ++++-
 .../providers/databricks/operators/databricks.py   | 17 +++++
 .../providers/databricks/hooks/test_databricks.py  | 73 ++++++++++++++++++++++
 3 files changed, 103 insertions(+), 2 deletions(-)

diff --git a/airflow/providers/databricks/hooks/databricks.py 
b/airflow/providers/databricks/hooks/databricks.py
index b39e3d622c..bc3bd90209 100644
--- a/airflow/providers/databricks/hooks/databricks.py
+++ b/airflow/providers/databricks/hooks/databricks.py
@@ -519,13 +519,24 @@ class DatabricksHook(BaseDatabricksHook):
         json = {"run_id": run_id}
         self._do_api_call(DELETE_RUN_ENDPOINT, json)
 
-    def repair_run(self, json: dict) -> None:
+    def repair_run(self, json: dict) -> int:
         """
         Re-run one or more tasks.
 
         :param json: repair a job run.
         """
-        self._do_api_call(REPAIR_RUN_ENDPOINT, json)
+        response = self._do_api_call(REPAIR_RUN_ENDPOINT, json)
+        return response["repair_id"]
+
+    def get_latest_repair_id(self, run_id: int) -> int | None:
+        """Get latest repair id if any exist for run_id else None."""
+        json = {"run_id": run_id, "include_history": True}
+        response = self._do_api_call(GET_RUN_ENDPOINT, json)
+        repair_history = response["repair_history"]
+        if len(repair_history) == 1:
+            return None
+        else:
+            return repair_history[-1]["id"]
 
     def get_cluster_state(self, cluster_id: str) -> ClusterState:
         """
diff --git a/airflow/providers/databricks/operators/databricks.py 
b/airflow/providers/databricks/operators/databricks.py
index edea8b4e59..5d8b62643f 100644
--- a/airflow/providers/databricks/operators/databricks.py
+++ b/airflow/providers/databricks/operators/databricks.py
@@ -88,6 +88,19 @@ def _handle_databricks_operator_execution(operator, hook, 
log, context) -> None:
                             f"{operator.task_id} failed with terminal state: 
{run_state} "
                             f"and with the error {run_state.state_message}"
                         )
+                    if isinstance(operator, DatabricksRunNowOperator) and 
operator.repair_run:
+                        operator.repair_run = False
+                        log.warning(
+                            "%s but since repair run is set, repairing the run 
with all failed tasks",
+                            error_message,
+                        )
+
+                        latest_repair_id = 
hook.get_latest_repair_id(operator.run_id)
+                        repair_json = {"run_id": operator.run_id, 
"rerun_all_failed_tasks": True}
+                        if latest_repair_id is not None:
+                            repair_json["latest_repair_id"] = latest_repair_id
+                        operator.json["latest_repair_id"] = 
hook.repair_run(operator, repair_json)
+                        _handle_databricks_operator_execution(operator, hook, 
log, context)
                     raise AirflowException(error_message)
 
             else:
@@ -623,6 +636,7 @@ class DatabricksRunNowOperator(BaseOperator):
         - ``jar_params``
         - ``spark_submit_params``
         - ``idempotency_token``
+        - ``repair_run``
 
     :param job_id: the job_id of the existing Databricks job.
         This field will be templated.
@@ -711,6 +725,7 @@ class DatabricksRunNowOperator(BaseOperator):
     :param do_xcom_push: Whether we should push run_id and run_page_url to 
xcom.
     :param wait_for_termination: if we should wait for termination of the job 
run. ``True`` by default.
     :param deferrable: Run operator in the deferrable mode.
+    :param repair_run: Repair the databricks run in case of failure, doesn't 
work in deferrable mode
     """
 
     # Used in airflow.models.BaseOperator
@@ -741,6 +756,7 @@ class DatabricksRunNowOperator(BaseOperator):
         do_xcom_push: bool = True,
         wait_for_termination: bool = True,
         deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
+        repair_run: bool = False,
         **kwargs,
     ) -> None:
         """Create a new ``DatabricksRunNowOperator``."""
@@ -753,6 +769,7 @@ class DatabricksRunNowOperator(BaseOperator):
         self.databricks_retry_args = databricks_retry_args
         self.wait_for_termination = wait_for_termination
         self.deferrable = deferrable
+        self.repair_run = repair_run
 
         if job_id is not None:
             self.json["job_id"] = job_id
diff --git a/tests/providers/databricks/hooks/test_databricks.py 
b/tests/providers/databricks/hooks/test_databricks.py
index 1baaab1fea..c9004e7175 100644
--- a/tests/providers/databricks/hooks/test_databricks.py
+++ b/tests/providers/databricks/hooks/test_databricks.py
@@ -683,6 +683,79 @@ class TestDatabricksHook:
             timeout=self.hook.timeout_seconds,
         )
 
+    @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
+    def test_negative_get_latest_repair_id(self, mock_requests):
+        mock_requests.codes.ok = 200
+        mock_requests.get.return_value.json.return_value = {
+            "job_id": JOB_ID,
+            "run_id": RUN_ID,
+            "state": {"life_cycle_state": "RUNNING", "result_state": 
"RUNNING"},
+            "repair_history": [
+                {
+                    "type": "ORIGINAL",
+                    "start_time": 1704528798059,
+                    "end_time": 1704529026679,
+                    "state": {
+                        "life_cycle_state": "RUNNING",
+                        "result_state": "RUNNING",
+                        "state_message": "dummy",
+                        "user_cancelled_or_timedout": "false",
+                    },
+                    "task_run_ids": [396529700633015, 1111270934390307],
+                }
+            ],
+        }
+        latest_repair_id = self.hook.get_latest_repair_id(RUN_ID)
+
+        assert latest_repair_id is None
+
+    @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
+    def test_positive_get_latest_repair_id(self, mock_requests):
+        mock_requests.codes.ok = 200
+        mock_requests.get.return_value.json.return_value = {
+            "job_id": JOB_ID,
+            "run_id": RUN_ID,
+            "state": {"life_cycle_state": "RUNNING", "result_state": 
"RUNNING"},
+            "repair_history": [
+                {
+                    "type": "ORIGINAL",
+                    "start_time": 1704528798059,
+                    "end_time": 1704529026679,
+                    "state": {
+                        "life_cycle_state": "TERMINATED",
+                        "result_state": "CANCELED",
+                        "state_message": "dummy_original",
+                        "user_cancelled_or_timedout": "false",
+                    },
+                    "task_run_ids": [396529700633015, 1111270934390307],
+                },
+                {
+                    "type": "REPAIR",
+                    "start_time": 1704530276423,
+                    "end_time": 1704530363736,
+                    "state": {
+                        "life_cycle_state": "TERMINATED",
+                        "result_state": "CANCELED",
+                        "state_message": "dummy_repair_1",
+                        "user_cancelled_or_timedout": "true",
+                    },
+                    "id": 108607572123234,
+                    "task_run_ids": [396529700633015, 1111270934390307],
+                },
+                {
+                    "type": "REPAIR",
+                    "start_time": 1704531464690,
+                    "end_time": 1704531481590,
+                    "state": {"life_cycle_state": "RUNNING", "result_state": 
"RUNNING"},
+                    "id": 52532060060836,
+                    "task_run_ids": [396529700633015, 1111270934390307],
+                },
+            ],
+        }
+        latest_repair_id = self.hook.get_latest_repair_id(RUN_ID)
+
+        assert latest_repair_id == 52532060060836
+
     @mock.patch("airflow.providers.databricks.hooks.databricks_base.requests")
     def test_get_cluster_state(self, mock_requests):
         """

Reply via email to