This is an automated email from the ASF dual-hosted git repository.

husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d1e6a5c48d Fix rendering the mapped parameters when using 
`expand_kwargs` method (#32272)
d1e6a5c48d is described below

commit d1e6a5c48d03322dda090113134f745d1f9c34d4
Author: Hussein Awala <[email protected]>
AuthorDate: Fri Aug 18 21:17:07 2023 +0200

    Fix rendering the mapped parameters when using `expand_kwargs` method 
(#32272)
    
    * Fix rendering the mapped parameters in the mapped operator
    
    Signed-off-by: Hussein Awala <[email protected]>
    
    * add template_in_template arg to expand method to tell Airflow whether to 
resolve the xcom data or not
    
    * fix dag serialization tests
    
    * Revert "fix dag serialization tests"
    
    This reverts commit 191351cda7b51bc6d49e7fcee5ab8ccd6cd219f5.
    
    * Revert "add template_in_template arg to expand method to tell Airflow 
whether to resolve the xcom data or not"
    
    This reverts commit 14bd392c7e5c6f25ce25d61dfd440e9b27c1bc2e.
    
    * Fix ListOfDictsExpandInput resolve method
    
    * remove _iter_parse_time_resolved_kwargs method
    
    * remove unnecessary step
    
    ---------
    
    Signed-off-by: Hussein Awala <[email protected]>
---
 airflow/models/expandinput.py       |  5 ++-
 tests/models/test_mappedoperator.py | 89 ++++++++++++++++++++++++++++++++-----
 2 files changed, 81 insertions(+), 13 deletions(-)

diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py
index 36fb5f4165..a9128568d3 100644
--- a/airflow/models/expandinput.py
+++ b/airflow/models/expandinput.py
@@ -265,7 +265,10 @@ class ListOfDictsExpandInput(NamedTuple):
                     f"expand_kwargs() input dict keys must all be str, "
                     f"but {key!r} is of type {_describe_type(key)}"
                 )
-        return mapping, {id(v) for v in mapping.values()}
+        # filter out parse time resolved values from the resolved_oids
+        resolved_oids = {id(v) for k, v in mapping.items() if not 
_is_parse_time_mappable(v)}
+
+        return mapping, resolved_oids
 
 
 EXPAND_INPUT_EMPTY = DictOfListsExpandInput({})  # Sentinel value.
diff --git a/tests/models/test_mappedoperator.py 
b/tests/models/test_mappedoperator.py
index a7f6d0660c..6d4a2fbca5 100644
--- a/tests/models/test_mappedoperator.py
+++ b/tests/models/test_mappedoperator.py
@@ -20,6 +20,7 @@ from __future__ import annotations
 import logging
 from collections import defaultdict
 from datetime import timedelta
+from unittest import mock
 from unittest.mock import patch
 
 import pendulum
@@ -399,17 +400,31 @@ def test_mapped_expand_against_params(dag_maker, 
dag_params, task_params, expect
 
 
 def test_mapped_render_template_fields_validating_operator(dag_maker, session):
-    class MyOperator(MockOperator):
-        def __init__(self, value, arg1, **kwargs):
-            assert isinstance(value, str), "value should have been resolved 
before unmapping"
-            assert isinstance(arg1, str), "value should have been resolved 
before unmapping"
-            super().__init__(arg1=arg1, **kwargs)
-            self.value = value
+    class MyOperator(BaseOperator):
+        template_fields = ("partial_template", "map_template", "file_template")
+        template_ext = (".ext",)
+
+        def __init__(
+            self, partial_template, partial_static, map_template, map_static, 
file_template, **kwargs
+        ):
+            for value in [partial_template, partial_static, map_template, 
map_static, file_template]:
+                assert isinstance(value, str), "value should have been 
resolved before unmapping"
+            super().__init__(**kwargs)
+            self.partial_template = partial_template
+            self.partial_static = partial_static
+            self.map_template = map_template
+            self.map_static = map_static
+            self.file_template = file_template
+
+        def execute(self, context):
+            pass
 
     with dag_maker(session=session):
         task1 = BaseOperator(task_id="op1")
         output1 = task1.output
-        mapped = MyOperator.partial(task_id="a", arg2="{{ ti.task_id 
}}").expand(value=output1, arg1=output1)
+        mapped = MyOperator.partial(
+            task_id="a", partial_template="{{ ti.task_id }}", 
partial_static="{{ ti.task_id }}"
+        ).expand(map_template=output1, map_static=output1, 
file_template=["/path/to/file.ext"])
 
     dr = dag_maker.create_dagrun()
     ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session)
