This is an automated email from the ASF dual-hosted git repository.
amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 6dbe76a5d4e Enhance `ResumableJobMixin.get_job_status` with context
for better job status tracking (#68009)
6dbe76a5d4e is described below
commit 6dbe76a5d4eebf5c9825b5da900c228961b99eec
Author: Amogh Desai <[email protected]>
AuthorDate: Thu Jun 4 17:51:18 2026 +0530
Enhance `ResumableJobMixin.get_job_status` with context for better job
status tracking (#68009)
---
.../providers/apache/spark/operators/spark_submit.py | 2 +-
.../unit/apache/spark/operators/test_spark_submit.py | 14 +++++++-------
task-sdk/src/airflow/sdk/bases/resumablemixin.py | 15 +++++++++++----
task-sdk/tests/task_sdk/bases/test_resumablemixin.py | 2 +-
4 files changed, 20 insertions(+), 13 deletions(-)
diff --git
a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
index 5321dbbb8be..3ac4870f313 100644
---
a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
+++
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
@@ -285,7 +285,7 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
self.log.info("Spark driver submitted: %s", driver_id)
return driver_id
- def get_job_status(self, external_id: JsonValue) -> str:
+ def get_job_status(self, external_id: JsonValue, context: Context) -> str:
# called from submit_job which always returns a str (Spark driver IDs
are strings)
external_id = cast("str", external_id)
if self._hook is None:
diff --git
a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
index 47cada84ce6..95ad9f5142a 100644
---
a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
+++
b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
@@ -555,7 +555,7 @@ class TestSparkSubmitOperatorResumable:
operator._hook.submit.return_value = "driver-new"
task_store = FakeTaskState({"spark_job_id": "driver-001"})
- operator.get_job_status = lambda external_id: prior_status
+ operator.get_job_status = lambda external_id, context: prior_status
polled = []
operator.poll_until_complete = lambda external_id, context:
polled.append(external_id)
@@ -639,9 +639,9 @@ class TestSparkSubmitOperatorResumable:
with mock.patch("requests.get", return_value=mock_response):
if expected_error:
with pytest.raises(RuntimeError, match=expected_error):
- operator.get_job_status("driver-001")
+ operator.get_job_status("driver-001", {})
else:
- assert operator.get_job_status("driver-001") == expected_status
+ assert operator.get_job_status("driver-001", {}) ==
expected_status
def test_get_job_status_ha_tries_next_master(self):
operator = self._make_operator()
@@ -661,7 +661,7 @@ class TestSparkSubmitOperatorResumable:
return good_response
with mock.patch("requests.get", side_effect=side_effect):
- assert operator.get_job_status("driver-001") == "RUNNING"
+ assert operator.get_job_status("driver-001", {}) == "RUNNING"
assert all(":6066/" in url for url in captured_urls), "REST API must
use port 6066, not the RPC port"
@@ -683,7 +683,7 @@ class TestSparkSubmitOperatorResumable:
return good_response
with mock.patch("requests.get", side_effect=side_effect):
- assert operator.get_job_status("driver-001") == "RUNNING"
+ assert operator.get_job_status("driver-001", {}) == "RUNNING"
def test_get_job_status_ha_raises_when_all_masters_unreachable(self):
operator = self._make_operator()
@@ -693,7 +693,7 @@ class TestSparkSubmitOperatorResumable:
with mock.patch("requests.get",
side_effect=ConnectionError("unreachable")):
with pytest.raises(ConnectionError):
- operator.get_job_status("driver-001")
+ operator.get_job_status("driver-001", {})
def test_get_job_status_uses_rest_scheme_from_connection(self):
operator = self._make_operator()
@@ -710,7 +710,7 @@ class TestSparkSubmitOperatorResumable:
return mock_response
with mock.patch("requests.get", side_effect=capture):
- operator.get_job_status("driver-001")
+ operator.get_job_status("driver-001", {})
assert len(captured_urls) == 1
assert captured_urls[0].startswith("https://")
diff --git a/task-sdk/src/airflow/sdk/bases/resumablemixin.py
b/task-sdk/src/airflow/sdk/bases/resumablemixin.py
index 68620924bcf..55a561c030d 100644
--- a/task-sdk/src/airflow/sdk/bases/resumablemixin.py
+++ b/task-sdk/src/airflow/sdk/bases/resumablemixin.py
@@ -59,7 +59,7 @@ class ResumableJobMixin:
def submit_job(self, context) -> JsonValue:
return self.hook.submit(...)
- def get_job_status(self, external_id: JsonValue) -> str:
+ def get_job_status(self, external_id: JsonValue, context: Context)
-> str:
return self.hook.get_status(external_id)
def is_job_active(self, status: str) -> bool:
@@ -106,7 +106,7 @@ class ResumableJobMixin:
if task_store is not None:
external_id = task_store.get(self.external_id_key)
if external_id:
- status = self.get_job_status(external_id)
+ status = self.get_job_status(external_id, context)
if self.is_job_active(status):
self.log.info(
"Reconnecting to existing job identified by: %s
(status: %s)", external_id, status
@@ -141,8 +141,15 @@ class ResumableJobMixin:
"""
raise NotImplementedError
- def get_job_status(self, external_id: JsonValue) -> str:
- """Query the external system for the current job status."""
+ def get_job_status(self, external_id: JsonValue, context: Context) -> str:
+ """
+ Query the external system for the current job status.
+
+ ``context`` is provided so implementations can use it if needed to
implement advanced features such as:
+
+ - cache terminal status to ``task_store`` when the remote resource may
be
+ ephemeral (e.g. a K8s driver pod that gets garbage-collected after
completion),
+ """
raise NotImplementedError
def is_job_active(self, status: str) -> bool:
diff --git a/task-sdk/tests/task_sdk/bases/test_resumablemixin.py
b/task-sdk/tests/task_sdk/bases/test_resumablemixin.py
index 6796ec1029c..9e28a53d31d 100644
--- a/task-sdk/tests/task_sdk/bases/test_resumablemixin.py
+++ b/task-sdk/tests/task_sdk/bases/test_resumablemixin.py
@@ -45,7 +45,7 @@ class ConcreteResumableOperator(ResumableJobMixin,
BaseOperator):
self.submitted_ids.append(self._next_id)
return self._next_id
- def get_job_status(self, external_id: JsonValue) -> str:
+ def get_job_status(self, external_id: JsonValue, context) -> str:
return self._status_map.get(str(external_id), "UNKNOWN")
def is_job_active(self, status: str) -> bool: