This is an automated email from the ASF dual-hosted git repository.

rom pushed a commit to branch v2-10-test
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/v2-10-test by this push:
     new 6ac51ca7c95 Added task_instance_mutation_hook for mapped operator 
index 0 (#42661) (#43089)
6ac51ca7c95 is described below

commit 6ac51ca7c956d11dc701011f2cd72345c5dc2991
Author: Jens Scheffler <[email protected]>
AuthorDate: Wed Oct 16 20:44:26 2024 +0200

    Added task_instance_mutation_hook for mapped operator index 0 (#42661) 
(#43089)
    
    * Added task_instance_mutation_hook for mapped operator index 0
    
    * Added unit test
    
    ---------
    
    Co-authored-by: Marco Küttelwesch <[email protected]>
    (cherry picked from commit b7007e2b146e6ef929a211925a3d4397b1e9955d)
    
    Co-authored-by: AutomationDev85 
<[email protected]>
---
 airflow/models/abstractoperator.py  |  2 ++
 tests/models/test_mappedoperator.py | 25 +++++++++++++++++++++++++
 2 files changed, 27 insertions(+)

diff --git a/airflow/models/abstractoperator.py 
b/airflow/models/abstractoperator.py
index 5e5d13d5dc2..45eb3c5fff1 100644
--- a/airflow/models/abstractoperator.py
+++ b/airflow/models/abstractoperator.py
@@ -654,6 +654,8 @@ class AbstractOperator(Templater, DAGNode):
                     unmapped_ti.map_index = 0
                     self.log.debug("Updated in place to become %s", 
unmapped_ti)
                     all_expanded_tis.append(unmapped_ti)
+                    # execute hook for task instance map index 0
+                    task_instance_mutation_hook(unmapped_ti)
                     session.flush()
                 else:
                     self.log.debug("Deleting the original task instance: %s", 
unmapped_ti)
diff --git a/tests/models/test_mappedoperator.py 
b/tests/models/test_mappedoperator.py
index 2b0cd50165c..3b7eff19036 100644
--- a/tests/models/test_mappedoperator.py
+++ b/tests/models/test_mappedoperator.py
@@ -20,6 +20,7 @@ from __future__ import annotations
 from collections import defaultdict
 from datetime import timedelta
 from typing import TYPE_CHECKING
+from unittest import mock
 from unittest.mock import patch
 
 import pendulum
@@ -716,6 +717,30 @@ def test_expand_mapped_task_instance_with_named_index(
     assert indices == expected_rendered_names
 
 
[email protected](
+    "create_mapped_task",
+    [
+        pytest.param(_create_mapped_with_name_template_classic, id="classic"),
+        pytest.param(_create_mapped_with_name_template_taskflow, 
id="taskflow"),
+    ],
+)
+def test_expand_mapped_task_task_instance_mutation_hook(dag_maker, session, 
create_mapped_task) -> None:
+    """Test that the tast_instance_mutation_hook is called."""
+    expected_map_index = [0, 1, 2]
+
+    with dag_maker(session=session):
+        task1 = BaseOperator(task_id="op1")
+        mapped = 
MockOperator.partial(task_id="task_2").expand(arg2=task1.output)
+
+    dr = dag_maker.create_dagrun()
+
+    with mock.patch("airflow.settings.task_instance_mutation_hook") as 
mock_hook:
+        expand_mapped_task(mapped, dr.run_id, task1.task_id, 
length=len(expected_map_index), session=session)
+
+        for index, call in enumerate(mock_hook.call_args_list):
+            assert call.args[0].map_index == expected_map_index[index]
+
+
 @pytest.mark.skip_if_database_isolation_mode  # Does not work in db isolation 
mode
 @pytest.mark.parametrize(
     "map_index, expected",

Reply via email to