This is an automated email from the ASF dual-hosted git repository.
rom pushed a commit to branch v2-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v2-10-test by this push:
new 6ac51ca7c95 Added task_instance_mutation_hook for mapped operator
index 0 (#42661) (#43089)
6ac51ca7c95 is described below
commit 6ac51ca7c956d11dc701011f2cd72345c5dc2991
Author: Jens Scheffler <[email protected]>
AuthorDate: Wed Oct 16 20:44:26 2024 +0200
Added task_instance_mutation_hook for mapped operator index 0 (#42661)
(#43089)
* Added task_instance_mutation_hook for mapped operator index 0
* Added unit test
---------
Co-authored-by: Marco Küttelwesch <[email protected]>
(cherry picked from commit b7007e2b146e6ef929a211925a3d4397b1e9955d)
Co-authored-by: AutomationDev85
<[email protected]>
---
airflow/models/abstractoperator.py | 2 ++
tests/models/test_mappedoperator.py | 25 +++++++++++++++++++++++++
2 files changed, 27 insertions(+)
diff --git a/airflow/models/abstractoperator.py
b/airflow/models/abstractoperator.py
index 5e5d13d5dc2..45eb3c5fff1 100644
--- a/airflow/models/abstractoperator.py
+++ b/airflow/models/abstractoperator.py
@@ -654,6 +654,8 @@ class AbstractOperator(Templater, DAGNode):
unmapped_ti.map_index = 0
self.log.debug("Updated in place to become %s",
unmapped_ti)
all_expanded_tis.append(unmapped_ti)
+ # execute hook for task instance map index 0
+ task_instance_mutation_hook(unmapped_ti)
session.flush()
else:
self.log.debug("Deleting the original task instance: %s",
unmapped_ti)
diff --git a/tests/models/test_mappedoperator.py
b/tests/models/test_mappedoperator.py
index 2b0cd50165c..3b7eff19036 100644
--- a/tests/models/test_mappedoperator.py
+++ b/tests/models/test_mappedoperator.py
@@ -20,6 +20,7 @@ from __future__ import annotations
from collections import defaultdict
from datetime import timedelta
from typing import TYPE_CHECKING
+from unittest import mock
from unittest.mock import patch
import pendulum
@@ -716,6 +717,30 @@ def test_expand_mapped_task_instance_with_named_index(
assert indices == expected_rendered_names
[email protected](
+ "create_mapped_task",
+ [
+ pytest.param(_create_mapped_with_name_template_classic, id="classic"),
+ pytest.param(_create_mapped_with_name_template_taskflow,
id="taskflow"),
+ ],
+)
+def test_expand_mapped_task_task_instance_mutation_hook(dag_maker, session,
create_mapped_task) -> None:
+ """Test that the tast_instance_mutation_hook is called."""
+ expected_map_index = [0, 1, 2]
+
+ with dag_maker(session=session):
+ task1 = BaseOperator(task_id="op1")
+ mapped =
MockOperator.partial(task_id="task_2").expand(arg2=task1.output)
+
+ dr = dag_maker.create_dagrun()
+
+ with mock.patch("airflow.settings.task_instance_mutation_hook") as
mock_hook:
+ expand_mapped_task(mapped, dr.run_id, task1.task_id,
length=len(expected_map_index), session=session)
+
+ for index, call in enumerate(mock_hook.call_args_list):
+ assert call.args[0].map_index == expected_map_index[index]
+
+
@pytest.mark.skip_if_database_isolation_mode # Does not work in db isolation
mode
@pytest.mark.parametrize(
"map_index, expected",