This is an automated email from the ASF dual-hosted git repository.

amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new c37ccaa610b Track Spark job status for YARN cluster mode via RM REST 
API to free JVM (#65991)
c37ccaa610b is described below

commit c37ccaa610b364be1ffb2ffdcac75c201a1afb91
Author: Aaron Chen <[email protected]>
AuthorDate: Wed Jun 3 21:42:53 2026 -0700

    Track Spark job status for YARN cluster mode via RM REST API to free JVM 
(#65991)
    
    Adds an opt-in yarn_track_via_rm_api flag to track YARN cluster 
applications via the ResourceManager REST API, allowing Airflow to release the 
local spark-submit process after submission and reduce worker memory usage.
    
    ---------
    
    Co-authored-by: Amogh Desai <[email protected]>
---
 providers/apache/spark/docs/operators.rst          |  58 +++
 providers/apache/spark/provider.yaml               |  10 +
 .../providers/apache/spark/get_provider_info.py    |   5 +
 .../providers/apache/spark/hooks/spark_submit.py   | 285 +++++++++++-
 .../apache/spark/operators/spark_submit.py         |  28 +-
 .../unit/apache/spark/hooks/test_spark_submit.py   | 513 ++++++++++++++++++++-
 6 files changed, 884 insertions(+), 15 deletions(-)

diff --git a/providers/apache/spark/docs/operators.rst 
b/providers/apache/spark/docs/operators.rst
index 64af53454f4..f20d389811e 100644
--- a/providers/apache/spark/docs/operators.rst
+++ b/providers/apache/spark/docs/operators.rst
@@ -214,3 +214,61 @@ See :doc:`connections/spark-submit` for how to configure 
these fields.
 .. note::
     Crash recovery in cluster mode requires Airflow 3.3+ (``task_state`` 
support). On earlier
     versions the operator falls back to the previous behavior of always 
submitting fresh.
+
+YARN ResourceManager API tracking
+"""""""""""""""""""""""""""""""""
+
+When running Spark applications on YARN in cluster deploy mode, the default 
Spark submit path keeps
+the local ``spark-submit`` JVM alive on the Airflow worker while the YARN
+application runs. For long-running Spark applications this can keep worker 
memory tied up for the
+whole application lifetime.
+
+Set ``yarn_track_via_rm_api=True`` to release the local ``spark-submit`` JVM 
after YARN accepts the
+application, then poll the YARN ResourceManager REST API until the application 
reaches a terminal
+state. The ResourceManager API polling interval is controlled by 
``status_poll_interval`` with a
+minimum of 10 seconds.
+
+This mode requires the Spark connection extra to set 
``yarn_resourcemanager_webapp_address`` before
+the application is submitted:
+
+.. code-block:: bash
+
+    airflow connections add spark_yarn_rm \
+        --conn-type spark \
+        --conn-host yarn \
+        --conn-extra '{
+            "deploy-mode": "cluster",
+            "yarn_resourcemanager_webapp_address": "http://rm.example.com:8088";
+        }'
+
+.. code-block:: python
+
+    SparkSubmitOperator(
+        task_id="spark_pi",
+        conn_id="spark_yarn_rm",
+        application="/path/to/spark-examples.jar",
+        java_class="org.apache.spark.examples.SparkPi",
+        deploy_mode="cluster",
+        yarn_track_via_rm_api=True,
+    )
+
+For Kerberized clusters, install ``requests-kerberos`` in the Airflow 
environment. When the
+Spark connection has both ``keytab`` and ``principal`` configured, Airflow 
automatically uses
+``HTTPKerberosAuth()`` for the ResourceManager REST requests.
+
+Use ``yarn_rm_auth`` only when the ResourceManager needs a custom ``requests`` 
authentication
+object:
+
+.. code-block:: python
+
+    import requests
+
+    SparkSubmitOperator(
+        task_id="spark_pi",
+        conn_id="spark_yarn_rm",
+        application="/path/to/spark-examples.jar",
+        java_class="org.apache.spark.examples.SparkPi",
+        deploy_mode="cluster",
+        yarn_track_via_rm_api=True,
+        yarn_rm_auth=requests.auth.HTTPBasicAuth("user", "password"),
+    )
diff --git a/providers/apache/spark/provider.yaml 
b/providers/apache/spark/provider.yaml
index af228a2aa29..2a57351b875 100644
--- a/providers/apache/spark/provider.yaml
+++ b/providers/apache/spark/provider.yaml
@@ -227,6 +227,16 @@ connection-types:
             - string
             - 'null'
           default: '6066'
+      yarn_resourcemanager_webapp_address:
+        label: YARN ResourceManager webapp address
+        description: >
+          YARN ResourceManager webapp URL (e.g. `http://rm.example.com:8088`),
+          required when `yarn_track_via_rm_api=True` on `SparkSubmitOperator` /
+          `SparkSubmitHook`. Mirrors Hadoop's 
`yarn.resourcemanager.webapp.address`.
+        schema:
+          type:
+            - string
+            - 'null'
 
 task-decorators:
   - class-name: airflow.providers.apache.spark.decorators.pyspark.pyspark_task
diff --git 
a/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
 
b/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
index ef09d0a6ae9..3fefe31f5ea 100644
--- 
a/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
+++ 
b/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
@@ -136,6 +136,11 @@ def get_provider_info():
                         "description": "Port for the Spark standalone REST API 
(spark.master.rest.port). Default is 6066.",
                         "schema": {"type": ["string", "null"], "default": 
"6066"},
                     },
+                    "yarn_resourcemanager_webapp_address": {
+                        "label": "YARN ResourceManager webapp address",
+                        "description": "YARN ResourceManager webapp URL (e.g. 
`http://rm.example.com:8088`), required when `yarn_track_via_rm_api=True` on 
`SparkSubmitOperator` / `SparkSubmitHook`. Mirrors Hadoop's 
`yarn.resourcemanager.webapp.address`.\n",
+                        "schema": {"type": ["string", "null"]},
+                    },
                 },
             },
         ],
diff --git 
a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
 
b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
index 9aa3ddc885e..7306563a078 100644
--- 
a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
+++ 
b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
@@ -28,16 +28,27 @@ import tempfile
 import time
 import uuid
 from collections.abc import Iterator
+from functools import cached_property
 from pathlib import Path
-from typing import Any
+from typing import TYPE_CHECKING, Any
 
-from airflow.providers.common.compat.sdk import AirflowException, BaseHook, 
conf as airflow_conf
+import requests
+
+from airflow.providers.common.compat.sdk import (
+    AirflowException,
+    AirflowNotFoundException,
+    BaseHook,
+    conf as airflow_conf,
+)
 from airflow.security.kerberos import renew_from_kt
 from airflow.utils.log.logging_mixin import LoggingMixin
 
 with contextlib.suppress(ImportError, NameError):
     from airflow.providers.cncf.kubernetes import kube_client
 
+if TYPE_CHECKING:
+    from requests.auth import AuthBase
+
 DEFAULT_SPARK_BINARY = "spark-submit"
 ALLOWED_SPARK_BINARIES = [DEFAULT_SPARK_BINARY, "spark2-submit", 
"spark3-submit"]
 
@@ -79,7 +90,11 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
     :param name: Name of the job (default airflow-spark)
     :param num_executors: Number of executors to launch
     :param status_poll_interval: Seconds to wait between polls of driver 
status in cluster
-        mode (Default: 1)
+        mode. Used both by the Spark standalone driver-status tracker and (when
+        ``yarn_track_via_rm_api=True``) by the YARN ResourceManager REST API
+        polling loop. The YARN ResourceManager REST API polling loop uses at
+        least 10 seconds to avoid flooding the ResourceManager on long-running
+        jobs (Default: 1).
     :param application_args: Arguments for the application being submitted
     :param env_vars: Environment variables for spark-submit. It
         supports yarn and k8s mode too.
@@ -99,6 +114,22 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
         job finishes (on both success and on_kill). Useful for cleaning up 
sidecars such
         as Istio (e.g. ``["curl -X POST localhost:15020/quitquitquit"]``). 
Each command
         is executed via the shell; failures produce a warning but do not fail 
the task.
+    :param yarn_track_via_rm_api: If True (when master is YARN and 
``deploy_mode``
+        is ``cluster``), release the ``spark-submit`` JVM once the application 
has
+        been submitted to YARN, then poll the YARN ResourceManager REST API
+        (``GET /ws/v1/cluster/apps/{appId}``) until the application reaches a
+        final state. The polling interval is controlled by 
``status_poll_interval``
+        with a 10-second minimum. This frees the worker from holding the
+        long-lived submit JVM. Requires the Spark connection's
+        ``extra`` JSON to set ``yarn_resourcemanager_webapp_address``
+        (e.g. ``http://rm:8088``). Cluster-side driver logs should be used 
after
+        the switch to polling. Defaults to ``False``.
+    :param yarn_rm_auth: Optional ``requests.auth.AuthBase`` instance used for
+        every call to the YARN ResourceManager REST API (status polling and
+        kill). When omitted, Kerberos-enabled Spark connections with both
+        ``keytab`` and ``principal`` configured use ``requests-kerberos``
+        automatically. Defaults to ``None`` (no auth for non-Kerberos
+        connections).
     """
 
     conn_name_attr = "conn_id"
@@ -106,6 +137,15 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
     conn_type = "spark"
     hook_name = "Spark"
 
+    # YARN ApplicationReport final-application-status values.
+    # See org.apache.hadoop.yarn.api.records.FinalApplicationStatus.
+    _YARN_FINAL_SUCCESS = "SUCCEEDED"
+    _YARN_FINAL_FAILURES = frozenset({"FAILED", "KILLED"})
+    _YARN_FINAL_UNDEFINED = "UNDEFINED"
+    _YARN_WAIT_APP_COMPLETION_CONF = "spark.yarn.submit.waitAppCompletion"
+    _YARN_RM_WEBAPP_ADDRESS_EXTRA_KEY = "yarn_resourcemanager_webapp_address"
+    _HTTP_TIMEOUT = (5, 30)  # (connect, read) seconds, matches old CLI 30s 
read budget
+
     @classmethod
     def get_ui_field_behaviour(cls) -> dict[str, Any]:
         """Return custom UI field behaviour for Spark connection."""
@@ -172,6 +212,16 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
                 description="Port for the Spark standalone REST API 
(spark.master.rest.port). Default: 6066.",
                 validators=[Optional()],
             ),
+            "yarn_resourcemanager_webapp_address": StringField(
+                lazy_gettext("YARN ResourceManager webapp address"),
+                widget=BS3TextFieldWidget(),
+                description=(
+                    "YARN ResourceManager webapp URL (e.g. 
http://rm.example.com:8088), "
+                    "required when yarn_track_via_rm_api=True on 
SparkSubmitOperator / "
+                    "SparkSubmitHook. Mirrors Hadoop's 
yarn.resourcemanager.webapp.address."
+                ),
+                validators=[Optional()],
+            ),
         }
 
     def __init__(
@@ -207,6 +257,8 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
         *,
         use_krb5ccache: bool = False,
         post_submit_commands: list[str] | None = None,
+        yarn_track_via_rm_api: bool = False,
+        yarn_rm_auth: AuthBase | None = None,
     ) -> None:
         super().__init__()
         self._conf = conf or {}
