This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 75b2b7d1ece Refactor/sqla2 multiple files (#59540)
75b2b7d1ece is described below
commit 75b2b7d1ece22b2429e33d39d3933cc84c530d92
Author: KUAN-HAO HUANG <[email protected]>
AuthorDate: Sat Dec 20 04:53:29 2025 +0800
Refactor/sqla2 multiple files (#59540)
* Refactor deprecated SQLA models/test_serialized_dag.py
* Refactor deprecated SQLA models/test_pool.py
* Refactor deprecated SQLA models/test_trigger.py
* Refactor deprecated SQLA models/test_callback.py
* Refactor deprecated SQLA models/test_xcom.py
* Refactor deprecated SQLA models/test_cleartasks.py
* Refactor deprecated SQLA models/test_dagrun.py
* Refactor deprecated SQLA test_log_handlers.py
* fix error
* Refactor deprecated SQLA utils/test_state.py
* fix error
* remove redundant parts
* change to where()
* fix pre-commit error
* fix error
* remove redundant commit()
---------
Co-authored-by: Jarek Potiuk <[email protected]>
---
.pre-commit-config.yaml | 9 ++
airflow-core/tests/unit/models/test_callback.py | 5 +-
airflow-core/tests/unit/models/test_cleartasks.py | 37 +++---
airflow-core/tests/unit/models/test_dagrun.py | 126 +++++++++++----------
airflow-core/tests/unit/models/test_pool.py | 7 +-
.../tests/unit/models/test_serialized_dag.py | 45 ++++----
airflow-core/tests/unit/models/test_trigger.py | 69 ++++++-----
airflow-core/tests/unit/models/test_xcom.py | 15 +--
airflow-core/tests/unit/utils/test_log_handlers.py | 21 ++--
airflow-core/tests/unit/utils/test_state.py | 11 +-
10 files changed, 184 insertions(+), 161 deletions(-)
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
index f417efb2333..572189651b3 100644
--- a/.pre-commit-config.yaml
+++ b/.pre-commit-config.yaml
@@ -425,6 +425,13 @@ repos:
^airflow-ctl.*\.py$|
^airflow-core/src/airflow/models/.*\.py$|
^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_assets.py$|
+ ^airflow-core/tests/unit/models/test_serialized_dag.py$|
+ ^airflow-core/tests/unit/models/test_pool.py$|
+ ^airflow-core/tests/unit/models/test_trigger.py$|
+ ^airflow-core/tests/unit/models/test_callback.py$|
+ ^airflow-core/tests/unit/models/test_cleartasks.py$|
+ ^airflow-core/tests/unit/models/test_xcom.py$|
+ ^airflow-core/tests/unit/models/test_dagrun.py$|
^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_dag_sources.py$|
^airflow-core/tests/unit/api_fastapi/core_api/routes/public/test_hitl.py$|
^airflow-core/tests/unit/api_fastapi/core_api/routes/ui/test_assets.py$|
@@ -444,6 +451,8 @@ repos:
^airflow-core/tests/unit/models/test_dagwarning.py$|
^airflow-core/tests/integration/otel/test_otel.py$|
^airflow-core/tests/unit/utils/test_db_cleanup.py$|
+ ^airflow-core/tests/unit/utils/test_state.py$|
+ ^airflow-core/tests/unit/utils/test_log_handlers.py$|
^dev/airflow_perf/scheduler_dag_execution_timing.py$|
^providers/celery/.*\.py$|
^providers/cncf/kubernetes/.*\.py$|
diff --git a/airflow-core/tests/unit/models/test_callback.py
b/airflow-core/tests/unit/models/test_callback.py
index 09d0931557d..dfc19fc61a3 100644
--- a/airflow-core/tests/unit/models/test_callback.py
+++ b/airflow-core/tests/unit/models/test_callback.py
@@ -17,6 +17,7 @@
from __future__ import annotations
import pytest
+from sqlalchemy import select
from airflow.models import Trigger
from airflow.models.callback import (
@@ -118,7 +119,7 @@ class TestTriggererCallback:
session.add(callback)
session.commit()
- retrieved = session.query(Callback).filter_by(id=callback.id).one()
+ retrieved = session.scalar(select(Callback).where(Callback.id ==
callback.id))
assert isinstance(retrieved, TriggererCallback)
assert retrieved.fetch_method == CallbackFetchMethod.IMPORT_PATH
assert retrieved.data == TEST_ASYNC_CALLBACK.serialize()
@@ -188,7 +189,7 @@ class TestExecutorCallback:
session.add(callback)
session.commit()
- retrieved = session.query(Callback).filter_by(id=callback.id).one()
+ retrieved = session.scalar(select(Callback).where(Callback.id ==
callback.id))
assert isinstance(retrieved, ExecutorCallback)
assert retrieved.fetch_method == CallbackFetchMethod.IMPORT_PATH
assert retrieved.data == TEST_SYNC_CALLBACK.serialize()
diff --git a/airflow-core/tests/unit/models/test_cleartasks.py
b/airflow-core/tests/unit/models/test_cleartasks.py
index 4d8063f8419..e92f57e741d 100644
--- a/airflow-core/tests/unit/models/test_cleartasks.py
+++ b/airflow-core/tests/unit/models/test_cleartasks.py
@@ -21,7 +21,7 @@ import datetime
import random
import pytest
-from sqlalchemy import select
+from sqlalchemy import func, select
from airflow.models.dag_version import DagVersion
from airflow.models.dagrun import DagRun
@@ -87,7 +87,7 @@ class TestClearTasks:
# this is equivalent to topological sort. It would not work in
general case
# but it works for our case because we specifically constructed
test DAGS
# in the way that those two sort methods are equivalent
- qry = session.query(TI).filter(TI.dag_id ==
dag.dag_id).order_by(TI.task_id).all()
+ qry = session.scalars(select(TI).where(TI.dag_id ==
dag.dag_id).order_by(TI.task_id)).all()
clear_task_instances(qry, session)
ti0.refresh_from_db(session)
@@ -121,7 +121,7 @@ class TestClearTasks:
# this is equivalent to topological sort. It would not work in
general case
# but it works for our case because we specifically constructed
test DAGS
# in the way that those two sort methods are equivalent
- qry = session.query(TI).filter(TI.dag_id ==
dag.dag_id).order_by(TI.task_id).all()
+ qry = session.scalars(select(TI).where(TI.dag_id ==
dag.dag_id).order_by(TI.task_id)).all()
clear_task_instances(qry, session)
ti0.refresh_from_db()
@@ -186,12 +186,12 @@ class TestClearTasks:
# this is equivalent to topological sort. It would not work in general
case
# but it works for our case because we specifically constructed test
DAGS
# in the way that those two sort methods are equivalent
- qry = session.query(TI).filter(TI.dag_id ==
dag.dag_id).order_by(TI.task_id).all()
- assert session.query(TaskInstanceHistory).count() == 0
+ qry = session.scalars(select(TI).where(TI.dag_id ==
dag.dag_id).order_by(TI.task_id)).all()
+ assert
session.scalar(select(func.count()).select_from(TaskInstanceHistory)) == 0
clear_task_instances(qry, session, dag_run_state=state)
session.flush()
# 2 TIs were cleared so 2 history records should be created
- assert session.query(TaskInstanceHistory).count() == 2
+ assert
session.scalar(select(func.count()).select_from(TaskInstanceHistory)) == 2
session.refresh(dr)
@@ -229,7 +229,7 @@ class TestClearTasks:
# this is equivalent to topological sort. It would not work in general
case
# but it works for our case because we specifically constructed test
DAGS
# in the way that those two sort methods are equivalent
- qry = session.query(TI).filter(TI.dag_id ==
dag.dag_id).order_by(TI.task_id).all()
+ qry = session.scalars(select(TI).where(TI.dag_id ==
dag.dag_id).order_by(TI.task_id)).all()
clear_task_instances(qry, session)
session.flush()
@@ -282,7 +282,7 @@ class TestClearTasks:
# this is equivalent to topological sort. It would not work in general
case
# but it works for our case because we specifically constructed test
DAGS
# in the way that those two sort methods are equivalent
- qry = session.query(TI).filter(TI.dag_id ==
dag.dag_id).order_by(TI.task_id).all()
+ qry = session.scalars(select(TI).where(TI.dag_id ==
dag.dag_id).order_by(TI.task_id)).all()
clear_task_instances(qry, session)
session.flush()
@@ -394,7 +394,7 @@ class TestClearTasks:
# this is equivalent to topological sort. It would not work in general
case
# but it works for our case because we specifically constructed test
DAGS
# in the way that those two sort methods are equivalent
- qry = session.query(TI).filter(TI.dag_id ==
dag.dag_id).order_by(TI.task_id).all()
+ qry = session.scalars(select(TI).where(TI.dag_id ==
dag.dag_id).order_by(TI.task_id)).all()
clear_task_instances(qry, session)
ti0.refresh_from_db(session=session)
@@ -477,7 +477,9 @@ class TestClearTasks:
with create_session() as session:
def count_task_reschedule(ti):
- return
session.query(TaskReschedule).filter(TaskReschedule.ti_id == ti.id).count()
+ return session.scalar(
+
select(func.count()).select_from(TaskReschedule).where(TaskReschedule.ti_id ==
ti.id)
+ )
assert count_task_reschedule(ti0) == 1
assert count_task_reschedule(ti1) == 1
@@ -485,12 +487,9 @@ class TestClearTasks:
# this is equivalent to topological sort. It would not work in
general case
# but it works for our case because we specifically constructed
test DAGS
# in the way that those two sort methods are equivalent
- qry = (
- session.query(TI)
- .filter(TI.dag_id == dag.dag_id, TI.task_id == ti0.task_id)
- .order_by(TI.task_id)
- .all()
- )
+ qry = session.scalars(
+ select(TI).where(TI.dag_id == dag.dag_id, TI.task_id ==
ti0.task_id).order_by(TI.task_id)
+ ).all()
clear_task_instances(qry, session)
assert count_task_reschedule(ti0) == 0
assert count_task_reschedule(ti1) == 1
@@ -531,7 +530,7 @@ class TestClearTasks:
ti1.state = state
session = dag_maker.session
session.flush()
- qry = session.query(TI).filter(TI.dag_id ==
dag.dag_id).order_by(TI.task_id).all()
+ qry = session.scalars(select(TI).where(TI.dag_id ==
dag.dag_id).order_by(TI.task_id)).all()
clear_task_instances(qry, session)
session.flush()
@@ -716,10 +715,10 @@ class TestClearTasks:
new_dag_version = DagVersion.get_latest_version(dag.dag_id)
assert old_dag_version.id != new_dag_version.id
- qry = session.query(TI).filter(TI.dag_id ==
dag.dag_id).order_by(TI.task_id).all()
+ qry = session.scalars(select(TI).where(TI.dag_id ==
dag.dag_id).order_by(TI.task_id)).all()
clear_task_instances(qry, session,
run_on_latest_version=run_on_latest_version)
session.commit()
- dr = session.query(DagRun).filter(DagRun.dag_id == dag.dag_id).one()
+ dr = session.scalar(select(DagRun).where(DagRun.dag_id == dag.dag_id))
if run_on_latest_version:
assert dr.created_dag_version_id == new_dag_version.id
assert dr.bundle_version == new_dag_version.bundle_version
diff --git a/airflow-core/tests/unit/models/test_dagrun.py
b/airflow-core/tests/unit/models/test_dagrun.py
index daeff282424..a09aceccc02 100644
--- a/airflow-core/tests/unit/models/test_dagrun.py
+++ b/airflow-core/tests/unit/models/test_dagrun.py
@@ -27,7 +27,7 @@ from unittest.mock import call
import pendulum
import pytest
-from sqlalchemy import exists, select
+from sqlalchemy import exists, func, select
from sqlalchemy.orm import joinedload
from airflow import settings
@@ -155,10 +155,10 @@ class TestDagRun:
EmptyOperator(task_id="backfill_task_0")
self.create_dag_run(dag, logical_date=now, is_backfill=True,
state=state, session=session)
- qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all()
+ qry = session.scalars(select(TI).where(TI.dag_id == dag.dag_id)).all()
clear_task_instances(qry, session)
session.flush()
- dr0 = session.query(DagRun).filter(DagRun.dag_id == dag_id,
DagRun.logical_date == now).first()
+ dr0 = session.scalar(select(DagRun).where(DagRun.dag_id == dag_id,
DagRun.logical_date == now))
assert dr0.state == state
assert dr0.clear_number < 1
@@ -170,10 +170,10 @@ class TestDagRun:
EmptyOperator(task_id="backfill_task_0")
self.create_dag_run(dag, logical_date=now, is_backfill=True,
state=state, session=session)
- qry = session.query(TI).filter(TI.dag_id == dag.dag_id).all()
+ qry = session.scalars(select(TI).where(TI.dag_id == dag.dag_id)).all()
clear_task_instances(qry, session)
session.flush()
- dr0 = session.query(DagRun).filter(DagRun.dag_id == dag_id,
DagRun.logical_date == now).first()
+ dr0 = session.scalar(select(DagRun).where(DagRun.dag_id == dag_id,
DagRun.logical_date == now))
assert dr0.state == DagRunState.QUEUED
assert dr0.clear_number == 1
@@ -721,7 +721,7 @@ class TestDagRun:
session.merge(dr)
session.commit()
- dr_database = session.query(DagRun).filter(DagRun.run_id ==
dr.run_id).one()
+ dr_database = session.scalar(select(DagRun).where(DagRun.run_id ==
dr.run_id))
assert dr_database.end_date is not None
assert dr.end_date == dr_database.end_date
@@ -729,14 +729,14 @@ class TestDagRun:
session.merge(dr)
session.commit()
- dr_database = session.query(DagRun).filter(DagRun.run_id ==
dr.run_id).one()
+ dr_database = session.scalar(select(DagRun).where(DagRun.run_id ==
dr.run_id))
assert dr_database.end_date is None
dr.set_state(DagRunState.FAILED)
session.merge(dr)
session.commit()
- dr_database = session.query(DagRun).filter(DagRun.run_id ==
dr.run_id).one()
+ dr_database = session.scalar(select(DagRun).where(DagRun.run_id ==
dr.run_id))
assert dr_database.end_date is not None
assert dr.end_date == dr_database.end_date
@@ -764,7 +764,7 @@ class TestDagRun:
dr.update_state()
- dr_database = session.query(DagRun).filter(DagRun.run_id ==
dr.run_id).one()
+ dr_database = session.scalar(select(DagRun).where(DagRun.run_id ==
dr.run_id))
assert dr_database.end_date is not None
assert dr.end_date == dr_database.end_date
@@ -772,7 +772,7 @@ class TestDagRun:
ti_op2.set_state(state=TaskInstanceState.RUNNING, session=session)
dr.update_state()
- dr_database = session.query(DagRun).filter(DagRun.run_id ==
dr.run_id).one()
+ dr_database = session.scalar(select(DagRun).where(DagRun.run_id ==
dr.run_id))
assert dr._state == DagRunState.RUNNING
assert dr.end_date is None
@@ -782,7 +782,7 @@ class TestDagRun:
ti_op2.set_state(state=TaskInstanceState.FAILED, session=session)
dr.update_state()
- dr_database = session.query(DagRun).filter(DagRun.run_id ==
dr.run_id).one()
+ dr_database = session.scalar(select(DagRun).where(DagRun.run_id ==
dr.run_id))
assert dr_database.end_date is not None
assert dr.end_date == dr_database.end_date
@@ -1216,7 +1216,7 @@ class TestDagRun:
EmptyOperator(task_id="empty")
dag_run = dag_maker.create_dagrun()
- dm =
session.query(DagModel).options(joinedload(DagModel.dag_versions)).one()
+ dm =
session.scalar(select(DagModel).options(joinedload(DagModel.dag_versions)))
assert dag_run.dag_versions[0].id == dm.dag_versions[0].id
def test_dag_run_version_number(self, dag_maker, session):
@@ -1231,7 +1231,7 @@ class TestDagRun:
tis[1].dag_version = dag_v
session.merge(tis[1])
session.flush()
- dag_run = session.query(DagRun).filter(DagRun.run_id ==
dag_run.run_id).one()
+ dag_run = session.scalar(select(DagRun).where(DagRun.run_id ==
dag_run.run_id))
# Check that dag_run.version_number returns the version number of
# the latest task instance dag_version
assert dag_run.version_number == dag_v.version_number
@@ -1337,14 +1337,14 @@ class TestDagRun:
dag_run1_deadline = exists().where(Deadline.dagrun_id == dag_run1.id)
dag_run2_deadline = exists().where(Deadline.dagrun_id == dag_run2.id)
- assert session.query(dag_run1_deadline).scalar()
- assert session.query(dag_run2_deadline).scalar()
+ assert session.scalar(select(dag_run1_deadline))
+ assert session.scalar(select(dag_run2_deadline))
session.add(dag_run1)
dag_run1.update_state()
- assert not session.query(dag_run1_deadline).scalar()
- assert session.query(dag_run2_deadline).scalar()
+ assert not session.scalar(select(dag_run1_deadline))
+ assert session.scalar(select(dag_run2_deadline))
assert dag_run1.state == DagRunState.SUCCESS
assert dag_run2.state == DagRunState.RUNNING
@@ -1399,13 +1399,12 @@ def test_expand_mapped_task_instance_at_create(is_noop,
dag_maker, session):
mapped =
MockOperator.partial(task_id="task_2").expand(arg2=literal)
dr = dag_maker.create_dagrun()
- indices = (
- session.query(TI.map_index)
- .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id,
run_id=dr.run_id)
+ indices = session.scalars(
+ select(TI.map_index)
+ .where(TI.task_id == mapped.task_id, TI.dag_id == mapped.dag_id,
TI.run_id == dr.run_id)
.order_by(TI.map_index)
- .all()
- )
- assert indices == [(0,), (1,), (2,), (3,)]
+ ).all()
+ assert indices == [0, 1, 2, 3]
@pytest.mark.parametrize("is_noop", [True, False])
@@ -1422,13 +1421,12 @@ def
test_expand_mapped_task_instance_task_decorator(is_noop, dag_maker, session)
mynameis.expand(arg=literal)
dr = dag_maker.create_dagrun()
- indices = (
- session.query(TI.map_index)
- .filter_by(task_id="mynameis", dag_id=dr.dag_id, run_id=dr.run_id)
+ indices = session.scalars(
+ select(TI.map_index)
+ .where(TI.task_id == "mynameis", TI.dag_id == dr.dag_id, TI.run_id
== dr.run_id)
.order_by(TI.map_index)
- .all()
- )
- assert indices == [(0,), (1,), (2,), (3,)]
+ ).all()
+ assert indices == [0, 1, 2, 3]
def test_mapped_literal_verify_integrity(dag_maker, session):
@@ -1444,7 +1442,7 @@ def test_mapped_literal_verify_integrity(dag_maker,
session):
query = (
select(TI.map_index, TI.state)
- .filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id)
+ .where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id ==
dr.run_id)
.order_by(TI.map_index)
)
indices = session.execute(query).all()
@@ -1483,12 +1481,11 @@ def
test_mapped_literal_to_xcom_arg_verify_integrity(dag_maker, session):
dag_version_id = DagVersion.get_latest_version(dag_id=dr.dag_id,
session=session).id
dr.verify_integrity(dag_version_id=dag_version_id, session=session)
- indices = (
- session.query(TI.map_index, TI.state)
- .filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id)
+ indices = session.execute(
+ select(TI.map_index, TI.state)
+ .where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id ==
dr.run_id)
.order_by(TI.map_index)
- .all()
- )
+ ).all()
assert indices == [
(0, TaskInstanceState.REMOVED),
@@ -1511,7 +1508,7 @@ def
test_mapped_literal_length_increase_adds_additional_ti(dag_maker, session):
query = (
select(TI.map_index, TI.state)
- .filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id)
+ .where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id ==
dr.run_id)
.order_by(TI.map_index)
)
indices = session.execute(query).all()
@@ -1552,7 +1549,7 @@ def
test_mapped_literal_length_reduction_adds_removed_state(dag_maker, session):
dr = dag_maker.create_dagrun()
query = (
select(TI.map_index, TI.state)
- .filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id)
+ .where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id ==
dr.run_id)
.order_by(TI.map_index)
)
indices = session.execute(query).all()
@@ -1661,7 +1658,7 @@ def
test_mapped_literal_length_reduction_at_runtime_adds_removed_state(dag_maker
dr.task_instance_scheduling_decisions(session=session)
query = (
select(TI.map_index, TI.state)
- .filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id)
+ .where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id ==
dr.run_id)
.order_by(TI.map_index)
)
indices = session.execute(query).all()
@@ -1751,7 +1748,7 @@ def
test_calls_to_verify_integrity_with_mapped_task_zero_length_at_runtime(dag_m
query = (
select(TI.map_index, TI.state)
- .filter_by(task_id="task_2", dag_id=dr.dag_id, run_id=dr.run_id)
+ .where(TI.task_id == "task_2", TI.dag_id == dr.dag_id, TI.run_id ==
dr.run_id)
.order_by(TI.map_index)
)
indices = session.execute(query).all()
@@ -1786,17 +1783,17 @@ def
test_mapped_mixed_literal_not_expanded_at_create(dag_maker, session):
dr = dag_maker.create_dagrun()
query = (
- session.query(TI.map_index, TI.state)
- .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id,
run_id=dr.run_id)
+ select(TI.map_index, TI.state)
+ .where(TI.task_id == mapped.task_id, TI.dag_id == mapped.dag_id,
TI.run_id == dr.run_id)
.order_by(TI.map_index)
)
- assert query.all() == [(-1, None)]
+ assert session.execute(query).all() == [(-1, None)]
# Verify_integrity shouldn't change the result now that the TIs exist
dag_version_id = DagVersion.get_latest_version(dag_id=dr.dag_id,
session=session).id
dr.verify_integrity(dag_version_id=dag_version_id, session=session)
- assert query.all() == [(-1, None)]
+ assert session.execute(query).all() == [(-1, None)]
def test_mapped_task_group_expands_at_create(dag_maker, session):
@@ -1823,11 +1820,11 @@ def test_mapped_task_group_expands_at_create(dag_maker,
session):
dr = dag_maker.create_dagrun()
query = (
- session.query(TI.task_id, TI.map_index, TI.state)
- .filter_by(dag_id=dr.dag_id, run_id=dr.run_id)
+ select(TI.task_id, TI.map_index, TI.state)
+ .where(TI.dag_id == dr.dag_id, TI.run_id == dr.run_id)
.order_by(TI.task_id, TI.map_index)
)
- assert query.all() == [
+ assert session.execute(query).all() == [
("tg.t1", 0, None),
("tg.t1", 1, None),
# ("tg.t2", 0, None),
@@ -1904,12 +1901,11 @@ def test_ti_scheduling_mapped_zero_length(dag_maker,
session):
# expanded against a zero-length XCom.
assert decision.finished_tis == [ti1, ti2]
- indices = (
- session.query(TI.map_index, TI.state)
- .filter_by(task_id=mapped.task_id, dag_id=mapped.dag_id,
run_id=dr.run_id)
+ indices = session.execute(
+ select(TI.map_index, TI.state)
+ .where(TI.task_id == mapped.task_id, TI.dag_id == mapped.dag_id,
TI.run_id == dr.run_id)
.order_by(TI.map_index)
- .all()
- )
+ ).all()
assert indices == [(-1, TaskInstanceState.SKIPPED)]
@@ -2576,8 +2572,14 @@ def
test_clearing_task_and_moving_from_non_mapped_to_mapped(dag_maker, session):
dr1: DagRun = dag_maker.create_dagrun(run_type=DagRunType.SCHEDULED)
ti = dr1.get_task_instances()[0]
- filter_kwargs = dict(dag_id=ti.dag_id, task_id=ti.task_id,
run_id=ti.run_id, map_index=ti.map_index)
- ti = session.query(TaskInstance).filter_by(**filter_kwargs).one()
+ ti = session.scalar(
+ select(TaskInstance).where(
+ TaskInstance.dag_id == ti.dag_id,
+ TaskInstance.task_id == ti.task_id,
+ TaskInstance.run_id == ti.run_id,
+ TaskInstance.map_index == ti.map_index,
+ )
+ )
tr = TaskReschedule(
ti_id=ti.id,
@@ -2598,10 +2600,10 @@ def
test_clearing_task_and_moving_from_non_mapped_to_mapped(dag_maker, session):
XComModel.set(key="test", value="value", task_id=ti.task_id,
dag_id=dag.dag_id, run_id=ti.run_id)
session.commit()
for table in [TaskInstanceNote, TaskReschedule, XComModel]:
- assert session.query(table).count() == 1
+ assert session.scalar(select(func.count()).select_from(table)) == 1
dr1.task_instance_scheduling_decisions(session)
for table in [TaskInstanceNote, TaskReschedule, XComModel]:
- assert session.query(table).count() == 0
+ assert session.scalar(select(func.count()).select_from(table)) == 0
def test_dagrun_with_note(dag_maker, session):
@@ -2619,14 +2621,14 @@ def test_dagrun_with_note(dag_maker, session):
session.add(dr)
session.commit()
- dr_note = session.query(DagRunNote).filter(DagRunNote.dag_run_id ==
dr.id).one()
+ dr_note = session.scalar(select(DagRunNote).where(DagRunNote.dag_run_id ==
dr.id))
assert dr_note.content == "dag run with note"
session.delete(dr)
session.commit()
- assert session.query(DagRun).filter(DagRun.id == dr.id).one_or_none() is
None
- assert session.query(DagRunNote).filter(DagRunNote.dag_run_id ==
dr.id).one_or_none() is None
+ assert session.scalar(select(DagRun).where(DagRun.id == dr.id)) is None
+ assert session.scalar(select(DagRunNote).where(DagRunNote.dag_run_id ==
dr.id)) is None
@pytest.mark.parametrize(
@@ -2655,7 +2657,7 @@ def test_teardown_failure_behaviour_on_dagrun(dag_maker,
session, dag_run_state,
session.flush()
dr.update_state()
session.flush()
- dr = session.query(DagRun).one()
+ dr = session.scalar(select(DagRun))
assert dr.state == dag_run_state
@@ -2694,7 +2696,7 @@ def test_teardown_failure_on_non_leaf_behaviour_on_dagrun(
session.flush()
dr.update_state()
session.flush()
- dr = session.query(DagRun).one()
+ dr = session.scalar(select(DagRun))
assert dr.state == dag_run_state
@@ -2729,7 +2731,7 @@ def
test_work_task_failure_when_setup_teardown_are_successful(dag_maker, session
session.flush()
dr.update_state()
session.flush()
- dr = session.query(DagRun).one()
+ dr = session.scalar(select(DagRun))
assert dr.state == DagRunState.FAILED
@@ -2765,7 +2767,7 @@ def
test_failure_of_leaf_task_not_connected_to_teardown_task(dag_maker, session)
session.flush()
dr.update_state()
session.flush()
- dr = session.query(DagRun).one()
+ dr = session.scalar(select(DagRun))
assert dr.state == DagRunState.FAILED
diff --git a/airflow-core/tests/unit/models/test_pool.py
b/airflow-core/tests/unit/models/test_pool.py
index fa59d85ffe8..a275bf2d4b2 100644
--- a/airflow-core/tests/unit/models/test_pool.py
+++ b/airflow-core/tests/unit/models/test_pool.py
@@ -21,6 +21,7 @@ from typing import TYPE_CHECKING
import pendulum
import pytest
+from sqlalchemy import func, select
from airflow import settings
from airflow.exceptions import AirflowException, PoolNotFound
@@ -292,7 +293,7 @@ class TestPool:
assert pool.slots == 5
assert pool.description == ""
assert pool.include_deferred is True
- assert session.query(Pool).count() == self.TOTAL_POOL_COUNT + 1
+ assert session.scalar(select(func.count()).select_from(Pool)) ==
self.TOTAL_POOL_COUNT + 1
def test_create_pool_existing(self, session):
self.add_pools()
@@ -303,13 +304,13 @@ class TestPool:
assert pool.slots == 5
assert pool.description == ""
assert pool.include_deferred is False
- assert session.query(Pool).count() == self.TOTAL_POOL_COUNT
+ assert session.scalar(select(func.count()).select_from(Pool)) ==
self.TOTAL_POOL_COUNT
def test_delete_pool(self, session):
self.add_pools()
pool = Pool.delete_pool(name=self.pools[-1].pool)
assert pool.pool == self.pools[-1].pool
- assert session.query(Pool).count() == self.TOTAL_POOL_COUNT - 1
+ assert session.scalar(select(func.count()).select_from(Pool)) ==
self.TOTAL_POOL_COUNT - 1
def test_delete_pool_non_existing(self):
with pytest.raises(PoolNotFound, match="^Pool 'test' doesn't exist$"):
diff --git a/airflow-core/tests/unit/models/test_serialized_dag.py
b/airflow-core/tests/unit/models/test_serialized_dag.py
index 2472c072046..ddc98f72064 100644
--- a/airflow-core/tests/unit/models/test_serialized_dag.py
+++ b/airflow-core/tests/unit/models/test_serialized_dag.py
@@ -24,7 +24,7 @@ from unittest import mock
import pendulum
import pytest
-from sqlalchemy import func, select, update
+from sqlalchemy import delete, func, select, update
import airflow.example_dags as example_dags_module
from airflow.dag_processing.dagbag import DagBag
@@ -59,7 +59,12 @@ def make_example_dags(module):
from airflow.utils.session import create_session
with create_session() as session:
- if session.query(DagBundleModel).filter(DagBundleModel.name ==
"testing").count() == 0:
+ if (
+ session.scalar(
+
select(func.count()).select_from(DagBundleModel).where(DagBundleModel.name ==
"testing")
+ )
+ == 0
+ ):
testing = DagBundleModel(name="testing")
session.add(testing)
@@ -101,7 +106,7 @@ class TestSerializedDagModel:
with create_session() as session:
for dag in example_dags.values():
assert SDM.has_dag(dag.dag_id)
- result = session.query(SDM).filter(SDM.dag_id ==
dag.dag_id).one()
+ result = session.scalar(select(SDM).where(SDM.dag_id ==
dag.dag_id))
assert result.dag_version.dag_code.fileloc == dag.fileloc
# Verifies JSON schema.
@@ -118,7 +123,7 @@ class TestSerializedDagModel:
with dag_maker("dag1"):
PythonOperator(task_id="task1", python_callable=lambda x: None)
dag_maker.create_dagrun(run_id="test2",
logical_date=pendulum.datetime(2025, 1, 1))
- assert len(session.query(DagVersion).all()) == 2
+ assert len(session.scalars(select(DagVersion)).all()) == 2
with dag_maker("dag2"):
@@ -136,7 +141,7 @@ class TestSerializedDagModel:
pass
my_callable2()
- assert len(session.query(DagVersion).all()) == 4
+ assert len(session.scalars(select(DagVersion)).all()) == 4
def test_serialized_dag_is_updated_if_dag_is_changed(self,
testing_dag_bundle):
"""Test Serialized DAG is updated if DAG is changed"""
@@ -212,7 +217,7 @@ class TestSerializedDagModel:
dag.doc_md = "new doc string"
SDM.write_dag(LazyDeserializedDAG.from_dag(dag), bundle_name="testing")
serialized_dags2 = SDM.read_all_dags()
- sdags = session.query(SDM).all()
+ sdags = session.scalars(select(SDM)).all()
# assert only the latest SDM is returned
assert len(sdags) != len(serialized_dags2)
@@ -334,7 +339,7 @@ class TestSerializedDagModel:
def test_new_dag_versions_are_not_created_if_no_dagruns(self, dag_maker,
session):
with dag_maker("dag1") as dag:
PythonOperator(task_id="task1", python_callable=lambda: None)
- assert session.query(SDM).count() == 1
+ assert session.scalar(select(func.count()).select_from(SDM)) == 1
sdm1 = SDM.get(dag.dag_id, session=session)
dag_hash = sdm1.dag_hash
created_at = sdm1.created_at
@@ -347,21 +352,21 @@ class TestSerializedDagModel:
assert sdm2.dag_hash != dag_hash # first recorded serdag
assert sdm2.created_at == created_at
assert sdm2.last_updated != last_updated
- assert session.query(DagVersion).count() == 1
- assert session.query(SDM).count() == 1
+ assert session.scalar(select(func.count()).select_from(DagVersion)) ==
1
+ assert session.scalar(select(func.count()).select_from(SDM)) == 1
def test_new_dag_versions_are_created_if_there_is_a_dagrun(self,
dag_maker, session):
with dag_maker("dag1") as dag:
PythonOperator(task_id="task1", python_callable=lambda: None)
dag_maker.create_dagrun(run_id="test3",
logical_date=pendulum.datetime(2025, 1, 2))
- assert session.query(SDM).count() == 1
- assert session.query(DagVersion).count() == 1
+ assert session.scalar(select(func.count()).select_from(SDM)) == 1
+ assert session.scalar(select(func.count()).select_from(DagVersion)) ==
1
# new task
PythonOperator(task_id="task2", python_callable=lambda: None, dag=dag)
SDM.write_dag(LazyDeserializedDAG.from_dag(dag),
bundle_name="dag_maker")
- assert session.query(DagVersion).count() == 2
- assert session.query(SDM).count() == 2
+ assert session.scalar(select(func.count()).select_from(DagVersion)) ==
2
+ assert session.scalar(select(func.count()).select_from(SDM)) == 2
def test_example_dag_sorting_serialised_dag(self, session):
"""
@@ -517,14 +522,14 @@ class TestSerializedDagModel:
# Create TIs
dag_maker.create_dagrun(run_id="test_run")
- assert session.query(DagVersion).count() == 1
+ assert session.scalar(select(func.count()).select_from(DagVersion)) ==
1
# Write the same DAG (no changes, so hash is the same) with a new
bundle_name
new_bundle = "bundleB"
SDM.write_dag(LazyDeserializedDAG.from_dag(dag),
bundle_name=new_bundle)
# There should now be two versions of the DAG
- assert session.query(DagVersion).count() == 2
+ assert session.scalar(select(func.count()).select_from(DagVersion)) ==
2
def test_hash_method_removes_fileloc_and_remains_consistent(self):
"""Test that the hash method removes fileloc before hashing."""
@@ -632,7 +637,7 @@ class TestSerializedDagModel:
assert dag_version is not None
# Manually delete SerializedDagModel (simulates edge case)
- session.query(SDM).filter(SDM.dag_id == "test_missing_serdag").delete()
+ session.execute(delete(SDM).where(SDM.dag_id == "test_missing_serdag"))
session.commit()
# Verify no SerializedDagModel exists
@@ -709,7 +714,9 @@ class TestSerializedDagModel:
EmptyOperator(task_id="task1")
dag = dag_maker.dag
- initial_version_count =
session.query(DagVersion).filter(DagVersion.dag_id == dag.dag_id).count()
+ initial_version_count = session.scalar(
+
select(func.count()).select_from(DagVersion).where(DagVersion.dag_id ==
dag.dag_id)
+ )
assert initial_version_count == 1, "Should have one DagVersion after
initial write"
dag_maker.create_dagrun() # ensure the second dag version is created
@@ -732,8 +739,8 @@ class TestSerializedDagModel:
# Verify that no new DagVersion was committed
# Use a fresh session to ensure we're reading from committed data
with create_session() as fresh_session:
- final_version_count = (
- fresh_session.query(DagVersion).filter(DagVersion.dag_id
== dag.dag_id).count()
+ final_version_count = fresh_session.scalar(
+
select(func.count()).select_from(DagVersion).where(DagVersion.dag_id ==
dag.dag_id)
)
assert final_version_count == initial_version_count, (
"DagVersion should not be committed when
DagCode.write_code fails"
diff --git a/airflow-core/tests/unit/models/test_trigger.py
b/airflow-core/tests/unit/models/test_trigger.py
index fe2fbeb6b98..df169852cb5 100644
--- a/airflow-core/tests/unit/models/test_trigger.py
+++ b/airflow-core/tests/unit/models/test_trigger.py
@@ -26,6 +26,7 @@ import pendulum
import pytest
import pytz
from cryptography.fernet import Fernet
+from sqlalchemy import delete, func, select
from airflow._shared.timezones import timezone
from airflow.jobs.job import Job
@@ -61,21 +62,21 @@ def session():
@pytest.fixture(autouse=True)
def clear_db(session):
- session.query(TaskInstance).delete()
- session.query(AssetWatcherModel).delete()
- session.query(Callback).delete()
- session.query(Trigger).delete()
- session.query(AssetModel).delete()
- session.query(AssetEvent).delete()
- session.query(Job).delete()
+ session.execute(delete(TaskInstance))
+ session.execute(delete(AssetWatcherModel))
+ session.execute(delete(Callback))
+ session.execute(delete(Trigger))
+ session.execute(delete(AssetModel))
+ session.execute(delete(AssetEvent))
+ session.execute(delete(Job))
yield session
- session.query(TaskInstance).delete()
- session.query(AssetWatcherModel).delete()
- session.query(Callback).delete()
- session.query(Trigger).delete()
- session.query(AssetModel).delete()
- session.query(AssetEvent).delete()
- session.query(Job).delete()
+ session.execute(delete(TaskInstance))
+ session.execute(delete(AssetWatcherModel))
+ session.execute(delete(Callback))
+ session.execute(delete(Trigger))
+ session.execute(delete(AssetModel))
+ session.execute(delete(AssetEvent))
+ session.execute(delete(Job))
session.commit()
@@ -121,7 +122,7 @@ def test_clean_unused(session, create_task_instance):
session.add(trigger5)
session.add(trigger6)
session.commit()
- assert session.query(Trigger).count() == 6
+ assert session.scalar(select(func.count()).select_from(Trigger)) == 6
# Tie one to a fake TaskInstance that is not deferred, and one to one that
is
task_instance = create_task_instance(
session=session, task_id="fake", state=State.DEFERRED,
logical_date=timezone.utcnow()
@@ -150,7 +151,7 @@ def test_clean_unused(session, create_task_instance):
asset.add_trigger(trigger5, "test_asset_watcher2")
session.add(asset)
session.commit()
- assert session.query(AssetModel).count() == 1
+ assert session.scalar(select(func.count()).select_from(AssetModel)) == 1
# Create callback with trigger
callback = TriggererCallback(
@@ -162,7 +163,7 @@ def test_clean_unused(session, create_task_instance):
# Run clear operation
Trigger.clean_unused()
- results = session.query(Trigger).all()
+ results = session.scalars(select(Trigger)).all()
assert len(results) == 4
assert {result.id for result in results} == {trigger1.id, trigger4.id,
trigger5.id, trigger6.id}
@@ -196,7 +197,10 @@ def test_submit_event(mock_callback_handle_event, session,
create_task_instance)
session.commit()
# Check that the asset has 0 event prior to sending an event to the trigger
- assert session.query(AssetEvent).filter_by(asset_id=asset.id).count() == 0
+ assert (
+
session.scalar(select(func.count()).select_from(AssetEvent).where(AssetEvent.asset_id
== asset.id))
+ == 0
+ )
# Create event
payload = "payload"
@@ -210,8 +214,11 @@ def test_submit_event(mock_callback_handle_event, session,
create_task_instance)
assert task_instance.state == State.SCHEDULED
assert task_instance.next_kwargs == {"event": payload, "cheesecake": True}
# Check that the asset has received an event
- assert session.query(AssetEvent).filter_by(asset_id=asset.id).count() == 1
- asset_event =
session.query(AssetEvent).filter_by(asset_id=asset.id).first()
+ assert (
+
session.scalar(select(func.count()).select_from(AssetEvent).where(AssetEvent.asset_id
== asset.id))
+ == 1
+ )
+ asset_event = session.scalar(select(AssetEvent).where(AssetEvent.asset_id
== asset.id))
assert asset_event.extra == {"from_trigger": True, "payload": payload}
# Check that the callback's handle_event was called
@@ -233,7 +240,7 @@ def test_submit_failure(session, create_task_instance):
# Call submit_event
Trigger.submit_failure(trigger.id, session=session)
# Check that the task instance is now scheduled to fail
- updated_task_instance = session.query(TaskInstance).one()
+ updated_task_instance = session.scalar(select(TaskInstance))
assert updated_task_instance.state == State.SCHEDULED
assert updated_task_instance.next_method == "__fail__"
@@ -272,7 +279,7 @@ def test_submit_event_task_end(mock_utcnow, session,
create_task_instance, event
# now for the real test
# first check initial state
- ti: TaskInstance = session.query(TaskInstance).one()
+ ti: TaskInstance = session.scalar(select(TaskInstance))
assert ti.state == "deferred"
assert get_xcoms(ti) == []
@@ -285,7 +292,7 @@ def test_submit_event_task_end(mock_utcnow, session,
create_task_instance, event
# commit changes made by submit event and expire all cache to read from db.
session.flush()
# Check that the task instance is now correct
- ti = session.query(TaskInstance).one()
+ ti = session.scalar(select(TaskInstance))
assert ti.state == expected
assert ti.next_kwargs is None
assert ti.end_date == now
@@ -370,26 +377,26 @@ def test_assign_unassigned(session, create_task_instance):
session.add(ti_trigger_unassigned_to_triggerer)
assert trigger_unassigned_to_triggerer.triggerer_id is None
session.commit()
- assert session.query(Trigger).count() == 4
+ assert session.scalar(select(func.count()).select_from(Trigger)) == 4
Trigger.assign_unassigned(new_triggerer.id, 100, health_check_threshold=30)
session.expire_all()
# Check that trigger on killed triggerer and unassigned trigger are
assigned to new triggerer
assert (
- session.query(Trigger).filter(Trigger.id ==
trigger_on_killed_triggerer.id).one().triggerer_id
+ session.scalar(select(Trigger).where(Trigger.id ==
trigger_on_killed_triggerer.id)).triggerer_id
== new_triggerer.id
)
assert (
- session.query(Trigger).filter(Trigger.id ==
trigger_unassigned_to_triggerer.id).one().triggerer_id
+ session.scalar(select(Trigger).where(Trigger.id ==
trigger_unassigned_to_triggerer.id)).triggerer_id
== new_triggerer.id
)
# Check that trigger on healthy triggerer still assigned to existing
triggerer
assert (
- session.query(Trigger).filter(Trigger.id ==
trigger_on_healthy_triggerer.id).one().triggerer_id
+ session.scalar(select(Trigger).where(Trigger.id ==
trigger_on_healthy_triggerer.id)).triggerer_id
== healthy_triggerer.id
)
# Check that trigger on unhealthy triggerer is assigned to new triggerer
assert (
- session.query(Trigger).filter(Trigger.id ==
trigger_on_unhealthy_triggerer.id).one().triggerer_id
+ session.scalar(select(Trigger).where(Trigger.id ==
trigger_on_unhealthy_triggerer.id)).triggerer_id
== new_triggerer.id
)
@@ -453,7 +460,7 @@ def test_get_sorted_triggers_same_priority_weight(session,
create_task_instance)
)
session.add(trigger_callback)
session.commit()
- assert session.query(Trigger).count() == 5
+ assert session.scalar(select(func.count()).select_from(Trigger)) == 5
# Create assets
asset = AssetModel("test")
asset.add_trigger(trigger_asset, "test_asset_watcher")
@@ -534,7 +541,7 @@ def
test_get_sorted_triggers_different_priority_weights(session, create_task_ins
session.add(TI_new)
session.commit()
- assert session.query(Trigger).count() == 5
+ assert session.scalar(select(func.count()).select_from(Trigger)) == 5
trigger_ids_query = Trigger.get_sorted_triggers(capacity=100,
alive_triggerer_ids=[], session=session)
@@ -605,7 +612,7 @@ def test_get_sorted_triggers_dont_starve_for_ha(session,
create_task_instance):
asset_triggers.append(trigger)
session.commit()
- assert session.query(Trigger).count() == 60
+ assert session.scalar(select(func.count()).select_from(Trigger)) == 60
# Mock max_trigger_to_select_per_loop to 5 for testing
with patch.object(Trigger, "max_trigger_to_select_per_loop", 5):
diff --git a/airflow-core/tests/unit/models/test_xcom.py
b/airflow-core/tests/unit/models/test_xcom.py
index 1bc7105cd8a..acf7ad752bf 100644
--- a/airflow-core/tests/unit/models/test_xcom.py
+++ b/airflow-core/tests/unit/models/test_xcom.py
@@ -23,6 +23,7 @@ from unittest import mock
from unittest.mock import MagicMock
import pytest
+from sqlalchemy import delete, func, select
from airflow._shared.timezones import timezone
from airflow.configuration import conf
@@ -88,7 +89,7 @@ def task_instance_factory(request, session: Session):
def cleanup_database():
# This should also clear task instances by cascading.
- session.query(DagRun).filter_by(id=run.id).delete()
+ session.execute(delete(DagRun).where(DagRun.id == run.id))
session.commit()
request.addfinalizer(cleanup_database)
@@ -384,7 +385,7 @@ class TestXComSet:
run_id=task_instance.run_id,
session=session,
)
- stored_xcoms = session.query(XComModel).all()
+ stored_xcoms = session.scalars(select(XComModel)).all()
assert stored_xcoms[0].key == key
assert isinstance(stored_xcoms[0].value,
type(json.dumps(expected_value)))
assert stored_xcoms[0].value == json.dumps(expected_value)
@@ -398,7 +399,7 @@ class TestXComSet:
@pytest.mark.usefixtures("setup_for_xcom_set_again_replace")
def test_xcom_set_again_replace(self, session, task_instance):
- assert session.query(XComModel).one().value == json.dumps({"key1":
"value1"})
+ assert session.scalar(select(XComModel)).value == json.dumps({"key1":
"value1"})
XComModel.set(
key="xcom_1",
value={"key2": "value2"},
@@ -407,7 +408,7 @@ class TestXComSet:
run_id=task_instance.run_id,
session=session,
)
- assert session.query(XComModel).one().value == json.dumps({"key2":
"value2"})
+ assert session.scalar(select(XComModel)).value == json.dumps({"key2":
"value2"})
def test_xcom_set_invalid_key(self, session, task_instance):
"""Test that setting an XCom with an invalid key raises a
ValueError."""
@@ -440,14 +441,14 @@ class TestXComClear:
@pytest.mark.usefixtures("setup_for_xcom_clear")
@mock.patch("airflow.sdk.execution_time.xcom.XCom.purge")
def test_xcom_clear(self, mock_purge, session, task_instance):
- assert session.query(XComModel).count() == 1
+ assert session.scalar(select(func.count()).select_from(XComModel)) == 1
XComModel.clear(
dag_id=task_instance.dag_id,
task_id=task_instance.task_id,
run_id=task_instance.run_id,
session=session,
)
- assert session.query(XComModel).count() == 0
+ assert session.scalar(select(func.count()).select_from(XComModel)) == 0
# purge will not be done when we clear, will be handled in task sdk
assert mock_purge.call_count == 0
@@ -459,7 +460,7 @@ class TestXComClear:
run_id="different_run",
session=session,
)
- assert session.query(XComModel).count() == 1
+ assert session.scalar(select(func.count()).select_from(XComModel)) == 1
class TestXComRoundTrip:
diff --git a/airflow-core/tests/unit/utils/test_log_handlers.py
b/airflow-core/tests/unit/utils/test_log_handlers.py
index 30669ab0599..cbe1a61d1c6 100644
--- a/airflow-core/tests/unit/utils/test_log_handlers.py
+++ b/airflow-core/tests/unit/utils/test_log_handlers.py
@@ -36,6 +36,7 @@ import pytest
from pydantic import TypeAdapter
from pydantic.v1.utils import deep_update
from requests.adapters import Response
+from sqlalchemy import delete, select
from airflow import settings
from airflow.config_templates.airflow_local_settings import
DEFAULT_LOGGING_CONFIG
@@ -98,8 +99,8 @@ def cleanup_tables():
class TestFileTaskLogHandler:
def clean_up(self):
with create_session() as session:
- session.query(DagRun).delete()
- session.query(TaskInstance).delete()
+ session.execute(delete(DagRun))
+ session.execute(delete(TaskInstance))
def setup_method(self):
settings.configure_logging()
@@ -781,16 +782,14 @@ class TestFilenameRendering:
)
TaskInstanceHistory.record_ti(ti, session=session)
session.flush()
- tih = (
- session.query(TaskInstanceHistory)
- .filter_by(
- dag_id=ti.dag_id,
- task_id=ti.task_id,
- run_id=ti.run_id,
- map_index=ti.map_index,
- try_number=ti.try_number,
+ tih = session.scalar(
+ select(TaskInstanceHistory).where(
+ TaskInstanceHistory.dag_id == ti.dag_id,
+ TaskInstanceHistory.task_id == ti.task_id,
+ TaskInstanceHistory.run_id == ti.run_id,
+ TaskInstanceHistory.map_index == ti.map_index,
+ TaskInstanceHistory.try_number == ti.try_number,
)
- .one()
)
fth = FileTaskHandler("")
rendered_ti = fth._render_filename(ti, ti.try_number, session=session)
diff --git a/airflow-core/tests/unit/utils/test_state.py
b/airflow-core/tests/unit/utils/test_state.py
index 463f9433204..88a1925842e 100644
--- a/airflow-core/tests/unit/utils/test_state.py
+++ b/airflow-core/tests/unit/utils/test_state.py
@@ -19,6 +19,7 @@ from __future__ import annotations
from datetime import timedelta
import pytest
+from sqlalchemy import select
from airflow.models.dagrun import DagRun
from airflow.sdk import DAG
@@ -58,22 +59,18 @@ def test_dagrun_state_enum_escape(testing_dag_bundle):
triggered_by=DagRunTriggeredByType.TEST,
)
- query = session.query(
- DagRun.dag_id,
- DagRun.state,
- DagRun.run_type,
- ).filter(
+ stmt = select(DagRun.dag_id, DagRun.state, DagRun.run_type).where(
DagRun.dag_id == dag.dag_id,
# make sure enum value can be used in filter queries
DagRun.state == DagRunState.QUEUED,
)
- assert str(query.statement.compile(compile_kwargs={"literal_binds":
True})) == (
+ assert str(stmt.compile(compile_kwargs={"literal_binds": True})) == (
"SELECT dag_run.dag_id, dag_run.state, dag_run.run_type \n"
"FROM dag_run \n"
"WHERE dag_run.dag_id = 'test_dagrun_state_enum_escape' AND
dag_run.state = 'queued'"
)
- rows = query.all()
+ rows = session.execute(stmt).all()
assert len(rows) == 1
assert rows[0].dag_id == dag.dag_id
# make sure value in db is stored as `queued`, not `DagRunType.QUEUED`