This is an automated email from the ASF dual-hosted git repository.
amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new b59126d7c90 Simplifing authoring of task and asset states by allowing
JSON types (#67418)
b59126d7c90 is described below
commit b59126d7c90e355476649dfb7a74a483c468422b
Author: Amogh Desai <[email protected]>
AuthorDate: Mon May 25 19:16:28 2026 +0530
Simplifing authoring of task and asset states by allowing JSON types
(#67418)
---
.../execution_api/datamodels/asset_state.py | 17 ++++++-
.../execution_api/datamodels/task_state.py | 16 ++++++-
.../execution_api/routes/asset_state.py | 9 ++--
.../api_fastapi/execution_api/routes/task_state.py | 5 +-
.../execution_api/versions/v2026_06_16.py | 2 +-
.../versions/head/test_asset_state.py | 51 +++++++++++++++++++-
.../execution_api/versions/head/test_task_state.py | 47 +++++++++++++++++-
shared/state/pyproject.toml | 4 +-
shared/state/src/airflow_shared/state/__init__.py | 55 ++++++++++++----------
shared/state/tests/state/test_state.py | 23 +++++++++
task-sdk/src/airflow/sdk/api/client.py | 8 ++--
.../src/airflow/sdk/api/datamodels/_generated.py | 48 +++++++++----------
task-sdk/src/airflow/sdk/execution_time/comms.py | 6 +--
task-sdk/src/airflow/sdk/execution_time/context.py | 37 ++++++++++-----
.../airflow/sdk/execution_time/schema/schema.json | 30 ++++++++----
.../tests/task_sdk/execution_time/test_context.py | 23 +++++++--
.../task_sdk/execution_time/test_task_runner.py | 34 +++++++++++++
uv.lock | 4 ++
18 files changed, 324 insertions(+), 95 deletions(-)
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset_state.py
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset_state.py
index ec773201c7e..ab8b3aa2aec 100644
---
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset_state.py
+++
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/asset_state.py
@@ -17,16 +17,29 @@
from __future__ import annotations
+import math
+
+from pydantic import JsonValue, field_validator
+
from airflow.api_fastapi.core_api.base import StrictBaseModel
class AssetStateResponse(StrictBaseModel):
"""Asset state value returned to a worker."""
- value: str
+ value: JsonValue
class AssetStatePutBody(StrictBaseModel):
"""Request body for setting an asset state value."""
- value: str
+ value: JsonValue
+
+ @field_validator("value")
+ @classmethod
+ def value_is_json_representable(cls, v: JsonValue) -> JsonValue:
+ if v is None:
+ raise ValueError("value cannot be null")
+ if isinstance(v, float) and not math.isfinite(v):
+ raise ValueError("value must be a finite number; NaN and Inf are
not JSON representable")
+ return v
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/task_state.py
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/task_state.py
index 20980b315c3..15fc44b7267 100644
---
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/task_state.py
+++
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/task_state.py
@@ -17,19 +17,31 @@
from __future__ import annotations
+import math
from datetime import datetime
+from pydantic import JsonValue, field_validator
+
from airflow.api_fastapi.core_api.base import StrictBaseModel
class TaskStateResponse(StrictBaseModel):
"""Task state value returned to a worker."""
- value: str
+ value: JsonValue
class TaskStatePutBody(StrictBaseModel):
"""Request body for setting a task state value."""
- value: str
+ value: JsonValue
expires_at: datetime | None = None
+
+ @field_validator("value")
+ @classmethod
+ def value_is_json_representable(cls, v: JsonValue) -> JsonValue:
+ if v is None:
+ raise ValueError("value cannot be null")
+ if isinstance(v, float) and not math.isfinite(v):
+ raise ValueError("value must be a finite number; NaN and Inf are
not JSON representable")
+ return v
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_state.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_state.py
index 2351caa6dfa..f7001c3158c 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_state.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/asset_state.py
@@ -28,6 +28,7 @@ Per-task asset registration checks are intentionally not
implemented here
from __future__ import annotations
+import json
from typing import Annotated
from cadwyn import VersionedAPIRouter
@@ -93,7 +94,7 @@ def get_asset_state_by_name(
status_code=status.HTTP_404_NOT_FOUND,
detail={"reason": "not_found", "message": f"Asset state key
{key!r} not found"},
)
- return AssetStateResponse(value=value)
+ return AssetStateResponse(value=json.loads(value))
@router.put("/by-name/value", status_code=status.HTTP_204_NO_CONTENT)
@@ -105,7 +106,7 @@ def set_asset_state_by_name(
) -> None:
"""Set an asset state value by asset name."""
asset_id = _resolve_asset_id_by_name(name, session)
- get_state_backend().set(AssetScope(asset_id=asset_id), key, body.value,
session=session)
+ get_state_backend().set(AssetScope(asset_id=asset_id), key,
json.dumps(body.value), session=session)
@router.delete("/by-name/value", status_code=status.HTTP_204_NO_CONTENT)
@@ -143,7 +144,7 @@ def get_asset_state_by_uri(
status_code=status.HTTP_404_NOT_FOUND,
detail={"reason": "not_found", "message": f"Asset state key
{key!r} not found"},
)
- return AssetStateResponse(value=value)
+ return AssetStateResponse(value=json.loads(value))
@router.put("/by-uri/value", status_code=status.HTTP_204_NO_CONTENT)
@@ -155,7 +156,7 @@ def set_asset_state_by_uri(
) -> None:
"""Set an asset state value by asset URI."""
asset_id = _resolve_asset_id_by_uri(uri, session)
- get_state_backend().set(AssetScope(asset_id=asset_id), key, body.value,
session=session)
+ get_state_backend().set(AssetScope(asset_id=asset_id), key,
json.dumps(body.value), session=session)
@router.delete("/by-uri/value", status_code=status.HTTP_204_NO_CONTENT)
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py
index 2f824e3ebb2..c59f2461e2a 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py
@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
+import json
from typing import Annotated
from uuid import UUID
@@ -74,7 +75,7 @@ def get_task_state(
"message": f"Task state key {key!r} not found",
},
)
- return TaskStateResponse(value=value)
+ return TaskStateResponse(value=json.loads(value))
@router.put("/{task_instance_id}/{key}",
status_code=status.HTTP_204_NO_CONTENT)
@@ -86,7 +87,7 @@ def set_task_state(
) -> None:
"""Set a task state key, creating or updating the row."""
scope = _get_task_scope_for_ti(task_instance_id, session)
- get_state_backend().set(scope, key, body.value,
expires_at=body.expires_at, session=session)
+ get_state_backend().set(scope, key, json.dumps(body.value),
expires_at=body.expires_at, session=session)
@router.delete("/{task_instance_id}/{key}",
status_code=status.HTTP_204_NO_CONTENT)
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_16.py
b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_16.py
index 779612bbde1..cd2a6861d11 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_16.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2026_06_16.py
@@ -60,7 +60,7 @@ class AddAssetsByAliasEndpoint(VersionChange):
class AddStateEndpoints(VersionChange):
- """Add task state and asset state CRUD endpoints."""
+ """Add task state and asset state API endpoints."""
description = __doc__
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_asset_state.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_asset_state.py
index 6041d01e7f1..c91171aa05e 100644
---
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_asset_state.py
+++
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_asset_state.py
@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
+import json
from typing import TYPE_CHECKING
import pytest
@@ -113,7 +114,37 @@ class TestPutAssetStateByName:
)
)
assert row is not None
- assert row.value == "2026-04-29"
+ # DB stores JSON-encoded string
+ assert row.value == '"2026-04-29"'
+
+ def test_put_int_value_roundtrip(self, client: TestClient, asset:
AssetModel):
+ response = client.put(
+ _BY_NAME_VALUE, params={"name": asset.name, "key": "total_runs"},
json={"value": 5}
+ )
+ assert response.status_code == 204
+ assert client.get(_BY_NAME_VALUE, params={"name": asset.name, "key":
"total_runs"}).json() == {
+ "value": 5
+ }
+
+ def test_put_dict_value_roundtrip(self, client: TestClient, asset:
AssetModel):
+ response = client.put(
+ _BY_NAME_VALUE,
+ params={"name": asset.name, "key": "last_run"},
+ json={"value": {"rows": 1234, "status": "ok"}},
+ )
+ assert response.status_code == 204
+ assert client.get(_BY_NAME_VALUE, params={"name": asset.name, "key":
"last_run"}).json() == {
+ "value": {"rows": 1234, "status": "ok"}
+ }
+
+ def test_put_list_value_roundtrip(self, client: TestClient, asset:
AssetModel):
+ response = client.put(
+ _BY_NAME_VALUE, params={"name": asset.name, "key": "ids"},
json={"value": [1, 2, 3]}
+ )
+ assert response.status_code == 204
+ assert client.get(_BY_NAME_VALUE, params={"name": asset.name, "key":
"ids"}).json() == {
+ "value": [1, 2, 3]
+ }
def test_put_overwrites_existing(self, client: TestClient, asset:
AssetModel):
client.put(
@@ -134,6 +165,22 @@ class TestPutAssetStateByName:
assert response.status_code == 422
+ def test_put_null_value_returns_422(self, client: TestClient, asset:
AssetModel):
+ response = client.put(
+ _BY_NAME_VALUE, params={"name": asset.name, "key": "watermark"},
json={"value": None}
+ )
+ assert response.status_code == 422
+
+ @pytest.mark.parametrize("bad_float", [float("nan"), float("inf"),
float("-inf")])
+ def test_put_non_finite_float_returns_422(self, client: TestClient, asset:
AssetModel, bad_float: float):
+ with pytest.raises(ValueError, match="Out of range float values are
not JSON compliant"):
+ _ = client.put(
+ _BY_NAME_VALUE,
+ params={"name": asset.name, "key": "watermark"},
+ content=json.dumps({"value": bad_float},
allow_nan=True).encode(),
+ headers={"Content-Type": "application/json"},
+ )
+
def test_put_unknown_asset_returns_404(self, client: TestClient):
response = client.put(
_BY_NAME_VALUE, params={"name": "nonexistent", "key":
"watermark"}, json={"value": "x"}
@@ -208,7 +255,7 @@ class TestPutAssetStateByUri:
)
)
assert row is not None
- assert row.value == "2026-04-29"
+ assert row.value == '"2026-04-29"'
def test_put_unknown_uri_returns_404(self, client: TestClient):
response = client.put(
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_state.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_state.py
index d83751050e7..97acc576a50 100644
---
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_state.py
+++
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_state.py
@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
+import json
from datetime import datetime
from typing import TYPE_CHECKING
from uuid import uuid4
@@ -95,7 +96,37 @@ class TestPutTaskState:
)
)
assert row is not None
- assert row.value == "spark_001"
+ # DB stores a json string
+ assert row.value == '"spark_001"'
+
+ def test_put_int_value_roundtrip(self, client: TestClient,
create_task_instance: CreateTaskInstance):
+ ti = create_task_instance()
+
+ response = client.put(_api_url(ti.id, "retry_count"), json={"value":
3})
+
+ assert response.status_code == 204
+ assert client.get(_api_url(ti.id, "retry_count")).json() == {"value":
3}
+
+ def test_put_dict_value_roundtrip(self, client: TestClient,
create_task_instance: CreateTaskInstance):
+ ti = create_task_instance()
+
+ response = client.put(
+ _api_url(ti.id, "poll_result"),
+ json={"value": {"status": "succeeded", "rows": 1234}},
+ )
+
+ assert response.status_code == 204
+ assert client.get(_api_url(ti.id, "poll_result")).json() == {
+ "value": {"status": "succeeded", "rows": 1234}
+ }
+
+ def test_put_list_value_roundtrip(self, client: TestClient,
create_task_instance: CreateTaskInstance):
+ ti = create_task_instance()
+
+ response = client.put(_api_url(ti.id, "checkpoints"), json={"value":
[1, 2, 3]})
+
+ assert response.status_code == 204
+ assert client.get(_api_url(ti.id, "checkpoints")).json() == {"value":
[1, 2, 3]}
def test_put_with_expires_at_creates_row(
self, client: TestClient, create_task_instance: CreateTaskInstance,
time_machine
@@ -122,7 +153,7 @@ class TestPutTaskState:
)
)
assert row is not None
- assert row.value == "spark_001"
+ assert row.value == '"spark_001"'
assert row.expires_at == datetime(2026, 5, 15, 12, 0, 0,
tzinfo=pendulum.UTC)
def test_put_overwrites_existing(self, client: TestClient,
create_task_instance: CreateTaskInstance):
@@ -155,6 +186,18 @@ class TestPutTaskState:
assert response.status_code == 422
+ @pytest.mark.parametrize("bad_float", [float("nan"), float("inf"),
float("-inf")])
+ def test_put_non_finite_float_returns_422(
+ self, client: TestClient, create_task_instance: CreateTaskInstance,
bad_float: float
+ ):
+ ti = create_task_instance()
+ with pytest.raises(ValueError, match="Out of range float values are
not JSON compliant"):
+ _ = client.put(
+ _api_url(ti.id, "job_id"),
+ content=json.dumps({"value": bad_float},
allow_nan=True).encode(),
+ headers={"Content-Type": "application/json"},
+ )
+
def test_put_missing_ti_returns_404(self, client: TestClient):
response = client.put(_api_url(uuid4(), "job_id"), json={"value": "x"})
diff --git a/shared/state/pyproject.toml b/shared/state/pyproject.toml
index f317eba1995..3d16d465726 100644
--- a/shared/state/pyproject.toml
+++ b/shared/state/pyproject.toml
@@ -23,7 +23,9 @@ classifiers = [
"Private :: Do Not Upload",
]
-dependencies = []
+dependencies = [
+ "pydantic>=2.11.0",
+]
[dependency-groups]
dev = [
diff --git a/shared/state/src/airflow_shared/state/__init__.py
b/shared/state/src/airflow_shared/state/__init__.py
index 7aa9fcba837..688cd6a6301 100644
--- a/shared/state/src/airflow_shared/state/__init__.py
+++ b/shared/state/src/airflow_shared/state/__init__.py
@@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations
+import json
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING
@@ -23,6 +24,7 @@ from typing import TYPE_CHECKING
if TYPE_CHECKING:
from datetime import datetime
+ from pydantic import JsonValue
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
@@ -96,9 +98,10 @@ class BaseStateBackend(ABC):
@abstractmethod
def get(self, scope: StateScope, key: str, *, session: Session | None =
None) -> str | None:
"""
- Return the stored value, or None if the key does not exist.
+ Return the stored JSON encoded value string, or None if the key does
not exist.
- Must handle both ``TaskScope`` and ``AssetScope``.
+ Must handle both ``TaskScope`` and ``AssetScope``. The execution API
calls
+ ``json.loads`` on the returned string from here, so it must be a valid
JSON document.
"""
@abstractmethod
@@ -112,9 +115,11 @@ class BaseStateBackend(ABC):
session: Session | None = None,
) -> None:
"""
- Write or overwrite the value for the given key.
+ Write or overwrite ``value`` for the given key.
- Must handle both ``TaskScope`` and ``AssetScope``.
+ Must handle both ``TaskScope`` and ``AssetScope``. ``value`` is always
a
+ JSON encoded string (the execution API calls ``json.dumps`` before
passing it
+ here); store it verbatim so ``get`` can return it unchanged.
``expires_at`` is an absolute UTC datetime after which the row may be
deleted.
Pass ``None`` (default) for a key that should never expire — stored as
``NULL``,
@@ -147,10 +152,10 @@ class BaseStateBackend(ABC):
@abstractmethod
async def aget(self, scope: StateScope, key: str, *, session: AsyncSession
| None = None) -> str | None:
"""
- Async variant of get. Must handle both ``TaskScope`` and
``AssetScope``.
+ Async variant of ``get`` which returns a JSON encoded value string or
None.
- ``session`` is optional. If provided, implementations should use it
directly.
- If ``None``, implementations manage their own async session internally.
+ Must handle both ``TaskScope`` and ``AssetScope``. ``session`` is used
directly
+ when provided; otherwise implementations manage their own session
internally.
"""
@abstractmethod
@@ -164,10 +169,10 @@ class BaseStateBackend(ABC):
session: AsyncSession | None = None,
) -> None:
"""
- Async variant of set. Must handle both ``TaskScope`` and
``AssetScope``.
+ Async variant of ``set``. ``value`` is always a JSON encoded string.
- ``session`` is optional. If provided, implementations should use it
directly.
- If ``None``, implementations manage their own async session internally.
+ Must handle both ``TaskScope`` and ``AssetScope``. ``session`` is used
directly
+ when provided; otherwise implementations manage their own session
internally.
"""
@abstractmethod
@@ -203,7 +208,7 @@ class BaseStateBackend(ABC):
``[state_store] default_retention_days``) and deciding what to delete.
"""
- def serialize_task_state_to_ref(self, *, value: str, key: str, ti_id: str)
-> str:
+ def serialize_task_state_to_ref(self, *, value: JsonValue, key: str,
ti_id: str) -> str:
"""
Serialize a task state value before it is sent to the execution API
for db persistence.
@@ -214,20 +219,21 @@ class BaseStateBackend(ABC):
The returned reference must be deterministic — given the same
``ti_id`` and ``key`` it
must always return the same string. Do not use timestamps or random
UUIDs as part of
the reference, otherwise ``delete()``/``clear()`` cannot reconstruct
it and the external
- object will be orphaned.
+ object will be orphaned. By default, it JSON dumps the value and
returns a JSON string.
"""
- return value
+ return json.dumps(value)
- def deserialize_task_state_from_ref(self, stored: str) -> str:
+ def deserialize_task_state_from_ref(self, stored: str) -> JsonValue:
"""
- Resolve a stored task state string back to the actual value.
+ Resolve a stored task state reference back to the actual value.
Called by ``TaskStateAccessor.get()`` after the stored string is
retrieved from
- the execution API. Default: return ``stored`` unchanged.
+ the execution API. By default, it JSON decodes ``stored`` to reverse
the default
+ ``serialize_task_state_to_ref`` encoding.
"""
- return stored
+ return json.loads(stored)
- def serialize_asset_state_to_ref(self, *, value: str, key: str, asset_ref:
str) -> str:
+ def serialize_asset_state_to_ref(self, *, value: JsonValue, key: str,
asset_ref: str) -> str:
"""
Serialize an asset state value before it is sent to the Execution API
for db persistence.
@@ -241,15 +247,16 @@ class BaseStateBackend(ABC):
The returned reference must be deterministic — given the same
``asset_ref`` and ``key`` it
must always return the same string. Do not use timestamps or random
UUIDs as part of
the reference, otherwise ``delete()``/``clear()`` cannot reconstruct
it and the external
- object will be orphaned.
+ object will be orphaned. By default, it JSON dumps the value and
returns a JSON string.
"""
- return value
+ return json.dumps(value)
- def deserialize_asset_state_from_ref(self, stored: str) -> str:
+ def deserialize_asset_state_from_ref(self, stored: str) -> JsonValue:
"""
- Resolve a stored asset state string back to the actual value.
+ Resolve a stored asset state reference back to the actual value.
Called by ``AssetStateAccessor.get()`` after the stored string is
retrieved from
- the Execution API. Default: return ``stored`` unchanged.
+ the Execution API. By default, it JSON decodes ``stored`` to reverse
the default
+ ``serialize_asset_state_to_ref`` encoding.
"""
- return stored
+ return json.loads(stored)
diff --git a/shared/state/tests/state/test_state.py
b/shared/state/tests/state/test_state.py
index 1ea31194e27..eb658ff8c74 100644
--- a/shared/state/tests/state/test_state.py
+++ b/shared/state/tests/state/test_state.py
@@ -92,6 +92,18 @@ class TestBaseStateBackend:
deserialized = backend.deserialize_task_state_from_ref(serialized)
assert deserialized == original
+ def test_task_state_serialize_deserialize_typed_values(self, backend):
+ """Default backend passes typed values through unchanged (custom
backends handle storage)."""
+ assert (
+ backend.deserialize_task_state_from_ref(
+ backend.serialize_task_state_to_ref(value=42, key="count",
ti_id="abc-123")
+ )
+ == 42
+ )
+ assert backend.deserialize_task_state_from_ref(
+ backend.serialize_task_state_to_ref(value={"status": "ok"},
key="result", ti_id="abc-123")
+ ) == {"status": "ok"}
+
def test_custom_backend_overrides_task_state_ser_deser(self):
class MyBackend(BaseStateBackend):
def get(self, scope, key): ...
@@ -126,6 +138,17 @@ class TestBaseStateBackend:
deserialized = backend.deserialize_asset_state_from_ref(serialized)
assert deserialized == original
+ def test_asset_state_serialize_deserialize_typed_values(self, backend):
+ assert (
+ backend.deserialize_asset_state_from_ref(
+ backend.serialize_asset_state_to_ref(value=5,
key="total_runs", asset_ref="my_asset")
+ )
+ == 5
+ )
+ assert backend.deserialize_asset_state_from_ref(
+ backend.serialize_asset_state_to_ref(value={"rows": 1234},
key="last_run", asset_ref="my_asset")
+ ) == {"rows": 1234}
+
def test_custom_backend_overrides_asset_state_ser_deser(self):
class MyBackend(BaseStateBackend):
def get(self, scope, key): ...
diff --git a/task-sdk/src/airflow/sdk/api/client.py
b/task-sdk/src/airflow/sdk/api/client.py
index 99b1aadb37f..1da539f29a3 100644
--- a/task-sdk/src/airflow/sdk/api/client.py
+++ b/task-sdk/src/airflow/sdk/api/client.py
@@ -32,7 +32,7 @@ import msgspec
import structlog
from opentelemetry import trace
from opentelemetry.trace.propagation.tracecontext import
TraceContextTextMapPropagator
-from pydantic import BaseModel
+from pydantic import BaseModel, JsonValue
from tenacity import (
before_log,
retry,
@@ -721,7 +721,7 @@ class TaskStateOperations:
raise
return TaskStateResponse.model_validate_json(resp.read())
- def set(self, ti_id: uuid.UUID, key: str, value: str, expires_at: datetime
| None) -> OKResponse:
+ def set(self, ti_id: uuid.UUID, key: str, value: JsonValue, expires_at:
datetime | None) -> OKResponse:
"""Set a task state value via the API server."""
body = TaskStatePutBody(value=value, expires_at=expires_at)
self.client.put(f"state/ti/{ti_id}/{key}",
content=body.model_dump_json())
@@ -774,7 +774,9 @@ class AssetStateOperations:
raise
return AssetStateResponse.model_validate_json(resp.read())
- def set(self, key: str, value: str, *, name: str | None = None, uri: str |
None = None) -> OKResponse:
+ def set(
+ self, key: str, value: JsonValue, *, name: str | None = None, uri: str
| None = None
+ ) -> OKResponse:
"""Set an asset state value via the API server."""
endpoint, params = self._resolve_endpoint("value", key=key, name=name,
uri=uri)
self.client.put(endpoint, params=params,
content=AssetStatePutBody(value=value).model_dump_json())
diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
index fc966a76969..62c43ac17d1 100644
--- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -63,28 +63,6 @@ class AssetProfile(BaseModel):
type: Annotated[str, Field(title="Type")]
-class AssetStatePutBody(BaseModel):
- """
- Request body for setting an asset state value.
- """
-
- model_config = ConfigDict(
- extra="forbid",
- )
- value: Annotated[str, Field(title="Value")]
-
-
-class AssetStateResponse(BaseModel):
- """
- Asset state value returned to a worker.
- """
-
- model_config = ConfigDict(
- extra="forbid",
- )
- value: Annotated[str, Field(title="Value")]
-
-
class ConnectionResponse(BaseModel):
"""
Connection schema for responses with fields that are needed for Runtime.
@@ -375,7 +353,7 @@ class TaskStatePutBody(BaseModel):
model_config = ConfigDict(
extra="forbid",
)
- value: Annotated[str, Field(title="Value")]
+ value: JsonValue
expires_at: Annotated[AwareDatetime | None, Field(title="Expires At")] =
None
@@ -387,7 +365,7 @@ class TaskStateResponse(BaseModel):
model_config = ConfigDict(
extra="forbid",
)
- value: Annotated[str, Field(title="Value")]
+ value: JsonValue
class TaskStatesResponse(BaseModel):
@@ -596,6 +574,28 @@ class AssetResponse(BaseModel):
extra: Annotated[dict[str, JsonValue] | None, Field(title="Extra")] = None
+class AssetStatePutBody(BaseModel):
+ """
+ Request body for setting an asset state value.
+ """
+
+ model_config = ConfigDict(
+ extra="forbid",
+ )
+ value: JsonValue
+
+
+class AssetStateResponse(BaseModel):
+ """
+ Asset state value returned to a worker.
+ """
+
+ model_config = ConfigDict(
+ extra="forbid",
+ )
+ value: JsonValue
+
+
class HITLDetailRequest(BaseModel):
"""
Schema for the request part of a Human-in-the-loop detail for a specific
task instance.
diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py
b/task-sdk/src/airflow/sdk/execution_time/comms.py
index 2364e942ed0..7b494cab835 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -923,7 +923,7 @@ class GetTaskState(BaseModel):
class SetTaskState(BaseModel):
ti_id: UUID
key: str
- value: str
+ value: JsonValue
expires_at: AwareDatetime | None
type: Literal["SetTaskState"] = "SetTaskState"
@@ -955,14 +955,14 @@ class GetAssetStateByUri(BaseModel):
class SetAssetStateByName(BaseModel):
name: str
key: str
- value: str
+ value: JsonValue
type: Literal["SetAssetStateByName"] = "SetAssetStateByName"
class SetAssetStateByUri(BaseModel):
uri: str
key: str
- value: str
+ value: JsonValue
type: Literal["SetAssetStateByUri"] = "SetAssetStateByUri"
diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py
b/task-sdk/src/airflow/sdk/execution_time/context.py
index 14922780da4..cba613da85a 100644
--- a/task-sdk/src/airflow/sdk/execution_time/context.py
+++ b/task-sdk/src/airflow/sdk/execution_time/context.py
@@ -495,12 +495,19 @@ class TaskStateAccessor:
# is not implemented yet cos it's unclear whether task state values will be
# used in templates.
- def get(self, key: str) -> str | None:
- """Return the stored value, or ``None`` if the key does not exist."""
+ def get(self, key: str) -> JsonValue:
+ """
+ Return the stored value, or ``None`` if the key does not exist.
+
+ Supported types: ``str``, ``int``, ``float``, ``bool``, ``list``,
``dict``.
+ ``datetime`` is not JSON-serializable; store it as
``value.isoformat()`` and
+ parse it back with ``datetime.fromisoformat(result)``.
+ """
from airflow.sdk.execution_time.comms import ErrorResponse,
GetTaskState, TaskStateResult
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
resp = SUPERVISOR_COMMS.send(GetTaskState(ti_id=self._ti_id, key=key))
+
if isinstance(resp, ErrorResponse) and resp.error !=
ErrorType.TASK_STATE_NOT_FOUND:
raise AirflowRuntimeError(resp)
if isinstance(resp, TaskStateResult):
@@ -508,10 +515,15 @@ class TaskStateAccessor:
# if custom backend is configured, the stored value in DB is a
reference, fetch the actual value from
# custom backend using the reference
backend = _get_worker_state_backend()
- return backend.deserialize_task_state_from_ref(stored) if backend
else stored
+ if backend is not None:
+ # serialize_task_state_to_ref always returns str by contract;
stored contains the ref.
+ if TYPE_CHECKING:
+ assert isinstance(stored, str)
+ return backend.deserialize_task_state_from_ref(stored)
+ return stored
return None
- def set(self, key: str, value: str, *, retention: timedelta | None = None)
-> None:
+ def set(self, key: str, value: JsonValue, *, retention: timedelta | None =
None) -> None:
"""
Write or overwrite the value for the given key.
@@ -614,7 +626,7 @@ class AssetStateAccessor:
return f"<AssetStateAccessor name={self._name!r}>"
return f"<AssetStateAccessor uri={self._uri!r}>"
- def get(self, key: str) -> str | None:
+ def get(self, key: str) -> JsonValue:
"""Return the stored value, or ``None`` if the key does not exist."""
from airflow.sdk.execution_time.comms import (
AssetStateResult,
@@ -635,13 +647,16 @@ class AssetStateAccessor:
raise AirflowRuntimeError(resp)
if isinstance(resp, AssetStateResult):
stored = resp.value
- # if custom backend is configured, the stored value in DB is a
reference, fetch the actual value from
- # custom backend using the reference
backend = _get_worker_state_backend()
- return backend.deserialize_asset_state_from_ref(stored) if backend
else stored
+ if backend is not None:
+ # serialize_asset_state_to_ref always returns str by contract;
stored contains the ref.
+ if TYPE_CHECKING:
+ assert isinstance(stored, str)
+ return backend.deserialize_asset_state_from_ref(stored)
+ return stored
return None
- def set(self, key: str, value: str) -> None:
+ def set(self, key: str, value: JsonValue) -> None:
"""Write or overwrite the value for the given key."""
from airflow.sdk.execution_time.comms import SetAssetStateByName,
SetAssetStateByUri, ToSupervisor
from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
@@ -756,11 +771,11 @@ class AssetStateAccessors:
return next(iter(self._by_name.values()))
return next(iter(self._by_uri.values()))
- def get(self, key: str) -> str | None:
+ def get(self, key: str) -> JsonValue:
"""Return the stored value for the single-inlet task, or ``None`` if
not found."""
return self._single_accessor().get(key)
- def set(self, key: str, value: str) -> None:
+ def set(self, key: str, value: JsonValue) -> None:
"""Write or overwrite the value for the single-inlet task."""
self._single_accessor().set(key, value)
diff --git a/task-sdk/src/airflow/sdk/execution_time/schema/schema.json
b/task-sdk/src/airflow/sdk/execution_time/schema/schema.json
index d4eb3d9c5a8..fec9596e493 100644
--- a/task-sdk/src/airflow/sdk/execution_time/schema/schema.json
+++ b/task-sdk/src/airflow/sdk/execution_time/schema/schema.json
@@ -317,6 +317,9 @@
"type": "object"
},
"AssetStateResult": {
+ "$defs": {
+ "JsonValue": {}
+ },
"additionalProperties": false,
"description": "Response to GetAssetState; wraps the generated API
response for supervisor to worker comms.",
"properties": {
@@ -327,8 +330,7 @@
"type": "string"
},
"value": {
- "title": "Value",
- "type": "string"
+ "$ref": "#/$defs/JsonValue"
}
},
"required": [
@@ -4549,6 +4551,9 @@
"type": "object"
},
"SetAssetStateByName": {
+ "$defs": {
+ "JsonValue": {}
+ },
"properties": {
"key": {
"title": "Key",
@@ -4565,8 +4570,7 @@
"type": "string"
},
"value": {
- "title": "Value",
- "type": "string"
+ "$ref": "#/$defs/JsonValue"
}
},
"required": [
@@ -4578,6 +4582,9 @@
"type": "object"
},
"SetAssetStateByUri": {
+ "$defs": {
+ "JsonValue": {}
+ },
"properties": {
"key": {
"title": "Key",
@@ -4594,8 +4601,7 @@
"type": "string"
},
"value": {
- "title": "Value",
- "type": "string"
+ "$ref": "#/$defs/JsonValue"
}
},
"required": [
@@ -4653,6 +4659,9 @@
"type": "object"
},
"SetTaskState": {
+ "$defs": {
+ "JsonValue": {}
+ },
"properties": {
"expires_at": {
"anyOf": [
@@ -4682,8 +4691,7 @@
"type": "string"
},
"value": {
- "title": "Value",
- "type": "string"
+ "$ref": "#/$defs/JsonValue"
}
},
"required": [
@@ -5773,6 +5781,9 @@
"type": "object"
},
"TaskStateResult": {
+ "$defs": {
+ "JsonValue": {}
+ },
"additionalProperties": false,
"description": "Response to GetTaskState; wraps the generated API
response for supervisor to worker comms.",
"properties": {
@@ -5783,8 +5794,7 @@
"type": "string"
},
"value": {
- "title": "Value",
- "type": "string"
+ "$ref": "#/$defs/JsonValue"
}
},
"required": [
diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py
b/task-sdk/tests/task_sdk/execution_time/test_context.py
index 1763604e477..35c74278748 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_context.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_context.py
@@ -18,11 +18,13 @@
from __future__ import annotations
from datetime import datetime, timedelta, timezone as dt_timezone
+from typing import TYPE_CHECKING
from unittest import mock
from unittest.mock import MagicMock, patch
from uuid import UUID
import pytest
+from pydantic import ValidationError
from airflow.sdk import BaseOperator, get_current_context, timezone
from airflow.sdk._shared.state import TaskScope
@@ -102,6 +104,9 @@ from airflow.sdk.state import BaseStateBackend
from tests_common.test_utils.config import conf_vars
+if TYPE_CHECKING:
+ from pydantic import JsonValue
+
def test_convert_connection_result_conn():
"""Test that the ConnectionResult is converted to a Connection object."""
@@ -1210,6 +1215,16 @@ class TestTaskStateAccessor:
ClearTaskState(ti_id=self.TI_ID, all_map_indices=True)
)
+ def test_set_datetime_raises_validation_error(self, mock_supervisor_comms):
+ """datetime is not JSON-serializable; callers must use .isoformat()
first."""
+ with pytest.raises(ValidationError):
+ TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).set(
+ "watermark",
+ datetime(2026, 5, 15, tzinfo=dt_timezone.utc),
+ )
+
+ mock_supervisor_comms.send.assert_not_called()
+
class TestAssetStateAccessor:
ASSET_NAME = "debug_watcher_asset"
@@ -1417,23 +1432,23 @@ class InMemoryStateBackend(BaseStateBackend):
self._actual_key_value_store: dict[str, str] = {} # key -> actual
value
self.reference: dict[str, str] = {} # key -> stored ref (mem:// URI)
- def serialize_task_state_to_ref(self, *, value: str, key: str, ti_id: str)
-> str:
+ def serialize_task_state_to_ref(self, *, value, key: str, ti_id: str) ->
str:
ref = f"mem://{ti_id}/{key}"
self._actual_key_value_store[key] = value
self.reference[key] = ref
return ref
- def deserialize_task_state_from_ref(self, stored: str) -> str:
+ def deserialize_task_state_from_ref(self, stored: str) -> JsonValue:
key = stored.rsplit("/", 1)[-1]
return self._actual_key_value_store.get(key, stored)
- def serialize_asset_state_to_ref(self, *, value: str, key: str, asset_ref:
str) -> str:
+ def serialize_asset_state_to_ref(self, *, value, key: str, asset_ref: str)
-> str:
ref = f"mem://{asset_ref}/{key}"
self._actual_key_value_store[key] = value
self.reference[key] = ref
return ref
- def deserialize_asset_state_from_ref(self, stored: str) -> str:
+ def deserialize_asset_state_from_ref(self, stored: str) -> JsonValue:
key = stored.rsplit("/", 1)[-1]
return self._actual_key_value_store.get(key, stored)
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 56900fbadab..4ba821e537e 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -5263,6 +5263,40 @@ class TestTaskInstanceStateOperations:
)
mock_supervisor_comms.send.assert_any_call(GetTaskState(ti_id=runtime_ti.id,
key="job_id"))
+ def test_task_state_set_sends_typed_values(self, create_runtime_ti,
mock_supervisor_comms, time_machine):
+ """set() accepts any JsonValue — dict, int, list — not just strings."""
+
+ class MyOperator(BaseOperator):
+ def execute(self, context):
+ ts = context["task_state"]
+ ts.set("retry_count", 3)
+ ts.set("poll_result", {"status": "succeeded", "rows": 1234})
+ ts.set("checkpoints", [1, 2, 3])
+
+ frozen_dt = datetime(2026, 1, 1, 12, 0, 0, tzinfo=dt_timezone.utc)
+ time_machine.move_to(frozen_dt, tick=False)
+ task = MyOperator(task_id="t")
+ runtime_ti = create_runtime_ti(task=task)
+
+ with conf_vars({("state_store", "default_retention_days"): "30"}):
+ run(runtime_ti, context=runtime_ti.get_template_context(),
log=mock.MagicMock())
+
+ expires_at = frozen_dt + timedelta(days=30)
+ mock_supervisor_comms.send.assert_any_call(
+ SetTaskState(ti_id=runtime_ti.id, key="retry_count", value=3,
expires_at=expires_at)
+ )
+ mock_supervisor_comms.send.assert_any_call(
+ SetTaskState(
+ ti_id=runtime_ti.id,
+ key="poll_result",
+ value={"status": "succeeded", "rows": 1234},
+ expires_at=expires_at,
+ )
+ )
+ mock_supervisor_comms.send.assert_any_call(
+ SetTaskState(ti_id=runtime_ti.id, key="checkpoints", value=[1, 2,
3], expires_at=expires_at)
+ )
+
def test_task_can_set_state_with_retention(self, create_runtime_ti,
mock_supervisor_comms, time_machine):
class MyOperator(BaseOperator):
def execute(self, context):
diff --git a/uv.lock b/uv.lock
index c1e290951ca..8cdd633364c 100644
--- a/uv.lock
+++ b/uv.lock
@@ -8528,6 +8528,9 @@ mypy = [{ name = "apache-airflow-devel-common", extras =
["mypy"], editable = "d
name = "apache-airflow-shared-state"
version = "0.0"
source = { editable = "shared/state" }
+dependencies = [
+ { name = "pydantic" },
+]
[package.dev-dependencies]
dev = [
@@ -8538,6 +8541,7 @@ mypy = [
]
[package.metadata]
+requires-dist = [{ name = "pydantic", specifier = ">=2.11.0" }]
[package.metadata.requires-dev]
dev = [{ name = "apache-airflow-devel-common", editable = "devel-common" }]