This is an automated email from the ASF dual-hosted git repository.

jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new a73c6268a3b Fix KubernetesPodTrigger.get_task_state KeyError on mapped 
TIs (#67296) (#67297)
a73c6268a3b is described below

commit a73c6268a3b2b718cb3363051a80e5e4e9511072
Author: Paul Mathew <[email protected]>
AuthorDate: Thu May 21 17:21:31 2026 -0400

    Fix KubernetesPodTrigger.get_task_state KeyError on mapped TIs (#67296) 
(#67297)
    
    The execution API's /states endpoint encodes the response key as
    ``f"{task_id}_{map_index}"`` for mapped TIs but the trigger was looking
    the value up by plain ``task_id``. For any mapped deferrable
    KubernetesPodOperator task that lookup raised KeyError, which
    cleanup()'s broad ``except Exception`` swallowed and skipped
    ``hook.delete_pod()`` -- so Mark Failed in the UI left the pod running
    until ``active_deadline_seconds`` expired.
    
    Compose the lookup key with the ``_{map_index}`` suffix when the TI is
    mapped, matching how the API serialises the response. cleanup() now
    sees the real state, ``safe_to_cancel()`` returns the right value, and
    mark-failed actually deletes the pod within the grace period.
    
    Co-authored-by: Cursor <[email protected]>
---
 .../providers/cncf/kubernetes/triggers/pod.py      | 10 ++-
 .../unit/cncf/kubernetes/triggers/test_pod.py      | 86 +++++++++++++++++++++-
 2 files changed, 94 insertions(+), 2 deletions(-)

diff --git 
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py
 
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py
index e801a2f0566..65b1b45bb04 100644
--- 
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py
+++ 
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py
@@ -397,8 +397,16 @@ class KubernetesPodTrigger(BaseTrigger):
                 run_ids=[self.task_instance.run_id],
                 map_index=self.task_instance.map_index,
             )
+            # The /states endpoint suffixes the response key with 
``_{map_index}`` for mapped TIs
+            # (see ``get_task_instance_states`` in airflow-core's 
execution_api routes); non-mapped
+            # TIs keep the plain ``task_id``.
+            ti_key = (
+                f"{self.task_instance.task_id}_{self.task_instance.map_index}"
+                if self.task_instance.map_index >= 0
+                else self.task_instance.task_id
+            )
             try:
-                return 
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+                return task_states_response[self.task_instance.run_id][ti_key]
             except KeyError:
                 raise AirflowException(
                     "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and 
map_index: %s is not found",
diff --git 
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_pod.py 
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_pod.py
index 26b2db90e16..765a3f35e3d 100644
--- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_pod.py
+++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_pod.py
@@ -34,7 +34,7 @@ from airflow.providers.cncf.kubernetes.utils.pod_manager 
import PodPhase
 from airflow.triggers.base import TriggerEvent
 from airflow.utils.state import TaskInstanceState
 
-from tests_common.test_utils.version_compat import AIRFLOW_V_3_3_PLUS
+from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS, 
AIRFLOW_V_3_3_PLUS
 
 TRIGGER_PATH = 
"airflow.providers.cncf.kubernetes.triggers.pod.KubernetesPodTrigger"
 HOOK_PATH = 
"airflow.providers.cncf.kubernetes.hooks.kubernetes.AsyncKubernetesHook"
@@ -827,6 +827,90 @@ class TestKubernetesPodTrigger:
         )
         assert await trigger.safe_to_cancel() is False
 
+    @pytest.mark.skipif(
+        not AIRFLOW_V_3_0_PLUS,
+        reason="get_task_state uses RuntimeTaskInstance.get_task_states on 
Airflow 3.0+",
+    )
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_task_states")
+    async def test_get_task_state_uses_task_id_for_non_mapped_ti(self, 
mock_get_task_states):
+        # Non-mapped TIs (``map_index < 0``) are keyed by plain ``task_id`` in 
the
+        # response, matching the dict-key construction in the execution API's
+        # ``get_task_instance_states`` handler.
+        run_id = "manual__2026-05-21T00:00:00+00:00"
+        mock_get_task_states.return_value = {run_id: {"my_task": 
TaskInstanceState.SUCCESS}}
+
+        trigger = KubernetesPodTrigger(
+            pod_name=POD_NAME,
+            pod_namespace=NAMESPACE,
+            base_container_name=BASE_CONTAINER_NAME,
+            trigger_start_time=TRIGGER_START_TIME,
+            schedule_timeout=STARTUP_TIMEOUT_SECS,
+        )
+        trigger.task_instance = MagicMock(dag_id="my_dag", task_id="my_task", 
run_id=run_id, map_index=-1)
+
+        assert await trigger.get_task_state() == TaskInstanceState.SUCCESS
+
+    @pytest.mark.skipif(
+        not AIRFLOW_V_3_0_PLUS,
+        reason="get_task_state uses RuntimeTaskInstance.get_task_states on 
Airflow 3.0+",
+    )
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_task_states")
+    async def test_get_task_state_uses_composite_key_for_mapped_ti(self, 
mock_get_task_states):
+        # Regression guard for #67296: mapped TIs (``map_index >= 0``) are
+        # keyed by ``f"{task_id}_{map_index}"`` in the response. Without the
+        # suffix this lookup would KeyError, which ``cleanup()`` would
+        # defensively swallow and skip ``hook.delete_pod()`` -- leaking the
+        # pod until ``active_deadline_seconds`` expires on user mark-failed.
+        run_id = "manual__2026-05-21T00:00:00+00:00"
+        mock_get_task_states.return_value = {run_id: {"map_group.task_a_2": 
TaskInstanceState.FAILED}}
+
+        trigger = KubernetesPodTrigger(
+            pod_name=POD_NAME,
+            pod_namespace=NAMESPACE,
+            base_container_name=BASE_CONTAINER_NAME,
+            trigger_start_time=TRIGGER_START_TIME,
+            schedule_timeout=STARTUP_TIMEOUT_SECS,
+        )
+        trigger.task_instance = MagicMock(
+            dag_id="my_dag", task_id="map_group.task_a", run_id=run_id, 
map_index=2
+        )
+
+        assert await trigger.get_task_state() == TaskInstanceState.FAILED
+
+    @pytest.mark.skipif(
+        not AIRFLOW_V_3_0_PLUS,
+        reason="get_task_state uses RuntimeTaskInstance.get_task_states on 
Airflow 3.0+",
+    )
+    @pytest.mark.asyncio
+    
@mock.patch("airflow.sdk.execution_time.task_runner.RuntimeTaskInstance.get_task_states")
+    async def test_get_task_state_raises_when_mapped_key_missing(self, 
mock_get_task_states):
+        # The wrapped ``AirflowException`` shape is preserved when the
+        # response is missing the expected (composite) key, so callers
+        # like ``safe_to_cancel`` keep the same behaviour they had before
+        # the lookup was fixed.
+        from airflow.exceptions import AirflowException
+
+        run_id = "manual__2026-05-21T00:00:00+00:00"
+        # Response has the run_id but not the (``map_group.task_a``, ``2``)
+        # entry -- e.g. supervisor has not observed the TI yet.
+        mock_get_task_states.return_value = {run_id: {"map_group.task_a_5": 
"running"}}
+
+        trigger = KubernetesPodTrigger(
+            pod_name=POD_NAME,
+            pod_namespace=NAMESPACE,
+            base_container_name=BASE_CONTAINER_NAME,
+            trigger_start_time=TRIGGER_START_TIME,
+            schedule_timeout=STARTUP_TIMEOUT_SECS,
+        )
+        trigger.task_instance = MagicMock(
+            dag_id="my_dag", task_id="map_group.task_a", run_id=run_id, 
map_index=2
+        )
+
+        with pytest.raises(AirflowException, match="TaskInstance with dag_id"):
+            await trigger.get_task_state()
+
     @pytest.mark.skipif(
         AIRFLOW_V_3_3_PLUS,
         reason="Legacy cleanup path runs only on Airflow < 3.3",

Reply via email to