This is an automated email from the ASF dual-hosted git repository.

amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new daad2c4bca6 Use K8s API to track Spark on K8s instead of JVM based 
spark-submit (#67715)
daad2c4bca6 is described below

commit daad2c4bca6d7945c2c7a4baf6f4daa1c7ae6bdc
Author: Amogh Desai <[email protected]>
AuthorDate: Thu Jun 4 16:16:42 2026 +0530

    Use K8s API to track Spark on K8s instead of JVM based spark-submit (#67715)
---
 providers/apache/spark/docs/operators.rst          |  37 ++++
 .../providers/apache/spark/hooks/spark_submit.py   | 216 ++++++++++++++++++---
 .../apache/spark/operators/spark_submit.py         |  17 ++
 .../unit/apache/spark/hooks/test_spark_submit.py   | 205 ++++++++++++++++++-
 .../apache/spark/operators/test_spark_submit.py    |  41 +++-
 5 files changed, 485 insertions(+), 31 deletions(-)

diff --git a/providers/apache/spark/docs/operators.rst 
b/providers/apache/spark/docs/operators.rst
index f20d389811e..6bdd4bbcdc7 100644
--- a/providers/apache/spark/docs/operators.rst
+++ b/providers/apache/spark/docs/operators.rst
@@ -215,6 +215,43 @@ See :doc:`connections/spark-submit` for how to configure 
these fields.
     Crash recovery in cluster mode requires Airflow 3.3+ (``task_state`` 
support). On earlier
     versions the operator falls back to the previous behavior of always 
submitting fresh.
 
+Tracking driver status via Kubernetes API
+""""""""""""""""""""""""""""""""""""""""""
+
+When running in Kubernetes cluster mode, ``spark-submit`` blocks for the 
duration of the job.
+The JVM runs processes which does nothing but polling of the pod phase and 
holds heap space for
+the entire duration. This is not ideal for long-running jobs, especially when 
the driver is idle
+for long periods (e.g. waiting for data or user input).
+
+Set ``track_driver_via_k8s_api=True`` to have the operator track the driver 
pod status via the
+Python Kubernetes client rather than holding ``spark-submit`` open for the 
full job duration:
+
+.. code-block:: python
+
+   from airflow.providers.apache.spark.operators.spark_submit import 
SparkSubmitOperator
+
+   run_spark = SparkSubmitOperator(
+       task_id="run_spark",
+       application="local:///opt/spark/examples/jars/spark-examples.jar",
+       conn_id="spark_k8s",
+       deploy_mode="cluster",
+       track_driver_via_k8s_api=True,
+   )
+
+**Requirements**
+
+* The Spark connection ``master`` must be ``k8s://...`` and ``deploy_mode`` 
must be ``cluster``.
+* Do not set ``spark.kubernetes.submission.waitAppCompletion=true`` in your 
``conf`` — this
+  conflicts with the flag and a ``ValueError`` will be raised at task start.
+* The Airflow worker must be able to reach the Kubernetes API server and have 
permission to
+  read and delete pods in the driver's namespace; otherwise pod tracking and 
cleanup will fail.
+* This path bypasses ``ResumableJobMixin``, so Airflow retries submit a fresh 
driver instead of
+  reconnecting to an existing one. Set ``execution_timeout`` to bound 
wall-clock time.
+* Pod completion is detected from ``pod.status.phase``. If your driver pods 
have sidecar
+  containers (e.g. Istio injection enabled for the driver namespace), the pod 
phase may not
+  advance to ``Succeeded`` until all sidecars exit. In that case the poll loop 
will wait
+  indefinitely — set ``execution_timeout`` as a hard bound.
+
 YARN ResourceManager API tracking
 """""""""""""""""""""""""""""""""
 
diff --git 
a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
 
b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
index 7306563a078..7cf1f3248ad 100644
--- 
a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
+++ 
b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
@@ -52,6 +52,8 @@ if TYPE_CHECKING:
 DEFAULT_SPARK_BINARY = "spark-submit"
 ALLOWED_SPARK_BINARIES = [DEFAULT_SPARK_BINARY, "spark2-submit", 
"spark3-submit"]
 
+_K8S_WAIT_APP_COMPLETION_CONF = "spark.kubernetes.submission.waitAppCompletion"
+
 
 class SparkSubmitHook(BaseHook, LoggingMixin):
     """
@@ -90,11 +92,11 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
     :param name: Name of the job (default airflow-spark)
     :param num_executors: Number of executors to launch
     :param status_poll_interval: Seconds to wait between polls of driver 
status in cluster
-        mode. Used both by the Spark standalone driver-status tracker and (when
-        ``yarn_track_via_rm_api=True``) by the YARN ResourceManager REST API
-        polling loop. The YARN ResourceManager REST API polling loop uses at
-        least 10 seconds to avoid flooding the ResourceManager on long-running
-        jobs (Default: 1).
+        mode (Default: 1). Controls three polling loops — each enforces its 
own minimum:
+
+        - Spark standalone driver-status tracker (no minimum)
+        - YARN ResourceManager REST API, when ``yarn_track_via_rm_api=True`` 
(10s minimum)
+        - Kubernetes API, when ``track_driver_via_k8s_api=True`` (20s minimum)
     :param application_args: Arguments for the application being submitted
     :param env_vars: Environment variables for spark-submit. It
         supports yarn and k8s mode too.
@@ -114,6 +116,13 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
         job finishes (on both success and on_kill). Useful for cleaning up 
sidecars such
         as Istio (e.g. ``["curl -X POST localhost:15020/quitquitquit"]``). 
Each command
         is executed via the shell; failures produce a warning but do not fail 
the task.
+    :param track_driver_via_k8s_api: If True (when master is Kubernetes and
+        ``deploy_mode`` is ``cluster``), release the ``spark-submit`` JVM once 
the
+        driver pod has been created, then poll the Kubernetes API for the pod 
phase
+        until the application reaches a terminal state. The polling interval is
+        controlled by ``status_poll_interval`` with a 20-second minimum. This 
frees
+        the worker from holding the long-lived submit JVM (~500 MB). Defaults 
to
+        ``False``.
     :param yarn_track_via_rm_api: If True (when master is YARN and 
``deploy_mode``
         is ``cluster``), release the ``spark-submit`` JVM once the application 
has
         been submitted to YARN, then poll the YARN ResourceManager REST API
@@ -257,6 +266,7 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
         *,
         use_krb5ccache: bool = False,
         post_submit_commands: list[str] | None = None,
+        track_driver_via_k8s_api: bool = False,
         yarn_track_via_rm_api: bool = False,
         yarn_rm_auth: AuthBase | None = None,
     ) -> None:
@@ -302,12 +312,14 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
                 f"{self._connection['master']} specified by kubernetes 
dependencies are not installed!"
             )
 
