This is an automated email from the ASF dual-hosted git repository.
ash pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 197cff3194 Ensure TaskMap only checks "relevant" dependencies (#23053)
197cff3194 is described below
commit 197cff3194e855b9207c3c0da8ae093a0d5dda55
Author: Tzu-ping Chung <[email protected]>
AuthorDate: Wed Apr 20 02:02:15 2022 +0800
Ensure TaskMap only checks "relevant" dependencies (#23053)
When looking for "mapped dependants" of a task, we only want a task if
it not only is a direct downstream of the task, but also it actually
"uses" the task's pushed XCom for task mapping. So we need to peek into
the mapped downstream task's expansion kwargs, and only count it as a
mapped dependant if the upstream is referenced there.
---
airflow/jobs/backfill_job.py | 2 +-
airflow/models/mappedoperator.py | 9 +++++++
airflow/models/taskinstance.py | 2 +-
airflow/models/taskmixin.py | 23 +++++++++++++-----
airflow/models/xcom_arg.py | 30 +++++++++++++++--------
tests/models/test_taskinstance.py | 51 +++++++++++++++++++++++++++++++++++++--
6 files changed, 97 insertions(+), 20 deletions(-)
diff --git a/airflow/jobs/backfill_job.py b/airflow/jobs/backfill_job.py
index 7a7541eb80..6695566cc7 100644
--- a/airflow/jobs/backfill_job.py
+++ b/airflow/jobs/backfill_job.py
@@ -266,7 +266,7 @@ class BackfillJob(BaseJob):
if ti.state not in self.STATES_COUNT_AS_RUNNING:
# Don't use ti.task; if this task is mapped, that attribute
# would hold the unmapped task. We need to original task here.
- for node in self.dag.get_task(ti.task_id,
include_subdags=True).mapped_dependants():
+ for node in self.dag.get_task(ti.task_id,
include_subdags=True).iter_mapped_dependants():
new_tis, num_mapped_tis =
node.expand_mapped_task(ti.run_id, session=session)
yield node, ti.run_id, new_tis, num_mapped_tis
diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py
index ddc32906df..4e4bd14e01 100644
--- a/airflow/models/mappedoperator.py
+++ b/airflow/models/mappedoperator.py
@@ -30,6 +30,7 @@ from typing import (
Dict,
FrozenSet,
Iterable,
+ Iterator,
List,
Optional,
Sequence,
@@ -76,6 +77,7 @@ if TYPE_CHECKING:
from airflow.models.baseoperator import BaseOperator, BaseOperatorLink
from airflow.models.dag import DAG
+ from airflow.models.operator import Operator
from airflow.models.taskinstance import TaskInstance
from airflow.models.xcom_arg import XComArg
from airflow.utils.task_group import TaskGroup
@@ -775,6 +777,13 @@ class MappedOperator(AbstractOperator):
return k, v
raise IndexError(f"index {map_index} is over mapped length")
+ def iter_mapped_dependencies(self) -> Iterator["Operator"]:
+ """Upstream dependencies that provide XComs used by this task for task
mapping."""
+ from airflow.models.xcom_arg import XComArg
+
+ for ref in XComArg.iter_xcom_args(self._get_expansion_kwargs()):
+ yield ref.operator
+
@cached_property
def parse_time_mapped_ti_count(self) -> Optional[int]:
"""
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 89ec9d17c5..48d3a047fb 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -2333,7 +2333,7 @@ class TaskInstance(Base, LoggingMixin):
# currently possible for a downstream to depend on one individual
mapped
# task instance, only a task as a whole. This will change in AIP-42
# Phase 2, and we'll need to further analyze the mapped task case.
- if task.is_mapped or not task.has_mapped_dependants():
+ if task.is_mapped or next(task.iter_mapped_dependants(), None) is None:
return
if value is None:
raise XComForMappingNotPushed()
diff --git a/airflow/models/taskmixin.py b/airflow/models/taskmixin.py
index c5d6165e8d..1a7129f66e 100644
--- a/airflow/models/taskmixin.py
+++ b/airflow/models/taskmixin.py
@@ -291,13 +291,20 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta):
"""This is used by SerializedTaskGroup to serialize a task group's
content."""
raise NotImplementedError()
- def mapped_dependants(self) -> Iterator["MappedOperator"]:
- """Return any mapped nodes that are direct dependencies of the current
task
+ def _iter_all_mapped_downstreams(self) -> Iterator["MappedOperator"]:
+ """Return mapped nodes that are direct dependencies of the current
task.
For now, this walks the entire DAG to find mapped nodes that has this
current task as an upstream. We cannot use ``downstream_list`` since it
only contains operators, not task groups. In the future, we should
provide a way to record an DAG node's all downstream nodes instead.
+
+ Note that this does not guarantee the returned tasks actually use the
+ current task for task mapping, but only checks those task are mapped
+ operators, and are downstreams of the current task.
+
+ To get a list of tasks that uses the current task for task mapping, use
+ :meth:`iter_mapped_dependants` instead.
"""
from airflow.models.mappedoperator import MappedOperator
from airflow.utils.task_group import TaskGroup
@@ -315,7 +322,7 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta):
tg = self.task_group
if not tg:
- raise RuntimeError("Cannot check for mapped_dependants when not
attached to a DAG")
+ raise RuntimeError("Cannot check for mapped dependants when not
attached to a DAG")
for key, child in _walk_group(tg):
if key == self.node_id:
continue
@@ -324,12 +331,16 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta):
if self.node_id in child.upstream_task_ids:
yield child
- def has_mapped_dependants(self) -> bool:
- """Whether any downstream dependencies depend on this task for mapping.
+ def iter_mapped_dependants(self) -> Iterator["MappedOperator"]:
+ """Return mapped nodes that depend on the current task the expansion.
For now, this walks the entire DAG to find mapped nodes that has this
current task as an upstream. We cannot use ``downstream_list`` since it
only contains operators, not task groups. In the future, we should
provide a way to record an DAG node's all downstream nodes instead.
"""
- return any(self.mapped_dependants())
+ return (
+ downstream
+ for downstream in self._iter_all_mapped_downstreams()
+ if any(p.node_id == self.node_id for p in
downstream.iter_mapped_dependencies())
+ )
diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py
index 449fab8af5..2c2a8f9b58 100644
--- a/airflow/models/xcom_arg.py
+++ b/airflow/models/xcom_arg.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Union
+from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Sequence,
Union
from airflow.exceptions import AirflowException
from airflow.models.abstractoperator import AbstractOperator
@@ -156,21 +156,31 @@ class XComArg(DependencyMixin):
return result
@staticmethod
- def apply_upstream_relationship(op: "Operator", arg: Any):
- """
- Set dependency for XComArgs.
+ def iter_xcom_args(arg: Any) -> Iterator["XComArg"]:
+ """Return XComArg instances in an arbitrary value.
- This looks for XComArg objects in ``arg`` "deeply" (looking inside
lists, dicts and classes decorated
- with "template_fields") and sets the relationship to ``op`` on any
found.
+ This recursively traverse ``arg`` and look for XComArg instances in any
+ collection objects, and instances with ``template_fields`` set.
"""
if isinstance(arg, XComArg):
- op.set_upstream(arg.operator)
+ yield arg
elif isinstance(arg, (tuple, set, list)):
for elem in arg:
- XComArg.apply_upstream_relationship(op, elem)
+ yield from XComArg.iter_xcom_args(elem)
elif isinstance(arg, dict):
for elem in arg.values():
- XComArg.apply_upstream_relationship(op, elem)
+ yield from XComArg.iter_xcom_args(elem)
elif isinstance(arg, AbstractOperator):
for elem in arg.template_fields:
- XComArg.apply_upstream_relationship(op, elem)
+ yield from XComArg.iter_xcom_args(elem)
+
+ @staticmethod
+ def apply_upstream_relationship(op: "Operator", arg: Any):
+ """Set dependency for XComArgs.
+
+ This looks for XComArg objects in ``arg`` "deeply" (looking inside
+ collections objects and classes decorated with ``template_fields``),
and
+ sets the relationship to ``op`` on any found.
+ """
+ for ref in XComArg.iter_xcom_args(arg):
+ op.set_upstream(ref.operator)
diff --git a/tests/models/test_taskinstance.py
b/tests/models/test_taskinstance.py
index 3e5f30239a..2c52fd0eb6 100644
--- a/tests/models/test_taskinstance.py
+++ b/tests/models/test_taskinstance.py
@@ -2382,8 +2382,8 @@ class TestTaskInstanceRecordTaskMapXComPush:
session.query(TaskMap).delete()
@pytest.mark.parametrize("xcom_value", [[1, 2, 3], {"a": 1, "b": 2},
"abc"])
- def test_not_recorded_for_unused(self, dag_maker, xcom_value):
- """A value not used for task-mapping should not be recorded."""
+ def test_not_recorded_if_leaf(self, dag_maker, xcom_value):
+ """Return value should not be recorded if there are no downstreams."""
with dag_maker(dag_id="test_not_recorded_for_unused") as dag:
@dag.task()
@@ -2397,6 +2397,53 @@ class TestTaskInstanceRecordTaskMapXComPush:
assert dag_maker.session.query(TaskMap).count() == 0
+ @pytest.mark.parametrize("xcom_value", [[1, 2, 3], {"a": 1, "b": 2},
"abc"])
+ def test_not_recorded_if_not_used(self, dag_maker, xcom_value):
+ """Return value should not be recorded if no downstreams are mapped."""
+ with dag_maker(dag_id="test_not_recorded_for_unused") as dag:
+
+ @dag.task()
+ def push_something():
+ return xcom_value
+
+ @dag.task()
+ def completely_different():
+ pass
+
+ push_something() >> completely_different()
+
+ ti = next(ti for ti in dag_maker.create_dagrun().task_instances if
ti.task_id == "push_something")
+ ti.run()
+
+ assert dag_maker.session.query(TaskMap).count() == 0
+
+ @pytest.mark.parametrize("xcom_value", [[1, 2, 3], {"a": 1, "b": 2},
"abc"])
+ def test_not_recorded_if_irrelevant(self, dag_maker, xcom_value):
+ """Return value should only be recorded if a mapped downstream uses
the it."""
+ with dag_maker(dag_id="test_not_recorded_for_unused") as dag:
+
+ @dag.task()
+ def push_1():
+ return xcom_value
+
+ @dag.task()
+ def push_2():
+ return [-1, -2]
+
+ @dag.task()
+ def show(arg1, arg2):
+ print(arg1, arg2)
+
+ show.partial(arg1=push_1()).expand(arg2=push_2())
+
+ tis = {ti.task_id: ti for ti in
dag_maker.create_dagrun().task_instances}
+
+ tis["push_1"].run()
+ assert dag_maker.session.query(TaskMap).count() == 0
+
+ tis["push_2"].run()
+ assert dag_maker.session.query(TaskMap).count() == 1
+
@pytest.mark.parametrize(
"return_value, exception_type, error_message",
[