This is an automated email from the ASF dual-hosted git repository. jedcunningham pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new cabd768309 Merge nowait and skip_locked into with_row_locks (#36889) cabd768309 is described below commit cabd768309296f5a9c92604d704307f816ff8786 Author: Tzu-ping Chung <uranu...@gmail.com> AuthorDate: Sat Jan 20 07:13:05 2024 +0800 Merge nowait and skip_locked into with_row_locks (#36889) Since the two functions are always used in conjunction with the last, we can simply handle the two arguments specially in with_row_locks, instead of doing the same checks over and over again. The two functions are removed outright since they are not documented and thus technically not subject to backward compatibility. I highly doubt anyone is using them directly due to their highly specific nature. --- airflow/dag_processing/manager.py | 6 +-- airflow/jobs/scheduler_job_runner.py | 21 ++--------- airflow/models/abstractoperator.py | 4 +- airflow/models/dag.py | 3 +- airflow/models/dagrun.py | 4 +- airflow/models/pool.py | 4 +- airflow/utils/sqlalchemy.py | 72 ++++++++++++++---------------------- tests/utils/test_sqlalchemy.py | 66 --------------------------------- 8 files changed, 40 insertions(+), 140 deletions(-) diff --git a/airflow/dag_processing/manager.py b/airflow/dag_processing/manager.py index b82a26f376..e1fa7a43bd 100644 --- a/airflow/dag_processing/manager.py +++ b/airflow/dag_processing/manager.py @@ -63,7 +63,7 @@ from airflow.utils.process_utils import ( ) from airflow.utils.retries import retry_db_transaction from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.sqlalchemy import prohibit_commit, skip_locked, with_row_locks +from airflow.utils.sqlalchemy import prohibit_commit, with_row_locks if TYPE_CHECKING: from multiprocessing.connection import Connection as MultiprocessingConnection @@ -681,9 +681,7 @@ class DagFileProcessorManager(LoggingMixin): DbCallbackRequest.processor_subdir == self.get_dag_directory(), ) query = query.order_by(DbCallbackRequest.priority_weight.asc()).limit(max_callbacks) - query = with_row_locks( - query, of=DbCallbackRequest, session=session, **skip_locked(session=session) - ) + query = with_row_locks(query, of=DbCallbackRequest, session=session, skip_locked=True) callbacks = session.scalars(query) for callback in callbacks: try: diff --git a/airflow/jobs/scheduler_job_runner.py b/airflow/jobs/scheduler_job_runner.py index 85dccbb26a..627e0d1468 100644 --- a/airflow/jobs/scheduler_job_runner.py +++ b/airflow/jobs/scheduler_job_runner.py @@ -68,7 +68,6 @@ from airflow.utils.session import NEW_SESSION, create_session, provide_session from airflow.utils.sqlalchemy import ( is_lock_not_available_error, prohibit_commit, - skip_locked, tuple_in_condition, with_row_locks, ) @@ -399,12 +398,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin): timer.start() try: - query = with_row_locks( - query, - of=TI, - session=session, - **skip_locked(session=session), - ) + query = with_row_locks(query, of=TI, session=session, skip_locked=True) task_instances_to_examine: list[TI] = session.scalars(query).all() timer.stop(send=True) @@ -706,12 +700,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin): query = select(TI).where(filter_for_tis).options(selectinload(TI.dag_model)) # row lock this entire set of taskinstances to make sure the scheduler doesn't fail when we have # multi-schedulers - tis_query: Query = with_row_locks( - query, - of=TI, - session=session, - **skip_locked(session=session), - ) + tis_query: Query = with_row_locks(query, of=TI, session=session, skip_locked=True) tis: Iterator[TI] = session.scalars(tis_query) for ti in tis: try_number = ti_primary_key_to_try_number_map[ti.key.primary] @@ -1434,7 +1423,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin): select(DagModel).where(DagModel.dag_id == dag_run.dag_id).options(joinedload(DagModel.parent_dag)) ) dag_model = session.scalars( - with_row_locks(query, of=DagModel, session=session, **skip_locked(session=session)) + with_row_locks(query, of=DagModel, session=session, skip_locked=True) ).one_or_none() if not dag: @@ -1660,9 +1649,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin): ) # Lock these rows, so that another scheduler can't try and adopt these too - tis_to_adopt_or_reset = with_row_locks( - query, of=TI, session=session, **skip_locked(session=session) - ) + tis_to_adopt_or_reset = with_row_locks(query, of=TI, session=session, skip_locked=True) tis_to_adopt_or_reset = session.scalars(tis_to_adopt_or_reset).all() to_reset = self.job.executor.try_adopt_task_instances(tis_to_adopt_or_reset) diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index f5a266f4b1..4ec8335255 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -35,7 +35,7 @@ from airflow.utils.db import exists_query from airflow.utils.log.secrets_masker import redact from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.setup_teardown import SetupTeardownContext -from airflow.utils.sqlalchemy import skip_locked, with_row_locks +from airflow.utils.sqlalchemy import with_row_locks from airflow.utils.state import State, TaskInstanceState from airflow.utils.task_group import MappedTaskGroup from airflow.utils.trigger_rule import TriggerRule @@ -625,7 +625,7 @@ class AbstractOperator(Templater, DAGNode): TaskInstance.run_id == run_id, TaskInstance.map_index >= total_expanded_ti_count, ) - query = with_row_locks(query, of=TaskInstance, session=session, **skip_locked(session=session)) + query = with_row_locks(query, of=TaskInstance, session=session, skip_locked=True) to_update = session.scalars(query) for ti in to_update: ti.state = TaskInstanceState.REMOVED diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 4e70b87817..9ee3409c0d 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -128,7 +128,6 @@ from airflow.utils.sqlalchemy import ( Interval, UtcDateTime, lock_rows, - skip_locked, tuple_in_condition, with_row_locks, ) @@ -3789,7 +3788,7 @@ class DagModel(Base): ) return ( - session.scalars(with_row_locks(query, of=cls, session=session, **skip_locked(session=session))), + session.scalars(with_row_locks(query, of=cls, session=session, skip_locked=True)), dataset_triggered_dag_info, ) diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index 6a1e71d4d7..501470fd56 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -65,7 +65,7 @@ from airflow.utils import timezone from airflow.utils.helpers import chunks, is_container, prune_dict from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.sqlalchemy import UtcDateTime, nulls_first, skip_locked, tuple_in_condition, with_row_locks +from airflow.utils.sqlalchemy import UtcDateTime, nulls_first, tuple_in_condition, with_row_locks from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.types import NOTSET, DagRunType @@ -365,7 +365,7 @@ class DagRun(Base, LoggingMixin): query = query.where(DagRun.execution_date <= func.now()) return session.scalars( - with_row_locks(query.limit(max_number), of=cls, session=session, **skip_locked(session=session)) + with_row_locks(query.limit(max_number), of=cls, session=session, skip_locked=True) ) @classmethod diff --git a/airflow/models/pool.py b/airflow/models/pool.py index 1960c0a867..3ca7293ffe 100644 --- a/airflow/models/pool.py +++ b/airflow/models/pool.py @@ -27,7 +27,7 @@ from airflow.ti_deps.dependencies_states import EXECUTION_STATES from airflow.typing_compat import TypedDict from airflow.utils.db import exists_query from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.sqlalchemy import nowait, with_row_locks +from airflow.utils.sqlalchemy import with_row_locks from airflow.utils.state import TaskInstanceState if TYPE_CHECKING: @@ -172,7 +172,7 @@ class Pool(Base): query = select(Pool.pool, Pool.slots, Pool.include_deferred) if lock_rows: - query = with_row_locks(query, session=session, **nowait(session)) + query = with_row_locks(query, session=session, nowait=True) pool_rows = session.execute(query) for pool_name, total_slots, include_deferred in pool_rows: diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py index 9d9b248ec7..2dc495811a 100644 --- a/airflow/utils/sqlalchemy.py +++ b/airflow/utils/sqlalchemy.py @@ -334,46 +334,6 @@ class Interval(TypeDecorator): return data -def skip_locked(session: Session) -> dict[str, Any]: - """ - Return kargs for passing to `with_for_update()` suitable for the current DB engine version. - - We do this as we document the fact that on DB engines that don't support this construct, we do not - support/recommend running HA scheduler. If a user ignores this and tries anyway everything will still - work, just slightly slower in some circumstances. - - Specifically don't emit SKIP LOCKED for MySQL < 8, or MariaDB, neither of which support this construct - - See https://jira.mariadb.org/browse/MDEV-13115 - """ - dialect = session.bind.dialect - - if dialect.name != "mysql" or dialect.supports_for_update_of: - return {"skip_locked": True} - else: - return {} - - -def nowait(session: Session) -> dict[str, Any]: - """ - Return kwargs for passing to `with_for_update()` suitable for the current DB engine version. - - We do this as we document the fact that on DB engines that don't support this construct, we do not - support/recommend running HA scheduler. If a user ignores this and tries anyway everything will still - work, just slightly slower in some circumstances. - - Specifically don't emit NOWAIT for MySQL < 8, or MariaDB, neither of which support this construct - - See https://jira.mariadb.org/browse/MDEV-13115 - """ - dialect = session.bind.dialect - - if dialect.name != "mysql" or dialect.supports_for_update_of: - return {"nowait": True} - else: - return {} - - def nulls_first(col, session: Session) -> dict[str, Any]: """Specify *NULLS FIRST* to the column ordering. @@ -390,22 +350,44 @@ def nulls_first(col, session: Session) -> dict[str, Any]: USE_ROW_LEVEL_LOCKING: bool = conf.getboolean("scheduler", "use_row_level_locking", fallback=True) -def with_row_locks(query: Query, session: Session, **kwargs) -> Query: +def with_row_locks( + query: Query, + session: Session, + *, + nowait: bool = False, + skip_locked: bool = False, + **kwargs, +) -> Query: """ - Apply with_for_update to an SQLAlchemy query, if row level locking is in use. + Apply with_for_update to the SQLAlchemy query if row level locking is in use. + + This wrapper is needed so we don't use the syntax on unsupported database + engines. In particular, MySQL (prior to 8.0) and MariaDB do not support + row locking, where we do not support nor recommend running HA scheduler. If + a user ignores this and tries anyway, everything will still work, just + slightly slower in some circumstances. + + See https://jira.mariadb.org/browse/MDEV-13115 :param query: An SQLAlchemy Query object :param session: ORM Session + :param nowait: If set to True, will pass NOWAIT to supported database backends. + :param skip_locked: If set to True, will pass SKIP LOCKED to supported database backends. :param kwargs: Extra kwargs to pass to with_for_update (of, nowait, skip_locked, etc) :return: updated query """ dialect = session.bind.dialect # Don't use row level locks if the MySQL dialect (Mariadb & MySQL < 8) does not support it. - if USE_ROW_LEVEL_LOCKING and (dialect.name != "mysql" or dialect.supports_for_update_of): - return query.with_for_update(**kwargs) - else: + if not USE_ROW_LEVEL_LOCKING: + return query + if dialect.name == "mysql" and not dialect.supports_for_update_of: return query + if nowait: + kwargs["nowait"] = True + if skip_locked: + kwargs["skip_locked"] = True + return query.with_for_update(**kwargs) @contextlib.contextmanager diff --git a/tests/utils/test_sqlalchemy.py b/tests/utils/test_sqlalchemy.py index e01d0904ad..16ba6b392d 100644 --- a/tests/utils/test_sqlalchemy.py +++ b/tests/utils/test_sqlalchemy.py @@ -36,9 +36,7 @@ from airflow.settings import Session from airflow.utils.sqlalchemy import ( ExecutorConfigType, ensure_pod_is_valid_after_unpickling, - nowait, prohibit_commit, - skip_locked, with_row_locks, ) from airflow.utils.state import State @@ -117,70 +115,6 @@ class TestSqlAlchemyUtils: ) dag.clear() - @pytest.mark.parametrize( - "dialect, supports_for_update_of, expected_return_value", - [ - ( - "postgresql", - True, - {"skip_locked": True}, - ), - ( - "mysql", - False, - {}, - ), - ( - "mysql", - True, - {"skip_locked": True}, - ), - ( - "sqlite", - False, - {"skip_locked": True}, - ), - ], - ) - def test_skip_locked(self, dialect, supports_for_update_of, expected_return_value): - session = mock.Mock() - session.bind.dialect.name = dialect - session.bind.dialect.supports_for_update_of = supports_for_update_of - assert skip_locked(session=session) == expected_return_value - - @pytest.mark.parametrize( - "dialect, supports_for_update_of, expected_return_value", - [ - ( - "postgresql", - True, - {"nowait": True}, - ), - ( - "mysql", - False, - {}, - ), - ( - "mysql", - True, - {"nowait": True}, - ), - ( - "sqlite", - False, - { - "nowait": True, - }, - ), - ], - ) - def test_nowait(self, dialect, supports_for_update_of, expected_return_value): - session = mock.Mock() - session.bind.dialect.name = dialect - session.bind.dialect.supports_for_update_of = supports_for_update_of - assert nowait(session=session) == expected_return_value - @pytest.mark.parametrize( "dialect, supports_for_update_of, use_row_level_lock_conf, expected_use_row_level_lock", [