amoghrajesh commented on code in PR #65991:
URL: https://github.com/apache/airflow/pull/65991#discussion_r3297150643


##########
providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py:
##########
@@ -704,6 +770,128 @@ def _process_spark_submit_log(self, itr: Iterator[Any]) 
-> None:
 
             self.log.info(line)
 
+    def _track_yarn_application(self, application_id: str) -> None:
+        """Poll the YARN RM REST API until ``app.finalStatus`` reaches a 
terminal value."""
+        self.log.info(
+            "Tracking YARN application %s via ResourceManager REST API 
polling",
+            application_id,
+        )
+        poll_interval = max(self._status_poll_interval, 1)
+        # Tolerate transient RM REST API failures (RM hiccup, network blip, 
request
+        # timeout) the same way `_start_driver_status_tracking` does for spark
+        # standalone — only give up after this many consecutive failures.
+        consecutive_failures = 0
+        max_consecutive_failures = 10
+        while True:
+            self.log.debug("Polling YARN RM REST API for application %s", 
application_id)
+            try:
+                final_status = 
self._query_yarn_application_final_status(application_id)
+            except RuntimeError as exc:
+                consecutive_failures += 1
+                if consecutive_failures > max_consecutive_failures:
+                    raise RuntimeError(
+                        f"Giving up tracking YARN application {application_id} 
after "
+                        f"{max_consecutive_failures} consecutive YARN RM REST 
API "
+                        f"failures. Last error: {exc}"
+                    ) from exc
+                self.log.warning(
+                    "Transient YARN RM REST API failure (%d/%d): %s",
+                    consecutive_failures,
+                    max_consecutive_failures,
+                    exc,
+                )
+                time.sleep(poll_interval)
+                continue
+            consecutive_failures = 0
+            if final_status == self._YARN_FINAL_SUCCESS:
+                self.log.info("YARN application %s finished with SUCCEEDED", 
application_id)
+                return
+            if final_status in self._YARN_FINAL_FAILURES:
+                raise RuntimeError(
+                    f"YARN application {application_id} ended with final 
status: {final_status}"
+                )
+            if final_status != self._YARN_FINAL_UNDEFINED:
+                raise RuntimeError(
+                    f"YARN application {application_id} returned unexpected 
final status: {final_status}"
+                )
+            time.sleep(poll_interval)
+
+    def _get_yarn_rm_base_url(self) -> str:
+        """
+        Resolve the YARN ResourceManager webapp base URL from the Spark 
connection.
+
+        Reads the ``yarn_resourcemanager_webapp_address`` key from the Spark
+        connection's ``extra`` JSON. Bare ``host:port`` values get ``http://``
+        prepended; fully-qualified URLs are used as-is. Trailing slashes 
stripped.
+        """
+        try:
+            conn = self.get_connection(self._conn_id)
+        except AirflowException:
+            conn = None
+        raw = ""
+        if conn is not None:
+            raw = 
(conn.extra_dejson.get(self._YARN_RM_WEBAPP_ADDRESS_EXTRA_KEY) or "").strip()
+        if not raw:
+            raise ValueError(
+                f"`yarn_track_via_rm_api=True` requires the Spark connection's 
`extra` to set "
+                f"`{self._YARN_RM_WEBAPP_ADDRESS_EXTRA_KEY}` (e.g. 
`http://rm.example.com:8088`)."
+            )
+        url = raw if "://" in raw else f"http://{raw}";
+        return url.rstrip("/")
+
+    def _query_yarn_application_final_status(self, application_id: str) -> str:
+        """GET ``/ws/v1/cluster/apps/{id}`` once and return 
``app.finalStatus``."""
+        url = 
f"{self._get_yarn_rm_base_url()}/ws/v1/cluster/apps/{application_id}"

Review Comment:
   `_get_yarn_rm_base_url` is called on every loop iteration of 
`_track_yarn_application(). get_connection()` performs a connection lookup each 
time. For a job polling every second for hours, this is thousands of 
unnecessary connection fetch calls. 



