This is an automated email from the ASF dual-hosted git repository.
amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new daad2c4bca6 Use K8s API to track Spark on K8s instead of JVM based
spark-submit (#67715)
daad2c4bca6 is described below
commit daad2c4bca6d7945c2c7a4baf6f4daa1c7ae6bdc
Author: Amogh Desai <[email protected]>
AuthorDate: Thu Jun 4 16:16:42 2026 +0530
Use K8s API to track Spark on K8s instead of JVM based spark-submit (#67715)
---
providers/apache/spark/docs/operators.rst | 37 ++++
.../providers/apache/spark/hooks/spark_submit.py | 216 ++++++++++++++++++---
.../apache/spark/operators/spark_submit.py | 17 ++
.../unit/apache/spark/hooks/test_spark_submit.py | 205 ++++++++++++++++++-
.../apache/spark/operators/test_spark_submit.py | 41 +++-
5 files changed, 485 insertions(+), 31 deletions(-)
diff --git a/providers/apache/spark/docs/operators.rst
b/providers/apache/spark/docs/operators.rst
index f20d389811e..6bdd4bbcdc7 100644
--- a/providers/apache/spark/docs/operators.rst
+++ b/providers/apache/spark/docs/operators.rst
@@ -215,6 +215,43 @@ See :doc:`connections/spark-submit` for how to configure
these fields.
Crash recovery in cluster mode requires Airflow 3.3+ (``task_state``
support). On earlier
versions the operator falls back to the previous behavior of always
submitting fresh.
+Tracking driver status via Kubernetes API
+""""""""""""""""""""""""""""""""""""""""""
+
+When running in Kubernetes cluster mode, ``spark-submit`` blocks for the
duration of the job.
+The JVM runs processes which does nothing but polling of the pod phase and
holds heap space for
+the entire duration. This is not ideal for long-running jobs, especially when
the driver is idle
+for long periods (e.g. waiting for data or user input).
+
+Set ``track_driver_via_k8s_api=True`` to have the operator track the driver
pod status via the
+Python Kubernetes client rather than holding ``spark-submit`` open for the
full job duration:
+
+.. code-block:: python
+
+ from airflow.providers.apache.spark.operators.spark_submit import
SparkSubmitOperator
+
+ run_spark = SparkSubmitOperator(
+ task_id="run_spark",
+ application="local:///opt/spark/examples/jars/spark-examples.jar",
+ conn_id="spark_k8s",
+ deploy_mode="cluster",
+ track_driver_via_k8s_api=True,
+ )
+
+**Requirements**
+
+* The Spark connection ``master`` must be ``k8s://...`` and ``deploy_mode``
must be ``cluster``.
+* Do not set ``spark.kubernetes.submission.waitAppCompletion=true`` in your
``conf`` — this
+ conflicts with the flag and a ``ValueError`` will be raised at task start.
+* The Airflow worker must be able to reach the Kubernetes API server and have
permission to
+ read and delete pods in the driver's namespace; otherwise pod tracking and
cleanup will fail.
+* This path bypasses ``ResumableJobMixin``, so Airflow retries submit a fresh
driver instead of
+ reconnecting to an existing one. Set ``execution_timeout`` to bound
wall-clock time.
+* Pod completion is detected from ``pod.status.phase``. If your driver pods
have sidecar
+ containers (e.g. Istio injection enabled for the driver namespace), the pod
phase may not
+ advance to ``Succeeded`` until all sidecars exit. In that case the poll loop
will wait
+ indefinitely — set ``execution_timeout`` as a hard bound.
+
YARN ResourceManager API tracking
"""""""""""""""""""""""""""""""""
diff --git
a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
index 7306563a078..7cf1f3248ad 100644
---
a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
+++
b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
@@ -52,6 +52,8 @@ if TYPE_CHECKING:
DEFAULT_SPARK_BINARY = "spark-submit"
ALLOWED_SPARK_BINARIES = [DEFAULT_SPARK_BINARY, "spark2-submit",
"spark3-submit"]
+_K8S_WAIT_APP_COMPLETION_CONF = "spark.kubernetes.submission.waitAppCompletion"
+
class SparkSubmitHook(BaseHook, LoggingMixin):
"""
@@ -90,11 +92,11 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
:param name: Name of the job (default airflow-spark)
:param num_executors: Number of executors to launch
:param status_poll_interval: Seconds to wait between polls of driver
status in cluster
- mode. Used both by the Spark standalone driver-status tracker and (when
- ``yarn_track_via_rm_api=True``) by the YARN ResourceManager REST API
- polling loop. The YARN ResourceManager REST API polling loop uses at
- least 10 seconds to avoid flooding the ResourceManager on long-running
- jobs (Default: 1).
+ mode (Default: 1). Controls three polling loops — each enforces its
own minimum:
+
+ - Spark standalone driver-status tracker (no minimum)
+ - YARN ResourceManager REST API, when ``yarn_track_via_rm_api=True``
(10s minimum)
+ - Kubernetes API, when ``track_driver_via_k8s_api=True`` (20s minimum)
:param application_args: Arguments for the application being submitted
:param env_vars: Environment variables for spark-submit. It
supports yarn and k8s mode too.
@@ -114,6 +116,13 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
job finishes (on both success and on_kill). Useful for cleaning up
sidecars such
as Istio (e.g. ``["curl -X POST localhost:15020/quitquitquit"]``).
Each command
is executed via the shell; failures produce a warning but do not fail
the task.
+ :param track_driver_via_k8s_api: If True (when master is Kubernetes and
+ ``deploy_mode`` is ``cluster``), release the ``spark-submit`` JVM once
the
+ driver pod has been created, then poll the Kubernetes API for the pod
phase
+ until the application reaches a terminal state. The polling interval is
+ controlled by ``status_poll_interval`` with a 20-second minimum. This
frees
+ the worker from holding the long-lived submit JVM (~500 MB). Defaults
to
+ ``False``.
:param yarn_track_via_rm_api: If True (when master is YARN and
``deploy_mode``
is ``cluster``), release the ``spark-submit`` JVM once the application
has
been submitted to YARN, then poll the YARN ResourceManager REST API
@@ -257,6 +266,7 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
*,
use_krb5ccache: bool = False,
post_submit_commands: list[str] | None = None,
+ track_driver_via_k8s_api: bool = False,
yarn_track_via_rm_api: bool = False,
yarn_rm_auth: AuthBase | None = None,
) -> None:
@@ -302,12 +312,14 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
f"{self._connection['master']} specified by kubernetes
dependencies are not installed!"
)
+ self._track_driver_via_k8s_api = track_driver_via_k8s_api
self._should_track_driver_status =
self._resolve_should_track_driver_status()
self._driver_id: str | None = None
self._driver_status: str | None = None
self._spark_exit_code: int | None = None
self._env: dict[str, Any] | None = None
self._post_submit_commands: list[str] = list(post_submit_commands) if
post_submit_commands else []
+ self._post_submit_commands_done: bool = False
self._yarn_track_via_rm_api = yarn_track_via_rm_api
self._yarn_rm_auth = yarn_rm_auth
# Cached after first successful resolution so the polling loop in
@@ -326,6 +338,34 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
"""
return "spark://" in self._connection["master"] and
self._connection["deploy_mode"] == "cluster"
+ def _should_track_driver_via_k8s_api(self) -> bool:
+ return (
+ self._track_driver_via_k8s_api
+ and self._is_kubernetes
+ and self._connection["deploy_mode"] == "cluster"
+ )
+
+ def _validate_track_driver_via_k8s_api_config(self) -> None:
+ if not self._is_kubernetes:
+ raise ValueError(
+ "`track_driver_via_k8s_api=True` requires Spark master to be
Kubernetes (k8s://...)."
+ )
+ if self._connection["deploy_mode"] != "cluster":
+ raise ValueError(
+ "`track_driver_via_k8s_api=True` requires
`deploy_mode='cluster'`; "
+ f"got deploy_mode={self._connection['deploy_mode']!r}."
+ )
+ if not self._connection.get("namespace"):
+ raise ValueError(
+ "`track_driver_via_k8s_api=True` requires a namespace; "
+ "set it in the connection extra as `namespace` or via
`spark.kubernetes.namespace` in conf."
+ )
+ if str(self._conf.get(_K8S_WAIT_APP_COMPLETION_CONF, "")).lower() ==
"true":
+ raise ValueError(
+ f"`track_driver_via_k8s_api=True` is incompatible with "
+ f"`{_K8S_WAIT_APP_COMPLETION_CONF}=true`; remove it from your
conf or set it to 'false'."
+ )
+
def _should_track_yarn_application_via_rm_api(self) -> bool:
"""Return whether this submit should switch to YARN RM REST API
polling."""
return self._yarn_track_via_rm_api and self._is_yarn and
self._connection["deploy_mode"] == "cluster"
@@ -581,6 +621,10 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
if self._connection["deploy_mode"]:
args += ["--deploy-mode", self._connection["deploy_mode"]]
+ if self._should_track_driver_via_k8s_api():
+ if _K8S_WAIT_APP_COMPLETION_CONF not in self._conf:
+ args += ["--conf", f"{_K8S_WAIT_APP_COMPLETION_CONF}=false"]
+
return args
def _build_spark_submit_command(self, application: str) -> list[str]:
@@ -665,7 +709,12 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
Called after the Spark job finishes (success or on_kill). Typical use
case
is killing sidecars like Istio that don't shut down automatically.
Failures are logged as warnings and never raise.
+ Guaranteed to run at most once per hook instance even if called from
both
+ the poll-loop finally and on_kill (e.g. after a SIGTERM).
"""
+ if self._post_submit_commands_done:
+ return
+ self._post_submit_commands_done = True
for cmd in self._post_submit_commands:
self.log.debug("Running post-submit command: %s", cmd)
try:
@@ -719,8 +768,14 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
# Check spark-submit return code. In Kubernetes mode, also check the
value
# of exit code in the log, as it may differ.
+ # When polling via K8s API, spark-submit exits after pod creation
(waitAppCompletion=false)
+ # so _spark_exit_code is never set by the JVM watcher — skip that
check entirely.
try:
- if returncode or (self._is_kubernetes and self._spark_exit_code !=
0):
+ if returncode or (
+ self._is_kubernetes
+ and not self._should_track_driver_via_k8s_api()
+ and self._spark_exit_code != 0
+ ):
if self._is_kubernetes:
raise AirflowException(
f"Cannot execute: {self._mask_cmd(spark_submit_cmd)}.
Error code is: {returncode}. "
@@ -744,10 +799,11 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
"No driver id is known: something went wrong when
executing the spark submit command"
)
finally:
- # In cluster mode with driver tracking, the operator calls
poll_until_complete
- # after submit() returns, so post_submit_commands are deferred
there to preserve
- # the "runs after job finishes" contract. In all other modes, run
them here.
- if not self._should_track_driver_status:
+ # K8s-API tracking defers post-submit commands to
_poll_k8s_driver_via_api's finally
+ # block so they run once after the driver reaches a terminal
state. Spark cluster-mode
+ # driver tracking defers them to poll_until_complete for the same
reason. All other
+ # modes run them here, immediately after spark-submit exits.
+ if not self._should_track_driver_status and not
self._should_track_driver_via_k8s_api():
self._run_post_submit_commands()
return self._driver_id
@@ -778,10 +834,17 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
# If we run Kubernetes cluster mode, we want to extract the driver
pod id
# from the logs so we can kill the application when we stop it
unexpectedly
elif self._is_kubernetes:
+ # Two log formats exist across Spark versions:
+ # "pod name: <name>-driver" and "submission ID
spark:<name>-driver"
match_driver_pod = re.search(r"\s*pod name:
((.+?)-([a-z0-9]+)-driver$)", line)
if match_driver_pod:
self._kubernetes_driver_pod = match_driver_pod.group(1)
self.log.info("Identified spark driver pod: %s",
self._kubernetes_driver_pod)
+ if not self._kubernetes_driver_pod:
+ match_submission_id = re.search(r"submission ID
spark:(.+?-driver)", line)
+ if match_submission_id:
+ self._kubernetes_driver_pod =
match_submission_id.group(1)
+ self.log.info("Identified spark driver pod: %s",
self._kubernetes_driver_pod)
match_application_id = re.search(r"\s*spark-app-selector ->
(spark-([a-z0-9]+)), ", line)
if match_application_id:
@@ -1047,6 +1110,96 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
f"returncode = {returncode}"
)
+ def _poll_k8s_driver_via_api(self) -> None:
+ """Poll the K8s driver pod phase until it reaches a terminal state."""
+ pod_name = self._kubernetes_driver_pod
+ namespace = self._connection["namespace"]
+ app_id = self._kubernetes_application_id or pod_name
+
+ client = kube_client.get_kube_client()
+ poll_interval = max(self._status_poll_interval, 20)
+ if poll_interval != self._status_poll_interval:
+ self.log.info(
+ "status_poll_interval=%ds is below the 20s minimum for K8s API
polling; using 20s.",
+ self._status_poll_interval,
+ )
+ # Mirror `missed_job_status_reports` / `max_missed_job_status_reports`
from
+ # `_start_driver_status_tracking`: tolerate transient failures before
giving up.
+ consecutive_unknown = 0
+ max_consecutive_unknown = 3
+ consecutive_api_errors = 0
+ max_consecutive_api_errors = 3
+ consecutive_pending = 0
+ pending_warn_threshold = 10
+
+ try:
+ if not pod_name:
+ raise ValueError("K8s driver pod name not set; cannot poll
status.")
+ while True:
+ try:
+ pod = client.read_namespaced_pod(pod_name, namespace)
+ consecutive_api_errors = 0
+ except kube_client.ApiException as e:
+ if e.status == 404:
+ self.log.info(
+ "Driver pod %s not found (404); pod was likely
deleted by on_kill. Exiting poll loop.",
+ pod_name,
+ )
+ return
+ consecutive_api_errors += 1
+ self.log.warning(
+ "ApiException polling pod %s (%d/%d): %s",
+ pod_name,
+ consecutive_api_errors,
+ max_consecutive_api_errors,
+ e,
+ )
+ if consecutive_api_errors >= max_consecutive_api_errors:
+ raise RuntimeError(
+ f"K8s API unreachable after
{consecutive_api_errors} consecutive errors "
+ f"while polling {app_id}; giving up."
+ ) from e
+ time.sleep(poll_interval)
+ continue
+
+ phase = pod.status.phase or "Initializing"
+ self.log.info("Application status for %s (phase: %s)", app_id,
phase)
+ if phase == "Succeeded":
+ break
+ if phase == "Failed":
+ container_state = ""
+ if pod.status.container_statuses:
+ cs = pod.status.container_statuses[0]
+ if cs.state and cs.state.terminated:
+ container_state = f"
exit_code={cs.state.terminated.exit_code} reason={cs.state.terminated.reason}"
+ raise RuntimeError(f"Spark application {app_id} failed
(phase=Failed{container_state})")
+ if phase == "Pending":
+ consecutive_pending += 1
+ if consecutive_pending == pending_warn_threshold:
+ self.log.warning(
+ "Driver pod %s has been Pending for %d polls
(~%ds); "
+ "it may be unschedulable. Continuing to wait — set
execution_timeout to bound wait time.",
+ pod_name,
+ consecutive_pending,
+ consecutive_pending * poll_interval,
+ )
+ else:
+ consecutive_pending = 0
+
+ if phase == "Unknown":
+ consecutive_unknown += 1
+ if consecutive_unknown >= max_consecutive_unknown:
+ raise RuntimeError(
+ f"Spark application {app_id} reported Unknown
phase "
+ f"{consecutive_unknown} times consecutively;
giving up."
+ )
+ else:
+ consecutive_unknown = 0
+ time.sleep(poll_interval)
+ self._delete_driver_pod()
+ finally:
+ self._run_post_submit_commands()
+
def _build_spark_driver_kill_command(self) -> list[str]:
"""
Construct the spark-submit command to kill a driver.
@@ -1067,6 +1220,25 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
return connection_cmd
+ def _delete_driver_pod(self) -> None:
+ """Delete the Kubernetes driver pod, logging a warning on failure."""
+ import kubernetes
+
+ self.log.info("Deleting driver pod %s on Kubernetes",
self._kubernetes_driver_pod)
+ try:
+ client = kube_client.get_kube_client()
+ client.delete_namespaced_pod(
+ self._kubernetes_driver_pod,
+ self._connection["namespace"],
+ body=kubernetes.client.V1DeleteOptions(),
+ pretty=True,
+ )
+ self.log.info("Deleted driver pod %s", self._kubernetes_driver_pod)
+ except kube_client.ApiException:
+ self.log.exception(
+ "Exception when attempting to delete driver pod %s",
self._kubernetes_driver_pod
+ )
+
def on_kill(self) -> None:
"""Kill Spark submit command."""
self.log.debug("Kill Command is being called")
@@ -1080,6 +1252,11 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
"Spark driver %s killed with return code: %s",
self._driver_id, driver_kill.wait()
)
+ if self._should_track_driver_via_k8s_api() and
self._kubernetes_driver_pod:
+ # spark-submit exits early under waitAppCompletion=false, so
_submit_sp.poll() is
+ # not None during the poll loop — the deletion block below is
skipped on kill.
+ self._delete_driver_pod()
+
if self._submit_sp and self._submit_sp.poll() is None:
self.log.info("Sending kill signal to %s",
self._connection["spark_binary"])
self._submit_sp.kill()
@@ -1111,24 +1288,7 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
self.log.info("YARN app killed with return code: %s",
yarn_kill.wait())
if self._kubernetes_driver_pod:
- self.log.info("Killing pod %s on Kubernetes",
self._kubernetes_driver_pod)
-
- # Currently only instantiate Kubernetes client for killing a
spark pod.
- try:
- import kubernetes
-
- client = kube_client.get_kube_client()
- api_response = client.delete_namespaced_pod(
- self._kubernetes_driver_pod,
- self._connection["namespace"],
- body=kubernetes.client.V1DeleteOptions(),
- pretty=True,
- )
-
- self.log.info("Spark on K8s killed with response: %s",
api_response)
-
- except kube_client.ApiException:
- self.log.exception("Exception when attempting to kill
Spark on K8s")
+ self._delete_driver_pod()
# Opt-in REST kill path — uses the same RM endpoint as polling, no
# `yarn` CLI dependency on the worker. Independent of `_submit_sp`
diff --git
a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
index ea7b4a8e4ef..5321dbbb8be 100644
---
a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
+++
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
@@ -113,6 +113,12 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
on keytab for Kerberos login
:param post_submit_commands: Optional list of shell commands to run after
the Spark job finishes.
Useful for cleaning up sidecars such as Istio. Failures produce a
warning but do not fail the task.
+ :param track_driver_via_k8s_api: If True (when master is Kubernetes and
``deploy_mode``
+ is ``cluster``), release the ``spark-submit`` JVM once the driver pod
has been
+ created, then poll the Kubernetes API for the pod phase until the
application
+ reaches a terminal state. The polling interval is controlled by
+ ``status_poll_interval`` with a 20-second minimum. This frees the
worker from
+ holding the long-lived submit JVM. Defaults to ``False``.
:param yarn_track_via_rm_api: If True (when master is YARN and
``deploy_mode``
is ``cluster``), release the ``spark-submit`` JVM once the application
has
been submitted to YARN, then poll the YARN ResourceManager REST API
@@ -188,6 +194,7 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
use_krb5ccache: bool = False,
post_submit_commands: list[str] | None = None,
reconnect_on_retry: bool = True,
+ track_driver_via_k8s_api: bool = False,
yarn_track_via_rm_api: bool = False,
yarn_rm_auth: AuthBase | None = None,
openlineage_inject_parent_job_info: bool = conf.getboolean(
@@ -236,6 +243,7 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
self._yarn_rm_auth = yarn_rm_auth
self.reconnect_on_retry = reconnect_on_retry
+ self._track_driver_via_k8s_api = track_driver_via_k8s_api
self._openlineage_inject_parent_job_info =
openlineage_inject_parent_job_info
self._openlineage_inject_transport_info =
openlineage_inject_transport_info
@@ -251,6 +259,8 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
if self._hook is None:
self._hook = self._get_hook()
hook = self._hook
+ if self._track_driver_via_k8s_api:
+ hook._validate_track_driver_via_k8s_api_config()
if hook._should_track_driver_status:
if self.reconnect_on_retry:
return self.execute_resumable(context)
@@ -258,6 +268,12 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
driver_id = self.submit_job(context)
self.poll_until_complete(driver_id, context)
return self.get_job_result(driver_id, context)
+ if hook._should_track_driver_via_k8s_api():
+ # TODO: Wire into execute_resumable() via ResumableJobMixin
+ # (fill submit_job / poll_until_complete K8s stubs) to enable
crash recovery.
+ hook.submit(self.application)
+ hook._poll_k8s_driver_via_api()
+ return
hook.submit(self.application)
def submit_job(self, context: Context) -> str:
@@ -402,6 +418,7 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
deploy_mode=self._deploy_mode,
use_krb5ccache=self._use_krb5ccache,
post_submit_commands=self.post_submit_commands,
+ track_driver_via_k8s_api=self._track_driver_via_k8s_api,
yarn_track_via_rm_api=self._yarn_track_via_rm_api,
yarn_rm_auth=self._yarn_rm_auth,
)
diff --git
a/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py
b/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py
index 301f720443d..f4a610a9408 100644
--- a/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py
+++ b/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py
@@ -24,11 +24,14 @@ from pathlib import Path
from types import ModuleType
from unittest.mock import MagicMock, call, mock_open, patch
+import kubernetes
import pytest
import requests
+from kubernetes.client import V1Pod, V1PodStatus
from airflow.models import Connection
from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook
+from airflow.providers.cncf.kubernetes import kube_client
from airflow.providers.common.compat.sdk import AirflowException
@@ -101,6 +104,14 @@ class TestSparkSubmitHook:
extra='{"deploy-mode": "client", "namespace": "mynamespace"}',
)
)
+ create_connection_without_db(
+ Connection(
+ conn_id="spark_k8s_cluster_no_namespace",
+ conn_type="spark",
+ host="k8s://https://k8s-master",
+ extra='{"deploy-mode": "cluster"}',
+ )
+ )
create_connection_without_db(
Connection(conn_id="spark_default_mesos", conn_type="spark",
host="mesos://host", port=5050)
)
@@ -930,6 +941,18 @@ class TestSparkSubmitHook:
# Then
assert hook._spark_exit_code == 999
+ def test_process_spark_submit_log_k8s_submission_id_format(self):
+ hook = SparkSubmitHook(conn_id="spark_k8s_cluster")
+ log_lines = [
+ "INFO Client: Deployed Spark application arrow-spark with
application ID "
+ "spark-1e22d65826b74ac2927249b0e607ed54 and submission ID "
+ "spark:arrow-spark-c8e2e29e73db9c93-driver into Kubernetes",
+ ]
+
+ hook._process_spark_submit_log(log_lines)
+
+ assert hook._kubernetes_driver_pod ==
"arrow-spark-c8e2e29e73db9c93-driver"
+
def test_process_spark_client_mode_submit_log_k8s(self):
# Given
hook = SparkSubmitHook(conn_id="spark_k8s_client")
@@ -1146,6 +1169,29 @@ class TestSparkSubmitHook:
"spark-pi-edf2ace37be7353a958b38733a12f8e6-driver", "mynamespace",
**kwargs
)
+ @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client")
+ def
test_on_kill_deletes_pod_when_k8s_api_tracking_and_submit_sp_already_exited(self,
mock_get_client):
+ """on_kill must delete the driver pod when K8s-API tracking is active
even if spark-submit
+ has already exited.
+ """
+ hook = SparkSubmitHook(conn_id="spark_k8s_cluster",
track_driver_via_k8s_api=True)
+ hook._kubernetes_driver_pod = "spark-app-abc-driver"
+ hook._kubernetes_application_id = "spark-abc"
+ hook._submit_sp = MagicMock()
+ # spark-submit already exited
+ hook._submit_sp.poll.return_value = 0
+
+ mock_client = mock_get_client.return_value
+
+ hook.on_kill()
+
+ mock_client.delete_namespaced_pod.assert_called_once_with(
+ "spark-app-abc-driver",
+ "mynamespace",
+ body=kubernetes.client.V1DeleteOptions(),
+ pretty=True,
+ )
+
@pytest.mark.parametrize(
("command", "expected"),
[
@@ -1347,7 +1393,6 @@ class TestSparkSubmitHook:
def test_run_post_submit_commands_success(self, mock_run):
"""Test that post_submit_commands are run with shell=False and
shlex.split."""
import subprocess
- from unittest.mock import MagicMock
mock_result = MagicMock(spec=subprocess.CompletedProcess)
mock_result.returncode = 0
@@ -1375,7 +1420,6 @@ class TestSparkSubmitHook:
def test_run_post_submit_commands_nonzero_exit_warns(self, mock_run):
"""Test that a non-zero exit code logs a warning but does not raise."""
import subprocess
- from unittest.mock import MagicMock
mock_result = MagicMock(spec=subprocess.CompletedProcess)
mock_result.returncode = 1
@@ -1412,6 +1456,163 @@ class TestSparkSubmitHook:
hook = SparkSubmitHook(conn_id="")
assert hook._post_submit_commands == []
+ @pytest.mark.parametrize(
+ ("conn_id", "flag", "expected"),
+ [
+ ("spark_k8s_cluster", False, False),
+ ("spark_k8s_cluster", True, True),
+ ("spark_k8s_client", True, False),
+ ],
+ )
+ def test_should_track_driver_via_k8s_api(self, conn_id, flag, expected):
+ hook = SparkSubmitHook(conn_id=conn_id, track_driver_via_k8s_api=flag)
+ assert hook._should_track_driver_via_k8s_api() is expected
+
+ @pytest.mark.parametrize(
+ ("conn_id", "match"),
+ [
+ ("spark_yarn_cluster", "requires Spark master to be Kubernetes"),
+ ("spark_k8s_client", "requires `deploy_mode='cluster'`"),
+ ("spark_k8s_cluster_no_namespace", "requires a namespace"),
+ ],
+ )
+ def test_validate_track_driver_via_k8s_api_raises(self, conn_id, match):
+ hook = SparkSubmitHook(conn_id=conn_id, track_driver_via_k8s_api=True)
+ with pytest.raises(ValueError, match=match):
+ hook._validate_track_driver_via_k8s_api_config()
+
+ def
test_validate_track_driver_via_k8s_api_raises_on_conflicting_user_conf(self):
+ hook = SparkSubmitHook(
+ conn_id="spark_k8s_cluster",
+ track_driver_via_k8s_api=True,
+ conf={"spark.kubernetes.submission.waitAppCompletion": "true"},
+ )
+ with pytest.raises(ValueError, match="incompatible
with.*waitAppCompletion=true"):
+ hook._validate_track_driver_via_k8s_api_config()
+
+ def test_conf_injection_adds_wait_app_completion(self):
+ hook = SparkSubmitHook(conn_id="spark_k8s_cluster",
track_driver_via_k8s_api=True)
+ cmd = hook._build_spark_submit_command("app.jar")
+ conf_pairs = [cmd[i + 1] for i, v in enumerate(cmd) if v == "--conf"]
+ assert "spark.kubernetes.submission.waitAppCompletion=false" in
conf_pairs
+
+ @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client")
+ def test_poll_k8s_driver_succeeds(self, mock_get_client):
+ hook = SparkSubmitHook(conn_id="spark_k8s_cluster",
track_driver_via_k8s_api=True)
+ hook._kubernetes_driver_pod = "spark-app-abc-driver"
+ hook._kubernetes_application_id = "spark-abc"
+
+ mock_client = mock_get_client.return_value
+ running_pod = V1Pod(status=V1PodStatus(phase="Running"))
+ succeeded_pod = V1Pod(status=V1PodStatus(phase="Succeeded"))
+ mock_client.read_namespaced_pod.side_effect = [running_pod,
succeeded_pod]
+
+ with patch.object(hook, "_run_post_submit_commands"):
+ hook._poll_k8s_driver_via_api()
+
+ assert mock_client.delete_namespaced_pod.call_args.args[:2] ==
("spark-app-abc-driver", "mynamespace")
+
+ @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client")
+ def test_poll_k8s_driver_raises_on_failed(self, mock_get_client):
+ hook = SparkSubmitHook(conn_id="spark_k8s_cluster",
track_driver_via_k8s_api=True)
+ hook._kubernetes_driver_pod = "spark-app-abc-driver"
+ hook._kubernetes_application_id = "spark-abc"
+
+ mock_client = mock_get_client.return_value
+ failed_pod = V1Pod(status=V1PodStatus(phase="Failed"))
+ mock_client.read_namespaced_pod.return_value = failed_pod
+
+ with pytest.raises(RuntimeError, match="phase=Failed"):
+ hook._poll_k8s_driver_via_api()
+
+ @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client")
+ def test_poll_k8s_driver_raises_after_consecutive_unknown(self,
mock_get_client):
+ hook = SparkSubmitHook(conn_id="spark_k8s_cluster",
track_driver_via_k8s_api=True)
+ hook._kubernetes_driver_pod = "spark-app-abc-driver"
+ hook._kubernetes_application_id = "spark-abc"
+
+ mock_client = mock_get_client.return_value
+ mock_client.read_namespaced_pod.return_value =
V1Pod(status=V1PodStatus(phase="Unknown"))
+
+ with patch("time.sleep"), pytest.raises(RuntimeError, match="Unknown
phase"):
+ hook._poll_k8s_driver_via_api()
+
+ # assert that it was polled minimum 3 times to confirm the Unknown
status before raising
+ assert mock_client.read_namespaced_pod.call_count == 3
+
+ @patch("time.sleep")
+ @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client")
+ def test_poll_k8s_driver_tolerates_transient_api_errors(self,
mock_get_client, _):
+ hook = SparkSubmitHook(conn_id="spark_k8s_cluster",
track_driver_via_k8s_api=True)
+ hook._kubernetes_driver_pod = "spark-app-abc-driver"
+ hook._kubernetes_application_id = "spark-abc"
+
+ mock_client = mock_get_client.return_value
+ api_error = kube_client.ApiException(status=500, reason="Internal
Server Error")
+ succeeded_pod = V1Pod(status=V1PodStatus(phase="Succeeded"))
+ mock_client.read_namespaced_pod.side_effect = [api_error, api_error,
succeeded_pod]
+
+ with patch.object(hook, "_run_post_submit_commands"):
+ hook._poll_k8s_driver_via_api()
+
+ assert mock_client.read_namespaced_pod.call_count == 3
+
+ @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client")
+ def test_post_submit_commands_run_exactly_once_on_k8s_path(self,
mock_get_client):
+ """_run_post_submit_commands must fire exactly once: in
_poll_k8s_driver_via_api finally."""
+ hook = SparkSubmitHook(conn_id="spark_k8s_cluster",
track_driver_via_k8s_api=True)
+ hook._kubernetes_driver_pod = "spark-app-abc-driver"
+ hook._kubernetes_application_id = "spark-abc"
+
+ mock_client = mock_get_client.return_value
+ mock_client.read_namespaced_pod.return_value =
V1Pod(status=V1PodStatus(phase="Succeeded"))
+
+ with patch.object(hook, "_run_post_submit_commands") as mock_cmd:
+ hook._poll_k8s_driver_via_api()
+
+ mock_cmd.assert_called_once()
+
+ @patch("time.sleep")
+ @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client")
+ def test_poll_k8s_driver_raises_after_consecutive_api_errors(self,
mock_get_client, _):
+ hook = SparkSubmitHook(conn_id="spark_k8s_cluster",
track_driver_via_k8s_api=True)
+ hook._kubernetes_driver_pod = "spark-app-abc-driver"
+ hook._kubernetes_application_id = "spark-abc"
+
+ mock_client = mock_get_client.return_value
+ api_error = kube_client.ApiException(status=500, reason="Internal
Server Error")
+ mock_client.read_namespaced_pod.side_effect = api_error
+
+ with pytest.raises(RuntimeError, match="K8s API unreachable"):
+ hook._poll_k8s_driver_via_api()
+
+ assert mock_client.read_namespaced_pod.call_count == 3
+
+ @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client")
+ def test_poll_k8s_driver_exits_cleanly_on_404(self, mock_get_client):
+ """404 from read_namespaced_pod means pod was deleted by on_kill —
should return cleanly, not raise."""
+ hook = SparkSubmitHook(conn_id="spark_k8s_cluster",
track_driver_via_k8s_api=True)
+ hook._kubernetes_driver_pod = "spark-app-abc-driver"
+ hook._kubernetes_application_id = "spark-abc"
+
+ mock_client = mock_get_client.return_value
+ mock_client.read_namespaced_pod.side_effect =
kube_client.ApiException(status=404, reason="Not Found")
+
+ hook._poll_k8s_driver_via_api()
+
+ mock_client.delete_namespaced_pod.assert_not_called()
+
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.run")
+ def test_run_post_submit_commands_runs_only_once(self, mock_run):
+ """Calling _run_post_submit_commands twice must execute commands
exactly once."""
+ mock_run.return_value = MagicMock(returncode=0, stdout="")
+ hook = SparkSubmitHook(conn_id="spark_k8s_cluster",
post_submit_commands=["echo done"])
+
+ hook._run_post_submit_commands()
+ hook._run_post_submit_commands()
+
+ mock_run.assert_called_once()
+
_YARN_LOG_LINES = [
"INFO Client: Requesting a new application from cluster with 1
NodeManagers",
"INFO Client: Uploading resource file:/tmp/lib.zip -> "
diff --git
a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
index e18de1dd576..47cada84ce6 100644
---
a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
+++
b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
@@ -733,4 +733,43 @@ class TestSparkSubmitOperatorResumable:
with pytest.raises(RuntimeError, match="FAILED"):
operator.poll_until_complete("driver-001", {})
- assert post_submit_called, "_run_post_submit_commands must be called
even on driver failure"
+
+class TestSparkSubmitOperatorK8sTracking:
+ def setup_method(self):
+ args = {"owner": "airflow", "start_date": DEFAULT_DATE}
+ self.dag = DAG("test_k8s_tracking_dag", schedule=None,
default_args=args)
+
+ def _make_operator(self, **kwargs):
+ return SparkSubmitOperator(task_id="test", dag=self.dag,
application="test.jar", **kwargs)
+
+ def _make_k8s_hook(self):
+ hook = MagicMock()
+ hook._should_track_driver_status = False
+ hook._should_track_driver_via_k8s_api.return_value = True
+ return hook
+
+ def test_execute_calls_submit_then_poll_when_flag_set(self):
+ operator = self._make_operator(track_driver_via_k8s_api=True)
+ hook = self._make_k8s_hook()
+ operator._hook = hook
+ call_order = []
+ hook.submit.side_effect = lambda *a, **kw: call_order.append("submit")
+ hook._poll_k8s_driver_via_api.side_effect = lambda:
call_order.append("poll")
+
+ operator.execute(context={})
+
+ hook.submit.assert_called_once_with("test.jar")
+ hook._poll_k8s_driver_via_api.assert_called_once()
+ assert call_order == ["submit", "poll"]
+
+ def test_execute_falls_through_to_plain_submit_when_flag_off(self):
+ operator = self._make_operator(track_driver_via_k8s_api=False)
+ hook = MagicMock()
+ hook._should_track_driver_status = False
+ hook._should_track_driver_via_k8s_api.return_value = False
+ operator._hook = hook
+
+ operator.execute(context={})
+
+ hook.submit.assert_called_once_with("test.jar")
+ hook._poll_k8s_driver_via_api.assert_not_called()