This is an automated email from the ASF dual-hosted git repository.
jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 471e9a17faf Pass DagRun to task_instance_mutation_hook for run-aware
task mutation (#68198)
471e9a17faf is described below
commit 471e9a17faf2c36b7a1a3123999524fe2d393cca
Author: Dheeraj Turaga <[email protected]>
AuthorDate: Wed Jun 10 16:53:01 2026 -0500
Pass DagRun to task_instance_mutation_hook for run-aware task mutation
(#68198)
* Pass DagRun to task_instance_mutation_hook for run-aware task mutation
Cluster policies could mutate a TaskInstance before scheduling but had no
supported way to read the DagRun it belongs to -- so routing a task to a
queue
based on the run's configuration was impossible. Reaching the DagRun via
TaskInstance.get_dagrun() opens a nested committing session that trips the
scheduler's prohibit_commit guard and crashes the scheduler.
Add an optional dag_run argument to the task_instance_mutation_hook policy
hook
and thread the in-scope DagRun through every scheduler-side call site (no
extra
database access). A hook can now route on dag_run.conf, for example:
def task_instance_mutation_hook(task_instance, dag_run=None):
if dag_run is not None and (dag_run.conf or {}).get("route") ==
"high":
task_instance.queue = "high_priority_queue"
Backward compatibility is preserved through the local-settings plugin shim.
Pluggy passes a hookimpl only the parameters it declares without a default,
so
this commit also makes make_plugin_from_local_settings register a function
as-is
when its parameters are a subset of the hookspec's, and shim it (forwarding
the
call positionally) when it renames a parameter or gives a hookspec
parameter a
default value. As a result every signature keeps working: the legacy
single-argument hook, an explicit (task_instance, dag_run) hook, and the
ergonomic (task_instance, dag_run=None) hook all behave correctly; a hook
declaring an unknown parameter is still rejected.
The dag_run may be None during early task-instance construction (when the
hook
is re-applied from refresh_from_task before the instance is bound to a
run), and
the hook must not open a new session; both are documented on the hookspec
and in
the cluster-policies guide, which also corrects the prior claim that the
hook
runs on the worker.
* Store resolved param defaults in dag_run.conf for runs with no explicit
conf
When a DAG run is created without an explicit conf dict (all scheduled runs,
and manual triggers that omit config), dag_run.conf was stored as None even
though the param defaults had already been merged and validated. Any runtime
consumer reading dag_run.conf -- tasks using {{ dag_run.conf }}, the
task_instance_mutation_hook reading dag_run.conf for queue routing, sensors
checking conf values -- silently got None for scheduled runs.
The fix is one line: pass the already-computed copied_params dict as conf
when
the caller supplied none. copied_params is the result of deep_merge(conf)
and
validate(), so no extra work is done -- the resolved value is just preserved
instead of discarded.
Manual/triggered runs that supply an explicit conf are unaffected: their
caller-supplied dict is stored as before.
* Address PR review feedback on task_instance_mutation_hook dag_run arg
- Revert conf=params fallback in create_dagrun (unrelated behaviour change;
document that scheduled runs carry empty conf instead)
- Add type annotations to settings.task_instance_mutation_hook
- Hoist dr = unmapped_ti.dag_run in taskmap.py to avoid a second lazy-load
- Rename defaults_a_spec_param -> has_defaulted_spec_param; inline
target_arity
- Remove redundant `or {}` guard on dag_run.conf (always a dict)
- Delete newsfragment (not a major/breaking change)
* Fix test hooks to accept dag_run kwarg after hookspec expansion
Upstream main added five tests in test_dagrun.py that define single-arg
mutation hooks (task_instance) registered via mock side_effect. After
rebasing onto main, the hookspec now passes dag_run as a keyword argument
to every hook, so these hooks crashed with an unexpected keyword argument.
Add dag_run=None to each hook signature so they accept the new argument
without changing their behaviour. Also restore the (dag_run.conf or {})
guard in cluster_policies/__init__.py -- DagRun.conf is typed as
dict | None at the ORM level, so a plain .get() fails mypy.
* Simplify needs_shim condition and defer dag_run lazy-load in taskmap
- policies.py: collapse names_are_subset + has_defaulted_spec_param into a
single needs_shim boolean; expand comment to explain that pluggy puts
defaulted params into kwargnames and never routes them (the why)
- taskmap.py: move dr = unmapped_ti.dag_run inside the else: branch where
it is actually consumed, avoiding an unconditional lazy DB load when
total_length is None or < 1 (upstream-failed / skipped paths)
* Fix compat and mypy failures from static checks CI
- pytest_plugin.py: gate dag_run= kwarg on AIRFLOW_V_3_3_PLUS; the param
was added to refresh_from_task in this PR (targeting 3.3) and does not
exist on 3.0.x/3.1.x/3.2.x used by the provider compat test matrix
- check_partition_mapper_defaults_in_sync.py: fix mypy union-attr errors
by narrowing stmt inside each isinstance branch before accessing .value,
rather than using a target_name sentinel that leaves stmt unnarrowed
* Simplify version gate and partition mapper script
- pytest_plugin.py: use dag (local) instead of self.dag in the
AIRFLOW_V_3_0_PLUS elif branch, matching the 3.3+ branch above it
- check_partition_mapper_defaults_in_sync.py: extract _attr_value() helper
to eliminate the duplicated 'value = stmt.value' bodies; _find_attr_value
becomes a straight-line loop (call helper, skip None, validate, return)
* Fix dag_maker dropping refresh_from_task on Airflow 2.x compat path
The version-gate refactor for the dag_run kwarg removed the pre-3.0
else branch, so under Airflow 2.x neither the 3.3+ nor 3.0+ condition
matched and refresh_from_task was never called. Task instances kept
ti.task as None, failing ~100 provider compat tests with
"'NoneType' object has no attribute 'dag'/'operator_extra_links'/'queue'".
Restore the else branch that calls refresh_from_task for 2.x.
---
.../cluster-policies.rst | 19 ++++--
airflow-core/src/airflow/models/dagrun.py | 8 +--
airflow-core/src/airflow/models/taskinstance.py | 10 ++-
airflow-core/src/airflow/models/taskmap.py | 11 ++--
airflow-core/src/airflow/policies.py | 54 ++++++++++++---
airflow-core/src/airflow/settings.py | 8 ++-
.../tests/unit/cluster_policies/__init__.py | 8 ++-
airflow-core/tests/unit/core/test_policies.py | 77 ++++++++++++++++++++++
airflow-core/tests/unit/models/test_dagrun.py | 51 ++++++++++++--
.../tests/unit/models/test_taskinstance.py | 2 +-
devel-common/src/tests_common/pytest_plugin.py | 4 +-
.../check_partition_mapper_defaults_in_sync.py | 30 +++++----
12 files changed, 232 insertions(+), 50 deletions(-)
diff --git
a/airflow-core/docs/administration-and-deployment/cluster-policies.rst
b/airflow-core/docs/administration-and-deployment/cluster-policies.rst
index fd7d0b622f5..e23fd7ffad2 100644
--- a/airflow-core/docs/administration-and-deployment/cluster-policies.rst
+++ b/airflow-core/docs/administration-and-deployment/cluster-policies.rst
@@ -37,10 +37,21 @@ There are three main types of cluster policy:
task running in a DagRun. The ``task_policy`` defined is applied to all the
task instances that will be
executed in the future.
* ``task_instance_mutation_hook``: Takes a
:class:`~airflow.models.taskinstance.TaskInstance` parameter called
- ``task_instance``. The ``task_instance_mutation_hook`` applies not to a task
but to the instance of a task that
- relates to a particular DagRun. It is executed in a "worker", not in the Dag
file processor, just before the
- task instance is executed. The policy is only applied to the currently
executed run (i.e. instance) of that
- task.
+ ``task_instance`` and an optional :class:`~airflow.models.dagrun.DagRun`
parameter called ``dag_run``. The
+ ``task_instance_mutation_hook`` applies not to a task but to the instance of
a task that
+ relates to a particular DagRun. It is executed scheduler-side while task
instances are created or
+ reconciled (not in the Dag file processor, and not on the worker). The
policy is only applied to the
+ currently executed run (i.e. instance) of that task. The ``dag_run``
argument lets the policy route on
+ run configuration (``dag_run.conf``); it may be ``None`` in early
task-instance construction, and a hook
+ that only declares ``task_instance`` keeps working unchanged. Note that
``dag_run.conf`` is only populated
+ for manually triggered or API-triggered runs; scheduled runs carry an empty
``conf``.
+
+.. warning::
+
+ ``task_instance_mutation_hook`` runs inside a scheduler transaction that
prohibits committing the
+ session. Do not open a new database session or commit from within the hook
-- in particular, do not
+ call ``task_instance.get_dagrun()`` without passing the active session, as
the resulting commit
+ crashes the scheduler. Use the ``dag_run`` argument to read run
configuration instead.
The Dag and Task cluster policies can raise the
:class:`~airflow.exceptions.AirflowClusterPolicyViolation`
exception to indicate that the Dag/task they were passed is not compliant and
should not be loaded.
diff --git a/airflow-core/src/airflow/models/dagrun.py
b/airflow-core/src/airflow/models/dagrun.py
index 364b6f0953b..09cb8b1206f 100644
--- a/airflow-core/src/airflow/models/dagrun.py
+++ b/airflow-core/src/airflow/models/dagrun.py
@@ -1722,7 +1722,7 @@ class DagRun(Base, LoggingMixin):
# check for removed or restored tasks
task_ids = set()
for ti in tis:
- ti_mutation_hook(ti)
+ ti_mutation_hook(ti, dag_run=self)
task_ids.add(ti.task_id)
try:
task = dag.get_task(ti.task_id)
@@ -1841,7 +1841,7 @@ class DagRun(Base, LoggingMixin):
def create_ti(task: Operator, indexes: Iterable[int]) ->
Iterator[TI]:
for map_index in indexes:
ti = TI(task, run_id=self.run_id, map_index=map_index,
dag_version_id=dag_version_id)
- ti_mutation_hook(ti)
+ ti_mutation_hook(ti, dag_run=self)
if ti.operator:
created_counts[ti.operator] += 1
yield ti
@@ -1980,9 +1980,9 @@ class DagRun(Base, LoggingMixin):
continue
ti = TI(task, run_id=self.run_id, map_index=index, state=None,
dag_version_id=dag_version_id)
self.log.debug("Expanding TIs upserted %s", ti)
- task_instance_mutation_hook(ti)
+ task_instance_mutation_hook(ti, dag_run=self)
ti = session.merge(ti)
- ti.refresh_from_task(task)
+ ti.refresh_from_task(task, dag_run=self)
session.flush()
yield ti
diff --git a/airflow-core/src/airflow/models/taskinstance.py
b/airflow-core/src/airflow/models/taskinstance.py
index 55a838529da..a31346097e8 100644
--- a/airflow-core/src/airflow/models/taskinstance.py
+++ b/airflow-core/src/airflow/models/taskinstance.py
@@ -386,7 +386,7 @@ def clear_task_instances(
task_id = ti.task_id
if ti_dag and ti_dag.has_task(task_id):
task = ti_dag.get_task(task_id)
- ti.refresh_from_task(task)
+ ti.refresh_from_task(task, dag_run=dr)
if TYPE_CHECKING:
assert ti.task
ti.max_tries = ti.try_number + task.retries
@@ -930,12 +930,16 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
else:
self.state = None
- def refresh_from_task(self, task: Operator, pool_override: str | None =
None) -> None:
+ def refresh_from_task(
+ self, task: Operator, pool_override: str | None = None, *, dag_run:
DagRun | None = None
+ ) -> None:
"""
Copy common attributes from the given task.
:param task: The task object to copy from
:param pool_override: Use the pool_override instead of task's pool
+ :param dag_run: the DagRun this task instance belongs to, forwarded to
the mutation hook so a
+ cluster policy can route on ``dag_run.conf``; ``None`` when no run
is available yet
"""
self.task = task
self.queue = task.queue
@@ -954,7 +958,7 @@ class TaskInstance(Base, LoggingMixin, BaseWorkload):
op_name = getattr(task, "operator_name", None)
self.custom_operator_name = op_name if isinstance(op_name, str) else ""
# Re-apply cluster policy here so that task default do not overload
previous data
- task_instance_mutation_hook(self)
+ task_instance_mutation_hook(self, dag_run=dag_run)
@property
def key(self) -> TaskInstanceKey:
diff --git a/airflow-core/src/airflow/models/taskmap.py
b/airflow-core/src/airflow/models/taskmap.py
index 60486b8ce86..b7c394a7c13 100644
--- a/airflow-core/src/airflow/models/taskmap.py
+++ b/airflow-core/src/airflow/models/taskmap.py
@@ -200,6 +200,7 @@ class TaskMap(TaskInstanceDependencies):
)
unmapped_ti.state = TaskInstanceState.SKIPPED
else:
+ dr = unmapped_ti.dag_run
zero_index_ti_exists = exists_query(
TaskInstance.dag_id == task.dag_id,
TaskInstance.task_id == task.task_id,
@@ -214,7 +215,7 @@ class TaskMap(TaskInstanceDependencies):
task.log.debug("Updated in place to become %s",
unmapped_ti)
all_expanded_tis.append(unmapped_ti)
# execute hook for task instance map index 0
- task_instance_mutation_hook(unmapped_ti)
+ task_instance_mutation_hook(unmapped_ti, dag_run=dr)
session.flush()
else:
task.log.debug("Deleting the original task instance: %s",
unmapped_ti)
@@ -245,9 +246,7 @@ class TaskMap(TaskInstanceDependencies):
else:
dag_version_id = None
- if unmapped_ti:
- dr = unmapped_ti.dag_run
- else:
+ if not unmapped_ti:
from airflow.models import DagRun
dr = session.scalar(
@@ -267,10 +266,10 @@ class TaskMap(TaskInstanceDependencies):
dag_version_id=dag_version_id,
)
task.log.debug("Expanding TIs upserted %s", ti)
- task_instance_mutation_hook(ti)
+ task_instance_mutation_hook(ti, dag_run=dr)
ti = session.merge(ti)
ti.context_carrier = new_task_run_carrier(dr.context_carrier)
- ti.refresh_from_task(task) # session.merge() loses task
information.
+ ti.refresh_from_task(task, dag_run=dr) # session.merge() loses
task information.
all_expanded_tis.append(ti)
# Coerce the None case to 0 -- these two are almost treated
identically,
diff --git a/airflow-core/src/airflow/policies.py
b/airflow-core/src/airflow/policies.py
index 7a8311bd67d..e08e4791e08 100644
--- a/airflow-core/src/airflow/policies.py
+++ b/airflow-core/src/airflow/policies.py
@@ -27,6 +27,7 @@ hookimpl = pluggy.HookimplMarker("airflow.policy")
__all__: list[str] = ["hookimpl"]
if TYPE_CHECKING:
+ from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
@@ -67,13 +68,27 @@ def dag_policy(dag) -> None:
@local_settings_hookspec
-def task_instance_mutation_hook(task_instance: TaskInstance) -> None:
+def task_instance_mutation_hook(task_instance: TaskInstance, dag_run: DagRun |
None) -> None:
"""
Allow altering task instances before being queued by the Airflow scheduler.
- This could be used, for instance, to modify the task instance during
retries.
+ This could be used, for instance, to modify the task instance during
retries, or to route a task to
+ a different queue based on the run's configuration (``dag_run.conf``).
+
+ This hook runs scheduler-side while task instances are created or
reconciled, inside a transaction
+ that prohibits committing the session. Implementations must therefore not
open a new database
+ session or commit -- in particular, do not call
``task_instance.get_dagrun()`` without passing the
+ active session, as the resulting commit raises ``RuntimeError("UNEXPECTED
COMMIT ...")`` and crashes
+ the scheduler. Use the ``dag_run`` argument instead.
+
+ ``dag_run`` is provided wherever it is available without extra database
access. It may be ``None`` in
+ early task-instance construction (for example when the hook is re-applied
from
+ ``TaskInstance.refresh_from_task`` before the instance is bound to a run);
implementations must handle
+ that case. Note that ``dag_run.conf`` is only populated for manually
triggered or API-triggered runs;
+ scheduled runs carry an empty ``conf``.
:param task_instance: task instance to be mutated
+ :param dag_run: the DagRun the task instance belongs to, or ``None`` when
not yet available
"""
@@ -157,21 +172,25 @@ def make_plugin_from_local_settings(pm:
pluggy.PluginManager, module, names: set
hook_methods = set()
def _make_shim_fn(name, desired_sig, target):
- # Functions defined in airflow_local_settings are called by positional
parameters, so the names don't
- # have to match what we define in the "template" policy.
+ # Functions defined in airflow_local_settings are called by positional
parameters, so the names
+ # don't have to match what we define in the "template" policy.
#
# However Pluggy validates the names match (and will raise an error if
they don't!)
#
- # To maintain compat, if we detect the names don't match, we will wrap
it with a dynamically created
- # shim function that looks somewhat like this:
+ # To maintain compat, if we detect a name that isn't in the hookspec,
we wrap the function with a
+ # dynamically created shim that looks somewhat like this:
#
# def dag_policy_name_mismatch_shim(dag):
# airflow_local_settings.dag_policy(dag)
#
+ # The target is called positionally, so we forward only as many of the
hookspec's parameters as
+ # the target declares (in hookspec order). This lets a misnamed
function still receive the leading
+ # hookspec arguments without breaking when the hookspec grows
additional trailing parameters.
+ forwarded = list(desired_sig.parameters)[:
len(inspect.signature(target).parameters)]
codestr = textwrap.dedent(
f"""
def {name}_name_mismatch_shim{desired_sig}:
- return __target({" ,".join(desired_sig.parameters)})
+ return __target({", ".join(forwarded)})
"""
)
code = compile(codestr, "<policy-shim>", "single")
@@ -199,8 +218,25 @@ def make_plugin_from_local_settings(pm:
pluggy.PluginManager, module, names: set
local_sig = inspect.signature(policy)
policy_sig = inspect.signature(globals()[name])
- # We only care if names/order/number of parameters match, not type
hints
- if local_sig.parameters.keys() != policy_sig.parameters.keys():
+ # Decide whether the local settings function can be registered as-is
or needs a shim. Pluggy
+ # passes a hookimpl only the parameters it declares *without a
default* (its `argnames`), by name.
+ # Parameters with defaults go into `kwargnames` and are never routed,
so an impl that writes
+ # ``def hook(task_instance, dag_run=None)`` would silently never
receive ``dag_run``. We shim when:
+ #
+ # * the function declares a name the hookspec does not have (a genuine
mismatch, e.g. a renamed
+ # argument), or
+ # * the function gives a hookspec parameter a default value -- pluggy
puts it in kwargnames and
+ # the function would always see its own default rather than the real
value.
+ #
+ # A function declaring a (default-free) subset of the hookspec
parameters is registered as-is, so a
+ # single-parameter hook keeps working unchanged after the hookspec
gains new parameters. The shim
+ # forwards positionally, so it transparently handles renames,
defaulted parameters, and signatures
+ # that declare fewer parameters than the hookspec.
+ needs_shim = not (local_sig.parameters.keys() <=
policy_sig.parameters.keys()) or any(
+ param.default is not inspect.Parameter.empty and param_name in
policy_sig.parameters
+ for param_name, param in local_sig.parameters.items()
+ )
+ if needs_shim:
policy = _make_shim_fn(name, policy_sig, target=policy)
setattr(AirflowLocalSettingsPolicy, name,
staticmethod(hookimpl(policy, specname=name)))
diff --git a/airflow-core/src/airflow/settings.py
b/airflow-core/src/airflow/settings.py
index 739e2437f6a..56fc3d07ced 100644
--- a/airflow-core/src/airflow/settings.py
+++ b/airflow-core/src/airflow/settings.py
@@ -73,6 +73,8 @@ if TYPE_CHECKING:
from sqlalchemy.engine import Engine
from airflow.api_fastapi.common.types import UIAlert
+ from airflow.models.dagrun import DagRun
+ from airflow.models.taskinstance import TaskInstance
log = logging.getLogger(__name__)
@@ -204,8 +206,10 @@ def dag_policy(dag):
return get_policy_plugin_manager().hook.dag_policy(dag=dag)
-def task_instance_mutation_hook(task_instance):
- return
get_policy_plugin_manager().hook.task_instance_mutation_hook(task_instance=task_instance)
+def task_instance_mutation_hook(task_instance: TaskInstance, dag_run: DagRun |
None = None):
+ return get_policy_plugin_manager().hook.task_instance_mutation_hook(
+ task_instance=task_instance, dag_run=dag_run
+ )
task_instance_mutation_hook.is_noop = True # type: ignore
diff --git a/airflow-core/tests/unit/cluster_policies/__init__.py
b/airflow-core/tests/unit/cluster_policies/__init__.py
index e9c380838e9..485f65ed406 100644
--- a/airflow-core/tests/unit/cluster_policies/__init__.py
+++ b/airflow-core/tests/unit/cluster_policies/__init__.py
@@ -28,6 +28,7 @@ from airflow.sdk import BaseOperator
if TYPE_CHECKING:
from airflow.models.dag import DAG
+ from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
@@ -109,9 +110,14 @@ def task_policy(task: TimedOperator):
# [START example_task_mutation_hook]
-def task_instance_mutation_hook(task_instance: TaskInstance):
+def task_instance_mutation_hook(task_instance: TaskInstance, dag_run: DagRun |
None = None):
+ # Route retries to a dedicated queue.
if task_instance.try_number >= 1:
task_instance.queue = "retry_queue"
+ # Route on the run configuration. ``dag_run`` may be None during early
task-instance construction,
+ # so guard for it; read ``dag_run.conf`` directly rather than opening a
new database session.
+ if dag_run is not None and (dag_run.conf or {}).get("route") == "high":
+ task_instance.queue = "high_priority_queue"
# [END example_task_mutation_hook]
diff --git a/airflow-core/tests/unit/core/test_policies.py
b/airflow-core/tests/unit/core/test_policies.py
index c7caff80b2c..47671e95a96 100644
--- a/airflow-core/tests/unit/core/test_policies.py
+++ b/airflow-core/tests/unit/core/test_policies.py
@@ -69,3 +69,80 @@ def test_local_settings_misnamed_argument(plugin_manager:
pluggy.PluginManager):
plugin_manager.hook.dag_policy(dag="passed_dag_value")
assert called_with == "passed_dag_value"
+
+
+def test_local_settings_subset_of_parameters(plugin_manager:
pluggy.PluginManager):
+ """
+ A local_settings function may declare a subset of a hookspec's parameters.
+
+ ``task_instance_mutation_hook`` accepts ``(task_instance, dag_run)``; a
legacy hook that only declares
+ ``task_instance`` is registered as-is (no shim) and pluggy passes it just
the argument it declares, so
+ single-parameter hooks keep working after the hookspec gained the optional
``dag_run`` parameter.
+ """
+ called_with = None
+
+ def local_hook(task_instance):
+ nonlocal called_with
+ called_with = task_instance
+
+ mod = Namespace(task_instance_mutation_hook=local_hook)
+
+ policies.make_plugin_from_local_settings(plugin_manager, mod,
{"task_instance_mutation_hook"})
+
+ plugin_manager.hook.task_instance_mutation_hook(task_instance="ti",
dag_run="dr")
+
+ assert called_with == "ti"
+
+
+def test_local_settings_receives_all_declared_parameters(plugin_manager:
pluggy.PluginManager):
+ """A hook declaring ``(task_instance, dag_run)`` with no defaults receives
both arguments."""
+ received = {}
+
+ def task_instance_mutation_hook(task_instance, dag_run):
+ received["task_instance"] = task_instance
+ received["dag_run"] = dag_run
+
+ mod = Namespace(task_instance_mutation_hook=task_instance_mutation_hook)
+
+ policies.make_plugin_from_local_settings(plugin_manager, mod,
{"task_instance_mutation_hook"})
+
+ plugin_manager.hook.task_instance_mutation_hook(task_instance="ti",
dag_run="dr")
+
+ assert received == {"task_instance": "ti", "dag_run": "dr"}
+
+
+def test_local_settings_defaulted_parameter_is_still_forwarded(plugin_manager:
pluggy.PluginManager):
+ """A hook that gives a hookspec parameter a default still receives the
passed value.
+
+ Pluggy forwards a hookimpl only the parameters it declares *without* a
default, so it would silently
+ drop ``dag_run`` from ``def hook(task_instance, dag_run=None)`` and leave
the default in place. The
+ shim detects the defaulted hookspec parameter and forwards the call
positionally so the ergonomic
+ ``dag_run=None`` signature receives the real value.
+ """
+ received = {}
+
+ def task_instance_mutation_hook(task_instance, dag_run=None):
+ received["task_instance"] = task_instance
+ received["dag_run"] = dag_run
+
+ mod = Namespace(task_instance_mutation_hook=task_instance_mutation_hook)
+
+ policies.make_plugin_from_local_settings(plugin_manager, mod,
{"task_instance_mutation_hook"})
+
+ plugin_manager.hook.task_instance_mutation_hook(task_instance="ti",
dag_run="dr")
+
+ assert received == {"task_instance": "ti", "dag_run": "dr"}
+
+
+def test_local_settings_unknown_argument_still_raises(plugin_manager:
pluggy.PluginManager):
+ """A local_settings function declaring a name the hookspec does not have
is still rejected loudly."""
+
+ def dag_policy(not_a_real_parameter, dag): ...
+
+ mod = Namespace(dag_policy=dag_policy)
+
+ # ``not_a_real_parameter`` is not a subset of the hookspec params, so it
is shimmed and forwarded
+ # positionally -- pluggy then rejects the extra positional argument rather
than silently ignoring it.
+ policies.make_plugin_from_local_settings(plugin_manager, mod,
{"dag_policy"})
+ with pytest.raises(TypeError):
+ plugin_manager.hook.dag_policy(dag="passed_dag_value")
diff --git a/airflow-core/tests/unit/models/test_dagrun.py
b/airflow-core/tests/unit/models/test_dagrun.py
index c33f3b007bf..68442bbfbc0 100644
--- a/airflow-core/tests/unit/models/test_dagrun.py
+++ b/airflow-core/tests/unit/models/test_dagrun.py
@@ -854,7 +854,7 @@ class TestDagRun:
@pytest.mark.parametrize("state", State.task_states)
@mock.patch.object(settings, "task_instance_mutation_hook", autospec=True)
def test_task_instance_mutation_hook(self, mock_hook, dag_maker, session,
state):
- def mutate_task_instance(task_instance):
+ def mutate_task_instance(task_instance, dag_run=None):
if task_instance.queue == "queue1":
task_instance.queue = "queue2"
else:
@@ -887,7 +887,7 @@ class TestDagRun:
"""
observed_run_ids = []
- def mutate_task_instance(task_instance):
+ def mutate_task_instance(task_instance, dag_run=None):
observed_run_ids.append(task_instance.run_id)
if task_instance.run_id and
task_instance.run_id.startswith("manual__"):
task_instance.pool = "manual_pool"
@@ -907,6 +907,43 @@ class TestDagRun:
f"task_instance_mutation_hook was called with run_id=None.
Observed run_ids: {observed_run_ids}"
)
+ def test_task_instance_mutation_hook_receives_dag_run(self, dag_maker,
session, monkeypatch):
+ """task_instance_mutation_hook can route on dag_run.conf for mapped
task instances.
+
+ Exercises the supported dag_run argument end-to-end: a conf-routing
hook sets queue based on
+ DagRun.conf, and the routed value persists on every expanded mapped
task instance. This is the
+ first-class replacement for reaching the DagRun through unsupported
internals.
+ """
+ observed_confs = []
+
+ def mutate_task_instance(task_instance, dag_run=None):
+ observed_confs.append(None if dag_run is None else dag_run.conf)
+ if dag_run is not None and (dag_run.conf or {}).get("route") ==
"high":
+ task_instance.queue = "high_queue"
+
+ with mock.patch.object(
+ get_policy_plugin_manager().hook, "task_instance_mutation_hook",
autospec=True
+ ) as mock_hook:
+ mock_hook.side_effect = mutate_task_instance
+ # Force the non-noop task-creation path so the scheduler invokes
the hook with dag_run while
+ # materializing the mapped instances (mocking the hook does not
flip the is_noop flag that
+ # import_local_settings would set for a real registered policy).
+ monkeypatch.setattr(settings.task_instance_mutation_hook,
"is_noop", False)
+ with dag_maker(dag_id="test_mutation_hook_dag_run",
session=session):
+ MockOperator.partial(task_id="mapped").expand(arg2=[1, 2, 3])
+
+ dr = dag_maker.create_dagrun(conf={"route": "high"})
+
+ # The hook saw the run conf (at least once with the routing value),
and every expanded mapped
+ # task instance was routed to the conf-selected queue.
+ assert {"route": "high"} in observed_confs
+ queues = session.scalars(
+ select(TI.queue)
+ .where(TI.task_id == "mapped", TI.dag_id == dr.dag_id, TI.run_id
== dr.run_id)
+ .order_by(TI.map_index)
+ ).all()
+ assert queues == ["high_queue", "high_queue", "high_queue"]
+
@pytest.mark.parametrize(
("prev_ti_state", "is_ti_schedulable"),
[
@@ -1652,7 +1689,7 @@ def
test_mutation_hook_committing_session_crashes_under_prohibit_commit(dag_make
before_commit guard with RuntimeError("UNEXPECTED COMMIT - THIS WILL BREAK
HA LOCKS!").
"""
- def naive_hook(task_instance):
+ def naive_hook(task_instance, dag_run=None):
# Reads DagRun.conf the unsafe way -- opens a fresh @provide_session
session that commits on exit.
task_instance.get_dagrun()
@@ -1683,7 +1720,7 @@ def
test_mutation_hook_safe_session_reuse_routes_mapped_tis_under_prohibit_commi
Resolving via the attached session is the discipline a real conf-routing
hook must follow.
"""
- def safe_hook(task_instance):
+ def safe_hook(task_instance, dag_run=None):
attached_session = sa_inspect(task_instance).session
if attached_session is None:
# Transient instance (pre-merge); it will be re-invoked once
attached. Nothing safe to do.
@@ -1722,7 +1759,7 @@ def
test_mutation_hook_deterministic_across_repeated_invocation_during_expansion
"""
call_counts: dict[int, int] = defaultdict(int)
- def deterministic_hook(task_instance):
+ def deterministic_hook(task_instance, dag_run=None):
call_counts[task_instance.map_index] += 1
task_instance.queue = f"q_{task_instance.map_index}"
@@ -1806,7 +1843,7 @@ def
test_naive_committing_hook_crashes_on_verify_integrity_under_guard(dag_maker
session; create_session() commits on exit and trips the before_commit
guard.
"""
- def naive_hook(task_instance):
+ def naive_hook(task_instance, dag_run=None):
task_instance.get_dagrun()
dr, dag_version_id = _make_literal_mapped_dagrun(
@@ -1835,7 +1872,7 @@ def
test_resolve_dagrun_attribute_access_is_safe_on_verify_integrity_under_guard
"""
seen_map_indices = []
- def resolve_dagrun_like_hook(task_instance):
+ def resolve_dagrun_like_hook(task_instance, dag_run=None):
state = sa_inspect(task_instance)
if "dag_run" not in state.unloaded:
_ = task_instance.dag_run # eager-loaded: cheap attribute read
diff --git a/airflow-core/tests/unit/models/test_taskinstance.py
b/airflow-core/tests/unit/models/test_taskinstance.py
index e840b01b859..f1a46cd52d9 100644
--- a/airflow-core/tests/unit/models/test_taskinstance.py
+++ b/airflow-core/tests/unit/models/test_taskinstance.py
@@ -2762,7 +2762,7 @@ def test_refresh_from_task(pool_override,
queue_by_policy, monkeypatch):
expected_queue = queue_by_policy or default_queue
if queue_by_policy:
# Apply a dummy cluster policy to check if it is always applied
- def mock_policy(task_instance: TaskInstance):
+ def mock_policy(task_instance: TaskInstance, dag_run=None):
task_instance.queue = queue_by_policy
monkeypatch.setattr("airflow.models.taskinstance.task_instance_mutation_hook",
mock_policy)
diff --git a/devel-common/src/tests_common/pytest_plugin.py
b/devel-common/src/tests_common/pytest_plugin.py
index a36cb50a206..0958dbd663a 100644
--- a/devel-common/src/tests_common/pytest_plugin.py
+++ b/devel-common/src/tests_common/pytest_plugin.py
@@ -1193,7 +1193,9 @@ def dag_maker(request) -> Generator[DagMaker, None, None]:
self.dag_run = dag.create_dagrun(**kwargs)
for ti in self.dag_run.task_instances:
- if AIRFLOW_V_3_0_PLUS:
+ if AIRFLOW_V_3_3_PLUS:
+ ti.refresh_from_task(dag.get_task(ti.task_id),
dag_run=self.dag_run)
+ elif AIRFLOW_V_3_0_PLUS:
ti.refresh_from_task(dag.get_task(ti.task_id))
else:
ti.refresh_from_task(self.dag.get_task(ti.task_id))
diff --git a/scripts/ci/prek/check_partition_mapper_defaults_in_sync.py
b/scripts/ci/prek/check_partition_mapper_defaults_in_sync.py
index 24f7520f1b5..b6d010e3be1 100755
--- a/scripts/ci/prek/check_partition_mapper_defaults_in_sync.py
+++ b/scripts/ci/prek/check_partition_mapper_defaults_in_sync.py
@@ -76,6 +76,20 @@ SDK_FIXED_KEY_FILE = (
)
+def _attr_value(stmt: ast.stmt) -> ast.expr | None:
+ """Return the value node of a statement that assigns ``ATTR_NAME``, or
None."""
+ if isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target, ast.Name)
and stmt.target.id == ATTR_NAME:
+ return stmt.value
+ if (
+ isinstance(stmt, ast.Assign)
+ and len(stmt.targets) == 1
+ and isinstance(stmt.targets[0], ast.Name)
+ and stmt.targets[0].id == ATTR_NAME
+ ):
+ return stmt.value
+ return None
+
+
def _find_attr_value(file_path: Path) -> ast.Dict:
"""Return the AST node assigned to
``FanOutMapper.default_downstream_mapper_by_window_name``."""
tree = ast.parse(file_path.read_text(encoding="utf-8"),
filename=str(file_path))
@@ -83,23 +97,15 @@ def _find_attr_value(file_path: Path) -> ast.Dict:
if not (isinstance(node, ast.ClassDef) and node.name == CLASS_NAME):
continue
for stmt in node.body:
- target_name = None
- if isinstance(stmt, ast.AnnAssign) and isinstance(stmt.target,
ast.Name):
- target_name = stmt.target.id
- elif (
- isinstance(stmt, ast.Assign)
- and len(stmt.targets) == 1
- and isinstance(stmt.targets[0], ast.Name)
- ):
- target_name = stmt.targets[0].id
- if target_name != ATTR_NAME:
+ value = _attr_value(stmt)
+ if value is None:
continue
- if not isinstance(stmt.value, ast.Dict):
+ if not isinstance(value, ast.Dict):
raise ValueError(
f"{file_path}: {CLASS_NAME}.{ATTR_NAME} is not a dict
literal; "
f"this check parses a dict literal and must be updated."
)
- return stmt.value
+ return value
raise ValueError(f"{file_path}: {CLASS_NAME} has no {ATTR_NAME}
attribute.")
raise ValueError(f"{file_path}: no class {CLASS_NAME} found.")