This is an automated email from the ASF dual-hosted git repository.
jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 201acda3545 Remove findings from positional session check in Core TI
Modules (#67809)
201acda3545 is described below
commit 201acda35458470de9818ab8344a41b43a704399
Author: Jens Scheffler <[email protected]>
AuthorDate: Mon Jun 1 12:31:52 2026 +0200
Remove findings from positional session check in Core TI Modules (#67809)
* Fix exceptions of positional session use in airflow-core models task
instance modules
* Fix pytests
---
.../execution_api/routes/task_instances.py | 6 +--
airflow-core/src/airflow/models/taskinstance.py | 47 ++++++++++++++--------
.../src/airflow/models/taskinstancehistory.py | 4 +-
.../src/airflow/ti_deps/deps/dagrun_exists_dep.py | 2 +-
.../ti_deps/deps/mapped_task_upstream_dep.py | 2 +-
.../ti_deps/deps/not_previously_skipped_dep.py | 4 +-
.../airflow/ti_deps/deps/runnable_exec_date_dep.py | 2 +-
.../src/airflow/ti_deps/deps/trigger_rule_dep.py | 12 +++---
airflow-core/tests/unit/jobs/test_scheduler_job.py | 12 +++---
airflow-core/tests/unit/models/test_cleartasks.py | 6 +--
airflow-core/tests/unit/models/test_dagrun.py | 2 +-
.../tests/unit/models/test_taskinstance.py | 17 ++++----
.../unit/standard/utils/test_sensor_helper.py | 2 +-
.../ci/prek/known_provide_session_positional.txt | 3 --
14 files changed, 68 insertions(+), 53 deletions(-)
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index a0677c55701..2afd96806c4 100644
---
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -578,7 +578,7 @@ def _create_ti_state_update_query_and_update_state(
_handle_fail_fast_for_dag(ti=ti, dag_id=dag_id,
session=session, dag_bag=dag_bag)
elif isinstance(ti_patch_payload, TIRetryStatePayload):
if ti is not None:
- ti.prepare_db_for_next_try(session)
+ ti.prepare_db_for_next_try(session=session)
# Store retry policy overrides so next_retry_datetime() can read
them.
# These are cleared when the task enters RUNNING (ti_run).
query = query.values(
@@ -591,7 +591,7 @@ def _create_ti_state_update_query_and_update_state(
ti,
ti_patch_payload.task_outlets,
ti_patch_payload.outlet_events,
- session,
+ session=session,
)
try:
_emit_task_span(ti, state=updated_state)
@@ -894,7 +894,7 @@ def ti_put_rtif(
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
)
- task_instance.update_rtif(put_rtif_payload, session)
+ task_instance.update_rtif(put_rtif_payload, session=session)
log.debug("RenderedTaskInstanceFields updated successfully")
return {"message": "Rendered task instance fields successfully set"}
diff --git a/airflow-core/src/airflow/models/taskinstance.py
b/airflow-core/src/airflow/models/taskinstance.py
index cd254e53a83..f9d498c1ab1 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -148,6 +148,7 @@ def _add_log(
owner=None,
owner_display_name=None,
extra=None,
+ *,
session: Session = NEW_SESSION,
**kwargs,
):
@@ -191,7 +192,7 @@ def _stop_remaining_tasks(*, task_instance: TaskInstance,
task_teardown_map=None
log.info("Forcing task %s to fail due to dag's `fail_fast`
setting", ti.task_id)
msg = "Forcing task to fail due to dag's `fail_fast` setting."
session.add(Log(event="fail task", extra=msg,
task_instance=ti.key))
- ti.error(session)
+ ti.error(session=session)
else:
log.info("Setting task %s to SKIPPED due to dag's `fail_fast`
setting.", ti.task_id)
msg = "Skipping task due to dag's `fail_fast` setting."
@@ -822,7 +823,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
return self.log_url
@provide_session
- def error(self, session: Session = NEW_SESSION) -> None:
+ def error(self, *, session: Session = NEW_SESSION) -> None:
"""
Force the task instance's state to FAILED in the database.
@@ -842,6 +843,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
task_id: str,
map_index: int,
lock_for_update: bool = False,
+ *,
session: Session = NEW_SESSION,
) -> TaskInstance | None:
query = (
@@ -866,7 +868,11 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
@provide_session
def refresh_from_db(
- self, session: Session = NEW_SESSION, lock_for_update: bool = False,
keep_local_changes: bool = False
+ self,
+ *,
+ session: Session = NEW_SESSION,
+ lock_for_update: bool = False,
+ keep_local_changes: bool = False,
) -> None:
"""
Refresh the task instance from the database based on the primary key.
@@ -962,7 +968,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
return self.executor
@provide_session
- def set_state(self, state: str | None, session: Session = NEW_SESSION) ->
bool:
+ def set_state(self, state: str | None, *, session: Session = NEW_SESSION)
-> bool:
"""
Set TaskInstance state.
@@ -976,7 +982,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
current_time = timezone.utcnow()
self.log.debug("Setting task state for %s to %s", self, state)
if self not in session:
- self.refresh_from_db(session)
+ self.refresh_from_db(session=session)
self.state = state
self.start_date = self.start_date or current_time
if self.state in State.finished or self.state ==
TaskInstanceState.UP_FOR_RETRY:
@@ -1001,7 +1007,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
self.id = uuid7()
@provide_session
- def are_dependents_done(self, session: Session = NEW_SESSION) -> bool:
+ def are_dependents_done(self, *, session: Session = NEW_SESSION) -> bool:
"""
Check whether the immediate dependents of this task instance have
succeeded or have been skipped.
@@ -1033,6 +1039,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
def get_previous_dagrun(
self,
state: DagRunState | None = None,
+ *,
session: Session | None = None,
) -> DagRun | None:
"""
@@ -1073,6 +1080,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
def get_previous_ti(
self,
state: DagRunState | None = None,
+ *,
session: Session = NEW_SESSION,
) -> TaskInstance | None:
"""
@@ -1088,7 +1096,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
@provide_session
def are_dependencies_met(
- self, dep_context: DepContext | None = None, session: Session =
NEW_SESSION, verbose: bool = False
+ self, dep_context: DepContext | None = None, *, session: Session =
NEW_SESSION, verbose: bool = False
) -> bool:
"""
Are all conditions met for this task instance to be run given the
context for the dependencies.
@@ -1129,7 +1137,9 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
return True
@provide_session
- def get_failed_dep_statuses(self, dep_context: DepContext | None = None,
session: Session = NEW_SESSION):
+ def get_failed_dep_statuses(
+ self, dep_context: DepContext | None = None, *, session: Session =
NEW_SESSION
+ ):
"""Get failed Dependencies."""
if TYPE_CHECKING:
assert self.task is not None
@@ -1229,7 +1239,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
return dr
@provide_session
- def get_dagrun(self, session: Session = NEW_SESSION) -> DagRun:
+ def get_dagrun(self, *, session: Session = NEW_SESSION) -> DagRun:
"""
Return the DagRun for this TaskInstance.
@@ -1273,6 +1283,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
hostname: str = "",
pool: str | None = None,
external_executor_id: str | None = None,
+ *,
session: Session = NEW_SESSION,
) -> bool:
"""
@@ -1410,6 +1421,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
test_mode: bool = False,
pool: str | None = None,
external_executor_id: str | None = None,
+ *,
session: Session = NEW_SESSION,
) -> bool:
return TaskInstance._check_and_change_state_before_execution(
@@ -1489,6 +1501,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
ti: TaskInstance,
task_outlets: list[AssetProfile],
outlet_events: list[dict[str, Any]],
+ *,
session: Session = NEW_SESSION,
) -> None:
from airflow.serialization.definitions.assets import (
@@ -1667,7 +1680,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
)
@provide_session
- def update_rtif(self, rendered_fields, session: Session = NEW_SESSION):
+ def update_rtif(self, rendered_fields, *, session: Session = NEW_SESSION):
from airflow.models.renderedtifields import RenderedTaskInstanceFields
rtif = RenderedTaskInstanceFields(ti=self, render_templates=False,
rendered_fields=rendered_fields)
@@ -1696,7 +1709,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
# the side effect is the changes done to the task instance aren't
picked up by the scheduler and
# thus the task instance isn't processed until the scheduler is
restarted.
@provide_session
- def defer_task(self, session: Session = NEW_SESSION) -> bool:
+ def defer_task(self, *, session: Session = NEW_SESSION) -> bool:
"""
Mark the task as deferred and sets up the trigger that is needed to
resume it when TaskDeferred is raised.
@@ -1773,7 +1786,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
if error:
cls.logger().error("%s", error)
if not test_mode:
- ti.refresh_from_db(session)
+ ti.refresh_from_db(session=session)
ti.end_date = timezone.utcnow()
ti.set_duration()
@@ -1825,7 +1838,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
@staticmethod
@provide_session
- def save_to_db(ti: TaskInstance, session: Session = NEW_SESSION):
+ def save_to_db(ti: TaskInstance, *, session: Session = NEW_SESSION):
ti.updated_at = timezone.utcnow()
session.merge(ti)
session.flush()
@@ -1836,6 +1849,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
self,
error: None | str,
test_mode: bool | None = None,
+ *,
session: Session = NEW_SESSION,
) -> None:
"""
@@ -1865,7 +1879,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
_log_state(task_instance=self)
if not test_mode:
- TaskInstance.save_to_db(ti, session)
+ TaskInstance.save_to_db(ti, session=session)
def is_eligible_to_retry(self) -> bool:
"""Is task instance is eligible for retry."""
@@ -1896,6 +1910,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
self,
key: str,
value: Any,
+ *,
session: Session = NEW_SESSION,
) -> None:
"""
@@ -1921,8 +1936,8 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
dag_id: str | None = None,
key: str = XCOM_RETURN_KEY,
include_prior_dates: bool = False,
- session: Session = NEW_SESSION,
*,
+ session: Session = NEW_SESSION,
map_indexes: int | Iterable[int] | None = None,
default: Any = None,
run_id: str | None = None,
@@ -1998,7 +2013,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
)
@provide_session
- def get_num_running_task_instances(self, session: Session, same_dagrun:
bool = False) -> int:
+ def get_num_running_task_instances(self, *, session: Session, same_dagrun:
bool = False) -> int:
"""Count running TIs from the DB."""
warnings.warn(
"This function is deprecated and will be removed in Airflow.",
diff --git a/airflow-core/src/airflow/models/taskinstancehistory.py
b/airflow-core/src/airflow/models/taskinstancehistory.py
index 4989b949a52..b0d55114bb6 100644
--- a/airflow-core/src/airflow/models/taskinstancehistory.py
+++ b/airflow-core/src/airflow/models/taskinstancehistory.py
@@ -195,7 +195,7 @@ class TaskInstanceHistory(Base):
@staticmethod
@provide_session
- def record_ti(ti: TaskInstance, session: Session = NEW_SESSION) -> None:
+ def record_ti(ti: TaskInstance, *, session: Session = NEW_SESSION) -> None:
"""Record a TaskInstance to TaskInstanceHistory."""
exists_q = session.scalar(
select(func.count(TaskInstanceHistory.task_id)).where(
@@ -221,6 +221,6 @@ class TaskInstanceHistory(Base):
session.add(HITLDetailHistory(ti_hitl_detail))
@provide_session
- def get_dagrun(self, session: Session = NEW_SESSION) -> DagRun:
+ def get_dagrun(self, *, session: Session = NEW_SESSION) -> DagRun:
"""Return the DagRun for this TaskInstanceHistory, matching
TaskInstance."""
return self.dag_run
diff --git a/airflow-core/src/airflow/ti_deps/deps/dagrun_exists_dep.py
b/airflow-core/src/airflow/ti_deps/deps/dagrun_exists_dep.py
index a97b2714eaa..19d66b5c1f7 100644
--- a/airflow-core/src/airflow/ti_deps/deps/dagrun_exists_dep.py
+++ b/airflow-core/src/airflow/ti_deps/deps/dagrun_exists_dep.py
@@ -30,7 +30,7 @@ class DagrunRunningDep(BaseTIDep):
@provide_session
def _get_dep_statuses(self, ti, dep_context, *, session):
- dr = ti.get_dagrun(session)
+ dr = ti.get_dagrun(session=session)
if dr.state != DagRunState.RUNNING:
yield self._failing_status(
reason=f"Task instance's dagrun was not in the 'running' state
but in the state '{dr.state}'."
diff --git a/airflow-core/src/airflow/ti_deps/deps/mapped_task_upstream_dep.py
b/airflow-core/src/airflow/ti_deps/deps/mapped_task_upstream_dep.py
index 062be5f872a..e1957b9a342 100644
--- a/airflow-core/src/airflow/ti_deps/deps/mapped_task_upstream_dep.py
+++ b/airflow-core/src/airflow/ti_deps/deps/mapped_task_upstream_dep.py
@@ -104,6 +104,6 @@ class MappedTaskUpstreamDep(BaseTIDep):
new_state = TaskInstanceState.UPSTREAM_FAILED
elif TaskInstanceState.SKIPPED in finished_states:
new_state = TaskInstanceState.SKIPPED
- if new_state is not None and ti.set_state(new_state, session):
+ if new_state is not None and ti.set_state(new_state,
session=session):
dep_context.have_changed_ti_states = True
yield self._failing_status(reason="At least one of task's mapped
dependencies has not succeeded!")
diff --git
a/airflow-core/src/airflow/ti_deps/deps/not_previously_skipped_dep.py
b/airflow-core/src/airflow/ti_deps/deps/not_previously_skipped_dep.py
index e00ca976d53..4b767a04f23 100644
--- a/airflow-core/src/airflow/ti_deps/deps/not_previously_skipped_dep.py
+++ b/airflow-core/src/airflow/ti_deps/deps/not_previously_skipped_dep.py
@@ -48,7 +48,7 @@ class NotPreviouslySkippedDep(BaseTIDep):
upstream = ti.task.get_direct_relatives(upstream=True)
- finished_tis = dep_context.ensure_finished_tis(ti.get_dagrun(session),
session)
+ finished_tis =
dep_context.ensure_finished_tis(ti.get_dagrun(session=session), session=session)
finished_task_ids = {t.task_id for t in finished_tis}
@@ -100,7 +100,7 @@ class NotPreviouslySkippedDep(BaseTIDep):
reason="Task should be skipped but the past
depends are not met"
)
return
- ti.set_state(TaskInstanceState.SKIPPED, session)
+ ti.set_state(TaskInstanceState.SKIPPED, session=session)
yield self._failing_status(
reason=f"Skipping because of previous XCom result from
parent task {parent.task_id}"
)
diff --git a/airflow-core/src/airflow/ti_deps/deps/runnable_exec_date_dep.py
b/airflow-core/src/airflow/ti_deps/deps/runnable_exec_date_dep.py
index 7ed248fca09..96b5dfd3952 100644
--- a/airflow-core/src/airflow/ti_deps/deps/runnable_exec_date_dep.py
+++ b/airflow-core/src/airflow/ti_deps/deps/runnable_exec_date_dep.py
@@ -30,7 +30,7 @@ class RunnableExecDateDep(BaseTIDep):
@provide_session
def _get_dep_statuses(self, ti, dep_context, *, session):
- logical_date = ti.get_dagrun(session).logical_date
+ logical_date = ti.get_dagrun(session=session).logical_date
if logical_date is None:
return
diff --git a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
index 5ddba214dd3..ba69f513152 100644
--- a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
+++ b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
@@ -277,7 +277,7 @@ class TriggerRuleDep(BaseTIDep):
indirect_setups = {k: v for k, v in relevant_setups.items() if k
not in task.upstream_task_ids}
finished_upstream_tis = (
x
- for x in
dep_context.ensure_finished_tis(ti.get_dagrun(session), session)
+ for x in
dep_context.ensure_finished_tis(ti.get_dagrun(session=session), session=session)
if _is_relevant_upstream(upstream=x,
relevant_ids=indirect_setups.keys())
)
upstream_states =
_UpstreamTIStates.calculate(finished_upstream_tis)
@@ -332,7 +332,7 @@ class TriggerRuleDep(BaseTIDep):
changed,
)
return
- changed = ti.set_state(new_state, session)
+ changed = ti.set_state(new_state, session=session)
if changed:
dep_context.have_changed_ti_states = True
@@ -360,7 +360,9 @@ class TriggerRuleDep(BaseTIDep):
finished_upstream_tis = (
finished_ti
- for finished_ti in
dep_context.ensure_finished_tis(ti.get_dagrun(session), session)
+ for finished_ti in dep_context.ensure_finished_tis(
+ ti.get_dagrun(session=session), session=session
+ )
if _is_relevant_upstream(upstream=finished_ti,
relevant_ids=task.upstream_task_ids)
)
upstream_states =
_UpstreamTIStates.calculate(finished_upstream_tis)
@@ -465,7 +467,7 @@ class TriggerRuleDep(BaseTIDep):
reason="Task should be skipped but the past
depends are not met"
)
return
- changed = ti.set_state(new_state, session)
+ changed = ti.set_state(new_state, session=session)
if changed:
dep_context.have_changed_ti_states = True
@@ -652,7 +654,7 @@ class TriggerRuleDep(BaseTIDep):
done = sum(
1
- for x in
dep_context.ensure_finished_tis(ti.get_dagrun(session), session)
+ for x in
dep_context.ensure_finished_tis(ti.get_dagrun(session=session), session=session)
if _is_relevant_upstream(upstream=x, relevant_ids=in_scope_ids)
)
diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py
b/airflow-core/tests/unit/jobs/test_scheduler_job.py
index 872d133c122..3b576271610 100644
--- a/airflow-core/tests/unit/jobs/test_scheduler_job.py
+++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py
@@ -3052,10 +3052,10 @@ class TestSchedulerJob:
assert res == 6
session.flush()
for ti in tis1[:3] + tis2[:3]:
- ti.refresh_from_db(session)
+ ti.refresh_from_db(session=session)
assert ti.state == TaskInstanceState.QUEUED
for ti in tis1[3:] + tis2[3:]:
- ti.refresh_from_db(session)
+ ti.refresh_from_db(session=session)
assert ti.state == TaskInstanceState.SCHEDULED
# The remaining TIs are queued
@@ -3064,7 +3064,7 @@ class TestSchedulerJob:
session.flush()
for ti in tis1 + tis2:
- ti.refresh_from_db(session)
+ ti.refresh_from_db(session=session)
assert ti.state == State.QUEUED
@pytest.mark.parametrize(
@@ -4014,7 +4014,7 @@ class TestSchedulerJob:
session = settings.Session()
ti = dr.get_task_instance("dummy")
- ti.set_state(State.SUCCESS, session)
+ ti.set_state(State.SUCCESS, session=session)
with mock.patch("airflow.jobs.scheduler_job_runner.prohibit_commit")
as mock_guard:
mock_guard.return_value.__enter__.return_value.commit.side_effect
= session.commit
@@ -4092,7 +4092,7 @@ class TestSchedulerJob:
session = settings.Session()
dr = dag_maker.create_dagrun()
ti = dr.get_task_instance("test_task")
- ti.set_state(state, session)
+ ti.set_state(state, session=session)
self.job_runner._do_scheduling(session)
@@ -4126,7 +4126,7 @@ class TestSchedulerJob:
dr = dag_maker.create_dagrun()
ti = dr.get_task_instance("dummy")
- ti.set_state(State.SUCCESS, session)
+ ti.set_state(State.SUCCESS, session=session)
self.job_runner._do_scheduling(session)
diff --git a/airflow-core/tests/unit/models/test_cleartasks.py
b/airflow-core/tests/unit/models/test_cleartasks.py
index 58eb1b37b12..421f4776433 100644
--- a/airflow-core/tests/unit/models/test_cleartasks.py
+++ b/airflow-core/tests/unit/models/test_cleartasks.py
@@ -85,10 +85,10 @@ class TestClearTasks:
# but it works for our case because we specifically constructed
test DAGS
# in the way that those two sort methods are equivalent
qry = session.scalars(select(TI).where(TI.dag_id ==
dag.dag_id).order_by(TI.task_id)).all()
- clear_task_instances(qry, session)
+ clear_task_instances(qry, session=session)
- ti0.refresh_from_db(session)
- ti1.refresh_from_db(session)
+ ti0.refresh_from_db(session=session)
+ ti1.refresh_from_db(session=session)
# Next try to run will be try 2
assert ti0.state is None
diff --git a/airflow-core/tests/unit/models/test_dagrun.py
b/airflow-core/tests/unit/models/test_dagrun.py
index e8a946a6aa1..9405d8a9555 100644
--- a/airflow-core/tests/unit/models/test_dagrun.py
+++ b/airflow-core/tests/unit/models/test_dagrun.py
@@ -183,7 +183,7 @@ class TestDagRun:
ti = dag_run.get_task_instance(task_id)
if TYPE_CHECKING:
assert ti
- ti.set_state(task_state, session)
+ ti.set_state(task_state, session=session)
session.flush()
return dag_run
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py
b/airflow-core/tests/unit/models/test_taskinstance.py
index 3b4dceb628d..dae7f419ff4 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -437,7 +437,7 @@ class TestTaskInstance:
)
@provide_session
- def test_ti_updates_with_task(self, dag_maker, create_task_instance,
session):
+ def test_ti_updates_with_task(self, dag_maker, create_task_instance, *,
session: Session):
"""
test that updating the executor_config propagates to the TaskInstance
DB
"""
@@ -467,7 +467,7 @@ class TestTaskInstance:
run_task_instance(ti2, task2, session=session)
# Ensure it's reloaded
ti2.executor_config = None
- ti2.refresh_from_db(session)
+ ti2.refresh_from_db(session=session)
assert ti2.executor_config == {"bar": "baz"}
session.rollback()
@@ -520,7 +520,7 @@ class TestTaskInstance:
def run_with_error(ti):
with contextlib.suppress(AirflowException):
dag_maker.run_ti(ti.task_id, ti.dag_run)
- ti.refresh_from_db(session)
+ ti.refresh_from_db(session=session)
ti =
dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances[0]
ti.task = dag_maker.serialized_dag.get_task(ti.task_id)
@@ -1410,7 +1410,8 @@ class TestTaskInstance:
downstream_ti_state,
expected_are_dependents_done,
dag_maker,
- session,
+ *,
+ session: Session,
):
with dag_maker():
EmptyOperator(task_id="0") >>
EmptyOperator(task_id="downstream_task")
@@ -1418,11 +1419,11 @@ class TestTaskInstance:
dr = dag_maker.create_dagrun()
downstream_ti = dr.get_task_instance("downstream_task",
session=session)
- downstream_ti.set_state(downstream_ti_state, session)
+ downstream_ti.set_state(downstream_ti_state, session=session)
session.flush()
ti0 = dr.get_task_instance(task_id="0", session=session)
- assert ti0.are_dependents_done(session) == expected_are_dependents_done
+ assert ti0.are_dependents_done(session=session) ==
expected_are_dependents_done
def test_xcom_push_flag(self, dag_maker):
"""
@@ -1504,7 +1505,7 @@ class TestTaskInstance:
assert ti_from_deserialized_task.try_number == 0
@provide_session
- def test_external_executor_id_accepts_long_values(self,
create_task_instance, session):
+ def test_external_executor_id_accepts_long_values(self,
create_task_instance, *, session: Session):
"""Test that external_executor_id can store values exceeding 250
characters."""
# Kubernetes pod names and other executor IDs can exceed 250 chars
long_executor_id = "k8s-pod-" + "a" * 300 # 308 characters total
@@ -2314,7 +2315,7 @@ class TestTaskInstance:
assert ti_list[3].get_previous_ti(state=State.SUCCESS).run_id !=
ti_list[2].run_id
@provide_session
- def test_handle_failure_calls_listener(self, dag_maker, session):
+ def test_handle_failure_calls_listener(self, dag_maker, *, session:
Session):
class CustomOp(BaseOperator):
def execute(self, context): ...
diff --git a/providers/standard/tests/unit/standard/utils/test_sensor_helper.py
b/providers/standard/tests/unit/standard/utils/test_sensor_helper.py
index 8ba35d9b943..d52d1bc761d 100644
--- a/providers/standard/tests/unit/standard/utils/test_sensor_helper.py
+++ b/providers/standard/tests/unit/standard/utils/test_sensor_helper.py
@@ -114,7 +114,7 @@ class TestSensorHelper:
ti = dag_run.get_task_instance(task_id)
if TYPE_CHECKING:
assert ti
- ti.set_state(task_state, session)
+ ti.set_state(task_state, session=session)
session.flush()
@pytest.mark.parametrize(
diff --git a/scripts/ci/prek/known_provide_session_positional.txt
b/scripts/ci/prek/known_provide_session_positional.txt
index ec271c99df4..6ffd1b7cac0 100644
--- a/scripts/ci/prek/known_provide_session_positional.txt
+++ b/scripts/ci/prek/known_provide_session_positional.txt
@@ -5,11 +5,8 @@ airflow-core/src/airflow/models/pool.py::11
airflow-core/src/airflow/models/renderedtifields.py::4
airflow-core/src/airflow/models/revoked_token.py::2
airflow-core/src/airflow/models/serialized_dag.py::6
-airflow-core/src/airflow/models/taskinstance.py::21
-airflow-core/src/airflow/models/taskinstancehistory.py::2
airflow-core/src/airflow/models/team.py::1
airflow-core/src/airflow/models/trigger.py::7
airflow-core/src/airflow/models/variable.py::2
-airflow-core/tests/unit/models/test_taskinstance.py::4
airflow-core/tests/unit/models/test_timestamp.py::2
providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py::1