This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 3873230a11d Remove tuple_in_condition helpers (#45201)
3873230a11d is described below

commit 3873230a11de8b9cc24d012ecdfe6848bc6ae0cf
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Thu Dec 26 11:04:00 2024 +0800

    Remove tuple_in_condition helpers (#45201)
---
 airflow/jobs/scheduler_job_runner.py               | 35 ++++++--------
 airflow/models/dag.py                              |  8 ++--
 airflow/models/dagrun.py                           |  7 +--
 airflow/models/skipmixin.py                        |  7 ++-
 airflow/models/taskinstance.py                     | 16 +++----
 airflow/utils/sqlalchemy.py                        | 54 ++--------------------
 airflow/www/utils.py                               | 10 +---
 .../providers/standard/utils/sensor_helper.py      | 22 ++++-----
 8 files changed, 44 insertions(+), 115 deletions(-)

diff --git a/airflow/jobs/scheduler_job_runner.py 
b/airflow/jobs/scheduler_job_runner.py
index a0558e9040d..ece0d4c1cb6 100644
--- a/airflow/jobs/scheduler_job_runner.py
+++ b/airflow/jobs/scheduler_job_runner.py
@@ -34,7 +34,7 @@ from pathlib import Path
 from typing import TYPE_CHECKING, Any, Callable
 
 from deprecated import deprecated
-from sqlalchemy import and_, delete, exists, func, not_, select, text, update
+from sqlalchemy import and_, delete, exists, func, select, text, tuple_, update
 from sqlalchemy.exc import OperationalError
 from sqlalchemy.orm import lazyload, load_only, make_transient, selectinload
 from sqlalchemy.sql import expression
@@ -77,12 +77,7 @@ from airflow.utils.event_scheduler import EventScheduler
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.retries import MAX_DB_RETRIES, retry_db_transaction, 
run_with_db_retries
 from airflow.utils.session import NEW_SESSION, create_session, provide_session
-from airflow.utils.sqlalchemy import (
-    is_lock_not_available_error,
-    prohibit_commit,
-    tuple_in_condition,
-    with_row_locks,
-)
+from airflow.utils.sqlalchemy import is_lock_not_available_error, 
prohibit_commit, with_row_locks
 from airflow.utils.state import DagRunState, JobState, State, TaskInstanceState
 from airflow.utils.types import DagRunTriggeredByType, DagRunType
 
@@ -357,28 +352,25 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
                 .join(TI.dag_run)
                 .where(DR.state == DagRunState.RUNNING)
                 .join(TI.dag_model)
-                .where(not_(DM.is_paused))
+                .where(~DM.is_paused)
                 .where(TI.state == TaskInstanceState.SCHEDULED)
                 .options(selectinload(TI.dag_model))
                 .order_by(-TI.priority_weight, DR.logical_date, TI.map_index)
             )
 
             if starved_pools:
-                query = query.where(not_(TI.pool.in_(starved_pools)))
+                query = query.where(TI.pool.not_in(starved_pools))
 
             if starved_dags:
-                query = query.where(not_(TI.dag_id.in_(starved_dags)))
+                query = query.where(TI.dag_id.not_in(starved_dags))
 
             if starved_tasks:
-                task_filter = tuple_in_condition((TI.dag_id, TI.task_id), 
starved_tasks)
-                query = query.where(not_(task_filter))
+                query = query.where(tuple_(TI.dag_id, 
TI.task_id).not_in(starved_tasks))
 
             if starved_tasks_task_dagrun_concurrency:
-                task_filter = tuple_in_condition(
-                    (TI.dag_id, TI.run_id, TI.task_id),
-                    starved_tasks_task_dagrun_concurrency,
+                query = query.where(
+                    tuple_(TI.dag_id, TI.run_id, 
TI.task_id).not_in(starved_tasks_task_dagrun_concurrency)
                 )
-                query = query.where(not_(task_filter))
 
             query = query.limit(max_tis)
 
@@ -1314,9 +1306,8 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
         existing_dagruns = (
             session.execute(
                 select(DagRun.dag_id, DagRun.logical_date).where(
-                    tuple_in_condition(
-                        (DagRun.dag_id, DagRun.logical_date),
-                        ((dm.dag_id, dm.next_dagrun) for dm in dag_models),
+                    tuple_(DagRun.dag_id, DagRun.logical_date).in_(
+                        (dm.dag_id, dm.next_dagrun) for dm in dag_models
                     ),
                 )
             )
@@ -1402,7 +1393,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
         existing_dagruns: set[tuple[str, timezone.DateTime]] = set(
             session.execute(
                 select(DagRun.dag_id, DagRun.logical_date).where(
-                    tuple_in_condition((DagRun.dag_id, DagRun.logical_date), 
logical_dates.items())
+                    tuple_(DagRun.dag_id, 
DagRun.logical_date).in_(logical_dates.items())
                 )
             )
         )
@@ -2188,7 +2179,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
         if assets:
             session.execute(
                 delete(AssetActive).where(
-                    tuple_in_condition((AssetActive.name, AssetActive.uri), 
((a.name, a.uri) for a in assets))
+                    tuple_(AssetActive.name, AssetActive.uri).in_((a.name, 
a.uri) for a in assets)
                 )
             )
         Stats.gauge("asset.orphaned", len(assets))
@@ -2201,7 +2192,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
         active_assets = set(
             session.execute(
                 select(AssetActive.name, AssetActive.uri).where(
-                    tuple_in_condition((AssetActive.name, AssetActive.uri), 
((a.name, a.uri) for a in assets))
+                    tuple_(AssetActive.name, AssetActive.uri).in_((a.name, 
a.uri) for a in assets)
                 )
             )
         )
diff --git a/airflow/models/dag.py b/airflow/models/dag.py
index f2090649301..d127914a8c5 100644
--- a/airflow/models/dag.py
+++ b/airflow/models/dag.py
@@ -58,9 +58,9 @@ from sqlalchemy import (
     and_,
     case,
     func,
-    not_,
     or_,
     select,
+    tuple_,
     update,
 )
 from sqlalchemy.ext.associationproxy import association_proxy
@@ -108,7 +108,7 @@ from airflow.utils import timezone
 from airflow.utils.dag_cycle_tester import check_cycle
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import UtcDateTime, lock_rows, 
tuple_in_condition, with_row_locks
+from airflow.utils.sqlalchemy import UtcDateTime, lock_rows, with_row_locks
 from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.types import DagRunTriggeredByType, DagRunType
 
@@ -1081,7 +1081,7 @@ class DAG(TaskSDKDag, LoggingMixin):
                     tis = tis.where(TaskInstance.state.in_(state))
 
         if exclude_run_ids:
-            tis = tis.where(not_(TaskInstance.run_id.in_(exclude_run_ids)))
+            tis = tis.where(TaskInstance.run_id.not_in(exclude_run_ids))
 
         if include_dependent_dags:
             # Recursively find external tasks indicated by ExternalTaskMarker
@@ -1192,7 +1192,7 @@ class DAG(TaskSDKDag, LoggingMixin):
         elif isinstance(next(iter(exclude_task_ids), None), str):
             tis = tis.where(TI.task_id.notin_(exclude_task_ids))
         else:
-            tis = tis.where(not_(tuple_in_condition((TI.task_id, 
TI.map_index), exclude_task_ids)))
+            tis = tis.where(tuple_(TI.task_id, 
TI.map_index).not_in(exclude_task_ids))
 
         return tis
 
diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py
index 7278e88742e..a5bef7e589c 100644
--- a/airflow/models/dagrun.py
+++ b/airflow/models/dagrun.py
@@ -42,6 +42,7 @@ from sqlalchemy import (
     not_,
     or_,
     text,
+    tuple_,
     update,
 )
 from sqlalchemy.exc import IntegrityError
@@ -74,7 +75,7 @@ from airflow.utils.helpers import chunks, is_container, 
prune_dict
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.retries import retry_db_transaction
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import UtcDateTime, nulls_first, 
tuple_in_condition, with_row_locks
+from airflow.utils.sqlalchemy import UtcDateTime, nulls_first, with_row_locks
 from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.types import NOTSET, DagRunTriggeredByType, DagRunType
 
@@ -1644,7 +1645,7 @@ class DagRun(Base, LoggingMixin):
                     .where(
                         TI.dag_id == self.dag_id,
                         TI.run_id == self.run_id,
-                        tuple_in_condition((TI.task_id, TI.map_index), 
schedulable_ti_ids_chunk),
+                        tuple_(TI.task_id, 
TI.map_index).in_(schedulable_ti_ids_chunk),
                     )
                     .values(
                         state=TaskInstanceState.SCHEDULED,
@@ -1668,7 +1669,7 @@ class DagRun(Base, LoggingMixin):
                     .where(
                         TI.dag_id == self.dag_id,
                         TI.run_id == self.run_id,
-                        tuple_in_condition((TI.task_id, TI.map_index), 
dummy_ti_ids_chunk),
+                        tuple_(TI.task_id, 
TI.map_index).in_(dummy_ti_ids_chunk),
                     )
                     .values(
                         state=TaskInstanceState.SUCCESS,
diff --git a/airflow/models/skipmixin.py b/airflow/models/skipmixin.py
index ad5c5d01539..8b59043ecef 100644
--- a/airflow/models/skipmixin.py
+++ b/airflow/models/skipmixin.py
@@ -21,18 +21,17 @@ from collections.abc import Iterable, Sequence
 from types import GeneratorType
 from typing import TYPE_CHECKING
 
-from sqlalchemy import update
+from sqlalchemy import tuple_, update
 
 from airflow.exceptions import AirflowException
 from airflow.models.taskinstance import TaskInstance
 from airflow.utils import timezone
 from airflow.utils.log.logging_mixin import LoggingMixin
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import tuple_in_condition
 from airflow.utils.state import TaskInstanceState
 
 if TYPE_CHECKING:
-    from sqlalchemy import Session
+    from sqlalchemy.orm import Session
 
     from airflow.models.dagrun import DagRun
     from airflow.models.operator import Operator
@@ -74,7 +73,7 @@ class SkipMixin(LoggingMixin):
                     .where(
                         TaskInstance.dag_id == dag_run.dag_id,
                         TaskInstance.run_id == dag_run.run_id,
-                        tuple_in_condition((TaskInstance.task_id, 
TaskInstance.map_index), tasks),
+                        tuple_(TaskInstance.task_id, 
TaskInstance.map_index).in_(tasks),
                     )
                     .values(state=TaskInstanceState.SKIPPED, start_date=now, 
end_date=now)
                     .execution_options(synchronize_session=False)
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 6ef4452834f..d519af41d90 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -55,12 +55,14 @@ from sqlalchemy import (
     Text,
     UniqueConstraint,
     and_,
+    case,
     delete,
     extract,
     false,
     func,
     inspect,
     or_,
+    select,
     text,
     tuple_,
     update,
@@ -71,7 +73,6 @@ from sqlalchemy.ext.hybrid import hybrid_property
 from sqlalchemy.ext.mutable import MutableDict
 from sqlalchemy.orm import lazyload, reconstructor, relationship
 from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value
-from sqlalchemy.sql.expression import case, select
 from sqlalchemy_utils import UUIDType
 
 from airflow import settings
@@ -131,12 +132,7 @@ from airflow.utils.operator_helpers import 
ExecutionCallableRunner, context_to_a
 from airflow.utils.platform import getuser
 from airflow.utils.retries import run_with_db_retries
 from airflow.utils.session import NEW_SESSION, create_session, provide_session
-from airflow.utils.sqlalchemy import (
-    ExecutorConfigType,
-    ExtendedJSON,
-    UtcDateTime,
-    tuple_in_condition,
-)
+from airflow.utils.sqlalchemy import ExecutorConfigType, ExtendedJSON, 
UtcDateTime
 from airflow.utils.state import DagRunState, State, TaskInstanceState
 from airflow.utils.task_group import MappedTaskGroup
 from airflow.utils.task_instance_session import 
set_current_task_instance_session
@@ -3497,7 +3493,7 @@ class TaskInstance(Base, LoggingMixin):
         if task_id_only:
             filters.append(cls.task_id.in_(task_id_only))
         if with_map_index:
-            filters.append(tuple_in_condition((cls.task_id, cls.map_index), 
with_map_index))
+            filters.append(tuple_(cls.task_id, 
cls.map_index).in_(with_map_index))
 
         if not filters:
             return false()
@@ -3675,8 +3671,8 @@ class TaskInstance(Base, LoggingMixin):
             AssetUniqueKey(name, uri)
             for name, uri in session.execute(
                 select(AssetActive.name, AssetActive.uri).where(
-                    tuple_in_condition(
-                        (AssetActive.name, AssetActive.uri), 
[attrs.astuple(key) for key in asset_unique_keys]
+                    tuple_(AssetActive.name, AssetActive.uri).in_(
+                        attrs.astuple(key) for key in asset_unique_keys
                     )
                 )
             )
diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py
index 23597f25a95..917af7c1f1d 100644
--- a/airflow/utils/sqlalchemy.py
+++ b/airflow/utils/sqlalchemy.py
@@ -23,7 +23,7 @@ import datetime
 import logging
 from collections.abc import Generator, Iterable
 from importlib import metadata
-from typing import TYPE_CHECKING, Any, overload
+from typing import TYPE_CHECKING, Any
 
 from packaging import version
 from sqlalchemy import TIMESTAMP, PickleType, event, nullsfirst, tuple_
@@ -438,22 +438,6 @@ def is_lock_not_available_error(error: OperationalError):
     return False
 
 
-@overload
-def tuple_in_condition(
-    columns: tuple[ColumnElement, ...],
-    collection: Iterable[Any],
-) -> ColumnOperators: ...
-
-
-@overload
-def tuple_in_condition(
-    columns: tuple[ColumnElement, ...],
-    collection: Select,
-    *,
-    session: Session,
-) -> ColumnOperators: ...
-
-
 def tuple_in_condition(
     columns: tuple[ColumnElement, ...],
     collection: Iterable[Any] | Select,
@@ -463,46 +447,14 @@ def tuple_in_condition(
     """
     Generate a tuple-in-collection operator to use in ``.where()``.
 
-    For most SQL backends, this generates a simple ``([col, ...]) IN 
[condition]``
-    clause.
+    Kept for backward compatibility. Remove when providers drop support for
+    apache-airflow<3.0.
 
     :meta private:
     """
     return tuple_(*columns).in_(collection)
 
 
-@overload
-def tuple_not_in_condition(
-    columns: tuple[ColumnElement, ...],
-    collection: Iterable[Any],
-) -> ColumnOperators: ...
-
-
-@overload
-def tuple_not_in_condition(
-    columns: tuple[ColumnElement, ...],
-    collection: Select,
-    *,
-    session: Session,
-) -> ColumnOperators: ...
-
-
-def tuple_not_in_condition(
-    columns: tuple[ColumnElement, ...],
-    collection: Iterable[Any] | Select,
-    *,
-    session: Session | None = None,
-) -> ColumnOperators:
-    """
-    Generate a tuple-not-in-collection operator to use in ``.where()``.
-
-    This is similar to ``tuple_in_condition`` except generating ``NOT IN``.
-
-    :meta private:
-    """
-    return tuple_(*columns).not_in(collection)
-
-
 def get_orm_mapper():
     """Get the correct ORM mapper for the installed SQLAlchemy version."""
     import sqlalchemy.orm.mapper
diff --git a/airflow/www/utils.py b/airflow/www/utils.py
index 727139a9a6b..9c319424574 100644
--- a/airflow/www/utils.py
+++ b/airflow/www/utils.py
@@ -37,7 +37,7 @@ from markdown_it import MarkdownIt
 from markupsafe import Markup
 from pygments import highlight, lexers
 from pygments.formatters import HtmlFormatter
-from sqlalchemy import delete, func, select, types
+from sqlalchemy import delete, func, select, tuple_, types
 from sqlalchemy.ext.associationproxy import AssociationProxy
 
 from airflow.api_fastapi.app import get_auth_manager
@@ -49,7 +49,6 @@ from airflow.utils import timezone
 from airflow.utils.code_utils import get_python_source
 from airflow.utils.helpers import alchemy_to_dict
 from airflow.utils.json import WebEncoder
-from airflow.utils.sqlalchemy import tuple_in_condition
 from airflow.utils.state import State, TaskInstanceState
 from airflow.www.forms import DateTimeWithTimezoneField
 from airflow.www.widgets import AirflowDateTimePickerWidget
@@ -867,12 +866,7 @@ class DagRunCustomSQLAInterface(CustomSQLAInterface):
 
     def delete_all(self, items: list[Model]) -> bool:
         self.session.execute(
-            delete(TI).where(
-                tuple_in_condition(
-                    (TI.dag_id, TI.run_id),
-                    ((x.dag_id, x.run_id) for x in items),
-                )
-            )
+            delete(TI).where(tuple_(TI.dag_id, TI.run_id).in_((x.dag_id, 
x.run_id) for x in items))
         )
         return super().delete_all(items)
 
diff --git a/providers/src/airflow/providers/standard/utils/sensor_helper.py 
b/providers/src/airflow/providers/standard/utils/sensor_helper.py
index 57d906da671..8c4524cba65 100644
--- a/providers/src/airflow/providers/standard/utils/sensor_helper.py
+++ b/providers/src/airflow/providers/standard/utils/sensor_helper.py
@@ -18,14 +18,14 @@ from __future__ import annotations
 
 from typing import TYPE_CHECKING, cast
 
-from sqlalchemy import func, select
+from sqlalchemy import func, select, tuple_
 
 from airflow.models import DagBag, DagRun, TaskInstance
 from airflow.utils.session import NEW_SESSION, provide_session
-from airflow.utils.sqlalchemy import tuple_in_condition
 
 if TYPE_CHECKING:
-    from sqlalchemy.orm import Query, Session
+    from sqlalchemy.orm import Session
+    from sqlalchemy.sql import Executable
 
 
 @provide_session
@@ -55,9 +55,7 @@ def _get_count(
     if external_task_ids:
         count = (
             session.scalar(
-                _count_query(TI, states, dttm_filter, external_dag_id, 
session).filter(
-                    TI.task_id.in_(external_task_ids)
-                )
+                _count_stmt(TI, states, dttm_filter, 
external_dag_id).where(TI.task_id.in_(external_task_ids))
             )
         ) / len(external_task_ids)
     elif external_task_group_id:
@@ -69,17 +67,17 @@ def _get_count(
         else:
             count = (
                 session.scalar(
-                    _count_query(TI, states, dttm_filter, external_dag_id, 
session).filter(
-                        tuple_in_condition((TI.task_id, TI.map_index), 
external_task_group_task_ids)
+                    _count_stmt(TI, states, dttm_filter, 
external_dag_id).where(
+                        tuple_(TI.task_id, 
TI.map_index).in_(external_task_group_task_ids)
                     )
                 )
             ) / len(external_task_group_task_ids)
     else:
-        count = session.scalar(_count_query(DR, states, dttm_filter, 
external_dag_id, session))
+        count = session.scalar(_count_stmt(DR, states, dttm_filter, 
external_dag_id))
     return cast(int, count)
 
 
-def _count_query(model, states, dttm_filter, external_dag_id, session: 
Session) -> Query:
+def _count_stmt(model, states, dttm_filter, external_dag_id) -> Executable:
     """
     Get the count of records against dttm filter and states.
 
@@ -87,12 +85,10 @@ def _count_query(model, states, dttm_filter, 
external_dag_id, session: Session)
     :param states: task or dag states
     :param dttm_filter: date time filter for logical date
     :param external_dag_id: The ID of the external DAG.
-    :param session: airflow session object
     """
-    query = select(func.count()).filter(
+    return select(func.count()).where(
         model.dag_id == external_dag_id, model.state.in_(states), 
model.logical_date.in_(dttm_filter)
     )
-    return query
 
 
 def _get_external_task_group_task_ids(dttm_filter, external_task_group_id, 
external_dag_id, session):

Reply via email to