@@ -256,6 +308,12 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
         self._spark_exit_code: int | None = None
         self._env: dict[str, Any] | None = None
         self._post_submit_commands: list[str] = list(post_submit_commands) if 
post_submit_commands else []
+        self._yarn_track_via_rm_api = yarn_track_via_rm_api
+        self._yarn_rm_auth = yarn_rm_auth
+        # Cached after first successful resolution so the polling loop in
+        # `_track_yarn_application` does not re-fetch the Spark connection
+        # (and re-hit any configured Secrets Backend) on every iteration.
+        self._yarn_rm_base_url: str | None = None
 
     def _resolve_should_track_driver_status(self) -> bool:
         """
@@ -268,6 +326,24 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
         """
         return "spark://" in self._connection["master"] and 
self._connection["deploy_mode"] == "cluster"
 
+    def _should_track_yarn_application_via_rm_api(self) -> bool:
+        """Return whether this submit should switch to YARN RM REST API 
polling."""
+        return self._yarn_track_via_rm_api and self._is_yarn and 
self._connection["deploy_mode"] == "cluster"
+
+    def _validate_yarn_track_via_rm_api_config(self) -> None:
+        """Validate that YARN RM REST API tracking can run for this submit."""
+        if not self._yarn_track_via_rm_api:
+            return
+        if not self._is_yarn:
+            raise ValueError("`yarn_track_via_rm_api=True` requires Spark 
master to be YARN.")
+        if self._connection["deploy_mode"] != "cluster":
+            raise ValueError(
+                "`yarn_track_via_rm_api=True` requires 
`deploy_mode='cluster'`; "
+                f"got {self._connection['deploy_mode']!r}."
+            )
+        self._get_yarn_rm_base_url()
+        self._resolved_yarn_rm_auth
+
     def _resolve_connection(self) -> dict[str, Any]:
         # Build from connection master or default to yarn if not available
         conn_data: dict[str, Any] = {
@@ -426,6 +502,16 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
 
         for key in self._conf:
             args += ["--conf", f"{key}={self._conf[key]}"]
+        if self._should_track_yarn_application_via_rm_api():
+            wait_app_completion = 
self._conf.get(self._YARN_WAIT_APP_COMPLETION_CONF)
+            if wait_app_completion is not None:
+                if str(wait_app_completion).strip().lower() != "false":
+                    raise ValueError(
+                        f"`{self._YARN_WAIT_APP_COMPLETION_CONF}=false` is 
required when "
+                        "`yarn_track_via_rm_api=True`."
+                    )
+            else:
+                args += ["--conf", 
f"{self._YARN_WAIT_APP_COMPLETION_CONF}=false"]
         if self._env_vars and (self._is_kubernetes or self._is_yarn):
             if self._is_yarn:
                 tmpl = "spark.yarn.appMasterEnv.{}={}"
@@ -611,6 +697,7 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
         :param application: Submitted application, jar or py file
         :param kwargs: extra arguments to Popen (see subprocess.Popen)
         """
+        self._validate_yarn_track_via_rm_api_config()
         spark_submit_cmd = self._build_spark_submit_command(application)
 
         if self._env:
@@ -643,6 +730,15 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
                     f"Cannot execute: {self._mask_cmd(spark_submit_cmd)}. 
Error code is: {returncode}."
                 )
 
+            if self._should_track_yarn_application_via_rm_api():
+                # Once spark-submit exits successfully, rely on RM REST API 
polling instead
+                # of requiring a particular Spark log line such as "Submitted 
application ...".
+                # The RM REST API is the authoritative source for the 
application's lifecycle.
+                if not self._yarn_application_id:
+                    raise RuntimeError("No YARN application id found after 
spark-submit completed.")
+                self._track_yarn_application(self._yarn_application_id)
+                return self._driver_id
+
             if self._should_track_driver_status and self._driver_id is None:
                 raise AirflowException(
                     "No driver id is known: something went wrong when 
executing the spark submit command"
@@ -712,6 +808,155 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
 
             self.log.info(line)
 
+    def _track_yarn_application(self, application_id: str) -> None:
+        """Poll the YARN RM REST API until the application reaches a terminal 
state."""
+        self.log.info(
+            "Tracking YARN application %s via ResourceManager REST API 
polling",
+            application_id,
+        )
+        poll_interval = max(self._status_poll_interval, 10)
+        # Tolerate transient RM REST API failures (RM hiccup, network blip, 
request
+        # timeout) the same way `_start_driver_status_tracking` does for spark
+        # standalone — only give up after this many consecutive failures.
+        consecutive_failures = 0
+        max_consecutive_failures = 10
+        while True:
+            self.log.debug("Polling YARN RM REST API for application %s", 
application_id)
+            try:
+                state, final_status = 
self._query_yarn_application_status(application_id)
+            except RuntimeError as exc:
+                consecutive_failures += 1
+                if consecutive_failures > max_consecutive_failures:
+                    raise RuntimeError(
+                        f"Giving up tracking YARN application {application_id} 
after "
+                        f"{max_consecutive_failures} consecutive YARN RM REST 
API "
+                        f"failures. Last error: {exc}"
+                    ) from exc
+                self.log.warning(
+                    "Transient YARN RM REST API failure (%d/%d): %s",
+                    consecutive_failures,
+                    max_consecutive_failures,
+                    exc,
+                )
+                time.sleep(poll_interval)
+                continue
+            consecutive_failures = 0
+            if state in self._YARN_FINAL_FAILURES:
+                raise RuntimeError(
+                    f"YARN application {application_id} ended with state: 
{state}, "
+                    f"final status: {final_status}"
+                )
+            if final_status == self._YARN_FINAL_SUCCESS:
+                self.log.info("YARN application %s finished with SUCCEEDED", 
application_id)
+                return
+            if final_status in self._YARN_FINAL_FAILURES:
+                raise RuntimeError(
+                    f"YARN application {application_id} ended with final 
status: {final_status}"
+                )
+            if final_status != self._YARN_FINAL_UNDEFINED:
+                raise RuntimeError(
+                    f"YARN application {application_id} returned unexpected 
final status: {final_status}"
+                )
+            time.sleep(poll_interval)
+
+    def _get_yarn_rm_base_url(self) -> str:
+        """
+        Resolve the YARN ResourceManager webapp base URL from the Spark 
connection.
+
+        Reads the ``yarn_resourcemanager_webapp_address`` key from the Spark
+        connection's ``extra`` JSON. Bare ``host:port`` values get ``http://``
+        prepended; fully-qualified URLs are used as-is. Trailing slashes 
stripped.
+        The resolved URL is cached on the hook instance so the polling loop 
does
+        not re-fetch the connection (or re-hit any Secrets Backend) on every 
iteration.
+        """
+        if self._yarn_rm_base_url is not None:
+            return self._yarn_rm_base_url
+        try:
+            conn = self.get_connection(self._conn_id)
+        except AirflowNotFoundException:
+            conn = None
+        raw = ""
+        if conn is not None:
+            raw = 
(conn.extra_dejson.get(self._YARN_RM_WEBAPP_ADDRESS_EXTRA_KEY) or "").strip()
+        if not raw:
+            raise ValueError(
+                f"`yarn_track_via_rm_api=True` requires the Spark connection's 
`extra` to set "
+                f"`{self._YARN_RM_WEBAPP_ADDRESS_EXTRA_KEY}` (e.g. 
`http://rm.example.com:8088`)."
+            )
+        url = raw if "://" in raw else f"http://{raw}";
+        self._yarn_rm_base_url = url.rstrip("/")
+        return self._yarn_rm_base_url
+
+    @cached_property
+    def _resolved_yarn_rm_auth(self) -> AuthBase | None:
+        """
+        Resolve the auth object for YARN ResourceManager REST API requests.
+
+        Explicit ``yarn_rm_auth`` wins. If omitted, Kerberos-enabled Spark
+        connections automatically use ``requests_kerberos.HTTPKerberosAuth``.
+        """
+        if self._yarn_rm_auth is not None:
+            return self._yarn_rm_auth
+        if self._connection.get("keytab") and 
self._connection.get("principal"):
+            try:
+                from requests_kerberos import HTTPKerberosAuth
+            except ImportError as exc:
+                raise RuntimeError(
+                    "Kerberos credentials are configured for Spark submit, but 
`requests-kerberos` "
+                    "is not installed. Install `requests-kerberos` to use "
+                    "`yarn_track_via_rm_api=True` with Kerberos, or pass 
`yarn_rm_auth` explicitly."
+                ) from exc
+            return HTTPKerberosAuth()
+
+        return None
+
+    def _query_yarn_application_status(self, application_id: str) -> 
tuple[str, str]:
+        """GET ``/ws/v1/cluster/apps/{id}`` once and return ``app.state`` and 
``app.finalStatus``."""
+        url = 
f"{self._get_yarn_rm_base_url()}/ws/v1/cluster/apps/{application_id}"
+        try:
+            resp = requests.get(url, auth=self._resolved_yarn_rm_auth, 
timeout=self._HTTP_TIMEOUT)
+        except requests.exceptions.RequestException as exc:
+            raise RuntimeError(
+                f"YARN RM REST API request for application {application_id} 
failed: {exc}"
+            ) from exc
+        if resp.status_code != 200:
+            raise RuntimeError(
+                f"YARN RM REST API returned HTTP {resp.status_code} for 
application "
+                f"{application_id}: {resp.text[:200]}"
+            )
+        try:
+            app = resp.json()["app"]
+            return app["state"], app["finalStatus"]
+        except (ValueError, KeyError, TypeError) as exc:
+            raise RuntimeError(
+                f"YARN RM REST API returned unexpected payload for application 
"
+                f"{application_id}: {resp.text[:200]}"
+            ) from exc
+
+    def _kill_yarn_application(self, application_id: str) -> None:
+        """PUT ``/ws/v1/cluster/apps/{id}/state`` to kill the application 
(best-effort)."""
+        try:
+            url = 
f"{self._get_yarn_rm_base_url()}/ws/v1/cluster/apps/{application_id}/state"
+            auth = self._resolved_yarn_rm_auth
+        except (ValueError, RuntimeError) as exc:
+            self.log.warning(
+                "Cannot send YARN kill for %s: %s",
+                application_id,
+                exc,
+            )
+            return
+        try:
+            resp = requests.put(
+                url,
+                json={"state": "KILLED"},
+                auth=auth,
+                timeout=self._HTTP_TIMEOUT,
+            )
+        except requests.exceptions.RequestException as exc:
+            self.log.warning("YARN kill request for %s failed: %s", 
application_id, exc)
+            return
+        self.log.info("YARN kill request for %s returned HTTP %s", 
application_id, resp.status_code)
+
     def _process_spark_status_log(self, itr: Iterator[Any]) -> None:
         """
         Parse the logs of the spark driver status query process.
@@ -839,22 +1084,29 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
             self.log.info("Sending kill signal to %s", 
self._connection["spark_binary"])
             self._submit_sp.kill()
 
-            if self._yarn_application_id:
+            # Legacy YARN CLI kill — gated on a live `_submit_sp` to preserve
+            # pre-`yarn_track_via_rm_api` behavior. The REST kill path below is
+            # the opt-in replacement and is intentionally not gated this way:
+            # `yarn_track_via_rm_api=True` deliberately terminates `_submit_sp`
+            # after submission, so by the time `on_kill` fires the gate would
+            # always be False and the YARN app would never be killed.
+            if self._yarn_application_id and not self._yarn_track_via_rm_api:
                 kill_cmd = f"yarn application -kill 
{self._yarn_application_id}".split()
                 env = {**os.environ, **(self._env or {})}
                 if self._connection["keytab"] is not None and 
self._connection["principal"] is not None:
-                    # we are ignoring renewal failures from renew_from_kt
-                    # here as the failure could just be due to a non-renewable 
ticket,
-                    # we still attempt to kill the yarn application
+                    # Renewal failures from `renew_from_kt` are ignored here — 
a
+                    # non-renewable ticket should not block the YARN kill 
attempt.
                     renew_from_kt(
-                        self._connection["principal"], 
self._connection["keytab"], exit_on_fail=False
+                        self._connection["principal"],
+                        self._connection["keytab"],
+                        exit_on_fail=False,
                     )
-                    env = os.environ.copy()
-                    ccacche = airflow_conf.get_mandatory_value("kerberos", 
"ccache")
-                    env["KRB5CCNAME"] = ccacche
-
+                    env["KRB5CCNAME"] = 
airflow_conf.get_mandatory_value("kerberos", "ccache")
                 with subprocess.Popen(
-                    kill_cmd, env=env, stdout=subprocess.PIPE, 
stderr=subprocess.PIPE
+                    kill_cmd,
+                    env=env,
+                    stdout=subprocess.PIPE,
+                    stderr=subprocess.PIPE,
                 ) as yarn_kill:
                     self.log.info("YARN app killed with return code: %s", 
yarn_kill.wait())
 
@@ -878,4 +1130,11 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
                 except kube_client.ApiException:
                     self.log.exception("Exception when attempting to kill 
Spark on K8s")
 
+        # Opt-in REST kill path — uses the same RM endpoint as polling, no
+        # `yarn` CLI dependency on the worker. Independent of `_submit_sp`
+        # state because `yarn_track_via_rm_api=True` deliberately terminates
+        # `_submit_sp` right after submission to free the JVM.
+        if self._yarn_application_id and self._yarn_track_via_rm_api:
+            self._kill_yarn_application(self._yarn_application_id)
+
         self._run_post_submit_commands()
diff --git 
a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
 
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
index 76b010107da..ea7b4a8e4ef 100644
--- 
a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
+++ 
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
@@ -49,6 +49,7 @@ except ImportError:
 
 if TYPE_CHECKING:
     from pydantic import JsonValue
+    from requests.auth import AuthBase
 
     from airflow.providers.common.compat.sdk import Context
 
@@ -91,7 +92,11 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
     :param name: Name of the job (default airflow-spark). (templated)
     :param num_executors: Number of executors to launch
     :param status_poll_interval: Seconds to wait between polls of driver 
status in cluster
-        mode (Default: 1)
+        mode. Used both by the Spark standalone driver-status tracker and (when
+        ``yarn_track_via_rm_api=True``) by the YARN ResourceManager REST API
+        polling loop. The YARN ResourceManager REST API polling loop uses at
+        least 10 seconds to avoid flooding the ResourceManager on long-running
+        jobs (Default: 1).
     :param application_args: Arguments for the application being submitted 
(templated)
     :param env_vars: Environment variables for spark-submit. It supports yarn 
and k8s mode too. (templated)
     :param verbose: Whether to pass the verbose flag to spark-submit process 
for debugging
@@ -108,6 +113,21 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
                            on keytab for Kerberos login
     :param post_submit_commands: Optional list of shell commands to run after 
the Spark job finishes.
         Useful for cleaning up sidecars such as Istio. Failures produce a 
warning but do not fail the task.
+    :param yarn_track_via_rm_api: If True (when master is YARN and 
``deploy_mode``
+        is ``cluster``), release the ``spark-submit`` JVM once the application 
has
+        been submitted to YARN, then poll the YARN ResourceManager REST API
+        (``GET /ws/v1/cluster/apps/{appId}``) until the application reaches a
+        final state. The polling interval is controlled by 
``status_poll_interval``
+        with a 10-second minimum. This frees the worker from holding the
+        long-lived submit JVM. Requires the Spark connection's ``extra``
+        JSON to set ``yarn_resourcemanager_webapp_address`` (e.g. 
``http://rm:8088``).
+        Cluster-side driver logs should be used after the switch to polling.
+        Defaults to ``False``.
+    :param yarn_rm_auth: Optional ``requests.auth.AuthBase`` instance used for 
every
+        call to the YARN ResourceManager REST API (status polling and kill). 
When
+        omitted, Kerberos-enabled Spark connections with both ``keytab`` and
+        ``principal`` configured use ``requests-kerberos`` automatically.
+        Defaults to ``None`` (no auth for non-Kerberos connections).
     """
 
     # Generic key used across all Spark deployment modes (standalone driver ID,
@@ -168,6 +188,8 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
         use_krb5ccache: bool = False,
         post_submit_commands: list[str] | None = None,
         reconnect_on_retry: bool = True,
+        yarn_track_via_rm_api: bool = False,
+        yarn_rm_auth: AuthBase | None = None,
         openlineage_inject_parent_job_info: bool = conf.getboolean(
             "openlineage", "spark_inject_parent_job_info", fallback=False
         ),
@@ -210,6 +232,8 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
         self._post_submit_commands = list(post_submit_commands) if 
post_submit_commands else []
         self._conn_id = conn_id
         self._use_krb5ccache = use_krb5ccache
+        self._yarn_track_via_rm_api = yarn_track_via_rm_api
+        self._yarn_rm_auth = yarn_rm_auth
 
         self.reconnect_on_retry = reconnect_on_retry
         self._openlineage_inject_parent_job_info = 
openlineage_inject_parent_job_info
@@ -378,4 +402,6 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
             deploy_mode=self._deploy_mode,
             use_krb5ccache=self._use_krb5ccache,
             post_submit_commands=self.post_submit_commands,
+            yarn_track_via_rm_api=self._yarn_track_via_rm_api,
+            yarn_rm_auth=self._yarn_rm_auth,
         )
diff --git 
a/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py 
b/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py
index c909e9f12ab..301f720443d 100644
--- a/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py
+++ b/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py
@@ -21,9 +21,11 @@ import base64
 import os
 from io import StringIO
 from pathlib import Path
-from unittest.mock import call, mock_open, patch
+from types import ModuleType
+from unittest.mock import MagicMock, call, mock_open, patch
 
 import pytest
+import requests
 
 from airflow.models import Connection
 from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook
@@ -198,6 +200,29 @@ class TestSparkSubmitHook:
                 uri="spark://local",
             )
         )
+        create_connection_without_db(
+            Connection(
+                conn_id="spark_yarn_rm",
+                conn_type="spark",
+                host="yarn",
+                extra=(
+                    '{"deploy-mode": "cluster", 
"yarn_resourcemanager_webapp_address": "http://rm.test:8088"}'
+                ),
+            )
+        )
+        create_connection_without_db(
+            Connection(
+                conn_id="spark_yarn_rm_kerberos",
+                conn_type="spark",
+                host="yarn",
+                extra=(
+                    '{"deploy-mode": "cluster", '
+                    '"yarn_resourcemanager_webapp_address": 
"http://rm.test:8088";, '
+                    '"principal": "[email protected]", '
+                    '"keytab": "cHJpdmlsZWdlZF91c2VyLmtleXRhYg=="}'
+                ),
+            )
+        )
 
     @pytest.mark.db_test
     @patch(
@@ -1036,6 +1061,26 @@ class TestSparkSubmitHook:
             in mock_popen.mock_calls
         )
 
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def test_legacy_on_kill_skips_yarn_cli_when_submit_sp_already_exited(self, 
mock_popen):
+        """Regression guard when `yarn_track_via_rm_api=False` (legacy
+        path), the YARN CLI kill must stay gated on a live `_submit_sp`. If the
+        spark-submit subprocess has already exited, `on_kill` must not spawn
+        `yarn application -kill` — preserving pre-PR behavior for users who 
have
+        not opted in to the REST kill path.
+        """
+        submit_process = MagicMock(spec=["kill", "poll"])
+        submit_process.poll.return_value = 0  # already exited
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_cluster")  # 
yarn_track_via_rm_api=False
+        hook._submit_sp = submit_process
+        hook._yarn_application_id = "application_1486558679801_1820"
+
+        hook.on_kill()
+
+        submit_process.kill.assert_not_called()
+        mock_popen.assert_not_called()
+
     def test_standalone_cluster_process_on_kill(self):
         # Given
         log_lines = [
@@ -1366,3 +1411,469 @@ class TestSparkSubmitHook:
         """Test that None post_submit_commands results in an empty list."""
         hook = SparkSubmitHook(conn_id="")
         assert hook._post_submit_commands == []
+
+    _YARN_LOG_LINES = [
+        "INFO Client: Requesting a new application from cluster with 1 
NodeManagers",
+        "INFO Client: Uploading resource file:/tmp/lib.zip -> "
+        
"hdfs://namenode:8020/user/root/.sparkStaging/application_1700000000000_0001/lib.zip",
+        "INFO Client: Submitting application application_1700000000000_0001 to 
ResourceManager",
+        "INFO YarnClientImpl: Submitted application 
application_1700000000000_0001",
+        "INFO Client: Application report for application_1700000000000_0001 
(state: ACCEPTED)",
+        "INFO Client: Application report for application_1700000000000_0001 
(state: RUNNING)",
+        "INFO Client: Application report for application_1700000000000_0001 
(state: FINISHED)",
+        "INFO Client: final status: SUCCEEDED",
+    ]
+
+    _RM_BASE_URL = "http://rm.test:8088";
+    _RM_APP_ID = "application_1700000000000_0001"
+
+    @classmethod
+    def _rm_status_url(cls, app_id: str | None = None) -> str:
+        return f"{cls._RM_BASE_URL}/ws/v1/cluster/apps/{app_id or 
cls._RM_APP_ID}"
+
+    @classmethod
+    def _rm_kill_url(cls, app_id: str | None = None) -> str:
+        return f"{cls._RM_BASE_URL}/ws/v1/cluster/apps/{app_id or 
cls._RM_APP_ID}/state"
+
+    @classmethod
+    def _rm_status_resp(cls, final_status: str, state: str = "FINISHED") -> 
MagicMock:
+        resp = MagicMock(spec=requests.Response)
+        resp.status_code = 200
+        resp.json.return_value = {"app": {"id": cls._RM_APP_ID, "state": 
state, "finalStatus": final_status}}
+        return resp
+
+    @staticmethod
+    def _rm_failure_resp(status_code: int = 500, text: str = "Internal Server 
Error") -> MagicMock:
+        resp = MagicMock(spec=requests.Response)
+        resp.status_code = status_code
+        resp.text = text
+        return resp
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.put")
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def test_default_keeps_existing_behavior_in_yarn_cluster(self, mock_popen, 
mock_get, mock_put):
+        """Flag default False -> no HTTP calls; behavior identical to today."""
+        proc = MagicMock(spec=["stdout", "terminate", "wait"])
+        proc.stdout = iter(self._YARN_LOG_LINES)
+        proc.wait.return_value = 0
+        mock_popen.return_value = proc
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_cluster")
+        hook.submit()
+
+        proc.terminate.assert_not_called()
+        mock_get.assert_not_called()
+        mock_put.assert_not_called()
+        assert hook._yarn_application_id == "application_1700000000000_0001"
+
+    def test_yarn_status_tracking_requires_yarn_master(self):
+        """yarn_track_via_rm_api=True should fail fast outside YARN."""
+        hook = SparkSubmitHook(conn_id="spark_k8s_cluster", 
yarn_track_via_rm_api=True)
+
+        with pytest.raises(ValueError, match="requires Spark master to be 
YARN"):
+            hook.submit()
+
+    def test_yarn_status_tracking_requires_cluster_deploy_mode(self):
+        """yarn_track_via_rm_api=True should fail fast outside cluster deploy 
mode."""
+        hook = SparkSubmitHook(
+            conn_id="spark_yarn_rm",
+            deploy_mode="client",
+            yarn_track_via_rm_api=True,
+        )
+
+        with pytest.raises(ValueError, match="requires 
`deploy_mode='cluster'`"):
+            hook.submit()
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def test_yarn_status_tracking_succeeds(self, mock_popen, mock_get, 
mock_sleep):
+        """RM returns UNDEFINED then SUCCEEDED -> hook returns normally."""
+        proc = MagicMock(spec=["stdout", "terminate", "wait"])
+        proc.stdout = iter(self._YARN_LOG_LINES)
+        proc.wait.return_value = 0
+        mock_popen.return_value = proc
+
+        mock_get.side_effect = [
+            self._rm_status_resp("UNDEFINED", state="RUNNING"),
+            self._rm_status_resp("SUCCEEDED"),
+        ]
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
+        hook.submit()
+
+        spark_submit_cmd = mock_popen.call_args.args[0]
+        assert "spark.yarn.submit.waitAppCompletion=false" in spark_submit_cmd
+        proc.terminate.assert_not_called()
+        assert mock_get.call_count == 2
+        mock_sleep.assert_called_once_with(10)
+        for call_obj in mock_get.call_args_list:
+            assert call_obj.args[0] == self._rm_status_url()
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def test_yarn_status_tracking_fails_on_killed(self, mock_popen, mock_get, 
mock_sleep):
+        """RM returns KILLED -> raise with message containing app id and 
KILLED."""
+        proc = MagicMock(spec=["stdout", "terminate", "wait"])
+        proc.stdout = iter(self._YARN_LOG_LINES)
+        proc.wait.return_value = 0
+        mock_popen.return_value = proc
+
+        mock_get.return_value = self._rm_status_resp("KILLED")
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
+        with pytest.raises(RuntimeError, match=f"{self._RM_APP_ID}.*KILLED"):
+            hook.submit()
+        proc.terminate.assert_not_called()
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def 
test_yarn_status_tracking_fails_on_failed_state_with_undefined_final_status(
+        self, mock_popen, mock_get, mock_sleep
+    ):
+        """RM state FAILED with finalStatus UNDEFINED should not poll 
forever."""
+        proc = MagicMock(spec=["stdout", "terminate", "wait"])
+        proc.stdout = iter(self._YARN_LOG_LINES)
+        proc.wait.return_value = 0
+        mock_popen.return_value = proc
+
+        mock_get.return_value = self._rm_status_resp("UNDEFINED", 
state="FAILED")
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
+        with pytest.raises(RuntimeError, match=f"{self._RM_APP_ID}.*state: 
FAILED"):
+            hook.submit()
+
+        proc.terminate.assert_not_called()
+        mock_sleep.assert_not_called()
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def test_yarn_status_tracking_fails_on_unexpected_final_status(self, 
mock_popen, mock_get, mock_sleep):
+        """RM returns a non-standard finalStatus ('BOGUS') -> raise without 
sleeping."""
+        proc = MagicMock(spec=["stdout", "terminate", "wait"])
+        proc.stdout = iter(self._YARN_LOG_LINES)
+        proc.wait.return_value = 0
+        mock_popen.return_value = proc
+
+        mock_get.return_value = self._rm_status_resp("BOGUS")
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
+        with pytest.raises(RuntimeError, match="unexpected final status: 
BOGUS"):
+            hook.submit()
+
+        proc.terminate.assert_not_called()
+        mock_sleep.assert_not_called()
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def 
test_yarn_status_tracking_polls_without_application_submission_log(self, 
mock_popen, mock_get):
+        """Missing 'Submitted application' log line should not block RM REST 
polling."""
+        yarn_log_lines = [
+            "INFO Client: Uploading resource file:/tmp/lib.zip -> "
+            
"hdfs://namenode:8020/user/root/.sparkStaging/application_1700000000000_0001/lib.zip",
+            "INFO Client: Submitting application 
application_1700000000000_0001 to ResourceManager",
+        ]
+        proc = MagicMock(spec=["stdout", "terminate", "wait"])
+        proc.stdout = iter(yarn_log_lines)
+        proc.wait.return_value = 0
+        mock_popen.return_value = proc
+        mock_get.return_value = self._rm_status_resp("SUCCEEDED")
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
+        hook.submit()
+
+        assert hook._yarn_application_id == self._RM_APP_ID
+        assert mock_get.call_args.args[0] == self._rm_status_url()
+        proc.terminate.assert_not_called()
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def 
test_yarn_status_tracking_checks_spark_submit_exit_code_before_polling(self, 
mock_popen, mock_get):
+        """spark-submit exits non-zero -> raise BEFORE issuing any HTTP 
request."""
+        proc = MagicMock(spec=["stdout", "terminate", "wait"])
+        proc.stdout = iter(self._YARN_LOG_LINES)
+        proc.wait.return_value = 1
+        mock_popen.return_value = proc
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
+        with pytest.raises(AirflowException, match="Error code is: 1"):
+            hook.submit()
+
+        proc.terminate.assert_not_called()
+        mock_get.assert_not_called()
+
+    def 
test_yarn_status_tracking_rejects_conflicting_wait_app_completion_conf(self):
+        """User-set spark.yarn.submit.waitAppCompletion=true conflicts with 
flag -> ValueError."""
+        hook = SparkSubmitHook(
+            conn_id="spark_yarn_rm",
+            conf={"spark.yarn.submit.waitAppCompletion": "true"},
+            yarn_track_via_rm_api=True,
+        )
+
+        with pytest.raises(ValueError, 
match="spark.yarn.submit.waitAppCompletion=false"):
+            hook._build_spark_submit_command("")
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def test_yarn_status_tracking_tolerates_transient_failures(self, 
mock_popen, mock_get, mock_sleep):
+        """3 consecutive 5xx responses then SUCCEEDED -> normal completion."""
+        proc = MagicMock(spec=["stdout", "terminate", "wait"])
+        proc.stdout = iter(self._YARN_LOG_LINES)
+        proc.wait.return_value = 0
+        mock_popen.return_value = proc
+
+        # 3 transient failures (within the 10-failure budget), then SUCCEEDED.
+        mock_get.side_effect = [
+            self._rm_failure_resp(503, "Service Unavailable"),
+            self._rm_failure_resp(502, "Bad Gateway"),
+            self._rm_failure_resp(500, "Internal Server Error"),
+            self._rm_status_resp("SUCCEEDED"),
+        ]
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
+        hook.submit()
+
+        assert mock_get.call_count == 4
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def test_yarn_status_tracking_tolerates_status_timeouts(self, mock_popen, 
mock_get, mock_sleep):
+        """First requests.exceptions.Timeout, second call succeeds -> normal 
completion."""
+        proc = MagicMock(spec=["stdout", "terminate", "wait"])
+        proc.stdout = iter(self._YARN_LOG_LINES)
+        proc.wait.return_value = 0
+        mock_popen.return_value = proc
+
+        mock_get.side_effect = [
+            requests.exceptions.Timeout("read timed out"),
+            self._rm_status_resp("SUCCEEDED"),
+        ]
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
+        hook.submit()
+
+        assert mock_get.call_count == 2
+        # All calls must include the (connect, read) timeout tuple.
+        for call_obj in mock_get.call_args_list:
+            assert call_obj.kwargs["timeout"] == (5, 30)
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def test_yarn_status_tracking_raises_after_too_many_failures(self, 
mock_popen, mock_get, mock_sleep):
+        """11 consecutive 5xx responses -> raise 'Giving up tracking YARN 
application'."""
+        proc = MagicMock(spec=["stdout", "terminate", "wait"])
+        proc.stdout = iter(self._YARN_LOG_LINES)
+        proc.wait.return_value = 0
+        mock_popen.return_value = proc
+
+        # 11 failures: 10 tolerated; the 11th trips the budget.
+        mock_get.side_effect = [self._rm_failure_resp() for _ in range(11)]
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
+        with pytest.raises(RuntimeError, match="Giving up tracking YARN 
application"):
+            hook.submit()
+
+        assert mock_get.call_count == 11
+
+    @pytest.mark.parametrize("use_auth", [False, True])
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    def test_yarn_status_query_passes_auth_to_requests(self, mock_get, 
use_auth):
+        """Explicit yarn_rm_auth is passed to requests.get, including the 
default None."""
+
+        class _SentinelAuth(requests.auth.AuthBase):
+            def __call__(self, r):
+                return r
+
+        auth = _SentinelAuth() if use_auth else None
+        mock_get.return_value = self._rm_status_resp("SUCCEEDED")
+
+        hook = SparkSubmitHook(
+            conn_id="spark_yarn_rm",
+            yarn_track_via_rm_api=True,
+            yarn_rm_auth=auth,
+        )
+        hook._query_yarn_application_status(self._RM_APP_ID)
+
+        assert mock_get.call_args.kwargs["auth"] is auth
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    def test_yarn_status_query_uses_kerberos_auth_from_connection(self, 
mock_get):
+        """Connection keytab + principal auto-enable HTTPKerberosAuth for RM 
requests."""
+
+        class _SentinelKerberosAuth(requests.auth.AuthBase):
+            def __call__(self, r):
+                return r
+
+        requests_kerberos = ModuleType("requests_kerberos")
+        requests_kerberos.HTTPKerberosAuth = _SentinelKerberosAuth
+        mock_get.return_value = self._rm_status_resp("SUCCEEDED")
+
+        with (
+            patch.object(
+                SparkSubmitHook,
+                "_create_keytab_path_from_base64_keytab",
+                return_value="privileged_user.keytab",
+            ),
+            patch.dict("sys.modules", {"requests_kerberos": 
requests_kerberos}),
+        ):
+            hook = SparkSubmitHook(conn_id="spark_yarn_rm_kerberos", 
yarn_track_via_rm_api=True)
+            hook._query_yarn_application_status(self._RM_APP_ID)
+
+        assert isinstance(mock_get.call_args.kwargs["auth"], 
_SentinelKerberosAuth)
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    def 
test_yarn_status_query_prefers_provided_auth_over_kerberos_connection(self, 
mock_get):
+        """Explicit yarn_rm_auth stays an escape hatch even when Kerberos is 
configured."""
+
+        class _SentinelAuth(requests.auth.AuthBase):
+            def __call__(self, r):
+                return r
+
+        auth = _SentinelAuth()
+        mock_get.return_value = self._rm_status_resp("SUCCEEDED")
+
+        with (
+            patch.object(
+                SparkSubmitHook,
+                "_create_keytab_path_from_base64_keytab",
+                return_value="privileged_user.keytab",
+            ),
+            patch.dict("sys.modules", {"requests_kerberos": None}),
+        ):
+            hook = SparkSubmitHook(
+                conn_id="spark_yarn_rm_kerberos",
+                yarn_track_via_rm_api=True,
+                yarn_rm_auth=auth,
+            )
+            hook._query_yarn_application_status(self._RM_APP_ID)
+
+        assert mock_get.call_args.kwargs["auth"] is auth
+
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def 
test_yarn_status_tracking_fails_before_submit_when_kerberos_auth_dependency_missing(self,
 mock_popen):
+        """Kerberos RM tracking requires requests-kerberos before spark-submit 
starts."""
+        with patch.object(
+            SparkSubmitHook,
+            "_create_keytab_path_from_base64_keytab",
+            return_value="privileged_user.keytab",
+        ):
+            hook = SparkSubmitHook(conn_id="spark_yarn_rm_kerberos", 
yarn_track_via_rm_api=True)
+
+        with (
+            patch.dict("sys.modules", {"requests_kerberos": None}),
+            pytest.raises(RuntimeError, match="requests-kerberos"),
+        ):
+            hook.submit()
+
+        mock_popen.assert_not_called()
+
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def 
test_yarn_status_tracking_fails_before_submit_when_rm_url_missing(self, 
mock_popen):
+        """Missing yarn_resourcemanager_webapp_address extra -> fail before 
spark-submit starts."""
+        hook = SparkSubmitHook(conn_id="spark_yarn_cluster", 
yarn_track_via_rm_api=True)
+
+        with pytest.raises(ValueError, 
match="yarn_resourcemanager_webapp_address"):
+            hook.submit()
+
+        mock_popen.assert_not_called()
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def test_yarn_rm_base_url_is_resolved_once_across_polling_loop(self, 
mock_popen, mock_get, mock_sleep):
+        """Connection lookup must run once even if the polling loop runs many 
iterations.
+
+        Regression guard: a job polling every few seconds for hours must not 
re-fetch
+        the Spark connection (and potentially re-hit a Secrets Backend) on 
every iteration.
+        """
+        proc = MagicMock(spec=["stdout", "terminate", "wait"])
+        proc.stdout = iter(self._YARN_LOG_LINES)
+        proc.wait.return_value = 0
+        mock_popen.return_value = proc
+
+        # 4 UNDEFINED iterations then SUCCEEDED -> 5 polling iterations total.
+        mock_get.side_effect = [
+            self._rm_status_resp("UNDEFINED", state="RUNNING"),
+            self._rm_status_resp("UNDEFINED", state="RUNNING"),
+            self._rm_status_resp("UNDEFINED", state="RUNNING"),
+            self._rm_status_resp("UNDEFINED", state="RUNNING"),
+            self._rm_status_resp("SUCCEEDED"),
+        ]
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
+        with patch.object(hook, "get_connection", wraps=hook.get_connection) 
as spy_get_conn:
+            hook.submit()
+
+        assert mock_get.call_count == 5
+        assert spy_get_conn.call_count == 1
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.put")
+    def test_on_kill_sends_put_to_rm_when_app_id_known(self, mock_put):
+        """_yarn_application_id known -> PUT {state: KILLED} to RM with 
configured auth."""
+
+        class _SentinelAuth(requests.auth.AuthBase):
+            def __call__(self, r):
+                return r
+
+        sentinel = _SentinelAuth()
+        mock_put.return_value = MagicMock(spec=requests.Response, 
status_code=202)
+
+        hook = SparkSubmitHook(
+            conn_id="spark_yarn_rm",
+            yarn_track_via_rm_api=True,
+            yarn_rm_auth=sentinel,
+        )
+        hook._yarn_application_id = self._RM_APP_ID
+        hook.on_kill()
+
+        mock_put.assert_called_once()
+        call_obj = mock_put.call_args
+        assert call_obj.args[0] == self._rm_kill_url()
+        assert call_obj.kwargs["json"] == {"state": "KILLED"}
+        assert call_obj.kwargs["auth"] is sentinel
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.put")
+    def test_on_kill_uses_kerberos_auth_from_connection(self, mock_put):
+        """Connection keytab + principal auto-enable HTTPKerberosAuth for RM 
kill requests."""
+
+        class _SentinelKerberosAuth(requests.auth.AuthBase):
+            def __call__(self, r):
+                return r
+
+        requests_kerberos = ModuleType("requests_kerberos")
+        requests_kerberos.HTTPKerberosAuth = _SentinelKerberosAuth
+        mock_put.return_value = MagicMock(spec=requests.Response, 
status_code=202)
+
+        with (
+            patch.object(
+                SparkSubmitHook,
+                "_create_keytab_path_from_base64_keytab",
+                return_value="privileged_user.keytab",
+            ),
+            patch.dict("sys.modules", {"requests_kerberos": 
requests_kerberos}),
+        ):
+            hook = SparkSubmitHook(conn_id="spark_yarn_rm_kerberos", 
yarn_track_via_rm_api=True)
+            hook._yarn_application_id = self._RM_APP_ID
+            hook.on_kill()
+
+        assert isinstance(mock_put.call_args.kwargs["auth"], 
_SentinelKerberosAuth)
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.put")
+    def test_on_kill_tolerates_rm_failure(self, mock_put):
+        """RM PUT raises -> on_kill does not raise (best-effort, mirrors 
today)."""
+        mock_put.side_effect = requests.exceptions.ConnectionError("RM 
unreachable")
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
+        hook._yarn_application_id = self._RM_APP_ID
+
+        # Must not raise.
+        hook.on_kill()
+
+        mock_put.assert_called_once()

Reply via email to