+        self._track_driver_via_k8s_api = track_driver_via_k8s_api
         self._should_track_driver_status = 
self._resolve_should_track_driver_status()
         self._driver_id: str | None = None
         self._driver_status: str | None = None
         self._spark_exit_code: int | None = None
         self._env: dict[str, Any] | None = None
         self._post_submit_commands: list[str] = list(post_submit_commands) if 
post_submit_commands else []
+        self._post_submit_commands_done: bool = False
         self._yarn_track_via_rm_api = yarn_track_via_rm_api
         self._yarn_rm_auth = yarn_rm_auth
         # Cached after first successful resolution so the polling loop in
@@ -326,6 +338,34 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
         """
         return "spark://" in self._connection["master"] and 
self._connection["deploy_mode"] == "cluster"
 
+    def _should_track_driver_via_k8s_api(self) -> bool:
+        return (
+            self._track_driver_via_k8s_api
+            and self._is_kubernetes
+            and self._connection["deploy_mode"] == "cluster"
+        )
+
+    def _validate_track_driver_via_k8s_api_config(self) -> None:
+        if not self._is_kubernetes:
+            raise ValueError(
+                "`track_driver_via_k8s_api=True` requires Spark master to be 
Kubernetes (k8s://...)."
+            )
+        if self._connection["deploy_mode"] != "cluster":
+            raise ValueError(
+                "`track_driver_via_k8s_api=True` requires 
`deploy_mode='cluster'`; "
+                f"got deploy_mode={self._connection['deploy_mode']!r}."
+            )
+        if not self._connection.get("namespace"):
+            raise ValueError(
+                "`track_driver_via_k8s_api=True` requires a namespace; "
+                "set it in the connection extra as `namespace` or via 
`spark.kubernetes.namespace` in conf."
+            )
+        if str(self._conf.get(_K8S_WAIT_APP_COMPLETION_CONF, "")).lower() == 
"true":
+            raise ValueError(
+                f"`track_driver_via_k8s_api=True` is incompatible with "
+                f"`{_K8S_WAIT_APP_COMPLETION_CONF}=true`; remove it from your 
conf or set it to 'false'."
+            )
+
     def _should_track_yarn_application_via_rm_api(self) -> bool:
         """Return whether this submit should switch to YARN RM REST API 
