This is an automated email from the ASF dual-hosted git repository.
uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new d3dc592ac3a Render custom `map_index_template` on task completion
(#49809)
d3dc592ac3a is described below
commit d3dc592ac3a663250aa0ef0a5aa9ac61c26894dd
Author: Purna Chander Dharam <[email protected]>
AuthorDate: Mon May 5 14:25:17 2025 +0530
Render custom `map_index_template` on task completion (#49809)
---
.../execution_api/datamodels/taskinstance.py | 4 ++
.../api_fastapi/execution_api/versions/__init__.py | 3 ++
.../versions/{__init__.py => v2025_04_28.py} | 23 ++++++++--
.../unit/api_fastapi/execution_api/test_app.py | 2 +-
.../versions/head/test_task_instances.py | 34 +++++++++++++++
task-sdk/src/airflow/sdk/api/client.py | 19 +++++---
.../src/airflow/sdk/api/datamodels/_generated.py | 6 ++-
task-sdk/src/airflow/sdk/execution_time/comms.py | 1 +
.../src/airflow/sdk/execution_time/supervisor.py | 12 ++++-
.../src/airflow/sdk/execution_time/task_runner.py | 51 +++++++++++++++++++---
task-sdk/tests/task_sdk/api/test_client.py | 10 ++++-
.../task_sdk/execution_time/test_supervisor.py | 15 +++++--
.../task_sdk/execution_time/test_task_runner.py | 31 +++++++++++++
13 files changed, 188 insertions(+), 23 deletions(-)
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
index cd8287be97b..c3a8be3134c 100644
---
a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
+++
b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py
@@ -75,6 +75,7 @@ class TITerminalStatePayload(StrictBaseModel):
end_date: UtcDateTime
"""When the task completed executing"""
+ rendered_map_index: str | None = None
class TISuccessStatePayload(StrictBaseModel):
@@ -97,6 +98,7 @@ class TISuccessStatePayload(StrictBaseModel):
task_outlets: Annotated[list[AssetProfile], Field(default_factory=list)]
outlet_events: Annotated[list[dict[str, Any]], Field(default_factory=list)]
+ rendered_map_index: str | None = None
class TITargetStatePayload(StrictBaseModel):
@@ -136,6 +138,7 @@ class TIDeferredStatePayload(StrictBaseModel):
Both forms will be passed along to the TaskSDK upon resume, the server
will not handle either.
"""
+ rendered_map_index: str | None = None
class TIRescheduleStatePayload(StrictBaseModel):
@@ -171,6 +174,7 @@ class TIRetryStatePayload(StrictBaseModel):
),
]
end_date: UtcDateTime
+ rendered_map_index: str | None = None
class TISkippedDownstreamTasksStatePayload(StrictBaseModel):
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
index 0d3a225305b..54329054c12 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
@@ -19,7 +19,10 @@ from __future__ import annotations
from cadwyn import HeadVersion, Version, VersionBundle
+from airflow.api_fastapi.execution_api.versions.v2025_04_28 import
AddRenderedMapIndexField
+
bundle = VersionBundle(
HeadVersion(),
+ Version("2025-04-28", AddRenderedMapIndexField),
Version("2025-04-11"),
)
diff --git
a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_04_28.py
similarity index 52%
copy from
airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
copy to
airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_04_28.py
index 0d3a225305b..e0916b4c93d 100644
--- a/airflow-core/src/airflow/api_fastapi/execution_api/versions/__init__.py
+++ b/airflow-core/src/airflow/api_fastapi/execution_api/versions/v2025_04_28.py
@@ -17,9 +17,24 @@
from __future__ import annotations
-from cadwyn import HeadVersion, Version, VersionBundle
+from cadwyn import VersionChange, schema
-bundle = VersionBundle(
- HeadVersion(),
- Version("2025-04-11"),
+from airflow.api_fastapi.execution_api.datamodels.taskinstance import (
+ TIDeferredStatePayload,
+ TIRetryStatePayload,
+ TISuccessStatePayload,
+ TITerminalStatePayload,
)
+
+
+class AddRenderedMapIndexField(VersionChange):
+ """Add the `rendered_map_index` field to payload models."""
+
+ description = __doc__
+
+ instructions_to_migrate_to_previous_version = (
+ schema(TITerminalStatePayload).field("rendered_map_index").didnt_exist,
+ schema(TISuccessStatePayload).field("rendered_map_index").didnt_exist,
+ schema(TIDeferredStatePayload).field("rendered_map_index").didnt_exist,
+ schema(TIRetryStatePayload).field("rendered_map_index").didnt_exist,
+ )
diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py
b/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py
index adaa93bc202..32c53ae0db9 100644
--- a/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py
+++ b/airflow-core/tests/unit/api_fastapi/execution_api/test_app.py
@@ -26,7 +26,7 @@ pytestmark = pytest.mark.db_test
def test_custom_openapi_includes_extra_schemas(client):
"""Test to ensure that extra schemas are correctly included in the OpenAPI
schema."""
- response = client.get("/execution/openapi.json?version=2025-04-11")
+ response = client.get("/execution/openapi.json?version=2025-04-28")
assert response.status_code == 200
openapi_schema = response.json()
diff --git
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
index 8af11c04df9..41bba2c2bd5 100644
---
a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
+++
b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py
@@ -44,6 +44,7 @@ pytestmark = pytest.mark.db_test
DEFAULT_START_DATE = timezone.parse("2024-10-31T11:00:00Z")
DEFAULT_END_DATE = timezone.parse("2024-10-31T12:00:00Z")
+DEFAULT_RENDERED_MAP_INDEX = "test rendered map index"
def _create_asset_aliases(session, num: int = 2) -> None:
@@ -465,6 +466,39 @@ class TestTIUpdateState:
assert ti.state == expected_state
assert ti.end_date == end_date
+ @pytest.mark.parametrize(
+ ("state", "end_date", "expected_state", "rendered_map_index"),
+ [
+ (State.SUCCESS, DEFAULT_END_DATE, State.SUCCESS,
DEFAULT_RENDERED_MAP_INDEX),
+ (State.FAILED, DEFAULT_END_DATE, State.FAILED,
DEFAULT_RENDERED_MAP_INDEX),
+ (State.SKIPPED, DEFAULT_END_DATE, State.SKIPPED,
DEFAULT_RENDERED_MAP_INDEX),
+ ],
+ )
+ def test_ti_update_state_to_terminal_with_rendered_map_index(
+ self, client, session, create_task_instance, state, end_date,
expected_state, rendered_map_index
+ ):
+ ti = create_task_instance(
+ task_id="test_ti_update_state_to_terminal_with_rendered_map_index",
+ start_date=DEFAULT_START_DATE,
+ state=State.RUNNING,
+ )
+ session.commit()
+
+ response = client.patch(
+ f"/execution/task-instances/{ti.id}/state",
+ json={"state": state, "end_date": end_date.isoformat(),
"rendered_map_index": rendered_map_index},
+ )
+
+ assert response.status_code == 204
+ assert response.text == ""
+
+ session.expire_all()
+
+ ti = session.get(TaskInstance, ti.id)
+ assert ti.state == expected_state
+ assert ti.end_date == end_date
+ assert ti.rendered_map_index == rendered_map_index
+
@pytest.mark.parametrize(
"task_outlets",
[
diff --git a/task-sdk/src/airflow/sdk/api/client.py
b/task-sdk/src/airflow/sdk/api/client.py
index 46e4fe6907a..5f76e8360be 100644
--- a/task-sdk/src/airflow/sdk/api/client.py
+++ b/task-sdk/src/airflow/sdk/api/client.py
@@ -146,22 +146,29 @@ class TaskInstanceOperations:
resp = self.client.patch(f"task-instances/{id}/run",
content=body.model_dump_json())
return TIRunContext.model_validate_json(resp.read())
- def finish(self, id: uuid.UUID, state: TerminalStateNonSuccess, when:
datetime):
+ def finish(self, id: uuid.UUID, state: TerminalStateNonSuccess, when:
datetime, rendered_map_index):
"""Tell the API server that this TI has reached a terminal state."""
if state == TaskInstanceState.SUCCESS:
raise ValueError("Logic error. SUCCESS state should call the
`succeed` function instead")
# TODO: handle the naming better. finish sounds wrong as "even"
deferred is essentially finishing.
- body = TITerminalStatePayload(end_date=when,
state=TerminalStateNonSuccess(state))
+ body = TITerminalStatePayload(
+ end_date=when, state=TerminalStateNonSuccess(state),
rendered_map_index=rendered_map_index
+ )
self.client.patch(f"task-instances/{id}/state",
content=body.model_dump_json())
- def retry(self, id: uuid.UUID, end_date: datetime):
+ def retry(self, id: uuid.UUID, end_date: datetime, rendered_map_index):
"""Tell the API server that this TI has failed and reached a
up_for_retry state."""
- body = TIRetryStatePayload(end_date=end_date)
+ body = TIRetryStatePayload(end_date=end_date,
rendered_map_index=rendered_map_index)
self.client.patch(f"task-instances/{id}/state",
content=body.model_dump_json())
- def succeed(self, id: uuid.UUID, when: datetime, task_outlets,
outlet_events):
+ def succeed(self, id: uuid.UUID, when: datetime, task_outlets,
outlet_events, rendered_map_index):
"""Tell the API server that this TI has succeeded."""
- body = TISuccessStatePayload(end_date=when, task_outlets=task_outlets,
outlet_events=outlet_events)
+ body = TISuccessStatePayload(
+ end_date=when,
+ task_outlets=task_outlets,
+ outlet_events=outlet_events,
+ rendered_map_index=rendered_map_index,
+ )
self.client.patch(f"task-instances/{id}/state",
content=body.model_dump_json())
def heartbeat(self, id: uuid.UUID, pid: int):
diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
index 7c7647635ea..a477c3ffc41 100644
--- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
+++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py
@@ -27,7 +27,7 @@ from uuid import UUID
from pydantic import AwareDatetime, BaseModel, ConfigDict, Field, JsonValue
-API_VERSION: Final[str] = "2025-04-11"
+API_VERSION: Final[str] = "2025-04-28"
class AssetAliasReferenceAssetEventDagRun(BaseModel):
@@ -193,6 +193,7 @@ class TIDeferredStatePayload(BaseModel):
trigger_timeout: Annotated[timedelta | None, Field(title="Trigger
Timeout")] = None
next_method: Annotated[str, Field(title="Next Method")]
next_kwargs: Annotated[dict[str, Any] | str | None, Field(title="Next
Kwargs")] = None
+ rendered_map_index: Annotated[str | None, Field(title="Rendered Map
Index")] = None
class TIEnterRunningPayload(BaseModel):
@@ -245,6 +246,7 @@ class TIRetryStatePayload(BaseModel):
)
state: Annotated[Literal["up_for_retry"] | None, Field(title="State")] =
"up_for_retry"
end_date: Annotated[AwareDatetime, Field(title="End Date")]
+ rendered_map_index: Annotated[str | None, Field(title="Rendered Map
Index")] = None
class TISkippedDownstreamTasksStatePayload(BaseModel):
@@ -270,6 +272,7 @@ class TISuccessStatePayload(BaseModel):
end_date: Annotated[AwareDatetime, Field(title="End Date")]
task_outlets: Annotated[list[AssetProfile] | None, Field(title="Task
Outlets")] = None
outlet_events: Annotated[list[dict[str, Any]] | None, Field(title="Outlet
Events")] = None
+ rendered_map_index: Annotated[str | None, Field(title="Rendered Map
Index")] = None
class TITargetStatePayload(BaseModel):
@@ -494,3 +497,4 @@ class TITerminalStatePayload(BaseModel):
)
state: TerminalStateNonSuccess
end_date: Annotated[AwareDatetime, Field(title="End Date")]
+ rendered_map_index: Annotated[str | None, Field(title="Rendered Map
Index")] = None
diff --git a/task-sdk/src/airflow/sdk/execution_time/comms.py
b/task-sdk/src/airflow/sdk/execution_time/comms.py
index 699fe5c0370..a25ba574582 100644
--- a/task-sdk/src/airflow/sdk/execution_time/comms.py
+++ b/task-sdk/src/airflow/sdk/execution_time/comms.py
@@ -376,6 +376,7 @@ class TaskState(BaseModel):
]
end_date: datetime | None = None
type: Literal["TaskState"] = "TaskState"
+ rendered_map_index: str | None = None
class SucceedTask(TISuccessStatePayload):
diff --git a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
index 001f8deb701..b5cf977488b 100644
--- a/task-sdk/src/airflow/sdk/execution_time/supervisor.py
+++ b/task-sdk/src/airflow/sdk/execution_time/supervisor.py
@@ -767,6 +767,7 @@ class ActivitySubprocess(WatchedSubprocess):
# TODO: This should come from airflow.cfg: [core] task_success_overtime
TASK_OVERTIME_THRESHOLD: ClassVar[float] = 20.0
_task_end_time_monotonic: float | None = attrs.field(default=None,
init=False)
+ _rendered_map_index: str | None = attrs.field(default=None, init=False)
decoder: ClassVar[TypeAdapter[ToSupervisor]] = TypeAdapter(ToSupervisor)
@@ -842,7 +843,10 @@ class ActivitySubprocess(WatchedSubprocess):
# by the subprocess in the `handle_requests` method.
if self.final_state not in STATES_SENT_DIRECTLY:
self.client.task_instances.finish(
- id=self.id, state=self.final_state,
when=datetime.now(tz=timezone.utc)
+ id=self.id,
+ state=self.final_state,
+ when=datetime.now(tz=timezone.utc),
+ rendered_map_index=self._rendered_map_index,
)
# Now at the last possible moment, when all logs and comms with the
subprocess has finished, lets
@@ -988,21 +992,26 @@ class ActivitySubprocess(WatchedSubprocess):
if isinstance(msg, TaskState):
self._terminal_state = msg.state
self._task_end_time_monotonic = time.monotonic()
+ self._rendered_map_index = msg.rendered_map_index
elif isinstance(msg, SucceedTask):
self._terminal_state = msg.state
self._task_end_time_monotonic = time.monotonic()
+ self._rendered_map_index = msg.rendered_map_index
self.client.task_instances.succeed(
id=self.id,
when=msg.end_date,
task_outlets=msg.task_outlets,
outlet_events=msg.outlet_events,
+ rendered_map_index=self._rendered_map_index,
)
elif isinstance(msg, RetryTask):
self._terminal_state = msg.state
self._task_end_time_monotonic = time.monotonic()
+ self._rendered_map_index = msg.rendered_map_index
self.client.task_instances.retry(
id=self.id,
end_date=msg.end_date,
+ rendered_map_index=self._rendered_map_index,
)
elif isinstance(msg, GetConnection):
conn = self.client.connections.get(msg.conn_id)
@@ -1045,6 +1054,7 @@ class ActivitySubprocess(WatchedSubprocess):
resp = xcom
elif isinstance(msg, DeferTask):
self._terminal_state = TaskInstanceState.DEFERRED
+ self._rendered_map_index = msg.rendered_map_index
self.client.task_instances.defer(self.id, msg)
elif isinstance(msg, RescheduleTask):
self._terminal_state = TaskInstanceState.UP_FOR_RESCHEDULE
diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 6e1e5f885f5..9092ee86f0b 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -19,6 +19,7 @@
from __future__ import annotations
+import contextlib
import contextvars
import functools
import os
@@ -134,6 +135,8 @@ class RuntimeTaskInstance(TaskInstance):
is_mapped: bool | None = None
"""True if the original task was mapped."""
+ rendered_map_index: str | None = None
+
def __rich_repr__(self):
yield "id", self.id
yield "task_id", self.task_id
@@ -831,7 +834,17 @@ def run(
ti.state = state = TaskInstanceState.FAILED
return state, msg, error
- result = _execute_task(context, ti, log)
+ try:
+ result = _execute_task(context, ti, log)
+ except Exception:
+ import jinja2
+
+ # If the task failed, swallow rendering error so it doesn't
mask the main error.
+ with contextlib.suppress(jinja2.TemplateSyntaxError,
jinja2.UndefinedError):
+ ti.rendered_map_index = _render_map_index(context, ti=ti,
log=log)
+ raise
+ else: # If the task succeeded, render normally to let rendering
error bubble up.
+ ti.rendered_map_index = _render_map_index(context, ti=ti,
log=log)
_push_xcom_if_needed(result, ti, log)
@@ -851,6 +864,7 @@ def run(
msg = TaskState(
state=TaskInstanceState.SKIPPED,
end_date=datetime.now(tz=timezone.utc),
+ rendered_map_index=ti.rendered_map_index,
)
state = TaskInstanceState.SKIPPED
except AirflowRescheduleException as reschedule:
@@ -868,6 +882,7 @@ def run(
msg = TaskState(
state=TaskInstanceState.FAILED,
end_date=datetime.now(tz=timezone.utc),
+ rendered_map_index=ti.rendered_map_index,
)
state = TaskInstanceState.FAILED
error = e
@@ -884,6 +899,7 @@ def run(
msg = TaskState(
state=TaskInstanceState.FAILED,
end_date=datetime.now(tz=timezone.utc),
+ rendered_map_index=ti.rendered_map_index,
)
state = TaskInstanceState.FAILED
error = e
@@ -915,6 +931,7 @@ def _handle_current_task_success(
end_date=datetime.now(tz=timezone.utc),
task_outlets=task_outlets,
outlet_events=outlet_events,
+ rendered_map_index=ti.rendered_map_index,
)
return msg, TaskInstanceState.SUCCESS
@@ -925,7 +942,9 @@ def _handle_current_task_failed(
end_date = datetime.now(tz=timezone.utc)
if ti._ti_context_from_server and ti._ti_context_from_server.should_retry:
return RetryTask(end_date=end_date), TaskInstanceState.UP_FOR_RETRY
- return TaskState(state=TaskInstanceState.FAILED, end_date=end_date),
TaskInstanceState.FAILED
+ return TaskState(
+ state=TaskInstanceState.FAILED, end_date=end_date,
rendered_map_index=ti.rendered_map_index
+ ), TaskInstanceState.FAILED
def _handle_trigger_dag_run(
@@ -951,11 +970,19 @@ def _handle_trigger_dag_run(
"Dag Run already exists, skipping task as
skip_when_already_exists is set to True.",
dag_id=drte.trigger_dag_id,
)
- msg = TaskState(state=TaskInstanceState.SKIPPED,
end_date=datetime.now(tz=timezone.utc))
+ msg = TaskState(
+ state=TaskInstanceState.SKIPPED,
+ end_date=datetime.now(tz=timezone.utc),
+ rendered_map_index=ti.rendered_map_index,
+ )
state = TaskInstanceState.SKIPPED
else:
log.error("Dag Run already exists, marking task as failed.",
dag_id=drte.trigger_dag_id)
- msg = TaskState(state=TaskInstanceState.FAILED,
end_date=datetime.now(tz=timezone.utc))
+ msg = TaskState(
+ state=TaskInstanceState.FAILED,
+ end_date=datetime.now(tz=timezone.utc),
+ rendered_map_index=ti.rendered_map_index,
+ )
state = TaskInstanceState.FAILED
return msg, state
@@ -1001,7 +1028,11 @@ def _handle_trigger_dag_run(
log.error(
"DagRun finished with failed state.",
dag_id=drte.trigger_dag_id, state=comms_msg.state
)
- msg = TaskState(state=TaskInstanceState.FAILED,
end_date=datetime.now(tz=timezone.utc))
+ msg = TaskState(
+ state=TaskInstanceState.FAILED,
+ end_date=datetime.now(tz=timezone.utc),
+ rendered_map_index=ti.rendered_map_index,
+ )
state = TaskInstanceState.FAILED
return msg, state
if comms_msg.state in drte.allowed_states:
@@ -1106,6 +1137,16 @@ def _execute_task(context: Context, ti:
RuntimeTaskInstance, log: Logger):
return result
+def _render_map_index(context: Context, ti: RuntimeTaskInstance, log: Logger)
-> str | None:
+ """Render named map index if the DAG author defined map_index_template at
the task level."""
+ if (template := context.get("map_index_template")) is None:
+ return None
+ jinja_env = ti.task.dag.get_template_env()
+ rendered_map_index = jinja_env.from_string(template).render(context)
+ log.info("Map index rendered as %s", rendered_map_index)
+ return rendered_map_index
+
+
def _push_xcom_if_needed(result: Any, ti: RuntimeTaskInstance, log: Logger):
"""Push XCom values when task has ``do_xcom_push`` set to ``True`` and the
task returns a result."""
if ti.task.do_xcom_push:
diff --git a/task-sdk/tests/task_sdk/api/test_client.py
b/task-sdk/tests/task_sdk/api/test_client.py
index e1f678bb1f7..ce532919749 100644
--- a/task-sdk/tests/task_sdk/api/test_client.py
+++ b/task-sdk/tests/task_sdk/api/test_client.py
@@ -295,13 +295,16 @@ class TestTaskInstanceOperations:
actual_body = json.loads(request.read())
assert actual_body["end_date"] == "2024-10-31T12:00:00Z"
assert actual_body["state"] == state
+ assert actual_body["rendered_map_index"] == "test"
return httpx.Response(
status_code=204,
)
return httpx.Response(status_code=400, json={"detail": "Bad
Request"})
client = make_client(transport=httpx.MockTransport(handle_request))
- client.task_instances.finish(ti_id, state=state,
when="2024-10-31T12:00:00Z")
+ client.task_instances.finish(
+ ti_id, state=state, when="2024-10-31T12:00:00Z",
rendered_map_index="test"
+ )
def test_task_instance_heartbeat(self):
# Simulate a successful response from the server that sends a
heartbeat for a ti
@@ -383,13 +386,16 @@ class TestTaskInstanceOperations:
actual_body = json.loads(request.read())
assert actual_body["state"] == "up_for_retry"
assert actual_body["end_date"] == "2024-10-31T12:00:00Z"
+ assert actual_body["rendered_map_index"] == "test"
return httpx.Response(
status_code=204,
)
return httpx.Response(status_code=400, json={"detail": "Bad
Request"})
client = make_client(transport=httpx.MockTransport(handle_request))
- client.task_instances.retry(ti_id,
end_date=timezone.parse("2024-10-31T12:00:00Z"))
+ client.task_instances.retry(
+ ti_id, end_date=timezone.parse("2024-10-31T12:00:00Z"),
rendered_map_index="test"
+ )
@pytest.mark.parametrize(
"rendered_fields",
diff --git a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
index 0795e1155c9..5690f9b418e 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_supervisor.py
@@ -1193,11 +1193,17 @@ class TestHandleRequest:
id="patch_task_instance_to_skipped",
),
pytest.param(
- RetryTask(end_date=timezone.parse("2024-10-31T12:00:00Z")),
+ RetryTask(
+ end_date=timezone.parse("2024-10-31T12:00:00Z"),
rendered_map_index="test retry task"
+ ),
b"",
"task_instances.retry",
(),
- {"id": TI_ID, "end_date":
timezone.parse("2024-10-31T12:00:00Z")},
+ {
+ "id": TI_ID,
+ "end_date": timezone.parse("2024-10-31T12:00:00Z"),
+ "rendered_map_index": "test retry task",
+ },
"",
id="up_for_retry",
),
@@ -1317,7 +1323,9 @@ class TestHandleRequest:
id="get_asset_events_by_asset_alias",
),
pytest.param(
- SucceedTask(end_date=timezone.parse("2024-10-31T12:00:00Z")),
+ SucceedTask(
+ end_date=timezone.parse("2024-10-31T12:00:00Z"),
rendered_map_index="test success task"
+ ),
b"",
"task_instances.succeed",
(),
@@ -1326,6 +1334,7 @@ class TestHandleRequest:
"outlet_events": None,
"task_outlets": None,
"when": timezone.parse("2024-10-31T12:00:00Z"),
+ "rendered_map_index": "test success task",
},
"",
id="succeed_task",
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index f9ad46f357e..8df65b88ebf 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -970,6 +970,37 @@ def test_execute_task_exports_env_vars(
assert os.environ["AIRFLOW_CTX_TASK_ID"] == "test_env_task"
+def test_execute_success_task_with_rendered_map_index(create_runtime_ti,
mock_supervisor_comms):
+ """Test that the map index is rendered in the task context."""
+
+ def test_function():
+ return "test function"
+
+ task = PythonOperator(
+ task_id="test_task",
+ python_callable=test_function,
+ map_index_template="Hello! {{ run_id }}",
+ )
+
+ ti = create_runtime_ti(task=task, dag_id="dag_with_map_index_template")
+
+ run(ti, ti.get_template_context(), log=mock.MagicMock())
+
+ assert ti.rendered_map_index == "Hello! test_run"
+
+
+def test_execute_failed_task_with_rendered_map_index(create_runtime_ti,
mock_supervisor_comms):
+ """Test that the map index is rendered in the task context."""
+
+ task = BaseOperator(task_id="test_task", map_index_template="Hello! {{
run_id }}")
+
+ ti = create_runtime_ti(task=task, dag_id="dag_with_map_index_template")
+
+ run(ti, ti.get_template_context(), log=mock.MagicMock())
+
+ assert ti.rendered_map_index == "Hello! test_run"
+
+
class TestRuntimeTaskInstance:
def test_get_context_without_ti_context_from_server(self, mocked_parse,
make_ti_context):
"""Test get_template_context without ti_context_from_server."""