kaxil commented on code in PR #62922:
URL: https://github.com/apache/airflow/pull/62922#discussion_r3348559776


##########
task-sdk/src/airflow/sdk/execution_time/task_runner.py:
##########
@@ -1934,7 +2053,10 @@ def _execute_task(context: Context, ti: 
RuntimeTaskInstance, log: Logger):
     ctx.run(ExecutorSafeguard.tracker.set, task)
 
     # Export context in os.environ to make it available for operators to use.
+    # Use thread-safe context variable storage to avoid race conditions in 
concurrent execution.
     airflow_context_vars = context_to_airflow_vars(context, 
in_env_var_format=True)
+    # Also update os.environ for backward compatibility with code that reads 
env vars directly.
+    # The context variable (below) provides thread-safe access without race 
conditions.
     os.environ.update(airflow_context_vars)

Review Comment:
   The new `airflow_context_vars_context` wrapper handles the thread-safe path, 
but this `os.environ.update` right below it still mutates the process-global 
env, and the comment itself flags it as the path for "code that reads env vars 
directly." Those direct readers are exactly the ones that race: two sub-tasks 
running concurrently through the thread pool clobber each other's 
`AIRFLOW_CTX_*`, so e.g. a BashOperator's templated env or a subprocess hook 
can pick up a sibling's `task_id` / `map_index`. Same race as the earlier 
round, now only half-closed by the contextvar.



##########
task-sdk/src/airflow/sdk/definitions/context.py:
##########
@@ -94,6 +94,68 @@ class Context(TypedDict, total=False):
 KNOWN_CONTEXT_KEYS: set[str] = set(Context.__annotations__.keys())
 
 
+def clone_context(context: Context) -> Context:
+    """
+    Create a safe, per-task copy of an execution ``Context`` for concurrent 
execution.
+
+    The execution context is a mutable mapping that contains many nested
+    structures (``params``, ``templates_dict``, ``outlet_events``, ``dag_run``,
+    etc.). When running the same logical task concurrently (for example when
+    the ``IterableOperator`` spawns multiple indexed task instances that run in
+    parallel using threads, processes or asyncio tasks), those mutable objects
+    could be mutated by one indexed runtime and unintentionally observed by
+    another. That leads to subtle race conditions, corrupted state, and
+    flakiness in task execution.
+
+    ``clone_context`` returns a new :class:`Context` mapping where the 
top-level
+    mapping is copied and specific mutable sub-objects that are commonly
+    mutated during execution are deep-copied or shallow-copied as appropriate:
+
+    - ``params`` and ``templates_dict`` are deep-copied because they are
+      dictionaries that users and operators commonly mutate.
+    - ``inlets`` and ``outlets`` are converted to new lists (shallow copy)
+      because the sequence identity must be isolated but the elements are
+      typically read-only accessor objects.
+    - ``outlet_events`` and ``dag_run`` are deep-copied because they carry
+      nested state that must not be shared between concurrent executions.
+
+    Use cases
+    - Multithreading: when using thread-based executors (``concurrent.futures``
+      ThreadPoolExecutor) multiple threads share memory; cloning prevents
+      concurrent mutation of shared structures.
+    - Async concurrency: when running coroutine-based tasks concurrently in
+      the same event loop, tasks may still mutate shared mappings; cloning
+      avoids interference.
+    - Multiprocessing: while processes do not share memory, cloning keeps the
+      semantics consistent and avoids accidentally capturing references that
+      would be pickled.
+
+    Performance
+    - The implementation intentionally copies only a small set of commonly
+      mutated fields rather than performing a blanket deep copy of the entire
+      context to keep the operation cheap. If future code stores additional
+      mutable state in the context that needs isolation, this function should
+      be extended appropriately.
+
+    :param context: The original execution context to clone.
+    :returns: A new :class:`Context` safe to hand to a concurrently running 
task.
+
+    :meta private:
+    """
+    cloned_context = Context()
+    cloned_context.update(context)
+    cloned_context["params"] = copy.deepcopy(context.get("params", {}))
+    cloned_context["inlets"] = list(context.get("inlets", []))
+    cloned_context["outlets"] = list(context.get("outlets", []))
+    templates_dict = cloned_context.get("templates_dict")
+    if templates_dict is not None:
+        cloned_context["templates_dict"] = copy.deepcopy(templates_dict)
+    cloned_context["inlet_events"] = copy.deepcopy(context["inlet_events"])
+    cloned_context["outlet_events"] = copy.deepcopy(context["outlet_events"])