@@ -432,12 +447,62 @@ def 
test_mapped_render_template_fields_validating_operator(dag_maker, session):
     mapped_ti.map_index = 0
 
     assert isinstance(mapped_ti.task, MappedOperator)
-    
mapped.render_template_fields(context=mapped_ti.get_template_context(session=session))
+    with patch("builtins.open", mock.mock_open(read_data=b"loaded data")), 
patch(
+        "os.path.isfile", return_value=True
+    ), patch("os.path.getmtime", return_value=0):
+        
mapped.render_template_fields(context=mapped_ti.get_template_context(session=session))
+    assert isinstance(mapped_ti.task, MyOperator)
+
+    assert mapped_ti.task.partial_template == "a", "Should be templated!"
+    assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be 
templated!"
+    assert mapped_ti.task.map_template == "{{ ds }}", "Should not be 
templated!"
+    assert mapped_ti.task.map_static == "{{ ds }}", "Should not be templated!"
+    assert mapped_ti.task.file_template == "loaded data", "Should be 
templated!"
+
+
+def 
test_mapped_expand_kwargs_render_template_fields_validating_operator(dag_maker, 
session):
+    class MyOperator(BaseOperator):
+        template_fields = ("partial_template", "map_template", "file_template")
+        template_ext = (".ext",)
+
+        def __init__(
+            self, partial_template, partial_static, map_template, map_static, 
file_template, **kwargs
+        ):
+            for value in [partial_template, partial_static, map_template, 
map_static, file_template]:
+                assert isinstance(value, str), "value should have been 
resolved before unmapping"
+            super().__init__(**kwargs)
+            self.partial_template = partial_template
+            self.partial_static = partial_static
+            self.map_template = map_template
+            self.map_static = map_static
+            self.file_template = file_template
+
+        def execute(self, context):
+            pass
+
+    with dag_maker(session=session):
+        mapped = MyOperator.partial(
+            task_id="a", partial_template="{{ ti.task_id }}", 
partial_static="{{ ti.task_id }}"
+        ).expand_kwargs(
+            [{"map_template": "{{ ds }}", "map_static": "{{ ds }}", 
"file_template": "/path/to/file.ext"}]
+        )
+
+    dr = dag_maker.create_dagrun()
+
+    mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, 
session=session, map_index=0)
+
+    assert isinstance(mapped_ti.task, MappedOperator)
+    with patch("builtins.open", mock.mock_open(read_data=b"loaded data")), 
patch(
+        "os.path.isfile", return_value=True
+    ), patch("os.path.getmtime", return_value=0):
+        
mapped.render_template_fields(context=mapped_ti.get_template_context(session=session))
     assert isinstance(mapped_ti.task, MyOperator)
 
-    assert mapped_ti.task.value == "{{ ds }}", "Should not be templated!"
-    assert mapped_ti.task.arg1 == "{{ ds }}", "Should not be templated!"
-    assert mapped_ti.task.arg2 == "a"
+    assert mapped_ti.task.partial_template == "a", "Should be templated!"
+    assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be 
templated!"
+    assert mapped_ti.task.map_template == "2016-01-01", "Should be templated!"
+    assert mapped_ti.task.map_static == "{{ ds }}", "Should not be templated!"
+    assert mapped_ti.task.file_template == "loaded data", "Should be 
templated!"
 
 
 def test_mapped_render_nested_template_fields(dag_maker, session):
@@ -534,7 +599,7 @@ def test_expand_kwargs_mapped_task_instance(dag_maker, 
session, num_existing_tis
 @pytest.mark.parametrize(
     "map_index, expected",
     [
-        pytest.param(0, "{{ ds }}", id="0"),
+        pytest.param(0, "2016-01-01", id="0"),
         pytest.param(1, 2, id="1"),
     ],
 )

Reply via email to