This is an automated email from the ASF dual-hosted git repository.

amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 28006a524c0 Add crash recovery ability to SparkSubmitOperator against 
Kubernetes (#68067)
28006a524c0 is described below

commit 28006a524c02eb261f9c0a1c57229e470f318957
Author: Amogh Desai <[email protected]>
AuthorDate: Mon Jun 15 13:27:15 2026 +0530

    Add crash recovery ability to SparkSubmitOperator against Kubernetes 
(#68067)
---
 providers/apache/spark/docs/operators.rst          |   7 +-
 .../providers/apache/spark/hooks/spark_submit.py   |  31 +++-
 .../apache/spark/operators/spark_submit.py         |  81 +++++++--
 .../apache/spark/operators/test_spark_submit.py    | 189 +++++++++++++++++++++
 4 files changed, 287 insertions(+), 21 deletions(-)

diff --git a/providers/apache/spark/docs/operators.rst 
b/providers/apache/spark/docs/operators.rst
index d4520578420..4d3a9a526af 100644
--- a/providers/apache/spark/docs/operators.rst
+++ b/providers/apache/spark/docs/operators.rst
@@ -236,6 +236,7 @@ Python Kubernetes client rather than holding 
``spark-submit`` open for the full
        conn_id="spark_k8s",
        deploy_mode="cluster",
        track_driver_via_k8s_api=True,
+       reconnect_on_retry=True,
    )
 
 **Requirements**
@@ -245,8 +246,10 @@ Python Kubernetes client rather than holding 
``spark-submit`` open for the full
   conflicts with the flag and a ``ValueError`` will be raised at task start.
 * The Airflow worker must be able to reach the Kubernetes API server and have 
permission to
   read and delete pods in the driver's namespace; otherwise pod tracking and 
cleanup will fail.
-* This path bypasses ``ResumableJobMixin``, so Airflow retries submit a fresh 
driver instead of
-  reconnecting to an existing one. Set ``execution_timeout`` to bound 
wall-clock time.
+* Set ``reconnect_on_retry=True`` (the default) to enable crash recovery: the 
driver pod name is
+  persisted to task state before polling begins, so a worker crash and retry 
reconnects to the
+  existing pod instead of submitting a fresh one. Set 
``reconnect_on_retry=False`` to always
+  submit a fresh driver on retry.
 * Pod completion is detected from ``pod.status.phase``. If your driver pods 
have sidecar
   containers (e.g. Istio injection enabled for the driver namespace), the pod 
phase may not
   advance to ``Succeeded`` until all sidecars exit. In that case the poll loop 
will wait
diff --git 
a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
 
b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
index 3a19950696a..662966e4e14 100644
--- 
a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
+++ 
b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
@@ -1138,8 +1138,14 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
                         f"returncode = {returncode}"
                     )
 
-    def _poll_k8s_driver_via_api(self) -> None:
-        """Poll the K8s driver pod phase until it reaches a terminal state."""
+    def _poll_k8s_driver_via_api(self) -> str | None:
+        """
+        Poll the K8s driver pod phase until it reaches a terminal state.
+
+        Returns the terminal phase string (e.g. ``"Succeeded"``) on normal 
completion,
+        or ``None`` if the pod vanished mid-poll (404 — likely deleted by 
``on_kill``).
+        Raises ``RuntimeError`` on failure phases or unrecoverable API errors.
+        """
         pod_name = self._kubernetes_driver_pod
         namespace = self._connection["namespace"]
         app_id = self._kubernetes_application_id or pod_name
@@ -1173,7 +1179,7 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
                             "Driver pod %s not found (404); pod was likely 
deleted by on_kill. Exiting poll loop.",
                             pod_name,
                         )
