This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 75b2b7d1ece Refactor/sqla2 multiple files (#59540)
75b2b7d1ece is described below

commit 75b2b7d1ece22b2429e33d39d3933cc84c530d92
Author: KUAN-HAO HUANG <[email protected]>
AuthorDate: Sat Dec 20 04:53:29 2025 +0800

    Refactor/sqla2 multiple files (#59540)
    
    * Refactor deprecated SQLA models/test_serialized_dag.py
    
    * Refactor deprecated SQLA models/test_pool.py
    
    * Refactor deprecated SQLA models/test_trigger.py
    
    * Refactor deprecated SQLA models/test_callback.py
    
    * Refactor deprecated SQLA models/test_xcom.py
    
    * Refactor deprecated SQLA models/test_cleartasks.py
    
    * Refactor deprecated SQLA models/test_dagrun.py
    
    * Refactor deprecated SQLA test_log_handlers.py
    
    * fix error
    
    * Refactor deprecated SQLA utils/test_state.py
    
    * fix error
    
    * remove redundant parts
    
    * change to where()
    
    * fix pre-commit error
    
    * fix error
    
    * remove redundant commit()
    
    ---------
    
    Co-authored-by: Jarek Potiuk <[email protected]>
---
 .pre-commit-config.yaml                            |   9 ++
 airflow-core/tests/unit/models/test_callback.py    |   5 +-
 airflow-core/tests/unit/models/test_cleartasks.py  |  37 +++---
 airflow-core/tests/unit/models/test_dagrun.py      | 126 +++++++++++----------
 airflow-core/tests/unit/models/test_pool.py        |   7 +-
 .../tests/unit/models/test_serialized_dag.py       |  45 ++++----
 airflow-core/tests/unit/models/test_trigger.py     |  69 ++++++-----
 airflow-core/tests/unit/models/test_xcom.py        |  15 +--
 airflow-core/tests/unit/utils/test_log_handlers.py |  21 ++--
 airflow-core/tests/unit/utils/test_state.py        |  11 +-
 10 files changed, 184 insertions(+), 161 deletions(-)

diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index f417efb2333..572189651b3 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -425,6 +425,13 @@ repos:
             ^airflow-ctl.*\.py$|
             ^airflow-core/src/airflow/models/.*\.py$|
             
^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py$|
+            ^airflow-core/tests/unit/models/test_serialized_dag.py$|
+            ^airflow-core/tests/unit/models/test_pool.py$|
+            ^airflow-core/tests/unit/models/test_trigger.py$|
+            ^airflow-core/tests/unit/models/test_callback.py$|
+            ^airflow-core/tests/unit/models/test_cleartasks.py$|
+            ^airflow-core/tests/unit/models/test_xcom.py$|
+            ^airflow-core/tests/unit/models/test_dagrun.py$|
             
^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_sources.py$|
             
^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_hitl.py$|
             
^airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_assets.py$|
@@ -444,6 +451,8 @@ repos:
             ^airflow-core/tests/unit/models/test_dagwarning.py$|
             ^airflow-core/tests/integration/otel/test_otel.py$|
             ^airflow-core/tests/unit/utils/test_db_cleanup.py$|
+            ^airflow-core/tests/unit/utils/test_state.py$|
+            ^airflow-core/tests/unit/utils/test_log_handlers.py$|
             ^dev/airflow_perf/scheduler_dag_execution_timing.py$|
             ^providers/celery/.*\.py$|
             ^providers/cncf/kubernetes/.*\.py$|
diff --git a/airflow-core/tests/unit/models/test_callback.py 
b/airflow-core/tests/unit/models/test_callback.py
index 09d0931557d..dfc19fc61a3 100644
--- a/airflow-core/tests/unit/models/test_callback.py
+++ b/airflow-core/tests/unit/models/test_callback.py
@@ -17,6 +17,7 @@
 from __future__ import annotations
 
 import pytest
+from sqlalchemy import select
 
 from airflow.models import Trigger
 from airflow.models.callback import (
@@ -118,7 +119,7 @@ class TestTriggererCallback:
         session.add(callback)
         session.commit()
 
-        retrieved = session.query(Callback).filter_by(id=callback.id).one()
+        retrieved = session.scalar(select(Callback).where(Callback.id == 
callback.id))
         assert isinstance(retrieved, TriggererCallback)
         assert retrieved.fetch_method == CallbackFetchMethod.IMPORT_PATH
         assert retrieved.data == TEST_ASYNC_CALLBACK.serialize()
@@ -188,7 +189,7 @@ class TestExecutorCallback:
         session.add(callback)
         session.commit()
 
-        retrieved = session.query(Callback).filter_by(id=callback.id).one()
+        retrieved = session.scalar(select(Callback).where(Callback.id == 
callback.id))
         assert isinstance(retrieved, ExecutorCallback)
         assert retrieved.fetch_method == CallbackFetchMethod.IMPORT_PATH
         assert retrieved.data == TEST_SYNC_CALLBACK.serialize()
diff --git a/airflow-core/tests/unit/models/test_cleartasks.py 
b/airflow-core/tests/unit/models/test_cleartasks.py
index 4d8063f8419..e92f57e741d 100644
--- a/airflow-core/tests/unit/models/test_cleartasks.py
+++ b/airflow-core/tests/unit/models/test_cleartasks.py
@@ -21,7 +21,7 @@ import datetime
 import random
 
 import pytest
-from sqlalchemy import select
+from sqlalchemy import func, select
 
 from airflow.models.dag_version import DagVersion
 from airflow.models.dagrun import DagRun
@@ -87,7 +87,7 @@ class TestClearTasks:
             # this is equivalent to topological sort. It would not work in 
general case
             # but it works for our case because we specifically constructed 
test DAGS
             # in the way that those two sort methods are equivalent
-            qry = session.query(TI).filter(TI.dag_id == 
dag.dag_id).order_by(TI.task_id).all()
+            qry = session.scalars(select(TI).where(TI.dag_id == 
dag.dag_id).order_by(TI.task_id)).all()
             clear_task_instances(qry, session)
 
             ti0.refresh_from_db(session)
@@ -121,7 +121,7 @@ class TestClearTasks:
             # this is equivalent to topological sort. It would not work in 
general case
             # but it works for our case because we specifically constructed 
test DAGS
             # in the way that those two sort methods are equivalent
-            qry = session.query(TI).filter(TI.dag_id == 
dag.dag_id).order_by(TI.task_id).all()
+            qry = session.scalars(select(TI).where(TI.dag_id == 
dag.dag_id).order_by(TI.task_id)).all()
             clear_task_instances(qry, session)
 
             ti0.refresh_from_db()
@@ -186,12 +186,12 @@ class TestClearTasks:
         # this is equivalent to topological sort. It would not work in general 
case
         # but it works for our case because we specifically constructed test 
DAGS
         # in the way that those two sort methods are equivalent
-        qry = session.query(TI).filter(TI.dag_id == 
dag.dag_id).order_by(TI.task_id).all()
-        assert session.query(TaskInstanceHistory).count() == 0
+        qry = session.scalars(select(TI).where(TI.dag_id == 
dag.dag_id).order_by(TI.task_id)).all()
+        assert 
session.scalar(select(func.count()).select_from(TaskInstanceHistory)) == 0
         clear_task_instances(qry, session, dag_run_state=state)
         session.flush()
         # 2 TIs were cleared so 2 history records should be created
-        assert session.query(TaskInstanceHistory).count() == 2
+        assert 
session.scalar(select(func.count()).select_from(TaskInstanceHistory)) == 2
 
         session.refresh(dr)
 
@@ -229,7 +229,7 @@ class TestClearTasks:
         # this is equivalent to topological sort. It would not work in general 
case
         # but it works for our case because we specifically constructed test 
DAGS
         # in the way that those two sort methods are equivalent
-        qry = session.query(TI).filter(TI.dag_id == 
dag.dag_id).order_by(TI.task_id).all()
+        qry = session.scalars(select(TI).where(TI.dag_id == 
dag.dag_id).order_by(TI.task_id)).all()
         clear_task_instances(qry, session)
         session.flush()
 
@@ -282,7 +282,7 @@ class TestClearTasks:
         # this is equivalent to topological sort. It would not work in general 
case
         # but it works for our case because we specifically constructed test 
DAGS
         # in the way that those two sort methods are equivalent
-        qry = session.query(TI).filter(TI.dag_id == 
dag.dag_id).order_by(TI.task_id).all()
+        qry = session.scalars(select(TI).where(TI.dag_id == 
dag.dag_id).order_by(TI.task_id)).all()
         clear_task_instances(qry, session)
         session.flush()
 
@@ -394,7 +394,7 @@ class TestClearTasks:
         # this is equivalent to topological sort. It would not work in general 
case
         # but it works for our case because we specifically constructed test 
DAGS
         # in the way that those two sort methods are equivalent
-        qry = session.query(TI).filter(TI.dag_id == 
dag.dag_id).order_by(TI.task_id).all()
+        qry = session.scalars(select(TI).where(TI.dag_id == 
dag.dag_id).order_by(TI.task_id)).all()
         clear_task_instances(qry, session)
 
         ti0.refresh_from_db(session=session)
@@ -477,7 +477,9 @@ class TestClearTasks:
         with create_session() as session:
 
             def count_task_reschedule(ti):
-                return 
session.query(TaskReschedule).filter(TaskReschedule.ti_id == ti.id).count()
+                return session.scalar(
+                    
select(func.count()).select_from(TaskReschedule).where(TaskReschedule.ti_id == 
ti.id)
+                )
 
             assert count_task_reschedule(ti0) == 1
             assert count_task_reschedule(ti1) == 1
@@ -485,12 +487,9 @@ class TestClearTasks:
             # this is equivalent to topological sort. It would not work in 
general case
             # but it works for our case because we specifically constructed 
test DAGS
             # in the way that those two sort methods are equivalent
-            qry = (
-                session.query(TI)
-                .filter(TI.dag_id == dag.dag_id, TI.task_id == ti0.task_id)
-                .order_by(TI.task_id)
-                .all()
-            )
+            qry = session.scalars(
+                select(TI).where(TI.dag_id == dag.dag_id, TI.task_id == 
ti0.task_id).order_by(TI.task_id)
+            ).all()
             clear_task_instances(qry, session)
             assert count_task_reschedule(ti0) == 0
             assert count_task_reschedule(ti1) == 1
@@ -531,7 +530,7 @@ class TestClearTasks:
         ti1.state = state
         session = dag_maker.session
         session.flush()
-        qry = session.query(TI).filter(TI.dag_id == 
dag.dag_id).order_by(TI.task_id).all()
+        qry = session.scalars(select(TI).where(TI.dag_id == 
dag.dag_id).order_by(TI.task_id)).all()
         clear_task_instances(qry, session)
         session.flush()
 
@@ -716,10 +715,10 @@ class TestClearTasks:
         new_dag_version = DagVersion.get_latest_version(dag.dag_id)
 
         assert old_dag_version.id != new_dag_version.id
-        qry = session.query(TI).filter(TI.dag_id == 
dag.dag_id).order_by(TI.task_id).all()
+        qry = session.scalars(select(TI).where(TI.dag_id == 
dag.dag_id).order_by(TI.task_id)).all()
         clear_task_instances(qry, session, 
run_on_latest_version=run_on_latest_version)
         session.commit()
-        dr = session.query(DagRun).filter(DagRun.dag_id == dag.dag_id).one()
+        dr = session.scalar(select(DagRun).where(DagRun.dag_id == dag.dag_id))
         if run_on_latest_version:
             assert dr.created_dag_version_id == new_dag_version.id
             assert dr.bundle_version == new_dag_version.bundle_version
diff --git a/airflow-core/tests/unit/models/test_dagrun.py 
b/airflow-core/tests/unit/models/test_dagrun.py
index daeff282424..a09aceccc02 100644
--- a/airflow-core/tests/unit/models/test_dagrun.py
+++ b/airflow-core/tests/unit/models/test_dagrun.py
@@ -27,7 +27,7 @@ from unittest.mock import call
 
 import pendulum
 import pytest
-from sqlalchemy import exists, select
+from sqlalchemy import exists, func, select
 from sqlalchemy.orm import joinedload
 
 from airflow import settings
@@ -155,10 +155,10 @@ class TestDagRun:
             EmptyOperator(task_id="backfill_task_0")
         self.create_dag_run(dag, logical_date=now, is_backfill=True, 
state=state, session=session)
 
-        qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all()
+        qry = session.scalars(select(TI).where(TI.dag_id == dag.dag_id)).all()
         clear_task_instances(qry, session)
         session.flush()
-        dr0 = session.query(DagRun).filter(DagRun.dag_id == dag_id, 
DagRun.logical_date == now).first()
+        dr0 = session.scalar(select(DagRun).where(DagRun.dag_id == dag_id, 
DagRun.logical_date == now))
         assert dr0.state == state
         assert dr0.clear_number < 1
 
@@ -170,10 +170,10 @@ class TestDagRun:
             EmptyOperator(task_id="backfill_task_0")
         self.create_dag_run(dag, logical_date=now, is_backfill=True, 
state=state, session=session)
 
-        qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all()
+        qry = session.scalars(select(TI).where(TI.dag_id == dag.dag_id)).all()
         clear_task_instances(qry, session)
         session.flush()
-        dr0 = session.query(DagRun).filter(DagRun.dag_id == dag_id, 
DagRun.logical_date == now).first()
+        dr0 = session.scalar(select(DagRun).where(DagRun.dag_id == dag_id, 
DagRun.logical_date == now))
         assert dr0.state == DagRunState.QUEUED
         assert dr0.clear_number == 1
 
@@ -721,7 +721,7 @@ class TestDagRun:
         session.merge(dr)
         session.commit()
 
-        dr_database = session.query(DagRun).filter(DagRun.run_id == 
dr.run_id).one()
+        dr_database = session.scalar(select(DagRun).where(DagRun.run_id == 
dr.run_id))
         assert dr_database.end_date is not None
         assert dr.end_date == dr_database.end_date
 
@@ -729,14 +729,14 @@ class TestDagRun:
         session.merge(dr)
         session.commit()
 
-        dr_database = session.query(DagRun).filter(DagRun.run_id == 
dr.run_id).one()
+        dr_database = session.scalar(select(DagRun).where(DagRun.run_id == 
dr.run_id))
 
         assert dr_database.end_date is None
 
         dr.set_state(DagRunState.FAILED)
         session.merge(dr)
         session.commit()
-        dr_database = session.query(DagRun).filter(DagRun.run_id == 
dr.run_id).one()
+        dr_database = session.scalar(select(DagRun).where(DagRun.run_id == 
dr.run_id))
 
         assert dr_database.end_date is not None
         assert dr.end_date == dr_database.end_date
@@ -764,7 +764,7 @@ class TestDagRun:
 
         dr.update_state()
 
-        dr_database = session.query(DagRun).filter(DagRun.run_id == 
dr.run_id).one()
+        dr_database = session.scalar(select(DagRun).where(DagRun.run_id == 
dr.run_id))
         assert dr_database.end_date is not None
         assert dr.end_date == dr_database.end_date
 
@@ -772,7 +772,7 @@ class TestDagRun:
         ti_op2.set_state(state=TaskInstanceState.RUNNING, session=session)
         dr.update_state()
 
-        dr_database = session.query(DagRun).filter(DagRun.run_id == 
dr.run_id).one()
+        dr_database = session.scalar(select(DagRun).where(DagRun.run_id == 
dr.run_id))
 
         assert dr._state == DagRunState.RUNNING
         assert dr.end_date is None
@@ -782,7 +782,7 @@ class TestDagRun:
         ti_op2.set_state(state=TaskInstanceState.FAILED, session=session)
         dr.update_state()
 
-        dr_database = session.query(DagRun).filter(DagRun.run_id == 
dr.run_id).one()
+        dr_database = session.scalar(select(DagRun).where(DagRun.run_id == 
dr.run_id))
 
         assert dr_database.end_date is not None
         assert dr.end_date == dr_database.end_date
@@ -1216,7 +1216,7 @@ class TestDagRun:
             EmptyOperator(task_id="empty")
         dag_run = dag_maker.create_dagrun()
 
-        dm = 
session.query(DagModel).options(joinedload(DagModel.dag_versions)).one()
+        dm = 
session.scalar(select(DagModel).options(joinedload(DagModel.dag_versions)))
         assert dag_run.dag_versions[0].id == dm.dag_versions[0].id
 
     def test_dag_run_version_number(self, dag_maker, session):
@@ -1231,7 +1231,7 @@ class TestDagRun:
         tis[1].dag_version = dag_v
         session.merge(tis[1])
         session.flush()
-        dag_run = session.query(DagRun).filter(DagRun.run_id == 
dag_run.run_id).one()
+        dag_run = session.scalar(select(DagRun).where(DagRun.run_id == 
dag_run.run_id))
         # Check that dag_run.version_number returns the version number of
         # the latest task instance dag_version
         assert dag_run.version_number == dag_v.version_number
@@ -1337,14 +1337,14 @@ class TestDagRun:
         dag_run1_deadline = exists().where(Deadline.dagrun_id == dag_run1.id)
         dag_run2_deadline = exists().where(Deadline.dagrun_id == dag_run2.id)
 
-        assert session.query(dag_run1_deadline).scalar()
-        assert session.query(dag_run2_deadline).scalar()
+        assert session.scalar(select(dag_run1_deadline))
+        assert session.scalar(select(dag_run2_deadline))
 
         session.add(dag_run1)
         dag_run1.update_state()
 
-        assert not session.query(dag_run1_deadline).scalar()
-        assert session.query(dag_run2_deadline).scalar()
+        assert not session.scalar(select(dag_run1_deadline))
+        assert session.scalar(select(dag_run2_deadline))
         assert dag_run1.state == DagRunState.SUCCESS
         assert dag_run2.state == DagRunState.RUNNING
 
@@ -1399,13 +1399,12 @@ def test_expand_mapped_task_instance_at_create(is_noop, 
dag_maker, session):
             mapped = 
MockOperator.partial(task_id="task_2").expand(arg2=literal)
 
         dr = dag_maker.create_dagrun()
-        indices = (
-            session.query(TI.map_index)
-            .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, 
run_id=dr.run_id)
+        indices = session.scalars(
+            select(TI.map_index)
+            .where(TI.task_id == mapped.task_id, TI.dag_id == mapped.dag_id, 
TI.run_id == dr.run_id)
             .order_by(TI.map_index)
-            .all()
-        )
-        assert indices == [(0,), (1,), (2,), (3,)]
+        ).all()
+        assert indices == [0, 1, 2, 3]
 
 
 @pytest.mark.parametrize("is_noop", [True, False])
@@ -1422,13 +1421,12 @@ def 
test_expand_mapped_task_instance_task_decorator(is_noop, dag_maker, session)
             mynameis.expand(arg=literal)
 
         dr = dag_maker.create_dagrun()
-        indices = (
-            session.query(TI.map_index)
-            .filter_by(task_id="mynameis", dag_id=dr.dag_id, run_id=dr.run_id)
+        indices = session.scalars(
+            select(TI.map_index)
+            .where(TI.task_id == "mynameis", TI.dag_id == dr.dag_id, TI.run_id 
== dr.run_id)
             .order_by(TI.map_index)
-            .all()
-        )
-        assert indices == [(0,), (1,), (2,), (3,)]
+        ).all()
+        assert indices == [0, 1, 2, 3]
 
 
 def test_mapped_literal_verify_integrity(dag_maker, session):
@@ -1444,7 +1442,7 @@ def test_mapped_literal_verify_integrity(dag_maker, 
session):
 
     query = (
         select(TI.map_index, TI.state)
-        .filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id)
+        .where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id == 
dr.run_id)
         .order_by(TI.map_index)
     )
     indices = session.execute(query).all()
