This is an automated email from the ASF dual-hosted git repository.
amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 28006a524c0 Add crash recovery ability to SparkSubmitOperator against
Kubernetes (#68067)
28006a524c0 is described below
commit 28006a524c02eb261f9c0a1c57229e470f318957
Author: Amogh Desai <[email protected]>
AuthorDate: Mon Jun 15 13:27:15 2026 +0530
Add crash recovery ability to SparkSubmitOperator against Kubernetes
(#68067)
---
providers/apache/spark/docs/operators.rst | 7 +-
.../providers/apache/spark/hooks/spark_submit.py | 31 +++-
.../apache/spark/operators/spark_submit.py | 81 +++++++--
.../apache/spark/operators/test_spark_submit.py | 189 +++++++++++++++++++++
4 files changed, 287 insertions(+), 21 deletions(-)
diff --git a/providers/apache/spark/docs/operators.rst
b/providers/apache/spark/docs/operators.rst
index d4520578420..4d3a9a526af 100644
--- a/providers/apache/spark/docs/operators.rst
+++ b/providers/apache/spark/docs/operators.rst
@@ -236,6 +236,7 @@ Python Kubernetes client rather than holding
``spark-submit`` open for the full
conn_id="spark_k8s",
deploy_mode="cluster",
track_driver_via_k8s_api=True,
+ reconnect_on_retry=True,
)
**Requirements**
@@ -245,8 +246,10 @@ Python Kubernetes client rather than holding
``spark-submit`` open for the full
conflicts with the flag and a ``ValueError`` will be raised at task start.
* The Airflow worker must be able to reach the Kubernetes API server and have
permission to
read and delete pods in the driver's namespace; otherwise pod tracking and
cleanup will fail.
-* This path bypasses ``ResumableJobMixin``, so Airflow retries submit a fresh
driver instead of
- reconnecting to an existing one. Set ``execution_timeout`` to bound
wall-clock time.
+* Set ``reconnect_on_retry=True`` (the default) to enable crash recovery: the
driver pod name is
+ persisted to task state before polling begins, so a worker crash and retry
reconnects to the
+ existing pod instead of submitting a fresh one. Set
``reconnect_on_retry=False`` to always
+ submit a fresh driver on retry.
* Pod completion is detected from ``pod.status.phase``. If your driver pods
have sidecar
containers (e.g. Istio injection enabled for the driver namespace), the pod
phase may not
advance to ``Succeeded`` until all sidecars exit. In that case the poll loop
will wait
diff --git
a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
index 3a19950696a..662966e4e14 100644
---
a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
+++
b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
@@ -1138,8 +1138,14 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
f"returncode = {returncode}"
)
- def _poll_k8s_driver_via_api(self) -> None:
- """Poll the K8s driver pod phase until it reaches a terminal state."""
+ def _poll_k8s_driver_via_api(self) -> str | None:
+ """
+ Poll the K8s driver pod phase until it reaches a terminal state.
+
+ Returns the terminal phase string (e.g. ``"Succeeded"``) on normal
completion,
+ or ``None`` if the pod vanished mid-poll (404 — likely deleted by
``on_kill``).
+ Raises ``RuntimeError`` on failure phases or unrecoverable API errors.
+ """
pod_name = self._kubernetes_driver_pod
namespace = self._connection["namespace"]
app_id = self._kubernetes_application_id or pod_name
@@ -1173,7 +1179,7 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
"Driver pod %s not found (404); pod was likely
deleted by on_kill. Exiting poll loop.",
pod_name,
)
- return
+ return None
consecutive_api_errors += 1
self.log.warning(
"ApiException polling pod %s (%d/%d): %s",
@@ -1193,6 +1199,18 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
phase = pod.status.phase or "Initializing"
self.log.info("Application status for %s (phase: %s)", app_id,
phase)
if phase == "Succeeded":
+ if pod.status.container_statuses:
+ cs = pod.status.container_statuses[0]
+ if cs.state and cs.state.terminated:
+ t = cs.state.terminated
+ self.log.info(
+ "Container final status: exit_code=%s
reason=%s started_at=%s finished_at=%s",
+ t.exit_code,
+ t.reason,
+ t.started_at,
+ t.finished_at,
+ )
+ terminal_phase = phase
break
if phase == "Failed":
container_state = ""
@@ -1224,7 +1242,12 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
else:
consecutive_unknown = 0
time.sleep(poll_interval)
- self._delete_driver_pod()
+ # Pod deletion is best-effort cleanup. If it fails (e.g. already
garbage collected or RBAC
+ # denied), suppress the error so terminal_phase is still returned
and the task
+ # succeeds. Raising here would skip the task_store write and force
an unnecessary retry.
+ with contextlib.suppress(Exception):
+ self._delete_driver_pod()
+ return terminal_phase
finally:
self._run_post_submit_commands()
diff --git
a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
index e2b6067d42c..7ceb95b387a 100644
---
a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
+++
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
@@ -23,13 +23,18 @@ from typing import TYPE_CHECKING, Any, cast
import requests
from tenacity import retry, stop_after_attempt, wait_fixed
-from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook
+from airflow.providers.apache.spark.hooks.spark_submit import
_K8S_WAIT_APP_COMPLETION_CONF, SparkSubmitHook
from airflow.providers.common.compat.openlineage.utils.spark import (
inject_parent_job_information_into_spark_properties,
inject_transport_information_into_spark_properties,
)
from airflow.providers.common.compat.sdk import BaseOperator, conf
+try:
+ from airflow.providers.cncf.kubernetes import kube_client
+except ImportError:
+ kube_client = None # type: ignore[assignment]
+
try:
from airflow.sdk import ResumableJobMixin
except ImportError:
@@ -140,6 +145,11 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
# YARN application ID, K8s driver pod name).
external_id_key = "spark_job_id"
+ # Used only for k8s cluster mode. Caches the pod phase ("Succeeded" /
"Failed") to task_store at the end of
+ # poll_until_complete. On retry, get_job_status reads this before querying
the K8s API
+ # so that a completed job can be identified even after the driver pod is
garbage collected.
+ _K8S_DRIVER_STATUS_KEY = "k8s_driver_status"
+
template_fields: Sequence[str] = (
"application",
"conf",
@@ -269,11 +279,12 @@ class SparkSubmitOperator(ResumableJobMixin,
BaseOperator):
self.poll_until_complete(driver_id, context)
return self.get_job_result(driver_id, context)
if hook._should_track_driver_via_k8s_api():
- # TODO: Wire into execute_resumable() via ResumableJobMixin
- # (fill submit_job / poll_until_complete K8s stubs) to enable
crash recovery.
- hook.submit(self.application)
- hook._poll_k8s_driver_via_api()
- return
+ if self.reconnect_on_retry:
+ return self.execute_resumable(context)
+ # reconnect_on_retry=False: still submit-and-poll, just skip
task_state persistence.
+ driver_id = self.submit_job(context)
+ self.poll_until_complete(driver_id, context)
+ return self.get_job_result(driver_id, context)
if hook._is_yarn_cluster_mode:
if self.reconnect_on_retry and not hook._yarn_track_via_rm_api:
raise ValueError(
@@ -290,9 +301,19 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
return self.get_job_result(driver_id, context)
hook.submit(self.application)
- def submit_job(self, context: Context) -> str:
+ def submit_job(self, context: Context) -> str | None:
if self._hook is None:
self._hook = self._get_hook()
+ if self._hook._is_kubernetes:
+ self._hook._conf[_K8S_WAIT_APP_COMPLETION_CONF] = "false"
+ self._hook.submit(self.application)
+ pod_name = self._hook._kubernetes_driver_pod
+ namespace = self._hook._connection["namespace"]
+ if not pod_name:
+ raise RuntimeError("spark-submit did not capture a K8s driver
pod name")
+ external_id = f"{namespace}:{pod_name}"
+ self.log.info("Spark K8s driver pod submitted: %s", external_id)
+ return external_id
if self._hook._is_yarn_cluster_mode:
if self._hook._conf.get("spark.yarn.submit.waitAppCompletion",
"").strip().lower() == "true":
raise ValueError(
@@ -321,12 +342,24 @@ class SparkSubmitOperator(ResumableJobMixin,
BaseOperator):
if self._hook._is_yarn_cluster_mode:
return self._hook.query_yarn_application_status(external_id)
if self._hook._is_kubernetes:
- # The K8s branches below (and in is_job_active, is_job_succeeded,
poll_until_complete)
- # are currently unreachable: execute_resumable is only called when
_should_track_driver_status
- # is True, which requires spark:// + cluster mode. They are
scaffolding for a follow-up PR
- # that extends ResumableJobMixin support to Kubernetes.
- # TODO: call K8s pod status API
- raise NotImplementedError("K8s job status not yet implemented")
+ if (task_state_store := context.get("task_state_store")) is not
None:
+ if (cached :=
task_state_store.get(self._K8S_DRIVER_STATUS_KEY)) is not None:
+ if not isinstance(cached, str):
+ raise ValueError(f"Cached K8s driver status is not a
string: {cached!r}")
+ return cached
+ if kube_client is None:
+ raise RuntimeError(
+ "apache-airflow-providers-cncf-kubernetes is required to
query K8s pod status"
+ )
+ namespace, pod_name = self._parse_k8s_external_id(external_id)
+ try:
+ client = kube_client.get_kube_client()
+ pod = client.read_namespaced_pod(pod_name, namespace)
+ return pod.status.phase or "Pending"
+ except kube_client.ApiException as e:
+ if e.status == 404:
+ return "NotFound"
+ raise
scheme = self._hook._connection.get("rest_scheme", "http")
rest_port = self._hook._connection.get("rest_port", 6066)
# HA master URLs can look like spark://m1:7077,m2:7077 — try each host
in order.
@@ -345,6 +378,14 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
last_exc = e
raise last_exc
+ @staticmethod
+ def _parse_k8s_external_id(external_id: str) -> tuple[str, str]:
+ """Parse a K8s external ID of the form 'namespace:pod_name' into its
components."""
+ parts = external_id.split(":", 1)
+ if len(parts) != 2:
+ raise ValueError(f"Invalid K8s external ID format {external_id!r};
expected 'namespace:pod_name'")
+ return parts[0], parts[1]
+
@retry(stop=stop_after_attempt(3), wait=wait_fixed(1), reraise=True)
def _fetch_driver_status(self, url: str, external_id: str) -> str:
response = requests.get(url, timeout=30)
@@ -397,8 +438,18 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
self._hook._run_post_submit_commands()
return
if self._hook._is_kubernetes:
- # TODO: poll K8s pod phase until terminal
- raise NotImplementedError("K8s poll not yet implemented")
+ if external_id is not None:
+ _, pod_name = self._parse_k8s_external_id(external_id)
+ self._hook._kubernetes_driver_pod = pod_name
+ terminal_phase = self._hook._poll_k8s_driver_via_api()
+ # Cache only when the pod actually reached Succeeded, the
404/vanished path
+ # returns None for cases like: pod deleted by on_kill or garbage
collected after failure)
+ # and must not be cached, otherwise a retry would see "Succeeded"
and skip resubmission.
+ if terminal_phase == "Succeeded" and self.reconnect_on_retry:
+ if (task_state_store := context.get("task_state_store")) is
not None:
+ task_state_store.set(self._K8S_DRIVER_STATUS_KEY,
"Succeeded")
+ return
+
self.log.info("Polling driver %s until completion", external_id)
self._hook._driver_id = external_id
try:
diff --git
a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
index 9aa2c206406..fa41b0af446 100644
---
a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
+++
b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
@@ -322,6 +322,7 @@ class TestSparkSubmitOperator:
**self._config,
)
mock_get_hook.return_value._should_track_driver_status = False
+
mock_get_hook.return_value._should_track_driver_via_k8s_api.return_value = False
operator.execute(MagicMock())
assert operator.conf == {
@@ -389,6 +390,7 @@ class TestSparkSubmitOperator:
**self._config,
)
mock_get_hook.return_value._should_track_driver_status = False
+
mock_get_hook.return_value._should_track_driver_via_k8s_api.return_value = False
operator.execute({"ti": mock_ti})
assert operator.conf == {
@@ -428,6 +430,7 @@ class TestSparkSubmitOperator:
)
mock_get_hook.return_value._should_track_driver_status = False
+
mock_get_hook.return_value._should_track_driver_via_k8s_api.return_value = False
with caplog.at_level(logging.INFO):
operator = SparkSubmitOperator(
task_id="spark_submit_job",
@@ -460,6 +463,7 @@ class TestSparkSubmitOperator:
)
mock_get_hook.return_value._should_track_driver_status = False
+
mock_get_hook.return_value._should_track_driver_via_k8s_api.return_value = False
with caplog.at_level(logging.INFO):
operator = SparkSubmitOperator(
task_id="spark_submit_job",
@@ -892,7 +896,10 @@ class TestSparkSubmitOperatorK8sTracking:
hook = MagicMock()
hook._should_track_driver_status = False
hook._should_track_driver_via_k8s_api.return_value = True
+ hook._is_kubernetes = True
+ hook._is_yarn = False
hook._is_yarn_cluster_mode = False
+ hook._conf = {}
return hook
def test_execute_calls_submit_then_poll_when_flag_set(self):
@@ -920,3 +927,185 @@ class TestSparkSubmitOperatorK8sTracking:
hook.submit.assert_called_once_with("test.jar")
hook._poll_k8s_driver_via_api.assert_not_called()
+
+ def test_k8s_submit_job_returns_encoded_external_id(self):
+ operator = self._make_operator(track_driver_via_k8s_api=True)
+ hook = self._make_k8s_hook()
+ hook._kubernetes_driver_pod = "spark-abc-driver"
+ hook._connection = {"namespace": "mynamespace"}
+ operator._hook = hook
+
+ result = operator.submit_job(context={})
+
+ assert result == "mynamespace:spark-abc-driver"
+ assert hook._conf.get("spark.kubernetes.submission.waitAppCompletion")
== "false"
+ hook.submit.assert_called_once_with("test.jar")
+
+ def test_k8s_submit_job_raises_when_pod_name_missing(self):
+ operator = self._make_operator(track_driver_via_k8s_api=True)
+ hook = self._make_k8s_hook()
+ hook._kubernetes_driver_pod = None
+ hook._connection = {"namespace": "mynamespace"}
+ operator._hook = hook
+
+ with pytest.raises(RuntimeError, match="did not capture a K8s driver
pod name"):
+ operator.submit_job(context={})
+
+ def test_k8s_get_job_status_returns_k8s_driver_status(self):
+ operator = self._make_operator(track_driver_via_k8s_api=True)
+ operator._hook = self._make_k8s_hook()
+ task_store = FakeTaskState({"k8s_driver_status": "Succeeded"})
+
+ with
mock.patch("airflow.providers.apache.spark.operators.spark_submit.kube_client")
as mock_kube:
+ result = operator.get_job_status("mynamespace:spark-abc-driver",
{"task_state_store": task_store})
+
+ assert result == "Succeeded"
+ mock_kube.get_kube_client.assert_not_called()
+
+ def
test_k8s_get_job_status_queries_k8s_api_when_no_k8s_driver_status(self):
+ operator = self._make_operator(track_driver_via_k8s_api=True)
+ operator._hook = self._make_k8s_hook()
+ task_store = FakeTaskState()
+
+ mock_pod = MagicMock()
+ mock_pod.status.phase = "Running"
+
+ with
mock.patch("airflow.providers.apache.spark.operators.spark_submit.kube_client")
as mock_kube:
+
mock_kube.get_kube_client.return_value.read_namespaced_pod.return_value =
mock_pod
+ result = operator.get_job_status("mynamespace:spark-abc-driver",
{"task_state_store": task_store})
+
+ assert result == "Running"
+
+ def test_k8s_get_job_status_returns_pending_when_phase_is_none(self):
+ operator = self._make_operator(track_driver_via_k8s_api=True)
+ operator._hook = self._make_k8s_hook()
+
+ mock_pod = MagicMock()
+ mock_pod.status.phase = None
+
+ with
mock.patch("airflow.providers.apache.spark.operators.spark_submit.kube_client")
as mock_kube:
+
mock_kube.get_kube_client.return_value.read_namespaced_pod.return_value =
mock_pod
+ result = operator.get_job_status("mynamespace:spark-abc-driver",
{})
+
+ assert result == "Pending"
+
+ def test_k8s_get_job_status_returns_not_found_on_404(self):
+ operator = self._make_operator(track_driver_via_k8s_api=True)
+ operator._hook = self._make_k8s_hook()
+
+ class FakeApiException(Exception):
+ def __init__(self, status):
+ self.status = status
+
+ with
mock.patch("airflow.providers.apache.spark.operators.spark_submit.kube_client")
as mock_kube:
+ mock_kube.ApiException = FakeApiException
+
mock_kube.get_kube_client.return_value.read_namespaced_pod.side_effect =
FakeApiException(404)
+ result = operator.get_job_status("mynamespace:spark-abc-driver",
{})
+
+ assert result == "NotFound"
+
+ def test_k8s_get_job_status_reraises_non_404_api_exception(self):
+ operator = self._make_operator(track_driver_via_k8s_api=True)
+ operator._hook = self._make_k8s_hook()
+
+ class FakeApiException(Exception):
+ def __init__(self, status):
+ self.status = status
+
+ with
mock.patch("airflow.providers.apache.spark.operators.spark_submit.kube_client")
as mock_kube:
+ mock_kube.ApiException = FakeApiException
+
mock_kube.get_kube_client.return_value.read_namespaced_pod.side_effect =
FakeApiException(500)
+ with pytest.raises(FakeApiException):
+ operator.get_job_status("mynamespace:spark-abc-driver", {})
+
+ def test_k8s_poll_until_complete_sets_pod_name_and_calls_poll_api(self):
+ operator = self._make_operator(track_driver_via_k8s_api=True)
+ hook = self._make_k8s_hook()
+ operator._hook = hook
+
+ operator.poll_until_complete("mynamespace:spark-abc-driver", {})
+
+ assert hook._kubernetes_driver_pod == "spark-abc-driver"
+ hook._poll_k8s_driver_via_api.assert_called_once()
+
+ def test_k8s_poll_until_complete_writes_succeeded_to_task_store(self):
+ operator = self._make_operator(track_driver_via_k8s_api=True)
+ hook = self._make_k8s_hook()
+ hook._poll_k8s_driver_via_api.return_value = "Succeeded"
+ operator._hook = hook
+ task_store = FakeTaskState()
+
+ operator.poll_until_complete("mynamespace:spark-abc-driver",
{"task_state_store": task_store})
+
+ assert task_store.get("k8s_driver_status") == "Succeeded"
+
+ def
test_k8s_polling_does_not_write_task_store_when_reconnect_disabled(self):
+ operator = self._make_operator(track_driver_via_k8s_api=True,
reconnect_on_retry=False)
+ hook = self._make_k8s_hook()
+ hook._poll_k8s_driver_via_api.return_value = "Succeeded"
+ operator._hook = hook
+ task_store = FakeTaskState()
+
+ operator.poll_until_complete("mynamespace:spark-abc-driver",
{"task_state_store": task_store})
+
+ assert task_store.get("k8s_driver_status") is None
+
+ def
test_k8s_poll_until_complete_does_not_cache_and_reraises_on_failure(self):
+ operator = self._make_operator(track_driver_via_k8s_api=True)
+ hook = self._make_k8s_hook()
+ hook._poll_k8s_driver_via_api.side_effect = RuntimeError("Spark
application failed (phase=Failed)")
+ operator._hook = hook
+ task_store = FakeTaskState()
+
+ with pytest.raises(RuntimeError, match="phase=Failed"):
+ operator.poll_until_complete("mynamespace:spark-abc-driver",
{"task_state_store": task_store})
+
+ assert task_store.get("k8s_driver_status") is None
+
+ def test_k8s_poll_until_complete_tolerates_absent_task_store(self):
+ operator = self._make_operator(track_driver_via_k8s_api=True)
+ operator._hook = self._make_k8s_hook()
+
+ operator.poll_until_complete("mynamespace:spark-abc-driver", {})
+
+ @pytest.mark.skipif(
+ not AIRFLOW_V_3_3_PLUS,
+ reason="ResumableJobMixin reconnect requires task_state, available in
Airflow 3.3+",
+ )
+ def test_k8s_execute_persists_pod_id_when_reconnect_on_retry(self):
+ """execute() with reconnect_on_retry=True stores the pod ID in
task_store before polling."""
+ operator = self._make_operator(track_driver_via_k8s_api=True,
reconnect_on_retry=True)
+ hook = self._make_k8s_hook()
+ hook._kubernetes_driver_pod = "spark-abc-driver"
+ hook._connection = {"namespace": "mynamespace"}
+ operator._hook = hook
+ task_store = FakeTaskState()
+ persisted_before_poll: list[str | None] = []
+
+ def track_poll(external_id, context):
+ persisted_before_poll.append(task_store.get("spark_job_id"))
+
+ operator.poll_until_complete = track_poll
+
+ operator.execute(context={"task_state_store": task_store})
+
+ assert persisted_before_poll == ["mynamespace:spark-abc-driver"]
+
+ @pytest.mark.skipif(
+ not AIRFLOW_V_3_3_PLUS,
+ reason="ResumableJobMixin reconnect requires task_state, available in
Airflow 3.3+",
+ )
+ def
test_k8s_execute_reconnect_on_retry_false_does_not_persist_pod_id(self):
+ """execute() with reconnect_on_retry=False does not write spark_job_id
to task_store."""
+ operator = self._make_operator(track_driver_via_k8s_api=True,
reconnect_on_retry=False)
+ hook = self._make_k8s_hook()
+ hook._kubernetes_driver_pod = "spark-abc-driver"
+ hook._connection = {"namespace": "mynamespace"}
+ operator._hook = hook
+ task_store = FakeTaskState()
+
+ operator.poll_until_complete = lambda external_id, context: None
+
+ operator.execute(context={"task_state_store": task_store})
+
+ assert task_store.get("spark_job_id") is None