ashb commented on a change in pull request #21641:
URL: https://github.com/apache/airflow/pull/21641#discussion_r809955067
##########
File path: tests/models/test_taskinstance.py
##########
@@ -2353,3 +2353,110 @@ def pull_something(value):
assert task_map.map_index == -1
assert task_map.length == expected_length
assert task_map.keys == expected_keys
+
+
+class TestMappedTaskInstanceReceiveValue:
+ @pytest.mark.parametrize(
+ "literal, expected_outputs",
+ [
+ pytest.param([1, 2, 3], {1, 2, 3}, id="list"),
+ pytest.param({"a": 1, "b": 2}, {("a", 1), ("b", 2)}, id="dict"),
+ ],
+ )
+ def test_map_literal(self, literal, expected_outputs, dag_maker, session):
+ outputs = set()
+
+ with dag_maker(dag_id="literal", session=session) as dag:
+
+ @dag.task
+ def show(value):
+ outputs.add(value)
+
+ show.map(value=literal)
+
+ dag_run = dag_maker.create_dagrun()
+ show_task = dag.get_task("show")
+ tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+ assert len(tis) == len(literal)
+
+ for ti in tis:
+ ti.refresh_from_task(show_task)
+ ti.run()
+ assert outputs == expected_outputs
+
+ @pytest.mark.parametrize(
+ "upstream_return, expected_outputs",
+ [
+ pytest.param([1, 2, 3], {1, 2, 3}, id="list"),
+ pytest.param({"a": 1, "b": 2}, {("a", 1), ("b", 2)}, id="dict"),
+ ],
+ )
+ def test_map_xcom(self, upstream_return, expected_outputs, dag_maker,
session):
+ outputs = set()
+
+ with dag_maker(dag_id="xcom", session=session) as dag:
+
+ @dag.task
+ def emit():
+ return upstream_return
+
+ @dag.task
+ def show(value):
+ outputs.add(value)
+
+ show.map(value=emit())
+
+ dag_run = dag_maker.create_dagrun()
+ emit_ti = dag_run.get_task_instance("emit", session=session)
+ emit_ti.refresh_from_task(dag.get_task("emit"))
+ emit_ti.run()
+
+ show_task = dag.get_task("show")
+ tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+ assert len(tis) == len(upstream_return)
+
+ for ti in tis:
+ ti.refresh_from_task(show_task)
+ ti.run()
+ assert outputs == expected_outputs
+
+ def test_map_product(self, dag_maker, session):
+ outputs = set()
+
+ with dag_maker(dag_id="product", session=session) as dag:
+
+ @dag.task
+ def emit_numbers():
+ return [1, 2]
+
+ @dag.task
+ def emit_letters():
+ return {"a": "x", "b": "y", "c": "z"}
+
+ @dag.task
+ def show(number, letter):
+ outputs.add((number, letter))
+
+ show.map(number=emit_numbers(), letter=emit_letters())
+
+ dag_run = dag_maker.create_dagrun()
+ for task_id in ["emit_numbers", "emit_letters"]:
+ ti = dag_run.get_task_instance(task_id, session=session)
+ ti.refresh_from_task(dag.get_task(task_id))
+ ti.run()
+
+ show_task = dag.get_task("show")
+ tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+ assert len(tis) == 6
+
+ for ti in tis:
+ ti.refresh_from_task(show_task)
+ ti.run()
+ assert outputs == {
+ (1, ("a", "x")),
+ (2, ("a", "x")),
+ (1, ("b", "y")),
+ (2, ("b", "y")),
+ (1, ("c", "z")),
+ (2, ("c", "z")),
+ }
Review comment:
I don't think it matters, but I would have expected `emit_numbers to be
the outer loop since it's the "first" argument
```python
assert outputs == {
(1, ("a", "x")),
(1, ("c", "z")),
(1, ("b", "y")),
(2, ("a", "x")),
(2, ("b", "y")),
(2, ("c", "z")),
}
```
i.e. `show.map(number=emit_numbers(), letter=emit_letters())` is the equiv
of:
```python
for number in emit_numbers():
for letter in emit_letters():
show(number, letter)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]