amoghrajesh commented on code in PR #66160:
URL: https://github.com/apache/airflow/pull/66160#discussion_r3194336519


##########
task-sdk/src/airflow/sdk/execution_time/context.py:
##########
@@ -406,6 +405,227 @@ def get(self, key, default: Any = NOTSET) -> Any:
             raise
 
 
+class TaskStateAccessor:
+    """Accessor for task state scoped to the current task instance. Available 
as ``context['task_state']`` at task execution time."""
+
+    def __init__(self, ti_id: UUID) -> None:
+        self._ti_id = ti_id
+
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, TaskStateAccessor):
+            return False
+        return self._ti_id == other._ti_id
+
+    def __hash__(self) -> int:
+        return hash(self._ti_id)
+
+    def __repr__(self) -> str:
+        return f"<TaskStateAccessor ti_id={self._ti_id}>"
+
+    # TODO: ``__getattr__`` for jinja template access like ``{{ 
task_state.job_id }}``
+    # is not implemented yet cos it's unclear whether task state values will be
+    # used in templates.
+
+    def get(self, key: str) -> str | None:
+        """Return the stored value, or ``None`` if the key does not exist."""
+        from airflow.sdk.execution_time.comms import ErrorResponse, 
GetTaskState, TaskStateResult
+        from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+        resp = SUPERVISOR_COMMS.send(GetTaskState(ti_id=self._ti_id, key=key))
+        if isinstance(resp, ErrorResponse) and resp.error != 
ErrorType.TASK_STATE_NOT_FOUND:
+            raise AirflowRuntimeError(resp)
+        if isinstance(resp, TaskStateResult):
+            return resp.value
+        return None
+
+    def set(self, key: str, value: str) -> None:
+        """Write or overwrite the value for the given key."""
+        from airflow.sdk.execution_time.comms import SetTaskState
+        from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+        SUPERVISOR_COMMS.send(SetTaskState(ti_id=self._ti_id, key=key, 
value=value))
+
+    def delete(self, key: str) -> None:
+        """Delete a single key. No-op if the key does not exist."""
+        from airflow.sdk.execution_time.comms import DeleteTaskState
+        from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+        SUPERVISOR_COMMS.send(DeleteTaskState(ti_id=self._ti_id, key=key))
+
+    def clear(self, all_map_indices: bool = False) -> None:
+        """
+        Delete all keys for this task instance.
+
+        Pass ``all_map_indices=True`` to wipe state across every mapped
+        instance of the task (fleet-wide reset). Defaults to clearing only
+        this task instance's own state.
+        """
+        from airflow.sdk.execution_time.comms import ClearTaskState
+        from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+        SUPERVISOR_COMMS.send(ClearTaskState(ti_id=self._ti_id, 
all_map_indices=all_map_indices))
+
+
+class AssetStateAccessor:
+    """
+    Accessor for asset state scoped to a single asset.
+
+    Obtained via ``context['asset_state'][MY_ASSET]`` or, as sugar for 
single-inlet
+    tasks, directly as ``context['asset_state']``.
+    """
+
+    def __init__(self, *, name: str | None = None, uri: str | None = None) -> 
None:
+        self._name = name
+        self._uri = uri
+
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, AssetStateAccessor):
+            return False
+        return self._name == other._name and self._uri == other._uri
+
+    def __hash__(self) -> int:
+        return hash((self._name, self._uri))
+
+    def __repr__(self) -> str:
+        if self._name is not None:
+            return f"<AssetStateAccessor name={self._name!r}>"
+        return f"<AssetStateAccessor uri={self._uri!r}>"
+
+    def get(self, key: str) -> str | None:
+        """Return the stored value, or ``None`` if the key does not exist."""
+        from airflow.sdk.execution_time.comms import (
+            AssetStateResult,
+            ErrorResponse,
+            GetAssetStateByName,
+            GetAssetStateByUri,
+            ToSupervisor,
+        )
+        from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+        msg: ToSupervisor
+        if self._name:
+            msg = GetAssetStateByName(name=self._name, key=key)
+        elif self._uri:
+            msg = GetAssetStateByUri(uri=self._uri, key=key)
+        else:
+            raise ValueError("Either `name` or `uri` must be provided")
+        resp = SUPERVISOR_COMMS.send(msg)
+        if isinstance(resp, ErrorResponse) and resp.error != 
ErrorType.ASSET_STATE_NOT_FOUND:
+            raise AirflowRuntimeError(resp)
+        if isinstance(resp, AssetStateResult):
+            return resp.value
+        return None
+
+    def set(self, key: str, value: str) -> None:
+        """Write or overwrite the value for the given key."""
+        from airflow.sdk.execution_time.comms import SetAssetStateByName, 
SetAssetStateByUri, ToSupervisor
+        from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+        msg: ToSupervisor
+        if self._name:
+            msg = SetAssetStateByName(name=self._name, key=key, value=value)
+        elif self._uri:
+            msg = SetAssetStateByUri(uri=self._uri, key=key, value=value)
+        else:
+            raise ValueError("Either `name` or `uri` must be provided")
+        SUPERVISOR_COMMS.send(msg)
+
+    def delete(self, key: str) -> None:
+        """Delete a single key. No-op if the key does not exist."""
+        from airflow.sdk.execution_time.comms import (
+            DeleteAssetStateByName,
+            DeleteAssetStateByUri,
+            ToSupervisor,
+        )
+        from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+        msg: ToSupervisor
+        if self._name:
+            msg = DeleteAssetStateByName(name=self._name, key=key)
+        elif self._uri:
+            msg = DeleteAssetStateByUri(uri=self._uri, key=key)
+        else:
+            raise ValueError("Either `name` or `uri` must be provided")
+        SUPERVISOR_COMMS.send(msg)
+
+    def clear(self) -> None:
+        """Delete all state keys for this asset."""
+        from airflow.sdk.execution_time.comms import ClearAssetStateByName, 
ClearAssetStateByUri, ToSupervisor
+        from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+        msg: ToSupervisor
+        if self._name:
+            msg = ClearAssetStateByName(name=self._name)
+        elif self._uri:
+            msg = ClearAssetStateByUri(uri=self._uri)
+        else:
+            raise ValueError("Either `name` or `uri` must be provided")
+        SUPERVISOR_COMMS.send(msg)
+
+
+class AssetStateAccessors:
+    """
+    Mapping of asset state accessors for all concrete inlets of a task.
+
+    Available as ``context['asset_state']``. Subscript by asset to get a per 
asset
+    accessor as: ``context['asset_state'][MY_ASSET].get('watermark')``.
+
+    For tasks with exactly one concrete inlet, the accessor methods (``get``, 
``set``,
+    ``delete``, ``clear``) can be called directly without subscripting.
+    """
+
+    def __init__(self, inlets: list) -> None:
+        self._by_name: dict[str, AssetStateAccessor] = {}
+        self._by_uri: dict[str, AssetStateAccessor] = {}
+
+        for inlet in inlets:
+            if isinstance(inlet, (Asset, AssetNameRef)):
+                self._by_name[inlet.name] = AssetStateAccessor(name=inlet.name)
+            elif isinstance(inlet, AssetUriRef):
+                self._by_uri[inlet.uri] = AssetStateAccessor(uri=inlet.uri)
+
+        self._total = len(self._by_name) + len(self._by_uri)
+
+    def __getitem__(self, key: Asset | AssetNameRef | AssetUriRef) -> 
AssetStateAccessor:
+        try:
+            if isinstance(key, (Asset, AssetNameRef)):
+                return self._by_name[key.name]
+            if isinstance(key, AssetUriRef):
+                return self._by_uri[key.uri]
+        except KeyError:
+            raise KeyError(f"{key!r} is not in this task's inlets")
+        raise TypeError(f"Expected Asset, AssetNameRef, or AssetUriRef; got 
{type(key).__name__}")
+
+    def _single_accessor(self) -> AssetStateAccessor:
+        if self._total != 1:
+            raise ValueError(
+                f"Task has {self._total} concrete inlets — use 
context['asset_state'][MY_ASSET] to specify which"
+            )
+        if self._by_name:
+            return next(iter(self._by_name.values()))
+        return next(iter(self._by_uri.values()))
+
+    def get(self, key: str) -> str | None:
+        """Return the stored value for the single-inlet task, or ``None`` if 
not found."""
+        return self._single_accessor().get(key)
+
+    def set(self, key: str, value: str) -> None:
+        """Write or overwrite the value for the single-inlet task."""
+        self._single_accessor().set(key, value)
+
+    def delete(self, key: str) -> None:
+        """Delete a single key for the single-inlet task."""
+        self._single_accessor().delete(key)
+
+    def clear(self) -> None:
+        """Delete all state keys for the single-inlet task."""
+        self._single_accessor().clear()
+
+    def __repr__(self) -> str:
+        names = [*self._by_name, *self._by_uri]
+        return f"<AssetStateAccessors {names!r}>"

Review Comment:
   Good catch! Changed repr to show `name` and `uri` prefixes



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to