This is an automated email from the ASF dual-hosted git repository.

jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 201acda3545 Remove findings from positional session check in Core TI 
Modules (#67809)
201acda3545 is described below

commit 201acda35458470de9818ab8344a41b43a704399
Author: Jens Scheffler <[email protected]>
AuthorDate: Mon Jun 1 12:31:52 2026 +0200

    Remove findings from positional session check in Core TI Modules (#67809)
    
    * Fix exceptions of positional session use in airflow-core models task 
instance modules
    
    * Fix pytests
---
 .../execution_api/routes/task_instances.py         |  6 +--
 airflow-core/src/airflow/models/taskinstance.py    | 47 ++++++++++++++--------
 .../src/airflow/models/taskinstancehistory.py      |  4 +-
 .../src/airflow/ti_deps/deps/dagrun_exists_dep.py  |  2 +-
 .../ti_deps/deps/mapped_task_upstream_dep.py       |  2 +-
 .../ti_deps/deps/not_previously_skipped_dep.py     |  4 +-
 .../airflow/ti_deps/deps/runnable_exec_date_dep.py |  2 +-
 .../src/airflow/ti_deps/deps/trigger_rule_dep.py   | 12 +++---
 airflow-core/tests/unit/jobs/test_scheduler_job.py | 12 +++---
 airflow-core/tests/unit/models/test_cleartasks.py  |  6 +--
 airflow-core/tests/unit/models/test_dagrun.py      |  2 +-
 .../tests/unit/models/test_taskinstance.py         | 17 ++++----
 .../unit/standard/utils/test_sensor_helper.py      |  2 +-
 .../ci/prek/known_provide_session_positional.txt   |  3 --
 14 files changed, 68 insertions(+), 53 deletions(-)

diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index a0677c55701..2afd96806c4 100644
--- 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++ 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -578,7 +578,7 @@ def _create_ti_state_update_query_and_update_state(
                 _handle_fail_fast_for_dag(ti=ti, dag_id=dag_id, 
session=session, dag_bag=dag_bag)
         elif isinstance(ti_patch_payload, TIRetryStatePayload):
             if ti is not None:
-                ti.prepare_db_for_next_try(session)
+                ti.prepare_db_for_next_try(session=session)
             # Store retry policy overrides so next_retry_datetime() can read 
them.
             # These are cleared when the task enters RUNNING (ti_run).
             query = query.values(
@@ -591,7 +591,7 @@ def _create_ti_state_update_query_and_update_state(
                     ti,
                     ti_patch_payload.task_outlets,
                     ti_patch_payload.outlet_events,
-                    session,
+                    session=session,
                 )
         try:
             _emit_task_span(ti, state=updated_state)
@@ -894,7 +894,7 @@ def ti_put_rtif(
         raise HTTPException(
             status_code=status.HTTP_404_NOT_FOUND,
         )
-    task_instance.update_rtif(put_rtif_payload, session)
+    task_instance.update_rtif(put_rtif_payload, session=session)
     log.debug("RenderedTaskInstanceFields updated successfully")
 
     return {"message": "Rendered task instance fields successfully set"}
diff --git a/airflow-core/src/airflow/models/taskinstance.py 
b/airflow-core/src/airflow/models/taskinstance.py
index cd254e53a83..f9d498c1ab1 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -148,6 +148,7 @@ def _add_log(
     owner=None,
     owner_display_name=None,
     extra=None,
+    *,
     session: Session = NEW_SESSION,
     **kwargs,
 ):
@@ -191,7 +192,7 @@ def _stop_remaining_tasks(*, task_instance: TaskInstance, 
task_teardown_map=None
                 log.info("Forcing task %s to fail due to dag's `fail_fast` 
setting", ti.task_id)
                 msg = "Forcing task to fail due to dag's `fail_fast` setting."
                 session.add(Log(event="fail task", extra=msg, 
task_instance=ti.key))
-                ti.error(session)
+                ti.error(session=session)
             else:
                 log.info("Setting task %s to SKIPPED due to dag's `fail_fast` 
setting.", ti.task_id)
                 msg = "Skipping task due to dag's `fail_fast` setting."
@@ -822,7 +823,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
         return self.log_url
 
     @provide_session
-    def error(self, session: Session = NEW_SESSION) -> None:
+    def error(self, *, session: Session = NEW_SESSION) -> None:
         """
         Force the task instance's state to FAILED in the database.
 
@@ -842,6 +843,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
         task_id: str,
         map_index: int,
         lock_for_update: bool = False,
+        *,
         session: Session = NEW_SESSION,
     ) -> TaskInstance | None:
         query = (
@@ -866,7 +868,11 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
 
     @provide_session
     def refresh_from_db(
-        self, session: Session = NEW_SESSION, lock_for_update: bool = False, 
keep_local_changes: bool = False
+        self,
+        *,
+        session: Session = NEW_SESSION,
+        lock_for_update: bool = False,
+        keep_local_changes: bool = False,
     ) -> None:
         """
         Refresh the task instance from the database based on the primary key.
@@ -962,7 +968,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
         return self.executor
 
     @provide_session
-    def set_state(self, state: str | None, session: Session = NEW_SESSION) -> 
bool:
+    def set_state(self, state: str | None, *, session: Session = NEW_SESSION) 
-> bool:
         """
         Set TaskInstance state.
 
@@ -976,7 +982,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
         current_time = timezone.utcnow()
         self.log.debug("Setting task state for %s to %s", self, state)
         if self not in session:
-            self.refresh_from_db(session)
+            self.refresh_from_db(session=session)
         self.state = state
         self.start_date = self.start_date or current_time
         if self.state in State.finished or self.state == 
TaskInstanceState.UP_FOR_RETRY:
@@ -1001,7 +1007,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
         self.id = uuid7()
 
     @provide_session
-    def are_dependents_done(self, session: Session = NEW_SESSION) -> bool:
+    def are_dependents_done(self, *, session: Session = NEW_SESSION) -> bool:
         """
         Check whether the immediate dependents of this task instance have 
succeeded or have been skipped.
 
@@ -1033,6 +1039,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
     def get_previous_dagrun(
         self,
         state: DagRunState | None = None,
+        *,
         session: Session | None = None,
     ) -> DagRun | None:
         """
@@ -1073,6 +1080,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
     def get_previous_ti(
         self,
         state: DagRunState | None = None,
+        *,
         session: Session = NEW_SESSION,
     ) -> TaskInstance | None:
         """
@@ -1088,7 +1096,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
 
     @provide_session
     def are_dependencies_met(
-        self, dep_context: DepContext | None = None, session: Session = 
NEW_SESSION, verbose: bool = False
+        self, dep_context: DepContext | None = None, *, session: Session = 
NEW_SESSION, verbose: bool = False
     ) -> bool:
         """
         Are all conditions met for this task instance to be run given the 
context for the dependencies.
@@ -1129,7 +1137,9 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
         return True
 
     @provide_session
-    def get_failed_dep_statuses(self, dep_context: DepContext | None = None, 
session: Session = NEW_SESSION):
+    def get_failed_dep_statuses(
+        self, dep_context: DepContext | None = None, *, session: Session = 
NEW_SESSION
+    ):
         """Get failed Dependencies."""
         if TYPE_CHECKING:
             assert self.task is not None
@@ -1229,7 +1239,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
         return dr
 
     @provide_session
-    def get_dagrun(self, session: Session = NEW_SESSION) -> DagRun:
+    def get_dagrun(self, *, session: Session = NEW_SESSION) -> DagRun:
         """
         Return the DagRun for this TaskInstance.
 
@@ -1273,6 +1283,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
         hostname: str = "",
         pool: str | None = None,
         external_executor_id: str | None = None,
+        *,
         session: Session = NEW_SESSION,
     ) -> bool:
         """
@@ -1410,6 +1421,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
         test_mode: bool = False,
         pool: str | None = None,
         external_executor_id: str | None = None,
+        *,
         session: Session = NEW_SESSION,
     ) -> bool:
         return TaskInstance._check_and_change_state_before_execution(
@@ -1489,6 +1501,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
         ti: TaskInstance,
         task_outlets: list[AssetProfile],
         outlet_events: list[dict[str, Any]],
+        *,
         session: Session = NEW_SESSION,
     ) -> None:
         from airflow.serialization.definitions.assets import (
@@ -1667,7 +1680,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
                     )
 
     @provide_session
-    def update_rtif(self, rendered_fields, session: Session = NEW_SESSION):
+    def update_rtif(self, rendered_fields, *, session: Session = NEW_SESSION):
         from airflow.models.renderedtifields import RenderedTaskInstanceFields
 
         rtif = RenderedTaskInstanceFields(ti=self, render_templates=False, 
rendered_fields=rendered_fields)
@@ -1696,7 +1709,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
     #       the side effect is the changes done to the task instance aren't 
picked up by the scheduler and
     #       thus the task instance isn't processed until the scheduler is 
restarted.
     @provide_session
-    def defer_task(self, session: Session = NEW_SESSION) -> bool:
+    def defer_task(self, *, session: Session = NEW_SESSION) -> bool:
         """
         Mark the task as deferred and sets up the trigger that is needed to 
resume it when TaskDeferred is raised.
 
@@ -1773,7 +1786,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
         if error:
             cls.logger().error("%s", error)
         if not test_mode:
-            ti.refresh_from_db(session)
+            ti.refresh_from_db(session=session)
 
         ti.end_date = timezone.utcnow()
         ti.set_duration()
@@ -1825,7 +1838,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
 
     @staticmethod
     @provide_session
-    def save_to_db(ti: TaskInstance, session: Session = NEW_SESSION):
+    def save_to_db(ti: TaskInstance, *, session: Session = NEW_SESSION):
         ti.updated_at = timezone.utcnow()
         session.merge(ti)
         session.flush()
@@ -1836,6 +1849,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
         self,
         error: None | str,
         test_mode: bool | None = None,
+        *,
         session: Session = NEW_SESSION,
     ) -> None:
         """
@@ -1865,7 +1879,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
         _log_state(task_instance=self)
 
         if not test_mode:
-            TaskInstance.save_to_db(ti, session)
+            TaskInstance.save_to_db(ti, session=session)
 
     def is_eligible_to_retry(self) -> bool:
         """Is task instance is eligible for retry."""
@@ -1896,6 +1910,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
         self,
         key: str,
         value: Any,
+        *,
         session: Session = NEW_SESSION,
     ) -> None:
         """
@@ -1921,8 +1936,8 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
         dag_id: str | None = None,
         key: str = XCOM_RETURN_KEY,
         include_prior_dates: bool = False,
-        session: Session = NEW_SESSION,
         *,
+        session: Session = NEW_SESSION,
         map_indexes: int | Iterable[int] | None = None,
         default: Any = None,
         run_id: str | None = None,
@@ -1998,7 +2013,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
         )
 
     @provide_session
-    def get_num_running_task_instances(self, session: Session, same_dagrun: 
bool = False) -> int:
+    def get_num_running_task_instances(self, *, session: Session, same_dagrun: 
bool = False) -> int:
         """Count running TIs from the DB."""
         warnings.warn(
             "This function is deprecated and will be removed in Airflow.",
diff --git a/airflow-core/src/airflow/models/taskinstancehistory.py 
b/airflow-core/src/airflow/models/taskinstancehistory.py
index 4989b949a52..b0d55114bb6 100644
--- a/airflow-core/src/airflow/models/taskinstancehistory.py
+++ b/airflow-core/src/airflow/models/taskinstancehistory.py
@@ -195,7 +195,7 @@ class TaskInstanceHistory(Base):
 
     @staticmethod
     @provide_session
-    def record_ti(ti: TaskInstance, session: Session = NEW_SESSION) -> None:
+    def record_ti(ti: TaskInstance, *, session: Session = NEW_SESSION) -> None:
         """Record a TaskInstance to TaskInstanceHistory."""
         exists_q = session.scalar(
             select(func.count(TaskInstanceHistory.task_id)).where(
@@ -221,6 +221,6 @@ class TaskInstanceHistory(Base):
             session.add(HITLDetailHistory(ti_hitl_detail))
 
     @provide_session
-    def get_dagrun(self, session: Session = NEW_SESSION) -> DagRun:
+    def get_dagrun(self, *, session: Session = NEW_SESSION) -> DagRun:
         """Return the DagRun for this TaskInstanceHistory, matching 
TaskInstance."""
         return self.dag_run
diff --git a/airflow-core/src/airflow/ti_deps/deps/dagrun_exists_dep.py 
b/airflow-core/src/airflow/ti_deps/deps/dagrun_exists_dep.py
index a97b2714eaa..19d66b5c1f7 100644
--- a/airflow-core/src/airflow/ti_deps/deps/dagrun_exists_dep.py
+++ b/airflow-core/src/airflow/ti_deps/deps/dagrun_exists_dep.py
@@ -30,7 +30,7 @@ class DagrunRunningDep(BaseTIDep):
 
     @provide_session
     def _get_dep_statuses(self, ti, dep_context, *, session):
-        dr = ti.get_dagrun(session)
+        dr = ti.get_dagrun(session=session)
         if dr.state != DagRunState.RUNNING:
             yield self._failing_status(
                 reason=f"Task instance's dagrun was not in the 'running' state 
but in the state '{dr.state}'."
diff --git a/airflow-core/src/airflow/ti_deps/deps/mapped_task_upstream_dep.py 
b/airflow-core/src/airflow/ti_deps/deps/mapped_task_upstream_dep.py
index 062be5f872a..e1957b9a342 100644
--- a/airflow-core/src/airflow/ti_deps/deps/mapped_task_upstream_dep.py
+++ b/airflow-core/src/airflow/ti_deps/deps/mapped_task_upstream_dep.py
@@ -104,6 +104,6 @@ class MappedTaskUpstreamDep(BaseTIDep):
                 new_state = TaskInstanceState.UPSTREAM_FAILED
             elif TaskInstanceState.SKIPPED in finished_states:
                 new_state = TaskInstanceState.SKIPPED
-            if new_state is not None and ti.set_state(new_state, session):
+            if new_state is not None and ti.set_state(new_state, 
session=session):
                 dep_context.have_changed_ti_states = True
         yield self._failing_status(reason="At least one of task's mapped 
dependencies has not succeeded!")
diff --git 
a/airflow-core/src/airflow/ti_deps/deps/not_previously_skipped_dep.py 
b/airflow-core/src/airflow/ti_deps/deps/not_previously_skipped_dep.py
index e00ca976d53..4b767a04f23 100644
--- a/airflow-core/src/airflow/ti_deps/deps/not_previously_skipped_dep.py
+++ b/airflow-core/src/airflow/ti_deps/deps/not_previously_skipped_dep.py
@@ -48,7 +48,7 @@ class NotPreviouslySkippedDep(BaseTIDep):
 
         upstream = ti.task.get_direct_relatives(upstream=True)
 
-        finished_tis = dep_context.ensure_finished_tis(ti.get_dagrun(session), 
session)
+        finished_tis = 
dep_context.ensure_finished_tis(ti.get_dagrun(session=session), session=session)
 
         finished_task_ids = {t.task_id for t in finished_tis}
 
@@ -100,7 +100,7 @@ class NotPreviouslySkippedDep(BaseTIDep):
                                 reason="Task should be skipped but the past 
depends are not met"
                             )
                             return
-                    ti.set_state(TaskInstanceState.SKIPPED, session)
+                    ti.set_state(TaskInstanceState.SKIPPED, session=session)
                     yield self._failing_status(
                         reason=f"Skipping because of previous XCom result from 
parent task {parent.task_id}"
                     )
diff --git a/airflow-core/src/airflow/ti_deps/deps/runnable_exec_date_dep.py 
b/airflow-core/src/airflow/ti_deps/deps/runnable_exec_date_dep.py
index 7ed248fca09..96b5dfd3952 100644
--- a/airflow-core/src/airflow/ti_deps/deps/runnable_exec_date_dep.py
+++ b/airflow-core/src/airflow/ti_deps/deps/runnable_exec_date_dep.py
@@ -30,7 +30,7 @@ class RunnableExecDateDep(BaseTIDep):
 
     @provide_session
     def _get_dep_statuses(self, ti, dep_context, *, session):
-        logical_date = ti.get_dagrun(session).logical_date
+        logical_date = ti.get_dagrun(session=session).logical_date
         if logical_date is None:
             return
 
diff --git a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py 
b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
index 5ddba214dd3..ba69f513152 100644
--- a/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
+++ b/airflow-core/src/airflow/ti_deps/deps/trigger_rule_dep.py
@@ -277,7 +277,7 @@ class TriggerRuleDep(BaseTIDep):
             indirect_setups = {k: v for k, v in relevant_setups.items() if k 
not in task.upstream_task_ids}
             finished_upstream_tis = (
                 x
-                for x in 
dep_context.ensure_finished_tis(ti.get_dagrun(session), session)
+                for x in 
dep_context.ensure_finished_tis(ti.get_dagrun(session=session), session=session)
                 if _is_relevant_upstream(upstream=x, 
relevant_ids=indirect_setups.keys())
             )
             upstream_states = 
_UpstreamTIStates.calculate(finished_upstream_tis)
@@ -332,7 +332,7 @@ class TriggerRuleDep(BaseTIDep):
                             changed,
                         )
                         return
-                changed = ti.set_state(new_state, session)
+                changed = ti.set_state(new_state, session=session)
 
             if changed:
                 dep_context.have_changed_ti_states = True
@@ -360,7 +360,9 @@ class TriggerRuleDep(BaseTIDep):
 
             finished_upstream_tis = (
                 finished_ti
-                for finished_ti in 
dep_context.ensure_finished_tis(ti.get_dagrun(session), session)
+                for finished_ti in dep_context.ensure_finished_tis(
+                    ti.get_dagrun(session=session), session=session
+                )
                 if _is_relevant_upstream(upstream=finished_ti, 
relevant_ids=task.upstream_task_ids)
             )
             upstream_states = 
_UpstreamTIStates.calculate(finished_upstream_tis)
@@ -465,7 +467,7 @@ class TriggerRuleDep(BaseTIDep):
                             reason="Task should be skipped but the past 
depends are not met"
                         )
                         return
-                changed = ti.set_state(new_state, session)
+                changed = ti.set_state(new_state, session=session)
 
             if changed:
                 dep_context.have_changed_ti_states = True
@@ -652,7 +654,7 @@ class TriggerRuleDep(BaseTIDep):
 
             done = sum(
                 1
-                for x in 
dep_context.ensure_finished_tis(ti.get_dagrun(session), session)
+                for x in 
dep_context.ensure_finished_tis(ti.get_dagrun(session=session), session=session)
                 if _is_relevant_upstream(upstream=x, relevant_ids=in_scope_ids)
             )
 
diff --git a/airflow-core/tests/unit/jobs/test_scheduler_job.py 
b/airflow-core/tests/unit/jobs/test_scheduler_job.py
index 872d133c122..3b576271610 100644
--- a/airflow-core/tests/unit/jobs/test_scheduler_job.py
+++ b/airflow-core/tests/unit/jobs/test_scheduler_job.py
@@ -3052,10 +3052,10 @@ class TestSchedulerJob:
         assert res == 6
         session.flush()
         for ti in tis1[:3] + tis2[:3]:
-            ti.refresh_from_db(session)
+            ti.refresh_from_db(session=session)
             assert ti.state == TaskInstanceState.QUEUED
         for ti in tis1[3:] + tis2[3:]:
-            ti.refresh_from_db(session)
+            ti.refresh_from_db(session=session)
             assert ti.state == TaskInstanceState.SCHEDULED
 
         # The remaining TIs are queued
@@ -3064,7 +3064,7 @@ class TestSchedulerJob:
         session.flush()
 
         for ti in tis1 + tis2:
-            ti.refresh_from_db(session)
+            ti.refresh_from_db(session=session)
             assert ti.state == State.QUEUED
 
     @pytest.mark.parametrize(
@@ -4014,7 +4014,7 @@ class TestSchedulerJob:
         session = settings.Session()
 
         ti = dr.get_task_instance("dummy")
-        ti.set_state(State.SUCCESS, session)
+        ti.set_state(State.SUCCESS, session=session)
 
         with mock.patch("airflow.jobs.scheduler_job_runner.prohibit_commit") 
as mock_guard:
             mock_guard.return_value.__enter__.return_value.commit.side_effect 
= session.commit
@@ -4092,7 +4092,7 @@ class TestSchedulerJob:
         session = settings.Session()
         dr = dag_maker.create_dagrun()
         ti = dr.get_task_instance("test_task")
-        ti.set_state(state, session)
+        ti.set_state(state, session=session)
 
         self.job_runner._do_scheduling(session)
 
@@ -4126,7 +4126,7 @@ class TestSchedulerJob:
         dr = dag_maker.create_dagrun()
 
         ti = dr.get_task_instance("dummy")
-        ti.set_state(State.SUCCESS, session)
+        ti.set_state(State.SUCCESS, session=session)
 
         self.job_runner._do_scheduling(session)
 
diff --git a/airflow-core/tests/unit/models/test_cleartasks.py 
b/airflow-core/tests/unit/models/test_cleartasks.py
index 58eb1b37b12..421f4776433 100644
--- a/airflow-core/tests/unit/models/test_cleartasks.py
+++ b/airflow-core/tests/unit/models/test_cleartasks.py
@@ -85,10 +85,10 @@ class TestClearTasks:
             # but it works for our case because we specifically constructed 
test DAGS
             # in the way that those two sort methods are equivalent
             qry = session.scalars(select(TI).where(TI.dag_id == 
dag.dag_id).order_by(TI.task_id)).all()
-            clear_task_instances(qry, session)
+            clear_task_instances(qry, session=session)
 
-            ti0.refresh_from_db(session)
-            ti1.refresh_from_db(session)
+            ti0.refresh_from_db(session=session)
+            ti1.refresh_from_db(session=session)
 
         # Next try to run will be try 2
         assert ti0.state is None
diff --git a/airflow-core/tests/unit/models/test_dagrun.py 
b/airflow-core/tests/unit/models/test_dagrun.py
index e8a946a6aa1..9405d8a9555 100644
--- a/airflow-core/tests/unit/models/test_dagrun.py
+++ b/airflow-core/tests/unit/models/test_dagrun.py
@@ -183,7 +183,7 @@ class TestDagRun:
                 ti = dag_run.get_task_instance(task_id)
                 if TYPE_CHECKING:
                     assert ti
-                ti.set_state(task_state, session)
+                ti.set_state(task_state, session=session)
             session.flush()
 
         return dag_run
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py 
b/airflow-core/tests/unit/models/test_taskinstance.py
index 3b4dceb628d..dae7f419ff4 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -437,7 +437,7 @@ class TestTaskInstance:
             )
 
     @provide_session
-    def test_ti_updates_with_task(self, dag_maker, create_task_instance, 
session):
+    def test_ti_updates_with_task(self, dag_maker, create_task_instance, *, 
session: Session):
         """
         test that updating the executor_config propagates to the TaskInstance 
DB
         """
@@ -467,7 +467,7 @@ class TestTaskInstance:
         run_task_instance(ti2, task2, session=session)
         # Ensure it's reloaded
         ti2.executor_config = None
-        ti2.refresh_from_db(session)
+        ti2.refresh_from_db(session=session)
         assert ti2.executor_config == {"bar": "baz"}
         session.rollback()
 
@@ -520,7 +520,7 @@ class TestTaskInstance:
         def run_with_error(ti):
             with contextlib.suppress(AirflowException):
                 dag_maker.run_ti(ti.task_id, ti.dag_run)
-            ti.refresh_from_db(session)
+            ti.refresh_from_db(session=session)
 
         ti = 
dag_maker.create_dagrun(logical_date=timezone.utcnow()).task_instances[0]
         ti.task = dag_maker.serialized_dag.get_task(ti.task_id)
@@ -1410,7 +1410,8 @@ class TestTaskInstance:
         downstream_ti_state,
         expected_are_dependents_done,
         dag_maker,
-        session,
+        *,
+        session: Session,
     ):
         with dag_maker():
             EmptyOperator(task_id="0") >> 
EmptyOperator(task_id="downstream_task")
@@ -1418,11 +1419,11 @@ class TestTaskInstance:
         dr = dag_maker.create_dagrun()
         downstream_ti = dr.get_task_instance("downstream_task", 
session=session)
 
-        downstream_ti.set_state(downstream_ti_state, session)
+        downstream_ti.set_state(downstream_ti_state, session=session)
         session.flush()
 
         ti0 = dr.get_task_instance(task_id="0", session=session)
-        assert ti0.are_dependents_done(session) == expected_are_dependents_done
+        assert ti0.are_dependents_done(session=session) == 
expected_are_dependents_done
 
     def test_xcom_push_flag(self, dag_maker):
         """
@@ -1504,7 +1505,7 @@ class TestTaskInstance:
         assert ti_from_deserialized_task.try_number == 0
 
     @provide_session
-    def test_external_executor_id_accepts_long_values(self, 
create_task_instance, session):
+    def test_external_executor_id_accepts_long_values(self, 
create_task_instance, *, session: Session):
         """Test that external_executor_id can store values exceeding 250 
characters."""
         # Kubernetes pod names and other executor IDs can exceed 250 chars
         long_executor_id = "k8s-pod-" + "a" * 300  # 308 characters total
@@ -2314,7 +2315,7 @@ class TestTaskInstance:
         assert ti_list[3].get_previous_ti(state=State.SUCCESS).run_id != 
ti_list[2].run_id
 
     @provide_session
-    def test_handle_failure_calls_listener(self, dag_maker, session):
+    def test_handle_failure_calls_listener(self, dag_maker, *, session: 
Session):
         class CustomOp(BaseOperator):
             def execute(self, context): ...
 
diff --git a/providers/standard/tests/unit/standard/utils/test_sensor_helper.py 
b/providers/standard/tests/unit/standard/utils/test_sensor_helper.py
index 8ba35d9b943..d52d1bc761d 100644
--- a/providers/standard/tests/unit/standard/utils/test_sensor_helper.py
+++ b/providers/standard/tests/unit/standard/utils/test_sensor_helper.py
@@ -114,7 +114,7 @@ class TestSensorHelper:
                 ti = dag_run.get_task_instance(task_id)
                 if TYPE_CHECKING:
                     assert ti
-                ti.set_state(task_state, session)
+                ti.set_state(task_state, session=session)
             session.flush()
 
     @pytest.mark.parametrize(
diff --git a/scripts/ci/prek/known_provide_session_positional.txt 
b/scripts/ci/prek/known_provide_session_positional.txt
index ec271c99df4..6ffd1b7cac0 100644
--- a/scripts/ci/prek/known_provide_session_positional.txt
+++ b/scripts/ci/prek/known_provide_session_positional.txt
@@ -5,11 +5,8 @@ airflow-core/src/airflow/models/pool.py::11
 airflow-core/src/airflow/models/renderedtifields.py::4
 airflow-core/src/airflow/models/revoked_token.py::2
 airflow-core/src/airflow/models/serialized_dag.py::6
-airflow-core/src/airflow/models/taskinstance.py::21
-airflow-core/src/airflow/models/taskinstancehistory.py::2
 airflow-core/src/airflow/models/team.py::1
 airflow-core/src/airflow/models/trigger.py::7
 airflow-core/src/airflow/models/variable.py::2
-airflow-core/tests/unit/models/test_taskinstance.py::4
 airflow-core/tests/unit/models/test_timestamp.py::2
 providers/edge3/src/airflow/providers/edge3/executors/edge_executor.py::1

Reply via email to