-                        return
+                        return None
                     consecutive_api_errors += 1
                     self.log.warning(
                         "ApiException polling pod %s (%d/%d): %s",
@@ -1193,6 +1199,18 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
                 phase = pod.status.phase or "Initializing"
                 self.log.info("Application status for %s (phase: %s)", app_id, 
phase)
                 if phase == "Succeeded":
+                    if pod.status.container_statuses:
+                        cs = pod.status.container_statuses[0]
+                        if cs.state and cs.state.terminated:
+                            t = cs.state.terminated
+                            self.log.info(
+                                "Container final status: exit_code=%s 
reason=%s started_at=%s finished_at=%s",
+                                t.exit_code,
+                                t.reason,
+                                t.started_at,
+                                t.finished_at,
+                            )
+                    terminal_phase = phase
                     break
                 if phase == "Failed":
                     container_state = ""
@@ -1224,7 +1242,12 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
                 else:
                     consecutive_unknown = 0
                 time.sleep(poll_interval)
-            self._delete_driver_pod()
+            # Pod deletion is best-effort cleanup. If it fails (e.g. already 
garbage collected or RBAC
+            # denied), suppress the error so terminal_phase is still returned 
and the task
+            # succeeds. Raising here would skip the task_store write and force 
an unnecessary retry.
+            with contextlib.suppress(Exception):
+                self._delete_driver_pod()
+            return terminal_phase
         finally:
             self._run_post_submit_commands()
 
diff --git 
a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
 
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
index e2b6067d42c..7ceb95b387a 100644
--- 
a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
+++ 
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
@@ -23,13 +23,18 @@ from typing import TYPE_CHECKING, Any, cast
 import requests
 from tenacity import retry, stop_after_attempt, wait_fixed
 
-from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook
+from airflow.providers.apache.spark.hooks.spark_submit import 
_K8S_WAIT_APP_COMPLETION_CONF, SparkSubmitHook
 from airflow.providers.common.compat.openlineage.utils.spark import (
     inject_parent_job_information_into_spark_properties,
     inject_transport_information_into_spark_properties,
 )
 from airflow.providers.common.compat.sdk import BaseOperator, conf
 
+try:
+    from airflow.providers.cncf.kubernetes import kube_client
+except ImportError:
+    kube_client = None  # type: ignore[assignment]
+
 try:
     from airflow.sdk import ResumableJobMixin
 except ImportError:
@@ -140,6 +145,11 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
     # YARN application ID, K8s driver pod name).
     external_id_key = "spark_job_id"
 
+    # Used only for k8s cluster mode. Caches the pod phase ("Succeeded" / 
"Failed") to task_store at the end of
+    # poll_until_complete. On retry, get_job_status reads this before querying 
the K8s API
+    # so that a completed job can be identified even after the driver pod is 
garbage collected.
+    _K8S_DRIVER_STATUS_KEY = "k8s_driver_status"
+
     template_fields: Sequence[str] = (
         "application",
         "conf",
@@ -269,11 +279,12 @@ class SparkSubmitOperator(ResumableJobMixin, 
BaseOperator):
             self.poll_until_complete(driver_id, context)
             return self.get_job_result(driver_id, context)
         if hook._should_track_driver_via_k8s_api():
-            # TODO: Wire into execute_resumable() via ResumableJobMixin
-            # (fill submit_job / poll_until_complete K8s stubs) to enable 
crash recovery.
-            hook.submit(self.application)
-            hook._poll_k8s_driver_via_api()
-            return
+            if self.reconnect_on_retry:
+                return self.execute_resumable(context)
+            # reconnect_on_retry=False: still submit-and-poll, just skip 
task_state persistence.
+            driver_id = self.submit_job(context)
+            self.poll_until_complete(driver_id, context)
+            return self.get_job_result(driver_id, context)
         if hook._is_yarn_cluster_mode:
             if self.reconnect_on_retry and not hook._yarn_track_via_rm_api:
                 raise ValueError(
@@ -290,9 +301,19 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
                 return self.get_job_result(driver_id, context)
         hook.submit(self.application)
 
-    def submit_job(self, context: Context) -> str:
+    def submit_job(self, context: Context) -> str | None:
         if self._hook is None:
             self._hook = self._get_hook()
+        if self._hook._is_kubernetes:
+            self._hook._conf[_K8S_WAIT_APP_COMPLETION_CONF] = "false"
+            self._hook.submit(self.application)
+            pod_name = self._hook._kubernetes_driver_pod
+            namespace = self._hook._connection["namespace"]
+            if not pod_name:
+                raise RuntimeError("spark-submit did not capture a K8s driver 
pod name")
+            external_id = f"{namespace}:{pod_name}"
+            self.log.info("Spark K8s driver pod submitted: %s", external_id)
+            return external_id
         if self._hook._is_yarn_cluster_mode:
             if self._hook._conf.get("spark.yarn.submit.waitAppCompletion", 
"").strip().lower() == "true":
                 raise ValueError(
@@ -321,12 +342,24 @@ class SparkSubmitOperator(ResumableJobMixin, 
BaseOperator):
         if self._hook._is_yarn_cluster_mode:
             return self._hook.query_yarn_application_status(external_id)
         if self._hook._is_kubernetes:
-            # The K8s branches below (and in is_job_active, is_job_succeeded, 
poll_until_complete)
-            # are currently unreachable: execute_resumable is only called when 
_should_track_driver_status
-            # is True, which requires spark:// + cluster mode. They are 
scaffolding for a follow-up PR
-            # that extends ResumableJobMixin support to Kubernetes.
-            # TODO: call K8s pod status API
-            raise NotImplementedError("K8s job status not yet implemented")
+            if (task_state_store := context.get("task_state_store")) is not 
None:
+                if (cached := 
task_state_store.get(self._K8S_DRIVER_STATUS_KEY)) is not None:
+                    if not isinstance(cached, str):
+                        raise ValueError(f"Cached K8s driver status is not a 
string: {cached!r}")
+                    return cached
+            if kube_client is None:
+                raise RuntimeError(
+                    "apache-airflow-providers-cncf-kubernetes is required to 
query K8s pod status"
+                )
+            namespace, pod_name = self._parse_k8s_external_id(external_id)
+            try:
+                client = kube_client.get_kube_client()
+                pod = client.read_namespaced_pod(pod_name, namespace)
+                return pod.status.phase or "Pending"
+            except kube_client.ApiException as e:
+                if e.status == 404:
+                    return "NotFound"
+                raise
         scheme = self._hook._connection.get("rest_scheme", "http")
         rest_port = self._hook._connection.get("rest_port", 6066)
         # HA master URLs can look like spark://m1:7077,m2:7077 — try each host 
in order.
@@ -345,6 +378,14 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
                 last_exc = e
         raise last_exc
 
+    @staticmethod
+    def _parse_k8s_external_id(external_id: str) -> tuple[str, str]:
+        """Parse a K8s external ID of the form 'namespace:pod_name' into its 
components."""
+        parts = external_id.split(":", 1)
+        if len(parts) != 2:
+            raise ValueError(f"Invalid K8s external ID format {external_id!r}; 
expected 'namespace:pod_name'")
+        return parts[0], parts[1]
+
     @retry(stop=stop_after_attempt(3), wait=wait_fixed(1), reraise=True)
     def _fetch_driver_status(self, url: str, external_id: str) -> str:
         response = requests.get(url, timeout=30)
@@ -397,8 +438,18 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
                 self._hook._run_post_submit_commands()
             return
         if self._hook._is_kubernetes:
-            # TODO: poll K8s pod phase until terminal
-            raise NotImplementedError("K8s poll not yet implemented")
+            if external_id is not None:
+                _, pod_name = self._parse_k8s_external_id(external_id)
+                self._hook._kubernetes_driver_pod = pod_name
+            terminal_phase = self._hook._poll_k8s_driver_via_api()
+            # Cache only when the pod actually reached Succeeded, the 
404/vanished path
+            # returns None for cases like: pod deleted by on_kill or garbage 
collected after failure)
+            # and must not be cached, otherwise a retry would see "Succeeded" 
and skip resubmission.
+            if terminal_phase == "Succeeded" and self.reconnect_on_retry:
+                if (task_state_store := context.get("task_state_store")) is 
not None:
+                    task_state_store.set(self._K8S_DRIVER_STATUS_KEY, 
"Succeeded")
+            return
+
         self.log.info("Polling driver %s until completion", external_id)
         self._hook._driver_id = external_id
         try:
diff --git 
a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py 
b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
index 9aa2c206406..fa41b0af446 100644
--- 
a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
+++ 
b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
@@ -322,6 +322,7 @@ class TestSparkSubmitOperator:
             **self._config,
         )
         mock_get_hook.return_value._should_track_driver_status = False
+        
mock_get_hook.return_value._should_track_driver_via_k8s_api.return_value = False
         operator.execute(MagicMock())
 
         assert operator.conf == {
@@ -389,6 +390,7 @@ class TestSparkSubmitOperator:
             **self._config,
         )
         mock_get_hook.return_value._should_track_driver_status = False
+        
mock_get_hook.return_value._should_track_driver_via_k8s_api.return_value = False
         operator.execute({"ti": mock_ti})
 
         assert operator.conf == {
@@ -428,6 +430,7 @@ class TestSparkSubmitOperator:
         )
 
         mock_get_hook.return_value._should_track_driver_status = False
+        
mock_get_hook.return_value._should_track_driver_via_k8s_api.return_value = False
         with caplog.at_level(logging.INFO):
             operator = SparkSubmitOperator(
                 task_id="spark_submit_job",
@@ -460,6 +463,7 @@ class TestSparkSubmitOperator:
         )
 
         mock_get_hook.return_value._should_track_driver_status = False
+        
mock_get_hook.return_value._should_track_driver_via_k8s_api.return_value = False
         with caplog.at_level(logging.INFO):
             operator = SparkSubmitOperator(
                 task_id="spark_submit_job",
@@ -892,7 +896,10 @@ class TestSparkSubmitOperatorK8sTracking:
         hook = MagicMock()
         hook._should_track_driver_status = False
         hook._should_track_driver_via_k8s_api.return_value = True
+        hook._is_kubernetes = True
+        hook._is_yarn = False
         hook._is_yarn_cluster_mode = False
+        hook._conf = {}
         return hook
 
     def test_execute_calls_submit_then_poll_when_flag_set(self):
@@ -920,3 +927,185 @@ class TestSparkSubmitOperatorK8sTracking:
 
         hook.submit.assert_called_once_with("test.jar")
         hook._poll_k8s_driver_via_api.assert_not_called()
+
+    def test_k8s_submit_job_returns_encoded_external_id(self):
+        operator = self._make_operator(track_driver_via_k8s_api=True)
+        hook = self._make_k8s_hook()
+        hook._kubernetes_driver_pod = "spark-abc-driver"
+        hook._connection = {"namespace": "mynamespace"}
+        operator._hook = hook
+
+        result = operator.submit_job(context={})
+
+        assert result == "mynamespace:spark-abc-driver"
+        assert hook._conf.get("spark.kubernetes.submission.waitAppCompletion") 
== "false"
+        hook.submit.assert_called_once_with("test.jar")
+
+    def test_k8s_submit_job_raises_when_pod_name_missing(self):
+        operator = self._make_operator(track_driver_via_k8s_api=True)
+        hook = self._make_k8s_hook()
+        hook._kubernetes_driver_pod = None
+        hook._connection = {"namespace": "mynamespace"}
+        operator._hook = hook
+
+        with pytest.raises(RuntimeError, match="did not capture a K8s driver 
pod name"):
+            operator.submit_job(context={})
+
+    def test_k8s_get_job_status_returns_k8s_driver_status(self):
+        operator = self._make_operator(track_driver_via_k8s_api=True)
+        operator._hook = self._make_k8s_hook()
+        task_store = FakeTaskState({"k8s_driver_status": "Succeeded"})
+
+        with 
mock.patch("airflow.providers.apache.spark.operators.spark_submit.kube_client") 
as mock_kube:
+            result = operator.get_job_status("mynamespace:spark-abc-driver", 
{"task_state_store": task_store})
+
+        assert result == "Succeeded"
+        mock_kube.get_kube_client.assert_not_called()
+
+    def 
test_k8s_get_job_status_queries_k8s_api_when_no_k8s_driver_status(self):
+        operator = self._make_operator(track_driver_via_k8s_api=True)
+        operator._hook = self._make_k8s_hook()
+        task_store = FakeTaskState()
+
+        mock_pod = MagicMock()
+        mock_pod.status.phase = "Running"
+
+        with 
mock.patch("airflow.providers.apache.spark.operators.spark_submit.kube_client") 
as mock_kube:
+            
mock_kube.get_kube_client.return_value.read_namespaced_pod.return_value = 
mock_pod
+            result = operator.get_job_status("mynamespace:spark-abc-driver", 
{"task_state_store": task_store})
+
+        assert result == "Running"
+
+    def test_k8s_get_job_status_returns_pending_when_phase_is_none(self):
+        operator = self._make_operator(track_driver_via_k8s_api=True)
+        operator._hook = self._make_k8s_hook()
+
+        mock_pod = MagicMock()
+        mock_pod.status.phase = None
+
+        with 
mock.patch("airflow.providers.apache.spark.operators.spark_submit.kube_client") 
as mock_kube:
+            
mock_kube.get_kube_client.return_value.read_namespaced_pod.return_value = 
mock_pod
+            result = operator.get_job_status("mynamespace:spark-abc-driver", 
{})
+
+        assert result == "Pending"
+
+    def test_k8s_get_job_status_returns_not_found_on_404(self):
+        operator = self._make_operator(track_driver_via_k8s_api=True)
+        operator._hook = self._make_k8s_hook()
+
+        class FakeApiException(Exception):
+            def __init__(self, status):
+                self.status = status
+
+        with 
mock.patch("airflow.providers.apache.spark.operators.spark_submit.kube_client") 
as mock_kube:
+            mock_kube.ApiException = FakeApiException
+            
mock_kube.get_kube_client.return_value.read_namespaced_pod.side_effect = 
FakeApiException(404)
+            result = operator.get_job_status("mynamespace:spark-abc-driver", 
{})
+
+        assert result == "NotFound"
+
+    def test_k8s_get_job_status_reraises_non_404_api_exception(self):
+        operator = self._make_operator(track_driver_via_k8s_api=True)
+        operator._hook = self._make_k8s_hook()
+
+        class FakeApiException(Exception):
+            def __init__(self, status):
+                self.status = status
+
+        with 
mock.patch("airflow.providers.apache.spark.operators.spark_submit.kube_client") 
as mock_kube:
+            mock_kube.ApiException = FakeApiException
+            
mock_kube.get_kube_client.return_value.read_namespaced_pod.side_effect = 
FakeApiException(500)
+            with pytest.raises(FakeApiException):
+                operator.get_job_status("mynamespace:spark-abc-driver", {})
+
+    def test_k8s_poll_until_complete_sets_pod_name_and_calls_poll_api(self):
+        operator = self._make_operator(track_driver_via_k8s_api=True)
+        hook = self._make_k8s_hook()
+        operator._hook = hook
+
+        operator.poll_until_complete("mynamespace:spark-abc-driver", {})
+
+        assert hook._kubernetes_driver_pod == "spark-abc-driver"
+        hook._poll_k8s_driver_via_api.assert_called_once()
+
+    def test_k8s_poll_until_complete_writes_succeeded_to_task_store(self):
+        operator = self._make_operator(track_driver_via_k8s_api=True)
+        hook = self._make_k8s_hook()
+        hook._poll_k8s_driver_via_api.return_value = "Succeeded"
+        operator._hook = hook
+        task_store = FakeTaskState()
+
+        operator.poll_until_complete("mynamespace:spark-abc-driver", 
{"task_state_store": task_store})
+
+        assert task_store.get("k8s_driver_status") == "Succeeded"
+
+    def 
test_k8s_polling_does_not_write_task_store_when_reconnect_disabled(self):
+        operator = self._make_operator(track_driver_via_k8s_api=True, 
reconnect_on_retry=False)
+        hook = self._make_k8s_hook()
+        hook._poll_k8s_driver_via_api.return_value = "Succeeded"
+        operator._hook = hook
+        task_store = FakeTaskState()
+
+        operator.poll_until_complete("mynamespace:spark-abc-driver", 
{"task_state_store": task_store})
+
+        assert task_store.get("k8s_driver_status") is None
+
+    def 
test_k8s_poll_until_complete_does_not_cache_and_reraises_on_failure(self):
+        operator = self._make_operator(track_driver_via_k8s_api=True)
+        hook = self._make_k8s_hook()
+        hook._poll_k8s_driver_via_api.side_effect = RuntimeError("Spark 
application failed (phase=Failed)")
+        operator._hook = hook
+        task_store = FakeTaskState()
+
+        with pytest.raises(RuntimeError, match="phase=Failed"):
+            operator.poll_until_complete("mynamespace:spark-abc-driver", 
{"task_state_store": task_store})
+
+        assert task_store.get("k8s_driver_status") is None
+
+    def test_k8s_poll_until_complete_tolerates_absent_task_store(self):
+        operator = self._make_operator(track_driver_via_k8s_api=True)
+        operator._hook = self._make_k8s_hook()
+
+        operator.poll_until_complete("mynamespace:spark-abc-driver", {})
+
+    @pytest.mark.skipif(
+        not AIRFLOW_V_3_3_PLUS,
+        reason="ResumableJobMixin reconnect requires task_state, available in 
Airflow 3.3+",
+    )
+    def test_k8s_execute_persists_pod_id_when_reconnect_on_retry(self):
+        """execute() with reconnect_on_retry=True stores the pod ID in 
task_store before polling."""
+        operator = self._make_operator(track_driver_via_k8s_api=True, 
reconnect_on_retry=True)
+        hook = self._make_k8s_hook()
+        hook._kubernetes_driver_pod = "spark-abc-driver"
+        hook._connection = {"namespace": "mynamespace"}
+        operator._hook = hook
+        task_store = FakeTaskState()
+        persisted_before_poll: list[str | None] = []
+
+        def track_poll(external_id, context):
+            persisted_before_poll.append(task_store.get("spark_job_id"))
+
+        operator.poll_until_complete = track_poll
+
+        operator.execute(context={"task_state_store": task_store})
+
+        assert persisted_before_poll == ["mynamespace:spark-abc-driver"]
+
+    @pytest.mark.skipif(
+        not AIRFLOW_V_3_3_PLUS,
+        reason="ResumableJobMixin reconnect requires task_state, available in 
Airflow 3.3+",
+    )
+    def 
test_k8s_execute_reconnect_on_retry_false_does_not_persist_pod_id(self):
+        """execute() with reconnect_on_retry=False does not write spark_job_id 
to task_store."""
+        operator = self._make_operator(track_driver_via_k8s_api=True, 
reconnect_on_retry=False)
+        hook = self._make_k8s_hook()
+        hook._kubernetes_driver_pod = "spark-abc-driver"
+        hook._connection = {"namespace": "mynamespace"}
+        operator._hook = hook
+        task_store = FakeTaskState()
+
+        operator.poll_until_complete = lambda external_id, context: None
+
+        operator.execute(context={"task_state_store": task_store})
+
+        assert task_store.get("spark_job_id") is None

Reply via email to