@@ -1483,12 +1481,11 @@ def 
test_mapped_literal_to_xcom_arg_verify_integrity(dag_maker, session):
     dag_version_id = DagVersion.get_latest_version(dag_id=dr.dag_id, 
session=session).id
     dr.verify_integrity(dag_version_id=dag_version_id, session=session)
 
-    indices = (
-        session.query(TI.map_index, TI.state)
-        .filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id)
+    indices = session.execute(
+        select(TI.map_index, TI.state)
+        .where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id == 
dr.run_id)
         .order_by(TI.map_index)
-        .all()
-    )
+    ).all()
 
     assert indices == [
         (0, TaskInstanceState.REMOVED),
@@ -1511,7 +1508,7 @@ def 
test_mapped_literal_length_increase_adds_additional_ti(dag_maker, session):
 
     query = (
         select(TI.map_index, TI.state)
-        .filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id)
+        .where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id == 
dr.run_id)
         .order_by(TI.map_index)
     )
     indices = session.execute(query).all()
@@ -1552,7 +1549,7 @@ def 
test_mapped_literal_length_reduction_adds_removed_state(dag_maker, session):
     dr = dag_maker.create_dagrun()
     query = (
         select(TI.map_index, TI.state)
-        .filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id)
+        .where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id == 
dr.run_id)
         .order_by(TI.map_index)
     )
     indices = session.execute(query).all()
