This is an automated email from the ASF dual-hosted git repository.

amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6dbe76a5d4e Enhance `ResumableJobMixin.get_job_status` with context 
for better job status tracking (#68009)
6dbe76a5d4e is described below

commit 6dbe76a5d4eebf5c9825b5da900c228961b99eec
Author: Amogh Desai <[email protected]>
AuthorDate: Thu Jun 4 17:51:18 2026 +0530

    Enhance `ResumableJobMixin.get_job_status` with context for better job 
status tracking (#68009)
---
 .../providers/apache/spark/operators/spark_submit.py      |  2 +-
 .../unit/apache/spark/operators/test_spark_submit.py      | 14 +++++++-------
 task-sdk/src/airflow/sdk/bases/resumablemixin.py          | 15 +++++++++++----
 task-sdk/tests/task_sdk/bases/test_resumablemixin.py      |  2 +-
 4 files changed, 20 insertions(+), 13 deletions(-)

diff --git 
a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
 
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
index 5321dbbb8be..3ac4870f313 100644
--- 
a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
+++ 
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
@@ -285,7 +285,7 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
         self.log.info("Spark driver submitted: %s", driver_id)
         return driver_id
 
-    def get_job_status(self, external_id: JsonValue) -> str:
+    def get_job_status(self, external_id: JsonValue, context: Context) -> str:
         # called from submit_job which always returns a str (Spark driver IDs 
are strings)
         external_id = cast("str", external_id)
         if self._hook is None:
diff --git 
a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py 
b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
index 47cada84ce6..95ad9f5142a 100644
--- 
a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
+++ 
b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
@@ -555,7 +555,7 @@ class TestSparkSubmitOperatorResumable:
         operator._hook.submit.return_value = "driver-new"
         task_store = FakeTaskState({"spark_job_id": "driver-001"})
 
-        operator.get_job_status = lambda external_id: prior_status
+        operator.get_job_status = lambda external_id, context: prior_status
         polled = []
         operator.poll_until_complete = lambda external_id, context: 
polled.append(external_id)
 
@@ -639,9 +639,9 @@ class TestSparkSubmitOperatorResumable:
         with mock.patch("requests.get", return_value=mock_response):
             if expected_error:
                 with pytest.raises(RuntimeError, match=expected_error):
-                    operator.get_job_status("driver-001")
+                    operator.get_job_status("driver-001", {})
             else:
-                assert operator.get_job_status("driver-001") == expected_status
+                assert operator.get_job_status("driver-001", {}) == 
expected_status
 
     def test_get_job_status_ha_tries_next_master(self):
         operator = self._make_operator()
@@ -661,7 +661,7 @@ class TestSparkSubmitOperatorResumable:
             return good_response
 
         with mock.patch("requests.get", side_effect=side_effect):
-            assert operator.get_job_status("driver-001") == "RUNNING"
+            assert operator.get_job_status("driver-001", {}) == "RUNNING"
 
         assert all(":6066/" in url for url in captured_urls), "REST API must 
use port 6066, not the RPC port"
 
@@ -683,7 +683,7 @@ class TestSparkSubmitOperatorResumable:
             return good_response
 
         with mock.patch("requests.get", side_effect=side_effect):
-            assert operator.get_job_status("driver-001") == "RUNNING"
+            assert operator.get_job_status("driver-001", {}) == "RUNNING"
 
     def test_get_job_status_ha_raises_when_all_masters_unreachable(self):
         operator = self._make_operator()
@@ -693,7 +693,7 @@ class TestSparkSubmitOperatorResumable:
 
         with mock.patch("requests.get", 
side_effect=ConnectionError("unreachable")):
             with pytest.raises(ConnectionError):
-                operator.get_job_status("driver-001")
+                operator.get_job_status("driver-001", {})
 
     def test_get_job_status_uses_rest_scheme_from_connection(self):
         operator = self._make_operator()
@@ -710,7 +710,7 @@ class TestSparkSubmitOperatorResumable:
             return mock_response
 
         with mock.patch("requests.get", side_effect=capture):
-            operator.get_job_status("driver-001")
+            operator.get_job_status("driver-001", {})
 
         assert len(captured_urls) == 1
         assert captured_urls[0].startswith("https://";)
diff --git a/task-sdk/src/airflow/sdk/bases/resumablemixin.py 
b/task-sdk/src/airflow/sdk/bases/resumablemixin.py
index 68620924bcf..55a561c030d 100644
--- a/task-sdk/src/airflow/sdk/bases/resumablemixin.py
+++ b/task-sdk/src/airflow/sdk/bases/resumablemixin.py
@@ -59,7 +59,7 @@ class ResumableJobMixin:
             def submit_job(self, context) -> JsonValue:
                 return self.hook.submit(...)
 
-            def get_job_status(self, external_id: JsonValue) -> str:
+            def get_job_status(self, external_id: JsonValue, context: Context) 
-> str:
                 return self.hook.get_status(external_id)
 
             def is_job_active(self, status: str) -> bool:
@@ -106,7 +106,7 @@ class ResumableJobMixin:
         if task_store is not None:
             external_id = task_store.get(self.external_id_key)
             if external_id:
-                status = self.get_job_status(external_id)
+                status = self.get_job_status(external_id, context)
                 if self.is_job_active(status):
                     self.log.info(
                         "Reconnecting to existing job identified by: %s 
(status: %s)", external_id, status
@@ -141,8 +141,15 @@ class ResumableJobMixin:
         """
         raise NotImplementedError
 
-    def get_job_status(self, external_id: JsonValue) -> str:
-        """Query the external system for the current job status."""
+    def get_job_status(self, external_id: JsonValue, context: Context) -> str:
+        """
+        Query the external system for the current job status.
+
+        ``context`` is provided so implementations can use it if needed to 
implement advanced features such as:
+
+        - cache terminal status to ``task_store`` when the remote resource may 
be
+          ephemeral (e.g. a K8s driver pod that gets garbage-collected after 
completion),
+        """
         raise NotImplementedError
 
     def is_job_active(self, status: str) -> bool:
diff --git a/task-sdk/tests/task_sdk/bases/test_resumablemixin.py 
b/task-sdk/tests/task_sdk/bases/test_resumablemixin.py
index 6796ec1029c..9e28a53d31d 100644
--- a/task-sdk/tests/task_sdk/bases/test_resumablemixin.py
+++ b/task-sdk/tests/task_sdk/bases/test_resumablemixin.py
@@ -45,7 +45,7 @@ class ConcreteResumableOperator(ResumableJobMixin, 
BaseOperator):
         self.submitted_ids.append(self._next_id)
         return self._next_id
 
-    def get_job_status(self, external_id: JsonValue) -> str:
+    def get_job_status(self, external_id: JsonValue, context) -> str:
         return self._status_map.get(str(external_id), "UNKNOWN")
 
     def is_job_active(self, status: str) -> bool:

Reply via email to