This is an automated email from the ASF dual-hosted git repository.

amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new ec2d56a473d AIP-103: Worker side custom state backend support (#66859)
ec2d56a473d is described below

commit ec2d56a473d6a919814788f7a0e3fff28b33c615
Author: Amogh Desai <[email protected]>
AuthorDate: Wed May 20 16:06:38 2026 +0530

    AIP-103: Worker side custom state backend support (#66859)
---
 .../src/airflow/config_templates/config.yml        |  11 ++
 shared/state/src/airflow_shared/state/__init__.py  |  71 ++++++-
 shared/state/tests/state/test_state.py             |  83 ++++++++-
 task-sdk/pyproject.toml                            |   3 +
 task-sdk/src/airflow/sdk/_shared/state             |   1 +
 task-sdk/src/airflow/sdk/api/client.py             |   3 +-
 task-sdk/src/airflow/sdk/configuration.py          |   3 +-
 task-sdk/src/airflow/sdk/execution_time/context.py |  97 +++++++++-
 .../src/airflow/sdk/execution_time/task_runner.py  |  13 +-
 task-sdk/src/airflow/sdk/state.py                  |  25 +++
 task-sdk/tests/task_sdk/docs/test_public_api.py    |   1 +
 .../tests/task_sdk/execution_time/test_context.py  | 207 +++++++++++++++++++--
 .../task_sdk/execution_time/test_task_runner.py    |  98 +++++++++-
 13 files changed, 588 insertions(+), 28 deletions(-)

diff --git a/airflow-core/src/airflow/config_templates/config.yml 
b/airflow-core/src/airflow/config_templates/config.yml
index f93642dd8ed..0dfbbbda6c4 100644
--- a/airflow-core/src/airflow/config_templates/config.yml
+++ b/airflow-core/src/airflow/config_templates/config.yml
@@ -1899,6 +1899,17 @@ workers:
       sensitive: true
       example: ~
       default: ""
+    state_backend:
+      description: |
+        Full class name of a custom worker-side state backend. When set, task 
state values are
+        routed through this backend so large payloads or credentialed storage 
stay on worker
+        infrastructure. The Execution API still records a reference string in 
the database.
+
+        Leave empty (default) to use the standard path through the task sdk 
supervisor.
+      version_added: 3.3.0
+      type: string
+      example: "mypackage.state.S3StateBackend"
+      default: ""
     min_heartbeat_interval:
       description: |
         The minimum interval (in seconds) at which the worker checks the task 
instance's
diff --git a/shared/state/src/airflow_shared/state/__init__.py 
b/shared/state/src/airflow_shared/state/__init__.py
index bfce8db3328..7aa9fcba837 100644
--- a/shared/state/src/airflow_shared/state/__init__.py
+++ b/shared/state/src/airflow_shared/state/__init__.py
@@ -47,9 +47,25 @@ class TaskScope:
 
 @dataclass(frozen=True)
 class AssetScope:
-    """Identifies the state namespace for an asset."""
+    """
+    Identifies the state namespace for an asset.
+
+    Server-side backends receive ``asset_id``. Worker-side backends receive 
``name`` or ``uri``
+    since workers do not have access to the integer ``asset_id``.
+
+    Note: ``name`` and ``uri`` are not guaranteed to be unique over time — if 
an asset is
+    deactivated and a new one created with the same name, both share the same 
``name`` value.
+    State for inactive assets is cleaned up by the orphan GC pass; until then, 
stale rows exist
+    in the DB but cannot be written to (the Execution API resolver filters to 
active assets only).
+    """
+
+    asset_id: int | None = None
+    name: str | None = None
+    uri: str | None = None
 
-    asset_id: int
+    def __post_init__(self) -> None:
+        if self.asset_id is None and self.name is None and self.uri is None:
+            raise ValueError("AssetScope requires at least one of: asset_id, 
name, or uri")
 
 
 StateScope = TaskScope | AssetScope
@@ -186,3 +202,54 @@ class BaseStateBackend(ABC):
         retention policy. The backend is responsible for reading any relevant 
config (e.g.
         ``[state_store] default_retention_days``) and deciding what to delete.
         """
+
+    def serialize_task_state_to_ref(self, *, value: str, key: str, ti_id: str) 
-> str:
+        """
+        Serialize a task state value before it is sent to the execution API 
for db persistence.
+
+        Called by ``TaskStateAccessor.set()`` on the worker. The return value 
is what gets
+        stored in the DB — typically a reference path (e.g. an S3 key) rather 
than the
+        actual value. Default: return ``value`` unchanged.
+
+        The returned reference must be deterministic — given the same 
``ti_id`` and ``key`` it
+        must always return the same string. Do not use timestamps or random 
UUIDs as part of
+        the reference, otherwise ``delete()``/``clear()`` cannot reconstruct 
it and the external
+        object will be orphaned.
+        """
+        return value
+
+    def deserialize_task_state_from_ref(self, stored: str) -> str:
+        """
+        Resolve a stored task state string back to the actual value.
+
+        Called by ``TaskStateAccessor.get()`` after the stored string is 
retrieved from
+        the execution API. Default: return ``stored`` unchanged.
+        """
+        return stored
+
+    def serialize_asset_state_to_ref(self, *, value: str, key: str, asset_ref: 
str) -> str:
+        """
+        Serialize an asset state value before it is sent to the Execution API 
for db persistence.
+
+        Called by ``AssetStateAccessor.set()`` on the worker. The return value 
is what gets
+        stored in the DB — typically a reference path rather than the actual 
value.
+        Default: return ``value`` unchanged.
+
+        ``asset_ref`` is either the asset name or URI, depending on how the 
accessor was
+        constructed. It may be a URI string if the task inlet was declared as 
``AssetUriRef``.
+
+        The returned reference must be deterministic — given the same 
``asset_ref`` and ``key`` it
+        must always return the same string. Do not use timestamps or random 
UUIDs as part of
+        the reference, otherwise ``delete()``/``clear()`` cannot reconstruct 
it and the external
+        object will be orphaned.
+        """
+        return value
+
+    def deserialize_asset_state_from_ref(self, stored: str) -> str:
+        """
+        Resolve a stored asset state string back to the actual value.
+
+        Called by ``AssetStateAccessor.get()`` after the stored string is 
retrieved from
+        the Execution API. Default: return ``stored`` unchanged.
+        """
+        return stored
diff --git a/shared/state/tests/state/test_state.py 
b/shared/state/tests/state/test_state.py
index 47bce18a69e..1ea31194e27 100644
--- a/shared/state/tests/state/test_state.py
+++ b/shared/state/tests/state/test_state.py
@@ -18,7 +18,22 @@ from __future__ import annotations
 
 import pytest
 
-from airflow_shared.state import BaseStateBackend, StateScope
+from airflow_shared.state import AssetScope, BaseStateBackend, StateScope
+
+
+class TestAssetScope:
+    def test_requires_at_least_one_identifier(self):
+        with pytest.raises(ValueError, match="at least one of"):
+            AssetScope()
+
+    def test_asset_id_alone_is_valid(self):
+        AssetScope(asset_id=1)
+
+    def test_name_alone_is_valid(self):
+        AssetScope(name="my_asset")
+
+    def test_uri_alone_is_valid(self):
+        AssetScope(uri="s3://bucket/key")
 
 
 class TestBaseStateBackend:
@@ -70,3 +85,69 @@ class TestBaseStateBackend:
         """BaseStateBackend enforces all 8 sync+async methods as abstract."""
         expected = {"get", "set", "delete", "clear", "aget", "aset", 
"adelete", "aclear"}
         assert BaseStateBackend.__abstractmethods__ == expected
+
+    def test_task_state_serialize_deserialize_round_trip(self, backend):
+        original = "app_1234"
+        serialized = backend.serialize_task_state_to_ref(value=original, 
key="job_id", ti_id="abc-123")
+        deserialized = backend.deserialize_task_state_from_ref(serialized)
+        assert deserialized == original
+
+    def test_custom_backend_overrides_task_state_ser_deser(self):
+        class MyBackend(BaseStateBackend):
+            def get(self, scope, key): ...
+            def set(self, scope, key, value): ...
+            def delete(self, scope, key): ...
+            def clear(self, scope, *, all_map_indices=False): ...
+            async def aget(self, scope, key): ...
+            async def aset(self, scope, key, value): ...
+            async def adelete(self, scope, key): ...
+            async def aclear(self, scope, *, all_map_indices=False): ...
+
+            def serialize_task_state_to_ref(self, *, value, key, ti_id):
+                return f"s3://bucket/{ti_id}/{key}"
+
+            def deserialize_task_state_from_ref(self, stored):
+                return f"fetched:{stored}"
+
+        b = MyBackend()
+        assert b.serialize_task_state_to_ref(value="app_1234", key="job_id", 
ti_id="abc-123") == (
+            "s3://bucket/abc-123/job_id"
+        )
+        assert (
+            b.deserialize_task_state_from_ref("s3://bucket/abc-123/job_id")
+            == "fetched:s3://bucket/abc-123/job_id"
+        )
+
+    def test_asset_state_serialize_deserialize_round_trip(self, backend):
+        original = "2026-05-01"
+        serialized = backend.serialize_asset_state_to_ref(
+            value="2026-05-01", key="watermark", asset_ref="my_asset"
+        )
+        deserialized = backend.deserialize_asset_state_from_ref(serialized)
+        assert deserialized == original
+
+    def test_custom_backend_overrides_asset_state_ser_deser(self):
+        class MyBackend(BaseStateBackend):
+            def get(self, scope, key): ...
+            def set(self, scope, key, value): ...
+            def delete(self, scope, key): ...
+            def clear(self, scope, *, all_map_indices=False): ...
+            async def aget(self, scope, key): ...
+            async def aset(self, scope, key, value): ...
+            async def adelete(self, scope, key): ...
+            async def aclear(self, scope, *, all_map_indices=False): ...
+
+            def serialize_asset_state_to_ref(self, *, value, key, asset_ref):
+                return f"s3://bucket/assets/{asset_ref}/{key}"
+
+            def deserialize_asset_state_from_ref(self, stored):
+                return f"resolved:{stored}"
+
+        b = MyBackend()
+        assert b.serialize_asset_state_to_ref(value="2026-05-01", 
key="watermark", asset_ref="my_asset") == (
+            "s3://bucket/assets/my_asset/watermark"
+        )
+        assert (
+            
b.deserialize_asset_state_from_ref("s3://bucket/assets/my_asset/watermark")
+            == "resolved:s3://bucket/assets/my_asset/watermark"
+        )
diff --git a/task-sdk/pyproject.toml b/task-sdk/pyproject.toml
index 89a17fb52c3..4fc9f3d586a 100644
--- a/task-sdk/pyproject.toml
+++ b/task-sdk/pyproject.toml
@@ -147,6 +147,7 @@ path = "src/airflow/sdk/__init__.py"
 "../shared/listeners/src/airflow_shared/listeners" = 
"src/airflow/sdk/_shared/listeners"
 "../shared/plugins_manager/src/airflow_shared/plugins_manager" = 
"src/airflow/sdk/_shared/plugins_manager"
 "../shared/providers_discovery/src/airflow_shared/providers_discovery" = 
"src/airflow/sdk/_shared/providers_discovery"
+"../shared/state/src/airflow_shared/state" = "src/airflow/sdk/_shared/state"
 "../shared/template_rendering/src/airflow_shared/template_rendering" = 
"src/airflow/sdk/_shared/template_rendering"
 
 [tool.hatch.build.targets.wheel]
@@ -240,6 +241,7 @@ apache-airflow = {workspace = true}
 apache-airflow-devel-common = {workspace = true}
 apache-airflow-providers-common-sql = {workspace = true}
 apache-airflow-providers-standard = {workspace = true}
+apache-airflow-shared-state = {workspace = true}
 
 # To use:
 #
@@ -316,6 +318,7 @@ shared_distributions = [
     "apache-airflow-shared-secrets-backend",
     "apache-airflow-shared-secrets-masker",
     "apache-airflow-shared-serialization",
+    "apache-airflow-shared-state",
     "apache-airflow-shared-timezones",
     "apache-airflow-shared-observability",
     "apache-airflow-shared-plugins-manager",
diff --git a/task-sdk/src/airflow/sdk/_shared/state 
b/task-sdk/src/airflow/sdk/_shared/state
new file mode 120000
index 00000000000..752da632206
--- /dev/null
+++ b/task-sdk/src/airflow/sdk/_shared/state
@@ -0,0 +1 @@
+../../../../../shared/state/src/airflow_shared/state
\ No newline at end of file
diff --git a/task-sdk/src/airflow/sdk/api/client.py 
b/task-sdk/src/airflow/sdk/api/client.py
index 4c8c565fe43..99b1aadb37f 100644
--- a/task-sdk/src/airflow/sdk/api/client.py
+++ b/task-sdk/src/airflow/sdk/api/client.py
@@ -93,6 +93,7 @@ from airflow.sdk.execution_time.comms import (
     OKResponse,
     PreviousDagRunResult,
     PreviousTIResult,
+    RescheduleTask,
     SkipDownstreamTasks,
     TaskRescheduleStartDate,
     TICount,
@@ -104,8 +105,6 @@ if TYPE_CHECKING:
     from datetime import datetime
     from typing import ParamSpec
 
-    from airflow.sdk.execution_time.comms import RescheduleTask
-
     P = ParamSpec("P")
     T = TypeVar("T")
 
diff --git a/task-sdk/src/airflow/sdk/configuration.py 
b/task-sdk/src/airflow/sdk/configuration.py
index 4e438a2cbf7..fb32f990c58 100644
--- a/task-sdk/src/airflow/sdk/configuration.py
+++ b/task-sdk/src/airflow/sdk/configuration.py
@@ -32,6 +32,7 @@ from airflow.sdk._shared.configuration.parser import (
     configure_parser_from_configuration_description,
     expand_env_var,
 )
+from airflow.sdk._shared.module_loading import import_string
 from airflow.sdk.execution_time.secrets import 
_SERVER_DEFAULT_SECRETS_SEARCH_PATH
 
 log = logging.getLogger(__name__)
@@ -236,8 +237,6 @@ def initialize_secrets_backends(
 
     Uses SDK's conf instead of Core's conf.
     """
-    from airflow.sdk._shared.module_loading import import_string
-
     backend_list = []
     worker_mode = False
     # Determine worker mode - if default_backends is not the server default, 
it's worker mode
diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py 
b/task-sdk/src/airflow/sdk/execution_time/context.py
index 1e6874121fc..14922780da4 100644
--- a/task-sdk/src/airflow/sdk/execution_time/context.py
+++ b/task-sdk/src/airflow/sdk/execution_time/context.py
@@ -56,6 +56,7 @@ if TYPE_CHECKING:
     from typing_extensions import Self
 
     from airflow.sdk import Variable
+    from airflow.sdk._shared.state import TaskScope
     from airflow.sdk.bases.operator import BaseOperator
     from airflow.sdk.definitions.connection import Connection
     from airflow.sdk.definitions.context import Context
@@ -70,6 +71,7 @@ if TYPE_CHECKING:
         ReceiveMsgType,
         VariableResult,
     )
+    from airflow.sdk.state import BaseStateBackend
     from airflow.sdk.types import OutletEventAccessorsProtocol
 
 
@@ -454,11 +456,29 @@ class VariableAccessor:
             raise
 
 
+@cache
+def _get_worker_state_backend() -> BaseStateBackend | None:
+    """Return the configured worker-side state backend, instantiated once and 
cached."""
+    class_name = conf.get("workers", "state_backend", fallback="")
+    if not class_name:
+        return None
+    from airflow.sdk._shared.module_loading import import_string
+
+    try:
+        return import_string(class_name)()
+    except (ImportError, AttributeError) as e:
+        raise ValueError(
+            f"Could not load worker state backend {class_name!r}. "
+            f"Check the [workers] state_backend config value. Error: {e}"
+        ) from e
+
+
 class TaskStateAccessor:
     """Accessor for task state scoped to the current task instance. Available 
as ``context['task_state']`` at task execution time."""
 
-    def __init__(self, ti_id: UUID) -> None:
+    def __init__(self, ti_id: UUID, scope: TaskScope) -> None:
         self._ti_id = ti_id
+        self._scope = scope
 
     def __eq__(self, other: object) -> bool:
         if not isinstance(other, TaskStateAccessor):
@@ -484,7 +504,11 @@ class TaskStateAccessor:
         if isinstance(resp, ErrorResponse) and resp.error != 
ErrorType.TASK_STATE_NOT_FOUND:
             raise AirflowRuntimeError(resp)
         if isinstance(resp, TaskStateResult):
-            return resp.value
+            stored = resp.value
+            # if custom backend is configured, the stored value in DB is a 
reference, fetch the actual value from
+            # custom backend using the reference
+            backend = _get_worker_state_backend()
+            return backend.deserialize_task_state_from_ref(stored) if backend 
else stored
         return None
 
     def set(self, key: str, value: str, *, retention: timedelta | None = None) 
-> None:
@@ -509,14 +533,29 @@ class TaskStateAccessor:
         else:
             days = conf.getint("state_store", "default_retention_days")
             expires_at = None if days <= 0 else now + timedelta(days=days)
-        SUPERVISOR_COMMS.send(SetTaskState(ti_id=self._ti_id, key=key, 
value=value, expires_at=expires_at))
+
+        # if custom backend is configured, store the value on the custom 
backend, and return the reference
+        # to the stored value to store in the DB
+        backend = _get_worker_state_backend()
+        stored = (
+            backend.serialize_task_state_to_ref(value=value, key=key, 
ti_id=str(self._ti_id))
+            if backend
+            else value
+        )
+
+        SUPERVISOR_COMMS.send(SetTaskState(ti_id=self._ti_id, key=key, 
value=stored, expires_at=expires_at))
 
     def delete(self, key: str) -> None:
         """Delete a single key. No-op if the key does not exist."""
         from airflow.sdk.execution_time.comms import DeleteTaskState
         from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
 
+        # cleanup the DB ref first, if backend cleanup fails after this, the 
ref is gone and
+        # deterministic keys are recoverable on next set().
         SUPERVISOR_COMMS.send(DeleteTaskState(ti_id=self._ti_id, key=key))
+        backend = _get_worker_state_backend()
+        if backend is not None:
+            backend.delete(self._scope, key)
 
     def clear(self, all_map_indices: bool = False) -> None:
         """
@@ -529,7 +568,23 @@ class TaskStateAccessor:
         from airflow.sdk.execution_time.comms import ClearTaskState
         from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
 
+        # cleanup the DB ref first, if backend cleanup fails after this, the 
ref is gone and
+        # deterministic keys are recoverable on next set().
         SUPERVISOR_COMMS.send(ClearTaskState(ti_id=self._ti_id, 
all_map_indices=all_map_indices))
+        backend = _get_worker_state_backend()
+        if backend is not None:
+            backend.clear(self._scope, all_map_indices=all_map_indices)
+
+    def _clear_backend_only(self) -> None:
+        """
+        Clear external storage via the worker backend without sending a comms 
message.
+
+        Used by clear_on_success: the server already clears DB rows as part of 
SucceedTask,
+        so the comms round-trip is redundant.
+        """
+        backend = _get_worker_state_backend()
+        if backend is not None:
+            backend.clear(self._scope)
 
 
 class AssetStateAccessor:
@@ -579,7 +634,11 @@ class AssetStateAccessor:
         if isinstance(resp, ErrorResponse) and resp.error != 
ErrorType.ASSET_STATE_NOT_FOUND:
             raise AirflowRuntimeError(resp)
         if isinstance(resp, AssetStateResult):
-            return resp.value
+            stored = resp.value
+            # if custom backend is configured, the stored value in DB is a 
reference, fetch the actual value from
+            # custom backend using the reference
+            backend = _get_worker_state_backend()
+            return backend.deserialize_asset_state_from_ref(stored) if backend 
else stored
         return None
 
     def set(self, key: str, value: str) -> None:
@@ -587,15 +646,26 @@ class AssetStateAccessor:
         from airflow.sdk.execution_time.comms import SetAssetStateByName, 
SetAssetStateByUri, ToSupervisor
         from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
 
+        # if custom backend is configured, store the value on the custom 
backend, and return the reference
+        # to the stored value to store in the DB
+        backend = _get_worker_state_backend()
+        asset_ref = self._name or self._uri or ""
+        stored = (
+            backend.serialize_asset_state_to_ref(value=value, key=key, 
asset_ref=asset_ref)
+            if backend
+            else value
+        )
+
         msg: ToSupervisor
         if self._name:
-            msg = SetAssetStateByName(name=self._name, key=key, value=value)
+            msg = SetAssetStateByName(name=self._name, key=key, value=stored)
         elif self._uri:
-            msg = SetAssetStateByUri(uri=self._uri, key=key, value=value)
+            msg = SetAssetStateByUri(uri=self._uri, key=key, value=stored)
         SUPERVISOR_COMMS.send(msg)
 
     def delete(self, key: str) -> None:
         """Delete a single key. No-op if the key does not exist."""
+        from airflow.sdk._shared.state import AssetScope
         from airflow.sdk.execution_time.comms import (
             DeleteAssetStateByName,
             DeleteAssetStateByUri,
@@ -608,11 +678,21 @@ class AssetStateAccessor:
             msg = DeleteAssetStateByName(name=self._name, key=key)
         elif self._uri:
             msg = DeleteAssetStateByUri(uri=self._uri, key=key)
+        # DB ref first: if backend cleanup fails after this, the ref is gone 
and
+        # deterministic keys are recoverable on next set().
         SUPERVISOR_COMMS.send(msg)
+        backend = _get_worker_state_backend()
+        if backend is not None:
+            backend.delete(AssetScope(name=self._name, uri=self._uri), key)
 
     def clear(self) -> None:
         """Delete all state keys for this asset."""
-        from airflow.sdk.execution_time.comms import ClearAssetStateByName, 
ClearAssetStateByUri, ToSupervisor
+        from airflow.sdk._shared.state import AssetScope
+        from airflow.sdk.execution_time.comms import (
+            ClearAssetStateByName,
+            ClearAssetStateByUri,
+            ToSupervisor,
+        )
         from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
 
         msg: ToSupervisor
@@ -621,6 +701,9 @@ class AssetStateAccessor:
         elif self._uri:
             msg = ClearAssetStateByUri(uri=self._uri)
         SUPERVISOR_COMMS.send(msg)
+        backend = _get_worker_state_backend()
+        if backend is not None:
+            backend.clear(AssetScope(name=self._name, uri=self._uri))
 
 
 class AssetStateAccessors:
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 852b5da9cc6..10977fb011b 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -134,6 +134,7 @@ from airflow.sdk.execution_time.sentry import Sentry
 from airflow.sdk.execution_time.xcom import XCom
 from airflow.sdk.listener import get_listener_manager
 from airflow.sdk.observability.metrics import stats_utils
+from airflow.sdk.state import TaskScope
 from airflow.sdk.timezone import coerce_datetime
 
 if TYPE_CHECKING:
@@ -260,7 +261,15 @@ class RuntimeTaskInstance(TaskInstance):
                     "value": VariableAccessor(deserialize_json=False),
                 },
                 "conn": ConnectionAccessor(),
-                "task_state": TaskStateAccessor(ti_id=self.id),
+                "task_state": TaskStateAccessor(
+                    ti_id=self.id,
+                    scope=TaskScope(
+                        dag_id=self.dag_id,
+                        run_id=self.run_id,
+                        task_id=self.task_id,
+                        map_index=self.map_index if self.map_index is not None 
else -1,
+                    ),
+                ),
             }
             if any(isinstance(i, (Asset, AssetNameRef, AssetUriRef, 
AssetAlias)) for i in self.task.inlets):
                 self._cached_template_context["asset_state"] = 
AssetStateAccessors(self.task.inlets)
@@ -1492,6 +1501,8 @@ def _handle_current_task_success(
     if conf.getboolean("state_store", "clear_on_success"):
         log.info("Task state will be cleared by the server because 
clear_on_success is enabled.")
 
+        context["task_state"]._clear_backend_only()
+
     msg = SucceedTask(
         end_date=end_date,
         task_outlets=task_outlets,
diff --git a/task-sdk/src/airflow/sdk/state.py 
b/task-sdk/src/airflow/sdk/state.py
new file mode 100644
index 00000000000..ac2a1126fe4
--- /dev/null
+++ b/task-sdk/src/airflow/sdk/state.py
@@ -0,0 +1,25 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from airflow.sdk._shared.state import (
+    AssetScope as AssetScope,
+    BaseStateBackend as BaseStateBackend,
+    TaskScope as TaskScope,
+)
diff --git a/task-sdk/tests/task_sdk/docs/test_public_api.py 
b/task-sdk/tests/task_sdk/docs/test_public_api.py
index 98391927f8a..a21424ea101 100644
--- a/task-sdk/tests/task_sdk/docs/test_public_api.py
+++ b/task-sdk/tests/task_sdk/docs/test_public_api.py
@@ -65,6 +65,7 @@ def test_airflow_sdk_no_unexpected_exports():
         "providers_manager_runtime",
         "lineage",
         "types",
+        "state",
     }
     unexpected = actual - public - ignore
     assert not unexpected, f"Unexpected exports in airflow.sdk: 
{sorted(unexpected)}"
diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py 
b/task-sdk/tests/task_sdk/execution_time/test_context.py
index 062645d25b1..1763604e477 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_context.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_context.py
@@ -25,7 +25,12 @@ from uuid import UUID
 import pytest
 
 from airflow.sdk import BaseOperator, get_current_context, timezone
-from airflow.sdk.api.datamodels._generated import AssetEventResponse, 
AssetResponse, DagRun
+from airflow.sdk._shared.state import TaskScope
+from airflow.sdk.api.datamodels._generated import (
+    AssetEventResponse,
+    AssetResponse,
+    DagRun,
+)
 from airflow.sdk.bases.xcom import BaseXCom
 from airflow.sdk.definitions.asset import (
     Asset,
@@ -93,6 +98,7 @@ from airflow.sdk.execution_time.context import (
     set_current_context,
 )
 from airflow.sdk.execution_time.secrets import ExecutionAPISecretsBackend
+from airflow.sdk.state import BaseStateBackend
 
 from tests_common.test_utils.config import conf_vars
 
@@ -1092,11 +1098,12 @@ class TestSecretsBackend:
 
 class TestTaskStateAccessor:
     TI_ID = UUID("01900000-0000-0000-0000-000000000001")
+    SCOPE = TaskScope(dag_id="dag", run_id="run", task_id="task")
 
     def test_get_returns_value(self, mock_supervisor_comms):
         mock_supervisor_comms.send.return_value = 
TaskStateResult(value="app_001")
 
-        result = TaskStateAccessor(ti_id=self.TI_ID).get("job_id")
+        result = TaskStateAccessor(ti_id=self.TI_ID, 
scope=self.SCOPE).get("job_id")
 
         assert result == "app_001"
         
mock_supervisor_comms.send.assert_called_once_with(GetTaskState(ti_id=self.TI_ID,
 key="job_id"))
@@ -1106,7 +1113,7 @@ class TestTaskStateAccessor:
             error=ErrorType.TASK_STATE_NOT_FOUND, detail={"key": "missing_key"}
         )
 
-        result = TaskStateAccessor(ti_id=self.TI_ID).get("missing_key")
+        result = TaskStateAccessor(ti_id=self.TI_ID, 
scope=self.SCOPE).get("missing_key")
 
         assert result is None
 
@@ -1116,7 +1123,7 @@ class TestTaskStateAccessor:
         )
 
         with pytest.raises(AirflowRuntimeError):
-            TaskStateAccessor(ti_id=self.TI_ID).get("some_key")
+            TaskStateAccessor(ti_id=self.TI_ID, 
scope=self.SCOPE).get("some_key")
 
     def test_set_operation_with_global_retention(self, mock_supervisor_comms, 
time_machine):
         """set() with no retention uses global default_retention_days 
config."""
@@ -1126,7 +1133,7 @@ class TestTaskStateAccessor:
         time_machine.move_to(now, tick=False)
 
         with conf_vars({("state_store", "default_retention_days"): "30"}):
-            TaskStateAccessor(ti_id=self.TI_ID).set("job_id", "app_001")
+            TaskStateAccessor(ti_id=self.TI_ID, 
scope=self.SCOPE).set("job_id", "app_001")
 
         mock_supervisor_comms.send.assert_called_once_with(
             SetTaskState(
@@ -1143,7 +1150,9 @@ class TestTaskStateAccessor:
         now = datetime(2026, 5, 14, 12, 0, 0, tzinfo=dt_timezone.utc)
         time_machine.move_to(now, tick=False)
 
-        TaskStateAccessor(ti_id=self.TI_ID).set("job_id", "app_001", 
retention=timedelta(days=7))
+        TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).set(
+            "job_id", "app_001", retention=timedelta(days=7)
+        )
 
         mock_supervisor_comms.send.assert_called_once_with(
             SetTaskState(
@@ -1159,7 +1168,7 @@ class TestTaskStateAccessor:
 
         mock_supervisor_comms.send.return_value = OKResponse(ok=True)
 
-        TaskStateAccessor(ti_id=self.TI_ID).set("job_id", "app_001", 
retention=NEVER_EXPIRE)
+        TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).set("job_id", 
"app_001", retention=NEVER_EXPIRE)
 
         mock_supervisor_comms.send.assert_called_once_with(
             SetTaskState(ti_id=self.TI_ID, key="job_id", value="app_001", 
expires_at=None)
@@ -1170,7 +1179,7 @@ class TestTaskStateAccessor:
         mock_supervisor_comms.send.return_value = OKResponse(ok=True)
 
         with conf_vars({("state_store", "default_retention_days"): "0"}):
-            TaskStateAccessor(ti_id=self.TI_ID).set("job_id", "app_001")
+            TaskStateAccessor(ti_id=self.TI_ID, 
scope=self.SCOPE).set("job_id", "app_001")
 
         mock_supervisor_comms.send.assert_called_once_with(
             SetTaskState(ti_id=self.TI_ID, key="job_id", value="app_001", 
expires_at=None)
@@ -1179,14 +1188,14 @@ class TestTaskStateAccessor:
     def test_delete_operation(self, mock_supervisor_comms):
         mock_supervisor_comms.send.return_value = OKResponse(ok=True)
 
-        TaskStateAccessor(ti_id=self.TI_ID).delete("job_id")
+        TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).delete("job_id")
 
         
mock_supervisor_comms.send.assert_called_once_with(DeleteTaskState(ti_id=self.TI_ID,
 key="job_id"))
 
     def test_clear_default_sends_all_map_indices_false(self, 
mock_supervisor_comms):
         mock_supervisor_comms.send.return_value = OKResponse(ok=True)
 
-        TaskStateAccessor(ti_id=self.TI_ID).clear()
+        TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).clear()
 
         mock_supervisor_comms.send.assert_called_once_with(
             ClearTaskState(ti_id=self.TI_ID, all_map_indices=False)
@@ -1195,7 +1204,7 @@ class TestTaskStateAccessor:
     def test_clear_all_map_indices_sends_flag_true(self, 
mock_supervisor_comms):
         mock_supervisor_comms.send.return_value = OKResponse(ok=True)
 
-        TaskStateAccessor(ti_id=self.TI_ID).clear(all_map_indices=True)
+        TaskStateAccessor(ti_id=self.TI_ID, 
scope=self.SCOPE).clear(all_map_indices=True)
 
         mock_supervisor_comms.send.assert_called_once_with(
             ClearTaskState(ti_id=self.TI_ID, all_map_indices=True)
@@ -1399,3 +1408,179 @@ class TestAssetStateAccessors:
         accessors = AssetStateAccessors([alias])
 
         assert accessors._total == 0
+
+
+class InMemoryStateBackend(BaseStateBackend):
+    """Simple in-memory test backend."""
+
+    def __init__(self):
+        self._actual_key_value_store: dict[str, str] = {}  # key -> actual 
value
+        self.reference: dict[str, str] = {}  # key -> stored ref (mem:// URI)
+
+    def serialize_task_state_to_ref(self, *, value: str, key: str, ti_id: str) 
-> str:
+        ref = f"mem://{ti_id}/{key}"
+        self._actual_key_value_store[key] = value
+        self.reference[key] = ref
+        return ref
+
+    def deserialize_task_state_from_ref(self, stored: str) -> str:
+        key = stored.rsplit("/", 1)[-1]
+        return self._actual_key_value_store.get(key, stored)
+
+    def serialize_asset_state_to_ref(self, *, value: str, key: str, asset_ref: 
str) -> str:
+        ref = f"mem://{asset_ref}/{key}"
+        self._actual_key_value_store[key] = value
+        self.reference[key] = ref
+        return ref
+
+    def deserialize_asset_state_from_ref(self, stored: str) -> str:
+        key = stored.rsplit("/", 1)[-1]
+        return self._actual_key_value_store.get(key, stored)
+
+    def get(self, scope, key, *, session=None): ...
+    def set(self, scope, key, value, *, session=None): ...
+
+    def delete(self, scope, key, *, session=None) -> None:
+        self._actual_key_value_store.pop(key, None)
+        self.reference.pop(key, None)
+
+    def clear(self, scope, *, all_map_indices=False, session=None) -> None:
+        self._actual_key_value_store.clear()
+        self.reference.clear()
+
+    async def aget(self, scope, key): ...
+    async def aset(self, scope, key, value): ...
+    async def adelete(self, scope, key): ...
+    async def aclear(self, scope, *, all_map_indices=False): ...
+
+
+class TestTaskStateAccessorWithCustomBackend:
+    TI_ID = UUID("01900000-0000-0000-0000-000000000002")
+    SCOPE = TaskScope(dag_id="dag", run_id="run", task_id="task")
+
+    @pytest.fixture(autouse=True)
+    def backend(self):
+        b = InMemoryStateBackend()
+        with mock.patch(
+            "airflow.sdk.execution_time.context._get_worker_state_backend",
+            return_value=b,
+        ):
+            yield b
+
+    def test_set_returns_reference_to_storage(self, mock_supervisor_comms, 
backend, time_machine):
+        """set() stores actual value in backend and sends mem:// reference via 
comms."""
+        mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+        expected_ref = f"mem://{self.TI_ID}/job_id"
+
+        frozen_dt = datetime(2026, 1, 1, 12, 0, 0, tzinfo=dt_timezone.utc)
+        time_machine.move_to(frozen_dt, tick=False)
+
+        TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).set("job_id", 
"app_001")
+        # comms message has the mem:// reference, not the actual value
+        mock_supervisor_comms.send.assert_called_once_with(
+            SetTaskState(
+                ti_id=self.TI_ID, key="job_id", value=expected_ref, 
expires_at=frozen_dt + timedelta(days=30)
+            )
+        )
+        # actual value is stored on the backend, reference is stored for DB
+        assert backend._actual_key_value_store["job_id"] == "app_001"
+        assert backend.reference["job_id"] == expected_ref
+
+    def test_get_resolves_reference_to_actual_value(self, 
mock_supervisor_comms, backend):
+        """get() fetches mem:// reference from DB, resolves it to actual value 
via backend."""
+        ref = f"mem://{self.TI_ID}/job_id"
+        backend._actual_key_value_store["job_id"] = "app_001"
+        mock_supervisor_comms.send.return_value = TaskStateResult(value=ref)
+
+        result = TaskStateAccessor(ti_id=self.TI_ID, 
scope=self.SCOPE).get("job_id")
+        # actual value is resolved from mem:// reference via backend
+        assert result == "app_001"
+
+    def test_deletes_from_backend_and_removes_db_ref(self, 
mock_supervisor_comms, backend):
+        """delete() purges from backend storage and removes the DB 
reference."""
+        backend._actual_key_value_store["job_id"] = "app_001"
+        mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+
+        TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).delete("job_id")
+
+        # backend does not have the value anymore
+        assert "job_id" not in backend._actual_key_value_store
+        # request to delete reference in DB was made
+        
mock_supervisor_comms.send.assert_any_call(DeleteTaskState(ti_id=self.TI_ID, 
key="job_id"))
+
+    def test_clears_all_from_backend_and_clears_db(self, 
mock_supervisor_comms, backend):
+        """clear() purges all backend objects for the TI and removes all DB 
references."""
+        backend._actual_key_value_store["job_id"] = "app_001"
+        backend._actual_key_value_store["checkpoint"] = "step_3"
+        mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+
+        TaskStateAccessor(ti_id=self.TI_ID, scope=self.SCOPE).clear()
+
+        assert "job_id" not in backend._actual_key_value_store
+        assert "checkpoint" not in backend._actual_key_value_store
+        
mock_supervisor_comms.send.assert_any_call(ClearTaskState(ti_id=self.TI_ID, 
all_map_indices=False))
+
+
+class TestAssetStateAccessorWithCustomBackend:
+    ASSET_NAME = "my_asset"
+
+    @pytest.fixture(autouse=True)
+    def backend(self):
+        b = InMemoryStateBackend()
+        with mock.patch(
+            "airflow.sdk.execution_time.context._get_worker_state_backend",
+            return_value=b,
+        ):
+            yield b
+
+    def test_set_sends_reference_not_value(self, mock_supervisor_comms, 
backend):
+        """set() stores actual value in backend and sends mem:// reference via 
comms."""
+        mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+
+        AssetStateAccessor(name=self.ASSET_NAME).set("watermark", "2026-05-01")
+
+        expected_ref = f"mem://{self.ASSET_NAME}/watermark"
+        # comms message has the mem:// reference, not the actual value
+        mock_supervisor_comms.send.assert_called_once_with(
+            SetAssetStateByName(name=self.ASSET_NAME, key="watermark", 
value=expected_ref)
+        )
+        # actual value is stored on the backend, reference is stored for DB
+        assert backend._actual_key_value_store["watermark"] == "2026-05-01"
+        assert backend.reference["watermark"] == expected_ref
+
+    def test_get_resolves_reference_to_actual_value(self, 
mock_supervisor_comms, backend):
+        """get() fetches mem:// reference from DB, resolves it to actual value 
via backend."""
+        ref = f"mem://{self.ASSET_NAME}/watermark"
+        backend._actual_key_value_store["watermark"] = "2026-05-01"
+        mock_supervisor_comms.send.return_value = AssetStateResult(value=ref)
+
+        result = AssetStateAccessor(name=self.ASSET_NAME).get("watermark")
+
+        # actual value is resolved from mem:// reference via backend
+        assert result == "2026-05-01"
+
+    def test_delete_purges_from_backend_and_removes_db_ref(self, 
mock_supervisor_comms, backend):
+        """delete() purges from backend storage and removes the DB 
reference."""
+        backend._actual_key_value_store["watermark"] = "2026-05-01"
+        mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+
+        AssetStateAccessor(name=self.ASSET_NAME).delete("watermark")
+
+        # backend doesn't have the value anymore
+        assert "watermark" not in backend._actual_key_value_store
+        # request to delete reference in DB was made
+        mock_supervisor_comms.send.assert_any_call(
+            DeleteAssetStateByName(name=self.ASSET_NAME, key="watermark")
+        )
+
+    def test_clear_purges_all_from_backend_and_clears_db(self, 
mock_supervisor_comms, backend):
+        """clear() purges all backend objects and removes all DB references."""
+        backend._actual_key_value_store["watermark"] = "2026-05-01"
+        backend._actual_key_value_store["file_count"] = "42"
+        mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+
+        AssetStateAccessor(name=self.ASSET_NAME).clear()
+
+        assert "watermark" not in backend._actual_key_value_store
+        assert "file_count" not in backend._actual_key_value_store
+        
mock_supervisor_comms.send.assert_any_call(ClearAssetStateByName(name=self.ASSET_NAME))
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py 
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 9425ce362ed..b37d2569ea4 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -55,6 +55,7 @@ from airflow.sdk import (
     timezone,
 )
 from airflow.sdk._shared.observability.metrics.base_stats_logger import 
StatsLogger
+from airflow.sdk._shared.state import TaskScope
 from airflow.sdk.api.datamodels._generated import (
     AssetProfile,
     AssetResponse,
@@ -1899,7 +1900,9 @@ class TestRuntimeTaskInstance:
             "run_id": "test_run",
             "task": task,
             "task_instance": runtime_ti,
-            "task_state": TaskStateAccessor(ti_id=ti_id),
+            "task_state": TaskStateAccessor(
+                ti_id=ti_id, scope=TaskScope(dag_id=dag_id, run_id="test_run", 
task_id="hello")
+            ),
             "ti": runtime_ti,
         }
 
@@ -1945,7 +1948,10 @@ class TestRuntimeTaskInstance:
             "run_id": "test_run",
             "task": task,
             "task_instance": runtime_ti,
-            "task_state": TaskStateAccessor(ti_id=runtime_ti.id),
+            "task_state": TaskStateAccessor(
+                ti_id=runtime_ti.id,
+                scope=TaskScope(dag_id=runtime_ti.dag_id, run_id="test_run", 
task_id="hello"),
+            ),
             "ti": runtime_ti,
             "dag_run": dr,
             "data_interval_end": timezone.datetime(2024, 12, 1, 1, 0, 0),
@@ -5372,3 +5378,91 @@ class TestTaskInstanceStateOperations:
         mock_supervisor_comms.send.assert_any_call(
             SetAssetStateByName(name="asset_b", key="watermark_b", 
value="2026-05-02")
         )
+
+    def test_asset_state_set_sends_reference_via_custom_backend(
+        self, create_runtime_ti, mock_supervisor_comms
+    ):
+        """When a worker backend is configured, asset state set() sends a 
reference, not the actual value."""
+        watched = Asset(name="my_asset", uri="s3://bucket/data")
+
+        class WatcherOperator(BaseOperator):
+            def execute(self, context):
+                context["asset_state"].set("watermark", "2026-05-01")
+
+        task = WatcherOperator(task_id="t", inlets=[watched])
+        runtime_ti = create_runtime_ti(task=task)
+        mock_supervisor_comms.send.side_effect = 
TestTaskInstanceStateOperations._watcher_side_effect
+
+        mock_backend = mock.MagicMock()
+        mock_backend.serialize_asset_state_to_ref.return_value = 
"mem://my_asset/watermark"
+
+        with mock.patch(
+            "airflow.sdk.execution_time.context._get_worker_state_backend", 
return_value=mock_backend
+        ):
+            run(runtime_ti, context=runtime_ti.get_template_context(), 
log=mock.MagicMock())
+
+        mock_backend.serialize_asset_state_to_ref.assert_called_once_with(
+            value="2026-05-01", key="watermark", asset_ref="my_asset"
+        )
+        mock_supervisor_comms.send.assert_any_call(
+            SetAssetStateByName(name="my_asset", key="watermark", 
value="mem://my_asset/watermark")
+        )
+
+    def test_task_state_set_sends_reference_via_custom_backend(
+        self, create_runtime_ti, mock_supervisor_comms, time_machine
+    ):
+        """When a worker backend is configured, task state set() sends a 
reference, not the actual value."""
+
+        class MyOperator(BaseOperator):
+            def execute(self, context):
+                context["task_state"].set("job_id", "app_001")
+
+        frozen_dt = datetime(2026, 1, 1, 12, 0, 0, tzinfo=dt_timezone.utc)
+        time_machine.move_to(frozen_dt, tick=False)
+        task = MyOperator(task_id="t")
+        runtime_ti = create_runtime_ti(task=task)
+        mock_supervisor_comms.send.side_effect = 
TestTaskInstanceStateOperations._watcher_side_effect
+
+        mock_backend = mock.MagicMock()
+        ref = f"mem://{runtime_ti.id}/job_id"
+        mock_backend.serialize_task_state_to_ref.return_value = ref
+
+        with mock.patch(
+            "airflow.sdk.execution_time.context._get_worker_state_backend", 
return_value=mock_backend
+        ):
+            run(runtime_ti, context=runtime_ti.get_template_context(), 
log=mock.MagicMock())
+
+        mock_backend.serialize_task_state_to_ref.assert_called_once_with(
+            value="app_001", key="job_id", ti_id=str(runtime_ti.id)
+        )
+        mock_supervisor_comms.send.assert_any_call(
+            SetTaskState(
+                ti_id=runtime_ti.id, key="job_id", value=ref, 
expires_at=frozen_dt + timedelta(days=30)
+            )
+        )
+
+    @conf_vars({("state_store", "clear_on_success"): "True"})
+    def test_clear_on_success_clears_backend_without_comms_roundtrip(
+        self, create_runtime_ti, mock_supervisor_comms
+    ):
+        """clear_on_success calls backend.clear() directly without sending 
ClearTaskState comms."""
+        mock_backend = mock.MagicMock()
+
+        class MyOperator(BaseOperator):
+            def execute(self, context):
+                pass
+
+        task = MyOperator(task_id="t")
+        runtime_ti = create_runtime_ti(task=task)
+
+        with mock.patch(
+            "airflow.sdk.execution_time.context._get_worker_state_backend", 
return_value=mock_backend
+        ):
+            run(runtime_ti, context=runtime_ti.get_template_context(), 
log=mock.MagicMock())
+
+        mock_backend.clear.assert_called_once()
+        sent_types = [
+            type(call.kwargs.get("msg") or (call.args[0] if call.args else 
None))
+            for call in mock_supervisor_comms.send.call_args_list
+        ]
+        assert ClearTaskState not in sent_types


Reply via email to