This is an automated email from the ASF dual-hosted git repository.
jasonliu pushed a commit to branch v3-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/v3-0-test by this push:
new 486dfda72dd [v3-0-test] fix(task_instances): handle
upstream_mapped_index when xcom access is needed (#50641) (#50950)
486dfda72dd is described below
commit 486dfda72ddf023653a6672d857f84c62ad7bf9f
Author: github-actions[bot]
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Mon May 26 20:33:15 2025 +0800
[v3-0-test] fix(task_instances): handle upstream_mapped_index when xcom
access is needed (#50641) (#50950)
* fix(task_instances): handle upstream_mapped_index when xcom access is
needed
* style(expand_input): fix expand_input and SchedulerExpandInput types
* test(task_instances): add test_dynamic_task_mapping_with_parse_time_value
* test(task_instance): add test_dynamic_task_mapping_with_xcom
* style: import typing
* style: move the SchedulerExpandInput into type checking block
* Revert "style: move the SchedulerExpandInput into type checking block"
This reverts commit c2c87ca304bfe721120bda19f3dcc3a0ddab8804.
(cherry picked from commit 5458e7e7be86c6de034d7a589bd26db85c532308)
Co-authored-by: Wei Lee <[email protected]>
---
.../execution_api/routes/task_instances.py | 20 +++-
airflow-core/src/airflow/models/expandinput.py | 13 ++-
.../airflow/serialization/serialized_objects.py | 4 +-
.../versions/head/test_task_instances.py | 124 ++++++++++++++++++++-
.../src/airflow/sdk/definitions/mappedoperator.py | 2 +-
task-sdk/src/airflow/sdk/definitions/taskgroup.py | 4 +-
6 files changed, 153 insertions(+), 14 deletions(-)
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index a48070deb26..ac1d1602460 100644
---
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -57,6 +57,7 @@ from airflow.models.taskinstance import TaskInstance as TI,
_stop_remaining_task
from airflow.models.taskreschedule import TaskReschedule
from airflow.models.trigger import Trigger
from airflow.models.xcom import XComModel
+from airflow.sdk.definitions._internal.expandinput import NotFullyPopulated
from airflow.sdk.definitions.taskgroup import MappedTaskGroup
from airflow.utils import timezone
from airflow.utils.state import DagRunState, TaskInstanceState, TerminalTIState
@@ -244,7 +245,9 @@ def ti_run(
)
if dag := dag_bag.get_dag(ti.dag_id):
- upstream_map_indexes =
dict(_get_upstream_map_indexes(dag.get_task(ti.task_id), ti.map_index))
+ upstream_map_indexes = dict(
+ _get_upstream_map_indexes(dag.get_task(ti.task_id),
ti.map_index, ti.run_id, session)
+ )
else:
upstream_map_indexes = None
@@ -274,7 +277,7 @@ def ti_run(
def _get_upstream_map_indexes(
- task: Operator, ti_map_index: int
+ task: Operator, ti_map_index: int, run_id: str, session: SessionDep
) -> Iterator[tuple[str, int | list[int] | None]]:
for upstream_task in task.upstream_list:
map_indexes: int | list[int] | None
@@ -287,8 +290,17 @@ def _get_upstream_map_indexes(
map_indexes = ti_map_index
else:
# tasks not in the same mapped task group
- # the upstream mapped task group should combine the xcom as a list
and return it
- mapped_ti_count: int =
upstream_task.task_group.get_parse_time_mapped_ti_count()
+ # the upstream mapped task group should combine the return xcom as
a list and return it
+ mapped_ti_count: int
+ upstream_mapped_group = upstream_task.task_group
+ try:
+ # for cases that does not need to resolve xcom
+ mapped_ti_count =
upstream_mapped_group.get_parse_time_mapped_ti_count()
+ except NotFullyPopulated:
+ # for cases that needs to resolve xcom to get the correct count
+ mapped_ti_count =
upstream_mapped_group._expand_input.get_total_map_length(
+ run_id, session=session
+ )
map_indexes = list(range(mapped_ti_count)) if mapped_ti_count is
not None else None
yield upstream_task.task_id, map_indexes
diff --git a/airflow-core/src/airflow/models/expandinput.py
b/airflow-core/src/airflow/models/expandinput.py
index f3e6aab1680..b126c6f24b0 100644
--- a/airflow-core/src/airflow/models/expandinput.py
+++ b/airflow-core/src/airflow/models/expandinput.py
@@ -20,7 +20,7 @@ from __future__ import annotations
import functools
import operator
from collections.abc import Iterable, Sized
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, ClassVar, Union
import attrs
@@ -32,7 +32,6 @@ if TYPE_CHECKING:
from airflow.sdk.definitions._internal.expandinput import (
DictOfListsExpandInput,
- ExpandInput,
ListOfDictsExpandInput,
MappedArgument,
NotFullyPopulated,
@@ -62,6 +61,8 @@ def _needs_run_time_resolution(v: OperatorExpandArgument) ->
TypeGuard[MappedArg
class SchedulerDictOfListsExpandInput:
value: dict
+ EXPAND_INPUT_TYPE: ClassVar[str] = "dict-of-lists"
+
def _iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]:
"""Generate kwargs with values available on parse-time."""
return ((k, v) for k, v in self.value.items() if not
_needs_run_time_resolution(v))
@@ -114,6 +115,8 @@ class SchedulerDictOfListsExpandInput:
class SchedulerListOfDictsExpandInput:
value: list
+ EXPAND_INPUT_TYPE: ClassVar[str] = "list-of-dicts"
+
def get_parse_time_mapped_ti_count(self) -> int:
if isinstance(self.value, Sized):
return len(self.value)
@@ -130,11 +133,13 @@ class SchedulerListOfDictsExpandInput:
return length
-_EXPAND_INPUT_TYPES = {
+_EXPAND_INPUT_TYPES: dict[str, type[SchedulerExpandInput]] = {
"dict-of-lists": SchedulerDictOfListsExpandInput,
"list-of-dicts": SchedulerListOfDictsExpandInput,
}
+SchedulerExpandInput = Union[SchedulerDictOfListsExpandInput,
SchedulerListOfDictsExpandInput]
+
-def create_expand_input(kind: str, value: Any) -> ExpandInput:
+def create_expand_input(kind: str, value: Any) -> SchedulerExpandInput:
return _EXPAND_INPUT_TYPES[kind](value)
diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py
b/airflow-core/src/airflow/serialization/serialized_objects.py
index dbd55c1adde..2a8f3cd9da6 100644
--- a/airflow-core/src/airflow/serialization/serialized_objects.py
+++ b/airflow-core/src/airflow/serialization/serialized_objects.py
@@ -100,7 +100,7 @@ if TYPE_CHECKING:
from inspect import Parameter
from airflow.models import DagRun
- from airflow.models.expandinput import ExpandInput
+ from airflow.models.expandinput import SchedulerExpandInput
from airflow.sdk import BaseOperatorLink
from airflow.sdk.definitions._internal.node import DAGNode
from airflow.sdk.types import Operator
@@ -577,7 +577,7 @@ class _ExpandInputRef(NamedTuple):
possible ExpandInput cases.
"""
- def deref(self, dag: DAG) -> ExpandInput:
+ def deref(self, dag: DAG) -> SchedulerExpandInput:
"""
De-reference into a concrete ExpandInput object.
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
index f1b8982eb04..21c733a5cce 100644
---
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
+++
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
@@ -33,7 +33,7 @@ from airflow.models.asset import AssetActive,
AssetAliasModel, AssetEvent, Asset
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskinstancehistory import TaskInstanceHistory
from airflow.providers.standard.operators.empty import EmptyOperator
-from airflow.sdk import TaskGroup
+from airflow.sdk import TaskGroup, task, task_group
from airflow.utils import timezone
from airflow.utils.state import State, TaskInstanceState, TerminalTIState
@@ -236,6 +236,128 @@ class TestTIRunState:
)
assert response.status_code == 409
+ def test_dynamic_task_mapping_with_parse_time_value(self, client,
dag_maker):
+ """
+ Test that the Task Instance upstream_map_indexes is correctly fetched
when to running the Task Instances
+ """
+
+ with dag_maker("test_dynamic_task_mapping_with_parse_time_value",
serialized=True):
+
+ @task_group
+ def task_group_1(arg1):
+ @task
+ def group1_task_1(arg1):
+ return {"a": arg1}
+
+ @task
+ def group1_task_2(arg2):
+ return arg2
+
+ group1_task_2(group1_task_1(arg1))
+
+ @task
+ def task2():
+ return None
+
+ task_group_1.expand(arg1=[0, 1]) >> task2()
+
+ dr = dag_maker.create_dagrun()
+ for ti in dr.get_task_instances():
+ ti.set_state(State.QUEUED)
+ dag_maker.session.flush()
+
+ # key: (task_id, map_index)
+ # value: result upstream_map_indexes ({task_id: map_indexes})
+ expected_upstream_map_indexes = {
+ # no upstream task for task_group_1.group_task_1
+ ("task_group_1.group1_task_1", 0): {},
+ ("task_group_1.group1_task_1", 1): {},
+ # the upstream task for task_group_1.group_task_2 is
task_group_1.group_task_2
+ # since they are in the same task group, the upstream map index
should be the same as the task
+ ("task_group_1.group1_task_2", 0): {"task_group_1.group1_task_1":
0},
+ ("task_group_1.group1_task_2", 1): {"task_group_1.group1_task_1":
1},
+ # the upstream task for task2 is the last tasks of task_group_1,
which is
+ # task_group_1.group_task_2
+ # since they are not in the same task group, the upstream map
index should include all the
+ # expanded tasks
+ ("task2", -1): {"task_group_1.group1_task_2": [0, 1]},
+ }
+
+ for ti in dr.get_task_instances():
+ response = client.patch(
+ f"/execution/task-instances/{ti.id}/run",
+ json={
+ "state": "running",
+ "hostname": "random-hostname",
+ "unixname": "random-unixname",
+ "pid": 100,
+ "start_date": "2024-09-30T12:00:00Z",
+ },
+ )
+
+ assert response.status_code == 200
+ upstream_map_indexes = response.json()["upstream_map_indexes"]
+ assert upstream_map_indexes ==
expected_upstream_map_indexes[(ti.task_id, ti.map_index)]
+
+ def test_dynamic_task_mapping_with_xcom(self, client, dag_maker,
create_task_instance, session, run_task):
+ """
+ Test that the Task Instance upstream_map_indexes is correctly fetched
when to running the Task Instances with xcom
+ """
+ from airflow.models.taskmap import TaskMap
+
+ with dag_maker(session=session):
+
+ @task
+ def task_1():
+ return [0, 1]
+
+ @task_group
+ def tg(x, y):
+ @task
+ def task_2():
+ pass
+
+ task_2()
+
+ @task
+ def task_3():
+ pass
+
+ tg.expand(x=task_1(), y=[1, 2, 3]) >> task_3()
+
+ dr = dag_maker.create_dagrun()
+
+ decision = dr.task_instance_scheduling_decisions(session=session)
+
+ # Simulate task_1 execution to produce TaskMap.
+ (ti_1,) = decision.schedulable_tis
+ # ti_1 = dr.get_task_instance(task_id="task_1")
+ ti_1.state = TaskInstanceState.SUCCESS
+ session.add(TaskMap.from_task_instance_xcom(ti_1, [0, 1]))
+ session.flush()
+
+ # Now task_2 in mapped tagk group is expanded.
+ decision = dr.task_instance_scheduling_decisions(session=session)
+ for ti in decision.schedulable_tis:
+ ti.state = TaskInstanceState.SUCCESS
+ session.flush()
+
+ decision = dr.task_instance_scheduling_decisions(session=session)
+ (task_3_ti,) = decision.schedulable_tis
+ task_3_ti.set_state(State.QUEUED)
+
+ response = client.patch(
+ f"/execution/task-instances/{task_3_ti.id}/run",
+ json={
+ "state": "running",
+ "hostname": "random-hostname",
+ "unixname": "random-unixname",
+ "pid": 100,
+ "start_date": "2024-09-30T12:00:00Z",
+ },
+ )
+ assert response.json()["upstream_map_indexes"] == {"tg.task_2": [0, 1,
2, 3, 4, 5]}
+
def test_next_kwargs_still_encoded(self, client, session,
create_task_instance, time_machine):
instant_str = "2024-09-30T12:00:00Z"
instant = timezone.parse(instant_str)
diff --git a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
index b2e3baaeccf..cb24a7cc6bd 100644
--- a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
+++ b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
@@ -64,13 +64,13 @@ if TYPE_CHECKING:
TaskStateChangeCallback,
)
from airflow.models.expandinput import (
- ExpandInput,
OperatorExpandArgument,
OperatorExpandKwargsArgument,
)
from airflow.models.xcom_arg import XComArg
from airflow.sdk.bases.operator import BaseOperator
from airflow.sdk.bases.operatorlink import BaseOperatorLink
+ from airflow.sdk.definitions._internal.expandinput import ExpandInput
from airflow.sdk.definitions.dag import DAG
from airflow.sdk.definitions.param import ParamsDict
from airflow.sdk.types import Operator
diff --git a/task-sdk/src/airflow/sdk/definitions/taskgroup.py
b/task-sdk/src/airflow/sdk/definitions/taskgroup.py
index 3363424dee6..03cc4bbad8d 100644
--- a/task-sdk/src/airflow/sdk/definitions/taskgroup.py
+++ b/task-sdk/src/airflow/sdk/definitions/taskgroup.py
@@ -40,7 +40,7 @@ from airflow.sdk.definitions._internal.node import DAGNode,
validate_group_key
from airflow.utils.trigger_rule import TriggerRule
if TYPE_CHECKING:
- from airflow.models.expandinput import ExpandInput
+ from airflow.models.expandinput import SchedulerExpandInput
from airflow.sdk.bases.operator import BaseOperator
from airflow.sdk.definitions._internal.abstractoperator import
AbstractOperator
from airflow.sdk.definitions._internal.mixins import DependencyMixin
@@ -613,7 +613,7 @@ class MappedTaskGroup(TaskGroup):
a ``@task_group`` function instead.
"""
- def __init__(self, *, expand_input: ExpandInput, **kwargs: Any) -> None:
+ def __init__(self, *, expand_input: SchedulerExpandInput, **kwargs: Any)
-> None:
super().__init__(**kwargs)
self._expand_input = expand_input