polling."""
         return self._yarn_track_via_rm_api and self._is_yarn and 
self._connection["deploy_mode"] == "cluster"
@@ -581,6 +621,10 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
         if self._connection["deploy_mode"]:
             args += ["--deploy-mode", self._connection["deploy_mode"]]
 
+        if self._should_track_driver_via_k8s_api():
+            if _K8S_WAIT_APP_COMPLETION_CONF not in self._conf:
+                args += ["--conf", f"{_K8S_WAIT_APP_COMPLETION_CONF}=false"]
+
         return args
 
     def _build_spark_submit_command(self, application: str) -> list[str]:
@@ -665,7 +709,12 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
         Called after the Spark job finishes (success or on_kill). Typical use 
case
         is killing sidecars like Istio that don't shut down automatically.
         Failures are logged as warnings and never raise.
+        Guaranteed to run at most once per hook instance even if called from 
both
+        the poll-loop finally and on_kill (e.g. after a SIGTERM).
         """
+        if self._post_submit_commands_done:
+            return
+        self._post_submit_commands_done = True
         for cmd in self._post_submit_commands:
             self.log.debug("Running post-submit command: %s", cmd)
             try:
@@ -719,8 +768,14 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
 
         # Check spark-submit return code. In Kubernetes mode, also check the 
value
         # of exit code in the log, as it may differ.
+        # When polling via K8s API, spark-submit exits after pod creation 
(waitAppCompletion=false)
+        # so _spark_exit_code is never set by the JVM watcher — skip that 
check entirely.
         try:
-            if returncode or (self._is_kubernetes and self._spark_exit_code != 
0):
+            if returncode or (
+                self._is_kubernetes
+                and not self._should_track_driver_via_k8s_api()
+                and self._spark_exit_code != 0
+            ):
                 if self._is_kubernetes:
                     raise AirflowException(
                         f"Cannot execute: {self._mask_cmd(spark_submit_cmd)}. 
Error code is: {returncode}. "
@@ -744,10 +799,11 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
                     "No driver id is known: something went wrong when 
executing the spark submit command"
                 )
         finally:
-            # In cluster mode with driver tracking, the operator calls 
poll_until_complete
-            # after submit() returns, so post_submit_commands are deferred 
there to preserve
-            # the "runs after job finishes" contract. In all other modes, run 
them here.
-            if not self._should_track_driver_status:
+            # K8s-API tracking defers post-submit commands to 
_poll_k8s_driver_via_api's finally
+            # block so they run once after the driver reaches a terminal 
state. Spark cluster-mode
+            # driver tracking defers them to poll_until_complete for the same 
reason. All other
+            # modes run them here, immediately after spark-submit exits.
+            if not self._should_track_driver_status and not 
self._should_track_driver_via_k8s_api():
                 self._run_post_submit_commands()
 
         return self._driver_id
@@ -778,10 +834,17 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
             # If we run Kubernetes cluster mode, we want to extract the driver 
pod id
             # from the logs so we can kill the application when we stop it 
unexpectedly
             elif self._is_kubernetes:
