This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v3-0-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit f001dc2d9b20ee6714164c04ce873cc65ab56848
Author: Wei Lee <[email protected]>
AuthorDate: Mon May 5 22:42:19 2025 +0800

    Handle MappedTaskGroup map indexes (#49996)
    
    * fix(task_group): correctly set upstream_map_indexes for mapped task group
    
    * refactor(task_instance): refactor how upstream_map_indexes is handled
    
    * test: fix broken tests
    
    * fix(task_group): fix maaping handling
    
    * refactor(task_instnaces): rewrite _get_upstream_map_indexes
    
    (cherry picked from commit 5adeac2c83e4c063e4599efa23575d2be60b2d13)
---
 .../execution_api/datamodels/taskinstance.py       |  2 +-
 .../execution_api/routes/task_instances.py         | 40 ++++++++++++++++++++--
 .../versions/head/test_task_instances.py           |  3 ++
 devel-common/src/tests_common/pytest_plugin.py     |  3 +-
 .../src/airflow/sdk/api/datamodels/_generated.py   |  4 ++-
 task-sdk/src/airflow/sdk/definitions/xcom_arg.py   | 23 +++++--------
 .../task_sdk/definitions/test_mappedoperator.py    | 15 +++++++-
 7 files changed, 70 insertions(+), 20 deletions(-)

diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
index c3a8be3134c..b83d731a54e 100644
--- 
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
+++ 
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
@@ -314,7 +314,7 @@ class TIRunContext(BaseModel):
     connections: Annotated[list[ConnectionResponse], 
Field(default_factory=list)]
     """Connections that can be accessed by the task instance."""
 
-    upstream_map_indexes: dict[str, int] | None = None
+    upstream_map_indexes: dict[str, int | list[int] | None] | None = None
 
     next_method: str | None = None
     """Method to call. Set when task resumes from a trigger."""
diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
index f97a684e3ff..00a5ea9a5e6 100644
--- 
a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
+++ 
b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py
@@ -19,7 +19,8 @@ from __future__ import annotations
 
 import json
 from collections import defaultdict
-from typing import Annotated, Any
+from collections.abc import Iterator
+from typing import TYPE_CHECKING, Annotated, Any
 from uuid import UUID
 
 import structlog
@@ -55,9 +56,14 @@ from airflow.models.taskinstance import TaskInstance as TI, 
_stop_remaining_task
 from airflow.models.taskreschedule import TaskReschedule
 from airflow.models.trigger import Trigger
 from airflow.models.xcom import XComModel
+from airflow.sdk.definitions.taskgroup import MappedTaskGroup
 from airflow.utils import timezone
 from airflow.utils.state import DagRunState, TaskInstanceState, TerminalTIState
 
+if TYPE_CHECKING:
+    from airflow.sdk.types import Operator
+
+
 router = VersionedAPIRouter()
 
 ti_id_router = VersionedAPIRouter(
@@ -82,7 +88,10 @@ log = structlog.get_logger(__name__)
     response_model_exclude_unset=True,
 )
 def ti_run(
-    task_instance_id: UUID, ti_run_payload: Annotated[TIEnterRunningPayload, 
Body()], session: SessionDep
+    task_instance_id: UUID,
+    ti_run_payload: Annotated[TIEnterRunningPayload, Body()],
+    session: SessionDep,
+    request: Request,
 ) -> TIRunContext:
     """
     Run a TaskInstance.
@@ -233,6 +242,11 @@ def ti_run(
             or 0
         )
 
+        if dag := request.app.state.dag_bag.get_dag(ti.dag_id):
+            upstream_map_indexes = 
dict(_get_upstream_map_indexes(dag.get_task(ti.task_id), ti.map_index))
+        else:
+            upstream_map_indexes = None
+
         context = TIRunContext(
             dag_run=dr,
             task_reschedule_count=task_reschedule_count,
@@ -242,6 +256,7 @@ def ti_run(
             connections=[],
             xcom_keys_to_clear=xcom_keys,
             should_retry=_is_eligible_to_retry(previous_state, ti.try_number, 
ti.max_tries),
+            upstream_map_indexes=upstream_map_indexes,
         )
 
         # Only set if they are non-null
@@ -257,6 +272,27 @@ def ti_run(
         )
 
 
+def _get_upstream_map_indexes(
+    task: Operator, ti_map_index: int
+) -> Iterator[tuple[str, int | list[int] | None]]:
+    for upstream_task in task.upstream_list:
+        map_indexes: int | list[int] | None
+        if not isinstance(upstream_task.task_group, MappedTaskGroup):
+            # regular tasks or non-mapped task groups
+            map_indexes = None
+        elif task.task_group == upstream_task.task_group:
+            # tasks in the same mapped task group
+            # the task should use the map_index as the previous task in the 
same mapped task group
+            map_indexes = ti_map_index
+        else:
+            # tasks not in the same mapped task group
+            # the upstream mapped task group should combine the xcom as a list 
and return it
+            mapped_ti_count: int = 
upstream_task.task_group.get_parse_time_mapped_ti_count()
+            map_indexes = list(range(mapped_ti_count)) if mapped_ti_count is 
not None else None
+
+        yield upstream_task.task_id, map_indexes
+
+
 @ti_id_router.patch(
     "/{task_instance_id}/state",
     status_code=status.HTTP_204_NO_CONTENT,
diff --git 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
index 41bba2c2bd5..95631484087 100644
--- 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
+++ 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
@@ -179,6 +179,7 @@ class TestTIRunState:
                 "consumed_asset_events": [],
             },
             "task_reschedule_count": 0,
+            "upstream_map_indexes": None,
             "max_tries": max_tries,
             "should_retry": should_retry,
             "variables": [],
@@ -257,6 +258,7 @@ class TestTIRunState:
         assert response.json() == {
             "dag_run": mock.ANY,
             "task_reschedule_count": 0,
+            "upstream_map_indexes": None,
             "max_tries": 0,
             "should_retry": False,
             "variables": [],
@@ -318,6 +320,7 @@ class TestTIRunState:
         assert response.json() == {
             "dag_run": mock.ANY,
             "task_reschedule_count": 0,
+            "upstream_map_indexes": None,
             "max_tries": 0,
             "should_retry": False,
             "variables": [],
diff --git a/devel-common/src/tests_common/pytest_plugin.py 
b/devel-common/src/tests_common/pytest_plugin.py
index b461ddc409d..2b31152719d 100644
--- a/devel-common/src/tests_common/pytest_plugin.py
+++ b/devel-common/src/tests_common/pytest_plugin.py
@@ -2088,7 +2088,7 @@ def create_runtime_ti(mocked_parse):
         run_type: str = "manual",
         try_number: int = 1,
         map_index: int | None = -1,
-        upstream_map_indexes: dict[str, int] | None = None,
+        upstream_map_indexes: dict[str, int | list[int] | None] | None = None,
         task_reschedule_count: int = 0,
         ti_id: UUID | None = None,
         conf: dict[str, Any] | None = None,
@@ -2143,6 +2143,7 @@ def create_runtime_ti(mocked_parse):
             task_reschedule_count=task_reschedule_count,
             max_tries=task_retries if max_tries is None else max_tries,
             should_retry=should_retry if should_retry is not None else 
try_number <= task_retries,
+            upstream_map_indexes=upstream_map_indexes,
         )
 
         if upstream_map_indexes is not None:
diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py 
b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
index a477c3ffc41..dbf39fa4ae1 100644
--- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -480,7 +480,9 @@ class TIRunContext(BaseModel):
     max_tries: Annotated[int, Field(title="Max Tries")]
     variables: Annotated[list[VariableResponse] | None, 
Field(title="Variables")] = None
     connections: Annotated[list[ConnectionResponse] | None, 
Field(title="Connections")] = None
-    upstream_map_indexes: Annotated[dict[str, int] | None, 
Field(title="Upstream Map Indexes")] = None
+    upstream_map_indexes: Annotated[
+        dict[str, int | list[int] | None] | None, Field(title="Upstream Map 
Indexes")
+    ] = None
     next_method: Annotated[str | None, Field(title="Next Method")] = None
     next_kwargs: Annotated[dict[str, Any] | str | None, Field(title="Next 
Kwargs")] = None
     xcom_keys_to_clear: Annotated[list[str] | None, Field(title="Xcom Keys To 
Clear")] = None
diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py 
b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py
index 1adcb7efaa7..2a93585304c 100644
--- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py
+++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py
@@ -339,22 +339,17 @@ class PlainXComArg(XComArg):
         if self.operator.is_mapped:
             return LazyXComSequence(xcom_arg=self, ti=ti)
         tg = self.operator.get_closest_mapped_task_group()
-        result = None
         if tg is None:
-            # regular task
-            result = ti.xcom_pull(
-                task_ids=task_id,
-                key=self.key,
-                default=NOTSET,
-                map_indexes=None,
-            )
+            map_indexes = None
         else:
-            # task from a task group
-            result = ti.xcom_pull(
-                task_ids=task_id,
-                key=self.key,
-                default=NOTSET,
-            )
+            upstream_map_indexes = getattr(ti, "_upstream_map_indexes", {})
+            map_indexes = upstream_map_indexes.get(task_id, None)
+        result = ti.xcom_pull(
+            task_ids=task_id,
+            key=self.key,
+            default=NOTSET,
+            map_indexes=map_indexes,
+        )
         if not isinstance(result, ArgNotSet):
             return result
         if self.key == XCOM_RETURN_KEY:
diff --git a/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py 
b/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py
index d6b8ca8da4e..6cdb4b520f2 100644
--- a/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py
+++ b/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py
@@ -627,9 +627,22 @@ def 
test_operator_mapped_task_group_receives_value(create_runtime_ti, mock_super
         "tg.t2": range(3),
         "t3": [None],
     }
+    upstream_map_indexes_per_task_id = {
+        ("tg.t1", 0): {},
+        ("tg.t1", 1): {},
+        ("tg.t1", 2): {},
+        ("tg.t2", 0): {"tg.t1": 0},
+        ("tg.t2", 1): {"tg.t1": 1},
+        ("tg.t2", 2): {"tg.t1": 2},
+        ("t3", None): {"tg.t2": [0, 1, 2]},
+    }
     for task in dag.tasks:
         for map_index in expansion_per_task_id[task.task_id]:
-            mapped_ti = create_runtime_ti(task=task.prepare_for_execution(), 
map_index=map_index)
+            mapped_ti = create_runtime_ti(
+                task=task.prepare_for_execution(),
+                map_index=map_index,
+                
upstream_map_indexes=upstream_map_indexes_per_task_id[(task.task_id, 
map_index)],
+            )
             context = mapped_ti.get_template_context()
             mapped_ti.task.render_template_fields(context)
             mapped_ti.task.execute(context)

Reply via email to