This is an automated email from the ASF dual-hosted git repository.
amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 99d88c13a0e Add default parameter to task and asset state get()
accessors (#67842)
99d88c13a0e is described below
commit 99d88c13a0eaaa5398a1f6e6dc6a48b0e59dc1d2
Author: Amogh Desai <[email protected]>
AuthorDate: Wed Jun 3 09:02:44 2026 +0530
Add default parameter to task and asset state get() accessors (#67842)
---
.../airflow/example_dags/example_asset_store.py | 5 +--
task-sdk/src/airflow/sdk/bases/resumablemixin.py | 9 ++++-
task-sdk/src/airflow/sdk/execution_time/context.py | 28 ++++++++-----
.../tests/task_sdk/bases/test_resumablemixin.py | 22 +++++++++++
.../tests/task_sdk/execution_time/test_context.py | 46 ++++++++++++++++++++++
.../task_sdk/execution_time/test_task_runner.py | 39 ++++++++++++++++++
6 files changed, 134 insertions(+), 15 deletions(-)
diff --git a/airflow-core/src/airflow/example_dags/example_asset_store.py
b/airflow-core/src/airflow/example_dags/example_asset_store.py
index 6e6d30e3501..febef84e6ad 100644
--- a/airflow-core/src/airflow/example_dags/example_asset_store.py
+++ b/airflow-core/src/airflow/example_dags/example_asset_store.py
@@ -53,14 +53,13 @@ with DAG(
def load(asset_store=None):
state = asset_store[ORDERS]
- # First run: watermark is None — fall back to epoch start.
- watermark = state.get("watermark") or "2026-01-01T00:00:00+00:00"
+ watermark = state.get("watermark", default="2026-01-01T00:00:00+00:00")
records = _fetch_records(since=watermark)
row_count = len(records)
now = datetime.now(tz=timezone.utc).isoformat()
state.set("watermark", now)
- state.set("total_runs", (state.get("total_runs") or 0) + 1)
+ state.set("total_runs", state.get("total_runs", default=0) + 1)
state.set(
"last_run_summary",
{
diff --git a/task-sdk/src/airflow/sdk/bases/resumablemixin.py
b/task-sdk/src/airflow/sdk/bases/resumablemixin.py
index 55e2d97ff7a..68620924bcf 100644
--- a/task-sdk/src/airflow/sdk/bases/resumablemixin.py
+++ b/task-sdk/src/airflow/sdk/bases/resumablemixin.py
@@ -126,14 +126,19 @@ class ResumableJobMixin:
external_id = self.submit_job(context)
- if task_store is not None:
+ if task_store is not None and external_id is not None:
task_store.set(self.external_id_key, external_id)
self.poll_until_complete(external_id, context)
return self.get_job_result(external_id, context)
def submit_job(self, context: Context) -> JsonValue:
- """Submit the job to the external system. Return its external ID."""
+ """
+ Submit the job to the external system. Return its external ID.
+
+ The returned ID must not be ``None``, a ``None`` return is treated as
+ "no ID available" and the ID will not be persisted to task state.
+ """
raise NotImplementedError
def get_job_status(self, external_id: JsonValue) -> str:
diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py
b/task-sdk/src/airflow/sdk/execution_time/context.py
index 5971c076ccb..7a601e6f320 100644
--- a/task-sdk/src/airflow/sdk/execution_time/context.py
+++ b/task-sdk/src/airflow/sdk/execution_time/context.py
@@ -506,9 +506,9 @@ class TaskStoreAccessor:
# is not implemented yet cos it's unclear whether task state values will be
# used in templates.
- def get(self, key: str) -> JsonValue:
+ def get(self, key: str, default: JsonValue = None) -> JsonValue:
"""
- Return the stored value, or ``None`` if the key does not exist.
+ Return the stored value, or ``default`` if the key does not exist.
Supported types: ``str``, ``int``, ``float``, ``bool``, ``list``,
``dict``.
``datetime`` is not JSON-serializable; store it as
``value.isoformat()`` and
@@ -535,12 +535,14 @@ class TaskStoreAccessor:
key,
)
return stored
- return None
+ return default
def set(self, key: str, value: JsonValue, *, retention: timedelta | None =
None) -> None:
"""
Write or overwrite the value for the given key.
+ ``value`` must not be ``None``.
+
``retention`` is an optional key that controls when this key expires:
- ``timedelta(...)`` — expire after the given duration (e.g.
``timedelta(hours=6)``).
@@ -550,6 +552,9 @@ class TaskStoreAccessor:
from airflow.sdk.execution_time.comms import SetTaskStore
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+ if value is None:
+ raise ValueError("Cannot set value as None")
+
# expires_at is always resolved on the worker in UTC before being sent.
now = datetime.now(tz=timezone.utc)
if retention is NEVER_EXPIRE:
@@ -640,8 +645,8 @@ class AssetStoreAccessor:
return f"<AssetStoreAccessor name={self._name!r}>"
return f"<AssetStoreAccessor uri={self._uri!r}>"
- def get(self, key: str) -> JsonValue:
- """Return the stored value, or ``None`` if the key does not exist."""
+ def get(self, key: str, default: JsonValue = None) -> JsonValue:
+ """Return the stored value, or ``default`` if the key does not
exist."""
from airflow.sdk.execution_time.comms import (
AssetStoreResult,
ErrorResponse,
@@ -674,13 +679,16 @@ class AssetStoreAccessor:
self._name or self._uri,
)
return stored
- return None
+ return default
def set(self, key: str, value: JsonValue) -> None:
- """Write or overwrite the value for the given key."""
+ """Write or overwrite the value for the given key. ``value`` must not
be ``None``."""
from airflow.sdk.execution_time.comms import SetAssetStoreByName,
SetAssetStoreByUri, ToSupervisor
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+ if value is None:
+ raise ValueError("Cannot set value as None")
+
# if custom backend is configured, store the value on the custom
backend, and return the reference
# to the stored value to store in the DB
backend = _get_worker_state_backend()
@@ -797,9 +805,9 @@ class AssetStoreAccessors:
return next(iter(self._by_name.values()))
return next(iter(self._by_uri.values()))
- def get(self, key: str) -> JsonValue:
- """Return the stored value for the single-inlet task, or ``None`` if
not found."""
- return self._single_accessor().get(key)
+ def get(self, key: str, default: JsonValue = None) -> JsonValue:
+ """Return the stored value for the single-inlet or single-outlet task,
or ``default`` if not found."""
+ return self._single_accessor().get(key, default)
def set(self, key: str, value: JsonValue) -> None:
"""Write or overwrite the value for the single-inlet task."""
diff --git a/task-sdk/tests/task_sdk/bases/test_resumablemixin.py
b/task-sdk/tests/task_sdk/bases/test_resumablemixin.py
index 4d0f8fa8cdb..6796ec1029c 100644
--- a/task-sdk/tests/task_sdk/bases/test_resumablemixin.py
+++ b/task-sdk/tests/task_sdk/bases/test_resumablemixin.py
@@ -164,6 +164,28 @@ class TestRetryWithDifferentJobStatuses:
assert op.polled_ids == ["job-002"]
+class TestNoneExternalId:
+ def test_none_external_id_is_not_stored(self):
+ """submit_job() returning None must not call task_state.set()."""
+
+ class NoneIdOp(ConcreteResumableOperator):
+ def submit_job(self, context) -> JsonValue:
+ return None
+
+ def poll_until_complete(self, external_id, context) -> None:
+ pass
+
+ def get_job_result(self, external_id, context) -> str:
+ return "done"
+
+ op = NoneIdOp(task_id="test_task")
+ task_state = FakeTaskState()
+
+ op.execute_resumable(make_context(task_state))
+
+ assert task_state._store == {}
+
+
class TestExternalIdKey:
def test_custom_key_used_for_storage_and_retrieval(self):
class CustomKeyOp(ConcreteResumableOperator):
diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py
b/task-sdk/tests/task_sdk/execution_time/test_context.py
index 2531d003cc0..9bcaa513d90 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_context.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_context.py
@@ -1123,6 +1123,24 @@ class TestTaskStoreAccessor:
assert result is None
+ def test_get_returns_default_when_key_missing(self, mock_supervisor_comms):
+ mock_supervisor_comms.send.return_value = ErrorResponse(
+ error=ErrorType.TASK_STORE_NOT_FOUND, detail={"key": "job_id"}
+ )
+
+ result = TaskStoreAccessor(ti_id=self.TI_ID,
scope=self.SCOPE).get("job_id", default="default-id")
+
+ assert result == "default-id"
+
+ def test_get_ignores_default_when_key_exists(self, mock_supervisor_comms):
+ mock_supervisor_comms.send.return_value =
TaskStoreResult(value="job-001")
+
+ result = TaskStoreAccessor(ti_id=self.TI_ID, scope=self.SCOPE).get(
+ "job_id", default="do-not-start-here"
+ )
+
+ assert result == "job-001"
+
def test_get_raises_on_error(self, mock_supervisor_comms):
mock_supervisor_comms.send.return_value = ErrorResponse(
error=ErrorType.GENERIC_ERROR, detail={"message": "server error"}
@@ -1131,6 +1149,10 @@ class TestTaskStoreAccessor:
with pytest.raises(AirflowRuntimeError):
TaskStoreAccessor(ti_id=self.TI_ID,
scope=self.SCOPE).get("some_key")
+ def test_set_none_raises(self, mock_supervisor_comms):
+ with pytest.raises(ValueError, match="Cannot set value as None"):
+ TaskStoreAccessor(ti_id=self.TI_ID,
scope=self.SCOPE).set("job_id", None)
+
def test_set_operation_with_global_retention(self, mock_supervisor_comms,
time_machine):
"""set() with no retention uses global default_retention_days
config."""
@@ -1287,6 +1309,26 @@ class TestAssetStoreAccessor:
assert result is None
+ def test_get_returns_default_when_key_missing(self, mock_supervisor_comms):
+ mock_supervisor_comms.send.return_value = ErrorResponse(
+ error=ErrorType.ASSET_STORE_NOT_FOUND, detail={"key": "watermark"}
+ )
+
+ result = AssetStoreAccessor(name=self.ASSET_NAME).get(
+ "watermark", default="2026-01-01T00:00:00+00:00"
+ )
+
+ assert result == "2026-01-01T00:00:00+00:00"
+
+ def test_get_ignores_default_when_key_exists(self, mock_supervisor_comms):
+ mock_supervisor_comms.send.return_value =
AssetStoreResult(value="2026-06-01T00:00:00+00:00")
+
+ result = AssetStoreAccessor(name=self.ASSET_NAME).get(
+ "watermark", default="2026-01-01T00:00:00+00:00"
+ )
+
+ assert result == "2026-06-01T00:00:00+00:00"
+
def test_get_raises_on_error(self, mock_supervisor_comms):
mock_supervisor_comms.send.return_value = ErrorResponse(
error=ErrorType.GENERIC_ERROR, detail={"message": "server error"}
@@ -1304,6 +1346,10 @@ class TestAssetStoreAccessor:
SetAssetStoreByName(name=self.ASSET_NAME, key="watermark",
value="2026-04-30T00:00:00Z")
)
+ def test_set_none_raises(self, mock_supervisor_comms):
+ with pytest.raises(ValueError, match="Cannot set value as None"):
+ AssetStoreAccessor(name=self.ASSET_NAME).set("watermark", None)
+
def test_delete_operation(self, mock_supervisor_comms):
mock_supervisor_comms.send.return_value = OKResponse(ok=True)
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 893419cc25c..f8953dc2329 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -5325,6 +5325,24 @@ class TestTaskInstanceStateOperations:
)
mock_supervisor_comms.send.assert_any_call(GetTaskStore(ti_id=runtime_ti.id,
key="job_id"))
+ def test_task_state_get_returns_default_when_key_missing(self,
create_runtime_ti, mock_supervisor_comms):
+ captured = {}
+
+ class MyOperator(BaseOperator):
+ def execute(self, context):
+ captured["result"] = context["task_store"].get(
+ "watermark", default="2026-01-01T00:00:00+00:00"
+ )
+
+ mock_supervisor_comms.send.return_value = ErrorResponse(
+ error=ErrorType.TASK_STORE_NOT_FOUND, detail={"key": "watermark"}
+ )
+ task = MyOperator(task_id="t")
+ runtime_ti = create_runtime_ti(task=task)
+ run(runtime_ti, context=runtime_ti.get_template_context(),
log=mock.MagicMock())
+
+ assert captured["result"] == "2026-01-01T00:00:00+00:00"
+
def test_task_state_set_sends_typed_values(self, create_runtime_ti,
mock_supervisor_comms, time_machine):
"""set() accepts any JsonValue — dict, int, list — not just strings."""
@@ -5441,6 +5459,27 @@ class TestTaskInstanceStateOperations:
)
mock_supervisor_comms.send.assert_any_call(GetAssetStoreByName(name="my_asset",
key="watermark"))
+ def test_asset_state_get_returns_default_when_key_missing(self,
create_runtime_ti, mock_supervisor_comms):
+ watched = Asset(name="my_asset", uri="s3://bucket/data")
+ captured = {}
+
+ class WatcherOperator(BaseOperator):
+ def execute(self, context):
+ captured["result"] = context["asset_store"].get(
+ "watermark", default="2026-01-01T00:00:00+00:00"
+ )
+
+ task = WatcherOperator(task_id="t", inlets=[watched])
+ runtime_ti = create_runtime_ti(task=task)
+ mock_supervisor_comms.send.side_effect = lambda msg: (
+ ErrorResponse(error=ErrorType.ASSET_STORE_NOT_FOUND,
detail={"key": "watermark"})
+ if isinstance(msg, GetAssetStoreByName)
+ else TestTaskInstanceStateOperations._watcher_side_effect(msg)
+ )
+ run(runtime_ti, context=runtime_ti.get_template_context(),
log=mock.MagicMock())
+
+ assert captured["result"] == "2026-01-01T00:00:00+00:00"
+
def test_asset_state_delete(self, create_runtime_ti,
mock_supervisor_comms):
watched = Asset(name="my_asset", uri="s3://bucket/data")