This is an automated email from the ASF dual-hosted git repository. kaxilnaik pushed a commit to branch v3-0-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit f001dc2d9b20ee6714164c04ce873cc65ab56848 Author: Wei Lee <[email protected]> AuthorDate: Mon May 5 22:42:19 2025 +0800 Handle MappedTaskGroup map indexes (#49996) * fix(task_group): correctly set upstream_map_indexes for mapped task group * refactor(task_instance): refactor how upstream_map_indexes is handled * test: fix broken tests * fix(task_group): fix maaping handling * refactor(task_instnaces): rewrite _get_upstream_map_indexes (cherry picked from commit 5adeac2c83e4c063e4599efa23575d2be60b2d13) --- .../execution_api/datamodels/taskinstance.py | 2 +- .../execution_api/routes/task_instances.py | 40 ++++++++++++++++++++-- .../versions/head/test_task_instances.py | 3 ++ devel-common/src/tests_common/pytest_plugin.py | 3 +- .../src/airflow/sdk/api/datamodels/_generated.py | 4 ++- task-sdk/src/airflow/sdk/definitions/xcom_arg.py | 23 +++++-------- .../task_sdk/definitions/test_mappedoperator.py | 15 +++++++- 7 files changed, 70 insertions(+), 20 deletions(-) diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py index c3a8be3134c..b83d731a54e 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/datamodels/taskinstance.py @@ -314,7 +314,7 @@ class TIRunContext(BaseModel): connections: Annotated[list[ConnectionResponse], Field(default_factory=list)] """Connections that can be accessed by the task instance.""" - upstream_map_indexes: dict[str, int] | None = None + upstream_map_indexes: dict[str, int | list[int] | None] | None = None next_method: str | None = None """Method to call. Set when task resumes from a trigger.""" diff --git a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py index f97a684e3ff..00a5ea9a5e6 100644 --- a/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow-core/src/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -19,7 +19,8 @@ from __future__ import annotations import json from collections import defaultdict -from typing import Annotated, Any +from collections.abc import Iterator +from typing import TYPE_CHECKING, Annotated, Any from uuid import UUID import structlog @@ -55,9 +56,14 @@ from airflow.models.taskinstance import TaskInstance as TI, _stop_remaining_task from airflow.models.taskreschedule import TaskReschedule from airflow.models.trigger import Trigger from airflow.models.xcom import XComModel +from airflow.sdk.definitions.taskgroup import MappedTaskGroup from airflow.utils import timezone from airflow.utils.state import DagRunState, TaskInstanceState, TerminalTIState +if TYPE_CHECKING: + from airflow.sdk.types import Operator + + router = VersionedAPIRouter() ti_id_router = VersionedAPIRouter( @@ -82,7 +88,10 @@ log = structlog.get_logger(__name__) response_model_exclude_unset=True, ) def ti_run( - task_instance_id: UUID, ti_run_payload: Annotated[TIEnterRunningPayload, Body()], session: SessionDep + task_instance_id: UUID, + ti_run_payload: Annotated[TIEnterRunningPayload, Body()], + session: SessionDep, + request: Request, ) -> TIRunContext: """ Run a TaskInstance. @@ -233,6 +242,11 @@ def ti_run( or 0 ) + if dag := request.app.state.dag_bag.get_dag(ti.dag_id): + upstream_map_indexes = dict(_get_upstream_map_indexes(dag.get_task(ti.task_id), ti.map_index)) + else: + upstream_map_indexes = None + context = TIRunContext( dag_run=dr, task_reschedule_count=task_reschedule_count, @@ -242,6 +256,7 @@ def ti_run( connections=[], xcom_keys_to_clear=xcom_keys, should_retry=_is_eligible_to_retry(previous_state, ti.try_number, ti.max_tries), + upstream_map_indexes=upstream_map_indexes, ) # Only set if they are non-null @@ -257,6 +272,27 @@ def ti_run( ) +def _get_upstream_map_indexes( + task: Operator, ti_map_index: int +) -> Iterator[tuple[str, int | list[int] | None]]: + for upstream_task in task.upstream_list: + map_indexes: int | list[int] | None + if not isinstance(upstream_task.task_group, MappedTaskGroup): + # regular tasks or non-mapped task groups + map_indexes = None + elif task.task_group == upstream_task.task_group: + # tasks in the same mapped task group + # the task should use the map_index as the previous task in the same mapped task group + map_indexes = ti_map_index + else: + # tasks not in the same mapped task group + # the upstream mapped task group should combine the xcom as a list and return it + mapped_ti_count: int = upstream_task.task_group.get_parse_time_mapped_ti_count() + map_indexes = list(range(mapped_ti_count)) if mapped_ti_count is not None else None + + yield upstream_task.task_id, map_indexes + + @ti_id_router.patch( "/{task_instance_id}/state", status_code=status.HTTP_204_NO_CONTENT, diff --git a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py index 41bba2c2bd5..95631484087 100644 --- a/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py +++ b/airflow-core/tests/unit/api_fastapi/execution_api/versions/head/test_task_instances.py @@ -179,6 +179,7 @@ class TestTIRunState: "consumed_asset_events": [], }, "task_reschedule_count": 0, + "upstream_map_indexes": None, "max_tries": max_tries, "should_retry": should_retry, "variables": [], @@ -257,6 +258,7 @@ class TestTIRunState: assert response.json() == { "dag_run": mock.ANY, "task_reschedule_count": 0, + "upstream_map_indexes": None, "max_tries": 0, "should_retry": False, "variables": [], @@ -318,6 +320,7 @@ class TestTIRunState: assert response.json() == { "dag_run": mock.ANY, "task_reschedule_count": 0, + "upstream_map_indexes": None, "max_tries": 0, "should_retry": False, "variables": [], diff --git a/devel-common/src/tests_common/pytest_plugin.py b/devel-common/src/tests_common/pytest_plugin.py index b461ddc409d..2b31152719d 100644 --- a/devel-common/src/tests_common/pytest_plugin.py +++ b/devel-common/src/tests_common/pytest_plugin.py @@ -2088,7 +2088,7 @@ def create_runtime_ti(mocked_parse): run_type: str = "manual", try_number: int = 1, map_index: int | None = -1, - upstream_map_indexes: dict[str, int] | None = None, + upstream_map_indexes: dict[str, int | list[int] | None] | None = None, task_reschedule_count: int = 0, ti_id: UUID | None = None, conf: dict[str, Any] | None = None, @@ -2143,6 +2143,7 @@ def create_runtime_ti(mocked_parse): task_reschedule_count=task_reschedule_count, max_tries=task_retries if max_tries is None else max_tries, should_retry=should_retry if should_retry is not None else try_number <= task_retries, + upstream_map_indexes=upstream_map_indexes, ) if upstream_map_indexes is not None: diff --git a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py index a477c3ffc41..dbf39fa4ae1 100644 --- a/task-sdk/src/airflow/sdk/api/datamodels/_generated.py +++ b/task-sdk/src/airflow/sdk/api/datamodels/_generated.py @@ -480,7 +480,9 @@ class TIRunContext(BaseModel): max_tries: Annotated[int, Field(title="Max Tries")] variables: Annotated[list[VariableResponse] | None, Field(title="Variables")] = None connections: Annotated[list[ConnectionResponse] | None, Field(title="Connections")] = None - upstream_map_indexes: Annotated[dict[str, int] | None, Field(title="Upstream Map Indexes")] = None + upstream_map_indexes: Annotated[ + dict[str, int | list[int] | None] | None, Field(title="Upstream Map Indexes") + ] = None next_method: Annotated[str | None, Field(title="Next Method")] = None next_kwargs: Annotated[dict[str, Any] | str | None, Field(title="Next Kwargs")] = None xcom_keys_to_clear: Annotated[list[str] | None, Field(title="Xcom Keys To Clear")] = None diff --git a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py index 1adcb7efaa7..2a93585304c 100644 --- a/task-sdk/src/airflow/sdk/definitions/xcom_arg.py +++ b/task-sdk/src/airflow/sdk/definitions/xcom_arg.py @@ -339,22 +339,17 @@ class PlainXComArg(XComArg): if self.operator.is_mapped: return LazyXComSequence(xcom_arg=self, ti=ti) tg = self.operator.get_closest_mapped_task_group() - result = None if tg is None: - # regular task - result = ti.xcom_pull( - task_ids=task_id, - key=self.key, - default=NOTSET, - map_indexes=None, - ) + map_indexes = None else: - # task from a task group - result = ti.xcom_pull( - task_ids=task_id, - key=self.key, - default=NOTSET, - ) + upstream_map_indexes = getattr(ti, "_upstream_map_indexes", {}) + map_indexes = upstream_map_indexes.get(task_id, None) + result = ti.xcom_pull( + task_ids=task_id, + key=self.key, + default=NOTSET, + map_indexes=map_indexes, + ) if not isinstance(result, ArgNotSet): return result if self.key == XCOM_RETURN_KEY: diff --git a/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py b/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py index d6b8ca8da4e..6cdb4b520f2 100644 --- a/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py +++ b/task-sdk/tests/task_sdk/definitions/test_mappedoperator.py @@ -627,9 +627,22 @@ def test_operator_mapped_task_group_receives_value(create_runtime_ti, mock_super "tg.t2": range(3), "t3": [None], } + upstream_map_indexes_per_task_id = { + ("tg.t1", 0): {}, + ("tg.t1", 1): {}, + ("tg.t1", 2): {}, + ("tg.t2", 0): {"tg.t1": 0}, + ("tg.t2", 1): {"tg.t1": 1}, + ("tg.t2", 2): {"tg.t1": 2}, + ("t3", None): {"tg.t2": [0, 1, 2]}, + } for task in dag.tasks: for map_index in expansion_per_task_id[task.task_id]: - mapped_ti = create_runtime_ti(task=task.prepare_for_execution(), map_index=map_index) + mapped_ti = create_runtime_ti( + task=task.prepare_for_execution(), + map_index=map_index, + upstream_map_indexes=upstream_map_indexes_per_task_id[(task.task_id, map_index)], + ) context = mapped_ti.get_template_context() mapped_ti.task.render_template_fields(context) mapped_ti.task.execute(context)
