This is an automated email from the ASF dual-hosted git repository.

jedcunningham pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new cabd768309 Merge nowait and skip_locked into with_row_locks (#36889)
cabd768309 is described below

commit cabd768309296f5a9c92604d704307f816ff8786
Author: Tzu-ping Chung <uranu...@gmail.com>
AuthorDate: Sat Jan 20 07:13:05 2024 +0800

    Merge nowait and skip_locked into with_row_locks (#36889)
    
    Since the two functions are always used in conjunction with the last, we
    can simply handle the two arguments specially in with_row_locks, instead
    of doing the same checks over and over again.
    
    The two functions are removed outright since they are not documented and
    thus technically not subject to backward compatibility. I highly doubt
    anyone is using them directly due to their highly specific nature.
---
 airflow/dag_processing/manager.py    |  6 +--
 airflow/jobs/scheduler_job_runner.py | 21 ++---------
 airflow/models/abstractoperator.py   |  4 +-
 airflow/models/dag.py                |  3 +-
 airflow/models/dagrun.py             |  4 +-
 airflow/models/pool.py               |  4 +-
 airflow/utils/sqlalchemy.py          | 72 ++++++++++++++----------------------
 tests/utils/test_sqlalchemy.py       | 66 ---------------------------------
 8 files changed, 40 insertions(+), 140 deletions(-)

diff --git a/airflow/dag_processing/manager.py 
b/airflow/dag_processing/manager.py
index b82a26f376..e1fa7a43bd 100644
--- a/airflow/dag_processing/manager.py
+++ b/airflow/dag_processing/manager.py
@@ -63,7 +63,7 @@ from airflow.utils.process_utils import (
 )
 from airflow.utils.retries import retry_db_transaction
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import prohibit_commit, skip_locked, 
with_row_locks
+from airflow.utils.sqlalchemy import prohibit_commit, with_row_locks
 
 if TYPE_CHECKING:
     from multiprocessing.connection import Connection as 
MultiprocessingConnection
@@ -681,9 +681,7 @@ class DagFileProcessorManager(LoggingMixin):
                     DbCallbackRequest.processor_subdir == 
self.get_dag_directory(),
                 )
             query = 
query.order_by(DbCallbackRequest.priority_weight.asc()).limit(max_callbacks)
-            query = with_row_locks(
-                query, of=DbCallbackRequest, session=session, 
**skip_locked(session=session)
-            )
+            query = with_row_locks(query, of=DbCallbackRequest, 
session=session, skip_locked=True)
             callbacks = session.scalars(query)
             for callback in callbacks:
                 try:
diff --git a/airflow/jobs/scheduler_job_runner.py 
b/airflow/jobs/scheduler_job_runner.py
index 85dccbb26a..627e0d1468 100644
--- a/airflow/jobs/scheduler_job_runner.py
+++ b/airflow/jobs/scheduler_job_runner.py
@@ -68,7 +68,6 @@ from airflow.utils.session import NEW_SESSION, 
create_session, provide_session
 from airflow.utils.sqlalchemy import (
     is_lock_not_available_error,
     prohibit_commit,
-    skip_locked,
     tuple_in_condition,
     with_row_locks,
 )
@@ -399,12 +398,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
             timer.start()
 
             try:
-                query = with_row_locks(
-                    query,
-                    of=TI,
-                    session=session,
-                    **skip_locked(session=session),
-                )
+                query = with_row_locks(query, of=TI, session=session, 
skip_locked=True)
                 task_instances_to_examine: list[TI] = 
session.scalars(query).all()
 
                 timer.stop(send=True)
@@ -706,12 +700,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
         query = 
select(TI).where(filter_for_tis).options(selectinload(TI.dag_model))
         # row lock this entire set of taskinstances to make sure the scheduler 
doesn't fail when we have
         # multi-schedulers
-        tis_query: Query = with_row_locks(
-            query,
-            of=TI,
-            session=session,
-            **skip_locked(session=session),
-        )
+        tis_query: Query = with_row_locks(query, of=TI, session=session, 
skip_locked=True)
         tis: Iterator[TI] = session.scalars(tis_query)
         for ti in tis:
             try_number = ti_primary_key_to_try_number_map[ti.key.primary]
@@ -1434,7 +1423,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
             select(DagModel).where(DagModel.dag_id == 
dag_run.dag_id).options(joinedload(DagModel.parent_dag))
         )
         dag_model = session.scalars(
-            with_row_locks(query, of=DagModel, session=session, 
**skip_locked(session=session))
+            with_row_locks(query, of=DagModel, session=session, 
skip_locked=True)
         ).one_or_none()
 
         if not dag:
