This is an automated email from the ASF dual-hosted git repository.

amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 008cbe90e2a AIP-103: Adding ability for per task state key retention 
from operators (#66699)
008cbe90e2a is described below

commit 008cbe90e2a8c3fc7e315ee61de0223854e88ec0
Author: Amogh Desai <[email protected]>
AuthorDate: Tue May 19 11:59:45 2026 +0530

    AIP-103: Adding ability for per task state key retention from operators 
(#66699)
---
 .../execution_api/datamodels/task_state.py         |  3 ++
 .../api_fastapi/execution_api/routes/task_state.py |  2 +-
 airflow-core/src/airflow/state/metastore.py        | 56 ++++++++++++--------
 .../execution_api/versions/head/test_task_state.py | 30 +++++++++++
 airflow-core/tests/unit/state/test_metastore.py    | 19 +++++--
 shared/state/src/airflow_shared/state/__init__.py  | 24 ++++++++-
 task-sdk/docs/api.rst                              |  4 ++
 task-sdk/src/airflow/sdk/__init__.py               |  3 ++
 task-sdk/src/airflow/sdk/api/client.py             |  5 +-
 .../src/airflow/sdk/api/datamodels/_generated.py   |  1 +
 task-sdk/src/airflow/sdk/execution_time/comms.py   |  1 +
 task-sdk/src/airflow/sdk/execution_time/context.py | 31 ++++++++++--
 .../src/airflow/sdk/execution_time/supervisor.py   |  2 +-
 task-sdk/tests/task_sdk/api/test_client.py         | 39 +++++++++++++-
 .../tests/task_sdk/execution_time/test_context.py  | 59 ++++++++++++++++++++--
 .../task_sdk/execution_time/test_supervisor.py     | 24 ++++++++-
 .../task_sdk/execution_time/test_task_runner.py    | 35 +++++++++++--
 17 files changed, 295 insertions(+), 43 deletions(-)

diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/task_state.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/task_state.py
index 3200f3177af..20980b315c3 100644
--- 
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/task_state.py
+++ 
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/task_state.py
@@ -17,6 +17,8 @@
 
 from __future__ import annotations
 
+from datetime import datetime
+
 from airflow.api_fastapi.core_api.base import StrictBaseModel
 
 
@@ -30,3 +32,4 @@ class TaskStatePutBody(StrictBaseModel):
     """Request body for setting a task state value."""
 
     value: str
+    expires_at: datetime | None = None
diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py
index acdaa8c6a24..2f824e3ebb2 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_state.py
@@ -86,7 +86,7 @@ def set_task_state(
 ) -> None:
     """Set a task state key, creating or updating the row."""
     scope = _get_task_scope_for_ti(task_instance_id, session)
-    get_state_backend().set(scope, key, body.value, session=session)
+    get_state_backend().set(scope, key, body.value, 
expires_at=body.expires_at, session=session)
 
 
 @router.delete("/{task_instance_id}/{key}", 
status_code=status.HTTP_204_NO_CONTENT)
diff --git a/airflow-core/src/airflow/state/metastore.py 
b/airflow-core/src/airflow/state/metastore.py
index f58c69f5808..e5e0a82be3e 100644
--- a/airflow-core/src/airflow/state/metastore.py
+++ b/airflow-core/src/airflow/state/metastore.py
@@ -19,7 +19,7 @@ from __future__ import annotations
 
 from collections.abc import AsyncGenerator
 from contextlib import asynccontextmanager
-from datetime import datetime, timedelta
+from datetime import datetime
 from typing import TYPE_CHECKING
 
 import structlog
@@ -46,18 +46,6 @@ if TYPE_CHECKING:
 log = structlog.get_logger(__name__)
 
 
-def _compute_expires_at(now: datetime) -> datetime | None:
-    """
-    Return the expiry timestamp for a new task state row based on config.
-
-    Returns None if default_retention_days is 0 (never expires).
-    """
-    retention_days = conf.getint("state_store", "default_retention_days")
-    if retention_days <= 0:
-        return None
-    return now + timedelta(days=retention_days)
-
-
 @asynccontextmanager
 async def _async_session(session: AsyncSession | None) -> 
AsyncGenerator[AsyncSession, None]:
     """Use provided async session or create a new one."""
@@ -111,12 +99,20 @@ class MetastoreStateBackend(BaseStateBackend):
                 assert_never(scope)
 
     @provide_session
-    def set(self, scope: StateScope, key: str, value: str, *, session: Session 
| None = NEW_SESSION) -> None:
+    def set(
+        self,
+        scope: StateScope,
+        key: str,
+        value: str,
+        *,
+        expires_at: datetime | None = None,
+        session: Session | None = NEW_SESSION,
+    ) -> None:
         if TYPE_CHECKING:
             assert session is not None
         match scope:
             case TaskScope():
-                self._set_task_state(scope, key, value, session=session)
+                self._set_task_state(scope, key, value, expires_at=expires_at, 
session=session)
             case AssetScope():
                 self._set_asset_state(scope, key, value, session=session)
             case _:
@@ -163,12 +159,18 @@ class MetastoreStateBackend(BaseStateBackend):
                     assert_never(scope)
 
     async def aset(
-        self, scope: StateScope, key: str, value: str, *, session: 
AsyncSession | None = None
+        self,
+        scope: StateScope,
+        key: str,
+        value: str,
+        *,
+        expires_at: datetime | None = None,
+        session: AsyncSession | None = None,
     ) -> None:
         async with _async_session(session) as s:
             match scope:
                 case TaskScope():
-                    await self._aset_task_state(scope, key, value, session=s)
+                    await self._aset_task_state(scope, key, value, 
expires_at=expires_at, session=s)
                 case AssetScope():
                     await self._aset_asset_state(scope, key, value, session=s)
                 case _:
@@ -208,7 +210,15 @@ class MetastoreStateBackend(BaseStateBackend):
         )
         return row.value if row is not None else None
 
-    def _set_task_state(self, scope: TaskScope, key: str, value: str, *, 
session: Session) -> None:
+    def _set_task_state(
+        self,
+        scope: TaskScope,
+        key: str,
+        value: str,
+        *,
+        expires_at: datetime | None = None,
+        session: Session,
+    ) -> None:
         dag_run_id = session.scalar(
             select(DagRun.id).where(
                 DagRun.dag_id == scope.dag_id,
@@ -218,7 +228,6 @@ class MetastoreStateBackend(BaseStateBackend):
         if dag_run_id is None:
             raise ValueError(f"No DagRun found for dag_id={scope.dag_id!r} 
run_id={scope.run_id!r}")
         now = timezone.utcnow()
-        expires_at = _compute_expires_at(now)
         values = dict(
             dag_run_id=dag_run_id,
             dag_id=scope.dag_id,
@@ -354,7 +363,13 @@ class MetastoreStateBackend(BaseStateBackend):
         return row.value if row is not None else None
 
     async def _aset_task_state(
-        self, scope: TaskScope, key: str, value: str, *, session: AsyncSession
+        self,
+        scope: TaskScope,
+        key: str,
+        value: str,
+        *,
+        expires_at: datetime | None = None,
+        session: AsyncSession,
     ) -> None:
         dag_run_id = await session.scalar(
             select(DagRun.id).where(
@@ -365,7 +380,6 @@ class MetastoreStateBackend(BaseStateBackend):
         if dag_run_id is None:
             raise ValueError(f"No DagRun found for dag_id={scope.dag_id!r} 
run_id={scope.run_id!r}")
         now = timezone.utcnow()
-        expires_at = _compute_expires_at(now)
         values = dict(
             dag_run_id=dag_run_id,
             dag_id=scope.dag_id,
diff --git 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_state.py
 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_state.py
index 8a66a0a23c7..d83751050e7 100644
--- 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_state.py
+++ 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_state.py
@@ -16,9 +16,11 @@
 # under the License.
 from __future__ import annotations
 
+from datetime import datetime
 from typing import TYPE_CHECKING
 from uuid import uuid4
 
+import pendulum
 import pytest
 from fastapi import Request
 from fastapi.testclient import TestClient
@@ -95,6 +97,34 @@ class TestPutTaskState:
             assert row is not None
             assert row.value == "spark_001"
 
+    def test_put_with_expires_at_creates_row(
+        self, client: TestClient, create_task_instance: CreateTaskInstance, 
time_machine
+    ):
+
+        ti = create_task_instance()
+        time_machine.move_to(datetime(2026, 5, 5, 12, 0, 0), tick=False)
+        response = client.put(
+            _api_url(ti.id, "job_id"),
+            json={
+                "value": "spark_001",
+                "expires_at": datetime(2026, 5, 15, 12, 0, 0, 
tzinfo=pendulum.UTC).isoformat(),
+            },
+        )
+
+        assert response.status_code == 204
+        with create_session() as session:
+            row = session.scalar(
+                select(TaskStateModel).where(
+                    TaskStateModel.dag_id == ti.dag_id,
+                    TaskStateModel.run_id == ti.run_id,
+                    TaskStateModel.task_id == ti.task_id,
+                    TaskStateModel.key == "job_id",
+                )
+            )
+            assert row is not None
+            assert row.value == "spark_001"
+            assert row.expires_at == datetime(2026, 5, 15, 12, 0, 0, 
tzinfo=pendulum.UTC)
+
     def test_put_overwrites_existing(self, client: TestClient, 
create_task_instance: CreateTaskInstance):
         ti = create_task_instance()
         client.put(_api_url(ti.id, "job_id"), json={"value": "spark_001"})
diff --git a/airflow-core/tests/unit/state/test_metastore.py 
b/airflow-core/tests/unit/state/test_metastore.py
index d9e1ff33afd..fbd37ddc30e 100644
--- a/airflow-core/tests/unit/state/test_metastore.py
+++ b/airflow-core/tests/unit/state/test_metastore.py
@@ -239,18 +239,29 @@ class TestMetastoreStateBackendTaskScope:
         assert backend.get(scope0, "job_id", session=session) is None
         assert backend.get(scope1, "job_id", session=session) is None
 
-    def test_set_populates_expires_at(
+    def test_set_without_expires_at_stores_null(
         self, session: Session, backend: MetastoreStateBackend, dag_run: DagRun
     ):
-        """set() always populates expires_at so cleanup has a single pass."""
+        """set() without expires_at stores NULL — the worker is responsible 
for computing expiry."""
         scope = TaskScope(dag_id=DAG_ID, run_id=RUN_ID, task_id=TASK_ID)
         backend.set(scope, "job_id", "app_1234", session=session)
         session.flush()
 
         row = session.scalar(select(TaskStateModel).where(TaskStateModel.key 
== "job_id"))
         assert row is not None
-        assert row.expires_at is not None
-        assert row.expires_at > row.updated_at
+        assert row.expires_at is None
+
+    def test_set_expires_at_none_stores_null(
+        self, session: Session, backend: MetastoreStateBackend, dag_run: DagRun
+    ):
+        """expires_at=None stores NULL — the key never expires regardless of 
global config."""
+        scope = TaskScope(dag_id=DAG_ID, run_id=RUN_ID, task_id=TASK_ID)
+        backend.set(scope, "job_id", "app_1234", session=session)
+        session.flush()
+
+        row = session.scalar(select(TaskStateModel).where(TaskStateModel.key 
== "job_id"))
+        assert row is not None
+        assert row.expires_at is None
 
     def test_cleanup_removes_expired_rows(
         self, session: Session, backend: MetastoreStateBackend, dag_run: DagRun
diff --git a/shared/state/src/airflow_shared/state/__init__.py 
b/shared/state/src/airflow_shared/state/__init__.py
index e231bdfd3bd..bfce8db3328 100644
--- a/shared/state/src/airflow_shared/state/__init__.py
+++ b/shared/state/src/airflow_shared/state/__init__.py
@@ -21,6 +21,8 @@ from dataclasses import dataclass
 from typing import TYPE_CHECKING
 
 if TYPE_CHECKING:
+    from datetime import datetime
+
     from sqlalchemy.ext.asyncio import AsyncSession
     from sqlalchemy.orm import Session
 
@@ -84,11 +86,23 @@ class BaseStateBackend(ABC):
         """
 
     @abstractmethod
-    def set(self, scope: StateScope, key: str, value: str, *, session: Session 
| None = None) -> None:
+    def set(
+        self,
+        scope: StateScope,
+        key: str,
+        value: str,
+        *,
+        expires_at: datetime | None = None,
+        session: Session | None = None,
+    ) -> None:
         """
         Write or overwrite the value for the given key.
 
         Must handle both ``TaskScope`` and ``AssetScope``.
+
+        ``expires_at`` is an absolute UTC datetime after which the row may be 
deleted.
+        Pass ``None`` (default) for a key that should never expire — stored as 
``NULL``,
+        skipped by garbage collection.
         """
 
     @abstractmethod
@@ -125,7 +139,13 @@ class BaseStateBackend(ABC):
 
     @abstractmethod
     async def aset(
-        self, scope: StateScope, key: str, value: str, *, session: 
AsyncSession | None = None
+        self,
+        scope: StateScope,
+        key: str,
+        value: str,
+        *,
+        expires_at: datetime | None = None,
+        session: AsyncSession | None = None,
     ) -> None:
         """
         Async variant of set. Must handle both ``TaskScope`` and 
``AssetScope``.
diff --git a/task-sdk/docs/api.rst b/task-sdk/docs/api.rst
index cb9789f5bb6..222b28ad53d 100644
--- a/task-sdk/docs/api.rst
+++ b/task-sdk/docs/api.rst
@@ -261,6 +261,10 @@ For a complete list of available context variables (such 
as ``dag_run``,
 ``task_instance``, ``logical_date``, etc.), see the
 :ref:`Templates reference <templates-ref>`.
 
+.. rubric:: Task State
+
+.. autodata:: airflow.sdk.NEVER_EXPIRE
+
 .. rubric:: Logging
 
 .. autofunction:: airflow.sdk.log.mask_secret
diff --git a/task-sdk/src/airflow/sdk/__init__.py 
b/task-sdk/src/airflow/sdk/__init__.py
index f304b068237..05ececc2956 100644
--- a/task-sdk/src/airflow/sdk/__init__.py
+++ b/task-sdk/src/airflow/sdk/__init__.py
@@ -55,6 +55,7 @@ __all__ = [
     "IdentityMapper",
     "Label",
     "Metadata",
+    "NEVER_EXPIRE",
     "MultipleCronTriggerTimetable",
     "ObjectStoragePath",
     "Param",
@@ -170,6 +171,7 @@ if TYPE_CHECKING:
     from airflow.sdk.definitions.variable import Variable
     from airflow.sdk.definitions.xcom_arg import XComArg
     from airflow.sdk.execution_time import macros
+    from airflow.sdk.execution_time.context import NEVER_EXPIRE
     from airflow.sdk.io.path import ObjectStoragePath
     from airflow.sdk.types import TaskInstance
 
@@ -245,6 +247,7 @@ __lazy_imports: dict[str, str] = {
     "conf": ".configuration",
     "cross_downstream": ".bases.operator",
     "dag": ".definitions.dag",
+    "NEVER_EXPIRE": ".execution_time.context",
     "get_current_context": ".definitions.context",
     "get_parsing_context": ".definitions.context",
     "literal": ".definitions.template",
diff --git a/task-sdk/src/airflow/sdk/api/client.py 
b/task-sdk/src/airflow/sdk/api/client.py
index 269978ac9dd..24074824b88 100644
--- a/task-sdk/src/airflow/sdk/api/client.py
+++ b/task-sdk/src/airflow/sdk/api/client.py
@@ -21,6 +21,7 @@ import logging
 import ssl
 import sys
 import uuid
+from datetime import datetime
 from functools import cache
 from http import HTTPStatus
 from typing import TYPE_CHECKING, Any, TypeVar
@@ -693,9 +694,9 @@ class TaskStateOperations:
             raise
         return TaskStateResponse.model_validate_json(resp.read())
 
-    def set(self, ti_id: uuid.UUID, key: str, value: str) -> OKResponse:
+    def set(self, ti_id: uuid.UUID, key: str, value: str, expires_at: datetime 
| None) -> OKResponse:
         """Set a task state value via the API server."""
-        body = TaskStatePutBody(value=value)
+        body = TaskStatePutBody(value=value, expires_at=expires_at)
         self.client.put(f"state/ti/{ti_id}/{key}", 
content=body.model_dump_json())
         return OKResponse(ok=True)
 
diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py 
b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
index 9f1dadeef51..fc966a76969 100644
--- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -376,6 +376,7 @@ class TaskStatePutBody(BaseModel):
         extra="forbid",
     )
     value: Annotated[str, Field(title="Value")]
+    expires_at: Annotated[AwareDatetime | None, Field(title="Expires At")] = 
None
 
 
 class TaskStateResponse(BaseModel):
diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py 
b/task-sdk/src/airflow/sdk/execution_time/comms.py
index c56f5b23ab3..2364e942ed0 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -924,6 +924,7 @@ class SetTaskState(BaseModel):
     ti_id: UUID
     key: str
     value: str
+    expires_at: AwareDatetime | None
     type: Literal["SetTaskState"] = "SetTaskState"
 
 
diff --git a/task-sdk/src/airflow/sdk/execution_time/context.py 
b/task-sdk/src/airflow/sdk/execution_time/context.py
index cdd25803989..afc5a120c0c 100644
--- a/task-sdk/src/airflow/sdk/execution_time/context.py
+++ b/task-sdk/src/airflow/sdk/execution_time/context.py
@@ -21,7 +21,7 @@ import contextlib
 import functools
 import inspect
 from collections.abc import Generator, Iterable, Iterator, Mapping, Sequence
-from datetime import datetime
+from datetime import datetime, timedelta, timezone
 from functools import cache
 from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload
 from uuid import UUID
@@ -29,6 +29,7 @@ from uuid import UUID
 import attrs
 import structlog
 
+from airflow.sdk.configuration import conf
 from airflow.sdk.definitions._internal.contextmanager import _CURRENT_CONTEXT
 from airflow.sdk.definitions._internal.types import NOTSET
 from airflow.sdk.definitions.asset import (
@@ -108,6 +109,11 @@ AIRFLOW_VAR_NAME_FORMAT_MAPPING = {
 
 log = structlog.get_logger(logger_name="task")
 
+#: Pass as ``retention`` to ``task_state.set()`` to store a key that never 
expires,
+#: regardless of the global ``[state_store] default_retention_days`` config.
+#: Example: ``context["task_state"].set("job_id", job_id, 
retention=NEVER_EXPIRE)``
+NEVER_EXPIRE: timedelta = timedelta.max
+
 T = TypeVar("T")
 
 
@@ -467,12 +473,29 @@ class TaskStateAccessor:
             return resp.value
         return None
 
-    def set(self, key: str, value: str) -> None:
-        """Write or overwrite the value for the given key."""
+    def set(self, key: str, value: str, *, retention: timedelta | None = None) 
-> None:
+        """
+        Write or overwrite the value for the given key.
+
+        ``retention`` is an optional key that controls when this key expires:
+
+        - ``timedelta(...)`` — expire after the given duration (e.g. 
``timedelta(hours=6)``).
+        - ``NEVER_EXPIRE`` — key never expires, regardless of the global 
config and is skipped by garbage collection.
+        - ``None`` (default) — use the global ``[state_store] 
default_retention_days`` config.
+        """
         from airflow.sdk.execution_time.comms import SetTaskState
         from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
 
-        SUPERVISOR_COMMS.send(SetTaskState(ti_id=self._ti_id, key=key, 
value=value))
+        # expires_at is always resolved on the worker in UTC before being sent.
+        now = datetime.now(tz=timezone.utc)
+        if retention is NEVER_EXPIRE:
+            expires_at = None
+        elif retention is not None:
+            expires_at = now + retention
+        else:
+            days = conf.getint("state_store", "default_retention_days")
+            expires_at = None if days <= 0 else now + timedelta(days=days)
+        SUPERVISOR_COMMS.send(SetTaskState(ti_id=self._ti_id, key=key, 
value=value, expires_at=expires_at))
 
     def delete(self, key: str) -> None:
         """Delete a single key. No-op if the key does not exist."""
diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index 3e6236c5786..5e46b6bc864 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -1660,7 +1660,7 @@ class ActivitySubprocess(WatchedSubprocess):
                 else TaskStateResult.from_task_state_response(task_state)
             )
         elif isinstance(msg, SetTaskState):
-            self.client.task_state.set(msg.ti_id, msg.key, msg.value)
+            self.client.task_state.set(msg.ti_id, msg.key, msg.value, 
expires_at=msg.expires_at)
             resp = OKResponse(ok=True)
         elif isinstance(msg, DeleteTaskState):
             self.client.task_state.delete(msg.ti_id, msg.key)
diff --git a/task-sdk/tests/task_sdk/api/test_client.py 
b/task-sdk/tests/task_sdk/api/test_client.py
index a179ff08436..805dac87934 100644
--- a/task-sdk/tests/task_sdk/api/test_client.py
+++ b/task-sdk/tests/task_sdk/api/test_client.py
@@ -1764,6 +1764,8 @@ class TestTaskStateOperations:
         assert result.error == ErrorType.TASK_STATE_NOT_FOUND
 
     def test_set_success(self):
+        expires = datetime(2026, 6, 13, 12, 0, 0, tzinfo=dt_timezone.utc)
+
         def handle_request(request: httpx.Request) -> httpx.Response:
             assert request.method == "PUT"
             assert request.url.path == f"/state/ti/{self.TI_ID}/job_id"
@@ -1771,7 +1773,42 @@ class TestTaskStateOperations:
             return httpx.Response(status_code=204)
 
         client = make_client(transport=httpx.MockTransport(handle_request))
-        result = client.task_state.set(ti_id=self.TI_ID, key="job_id", 
value="spark_app_001")
+        result = client.task_state.set(
+            ti_id=self.TI_ID, key="job_id", value="spark_app_001", 
expires_at=expires
+        )
+        assert result == OKResponse(ok=True)
+
+    def test_set_with_expires_at_sends_field(self):
+        """expires_at is forwarded as an ISO datetime string in the request 
body."""
+        expires = datetime(2026, 5, 21, 12, 0, 0, tzinfo=dt_timezone.utc)
+
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            body = json.loads(request.content)
+            assert body["value"] == "spark_app_001"
+            assert body["expires_at"] == "2026-05-21T12:00:00Z"
+            return httpx.Response(status_code=204)
+
+        client = make_client(transport=httpx.MockTransport(handle_request))
+        result = client.task_state.set(
+            ti_id=self.TI_ID, key="job_id", value="spark_app_001", 
expires_at=expires
+        )
+        assert result == OKResponse(ok=True)
+
+    def test_set_with_never_expire_sends_null_expires_at(self):
+        """NEVER_EXPIRE sends expires_at=null — stored as NULL in DB, GC skips 
it."""
+
+        def handle_request(request: httpx.Request) -> httpx.Response:
+            body = json.loads(request.content)
+            assert body.get("expires_at") is None
+            return httpx.Response(status_code=204)
+
+        client = make_client(transport=httpx.MockTransport(handle_request))
+        result = client.task_state.set(
+            ti_id=self.TI_ID,
+            key="job_id",
+            value="v",
+            expires_at=None,
+        )
         assert result == OKResponse(ok=True)
 
     def test_delete_success(self):
diff --git a/task-sdk/tests/task_sdk/execution_time/test_context.py 
b/task-sdk/tests/task_sdk/execution_time/test_context.py
index ff0e6025c63..a5ff7be9ce8 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_context.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_context.py
@@ -17,6 +17,7 @@
 
 from __future__ import annotations
 
+from datetime import datetime, timedelta, timezone as dt_timezone
 from unittest import mock
 from unittest.mock import MagicMock, patch
 from uuid import UUID
@@ -73,6 +74,7 @@ from airflow.sdk.execution_time.comms import (
     XComResult,
 )
 from airflow.sdk.execution_time.context import (
+    NEVER_EXPIRE,
     AssetStateAccessor,
     AssetStateAccessors,
     ConnectionAccessor,
@@ -92,6 +94,8 @@ from airflow.sdk.execution_time.context import (
 )
 from airflow.sdk.execution_time.secrets import ExecutionAPISecretsBackend
 
+from tests_common.test_utils.config import conf_vars
+
 
 def test_convert_connection_result_conn():
     """Test that the ConnectionResult is converted to a Connection object."""
@@ -1085,13 +1089,62 @@ class TestTaskStateAccessor:
         with pytest.raises(AirflowRuntimeError):
             TaskStateAccessor(ti_id=self.TI_ID).get("some_key")
 
-    def test_set_operation(self, mock_supervisor_comms):
+    def test_set_operation_with_global_retention(self, mock_supervisor_comms, 
time_machine):
+        """set() with no retention uses global default_retention_days 
config."""
+
+        mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+        now = datetime(2026, 5, 14, 12, 0, 0, tzinfo=dt_timezone.utc)
+        time_machine.move_to(now, tick=False)
+
+        with conf_vars({("state_store", "default_retention_days"): "30"}):
+            TaskStateAccessor(ti_id=self.TI_ID).set("job_id", "app_001")
+
+        mock_supervisor_comms.send.assert_called_once_with(
+            SetTaskState(
+                ti_id=self.TI_ID,
+                key="job_id",
+                value="app_001",
+                expires_at=datetime(2026, 6, 13, 12, 0, 0, 
tzinfo=dt_timezone.utc),
+            )
+        )
+
+    def test_set_with_retention_computes_expires_at(self, 
mock_supervisor_comms, time_machine):
+        """set(retention=timedelta(...)) computes expires_at on the worker and 
sends it."""
+        mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+        now = datetime(2026, 5, 14, 12, 0, 0, tzinfo=dt_timezone.utc)
+        time_machine.move_to(now, tick=False)
+
+        TaskStateAccessor(ti_id=self.TI_ID).set("job_id", "app_001", 
retention=timedelta(days=7))
+
+        mock_supervisor_comms.send.assert_called_once_with(
+            SetTaskState(
+                ti_id=self.TI_ID,
+                key="job_id",
+                value="app_001",
+                expires_at=datetime(2026, 5, 21, 12, 0, 0, 
tzinfo=dt_timezone.utc),
+            )
+        )
+
+    def test_set_with_never_expire_sends_null_expires_at(self, 
mock_supervisor_comms):
+        """set(retention=NEVER_EXPIRE) sends expires_at=None"""
+
+        mock_supervisor_comms.send.return_value = OKResponse(ok=True)
+
+        TaskStateAccessor(ti_id=self.TI_ID).set("job_id", "app_001", 
retention=NEVER_EXPIRE)
+
+        mock_supervisor_comms.send.assert_called_once_with(
+            SetTaskState(ti_id=self.TI_ID, key="job_id", value="app_001", 
expires_at=None)
+        )
+
+    def test_set_global_default_zero_sends_null_expires_at(self, 
mock_supervisor_comms):
+        """When default_retention_days=0 (never expire globally), 
expires_at=None (stored as NULL)."""
         mock_supervisor_comms.send.return_value = OKResponse(ok=True)
 
-        TaskStateAccessor(ti_id=self.TI_ID).set("job_id", "app_001")
+        with conf_vars({("state_store", "default_retention_days"): "0"}):
+            TaskStateAccessor(ti_id=self.TI_ID).set("job_id", "app_001")
 
         mock_supervisor_comms.send.assert_called_once_with(
-            SetTaskState(ti_id=self.TI_ID, key="job_id", value="app_001")
+            SetTaskState(ti_id=self.TI_ID, key="job_id", value="app_001", 
expires_at=None)
         )
 
     def test_delete_operation(self, mock_supervisor_comms):
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py 
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index 51131f3c48d..b4ba8de42e2 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -2722,11 +2722,33 @@ REQUEST_TEST_CASES = [
         expected_body={"value": "spark_app_001", "type": "TaskStateResult"},
     ),
     RequestTestCase(
-        message=SetTaskState(ti_id=TI_ID, key="job_id", value="spark_app_001"),
+        message=SetTaskState(
+            ti_id=TI_ID,
+            key="job_id",
+            value="spark_app_001",
+            expires_at=datetime(2026, 6, 13, 12, 0, 0, tzinfo=dt_timezone.utc),
+        ),
         test_id="set_task_state",
         client_mock=ClientMock(
             method_path="task_state.set",
             args=(TI_ID, "job_id", "spark_app_001"),
+            kwargs={"expires_at": datetime(2026, 6, 13, 12, 0, 0, 
tzinfo=dt_timezone.utc)},
+            response=OKResponse(ok=True),
+        ),
+        expected_body={"ok": True, "type": "OKResponse"},
+    ),
+    RequestTestCase(
+        message=SetTaskState(
+            ti_id=TI_ID,
+            key="job_id",
+            value="spark_app_001",
+            expires_at=datetime(2026, 5, 21, 12, 0, 0, tzinfo=dt_timezone.utc),
+        ),
+        test_id="set_task_state_with_expires_at",
+        client_mock=ClientMock(
+            method_path="task_state.set",
+            args=(TI_ID, "job_id", "spark_app_001"),
+            kwargs={"expires_at": datetime(2026, 5, 21, 12, 0, 0, 
tzinfo=dt_timezone.utc)},
             response=OKResponse(ok=True),
         ),
         expected_body={"ok": True, "type": "OKResponse"},
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py 
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index 5234b30b84c..54afc412f56 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -4966,7 +4966,7 @@ def test_dag_add_result(create_runtime_ti, 
mock_supervisor_comms):
 class TestTaskInstanceStateOperations:
     """Tests to verify that tasks can perform state operations (task / asset) 
via the supervisor."""
 
-    def test_task_can_set_and_get_state(self, create_runtime_ti, 
mock_supervisor_comms):
+    def test_task_can_set_and_get_state(self, create_runtime_ti, 
mock_supervisor_comms, time_machine):
         class MyOperator(BaseOperator):
             def execute(self, context):
                 ts = context["task_state"]
@@ -4975,14 +4975,43 @@ class TestTaskInstanceStateOperations:
 
         task = MyOperator(task_id="t")
         runtime_ti = create_runtime_ti(task=task)
+        frozen_dt = datetime(2026, 1, 1, 12, 0, 0, tzinfo=dt_timezone.utc)
+        time_machine.move_to(frozen_dt, tick=False)
 
-        run(runtime_ti, context=runtime_ti.get_template_context(), 
log=mock.MagicMock())
+        with conf_vars({("state_store", "default_retention_days"): "30"}):
+            run(runtime_ti, context=runtime_ti.get_template_context(), 
log=mock.MagicMock())
 
         mock_supervisor_comms.send.assert_any_call(
-            SetTaskState(ti_id=runtime_ti.id, key="job_id", 
value="spark_app_001")
+            SetTaskState(
+                ti_id=runtime_ti.id,
+                key="job_id",
+                value="spark_app_001",
+                expires_at=frozen_dt + timedelta(days=30),
+            )
         )
         
mock_supervisor_comms.send.assert_any_call(GetTaskState(ti_id=runtime_ti.id, 
key="job_id"))
 
+    def test_task_can_set_state_with_retention(self, create_runtime_ti, 
mock_supervisor_comms, time_machine):
+        class MyOperator(BaseOperator):
+            def execute(self, context):
+                context["task_state"].set("job_id", "spark_app_001", 
retention=timedelta(days=7))
+
+        task = MyOperator(task_id="t")
+        runtime_ti = create_runtime_ti(task=task)
+        frozen_dt = datetime(2026, 1, 1, 12, 0, 0, tzinfo=dt_timezone.utc)
+        time_machine.move_to(frozen_dt, tick=False)
+
+        run(runtime_ti, context=runtime_ti.get_template_context(), 
log=mock.MagicMock())
+
+        mock_supervisor_comms.send.assert_any_call(
+            SetTaskState(
+                ti_id=runtime_ti.id,
+                key="job_id",
+                value="spark_app_001",
+                expires_at=frozen_dt + timedelta(days=7),
+            )
+        )
+
     def test_task_can_delete_state(self, create_runtime_ti, 
mock_supervisor_comms):
         class MyOperator(BaseOperator):
             def execute(self, context):

Reply via email to