##########
providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py:
##########
@@ -195,6 +220,8 @@ def __init__(
         *,
         use_krb5ccache: bool = False,
         post_submit_commands: list[str] | None = None,
+        yarn_track_via_rm_api: bool = False,
+        yarn_rm_auth: Any = None,

Review Comment:
   Narrow the type down pls.



##########
providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py:
##########
@@ -141,6 +155,8 @@ def __init__(
         deploy_mode: str | None = None,
         use_krb5ccache: bool = False,
         post_submit_commands: list[str] | None = None,
+        yarn_track_via_rm_api: bool = False,
+        yarn_rm_auth: Any = None,

Review Comment:
   Narrow type down pls



##########
providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py:
##########
@@ -1336,3 +1371,316 @@ def 
test_post_submit_commands_none_gives_empty_list(self):
         """Test that None post_submit_commands results in an empty list."""
         hook = SparkSubmitHook(conn_id="")
         assert hook._post_submit_commands == []
+
+    # ---------------------------------------------------------------
+    # yarn_track_via_rm_api (issue #24171)
+    # ---------------------------------------------------------------
+    # Tests for the YARN ResourceManager REST API polling path that lets
+    # SparkSubmitHook release the spark-submit JVM after YARN accepts the
+    # application, instead of holding the JVM open just to read stdout.
+
+    _YARN_LOG_LINES = [
+        "INFO Client: Requesting a new application from cluster with 1 
NodeManagers",
+        "INFO Client: Uploading resource file:/tmp/lib.zip -> "
+        
"hdfs://namenode:8020/user/root/.sparkStaging/application_1700000000000_0001/lib.zip",
+        "INFO Client: Submitting application application_1700000000000_0001 to 
ResourceManager",
+        "INFO YarnClientImpl: Submitted application 
application_1700000000000_0001",
+        "INFO Client: Application report for application_1700000000000_0001 
(state: ACCEPTED)",
+        "INFO Client: Application report for application_1700000000000_0001 
(state: RUNNING)",
+        "INFO Client: Application report for application_1700000000000_0001 
(state: FINISHED)",
+        "INFO Client: final status: SUCCEEDED",
+    ]
+
+    _RM_BASE_URL = "http://rm.test:8088";
+    _RM_APP_ID = "application_1700000000000_0001"
+
+    @classmethod
+    def _rm_status_url(cls, app_id: str | None = None) -> str:
+        return f"{cls._RM_BASE_URL}/ws/v1/cluster/apps/{app_id or 
cls._RM_APP_ID}"
+
+    @classmethod
+    def _rm_kill_url(cls, app_id: str | None = None) -> str:
+        return f"{cls._RM_BASE_URL}/ws/v1/cluster/apps/{app_id or 
cls._RM_APP_ID}/state"
+
+    @classmethod
+    def _rm_status_resp(cls, final_status: str, state: str = "FINISHED") -> 
MagicMock:
+        resp = MagicMock(spec=requests.Response)
+        resp.status_code = 200
+        resp.json.return_value = {"app": {"id": cls._RM_APP_ID, "state": 
state, "finalStatus": final_status}}
+        return resp
+
+    @staticmethod
+    def _rm_failure_resp(status_code: int = 500, text: str = "Internal Server 
Error") -> MagicMock:
+        resp = MagicMock(spec=requests.Response)
+        resp.status_code = status_code
+        resp.text = text
+        return resp
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.put")
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def test_default_keeps_existing_behavior_in_yarn_cluster(self, mock_popen, 
mock_get, mock_put):
+        """Flag default False -> no HTTP calls; behavior identical to today."""
+        proc = MagicMock(spec=["stdout", "terminate", "wait"])
+        proc.stdout = iter(self._YARN_LOG_LINES)
+        proc.wait.return_value = 0
+        mock_popen.return_value = proc
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_cluster")
+        hook.submit()
+
+        proc.terminate.assert_not_called()
+        mock_get.assert_not_called()
+        mock_put.assert_not_called()
+        assert hook._yarn_application_id == "application_1700000000000_0001"
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def test_yarn_status_tracking_succeeds(self, mock_popen, mock_get, 
mock_sleep):
+        """RM returns UNDEFINED then SUCCEEDED -> hook returns normally."""
+        proc = MagicMock()

Review Comment:
   Pls use `spec` or `autospec`



##########
providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py:
##########
@@ -1336,3 +1371,316 @@ def 
test_post_submit_commands_none_gives_empty_list(self):
         """Test that None post_submit_commands results in an empty list."""
         hook = SparkSubmitHook(conn_id="")
         assert hook._post_submit_commands == []
+
+    # ---------------------------------------------------------------
+    # yarn_track_via_rm_api (issue #24171)
+    # ---------------------------------------------------------------
+    # Tests for the YARN ResourceManager REST API polling path that lets
+    # SparkSubmitHook release the spark-submit JVM after YARN accepts the
+    # application, instead of holding the JVM open just to read stdout.

Review Comment:
   ```suggestion
   ```



##########
providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py:
##########
@@ -1336,3 +1371,316 @@ def 
test_post_submit_commands_none_gives_empty_list(self):
         """Test that None post_submit_commands results in an empty list."""
         hook = SparkSubmitHook(conn_id="")
         assert hook._post_submit_commands == []
+
+    # ---------------------------------------------------------------
+    # yarn_track_via_rm_api (issue #24171)
+    # ---------------------------------------------------------------
+    # Tests for the YARN ResourceManager REST API polling path that lets
+    # SparkSubmitHook release the spark-submit JVM after YARN accepts the
+    # application, instead of holding the JVM open just to read stdout.
+
+    _YARN_LOG_LINES = [
+        "INFO Client: Requesting a new application from cluster with 1 
NodeManagers",
+        "INFO Client: Uploading resource file:/tmp/lib.zip -> "
+        
"hdfs://namenode:8020/user/root/.sparkStaging/application_1700000000000_0001/lib.zip",
+        "INFO Client: Submitting application application_1700000000000_0001 to 
ResourceManager",
+        "INFO YarnClientImpl: Submitted application 
application_1700000000000_0001",
+        "INFO Client: Application report for application_1700000000000_0001 
(state: ACCEPTED)",
+        "INFO Client: Application report for application_1700000000000_0001 
(state: RUNNING)",
+        "INFO Client: Application report for application_1700000000000_0001 
(state: FINISHED)",
+        "INFO Client: final status: SUCCEEDED",
+    ]
+
+    _RM_BASE_URL = "http://rm.test:8088";
+    _RM_APP_ID = "application_1700000000000_0001"
+
+    @classmethod
+    def _rm_status_url(cls, app_id: str | None = None) -> str:
+        return f"{cls._RM_BASE_URL}/ws/v1/cluster/apps/{app_id or 
cls._RM_APP_ID}"
+
+    @classmethod
+    def _rm_kill_url(cls, app_id: str | None = None) -> str:
+        return f"{cls._RM_BASE_URL}/ws/v1/cluster/apps/{app_id or 
cls._RM_APP_ID}/state"
+
+    @classmethod
+    def _rm_status_resp(cls, final_status: str, state: str = "FINISHED") -> 
MagicMock:
+        resp = MagicMock(spec=requests.Response)
+        resp.status_code = 200
+        resp.json.return_value = {"app": {"id": cls._RM_APP_ID, "state": 
state, "finalStatus": final_status}}
+        return resp
+
+    @staticmethod
+    def _rm_failure_resp(status_code: int = 500, text: str = "Internal Server 
Error") -> MagicMock:
+        resp = MagicMock(spec=requests.Response)
+        resp.status_code = status_code
+        resp.text = text
+        return resp
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.put")
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def test_default_keeps_existing_behavior_in_yarn_cluster(self, mock_popen, 
mock_get, mock_put):
+        """Flag default False -> no HTTP calls; behavior identical to today."""
+        proc = MagicMock(spec=["stdout", "terminate", "wait"])
+        proc.stdout = iter(self._YARN_LOG_LINES)
+        proc.wait.return_value = 0
+        mock_popen.return_value = proc
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_cluster")
+        hook.submit()
+
+        proc.terminate.assert_not_called()
+        mock_get.assert_not_called()
+        mock_put.assert_not_called()
+        assert hook._yarn_application_id == "application_1700000000000_0001"
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def test_yarn_status_tracking_succeeds(self, mock_popen, mock_get, 
mock_sleep):
+        """RM returns UNDEFINED then SUCCEEDED -> hook returns normally."""
+        proc = MagicMock()
+        proc.stdout = iter(self._YARN_LOG_LINES)
+        proc.wait.return_value = 0
+        mock_popen.return_value = proc
+
+        mock_get.side_effect = [
+            self._rm_status_resp("UNDEFINED", state="RUNNING"),
+            self._rm_status_resp("SUCCEEDED"),
+        ]
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
+        hook.submit()
+
+        spark_submit_cmd = mock_popen.call_args.args[0]
+        assert "spark.yarn.submit.waitAppCompletion=false" in spark_submit_cmd
+        proc.terminate.assert_not_called()
+        assert mock_get.call_count == 2
+        for call_obj in mock_get.call_args_list:
+            assert call_obj.args[0] == self._rm_status_url()
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def test_yarn_status_tracking_fails_on_killed(self, mock_popen, mock_get, 
mock_sleep):
+        """RM returns KILLED -> raise with message containing app id and 
KILLED."""
+        proc = MagicMock()
+        proc.stdout = iter(self._YARN_LOG_LINES)
+        proc.wait.return_value = 0
+        mock_popen.return_value = proc
+
+        mock_get.return_value = self._rm_status_resp("KILLED")
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
+        with pytest.raises(RuntimeError, match=f"{self._RM_APP_ID}.*KILLED"):
+            hook.submit()
+        proc.terminate.assert_not_called()
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def test_yarn_status_tracking_fails_on_unexpected_final_status(self, 
mock_popen, mock_get, mock_sleep):
+        """RM returns a non-standard finalStatus ('BOGUS') -> raise without 
sleeping."""
+        proc = MagicMock()

Review Comment:
   Same here



##########
providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py:
##########
@@ -1336,3 +1371,316 @@ def 
test_post_submit_commands_none_gives_empty_list(self):
         """Test that None post_submit_commands results in an empty list."""
         hook = SparkSubmitHook(conn_id="")
         assert hook._post_submit_commands == []
+
+    # ---------------------------------------------------------------
+    # yarn_track_via_rm_api (issue #24171)
+    # ---------------------------------------------------------------
+    # Tests for the YARN ResourceManager REST API polling path that lets
+    # SparkSubmitHook release the spark-submit JVM after YARN accepts the
+    # application, instead of holding the JVM open just to read stdout.
+
+    _YARN_LOG_LINES = [
+        "INFO Client: Requesting a new application from cluster with 1 
NodeManagers",
+        "INFO Client: Uploading resource file:/tmp/lib.zip -> "
+        
"hdfs://namenode:8020/user/root/.sparkStaging/application_1700000000000_0001/lib.zip",
+        "INFO Client: Submitting application application_1700000000000_0001 to 
ResourceManager",
+        "INFO YarnClientImpl: Submitted application 
application_1700000000000_0001",
+        "INFO Client: Application report for application_1700000000000_0001 
(state: ACCEPTED)",
+        "INFO Client: Application report for application_1700000000000_0001 
(state: RUNNING)",
+        "INFO Client: Application report for application_1700000000000_0001 
(state: FINISHED)",
+        "INFO Client: final status: SUCCEEDED",
+    ]
+
+    _RM_BASE_URL = "http://rm.test:8088";
+    _RM_APP_ID = "application_1700000000000_0001"
+
+    @classmethod
+    def _rm_status_url(cls, app_id: str | None = None) -> str:
+        return f"{cls._RM_BASE_URL}/ws/v1/cluster/apps/{app_id or 
cls._RM_APP_ID}"
+
+    @classmethod
+    def _rm_kill_url(cls, app_id: str | None = None) -> str:
+        return f"{cls._RM_BASE_URL}/ws/v1/cluster/apps/{app_id or 
cls._RM_APP_ID}/state"
+
+    @classmethod
+    def _rm_status_resp(cls, final_status: str, state: str = "FINISHED") -> 
MagicMock:
+        resp = MagicMock(spec=requests.Response)
+        resp.status_code = 200
+        resp.json.return_value = {"app": {"id": cls._RM_APP_ID, "state": 
state, "finalStatus": final_status}}
+        return resp
+
+    @staticmethod
+    def _rm_failure_resp(status_code: int = 500, text: str = "Internal Server 
Error") -> MagicMock:
+        resp = MagicMock(spec=requests.Response)
+        resp.status_code = status_code
+        resp.text = text
+        return resp
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.put")
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def test_default_keeps_existing_behavior_in_yarn_cluster(self, mock_popen, 
mock_get, mock_put):
+        """Flag default False -> no HTTP calls; behavior identical to today."""
+        proc = MagicMock(spec=["stdout", "terminate", "wait"])
+        proc.stdout = iter(self._YARN_LOG_LINES)
+        proc.wait.return_value = 0
+        mock_popen.return_value = proc
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_cluster")
+        hook.submit()
+
+        proc.terminate.assert_not_called()
+        mock_get.assert_not_called()
+        mock_put.assert_not_called()
+        assert hook._yarn_application_id == "application_1700000000000_0001"
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def test_yarn_status_tracking_succeeds(self, mock_popen, mock_get, 
mock_sleep):
+        """RM returns UNDEFINED then SUCCEEDED -> hook returns normally."""
+        proc = MagicMock()
+        proc.stdout = iter(self._YARN_LOG_LINES)
+        proc.wait.return_value = 0
+        mock_popen.return_value = proc
+
+        mock_get.side_effect = [
+            self._rm_status_resp("UNDEFINED", state="RUNNING"),
+            self._rm_status_resp("SUCCEEDED"),
+        ]
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
+        hook.submit()
+
+        spark_submit_cmd = mock_popen.call_args.args[0]
+        assert "spark.yarn.submit.waitAppCompletion=false" in spark_submit_cmd
+        proc.terminate.assert_not_called()
+        assert mock_get.call_count == 2
+        for call_obj in mock_get.call_args_list:
+            assert call_obj.args[0] == self._rm_status_url()
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def test_yarn_status_tracking_fails_on_killed(self, mock_popen, mock_get, 
mock_sleep):
+        """RM returns KILLED -> raise with message containing app id and 
KILLED."""
+        proc = MagicMock()

Review Comment:
   Same here



##########
providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py:
##########
@@ -870,4 +1039,32 @@ def on_kill(self) -> None:
                 except kube_client.ApiException:
                     self.log.exception("Exception when attempting to kill 
Spark on K8s")
 
+        if self._yarn_application_id:

Review Comment:
   The kill block was moved outside `if self._submit_sp` check. For the CLI 
path (`yarn_track_via_rm_api=False`), this is a behavioral change: previously 
the `yarn application -kill` was only issued when `_submit_sp` existed; now it 
fires whenever `_yarn_application_id` is set even if there is no submit 
process. This may be intentional for cleanup correctness but it is undocumented 
and changes existing behaviour for users who have not opted in to the new flag.



##########
providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_submit.py:
##########
@@ -1336,3 +1371,316 @@ def 
test_post_submit_commands_none_gives_empty_list(self):
         """Test that None post_submit_commands results in an empty list."""
         hook = SparkSubmitHook(conn_id="")
         assert hook._post_submit_commands == []
+
+    # ---------------------------------------------------------------
+    # yarn_track_via_rm_api (issue #24171)
+    # ---------------------------------------------------------------
+    # Tests for the YARN ResourceManager REST API polling path that lets
+    # SparkSubmitHook release the spark-submit JVM after YARN accepts the
+    # application, instead of holding the JVM open just to read stdout.
+
+    _YARN_LOG_LINES = [
+        "INFO Client: Requesting a new application from cluster with 1 
NodeManagers",
+        "INFO Client: Uploading resource file:/tmp/lib.zip -> "
+        
"hdfs://namenode:8020/user/root/.sparkStaging/application_1700000000000_0001/lib.zip",
+        "INFO Client: Submitting application application_1700000000000_0001 to 
ResourceManager",
+        "INFO YarnClientImpl: Submitted application 
application_1700000000000_0001",
+        "INFO Client: Application report for application_1700000000000_0001 
(state: ACCEPTED)",
+        "INFO Client: Application report for application_1700000000000_0001 
(state: RUNNING)",
+        "INFO Client: Application report for application_1700000000000_0001 
(state: FINISHED)",
+        "INFO Client: final status: SUCCEEDED",
+    ]
+
+    _RM_BASE_URL = "http://rm.test:8088";
+    _RM_APP_ID = "application_1700000000000_0001"
+
+    @classmethod
+    def _rm_status_url(cls, app_id: str | None = None) -> str:
+        return f"{cls._RM_BASE_URL}/ws/v1/cluster/apps/{app_id or 
cls._RM_APP_ID}"
+
+    @classmethod
+    def _rm_kill_url(cls, app_id: str | None = None) -> str:
+        return f"{cls._RM_BASE_URL}/ws/v1/cluster/apps/{app_id or 
cls._RM_APP_ID}/state"
+
+    @classmethod
+    def _rm_status_resp(cls, final_status: str, state: str = "FINISHED") -> 
MagicMock:
+        resp = MagicMock(spec=requests.Response)
+        resp.status_code = 200
+        resp.json.return_value = {"app": {"id": cls._RM_APP_ID, "state": 
state, "finalStatus": final_status}}
+        return resp
+
+    @staticmethod
+    def _rm_failure_resp(status_code: int = 500, text: str = "Internal Server 
Error") -> MagicMock:
+        resp = MagicMock(spec=requests.Response)
+        resp.status_code = status_code
+        resp.text = text
+        return resp
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.put")
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def test_default_keeps_existing_behavior_in_yarn_cluster(self, mock_popen, 
mock_get, mock_put):
+        """Flag default False -> no HTTP calls; behavior identical to today."""
+        proc = MagicMock(spec=["stdout", "terminate", "wait"])
+        proc.stdout = iter(self._YARN_LOG_LINES)
+        proc.wait.return_value = 0
+        mock_popen.return_value = proc
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_cluster")
+        hook.submit()
+
+        proc.terminate.assert_not_called()
+        mock_get.assert_not_called()
+        mock_put.assert_not_called()
+        assert hook._yarn_application_id == "application_1700000000000_0001"
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def test_yarn_status_tracking_succeeds(self, mock_popen, mock_get, 
mock_sleep):
+        """RM returns UNDEFINED then SUCCEEDED -> hook returns normally."""
+        proc = MagicMock()
+        proc.stdout = iter(self._YARN_LOG_LINES)
+        proc.wait.return_value = 0
+        mock_popen.return_value = proc
+
+        mock_get.side_effect = [
+            self._rm_status_resp("UNDEFINED", state="RUNNING"),
+            self._rm_status_resp("SUCCEEDED"),
+        ]
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
+        hook.submit()
+
+        spark_submit_cmd = mock_popen.call_args.args[0]
+        assert "spark.yarn.submit.waitAppCompletion=false" in spark_submit_cmd
+        proc.terminate.assert_not_called()
+        assert mock_get.call_count == 2
+        for call_obj in mock_get.call_args_list:
+            assert call_obj.args[0] == self._rm_status_url()
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def test_yarn_status_tracking_fails_on_killed(self, mock_popen, mock_get, 
mock_sleep):
+        """RM returns KILLED -> raise with message containing app id and 
KILLED."""
+        proc = MagicMock()
+        proc.stdout = iter(self._YARN_LOG_LINES)
+        proc.wait.return_value = 0
+        mock_popen.return_value = proc
+
+        mock_get.return_value = self._rm_status_resp("KILLED")
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
+        with pytest.raises(RuntimeError, match=f"{self._RM_APP_ID}.*KILLED"):
+            hook.submit()
+        proc.terminate.assert_not_called()
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def test_yarn_status_tracking_fails_on_unexpected_final_status(self, 
mock_popen, mock_get, mock_sleep):
+        """RM returns a non-standard finalStatus ('BOGUS') -> raise without 
sleeping."""
+        proc = MagicMock()
+        proc.stdout = iter(self._YARN_LOG_LINES)
+        proc.wait.return_value = 0
+        mock_popen.return_value = proc
+
+        mock_get.return_value = self._rm_status_resp("BOGUS")
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
+        with pytest.raises(RuntimeError, match="unexpected final status: 
BOGUS"):
+            hook.submit()
+
+        proc.terminate.assert_not_called()
+        mock_sleep.assert_not_called()
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def 
test_yarn_status_tracking_checks_spark_submit_exit_code_before_polling(self, 
mock_popen, mock_get):
+        """spark-submit exits non-zero -> raise BEFORE issuing any HTTP 
request."""
+        proc = MagicMock()
+        proc.stdout = iter(self._YARN_LOG_LINES)
+        proc.wait.return_value = 1
+        mock_popen.return_value = proc
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
+        with pytest.raises(AirflowException, match="Error code is: 1"):
+            hook.submit()
+
+        proc.terminate.assert_not_called()
+        mock_get.assert_not_called()
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def test_yarn_status_tracking_requires_application_submission_signal(self, 
mock_popen, mock_get):
+        """.sparkStaging app id without 'Submitted application' log -> raise; 
no HTTP."""
+        yarn_log_lines = [
+            "INFO Client: Uploading resource file:/tmp/lib.zip -> "
+            
"hdfs://namenode:8020/user/root/.sparkStaging/application_1700000000000_0001/lib.zip",
+            "INFO Client: Submitting application 
application_1700000000000_0001 to ResourceManager",
+        ]
+        proc = MagicMock()
+        proc.stdout = iter(yarn_log_lines)
+        proc.wait.return_value = 0
+        mock_popen.return_value = proc
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
+        with pytest.raises(RuntimeError, match="not confirmed as submitted"):
+            hook.submit()
+
+        assert hook._yarn_application_id == self._RM_APP_ID
+        assert hook._yarn_application_submitted is False
+        proc.terminate.assert_not_called()
+        mock_get.assert_not_called()
+
+    def 
test_yarn_status_tracking_rejects_conflicting_wait_app_completion_conf(self):
+        """User-set spark.yarn.submit.waitAppCompletion=true conflicts with 
flag -> ValueError."""
+        hook = SparkSubmitHook(
+            conn_id="spark_yarn_rm",
+            conf={"spark.yarn.submit.waitAppCompletion": "true"},
+            yarn_track_via_rm_api=True,
+        )
+
+        with pytest.raises(ValueError, 
match="spark.yarn.submit.waitAppCompletion=false"):
+            hook._build_spark_submit_command("")
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def test_yarn_status_tracking_tolerates_transient_failures(self, 
mock_popen, mock_get, mock_sleep):
+        """3 consecutive 5xx responses then SUCCEEDED -> normal completion."""
+        proc = MagicMock()
+        proc.stdout = iter(self._YARN_LOG_LINES)
+        proc.wait.return_value = 0
+        mock_popen.return_value = proc
+
+        # 3 transient failures (within the 10-failure budget), then SUCCEEDED.
+        mock_get.side_effect = [
+            self._rm_failure_resp(503, "Service Unavailable"),
+            self._rm_failure_resp(502, "Bad Gateway"),
+            self._rm_failure_resp(500, "Internal Server Error"),
+            self._rm_status_resp("SUCCEEDED"),
+        ]
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
+        hook.submit()
+
+        assert mock_get.call_count == 4
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def test_yarn_status_tracking_tolerates_status_timeouts(self, mock_popen, 
mock_get, mock_sleep):
+        """First requests.exceptions.Timeout, second call succeeds -> normal 
completion."""
+        proc = MagicMock(spec=["stdout", "terminate", "wait"])
+        proc.stdout = iter(self._YARN_LOG_LINES)
+        proc.wait.return_value = 0
+        mock_popen.return_value = proc
+
+        mock_get.side_effect = [
+            requests.exceptions.Timeout("read timed out"),
+            self._rm_status_resp("SUCCEEDED"),
+        ]
+
+        hook = SparkSubmitHook(conn_id="spark_yarn_rm", 
yarn_track_via_rm_api=True)
+        hook.submit()
+
+        assert mock_get.call_count == 2
+        # All calls must include the (connect, read) timeout tuple.
+        for call_obj in mock_get.call_args_list:
+            assert call_obj.kwargs["timeout"] == (5, 30)
+
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.time.sleep")
+    @patch("airflow.providers.apache.spark.hooks.spark_submit.requests.get")
+    
@patch("airflow.providers.apache.spark.hooks.spark_submit.subprocess.Popen")
+    def test_yarn_status_tracking_raises_after_too_many_failures(self, 
mock_popen, mock_get, mock_sleep):
+        """11 consecutive 5xx responses -> raise 'Giving up tracking YARN 
application'."""
+        proc = MagicMock()

Review Comment:
   Same here



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to