@@ -1660,9 +1649,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
                     )
 
                     # Lock these rows, so that another scheduler can't try and 
adopt these too
-                    tis_to_adopt_or_reset = with_row_locks(
-                        query, of=TI, session=session, 
**skip_locked(session=session)
-                    )
+                    tis_to_adopt_or_reset = with_row_locks(query, of=TI, 
session=session, skip_locked=True)
                     tis_to_adopt_or_reset = 
session.scalars(tis_to_adopt_or_reset).all()
                     to_reset = 
self.job.executor.try_adopt_task_instances(tis_to_adopt_or_reset)
 
diff --git a/airflow/models/abstractoperator.py 
b/airflow/models/abstractoperator.py
index f5a266f4b1..4ec8335255 100644
--- a/airflow/models/abstractoperator.py
+++ b/airflow/models/abstractoperator.py
@@ -35,7 +35,7 @@ from airflow.utils.db import exists_query
 from airflow.utils.log.secrets_masker import redact
 from airflow.utils.session import NEW_SESSION, provide_session
 from airflow.utils.setup_teardown import SetupTeardownContext
-from airflow.utils.sqlalchemy import skip_locked, with_row_locks
+from airflow.utils.sqlalchemy import with_row_locks
 from airflow.utils.state import State, TaskInstanceState
 from airflow.utils.task_group import MappedTaskGroup
 from airflow.utils.trigger_rule import TriggerRule
@@ -625,7 +625,7 @@ class AbstractOperator(Templater, DAGNode):
             TaskInstance.run_id == run_id,
             TaskInstance.map_index >= total_expanded_ti_count,
         )
-        query = with_row_locks(query, of=TaskInstance, session=session, 
**skip_locked(session=session))
+        query = with_row_locks(query, of=TaskInstance, session=session, 
skip_locked=True)
         to_update = session.scalars(query)
         for ti in to_update:
             ti.state = TaskInstanceState.REMOVED
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index 4e70b87817..9ee3409c0d 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -128,7 +128,6 @@ from airflow.utils.sqlalchemy import (
     Interval,
     UtcDateTime,
     lock_rows,
-    skip_locked,
     tuple_in_condition,
     with_row_locks,
 )
@@ -3789,7 +3788,7 @@ class DagModel(Base):
         )
 
         return (
-            session.scalars(with_row_locks(query, of=cls, session=session, 
**skip_locked(session=session))),
+            session.scalars(with_row_locks(query, of=cls, session=session, 
skip_locked=True)),
             dataset_triggered_dag_info,
         )
 
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 6a1e71d4d7..501470fd56 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -65,7 +65,7 @@ from airflow.utils import timezone
 from airflow.utils.helpers import chunks, is_container, prune_dict
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import UtcDateTime, nulls_first, skip_locked, 
tuple_in_condition, with_row_locks
+from airflow.utils.sqlalchemy import UtcDateTime, nulls_first, 
tuple_in_condition, with_row_locks
 from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.types import NOTSET, DagRunType
 
