This is an automated email from the ASF dual-hosted git repository.

shahar pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e50ce94cdc4 Handle multiple pods to prevent 
```KubernetesJobOperator``` falls with parallelism option (#49899)
e50ce94cdc4 is described below

commit e50ce94cdc427b3f4930d59e5a7390e5eb2bfab5
Author: Nitochkin <[email protected]>
AuthorDate: Thu Jul 24 11:56:03 2025 +0200

    Handle multiple pods to prevent ```KubernetesJobOperator``` falls with 
parallelism option (#49899)
    
    Co-authored-by: Anton Nitochkin <[email protected]>
---
 .../providers/cncf/kubernetes/operators/job.py     | 132 ++++++++++++-----
 .../providers/cncf/kubernetes/triggers/job.py      |  63 +++++---
 .../unit/cncf/kubernetes/operators/test_job.py     | 161 +++++++++++++++++++--
 .../unit/cncf/kubernetes/triggers/test_job.py      |  16 +-
 .../google/cloud/operators/kubernetes_engine.py    |   4 +-
 .../google/cloud/triggers/kubernetes_engine.py     |  66 ++++++---
 .../example_kubernetes_engine_job.py               | 101 ++++++++++++-
 .../cloud/operators/test_kubernetes_engine.py      |  68 ++++++++-
 .../cloud/triggers/test_kubernetes_engine.py       |  16 +-
 9 files changed, 517 insertions(+), 110 deletions(-)

diff --git 
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/job.py
 
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/job.py
index 7c19421594a..79e9b835812 100644
--- 
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/job.py
+++ 
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/job.py
@@ -22,16 +22,17 @@ import copy
 import json
 import logging
 import os
+import warnings
 from collections.abc import Sequence
 from functools import cached_property
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Literal
 
 from kubernetes.client import BatchV1Api, models as k8s
 from kubernetes.client.api_client import ApiClient
 from kubernetes.client.rest import ApiException
 
 from airflow.configuration import conf
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
 from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook
 from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import (
     add_unique_suffix,
@@ -81,6 +82,17 @@ class KubernetesJobOperator(KubernetesPodOperator):
         Used if the parameter `wait_until_job_complete` set True.
     :param deferrable: Run operator in the deferrable mode. Note that the 
parameter
         `wait_until_job_complete` must be set True.
+    :param on_kill_propagation_policy: Whether and how garbage collection will 
be performed. Default is 'Foreground'.
+        Acceptable values are:
+        'Orphan' - orphan the dependents;
+        'Background' - allow the garbage collector to delete the dependents in 
the background;
+        'Foreground' - a cascading policy that deletes all dependents in the 
foreground.
+        Default value is 'Foreground'.
+    :param discover_pods_retry_number: Number of time list_namespaced_pod will 
be performed to discover
+        already running pods.
+    :param unwrap_single: Unwrap single result from the pod. For example, when 
set to `True` - if the XCom
+        result should be `['res']`, the final result would be `'res'`. Default 
is True to support backward
+        compatibility.
     """
 
     template_fields: Sequence[str] = tuple({"job_template_file"} | 
set(KubernetesPodOperator.template_fields))
@@ -101,8 +113,12 @@ class KubernetesJobOperator(KubernetesPodOperator):
         wait_until_job_complete: bool = False,
         job_poll_interval: float = 10,
         deferrable: bool = conf.getboolean("operators", "default_deferrable", 
fallback=False),
+        on_kill_propagation_policy: Literal["Foreground", "Background", 
"Orphan"] = "Foreground",
+        discover_pods_retry_number: int = 3,
+        unwrap_single: bool = True,
         **kwargs,
     ) -> None:
+        self._pod = None
         super().__init__(**kwargs)
         self.job_template_file = job_template_file
         self.full_job_spec = full_job_spec
@@ -119,6 +135,22 @@ class KubernetesJobOperator(KubernetesPodOperator):
         self.wait_until_job_complete = wait_until_job_complete
         self.job_poll_interval = job_poll_interval
         self.deferrable = deferrable
+        self.on_kill_propagation_policy = on_kill_propagation_policy
+        self.discover_pods_retry_number = discover_pods_retry_number
+        self.unwrap_single = unwrap_single
+
+    @property
+    def pod(self):
+        warnings.warn(
+            "`pod` parameter is deprecated, please use `pods`",
+            AirflowProviderDeprecationWarning,
+            stacklevel=2,
+        )
+        return self.pods[0] if self.pods else None
+
+    @pod.setter
+    def pod(self, value):
+        self._pod = value
 
     @cached_property
     def _incluster_namespace(self):
@@ -167,12 +199,16 @@ class KubernetesJobOperator(KubernetesPodOperator):
         ti.xcom_push(key="job_name", value=self.job.metadata.name)
         ti.xcom_push(key="job_namespace", value=self.job.metadata.namespace)
 
-        self.pod: k8s.V1Pod | None
-        if self.pod is None:
-            self.pod = self.get_or_create_pod(  # must set `self.pod` for 
`on_kill`
-                pod_request_obj=self.pod_request_obj,
-                context=context,
-            )
+        self.pods: Sequence[k8s.V1Pod] | None = None
+        if self.parallelism is None and self.pod is None:
+            self.pods = [
+                self.get_or_create_pod(
+                    pod_request_obj=self.pod_request_obj,
+                    context=context,
+                )
+            ]
+        else:
+            self.pods = self.get_pods(pod_request_obj=self.pod_request_obj, 
context=context)
 
         if self.wait_until_job_complete and self.deferrable:
             self.execute_deferrable()
@@ -180,22 +216,25 @@ class KubernetesJobOperator(KubernetesPodOperator):
 
         if self.wait_until_job_complete:
             if self.do_xcom_push:
-                self.pod_manager.await_container_completion(
-                    pod=self.pod, container_name=self.base_container_name
-                )
-                
self.pod_manager.await_xcom_sidecar_container_start(pod=self.pod)
-                xcom_result = self.extract_xcom(pod=self.pod)
+                xcom_result = []
+                for pod in self.pods:
+                    self.pod_manager.await_container_completion(
+                        pod=pod, container_name=self.base_container_name
+                    )
+                    
self.pod_manager.await_xcom_sidecar_container_start(pod=pod)
+                    xcom_result.append(self.extract_xcom(pod=pod))
             self.job = self.hook.wait_until_job_complete(
                 job_name=self.job.metadata.name,
                 namespace=self.job.metadata.namespace,
                 job_poll_interval=self.job_poll_interval,
             )
             if self.get_logs:
-                self.pod_manager.fetch_requested_container_logs(
-                    pod=self.pod,
-                    containers=self.container_logs,
-                    follow_logs=True,
-                )
+                for pod in self.pods:
+                    self.pod_manager.fetch_requested_container_logs(
+                        pod=pod,
+                        containers=self.container_logs,
+                        follow_logs=True,
+                    )
 
         ti.xcom_push(key="job", value=self.job.to_dict())
         if self.wait_until_job_complete:
@@ -211,8 +250,8 @@ class KubernetesJobOperator(KubernetesPodOperator):
             trigger=KubernetesJobTrigger(
                 job_name=self.job.metadata.name,
                 job_namespace=self.job.metadata.namespace,
-                pod_name=self.pod.metadata.name,
-                pod_namespace=self.pod.metadata.namespace,
+                pod_names=[pod.metadata.name for pod in self.pods],
+                pod_namespace=self.pods[0].metadata.namespace,
                 base_container_name=self.base_container_name,
                 kubernetes_conn_id=self.kubernetes_conn_id,
                 cluster_context=self.cluster_context,
@@ -232,20 +271,23 @@ class KubernetesJobOperator(KubernetesPodOperator):
             raise AirflowException(event["message"])
 
         if self.get_logs:
-            pod_name = event["pod_name"]
-            pod_namespace = event["pod_namespace"]
-            self.pod = self.hook.get_pod(pod_name, pod_namespace)
-            if not self.pod:
-                raise PodNotFoundException("Could not find pod after resuming 
from deferral")
-            self._write_logs(self.pod)
+            for pod_name in event["pod_names"]:
+                pod_namespace = event["pod_namespace"]
+                pod = self.hook.get_pod(pod_name, pod_namespace)
+                if not pod:
+                    raise PodNotFoundException("Could not find pod after 
resuming from deferral")
+                self._write_logs(pod)
 
         if self.do_xcom_push:
-            xcom_result = event["xcom_result"]
-            if isinstance(xcom_result, str) and xcom_result.rstrip() == 
EMPTY_XCOM_RESULT:
-                self.log.info("xcom result file is empty.")
-                return None
-            self.log.info("xcom result: \n%s", xcom_result)
-            return json.loads(xcom_result)
+            xcom_results: list[Any | None] = []
+            for xcom_result in event["xcom_result"]:
+                if isinstance(xcom_result, str) and xcom_result.rstrip() == 
EMPTY_XCOM_RESULT:
+                    self.log.info("xcom result file is empty.")
+                    xcom_results.append(None)
+                    continue
+                self.log.info("xcom result: \n%s", xcom_result)
+                xcom_results.append(json.loads(xcom_result))
+            return xcom_results[0] if self.unwrap_single and len(xcom_results) 
== 1 else xcom_results
 
     @staticmethod
     def deserialize_job_template_file(path: str) -> k8s.V1Job:
@@ -275,12 +317,11 @@ class KubernetesJobOperator(KubernetesPodOperator):
             kwargs = {
                 "name": job.metadata.name,
                 "namespace": job.metadata.namespace,
+                "propagation_policy": self.on_kill_propagation_policy,
             }
             if self.termination_grace_period is not None:
                 
kwargs.update(grace_period_seconds=self.termination_grace_period)
             self.job_client.delete_namespaced_job(**kwargs)
-        if self.pod:
-            super().on_kill()
 
     def build_job_request_obj(self, context: Context | None = None) -> 
k8s.V1Job:
         """
@@ -400,6 +441,29 @@ class KubernetesJobOperator(KubernetesPodOperator):
 
         return None
 
+    def get_pods(
+        self, pod_request_obj: k8s.V1Pod, context: Context, *, 
exclude_checked: bool = True
+    ) -> Sequence[k8s.V1Pod]:
+        """Return an already-running pods if exists."""
+        label_selector = self._build_find_pod_label_selector(context, 
exclude_checked=exclude_checked)
+        pod_list: Sequence[k8s.V1Pod] = []
+        retry_number: int = 0
+
+        while len(pod_list) != self.parallelism or retry_number <= 
self.discover_pods_retry_number:
+            pod_list = self.client.list_namespaced_pod(
+                namespace=pod_request_obj.metadata.namespace,
+                label_selector=label_selector,
+            ).items
+            retry_number += 1
+
+        if len(pod_list) == 0:
+            raise AirflowException(f"No pods running with labels 
{label_selector}")
+
+        for pod_instance in pod_list:
+            self.log_matching_pod(pod=pod_instance, context=context)
+
+        return pod_list
+
 
 class KubernetesDeleteJobOperator(BaseOperator):
     """
diff --git 
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/job.py
 
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/job.py
index 359c3054788..b60373c2d53 100644
--- 
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/job.py
+++ 
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/job.py
@@ -17,10 +17,12 @@
 from __future__ import annotations
 
 import asyncio
+import warnings
 from collections.abc import AsyncIterator
 from functools import cached_property
 from typing import TYPE_CHECKING, Any
 
+from airflow.exceptions import AirflowProviderDeprecationWarning
 from airflow.providers.cncf.kubernetes.hooks.kubernetes import 
AsyncKubernetesHook, KubernetesHook
 from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager
 from airflow.providers.cncf.kubernetes.utils.xcom_sidecar import PodDefaults
@@ -36,7 +38,8 @@ class KubernetesJobTrigger(BaseTrigger):
 
     :param job_name: The name of the job.
     :param job_namespace: The namespace of the job.
-    :param pod_name: The name of the Pod.
+    :param pod_name: The name of the Pod. Parameter is deprecated, please use 
pod_names instead.
+    :param pod_names: The name of the Pods.
     :param pod_namespace: The namespace of the Pod.
     :param base_container_name: The name of the base container in the pod.
     :param kubernetes_conn_id: The :ref:`kubernetes connection id 
<howto/connection:kubernetes>`
@@ -55,9 +58,10 @@ class KubernetesJobTrigger(BaseTrigger):
         self,
         job_name: str,
         job_namespace: str,
-        pod_name: str,
+        pod_names: list[str],
         pod_namespace: str,
         base_container_name: str,
+        pod_name: str | None = None,
         kubernetes_conn_id: str | None = None,
         poll_interval: float = 10.0,
         cluster_context: str | None = None,
@@ -69,7 +73,13 @@ class KubernetesJobTrigger(BaseTrigger):
         super().__init__()
         self.job_name = job_name
         self.job_namespace = job_namespace
-        self.pod_name = pod_name
+        if pod_name is not None:
+            self._pod_name = pod_name
+            self.pod_names = [
+                self.pod_name,
+            ]
+        else:
+            self.pod_names = pod_names
         self.pod_namespace = pod_namespace
         self.base_container_name = base_container_name
         self.kubernetes_conn_id = kubernetes_conn_id
@@ -80,6 +90,15 @@ class KubernetesJobTrigger(BaseTrigger):
         self.get_logs = get_logs
         self.do_xcom_push = do_xcom_push
 
+    @property
+    def pod_name(self):
+        warnings.warn(
+            "`pod_name` parameter is deprecated, please use `pod_names`",
+            AirflowProviderDeprecationWarning,
+            stacklevel=2,
+        )
+        return self._pod_name
+
     def serialize(self) -> tuple[str, dict[str, Any]]:
         """Serialize KubernetesCreateJobTrigger arguments and classpath."""
         return (
@@ -87,7 +106,7 @@ class KubernetesJobTrigger(BaseTrigger):
             {
                 "job_name": self.job_name,
                 "job_namespace": self.job_namespace,
-                "pod_name": self.pod_name,
+                "pod_names": self.pod_names,
                 "pod_namespace": self.pod_namespace,
                 "base_container_name": self.base_container_name,
                 "kubernetes_conn_id": self.kubernetes_conn_id,
@@ -102,21 +121,23 @@ class KubernetesJobTrigger(BaseTrigger):
 
     async def run(self) -> AsyncIterator[TriggerEvent]:
         """Get current job status and yield a TriggerEvent."""
-        if self.get_logs or self.do_xcom_push:
-            pod = await self.hook.get_pod(name=self.pod_name, 
namespace=self.pod_namespace)
         if self.do_xcom_push:
-            await self.hook.wait_until_container_complete(
-                name=self.pod_name, namespace=self.pod_namespace, 
container_name=self.base_container_name
-            )
-            self.log.info("Checking if xcom sidecar container is started.")
-            await self.hook.wait_until_container_started(
-                name=self.pod_name,
-                namespace=self.pod_namespace,
-                container_name=PodDefaults.SIDECAR_CONTAINER_NAME,
-            )
-            self.log.info("Extracting result from xcom sidecar container.")
-            loop = asyncio.get_running_loop()
-            xcom_result = await loop.run_in_executor(None, 
self.pod_manager.extract_xcom, pod)
+            xcom_results = []
+            for pod_name in self.pod_names:
+                pod = await self.hook.get_pod(name=pod_name, 
namespace=self.pod_namespace)
+                await self.hook.wait_until_container_complete(
+                    name=pod_name, namespace=self.pod_namespace, 
container_name=self.base_container_name
+                )
+                self.log.info("Checking if xcom sidecar container is started.")
+                await self.hook.wait_until_container_started(
+                    name=pod_name,
+                    namespace=self.pod_namespace,
+                    container_name=PodDefaults.SIDECAR_CONTAINER_NAME,
+                )
+                self.log.info("Extracting result from xcom sidecar container.")
+                loop = asyncio.get_running_loop()
+                xcom_result = await loop.run_in_executor(None, 
self.pod_manager.extract_xcom, pod)
+                xcom_results.append(xcom_result)
         job: V1Job = await 
self.hook.wait_until_job_complete(name=self.job_name, 
namespace=self.job_namespace)
         job_dict = job.to_dict()
         error_message = self.hook.is_job_failed(job=job)
@@ -124,14 +145,14 @@ class KubernetesJobTrigger(BaseTrigger):
             {
                 "name": job.metadata.name,
                 "namespace": job.metadata.namespace,
-                "pod_name": pod.metadata.name if self.get_logs else None,
-                "pod_namespace": pod.metadata.namespace if self.get_logs else 
None,
+                "pod_names": [pod_name for pod_name in self.pod_names] if 
self.get_logs else None,
+                "pod_namespace": self.pod_namespace if self.get_logs else None,
                 "status": "error" if error_message else "success",
                 "message": f"Job failed with error: {error_message}"
                 if error_message
                 else "Job completed successfully",
                 "job": job_dict,
-                "xcom_result": xcom_result if self.do_xcom_push else None,
+                "xcom_result": xcom_results if self.do_xcom_push else None,
             }
         )
 
diff --git 
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_job.py 
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_job.py
index 130a87cbbe3..6caf69e0481 100644
--- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_job.py
+++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_job.py
@@ -26,7 +26,7 @@ import pendulum
 import pytest
 from kubernetes.client import ApiClient, models as k8s
 
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException, 
AirflowProviderDeprecationWarning
 from airflow.models import DAG, DagModel, DagRun, TaskInstance
 from airflow.models.serialized_dag import SerializedDagModel
 from airflow.providers.cncf.kubernetes.operators.job import (
@@ -52,6 +52,7 @@ POD_NAME = "test-pod"
 POD_NAMESPACE = "test-namespace"
 TEST_XCOM_RESULT = '{"result": "test-xcom-result"}'
 POD_MANAGER_CLASS = 
"airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager"
+ON_KILL_PROPAGATION_POLICY = "Foreground"
 
 
 def create_context(task, persist_to_db=False, map_index=None):
@@ -508,6 +509,45 @@ class TestKubernetesJobOperator:
         op = KubernetesJobOperator(
             task_id="test_task_id",
         )
+        with pytest.warns(AirflowProviderDeprecationWarning):
+            execute_result = op.execute(context=context)
+
+        mock_build_job_request_obj.assert_called_once_with(context)
+        
mock_create_job.assert_called_once_with(job_request_obj=mock_job_request_obj)
+        mock_ti.xcom_push.assert_has_calls(
+            [
+                mock.call(key="job_name", 
value=mock_job_expected.metadata.name),
+                mock.call(key="job_namespace", 
value=mock_job_expected.metadata.namespace),
+                mock.call(key="job", 
value=mock_job_expected.to_dict.return_value),
+            ]
+        )
+
+        assert op.job_request_obj == mock_job_request_obj
+        assert op.job == mock_job_expected
+        assert not op.wait_until_job_complete
+        assert execute_result is None
+        assert not mock_hook.wait_until_job_complete.called
+
+    @pytest.mark.non_db_test_override
+    @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_pods"))
+    
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj"))
+    @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job"))
+    @patch(HOOK_CLASS)
+    def test_execute_with_parallelism(
+        self, mock_hook, mock_create_job, mock_build_job_request_obj, 
mock_get_pods
+    ):
+        mock_hook.return_value.is_job_failed.return_value = False
+        mock_job_request_obj = mock_build_job_request_obj.return_value
+        mock_job_expected = mock_create_job.return_value
+        mock_get_pods.return_value = [mock.MagicMock(), mock.MagicMock()]
+        mock_pods_expected = mock_get_pods.return_value
+        mock_ti = mock.MagicMock()
+        context = dict(ti=mock_ti)
+
+        op = KubernetesJobOperator(
+            task_id="test_task_id",
+            parallelism=2,
+        )
         execute_result = op.execute(context=context)
 
         mock_build_job_request_obj.assert_called_once_with(context)
@@ -522,6 +562,9 @@ class TestKubernetesJobOperator:
 
         assert op.job_request_obj == mock_job_request_obj
         assert op.job == mock_job_expected
+        assert op.pods == mock_pods_expected
+        with pytest.warns(AirflowProviderDeprecationWarning):
+            assert op.pod is mock_pods_expected[0]
         assert not op.wait_until_job_complete
         assert execute_result is None
         assert not mock_hook.wait_until_job_complete.called
@@ -551,7 +594,8 @@ class TestKubernetesJobOperator:
             wait_until_job_complete=True,
             deferrable=True,
         )
-        actual_result = op.execute(context=context)
+        with pytest.warns(AirflowProviderDeprecationWarning):
+            actual_result = op.execute(context=context)
 
         mock_build_job_request_obj.assert_called_once_with(context)
         
mock_create_job.assert_called_once_with(job_request_obj=mock_job_request_obj)
@@ -583,8 +627,9 @@ class TestKubernetesJobOperator:
             wait_until_job_complete=True,
         )
 
-        with pytest.raises(AirflowException):
-            op.execute(context=dict(ti=mock.MagicMock()))
+        with pytest.warns(AirflowProviderDeprecationWarning):
+            with pytest.raises(AirflowException):
+                op.execute(context=dict(ti=mock.MagicMock()))
 
     @pytest.mark.non_db_test_override
     @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.defer"))
@@ -616,6 +661,71 @@ class TestKubernetesJobOperator:
         )
         op.job = mock_job
         op.pod = mock_pod
+        op.pods = [
+            mock_pod,
+        ]
+
+        actual_result = op.execute_deferrable()
+
+        mock_execute_deferrable.assert_called_once_with(
+            trigger=mock_trigger_instance,
+            method_name="execute_complete",
+        )
+        mock_trigger.assert_called_once_with(
+            job_name=JOB_NAME,
+            job_namespace=JOB_NAMESPACE,
+            pod_names=[
+                POD_NAME,
+            ],
+            pod_namespace=POD_NAMESPACE,
+            base_container_name=op.BASE_CONTAINER_NAME,
+            kubernetes_conn_id=KUBERNETES_CONN_ID,
+            cluster_context=mock_cluster_context,
+            config_file=mock_config_file,
+            in_cluster=mock_in_cluster,
+            poll_interval=POLL_INTERVAL,
+            get_logs=True,
+            do_xcom_push=False,
+        )
+        assert actual_result is None
+
+    @pytest.mark.non_db_test_override
+    @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.defer"))
+    @patch(JOB_OPERATORS_PATH.format("KubernetesJobTrigger"))
+    def test_execute_deferrable_with_parallelism(self, mock_trigger, 
mock_execute_deferrable):
+        mock_cluster_context = mock.MagicMock()
+        mock_config_file = mock.MagicMock()
+        mock_in_cluster = mock.MagicMock()
+
+        mock_job = mock.MagicMock()
+        mock_job.metadata.name = JOB_NAME
+        mock_job.metadata.namespace = JOB_NAMESPACE
+
+        pod_name_1 = POD_NAME + "-1"
+        mock_pod_1 = mock.MagicMock()
+        mock_pod_1.metadata.name = pod_name_1
+        mock_pod_1.metadata.namespace = POD_NAMESPACE
+
+        pod_name_2 = POD_NAME + "-2"
+        mock_pod_2 = mock.MagicMock()
+        mock_pod_2.metadata.name = pod_name_2
+        mock_pod_2.metadata.namespace = POD_NAMESPACE
+
+        mock_trigger_instance = mock_trigger.return_value
+
+        op = KubernetesJobOperator(
+            task_id="test_task_id",
+            kubernetes_conn_id=KUBERNETES_CONN_ID,
+            cluster_context=mock_cluster_context,
+            config_file=mock_config_file,
+            in_cluster=mock_in_cluster,
+            job_poll_interval=POLL_INTERVAL,
+            parallelism=2,
+            wait_until_job_complete=True,
+            deferrable=True,
+        )
+        op.job = mock_job
+        op.pods = [mock_pod_1, mock_pod_2]
 
         actual_result = op.execute_deferrable()
 
@@ -626,7 +736,7 @@ class TestKubernetesJobOperator:
         mock_trigger.assert_called_once_with(
             job_name=JOB_NAME,
             job_namespace=JOB_NAMESPACE,
-            pod_name=POD_NAME,
+            pod_names=[pod_name_1, pod_name_2],
             pod_namespace=POD_NAMESPACE,
             base_container_name=op.BASE_CONTAINER_NAME,
             kubernetes_conn_id=KUBERNETES_CONN_ID,
@@ -656,7 +766,8 @@ class TestKubernetesJobOperator:
         op = KubernetesJobOperator(
             task_id="test_task_id", wait_until_job_complete=True, 
job_poll_interval=POLL_INTERVAL
         )
-        op.execute(context=dict(ti=mock_ti))
+        with pytest.warns(AirflowProviderDeprecationWarning):
+            op.execute(context=dict(ti=mock_ti))
 
         assert op.wait_until_job_complete
         assert op.job_poll_interval == POLL_INTERVAL
@@ -676,9 +787,17 @@ class TestKubernetesJobOperator:
         event = {
             "job": mock_job,
             "status": "success",
-            "pod_name": POD_NAME if get_logs else None,
+            "pod_names": [
+                POD_NAME,
+            ]
+            if get_logs
+            else None,
             "pod_namespace": POD_NAMESPACE if get_logs else None,
-            "xcom_result": TEST_XCOM_RESULT if do_xcom_push else None,
+            "xcom_result": [
+                TEST_XCOM_RESULT,
+            ]
+            if do_xcom_push
+            else None,
         }
 
         KubernetesJobOperator(
@@ -718,6 +837,7 @@ class TestKubernetesJobOperator:
         mock_client.delete_namespaced_job.assert_called_once_with(
             name=JOB_NAME,
             namespace=JOB_NAMESPACE,
+            propagation_policy=ON_KILL_PROPAGATION_POLICY,
         )
 
     @pytest.mark.non_db_test_override
@@ -737,6 +857,7 @@ class TestKubernetesJobOperator:
         mock_client.delete_namespaced_job.assert_called_once_with(
             name=JOB_NAME,
             namespace=JOB_NAMESPACE,
+            propagation_policy=ON_KILL_PROPAGATION_POLICY,
             grace_period_seconds=mock_termination_grace_period,
         )
 
@@ -752,9 +873,11 @@ class TestKubernetesJobOperator:
         mock_client.delete_namespaced_job.assert_not_called()
         mock_serialize.assert_not_called()
 
+    @pytest.mark.parametrize("parallelism", [None, 2])
     @pytest.mark.parametrize("do_xcom_push", [True, False])
     @pytest.mark.parametrize("get_logs", [True, False])
     @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.extract_xcom"))
+    @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_pods"))
     
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_or_create_pod"))
     
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj"))
     @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job"))
@@ -771,10 +894,16 @@ class TestKubernetesJobOperator:
         mock_create_job,
         mock_build_job_request_obj,
         mock_get_or_create_pod,
+        mock_get_pods,
         mock_extract_xcom,
         get_logs,
         do_xcom_push,
+        parallelism,
     ):
+        if parallelism == 2:
+            mock_pod_1 = mock.MagicMock()
+            mock_pod_2 = mock.MagicMock()
+            mock_get_pods.return_value = [mock_pod_1, mock_pod_2]
         mock_ti = mock.MagicMock()
         op = KubernetesJobOperator(
             task_id="test_task_id",
@@ -782,16 +911,26 @@ class TestKubernetesJobOperator:
             job_poll_interval=POLL_INTERVAL,
             get_logs=get_logs,
             do_xcom_push=do_xcom_push,
+            parallelism=parallelism,
         )
-        op.execute(context=dict(ti=mock_ti))
 
-        if do_xcom_push:
+        if not parallelism:
+            with pytest.warns(AirflowProviderDeprecationWarning):
+                op.execute(context=dict(ti=mock_ti))
+        else:
+            op.execute(context=dict(ti=mock_ti))
+
+        if do_xcom_push and not parallelism:
             mock_extract_xcom.assert_called_once()
+        elif do_xcom_push and parallelism is not None:
+            assert mock_extract_xcom.call_count == parallelism
         else:
             mock_extract_xcom.assert_not_called()
 
-        if get_logs:
+        if get_logs and not parallelism:
             mocked_fetch_logs.assert_called_once()
+        elif get_logs and parallelism is not None:
+            assert mocked_fetch_logs.call_count == parallelism
         else:
             mocked_fetch_logs.assert_not_called()
 
diff --git 
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_job.py 
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_job.py
index 7a85e090536..4f6f6597fcf 100644
--- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_job.py
+++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_job.py
@@ -45,7 +45,9 @@ def trigger():
     return KubernetesJobTrigger(
         job_name=JOB_NAME,
         job_namespace=NAMESPACE,
-        pod_name=POD_NAME,
+        pod_names=[
+            POD_NAME,
+        ],
         pod_namespace=NAMESPACE,
         base_container_name=CONTAINER_NAME,
         kubernetes_conn_id=CONN_ID,
@@ -66,7 +68,9 @@ class TestKubernetesJobTrigger:
         assert kwargs_dict == {
             "job_name": JOB_NAME,
             "job_namespace": NAMESPACE,
-            "pod_name": POD_NAME,
+            "pod_names": [
+                POD_NAME,
+            ],
             "pod_namespace": NAMESPACE,
             "base_container_name": CONTAINER_NAME,
             "kubernetes_conn_id": CONN_ID,
@@ -105,7 +109,9 @@ class TestKubernetesJobTrigger:
             {
                 "name": JOB_NAME,
                 "namespace": NAMESPACE,
-                "pod_name": POD_NAME,
+                "pod_names": [
+                    POD_NAME,
+                ],
                 "pod_namespace": NAMESPACE,
                 "status": "success",
                 "message": "Job completed successfully",
@@ -141,7 +147,9 @@ class TestKubernetesJobTrigger:
             {
                 "name": JOB_NAME,
                 "namespace": NAMESPACE,
-                "pod_name": POD_NAME,
+                "pod_names": [
+                    POD_NAME,
+                ],
                 "pod_namespace": NAMESPACE,
                 "status": "error",
                 "message": "Job failed with error: Error",
diff --git 
a/providers/google/src/airflow/providers/google/cloud/operators/kubernetes_engine.py
 
b/providers/google/src/airflow/providers/google/cloud/operators/kubernetes_engine.py
index fc044568a4c..b2303374a52 100644
--- 
a/providers/google/src/airflow/providers/google/cloud/operators/kubernetes_engine.py
+++ 
b/providers/google/src/airflow/providers/google/cloud/operators/kubernetes_engine.py
@@ -789,8 +789,8 @@ class GKEStartJobOperator(GKEOperatorMixin, 
KubernetesJobOperator):
                 ssl_ca_cert=self.ssl_ca_cert,
                 job_name=self.job.metadata.name,
                 job_namespace=self.job.metadata.namespace,
-                pod_name=self.pod.metadata.name,
-                pod_namespace=self.pod.metadata.namespace,
+                pod_names=[pod.metadata.name for pod in self.pods],
+                pod_namespace=self.pods[0].metadata.namespace,
                 base_container_name=self.base_container_name,
                 gcp_conn_id=self.gcp_conn_id,
                 poll_interval=self.job_poll_interval,
diff --git 
a/providers/google/src/airflow/providers/google/cloud/triggers/kubernetes_engine.py
 
b/providers/google/src/airflow/providers/google/cloud/triggers/kubernetes_engine.py
index 08d30fcb3d6..3e458931652 100644
--- 
a/providers/google/src/airflow/providers/google/cloud/triggers/kubernetes_engine.py
+++ 
b/providers/google/src/airflow/providers/google/cloud/triggers/kubernetes_engine.py
@@ -260,9 +260,10 @@ class GKEJobTrigger(BaseTrigger):
         ssl_ca_cert: str,
         job_name: str,
         job_namespace: str,
-        pod_name: str,
+        pod_names: list[str],
         pod_namespace: str,
         base_container_name: str,
+        pod_name: str | None = None,
         gcp_conn_id: str = "google_cloud_default",
         poll_interval: float = 2,
         impersonation_chain: str | Sequence[str] | None = None,
@@ -274,7 +275,13 @@ class GKEJobTrigger(BaseTrigger):
         self.ssl_ca_cert = ssl_ca_cert
         self.job_name = job_name
         self.job_namespace = job_namespace
-        self.pod_name = pod_name
+        if pod_name is not None:
+            self._pod_name = pod_name
+            self.pod_names = [
+                self.pod_name,
+            ]
+        else:
+            self.pod_names = pod_names
         self.pod_namespace = pod_namespace
         self.base_container_name = base_container_name
         self.gcp_conn_id = gcp_conn_id
@@ -283,6 +290,15 @@ class GKEJobTrigger(BaseTrigger):
         self.get_logs = get_logs
         self.do_xcom_push = do_xcom_push
 
+    @property
+    def pod_name(self):
+        warnings.warn(
+            "`pod_name` parameter is deprecated, please use `pod_names`",
+            AirflowProviderDeprecationWarning,
+            stacklevel=2,
+        )
+        return self._pod_name
+
     def serialize(self) -> tuple[str, dict[str, Any]]:
         """Serialize KubernetesCreateJobTrigger arguments and classpath."""
         return (
@@ -292,7 +308,7 @@ class GKEJobTrigger(BaseTrigger):
                 "ssl_ca_cert": self.ssl_ca_cert,
                 "job_name": self.job_name,
                 "job_namespace": self.job_namespace,
-                "pod_name": self.pod_name,
+                "pod_names": self.pod_names,
                 "pod_namespace": self.pod_namespace,
                 "base_container_name": self.base_container_name,
                 "gcp_conn_id": self.gcp_conn_id,
@@ -305,8 +321,6 @@ class GKEJobTrigger(BaseTrigger):
 
     async def run(self) -> AsyncIterator[TriggerEvent]:
         """Get current job status and yield a TriggerEvent."""
-        if self.get_logs or self.do_xcom_push:
-            pod = await self.hook.get_pod(name=self.pod_name, 
namespace=self.pod_namespace)
         if self.do_xcom_push:
             kubernetes_provider = 
ProvidersManager().providers["apache-airflow-providers-cncf-kubernetes"]
             kubernetes_provider_name = kubernetes_provider.data["package-name"]
@@ -318,22 +332,26 @@ class GKEJobTrigger(BaseTrigger):
                     f"package 
{kubernetes_provider_name}=={kubernetes_provider_version} which doesn't "
                     f"support this feature. Please upgrade it to version 
higher than or equal to {min_version}."
                 )
-            await self.hook.wait_until_container_complete(
-                name=self.pod_name,
-                namespace=self.pod_namespace,
-                container_name=self.base_container_name,
-                poll_interval=self.poll_interval,
-            )
-            self.log.info("Checking if xcom sidecar container is started.")
-            await self.hook.wait_until_container_started(
-                name=self.pod_name,
-                namespace=self.pod_namespace,
-                container_name=PodDefaults.SIDECAR_CONTAINER_NAME,
-                poll_interval=self.poll_interval,
-            )
-            self.log.info("Extracting result from xcom sidecar container.")
-            loop = asyncio.get_running_loop()
-            xcom_result = await loop.run_in_executor(None, 
self.pod_manager.extract_xcom, pod)
+            xcom_results = []
+            for pod_name in self.pod_names:
+                pod = await self.hook.get_pod(name=pod_name, 
namespace=self.pod_namespace)
+                await self.hook.wait_until_container_complete(
+                    name=pod_name,
+                    namespace=self.pod_namespace,
+                    container_name=self.base_container_name,
+                    poll_interval=self.poll_interval,
+                )
+                self.log.info("Checking if xcom sidecar container is started.")
+                await self.hook.wait_until_container_started(
+                    name=pod_name,
+                    namespace=self.pod_namespace,
+                    container_name=PodDefaults.SIDECAR_CONTAINER_NAME,
+                    poll_interval=self.poll_interval,
+                )
+                self.log.info("Extracting result from xcom sidecar container.")
+                loop = asyncio.get_running_loop()
+                xcom_result = await loop.run_in_executor(None, 
self.pod_manager.extract_xcom, pod)
+                xcom_results.append(xcom_result)
         job: V1Job = await self.hook.wait_until_job_complete(
             name=self.job_name, namespace=self.job_namespace, 
poll_interval=self.poll_interval
         )
@@ -345,12 +363,12 @@ class GKEJobTrigger(BaseTrigger):
             {
                 "name": job.metadata.name,
                 "namespace": job.metadata.namespace,
-                "pod_name": pod.metadata.name if self.get_logs else None,
-                "pod_namespace": pod.metadata.namespace if self.get_logs else 
None,
+                "pod_names": [pod_name for pod_name in self.pod_names] if 
self.get_logs else None,
+                "pod_namespace": self.pod_namespace if self.get_logs else None,
                 "status": status,
                 "message": message,
                 "job": job_dict,
-                "xcom_result": xcom_result if self.do_xcom_push else None,
+                "xcom_result": xcom_results if self.do_xcom_push else None,
             }
         )
 
diff --git 
a/providers/google/tests/system/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py
 
b/providers/google/tests/system/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py
index 578a8e62b84..ccf1fbabeaf 100644
--- 
a/providers/google/tests/system/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py
+++ 
b/providers/google/tests/system/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py
@@ -48,8 +48,13 @@ CLUSTER = {"name": CLUSTER_NAME, "initial_node_count": 1}
 
 JOB_NAME = "test-pi"
 JOB_NAME_DEF = "test-pi-def"
+JOB_NAME_WITH_PARALLELISM = "test-pi-with-parallelism"
+JOB_NAME_DEF_WITH_PARALLELISM = "test-pi-def-with-parallelism"
 JOB_NAMESPACE = "default"
 
+PARALLELISM = 2
+COMPLETION_MODE = "Indexed"
+
 with DAG(
     DAG_ID,
     schedule="@once",  # Override to match your needs
@@ -92,6 +97,45 @@ with DAG(
     )
     # [END howto_operator_gke_start_job_def]
 
+    # [START howto_operator_gke_start_job_parallelism]
+    job_task_with_parallelism = GKEStartJobOperator(
+        task_id="job_task_with_parallelism",
+        project_id=GCP_PROJECT_ID,
+        location=GCP_LOCATION,
+        cluster_name=CLUSTER_NAME,
+        namespace=JOB_NAMESPACE,
+        image="perl:5.34.0",
+        cmds=["perl", "-Mbignum=bpi", "-wle", "print bpi(2000)"],
+        name=JOB_NAME_WITH_PARALLELISM,
+        wait_until_job_complete=True,
+        parallelism=PARALLELISM,
+        completions=PARALLELISM,
+        completion_mode=COMPLETION_MODE,
+        get_logs=True,
+        do_xcom_push=True,
+    )
+    # [END howto_operator_gke_start_job_with_parallelism]
+
+    # [START howto_operator_gke_start_job_def_with_parallelism]
+    job_task_def_with_parallelism = GKEStartJobOperator(
+        task_id="job_task_def_with_parallelism",
+        project_id=GCP_PROJECT_ID,
+        location=GCP_LOCATION,
+        cluster_name=CLUSTER_NAME,
+        namespace=JOB_NAMESPACE,
+        image="perl:5.34.0",
+        cmds=["perl", "-Mbignum=bpi", "-wle", "print bpi(2000)"],
+        name=JOB_NAME_DEF_WITH_PARALLELISM,
+        wait_until_job_complete=True,
+        deferrable=True,
+        parallelism=PARALLELISM,
+        completions=PARALLELISM,
+        completion_mode=COMPLETION_MODE,
+        get_logs=True,
+        do_xcom_push=True,
+    )
+    # [END howto_operator_gke_start_job_def_with_parallelism]
+
     # [START howto_operator_gke_list_jobs]
     list_job_task = GKEListJobsOperator(
         task_id="list_job_task", project_id=GCP_PROJECT_ID, 
location=GCP_LOCATION, cluster_name=CLUSTER_NAME
@@ -104,7 +148,7 @@ with DAG(
         project_id=GCP_PROJECT_ID,
         location=GCP_LOCATION,
         job_name=job_task.output["job_name"],
-        namespace="default",
+        namespace=JOB_NAMESPACE,
         cluster_name=CLUSTER_NAME,
     )
     # [END howto_operator_gke_describe_job]
@@ -114,7 +158,25 @@ with DAG(
         project_id=GCP_PROJECT_ID,
         location=GCP_LOCATION,
         job_name=job_task_def.output["job_name"],
-        namespace="default",
+        namespace=JOB_NAMESPACE,
+        cluster_name=CLUSTER_NAME,
+    )
+
+    describe_job_with_parallelism_task = GKEDescribeJobOperator(
+        task_id="describe_job_with_parallelism_task",
+        project_id=GCP_PROJECT_ID,
+        location=GCP_LOCATION,
+        job_name=job_task_with_parallelism.output["job_name"],
+        namespace=JOB_NAMESPACE,
+        cluster_name=CLUSTER_NAME,
+    )
+
+    describe_job_task_def_with_parallelism = GKEDescribeJobOperator(
+        task_id="describe_job_task_def_with_parallelism",
+        project_id=GCP_PROJECT_ID,
+        location=GCP_LOCATION,
+        job_name=job_task_def_with_parallelism.output["job_name"],
+        namespace=JOB_NAMESPACE,
         cluster_name=CLUSTER_NAME,
     )
 
@@ -125,7 +187,7 @@ with DAG(
         location=GCP_LOCATION,
         cluster_name=CLUSTER_NAME,
         name=job_task.output["job_name"],
-        namespace="default",
+        namespace=JOB_NAMESPACE,
     )
     # [END howto_operator_gke_suspend_job]
 
@@ -136,7 +198,7 @@ with DAG(
         location=GCP_LOCATION,
         cluster_name=CLUSTER_NAME,
         name=job_task.output["job_name"],
-        namespace="default",
+        namespace=JOB_NAMESPACE,
     )
     # [END howto_operator_gke_resume_job]
 
@@ -156,7 +218,25 @@ with DAG(
         project_id=GCP_PROJECT_ID,
         location=GCP_LOCATION,
         cluster_name=CLUSTER_NAME,
-        name=JOB_NAME,
+        name=JOB_NAME_DEF,
+        namespace=JOB_NAMESPACE,
+    )
+
+    delete_job_with_parallelism = GKEDeleteJobOperator(
+        task_id="delete_job_with_parallelism",
+        project_id=GCP_PROJECT_ID,
+        location=GCP_LOCATION,
+        cluster_name=CLUSTER_NAME,
+        name=JOB_NAME_WITH_PARALLELISM,
+        namespace=JOB_NAMESPACE,
+    )
+
+    delete_job_def_with_parallelism = GKEDeleteJobOperator(
+        task_id="delete_job_def_with_parallelism",
+        project_id=GCP_PROJECT_ID,
+        location=GCP_LOCATION,
+        cluster_name=CLUSTER_NAME,
+        name=JOB_NAME_DEF_WITH_PARALLELISM,
         namespace=JOB_NAMESPACE,
     )
 
@@ -170,12 +250,17 @@ with DAG(
 
     chain(
         create_cluster,
-        [job_task, job_task_def],
+        [job_task, job_task_def, job_task_with_parallelism, 
job_task_def_with_parallelism],
         list_job_task,
-        [describe_job_task, describe_job_task_def],
+        [
+            describe_job_task,
+            describe_job_task_def,
+            describe_job_with_parallelism_task,
+            describe_job_task_def_with_parallelism,
+        ],
         suspend_job,
         resume_job,
-        [delete_job, delete_job_def],
+        [delete_job, delete_job_def, delete_job_with_parallelism, 
delete_job_def_with_parallelism],
         delete_cluster,
     )
 
diff --git 
a/providers/google/tests/unit/google/cloud/operators/test_kubernetes_engine.py 
b/providers/google/tests/unit/google/cloud/operators/test_kubernetes_engine.py
index a8fafd6af5d..3c3775025e1 100644
--- 
a/providers/google/tests/unit/google/cloud/operators/test_kubernetes_engine.py
+++ 
b/providers/google/tests/unit/google/cloud/operators/test_kubernetes_engine.py
@@ -862,7 +862,9 @@ class TestGKEStartJobOperator:
         mock_pod_metadata = mock.MagicMock()
         mock_pod_metadata.name = K8S_POD_NAME
         mock_pod_metadata.namespace = K8S_NAMESPACE
-        self.operator.pod = mock.MagicMock(metadata=mock_pod_metadata)
+        self.operator.pods = [
+            mock.MagicMock(metadata=mock_pod_metadata),
+        ]
 
         mock_job_metadata = mock.MagicMock()
         mock_job_metadata.name = K8S_JOB_NAME
@@ -880,7 +882,69 @@ class TestGKEStartJobOperator:
             ssl_ca_cert=GKE_SSL_CA_CERT,
             job_name=K8S_JOB_NAME,
             job_namespace=K8S_NAMESPACE,
-            pod_name=K8S_POD_NAME,
+            pod_names=[
+                K8S_POD_NAME,
+            ],
+            pod_namespace=K8S_NAMESPACE,
+            base_container_name="base",
+            gcp_conn_id=TEST_CONN_ID,
+            poll_interval=10.0,
+            impersonation_chain=TEST_IMPERSONATION_CHAIN,
+            get_logs=mock_get_logs,
+            do_xcom_push=False,
+        )
+        mock_defer.assert_called_once_with(
+            trigger=mock_trigger.return_value,
+            method_name="execute_complete",
+        )
+
+    @mock.patch(GKE_OPERATORS_PATH.format("GKEStartJobOperator.defer"))
+    
@mock.patch(GKE_OPERATORS_PATH.format("GKEClusterAuthDetails.fetch_cluster_info"))
+    @mock.patch(GKE_OPERATORS_PATH.format("GKEHook"))
+    @mock.patch(GKE_OPERATORS_PATH.format("GKEJobTrigger"))
+    def test_execute_deferrable_with_parallelism(
+        self, mock_trigger, mock_cluster_hook, mock_fetch_cluster_info, 
mock_defer
+    ):
+        op = GKEStartJobOperator(
+            project_id=TEST_PROJECT_ID,
+            location=TEST_LOCATION,
+            cluster_name=GKE_CLUSTER_NAME,
+            task_id=TEST_TASK_ID,
+            name=K8S_JOB_NAME,
+            namespace=K8S_NAMESPACE,
+            image=TEST_IMAGE,
+            gcp_conn_id=TEST_CONN_ID,
+            impersonation_chain=TEST_IMPERSONATION_CHAIN,
+            parallelism=2,
+        )
+        mock_pod_name_1 = K8S_POD_NAME + "-1"
+        mock_pod_metadata_1 = mock.MagicMock()
+        mock_pod_metadata_1.name = mock_pod_name_1
+        mock_pod_metadata_1.namespace = K8S_NAMESPACE
+
+        mock_pod_name_2 = K8S_POD_NAME + "-2"
+        mock_pod_metadata_2 = mock.MagicMock()
+        mock_pod_metadata_2.name = mock_pod_name_2
+        mock_pod_metadata_2.namespace = K8S_NAMESPACE
+        op.pods = [mock.MagicMock(metadata=mock_pod_metadata_1), 
mock.MagicMock(metadata=mock_pod_metadata_2)]
+
+        mock_job_metadata = mock.MagicMock()
+        mock_job_metadata.name = K8S_JOB_NAME
+        mock_job_metadata.namespace = K8S_NAMESPACE
+        op.job = mock.MagicMock(metadata=mock_job_metadata)
+
+        mock_fetch_cluster_info.return_value = GKE_CLUSTER_URL, GKE_SSL_CA_CERT
+        mock_get_logs = mock.MagicMock()
+        op.get_logs = mock_get_logs
+
+        op.execute_deferrable()
+
+        mock_trigger.assert_called_once_with(
+            cluster_url=GKE_CLUSTER_URL,
+            ssl_ca_cert=GKE_SSL_CA_CERT,
+            job_name=K8S_JOB_NAME,
+            job_namespace=K8S_NAMESPACE,
+            pod_names=[mock_pod_name_1, mock_pod_name_2],
             pod_namespace=K8S_NAMESPACE,
             base_container_name="base",
             gcp_conn_id=TEST_CONN_ID,
diff --git 
a/providers/google/tests/unit/google/cloud/triggers/test_kubernetes_engine.py 
b/providers/google/tests/unit/google/cloud/triggers/test_kubernetes_engine.py
index ffe749ef397..3704c5c3206 100644
--- 
a/providers/google/tests/unit/google/cloud/triggers/test_kubernetes_engine.py
+++ 
b/providers/google/tests/unit/google/cloud/triggers/test_kubernetes_engine.py
@@ -98,7 +98,9 @@ def job_trigger():
         ssl_ca_cert=SSL_CA_CERT,
         job_name=JOB_NAME,
         job_namespace=NAMESPACE,
-        pod_name=POD_NAME,
+        pod_names=[
+            POD_NAME,
+        ],
         pod_namespace=NAMESPACE,
         base_container_name=BASE_CONTAINER_NAME,
         gcp_conn_id=GCP_CONN_ID,
@@ -482,7 +484,9 @@ class TestGKEStartJobTrigger:
             "ssl_ca_cert": SSL_CA_CERT,
             "job_name": JOB_NAME,
             "job_namespace": NAMESPACE,
-            "pod_name": POD_NAME,
+            "pod_names": [
+                POD_NAME,
+            ],
             "pod_namespace": NAMESPACE,
             "base_container_name": BASE_CONTAINER_NAME,
             "gcp_conn_id": GCP_CONN_ID,
@@ -521,7 +525,9 @@ class TestGKEStartJobTrigger:
             {
                 "name": JOB_NAME,
                 "namespace": NAMESPACE,
-                "pod_name": POD_NAME,
+                "pod_names": [
+                    POD_NAME,
+                ],
                 "pod_namespace": NAMESPACE,
                 "status": "success",
                 "message": "Job completed successfully",
@@ -559,7 +565,9 @@ class TestGKEStartJobTrigger:
             {
                 "name": JOB_NAME,
                 "namespace": NAMESPACE,
-                "pod_name": POD_NAME,
+                "pod_names": [
+                    POD_NAME,
+                ],
                 "pod_namespace": NAMESPACE,
                 "status": "error",
                 "message": "Job failed with error: Error",

Reply via email to