uranusjr commented on a change in pull request #21641:
URL: https://github.com/apache/airflow/pull/21641#discussion_r809974265
##########
File path: tests/models/test_taskinstance.py
##########
@@ -2353,3 +2353,110 @@ def pull_something(value):
assert task_map.map_index == -1
assert task_map.length == expected_length
assert task_map.keys == expected_keys
+
+
+class TestMappedTaskInstanceReceiveValue:
+ @pytest.mark.parametrize(
+ "literal, expected_outputs",
+ [
+ pytest.param([1, 2, 3], {1, 2, 3}, id="list"),
+ pytest.param({"a": 1, "b": 2}, {("a", 1), ("b", 2)}, id="dict"),
+ ],
+ )
+ def test_map_literal(self, literal, expected_outputs, dag_maker, session):
+ outputs = set()
+
+ with dag_maker(dag_id="literal", session=session) as dag:
+
+ @dag.task
+ def show(value):
+ outputs.add(value)
+
+ show.map(value=literal)
+
+ dag_run = dag_maker.create_dagrun()
+ show_task = dag.get_task("show")
+ tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+ assert len(tis) == len(literal)
+
+ for ti in tis:
+ ti.refresh_from_task(show_task)
+ ti.run()
+ assert outputs == expected_outputs
+
+ @pytest.mark.parametrize(
+ "upstream_return, expected_outputs",
+ [
+ pytest.param([1, 2, 3], {1, 2, 3}, id="list"),
+ pytest.param({"a": 1, "b": 2}, {("a", 1), ("b", 2)}, id="dict"),
+ ],
+ )
+ def test_map_xcom(self, upstream_return, expected_outputs, dag_maker,
session):
+ outputs = set()
+
+ with dag_maker(dag_id="xcom", session=session) as dag:
+
+ @dag.task
+ def emit():
+ return upstream_return
+
+ @dag.task
+ def show(value):
+ outputs.add(value)
+
+ show.map(value=emit())
+
+ dag_run = dag_maker.create_dagrun()
+ emit_ti = dag_run.get_task_instance("emit", session=session)
+ emit_ti.refresh_from_task(dag.get_task("emit"))
+ emit_ti.run()
+
+ show_task = dag.get_task("show")
+ tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+ assert len(tis) == len(upstream_return)
+
+ for ti in tis:
+ ti.refresh_from_task(show_task)
+ ti.run()
+ assert outputs == expected_outputs
+
+ def test_map_product(self, dag_maker, session):
+ outputs = set()
+
+ with dag_maker(dag_id="product", session=session) as dag:
+
+ @dag.task
+ def emit_numbers():
+ return [1, 2]
+
+ @dag.task
+ def emit_letters():
+ return {"a": "x", "b": "y", "c": "z"}
+
+ @dag.task
+ def show(number, letter):
+ outputs.add((number, letter))
+
+ show.map(number=emit_numbers(), letter=emit_letters())
+
+ dag_run = dag_maker.create_dagrun()
+ for task_id in ["emit_numbers", "emit_letters"]:
+ ti = dag_run.get_task_instance(task_id, session=session)
+ ti.refresh_from_task(dag.get_task(task_id))
+ ti.run()
+
+ show_task = dag.get_task("show")
+ tis = show_task.expand_mapped_task(dag_run.run_id, session=session)
+ assert len(tis) == 6
+
+ for ti in tis:
+ ti.refresh_from_task(show_task)
+ ti.run()
+ assert outputs == {
+ (1, ("a", "x")),
+ (2, ("a", "x")),
+ (1, ("b", "y")),
+ (2, ("b", "y")),
+ (1, ("c", "z")),
+ (2, ("c", "z")),
+ }
Review comment:
This is a set so the ordering doesn’t matter :p You’re right on the
actual ordering. (I didn’t want to rely on that because the *list of task
instances* are not guaranteed to be ordered.)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]