@@ -365,7 +365,7 @@ class DagRun(Base, LoggingMixin):
             query = query.where(DagRun.execution_date <= func.now())
 
         return session.scalars(
-            with_row_locks(query.limit(max_number), of=cls, session=session, 
**skip_locked(session=session))
+            with_row_locks(query.limit(max_number), of=cls, session=session, 
skip_locked=True)
         )
 
     @classmethod
diff --git a/airflow/models/pool.py b/airflow/models/pool.py
index 1960c0a867..3ca7293ffe 100644
--- a/airflow/models/pool.py
+++ b/airflow/models/pool.py
@@ -27,7 +27,7 @@ from airflow.ti_deps.dependencies_states import 
EXECUTION_STATES
 from airflow.typing_compat import TypedDict
 from airflow.utils.db import exists_query
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import nowait, with_row_locks
+from airflow.utils.sqlalchemy import with_row_locks
 from airflow.utils.state import TaskInstanceState
 
 if TYPE_CHECKING:
@@ -172,7 +172,7 @@ class Pool(Base):
         query = select(Pool.pool, Pool.slots, Pool.include_deferred)
 
         if lock_rows:
-            query = with_row_locks(query, session=session, **nowait(session))
+            query = with_row_locks(query, session=session, nowait=True)
 
         pool_rows = session.execute(query)
         for pool_name, total_slots, include_deferred in pool_rows:
diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py
index 9d9b248ec7..2dc495811a 100644
--- a/airflow/utils/sqlalchemy.py
+++ b/airflow/utils/sqlalchemy.py
@@ -334,46 +334,6 @@ class Interval(TypeDecorator):
         return data
 
 
-def skip_locked(session: Session) -> dict[str, Any]:
-    """
-    Return kargs for passing to `with_for_update()` suitable for the current 
DB engine version.
-
-    We do this as we document the fact that on DB engines that don't support 
this construct, we do not
-    support/recommend running HA scheduler. If a user ignores this and tries 
anyway everything will still
-    work, just slightly slower in some circumstances.
-
-    Specifically don't emit SKIP LOCKED for MySQL < 8, or MariaDB, neither of 
which support this construct
-
-    See https://jira.mariadb.org/browse/MDEV-13115
-    """
-    dialect = session.bind.dialect
-
-    if dialect.name != "mysql" or dialect.supports_for_update_of:
-        return {"skip_locked": True}
-    else:
-        return {}
-
-
-def nowait(session: Session) -> dict[str, Any]:
-    """
-    Return kwargs for passing to `with_for_update()` suitable for the current 
DB engine version.
-
-    We do this as we document the fact that on DB engines that don't support 
this construct, we do not
-    support/recommend running HA scheduler. If a user ignores this and tries 
anyway everything will still
-    work, just slightly slower in some circumstances.
-
-    Specifically don't emit NOWAIT for MySQL < 8, or MariaDB, neither of which 
support this construct
-
-    See https://jira.mariadb.org/browse/MDEV-13115
-    """
-    dialect = session.bind.dialect
-
-    if dialect.name != "mysql" or dialect.supports_for_update_of:
-        return {"nowait": True}
-    else:
-        return {}
-
-
 def nulls_first(col, session: Session) -> dict[str, Any]:
     """Specify *NULLS FIRST* to the column ordering.
 
@@ -390,22 +350,44 @@ def nulls_first(col, session: Session) -> dict[str, Any]:
 USE_ROW_LEVEL_LOCKING: bool = conf.getboolean("scheduler", 
"use_row_level_locking", fallback=True)
 
 
-def with_row_locks(query: Query, session: Session, **kwargs) -> Query:
+def with_row_locks(
+    query: Query,
+    session: Session,
+    *,
+    nowait: bool = False,
+    skip_locked: bool = False,
+    **kwargs,
+) -> Query:
     """
