This is an automated email from the ASF dual-hosted git repository.
husseinawala pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d1e6a5c48d Fix rendering the mapped parameters when using
`expand_kwargs` method (#32272)
d1e6a5c48d is described below
commit d1e6a5c48d03322dda090113134f745d1f9c34d4
Author: Hussein Awala <[email protected]>
AuthorDate: Fri Aug 18 21:17:07 2023 +0200
Fix rendering the mapped parameters when using `expand_kwargs` method
(#32272)
* Fix rendering the mapped parameters in the mapped operator
Signed-off-by: Hussein Awala <[email protected]>
* add template_in_template arg to expand method to tell Airflow whether to
resolve the xcom data or not
* fix dag serialization tests
* Revert "fix dag serialization tests"
This reverts commit 191351cda7b51bc6d49e7fcee5ab8ccd6cd219f5.
* Revert "add template_in_template arg to expand method to tell Airflow
whether to resolve the xcom data or not"
This reverts commit 14bd392c7e5c6f25ce25d61dfd440e9b27c1bc2e.
* Fix ListOfDictsExpandInput resolve method
* remove _iter_parse_time_resolved_kwargs method
* remove unnecessary step
---------
Signed-off-by: Hussein Awala <[email protected]>
---
airflow/models/expandinput.py | 5 ++-
tests/models/test_mappedoperator.py | 89 ++++++++++++++++++++++++++++++++-----
2 files changed, 81 insertions(+), 13 deletions(-)
diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py
index 36fb5f4165..a9128568d3 100644
--- a/airflow/models/expandinput.py
+++ b/airflow/models/expandinput.py
@@ -265,7 +265,10 @@ class ListOfDictsExpandInput(NamedTuple):
f"expand_kwargs() input dict keys must all be str, "
f"but {key!r} is of type {_describe_type(key)}"
)
- return mapping, {id(v) for v in mapping.values()}
+ # filter out parse time resolved values from the resolved_oids
+ resolved_oids = {id(v) for k, v in mapping.items() if not
_is_parse_time_mappable(v)}
+
+ return mapping, resolved_oids
EXPAND_INPUT_EMPTY = DictOfListsExpandInput({}) # Sentinel value.
diff --git a/tests/models/test_mappedoperator.py
b/tests/models/test_mappedoperator.py
index a7f6d0660c..6d4a2fbca5 100644
--- a/tests/models/test_mappedoperator.py
+++ b/tests/models/test_mappedoperator.py
@@ -20,6 +20,7 @@ from __future__ import annotations
import logging
from collections import defaultdict
from datetime import timedelta
+from unittest import mock
from unittest.mock import patch
import pendulum
@@ -399,17 +400,31 @@ def test_mapped_expand_against_params(dag_maker,
dag_params, task_params, expect
def test_mapped_render_template_fields_validating_operator(dag_maker, session):
- class MyOperator(MockOperator):
- def __init__(self, value, arg1, **kwargs):
- assert isinstance(value, str), "value should have been resolved
before unmapping"
- assert isinstance(arg1, str), "value should have been resolved
before unmapping"
- super().__init__(arg1=arg1, **kwargs)
- self.value = value
+ class MyOperator(BaseOperator):
+ template_fields = ("partial_template", "map_template", "file_template")
+ template_ext = (".ext",)
+
+ def __init__(
+ self, partial_template, partial_static, map_template, map_static,
file_template, **kwargs
+ ):
+ for value in [partial_template, partial_static, map_template,
map_static, file_template]:
+ assert isinstance(value, str), "value should have been
resolved before unmapping"
+ super().__init__(**kwargs)
+ self.partial_template = partial_template
+ self.partial_static = partial_static
+ self.map_template = map_template
+ self.map_static = map_static
+ self.file_template = file_template
+
+ def execute(self, context):
+ pass
with dag_maker(session=session):
task1 = BaseOperator(task_id="op1")
output1 = task1.output
- mapped = MyOperator.partial(task_id="a", arg2="{{ ti.task_id
}}").expand(value=output1, arg1=output1)
+ mapped = MyOperator.partial(
+ task_id="a", partial_template="{{ ti.task_id }}",
partial_static="{{ ti.task_id }}"
+ ).expand(map_template=output1, map_static=output1,
file_template=["/path/to/file.ext"])
dr = dag_maker.create_dagrun()
ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session)
@@ -432,12 +447,62 @@ def
test_mapped_render_template_fields_validating_operator(dag_maker, session):
mapped_ti.map_index = 0
assert isinstance(mapped_ti.task, MappedOperator)
-
mapped.render_template_fields(context=mapped_ti.get_template_context(session=session))
+ with patch("builtins.open", mock.mock_open(read_data=b"loaded data")),
patch(
+ "os.path.isfile", return_value=True
+ ), patch("os.path.getmtime", return_value=0):
+
mapped.render_template_fields(context=mapped_ti.get_template_context(session=session))
+ assert isinstance(mapped_ti.task, MyOperator)
+
+ assert mapped_ti.task.partial_template == "a", "Should be templated!"
+ assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be
templated!"
+ assert mapped_ti.task.map_template == "{{ ds }}", "Should not be
templated!"
+ assert mapped_ti.task.map_static == "{{ ds }}", "Should not be templated!"
+ assert mapped_ti.task.file_template == "loaded data", "Should be
templated!"
+
+
+def
test_mapped_expand_kwargs_render_template_fields_validating_operator(dag_maker,
session):
+ class MyOperator(BaseOperator):
+ template_fields = ("partial_template", "map_template", "file_template")
+ template_ext = (".ext",)
+
+ def __init__(
+ self, partial_template, partial_static, map_template, map_static,
file_template, **kwargs
+ ):
+ for value in [partial_template, partial_static, map_template,
map_static, file_template]:
+ assert isinstance(value, str), "value should have been
resolved before unmapping"
+ super().__init__(**kwargs)
+ self.partial_template = partial_template
+ self.partial_static = partial_static
+ self.map_template = map_template
+ self.map_static = map_static
+ self.file_template = file_template
+
+ def execute(self, context):
+ pass
+
+ with dag_maker(session=session):
+ mapped = MyOperator.partial(
+ task_id="a", partial_template="{{ ti.task_id }}",
partial_static="{{ ti.task_id }}"
+ ).expand_kwargs(
+ [{"map_template": "{{ ds }}", "map_static": "{{ ds }}",
"file_template": "/path/to/file.ext"}]
+ )
+
+ dr = dag_maker.create_dagrun()
+
+ mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id,
session=session, map_index=0)
+
+ assert isinstance(mapped_ti.task, MappedOperator)
+ with patch("builtins.open", mock.mock_open(read_data=b"loaded data")),
patch(
+ "os.path.isfile", return_value=True
+ ), patch("os.path.getmtime", return_value=0):
+
mapped.render_template_fields(context=mapped_ti.get_template_context(session=session))
assert isinstance(mapped_ti.task, MyOperator)
- assert mapped_ti.task.value == "{{ ds }}", "Should not be templated!"
- assert mapped_ti.task.arg1 == "{{ ds }}", "Should not be templated!"
- assert mapped_ti.task.arg2 == "a"
+ assert mapped_ti.task.partial_template == "a", "Should be templated!"
+ assert mapped_ti.task.partial_static == "{{ ti.task_id }}", "Should not be
templated!"
+ assert mapped_ti.task.map_template == "2016-01-01", "Should be templated!"
+ assert mapped_ti.task.map_static == "{{ ds }}", "Should not be templated!"
+ assert mapped_ti.task.file_template == "loaded data", "Should be
templated!"
def test_mapped_render_nested_template_fields(dag_maker, session):
@@ -534,7 +599,7 @@ def test_expand_kwargs_mapped_task_instance(dag_maker,
session, num_existing_tis
@pytest.mark.parametrize(
"map_index, expected",
[
- pytest.param(0, "{{ ds }}", id="0"),
+ pytest.param(0, "2016-01-01", id="0"),
pytest.param(1, 2, id="1"),
],
)