ashb commented on a change in pull request #21641:
URL: https://github.com/apache/airflow/pull/21641#discussion_r809950052
##########
File path: airflow/models/mappedoperator.py
##########
@@ -103,6 +134,24 @@ def prevent_duplicates(kwargs1: Dict[str, Any], kwargs2:
Dict[str, Any], *, fail
raise TypeError(f"{fail_reason} arguments: {duplicated_keys_display}")
+def create_mocked_kwargs(kwargs: Dict[str, "MapArgument"]) -> Dict[str,
unittest.mock.MagicMock]:
+ """Create a mapping of mocks for given map arguments.
+
+ When a mapped operator is created, we want to perform basic validation on
+ the map arguments, especially the count of arguments. However, most of this
+ kind of logic lives directly on an operator class's ``__init__``, and
+ there's no good way to validate the arguments except to actually try to
+ create an operator instance.
+
+ Since the map arguments are yet to be populated when the mapped operator is
+ being parsed, we need to "invent" some mocked values for this validation
+ purpose. The :class:`~unittest.mock.MagicMock` class is a good fit for this
+ since it not only provide good run-time properties, but also enjoy special
+ treatments in Mypy.
+ """
+ return {k: unittest.mock.MagicMock(name=k) for k in kwargs}
Review comment:
Future possible enhancement (absolutely not in this pr):
Set `spec` for the mock based on the type hint of the param we are are
creating the mock for.
##########
File path: airflow/models/mappedoperator.py
##########
@@ -245,16 +296,13 @@ def _validate_argument_count(self) -> None:
"""
if isinstance(self.operator_class, str):
return # No need to validate deserialized operator.
- operator = self._create_unmapped_operator(
- mapped_kwargs={k: unittest.mock.MagicMock(name=k) for k in
self.mapped_kwargs},
- partial_kwargs=self.partial_kwargs,
- real=False,
- )
- if operator.task_group:
- operator.task_group._remove(operator)
- dag = operator.get_dag()
+ mocked_mapped_kwargs = create_mocked_kwargs(self.mapped_kwargs)
+ op =
self._create_unmapped_operator(mapped_kwargs=mocked_mapped_kwargs, real=False)
Review comment:
I do wonder if we'll need an escape hatch for this.
I guess we can add that when someone finds a specific case.
##########
File path: tests/models/test_taskinstance.py
##########
@@ -2353,3 +2353,110 @@ def pull_something(value):
assert task_map.map_index == -1
assert task_map.length == expected_length
assert task_map.keys == expected_keys
+
+
+class TestMappedTaskInstanceReceiveValue:
+ @pytest.mark.parametrize(
+ "literal, expected_outputs",
+ [
+ pytest.param([1, 2, 3], {1, 2, 3}, id="list"),
+ pytest.param({"a": 1, "b": 2}, {("a", 1), ("b", 2)}, id="dict"),
+ ],
+ )
+ def test_map_literal(self, literal, expected_outputs, dag_maker, session):
+ outputs = set()
+
+ with dag_maker(dag_id="literal", session=session) as dag:
+
+ @dag.task
+ def show(value):
+ outputs.add(value)
+
+ show.map(value=literal)
+
+ dag_run = dag_maker.create_dagrun()
+ show_task = dag.get_task("show")
+ tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+ assert len(tis) == len(literal)
+
+ for ti in tis:
+ ti.refresh_from_task(show_task)
+ ti.run()
+ assert outputs == expected_outputs
+
+ @pytest.mark.parametrize(
+ "upstream_return, expected_outputs",
+ [
+ pytest.param([1, 2, 3], {1, 2, 3}, id="list"),
+ pytest.param({"a": 1, "b": 2}, {("a", 1), ("b", 2)}, id="dict"),
+ ],
+ )
+ def test_map_xcom(self, upstream_return, expected_outputs, dag_maker,
session):
+ outputs = set()
+
+ with dag_maker(dag_id="xcom", session=session) as dag:
+
+ @dag.task
+ def emit():
+ return upstream_return
+
+ @dag.task
+ def show(value):
+ outputs.add(value)
+
+ show.map(value=emit())
+
+ dag_run = dag_maker.create_dagrun()
+ emit_ti = dag_run.get_task_instance("emit", session=session)
+ emit_ti.refresh_from_task(dag.get_task("emit"))
+ emit_ti.run()
+
+ show_task = dag.get_task("show")
+ tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+ assert len(tis) == len(upstream_return)
+
+ for ti in tis:
+ ti.refresh_from_task(show_task)
+ ti.run()
+ assert outputs == expected_outputs
+
+ def test_map_product(self, dag_maker, session):
+ outputs = set()
+
+ with dag_maker(dag_id="product", session=session) as dag:
+
+ @dag.task
+ def emit_numbers():
+ return [1, 2]
+
+ @dag.task
+ def emit_letters():
+ return {"a": "x", "b": "y", "c": "z"}
+
+ @dag.task
+ def show(number, letter):
+ outputs.add((number, letter))
+
+ show.map(number=emit_numbers(), letter=emit_letters())
+
+ dag_run = dag_maker.create_dagrun()
+ for task_id in ["emit_numbers", "emit_letters"]:
+ ti = dag_run.get_task_instance(task_id, session=session)
+ ti.refresh_from_task(dag.get_task(task_id))
+ ti.run()
+
+ show_task = dag.get_task("show")
+ tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+ assert len(tis) == 6
+
+ for ti in tis:
+ ti.refresh_from_task(show_task)
+ ti.run()
+ assert outputs == {
+ (1, ("a", "x")),
+ (2, ("a", "x")),
+ (1, ("b", "y")),
+ (2, ("b", "y")),
+ (1, ("c", "z")),
+ (2, ("c", "z")),
+ }
Review comment:
I don't think it matters, but I would have expected `emit_numbers to be
the outer loop since it's the "first" argument
```python
assert outputs == {
(1, ("a", "x")),
(1, ("c", "z")),
(1, ("b", "y")),
(2, ("a", "x")),
(2, ("b", "y")),
(2, ("c", "z")),
}
```
i.e. it's the equiv of:
```python
for number in emit_numbers():
for letter in emit_letters():
show(number, letter)
```
##########
File path: airflow/models/mappedoperator.py
##########
@@ -245,16 +296,13 @@ def _validate_argument_count(self) -> None:
"""
if isinstance(self.operator_class, str):
return # No need to validate deserialized operator.
- operator = self._create_unmapped_operator(
- mapped_kwargs={k: unittest.mock.MagicMock(name=k) for k in
self.mapped_kwargs},
- partial_kwargs=self.partial_kwargs,
- real=False,
- )
- if operator.task_group:
- operator.task_group._remove(operator)
- dag = operator.get_dag()
+ mocked_mapped_kwargs = create_mocked_kwargs(self.mapped_kwargs)
+ op =
self._create_unmapped_operator(mapped_kwargs=mocked_mapped_kwargs, real=False)
+ if op.task_group:
+ op.task_group._remove(op)
+ dag = op.get_dag()
if dag:
- dag._remove_task(operator.task_id)
+ dag._remove_task(op.task_id)
Review comment:
```python
dag = op.get_dag()
if dag:
dag._remove_task(op.task_id)
```
I think that is all we need -- dag._remove_tasks does the taskgroup
check/remove already.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]