This is an automated email from the ASF dual-hosted git repository.
amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new c37ccaa610b Track Spark job status for YARN cluster mode via RM REST
API to free JVM (#65991)
c37ccaa610b is described below
commit c37ccaa610b364be1ffb2ffdcac75c201a1afb91
Author: Aaron Chen <[email protected]>
AuthorDate: Wed Jun 3 21:42:53 2026 -0700
Track Spark job status for YARN cluster mode via RM REST API to free JVM
(#65991)
Adds an opt-in yarn_track_via_rm_api flag to track YARN cluster
applications via the ResourceManager REST API, allowing Airflow to release the
local spark-submit process after submission and reduce worker memory usage.
---------
Co-authored-by: Amogh Desai <[email protected]>
---
providers/apache/spark/docs/operators.rst | 58 +++
providers/apache/spark/provider.yaml | 10 +
.../providers/apache/spark/get_provider_info.py | 5 +
.../providers/apache/spark/hooks/spark_submit.py | 285 +++++++++++-
.../apache/spark/operators/spark_submit.py | 28 +-
.../unit/apache/spark/hooks/test_spark_submit.py | 513 ++++++++++++++++++++-
6 files changed, 884 insertions(+), 15 deletions(-)
diff --git a/providers/apache/spark/docs/operators.rst
b/providers/apache/spark/docs/operators.rst
index 64af53454f4..f20d389811e 100644
--- a/providers/apache/spark/docs/operators.rst
+++ b/providers/apache/spark/docs/operators.rst
@@ -214,3 +214,61 @@ See :doc:`connections/spark-submit` for how to configure
these fields.
.. note::
Crash recovery in cluster mode requires Airflow 3.3+ (``task_state``
support). On earlier
versions the operator falls back to the previous behavior of always
submitting fresh.
+
+YARN ResourceManager API tracking
+"""""""""""""""""""""""""""""""""
+
+When running Spark applications on YARN in cluster deploy mode, the default
Spark submit path keeps
+the local ``spark-submit`` JVM alive on the Airflow worker while the YARN
+application runs. For long-running Spark applications this can keep worker
memory tied up for the
+whole application lifetime.
+
+Set ``yarn_track_via_rm_api=True`` to release the local ``spark-submit`` JVM
after YARN accepts the
+application, then poll the YARN ResourceManager REST API until the application
reaches a terminal
+state. The ResourceManager API polling interval is controlled by
``status_poll_interval`` with a
+minimum of 10 seconds.
+
+This mode requires the Spark connection extra to set
``yarn_resourcemanager_webapp_address`` before
+the application is submitted:
+
+.. code-block:: bash
+
+ airflow connections add spark_yarn_rm \
+ --conn-type spark \
+ --conn-host yarn \
+ --conn-extra '{
+ "deploy-mode": "cluster",
+ "yarn_resourcemanager_webapp_address": "http://rm.example.com:8088"
+ }'
+
+.. code-block:: python
+
+ SparkSubmitOperator(
+ task_id="spark_pi",
+ conn_id="spark_yarn_rm",
+ application="/path/to/spark-examples.jar",
+ java_class="org.apache.spark.examples.SparkPi",
+ deploy_mode="cluster",
+ yarn_track_via_rm_api=True,
+ )
+
+For Kerberized clusters, install ``requests-kerberos`` in the Airflow
environment. When the
+Spark connection has both ``keytab`` and ``principal`` configured, Airflow
automatically uses
+``HTTPKerberosAuth()`` for the ResourceManager REST requests.
+
+Use ``yarn_rm_auth`` only when the ResourceManager needs a custom ``requests``
authentication
+object:
+
+.. code-block:: python
+
+ import requests
+
+ SparkSubmitOperator(
+ task_id="spark_pi",
+ conn_id="spark_yarn_rm",
+ application="/path/to/spark-examples.jar",
+ java_class="org.apache.spark.examples.SparkPi",
+ deploy_mode="cluster",
+ yarn_track_via_rm_api=True,
+ yarn_rm_auth=requests.auth.HTTPBasicAuth("user", "password"),
+ )
diff --git a/providers/apache/spark/provider.yaml
b/providers/apache/spark/provider.yaml
index af228a2aa29..2a57351b875 100644
--- a/providers/apache/spark/provider.yaml
+++ b/providers/apache/spark/provider.yaml
@@ -227,6 +227,16 @@ connection-types:
- string
- 'null'
default: '6066'
+ yarn_resourcemanager_webapp_address:
+ label: YARN ResourceManager webapp address
+ description: >
+ YARN ResourceManager webapp URL (e.g. `http://rm.example.com:8088`),
+ required when `yarn_track_via_rm_api=True` on `SparkSubmitOperator` /
+ `SparkSubmitHook`. Mirrors Hadoop's
`yarn.resourcemanager.webapp.address`.
+ schema:
+ type:
+ - string
+ - 'null'
task-decorators:
- class-name: airflow.providers.apache.spark.decorators.pyspark.pyspark_task
diff --git
a/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
b/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
index ef09d0a6ae9..3fefe31f5ea 100644
---
a/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
+++
b/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
@@ -136,6 +136,11 @@ def get_provider_info():
"description": "Port for the Spark standalone REST API
(spark.master.rest.port). Default is 6066.",
"schema": {"type": ["string", "null"], "default":
"6066"},
},
+ "yarn_resourcemanager_webapp_address": {
+ "label": "YARN ResourceManager webapp address",
+ "description": "YARN ResourceManager webapp URL (e.g.
`http://rm.example.com:8088`), required when `yarn_track_via_rm_api=True` on
`SparkSubmitOperator` / `SparkSubmitHook`. Mirrors Hadoop's
`yarn.resourcemanager.webapp.address`.\n",
+ "schema": {"type": ["string", "null"]},
+ },
},
},
],
diff --git
a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
index 9aa3ddc885e..7306563a078 100644
---
a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
+++
b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
@@ -28,16 +28,27 @@ import tempfile
import time
import uuid
from collections.abc import Iterator
+from functools import cached_property
from pathlib import Path
-from typing import Any
+from typing import TYPE_CHECKING, Any
-from airflow.providers.common.compat.sdk import AirflowException, BaseHook,
conf as airflow_conf
+import requests
+
+from airflow.providers.common.compat.sdk import (
+ AirflowException,
+ AirflowNotFoundException,
+ BaseHook,
+ conf as airflow_conf,
+)
from airflow.security.kerberos import renew_from_kt
from airflow.utils.log.logging_mixin import LoggingMixin
with contextlib.suppress(ImportError, NameError):
from airflow.providers.cncf.kubernetes import kube_client
+if TYPE_CHECKING:
+ from requests.auth import AuthBase
+
DEFAULT_SPARK_BINARY = "spark-submit"
ALLOWED_SPARK_BINARIES = [DEFAULT_SPARK_BINARY, "spark2-submit",
"spark3-submit"]
@@ -79,7 +90,11 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
:param name: Name of the job (default airflow-spark)
:param num_executors: Number of executors to launch
:param status_poll_interval: Seconds to wait between polls of driver
status in cluster
- mode (Default: 1)
+ mode. Used both by the Spark standalone driver-status tracker and (when
+ ``yarn_track_via_rm_api=True``) by the YARN ResourceManager REST API
+ polling loop. The YARN ResourceManager REST API polling loop uses at
+ least 10 seconds to avoid flooding the ResourceManager on long-running
+ jobs (Default: 1).
:param application_args: Arguments for the application being submitted
:param env_vars: Environment variables for spark-submit. It
supports yarn and k8s mode too.
@@ -99,6 +114,22 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
job finishes (on both success and on_kill). Useful for cleaning up
sidecars such
as Istio (e.g. ``["curl -X POST localhost:15020/quitquitquit"]``).
Each command
is executed via the shell; failures produce a warning but do not fail
the task.
+ :param yarn_track_via_rm_api: If True (when master is YARN and
``deploy_mode``
+ is ``cluster``), release the ``spark-submit`` JVM once the application
has
+ been submitted to YARN, then poll the YARN ResourceManager REST API
+ (``GET /ws/v1/cluster/apps/{appId}``) until the application reaches a
+ final state. The polling interval is controlled by
``status_poll_interval``
+ with a 10-second minimum. This frees the worker from holding the
+ long-lived submit JVM. Requires the Spark connection's
+ ``extra`` JSON to set ``yarn_resourcemanager_webapp_address``
+ (e.g. ``http://rm:8088``). Cluster-side driver logs should be used
after
+ the switch to polling. Defaults to ``False``.
+ :param yarn_rm_auth: Optional ``requests.auth.AuthBase`` instance used for
+ every call to the YARN ResourceManager REST API (status polling and
+ kill). When omitted, Kerberos-enabled Spark connections with both
+ ``keytab`` and ``principal`` configured use ``requests-kerberos``
+ automatically. Defaults to ``None`` (no auth for non-Kerberos
+ connections).
"""
conn_name_attr = "conn_id"
@@ -106,6 +137,15 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
conn_type = "spark"
hook_name = "Spark"
+ # YARN ApplicationReport final-application-status values.
+ # See org.apache.hadoop.yarn.api.records.FinalApplicationStatus.
+ _YARN_FINAL_SUCCESS = "SUCCEEDED"
+ _YARN_FINAL_FAILURES = frozenset({"FAILED", "KILLED"})
+ _YARN_FINAL_UNDEFINED = "UNDEFINED"
+ _YARN_WAIT_APP_COMPLETION_CONF = "spark.yarn.submit.waitAppCompletion"
+ _YARN_RM_WEBAPP_ADDRESS_EXTRA_KEY = "yarn_resourcemanager_webapp_address"
+ _HTTP_TIMEOUT = (5, 30) # (connect, read) seconds, matches old CLI 30s
read budget
+
@classmethod
def get_ui_field_behaviour(cls) -> dict[str, Any]:
"""Return custom UI field behaviour for Spark connection."""
@@ -172,6 +212,16 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
description="Port for the Spark standalone REST API
(spark.master.rest.port). Default: 6066.",
validators=[Optional()],
),
+ "yarn_resourcemanager_webapp_address": StringField(
+ lazy_gettext("YARN ResourceManager webapp address"),
+ widget=BS3TextFieldWidget(),
+ description=(
+ "YARN ResourceManager webapp URL (e.g.
http://rm.example.com:8088), "
+ "required when yarn_track_via_rm_api=True on
SparkSubmitOperator / "
+ "SparkSubmitHook. Mirrors Hadoop's
yarn.resourcemanager.webapp.address."
+ ),
+ validators=[Optional()],
+ ),
}
def __init__(
@@ -207,6 +257,8 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
*,
use_krb5ccache: bool = False,
post_submit_commands: list[str] | None = None,
+ yarn_track_via_rm_api: bool = False,
+ yarn_rm_auth: AuthBase | None = None,
) -> None:
super().__init__()
self._conf = conf or {}
@@ -256,6 +308,12 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
self._spark_exit_code: int | None = None
self._env: dict[str, Any] | None = None
self._post_submit_commands: list[str] = list(post_submit_commands) if
post_submit_commands else []
+ self._yarn_track_via_rm_api = yarn_track_via_rm_api
+ self._yarn_rm_auth = yarn_rm_auth
+ # Cached after first successful resolution so the polling loop in
+ # `_track_yarn_application` does not re-fetch the Spark connection
+ # (and re-hit any configured Secrets Backend) on every iteration.
+ self._yarn_rm_base_url: str | None = None
def _resolve_should_track_driver_status(self) -> bool:
"""
@@ -268,6 +326,24 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
"""
return "spark://" in self._connection["master"] and
self._connection["deploy_mode"] == "cluster"
+ def _should_track_yarn_application_via_rm_api(self) -> bool:
+ """Return whether this submit should switch to YARN RM REST API
polling."""
+ return self._yarn_track_via_rm_api and self._is_yarn and
self._connection["deploy_mode"] == "cluster"
+
+ def _validate_yarn_track_via_rm_api_config(self) -> None:
+ """Validate that YARN RM REST API tracking can run for this submit."""
+ if not self._yarn_track_via_rm_api:
+ return
+ if not self._is_yarn:
+ raise ValueError("`yarn_track_via_rm_api=True` requires Spark
master to be YARN.")
+ if self._connection["deploy_mode"] != "cluster":
+ raise ValueError(
+ "`yarn_track_via_rm_api=True` requires
`deploy_mode='cluster'`; "
+ f"got {self._connection['deploy_mode']!r}."
+ )
+ self._get_yarn_rm_base_url()
+ self._resolved_yarn_rm_auth
+
def _resolve_connection(self) -> dict[str, Any]:
# Build from connection master or default to yarn if not available
conn_data: dict[str, Any] = {
@@ -426,6 +502,16 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
for key in self._conf:
args += ["--conf", f"{key}={self._conf[key]}"]
+ if self._should_track_yarn_application_via_rm_api():
+ wait_app_completion =
self._conf.get(self._YARN_WAIT_APP_COMPLETION_CONF)
+ if wait_app_completion is not None:
+ if str(wait_app_completion).strip().lower() != "false":
+ raise ValueError(
+ f"`{self._YARN_WAIT_APP_COMPLETION_CONF}=false` is
required when "
+ "`yarn_track_via_rm_api=True`."
+ )
+ else:
+ args += ["--conf",
f"{self._YARN_WAIT_APP_COMPLETION_CONF}=false"]
if self._env_vars and (self._is_kubernetes or self._is_yarn):
if self._is_yarn:
tmpl = "spark.yarn.appMasterEnv.{}={}"
@@ -611,6 +697,7 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
:param application: Submitted application, jar or py file
:param kwargs: extra arguments to Popen (see subprocess.Popen)
"""
+ self._validate_yarn_track_via_rm_api_config()
spark_submit_cmd = self._build_spark_submit_command(application)
if self._env:
@@ -643,6 +730,15 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
f"Cannot execute: {self._mask_cmd(spark_submit_cmd)}.
Error code is: {returncode}."
)
+ if self._should_track_yarn_application_via_rm_api():
+ # Once spark-submit exits successfully, rely on RM REST API
polling instead
+ # of requiring a particular Spark log line such as "Submitted
application ...".
+ # The RM REST API is the authoritative source for the
application's lifecycle.
+ if not self._yarn_application_id:
+ raise RuntimeError("No YARN application id found after
spark-submit completed.")
+ self._track_yarn_application(self._yarn_application_id)
+ return self._driver_id
+
if self._should_track_driver_status and self._driver_id is None:
raise AirflowException(
"No driver id is known: something went wrong when
executing the spark submit command"
@@ -712,6 +808,155 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
self.log.info(line)
+ def _track_yarn_application(self, application_id: str) -> None:
+ """Poll the YARN RM REST API until the application reaches a terminal
state."""
+ self.log.info(
+ "Tracking YARN application %s via ResourceManager REST API
polling",
+ application_id,
+ )
+ poll_interval = max(self._status_poll_interval, 10)
+ # Tolerate transient RM REST API failures (RM hiccup, network blip,
request
+ # timeout) the same way `_start_driver_status_tracking` does for spark
+ # standalone — only give up after this many consecutive failures.
+ consecutive_failures = 0
+ max_consecutive_failures = 10
+ while True:
+ self.log.debug("Polling YARN RM REST API for application %s",
application_id)
+ try:
+ state, final_status =
self._query_yarn_application_status(application_id)
+ except RuntimeError as exc:
+ consecutive_failures += 1
+ if consecutive_failures > max_consecutive_failures:
+ raise RuntimeError(
+ f"Giving up tracking YARN application {application_id}
after "
+ f"{max_consecutive_failures} consecutive YARN RM REST
API "
+ f"failures. Last error: {exc}"
+ ) from exc
+ self.log.warning(
+ "Transient YARN RM REST API failure (%d/%d): %s",
+ consecutive_failures,
+ max_consecutive_failures,
+ exc,
+ )
+ time.sleep(poll_interval)
+ continue
+ consecutive_failures = 0
+ if state in self._YARN_FINAL_FAILURES:
+ raise RuntimeError(
+ f"YARN application {application_id} ended with state:
{state}, "
+ f"final status: {final_status}"
+ )
+ if final_status == self._YARN_FINAL_SUCCESS:
+ self.log.info("YARN application %s finished with SUCCEEDED",
application_id)
+ return
+ if final_status in self._YARN_FINAL_FAILURES:
+ raise RuntimeError(
+ f"YARN application {application_id} ended with final
status: {final_status}"
+ )
+ if final_status != self._YARN_FINAL_UNDEFINED:
+ raise RuntimeError(
+ f"YARN application {application_id} returned unexpected
final status: {final_status}"
+ )
+ time.sleep(poll_interval)
+
+ def _get_yarn_rm_base_url(self) -> str:
+ """
+ Resolve the YARN ResourceManager webapp base URL from the Spark
connection.
+
+ Reads the ``yarn_resourcemanager_webapp_address`` key from the Spark
+ connection's ``extra`` JSON. Bare ``host:port`` values get ``http://``
+ prepended; fully-qualified URLs are used as-is. Trailing slashes
stripped.
+ The resolved URL is cached on the hook instance so the polling loop
does
+ not re-fetch the connection (or re-hit any Secrets Backend) on every
iteration.
+ """
+ if self._yarn_rm_base_url is not None:
+ return self._yarn_rm_base_url
+ try:
+ conn = self.get_connection(self._conn_id)
+ except AirflowNotFoundException:
+ conn = None
+ raw = ""
+ if conn is not None:
+ raw =
(conn.extra_dejson.get(self._YARN_RM_WEBAPP_ADDRESS_EXTRA_KEY) or "").strip()
+ if not raw:
+ raise ValueError(
+ f"`yarn_track_via_rm_api=True` requires the Spark connection's
`extra` to set "
+ f"`{self._YARN_RM_WEBAPP_ADDRESS_EXTRA_KEY}` (e.g.
`http://rm.example.com:8088`)."
+ )
+ url = raw if "://" in raw else f"http://{raw}"
+ self._yarn_rm_base_url = url.rstrip("/")
+ return self._yarn_rm_base_url
+
+ @cached_property
+ def _resolved_yarn_rm_auth(self) -> AuthBase | None:
+ """
+ Resolve the auth object for YARN ResourceManager REST API requests.
+
+ Explicit ``yarn_rm_auth`` wins. If omitted, Kerberos-enabled Spark
+ connections automatically use ``requests_kerberos.HTTPKerberosAuth``.
+ """
+ if self._yarn_rm_auth is not None:
+ return self._yarn_rm_auth
+ if self._connection.get("keytab") and
self._connection.get("principal"):
+ try:
+ from requests_kerberos import HTTPKerberosAuth
+ except ImportError as exc:
+ raise RuntimeError(
+ "Kerberos credentials are configured for Spark submit, but
`requests-kerberos` "
+ "is not installed. Install `requests-kerberos` to use "
+ "`yarn_track_via_rm_api=True` with Kerberos, or pass
`yarn_rm_auth` explicitly."
+ ) from exc
+ return HTTPKerberosAuth()
+
+ return None
+
+ def _query_yarn_application_status(self, application_id: str) ->
tuple[str, str]:
+ """GET ``/ws/v1/cluster/apps/{id}`` once and return ``app.state`` and
``app.finalStatus``."""
+ url =
f"{self._get_yarn_rm_base_url()}/ws/v1/cluster/apps/{application_id}"
+ try:
+ resp = requests.get(url, auth=self._resolved_yarn_rm_auth,
timeout=self._HTTP_TIMEOUT)
+ except requests.exceptions.RequestException as exc:
+ raise RuntimeError(
+ f"YARN RM REST API request for application {application_id}
failed: {exc}"
+ ) from exc
+ if resp.status_code != 200:
+ raise RuntimeError(
+ f"YARN RM REST API returned HTTP {resp.status_code} for
application "
+ f"{application_id}: {resp.text[:200]}"
+ )
+ try:
+ app = resp.json()["app"]
+ return app["state"], app["finalStatus"]
+ except (ValueError, KeyError, TypeError) as exc:
+ raise RuntimeError(
+ f"YARN RM REST API returned unexpected payload for application
"
+ f"{application_id}: {resp.text[:200]}"
+ ) from exc
+
+ def _kill_yarn_application(self, application_id: str) -> None:
+ """PUT ``/ws/v1/cluster/apps/{id}/state`` to kill the application
(best-effort)."""
+ try:
+ url =
f"{self._get_yarn_rm_base_url()}/ws/v1/cluster/apps/{application_id}/state"
+ auth = self._resolved_yarn_rm_auth
+ except (ValueError, RuntimeError) as exc:
+ self.log.warning(
+ "Cannot send YARN kill for %s: %s",
+ application_id,
+ exc,
+ )
+ return
+ try:
+ resp = requests.put(
+ url,
+ json={"state": "KILLED"},
+ auth=auth,
+ timeout=self._HTTP_TIMEOUT,
+ )
+ except requests.exceptions.RequestException as exc:
+ self.log.warning("YARN kill request for %s failed: %s",
application_id, exc)
+ return
+ self.log.info("YARN kill request for %s returned HTTP %s",
application_id, resp.status_code)
+
def _process_spark_status_log(self, itr: Iterator[Any]) -> None:
"""
Parse the logs of the spark driver status query process.
@@ -839,22 +1084,29 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
self.log.info("Sending kill signal to %s",
self._connection["spark_binary"])
self._submit_sp.kill()
- if self._yarn_application_id:
+ # Legacy YARN CLI kill — gated on a live `_submit_sp` to preserve
+ # pre-`yarn_track_via_rm_api` behavior. The REST kill path below is
+ # the opt-in replacement and is intentionally not gated this way:
+ # `yarn_track_via_rm_api=True` deliberately terminates `_submit_sp`
+ # after submission, so by the time `on_kill` fires the gate would
+ # always be False and the YARN app would never be killed.
+ if self._yarn_application_id and not self._yarn_track_via_rm_api:
kill_cmd = f"yarn application -kill
{self._yarn_application_id}".split()
env = {**os.environ, **(self._env or {})}
if self._connection["keytab"] is not None and
self._connection["principal"] is not None:
- # we are ignoring renewal failures from renew_from_kt
- # here as the failure could just be due to a non-renewable
ticket,
- # we still attempt to kill the yarn application
+ # Renewal failures from `renew_from_kt` are ignored here —
a
+ # non-renewable ticket should not block the YARN kill
attempt.
renew_from_kt(
- self._connection["principal"],
self._connection["keytab"], exit_on_fail=False
+ self._connection["principal"],
+ self._connection["keytab"],
+ exit_on_fail=False,
)
- env = os.environ.copy()
- ccacche = airflow_conf.get_mandatory_value("kerberos",
"ccache")
- env["KRB5CCNAME"] = ccacche
-
+ env["KRB5CCNAME"] =
airflow_conf.get_mandatory_value("kerberos", "ccache")
with subprocess.Popen(
- kill_cmd, env=env, stdout=subprocess.PIPE,
stderr=subprocess.PIPE
+ kill_cmd,
+ env=env,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
) as yarn_kill:
self.log.info("YARN app killed with return code: %s",
yarn_kill.wait())
@@ -878,4 +1130,11 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
except kube_client.ApiException:
self.log.exception("Exception when attempting to kill
Spark on K8s")
+ # Opt-in REST kill path — uses the same RM endpoint as polling, no
+ # `yarn` CLI dependency on the worker. Independent of `_submit_sp`
+ # state because `yarn_track_via_rm_api=True` deliberately terminates
+ # `_submit_sp` right after submission to free the JVM.
+ if self._yarn_application_id and self._yarn_track_via_rm_api:
+ self._kill_yarn_application(self._yarn_application_id)
+
self._run_post_submit_commands()
diff --git
a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
index 76b010107da..ea7b4a8e4ef 100644
---
a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
+++
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
@@ -49,6 +49,7 @@ except ImportError:
if TYPE_CHECKING:
from pydantic import JsonValue
+ from requests.auth import AuthBase
from airflow.providers.common.compat.sdk import Context
@@ -91,7 +92,11 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
:param name: Name of the job (default airflow-spark). (templated)
:param num_executors: Number of executors to launch
:param status_poll_interval: Seconds to wait between polls of driver
status in cluster
- mode (Default: 1)
+ mode. Used both by the Spark standalone driver-status tracker and (when
+ ``yarn_track_via_rm_api=True``) by the YARN ResourceManager REST API
+ polling loop. The YARN ResourceManager REST API polling loop uses at
+ least 10 seconds to avoid flooding the ResourceManager on long-running
+ jobs (Default: 1).
:param application_args: Arguments for the application being submitted
(templated)
:param env_vars: Environment variables for spark-submit. It supports yarn
and k8s mode too. (templated)
:param verbose: Whether to pass the verbose flag to spark-submit process
for debugging
@@ -108,6 +113,21 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
on keytab for Kerberos login
:param post_submit_commands: Optional list of shell commands to run after
the Spark job finishes.
Useful for cleaning up sidecars such as Istio. Failures produce a
warning but do not fail the task.
+ :param yarn_track_via_rm_api: If True (when master is YARN and
``deploy_mode``
+ is ``cluster``), release the ``spark-submit`` JVM once the application
has
+ been submitted to YARN, then poll the YARN ResourceManager REST API
+ (``GET /ws/v1/cluster/apps/{appId}``) until the application reaches a
+ final state. The polling interval is controlled by
``status_poll_interval``
+ with a 10-second minimum. This frees the worker from holding the
+ long-lived submit JVM. Requires the Spark connection's ``extra``
+ JSON to set ``yarn_resourcemanager_webapp_address`` (e.g.
``http://rm:8088``).
+ Cluster-side driver logs should be used after the switch to polling.
+ Defaults to ``False``.
+ :param yarn_rm_auth: Optional ``requests.auth.AuthBase`` instance used for
every
+ call to the YARN ResourceManager REST API (status polling and kill).
When
+ omitted, Kerberos-enabled Spark connections with both ``keytab`` and
+ ``principal`` configured use ``requests-kerberos`` automatically.
+ Defaults to ``None`` (no auth for non-Kerberos connections).
"""
# Generic key used across all Spark deployment modes (standalone driver ID,
@@ -168,6 +188,8 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
use_krb5ccache: bool = False,
post_submit_commands: list[str] | None = None,
reconnect_on_retry: bool = True,
+ yarn_track_via_rm_api: bool = False,
+ yarn_rm_auth: AuthBase | None = None,
openlineage_inject_parent_job_info: bool = conf.getboolean(
"openlineage", "spark_inject_parent_job_info", fallback=False
),
@@ -210,6 +232,8 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
self._post_submit_commands = list(post_submit_commands) if
post_submit_commands else []
self._conn_id = conn_id
self._use_krb5ccache = use_krb5ccache
+ self._yarn_track_via_rm_api = yarn_track_via_rm_api
+ self._yarn_rm_auth = yarn_rm_auth
self.reconnect_on_retry = reconnect_on_retry
self._openlineage_inject_parent_job_info =
openlineage_inject_parent_job_info
@@ -378,4 +402,6 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
deploy_mode=self._deploy_mode,
use_krb5ccache=self._use_krb5ccache,
post_submit_commands=self.post_submit_commands,
+ yarn_track_via_rm_api=self._yarn_track_via_rm_api,
+ yarn_rm_auth=self._yarn_rm_auth,
)
diff --git
a/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py
b/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py
index c909e9f12ab..301f720443d 100644
--- a/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py
+++ b/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py
@@ -21,9 +21,11 @@ import base64
import os
from io import StringIO
from pathlib import Path
-from unittest.mock import call, mock_open, patch
+from types import ModuleType
+from unittest.mock import MagicMock, call, mock_open, patch
import pytest
+import requests
from airflow.models import Connection
from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook
@@ -198,6 +200,29 @@ class TestSparkSubmitHook:
uri="spark://local",
)
)
+ create_connection_without_db(
+ Connection(
+ conn_id="spark_yarn_rm",
+ conn_type="spark",
+ host="yarn",
+ extra=(
+ '{"deploy-mode": "cluster",
"yarn_resourcemanager_webapp_address": "http://rm.test:8088"}'
+ ),
+ )
+ )
+ create_connection_without_db(
+ Connection(
+ conn_id="spark_yarn_rm_kerberos",
+ conn_type="spark",
+ host="yarn",
+ extra=(
+ '{"deploy-mode": "cluster", '
+ '"yarn_resourcemanager_webapp_address":
"http://rm.test:8088", '
+ '"principal": "[email protected]", '
+ '"keytab": "cHJpdmlsZWdlZF91c2VyLmtleXRhYg=="}'
+ ),
+ )
+ )
@pytest.mark.db_test
@patch(
@@ -1036,6 +1061,26 @@ class TestSparkSubmitHook:
in mock_popen.mock_calls
)
+
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+ def test_legacy_on_kill_skips_yarn_cli_when_submit_sp_already_exited(self,
mock_popen):
+ """Regression guard when `yarn_track_via_rm_api=False` (legacy
+ path), the YARN CLI kill must stay gated on a live `_submit_sp`. If the
+ spark-submit subprocess has already exited, `on_kill` must not spawn
+ `yarn application -kill` — preserving pre-PR behavior for users who
have
+ not opted in to the REST kill path.
+ """
+ submit_process = MagicMock(spec=["kill", "poll"])
+ submit_process.poll.return_value = 0 # already exited
+
+ hook = SparkSubmitHook(conn_id="spark_yarn_cluster") #
yarn_track_via_rm_api=False
+ hook._submit_sp = submit_process
+ hook._yarn_application_id = "application_1486558679801_1820"
+
+ hook.on_kill()
+
+ submit_process.kill.assert_not_called()
+ mock_popen.assert_not_called()
+
def test_standalone_cluster_process_on_kill(self):
# Given
log_lines = [
@@ -1366,3 +1411,469 @@ class TestSparkSubmitHook:
"""Test that None post_submit_commands results in an empty list."""
hook = SparkSubmitHook(conn_id="")
assert hook._post_submit_commands == []
+
+ _YARN_LOG_LINES = [
+ "INFO Client: Requesting a new application from cluster with 1
NodeManagers",
+ "INFO Client: Uploading resource file:/tmp/lib.zip -> "
+
"hdfs://namenode:8020/user/root/.sparkStaging/application_1700000000000_0001/lib.zip",
+ "INFO Client: Submitting application application_1700000000000_0001 to
ResourceManager",
+ "INFO YarnClientImpl: Submitted application
application_1700000000000_0001",
+ "INFO Client: Application report for application_1700000000000_0001
(state: ACCEPTED)",
+ "INFO Client: Application report for application_1700000000000_0001
(state: RUNNING)",
+ "INFO Client: Application report for application_1700000000000_0001
(state: FINISHED)",
+ "INFO Client: final status: SUCCEEDED",
+ ]
+
+ _RM_BASE_URL = "http://rm.test:8088"
+ _RM_APP_ID = "application_1700000000000_0001"
+
+ @classmethod
+ def _rm_status_url(cls, app_id: str | None = None) -> str:
+ return f"{cls._RM_BASE_URL}/ws/v1/cluster/apps/{app_id or
cls._RM_APP_ID}"
+
+ @classmethod
+ def _rm_kill_url(cls, app_id: str | None = None) -> str:
+ return f"{cls._RM_BASE_URL}/ws/v1/cluster/apps/{app_id or
cls._RM_APP_ID}/state"
+
+ @classmethod
+ def _rm_status_resp(cls, final_status: str, state: str = "FINISHED") ->
MagicMock:
+ resp = MagicMock(spec=requests.Response)
+ resp.status_code = 200
+ resp.json.return_value = {"app": {"id": cls._RM_APP_ID, "state":
state, "finalStatus": final_status}}
+ return resp
+
+ @staticmethod
+ def _rm_failure_resp(status_code: int = 500, text: str = "Internal Server
Error") -> MagicMock:
+ resp = MagicMock(spec=requests.Response)
+ resp.status_code = status_code
+ resp.text = text
+ return resp
+
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.put")
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+ def test_default_keeps_existing_behavior_in_yarn_cluster(self, mock_popen,
mock_get, mock_put):
+ """Flag default False -> no HTTP calls; behavior identical to today."""
+ proc = MagicMock(spec=["stdout", "terminate", "wait"])
+ proc.stdout = iter(self._YARN_LOG_LINES)
+ proc.wait.return_value = 0
+ mock_popen.return_value = proc
+
+ hook = SparkSubmitHook(conn_id="spark_yarn_cluster")
+ hook.submit()
+
+ proc.terminate.assert_not_called()
+ mock_get.assert_not_called()
+ mock_put.assert_not_called()
+ assert hook._yarn_application_id == "application_1700000000000_0001"
+
+ def test_yarn_status_tracking_requires_yarn_master(self):
+ """yarn_track_via_rm_api=True should fail fast outside YARN."""
+ hook = SparkSubmitHook(conn_id="spark_k8s_cluster",
yarn_track_via_rm_api=True)
+
+ with pytest.raises(ValueError, match="requires Spark master to be
YARN"):
+ hook.submit()
+
+ def test_yarn_status_tracking_requires_cluster_deploy_mode(self):
+ """yarn_track_via_rm_api=True should fail fast outside cluster deploy
mode."""
+ hook = SparkSubmitHook(
+ conn_id="spark_yarn_rm",
+ deploy_mode="client",
+ yarn_track_via_rm_api=True,
+ )
+
+ with pytest.raises(ValueError, match="requires
`deploy_mode='cluster'`"):
+ hook.submit()
+
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+ def test_yarn_status_tracking_succeeds(self, mock_popen, mock_get,
mock_sleep):
+ """RM returns UNDEFINED then SUCCEEDED -> hook returns normally."""
+ proc = MagicMock(spec=["stdout", "terminate", "wait"])
+ proc.stdout = iter(self._YARN_LOG_LINES)
+ proc.wait.return_value = 0
+ mock_popen.return_value = proc
+
+ mock_get.side_effect = [
+ self._rm_status_resp("UNDEFINED", state="RUNNING"),
+ self._rm_status_resp("SUCCEEDED"),
+ ]
+
+ hook = SparkSubmitHook(conn_id="spark_yarn_rm",
yarn_track_via_rm_api=True)
+ hook.submit()
+
+ spark_submit_cmd = mock_popen.call_args.args[0]
+ assert "spark.yarn.submit.waitAppCompletion=false" in spark_submit_cmd
+ proc.terminate.assert_not_called()
+ assert mock_get.call_count == 2
+ mock_sleep.assert_called_once_with(10)
+ for call_obj in mock_get.call_args_list:
+ assert call_obj.args[0] == self._rm_status_url()
+
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+ def test_yarn_status_tracking_fails_on_killed(self, mock_popen, mock_get,
mock_sleep):
+ """RM returns KILLED -> raise with message containing app id and
KILLED."""
+ proc = MagicMock(spec=["stdout", "terminate", "wait"])
+ proc.stdout = iter(self._YARN_LOG_LINES)
+ proc.wait.return_value = 0
+ mock_popen.return_value = proc
+
+ mock_get.return_value = self._rm_status_resp("KILLED")
+
+ hook = SparkSubmitHook(conn_id="spark_yarn_rm",
yarn_track_via_rm_api=True)
+ with pytest.raises(RuntimeError, match=f"{self._RM_APP_ID}.*KILLED"):
+ hook.submit()
+ proc.terminate.assert_not_called()
+
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+ def
test_yarn_status_tracking_fails_on_failed_state_with_undefined_final_status(
+ self, mock_popen, mock_get, mock_sleep
+ ):
+ """RM state FAILED with finalStatus UNDEFINED should not poll
forever."""
+ proc = MagicMock(spec=["stdout", "terminate", "wait"])
+ proc.stdout = iter(self._YARN_LOG_LINES)
+ proc.wait.return_value = 0
+ mock_popen.return_value = proc
+
+ mock_get.return_value = self._rm_status_resp("UNDEFINED",
state="FAILED")
+
+ hook = SparkSubmitHook(conn_id="spark_yarn_rm",
yarn_track_via_rm_api=True)
+ with pytest.raises(RuntimeError, match=f"{self._RM_APP_ID}.*state:
FAILED"):
+ hook.submit()
+
+ proc.terminate.assert_not_called()
+ mock_sleep.assert_not_called()
+
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+ def test_yarn_status_tracking_fails_on_unexpected_final_status(self,
mock_popen, mock_get, mock_sleep):
+ """RM returns a non-standard finalStatus ('BOGUS') -> raise without
sleeping."""
+ proc = MagicMock(spec=["stdout", "terminate", "wait"])
+ proc.stdout = iter(self._YARN_LOG_LINES)
+ proc.wait.return_value = 0
+ mock_popen.return_value = proc
+
+ mock_get.return_value = self._rm_status_resp("BOGUS")
+
+ hook = SparkSubmitHook(conn_id="spark_yarn_rm",
yarn_track_via_rm_api=True)
+ with pytest.raises(RuntimeError, match="unexpected final status:
BOGUS"):
+ hook.submit()
+
+ proc.terminate.assert_not_called()
+ mock_sleep.assert_not_called()
+
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+ def
test_yarn_status_tracking_polls_without_application_submission_log(self,
mock_popen, mock_get):
+ """Missing 'Submitted application' log line should not block RM REST
polling."""
+ yarn_log_lines = [
+ "INFO Client: Uploading resource file:/tmp/lib.zip -> "
+
"hdfs://namenode:8020/user/root/.sparkStaging/application_1700000000000_0001/lib.zip",
+ "INFO Client: Submitting application
application_1700000000000_0001 to ResourceManager",
+ ]
+ proc = MagicMock(spec=["stdout", "terminate", "wait"])
+ proc.stdout = iter(yarn_log_lines)
+ proc.wait.return_value = 0
+ mock_popen.return_value = proc
+ mock_get.return_value = self._rm_status_resp("SUCCEEDED")
+
+ hook = SparkSubmitHook(conn_id="spark_yarn_rm",
yarn_track_via_rm_api=True)
+ hook.submit()
+
+ assert hook._yarn_application_id == self._RM_APP_ID
+ assert mock_get.call_args.args[0] == self._rm_status_url()
+ proc.terminate.assert_not_called()
+
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+ def
test_yarn_status_tracking_checks_spark_submit_exit_code_before_polling(self,
mock_popen, mock_get):
+ """spark-submit exits non-zero -> raise BEFORE issuing any HTTP
request."""
+ proc = MagicMock(spec=["stdout", "terminate", "wait"])
+ proc.stdout = iter(self._YARN_LOG_LINES)
+ proc.wait.return_value = 1
+ mock_popen.return_value = proc
+
+ hook = SparkSubmitHook(conn_id="spark_yarn_rm",
yarn_track_via_rm_api=True)
+ with pytest.raises(AirflowException, match="Error code is: 1"):
+ hook.submit()
+
+ proc.terminate.assert_not_called()
+ mock_get.assert_not_called()
+
+ def
test_yarn_status_tracking_rejects_conflicting_wait_app_completion_conf(self):
+ """User-set spark.yarn.submit.waitAppCompletion=true conflicts with
flag -> ValueError."""
+ hook = SparkSubmitHook(
+ conn_id="spark_yarn_rm",
+ conf={"spark.yarn.submit.waitAppCompletion": "true"},
+ yarn_track_via_rm_api=True,
+ )
+
+ with pytest.raises(ValueError,
match="spark.yarn.submit.waitAppCompletion=false"):
+ hook._build_spark_submit_command("")
+
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+ def test_yarn_status_tracking_tolerates_transient_failures(self,
mock_popen, mock_get, mock_sleep):
+ """3 consecutive 5xx responses then SUCCEEDED -> normal completion."""
+ proc = MagicMock(spec=["stdout", "terminate", "wait"])
+ proc.stdout = iter(self._YARN_LOG_LINES)
+ proc.wait.return_value = 0
+ mock_popen.return_value = proc
+
+ # 3 transient failures (within the 10-failure budget), then SUCCEEDED.
+ mock_get.side_effect = [
+ self._rm_failure_resp(503, "Service Unavailable"),
+ self._rm_failure_resp(502, "Bad Gateway"),
+ self._rm_failure_resp(500, "Internal Server Error"),
+ self._rm_status_resp("SUCCEEDED"),
+ ]
+
+ hook = SparkSubmitHook(conn_id="spark_yarn_rm",
yarn_track_via_rm_api=True)
+ hook.submit()
+
+ assert mock_get.call_count == 4
+
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+ def test_yarn_status_tracking_tolerates_status_timeouts(self, mock_popen,
mock_get, mock_sleep):
+ """First requests.exceptions.Timeout, second call succeeds -> normal
completion."""
+ proc = MagicMock(spec=["stdout", "terminate", "wait"])
+ proc.stdout = iter(self._YARN_LOG_LINES)
+ proc.wait.return_value = 0
+ mock_popen.return_value = proc
+
+ mock_get.side_effect = [
+ requests.exceptions.Timeout("read timed out"),
+ self._rm_status_resp("SUCCEEDED"),
+ ]
+
+ hook = SparkSubmitHook(conn_id="spark_yarn_rm",
yarn_track_via_rm_api=True)
+ hook.submit()
+
+ assert mock_get.call_count == 2
+ # All calls must include the (connect, read) timeout tuple.
+ for call_obj in mock_get.call_args_list:
+ assert call_obj.kwargs["timeout"] == (5, 30)
+
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+ def test_yarn_status_tracking_raises_after_too_many_failures(self,
mock_popen, mock_get, mock_sleep):
+ """11 consecutive 5xx responses -> raise 'Giving up tracking YARN
application'."""
+ proc = MagicMock(spec=["stdout", "terminate", "wait"])
+ proc.stdout = iter(self._YARN_LOG_LINES)
+ proc.wait.return_value = 0
+ mock_popen.return_value = proc
+
+ # 11 failures: 10 tolerated; the 11th trips the budget.
+ mock_get.side_effect = [self._rm_failure_resp() for _ in range(11)]
+
+ hook = SparkSubmitHook(conn_id="spark_yarn_rm",
yarn_track_via_rm_api=True)
+ with pytest.raises(RuntimeError, match="Giving up tracking YARN
application"):
+ hook.submit()
+
+ assert mock_get.call_count == 11
+
+ @pytest.mark.parametrize("use_auth", [False, True])
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+ def test_yarn_status_query_passes_auth_to_requests(self, mock_get,
use_auth):
+ """Explicit yarn_rm_auth is passed to requests.get, including the
default None."""
+
+ class _SentinelAuth(requests.auth.AuthBase):
+ def __call__(self, r):
+ return r
+
+ auth = _SentinelAuth() if use_auth else None
+ mock_get.return_value = self._rm_status_resp("SUCCEEDED")
+
+ hook = SparkSubmitHook(
+ conn_id="spark_yarn_rm",
+ yarn_track_via_rm_api=True,
+ yarn_rm_auth=auth,
+ )
+ hook._query_yarn_application_status(self._RM_APP_ID)
+
+ assert mock_get.call_args.kwargs["auth"] is auth
+
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+ def test_yarn_status_query_uses_kerberos_auth_from_connection(self,
mock_get):
+ """Connection keytab + principal auto-enable HTTPKerberosAuth for RM
requests."""
+
+ class _SentinelKerberosAuth(requests.auth.AuthBase):
+ def __call__(self, r):
+ return r
+
+ requests_kerberos = ModuleType("requests_kerberos")
+ requests_kerberos.HTTPKerberosAuth = _SentinelKerberosAuth
+ mock_get.return_value = self._rm_status_resp("SUCCEEDED")
+
+ with (
+ patch.object(
+ SparkSubmitHook,
+ "_create_keytab_path_from_base64_keytab",
+ return_value="privileged_user.keytab",
+ ),
+ patch.dict("sys.modules", {"requests_kerberos":
requests_kerberos}),
+ ):
+ hook = SparkSubmitHook(conn_id="spark_yarn_rm_kerberos",
yarn_track_via_rm_api=True)
+ hook._query_yarn_application_status(self._RM_APP_ID)
+
+ assert isinstance(mock_get.call_args.kwargs["auth"],
_SentinelKerberosAuth)
+
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+ def
test_yarn_status_query_prefers_provided_auth_over_kerberos_connection(self,
mock_get):
+ """Explicit yarn_rm_auth stays an escape hatch even when Kerberos is
configured."""
+
+ class _SentinelAuth(requests.auth.AuthBase):
+ def __call__(self, r):
+ return r
+
+ auth = _SentinelAuth()
+ mock_get.return_value = self._rm_status_resp("SUCCEEDED")
+
+ with (
+ patch.object(
+ SparkSubmitHook,
+ "_create_keytab_path_from_base64_keytab",
+ return_value="privileged_user.keytab",
+ ),
+ patch.dict("sys.modules", {"requests_kerberos": None}),
+ ):
+ hook = SparkSubmitHook(
+ conn_id="spark_yarn_rm_kerberos",
+ yarn_track_via_rm_api=True,
+ yarn_rm_auth=auth,
+ )
+ hook._query_yarn_application_status(self._RM_APP_ID)
+
+ assert mock_get.call_args.kwargs["auth"] is auth
+
+
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+ def
test_yarn_status_tracking_fails_before_submit_when_kerberos_auth_dependency_missing(self,
mock_popen):
+ """Kerberos RM tracking requires requests-kerberos before spark-submit
starts."""
+ with patch.object(
+ SparkSubmitHook,
+ "_create_keytab_path_from_base64_keytab",
+ return_value="privileged_user.keytab",
+ ):
+ hook = SparkSubmitHook(conn_id="spark_yarn_rm_kerberos",
yarn_track_via_rm_api=True)
+
+ with (
+ patch.dict("sys.modules", {"requests_kerberos": None}),
+ pytest.raises(RuntimeError, match="requests-kerberos"),
+ ):
+ hook.submit()
+
+ mock_popen.assert_not_called()
+
+
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+ def
test_yarn_status_tracking_fails_before_submit_when_rm_url_missing(self,
mock_popen):
+ """Missing yarn_resourcemanager_webapp_address extra -> fail before
spark-submit starts."""
+ hook = SparkSubmitHook(conn_id="spark_yarn_cluster",
yarn_track_via_rm_api=True)
+
+ with pytest.raises(ValueError,
match="yarn_resourcemanager_webapp_address"):
+ hook.submit()
+
+ mock_popen.assert_not_called()
+
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+ def test_yarn_rm_base_url_is_resolved_once_across_polling_loop(self,
mock_popen, mock_get, mock_sleep):
+ """Connection lookup must run once even if the polling loop runs many
iterations.
+
+ Regression guard: a job polling every few seconds for hours must not
re-fetch
+ the Spark connection (and potentially re-hit a Secrets Backend) on
every iteration.
+ """
+ proc = MagicMock(spec=["stdout", "terminate", "wait"])
+ proc.stdout = iter(self._YARN_LOG_LINES)
+ proc.wait.return_value = 0
+ mock_popen.return_value = proc
+
+ # 4 UNDEFINED iterations then SUCCEEDED -> 5 polling iterations total.
+ mock_get.side_effect = [
+ self._rm_status_resp("UNDEFINED", state="RUNNING"),
+ self._rm_status_resp("UNDEFINED", state="RUNNING"),
+ self._rm_status_resp("UNDEFINED", state="RUNNING"),
+ self._rm_status_resp("UNDEFINED", state="RUNNING"),
+ self._rm_status_resp("SUCCEEDED"),
+ ]
+
+ hook = SparkSubmitHook(conn_id="spark_yarn_rm",
yarn_track_via_rm_api=True)
+ with patch.object(hook, "get_connection", wraps=hook.get_connection)
as spy_get_conn:
+ hook.submit()
+
+ assert mock_get.call_count == 5
+ assert spy_get_conn.call_count == 1
+
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.put")
+ def test_on_kill_sends_put_to_rm_when_app_id_known(self, mock_put):
+ """_yarn_application_id known -> PUT {state: KILLED} to RM with
configured auth."""
+
+ class _SentinelAuth(requests.auth.AuthBase):
+ def __call__(self, r):
+ return r
+
+ sentinel = _SentinelAuth()
+ mock_put.return_value = MagicMock(spec=requests.Response,
status_code=202)
+
+ hook = SparkSubmitHook(
+ conn_id="spark_yarn_rm",
+ yarn_track_via_rm_api=True,
+ yarn_rm_auth=sentinel,
+ )
+ hook._yarn_application_id = self._RM_APP_ID
+ hook.on_kill()
+
+ mock_put.assert_called_once()
+ call_obj = mock_put.call_args
+ assert call_obj.args[0] == self._rm_kill_url()
+ assert call_obj.kwargs["json"] == {"state": "KILLED"}
+ assert call_obj.kwargs["auth"] is sentinel
+
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.put")
+ def test_on_kill_uses_kerberos_auth_from_connection(self, mock_put):
+ """Connection keytab + principal auto-enable HTTPKerberosAuth for RM
kill requests."""
+
+ class _SentinelKerberosAuth(requests.auth.AuthBase):
+ def __call__(self, r):
+ return r
+
+ requests_kerberos = ModuleType("requests_kerberos")
+ requests_kerberos.HTTPKerberosAuth = _SentinelKerberosAuth
+ mock_put.return_value = MagicMock(spec=requests.Response,
status_code=202)
+
+ with (
+ patch.object(
+ SparkSubmitHook,
+ "_create_keytab_path_from_base64_keytab",
+ return_value="privileged_user.keytab",
+ ),
+ patch.dict("sys.modules", {"requests_kerberos":
requests_kerberos}),
+ ):
+ hook = SparkSubmitHook(conn_id="spark_yarn_rm_kerberos",
yarn_track_via_rm_api=True)
+ hook._yarn_application_id = self._RM_APP_ID
+ hook.on_kill()
+
+ assert isinstance(mock_put.call_args.kwargs["auth"],
_SentinelKerberosAuth)
+
+ @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.put")
+ def test_on_kill_tolerates_rm_failure(self, mock_put):
+ """RM PUT raises -> on_kill does not raise (best-effort, mirrors
today)."""
+ mock_put.side_effect = requests.exceptions.ConnectionError("RM
unreachable")
+
+ hook = SparkSubmitHook(conn_id="spark_yarn_rm",
yarn_track_via_rm_api=True)
+ hook._yarn_application_id = self._RM_APP_ID
+
+ # Must not raise.
+ hook.on_kill()
+
+ mock_put.assert_called_once()