Review Comment:
   `clone_context` fixes the params/inlets race, but deep-copying 
`outlet_events` here means each sub-task emits into its own throwaway copy. 
`_run_operator` returns only the xcom result and discards the clone, while the 
supervisor serializes the parent IterableOperator's `outlet_events` in 
[`_handle_current_task_success`](https://github.com/apache/airflow/blob/fcfb3f472eb2d3524a3065a5fe917922ac67e44f/task-sdk/src/airflow/sdk/execution_time/task_runner.py#L1709),
 so any `yield Metadata(...)` or `context['outlet_events'][asset]` from a 
sub-task is silently dropped, no event and no error. Static 
`outlets=[Asset(...)]` still emit once via the parent, but dynamic per-sub-task 
asset metadata is lost, which breaks asset-driven scheduling downstream. Either 
merge each clone's events back into the parent before discarding, or document 
the limitation.



##########
task-sdk/src/airflow/sdk/execution_time/executor.py:
##########
@@ -0,0 +1,328 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import contextvars
+import inspect
+import logging
+import time
+from asyncio import (
+    FIRST_COMPLETED,
+    AbstractEventLoop,
+    Future,
+    Semaphore,
+    Task,
+    TimeoutError as AsyncTimeoutError,
+    gather,
+    wait,
+    wait_for,
+    wrap_future,
+)
+from collections.abc import Callable, Iterable, Iterator
+from concurrent.futures import Executor, ThreadPoolExecutor
+from typing import TYPE_CHECKING, Any, cast
+
+from airflow.sdk import BaseAsyncOperator, BaseOperator, TaskInstanceState, 
timezone
+from airflow.sdk.bases.operator import ExecutorSafeguard
+from airflow.sdk.definitions._internal.logging_mixin import LoggingMixin
+from airflow.sdk.exceptions import AirflowRescheduleTaskInstanceException, 
TaskDeferred
+from airflow.sdk.execution_time.callback_runner import create_executable_runner
+from airflow.sdk.execution_time.context import context_get_outlet_events
+from airflow.sdk.execution_time.task_runner import (
+    RuntimeTaskInstance,
+    _execute_task,
+    _run_task_state_change_callbacks,
+)
+
+if TYPE_CHECKING:
+    from structlog.typing import FilteringBoundLogger as Logger
+
+    from airflow.sdk import Context
+    from airflow.sdk.execution_time.task_runner import IndexedTaskInstance
+
+
+class AsyncAwareExecutor(Executor):
+    """
+    Executes both sync and async functions concurrently.
+
+    Sync functions run in a ThreadPoolExecutor.
+    Async coroutines run on an asyncio event loop with a semaphore limit.
+
+    :param loop: Event loop used to schedule async tasks and coordinate mixed 
execution.
+    :param max_workers: Maximum concurrent workers used by both thread pool 
and async semaphore.
+    """
+
+    def __init__(self, loop: AbstractEventLoop, max_workers: int = 4):
+        self._loop = loop
+        self._max_workers = max_workers
+        self._semaphore = Semaphore(max_workers)
+        self._thread_pool = ThreadPoolExecutor(max_workers=max_workers)
+        self._async_tasks: set[Task[Any]] = set()
+        self._shutdown = False
+
+    def __enter__(self):
+        return self
+
+    def __exit__(self, exc_type, exc_val, exc_tb):
+        self.shutdown(wait=True)

Review Comment:
   `__exit__` hardcodes `shutdown(wait=True)` with the default 
`cancel_futures=False`, so when `map(..., timeout=...)` raises and unwinds this 
block, the `gather(...)` and thread-pool join wait on the very worker that just 
timed out, which defeats the timeout you enforced. `shutdown` already takes 
`cancel_futures`, so `shutdown(wait=False, cancel_futures=True)` on the 
error/timeout path would let it bail (worth noting in-flight threads are 
abandoned, not killed).



##########
task-sdk/src/airflow/sdk/definitions/iterableoperator.py:
##########
@@ -0,0 +1,519 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import copy
+import os
+from collections import deque
+from collections.abc import Iterable, Mapping, Sequence
+from functools import cached_property
+from itertools import repeat
+from typing import TYPE_CHECKING, Any
+
+try:
+    # Python 3.11+
+    BaseExceptionGroup
+except NameError:
+    from exceptiongroup import BaseExceptionGroup
+
+from airflow.sdk import BaseXCom, TaskInstanceState, timezone
+from airflow.sdk.bases.operator import BaseOperator, 
DecoratedDeferredAsyncOperator, event_loop
+from airflow.sdk.definitions._internal.expandinput import 
PartitionedExpandInput
+from airflow.sdk.definitions.context import clone_context
+from airflow.sdk.definitions.mappedoperator import MappedOperator
+from airflow.sdk.definitions.xcom_arg import MapXComArg, XComArg  # noqa: F401
+from airflow.sdk.exceptions import (
+    AirflowFailException,
+    AirflowRescheduleTaskInstanceException,
+    AirflowTaskTimeout,
+    TaskDeferred,
+)
+from airflow.sdk.execution_time.executor import AsyncAwareExecutor, 
TaskExecutor
+from airflow.sdk.execution_time.task_runner import IndexedTaskInstance
+
+if TYPE_CHECKING:
+    import jinja2
+
+    from airflow.providers.standard.triggers.temporal import DateTimeTrigger
+    from airflow.sdk.definitions._internal.expandinput import ExpandInput
+    from airflow.sdk.definitions.context import Context
+    from airflow.sdk.execution_time.lazy_sequence import XComIterable
+
+
+ExternalDateTimeTrigger: type[DateTimeTrigger] | None
+
+try:
+    from airflow.providers.standard.triggers.temporal import DateTimeTrigger 
as ExternalDateTimeTrigger
+except ModuleNotFoundError:
+    # If the providers package with DateTimeTrigger is not available (e.g. in
+    # minimal installs or tests), set the symbol to None so callers can
+    # explicitly check for availability. Using hasattr(self, DateTimeTrigger)
+    # is incorrect because hasattr expects a string attribute name.
+    ExternalDateTimeTrigger = None
+
+
+class IterableOperator(BaseOperator):
+    """
+    Operator used for Dynamic Task Iteration (DTI) that runs a mapped operator 
over an iterable input.
+
+    The IterableOperator wraps a :class:`MappedOperator` together with an
+    :class:`ExpandInput` and is responsible for creating and running the
+    per-index runtime task instances. The IterableOperator itself is a
+    lightweight, non-retrying wrapper — retries, timeouts and deferred
+    execution are handled by the individual indexed task instances that the
+    IterableOperator creates for each element produced by the
+    ``expand_input``.
+
+    The IterableOperator executes the mapped operator instances using a
+    concurrent executor with a configurable number of workers. By default
+    the worker count is taken from the mapped operator's ``partial_kwargs``
+    (``task_concurrency``) if present, otherwise falls back to
+    ``os.cpu_count()`` and finally to ``1``.
+
+    :param operator: The :class:`MappedOperator` to unmap and execute for
+        each element of ``expand_input``. Each indexed runtime receives a
+        deep copy/unmapped instance of this operator.
+
+    :param expand_input: Provider of the values (or partitions) to iterate
+        over. Its ``iter_values(context)`` method is used to produce the
+        per-index ``mapped_kwargs`` used to unmap the operator.
+
+    :param kwargs: Additional keyword arguments forwarded to
+        :class:`BaseOperator` when instantiating the IterableOperator
+        (e.g. ``dag``, ``start_date``). Note that the IterableOperator
+        overrides retry-related parameters because retries are managed by
+        the per-index tasks.
+
+    :returns: An :class:`XComIterable` if the mapped operator pushes XComs, 
otherwise ``None``.
+    """
+
+    _operator: MappedOperator
+    expand_input: ExpandInput
+    partial_kwargs: dict[str, Any]
+    shallow_copy_attrs: Sequence[str] = (
+        "_operator",
+        "expand_input",
+        "partial_kwargs",
+        "_log",
+    )
+
+    def __init__(
+        self,
+        *,
+        operator: MappedOperator,
+        expand_input: ExpandInput,
+        **kwargs,
+    ):
+        super().__init__(
+            **{
+                **kwargs,
+                "task_id": operator.task_id,
+                "owner": operator.owner,
+                "email": operator.email,
+                "email_on_retry": operator.email_on_retry,
+                "email_on_failure": operator.email_on_failure,
+                "retries": 0,  # We should not retry the IterableOperator, 
only the indexed runtime ti's should be retried
+                "retry_delay": operator.retry_delay,
+                "retry_exponential_backoff": 
operator.retry_exponential_backoff,
+                "max_retry_delay": operator.max_retry_delay,
+                "start_date": operator.start_date,
+                "end_date": operator.end_date,
+                "depends_on_past": operator.depends_on_past,
+                "ignore_first_depends_on_past": 
operator.ignore_first_depends_on_past,
+                "wait_for_past_depends_before_skipping": 
operator.wait_for_past_depends_before_skipping,
+                "wait_for_downstream": operator.wait_for_downstream,
+                "dag": operator.dag,
+                "priority_weight": operator.priority_weight,
+                "queue": operator.queue,
+                "pool": operator.pool,
+                "pool_slots": operator.pool_slots,
+                "execution_timeout": None,
+                "trigger_rule": operator.trigger_rule,
+                "resources": operator.resources,
+                "run_as_user": operator.run_as_user,
+                "map_index_template": operator.map_index_template,
+                "max_active_tis_per_dag": operator.max_active_tis_per_dag,
+                "max_active_tis_per_dagrun": 
operator.max_active_tis_per_dagrun,
+                "executor": operator.executor,
+                "executor_config": operator.executor_config,
+                "inlets": operator.inlets,
+                "outlets": operator.outlets,
+                "task_group": operator.task_group,
+                "doc": operator.doc,
+                "doc_md": operator.doc_md,
+                "doc_json": operator.doc_json,
+                "doc_yaml": operator.doc_yaml,
+                "doc_rst": operator.doc_rst,
+                "task_display_name": operator.task_display_name,
+                "allow_nested_operators": operator.allow_nested_operators,
+            }
+        )
+        self._operator = operator
+        self.expand_input = expand_input
+        self.partial_kwargs = operator.partial_kwargs or {}

Review Comment:
   `operator.partial_kwargs or {}` hands back the same dict object when it's 
non-empty, so the `.pop("task_concurrency")` on the next line mutates the 
wrapped MappedOperator's own `partial_kwargs`. Copy it before popping. 
Separately, `pop(..., None) or os.cpu_count()` treats an explicit 
`task_concurrency=0` as falsy and silently runs at `os.cpu_count()` instead.



##########
task-sdk/src/airflow/sdk/definitions/iterableoperator.py:
##########
@@ -0,0 +1,519 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import copy
+import os
+from collections import deque
+from collections.abc import Iterable, Mapping, Sequence
+from functools import cached_property
+from itertools import repeat
+from typing import TYPE_CHECKING, Any
+
+try:
+    # Python 3.11+
+    BaseExceptionGroup
+except NameError:
+    from exceptiongroup import BaseExceptionGroup
+
+from airflow.sdk import BaseXCom, TaskInstanceState, timezone
+from airflow.sdk.bases.operator import BaseOperator, 
DecoratedDeferredAsyncOperator, event_loop
+from airflow.sdk.definitions._internal.expandinput import 
PartitionedExpandInput
+from airflow.sdk.definitions.context import clone_context
+from airflow.sdk.definitions.mappedoperator import MappedOperator
+from airflow.sdk.definitions.xcom_arg import MapXComArg, XComArg  # noqa: F401
+from airflow.sdk.exceptions import (
+    AirflowFailException,
+    AirflowRescheduleTaskInstanceException,
+    AirflowTaskTimeout,
+    TaskDeferred,
+)
+from airflow.sdk.execution_time.executor import AsyncAwareExecutor, 
TaskExecutor
+from airflow.sdk.execution_time.task_runner import IndexedTaskInstance
+
+if TYPE_CHECKING:
+    import jinja2
+
+    from airflow.providers.standard.triggers.temporal import DateTimeTrigger
+    from airflow.sdk.definitions._internal.expandinput import ExpandInput
+    from airflow.sdk.definitions.context import Context
+    from airflow.sdk.execution_time.lazy_sequence import XComIterable
+
+
+ExternalDateTimeTrigger: type[DateTimeTrigger] | None
+
+try:
+    from airflow.providers.standard.triggers.temporal import DateTimeTrigger 
as ExternalDateTimeTrigger
+except ModuleNotFoundError:
+    # If the providers package with DateTimeTrigger is not available (e.g. in
+    # minimal installs or tests), set the symbol to None so callers can
+    # explicitly check for availability. Using hasattr(self, DateTimeTrigger)
+    # is incorrect because hasattr expects a string attribute name.
+    ExternalDateTimeTrigger = None
+
+
+class IterableOperator(BaseOperator):
+    """
+    Operator used for Dynamic Task Iteration (DTI) that runs a mapped operator 
over an iterable input.
+
+    The IterableOperator wraps a :class:`MappedOperator` together with an
+    :class:`ExpandInput` and is responsible for creating and running the
+    per-index runtime task instances. The IterableOperator itself is a
+    lightweight, non-retrying wrapper — retries, timeouts and deferred
+    execution are handled by the individual indexed task instances that the
+    IterableOperator creates for each element produced by the
+    ``expand_input``.
+
+    The IterableOperator executes the mapped operator instances using a
+    concurrent executor with a configurable number of workers. By default
+    the worker count is taken from the mapped operator's ``partial_kwargs``
+    (``task_concurrency``) if present, otherwise falls back to
+    ``os.cpu_count()`` and finally to ``1``.
+
+    :param operator: The :class:`MappedOperator` to unmap and execute for
+        each element of ``expand_input``. Each indexed runtime receives a
+        deep copy/unmapped instance of this operator.
+
+    :param expand_input: Provider of the values (or partitions) to iterate
+        over. Its ``iter_values(context)`` method is used to produce the
+        per-index ``mapped_kwargs`` used to unmap the operator.
+
+    :param kwargs: Additional keyword arguments forwarded to
+        :class:`BaseOperator` when instantiating the IterableOperator
+        (e.g. ``dag``, ``start_date``). Note that the IterableOperator
+        overrides retry-related parameters because retries are managed by
+        the per-index tasks.
+
+    :returns: An :class:`XComIterable` if the mapped operator pushes XComs, 
otherwise ``None``.
+    """
+
+    _operator: MappedOperator
+    expand_input: ExpandInput
+    partial_kwargs: dict[str, Any]
+    shallow_copy_attrs: Sequence[str] = (
+        "_operator",
+        "expand_input",
+        "partial_kwargs",
+        "_log",
+    )
+
+    def __init__(
+        self,
+        *,
+        operator: MappedOperator,
+        expand_input: ExpandInput,
+        **kwargs,
+    ):
+        super().__init__(
+            **{
+                **kwargs,
+                "task_id": operator.task_id,
+                "owner": operator.owner,
+                "email": operator.email,
+                "email_on_retry": operator.email_on_retry,
+                "email_on_failure": operator.email_on_failure,
+                "retries": 0,  # We should not retry the IterableOperator, 
only the indexed runtime ti's should be retried
+                "retry_delay": operator.retry_delay,
+                "retry_exponential_backoff": 
operator.retry_exponential_backoff,
+                "max_retry_delay": operator.max_retry_delay,
+                "start_date": operator.start_date,
+                "end_date": operator.end_date,
+                "depends_on_past": operator.depends_on_past,
+                "ignore_first_depends_on_past": 
operator.ignore_first_depends_on_past,
+                "wait_for_past_depends_before_skipping": 
operator.wait_for_past_depends_before_skipping,
+                "wait_for_downstream": operator.wait_for_downstream,
+                "dag": operator.dag,
+                "priority_weight": operator.priority_weight,
+                "queue": operator.queue,
+                "pool": operator.pool,
+                "pool_slots": operator.pool_slots,
+                "execution_timeout": None,
+                "trigger_rule": operator.trigger_rule,
+                "resources": operator.resources,
+                "run_as_user": operator.run_as_user,
+                "map_index_template": operator.map_index_template,
+                "max_active_tis_per_dag": operator.max_active_tis_per_dag,
+                "max_active_tis_per_dagrun": 
operator.max_active_tis_per_dagrun,
+                "executor": operator.executor,
+                "executor_config": operator.executor_config,
+                "inlets": operator.inlets,
+                "outlets": operator.outlets,
+                "task_group": operator.task_group,
+                "doc": operator.doc,
+                "doc_md": operator.doc_md,
+                "doc_json": operator.doc_json,
+                "doc_yaml": operator.doc_yaml,
+                "doc_rst": operator.doc_rst,
+                "task_display_name": operator.task_display_name,
+                "allow_nested_operators": operator.allow_nested_operators,
+            }
+        )
+        self._operator = operator
+        self.expand_input = expand_input
+        self.partial_kwargs = operator.partial_kwargs or {}
+        self.max_workers = self.partial_kwargs.pop("task_concurrency", None) 
or os.cpu_count() or 1
+        self._number_of_tasks: int = 0
+        XComArg.apply_upstream_relationship(self, self.expand_input.value)
+
+    @property
+    def task_type(self) -> str:
+        """@property: type of the task."""
+        return self._operator.__class__.__name__
+
+    @cached_property
+    def timeout(self) -> float | None:
+        if self._operator.execution_timeout:
+            return self._operator.execution_timeout.total_seconds()
+        return None
+
+    def _do_render_template_fields(
+        self,
+        parent: Any,
+        template_fields: Iterable[str],
+        context: Context,
+        jinja_env: jinja2.Environment,
+        seen_oids: set[int],
+    ) -> None:
+        # IterableOperator doesn't need to render template fields as the 
actual operator's template fields
+        # will be rendered in the TaskExecutor when running each mapped task 
instance.
+        pass
+
+    def _get_specified_expand_input(self) -> ExpandInput:
+        return self.expand_input
+
+    def _unmap_operator(
+        self, context: Context, mapped_kwargs: Context, jinja_env: 
jinja2.Environment
+    ) -> BaseOperator:
+        from airflow.sdk.execution_time.context import 
context_update_for_unmapped
+
+        self._number_of_tasks += 1
+        unmapped_task = self._operator.unmap(mapped_kwargs)
+        # Make sure deferred operators will always raise a DeferredTask 
exception when executed
+        unmapped_task.start_from_trigger = False
+        context_update_for_unmapped(context, unmapped_task)
+
+        unmapped_task._do_render_template_fields(
+            parent=unmapped_task,
+            template_fields=self._operator.template_fields,
+            context=context,
+            jinja_env=jinja_env,
+            seen_oids=set(),
+        )
+        return unmapped_task
+
+    def _xcom_push(self, task: IndexedTaskInstance, value: Any) -> None:
+        if task.xcom_pushed:
+            self.log.debug(
+                "XCom already pushed for task_id %s with index %s",
+                task.task_id,
+                task.index,
+            )
+        else:
+            self.log.debug(
+                "Pushing XCom for task_id %s with index %s",
+                task.task_id,
+                task.index,
+            )
+
+            task.xcom_push(key=BaseXCom.XCOM_RETURN_KEY, value=value)
+
+    def _run_tasks(
+        self,
+        context: Context,
+        tasks: Iterable[IndexedTaskInstance],
+    ) -> XComIterable | None:
+        exceptions: list[BaseException] = []
+        reschedule_date = timezone.utcnow()
+        deferred_tasks: deque[IndexedTaskInstance] = deque()
+        failed_tasks: deque[IndexedTaskInstance] = deque()
+        do_xcom_push = True
+
+        self.log.info("Running tasks with %d workers", self.max_workers)
+
+        while True:
+            with event_loop() as loop:
+                with AsyncAwareExecutor(loop=loop, 
max_workers=self.max_workers) as executor:
+                    for task, result, raised in executor.map(
+                        self._run_task,
+                        repeat(executor),
+                        repeat(context),
+                        tasks,
+                        timeout=self.timeout,
+                    ):
+                        do_xcom_push = task.do_xcom_push
+
+                        if raised is None:
+                            self.log.debug("result: %s", result)
+                            if result is not None and task.do_xcom_push:
+                                self._xcom_push(task=task, value=result)
+                            continue
+
+                        if isinstance(raised, TaskDeferred):
+                            operator = DecoratedDeferredAsyncOperator(
+                                operator=task.task, task_deferred=raised
+                            )
+                            deferred_tasks.append(
+                                self._create_mapped_task(
+                                    run_id=task.run_id,
+                                    index=task.index,
+                                    map_index=task.map_index,  # type: 
ignore[arg-type]
+                                    try_number=task.try_number,
+                                    operator=operator,
+                                )
+                            )
+                            continue
+
+                        if isinstance(raised, asyncio.TimeoutError):
+                            self.log.warning("A timeout occurred for task_id 
%s", task.task_id)
+                            if task.next_try_number > (self.retries or 0):
+                                exceptions.append(AirflowTaskTimeout(raised))
+                            else:
+                                reschedule_date = min(reschedule_date, 
task.next_retry_datetime())
+                                failed_tasks.append(task)
+                            continue
+
+                        if isinstance(raised, 
AirflowRescheduleTaskInstanceException):
+                            reschedule_date = min(reschedule_date, 
raised.reschedule_date)
+                            self.log.exception(
+                                "An exception occurred for task_id %s with 
index %s, it has been rescheduled at %s",
+                                task.task_id,
+                                task.index,
+                                reschedule_date,
+                            )
+                            failed_tasks.append(raised.task)
+                            continue
+
+                        self.log.exception(
+                            "An exception occurred for task_id %s with index 
%s",
+                            task.task_id,
+                            task.index,
+                            exc_info=raised,
+                        )
+                        exceptions.append(raised)
+
+                    # Deferred tasks are re-fed as a new pass once the current 
batch completes,
+                    # because the event loop is restarted at the top of the 
outer while loop.
+                    if deferred_tasks:
+                        tasks = list(deferred_tasks)
+                        deferred_tasks.clear()
+                        continue
+
+            if not failed_tasks:
+                if exceptions:
+                    # If this IterableOperator is backed by a partitioned 
expand input
+                    # (created from a MappedIterableOperator), the parent 
mapped
+                    # task should never be retried; retries are handled by the
+                    # individual indexed runtime tasks. In that case raise
+                    # AirflowFailException to mark failure without retrying the
+                    # parent TaskInstance. For regular (non-partitioned) 
IterableOperator
+                    # behavior, preserve the previous behavior and raise the
+                    # BaseExceptionGroup so callers/tests that expect it keep 
working.
+                    if isinstance(self.expand_input, PartitionedExpandInput):
+                        raise AirflowFailException(f"Multiple sub-task 
failures: {exceptions}")
+                    raise BaseExceptionGroup("Multiple sub-task failures", 
exceptions)
+                if do_xcom_push:
+                    from airflow.sdk.execution_time.lazy_sequence import 
XComIterable
+
+                    return XComIterable(
+                        task_id=self.task_id,
+                        dag_id=self.dag_id,
+                        run_id=context["run_id"],
+                        length=self._number_of_tasks,
+                        map_index=context["ti"].map_index,
+                    )
+                return None
+
+            # If the retry time is still in the future we defer the operator 
so the worker
+            # slot is released. If the retry time has already passed we 
immediately re-run
+            # the failed tasks without deferring.
+            if reschedule_date > timezone.utcnow():
+                if ExternalDateTimeTrigger is not None:
+                    self.defer(
+                        trigger=ExternalDateTimeTrigger(reschedule_date),
+                        method_name=self.execute_failed_tasks.__name__,
+                        kwargs={
+                            "failed_tasks": {failed_task.index for failed_task 
in failed_tasks},
+                            "try_number": next(iter(failed_tasks)).try_number,

Review Comment:
   The defer payload carries a single `try_number` 
(`next(iter(failed_tasks)).try_number`) for the whole failed batch, and 
`execute_failed_tasks` stamps every recreated task with it. When failed 
sub-tasks in a batch are on different attempt counts, their retry budgets drift 
after a reschedule. Carry per-index try numbers instead.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to