-    Apply with_for_update to an SQLAlchemy query, if row level locking is in 
use.
+    Apply with_for_update to the SQLAlchemy query if row level locking is in 
use.
+
+    This wrapper is needed so we don't use the syntax on unsupported database
+    engines. In particular, MySQL (prior to 8.0) and MariaDB do not support
+    row locking, where we do not support nor recommend running HA scheduler. If
+    a user ignores this and tries anyway, everything will still work, just
+    slightly slower in some circumstances.
+
+    See https://jira.mariadb.org/browse/MDEV-13115
 
     :param query: An SQLAlchemy Query object
     :param session: ORM Session
+    :param nowait: If set to True, will pass NOWAIT to supported database 
backends.
+    :param skip_locked: If set to True, will pass SKIP LOCKED to supported 
database backends.
     :param kwargs: Extra kwargs to pass to with_for_update (of, nowait, 
skip_locked, etc)
     :return: updated query
     """
     dialect = session.bind.dialect
 
     # Don't use row level locks if the MySQL dialect (Mariadb & MySQL < 8) 
does not support it.
-    if USE_ROW_LEVEL_LOCKING and (dialect.name != "mysql" or 
dialect.supports_for_update_of):
-        return query.with_for_update(**kwargs)
-    else:
+    if not USE_ROW_LEVEL_LOCKING:
+        return query
+    if dialect.name == "mysql" and not dialect.supports_for_update_of:
         return query
+    if nowait:
+        kwargs["nowait"] = True
+    if skip_locked:
+        kwargs["skip_locked"] = True
+    return query.with_for_update(**kwargs)
 
 
 @contextlib.contextmanager
diff --git a/tests/utils/test_sqlalchemy.py b/tests/utils/test_sqlalchemy.py
index e01d0904ad..16ba6b392d 100644
--- a/tests/utils/test_sqlalchemy.py
+++ b/tests/utils/test_sqlalchemy.py
@@ -36,9 +36,7 @@ from airflow.settings import Session
 from airflow.utils.sqlalchemy import (
     ExecutorConfigType,
     ensure_pod_is_valid_after_unpickling,
-    nowait,
     prohibit_commit,
-    skip_locked,
     with_row_locks,
 )
 from airflow.utils.state import State
@@ -117,70 +115,6 @@ class TestSqlAlchemyUtils:
             )
         dag.clear()
 
-    @pytest.mark.parametrize(
-        "dialect, supports_for_update_of, expected_return_value",
-        [
-            (
-                "postgresql",
-                True,
-                {"skip_locked": True},
-            ),
-            (
-                "mysql",
-                False,
-                {},
-            ),
-            (
-                "mysql",
-                True,
-                {"skip_locked": True},
-            ),
-            (
-                "sqlite",
-                False,
-                {"skip_locked": True},
-            ),
-        ],
-    )
-    def test_skip_locked(self, dialect, supports_for_update_of, 
expected_return_value):
-        session = mock.Mock()
-        session.bind.dialect.name = dialect
-        session.bind.dialect.supports_for_update_of = supports_for_update_of
-        assert skip_locked(session=session) == expected_return_value
-
-    @pytest.mark.parametrize(
-        "dialect, supports_for_update_of, expected_return_value",
-        [
-            (
-                "postgresql",
-                True,
-                {"nowait": True},
-            ),
-            (
-                "mysql",
-                False,
-                {},
-            ),
-            (
-                "mysql",
-                True,
-                {"nowait": True},
-            ),
-            (
-                "sqlite",
-                False,
-                {
-                    "nowait": True,
-                },
-            ),
-        ],
-    )
-    def test_nowait(self, dialect, supports_for_update_of, 
expected_return_value):
-        session = mock.Mock()
-        session.bind.dialect.name = dialect
-        session.bind.dialect.supports_for_update_of = supports_for_update_of
-        assert nowait(session=session) == expected_return_value
-
     @pytest.mark.parametrize(
         "dialect, supports_for_update_of, use_row_level_lock_conf, 
expected_use_row_level_lock",
         [

Reply via email to