@@ -1661,7 +1658,7 @@ def 
test_mapped_literal_length_reduction_at_runtime_adds_removed_state(dag_maker
     dr.task_instance_scheduling_decisions(session=session)
     query = (
         select(TI.map_index, TI.state)
-        .filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id)
+        .where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id == 
dr.run_id)
         .order_by(TI.map_index)
     )
     indices = session.execute(query).all()
@@ -1751,7 +1748,7 @@ def 
test_calls_to_verify_integrity_with_mapped_task_zero_length_at_runtime(dag_m
 
     query = (
         select(TI.map_index, TI.state)
-        .filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id)
+        .where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id == 
dr.run_id)
         .order_by(TI.map_index)
     )
     indices = session.execute(query).all()
@@ -1786,17 +1783,17 @@ def 
test_mapped_mixed_literal_not_expanded_at_create(dag_maker, session):
 
     dr = dag_maker.create_dagrun()
     query = (
-        session.query(TI.map_index, TI.state)
-        .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, 
run_id=dr.run_id)
+        select(TI.map_index, TI.state)
+        .where(TI.task_id == mapped.task_id, TI.dag_id == mapped.dag_id, 
TI.run_id == dr.run_id)
         .order_by(TI.map_index)
     )
 
-    assert query.all() == [(-1, None)]
+    assert session.execute(query).all() == [(-1, None)]
 
     # Verify_integrity shouldn't change the result now that the TIs exist
     dag_version_id = DagVersion.get_latest_version(dag_id=dr.dag_id, 
session=session).id
     dr.verify_integrity(dag_version_id=dag_version_id, session=session)
