This is an automated email from the ASF dual-hosted git repository.
amoghrajesh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 0cf6462459e AIP-103: Adding periodic task state garbage collection and
retention support (#66463)
0cf6462459e is described below
commit 0cf6462459e6f1dae463aedfa38b52fef31dbbb3
Author: Amogh Desai <[email protected]>
AuthorDate: Thu May 14 11:58:46 2026 +0530
AIP-103: Adding periodic task state garbage collection and retention
support (#66463)
---
airflow-core/src/airflow/cli/cli_config.py | 19 +++
.../airflow/cli/commands/state_store_command.py | 49 ++++++++
.../src/airflow/config_templates/config.yml | 18 +++
.../src/airflow/jobs/scheduler_job_runner.py | 32 ++++-
..._3_3_0_add_task_state_and_asset_state_tables.py | 7 +-
airflow-core/src/airflow/models/task_state.py | 21 +++-
airflow-core/src/airflow/state/metastore.py | 73 ++++++++++-
.../unit/cli/commands/test_state_store_command.py | 65 ++++++++++
airflow-core/tests/unit/state/test_metastore.py | 140 ++++++++++++++++++++-
shared/state/src/airflow_shared/state/__init__.py | 9 ++
10 files changed, 421 insertions(+), 12 deletions(-)
diff --git a/airflow-core/src/airflow/cli/cli_config.py
b/airflow-core/src/airflow/cli/cli_config.py
index 4c44ab39d67..81b9dcf0600 100644
--- a/airflow-core/src/airflow/cli/cli_config.py
+++ b/airflow-core/src/airflow/cli/cli_config.py
@@ -1531,6 +1531,20 @@ TEAMS_COMMANDS = (
args=(ARG_VERBOSE,),
),
)
+STATE_STORE_COMMANDS = (
+ ActionCommand(
+ name="cleanup-task-states",
+ help="Remove expired task state rows (MetastoreStateBackend only)",
+ description=(
+ "Reads [state_store] default_retention_days from config and
deletes task_state rows "
+ "older than the configured threshold. Only applies when
MetastoreStateBackend is configured; "
+ "custom backends are skipped. Use --dry-run to preview without
deleting."
+ ),
+
func=lazy_load_command("airflow.cli.commands.state_store_command.cleanup_task_states"),
+ args=(ARG_DB_DRY_RUN, ARG_VERBOSE),
+ ),
+)
+
DB_COMMANDS = (
ActionCommand(
name="check-migrations",
@@ -2115,6 +2129,11 @@ core_commands: list[CLICommand] = [
help="Display providers",
subcommands=PROVIDERS_COMMANDS,
),
+ GroupCommand(
+ name="state-store",
+ help="Manage task and asset state storage",
+ subcommands=STATE_STORE_COMMANDS,
+ ),
ActionCommand(
name="rotate-fernet-key",
func=lazy_load_command("airflow.cli.commands.rotate_fernet_key_command.rotate_fernet_key"),
diff --git a/airflow-core/src/airflow/cli/commands/state_store_command.py
b/airflow-core/src/airflow/cli/commands/state_store_command.py
new file mode 100644
index 00000000000..52bd0952561
--- /dev/null
+++ b/airflow-core/src/airflow/cli/commands/state_store_command.py
@@ -0,0 +1,49 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import logging
+
+from airflow.state import get_state_backend
+from airflow.state.metastore import MetastoreStateBackend
+
+log = logging.getLogger(__name__)
+
+# Other state operations (list, get, delete per key) will be added here in the
future.
+
+
+def cleanup_task_states(args) -> None:
+ """Remove expired task state rows (MetastoreStateBackend only)."""
+ backend = get_state_backend()
+
+ if not isinstance(backend, MetastoreStateBackend):
+ print("Custom backend configured — skipping cleanup (not supported).")
+ return
+
+ if args.dry_run:
+ summary = backend._summary_dry_run()
+ expired = summary["expired"]
+ if not expired:
+ print("Nothing to delete.")
+ return
+ print(f"Would delete {len(expired)} task state row(s):\n")
+ for dag_id, run_id, task_id, map_index, key in expired:
+ print(f" Dag {dag_id!r}, run {run_id!r}, task {task_id!r},
map_index {map_index!r}, key {key!r}")
+ return
+
+ log.info("Running task state cleanup")
+ backend.cleanup()
diff --git a/airflow-core/src/airflow/config_templates/config.yml
b/airflow-core/src/airflow/config_templates/config.yml
index 8d5d6e5fd26..4b183f9c2b4 100644
--- a/airflow-core/src/airflow/config_templates/config.yml
+++ b/airflow-core/src/airflow/config_templates/config.yml
@@ -3025,6 +3025,24 @@ state_store:
type: string
example: "mypackage.state.CustomStateBackend"
default: "airflow.state.metastore.MetastoreStateBackend"
+ default_retention_days:
+ description: |
+ Number of days to retain task state after their last update.
+ Rows older than this are removed when cleanup is triggered.
+ This config does not affect asset_state rows.
+ Set to 0 to disable time-based cleanup entirely.
+ version_added: 3.3.0
+ type: integer
+ example: "7"
+ default: "30"
+ state_cleanup_batch_size:
+ description: |
+ Number of rows deleted per batch during cleanup. Defaults to 0 (no
batching).
+ Tune this on deployments with large task_state tables to improve
performance per transaction.
+ version_added: 3.3.0
+ type: integer
+ example: "10000"
+ default: "0"
profiling:
description: |
diff --git a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
index 1a3f55b7f6f..9a650b110c9 100644
--- a/airflow-core/src/airflow/jobs/scheduler_job_runner.py
+++ b/airflow-core/src/airflow/jobs/scheduler_job_runner.py
@@ -33,7 +33,20 @@ from functools import lru_cache, partial
from itertools import groupby
from typing import TYPE_CHECKING, Any, cast
-from sqlalchemy import CTE, and_, case, delete, exists, func, inspect, or_,
select, text, tuple_, update
+from sqlalchemy import (
+ CTE,
+ and_,
+ case,
+ delete,
+ exists,
+ func,
+ inspect,
+ or_,
+ select,
+ text,
+ tuple_,
+ update,
+)
from sqlalchemy.exc import DBAPIError, OperationalError
from sqlalchemy.orm import joinedload, lazyload, load_only, make_transient,
selectinload
from sqlalchemy.sql import expression
@@ -70,6 +83,7 @@ from airflow.models.asset import (
TaskInletAssetReference,
TaskOutletAssetReference,
)
+from airflow.models.asset_state import AssetStateModel
from airflow.models.backfill import Backfill, BackfillDagRun
from airflow.models.callback import Callback, CallbackType, ExecutorCallback
from airflow.models.dag import DagModel
@@ -3096,6 +3110,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
self._orphan_unreferenced_assets(orphan_query, session=session)
self._activate_referenced_assets(activate_query, session=session)
+ self._cleanup_orphaned_asset_state(session=session)
@staticmethod
def _orphan_unreferenced_assets(assets_query: CTE, *, session: Session) ->
None:
@@ -3204,6 +3219,21 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
session.add(warning)
existing_warned_dag_ids.add(warning.dag_id)
+ @staticmethod
+ def _cleanup_orphaned_asset_state(*, session: Session) -> None:
+ """
+ Delete asset_state rows for assets no longer active in any Dag.
+
+ When _orphan_unreferenced_assets removes an asset from asset_active,
its
+ asset_state rows become unreachable — no task can write to them
anymore.
+ This runs in the same pass as asset orphanage to keep the table clean.
+ """
+ active_asset_ids = select(AssetModel.id).join(
+ AssetActive,
+ (AssetActive.name == AssetModel.name) & (AssetActive.uri ==
AssetModel.uri),
+ )
+
session.execute(delete(AssetStateModel).where(AssetStateModel.asset_id.not_in(active_asset_ids)))
+
def _executor_to_workloads(
self,
workloads: Iterable[SchedulerWorkload],
diff --git
a/airflow-core/src/airflow/migrations/versions/0112_3_3_0_add_task_state_and_asset_state_tables.py
b/airflow-core/src/airflow/migrations/versions/0112_3_3_0_add_task_state_and_asset_state_tables.py
index 7f852d05c6c..e64f80a05b1 100644
---
a/airflow-core/src/airflow/migrations/versions/0112_3_3_0_add_task_state_and_asset_state_tables.py
+++
b/airflow-core/src/airflow/migrations/versions/0112_3_3_0_add_task_state_and_asset_state_tables.py
@@ -57,6 +57,7 @@ def upgrade():
)
op.create_table(
"task_state",
+ sa.Column("id", sa.Integer(), nullable=False, autoincrement=True),
sa.Column("dag_run_id", sa.Integer(), nullable=False),
sa.Column("task_id", StringID(), nullable=False),
sa.Column("map_index", sa.Integer(), server_default="-1",
nullable=False),
@@ -65,20 +66,24 @@ def upgrade():
sa.Column("run_id", StringID(), nullable=False),
sa.Column("value", sa.Text().with_variant(mysql.MEDIUMTEXT(),
"mysql"), nullable=False),
sa.Column("updated_at", UtcDateTime(), nullable=False),
+ sa.Column("expires_at", UtcDateTime(), nullable=True),
sa.ForeignKeyConstraint(
["dag_run_id"], ["dag_run.id"], name="task_state_dag_run_fkey",
ondelete="CASCADE"
),
- sa.PrimaryKeyConstraint("dag_run_id", "task_id", "map_index", "key",
name="task_state_pkey"),
+ sa.PrimaryKeyConstraint("id", name="task_state_pkey"),
+ sa.UniqueConstraint("dag_run_id", "task_id", "map_index", "key",
name="task_state_uq"),
)
with op.batch_alter_table("task_state", schema=None) as batch_op:
batch_op.create_index(
"idx_task_state_lookup", ["dag_id", "run_id", "task_id",
"map_index"], unique=False
)
+ batch_op.create_index("idx_task_state_expires_at", ["expires_at"],
unique=False)
def downgrade():
"""Unapply add task_state and asset_state tables."""
with op.batch_alter_table("task_state", schema=None) as batch_op:
+ batch_op.drop_index("idx_task_state_expires_at")
batch_op.drop_index("idx_task_state_lookup")
op.drop_table("task_state")
diff --git a/airflow-core/src/airflow/models/task_state.py
b/airflow-core/src/airflow/models/task_state.py
index dbc17e3b069..72a7624eddd 100644
--- a/airflow-core/src/airflow/models/task_state.py
+++ b/airflow-core/src/airflow/models/task_state.py
@@ -19,7 +19,7 @@ from __future__ import annotations
from datetime import datetime
-from sqlalchemy import ForeignKeyConstraint, Index, Integer,
PrimaryKeyConstraint, String, Text
+from sqlalchemy import ForeignKeyConstraint, Index, Integer, String, Text,
UniqueConstraint
from sqlalchemy.dialects.mysql import MEDIUMTEXT
from sqlalchemy.orm import Mapped, mapped_column
@@ -39,19 +39,27 @@ class TaskStateModel(Base):
__tablename__ = "task_state"
- dag_run_id: Mapped[int] = mapped_column(Integer, nullable=False,
primary_key=True)
- task_id: Mapped[str] = mapped_column(StringID(), nullable=False,
primary_key=True)
- map_index: Mapped[int] = mapped_column(Integer, primary_key=True,
nullable=False, server_default="-1")
- key: Mapped[str] = mapped_column(String(512, **COLLATION_ARGS),
nullable=False, primary_key=True)
+ id: Mapped[int] = mapped_column(Integer, primary_key=True,
autoincrement=True)
+
+ dag_run_id: Mapped[int] = mapped_column(Integer, nullable=False)
+ task_id: Mapped[str] = mapped_column(StringID(), nullable=False)
+ map_index: Mapped[int] = mapped_column(Integer, nullable=False,
server_default="-1")
+ key: Mapped[str] = mapped_column(String(512, **COLLATION_ARGS),
nullable=False)
dag_id: Mapped[str] = mapped_column(StringID(), nullable=False)
run_id: Mapped[str] = mapped_column(StringID(), nullable=False)
value: Mapped[str] = mapped_column(Text().with_variant(MEDIUMTEXT,
"mysql"), nullable=False)
updated_at: Mapped[datetime] = mapped_column(UtcDateTime,
default=timezone.utcnow, nullable=False)
+ # Optional override for early expiry. When set, garbage collection deletes
this row when
+ # expires_at < now(), even if updated_at is recent. NULL means no early
expiry —
+ # the row is still cleaned up by the global `updated_at +
default_retention_days` check.
+ # Populated via task_state.set(retention_days=N) for keys that should
expire differently
+ # than the deployment wide default.
+ expires_at: Mapped[datetime | None] = mapped_column(UtcDateTime,
nullable=True)
__table_args__ = (
- PrimaryKeyConstraint("dag_run_id", "task_id", "map_index", "key",
name="task_state_pkey"),
+ UniqueConstraint("dag_run_id", "task_id", "map_index", "key",
name="task_state_uq"),
ForeignKeyConstraint(
["dag_run_id"],
["dag_run.id"],
@@ -59,4 +67,5 @@ class TaskStateModel(Base):
ondelete="CASCADE",
),
Index("idx_task_state_lookup", "dag_id", "run_id", "task_id",
"map_index"),
+ Index("idx_task_state_expires_at", "expires_at"),
)
diff --git a/airflow-core/src/airflow/state/metastore.py
b/airflow-core/src/airflow/state/metastore.py
index 31b4de3158f..f58c69f5808 100644
--- a/airflow-core/src/airflow/state/metastore.py
+++ b/airflow-core/src/airflow/state/metastore.py
@@ -19,17 +19,20 @@ from __future__ import annotations
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
+from datetime import datetime, timedelta
from typing import TYPE_CHECKING
+import structlog
from sqlalchemy import delete, select
from airflow._shared.state import AssetScope, BaseStateBackend, StateScope,
TaskScope
from airflow._shared.timezones import timezone
+from airflow.configuration import conf
from airflow.models.asset_state import AssetStateModel
from airflow.models.dagrun import DagRun
from airflow.models.task_state import TaskStateModel
from airflow.typing_compat import assert_never
-from airflow.utils.session import NEW_SESSION, create_session_async,
provide_session
+from airflow.utils.session import NEW_SESSION, create_session,
create_session_async, provide_session
from airflow.utils.sqlalchemy import get_dialect_name
if TYPE_CHECKING:
@@ -40,6 +43,21 @@ if TYPE_CHECKING:
from sqlalchemy.orm import Session
+log = structlog.get_logger(__name__)
+
+
+def _compute_expires_at(now: datetime) -> datetime | None:
+ """
+ Return the expiry timestamp for a new task state row based on config.
+
+ Returns None if default_retention_days is 0 (never expires).
+ """
+ retention_days = conf.getint("state_store", "default_retention_days")
+ if retention_days <= 0:
+ return None
+ return now + timedelta(days=retention_days)
+
+
@asynccontextmanager
async def _async_session(session: AsyncSession | None) ->
AsyncGenerator[AsyncSession, None]:
"""Use provided async session or create a new one."""
@@ -200,6 +218,7 @@ class MetastoreStateBackend(BaseStateBackend):
if dag_run_id is None:
raise ValueError(f"No DagRun found for dag_id={scope.dag_id!r}
run_id={scope.run_id!r}")
now = timezone.utcnow()
+ expires_at = _compute_expires_at(now)
values = dict(
dag_run_id=dag_run_id,
dag_id=scope.dag_id,
@@ -209,13 +228,14 @@ class MetastoreStateBackend(BaseStateBackend):
key=key,
value=value,
updated_at=now,
+ expires_at=expires_at,
)
stmt = _build_upsert_stmt(
get_dialect_name(session),
TaskStateModel,
["dag_run_id", "task_id", "map_index", "key"],
values,
- dict(value=value, updated_at=now),
+ dict(value=value, updated_at=now, expires_at=expires_at),
)
session.execute(stmt)
@@ -276,6 +296,51 @@ class MetastoreStateBackend(BaseStateBackend):
)
)
+ def cleanup(self) -> None:
+ """
+ Remove expired task state rows.
+
+ ``expires_at`` is set at write time on every ``set()`` call, so
cleanup is a single
+ ``WHERE expires_at < now()`` pass. Rows with ``expires_at=NULL``
(default_retention_days=0)
+ are never deleted. Batching is configurable via ``[state_store]
state_cleanup_batch_size``.
+ """
+ batch_size = conf.getint("state_store", "state_cleanup_batch_size")
+ now = timezone.utcnow()
+
+ def _delete_batched(where_clause) -> int:
+ total = 0
+ with create_session() as session:
+ while True:
+ id_query = select(TaskStateModel.id).where(where_clause)
+ if batch_size > 0:
+ id_query = id_query.limit(batch_size)
+ ids = session.scalars(id_query).all()
+ if not ids:
+ break
+
session.execute(delete(TaskStateModel).where(TaskStateModel.id.in_(ids)))
+ session.commit()
+ total += len(ids)
+ if batch_size <= 0 or len(ids) < batch_size:
+ break
+ return total
+
+ deleted = _delete_batched(TaskStateModel.expires_at < now)
+ log.info("Deleted expired task_state rows", rows_deleted=deleted)
+
+ def _summary_dry_run(self) -> dict[str, list]:
+ """Return rows that would be deleted by cleanup() without deleting
anything."""
+ now = timezone.utcnow()
+ cols = (
+ TaskStateModel.dag_id,
+ TaskStateModel.run_id,
+ TaskStateModel.task_id,
+ TaskStateModel.map_index,
+ TaskStateModel.key,
+ )
+ with create_session() as session:
+ expired =
session.execute(select(*cols).where(TaskStateModel.expires_at < now)).all()
+ return {"expired": list(expired)}
+
async def _aget_task_state(self, scope: TaskScope, key: str, *, session:
AsyncSession) -> str | None:
row = await session.scalar(
select(TaskStateModel).where(
@@ -300,6 +365,7 @@ class MetastoreStateBackend(BaseStateBackend):
if dag_run_id is None:
raise ValueError(f"No DagRun found for dag_id={scope.dag_id!r}
run_id={scope.run_id!r}")
now = timezone.utcnow()
+ expires_at = _compute_expires_at(now)
values = dict(
dag_run_id=dag_run_id,
dag_id=scope.dag_id,
@@ -309,6 +375,7 @@ class MetastoreStateBackend(BaseStateBackend):
key=key,
value=value,
updated_at=now,
+ expires_at=expires_at,
)
# get_dialect_name expects a sync Session; sync_session is the
underlying Session the async wrapper delegates to
stmt = _build_upsert_stmt(
@@ -316,7 +383,7 @@ class MetastoreStateBackend(BaseStateBackend):
TaskStateModel,
["dag_run_id", "task_id", "map_index", "key"],
values,
- dict(value=value, updated_at=now),
+ dict(value=value, updated_at=now, expires_at=expires_at),
)
await session.execute(stmt)
diff --git a/airflow-core/tests/unit/cli/commands/test_state_store_command.py
b/airflow-core/tests/unit/cli/commands/test_state_store_command.py
new file mode 100644
index 00000000000..e4b44eee13f
--- /dev/null
+++ b/airflow-core/tests/unit/cli/commands/test_state_store_command.py
@@ -0,0 +1,65 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from argparse import Namespace
+from unittest import mock
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from airflow.cli.commands.state_store_command import cleanup_task_states
+from airflow.state.metastore import MetastoreStateBackend
+
+pytestmark = pytest.mark.db_test
+
+
+class TestStateStoreCleanupCommand:
+ def test_cleanup_calls_backend(self):
+ args = Namespace(dry_run=False, verbose=False)
+ backend = MetastoreStateBackend()
+ with (
+
mock.patch("airflow.cli.commands.state_store_command.get_state_backend",
return_value=backend),
+ patch.object(backend, "cleanup"),
+ ):
+ cleanup_task_states(args)
+
+ backend.cleanup.assert_called_once_with()
+
+ def test_dry_run_does_not_call_backend(self, capsys):
+ args = Namespace(dry_run=True, verbose=False)
+ backend = MetastoreStateBackend()
+ with (
+
mock.patch("airflow.cli.commands.state_store_command.get_state_backend",
return_value=backend),
+ patch.object(backend, "_summary_dry_run", return_value={"expired":
[]}),
+ ):
+ cleanup_task_states(args)
+
+ captured = capsys.readouterr()
+ assert "Nothing to delete" in captured.out
+
+ def test_custom_backend_is_skipped(self, capsys):
+ args = Namespace(dry_run=False, verbose=False)
+ custom_backend = MagicMock(spec=[])
+ with mock.patch(
+ "airflow.cli.commands.state_store_command.get_state_backend",
return_value=custom_backend
+ ):
+ cleanup_task_states(args)
+
+ captured = capsys.readouterr()
+ assert "Custom backend configured" in captured.out
+ assert not hasattr(custom_backend, "cleanup") or not
custom_backend.cleanup.called
diff --git a/airflow-core/tests/unit/state/test_metastore.py
b/airflow-core/tests/unit/state/test_metastore.py
index dfd154cc92a..d9e1ff33afd 100644
--- a/airflow-core/tests/unit/state/test_metastore.py
+++ b/airflow-core/tests/unit/state/test_metastore.py
@@ -17,13 +17,18 @@
# under the License.
from __future__ import annotations
+from contextlib import contextmanager
+from datetime import timedelta
from typing import TYPE_CHECKING
+from unittest.mock import patch
import pytest
-from sqlalchemy import select
+from sqlalchemy import Delete, select
from airflow._shared.timezones import timezone
+from airflow.configuration import conf
from airflow.models.asset import AssetModel
+from airflow.models.asset_state import AssetStateModel
from airflow.models.dagrun import DagRun, DagRunType
from airflow.models.task_state import TaskStateModel
from airflow.state import AssetScope, TaskScope, resolve_state_backend
@@ -234,6 +239,113 @@ class TestMetastoreStateBackendTaskScope:
assert backend.get(scope0, "job_id", session=session) is None
assert backend.get(scope1, "job_id", session=session) is None
+ def test_set_populates_expires_at(
+ self, session: Session, backend: MetastoreStateBackend, dag_run: DagRun
+ ):
+ """set() always populates expires_at so cleanup has a single pass."""
+ scope = TaskScope(dag_id=DAG_ID, run_id=RUN_ID, task_id=TASK_ID)
+ backend.set(scope, "job_id", "app_1234", session=session)
+ session.flush()
+
+ row = session.scalar(select(TaskStateModel).where(TaskStateModel.key
== "job_id"))
+ assert row is not None
+ assert row.expires_at is not None
+ assert row.expires_at > row.updated_at
+
+ def test_cleanup_removes_expired_rows(
+ self, session: Session, backend: MetastoreStateBackend, dag_run: DagRun
+ ):
+ scope = TaskScope(dag_id=DAG_ID, run_id=RUN_ID, task_id=TASK_ID)
+ backend.set(scope, "old_key", "old_value", session=session)
+ backend.set(scope, "new_key", "new_value", session=session)
+ session.flush()
+
+ # Backdate expires_at on old_key to simulate it having expired
+ old_row = session.scalar(
+ select(TaskStateModel).where(TaskStateModel.dag_id == DAG_ID,
TaskStateModel.key == "old_key")
+ )
+ assert old_row is not None
+ old_row.expires_at = timezone.utcnow() - timedelta(hours=1)
+ session.flush()
+ session.commit()
+
+ backend.cleanup()
+
+ session.expire_all()
+ assert session.scalar(select(TaskStateModel).where(TaskStateModel.key
== "old_key")) is None
+ assert session.scalar(select(TaskStateModel).where(TaskStateModel.key
== "new_key")) is not None
+
+ def test_cleanup_removes_expires_at_rows(
+ self, session: Session, backend: MetastoreStateBackend, dag_run: DagRun
+ ):
+ scope = TaskScope(dag_id=DAG_ID, run_id=RUN_ID, task_id=TASK_ID)
+ backend.set(scope, "short_lived", "value", session=session)
+ session.flush()
+
+ row = session.scalar(
+ select(TaskStateModel).where(TaskStateModel.dag_id == DAG_ID,
TaskStateModel.key == "short_lived")
+ )
+ assert row is not None
+ row.expires_at = timezone.utcnow() - timedelta(hours=1)
+ session.flush()
+ session.commit()
+
+ backend.cleanup()
+
+ session.expire_all()
+
+ # cleaned up via expires_at, even though updated_at is recent
+ assert session.scalar(select(TaskStateModel).where(TaskStateModel.key
== "short_lived")) is None
+
+ @conf_vars({("state_store", "state_cleanup_batch_size"): "2"})
+ def test_cleanup_batches_deletes(self, session: Session, backend:
MetastoreStateBackend, dag_run: DagRun):
+ """cleanup() issues one DELETE per batch, not one DELETE for all rows
at once.
+
+ Verifying this is not straightforward because cleanup() creates its
own internal session,
+ so we cannot simply inspect it from outside, so what we do is:
+
+ 1. Patch `create_session` in the metastore module with a thin wrapper
(`tracking_cs`) that
+ yields the real session but replaces `session.execute` with a spy.
+ 2. The spy checks whether the statement being executed is a sqla
Delete object and
+ records it if so.
+ 3. After cleanup() returns, we assert that exactly ceil(<number of
rows>/<batch size>).
+ """
+ import airflow.state.metastore as metastore_mod
+
+ scope = TaskScope(dag_id=DAG_ID, run_id=RUN_ID, task_id=TASK_ID)
+ for key in ("k1", "k2", "k3", "k4", "k5"):
+ backend.set(scope, key, "v", session=session)
+ session.flush()
+
+ session.execute(
+
TaskStateModel.__table__.update().values(expires_at=timezone.utcnow() -
timedelta(hours=1))
+ )
+ session.commit()
+
+ deletes = []
+ original_cs = metastore_mod.create_session
+
+ @contextmanager
+ def tracking_cs(*args, **kwargs):
+ with original_cs(*args, **kwargs) as s:
+ orig_execute = s.execute
+
+ def tracked(stmt, *a, **kw):
+ if isinstance(stmt, Delete):
+ deletes.append(stmt)
+ return orig_execute(stmt, *a, **kw)
+
+ s.execute = tracked
+ yield s
+
+ with patch.object(metastore_mod, "create_session",
side_effect=tracking_cs):
+ backend.cleanup()
+
+ session.expire_all()
+
+ # batch_size=2, 5 rows -> delete runs 3 times (2+2+1)
+ assert len(deletes) == 3
+
class TestMetastoreStateBackendAssetScope:
def test_get_returns_none_for_missing_key(
@@ -306,6 +418,19 @@ class TestMetastoreStateBackendAssetScope:
assert backend.get(scope2, "watermark", session=session) is None
+ def test_cleanup_does_not_touch_asset_state(
+ self, session: Session, backend: MetastoreStateBackend, asset:
AssetModel
+ ):
+ scope = AssetScope(asset_id=asset.id)
+ backend.set(scope, "watermark", "2026-01-01", session=session)
+ session.flush()
+ session.commit()
+
+ backend.cleanup()
+
+ session.expire_all()
+ assert
session.scalar(select(AssetStateModel).where(AssetStateModel.asset_id ==
asset.id)) is not None
+
@pytest.mark.asyncio(loop_scope="class")
class TestMetastoreStateBackendAsync:
@@ -390,6 +515,19 @@ class TestMetastoreStateBackendAsync:
assert result == "app_with_session"
+class TestStateStoreConfig:
+ def test_defaults(self):
+ assert conf.getint("state_store", "default_retention_days") == 30
+ assert conf.getint("state_store", "state_cleanup_batch_size") == 0
+
+ @conf_vars(
+ {("state_store", "default_retention_days"): "7", ("state_store",
"state_cleanup_batch_size"): "50"}
+ )
+ def test_overrides(self):
+ assert conf.getint("state_store", "default_retention_days") == 7
+ assert conf.getint("state_store", "state_cleanup_batch_size") == 50
+
+
class TestResolveStateBackend:
@conf_vars({("state_store", "backend"):
"airflow.state.metastore.MetastoreStateBackend"})
def test_resolve_returns_configured_backend(self):
diff --git a/shared/state/src/airflow_shared/state/__init__.py
b/shared/state/src/airflow_shared/state/__init__.py
index 4920f66ae67..e231bdfd3bd 100644
--- a/shared/state/src/airflow_shared/state/__init__.py
+++ b/shared/state/src/airflow_shared/state/__init__.py
@@ -157,3 +157,12 @@ class BaseStateBackend(ABC):
``session`` is optional. If provided, implementations should use it
directly.
If ``None``, implementations manage their own async session internally.
"""
+
+ def cleanup(self) -> None:
+ """
+ Remove expired and orphaned state records.
+
+ This is a no-op by default. Custom backends override this to implement
their own
+ retention policy. The backend is responsible for reading any relevant
config (e.g.
+ ``[state_store] default_retention_days``) and deciding what to delete.
+ """