This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new d3dc592ac3a Render custom `map_index_template` on task completion 
(#49809)
d3dc592ac3a is described below

commit d3dc592ac3a663250aa0ef0a5aa9ac61c26894dd
Author: Purna Chander Dharam <[email protected]>
AuthorDate: Mon May 5 14:25:17 2025 +0530

    Render custom `map_index_template` on task completion (#49809)
---
 .../execution_api/datamodels/taskinstance.py       |  4 ++
 .../api_fastapi/execution_api/versions/__init__.py |  3 ++
 .../versions/{__init__.py => v2025_04_28.py}       | 23 ++++++++--
 .../unit/api_fastapi/execution_api/test_app.py     |  2 +-
 .../versions/head/test_task_instances.py           | 34 +++++++++++++++
 task-sdk/src/airflow/sdk/api/client.py             | 19 +++++---
 .../src/airflow/sdk/api/datamodels/_generated.py   |  6 ++-
 task-sdk/src/airflow/sdk/execution_time/comms.py   |  1 +
 .../src/airflow/sdk/execution_time/supervisor.py   | 12 ++++-
 .../src/airflow/sdk/execution_time/task_runner.py  | 51 +++++++++++++++++++---
 task-sdk/tests/task_sdk/api/test_client.py         | 10 ++++-
 .../task_sdk/execution_time/test_supervisor.py     | 15 +++++--
 .../task_sdk/execution_time/test_task_runner.py    | 31 +++++++++++++
 13 files changed, 188 insertions(+), 23 deletions(-)

diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
index cd8287be97b..c3a8be3134c 100644
--- 
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
+++ 
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
@@ -75,6 +75,7 @@ class TITerminalStatePayload(StrictBaseModel):
 
     end_date: UtcDateTime
     """When the task completed executing"""
+    rendered_map_index: str | None = None
 
 
 class TISuccessStatePayload(StrictBaseModel):
@@ -97,6 +98,7 @@ class TISuccessStatePayload(StrictBaseModel):
 
     task_outlets: Annotated[list[AssetProfile], Field(default_factory=list)]
     outlet_events: Annotated[list[dict[str, Any]], Field(default_factory=list)]
+    rendered_map_index: str | None = None
 
 
 class TITargetStatePayload(StrictBaseModel):
@@ -136,6 +138,7 @@ class TIDeferredStatePayload(StrictBaseModel):
 
     Both forms will be passed along to the TaskSDK upon resume, the server 
will not handle either.
     """
+    rendered_map_index: str | None = None
 
 
 class TIRescheduleStatePayload(StrictBaseModel):
@@ -171,6 +174,7 @@ class TIRetryStatePayload(StrictBaseModel):
         ),
     ]
     end_date: UtcDateTime
+    rendered_map_index: str | None = None
 
 
 class TISkippedDownstreamTasksStatePayload(StrictBaseModel):
diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
index 0d3a225305b..54329054c12 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
@@ -19,7 +19,10 @@ from __future__ import annotations
 
 from cadwyn import HeadVersion, Version, VersionBundle
 
+from airflow.api_fastapi.execution_api.versions.v2025_04_28 import 
AddRenderedMapIndexField
+
 bundle = VersionBundle(
     HeadVersion(),
+    Version("2025-04-28", AddRenderedMapIndexField),
     Version("2025-04-11"),
 )
diff --git 
a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py 
b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_04_28.py
similarity index 52%
copy from 
airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
copy to 
airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_04_28.py
index 0d3a225305b..e0916b4c93d 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_04_28.py
@@ -17,9 +17,24 @@
 
 from __future__ import annotations
 
-from cadwyn import HeadVersion, Version, VersionBundle
+from cadwyn import VersionChange, schema
 
-bundle = VersionBundle(
-    HeadVersion(),
-    Version("2025-04-11"),
+from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
+    TIDeferredStatePayload,
+    TIRetryStatePayload,
+    TISuccessStatePayload,
+    TITerminalStatePayload,
 )
+
+
+class AddRenderedMapIndexField(VersionChange):
+    """Add the `rendered_map_index` field to payload models."""
+
+    description = __doc__
+
+    instructions_to_migrate_to_previous_version = (
+        schema(TITerminalStatePayload).field("rendered_map_index").didnt_exist,
+        schema(TISuccessStatePayload).field("rendered_map_index").didnt_exist,
+        schema(TIDeferredStatePayload).field("rendered_map_index").didnt_exist,
+        schema(TIRetryStatePayload).field("rendered_map_index").didnt_exist,
+    )
diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py 
b/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py
index adaa93bc202..32c53ae0db9 100644
--- a/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py
+++ b/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py
@@ -26,7 +26,7 @@ pytestmark = pytest.mark.db_test
 
 def test_custom_openapi_includes_extra_schemas(client):
     """Test to ensure that extra schemas are correctly included in the OpenAPI 
schema."""
-    response = client.get("/execution/openapi.json?version=2025-04-11")
+    response = client.get("/execution/openapi.json?version=2025-04-28")
     assert response.status_code == 200
 
     openapi_schema = response.json()
diff --git 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
index 8af11c04df9..41bba2c2bd5 100644
--- 
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
+++ 
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
@@ -44,6 +44,7 @@ pytestmark = pytest.mark.db_test
 
 DEFAULT_START_DATE = timezone.parse("2024-10-31T11:00:00Z")
 DEFAULT_END_DATE = timezone.parse("2024-10-31T12:00:00Z")
+DEFAULT_RENDERED_MAP_INDEX = "test rendered map index"
 
 
 def _create_asset_aliases(session, num: int = 2) -> None:
@@ -465,6 +466,39 @@ class TestTIUpdateState:
         assert ti.state == expected_state
         assert ti.end_date == end_date
 
+    @pytest.mark.parametrize(
+        ("state", "end_date", "expected_state", "rendered_map_index"),
+        [
+            (State.SUCCESS, DEFAULT_END_DATE, State.SUCCESS, 
DEFAULT_RENDERED_MAP_INDEX),
+            (State.FAILED, DEFAULT_END_DATE, State.FAILED, 
DEFAULT_RENDERED_MAP_INDEX),
+            (State.SKIPPED, DEFAULT_END_DATE, State.SKIPPED, 
DEFAULT_RENDERED_MAP_INDEX),
+        ],
+    )
+    def test_ti_update_state_to_terminal_with_rendered_map_index(
+        self, client, session, create_task_instance, state, end_date, 
expected_state, rendered_map_index
+    ):
+        ti = create_task_instance(
+            task_id="test_ti_update_state_to_terminal_with_rendered_map_index",
+            start_date=DEFAULT_START_DATE,
+            state=State.RUNNING,
+        )
+        session.commit()
+
+        response = client.patch(
+            f"/execution/task-instances/{ti.id}/state",
+            json={"state": state, "end_date": end_date.isoformat(), 
"rendered_map_index": rendered_map_index},
+        )
+
+        assert response.status_code == 204
+        assert response.text == ""
+
+        session.expire_all()
+
+        ti = session.get(TaskInstance, ti.id)
+        assert ti.state == expected_state
+        assert ti.end_date == end_date
+        assert ti.rendered_map_index == rendered_map_index
+
     @pytest.mark.parametrize(
         "task_outlets",
         [
diff --git a/task-sdk/src/airflow/sdk/api/client.py 
b/task-sdk/src/airflow/sdk/api/client.py
index 46e4fe6907a..5f76e8360be 100644
--- a/task-sdk/src/airflow/sdk/api/client.py
+++ b/task-sdk/src/airflow/sdk/api/client.py
@@ -146,22 +146,29 @@ class TaskInstanceOperations:
         resp = self.client.patch(f"task-instances/{id}/run", 
content=body.model_dump_json())
         return TIRunContext.model_validate_json(resp.read())
 
-    def finish(self, id: uuid.UUID, state: TerminalStateNonSuccess, when: 
datetime):
+    def finish(self, id: uuid.UUID, state: TerminalStateNonSuccess, when: 
datetime, rendered_map_index):
         """Tell the API server that this TI has reached a terminal state."""
         if state == TaskInstanceState.SUCCESS:
             raise ValueError("Logic error. SUCCESS state should call the 
`succeed` function instead")
         # TODO: handle the naming better. finish sounds wrong as "even" 
deferred is essentially finishing.
-        body = TITerminalStatePayload(end_date=when, 
state=TerminalStateNonSuccess(state))
+        body = TITerminalStatePayload(
+            end_date=when, state=TerminalStateNonSuccess(state), 
rendered_map_index=rendered_map_index
+        )
         self.client.patch(f"task-instances/{id}/state", 
content=body.model_dump_json())
 
-    def retry(self, id: uuid.UUID, end_date: datetime):
+    def retry(self, id: uuid.UUID, end_date: datetime, rendered_map_index):
         """Tell the API server that this TI has failed and reached a 
up_for_retry state."""
-        body = TIRetryStatePayload(end_date=end_date)
+        body = TIRetryStatePayload(end_date=end_date, 
rendered_map_index=rendered_map_index)
         self.client.patch(f"task-instances/{id}/state", 
content=body.model_dump_json())
 
-    def succeed(self, id: uuid.UUID, when: datetime, task_outlets, 
outlet_events):
+    def succeed(self, id: uuid.UUID, when: datetime, task_outlets, 
outlet_events, rendered_map_index):
         """Tell the API server that this TI has succeeded."""
-        body = TISuccessStatePayload(end_date=when, task_outlets=task_outlets, 
outlet_events=outlet_events)
+        body = TISuccessStatePayload(
+            end_date=when,
+            task_outlets=task_outlets,
+            outlet_events=outlet_events,
+            rendered_map_index=rendered_map_index,
+        )
         self.client.patch(f"task-instances/{id}/state", 
content=body.model_dump_json())
 
     def heartbeat(self, id: uuid.UUID, pid: int):
diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py 
b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
index 7c7647635ea..a477c3ffc41 100644
--- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -27,7 +27,7 @@ from uuid import UUID
 
 from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue
 
-API_VERSION: Final[str] = "2025-04-11"
+API_VERSION: Final[str] = "2025-04-28"
 
 
 class AssetAliasReferenceAssetEventDagRun(BaseModel):
@@ -193,6 +193,7 @@ class TIDeferredStatePayload(BaseModel):
     trigger_timeout: Annotated[timedelta | None, Field(title="Trigger 
Timeout")] = None
     next_method: Annotated[str, Field(title="Next Method")]
     next_kwargs: Annotated[dict[str, Any] | str | None, Field(title="Next 
Kwargs")] = None
+    rendered_map_index: Annotated[str | None, Field(title="Rendered Map 
Index")] = None
 
 
 class TIEnterRunningPayload(BaseModel):
@@ -245,6 +246,7 @@ class TIRetryStatePayload(BaseModel):
     )
     state: Annotated[Literal["up_for_retry"] | None, Field(title="State")] = 
"up_for_retry"
     end_date: Annotated[AwareDatetime, Field(title="End Date")]
+    rendered_map_index: Annotated[str | None, Field(title="Rendered Map 
Index")] = None
 
 
 class TISkippedDownstreamTasksStatePayload(BaseModel):
@@ -270,6 +272,7 @@ class TISuccessStatePayload(BaseModel):
     end_date: Annotated[AwareDatetime, Field(title="End Date")]
     task_outlets: Annotated[list[AssetProfile] | None, Field(title="Task 
Outlets")] = None
     outlet_events: Annotated[list[dict[str, Any]] | None, Field(title="Outlet 
Events")] = None
+    rendered_map_index: Annotated[str | None, Field(title="Rendered Map 
Index")] = None
 
 
 class TITargetStatePayload(BaseModel):
@@ -494,3 +497,4 @@ class TITerminalStatePayload(BaseModel):
     )
     state: TerminalStateNonSuccess
     end_date: Annotated[AwareDatetime, Field(title="End Date")]
+    rendered_map_index: Annotated[str | None, Field(title="Rendered Map 
Index")] = None
diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py 
b/task-sdk/src/airflow/sdk/execution_time/comms.py
index 699fe5c0370..a25ba574582 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -376,6 +376,7 @@ class TaskState(BaseModel):
     ]
     end_date: datetime | None = None
     type: Literal["TaskState"] = "TaskState"
+    rendered_map_index: str | None = None
 
 
 class SucceedTask(TISuccessStatePayload):
diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py 
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index 001f8deb701..b5cf977488b 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -767,6 +767,7 @@ class ActivitySubprocess(WatchedSubprocess):
     # TODO: This should come from airflow.cfg: [core] task_success_overtime
     TASK_OVERTIME_THRESHOLD: ClassVar[float] = 20.0
     _task_end_time_monotonic: float | None = attrs.field(default=None, 
init=False)
+    _rendered_map_index: str | None = attrs.field(default=None, init=False)
 
     decoder: ClassVar[TypeAdapter[ToSupervisor]] = TypeAdapter(ToSupervisor)
 
@@ -842,7 +843,10 @@ class ActivitySubprocess(WatchedSubprocess):
         # by the subprocess in the `handle_requests` method.
         if self.final_state not in STATES_SENT_DIRECTLY:
             self.client.task_instances.finish(
-                id=self.id, state=self.final_state, 
when=datetime.now(tz=timezone.utc)
+                id=self.id,
+                state=self.final_state,
+                when=datetime.now(tz=timezone.utc),
+                rendered_map_index=self._rendered_map_index,
             )
 
         # Now at the last possible moment, when all logs and comms with the 
subprocess has finished, lets
@@ -988,21 +992,26 @@ class ActivitySubprocess(WatchedSubprocess):
         if isinstance(msg, TaskState):
             self._terminal_state = msg.state
             self._task_end_time_monotonic = time.monotonic()
+            self._rendered_map_index = msg.rendered_map_index
         elif isinstance(msg, SucceedTask):
             self._terminal_state = msg.state
             self._task_end_time_monotonic = time.monotonic()
+            self._rendered_map_index = msg.rendered_map_index
             self.client.task_instances.succeed(
                 id=self.id,
                 when=msg.end_date,
                 task_outlets=msg.task_outlets,
                 outlet_events=msg.outlet_events,
+                rendered_map_index=self._rendered_map_index,
             )
         elif isinstance(msg, RetryTask):
             self._terminal_state = msg.state
             self._task_end_time_monotonic = time.monotonic()
+            self._rendered_map_index = msg.rendered_map_index
             self.client.task_instances.retry(
                 id=self.id,
                 end_date=msg.end_date,
+                rendered_map_index=self._rendered_map_index,
             )
         elif isinstance(msg, GetConnection):
             conn = self.client.connections.get(msg.conn_id)
@@ -1045,6 +1054,7 @@ class ActivitySubprocess(WatchedSubprocess):
                 resp = xcom
         elif isinstance(msg, DeferTask):
             self._terminal_state = TaskInstanceState.DEFERRED
+            self._rendered_map_index = msg.rendered_map_index
             self.client.task_instances.defer(self.id, msg)
         elif isinstance(msg, RescheduleTask):
             self._terminal_state = TaskInstanceState.UP_FOR_RESCHEDULE
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 6e1e5f885f5..9092ee86f0b 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -19,6 +19,7 @@
 
 from __future__ import annotations
 
+import contextlib
 import contextvars
 import functools
 import os
@@ -134,6 +135,8 @@ class RuntimeTaskInstance(TaskInstance):
     is_mapped: bool | None = None
     """True if the original task was mapped."""
 
+    rendered_map_index: str | None = None
+
     def __rich_repr__(self):
         yield "id", self.id
         yield "task_id", self.task_id
@@ -831,7 +834,17 @@ def run(
                 ti.state = state = TaskInstanceState.FAILED
                 return state, msg, error
 
-            result = _execute_task(context, ti, log)
+            try:
+                result = _execute_task(context, ti, log)
+            except Exception:
+                import jinja2
+
+                # If the task failed, swallow rendering error so it doesn't 
mask the main error.
+                with contextlib.suppress(jinja2.TemplateSyntaxError, 
jinja2.UndefinedError):
+                    ti.rendered_map_index = _render_map_index(context, ti=ti, 
log=log)
+                raise
+            else:  # If the task succeeded, render normally to let rendering 
error bubble up.
+                ti.rendered_map_index = _render_map_index(context, ti=ti, 
log=log)
 
         _push_xcom_if_needed(result, ti, log)
 
@@ -851,6 +864,7 @@ def run(
         msg = TaskState(
             state=TaskInstanceState.SKIPPED,
             end_date=datetime.now(tz=timezone.utc),
+            rendered_map_index=ti.rendered_map_index,
         )
         state = TaskInstanceState.SKIPPED
     except AirflowRescheduleException as reschedule:
@@ -868,6 +882,7 @@ def run(
         msg = TaskState(
             state=TaskInstanceState.FAILED,
             end_date=datetime.now(tz=timezone.utc),
+            rendered_map_index=ti.rendered_map_index,
         )
         state = TaskInstanceState.FAILED
         error = e
@@ -884,6 +899,7 @@ def run(
         msg = TaskState(
             state=TaskInstanceState.FAILED,
             end_date=datetime.now(tz=timezone.utc),
+            rendered_map_index=ti.rendered_map_index,
         )
         state = TaskInstanceState.FAILED
         error = e
@@ -915,6 +931,7 @@ def _handle_current_task_success(
         end_date=datetime.now(tz=timezone.utc),
         task_outlets=task_outlets,
         outlet_events=outlet_events,
+        rendered_map_index=ti.rendered_map_index,
     )
     return msg, TaskInstanceState.SUCCESS
 
@@ -925,7 +942,9 @@ def _handle_current_task_failed(
     end_date = datetime.now(tz=timezone.utc)
     if ti._ti_context_from_server and ti._ti_context_from_server.should_retry:
         return RetryTask(end_date=end_date), TaskInstanceState.UP_FOR_RETRY
-    return TaskState(state=TaskInstanceState.FAILED, end_date=end_date), 
TaskInstanceState.FAILED
+    return TaskState(
+        state=TaskInstanceState.FAILED, end_date=end_date, 
rendered_map_index=ti.rendered_map_index
+    ), TaskInstanceState.FAILED
 
 
 def _handle_trigger_dag_run(
@@ -951,11 +970,19 @@ def _handle_trigger_dag_run(
                 "Dag Run already exists, skipping task as 
skip_when_already_exists is set to True.",
                 dag_id=drte.trigger_dag_id,
             )
-            msg = TaskState(state=TaskInstanceState.SKIPPED, 
end_date=datetime.now(tz=timezone.utc))
+            msg = TaskState(
+                state=TaskInstanceState.SKIPPED,
+                end_date=datetime.now(tz=timezone.utc),
+                rendered_map_index=ti.rendered_map_index,
+            )
             state = TaskInstanceState.SKIPPED
         else:
             log.error("Dag Run already exists, marking task as failed.", 
dag_id=drte.trigger_dag_id)
-            msg = TaskState(state=TaskInstanceState.FAILED, 
end_date=datetime.now(tz=timezone.utc))
+            msg = TaskState(
+                state=TaskInstanceState.FAILED,
+                end_date=datetime.now(tz=timezone.utc),
+                rendered_map_index=ti.rendered_map_index,
+            )
             state = TaskInstanceState.FAILED
 
         return msg, state
@@ -1001,7 +1028,11 @@ def _handle_trigger_dag_run(
                 log.error(
                     "DagRun finished with failed state.", 
dag_id=drte.trigger_dag_id, state=comms_msg.state
                 )
-                msg = TaskState(state=TaskInstanceState.FAILED, 
end_date=datetime.now(tz=timezone.utc))
+                msg = TaskState(
+                    state=TaskInstanceState.FAILED,
+                    end_date=datetime.now(tz=timezone.utc),
+                    rendered_map_index=ti.rendered_map_index,
+                )
                 state = TaskInstanceState.FAILED
                 return msg, state
             if comms_msg.state in drte.allowed_states:
@@ -1106,6 +1137,16 @@ def _execute_task(context: Context, ti: 
RuntimeTaskInstance, log: Logger):
     return result
 
 
+def _render_map_index(context: Context, ti: RuntimeTaskInstance, log: Logger) 
-> str | None:
+    """Render named map index if the DAG author defined map_index_template at 
the task level."""
+    if (template := context.get("map_index_template")) is None:
+        return None
+    jinja_env = ti.task.dag.get_template_env()
+    rendered_map_index = jinja_env.from_string(template).render(context)
+    log.info("Map index rendered as %s", rendered_map_index)
+    return rendered_map_index
+
+
 def _push_xcom_if_needed(result: Any, ti: RuntimeTaskInstance, log: Logger):
     """Push XCom values when task has ``do_xcom_push`` set to ``True`` and the 
task returns a result."""
     if ti.task.do_xcom_push:
diff --git a/task-sdk/tests/task_sdk/api/test_client.py 
b/task-sdk/tests/task_sdk/api/test_client.py
index e1f678bb1f7..ce532919749 100644
--- a/task-sdk/tests/task_sdk/api/test_client.py
+++ b/task-sdk/tests/task_sdk/api/test_client.py
@@ -295,13 +295,16 @@ class TestTaskInstanceOperations:
                 actual_body = json.loads(request.read())
                 assert actual_body["end_date"] == "2024-10-31T12:00:00Z"
                 assert actual_body["state"] == state
+                assert actual_body["rendered_map_index"] == "test"
                 return httpx.Response(
                     status_code=204,
                 )
             return httpx.Response(status_code=400, json={"detail": "Bad 
Request"})
 
         client = make_client(transport=httpx.MockTransport(handle_request))
-        client.task_instances.finish(ti_id, state=state, 
when="2024-10-31T12:00:00Z")
+        client.task_instances.finish(
+            ti_id, state=state, when="2024-10-31T12:00:00Z", 
rendered_map_index="test"
+        )
 
     def test_task_instance_heartbeat(self):
         # Simulate a successful response from the server that sends a 
heartbeat for a ti
@@ -383,13 +386,16 @@ class TestTaskInstanceOperations:
                 actual_body = json.loads(request.read())
                 assert actual_body["state"] == "up_for_retry"
                 assert actual_body["end_date"] == "2024-10-31T12:00:00Z"
+                assert actual_body["rendered_map_index"] == "test"
                 return httpx.Response(
                     status_code=204,
                 )
             return httpx.Response(status_code=400, json={"detail": "Bad 
Request"})
 
         client = make_client(transport=httpx.MockTransport(handle_request))
-        client.task_instances.retry(ti_id, 
end_date=timezone.parse("2024-10-31T12:00:00Z"))
+        client.task_instances.retry(
+            ti_id, end_date=timezone.parse("2024-10-31T12:00:00Z"), 
rendered_map_index="test"
+        )
 
     @pytest.mark.parametrize(
         "rendered_fields",
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py 
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index 0795e1155c9..5690f9b418e 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -1193,11 +1193,17 @@ class TestHandleRequest:
                 id="patch_task_instance_to_skipped",
             ),
             pytest.param(
-                RetryTask(end_date=timezone.parse("2024-10-31T12:00:00Z")),
+                RetryTask(
+                    end_date=timezone.parse("2024-10-31T12:00:00Z"), 
rendered_map_index="test retry task"
+                ),
                 b"",
                 "task_instances.retry",
                 (),
-                {"id": TI_ID, "end_date": 
timezone.parse("2024-10-31T12:00:00Z")},
+                {
+                    "id": TI_ID,
+                    "end_date": timezone.parse("2024-10-31T12:00:00Z"),
+                    "rendered_map_index": "test retry task",
+                },
                 "",
                 id="up_for_retry",
             ),
@@ -1317,7 +1323,9 @@ class TestHandleRequest:
                 id="get_asset_events_by_asset_alias",
             ),
             pytest.param(
-                SucceedTask(end_date=timezone.parse("2024-10-31T12:00:00Z")),
+                SucceedTask(
+                    end_date=timezone.parse("2024-10-31T12:00:00Z"), 
rendered_map_index="test success task"
+                ),
                 b"",
                 "task_instances.succeed",
                 (),
@@ -1326,6 +1334,7 @@ class TestHandleRequest:
                     "outlet_events": None,
                     "task_outlets": None,
                     "when": timezone.parse("2024-10-31T12:00:00Z"),
+                    "rendered_map_index": "test success task",
                 },
                 "",
                 id="succeed_task",
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py 
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index f9ad46f357e..8df65b88ebf 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -970,6 +970,37 @@ def test_execute_task_exports_env_vars(
     assert os.environ["AIRFLOW_CTX_TASK_ID"] == "test_env_task"
 
 
+def test_execute_success_task_with_rendered_map_index(create_runtime_ti, 
mock_supervisor_comms):
+    """Test that the map index is rendered in the task context."""
+
+    def test_function():
+        return "test function"
+
+    task = PythonOperator(
+        task_id="test_task",
+        python_callable=test_function,
+        map_index_template="Hello! {{ run_id }}",
+    )
+
+    ti = create_runtime_ti(task=task, dag_id="dag_with_map_index_template")
+
+    run(ti, ti.get_template_context(), log=mock.MagicMock())
+
+    assert ti.rendered_map_index == "Hello! test_run"
+
+
+def test_execute_failed_task_with_rendered_map_index(create_runtime_ti, 
mock_supervisor_comms):
+    """Test that the map index is rendered in the task context."""
+
+    task = BaseOperator(task_id="test_task", map_index_template="Hello! {{ 
run_id }}")
+
+    ti = create_runtime_ti(task=task, dag_id="dag_with_map_index_template")
+
+    run(ti, ti.get_template_context(), log=mock.MagicMock())
+
+    assert ti.rendered_map_index == "Hello! test_run"
+
+
 class TestRuntimeTaskInstance:
     def test_get_context_without_ti_context_from_server(self, mocked_parse, 
make_ti_context):
         """Test get_template_context without ti_context_from_server."""

Reply via email to