This is an automated email from the ASF dual-hosted git repository.
shahar pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new e50ce94cdc4 Handle multiple pods to prevent
```KubernetesJobOperator``` falls with parallelism option (#49899)
e50ce94cdc4 is described below
commit e50ce94cdc427b3f4930d59e5a7390e5eb2bfab5
Author: Nitochkin <[email protected]>
AuthorDate: Thu Jul 24 11:56:03 2025 +0200
Handle multiple pods to prevent ```KubernetesJobOperator``` falls with
parallelism option (#49899)
Co-authored-by: Anton Nitochkin <[email protected]>
---
.../providers/cncf/kubernetes/operators/job.py | 132 ++++++++++++-----
.../providers/cncf/kubernetes/triggers/job.py | 63 +++++---
.../unit/cncf/kubernetes/operators/test_job.py | 161 +++++++++++++++++++--
.../unit/cncf/kubernetes/triggers/test_job.py | 16 +-
.../google/cloud/operators/kubernetes_engine.py | 4 +-
.../google/cloud/triggers/kubernetes_engine.py | 66 ++++++---
.../example_kubernetes_engine_job.py | 101 ++++++++++++-
.../cloud/operators/test_kubernetes_engine.py | 68 ++++++++-
.../cloud/triggers/test_kubernetes_engine.py | 16 +-
9 files changed, 517 insertions(+), 110 deletions(-)
diff --git
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/job.py
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/job.py
index 7c19421594a..79e9b835812 100644
---
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/job.py
+++
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/job.py
@@ -22,16 +22,17 @@ import copy
import json
import logging
import os
+import warnings
from collections.abc import Sequence
from functools import cached_property
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any, Literal
from kubernetes.client import BatchV1Api, models as k8s
from kubernetes.client.api_client import ApiClient
from kubernetes.client.rest import ApiException
from airflow.configuration import conf
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
from airflow.providers.cncf.kubernetes.hooks.kubernetes import KubernetesHook
from airflow.providers.cncf.kubernetes.kubernetes_helper_functions import (
add_unique_suffix,
@@ -81,6 +82,17 @@ class KubernetesJobOperator(KubernetesPodOperator):
Used if the parameter `wait_until_job_complete` set True.
:param deferrable: Run operator in the deferrable mode. Note that the
parameter
`wait_until_job_complete` must be set True.
+ :param on_kill_propagation_policy: Whether and how garbage collection will
be performed. Default is 'Foreground'.
+ Acceptable values are:
+ 'Orphan' - orphan the dependents;
+ 'Background' - allow the garbage collector to delete the dependents in
the background;
+ 'Foreground' - a cascading policy that deletes all dependents in the
foreground.
+ Default value is 'Foreground'.
+ :param discover_pods_retry_number: Number of time list_namespaced_pod will
be performed to discover
+ already running pods.
+ :param unwrap_single: Unwrap single result from the pod. For example, when
set to `True` - if the XCom
+ result should be `['res']`, the final result would be `'res'`. Default
is True to support backward
+ compatibility.
"""
template_fields: Sequence[str] = tuple({"job_template_file"} |
set(KubernetesPodOperator.template_fields))
@@ -101,8 +113,12 @@ class KubernetesJobOperator(KubernetesPodOperator):
wait_until_job_complete: bool = False,
job_poll_interval: float = 10,
deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ on_kill_propagation_policy: Literal["Foreground", "Background",
"Orphan"] = "Foreground",
+ discover_pods_retry_number: int = 3,
+ unwrap_single: bool = True,
**kwargs,
) -> None:
+ self._pod = None
super().__init__(**kwargs)
self.job_template_file = job_template_file
self.full_job_spec = full_job_spec
@@ -119,6 +135,22 @@ class KubernetesJobOperator(KubernetesPodOperator):
self.wait_until_job_complete = wait_until_job_complete
self.job_poll_interval = job_poll_interval
self.deferrable = deferrable
+ self.on_kill_propagation_policy = on_kill_propagation_policy
+ self.discover_pods_retry_number = discover_pods_retry_number
+ self.unwrap_single = unwrap_single
+
+ @property
+ def pod(self):
+ warnings.warn(
+ "`pod` parameter is deprecated, please use `pods`",
+ AirflowProviderDeprecationWarning,
+ stacklevel=2,
+ )
+ return self.pods[0] if self.pods else None
+
+ @pod.setter
+ def pod(self, value):
+ self._pod = value
@cached_property
def _incluster_namespace(self):
@@ -167,12 +199,16 @@ class KubernetesJobOperator(KubernetesPodOperator):
ti.xcom_push(key="job_name", value=self.job.metadata.name)
ti.xcom_push(key="job_namespace", value=self.job.metadata.namespace)
- self.pod: k8s.V1Pod | None
- if self.pod is None:
- self.pod = self.get_or_create_pod( # must set `self.pod` for
`on_kill`
- pod_request_obj=self.pod_request_obj,
- context=context,
- )
+ self.pods: Sequence[k8s.V1Pod] | None = None
+ if self.parallelism is None and self.pod is None:
+ self.pods = [
+ self.get_or_create_pod(
+ pod_request_obj=self.pod_request_obj,
+ context=context,
+ )
+ ]
+ else:
+ self.pods = self.get_pods(pod_request_obj=self.pod_request_obj,
context=context)
if self.wait_until_job_complete and self.deferrable:
self.execute_deferrable()
@@ -180,22 +216,25 @@ class KubernetesJobOperator(KubernetesPodOperator):
if self.wait_until_job_complete:
if self.do_xcom_push:
- self.pod_manager.await_container_completion(
- pod=self.pod, container_name=self.base_container_name
- )
-
self.pod_manager.await_xcom_sidecar_container_start(pod=self.pod)
- xcom_result = self.extract_xcom(pod=self.pod)
+ xcom_result = []
+ for pod in self.pods:
+ self.pod_manager.await_container_completion(
+ pod=pod, container_name=self.base_container_name
+ )
+
self.pod_manager.await_xcom_sidecar_container_start(pod=pod)
+ xcom_result.append(self.extract_xcom(pod=pod))
self.job = self.hook.wait_until_job_complete(
job_name=self.job.metadata.name,
namespace=self.job.metadata.namespace,
job_poll_interval=self.job_poll_interval,
)
if self.get_logs:
- self.pod_manager.fetch_requested_container_logs(
- pod=self.pod,
- containers=self.container_logs,
- follow_logs=True,
- )
+ for pod in self.pods:
+ self.pod_manager.fetch_requested_container_logs(
+ pod=pod,
+ containers=self.container_logs,
+ follow_logs=True,
+ )
ti.xcom_push(key="job", value=self.job.to_dict())
if self.wait_until_job_complete:
@@ -211,8 +250,8 @@ class KubernetesJobOperator(KubernetesPodOperator):
trigger=KubernetesJobTrigger(
job_name=self.job.metadata.name,
job_namespace=self.job.metadata.namespace,
- pod_name=self.pod.metadata.name,
- pod_namespace=self.pod.metadata.namespace,
+ pod_names=[pod.metadata.name for pod in self.pods],
+ pod_namespace=self.pods[0].metadata.namespace,
base_container_name=self.base_container_name,
kubernetes_conn_id=self.kubernetes_conn_id,
cluster_context=self.cluster_context,
@@ -232,20 +271,23 @@ class KubernetesJobOperator(KubernetesPodOperator):
raise AirflowException(event["message"])
if self.get_logs:
- pod_name = event["pod_name"]
- pod_namespace = event["pod_namespace"]
- self.pod = self.hook.get_pod(pod_name, pod_namespace)
- if not self.pod:
- raise PodNotFoundException("Could not find pod after resuming
from deferral")
- self._write_logs(self.pod)
+ for pod_name in event["pod_names"]:
+ pod_namespace = event["pod_namespace"]
+ pod = self.hook.get_pod(pod_name, pod_namespace)
+ if not pod:
+ raise PodNotFoundException("Could not find pod after
resuming from deferral")
+ self._write_logs(pod)
if self.do_xcom_push:
- xcom_result = event["xcom_result"]
- if isinstance(xcom_result, str) and xcom_result.rstrip() ==
EMPTY_XCOM_RESULT:
- self.log.info("xcom result file is empty.")
- return None
- self.log.info("xcom result: \n%s", xcom_result)
- return json.loads(xcom_result)
+ xcom_results: list[Any | None] = []
+ for xcom_result in event["xcom_result"]:
+ if isinstance(xcom_result, str) and xcom_result.rstrip() ==
EMPTY_XCOM_RESULT:
+ self.log.info("xcom result file is empty.")
+ xcom_results.append(None)
+ continue
+ self.log.info("xcom result: \n%s", xcom_result)
+ xcom_results.append(json.loads(xcom_result))
+ return xcom_results[0] if self.unwrap_single and len(xcom_results)
== 1 else xcom_results
@staticmethod
def deserialize_job_template_file(path: str) -> k8s.V1Job:
@@ -275,12 +317,11 @@ class KubernetesJobOperator(KubernetesPodOperator):
kwargs = {
"name": job.metadata.name,
"namespace": job.metadata.namespace,
+ "propagation_policy": self.on_kill_propagation_policy,
}
if self.termination_grace_period is not None:
kwargs.update(grace_period_seconds=self.termination_grace_period)
self.job_client.delete_namespaced_job(**kwargs)
- if self.pod:
- super().on_kill()
def build_job_request_obj(self, context: Context | None = None) ->
k8s.V1Job:
"""
@@ -400,6 +441,29 @@ class KubernetesJobOperator(KubernetesPodOperator):
return None
+ def get_pods(
+ self, pod_request_obj: k8s.V1Pod, context: Context, *,
exclude_checked: bool = True
+ ) -> Sequence[k8s.V1Pod]:
+ """Return an already-running pods if exists."""
+ label_selector = self._build_find_pod_label_selector(context,
exclude_checked=exclude_checked)
+ pod_list: Sequence[k8s.V1Pod] = []
+ retry_number: int = 0
+
+ while len(pod_list) != self.parallelism or retry_number <=
self.discover_pods_retry_number:
+ pod_list = self.client.list_namespaced_pod(
+ namespace=pod_request_obj.metadata.namespace,
+ label_selector=label_selector,
+ ).items
+ retry_number += 1
+
+ if len(pod_list) == 0:
+ raise AirflowException(f"No pods running with labels
{label_selector}")
+
+ for pod_instance in pod_list:
+ self.log_matching_pod(pod=pod_instance, context=context)
+
+ return pod_list
+
class KubernetesDeleteJobOperator(BaseOperator):
"""
diff --git
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/job.py
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/job.py
index 359c3054788..b60373c2d53 100644
---
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/job.py
+++
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/job.py
@@ -17,10 +17,12 @@
from __future__ import annotations
import asyncio
+import warnings
from collections.abc import AsyncIterator
from functools import cached_property
from typing import TYPE_CHECKING, Any
+from airflow.exceptions import AirflowProviderDeprecationWarning
from airflow.providers.cncf.kubernetes.hooks.kubernetes import
AsyncKubernetesHook, KubernetesHook
from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager
from airflow.providers.cncf.kubernetes.utils.xcom_sidecar import PodDefaults
@@ -36,7 +38,8 @@ class KubernetesJobTrigger(BaseTrigger):
:param job_name: The name of the job.
:param job_namespace: The namespace of the job.
- :param pod_name: The name of the Pod.
+ :param pod_name: The name of the Pod. Parameter is deprecated, please use
pod_names instead.
+ :param pod_names: The name of the Pods.
:param pod_namespace: The namespace of the Pod.
:param base_container_name: The name of the base container in the pod.
:param kubernetes_conn_id: The :ref:`kubernetes connection id
<howto/connection:kubernetes>`
@@ -55,9 +58,10 @@ class KubernetesJobTrigger(BaseTrigger):
self,
job_name: str,
job_namespace: str,
- pod_name: str,
+ pod_names: list[str],
pod_namespace: str,
base_container_name: str,
+ pod_name: str | None = None,
kubernetes_conn_id: str | None = None,
poll_interval: float = 10.0,
cluster_context: str | None = None,
@@ -69,7 +73,13 @@ class KubernetesJobTrigger(BaseTrigger):
super().__init__()
self.job_name = job_name
self.job_namespace = job_namespace
- self.pod_name = pod_name
+ if pod_name is not None:
+ self._pod_name = pod_name
+ self.pod_names = [
+ self.pod_name,
+ ]
+ else:
+ self.pod_names = pod_names
self.pod_namespace = pod_namespace
self.base_container_name = base_container_name
self.kubernetes_conn_id = kubernetes_conn_id
@@ -80,6 +90,15 @@ class KubernetesJobTrigger(BaseTrigger):
self.get_logs = get_logs
self.do_xcom_push = do_xcom_push
+ @property
+ def pod_name(self):
+ warnings.warn(
+ "`pod_name` parameter is deprecated, please use `pod_names`",
+ AirflowProviderDeprecationWarning,
+ stacklevel=2,
+ )
+ return self._pod_name
+
def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize KubernetesCreateJobTrigger arguments and classpath."""
return (
@@ -87,7 +106,7 @@ class KubernetesJobTrigger(BaseTrigger):
{
"job_name": self.job_name,
"job_namespace": self.job_namespace,
- "pod_name": self.pod_name,
+ "pod_names": self.pod_names,
"pod_namespace": self.pod_namespace,
"base_container_name": self.base_container_name,
"kubernetes_conn_id": self.kubernetes_conn_id,
@@ -102,21 +121,23 @@ class KubernetesJobTrigger(BaseTrigger):
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Get current job status and yield a TriggerEvent."""
- if self.get_logs or self.do_xcom_push:
- pod = await self.hook.get_pod(name=self.pod_name,
namespace=self.pod_namespace)
if self.do_xcom_push:
- await self.hook.wait_until_container_complete(
- name=self.pod_name, namespace=self.pod_namespace,
container_name=self.base_container_name
- )
- self.log.info("Checking if xcom sidecar container is started.")
- await self.hook.wait_until_container_started(
- name=self.pod_name,
- namespace=self.pod_namespace,
- container_name=PodDefaults.SIDECAR_CONTAINER_NAME,
- )
- self.log.info("Extracting result from xcom sidecar container.")
- loop = asyncio.get_running_loop()
- xcom_result = await loop.run_in_executor(None,
self.pod_manager.extract_xcom, pod)
+ xcom_results = []
+ for pod_name in self.pod_names:
+ pod = await self.hook.get_pod(name=pod_name,
namespace=self.pod_namespace)
+ await self.hook.wait_until_container_complete(
+ name=pod_name, namespace=self.pod_namespace,
container_name=self.base_container_name
+ )
+ self.log.info("Checking if xcom sidecar container is started.")
+ await self.hook.wait_until_container_started(
+ name=pod_name,
+ namespace=self.pod_namespace,
+ container_name=PodDefaults.SIDECAR_CONTAINER_NAME,
+ )
+ self.log.info("Extracting result from xcom sidecar container.")
+ loop = asyncio.get_running_loop()
+ xcom_result = await loop.run_in_executor(None,
self.pod_manager.extract_xcom, pod)
+ xcom_results.append(xcom_result)
job: V1Job = await
self.hook.wait_until_job_complete(name=self.job_name,
namespace=self.job_namespace)
job_dict = job.to_dict()
error_message = self.hook.is_job_failed(job=job)
@@ -124,14 +145,14 @@ class KubernetesJobTrigger(BaseTrigger):
{
"name": job.metadata.name,
"namespace": job.metadata.namespace,
- "pod_name": pod.metadata.name if self.get_logs else None,
- "pod_namespace": pod.metadata.namespace if self.get_logs else
None,
+ "pod_names": [pod_name for pod_name in self.pod_names] if
self.get_logs else None,
+ "pod_namespace": self.pod_namespace if self.get_logs else None,
"status": "error" if error_message else "success",
"message": f"Job failed with error: {error_message}"
if error_message
else "Job completed successfully",
"job": job_dict,
- "xcom_result": xcom_result if self.do_xcom_push else None,
+ "xcom_result": xcom_results if self.do_xcom_push else None,
}
)
diff --git
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_job.py
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_job.py
index 130a87cbbe3..6caf69e0481 100644
--- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_job.py
+++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_job.py
@@ -26,7 +26,7 @@ import pendulum
import pytest
from kubernetes.client import ApiClient, models as k8s
-from airflow.exceptions import AirflowException
+from airflow.exceptions import AirflowException,
AirflowProviderDeprecationWarning
from airflow.models import DAG, DagModel, DagRun, TaskInstance
from airflow.models.serialized_dag import SerializedDagModel
from airflow.providers.cncf.kubernetes.operators.job import (
@@ -52,6 +52,7 @@ POD_NAME = "test-pod"
POD_NAMESPACE = "test-namespace"
TEST_XCOM_RESULT = '{"result": "test-xcom-result"}'
POD_MANAGER_CLASS =
"airflow.providers.cncf.kubernetes.utils.pod_manager.PodManager"
+ON_KILL_PROPAGATION_POLICY = "Foreground"
def create_context(task, persist_to_db=False, map_index=None):
@@ -508,6 +509,45 @@ class TestKubernetesJobOperator:
op = KubernetesJobOperator(
task_id="test_task_id",
)
+ with pytest.warns(AirflowProviderDeprecationWarning):
+ execute_result = op.execute(context=context)
+
+ mock_build_job_request_obj.assert_called_once_with(context)
+
mock_create_job.assert_called_once_with(job_request_obj=mock_job_request_obj)
+ mock_ti.xcom_push.assert_has_calls(
+ [
+ mock.call(key="job_name",
value=mock_job_expected.metadata.name),
+ mock.call(key="job_namespace",
value=mock_job_expected.metadata.namespace),
+ mock.call(key="job",
value=mock_job_expected.to_dict.return_value),
+ ]
+ )
+
+ assert op.job_request_obj == mock_job_request_obj
+ assert op.job == mock_job_expected
+ assert not op.wait_until_job_complete
+ assert execute_result is None
+ assert not mock_hook.wait_until_job_complete.called
+
+ @pytest.mark.non_db_test_override
+ @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_pods"))
+
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj"))
+ @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job"))
+ @patch(HOOK_CLASS)
+ def test_execute_with_parallelism(
+ self, mock_hook, mock_create_job, mock_build_job_request_obj,
mock_get_pods
+ ):
+ mock_hook.return_value.is_job_failed.return_value = False
+ mock_job_request_obj = mock_build_job_request_obj.return_value
+ mock_job_expected = mock_create_job.return_value
+ mock_get_pods.return_value = [mock.MagicMock(), mock.MagicMock()]
+ mock_pods_expected = mock_get_pods.return_value
+ mock_ti = mock.MagicMock()
+ context = dict(ti=mock_ti)
+
+ op = KubernetesJobOperator(
+ task_id="test_task_id",
+ parallelism=2,
+ )
execute_result = op.execute(context=context)
mock_build_job_request_obj.assert_called_once_with(context)
@@ -522,6 +562,9 @@ class TestKubernetesJobOperator:
assert op.job_request_obj == mock_job_request_obj
assert op.job == mock_job_expected
+ assert op.pods == mock_pods_expected
+ with pytest.warns(AirflowProviderDeprecationWarning):
+ assert op.pod is mock_pods_expected[0]
assert not op.wait_until_job_complete
assert execute_result is None
assert not mock_hook.wait_until_job_complete.called
@@ -551,7 +594,8 @@ class TestKubernetesJobOperator:
wait_until_job_complete=True,
deferrable=True,
)
- actual_result = op.execute(context=context)
+ with pytest.warns(AirflowProviderDeprecationWarning):
+ actual_result = op.execute(context=context)
mock_build_job_request_obj.assert_called_once_with(context)
mock_create_job.assert_called_once_with(job_request_obj=mock_job_request_obj)
@@ -583,8 +627,9 @@ class TestKubernetesJobOperator:
wait_until_job_complete=True,
)
- with pytest.raises(AirflowException):
- op.execute(context=dict(ti=mock.MagicMock()))
+ with pytest.warns(AirflowProviderDeprecationWarning):
+ with pytest.raises(AirflowException):
+ op.execute(context=dict(ti=mock.MagicMock()))
@pytest.mark.non_db_test_override
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.defer"))
@@ -616,6 +661,71 @@ class TestKubernetesJobOperator:
)
op.job = mock_job
op.pod = mock_pod
+ op.pods = [
+ mock_pod,
+ ]
+
+ actual_result = op.execute_deferrable()
+
+ mock_execute_deferrable.assert_called_once_with(
+ trigger=mock_trigger_instance,
+ method_name="execute_complete",
+ )
+ mock_trigger.assert_called_once_with(
+ job_name=JOB_NAME,
+ job_namespace=JOB_NAMESPACE,
+ pod_names=[
+ POD_NAME,
+ ],
+ pod_namespace=POD_NAMESPACE,
+ base_container_name=op.BASE_CONTAINER_NAME,
+ kubernetes_conn_id=KUBERNETES_CONN_ID,
+ cluster_context=mock_cluster_context,
+ config_file=mock_config_file,
+ in_cluster=mock_in_cluster,
+ poll_interval=POLL_INTERVAL,
+ get_logs=True,
+ do_xcom_push=False,
+ )
+ assert actual_result is None
+
+ @pytest.mark.non_db_test_override
+ @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.defer"))
+ @patch(JOB_OPERATORS_PATH.format("KubernetesJobTrigger"))
+ def test_execute_deferrable_with_parallelism(self, mock_trigger,
mock_execute_deferrable):
+ mock_cluster_context = mock.MagicMock()
+ mock_config_file = mock.MagicMock()
+ mock_in_cluster = mock.MagicMock()
+
+ mock_job = mock.MagicMock()
+ mock_job.metadata.name = JOB_NAME
+ mock_job.metadata.namespace = JOB_NAMESPACE
+
+ pod_name_1 = POD_NAME + "-1"
+ mock_pod_1 = mock.MagicMock()
+ mock_pod_1.metadata.name = pod_name_1
+ mock_pod_1.metadata.namespace = POD_NAMESPACE
+
+ pod_name_2 = POD_NAME + "-2"
+ mock_pod_2 = mock.MagicMock()
+ mock_pod_2.metadata.name = pod_name_2
+ mock_pod_2.metadata.namespace = POD_NAMESPACE
+
+ mock_trigger_instance = mock_trigger.return_value
+
+ op = KubernetesJobOperator(
+ task_id="test_task_id",
+ kubernetes_conn_id=KUBERNETES_CONN_ID,
+ cluster_context=mock_cluster_context,
+ config_file=mock_config_file,
+ in_cluster=mock_in_cluster,
+ job_poll_interval=POLL_INTERVAL,
+ parallelism=2,
+ wait_until_job_complete=True,
+ deferrable=True,
+ )
+ op.job = mock_job
+ op.pods = [mock_pod_1, mock_pod_2]
actual_result = op.execute_deferrable()
@@ -626,7 +736,7 @@ class TestKubernetesJobOperator:
mock_trigger.assert_called_once_with(
job_name=JOB_NAME,
job_namespace=JOB_NAMESPACE,
- pod_name=POD_NAME,
+ pod_names=[pod_name_1, pod_name_2],
pod_namespace=POD_NAMESPACE,
base_container_name=op.BASE_CONTAINER_NAME,
kubernetes_conn_id=KUBERNETES_CONN_ID,
@@ -656,7 +766,8 @@ class TestKubernetesJobOperator:
op = KubernetesJobOperator(
task_id="test_task_id", wait_until_job_complete=True,
job_poll_interval=POLL_INTERVAL
)
- op.execute(context=dict(ti=mock_ti))
+ with pytest.warns(AirflowProviderDeprecationWarning):
+ op.execute(context=dict(ti=mock_ti))
assert op.wait_until_job_complete
assert op.job_poll_interval == POLL_INTERVAL
@@ -676,9 +787,17 @@ class TestKubernetesJobOperator:
event = {
"job": mock_job,
"status": "success",
- "pod_name": POD_NAME if get_logs else None,
+ "pod_names": [
+ POD_NAME,
+ ]
+ if get_logs
+ else None,
"pod_namespace": POD_NAMESPACE if get_logs else None,
- "xcom_result": TEST_XCOM_RESULT if do_xcom_push else None,
+ "xcom_result": [
+ TEST_XCOM_RESULT,
+ ]
+ if do_xcom_push
+ else None,
}
KubernetesJobOperator(
@@ -718,6 +837,7 @@ class TestKubernetesJobOperator:
mock_client.delete_namespaced_job.assert_called_once_with(
name=JOB_NAME,
namespace=JOB_NAMESPACE,
+ propagation_policy=ON_KILL_PROPAGATION_POLICY,
)
@pytest.mark.non_db_test_override
@@ -737,6 +857,7 @@ class TestKubernetesJobOperator:
mock_client.delete_namespaced_job.assert_called_once_with(
name=JOB_NAME,
namespace=JOB_NAMESPACE,
+ propagation_policy=ON_KILL_PROPAGATION_POLICY,
grace_period_seconds=mock_termination_grace_period,
)
@@ -752,9 +873,11 @@ class TestKubernetesJobOperator:
mock_client.delete_namespaced_job.assert_not_called()
mock_serialize.assert_not_called()
+ @pytest.mark.parametrize("parallelism", [None, 2])
@pytest.mark.parametrize("do_xcom_push", [True, False])
@pytest.mark.parametrize("get_logs", [True, False])
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.extract_xcom"))
+ @patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_pods"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.get_or_create_pod"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.build_job_request_obj"))
@patch(JOB_OPERATORS_PATH.format("KubernetesJobOperator.create_job"))
@@ -771,10 +894,16 @@ class TestKubernetesJobOperator:
mock_create_job,
mock_build_job_request_obj,
mock_get_or_create_pod,
+ mock_get_pods,
mock_extract_xcom,
get_logs,
do_xcom_push,
+ parallelism,
):
+ if parallelism == 2:
+ mock_pod_1 = mock.MagicMock()
+ mock_pod_2 = mock.MagicMock()
+ mock_get_pods.return_value = [mock_pod_1, mock_pod_2]
mock_ti = mock.MagicMock()
op = KubernetesJobOperator(
task_id="test_task_id",
@@ -782,16 +911,26 @@ class TestKubernetesJobOperator:
job_poll_interval=POLL_INTERVAL,
get_logs=get_logs,
do_xcom_push=do_xcom_push,
+ parallelism=parallelism,
)
- op.execute(context=dict(ti=mock_ti))
- if do_xcom_push:
+ if not parallelism:
+ with pytest.warns(AirflowProviderDeprecationWarning):
+ op.execute(context=dict(ti=mock_ti))
+ else:
+ op.execute(context=dict(ti=mock_ti))
+
+ if do_xcom_push and not parallelism:
mock_extract_xcom.assert_called_once()
+ elif do_xcom_push and parallelism is not None:
+ assert mock_extract_xcom.call_count == parallelism
else:
mock_extract_xcom.assert_not_called()
- if get_logs:
+ if get_logs and not parallelism:
mocked_fetch_logs.assert_called_once()
+ elif get_logs and parallelism is not None:
+ assert mocked_fetch_logs.call_count == parallelism
else:
mocked_fetch_logs.assert_not_called()
diff --git
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_job.py
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_job.py
index 7a85e090536..4f6f6597fcf 100644
--- a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_job.py
+++ b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_job.py
@@ -45,7 +45,9 @@ def trigger():
return KubernetesJobTrigger(
job_name=JOB_NAME,
job_namespace=NAMESPACE,
- pod_name=POD_NAME,
+ pod_names=[
+ POD_NAME,
+ ],
pod_namespace=NAMESPACE,
base_container_name=CONTAINER_NAME,
kubernetes_conn_id=CONN_ID,
@@ -66,7 +68,9 @@ class TestKubernetesJobTrigger:
assert kwargs_dict == {
"job_name": JOB_NAME,
"job_namespace": NAMESPACE,
- "pod_name": POD_NAME,
+ "pod_names": [
+ POD_NAME,
+ ],
"pod_namespace": NAMESPACE,
"base_container_name": CONTAINER_NAME,
"kubernetes_conn_id": CONN_ID,
@@ -105,7 +109,9 @@ class TestKubernetesJobTrigger:
{
"name": JOB_NAME,
"namespace": NAMESPACE,
- "pod_name": POD_NAME,
+ "pod_names": [
+ POD_NAME,
+ ],
"pod_namespace": NAMESPACE,
"status": "success",
"message": "Job completed successfully",
@@ -141,7 +147,9 @@ class TestKubernetesJobTrigger:
{
"name": JOB_NAME,
"namespace": NAMESPACE,
- "pod_name": POD_NAME,
+ "pod_names": [
+ POD_NAME,
+ ],
"pod_namespace": NAMESPACE,
"status": "error",
"message": "Job failed with error: Error",
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/kubernetes_engine.py
b/providers/google/src/airflow/providers/google/cloud/operators/kubernetes_engine.py
index fc044568a4c..b2303374a52 100644
---
a/providers/google/src/airflow/providers/google/cloud/operators/kubernetes_engine.py
+++
b/providers/google/src/airflow/providers/google/cloud/operators/kubernetes_engine.py
@@ -789,8 +789,8 @@ class GKEStartJobOperator(GKEOperatorMixin,
KubernetesJobOperator):
ssl_ca_cert=self.ssl_ca_cert,
job_name=self.job.metadata.name,
job_namespace=self.job.metadata.namespace,
- pod_name=self.pod.metadata.name,
- pod_namespace=self.pod.metadata.namespace,
+ pod_names=[pod.metadata.name for pod in self.pods],
+ pod_namespace=self.pods[0].metadata.namespace,
base_container_name=self.base_container_name,
gcp_conn_id=self.gcp_conn_id,
poll_interval=self.job_poll_interval,
diff --git
a/providers/google/src/airflow/providers/google/cloud/triggers/kubernetes_engine.py
b/providers/google/src/airflow/providers/google/cloud/triggers/kubernetes_engine.py
index 08d30fcb3d6..3e458931652 100644
---
a/providers/google/src/airflow/providers/google/cloud/triggers/kubernetes_engine.py
+++
b/providers/google/src/airflow/providers/google/cloud/triggers/kubernetes_engine.py
@@ -260,9 +260,10 @@ class GKEJobTrigger(BaseTrigger):
ssl_ca_cert: str,
job_name: str,
job_namespace: str,
- pod_name: str,
+ pod_names: list[str],
pod_namespace: str,
base_container_name: str,
+ pod_name: str | None = None,
gcp_conn_id: str = "google_cloud_default",
poll_interval: float = 2,
impersonation_chain: str | Sequence[str] | None = None,
@@ -274,7 +275,13 @@ class GKEJobTrigger(BaseTrigger):
self.ssl_ca_cert = ssl_ca_cert
self.job_name = job_name
self.job_namespace = job_namespace
- self.pod_name = pod_name
+ if pod_name is not None:
+ self._pod_name = pod_name
+ self.pod_names = [
+ self.pod_name,
+ ]
+ else:
+ self.pod_names = pod_names
self.pod_namespace = pod_namespace
self.base_container_name = base_container_name
self.gcp_conn_id = gcp_conn_id
@@ -283,6 +290,15 @@ class GKEJobTrigger(BaseTrigger):
self.get_logs = get_logs
self.do_xcom_push = do_xcom_push
+ @property
+ def pod_name(self):
+ warnings.warn(
+ "`pod_name` parameter is deprecated, please use `pod_names`",
+ AirflowProviderDeprecationWarning,
+ stacklevel=2,
+ )
+ return self._pod_name
+
def serialize(self) -> tuple[str, dict[str, Any]]:
"""Serialize KubernetesCreateJobTrigger arguments and classpath."""
return (
@@ -292,7 +308,7 @@ class GKEJobTrigger(BaseTrigger):
"ssl_ca_cert": self.ssl_ca_cert,
"job_name": self.job_name,
"job_namespace": self.job_namespace,
- "pod_name": self.pod_name,
+ "pod_names": self.pod_names,
"pod_namespace": self.pod_namespace,
"base_container_name": self.base_container_name,
"gcp_conn_id": self.gcp_conn_id,
@@ -305,8 +321,6 @@ class GKEJobTrigger(BaseTrigger):
async def run(self) -> AsyncIterator[TriggerEvent]:
"""Get current job status and yield a TriggerEvent."""
- if self.get_logs or self.do_xcom_push:
- pod = await self.hook.get_pod(name=self.pod_name,
namespace=self.pod_namespace)
if self.do_xcom_push:
kubernetes_provider =
ProvidersManager().providers["apache-airflow-providers-cncf-kubernetes"]
kubernetes_provider_name = kubernetes_provider.data["package-name"]
@@ -318,22 +332,26 @@ class GKEJobTrigger(BaseTrigger):
f"package
{kubernetes_provider_name}=={kubernetes_provider_version} which doesn't "
f"support this feature. Please upgrade it to version
higher than or equal to {min_version}."
)
- await self.hook.wait_until_container_complete(
- name=self.pod_name,
- namespace=self.pod_namespace,
- container_name=self.base_container_name,
- poll_interval=self.poll_interval,
- )
- self.log.info("Checking if xcom sidecar container is started.")
- await self.hook.wait_until_container_started(
- name=self.pod_name,
- namespace=self.pod_namespace,
- container_name=PodDefaults.SIDECAR_CONTAINER_NAME,
- poll_interval=self.poll_interval,
- )
- self.log.info("Extracting result from xcom sidecar container.")
- loop = asyncio.get_running_loop()
- xcom_result = await loop.run_in_executor(None,
self.pod_manager.extract_xcom, pod)
+ xcom_results = []
+ for pod_name in self.pod_names:
+ pod = await self.hook.get_pod(name=pod_name,
namespace=self.pod_namespace)
+ await self.hook.wait_until_container_complete(
+ name=pod_name,
+ namespace=self.pod_namespace,
+ container_name=self.base_container_name,
+ poll_interval=self.poll_interval,
+ )
+ self.log.info("Checking if xcom sidecar container is started.")
+ await self.hook.wait_until_container_started(
+ name=pod_name,
+ namespace=self.pod_namespace,
+ container_name=PodDefaults.SIDECAR_CONTAINER_NAME,
+ poll_interval=self.poll_interval,
+ )
+ self.log.info("Extracting result from xcom sidecar container.")
+ loop = asyncio.get_running_loop()
+ xcom_result = await loop.run_in_executor(None,
self.pod_manager.extract_xcom, pod)
+ xcom_results.append(xcom_result)
job: V1Job = await self.hook.wait_until_job_complete(
name=self.job_name, namespace=self.job_namespace,
poll_interval=self.poll_interval
)
@@ -345,12 +363,12 @@ class GKEJobTrigger(BaseTrigger):
{
"name": job.metadata.name,
"namespace": job.metadata.namespace,
- "pod_name": pod.metadata.name if self.get_logs else None,
- "pod_namespace": pod.metadata.namespace if self.get_logs else
None,
+ "pod_names": [pod_name for pod_name in self.pod_names] if
self.get_logs else None,
+ "pod_namespace": self.pod_namespace if self.get_logs else None,
"status": status,
"message": message,
"job": job_dict,
- "xcom_result": xcom_result if self.do_xcom_push else None,
+ "xcom_result": xcom_results if self.do_xcom_push else None,
}
)
diff --git
a/providers/google/tests/system/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py
b/providers/google/tests/system/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py
index 578a8e62b84..ccf1fbabeaf 100644
---
a/providers/google/tests/system/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py
+++
b/providers/google/tests/system/google/cloud/kubernetes_engine/example_kubernetes_engine_job.py
@@ -48,8 +48,13 @@ CLUSTER = {"name": CLUSTER_NAME, "initial_node_count": 1}
JOB_NAME = "test-pi"
JOB_NAME_DEF = "test-pi-def"
+JOB_NAME_WITH_PARALLELISM = "test-pi-with-parallelism"
+JOB_NAME_DEF_WITH_PARALLELISM = "test-pi-def-with-parallelism"
JOB_NAMESPACE = "default"
+PARALLELISM = 2
+COMPLETION_MODE = "Indexed"
+
with DAG(
DAG_ID,
schedule="@once", # Override to match your needs
@@ -92,6 +97,45 @@ with DAG(
)
# [END howto_operator_gke_start_job_def]
+ # [START howto_operator_gke_start_job_parallelism]
+ job_task_with_parallelism = GKEStartJobOperator(
+ task_id="job_task_with_parallelism",
+ project_id=GCP_PROJECT_ID,
+ location=GCP_LOCATION,
+ cluster_name=CLUSTER_NAME,
+ namespace=JOB_NAMESPACE,
+ image="perl:5.34.0",
+ cmds=["perl", "-Mbignum=bpi", "-wle", "print bpi(2000)"],
+ name=JOB_NAME_WITH_PARALLELISM,
+ wait_until_job_complete=True,
+ parallelism=PARALLELISM,
+ completions=PARALLELISM,
+ completion_mode=COMPLETION_MODE,
+ get_logs=True,
+ do_xcom_push=True,
+ )
+ # [END howto_operator_gke_start_job_with_parallelism]
+
+ # [START howto_operator_gke_start_job_def_with_parallelism]
+ job_task_def_with_parallelism = GKEStartJobOperator(
+ task_id="job_task_def_with_parallelism",
+ project_id=GCP_PROJECT_ID,
+ location=GCP_LOCATION,
+ cluster_name=CLUSTER_NAME,
+ namespace=JOB_NAMESPACE,
+ image="perl:5.34.0",
+ cmds=["perl", "-Mbignum=bpi", "-wle", "print bpi(2000)"],
+ name=JOB_NAME_DEF_WITH_PARALLELISM,
+ wait_until_job_complete=True,
+ deferrable=True,
+ parallelism=PARALLELISM,
+ completions=PARALLELISM,
+ completion_mode=COMPLETION_MODE,
+ get_logs=True,
+ do_xcom_push=True,
+ )
+ # [END howto_operator_gke_start_job_def_with_parallelism]
+
# [START howto_operator_gke_list_jobs]
list_job_task = GKEListJobsOperator(
task_id="list_job_task", project_id=GCP_PROJECT_ID,
location=GCP_LOCATION, cluster_name=CLUSTER_NAME
@@ -104,7 +148,7 @@ with DAG(
project_id=GCP_PROJECT_ID,
location=GCP_LOCATION,
job_name=job_task.output["job_name"],
- namespace="default",
+ namespace=JOB_NAMESPACE,
cluster_name=CLUSTER_NAME,
)
# [END howto_operator_gke_describe_job]
@@ -114,7 +158,25 @@ with DAG(
project_id=GCP_PROJECT_ID,
location=GCP_LOCATION,
job_name=job_task_def.output["job_name"],
- namespace="default",
+ namespace=JOB_NAMESPACE,
+ cluster_name=CLUSTER_NAME,
+ )
+
+ describe_job_with_parallelism_task = GKEDescribeJobOperator(
+ task_id="describe_job_with_parallelism_task",
+ project_id=GCP_PROJECT_ID,
+ location=GCP_LOCATION,
+ job_name=job_task_with_parallelism.output["job_name"],
+ namespace=JOB_NAMESPACE,
+ cluster_name=CLUSTER_NAME,
+ )
+
+ describe_job_task_def_with_parallelism = GKEDescribeJobOperator(
+ task_id="describe_job_task_def_with_parallelism",
+ project_id=GCP_PROJECT_ID,
+ location=GCP_LOCATION,
+ job_name=job_task_def_with_parallelism.output["job_name"],
+ namespace=JOB_NAMESPACE,
cluster_name=CLUSTER_NAME,
)
@@ -125,7 +187,7 @@ with DAG(
location=GCP_LOCATION,
cluster_name=CLUSTER_NAME,
name=job_task.output["job_name"],
- namespace="default",
+ namespace=JOB_NAMESPACE,
)
# [END howto_operator_gke_suspend_job]
@@ -136,7 +198,7 @@ with DAG(
location=GCP_LOCATION,
cluster_name=CLUSTER_NAME,
name=job_task.output["job_name"],
- namespace="default",
+ namespace=JOB_NAMESPACE,
)
# [END howto_operator_gke_resume_job]
@@ -156,7 +218,25 @@ with DAG(
project_id=GCP_PROJECT_ID,
location=GCP_LOCATION,
cluster_name=CLUSTER_NAME,
- name=JOB_NAME,
+ name=JOB_NAME_DEF,
+ namespace=JOB_NAMESPACE,
+ )
+
+ delete_job_with_parallelism = GKEDeleteJobOperator(
+ task_id="delete_job_with_parallelism",
+ project_id=GCP_PROJECT_ID,
+ location=GCP_LOCATION,
+ cluster_name=CLUSTER_NAME,
+ name=JOB_NAME_WITH_PARALLELISM,
+ namespace=JOB_NAMESPACE,
+ )
+
+ delete_job_def_with_parallelism = GKEDeleteJobOperator(
+ task_id="delete_job_def_with_parallelism",
+ project_id=GCP_PROJECT_ID,
+ location=GCP_LOCATION,
+ cluster_name=CLUSTER_NAME,
+ name=JOB_NAME_DEF_WITH_PARALLELISM,
namespace=JOB_NAMESPACE,
)
@@ -170,12 +250,17 @@ with DAG(
chain(
create_cluster,
- [job_task, job_task_def],
+ [job_task, job_task_def, job_task_with_parallelism,
job_task_def_with_parallelism],
list_job_task,
- [describe_job_task, describe_job_task_def],
+ [
+ describe_job_task,
+ describe_job_task_def,
+ describe_job_with_parallelism_task,
+ describe_job_task_def_with_parallelism,
+ ],
suspend_job,
resume_job,
- [delete_job, delete_job_def],
+ [delete_job, delete_job_def, delete_job_with_parallelism,
delete_job_def_with_parallelism],
delete_cluster,
)
diff --git
a/providers/google/tests/unit/google/cloud/operators/test_kubernetes_engine.py
b/providers/google/tests/unit/google/cloud/operators/test_kubernetes_engine.py
index a8fafd6af5d..3c3775025e1 100644
---
a/providers/google/tests/unit/google/cloud/operators/test_kubernetes_engine.py
+++
b/providers/google/tests/unit/google/cloud/operators/test_kubernetes_engine.py
@@ -862,7 +862,9 @@ class TestGKEStartJobOperator:
mock_pod_metadata = mock.MagicMock()
mock_pod_metadata.name = K8S_POD_NAME
mock_pod_metadata.namespace = K8S_NAMESPACE
- self.operator.pod = mock.MagicMock(metadata=mock_pod_metadata)
+ self.operator.pods = [
+ mock.MagicMock(metadata=mock_pod_metadata),
+ ]
mock_job_metadata = mock.MagicMock()
mock_job_metadata.name = K8S_JOB_NAME
@@ -880,7 +882,69 @@ class TestGKEStartJobOperator:
ssl_ca_cert=GKE_SSL_CA_CERT,
job_name=K8S_JOB_NAME,
job_namespace=K8S_NAMESPACE,
- pod_name=K8S_POD_NAME,
+ pod_names=[
+ K8S_POD_NAME,
+ ],
+ pod_namespace=K8S_NAMESPACE,
+ base_container_name="base",
+ gcp_conn_id=TEST_CONN_ID,
+ poll_interval=10.0,
+ impersonation_chain=TEST_IMPERSONATION_CHAIN,
+ get_logs=mock_get_logs,
+ do_xcom_push=False,
+ )
+ mock_defer.assert_called_once_with(
+ trigger=mock_trigger.return_value,
+ method_name="execute_complete",
+ )
+
+ @mock.patch(GKE_OPERATORS_PATH.format("GKEStartJobOperator.defer"))
+
@mock.patch(GKE_OPERATORS_PATH.format("GKEClusterAuthDetails.fetch_cluster_info"))
+ @mock.patch(GKE_OPERATORS_PATH.format("GKEHook"))
+ @mock.patch(GKE_OPERATORS_PATH.format("GKEJobTrigger"))
+ def test_execute_deferrable_with_parallelism(
+ self, mock_trigger, mock_cluster_hook, mock_fetch_cluster_info,
mock_defer
+ ):
+ op = GKEStartJobOperator(
+ project_id=TEST_PROJECT_ID,
+ location=TEST_LOCATION,
+ cluster_name=GKE_CLUSTER_NAME,
+ task_id=TEST_TASK_ID,
+ name=K8S_JOB_NAME,
+ namespace=K8S_NAMESPACE,
+ image=TEST_IMAGE,
+ gcp_conn_id=TEST_CONN_ID,
+ impersonation_chain=TEST_IMPERSONATION_CHAIN,
+ parallelism=2,
+ )
+ mock_pod_name_1 = K8S_POD_NAME + "-1"
+ mock_pod_metadata_1 = mock.MagicMock()
+ mock_pod_metadata_1.name = mock_pod_name_1
+ mock_pod_metadata_1.namespace = K8S_NAMESPACE
+
+ mock_pod_name_2 = K8S_POD_NAME + "-2"
+ mock_pod_metadata_2 = mock.MagicMock()
+ mock_pod_metadata_2.name = mock_pod_name_2
+ mock_pod_metadata_2.namespace = K8S_NAMESPACE
+ op.pods = [mock.MagicMock(metadata=mock_pod_metadata_1),
mock.MagicMock(metadata=mock_pod_metadata_2)]
+
+ mock_job_metadata = mock.MagicMock()
+ mock_job_metadata.name = K8S_JOB_NAME
+ mock_job_metadata.namespace = K8S_NAMESPACE
+ op.job = mock.MagicMock(metadata=mock_job_metadata)
+
+ mock_fetch_cluster_info.return_value = GKE_CLUSTER_URL, GKE_SSL_CA_CERT
+ mock_get_logs = mock.MagicMock()
+ op.get_logs = mock_get_logs
+
+ op.execute_deferrable()
+
+ mock_trigger.assert_called_once_with(
+ cluster_url=GKE_CLUSTER_URL,
+ ssl_ca_cert=GKE_SSL_CA_CERT,
+ job_name=K8S_JOB_NAME,
+ job_namespace=K8S_NAMESPACE,
+ pod_names=[mock_pod_name_1, mock_pod_name_2],
pod_namespace=K8S_NAMESPACE,
base_container_name="base",
gcp_conn_id=TEST_CONN_ID,
diff --git
a/providers/google/tests/unit/google/cloud/triggers/test_kubernetes_engine.py
b/providers/google/tests/unit/google/cloud/triggers/test_kubernetes_engine.py
index ffe749ef397..3704c5c3206 100644
---
a/providers/google/tests/unit/google/cloud/triggers/test_kubernetes_engine.py
+++
b/providers/google/tests/unit/google/cloud/triggers/test_kubernetes_engine.py
@@ -98,7 +98,9 @@ def job_trigger():
ssl_ca_cert=SSL_CA_CERT,
job_name=JOB_NAME,
job_namespace=NAMESPACE,
- pod_name=POD_NAME,
+ pod_names=[
+ POD_NAME,
+ ],
pod_namespace=NAMESPACE,
base_container_name=BASE_CONTAINER_NAME,
gcp_conn_id=GCP_CONN_ID,
@@ -482,7 +484,9 @@ class TestGKEStartJobTrigger:
"ssl_ca_cert": SSL_CA_CERT,
"job_name": JOB_NAME,
"job_namespace": NAMESPACE,
- "pod_name": POD_NAME,
+ "pod_names": [
+ POD_NAME,
+ ],
"pod_namespace": NAMESPACE,
"base_container_name": BASE_CONTAINER_NAME,
"gcp_conn_id": GCP_CONN_ID,
@@ -521,7 +525,9 @@ class TestGKEStartJobTrigger:
{
"name": JOB_NAME,
"namespace": NAMESPACE,
- "pod_name": POD_NAME,
+ "pod_names": [
+ POD_NAME,
+ ],
"pod_namespace": NAMESPACE,
"status": "success",
"message": "Job completed successfully",
@@ -559,7 +565,9 @@ class TestGKEStartJobTrigger:
{
"name": JOB_NAME,
"namespace": NAMESPACE,
- "pod_name": POD_NAME,
+ "pod_names": [
+ POD_NAME,
+ ],
"pod_namespace": NAMESPACE,
"status": "error",
"message": "Job failed with error: Error",