This is an automated email from the ASF dual-hosted git repository.

jasonliu pushed a commit to branch v3-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/v3-0-test by this push:
     new 486dfda72dd [v3-0-test] fix(task_instances): handle 
upstream_mapped_index when xcom access is needed (#50641) (#50950)
486dfda72dd is described below

commit 486dfda72ddf023653a6672d857f84c62ad7bf9f
Author: github-actions[bot] 
<41898282+github-actions[bot]@users.noreply.github.com>
AuthorDate: Mon May 26 20:33:15 2025 +0800

    [v3-0-test] fix(task_instances): handle upstream_mapped_index when xcom 
access is needed (#50641) (#50950)
    
    * fix(task_instances): handle upstream_mapped_index when xcom access is 
needed
    
    * style(expand_input): fix expand_input and SchedulerExpandInput types
    
    * test(task_instances): add test_dynamic_task_mapping_with_parse_time_value
    
    * test(task_instance): add test_dynamic_task_mapping_with_xcom
    
    * style: import typing
    
    * style: move the SchedulerExpandInput into type checking block
    
    * Revert "style: move the SchedulerExpandInput into type checking block"
    
    This reverts commit c2c87ca304bfe721120bda19f3dcc3a0ddab8804.
    (cherry picked from commit 5458e7e7be86c6de034d7a589bd26db85c532308)
    
    Co-authored-by: Wei Lee <[email protected]>
---
 .../execution_api/routes/task_instances.py         |  20 +++-
 airflow-core/src/airflow/models/expandinput.py     |  13 ++-
 .../airflow/serialization/serialized_objects.py    |   4 +-
 .../versions/head/test_task_instances.py           | 124 ++++++++++++++++++++-
 .../src/airflow/sdk/definitions/mappedoperator.py  |   2 +-
 task-sdk/src/airflow/sdk/definitions/taskgroup.py  |   4 +-
 6 files changed, 153 insertions(+), 14 deletions(-)

diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index a48070deb26..ac1d1602460 100644
--- 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++ 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -57,6 +57,7 @@ from airflow.models.taskinstance import TaskInstance as TI, 
_stop_remaining_task
 from airflow.models.taskreschedule import TaskReschedule
 from airflow.models.trigger import Trigger
 from airflow.models.xcom import XComModel
+from airflow.sdk.definitions._internal.expandinput import NotFullyPopulated
 from airflow.sdk.definitions.taskgroup import MappedTaskGroup
 from airflow.utils import timezone
 from airflow.utils.state import DagRunState, TaskInstanceState, TerminalTIState
@@ -244,7 +245,9 @@ def ti_run(
         )
 
         if dag := dag_bag.get_dag(ti.dag_id):
-            upstream_map_indexes = 
dict(_get_upstream_map_indexes(dag.get_task(ti.task_id), ti.map_index))
+            upstream_map_indexes = dict(
+                _get_upstream_map_indexes(dag.get_task(ti.task_id), 
ti.map_index, ti.run_id, session)
+            )
         else:
             upstream_map_indexes = None
 
@@ -274,7 +277,7 @@ def ti_run(
 
 
 def _get_upstream_map_indexes(
-    task: Operator, ti_map_index: int
+    task: Operator, ti_map_index: int, run_id: str, session: SessionDep
 ) -> Iterator[tuple[str, int | list[int] | None]]:
     for upstream_task in task.upstream_list:
         map_indexes: int | list[int] | None
@@ -287,8 +290,17 @@ def _get_upstream_map_indexes(
             map_indexes = ti_map_index
         else:
             # tasks not in the same mapped task group
-            # the upstream mapped task group should combine the xcom as a list 
and return it
-            mapped_ti_count: int = 
upstream_task.task_group.get_parse_time_mapped_ti_count()
+            # the upstream mapped task group should combine the return xcom as 
a list and return it
+            mapped_ti_count: int
+            upstream_mapped_group = upstream_task.task_group
+            try:
+                # for cases that does not need to resolve xcom
+                mapped_ti_count = 
upstream_mapped_group.get_parse_time_mapped_ti_count()
+            except NotFullyPopulated:
+                # for cases that needs to resolve xcom to get the correct count
+                mapped_ti_count = 
upstream_mapped_group._expand_input.get_total_map_length(
+                    run_id, session=session
+                )
             map_indexes = list(range(mapped_ti_count)) if mapped_ti_count is 
not None else None
 
         yield upstream_task.task_id, map_indexes
diff --git a/airflow-core/src/airflow/models/expandinput.py 
b/airflow-core/src/airflow/models/expandinput.py
index f3e6aab1680..b126c6f24b0 100644
--- a/airflow-core/src/airflow/models/expandinput.py
+++ b/airflow-core/src/airflow/models/expandinput.py
@@ -20,7 +20,7 @@ from __future__ import annotations
 import functools
 import operator
 from collections.abc import Iterable, Sized
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, ClassVar, Union
 
 import attrs
 
@@ -32,7 +32,6 @@ if TYPE_CHECKING:
 
 from airflow.sdk.definitions._internal.expandinput import (
     DictOfListsExpandInput,
-    ExpandInput,
     ListOfDictsExpandInput,
     MappedArgument,
     NotFullyPopulated,
@@ -62,6 +61,8 @@ def _needs_run_time_resolution(v: OperatorExpandArgument) -> 
TypeGuard[MappedArg
 class SchedulerDictOfListsExpandInput:
     value: dict
 
+    EXPAND_INPUT_TYPE: ClassVar[str] = "dict-of-lists"
+
     def _iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]:
         """Generate kwargs with values available on parse-time."""
         return ((k, v) for k, v in self.value.items() if not 
_needs_run_time_resolution(v))
@@ -114,6 +115,8 @@ class SchedulerDictOfListsExpandInput:
 class SchedulerListOfDictsExpandInput:
     value: list
 
+    EXPAND_INPUT_TYPE: ClassVar[str] = "list-of-dicts"
+
     def get_parse_time_mapped_ti_count(self) -> int:
         if isinstance(self.value, Sized):
             return len(self.value)
@@ -130,11 +133,13 @@ class SchedulerListOfDictsExpandInput:
         return length
 
 
-_EXPAND_INPUT_TYPES = {
+_EXPAND_INPUT_TYPES: dict[str, type[SchedulerExpandInput]] = {
     "dict-of-lists": SchedulerDictOfListsExpandInput,
     "list-of-dicts": SchedulerListOfDictsExpandInput,
 }
 
+SchedulerExpandInput = Union[SchedulerDictOfListsExpandInput, 
SchedulerListOfDictsExpandInput]
+
 
-def create_expand_input(kind: str, value: Any) -> ExpandInput:
+def create_expand_input(kind: str, value: Any) -> SchedulerExpandInput:
     return _EXPAND_INPUT_TYPES[kind](value)
diff --git a/airflow-core/src/airflow/serialization/serialized_objects.py 
b/airflow-core/src/airflow/serialization/serialized_objects.py
index dbd55c1adde..2a8f3cd9da6 100644
--- a/airflow-core/src/airflow/serialization/serialized_objects.py
+++ b/airflow-core/src/airflow/serialization/serialized_objects.py
@@ -100,7 +100,7 @@ if TYPE_CHECKING:
     from inspect import Parameter
 
     from airflow.models import DagRun
-    from airflow.models.expandinput import ExpandInput
+    from airflow.models.expandinput import SchedulerExpandInput
     from airflow.sdk import BaseOperatorLink
     from airflow.sdk.definitions._internal.node import DAGNode
     from airflow.sdk.types import Operator
@@ -577,7 +577,7 @@ class _ExpandInputRef(NamedTuple):
         possible ExpandInput cases.
         """
 
-    def deref(self, dag: DAG) -> ExpandInput:
+    def deref(self, dag: DAG) -> SchedulerExpandInput:
         """
         De-reference into a concrete ExpandInput object.
 
diff --git 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
index f1b8982eb04..21c733a5cce 100644
--- 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
+++ 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
@@ -33,7 +33,7 @@ from airflow.models.asset import AssetActive, 
AssetAliasModel, AssetEvent, Asset
 from airflow.models.taskinstance import TaskInstance
 from airflow.models.taskinstancehistory import TaskInstanceHistory
 from airflow.providers.standard.operators.empty import EmptyOperator
-from airflow.sdk import TaskGroup
+from airflow.sdk import TaskGroup, task, task_group
 from airflow.utils import timezone
 from airflow.utils.state import State, TaskInstanceState, TerminalTIState
 
@@ -236,6 +236,128 @@ class TestTIRunState:
         )
         assert response.status_code == 409
 
+    def test_dynamic_task_mapping_with_parse_time_value(self, client, 
dag_maker):
+        """
+        Test that the Task Instance upstream_map_indexes is correctly fetched 
when to running the Task Instances
+        """
+
+        with dag_maker("test_dynamic_task_mapping_with_parse_time_value", 
serialized=True):
+
+            @task_group
+            def task_group_1(arg1):
+                @task
+                def group1_task_1(arg1):
+                    return {"a": arg1}
+
+                @task
+                def group1_task_2(arg2):
+                    return arg2
+
+                group1_task_2(group1_task_1(arg1))
+
+            @task
+            def task2():
+                return None
+
+            task_group_1.expand(arg1=[0, 1]) >> task2()
+
+        dr = dag_maker.create_dagrun()
+        for ti in dr.get_task_instances():
+            ti.set_state(State.QUEUED)
+        dag_maker.session.flush()
+
+        # key: (task_id, map_index)
+        # value: result upstream_map_indexes ({task_id: map_indexes})
+        expected_upstream_map_indexes = {
+            # no upstream task for task_group_1.group_task_1
+            ("task_group_1.group1_task_1", 0): {},
+            ("task_group_1.group1_task_1", 1): {},
+            # the upstream task for task_group_1.group_task_2 is 
task_group_1.group_task_2
+            # since they are in the same task group, the upstream map index 
should be the same as the task
+            ("task_group_1.group1_task_2", 0): {"task_group_1.group1_task_1": 
0},
+            ("task_group_1.group1_task_2", 1): {"task_group_1.group1_task_1": 
1},
+            # the upstream task for task2 is the last tasks of task_group_1, 
which is
+            # task_group_1.group_task_2
+            # since they are not in the same task group, the upstream map 
index should include all the
+            # expanded tasks
+            ("task2", -1): {"task_group_1.group1_task_2": [0, 1]},
+        }
+
+        for ti in dr.get_task_instances():
+            response = client.patch(
+                f"/execution/task-instances/{ti.id}/run",
+                json={
+                    "state": "running",
+                    "hostname": "random-hostname",
+                    "unixname": "random-unixname",
+                    "pid": 100,
+                    "start_date": "2024-09-30T12:00:00Z",
+                },
+            )
+
+            assert response.status_code == 200
+            upstream_map_indexes = response.json()["upstream_map_indexes"]
+            assert upstream_map_indexes == 
expected_upstream_map_indexes[(ti.task_id, ti.map_index)]
+
+    def test_dynamic_task_mapping_with_xcom(self, client, dag_maker, 
create_task_instance, session, run_task):
+        """
+        Test that the Task Instance upstream_map_indexes is correctly fetched 
when to running the Task Instances with xcom
+        """
+        from airflow.models.taskmap import TaskMap
+
+        with dag_maker(session=session):
+
+            @task
+            def task_1():
+                return [0, 1]
+
+            @task_group
+            def tg(x, y):
+                @task
+                def task_2():
+                    pass
+
+                task_2()
+
+            @task
+            def task_3():
+                pass
+
+            tg.expand(x=task_1(), y=[1, 2, 3]) >> task_3()
+
+        dr = dag_maker.create_dagrun()
+
+        decision = dr.task_instance_scheduling_decisions(session=session)
+
+        # Simulate task_1 execution to produce TaskMap.
+        (ti_1,) = decision.schedulable_tis
+        # ti_1 = dr.get_task_instance(task_id="task_1")
+        ti_1.state = TaskInstanceState.SUCCESS
+        session.add(TaskMap.from_task_instance_xcom(ti_1, [0, 1]))
+        session.flush()
+
+        # Now task_2 in mapped tagk group is expanded.
+        decision = dr.task_instance_scheduling_decisions(session=session)
+        for ti in decision.schedulable_tis:
+            ti.state = TaskInstanceState.SUCCESS
+        session.flush()
+
+        decision = dr.task_instance_scheduling_decisions(session=session)
+        (task_3_ti,) = decision.schedulable_tis
+        task_3_ti.set_state(State.QUEUED)
+
+        response = client.patch(
+            f"/execution/task-instances/{task_3_ti.id}/run",
+            json={
+                "state": "running",
+                "hostname": "random-hostname",
+                "unixname": "random-unixname",
+                "pid": 100,
+                "start_date": "2024-09-30T12:00:00Z",
+            },
+        )
+        assert response.json()["upstream_map_indexes"] == {"tg.task_2": [0, 1, 
2, 3, 4, 5]}
+
     def test_next_kwargs_still_encoded(self, client, session, 
create_task_instance, time_machine):
         instant_str = "2024-09-30T12:00:00Z"
         instant = timezone.parse(instant_str)
diff --git a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py 
b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
index b2e3baaeccf..cb24a7cc6bd 100644
--- a/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
+++ b/task-sdk/src/airflow/sdk/definitions/mappedoperator.py
@@ -64,13 +64,13 @@ if TYPE_CHECKING:
         TaskStateChangeCallback,
     )
     from airflow.models.expandinput import (
-        ExpandInput,
         OperatorExpandArgument,
         OperatorExpandKwargsArgument,
     )
     from airflow.models.xcom_arg import XComArg
     from airflow.sdk.bases.operator import BaseOperator
     from airflow.sdk.bases.operatorlink import BaseOperatorLink
+    from airflow.sdk.definitions._internal.expandinput import ExpandInput
     from airflow.sdk.definitions.dag import DAG
     from airflow.sdk.definitions.param import ParamsDict
     from airflow.sdk.types import Operator
diff --git a/task-sdk/src/airflow/sdk/definitions/taskgroup.py 
b/task-sdk/src/airflow/sdk/definitions/taskgroup.py
index 3363424dee6..03cc4bbad8d 100644
--- a/task-sdk/src/airflow/sdk/definitions/taskgroup.py
+++ b/task-sdk/src/airflow/sdk/definitions/taskgroup.py
@@ -40,7 +40,7 @@ from airflow.sdk.definitions._internal.node import DAGNode, 
validate_group_key
 from airflow.utils.trigger_rule import TriggerRule
 
 if TYPE_CHECKING:
-    from airflow.models.expandinput import ExpandInput
+    from airflow.models.expandinput import SchedulerExpandInput
     from airflow.sdk.bases.operator import BaseOperator
     from airflow.sdk.definitions._internal.abstractoperator import 
AbstractOperator
     from airflow.sdk.definitions._internal.mixins import DependencyMixin
@@ -613,7 +613,7 @@ class MappedTaskGroup(TaskGroup):
     a ``@task_group`` function instead.
     """
 
-    def __init__(self, *, expand_input: ExpandInput, **kwargs: Any) -> None:
+    def __init__(self, *, expand_input: SchedulerExpandInput, **kwargs: Any) 
-> None:
         super().__init__(**kwargs)
         self._expand_input = expand_input
 

Reply via email to