Lee-W commented on code in PR #66160:
URL: https://github.com/apache/airflow/pull/66160#discussion_r3199201272


##########
task-sdk/tests/task_sdk/execution_time/test_task_runner.py:
##########
@@ -4868,3 +4887,218 @@ def test_dag_add_result(create_runtime_ti, 
mock_supervisor_comms):
             dag_result=True,
         )
     )
+
+
+class TestTaskInstanceStateOperations:
+    """Tests to verify that tasks can perform state operations (task / asset) 
via the supervisor."""
+
+    def test_task_can_set_and_get_state(self, create_runtime_ti, 
mock_supervisor_comms):
+        class MyOperator(BaseOperator):
+            def execute(self, context):
+                ts = context["task_state"]
+                ts.set("job_id", "spark_app_001")
+                return ts.get("job_id")
+
+        task = MyOperator(task_id="t")
+        runtime_ti = create_runtime_ti(task=task)
+
+        run(runtime_ti, context=runtime_ti.get_template_context(), 
log=mock.MagicMock())
+
+        mock_supervisor_comms.send.assert_any_call(
+            SetTaskState(ti_id=runtime_ti.id, key="job_id", 
value="spark_app_001")
+        )
+        
mock_supervisor_comms.send.assert_any_call(GetTaskState(ti_id=runtime_ti.id, 
key="job_id"))
+
+    def test_task_can_delete_state(self, create_runtime_ti, 
mock_supervisor_comms):
+        class MyOperator(BaseOperator):
+            def execute(self, context):
+                context["task_state"].delete("job_id")
+
+        task = MyOperator(task_id="t")
+        runtime_ti = create_runtime_ti(task=task)
+
+        run(runtime_ti, context=runtime_ti.get_template_context(), 
log=mock.MagicMock())
+
+        
mock_supervisor_comms.send.assert_any_call(DeleteTaskState(ti_id=runtime_ti.id, 
key="job_id"))
+
+    @pytest.mark.parametrize(
+        ("call_kwargs", "expected_flag"),
+        [
+            pytest.param({}, False, id="default"),
+            pytest.param({"all_map_indices": True}, True, id="fleet-wipe"),
+        ],
+    )
+    def test_task_can_clear_state(self, call_kwargs, expected_flag, 
create_runtime_ti, mock_supervisor_comms):
+        class MyOperator(BaseOperator):
+            def execute(self, context):
+                context["task_state"].clear(**call_kwargs)
+
+        task = MyOperator(task_id="t")
+        runtime_ti = create_runtime_ti(task=task)
+        run(runtime_ti, context=runtime_ti.get_template_context(), 
log=mock.MagicMock())
+        mock_supervisor_comms.send.assert_any_call(
+            ClearTaskState(ti_id=runtime_ti.id, all_map_indices=expected_flag)
+        )
+
+    @staticmethod
+    def _watcher_side_effect(msg=None, *args, **kwargs):
+        actual = msg or (args[0] if args else None)
+        if isinstance(actual, ValidateInletsAndOutlets):
+            return InactiveAssetsResult(inactive_assets=[])
+        if isinstance(actual, GetAssetByUri):
+            return AssetResult(name=actual.uri, uri=actual.uri, group="asset")

Review Comment:
   ```suggestion
               return AssetResult(name=actual.name, uri=actual.uri, 
group="asset")
   ```



##########
task-sdk/src/airflow/sdk/execution_time/context.py:
##########
@@ -406,6 +405,235 @@ def get(self, key, default: Any = NOTSET) -> Any:
             raise
 
 
+class TaskStateAccessor:
+    """Accessor for task state scoped to the current task instance. Available 
as ``context['task_state']`` at task execution time."""
+
+    def __init__(self, ti_id: UUID) -> None:
+        self._ti_id = ti_id
+
+    def __eq__(self, other: object) -> bool:
+        if not isinstance(other, TaskStateAccessor):
+            return False
+        return self._ti_id == other._ti_id
+
+    def __hash__(self) -> int:
+        return hash(self._ti_id)
+
+    def __repr__(self) -> str:
+        return f"<TaskStateAccessor ti_id={self._ti_id}>"
+
+    # TODO: ``__getattr__`` for jinja template access like ``{{ 
task_state.job_id }}``
+    # is not implemented yet cos it's unclear whether task state values will be
+    # used in templates.
+
+    def get(self, key: str) -> str | None:
+        """Return the stored value, or ``None`` if the key does not exist."""
+        from airflow.sdk.execution_time.comms import ErrorResponse, 
GetTaskState, TaskStateResult
+        from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+        resp = SUPERVISOR_COMMS.send(GetTaskState(ti_id=self._ti_id, key=key))
+        if isinstance(resp, ErrorResponse) and resp.error != 
ErrorType.TASK_STATE_NOT_FOUND:
+            raise AirflowRuntimeError(resp)
+        if isinstance(resp, TaskStateResult):
+            return resp.value
+        return None
+
+    def set(self, key: str, value: str) -> None:
+        """Write or overwrite the value for the given key."""
+        from airflow.sdk.execution_time.comms import SetTaskState
+        from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+        SUPERVISOR_COMMS.send(SetTaskState(ti_id=self._ti_id, key=key, 
value=value))
+
+    def delete(self, key: str) -> None:
+        """Delete a single key. No-op if the key does not exist."""
+        from airflow.sdk.execution_time.comms import DeleteTaskState
+        from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+        SUPERVISOR_COMMS.send(DeleteTaskState(ti_id=self._ti_id, key=key))
+
+    def clear(self, all_map_indices: bool = False) -> None:
+        """
+        Delete all keys for this task instance.
+
+        Pass ``all_map_indices=True`` to wipe state across every mapped
+        instance of the task (fleet-wide reset). Defaults to clearing only
+        this task instance's own state.
+        """
+        from airflow.sdk.execution_time.comms import ClearTaskState
+        from airflow.sdk.execution_time.task_runner import SUPERVISOR_COMMS
+
+        SUPERVISOR_COMMS.send(ClearTaskState(ti_id=self._ti_id, 
all_map_indices=all_map_indices))
+
+
+class AssetStateAccessor:
+    """
+    Accessor for asset state scoped to a single asset.
+
+    Obtained via ``context['asset_state'][MY_ASSET]`` or, as sugar for 
single-inlet
+    tasks, directly as ``context['asset_state']``.
+    """
+
+    def __init__(self, *, name: str | None = None, uri: str | None = None) -> 
None:
+        self._name = name
+        self._uri = uri

Review Comment:
   I kinda feel we might encounter lint or mypy issue by doing this 🤔 (as there 
won't be an else case), but probably not a bad idea to try



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to