This is an automated email from the ASF dual-hosted git repository.
jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 5bcc4fb2ad4 Remove findings from positional session check in all
leftover Airflow-Core Model Modules (#67872)
5bcc4fb2ad4 is described below
commit 5bcc4fb2ad4a04a6031b2cf407c8a47e6281f30a
Author: Jens Scheffler <[email protected]>
AuthorDate: Wed Jun 3 07:52:25 2026 +0200
Remove findings from positional session check in all leftover Airflow-Core
Model Modules (#67872)
* Fix exceptions of positional session use in all leftover airflow-core
models modules
* Fix pytests
---
.../src/airflow/cli/commands/team_command.py | 2 +-
airflow-core/src/airflow/models/connection.py | 4 ++--
airflow-core/src/airflow/models/dagrun.py | 2 +-
airflow-core/src/airflow/models/deadline.py | 2 +-
airflow-core/src/airflow/models/deadline_alert.py | 2 +-
airflow-core/src/airflow/models/pool.py | 24 +++++++++++-----------
.../src/airflow/models/renderedtifields.py | 6 ++++--
airflow-core/src/airflow/models/revoked_token.py | 4 ++--
airflow-core/src/airflow/models/serialized_dag.py | 11 +++++-----
airflow-core/src/airflow/models/team.py | 2 +-
airflow-core/src/airflow/models/trigger.py | 12 ++++++-----
airflow-core/src/airflow/models/variable.py | 6 ++++--
airflow-core/tests/unit/models/test_dag.py | 2 +-
airflow-core/tests/unit/models/test_dagrun.py | 4 ++--
airflow-core/tests/unit/models/test_timestamp.py | 4 ++--
devel-common/src/tests_common/test_utils/db.py | 2 +-
.../ci/prek/known_provide_session_positional.txt | 11 ----------
17 files changed, 48 insertions(+), 52 deletions(-)
diff --git a/airflow-core/src/airflow/cli/commands/team_command.py
b/airflow-core/src/airflow/cli/commands/team_command.py
index 6106fcddb2e..931b5507be3 100644
--- a/airflow-core/src/airflow/cli/commands/team_command.py
+++ b/airflow-core/src/airflow/cli/commands/team_command.py
@@ -172,7 +172,7 @@ def team_sync(args, *, session=NEW_SESSION):
teams_added = 0
try:
- for team_name in dag_bundle_teams - Team.get_all_team_names(session):
+ for team_name in dag_bundle_teams -
Team.get_all_team_names(session=session):
team = Team(name=team_name)
session.add(team)
teams_added += 1
diff --git a/airflow-core/src/airflow/models/connection.py
b/airflow-core/src/airflow/models/connection.py
index 819ec8d8281..170351cc651 100644
--- a/airflow-core/src/airflow/models/connection.py
+++ b/airflow-core/src/airflow/models/connection.py
@@ -646,14 +646,14 @@ class Connection(Base, LoggingMixin):
@staticmethod
@provide_session
- def get_team_name(connection_id: str, session=NEW_SESSION) -> str | None:
+ def get_team_name(connection_id: str, *, session=NEW_SESSION) -> str |
None:
stmt = select(Connection.team_name).where(Connection.conn_id ==
connection_id)
return session.scalar(stmt)
@staticmethod
@provide_session
def get_conn_id_to_team_name_mapping(
- connection_ids: list[str], session=NEW_SESSION
+ connection_ids: list[str], *, session=NEW_SESSION
) -> dict[str, str | None]:
stmt = select(Connection.conn_id,
Connection.team_name).where(Connection.conn_id.in_(connection_ids))
return {conn_id: team_name for conn_id, team_name in
session.execute(stmt)}
diff --git a/airflow-core/src/airflow/models/dagrun.py
b/airflow-core/src/airflow/models/dagrun.py
index 36ef372e39a..564e11a9522 100644
--- a/airflow-core/src/airflow/models/dagrun.py
+++ b/airflow-core/src/airflow/models/dagrun.py
@@ -1223,7 +1223,7 @@ class DagRun(Base, LoggingMixin):
if dag.deadline:
# The dagrun has succeeded. If there were any Deadlines for
it which were not breached, they are no longer needed.
deadline_alerts = [
- DeadlineAlertModel.get_by_id(alert_id, session) for
alert_id in dag.deadline
+ DeadlineAlertModel.get_by_id(alert_id, session=session)
for alert_id in dag.deadline
]
if any(
diff --git a/airflow-core/src/airflow/models/deadline.py
b/airflow-core/src/airflow/models/deadline.py
index 13351c72c77..6fda3597504 100644
--- a/airflow-core/src/airflow/models/deadline.py
+++ b/airflow-core/src/airflow/models/deadline.py
@@ -490,7 +490,7 @@ DeadlineReferenceType =
ReferenceModels.BaseDeadlineReference
@provide_session
-def _fetch_from_db(model_reference: Mapped, session=None, **conditions) ->
datetime | None:
+def _fetch_from_db(model_reference: Mapped, *, session=None, **conditions) ->
datetime | None:
"""
Fetch a datetime value from the database using the provided model
reference and filtering conditions.
diff --git a/airflow-core/src/airflow/models/deadline_alert.py
b/airflow-core/src/airflow/models/deadline_alert.py
index d9b6590c0f7..20bfed459ee 100644
--- a/airflow-core/src/airflow/models/deadline_alert.py
+++ b/airflow-core/src/airflow/models/deadline_alert.py
@@ -102,7 +102,7 @@ class DeadlineAlert(Base):
@classmethod
@provide_session
- def get_by_id(cls, deadline_alert_id: str | UUID, session: Session =
NEW_SESSION) -> DeadlineAlert:
+ def get_by_id(cls, deadline_alert_id: str | UUID, *, session: Session =
NEW_SESSION) -> DeadlineAlert:
"""
Retrieve a DeadlineAlert record by its UUID.
diff --git a/airflow-core/src/airflow/models/pool.py
b/airflow-core/src/airflow/models/pool.py
index 65f17c15b19..08eff959ea2 100644
--- a/airflow-core/src/airflow/models/pool.py
+++ b/airflow-core/src/airflow/models/pool.py
@@ -94,13 +94,13 @@ class Pool(Base):
@staticmethod
@provide_session
- def get_pools(session: Session = NEW_SESSION) -> Sequence[Pool]:
+ def get_pools(*, session: Session = NEW_SESSION) -> Sequence[Pool]:
"""Get all pools."""
return session.scalars(select(Pool)).all()
@staticmethod
@provide_session
- def get_pool(pool_name: str, session: Session = NEW_SESSION) -> Pool |
None:
+ def get_pool(pool_name: str, *, session: Session = NEW_SESSION) -> Pool |
None:
"""
Get the Pool with specific pool name from the Pools.
@@ -112,7 +112,7 @@ class Pool(Base):
@staticmethod
@provide_session
- def get_default_pool(session: Session = NEW_SESSION) -> Pool | None:
+ def get_default_pool(*, session: Session = NEW_SESSION) -> Pool | None:
"""
Get the Pool of the default_pool from the Pools.
@@ -251,7 +251,7 @@ class Pool(Base):
}
@provide_session
- def occupied_slots(self, session: Session = NEW_SESSION) -> int:
+ def occupied_slots(self, *, session: Session = NEW_SESSION) -> int:
"""
Get the number of slots used by running/queued tasks at the moment.
@@ -279,7 +279,7 @@ class Pool(Base):
return EXECUTION_STATES
@provide_session
- def running_slots(self, session: Session = NEW_SESSION) -> int:
+ def running_slots(self, *, session: Session = NEW_SESSION) -> int:
"""
Get the number of slots used by running tasks at the moment.
@@ -298,7 +298,7 @@ class Pool(Base):
)
@provide_session
- def queued_slots(self, session: Session = NEW_SESSION) -> int:
+ def queued_slots(self, *, session: Session = NEW_SESSION) -> int:
"""
Get the number of slots used by queued tasks at the moment.
@@ -317,7 +317,7 @@ class Pool(Base):
)
@provide_session
- def scheduled_slots(self, session: Session = NEW_SESSION) -> int:
+ def scheduled_slots(self, *, session: Session = NEW_SESSION) -> int:
"""
Get the number of slots scheduled at the moment.
@@ -336,7 +336,7 @@ class Pool(Base):
)
@provide_session
- def deferred_slots(self, session: Session = NEW_SESSION) -> int:
+ def deferred_slots(self, *, session: Session = NEW_SESSION) -> int:
"""
Get the number of slots deferred at the moment.
@@ -355,7 +355,7 @@ class Pool(Base):
)
@provide_session
- def open_slots(self, session: Session = NEW_SESSION) -> float:
+ def open_slots(self, *, session: Session = NEW_SESSION) -> float:
"""
Get the number of slots open at the moment.
@@ -364,18 +364,18 @@ class Pool(Base):
"""
if self.slots == -1:
return float("inf")
- return self.slots - self.occupied_slots(session)
+ return self.slots - self.occupied_slots(session=session)
@staticmethod
@provide_session
- def get_team_name(pool_name: str, session: Session = NEW_SESSION) -> str |
None:
+ def get_team_name(pool_name: str, *, session: Session = NEW_SESSION) ->
str | None:
stmt = select(Pool.team_name).where(Pool.pool == pool_name)
return session.scalar(stmt)
@staticmethod
@provide_session
def get_name_to_team_name_mapping(
- pool_names: list[str], session: Session = NEW_SESSION
+ pool_names: list[str], *, session: Session = NEW_SESSION
) -> dict[str, str | None]:
stmt = select(Pool.pool,
Pool.team_name).where(Pool.pool.in_(pool_names))
return {pool: team_name for pool, team_name in session.execute(stmt)}
diff --git a/airflow-core/src/airflow/models/renderedtifields.py
b/airflow-core/src/airflow/models/renderedtifields.py
index e405f3bfce7..01e68849d05 100644
--- a/airflow-core/src/airflow/models/renderedtifields.py
+++ b/airflow-core/src/airflow/models/renderedtifields.py
@@ -194,6 +194,7 @@ class RenderedTaskInstanceFields(TaskInstanceDependencies):
def get_templated_fields(
cls,
ti: TaskInstance | TaskInstanceKey,
+ *,
session: Session = NEW_SESSION,
) -> dict | None:
"""
@@ -219,7 +220,7 @@ class RenderedTaskInstanceFields(TaskInstanceDependencies):
@classmethod
@provide_session
- def get_k8s_pod_yaml(cls, ti: TaskInstance, session: Session =
NEW_SESSION) -> dict | None:
+ def get_k8s_pod_yaml(cls, ti: TaskInstance, *, session: Session =
NEW_SESSION) -> dict | None:
"""
Get rendered Kubernetes Pod Yaml for a TaskInstance from the
RenderedTaskInstanceFields table.
@@ -239,7 +240,7 @@ class RenderedTaskInstanceFields(TaskInstanceDependencies):
@provide_session
@retry_db_transaction
- def write(self, session: Session = NEW_SESSION):
+ def write(self, *, session: Session = NEW_SESSION) -> None:
"""
Write instance to database.
@@ -279,6 +280,7 @@ class RenderedTaskInstanceFields(TaskInstanceDependencies):
task_id: str,
dag_id: str,
num_to_keep: int = conf.getint("core",
"num_dag_runs_to_retain_rendered_fields", fallback=0),
+ *,
session: Session = NEW_SESSION,
) -> None:
"""
diff --git a/airflow-core/src/airflow/models/revoked_token.py
b/airflow-core/src/airflow/models/revoked_token.py
index 39dc25c6884..200be269e32 100644
--- a/airflow-core/src/airflow/models/revoked_token.py
+++ b/airflow-core/src/airflow/models/revoked_token.py
@@ -49,13 +49,13 @@ class RevokedToken(Base):
@classmethod
@provide_session
- def revoke(cls, jti: str, exp: datetime, session: Session = NEW_SESSION)
-> None:
+ def revoke(cls, jti: str, exp: datetime, *, session: Session =
NEW_SESSION) -> None:
"""Add a token JTI to the revoked tokens."""
session.merge(cls(jti=jti, exp=exp))
@classmethod
@provide_session
- def is_revoked(cls, jti: str, session: Session = NEW_SESSION) -> bool:
+ def is_revoked(cls, jti: str, *, session: Session = NEW_SESSION) -> bool:
"""Check if a token JTI has been revoked."""
cls._maybe_cleanup_expired(session)
return bool(session.scalar(select(exists().where(cls.jti == jti))))
diff --git a/airflow-core/src/airflow/models/serialized_dag.py
b/airflow-core/src/airflow/models/serialized_dag.py
index d2d9d47d61c..4f20c3a180e 100644
--- a/airflow-core/src/airflow/models/serialized_dag.py
+++ b/airflow-core/src/airflow/models/serialized_dag.py
@@ -599,6 +599,7 @@ class SerializedDagModel(Base):
bundle_version: str | None = None,
version_data: dict | None = None,
min_update_interval: int | None = None,
+ *,
session: Session = NEW_SESSION,
_prefetched: DagWriteMetadata | None = None,
) -> bool:
@@ -807,7 +808,7 @@ class SerializedDagModel(Base):
@classmethod
@provide_session
- def read_all_dags(cls, session: Session = NEW_SESSION) -> dict[str,
SerializedDAG]:
+ def read_all_dags(cls, *, session: Session = NEW_SESSION) -> dict[str,
SerializedDAG]:
"""
Read all DAGs in serialized_dag table.
@@ -866,7 +867,7 @@ class SerializedDagModel(Base):
@classmethod
@provide_session
- def has_dag(cls, dag_id: str, session: Session = NEW_SESSION) -> bool:
+ def has_dag(cls, dag_id: str, *, session: Session = NEW_SESSION) -> bool:
"""
Check a DAG exist in serialized_dag table.
@@ -877,7 +878,7 @@ class SerializedDagModel(Base):
@classmethod
@provide_session
- def get_dag(cls, dag_id: str, session: Session = NEW_SESSION) ->
SerializedDAG | None:
+ def get_dag(cls, dag_id: str, *, session: Session = NEW_SESSION) ->
SerializedDAG | None:
row = cls.get(dag_id, session=session)
if row:
return row.dag
@@ -885,7 +886,7 @@ class SerializedDagModel(Base):
@classmethod
@provide_session
- def get(cls, dag_id: str, session: Session = NEW_SESSION) ->
SerializedDagModel | None:
+ def get(cls, dag_id: str, *, session: Session = NEW_SESSION) ->
SerializedDagModel | None:
"""
Get the SerializedDAG for the given dag ID.
@@ -896,7 +897,7 @@ class SerializedDagModel(Base):
@classmethod
@provide_session
- def get_dag_dependencies(cls, session: Session = NEW_SESSION) -> dict[str,
list[DagDependency]]:
+ def get_dag_dependencies(cls, *, session: Session = NEW_SESSION) ->
dict[str, list[DagDependency]]:
"""
Get the dependencies between DAGs.
diff --git a/airflow-core/src/airflow/models/team.py
b/airflow-core/src/airflow/models/team.py
index 7ff6085eabf..64877489c66 100644
--- a/airflow-core/src/airflow/models/team.py
+++ b/airflow-core/src/airflow/models/team.py
@@ -68,7 +68,7 @@ class Team(Base):
@classmethod
@provide_session
- def get_all_team_names(cls, session: Session = NEW_SESSION) -> set[str]:
+ def get_all_team_names(cls, *, session: Session = NEW_SESSION) -> set[str]:
"""
Return a set of all team names from the database.
diff --git a/airflow-core/src/airflow/models/trigger.py
b/airflow-core/src/airflow/models/trigger.py
index 09c5a1054e5..6fecd4d825f 100644
--- a/airflow-core/src/airflow/models/trigger.py
+++ b/airflow-core/src/airflow/models/trigger.py
@@ -204,7 +204,7 @@ class Trigger(Base):
@classmethod
@provide_session
- def bulk_fetch(cls, ids: Iterable[int], session: Session = NEW_SESSION) ->
dict[int, Trigger]:
+ def bulk_fetch(cls, ids: Iterable[int], *, session: Session = NEW_SESSION)
-> dict[int, Trigger]:
"""Fetch all the Triggers by ID and return a dict mapping ID ->
Trigger instance."""
stmt = (
select(cls)
@@ -219,7 +219,7 @@ class Trigger(Base):
@classmethod
@provide_session
- def fetch_trigger_ids_with_non_task_associations(cls, session: Session =
NEW_SESSION) -> set[int]:
+ def fetch_trigger_ids_with_non_task_associations(cls, *, session: Session
= NEW_SESSION) -> set[int]:
"""Fetch all trigger IDs actively associated with non-task entities
like assets and callbacks."""
from airflow.models.callback import Callback # to avoid circular
import: Callback -> Trigger
@@ -231,7 +231,7 @@ class Trigger(Base):
@classmethod
@provide_session
- def clean_unused(cls, session: Session = NEW_SESSION) -> None:
+ def clean_unused(cls, *, session: Session = NEW_SESSION) -> None:
"""
Delete all triggers that have no tasks dependent on them and are not
associated to an asset.
@@ -270,7 +270,7 @@ class Trigger(Base):
@classmethod
@provide_session
- def submit_event(cls, trigger_id, event: TriggerEvent, session: Session =
NEW_SESSION) -> None:
+ def submit_event(cls, trigger_id, event: TriggerEvent, *, session: Session
= NEW_SESSION) -> None:
"""
Fire an event.
@@ -301,7 +301,7 @@ class Trigger(Base):
@classmethod
@provide_session
- def submit_failure(cls, trigger_id, exc=None, session: Session =
NEW_SESSION) -> None:
+ def submit_failure(cls, trigger_id, exc=None, *, session: Session =
NEW_SESSION) -> None:
"""
When a trigger has failed unexpectedly, mark everything that depended
on it as failed.
@@ -341,6 +341,7 @@ class Trigger(Base):
triggerer_id,
queues: set[str] | None = None,
team_name: str | None = None,
+ *,
session: Session = NEW_SESSION,
) -> list[int]:
"""Retrieve a list of trigger ids."""
@@ -372,6 +373,7 @@ class Trigger(Base):
health_check_threshold,
queues: set[str] | None = None,
team_name: str | None = None,
+ *,
session: Session = NEW_SESSION,
) -> None:
"""
diff --git a/airflow-core/src/airflow/models/variable.py
b/airflow-core/src/airflow/models/variable.py
index eb50b92f988..f7a2aba5dda 100644
--- a/airflow-core/src/airflow/models/variable.py
+++ b/airflow-core/src/airflow/models/variable.py
@@ -497,12 +497,14 @@ class Variable(Base, LoggingMixin):
@staticmethod
@provide_session
- def get_team_name(variable_key: str, session=NEW_SESSION) -> str | None:
+ def get_team_name(variable_key: str, *, session=NEW_SESSION) -> str | None:
stmt = select(Variable.team_name).where(Variable.key == variable_key)
return session.scalar(stmt)
@staticmethod
@provide_session
- def get_key_to_team_name_mapping(variable_keys: list[str],
session=NEW_SESSION) -> dict[str, str | None]:
+ def get_key_to_team_name_mapping(
+ variable_keys: list[str], *, session=NEW_SESSION
+ ) -> dict[str, str | None]:
stmt = select(Variable.key,
Variable.team_name).where(Variable.key.in_(variable_keys))
return {key: team_name for key, team_name in session.execute(stmt)}
diff --git a/airflow-core/tests/unit/models/test_dag.py
b/airflow-core/tests/unit/models/test_dag.py
index 0a15c9ab09d..56b16accd06 100644
--- a/airflow-core/tests/unit/models/test_dag.py
+++ b/airflow-core/tests/unit/models/test_dag.py
@@ -2750,7 +2750,7 @@ class TestDagModel:
scheduler_dag = sync_dag_to_db(dag)
assert scheduler_dag._processor_dags_folder == settings.DAGS_FOLDER
- sdm = SerializedDagModel.get(dag.dag_id, session)
+ sdm = SerializedDagModel.get(dag.dag_id, session=session)
assert sdm.dag._processor_dags_folder == settings.DAGS_FOLDER
@pytest.mark.need_serialized_dag
diff --git a/airflow-core/tests/unit/models/test_dagrun.py
b/airflow-core/tests/unit/models/test_dagrun.py
index 9405d8a9555..1154a19617e 100644
--- a/airflow-core/tests/unit/models/test_dagrun.py
+++ b/airflow-core/tests/unit/models/test_dagrun.py
@@ -1401,7 +1401,7 @@ class TestDagRun:
assert mock_get_by_id.call_count == len(deadline_ids)
for deadline_id in deadline_ids:
- mock_get_by_id.assert_any_call(deadline_id, session)
+ mock_get_by_id.assert_any_call(deadline_id, session=session)
mock_prune.assert_called_once_with(session=session,
conditions={DagRun.id: dag_run.id})
assert dag_run.state == DagRunState.SUCCESS
@@ -1431,7 +1431,7 @@ class TestDagRun:
dag_run.update_state(session=session)
- mock_get_by_id.assert_called_once_with(deadline_id, session)
+ mock_get_by_id.assert_called_once_with(deadline_id, session=session)
mock_prune.assert_not_called()
assert dag_run.state == DagRunState.SUCCESS
diff --git a/airflow-core/tests/unit/models/test_timestamp.py
b/airflow-core/tests/unit/models/test_timestamp.py
index b200fe3ecdb..9f469931a50 100644
--- a/airflow-core/tests/unit/models/test_timestamp.py
+++ b/airflow-core/tests/unit/models/test_timestamp.py
@@ -55,7 +55,7 @@ def add_log(execdate, session, dag_maker,
timezone_override=None):
@provide_session
-def test_timestamp_behaviour(dag_maker, session):
+def test_timestamp_behaviour(dag_maker, *, session):
execdate = timezone.utcnow()
with time_machine.travel(execdate, tick=False):
current_time = timezone.utcnow()
@@ -67,7 +67,7 @@ def test_timestamp_behaviour(dag_maker, session):
@provide_session
-def test_timestamp_behaviour_with_timezone(dag_maker, session):
+def test_timestamp_behaviour_with_timezone(dag_maker, *, session):
execdate = timezone.utcnow()
with time_machine.travel(execdate, tick=False):
current_time = timezone.utcnow()
diff --git a/devel-common/src/tests_common/test_utils/db.py
b/devel-common/src/tests_common/test_utils/db.py
index bf8a7b6d59f..6aad852a2d4 100644
--- a/devel-common/src/tests_common/test_utils/db.py
+++ b/devel-common/src/tests_common/test_utils/db.py
@@ -380,7 +380,7 @@ def clear_db_callbacks():
@_retry_db
def set_default_pool_slots(slots):
with create_session() as session:
- default_pool = Pool.get_default_pool(session)
+ default_pool = Pool.get_default_pool(session=session)
default_pool.slots = slots
diff --git a/scripts/ci/prek/known_provide_session_positional.txt
b/scripts/ci/prek/known_provide_session_positional.txt
index 7142b0362c5..e69de29bb2d 100644
--- a/scripts/ci/prek/known_provide_session_positional.txt
+++ b/scripts/ci/prek/known_provide_session_positional.txt
@@ -1,11 +0,0 @@
-airflow-core/src/airflow/models/connection.py::2
-airflow-core/src/airflow/models/deadline.py::1
-airflow-core/src/airflow/models/deadline_alert.py::1
-airflow-core/src/airflow/models/pool.py::11
-airflow-core/src/airflow/models/renderedtifields.py::4
-airflow-core/src/airflow/models/revoked_token.py::2
-airflow-core/src/airflow/models/serialized_dag.py::6
-airflow-core/src/airflow/models/team.py::1
-airflow-core/src/airflow/models/trigger.py::7
-airflow-core/src/airflow/models/variable.py::2
-airflow-core/tests/unit/models/test_timestamp.py::2