-    assert query.all() == [(-1, None)]
+    assert session.execute(query).all() == [(-1, None)]
 
 
 def test_mapped_task_group_expands_at_create(dag_maker, session):
@@ -1823,11 +1820,11 @@ def test_mapped_task_group_expands_at_create(dag_maker, 
session):
 
     dr = dag_maker.create_dagrun()
     query = (
-        session.query(TI.task_id, TI.map_index, TI.state)
-        .filter_by(dag_id=dr.dag_id, run_id=dr.run_id)
+        select(TI.task_id, TI.map_index, TI.state)
+        .where(TI.dag_id == dr.dag_id, TI.run_id == dr.run_id)
         .order_by(TI.task_id, TI.map_index)
     )
-    assert query.all() == [
+    assert session.execute(query).all() == [
         ("tg.t1", 0, None),
         ("tg.t1", 1, None),
         # ("tg.t2", 0, None),
@@ -1904,12 +1901,11 @@ def test_ti_scheduling_mapped_zero_length(dag_maker, 
session):
     # expanded against a zero-length XCom.
     assert decision.finished_tis == [ti1, ti2]
 
-    indices = (
-        session.query(TI.map_index, TI.state)
-        .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id, 
run_id=dr.run_id)
+    indices = session.execute(
+        select(TI.map_index, TI.state)
+        .where(TI.task_id == mapped.task_id, TI.dag_id == mapped.dag_id, 
TI.run_id == dr.run_id)
         .order_by(TI.map_index)
-        .all()
-    )
+    ).all()
 
     assert indices == [(-1, TaskInstanceState.SKIPPED)]
 
@@ -2576,8 +2572,14 @@ def 
test_clearing_task_and_moving_from_non_mapped_to_mapped(dag_maker, session):
 
     dr1: DagRun = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
     ti = dr1.get_task_instances()[0]
-    filter_kwargs = dict(dag_id=ti.dag_id, task_id=ti.task_id, 
run_id=ti.run_id, map_index=ti.map_index)
-    ti = session.query(TaskInstance).filter_by(**filter_kwargs).one()
+    ti = session.scalar(
+        select(TaskInstance).where(
+            TaskInstance.dag_id == ti.dag_id,
+            TaskInstance.task_id == ti.task_id,
+            TaskInstance.run_id == ti.run_id,
+            TaskInstance.map_index == ti.map_index,
+        )
+    )
 
     tr = TaskReschedule(
         ti_id=ti.id,
@@ -2598,10 +2600,10 @@ def 
test_clearing_task_and_moving_from_non_mapped_to_mapped(dag_maker, session):
     XComModel.set(key="test", value="value", task_id=ti.task_id, 
dag_id=dag.dag_id, run_id=ti.run_id)
     session.commit()
     for table in [TaskInstanceNote, TaskReschedule, XComModel]:
-        assert session.query(table).count() == 1
+        assert session.scalar(select(func.count()).select_from(table)) == 1
     dr1.task_instance_scheduling_decisions(session)
     for table in [TaskInstanceNote, TaskReschedule, XComModel]:
-        assert session.query(table).count() == 0
+        assert session.scalar(select(func.count()).select_from(table)) == 0
 
 
 def test_dagrun_with_note(dag_maker, session):
@@ -2619,14 +2621,14 @@ def test_dagrun_with_note(dag_maker, session):
     session.add(dr)
     session.commit()
 
-    dr_note = session.query(DagRunNote).filter(DagRunNote.dag_run_id == 
dr.id).one()
+    dr_note = session.scalar(select(DagRunNote).where(DagRunNote.dag_run_id == 
dr.id))
     assert dr_note.content == "dag run with note"
 
     session.delete(dr)
     session.commit()
 
-    assert session.query(DagRun).filter(DagRun.id == dr.id).one_or_none() is 
None
-    assert session.query(DagRunNote).filter(DagRunNote.dag_run_id == 
dr.id).one_or_none() is None
+    assert session.scalar(select(DagRun).where(DagRun.id == dr.id)) is None
+    assert session.scalar(select(DagRunNote).where(DagRunNote.dag_run_id == 
dr.id)) is None
 
 
 @pytest.mark.parametrize(
@@ -2655,7 +2657,7 @@ def test_teardown_failure_behaviour_on_dagrun(dag_maker, 
session, dag_run_state,
     session.flush()
     dr.update_state()
     session.flush()
-    dr = session.query(DagRun).one()
+    dr = session.scalar(select(DagRun))
     assert dr.state == dag_run_state
 
 
@@ -2694,7 +2696,7 @@ def test_teardown_failure_on_non_leaf_behaviour_on_dagrun(
     session.flush()
     dr.update_state()
     session.flush()
-    dr = session.query(DagRun).one()
+    dr = session.scalar(select(DagRun))
     assert dr.state == dag_run_state
 
 
@@ -2729,7 +2731,7 @@ def 
test_work_task_failure_when_setup_teardown_are_successful(dag_maker, session
     session.flush()
     dr.update_state()
     session.flush()
-    dr = session.query(DagRun).one()
+    dr = session.scalar(select(DagRun))
     assert dr.state == DagRunState.FAILED
 
 
@@ -2765,7 +2767,7 @@ def 
test_failure_of_leaf_task_not_connected_to_teardown_task(dag_maker, session)
     session.flush()
     dr.update_state()
     session.flush()
-    dr = session.query(DagRun).one()
+    dr = session.scalar(select(DagRun))
     assert dr.state == DagRunState.FAILED
 
 
diff --git a/airflow-core/tests/unit/models/test_pool.py 
b/airflow-core/tests/unit/models/test_pool.py
index fa59d85ffe8..a275bf2d4b2 100644
--- a/airflow-core/tests/unit/models/test_pool.py
+++ b/airflow-core/tests/unit/models/test_pool.py
@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING
 
 import pendulum
 import pytest
+from sqlalchemy import func, select
 
 from airflow import settings
 from airflow.exceptions import AirflowException, PoolNotFound
@@ -292,7 +293,7 @@ class TestPool:
         assert pool.slots == 5
         assert pool.description == ""
         assert pool.include_deferred is True
-        assert session.query(Pool).count() == self.TOTAL_POOL_COUNT + 1
+        assert session.scalar(select(func.count()).select_from(Pool)) == 
self.TOTAL_POOL_COUNT + 1
 
     def test_create_pool_existing(self, session):
         self.add_pools()
@@ -303,13 +304,13 @@ class TestPool:
         assert pool.slots == 5
         assert pool.description == ""
         assert pool.include_deferred is False
-        assert session.query(Pool).count() == self.TOTAL_POOL_COUNT
+        assert session.scalar(select(func.count()).select_from(Pool)) == 
self.TOTAL_POOL_COUNT
 
     def test_delete_pool(self, session):
         self.add_pools()
         pool = Pool.delete_pool(name=self.pools[-1].pool)
         assert pool.pool == self.pools[-1].pool
-        assert session.query(Pool).count() == self.TOTAL_POOL_COUNT - 1
+        assert session.scalar(select(func.count()).select_from(Pool)) == 
self.TOTAL_POOL_COUNT - 1
 
     def test_delete_pool_non_existing(self):
         with pytest.raises(PoolNotFound, match="^Pool 'test' doesn't exist$"):
diff --git a/airflow-core/tests/unit/models/test_serialized_dag.py 
b/airflow-core/tests/unit/models/test_serialized_dag.py
index 2472c072046..ddc98f72064 100644
--- a/airflow-core/tests/unit/models/test_serialized_dag.py
+++ b/airflow-core/tests/unit/models/test_serialized_dag.py
@@ -24,7 +24,7 @@ from unittest import mock
 
 import pendulum
 import pytest
-from sqlalchemy import func, select, update
+from sqlalchemy import delete, func, select, update
 
 import airflow.example_dags as example_dags_module
 from airflow.dag_processing.dagbag import DagBag
@@ -59,7 +59,12 @@ def make_example_dags(module):
     from airflow.utils.session import create_session
 
     with create_session() as session:
-        if session.query(DagBundleModel).filter(DagBundleModel.name == 
"testing").count() == 0:
+        if (
+            session.scalar(
+                
select(func.count()).select_from(DagBundleModel).where(DagBundleModel.name == 
"testing")
+            )
+            == 0
+        ):
             testing = DagBundleModel(name="testing")
             session.add(testing)
 
@@ -101,7 +106,7 @@ class TestSerializedDagModel:
         with create_session() as session:
             for dag in example_dags.values():
                 assert SDM.has_dag(dag.dag_id)
-                result = session.query(SDM).filter(SDM.dag_id == 
dag.dag_id).one()
+                result = session.scalar(select(SDM).where(SDM.dag_id == 
dag.dag_id))
 
                 assert result.dag_version.dag_code.fileloc == dag.fileloc
                 # Verifies JSON schema.
@@ -118,7 +123,7 @@ class TestSerializedDagModel:
         with dag_maker("dag1"):
             PythonOperator(task_id="task1", python_callable=lambda x: None)
         dag_maker.create_dagrun(run_id="test2", 
logical_date=pendulum.datetime(2025, 1, 1))
-        assert len(session.query(DagVersion).all()) == 2
+        assert len(session.scalars(select(DagVersion)).all()) == 2
 
         with dag_maker("dag2"):
 
@@ -136,7 +141,7 @@ class TestSerializedDagModel:
                 pass
 
             my_callable2()
-        assert len(session.query(DagVersion).all()) == 4
+        assert len(session.scalars(select(DagVersion)).all()) == 4
 
     def test_serialized_dag_is_updated_if_dag_is_changed(self, 
testing_dag_bundle):
         """Test Serialized DAG is updated if DAG is changed"""
@@ -212,7 +217,7 @@ class TestSerializedDagModel:
         dag.doc_md = "new doc string"
         SDM.write_dag(LazyDeserializedDAG.from_dag(dag), bundle_name="testing")
         serialized_dags2 = SDM.read_all_dags()
-        sdags = session.query(SDM).all()
+        sdags = session.scalars(select(SDM)).all()
         # assert only the latest SDM is returned
         assert len(sdags) != len(serialized_dags2)
 
@@ -334,7 +339,7 @@ class TestSerializedDagModel:
     def test_new_dag_versions_are_not_created_if_no_dagruns(self, dag_maker, 
session):
         with dag_maker("dag1") as dag:
             PythonOperator(task_id="task1", python_callable=lambda: None)
-        assert session.query(SDM).count() == 1
+        assert session.scalar(select(func.count()).select_from(SDM)) == 1
         sdm1 = SDM.get(dag.dag_id, session=session)
         dag_hash = sdm1.dag_hash
         created_at = sdm1.created_at
@@ -347,21 +352,21 @@ class TestSerializedDagModel:
         assert sdm2.dag_hash != dag_hash  # first recorded serdag
         assert sdm2.created_at == created_at
         assert sdm2.last_updated != last_updated
-        assert session.query(DagVersion).count() == 1
-        assert session.query(SDM).count() == 1
+        assert session.scalar(select(func.count()).select_from(DagVersion)) == 
1
+        assert session.scalar(select(func.count()).select_from(SDM)) == 1
 
     def test_new_dag_versions_are_created_if_there_is_a_dagrun(self, 
dag_maker, session):
         with dag_maker("dag1") as dag:
             PythonOperator(task_id="task1", python_callable=lambda: None)
         dag_maker.create_dagrun(run_id="test3", 
logical_date=pendulum.datetime(2025, 1, 2))
-        assert session.query(SDM).count() == 1
-        assert session.query(DagVersion).count() == 1
+        assert session.scalar(select(func.count()).select_from(SDM)) == 1
+        assert session.scalar(select(func.count()).select_from(DagVersion)) == 
1
         # new task
         PythonOperator(task_id="task2", python_callable=lambda: None, dag=dag)
         SDM.write_dag(LazyDeserializedDAG.from_dag(dag), 
bundle_name="dag_maker")
 
-        assert session.query(DagVersion).count() == 2
-        assert session.query(SDM).count() == 2
+        assert session.scalar(select(func.count()).select_from(DagVersion)) == 
2
+        assert session.scalar(select(func.count()).select_from(SDM)) == 2
 
     def test_example_dag_sorting_serialised_dag(self, session):
         """
@@ -517,14 +522,14 @@ class TestSerializedDagModel:
         # Create TIs
         dag_maker.create_dagrun(run_id="test_run")
 
-        assert session.query(DagVersion).count() == 1
+        assert session.scalar(select(func.count()).select_from(DagVersion)) == 
1
 
         # Write the same DAG (no changes, so hash is the same) with a new 
bundle_name
         new_bundle = "bundleB"
         SDM.write_dag(LazyDeserializedDAG.from_dag(dag), 
bundle_name=new_bundle)
 
         # There should now be two versions of the DAG
-        assert session.query(DagVersion).count() == 2
+        assert session.scalar(select(func.count()).select_from(DagVersion)) == 
2
 
     def test_hash_method_removes_fileloc_and_remains_consistent(self):
         """Test that the hash method removes fileloc before hashing."""
@@ -632,7 +637,7 @@ class TestSerializedDagModel:
         assert dag_version is not None
 
         # Manually delete SerializedDagModel (simulates edge case)
-        session.query(SDM).filter(SDM.dag_id == "test_missing_serdag").delete()
+        session.execute(delete(SDM).where(SDM.dag_id == "test_missing_serdag"))
         session.commit()
 
         # Verify no SerializedDagModel exists
@@ -709,7 +714,9 @@ class TestSerializedDagModel:
             EmptyOperator(task_id="task1")
 
         dag = dag_maker.dag
-        initial_version_count = 
session.query(DagVersion).filter(DagVersion.dag_id == dag.dag_id).count()
+        initial_version_count = session.scalar(
+            
select(func.count()).select_from(DagVersion).where(DagVersion.dag_id == 
dag.dag_id)
+        )
         assert initial_version_count == 1, "Should have one DagVersion after 
initial write"
         dag_maker.create_dagrun()  # ensure the second dag version is created
 
@@ -732,8 +739,8 @@ class TestSerializedDagModel:
             # Verify that no new DagVersion was committed
             # Use a fresh session to ensure we're reading from committed data
             with create_session() as fresh_session:
-                final_version_count = (
-                    fresh_session.query(DagVersion).filter(DagVersion.dag_id 
== dag.dag_id).count()
+                final_version_count = fresh_session.scalar(
+                    
select(func.count()).select_from(DagVersion).where(DagVersion.dag_id == 
dag.dag_id)
                 )
                 assert final_version_count == initial_version_count, (
                     "DagVersion should not be committed when 
DagCode.write_code fails"
diff --git a/airflow-core/tests/unit/models/test_trigger.py 
b/airflow-core/tests/unit/models/test_trigger.py
index fe2fbeb6b98..df169852cb5 100644
--- a/airflow-core/tests/unit/models/test_trigger.py
+++ b/airflow-core/tests/unit/models/test_trigger.py
@@ -26,6 +26,7 @@ import pendulum
 import pytest
 import pytz
 from cryptography.fernet import Fernet
+from sqlalchemy import delete, func, select
 
 from airflow._shared.timezones import timezone
 from airflow.jobs.job import Job
@@ -61,21 +62,21 @@ def session():
 
 @pytest.fixture(autouse=True)
 def clear_db(session):
-    session.query(TaskInstance).delete()
-    session.query(AssetWatcherModel).delete()
-    session.query(Callback).delete()
-    session.query(Trigger).delete()
-    session.query(AssetModel).delete()
-    session.query(AssetEvent).delete()
-    session.query(Job).delete()
+    session.execute(delete(TaskInstance))
+    session.execute(delete(AssetWatcherModel))
+    session.execute(delete(Callback))
+    session.execute(delete(Trigger))
+    session.execute(delete(AssetModel))
+    session.execute(delete(AssetEvent))
+    session.execute(delete(Job))
     yield session
-    session.query(TaskInstance).delete()
-    session.query(AssetWatcherModel).delete()
-    session.query(Callback).delete()
-    session.query(Trigger).delete()
-    session.query(AssetModel).delete()
-    session.query(AssetEvent).delete()
-    session.query(Job).delete()
+    session.execute(delete(TaskInstance))
+    session.execute(delete(AssetWatcherModel))
+    session.execute(delete(Callback))
+    session.execute(delete(Trigger))
+    session.execute(delete(AssetModel))
+    session.execute(delete(AssetEvent))
+    session.execute(delete(Job))
     session.commit()
 
 
@@ -121,7 +122,7 @@ def test_clean_unused(session, create_task_instance):
     session.add(trigger5)
     session.add(trigger6)
     session.commit()
-    assert session.query(Trigger).count() == 6
+    assert session.scalar(select(func.count()).select_from(Trigger)) == 6
     # Tie one to a fake TaskInstance that is not deferred, and one to one that 
is
     task_instance = create_task_instance(
         session=session, task_id="fake", state=State.DEFERRED, 
logical_date=timezone.utcnow()
@@ -150,7 +151,7 @@ def test_clean_unused(session, create_task_instance):
     asset.add_trigger(trigger5, "test_asset_watcher2")
     session.add(asset)
     session.commit()
-    assert session.query(AssetModel).count() == 1
+    assert session.scalar(select(func.count()).select_from(AssetModel)) == 1
 
     # Create callback with trigger
     callback = TriggererCallback(
@@ -162,7 +163,7 @@ def test_clean_unused(session, create_task_instance):
 
     # Run clear operation
     Trigger.clean_unused()
-    results = session.query(Trigger).all()
+    results = session.scalars(select(Trigger)).all()
     assert len(results) == 4
     assert {result.id for result in results} == {trigger1.id, trigger4.id, 
trigger5.id, trigger6.id}
 
@@ -196,7 +197,10 @@ def test_submit_event(mock_callback_handle_event, session, 
create_task_instance)
     session.commit()
 
     # Check that the asset has 0 event prior to sending an event to the trigger
-    assert session.query(AssetEvent).filter_by(asset_id=asset.id).count() == 0
+    assert (
+        
session.scalar(select(func.count()).select_from(AssetEvent).where(AssetEvent.asset_id
 == asset.id))
+        == 0
+    )
 
     # Create event
     payload = "payload"
@@ -210,8 +214,11 @@ def test_submit_event(mock_callback_handle_event, session, 
create_task_instance)
     assert task_instance.state == State.SCHEDULED
     assert task_instance.next_kwargs == {"event": payload, "cheesecake": True}
     # Check that the asset has received an event
-    assert session.query(AssetEvent).filter_by(asset_id=asset.id).count() == 1
-    asset_event = 
session.query(AssetEvent).filter_by(asset_id=asset.id).first()
+    assert (
+        
session.scalar(select(func.count()).select_from(AssetEvent).where(AssetEvent.asset_id
 == asset.id))
+        == 1
+    )
+    asset_event = session.scalar(select(AssetEvent).where(AssetEvent.asset_id 
== asset.id))
     assert asset_event.extra == {"from_trigger": True, "payload": payload}
 
     # Check that the callback's handle_event was called
@@ -233,7 +240,7 @@ def test_submit_failure(session, create_task_instance):
     # Call submit_event
     Trigger.submit_failure(trigger.id, session=session)
     # Check that the task instance is now scheduled to fail
-    updated_task_instance = session.query(TaskInstance).one()
+    updated_task_instance = session.scalar(select(TaskInstance))
     assert updated_task_instance.state == State.SCHEDULED
     assert updated_task_instance.next_method == "__fail__"
 
@@ -272,7 +279,7 @@ def test_submit_event_task_end(mock_utcnow, session, 
create_task_instance, event
 
     # now for the real test
     # first check initial state
-    ti: TaskInstance = session.query(TaskInstance).one()
+    ti: TaskInstance = session.scalar(select(TaskInstance))
     assert ti.state == "deferred"
     assert get_xcoms(ti) == []
 
@@ -285,7 +292,7 @@ def test_submit_event_task_end(mock_utcnow, session, 
create_task_instance, event
     # commit changes made by submit event and expire all cache to read from db.
     session.flush()
     # Check that the task instance is now correct
-    ti = session.query(TaskInstance).one()
+    ti = session.scalar(select(TaskInstance))
     assert ti.state == expected
     assert ti.next_kwargs is None
     assert ti.end_date == now
@@ -370,26 +377,26 @@ def test_assign_unassigned(session, create_task_instance):
     session.add(ti_trigger_unassigned_to_triggerer)
     assert trigger_unassigned_to_triggerer.triggerer_id is None
     session.commit()
-    assert session.query(Trigger).count() == 4
+    assert session.scalar(select(func.count()).select_from(Trigger)) == 4
     Trigger.assign_unassigned(new_triggerer.id, 100, health_check_threshold=30)
     session.expire_all()
     # Check that trigger on killed triggerer and unassigned trigger are 
assigned to new triggerer
     assert (
-        session.query(Trigger).filter(Trigger.id == 
trigger_on_killed_triggerer.id).one().triggerer_id
+        session.scalar(select(Trigger).where(Trigger.id == 
trigger_on_killed_triggerer.id)).triggerer_id
         == new_triggerer.id
     )
     assert (
-        session.query(Trigger).filter(Trigger.id == 
trigger_unassigned_to_triggerer.id).one().triggerer_id
+        session.scalar(select(Trigger).where(Trigger.id == 
trigger_unassigned_to_triggerer.id)).triggerer_id
         == new_triggerer.id
     )
     # Check that trigger on healthy triggerer still assigned to existing 
triggerer
     assert (
-        session.query(Trigger).filter(Trigger.id == 
trigger_on_healthy_triggerer.id).one().triggerer_id
+        session.scalar(select(Trigger).where(Trigger.id == 
trigger_on_healthy_triggerer.id)).triggerer_id
         == healthy_triggerer.id
     )
     # Check that trigger on unhealthy triggerer is assigned to new triggerer
     assert (
-        session.query(Trigger).filter(Trigger.id == 
trigger_on_unhealthy_triggerer.id).one().triggerer_id
+        session.scalar(select(Trigger).where(Trigger.id == 
trigger_on_unhealthy_triggerer.id)).triggerer_id
         == new_triggerer.id
     )
 
@@ -453,7 +460,7 @@ def test_get_sorted_triggers_same_priority_weight(session, 
create_task_instance)
     )
     session.add(trigger_callback)
     session.commit()
-    assert session.query(Trigger).count() == 5
+    assert session.scalar(select(func.count()).select_from(Trigger)) == 5
     # Create assets
     asset = AssetModel("test")
     asset.add_trigger(trigger_asset, "test_asset_watcher")
@@ -534,7 +541,7 @@ def 
test_get_sorted_triggers_different_priority_weights(session, create_task_ins
     session.add(TI_new)
 
     session.commit()
-    assert session.query(Trigger).count() == 5
+    assert session.scalar(select(func.count()).select_from(Trigger)) == 5
 
     trigger_ids_query = Trigger.get_sorted_triggers(capacity=100, 
alive_triggerer_ids=[], session=session)
 
@@ -605,7 +612,7 @@ def test_get_sorted_triggers_dont_starve_for_ha(session, 
create_task_instance):
         asset_triggers.append(trigger)
 
     session.commit()
-    assert session.query(Trigger).count() == 60
+    assert session.scalar(select(func.count()).select_from(Trigger)) == 60
 
     # Mock max_trigger_to_select_per_loop to 5 for testing
     with patch.object(Trigger, "max_trigger_to_select_per_loop", 5):
diff --git a/airflow-core/tests/unit/models/test_xcom.py 
b/airflow-core/tests/unit/models/test_xcom.py
index 1bc7105cd8a..acf7ad752bf 100644
--- a/airflow-core/tests/unit/models/test_xcom.py
+++ b/airflow-core/tests/unit/models/test_xcom.py
@@ -23,6 +23,7 @@ from unittest import mock
 from unittest.mock import MagicMock
 
 import pytest
+from sqlalchemy import delete, func, select
 
 from airflow._shared.timezones import timezone
 from airflow.configuration import conf
@@ -88,7 +89,7 @@ def task_instance_factory(request, session: Session):
 
         def cleanup_database():
             # This should also clear task instances by cascading.
-            session.query(DagRun).filter_by(id=run.id).delete()
+            session.execute(delete(DagRun).where(DagRun.id == run.id))
             session.commit()
 
         request.addfinalizer(cleanup_database)
@@ -384,7 +385,7 @@ class TestXComSet:
             run_id=task_instance.run_id,
             session=session,
         )
-        stored_xcoms = session.query(XComModel).all()
+        stored_xcoms = session.scalars(select(XComModel)).all()
         assert stored_xcoms[0].key == key
         assert isinstance(stored_xcoms[0].value, 
type(json.dumps(expected_value)))
         assert stored_xcoms[0].value == json.dumps(expected_value)
@@ -398,7 +399,7 @@ class TestXComSet:
 
     @pytest.mark.usefixtures("setup_for_xcom_set_again_replace")
     def test_xcom_set_again_replace(self, session, task_instance):
-        assert session.query(XComModel).one().value == json.dumps({"key1": 
"value1"})
+        assert session.scalar(select(XComModel)).value == json.dumps({"key1": 
"value1"})
         XComModel.set(
             key="xcom_1",
             value={"key2": "value2"},
@@ -407,7 +408,7 @@ class TestXComSet:
             run_id=task_instance.run_id,
             session=session,
         )
-        assert session.query(XComModel).one().value == json.dumps({"key2": 
"value2"})
+        assert session.scalar(select(XComModel)).value == json.dumps({"key2": 
"value2"})
 
     def test_xcom_set_invalid_key(self, session, task_instance):
         """Test that setting an XCom with an invalid key raises a 
ValueError."""
@@ -440,14 +441,14 @@ class TestXComClear:
     @pytest.mark.usefixtures("setup_for_xcom_clear")
     @mock.patch("airflow.sdk.execution_time.xcom.XCom.purge")
     def test_xcom_clear(self, mock_purge, session, task_instance):
-        assert session.query(XComModel).count() == 1
+        assert session.scalar(select(func.count()).select_from(XComModel)) == 1
         XComModel.clear(
             dag_id=task_instance.dag_id,
             task_id=task_instance.task_id,
             run_id=task_instance.run_id,
             session=session,
         )
-        assert session.query(XComModel).count() == 0
+        assert session.scalar(select(func.count()).select_from(XComModel)) == 0
         # purge will not be done when we clear, will be handled in task sdk
         assert mock_purge.call_count == 0
 
@@ -459,7 +460,7 @@ class TestXComClear:
             run_id="different_run",
             session=session,
         )
-        assert session.query(XComModel).count() == 1
+        assert session.scalar(select(func.count()).select_from(XComModel)) == 1
 
 
 class TestXComRoundTrip:
diff --git a/airflow-core/tests/unit/utils/test_log_handlers.py 
b/airflow-core/tests/unit/utils/test_log_handlers.py
index 30669ab0599..cbe1a61d1c6 100644
--- a/airflow-core/tests/unit/utils/test_log_handlers.py
+++ b/airflow-core/tests/unit/utils/test_log_handlers.py
@@ -36,6 +36,7 @@ import pytest
 from pydantic import TypeAdapter
 from pydantic.v1.utils import deep_update
 from requests.adapters import Response
+from sqlalchemy import delete, select
 
 from airflow import settings
 from airflow.config_templates.airflow_local_settings import 
DEFAULT_LOGGING_CONFIG
@@ -98,8 +99,8 @@ def cleanup_tables():
 class TestFileTaskLogHandler:
     def clean_up(self):
         with create_session() as session:
-            session.query(DagRun).delete()
-            session.query(TaskInstance).delete()
+            session.execute(delete(DagRun))
+            session.execute(delete(TaskInstance))
 
     def setup_method(self):
         settings.configure_logging()
@@ -781,16 +782,14 @@ class TestFilenameRendering:
         )
         TaskInstanceHistory.record_ti(ti, session=session)
         session.flush()
-        tih = (
-            session.query(TaskInstanceHistory)
-            .filter_by(
-                dag_id=ti.dag_id,
-                task_id=ti.task_id,
-                run_id=ti.run_id,
-                map_index=ti.map_index,
-                try_number=ti.try_number,
+        tih = session.scalar(
+            select(TaskInstanceHistory).where(
+                TaskInstanceHistory.dag_id == ti.dag_id,
+                TaskInstanceHistory.task_id == ti.task_id,
+                TaskInstanceHistory.run_id == ti.run_id,
+                TaskInstanceHistory.map_index == ti.map_index,
+                TaskInstanceHistory.try_number == ti.try_number,
             )
-            .one()
         )
         fth = FileTaskHandler("")
         rendered_ti = fth._render_filename(ti, ti.try_number, session=session)
diff --git a/airflow-core/tests/unit/utils/test_state.py 
b/airflow-core/tests/unit/utils/test_state.py
index 463f9433204..88a1925842e 100644
--- a/airflow-core/tests/unit/utils/test_state.py
+++ b/airflow-core/tests/unit/utils/test_state.py
@@ -19,6 +19,7 @@ from __future__ import annotations
 from datetime import timedelta
 
 import pytest
+from sqlalchemy import select
 
 from airflow.models.dagrun import DagRun
 from airflow.sdk import DAG
@@ -58,22 +59,18 @@ def test_dagrun_state_enum_escape(testing_dag_bundle):
             triggered_by=DagRunTriggeredByType.TEST,
         )
 
-        query = session.query(
-            DagRun.dag_id,
-            DagRun.state,
-            DagRun.run_type,
-        ).filter(
+        stmt = select(DagRun.dag_id, DagRun.state, DagRun.run_type).where(
             DagRun.dag_id == dag.dag_id,
             # make sure enum value can be used in filter queries
             DagRun.state == DagRunState.QUEUED,
         )
-        assert str(query.statement.compile(compile_kwargs={"literal_binds": 
True})) == (
+        assert str(stmt.compile(compile_kwargs={"literal_binds": True})) == (
             "SELECT dag_run.dag_id, dag_run.state, dag_run.run_type \n"
             "FROM dag_run \n"
             "WHERE dag_run.dag_id = 'test_dagrun_state_enum_escape' AND 
dag_run.state = 'queued'"
         )
 
-        rows = query.all()
+        rows = session.execute(stmt).all()
         assert len(rows) == 1
         assert rows[0].dag_id == dag.dag_id
         # make sure value in db is stored as `queued`, not `DagRunType.QUEUED`

Reply via email to