This is an automated email from the ASF dual-hosted git repository.

jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 471e9a17faf Pass DagRun to task_instance_mutation_hook for run-aware 
task mutation (#68198)
471e9a17faf is described below

commit 471e9a17faf2c36b7a1a3123999524fe2d393cca
Author: Dheeraj Turaga <[email protected]>
AuthorDate: Wed Jun 10 16:53:01 2026 -0500

    Pass DagRun to task_instance_mutation_hook for run-aware task mutation 
(#68198)
    
    * Pass DagRun to task_instance_mutation_hook for run-aware task mutation
    
    Cluster policies could mutate a TaskInstance before scheduling but had no
    supported way to read the DagRun it belongs to -- so routing a task to a 
queue
    based on the run's configuration was impossible. Reaching the DagRun via
    TaskInstance.get_dagrun() opens a nested committing session that trips the
    scheduler's prohibit_commit guard and crashes the scheduler.
    
    Add an optional dag_run argument to the task_instance_mutation_hook policy 
hook
    and thread the in-scope DagRun through every scheduler-side call site (no 
extra
    database access). A hook can now route on dag_run.conf, for example:
    
        def task_instance_mutation_hook(task_instance, dag_run=None):
            if dag_run is not None and (dag_run.conf or {}).get("route") == 
"high":
                task_instance.queue = "high_priority_queue"
    
    Backward compatibility is preserved through the local-settings plugin shim.
    Pluggy passes a hookimpl only the parameters it declares without a default, 
so
    this commit also makes make_plugin_from_local_settings register a function 
as-is
    when its parameters are a subset of the hookspec's, and shim it (forwarding 
the
    call positionally) when it renames a parameter or gives a hookspec 
parameter a
    default value. As a result every signature keeps working: the legacy
    single-argument hook, an explicit (task_instance, dag_run) hook, and the
    ergonomic (task_instance, dag_run=None) hook all behave correctly; a hook
    declaring an unknown parameter is still rejected.
    
    The dag_run may be None during early task-instance construction (when the 
hook
    is re-applied from refresh_from_task before the instance is bound to a 
run), and
    the hook must not open a new session; both are documented on the hookspec 
and in
    the cluster-policies guide, which also corrects the prior claim that the 
hook
    runs on the worker.
    
    * Store resolved param defaults in dag_run.conf for runs with no explicit 
conf
    
    When a DAG run is created without an explicit conf dict (all scheduled runs,
    and manual triggers that omit config), dag_run.conf was stored as None even
    though the param defaults had already been merged and validated. Any runtime
    consumer reading dag_run.conf -- tasks using {{ dag_run.conf }}, the
    task_instance_mutation_hook reading dag_run.conf for queue routing, sensors
    checking conf values -- silently got None for scheduled runs.
    
    The fix is one line: pass the already-computed copied_params dict as conf 
when
    the caller supplied none. copied_params is the result of deep_merge(conf) 
and
    validate(), so no extra work is done -- the resolved value is just preserved
    instead of discarded.
    
    Manual/triggered runs that supply an explicit conf are unaffected: their
    caller-supplied dict is stored as before.
    
    * Address PR review feedback on task_instance_mutation_hook dag_run arg
    
    - Revert conf=params fallback in create_dagrun (unrelated behaviour change;
      document that scheduled runs carry empty conf instead)
    - Add type annotations to settings.task_instance_mutation_hook
    - Hoist dr = unmapped_ti.dag_run in taskmap.py to avoid a second lazy-load
    - Rename defaults_a_spec_param -> has_defaulted_spec_param; inline 
target_arity
    - Remove redundant `or {}` guard on dag_run.conf (always a dict)
    - Delete newsfragment (not a major/breaking change)
    
    * Fix test hooks to accept dag_run kwarg after hookspec expansion
    
    Upstream main added five tests in test_dagrun.py that define single-arg
    mutation hooks (task_instance) registered via mock side_effect. After
    rebasing onto main, the hookspec now passes dag_run as a keyword argument
    to every hook, so these hooks crashed with an unexpected keyword argument.
    
    Add dag_run=None to each hook signature so they accept the new argument
    without changing their behaviour. Also restore the (dag_run.conf or {})
    guard in cluster_policies/__init__.py -- DagRun.conf is typed as
    dict | None at the ORM level, so a plain .get() fails mypy.
    
    * Simplify needs_shim condition and defer dag_run lazy-load in taskmap
    
    - policies.py: collapse names_are_subset + has_defaulted_spec_param into a
      single needs_shim boolean; expand comment to explain that pluggy puts
      defaulted params into kwargnames and never routes them (the why)
    - taskmap.py: move dr = unmapped_ti.dag_run inside the else: branch where
      it is actually consumed, avoiding an unconditional lazy DB load when
      total_length is None or < 1 (upstream-failed / skipped paths)
    
    * Fix compat and mypy failures from static checks CI
    
    - pytest_plugin.py: gate dag_run= kwarg on AIRFLOW_V_3_3_PLUS; the param
      was added to refresh_from_task in this PR (targeting 3.3) and does not
      exist on 3.0.x/3.1.x/3.2.x used by the provider compat test matrix
    - check_partition_mapper_defaults_in_sync.py: fix mypy union-attr errors
      by narrowing stmt inside each isinstance branch before accessing .value,
      rather than using a target_name sentinel that leaves stmt unnarrowed
    
    * Simplify version gate and partition mapper script
    
    - pytest_plugin.py: use dag (local) instead of self.dag in the
      AIRFLOW_V_3_0_PLUS elif branch, matching the 3.3+ branch above it
    - check_partition_mapper_defaults_in_sync.py: extract _attr_value() helper
      to eliminate the duplicated 'value = stmt.value' bodies; _find_attr_value
      becomes a straight-line loop (call helper, skip None, validate, return)
    
    * Fix dag_maker dropping refresh_from_task on Airflow 2.x compat path
    
    The version-gate refactor for the dag_run kwarg removed the pre-3.0
    else branch, so under Airflow 2.x neither the 3.3+ nor 3.0+ condition
    matched and refresh_from_task was never called. Task instances kept
    ti.task as None, failing ~100 provider compat tests with
    "'NoneType' object has no attribute 'dag'/'operator_extra_links'/'queue'".
    Restore the else branch that calls refresh_from_task for 2.x.
---
 .../cluster-policies.rst                           | 19 ++++--
 airflow-core/src/airflow/models/dagrun.py          |  8 +--
 airflow-core/src/airflow/models/taskinstance.py    | 10 ++-
 airflow-core/src/airflow/models/taskmap.py         | 11 ++--
 airflow-core/src/airflow/policies.py               | 54 ++++++++++++---
 airflow-core/src/airflow/settings.py               |  8 ++-
 .../tests/unit/cluster_policies/__init__.py        |  8 ++-
 airflow-core/tests/unit/core/test_policies.py      | 77 ++++++++++++++++++++++
 airflow-core/tests/unit/models/test_dagrun.py      | 51 ++++++++++++--
 .../tests/unit/models/test_taskinstance.py         |  2 +-
 devel-common/src/tests_common/pytest_plugin.py     |  4 +-
 .../check_partition_mapper_defaults_in_sync.py     | 30 +++++----
 12 files changed, 232 insertions(+), 50 deletions(-)

diff --git 
a/airflow-core/docs/administration-and-deployment/cluster-policies.rst 
b/airflow-core/docs/administration-and-deployment/cluster-policies.rst
index fd7d0b622f5..e23fd7ffad2 100644
--- a/airflow-core/docs/administration-and-deployment/cluster-policies.rst
+++ b/airflow-core/docs/administration-and-deployment/cluster-policies.rst
@@ -37,10 +37,21 @@ There are three main types of cluster policy:
   task running in a DagRun. The ``task_policy`` defined is applied to all the 
task instances that will be
   executed in the future.
 * ``task_instance_mutation_hook``: Takes a 
:class:`~airflow.models.taskinstance.TaskInstance` parameter called
-  ``task_instance``. The ``task_instance_mutation_hook`` applies not to a task 
but to the instance of a task that
-  relates to a particular DagRun. It is executed in a "worker", not in the Dag 
file processor, just before the
-  task instance is executed. The policy is only applied to the currently 
executed run (i.e. instance) of that
-  task.
+  ``task_instance`` and an optional :class:`~airflow.models.dagrun.DagRun` 
parameter called ``dag_run``. The
+  ``task_instance_mutation_hook`` applies not to a task but to the instance of 
a task that
+  relates to a particular DagRun. It is executed scheduler-side while task 
instances are created or
+  reconciled (not in the Dag file processor, and not on the worker). The 
policy is only applied to the
+  currently executed run (i.e. instance) of that task. The ``dag_run`` 
argument lets the policy route on
+  run configuration (``dag_run.conf``); it may be ``None`` in early 
task-instance construction, and a hook
+  that only declares ``task_instance`` keeps working unchanged. Note that 
``dag_run.conf`` is only populated
+  for manually triggered or API-triggered runs; scheduled runs carry an empty 
``conf``.
+
+.. warning::
+
+    ``task_instance_mutation_hook`` runs inside a scheduler transaction that 
prohibits committing the
+    session. Do not open a new database session or commit from within the hook 
-- in particular, do not
+    call ``task_instance.get_dagrun()`` without passing the active session, as 
the resulting commit
+    crashes the scheduler. Use the ``dag_run`` argument to read run 
configuration instead.
 
 The Dag and Task cluster policies can raise the  
:class:`~airflow.exceptions.AirflowClusterPolicyViolation`
 exception to indicate that the Dag/task they were passed is not compliant and 
should not be loaded.
diff --git a/airflow-core/src/airflow/models/dagrun.py 
b/airflow-core/src/airflow/models/dagrun.py
index 364b6f0953b..09cb8b1206f 100644
--- a/airflow-core/src/airflow/models/dagrun.py
+++ b/airflow-core/src/airflow/models/dagrun.py
@@ -1722,7 +1722,7 @@ class DagRun(Base, LoggingMixin):
         # check for removed or restored tasks
         task_ids = set()
         for ti in tis:
-            ti_mutation_hook(ti)
+            ti_mutation_hook(ti, dag_run=self)
             task_ids.add(ti.task_id)
             try:
                 task = dag.get_task(ti.task_id)
@@ -1841,7 +1841,7 @@ class DagRun(Base, LoggingMixin):
             def create_ti(task: Operator, indexes: Iterable[int]) -> 
Iterator[TI]:
                 for map_index in indexes:
                     ti = TI(task, run_id=self.run_id, map_index=map_index, 
dag_version_id=dag_version_id)
-                    ti_mutation_hook(ti)
+                    ti_mutation_hook(ti, dag_run=self)
                     if ti.operator:
                         created_counts[ti.operator] += 1
                     yield ti
@@ -1980,9 +1980,9 @@ class DagRun(Base, LoggingMixin):
                 continue
             ti = TI(task, run_id=self.run_id, map_index=index, state=None, 
dag_version_id=dag_version_id)
             self.log.debug("Expanding TIs upserted %s", ti)
-            task_instance_mutation_hook(ti)
+            task_instance_mutation_hook(ti, dag_run=self)
             ti = session.merge(ti)
-            ti.refresh_from_task(task)
+            ti.refresh_from_task(task, dag_run=self)
             session.flush()
             yield ti
 
diff --git a/airflow-core/src/airflow/models/taskinstance.py 
b/airflow-core/src/airflow/models/taskinstance.py
index 55a838529da..a31346097e8 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -386,7 +386,7 @@ def clear_task_instances(
             task_id = ti.task_id
             if ti_dag and ti_dag.has_task(task_id):
                 task = ti_dag.get_task(task_id)
-                ti.refresh_from_task(task)
+                ti.refresh_from_task(task, dag_run=dr)
                 if TYPE_CHECKING:
                     assert ti.task
                 ti.max_tries = ti.try_number + task.retries
@@ -930,12 +930,16 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
         else:
             self.state = None
 
-    def refresh_from_task(self, task: Operator, pool_override: str | None = 
None) -> None:
+    def refresh_from_task(
+        self, task: Operator, pool_override: str | None = None, *, dag_run: 
DagRun | None = None
+    ) -> None:
         """
         Copy common attributes from the given task.
 
         :param task: The task object to copy from
         :param pool_override: Use the pool_override instead of task's pool
+        :param dag_run: the DagRun this task instance belongs to, forwarded to 
the mutation hook so a
+            cluster policy can route on ``dag_run.conf``; ``None`` when no run 
is available yet
         """
         self.task = task
         self.queue = task.queue
@@ -954,7 +958,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
         op_name = getattr(task, "operator_name", None)
         self.custom_operator_name = op_name if isinstance(op_name, str) else ""
         # Re-apply cluster policy here so that task default do not overload 
previous data
-        task_instance_mutation_hook(self)
+        task_instance_mutation_hook(self, dag_run=dag_run)
 
     @property
     def key(self) -> TaskInstanceKey:
diff --git a/airflow-core/src/airflow/models/taskmap.py 
b/airflow-core/src/airflow/models/taskmap.py
index 60486b8ce86..b7c394a7c13 100644
--- a/airflow-core/src/airflow/models/taskmap.py
+++ b/airflow-core/src/airflow/models/taskmap.py
@@ -200,6 +200,7 @@ class TaskMap(TaskInstanceDependencies):
                 )
                 unmapped_ti.state = TaskInstanceState.SKIPPED
             else:
+                dr = unmapped_ti.dag_run
                 zero_index_ti_exists = exists_query(
                     TaskInstance.dag_id == task.dag_id,
                     TaskInstance.task_id == task.task_id,
@@ -214,7 +215,7 @@ class TaskMap(TaskInstanceDependencies):
                     task.log.debug("Updated in place to become %s", 
unmapped_ti)
                     all_expanded_tis.append(unmapped_ti)
                     # execute hook for task instance map index 0
-                    task_instance_mutation_hook(unmapped_ti)
+                    task_instance_mutation_hook(unmapped_ti, dag_run=dr)
                     session.flush()
                 else:
                     task.log.debug("Deleting the original task instance: %s", 
unmapped_ti)
@@ -245,9 +246,7 @@ class TaskMap(TaskInstanceDependencies):
         else:
             dag_version_id = None
 
-        if unmapped_ti:
-            dr = unmapped_ti.dag_run
-        else:
+        if not unmapped_ti:
             from airflow.models import DagRun
 
             dr = session.scalar(
@@ -267,10 +266,10 @@ class TaskMap(TaskInstanceDependencies):
                 dag_version_id=dag_version_id,
             )
             task.log.debug("Expanding TIs upserted %s", ti)
-            task_instance_mutation_hook(ti)
+            task_instance_mutation_hook(ti, dag_run=dr)
             ti = session.merge(ti)
             ti.context_carrier = new_task_run_carrier(dr.context_carrier)
-            ti.refresh_from_task(task)  # session.merge() loses task 
information.
+            ti.refresh_from_task(task, dag_run=dr)  # session.merge() loses 
task information.
             all_expanded_tis.append(ti)
 
         # Coerce the None case to 0 -- these two are almost treated 
identically,
diff --git a/airflow-core/src/airflow/policies.py 
b/airflow-core/src/airflow/policies.py
index 7a8311bd67d..e08e4791e08 100644
--- a/airflow-core/src/airflow/policies.py
+++ b/airflow-core/src/airflow/policies.py
@@ -27,6 +27,7 @@ hookimpl = pluggy.HookimplMarker("airflow.policy")
 __all__: list[str] = ["hookimpl"]
 
 if TYPE_CHECKING:
+    from airflow.models.dagrun import DagRun
     from airflow.models.taskinstance import TaskInstance
 
 
@@ -67,13 +68,27 @@ def dag_policy(dag) -> None:
 
 
 @local_settings_hookspec
-def task_instance_mutation_hook(task_instance: TaskInstance) -> None:
+def task_instance_mutation_hook(task_instance: TaskInstance, dag_run: DagRun | 
None) -> None:
     """
     Allow altering task instances before being queued by the Airflow scheduler.
 
-    This could be used, for instance, to modify the task instance during 
retries.
+    This could be used, for instance, to modify the task instance during 
retries, or to route a task to
+    a different queue based on the run's configuration (``dag_run.conf``).
+
+    This hook runs scheduler-side while task instances are created or 
reconciled, inside a transaction
+    that prohibits committing the session. Implementations must therefore not 
open a new database
+    session or commit -- in particular, do not call 
``task_instance.get_dagrun()`` without passing the
+    active session, as the resulting commit raises ``RuntimeError("UNEXPECTED 
COMMIT ...")`` and crashes
+    the scheduler. Use the ``dag_run`` argument instead.
+
+    ``dag_run`` is provided wherever it is available without extra database 
access. It may be ``None`` in
+    early task-instance construction (for example when the hook is re-applied 
from
+    ``TaskInstance.refresh_from_task`` before the instance is bound to a run); 
implementations must handle
+    that case. Note that ``dag_run.conf`` is only populated for manually 
triggered or API-triggered runs;
+    scheduled runs carry an empty ``conf``.
 
     :param task_instance: task instance to be mutated
+    :param dag_run: the DagRun the task instance belongs to, or ``None`` when 
not yet available
     """
 
 
@@ -157,21 +172,25 @@ def make_plugin_from_local_settings(pm: 
pluggy.PluginManager, module, names: set
     hook_methods = set()
 
     def _make_shim_fn(name, desired_sig, target):
-        # Functions defined in airflow_local_settings are called by positional 
parameters, so the names don't
-        # have to match what we define in the "template" policy.
+        # Functions defined in airflow_local_settings are called by positional 
parameters, so the names
+        # don't have to match what we define in the "template" policy.
         #
         # However Pluggy validates the names match (and will raise an error if 
they don't!)
         #
-        # To maintain compat, if we detect the names don't match, we will wrap 
it with a dynamically created
-        # shim function that looks somewhat like this:
+        # To maintain compat, if we detect a name that isn't in the hookspec, 
we wrap the function with a
+        # dynamically created shim that looks somewhat like this:
         #
         #  def dag_policy_name_mismatch_shim(dag):
         #      airflow_local_settings.dag_policy(dag)
         #
+        # The target is called positionally, so we forward only as many of the 
hookspec's parameters as
+        # the target declares (in hookspec order). This lets a misnamed 
function still receive the leading
+        # hookspec arguments without breaking when the hookspec grows 
additional trailing parameters.
+        forwarded = list(desired_sig.parameters)[: 
len(inspect.signature(target).parameters)]
         codestr = textwrap.dedent(
             f"""
             def {name}_name_mismatch_shim{desired_sig}:
-                return __target({" ,".join(desired_sig.parameters)})
+                return __target({", ".join(forwarded)})
             """
         )
         code = compile(codestr, "<policy-shim>", "single")
@@ -199,8 +218,25 @@ def make_plugin_from_local_settings(pm: 
pluggy.PluginManager, module, names: set
 
         local_sig = inspect.signature(policy)
         policy_sig = inspect.signature(globals()[name])
-        # We only care if names/order/number of parameters match, not type 
hints
-        if local_sig.parameters.keys() != policy_sig.parameters.keys():
+        # Decide whether the local settings function can be registered as-is 
or needs a shim. Pluggy
+        # passes a hookimpl only the parameters it declares *without a 
default* (its `argnames`), by name.
+        # Parameters with defaults go into `kwargnames` and are never routed, 
so an impl that writes
+        # ``def hook(task_instance, dag_run=None)`` would silently never 
receive ``dag_run``. We shim when:
+        #
+        # * the function declares a name the hookspec does not have (a genuine 
mismatch, e.g. a renamed
+        #   argument), or
+        # * the function gives a hookspec parameter a default value -- pluggy 
puts it in kwargnames and
+        #   the function would always see its own default rather than the real 
value.
+        #
+        # A function declaring a (default-free) subset of the hookspec 
parameters is registered as-is, so a
+        # single-parameter hook keeps working unchanged after the hookspec 
gains new parameters. The shim
+        # forwards positionally, so it transparently handles renames, 
defaulted parameters, and signatures
+        # that declare fewer parameters than the hookspec.
+        needs_shim = not (local_sig.parameters.keys() <= 
policy_sig.parameters.keys()) or any(
+            param.default is not inspect.Parameter.empty and param_name in 
policy_sig.parameters
+            for param_name, param in local_sig.parameters.items()
+        )
+        if needs_shim:
             policy = _make_shim_fn(name, policy_sig, target=policy)
 
         setattr(AirflowLocalSettingsPolicy, name, 
staticmethod(hookimpl(policy, specname=name)))
diff --git a/airflow-core/src/airflow/settings.py 
b/airflow-core/src/airflow/settings.py
index 739e2437f6a..56fc3d07ced 100644
--- a/airflow-core/src/airflow/settings.py
+++ b/airflow-core/src/airflow/settings.py
@@ -73,6 +73,8 @@ if TYPE_CHECKING:
     from sqlalchemy.engine import Engine
 
     from airflow.api_fastapi.common.types import UIAlert
+    from airflow.models.dagrun import DagRun
+    from airflow.models.taskinstance import TaskInstance
 
 log = logging.getLogger(__name__)
 
@@ -204,8 +206,10 @@ def dag_policy(dag):
     return get_policy_plugin_manager().hook.dag_policy(dag=dag)
 
 
-def task_instance_mutation_hook(task_instance):
-    return 
get_policy_plugin_manager().hook.task_instance_mutation_hook(task_instance=task_instance)
+def task_instance_mutation_hook(task_instance: TaskInstance, dag_run: DagRun | 
None = None):
+    return get_policy_plugin_manager().hook.task_instance_mutation_hook(
+        task_instance=task_instance, dag_run=dag_run
+    )
 
 
 task_instance_mutation_hook.is_noop = True  # type: ignore
diff --git a/airflow-core/tests/unit/cluster_policies/__init__.py 
b/airflow-core/tests/unit/cluster_policies/__init__.py
index e9c380838e9..485f65ed406 100644
--- a/airflow-core/tests/unit/cluster_policies/__init__.py
+++ b/airflow-core/tests/unit/cluster_policies/__init__.py
@@ -28,6 +28,7 @@ from airflow.sdk import BaseOperator
 
 if TYPE_CHECKING:
     from airflow.models.dag import DAG
+    from airflow.models.dagrun import DagRun
     from airflow.models.taskinstance import TaskInstance
 
 
@@ -109,9 +110,14 @@ def task_policy(task: TimedOperator):
 
 
 # [START example_task_mutation_hook]
-def task_instance_mutation_hook(task_instance: TaskInstance):
+def task_instance_mutation_hook(task_instance: TaskInstance, dag_run: DagRun | 
None = None):
+    # Route retries to a dedicated queue.
     if task_instance.try_number >= 1:
         task_instance.queue = "retry_queue"
+    # Route on the run configuration. ``dag_run`` may be None during early 
task-instance construction,
+    # so guard for it; read ``dag_run.conf`` directly rather than opening a 
new database session.
+    if dag_run is not None and (dag_run.conf or {}).get("route") == "high":
+        task_instance.queue = "high_priority_queue"
 
 
 # [END example_task_mutation_hook]
diff --git a/airflow-core/tests/unit/core/test_policies.py 
b/airflow-core/tests/unit/core/test_policies.py
index c7caff80b2c..47671e95a96 100644
--- a/airflow-core/tests/unit/core/test_policies.py
+++ b/airflow-core/tests/unit/core/test_policies.py
@@ -69,3 +69,80 @@ def test_local_settings_misnamed_argument(plugin_manager: 
pluggy.PluginManager):
     plugin_manager.hook.dag_policy(dag="passed_dag_value")
 
     assert called_with == "passed_dag_value"
+
+
+def test_local_settings_subset_of_parameters(plugin_manager: 
pluggy.PluginManager):
+    """
+    A local_settings function may declare a subset of a hookspec's parameters.
+
+    ``task_instance_mutation_hook`` accepts ``(task_instance, dag_run)``; a 
legacy hook that only declares
+    ``task_instance`` is registered as-is (no shim) and pluggy passes it just 
the argument it declares, so
+    single-parameter hooks keep working after the hookspec gained the optional 
``dag_run`` parameter.
+    """
+    called_with = None
+
+    def local_hook(task_instance):
+        nonlocal called_with
+        called_with = task_instance
+
+    mod = Namespace(task_instance_mutation_hook=local_hook)
+
+    policies.make_plugin_from_local_settings(plugin_manager, mod, 
{"task_instance_mutation_hook"})
+
+    plugin_manager.hook.task_instance_mutation_hook(task_instance="ti", 
dag_run="dr")
+
+    assert called_with == "ti"
+
+
+def test_local_settings_receives_all_declared_parameters(plugin_manager: 
pluggy.PluginManager):
+    """A hook declaring ``(task_instance, dag_run)`` with no defaults receives 
both arguments."""
+    received = {}
+
+    def task_instance_mutation_hook(task_instance, dag_run):
+        received["task_instance"] = task_instance
+        received["dag_run"] = dag_run
+
+    mod = Namespace(task_instance_mutation_hook=task_instance_mutation_hook)
+
+    policies.make_plugin_from_local_settings(plugin_manager, mod, 
{"task_instance_mutation_hook"})
+
+    plugin_manager.hook.task_instance_mutation_hook(task_instance="ti", 
dag_run="dr")
+
+    assert received == {"task_instance": "ti", "dag_run": "dr"}
+
+
+def test_local_settings_defaulted_parameter_is_still_forwarded(plugin_manager: 
pluggy.PluginManager):
+    """A hook that gives a hookspec parameter a default still receives the 
passed value.
+
+    Pluggy forwards a hookimpl only the parameters it declares *without* a 
default, so it would silently
+    drop ``dag_run`` from ``def hook(task_instance, dag_run=None)`` and leave 
the default in place. The
+    shim detects the defaulted hookspec parameter and forwards the call 
positionally so the ergonomic
+    ``dag_run=None`` signature receives the real value.
+    """
+    received = {}
+
+    def task_instance_mutation_hook(task_instance, dag_run=None):
+        received["task_instance"] = task_instance
+        received["dag_run"] = dag_run
+
+    mod = Namespace(task_instance_mutation_hook=task_instance_mutation_hook)
+
+    policies.make_plugin_from_local_settings(plugin_manager, mod, 
{"task_instance_mutation_hook"})
+
+    plugin_manager.hook.task_instance_mutation_hook(task_instance="ti", 
dag_run="dr")
+
+    assert received == {"task_instance": "ti", "dag_run": "dr"}
+
+
+def test_local_settings_unknown_argument_still_raises(plugin_manager: 
pluggy.PluginManager):
+    """A local_settings function declaring a name the hookspec does not have 
is still rejected loudly."""
+
+    def dag_policy(not_a_real_parameter, dag): ...
+
+    mod = Namespace(dag_policy=dag_policy)
+
+    # ``not_a_real_parameter`` is not a subset of the hookspec params, so it 
is shimmed and forwarded
+    # positionally -- pluggy then rejects the extra positional argument rather 
than silently ignoring it.
+    policies.make_plugin_from_local_settings(plugin_manager, mod, 
{"dag_policy"})
+    with pytest.raises(TypeError):
+        plugin_manager.hook.dag_policy(dag="passed_dag_value")
diff --git a/airflow-core/tests/unit/models/test_dagrun.py 
b/airflow-core/tests/unit/models/test_dagrun.py
index c33f3b007bf..68442bbfbc0 100644
--- a/airflow-core/tests/unit/models/test_dagrun.py
+++ b/airflow-core/tests/unit/models/test_dagrun.py
@@ -854,7 +854,7 @@ class TestDagRun:
     @pytest.mark.parametrize("state", State.task_states)
     @mock.patch.object(settings, "task_instance_mutation_hook", autospec=True)
     def test_task_instance_mutation_hook(self, mock_hook, dag_maker, session, 
state):
-        def mutate_task_instance(task_instance):
+        def mutate_task_instance(task_instance, dag_run=None):
             if task_instance.queue == "queue1":
                 task_instance.queue = "queue2"
             else:
@@ -887,7 +887,7 @@ class TestDagRun:
         """
         observed_run_ids = []
 
-        def mutate_task_instance(task_instance):
+        def mutate_task_instance(task_instance, dag_run=None):
             observed_run_ids.append(task_instance.run_id)
             if task_instance.run_id and 
task_instance.run_id.startswith("manual__"):
                 task_instance.pool = "manual_pool"
@@ -907,6 +907,43 @@ class TestDagRun:
             f"task_instance_mutation_hook was called with run_id=None. 
Observed run_ids: {observed_run_ids}"
         )
 
+    def test_task_instance_mutation_hook_receives_dag_run(self, dag_maker, 
session, monkeypatch):
+        """task_instance_mutation_hook can route on dag_run.conf for mapped 
task instances.
+
+        Exercises the supported dag_run argument end-to-end: a conf-routing 
hook sets queue based on
+        DagRun.conf, and the routed value persists on every expanded mapped 
task instance. This is the
+        first-class replacement for reaching the DagRun through unsupported 
internals.
+        """
+        observed_confs = []
+
+        def mutate_task_instance(task_instance, dag_run=None):
+            observed_confs.append(None if dag_run is None else dag_run.conf)
+            if dag_run is not None and (dag_run.conf or {}).get("route") == 
"high":
+                task_instance.queue = "high_queue"
+
+        with mock.patch.object(
+            get_policy_plugin_manager().hook, "task_instance_mutation_hook", 
autospec=True
+        ) as mock_hook:
+            mock_hook.side_effect = mutate_task_instance
+            # Force the non-noop task-creation path so the scheduler invokes 
the hook with dag_run while
+            # materializing the mapped instances (mocking the hook does not 
flip the is_noop flag that
+            # import_local_settings would set for a real registered policy).
+            monkeypatch.setattr(settings.task_instance_mutation_hook, 
"is_noop", False)
+            with dag_maker(dag_id="test_mutation_hook_dag_run", 
session=session):
+                MockOperator.partial(task_id="mapped").expand(arg2=[1, 2, 3])
+
+            dr = dag_maker.create_dagrun(conf={"route": "high"})
+
+        # The hook saw the run conf (at least once with the routing value), 
and every expanded mapped
+        # task instance was routed to the conf-selected queue.
+        assert {"route": "high"} in observed_confs
+        queues = session.scalars(
+            select(TI.queue)
+            .where(TI.task_id == "mapped", TI.dag_id == dr.dag_id, TI.run_id 
== dr.run_id)
+            .order_by(TI.map_index)
+        ).all()
+        assert queues == ["high_queue", "high_queue", "high_queue"]
+
     @pytest.mark.parametrize(
         ("prev_ti_state", "is_ti_schedulable"),
         [
@@ -1652,7 +1689,7 @@ def 
test_mutation_hook_committing_session_crashes_under_prohibit_commit(dag_make
     before_commit guard with RuntimeError("UNEXPECTED COMMIT - THIS WILL BREAK 
HA LOCKS!").
     """
 
-    def naive_hook(task_instance):
+    def naive_hook(task_instance, dag_run=None):
         # Reads DagRun.conf the unsafe way -- opens a fresh @provide_session 
session that commits on exit.
         task_instance.get_dagrun()
 
@@ -1683,7 +1720,7 @@ def 
test_mutation_hook_safe_session_reuse_routes_mapped_tis_under_prohibit_commi
     Resolving via the attached session is the discipline a real conf-routing 
hook must follow.
     """
 
-    def safe_hook(task_instance):
+    def safe_hook(task_instance, dag_run=None):
         attached_session = sa_inspect(task_instance).session
         if attached_session is None:
             # Transient instance (pre-merge); it will be re-invoked once 
attached. Nothing safe to do.
@@ -1722,7 +1759,7 @@ def 
test_mutation_hook_deterministic_across_repeated_invocation_during_expansion
     """
     call_counts: dict[int, int] = defaultdict(int)
 
-    def deterministic_hook(task_instance):
+    def deterministic_hook(task_instance, dag_run=None):
         call_counts[task_instance.map_index] += 1
         task_instance.queue = f"q_{task_instance.map_index}"
 
@@ -1806,7 +1843,7 @@ def 
test_naive_committing_hook_crashes_on_verify_integrity_under_guard(dag_maker
     session; create_session() commits on exit and trips the before_commit 
guard.
     """
 
-    def naive_hook(task_instance):
+    def naive_hook(task_instance, dag_run=None):
         task_instance.get_dagrun()
 
     dr, dag_version_id = _make_literal_mapped_dagrun(
@@ -1835,7 +1872,7 @@ def 
test_resolve_dagrun_attribute_access_is_safe_on_verify_integrity_under_guard
     """
     seen_map_indices = []
 
-    def resolve_dagrun_like_hook(task_instance):
+    def resolve_dagrun_like_hook(task_instance, dag_run=None):
         state = sa_inspect(task_instance)
         if "dag_run" not in state.unloaded:
             _ = task_instance.dag_run  # eager-loaded: cheap attribute read
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py 
b/airflow-core/tests/unit/models/test_taskinstance.py
index e840b01b859..f1a46cd52d9 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -2762,7 +2762,7 @@ def test_refresh_from_task(pool_override, 
queue_by_policy, monkeypatch):
     expected_queue = queue_by_policy or default_queue
     if queue_by_policy:
         # Apply a dummy cluster policy to check if it is always applied
-        def mock_policy(task_instance: TaskInstance):
+        def mock_policy(task_instance: TaskInstance, dag_run=None):
             task_instance.queue = queue_by_policy
 
         
monkeypatch.setattr("airflow.models.taskinstance.task_instance_mutation_hook", 
mock_policy)
diff --git a/devel-common/src/tests_common/pytest_plugin.py 
b/devel-common/src/tests_common/pytest_plugin.py
index a36cb50a206..0958dbd663a 100644
--- a/devel-common/src/tests_common/pytest_plugin.py
+++ b/devel-common/src/tests_common/pytest_plugin.py
@@ -1193,7 +1193,9 @@ def dag_maker(request) -> Generator[DagMaker, None, None]:
 
             self.dag_run = dag.create_dagrun(**kwargs)
             for ti in self.dag_run.task_instances:
-                if AIRFLOW_V_3_0_PLUS:
+                if AIRFLOW_V_3_3_PLUS:
+                    ti.refresh_from_task(dag.get_task(ti.task_id), 
dag_run=self.dag_run)
+                elif AIRFLOW_V_3_0_PLUS:
                     ti.refresh_from_task(dag.get_task(ti.task_id))
                 else:
                     ti.refresh_from_task(self.dag.get_task(ti.task_id))
diff --git a/scripts/ci/prek/check_partition_mapper_defaults_in_sync.py 
b/scripts/ci/prek/check_partition_mapper_defaults_in_sync.py
index 24f7520f1b5..b6d010e3be1 100755
--- a/scripts/ci/prek/check_partition_mapper_defaults_in_sync.py
+++ b/scripts/ci/prek/check_partition_mapper_defaults_in_sync.py
@@ -76,6 +76,20 @@ SDK_FIXED_KEY_FILE = (
 )
 
 
+def _attr_value(stmt: ast.stmt) -> ast.expr | None:
+    """Return the value node of a statement that assigns ``ATTR_NAME``, or 
None."""
+    if isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name) 
and stmt.target.id == ATTR_NAME:
+        return stmt.value
+    if (
+        isinstance(stmt, ast.Assign)
+        and len(stmt.targets) == 1
+        and isinstance(stmt.targets[0], ast.Name)
+        and stmt.targets[0].id == ATTR_NAME
+    ):
+        return stmt.value
+    return None
+
+
 def _find_attr_value(file_path: Path) -> ast.Dict:
     """Return the AST node assigned to 
``FanOutMapper.default_downstream_mapper_by_window_name``."""
     tree = ast.parse(file_path.read_text(encoding="utf-8"), 
filename=str(file_path))
@@ -83,23 +97,15 @@ def _find_attr_value(file_path: Path) -> ast.Dict:
         if not (isinstance(node, ast.ClassDef) and node.name == CLASS_NAME):
             continue
         for stmt in node.body:
-            target_name = None
-            if isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, 
ast.Name):
-                target_name = stmt.target.id
-            elif (
-                isinstance(stmt, ast.Assign)
-                and len(stmt.targets) == 1
-                and isinstance(stmt.targets[0], ast.Name)
-            ):
-                target_name = stmt.targets[0].id
-            if target_name != ATTR_NAME:
+            value = _attr_value(stmt)
+            if value is None:
                 continue
-            if not isinstance(stmt.value, ast.Dict):
+            if not isinstance(value, ast.Dict):
                 raise ValueError(
                     f"{file_path}: {CLASS_NAME}.{ATTR_NAME} is not a dict 
literal; "
                     f"this check parses a dict literal and must be updated."
                 )
-            return stmt.value
+            return value
         raise ValueError(f"{file_path}: {CLASS_NAME} has no {ATTR_NAME} 
attribute.")
     raise ValueError(f"{file_path}: no class {CLASS_NAME} found.")
 


Reply via email to