uranusjr commented on a change in pull request #21641:
URL: https://github.com/apache/airflow/pull/21641#discussion_r810680291
##########
File path: tests/models/test_taskinstance.py
##########
@@ -2353,3 +2353,110 @@ def pull_something(value):
assert task_map.map_index == -1
assert task_map.length == expected_length
assert task_map.keys == expected_keys
+
+
+class TestMappedTaskInstanceReceiveValue:
+ @pytest.mark.parametrize(
+ "literal, expected_outputs",
+ [
+ pytest.param([1, 2, 3], {1, 2, 3}, id="list"),
+ pytest.param({"a": 1, "b": 2}, {("a", 1), ("b", 2)}, id="dict"),
+ ],
+ )
+ def test_map_literal(self, literal, expected_outputs, dag_maker, session):
+ outputs = set()
+
+ with dag_maker(dag_id="literal", session=session) as dag:
+
+ @dag.task
+ def show(value):
+ outputs.add(value)
+
+ show.map(value=literal)
+
+ dag_run = dag_maker.create_dagrun()
+ show_task = dag.get_task("show")
+ tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+ assert len(tis) == len(literal)
+
+ for ti in tis:
+ ti.refresh_from_task(show_task)
+ ti.run()
+ assert outputs == expected_outputs
+
+ @pytest.mark.parametrize(
+ "upstream_return, expected_outputs",
+ [
+ pytest.param([1, 2, 3], {1, 2, 3}, id="list"),
+ pytest.param({"a": 1, "b": 2}, {("a", 1), ("b", 2)}, id="dict"),
+ ],
+ )
+ def test_map_xcom(self, upstream_return, expected_outputs, dag_maker,
session):
+ outputs = set()
+
+ with dag_maker(dag_id="xcom", session=session) as dag:
+
+ @dag.task
+ def emit():
+ return upstream_return
+
+ @dag.task
+ def show(value):
+ outputs.add(value)
+
+ show.map(value=emit())
+
+ dag_run = dag_maker.create_dagrun()
+ emit_ti = dag_run.get_task_instance("emit", session=session)
+ emit_ti.refresh_from_task(dag.get_task("emit"))
+ emit_ti.run()
+
+ show_task = dag.get_task("show")
+ tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+ assert len(tis) == len(upstream_return)
+
+ for ti in tis:
+ ti.refresh_from_task(show_task)
+ ti.run()
+ assert outputs == expected_outputs
+
+ def test_map_product(self, dag_maker, session):
+ outputs = set()
+
+ with dag_maker(dag_id="product", session=session) as dag:
+
+ @dag.task
+ def emit_numbers():
+ return [1, 2]
+
+ @dag.task
+ def emit_letters():
+ return {"a": "x", "b": "y", "c": "z"}
+
+ @dag.task
+ def show(number, letter):
+ outputs.add((number, letter))
+
+ show.map(number=emit_numbers(), letter=emit_letters())
+
+ dag_run = dag_maker.create_dagrun()
+ for task_id in ["emit_numbers", "emit_letters"]:
+ ti = dag_run.get_task_instance(task_id, session=session)
+ ti.refresh_from_task(dag.get_task(task_id))
+ ti.run()
+
+ show_task = dag.get_task("show")
+ tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+ assert len(tis) == 6
+
+ for ti in tis:
+ ti.refresh_from_task(show_task)
+ ti.run()
+ assert outputs == {
+ (1, ("a", "x")),
+ (2, ("a", "x")),
+ (1, ("b", "y")),
+ (2, ("b", "y")),
+ (1, ("c", "z")),
+ (2, ("c", "z")),
+ }
Review comment:
We guarantee that map_index would be generated in a predictable order,
but not the mapped task instances would execute in that same order (which is
the potentially unstable part here).
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]