+                # Two log formats exist across Spark versions:
+                # "pod name: <name>-driver" and "submission ID 
spark:<name>-driver"
                 match_driver_pod = re.search(r"\s*pod name: 
((.+?)-([a-z0-9]+)-driver$)", line)
                 if match_driver_pod:
                     self._kubernetes_driver_pod = match_driver_pod.group(1)
                     self.log.info("Identified spark driver pod: %s", 
self._kubernetes_driver_pod)
+                if not self._kubernetes_driver_pod:
+                    match_submission_id = re.search(r"submission ID 
spark:(.+?-driver)", line)
+                    if match_submission_id:
+                        self._kubernetes_driver_pod = 
match_submission_id.group(1)
+                        self.log.info("Identified spark driver pod: %s", 
self._kubernetes_driver_pod)
 
                 match_application_id = re.search(r"\s*spark-app-selector -> 
(spark-([a-z0-9]+)), ", line)
                 if match_application_id:
@@ -1047,6 +1110,96 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
                         f"returncode = {returncode}"
                     )
 
+    def _poll_k8s_driver_via_api(self) -> None:
+        """Poll the K8s driver pod phase until it reaches a terminal state."""
+        pod_name = self._kubernetes_driver_pod
+        namespace = self._connection["namespace"]
+        app_id = self._kubernetes_application_id or pod_name
+
+        client = kube_client.get_kube_client()
+        poll_interval = max(self._status_poll_interval, 20)
+        if poll_interval != self._status_poll_interval:
+            self.log.info(
+                "status_poll_interval=%ds is below the 20s minimum for K8s API 
polling; using 20s.",
+                self._status_poll_interval,
+            )
+        # Mirror `missed_job_status_reports` / `max_missed_job_status_reports` 
from
+        # `_start_driver_status_tracking`: tolerate transient failures before 
giving up.
+        consecutive_unknown = 0
+        max_consecutive_unknown = 3
+        consecutive_api_errors = 0
+        max_consecutive_api_errors = 3
+        consecutive_pending = 0
+        pending_warn_threshold = 10
+
+        try:
+            if not pod_name:
+                raise ValueError("K8s driver pod name not set; cannot poll 
status.")
+            while True:
+                try:
+                    pod = client.read_namespaced_pod(pod_name, namespace)
+                    consecutive_api_errors = 0
+                except kube_client.ApiException as e:
+                    if e.status == 404:
+                        self.log.info(
+                            "Driver pod %s not found (404); pod was likely 
deleted by on_kill. Exiting poll loop.",
+                            pod_name,
+                        )
+                        return
+                    consecutive_api_errors += 1
+                    self.log.warning(
+                        "ApiException polling pod %s (%d/%d): %s",
+                        pod_name,
+                        consecutive_api_errors,
+                        max_consecutive_api_errors,
+                        e,
+                    )
+                    if consecutive_api_errors >= max_consecutive_api_errors:
+                        raise RuntimeError(
+                            f"K8s API unreachable after 
{consecutive_api_errors} consecutive errors "
+                            f"while polling {app_id}; giving up."
+                        ) from e
+                    time.sleep(poll_interval)
+                    continue
+
+                phase = pod.status.phase or "Initializing"
+                self.log.info("Application status for %s (phase: %s)", app_id, 
phase)
+                if phase == "Succeeded":
+                    break
+                if phase == "Failed":
+                    container_state = ""
+                    if pod.status.container_statuses:
+                        cs = pod.status.container_statuses[0]
+                        if cs.state and cs.state.terminated:
+                            container_state = f" 
exit_code={cs.state.terminated.exit_code} reason={cs.state.terminated.reason}"
+                    raise RuntimeError(f"Spark application {app_id} failed 
(phase=Failed{container_state})")
+                if phase == "Pending":
+                    consecutive_pending += 1
+                    if consecutive_pending == pending_warn_threshold:
+                        self.log.warning(
+                            "Driver pod %s has been Pending for %d polls 
(~%ds); "
+                            "it may be unschedulable. Continuing to wait — set 
execution_timeout to bound wait time.",
+                            pod_name,
+                            consecutive_pending,
+                            consecutive_pending * poll_interval,
+                        )
+                else:
+                    consecutive_pending = 0
+
+                if phase == "Unknown":
+                    consecutive_unknown += 1
+                    if consecutive_unknown >= max_consecutive_unknown:
+                        raise RuntimeError(
+                            f"Spark application {app_id} reported Unknown 
phase "
+                            f"{consecutive_unknown} times consecutively; 
giving up."
+                        )
+                else:
+                    consecutive_unknown = 0
+                time.sleep(poll_interval)
+            self._delete_driver_pod()
+        finally:
+            self._run_post_submit_commands()
+
     def _build_spark_driver_kill_command(self) -> list[str]:
         """
         Construct the spark-submit command to kill a driver.
@@ -1067,6 +1220,25 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
 
         return connection_cmd
 
+    def _delete_driver_pod(self) -> None:
+        """Delete the Kubernetes driver pod, logging a warning on failure."""
+        import kubernetes
+
+        self.log.info("Deleting driver pod %s on Kubernetes", 
self._kubernetes_driver_pod)
+        try:
+            client = kube_client.get_kube_client()
+            client.delete_namespaced_pod(
+                self._kubernetes_driver_pod,
+                self._connection["namespace"],
+                body=kubernetes.client.V1DeleteOptions(),
+                pretty=True,
+            )
+            self.log.info("Deleted driver pod %s", self._kubernetes_driver_pod)
+        except kube_client.ApiException:
+            self.log.exception(
+                "Exception when attempting to delete driver pod %s", 
self._kubernetes_driver_pod
+            )
+
     def on_kill(self) -> None:
         """Kill Spark submit command."""
         self.log.debug("Kill Command is being called")
@@ -1080,6 +1252,11 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
                     "Spark driver %s killed with return code: %s", 
self._driver_id, driver_kill.wait()
                 )
 
+        if self._should_track_driver_via_k8s_api() and 
self._kubernetes_driver_pod:
+            # spark-submit exits early under waitAppCompletion=false, so 
_submit_sp.poll() is
+            # not None during the poll loop — the deletion block below is 
skipped on kill.
+            self._delete_driver_pod()
+
         if self._submit_sp and self._submit_sp.poll() is None:
             self.log.info("Sending kill signal to %s", 
self._connection["spark_binary"])
             self._submit_sp.kill()
@@ -1111,24 +1288,7 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
                     self.log.info("YARN app killed with return code: %s", 
yarn_kill.wait())
 
             if self._kubernetes_driver_pod:
-                self.log.info("Killing pod %s on Kubernetes", 
self._kubernetes_driver_pod)
-
-                # Currently only instantiate Kubernetes client for killing a 
spark pod.
-                try:
-                    import kubernetes
-
-                    client = kube_client.get_kube_client()
-                    api_response = client.delete_namespaced_pod(
-                        self._kubernetes_driver_pod,
-                        self._connection["namespace"],
-                        body=kubernetes.client.V1DeleteOptions(),
-                        pretty=True,
-                    )
-
-                    self.log.info("Spark on K8s killed with response: %s", 
api_response)
-
-                except kube_client.ApiException:
-                    self.log.exception("Exception when attempting to kill 
Spark on K8s")
+                self._delete_driver_pod()
 
         # Opt-in REST kill path — uses the same RM endpoint as polling, no
         # `yarn` CLI dependency on the worker. Independent of `_submit_sp`
diff --git 
a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
 
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
index ea7b4a8e4ef..5321dbbb8be 100644
--- 
a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
+++ 
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
@@ -113,6 +113,12 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
                            on keytab for Kerberos login
     :param post_submit_commands: Optional list of shell commands to run after 
the Spark job finishes.
         Useful for cleaning up sidecars such as Istio. Failures produce a 
warning but do not fail the task.
+    :param track_driver_via_k8s_api: If True (when master is Kubernetes and 
``deploy_mode``
+        is ``cluster``), release the ``spark-submit`` JVM once the driver pod 
has been
+        created, then poll the Kubernetes API for the pod phase until the 
application
+        reaches a terminal state. The polling interval is controlled by
+        ``status_poll_interval`` with a 20-second minimum. This frees the 
worker from
+        holding the long-lived submit JVM. Defaults to ``False``.
     :param yarn_track_via_rm_api: If True (when master is YARN and 
``deploy_mode``
         is ``cluster``), release the ``spark-submit`` JVM once the application 
has
         been submitted to YARN, then poll the YARN ResourceManager REST API
@@ -188,6 +194,7 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
         use_krb5ccache: bool = False,
         post_submit_commands: list[str] | None = None,
         reconnect_on_retry: bool = True,
+        track_driver_via_k8s_api: bool = False,
         yarn_track_via_rm_api: bool = False,
         yarn_rm_auth: AuthBase | None = None,
         openlineage_inject_parent_job_info: bool = conf.getboolean(
@@ -236,6 +243,7 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
         self._yarn_rm_auth = yarn_rm_auth
 
         self.reconnect_on_retry = reconnect_on_retry
+        self._track_driver_via_k8s_api = track_driver_via_k8s_api
         self._openlineage_inject_parent_job_info = 
openlineage_inject_parent_job_info
         self._openlineage_inject_transport_info = 
openlineage_inject_transport_info
 
@@ -251,6 +259,8 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
         if self._hook is None:
             self._hook = self._get_hook()
         hook = self._hook
+        if self._track_driver_via_k8s_api:
+            hook._validate_track_driver_via_k8s_api_config()
         if hook._should_track_driver_status:
             if self.reconnect_on_retry:
                 return self.execute_resumable(context)
@@ -258,6 +268,12 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
             driver_id = self.submit_job(context)
             self.poll_until_complete(driver_id, context)
             return self.get_job_result(driver_id, context)
+        if hook._should_track_driver_via_k8s_api():
+            # TODO: Wire into execute_resumable() via ResumableJobMixin
+            # (fill submit_job / poll_until_complete K8s stubs) to enable 
crash recovery.
+            hook.submit(self.application)
+            hook._poll_k8s_driver_via_api()
+            return
         hook.submit(self.application)
 
     def submit_job(self, context: Context) -> str:
@@ -402,6 +418,7 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
             deploy_mode=self._deploy_mode,
             use_krb5ccache=self._use_krb5ccache,
             post_submit_commands=self.post_submit_commands,
+            track_driver_via_k8s_api=self._track_driver_via_k8s_api,
             yarn_track_via_rm_api=self._yarn_track_via_rm_api,
             yarn_rm_auth=self._yarn_rm_auth,
         )
diff --git 
a/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py 
b/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py
index 301f720443d..f4a610a9408 100644
--- a/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py
+++ b/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py
@@ -24,11 +24,14 @@ from pathlib import Path
 from types import ModuleType
 from unittest.mock import MagicMock, call, mock_open, patch
 
+import kubernetes
 import pytest
 import requests
+from kubernetes.client import V1Pod, V1PodStatus
 
 from airflow.models import Connection
 from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook
+from airflow.providers.cncf.kubernetes import kube_client
 from airflow.providers.common.compat.sdk import AirflowException
 
 
@@ -101,6 +104,14 @@ class TestSparkSubmitHook:
                 extra='{"deploy-mode": "client", "namespace": "mynamespace"}',
             )
         )
+        create_connection_without_db(
+            Connection(
+                conn_id="spark_k8s_cluster_no_namespace",
+                conn_type="spark",
+                host="k8s://https://k8s-master";,
+                extra='{"deploy-mode": "cluster"}',
+            )
+        )
         create_connection_without_db(
             Connection(conn_id="spark_default_mesos", conn_type="spark", 
host="mesos://host", port=5050)
         )
@@ -930,6 +941,18 @@ class TestSparkSubmitHook:
         # Then
         assert hook._spark_exit_code == 999
 
+    def test_process_spark_submit_log_k8s_submission_id_format(self):
+        hook = SparkSubmitHook(conn_id="spark_k8s_cluster")
+        log_lines = [
+            "INFO Client: Deployed Spark application arrow-spark with 
application ID "
+            "spark-1e22d65826b74ac2927249b0e607ed54 and submission ID "
+            "spark:arrow-spark-c8e2e29e73db9c93-driver into Kubernetes",
+        ]
+
+        hook._process_spark_submit_log(log_lines)
+
+        assert hook._kubernetes_driver_pod == 
"arrow-spark-c8e2e29e73db9c93-driver"
+
     def test_process_spark_client_mode_submit_log_k8s(self):
         # Given
         hook = SparkSubmitHook(conn_id="spark_k8s_client")
@@ -1146,6 +1169,29 @@ class TestSparkSubmitHook:
             "spark-pi-edf2ace37be7353a958b38733a12f8e6-driver", "mynamespace", 
**kwargs
         )
 
+    @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client")
+    def 
test_on_kill_deletes_pod_when_k8s_api_tracking_and_submit_sp_already_exited(self,
 mock_get_client):
+        """on_kill must delete the driver pod when K8s-API tracking is active 
even if spark-submit
+        has already exited.
+        """
+        hook = SparkSubmitHook(conn_id="spark_k8s_cluster", 
track_driver_via_k8s_api=True)
+        hook._kubernetes_driver_pod = "spark-app-abc-driver"
+        hook._kubernetes_application_id = "spark-abc"
+        hook._submit_sp = MagicMock()
+        # spark-submit already exited
+        hook._submit_sp.poll.return_value = 0
+
+        mock_client = mock_get_client.return_value
+
+        hook.on_kill()
+
+        mock_client.delete_namespaced_pod.assert_called_once_with(
+            "spark-app-abc-driver",
+            "mynamespace",
+            body=kubernetes.client.V1DeleteOptions(),
+            pretty=True,
+        )
+
     @pytest.mark.parametrize(
         ("command", "expected"),
         [
@@ -1347,7 +1393,6 @@ class TestSparkSubmitHook:
     def test_run_post_submit_commands_success(self, mock_run):
         """Test that post_submit_commands are run with shell=False and 
shlex.split."""
         import subprocess
-        from unittest.mock import MagicMock
 
         mock_result = MagicMock(spec=subprocess.CompletedProcess)
         mock_result.returncode = 0
@@ -1375,7 +1420,6 @@ class TestSparkSubmitHook:
     def test_run_post_submit_commands_nonzero_exit_warns(self, mock_run):
         """Test that a non-zero exit code logs a warning but does not raise."""
         import subprocess
-        from unittest.mock import MagicMock
 
         mock_result = MagicMock(spec=subprocess.CompletedProcess)
         mock_result.returncode = 1
@@ -1412,6 +1456,163 @@ class TestSparkSubmitHook:
         hook = SparkSubmitHook(conn_id="")
         assert hook._post_submit_commands == []
 
+    @pytest.mark.parametrize(
+        ("conn_id", "flag", "expected"),
+        [
+            ("spark_k8s_cluster", False, False),
+            ("spark_k8s_cluster", True, True),
+            ("spark_k8s_client", True, False),
+        ],
+    )
+    def test_should_track_driver_via_k8s_api(self, conn_id, flag, expected):
+        hook = SparkSubmitHook(conn_id=conn_id, track_driver_via_k8s_api=flag)
+        assert hook._should_track_driver_via_k8s_api() is expected
+
+    @pytest.mark.parametrize(
+        ("conn_id", "match"),
+        [
+            ("spark_yarn_cluster", "requires Spark master to be Kubernetes"),
+            ("spark_k8s_client", "requires `deploy_mode='cluster'`"),
+            ("spark_k8s_cluster_no_namespace", "requires a namespace"),
+        ],
+    )
+    def test_validate_track_driver_via_k8s_api_raises(self, conn_id, match):
+        hook = SparkSubmitHook(conn_id=conn_id, track_driver_via_k8s_api=True)
+        with pytest.raises(ValueError, match=match):
+            hook._validate_track_driver_via_k8s_api_config()
+
+    def 
test_validate_track_driver_via_k8s_api_raises_on_conflicting_user_conf(self):
+        hook = SparkSubmitHook(
+            conn_id="spark_k8s_cluster",
+            track_driver_via_k8s_api=True,
+            conf={"spark.kubernetes.submission.waitAppCompletion": "true"},
+        )
+        with pytest.raises(ValueError, match="incompatible 
with.*waitAppCompletion=true"):
+            hook._validate_track_driver_via_k8s_api_config()
+
+    def test_conf_injection_adds_wait_app_completion(self):
+        hook = SparkSubmitHook(conn_id="spark_k8s_cluster", 
track_driver_via_k8s_api=True)
+        cmd = hook._build_spark_submit_command("app.jar")
+        conf_pairs = [cmd[i + 1] for i, v in enumerate(cmd) if v == "--conf"]
+        assert "spark.kubernetes.submission.waitAppCompletion=false" in 
conf_pairs
+
+    @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client")
+    def test_poll_k8s_driver_succeeds(self, mock_get_client):
+        hook = SparkSubmitHook(conn_id="spark_k8s_cluster", 
track_driver_via_k8s_api=True)
+        hook._kubernetes_driver_pod = "spark-app-abc-driver"
+        hook._kubernetes_application_id = "spark-abc"
+
+        mock_client = mock_get_client.return_value
+        running_pod = V1Pod(status=V1PodStatus(phase="Running"))
+        succeeded_pod = V1Pod(status=V1PodStatus(phase="Succeeded"))
+        mock_client.read_namespaced_pod.side_effect = [running_pod, 
succeeded_pod]
+
+        with patch.object(hook, "_run_post_submit_commands"):
+            hook._poll_k8s_driver_via_api()
+
+        assert mock_client.delete_namespaced_pod.call_args.args[:2] == 
("spark-app-abc-driver", "mynamespace")
+
+    @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client")
+    def test_poll_k8s_driver_raises_on_failed(self, mock_get_client):
+        hook = SparkSubmitHook(conn_id="spark_k8s_cluster", 
track_driver_via_k8s_api=True)
+        hook._kubernetes_driver_pod = "spark-app-abc-driver"
+        hook._kubernetes_application_id = "spark-abc"
+
+        mock_client = mock_get_client.return_value
+        failed_pod = V1Pod(status=V1PodStatus(phase="Failed"))
+        mock_client.read_namespaced_pod.return_value = failed_pod
+
+        with pytest.raises(RuntimeError, match="phase=Failed"):
+            hook._poll_k8s_driver_via_api()
+
+    @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client")
+    def test_poll_k8s_driver_raises_after_consecutive_unknown(self, 
mock_get_client):
+        hook = SparkSubmitHook(conn_id="spark_k8s_cluster", 
track_driver_via_k8s_api=True)
+        hook._kubernetes_driver_pod = "spark-app-abc-driver"
+        hook._kubernetes_application_id = "spark-abc"
+
+        mock_client = mock_get_client.return_value
+        mock_client.read_namespaced_pod.return_value = 
V1Pod(status=V1PodStatus(phase="Unknown"))
+
+        with patch("time.sleep"), pytest.raises(RuntimeError, match="Unknown 
phase"):
+            hook._poll_k8s_driver_via_api()
+
+        # assert that it was polled minimum 3 times to confirm the Unknown 
status before raising
+        assert mock_client.read_namespaced_pod.call_count == 3
+
+    @patch("time.sleep")
+    @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client")
+    def test_poll_k8s_driver_tolerates_transient_api_errors(self, 
mock_get_client, _):
+        hook = SparkSubmitHook(conn_id="spark_k8s_cluster", 
track_driver_via_k8s_api=True)
+        hook._kubernetes_driver_pod = "spark-app-abc-driver"
+        hook._kubernetes_application_id = "spark-abc"
+
+        mock_client = mock_get_client.return_value
+        api_error = kube_client.ApiException(status=500, reason="Internal 
Server Error")
+        succeeded_pod = V1Pod(status=V1PodStatus(phase="Succeeded"))
+        mock_client.read_namespaced_pod.side_effect = [api_error, api_error, 
succeeded_pod]
+
+        with patch.object(hook, "_run_post_submit_commands"):
+            hook._poll_k8s_driver_via_api()
+
+        assert mock_client.read_namespaced_pod.call_count == 3
+
+    @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client")
+    def test_post_submit_commands_run_exactly_once_on_k8s_path(self, 
mock_get_client):
+        """_run_post_submit_commands must fire exactly once: in 
_poll_k8s_driver_via_api finally."""
+        hook = SparkSubmitHook(conn_id="spark_k8s_cluster", 
track_driver_via_k8s_api=True)
+        hook._kubernetes_driver_pod = "spark-app-abc-driver"
+        hook._kubernetes_application_id = "spark-abc"
+
+        mock_client = mock_get_client.return_value
+        mock_client.read_namespaced_pod.return_value = 
V1Pod(status=V1PodStatus(phase="Succeeded"))
+
+        with patch.object(hook, "_run_post_submit_commands") as mock_cmd:
+            hook._poll_k8s_driver_via_api()
+
+        mock_cmd.assert_called_once()
+
+    @patch("time.sleep")
+    @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client")
+    def test_poll_k8s_driver_raises_after_consecutive_api_errors(self, 
mock_get_client, _):
+        hook = SparkSubmitHook(conn_id="spark_k8s_cluster", 
track_driver_via_k8s_api=True)
+        hook._kubernetes_driver_pod = "spark-app-abc-driver"
+        hook._kubernetes_application_id = "spark-abc"
+
+        mock_client = mock_get_client.return_value
+        api_error = kube_client.ApiException(status=500, reason="Internal 
Server Error")
+        mock_client.read_namespaced_pod.side_effect = api_error
+
+        with pytest.raises(RuntimeError, match="K8s API unreachable"):
+            hook._poll_k8s_driver_via_api()
+
+        assert mock_client.read_namespaced_pod.call_count == 3
+
+    @patch("airflow.providers.cncf.kubernetes.kube_client.get_kube_client")
+    def test_poll_k8s_driver_exits_cleanly_on_404(self, mock_get_client):
+        """404 from read_namespaced_pod means pod was deleted by on_kill — 
should return cleanly, not raise."""
+        hook = SparkSubmitHook(conn_id="spark_k8s_cluster", 
track_driver_via_k8s_api=True)
+        hook._kubernetes_driver_pod = "spark-app-abc-driver"
+        hook._kubernetes_application_id = "spark-abc"
+
+        mock_client = mock_get_client.return_value
+        mock_client.read_namespaced_pod.side_effect = 
kube_client.ApiException(status=404, reason="Not Found")
+
+        hook._poll_k8s_driver_via_api()
+
+        mock_client.delete_namespaced_pod.assert_not_called()
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.run")
+    def test_run_post_submit_commands_runs_only_once(self, mock_run):
+        """Calling _run_post_submit_commands twice must execute commands 
exactly once."""
+        mock_run.return_value = MagicMock(returncode=0, stdout="")
+        hook = SparkSubmitHook(conn_id="spark_k8s_cluster", 
post_submit_commands=["echo done"])
+
+        hook._run_post_submit_commands()
+        hook._run_post_submit_commands()
+
+        mock_run.assert_called_once()
+
     _YARN_LOG_LINES = [
         "INFO Client: Requesting a new application from cluster with 1 
NodeManagers",
         "INFO Client: Uploading resource file:/tmp/lib.zip -> "
diff --git 
a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py 
b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
index e18de1dd576..47cada84ce6 100644
--- 
a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
+++ 
b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
@@ -733,4 +733,43 @@ class TestSparkSubmitOperatorResumable:
         with pytest.raises(RuntimeError, match="FAILED"):
             operator.poll_until_complete("driver-001", {})
 
-        assert post_submit_called, "_run_post_submit_commands must be called 
even on driver failure"
+
+class TestSparkSubmitOperatorK8sTracking:
+    def setup_method(self):
+        args = {"owner": "airflow", "start_date": DEFAULT_DATE}
+        self.dag = DAG("test_k8s_tracking_dag", schedule=None, 
default_args=args)
+
+    def _make_operator(self, **kwargs):
+        return SparkSubmitOperator(task_id="test", dag=self.dag, 
application="test.jar", **kwargs)
+
+    def _make_k8s_hook(self):
+        hook = MagicMock()
+        hook._should_track_driver_status = False
+        hook._should_track_driver_via_k8s_api.return_value = True
+        return hook
+
+    def test_execute_calls_submit_then_poll_when_flag_set(self):
+        operator = self._make_operator(track_driver_via_k8s_api=True)
+        hook = self._make_k8s_hook()
+        operator._hook = hook
+        call_order = []
+        hook.submit.side_effect = lambda *a, **kw: call_order.append("submit")
+        hook._poll_k8s_driver_via_api.side_effect = lambda: 
call_order.append("poll")
+
+        operator.execute(context={})
+
+        hook.submit.assert_called_once_with("test.jar")
+        hook._poll_k8s_driver_via_api.assert_called_once()
+        assert call_order == ["submit", "poll"]
+
+    def test_execute_falls_through_to_plain_submit_when_flag_off(self):
+        operator = self._make_operator(track_driver_via_k8s_api=False)
+        hook = MagicMock()
+        hook._should_track_driver_status = False
+        hook._should_track_driver_via_k8s_api.return_value = False
+        operator._hook = hook
+
+        operator.execute(context={})
+
+        hook.submit.assert_called_once_with("test.jar")
+        hook._poll_k8s_driver_via_api.assert_not_called()


Reply via email to