This is an automated email from the ASF dual-hosted git repository.

amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 7d615cff9e9 AIP-103: Implement clear_on_success config to wipe task 
state on success (#66586)
7d615cff9e9 is described below

commit 7d615cff9e94ff73bbca18779050aafb52772561
Author: Amogh Desai <[email protected]>
AuthorDate: Thu May 14 14:58:34 2026 +0530

    AIP-103: Implement clear_on_success config to wipe task state on success 
(#66586)
---
 .../api_fastapi/core_api/routes/public/dag_run.py  |  3 +
 .../core_api/services/public/task_instances.py     | 39 ++++++++++
 .../execution_api/routes/task_instances.py         | 28 ++++++++
 .../src/airflow/config_templates/config.yml        | 14 ++++
 .../core_api/routes/public/test_task_instances.py  | 74 +++++++++++++++++++
 .../versions/head/test_task_instances.py           | 83 ++++++++++++++++++++++
 .../src/airflow/sdk/execution_time/task_runner.py  |  4 ++
 7 files changed, 245 insertions(+)

diff --git 
a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_run.py 
b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_run.py
index 2e23ad1a171..b6896b4278b 100644
--- a/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_run.py
+++ b/airflow-core/src/airflow/api_fastapi/core_api/routes/public/dag_run.py
@@ -226,6 +226,9 @@ def patch_dag_run(
                     
get_listener_manager().hook.on_dag_run_success(dag_run=dag_run, msg="")
                 except Exception:
                     log.exception("error calling listener")
+
+            # TODO AIP-103: https://github.com/apache/airflow/issues/66755
+            # Handle clearing states for all task instances in a dagrun when 
cleared
             elif attr_value == DAGRunPatchStates.QUEUED:
                 set_dag_run_state_to_queued(dag=dag, run_id=dag_run.run_id, 
commit=True, session=session)
                 # Not notifying on queued - only notifying on RUNNING, this is 
happening in scheduler
diff --git 
a/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py
 
b/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py
index 65e8260c2f4..f72dfc52fbb 100644
--- 
a/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py
+++ 
b/airflow-core/src/airflow/api_fastapi/core_api/services/public/task_instances.py
@@ -28,6 +28,7 @@ from sqlalchemy import select, tuple_
 from sqlalchemy.orm import joinedload
 from sqlalchemy.orm.session import Session
 
+from airflow._shared.state import TaskScope
 from airflow.api_fastapi.app import get_auth_manager
 from airflow.api_fastapi.auth.managers.models.resource_details import 
DagAccessEntity, DagDetails
 from airflow.api_fastapi.common.dagbag import DagBagDep, 
get_latest_version_of_dag
@@ -47,15 +48,47 @@ from airflow.api_fastapi.core_api.datamodels.task_instances 
import (
 )
 from airflow.api_fastapi.core_api.security import GetUserDep
 from airflow.api_fastapi.core_api.services.public.common import BulkService
+from airflow.configuration import conf
 from airflow.listeners.listener import get_listener_manager
 from airflow.models.dag import DagModel
 from airflow.models.taskinstance import TaskInstance as TI
 from airflow.serialization.definitions.dag import SerializedDAG
+from airflow.state import get_state_backend
 from airflow.utils.state import TaskInstanceState
 
 log = structlog.get_logger(__name__)
 
 
+def _clear_task_state_on_success(tis: Sequence[TI], session: Session) -> None:
+    """Clear task state rows for each TI if clear_on_success is enabled."""
+    if not conf.getboolean("state_store", "clear_on_success", fallback=False):
+        return
+    backend = get_state_backend()
+    for ti in tis:
+        scope = TaskScope(
+            dag_id=ti.dag_id,
+            run_id=ti.run_id,
+            task_id=ti.task_id,
+            map_index=ti.map_index if ti.map_index is not None else -1,
+        )
+        try:
+            backend.clear(scope=scope, session=session)
+            log.info(
+                "Cleared task state on success",
+                dag_id=ti.dag_id,
+                run_id=ti.run_id,
+                task_id=ti.task_id,
+                map_index=ti.map_index,
+            )
+        except Exception:
+            log.warning(
+                "Failed to clear task state on success",
+                dag_id=ti.dag_id,
+                run_id=ti.run_id,
+                task_id=ti.task_id,
+            )
+
+
 def _validate_patch_task_instance_body(
     body: PatchTaskInstanceBody,
     update_mask: list[str] | None,
@@ -231,6 +264,9 @@ def _patch_task_instance_state(
             f"Task id {task_id} is already in {data['new_state']} state",
         )
 
+    if data["new_state"] == TaskInstanceState.SUCCESS:
+        _clear_task_state_on_success(updated_tis, session)
+
     _emit_state_listener_hooks(updated_tis, data["new_state"])
 
     return updated_tis
@@ -263,6 +299,9 @@ def _patch_task_group_state(
             f"All task instances in the group are already in 
{data['new_state']} state",
         )
 
+    if data["new_state"] == TaskInstanceState.SUCCESS:
+        _clear_task_state_on_success(updated_tis, session)
+
     _emit_state_listener_hooks(updated_tis, data["new_state"])
 
     return updated_tis
diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index 13d8245621d..b9242583e69 100644
--- 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++ 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -41,6 +41,7 @@ from sqlalchemy.sql import select
 from structlog.contextvars import bind_contextvars
 
 from airflow._shared.observability.traces import override_ids
+from airflow._shared.state import TaskScope
 from airflow._shared.timezones import timezone
 from airflow.api_fastapi.auth.tokens import JWTGenerator
 from airflow.api_fastapi.common.dagbag import DagBagDep, 
get_latest_version_of_dag
@@ -67,6 +68,7 @@ from 
airflow.api_fastapi.execution_api.datamodels.taskinstance import (
 from airflow.api_fastapi.execution_api.datamodels.token import TIToken
 from airflow.api_fastapi.execution_api.deps import DepContainer
 from airflow.api_fastapi.execution_api.security import CurrentTIToken, 
ExecutionAPIRoute, require_auth
+from airflow.configuration import conf
 from airflow.exceptions import TaskNotFound
 from airflow.models.asset import AssetActive
 from airflow.models.dag import DagModel
@@ -78,6 +80,7 @@ from airflow.models.taskreschedule import TaskReschedule
 from airflow.models.trigger import Trigger
 from airflow.models.xcom import XComModel
 from airflow.serialization.definitions.assets import SerializedAsset, 
SerializedAssetUniqueKey
+from airflow.state import get_state_backend
 from airflow.utils.sqlalchemy import get_dialect_name
 from airflow.utils.state import DagRunState, TaskInstanceState, TerminalTIState
 
@@ -460,6 +463,31 @@ def ti_update_state(
             status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, 
detail="Database error occurred"
         )
 
+    if updated_state == TaskInstanceState.SUCCESS:
+        if conf.getboolean("state_store", "clear_on_success"):
+            scope = TaskScope(
+                dag_id=dag_id,
+                run_id=run_id,
+                task_id=task_id,
+                map_index=map_index if map_index is not None else -1,
+            )
+            try:
+                get_state_backend().clear(scope, session=session)
+                log.info(
+                    "Cleared task state on success",
+                    dag_id=dag_id,
+                    run_id=run_id,
+                    task_id=task_id,
+                    map_index=map_index,
+                )
+            except Exception:
+                log.warning(
+                    "Failed to clear task state on success",
+                    dag_id=dag_id,
+                    run_id=run_id,
+                    task_id=task_id,
+                )
+
 
 def _emit_task_span(ti, state):
     # just to be safe
diff --git a/airflow-core/src/airflow/config_templates/config.yml 
b/airflow-core/src/airflow/config_templates/config.yml
index 4b183f9c2b4..03593ce4ba0 100644
--- a/airflow-core/src/airflow/config_templates/config.yml
+++ b/airflow-core/src/airflow/config_templates/config.yml
@@ -3025,6 +3025,20 @@ state_store:
       type: string
       example: "mypackage.state.CustomStateBackend"
       default: "airflow.state.metastore.MetastoreStateBackend"
+    clear_on_success:
+      description: |
+        If set to True, all task state keys for a task instance are 
automatically cleared
+        when that task instance moved to SUCCESS.
+
+        Defaults to False so that task state persists after success for 
observability —
+        operators and the UI can inspect what the task wrote (e.g. submitted 
job IDs,
+        advanced watermarks) after the run completes.
+        Consider setting to True if you do not need post-success visibility 
and want automatic
+        cleanup without waiting for the global retention period.
+      version_added: 3.3.0
+      type: boolean
+      example: "True"
+      default: "False"
     default_retention_days:
       description: |
         Number of days to retain task state after their last update.
diff --git 
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
 
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
index 505a390f7bf..b53f2f71522 100644
--- 
a/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
+++ 
b/airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_task_instances.py
@@ -31,6 +31,7 @@ from fastapi.testclient import TestClient
 from sqlalchemy import delete, func, select, update
 from sqlalchemy.orm import joinedload
 
+from airflow._shared.state import TaskScope
 from airflow._shared.timezones.timezone import datetime
 from airflow.api_fastapi.auth.managers.simple.user import SimpleAuthManagerUser
 from airflow.dag_processing.bundles.manager import DagBundlesManager
@@ -41,12 +42,14 @@ from airflow.models import DagModel, DagRun, Log, 
TaskInstance
 from airflow.models.dag_version import DagVersion
 from airflow.models.dagbundle import DagBundleModel
 from airflow.models.renderedtifields import RenderedTaskInstanceFields as RTIF
+from airflow.models.task_state import TaskStateModel
 from airflow.models.taskinstancehistory import TaskInstanceHistory
 from airflow.models.taskmap import TaskMap
 from airflow.models.team import Team
 from airflow.models.trigger import Trigger
 from airflow.providers.standard.operators.empty import EmptyOperator
 from airflow.sdk import BaseOperator, TaskGroup
+from airflow.state.metastore import MetastoreStateBackend
 from airflow.utils.platform import getuser
 from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.types import DagRunType
@@ -5160,6 +5163,77 @@ class TestPatchTaskInstance(TestTaskInstanceEndpoint):
         assert response.status_code == 409
         assert "Task id print_the_context is already in success state" in 
response.text
 
+    @pytest.mark.db_test
+    @conf_vars({("state_store", "clear_on_success"): "True"})
+    def test_patch_task_instance_to_success_clears_task_state(self, 
test_client, session):
+        """When clear_on_success=True, task_state rows are deleted after 
manual mark-as-success."""
+        self.create_task_instances(session)
+        ti = session.scalars(
+            select(TaskInstance).where(
+                TaskInstance.dag_id == self.DAG_ID,
+                TaskInstance.task_id == self.TASK_ID,
+                TaskInstance.run_id == self.RUN_ID,
+            )
+        ).one()
+
+        backend = MetastoreStateBackend()
+        scope = TaskScope(dag_id=ti.dag_id, run_id=ti.run_id, 
task_id=ti.task_id, map_index=ti.map_index)
+        backend.set(scope, "job_id", "app_1234", session=session)
+        session.commit()
+
+        assert 
session.scalars(select(TaskStateModel).where(TaskStateModel.task_id == 
self.TASK_ID)).all()
+
+        test_client.patch(self.ENDPOINT_URL, json={"new_state": "success"})
+
+        session.expire_all()
+        assert not 
session.scalars(select(TaskStateModel).where(TaskStateModel.task_id == 
self.TASK_ID)).all()
+
+    @pytest.mark.db_test
+    @conf_vars({("state_store", "clear_on_success"): "True"})
+    def test_patch_task_instance_to_failed_does_not_clear_task_state(self, 
test_client, session):
+        """Task state rows are preserved when manually marking a TI as 
FAILED."""
+        self.create_task_instances(session)
+        ti = session.scalars(
+            select(TaskInstance).where(
+                TaskInstance.dag_id == self.DAG_ID,
+                TaskInstance.task_id == self.TASK_ID,
+                TaskInstance.run_id == self.RUN_ID,
+            )
+        ).one()
+
+        backend = MetastoreStateBackend()
+        scope = TaskScope(dag_id=ti.dag_id, run_id=ti.run_id, 
task_id=ti.task_id, map_index=ti.map_index)
+        backend.set(scope, "job_id", "app_1234", session=session)
+        session.commit()
+
+        test_client.patch(self.ENDPOINT_URL, json={"new_state": "failed"})
+
+        session.expire_all()
+        assert 
session.scalars(select(TaskStateModel).where(TaskStateModel.task_id == 
self.TASK_ID)).all()
+
+    @pytest.mark.db_test
+    @conf_vars({("state_store", "clear_on_success"): "False"})
+    def 
test_patch_task_instance_to_success_skips_clear_when_config_disabled(self, 
test_client, session):
+        """Task state rows are preserved on manual mark-as-success when 
clear_on_success=False."""
+        self.create_task_instances(session)
+        ti = session.scalars(
+            select(TaskInstance).where(
+                TaskInstance.dag_id == self.DAG_ID,
+                TaskInstance.task_id == self.TASK_ID,
+                TaskInstance.run_id == self.RUN_ID,
+            )
+        ).one()
+
+        backend = MetastoreStateBackend()
+        scope = TaskScope(dag_id=ti.dag_id, run_id=ti.run_id, 
task_id=ti.task_id, map_index=ti.map_index)
+        backend.set(scope, "job_id", "app_1234", session=session)
+        session.commit()
+
+        test_client.patch(self.ENDPOINT_URL, json={"new_state": "success"})
+
+        session.expire_all()
+        assert 
session.scalars(select(TaskStateModel).where(TaskStateModel.task_id == 
self.TASK_ID)).all()
+
 
 class TestPatchTaskInstanceDryRun(TestTaskInstanceEndpoint):
     ENDPOINT_URL = 
"/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context"
diff --git 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
index d26bcf7bfd8..6f19cff9389 100644
--- 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
+++ 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
@@ -36,6 +36,7 @@ from sqlalchemy.exc import SQLAlchemyError
 from sqlalchemy.orm import Session
 
 from airflow._shared.observability.traces import OverrideableRandomIdGenerator
+from airflow._shared.state import TaskScope
 from airflow._shared.timezones import timezone
 from airflow.api_fastapi.auth.tokens import JWTGenerator, JWTValidator
 from airflow.api_fastapi.execution_api.app import lifespan
@@ -47,10 +48,12 @@ from airflow.models import RenderedTaskInstanceFields, 
TaskReschedule, Trigger
 from airflow.models.asset import AssetActive, AssetAliasModel, AssetEvent, 
AssetModel
 from airflow.models.dag import DagModel
 from airflow.models.log import Log
+from airflow.models.task_state import TaskStateModel
 from airflow.models.taskinstance import TaskInstance
 from airflow.models.taskinstancehistory import TaskInstanceHistory
 from airflow.providers.standard.operators.empty import EmptyOperator
 from airflow.sdk import Asset, TaskGroup, TriggerRule, task, task_group
+from airflow.state.metastore import MetastoreStateBackend
 from airflow.utils.state import DagRunState, State, TaskInstanceState, 
TerminalTIState
 
 from tests_common.test_utils.config import conf_vars
@@ -1891,6 +1894,86 @@ class TestTIUpdateState:
         ti1 = session.get(TaskInstance, ti1.id)
         assert ti1.state == State.FAILED
 
+    @pytest.mark.db_test
+    @conf_vars({("state_store", "clear_on_success"): "True"})
+    def test_ti_update_state_to_success_clears_task_state(self, client, 
session, create_task_instance):
+        """When clear_on_success=True, task_state rows are deleted after TI 
transitions to SUCCESS."""
+        ti = create_task_instance(
+            task_id="test_clear_on_success",
+            start_date=DEFAULT_START_DATE,
+            state=State.RUNNING,
+        )
+        session.commit()
+
+        backend = MetastoreStateBackend()
+        scope = TaskScope(dag_id=ti.dag_id, run_id=ti.run_id, 
task_id=ti.task_id, map_index=ti.map_index)
+        backend.set(scope, "job_id", "app_1234", session=session)
+        backend.set(scope, "checkpoint", "step_3", session=session)
+        session.commit()
+
+        assert 
session.scalars(select(TaskStateModel).where(TaskStateModel.task_id == 
ti.task_id)).all()
+
+        response = client.patch(
+            f"/execution/task-instances/{ti.id}/state",
+            json={"state": "success", "end_date": 
DEFAULT_END_DATE.isoformat()},
+        )
+
+        assert response.status_code == 204
+        session.expire_all()
+        assert not 
session.scalars(select(TaskStateModel).where(TaskStateModel.task_id == 
ti.task_id)).all()
+
+    @pytest.mark.db_test
+    @conf_vars({("state_store", "clear_on_success"): "True"})
+    def test_ti_update_state_to_failed_does_not_clear_task_state(self, client, 
session, create_task_instance):
+        """Task state rows are preserved when a TI transitions to FAILED."""
+        ti = create_task_instance(
+            task_id="test_no_clear_on_failed",
+            start_date=DEFAULT_START_DATE,
+            state=State.RUNNING,
+        )
+        session.commit()
+
+        backend = MetastoreStateBackend()
+        scope = TaskScope(dag_id=ti.dag_id, run_id=ti.run_id, 
task_id=ti.task_id, map_index=ti.map_index)
+        backend.set(scope, "job_id", "app_1234", session=session)
+        session.commit()
+
+        response = client.patch(
+            f"/execution/task-instances/{ti.id}/state",
+            json={"state": "failed", "end_date": DEFAULT_END_DATE.isoformat()},
+        )
+
+        assert response.status_code == 204
+        session.expire_all()
+        assert 
session.scalars(select(TaskStateModel).where(TaskStateModel.task_id == 
ti.task_id)).all()
+
+    @pytest.mark.db_test
+    @conf_vars({("state_store", "clear_on_success"): "False"})
+    def test_ti_update_state_to_success_skips_clear_when_config_disabled(
+        self, client, session, create_task_instance
+    ):
+        """Task state rows are preserved on SUCCESS when 
clear_on_success=False."""
+        ti = create_task_instance(
+            task_id="test_clear_disabled",
+            start_date=DEFAULT_START_DATE,
+            state=State.RUNNING,
+        )
+        session.commit()
+
+        backend = MetastoreStateBackend()
+        scope = TaskScope(dag_id=ti.dag_id, run_id=ti.run_id, 
task_id=ti.task_id, map_index=ti.map_index)
+        backend.set(scope, "job_id", "app_1234", session=session)
+        session.commit()
+
+        response = client.patch(
+            f"/execution/task-instances/{ti.id}/state",
+            json={"state": "success", "end_date": 
DEFAULT_END_DATE.isoformat()},
+        )
+
+        assert response.status_code == 204
+        session.expire_all()
+        assert 
session.scalars(select(TaskStateModel).where(TaskStateModel.task_id == 
ti.task_id)).all()
+
 
 class TestTISkipDownstream:
     def setup_method(self):
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 2cb336c8ea7..ef797ca558d 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -1440,6 +1440,10 @@ def _handle_current_task_success(
 
     task_outlets = list(_build_asset_profiles(ti.task.outlets))
     outlet_events = list(_serialize_outlet_events(context["outlet_events"]))
+
+    if conf.getboolean("state_store", "clear_on_success"):
+        log.info("Task state will be cleared by the server because 
clear_on_success is enabled.")
+
     msg = SucceedTask(
         end_date=end_date,
         task_outlets=task_outlets,

Reply via email to