This is an automated email from the ASF dual-hosted git repository. bbovenzi pushed a commit to branch mapped-instance-actions in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 898480765fe8117938037b194a10663565d44e3a Author: Tzu-ping Chung <[email protected]> AuthorDate: Tue Apr 19 18:01:55 2022 +0800 Introduce tuple_().in_() shim for MSSQL compat --- airflow/api/common/mark_tasks.py | 5 +++-- airflow/jobs/scheduler_job.py | 47 +++++++++++++++------------------------- airflow/models/dag.py | 8 +++---- airflow/models/taskinstance.py | 21 +++++------------- airflow/utils/sqlalchemy.py | 25 +++++++++++++++++++-- 5 files changed, 52 insertions(+), 54 deletions(-) diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py index 349b935e82..1d4709fb82 100644 --- a/airflow/api/common/mark_tasks.py +++ b/airflow/api/common/mark_tasks.py @@ -20,7 +20,7 @@ from datetime import datetime from typing import TYPE_CHECKING, Collection, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Union -from sqlalchemy import or_, tuple_ +from sqlalchemy import or_ from sqlalchemy.orm import contains_eager from sqlalchemy.orm.session import Session as SASession @@ -32,6 +32,7 @@ from airflow.operators.subdag import SubDagOperator from airflow.utils import timezone from airflow.utils.helpers import exactly_one from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.sqlalchemy import tuple_in_condition from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.types import DagRunType @@ -203,7 +204,7 @@ def get_all_dag_task_query( if is_string_list: qry_dag = qry_dag.filter(TaskInstance.task_id.in_(task_ids)) else: - qry_dag = qry_dag.filter(tuple_(TaskInstance.task_id, TaskInstance.map_index).in_(task_ids)) + qry_dag = qry_dag.filter(tuple_in_condition((TaskInstance.task_id, TaskInstance.map_index), task_ids)) qry_dag = qry_dag.filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state)).options( contains_eager(TaskInstance.dag_run) ) diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index e0b8c437ac..ac1d25833b 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -28,7 +28,7 @@ from collections import defaultdict from datetime import timedelta from typing import Collection, DefaultDict, Dict, Iterator, List, Optional, Set, Tuple -from sqlalchemy import and_, func, not_, or_, text, tuple_ +from sqlalchemy import func, not_, or_, text from sqlalchemy.exc import OperationalError from sqlalchemy.orm import load_only, selectinload from sqlalchemy.orm.session import Session, make_transient @@ -55,7 +55,13 @@ from airflow.utils.docs import get_docs_url from airflow.utils.event_scheduler import EventScheduler from airflow.utils.retries import MAX_DB_RETRIES, retry_db_transaction, run_with_db_retries from airflow.utils.session import create_session, provide_session -from airflow.utils.sqlalchemy import is_lock_not_available_error, prohibit_commit, skip_locked, with_row_locks +from airflow.utils.sqlalchemy import ( + is_lock_not_available_error, + prohibit_commit, + skip_locked, + tuple_in_condition, + with_row_locks, +) from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.types import DagRunType @@ -321,17 +327,7 @@ class SchedulerJob(BaseJob): query = query.filter(not_(TI.dag_id.in_(starved_dags))) if starved_tasks: - if settings.engine.dialect.name == 'mssql': - task_filter = or_( - and_( - TaskInstance.dag_id == dag_id, - TaskInstance.task_id == task_id, - ) - for (dag_id, task_id) in starved_tasks - ) - else: - task_filter = tuple_(TaskInstance.dag_id, TaskInstance.task_id).in_(starved_tasks) - + task_filter = tuple_in_condition((TaskInstance.dag_id, TaskInstance.task_id), starved_tasks) query = query.filter(not_(task_filter)) query = query.limit(max_tis) @@ -980,24 +976,15 @@ class SchedulerJob(BaseJob): # as DagModel.dag_id and DagModel.next_dagrun # This list is used to verify if the DagRun already exist so that we don't attempt to create # duplicate dag runs - - if session.bind.dialect.name == 'mssql': - existing_dagruns_filter = or_( - *( - and_( - DagRun.dag_id == dm.dag_id, - DagRun.execution_date == dm.next_dagrun, - ) - for dm in dag_models - ) - ) - else: - existing_dagruns_filter = tuple_(DagRun.dag_id, DagRun.execution_date).in_( - [(dm.dag_id, dm.next_dagrun) for dm in dag_models] - ) - existing_dagruns = ( - session.query(DagRun.dag_id, DagRun.execution_date).filter(existing_dagruns_filter).all() + session.query(DagRun.dag_id, DagRun.execution_date) + .filter( + tuple_in_condition( + (DagRun.dag_id, DagRun.execution_date), + ((dm.dag_id, dm.next_dagrun) for dm in dag_models), + ), + ) + .all() ) active_runs_of_dags = defaultdict( diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 9c93bcef13..83860ba591 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -52,7 +52,7 @@ import jinja2 import pendulum from dateutil.relativedelta import relativedelta from pendulum.tz.timezone import Timezone -from sqlalchemy import Boolean, Column, ForeignKey, Index, Integer, String, Text, func, or_, tuple_ +from sqlalchemy import Boolean, Column, ForeignKey, Index, Integer, String, Text, func, not_, or_ from sqlalchemy.orm import backref, joinedload, relationship from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session @@ -85,7 +85,7 @@ from airflow.utils.file import correct_maybe_zipped from airflow.utils.helpers import exactly_one, validate_key from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, with_row_locks +from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, tuple_in_condition, with_row_locks from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.types import NOTSET, ArgNotSet, DagRunType, EdgeInfoType @@ -1451,7 +1451,7 @@ class DAG(LoggingMixin): elif isinstance(next(iter(task_ids), None), str): tis = tis.filter(TI.task_id.in_(task_ids)) else: - tis = tis.filter(tuple_(TI.task_id, TI.map_index).in_(task_ids)) + tis = tis.filter(tuple_in_condition((TI.task_id, TI.map_index), task_ids)) # This allows allow_trigger_in_future config to take affect, rather than mandating exec_date <= UTC if end_date or not self.allow_future_exec_dates: @@ -1611,7 +1611,7 @@ class DAG(LoggingMixin): elif isinstance(next(iter(exclude_task_ids), None), str): tis = tis.filter(TI.task_id.notin_(exclude_task_ids)) else: - tis = tis.filter(tuple_(TI.task_id, TI.map_index).notin_(exclude_task_ids)) + tis = tis.filter(not_(tuple_in_condition((TI.task_id, TI.map_index), exclude_task_ids))) return tis diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 48d3a047fb..9d135a47b8 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -67,7 +67,6 @@ from sqlalchemy import ( inspect, or_, text, - tuple_, ) from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.ext.mutable import MutableDict @@ -122,7 +121,7 @@ from airflow.utils.operator_helpers import context_to_airflow_vars from airflow.utils.platform import getuser from airflow.utils.retries import run_with_db_retries from airflow.utils.session import NEW_SESSION, create_session, provide_session -from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime, with_row_locks +from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime, tuple_in_condition, with_row_locks from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.timeout import timeout @@ -2540,20 +2539,10 @@ class TaskInstance(Base, LoggingMixin): TaskInstance.task_id == first_task_id, ) - if settings.engine.dialect.name == 'mssql': - return or_( - and_( - TaskInstance.dag_id == ti.dag_id, - TaskInstance.task_id == ti.task_id, - TaskInstance.run_id == ti.run_id, - TaskInstance.map_index == ti.map_index, - ) - for ti in tis - ) - else: - return tuple_( - TaskInstance.dag_id, TaskInstance.task_id, TaskInstance.run_id, TaskInstance.map_index - ).in_([ti.key.primary for ti in tis]) + return tuple_in_condition( + (TaskInstance.dag_id, TaskInstance.task_id, TaskInstance.run_id, TaskInstance.map_index), + (ti.key.primary for ti in tis), + ) # State of the task instance. diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py index c240a94456..5c36d826b2 100644 --- a/airflow/utils/sqlalchemy.py +++ b/airflow/utils/sqlalchemy.py @@ -19,15 +19,19 @@ import datetime import json import logging -from typing import Any, Dict +from operator import and_, or_ +from typing import Any, Dict, Iterable, Tuple import pendulum from dateutil import relativedelta -from sqlalchemy import event, nullsfirst +from sqlalchemy import event, nullsfirst, tuple_ from sqlalchemy.exc import OperationalError from sqlalchemy.orm.session import Session +from sqlalchemy.sql import ColumnElement +from sqlalchemy.sql.expression import ColumnOperators from sqlalchemy.types import JSON, DateTime, Text, TypeDecorator, TypeEngine, UnicodeText +from airflow import settings from airflow.configuration import conf log = logging.getLogger(__name__) @@ -319,3 +323,20 @@ def is_lock_not_available_error(error: OperationalError): if db_err_code in ('55P03', 1205, 3572): return True return False + + +def tuple_in_condition( + columns: Tuple[ColumnElement, ...], + collection: Iterable[Any], +) -> ColumnOperators: + """Generates a tuple-in-collection operator to use in ``.filter()``. + + For most SQL backends, this generates a simple ``([col, ...]) IN [condition]`` + clause. This however does not work with MSSQL, where we need to expand to + ``(c1 = v1a AND c2 = v2a ...) OR (c1 = v1b AND c2 = v2b ...) ...`` manually. + + :meta private: + """ + if settings.engine.dialect.name != "mssql": + return tuple_(*columns).in_(collection) + return or_(*(and_(*(c == v for c, v in zip(columns, values))) for values in collection))
