Dev-iL commented on code in PR #61975:
URL: https://github.com/apache/airflow/pull/61975#discussion_r2810964569


##########
airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py:
##########
@@ -921,6 +926,98 @@ def _iter_breadcrumbs() -> Iterator[dict[str, Any]]:
     return TaskBreadcrumbsResponse(breadcrumbs=_iter_breadcrumbs())
 
 
+def _populate_task_group_map_index_context(
+    context: TIRunContext,
+    dag_id: str,
+    task_id: str,
+    map_index: int,
+    run_id: str,
+    session: SessionDep,
+    dag_bag: DagBagDep,
+) -> None:
+    """Populate task group map_index_template and expanded args on the 
TIRunContext."""
+    try:
+        dag = get_latest_version_of_dag(dag_bag, dag_id, session)
+    except HTTPException:
+        return
+
+    task = dag.task_dict.get(task_id)
+    if not task:
+        return
+
+    for mtg in task.iter_mapped_task_groups():
+        if not mtg.map_index_template:
+            continue
+
+        context.task_group_map_index_template = mtg.map_index_template
+        context.task_group_expanded_args = _resolve_task_group_expand_args(
+            mtg._expand_input, map_index, run_id, session
+        )
+        break
+
+
+def _resolve_task_group_expand_args(
+    expand_input: Any,
+    map_index: int,
+    run_id: str,
+    session: SessionDep,
+) -> dict[str, Any] | None:
+    """Resolve the expand_input for a specific map_index to get the expanded 
arguments."""
+    from airflow.models.expandinput import SchedulerDictOfListsExpandInput, 
SchedulerListOfDictsExpandInput
+    from airflow.serialization.definitions.xcom_arg import SchedulerXComArg
+
+    if isinstance(expand_input, SchedulerDictOfListsExpandInput):
+        resolved: dict[str, Any] = {}
+        for key, value in expand_input.value.items():
+            if isinstance(value, SchedulerXComArg):
+                xcom_result = _resolve_xcom_arg_value(value, run_id, session)
+                if isinstance(xcom_result, list) and map_index < 
len(xcom_result):
+                    resolved[key] = xcom_result[map_index]
+            elif isinstance(value, (list, tuple)):
+                if map_index < len(value):
+                    resolved[key] = value[map_index]
+        return resolved if resolved else None
+
+    if isinstance(expand_input, SchedulerListOfDictsExpandInput):
+        if isinstance(expand_input.value, (list, tuple)):
+            if map_index < len(expand_input.value):
+                item = expand_input.value[map_index]
+                if isinstance(item, dict):
+                    return item
+        elif isinstance(expand_input.value, SchedulerXComArg):
+            xcom_result = _resolve_xcom_arg_value(expand_input.value, run_id, 
session)
+            if isinstance(xcom_result, list) and map_index < len(xcom_result):
+                item = xcom_result[map_index]
+                if isinstance(item, dict):
+                    return item
+
+    return None

Review Comment:
   How about this?
   
   ```python
   def _resolve_task_group_expand_args(
       expand_input: Any,
       map_index: int,
       run_id: str,
       session: SessionDep,
   ) -> dict[str, Any] | None:
       """Resolve the expand_input for a specific map_index to get the expanded 
arguments."""
       from airflow.models.expandinput import SchedulerDictOfListsExpandInput, 
SchedulerListOfDictsExpandInput
       from airflow.serialization.definitions.xcom_arg import SchedulerXComArg
   
       def _resolve_at_index(value: Any) -> Any | None:
           """Resolve a single value (list/tuple or XComArg) at the given 
map_index."""
           match value:
               case SchedulerXComArg():
                   value = _resolve_xcom_arg_value(value, run_id, session)
               case list() | tuple():
                   pass
               case _:
                   return None
           if isinstance(value, (list, tuple)) and map_index < len(value):
               return value[map_index]
           return None
   
       match expand_input:
           case SchedulerDictOfListsExpandInput(value=mapping):
               resolved = {}
               for key, val in mapping.items():
                   if (item := _resolve_at_index(val)) is not None:
                       resolved[key] = item
               return resolved or None
   
           case SchedulerListOfDictsExpandInput(value=val):
               item = _resolve_at_index(val)
               if isinstance(item, dict):
                   return item
   
       return None
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to