This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch v3-2-test
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/v3-2-test by this push:
     new 1ff7aace593 [v3-2-test] fix(scheduler): catch StaleDataError in 
verify_integrity to prevent scheduler crash (#64503) (#66727)
1ff7aace593 is described below

commit 1ff7aace5939eabe379dd5cb14ca4d62d0aae93a
Author: github-actions[bot] 
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Tue May 12 05:10:31 2026 +0200

    [v3-2-test] fix(scheduler): catch StaleDataError in verify_integrity to 
prevent scheduler crash (#64503) (#66727)
    
    Closes #63926
    
    StaleDataError raised by SQLAlchemy's optimistic locking when a concurrent
    session modifies the same row can cause the scheduler to crash during
    verify_integrity. Fix by catching StaleDataError alongside IntegrityError
    in dagrun.verify_integrity() and adding it to the retry exceptions in
    run_with_db_retries()/retry_db_transaction() so the operation is retried
    automatically.
    (cherry picked from commit dcfa2715632de7f665c3eba1b42d2e3084f08361)
    
    Co-authored-by: Pradeep Kalluri 
<[email protected]>
---
 airflow-core/newsfragments/64503.bugfix.rst   |  1 +
 airflow-core/src/airflow/models/dagrun.py     |  8 ++++++--
 airflow-core/src/airflow/utils/retries.py     |  5 +++--
 airflow-core/tests/unit/models/test_dagrun.py | 24 ++++++++++++++++++++++
 airflow-core/tests/unit/utils/test_retries.py | 29 ++++++++++++++++++---------
 5 files changed, 54 insertions(+), 13 deletions(-)

diff --git a/airflow-core/newsfragments/64503.bugfix.rst 
b/airflow-core/newsfragments/64503.bugfix.rst
new file mode 100644
index 00000000000..0358708ea1f
--- /dev/null
+++ b/airflow-core/newsfragments/64503.bugfix.rst
@@ -0,0 +1 @@
+Fix scheduler crashing with ``StaleDataError`` when a task instance is 
completed or removed by another session between ``verify_integrity`` loading 
task instances and ``session.flush()`` persisting them. Now caught and rolled 
back like the existing ``IntegrityError`` path.
diff --git a/airflow-core/src/airflow/models/dagrun.py 
b/airflow-core/src/airflow/models/dagrun.py
index 7eabadd73cf..afe73a43b96 100644
--- a/airflow-core/src/airflow/models/dagrun.py
+++ b/airflow-core/src/airflow/models/dagrun.py
@@ -57,6 +57,7 @@ from sqlalchemy.ext.associationproxy import association_proxy
 from sqlalchemy.ext.hybrid import hybrid_property
 from sqlalchemy.ext.mutable import MutableDict
 from sqlalchemy.orm import Mapped, declared_attr, joinedload, mapped_column, 
relationship, synonym, validates
+from sqlalchemy.orm.exc import StaleDataError
 from sqlalchemy.sql.expression import false, select
 from sqlalchemy.sql.functions import coalesce
 
@@ -1873,14 +1874,17 @@ class DagRun(Base, LoggingMixin):
                     extra_tags={"task_type": task_type},
                 )
             session.flush()
-        except IntegrityError:
+        except (IntegrityError, StaleDataError) as exc:
             self.log.info(
-                "Hit IntegrityError while creating the TIs for %s- %s",
+                "Hit %s while creating the TIs for %s- %s",
+                type(exc).__name__,
                 dag_id,
                 run_id,
                 exc_info=True,
             )
             self.log.info("Doing session rollback.")
+            # Catching StaleDataError and rolling back is sufficient here 
because
+            # the next scheduler loop will re-read the latest state from the 
DB.
             # TODO[HA]: We probably need to savepoint this so we can keep the 
transaction alive.
             session.rollback()
 
diff --git a/airflow-core/src/airflow/utils/retries.py 
b/airflow-core/src/airflow/utils/retries.py
index a30d6766853..69b71046acb 100644
--- a/airflow-core/src/airflow/utils/retries.py
+++ b/airflow-core/src/airflow/utils/retries.py
@@ -23,6 +23,7 @@ from inspect import signature
 from typing import TYPE_CHECKING, TypeVar, overload
 
 from sqlalchemy.exc import DBAPIError
+from sqlalchemy.orm.exc import StaleDataError
 
 from airflow.configuration import conf
 
@@ -40,7 +41,7 @@ def run_with_db_retries(max_retries: int = MAX_DB_RETRIES, 
logger: Logger | None
 
     # Default kwargs
     retry_kwargs = dict(
-        retry=tenacity.retry_if_exception_type(exception_types=(DBAPIError)),
+        retry=tenacity.retry_if_exception_type(exception_types=(DBAPIError, 
StaleDataError)),
         wait=tenacity.wait_random_exponential(multiplier=0.5, max=5),
         stop=tenacity.stop_after_attempt(max_retries),
         reraise=True,
@@ -104,7 +105,7 @@ def retry_db_transaction(_func: Callable | None = None, *, 
retries: int = MAX_DB
                     )
                     try:
                         return func(*args, **kwargs)
-                    except DBAPIError:
+                    except (DBAPIError, StaleDataError):
                         session.rollback()
                         raise
 
diff --git a/airflow-core/tests/unit/models/test_dagrun.py 
b/airflow-core/tests/unit/models/test_dagrun.py
index 93bf2dcbdf4..b259b62552e 100644
--- a/airflow-core/tests/unit/models/test_dagrun.py
+++ b/airflow-core/tests/unit/models/test_dagrun.py
@@ -39,6 +39,7 @@ from sqlalchemy import (
     update,
 )
 from sqlalchemy.orm import joinedload
+from sqlalchemy.orm.exc import StaleDataError
 
 from airflow import settings
 from airflow._shared.observability.metrics.stats import Stats
@@ -1443,6 +1444,29 @@ def 
test_expand_mapped_task_instance_task_decorator(is_noop, dag_maker, session)
         assert indices == [0, 1, 2, 3]
 
 
+def test_verify_integrity_handles_stale_data_error(dag_maker, session):
+    """Test that StaleDataError during _create_task_instances is caught and 
session is rolled back."""
+    with dag_maker("test_stale_data_error_dag", session=session) as dag:
+        task = EmptyOperator(task_id="task1")
+
+    dr = dag_maker.create_dagrun()
+    dag_version_id = DagVersion.get_latest_version(dag.dag_id, 
session=session).id
+
+    with mock.patch.object(session, "flush", side_effect=StaleDataError()):
+        with mock.patch.object(session, "rollback") as mock_rollback:
+            # Should not raise — StaleDataError must be caught gracefully.
+            # Call _create_task_instances directly with a non-empty task list 
so the
+            # test exercises the session.flush() → StaleDataError → 
session.rollback() path.
+            dr._create_task_instances(
+                dag_id=dag.dag_id,
+                tasks=[TI(task=task, run_id=dr.run_id, 
dag_version_id=dag_version_id)],
+                created_counts={"EmptyOperator": 1},
+                hook_is_noop=False,
+                session=session,
+            )
+            mock_rollback.assert_called_once()
+
+
 def test_mapped_literal_verify_integrity(dag_maker, session):
     """Test that when the length of a mapped literal changes we remove extra 
TIs"""
 
diff --git a/airflow-core/tests/unit/utils/test_retries.py 
b/airflow-core/tests/unit/utils/test_retries.py
index 1f44ee9ebf8..f0976d0e358 100644
--- a/airflow-core/tests/unit/utils/test_retries.py
+++ b/airflow-core/tests/unit/utils/test_retries.py
@@ -18,17 +18,14 @@
 from __future__ import annotations
 
 import logging
-from typing import TYPE_CHECKING
 from unittest import mock
 
 import pytest
 from sqlalchemy.exc import InternalError, OperationalError
+from sqlalchemy.orm.exc import StaleDataError
 
 from airflow.utils.retries import retry_db_transaction
 
-if TYPE_CHECKING:
-    from sqlalchemy.exc import DBAPIError
-
 
 class TestRetries:
     def test_retry_db_transaction_with_passing_retries(self):
@@ -48,15 +45,29 @@ class TestRetries:
 
         assert mock_obj.call_count == 2
 
-    @pytest.mark.db_test
-    @pytest.mark.parametrize("excection_type", [OperationalError, 
InternalError])
-    def test_retry_db_transaction_with_default_retries(self, caplog, 
excection_type: type[DBAPIError]):
+    @pytest.mark.parametrize(
+        ("exception_type", "exception_kwargs"),
+        [
+            pytest.param(
+                InternalError,
+                {"statement": mock.ANY, "params": mock.ANY, "orig": mock.ANY},
+                id="dbapi-internal",
+            ),
+            pytest.param(
+                OperationalError,
+                {"statement": mock.ANY, "params": mock.ANY, "orig": mock.ANY},
+                id="dbapi-operational",
+            ),
+            pytest.param(StaleDataError, {}, id="stale-data"),
+        ],
+    )
+    def test_retry_db_transaction_with_default_retries(self, caplog, 
exception_type, exception_kwargs):
         """Test that by default 3 retries will be carried out"""
         mock_obj = mock.MagicMock()
         mock_session = mock.MagicMock()
         mock_rollback = mock.MagicMock()
         mock_session.rollback = mock_rollback
-        db_error = excection_type(statement=mock.ANY, params=mock.ANY, 
orig=mock.ANY)
+        db_error = exception_type(**exception_kwargs)
 
         @retry_db_transaction
         def test_function(session):
@@ -66,7 +77,7 @@ class TestRetries:
 
         caplog.set_level(logging.DEBUG)
         caplog.clear()
-        with pytest.raises(excection_type):
+        with pytest.raises(exception_type):
             test_function(session=mock_session)
 
         for try_no in range(1, 4):

Reply via email to