This is an automated email from the ASF dual-hosted git repository.
jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new aee89644d9a Ensure deterministic Spark driver pod selection during
reattach. Added unit tests. (#60717)
aee89644d9a is described below
commit aee89644d9a7702de179dd75f09f9b6179d20270
Author: SameerMesiah97 <[email protected]>
AuthorDate: Mon Jan 26 21:10:50 2026 +0000
Ensure deterministic Spark driver pod selection during reattach. Added unit
tests. (#60717)
Co-authored-by: Sameer Mesiah <[email protected]>
---
.../cncf/kubernetes/operators/spark_kubernetes.py | 32 +++-
.../kubernetes/operators/test_spark_kubernetes.py | 165 +++++++++++++++++++++
2 files changed, 194 insertions(+), 3 deletions(-)
diff --git
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
index 5e1b38305b0..a01f05ae82d 100644
---
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
+++
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
@@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations
+from datetime import datetime, timezone
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING, Any, cast
@@ -29,7 +30,7 @@ from
airflow.providers.cncf.kubernetes.kubernetes_helper_functions import add_un
from airflow.providers.cncf.kubernetes.operators.custom_object_launcher import
CustomObjectLauncher
from airflow.providers.cncf.kubernetes.operators.pod import
KubernetesPodOperator
from airflow.providers.cncf.kubernetes.pod_generator import MAX_LABEL_LEN,
PodGenerator
-from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager
+from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager,
PodPhase
from airflow.providers.common.compat.sdk import AirflowException
from airflow.utils.helpers import prune_dict
@@ -235,6 +236,14 @@ class SparkKubernetesOperator(KubernetesPodOperator):
return self.manage_template_specs()
def find_spark_job(self, context, exclude_checked: bool = True):
+ """
+ Find an existing Spark driver pod for this task instance.
+
+ The pod is identified using Airflow task context labels. If multiple
+ driver pods match the same labels (which can occur if cleanup did not
+ run after an abrupt failure), a single pod is selected
deterministically
+ for reattachment, preferring a Running driver pod when present.
+ """
label_selector = (
self._build_find_pod_label_selector(context,
exclude_checked=exclude_checked)
+ ",spark-role=driver"
@@ -242,8 +251,25 @@ class SparkKubernetesOperator(KubernetesPodOperator):
pod_list = self.client.list_namespaced_pod(self.namespace,
label_selector=label_selector).items
pod = None
- if len(pod_list) > 1: # and self.reattach_on_restart:
- raise AirflowException(f"More than one pod running with labels:
{label_selector}")
+ if len(pod_list) > 1:
+ # When multiple pods match the same labels, select one
deterministically,
+ # preferring a Running pod, then creation time, with name as a
tie-breaker.
+ pod = max(
+ pod_list,
+ key=lambda p: (
+ p.status.phase == PodPhase.RUNNING,
+ p.metadata.creation_timestamp or
datetime.min.replace(tzinfo=timezone.utc),
+ p.metadata.name or "",
+ ),
+ )
+ self.log.warning(
+ "Found %d Spark driver pods matching labels %s; "
+ "selecting pod %s for reattachment based on status and
creation time.",
+ len(pod_list),
+ label_selector,
+ pod.metadata.name,
+ )
+
if len(pod_list) == 1:
pod = pod_list[0]
self.log.info(
diff --git
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py
index bf573ba51b3..1e7f5c1da23 100644
---
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py
+++
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py
@@ -37,6 +37,7 @@ from airflow.models import Connection, DagRun, TaskInstance
from airflow.providers.cncf.kubernetes.operators.spark_kubernetes import
SparkKubernetesOperator
from airflow.providers.cncf.kubernetes.pod_generator import MAX_LABEL_LEN
from airflow.providers.cncf.kubernetes.triggers.pod import KubernetesPodTrigger
+from airflow.providers.cncf.kubernetes.utils.pod_manager import PodPhase
from airflow.providers.common.compat.sdk import TaskDeferred
from airflow.utils import timezone
from airflow.utils.types import DagRunType
@@ -944,6 +945,170 @@ class TestSparkKubernetesOperator:
mock_create_namespaced_crd.assert_not_called()
+ def test_find_spark_job_picks_running_pod(
+ self,
+ mock_is_in_cluster,
+ mock_parent_execute,
+ mock_create_namespaced_crd,
+ mock_get_namespaced_custom_object_status,
+ mock_cleanup,
+ mock_create_job_name,
+ mock_get_kube_client,
+ mock_create_pod,
+ mock_await_pod_completion,
+ mock_fetch_requested_container_logs,
+ data_file,
+ ):
+ """
+ Verifies that find_spark_job picks a Running Spark driver pod over a
non-Running pod.
+ """
+
+ task_name = "test_find_spark_job_prefers_running_pod"
+ job_spec =
yaml.safe_load(data_file("spark/application_template.yaml").read_text())
+
+ mock_create_job_name.return_value = task_name
+ op = SparkKubernetesOperator(
+ template_spec=job_spec,
+ kubernetes_conn_id="kubernetes_default_kube_config",
+ task_id=task_name,
+ get_logs=True,
+ reattach_on_restart=True,
+ )
+ context = create_context(op)
+
+ # Running pod should be selected.
+ running_pod = mock.MagicMock()
+ running_pod.metadata.creation_timestamp = timezone.datetime(2025, 1,
1, tzinfo=timezone.utc)
+ running_pod.metadata.name = "spark-driver-running"
+ running_pod.metadata.labels = {"try_number": "1"}
+ running_pod.status.phase = "Running"
+
+ # Pending pod should not be selected.
+ pending_pod = mock.MagicMock()
+ pending_pod.metadata.creation_timestamp = timezone.datetime(2025, 1,
1, tzinfo=timezone.utc)
+ pending_pod.metadata.name = "spark-driver-pending"
+ pending_pod.metadata.labels = {"try_number": "1"}
+ pending_pod.status.phase = "Pending"
+
+ mock_get_kube_client.list_namespaced_pod.return_value.items = [
+ running_pod,
+ pending_pod,
+ ]
+
+ returned_pod = op.find_spark_job(context)
+
+ assert returned_pod is running_pod
+
+ def test_find_spark_job_picks_latest_pod(
+ self,
+ mock_is_in_cluster,
+ mock_parent_execute,
+ mock_create_namespaced_crd,
+ mock_get_namespaced_custom_object_status,
+ mock_cleanup,
+ mock_create_job_name,
+ mock_get_kube_client,
+ mock_create_pod,
+ mock_await_pod_completion,
+ mock_fetch_requested_container_logs,
+ data_file,
+ ):
+ """
+ Verifies that find_spark_job selects the most recently created Spark
driver pod
+ when multiple candidate driver pods are present and status does not
disambiguate.
+ """
+
+ task_name = "test_find_spark_job_picks_latest_pod"
+ job_spec =
yaml.safe_load(data_file("spark/application_template.yaml").read_text())
+
+ mock_create_job_name.return_value = task_name
+ op = SparkKubernetesOperator(
+ template_spec=job_spec,
+ kubernetes_conn_id="kubernetes_default_kube_config",
+ task_id=task_name,
+ get_logs=True,
+ reattach_on_restart=True,
+ )
+ context = create_context(op)
+
+ # Older pod that should be ignored.
+ old_mock_pod = mock.MagicMock()
+ old_mock_pod.metadata.creation_timestamp = timezone.datetime(2025, 1,
1, tzinfo=timezone.utc)
+ old_mock_pod.metadata.name = "spark-driver-old"
+ old_mock_pod.status.phase = PodPhase.RUNNING
+
+ # Newer pod that should be picked up.
+ new_mock_pod = mock.MagicMock()
+ new_mock_pod.metadata.creation_timestamp = timezone.datetime(2025, 1,
2, tzinfo=timezone.utc)
+ new_mock_pod.metadata.name = "spark-driver-new"
+ new_mock_pod.status.phase = PodPhase.RUNNING
+
+ # Same try_number to simulate abrupt failure scenarios (e.g. scheduler
crash)
+ # where cleanup did not occur and multiple pods share identical labels.
+ old_mock_pod.metadata.labels = {"try_number": "1"}
+ new_mock_pod.metadata.labels = {"try_number": "1"}
+
+ mock_get_kube_client.list_namespaced_pod.return_value.items =
[old_mock_pod, new_mock_pod]
+
+ returned_pod = op.find_spark_job(context)
+
+ assert returned_pod is new_mock_pod
+
+ def test_find_spark_job_tiebreaks_by_name(
+ self,
+ mock_is_in_cluster,
+ mock_parent_execute,
+ mock_create_namespaced_crd,
+ mock_get_namespaced_custom_object_status,
+ mock_cleanup,
+ mock_create_job_name,
+ mock_get_kube_client,
+ mock_create_pod,
+ mock_await_pod_completion,
+ mock_fetch_requested_container_logs,
+ data_file,
+ ):
+ """
+ Verifies that find_spark_job uses pod name as a deterministic
tie-breaker
+ when multiple running Spark driver pods share the same
creation_timestamp.
+ """
+
+ task_name = "test_find_spark_job_tiebreaks_by_name"
+ job_spec =
yaml.safe_load(data_file("spark/application_template.yaml").read_text())
+
+ mock_create_job_name.return_value = task_name
+ op = SparkKubernetesOperator(
+ template_spec=job_spec,
+ kubernetes_conn_id="kubernetes_default_kube_config",
+ task_id=task_name,
+ get_logs=True,
+ reattach_on_restart=True,
+ )
+ context = create_context(op)
+
+ # Use identical creation timestamps to force name-based tie-breaking.
+ ts = timezone.datetime(2025, 1, 1, tzinfo=timezone.utc)
+
+ # Pod with lexicographically smaller name should not be selected.
+ invalid_mock_pod = mock.MagicMock()
+ invalid_mock_pod.metadata.creation_timestamp = ts
+ invalid_mock_pod.metadata.name = "spark-driver-abc"
+ invalid_mock_pod.metadata.labels = {"try_number": "1"}
+ invalid_mock_pod.status.phase = PodPhase.RUNNING
+
+ # Pod with lexicographically greater name should be selected.
+ valid_mock_pod = mock.MagicMock()
+ valid_mock_pod.metadata.creation_timestamp = ts
+ valid_mock_pod.metadata.name = "spark-driver-xyz"
+ valid_mock_pod.metadata.labels = {"try_number": "1"}
+ valid_mock_pod.status.phase = PodPhase.RUNNING
+
+ mock_get_kube_client.list_namespaced_pod.return_value.items =
[invalid_mock_pod, valid_mock_pod]
+
+ returned_pod = op.find_spark_job(context)
+
+ assert returned_pod is valid_mock_pod
+
@pytest.mark.asyncio
def test_execute_deferrable(
self,