This is an automated email from the ASF dual-hosted git repository.

jscheffl pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new aee89644d9a Ensure deterministic Spark driver pod selection during 
reattach. Added unit tests. (#60717)
aee89644d9a is described below

commit aee89644d9a7702de179dd75f09f9b6179d20270
Author: SameerMesiah97 <[email protected]>
AuthorDate: Mon Jan 26 21:10:50 2026 +0000

    Ensure deterministic Spark driver pod selection during reattach. Added unit 
tests. (#60717)
    
    Co-authored-by: Sameer Mesiah <[email protected]>
---
 .../cncf/kubernetes/operators/spark_kubernetes.py  |  32 +++-
 .../kubernetes/operators/test_spark_kubernetes.py  | 165 +++++++++++++++++++++
 2 files changed, 194 insertions(+), 3 deletions(-)

diff --git 
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
 
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
index 5e1b38305b0..a01f05ae82d 100644
--- 
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
+++ 
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
@@ -17,6 +17,7 @@
 # under the License.
 from __future__ import annotations
 
+from datetime import datetime, timezone
 from functools import cached_property
 from pathlib import Path
 from typing import TYPE_CHECKING, Any, cast
@@ -29,7 +30,7 @@ from 
airflow.providers.cncf.kubernetes.kubernetes_helper_functions import add_un
 from airflow.providers.cncf.kubernetes.operators.custom_object_launcher import 
CustomObjectLauncher
 from airflow.providers.cncf.kubernetes.operators.pod import 
KubernetesPodOperator
 from airflow.providers.cncf.kubernetes.pod_generator import MAX_LABEL_LEN, 
PodGenerator
-from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager
+from airflow.providers.cncf.kubernetes.utils.pod_manager import PodManager, 
PodPhase
 from airflow.providers.common.compat.sdk import AirflowException
 from airflow.utils.helpers import prune_dict
 
@@ -235,6 +236,14 @@ class SparkKubernetesOperator(KubernetesPodOperator):
         return self.manage_template_specs()
 
     def find_spark_job(self, context, exclude_checked: bool = True):
+        """
+        Find an existing Spark driver pod for this task instance.
+
+        The pod is identified using Airflow task context labels. If multiple
+        driver pods match the same labels (which can occur if cleanup did not
+        run after an abrupt failure), a single pod is selected 
deterministically
+        for reattachment, preferring a Running driver pod when present.
+        """
         label_selector = (
             self._build_find_pod_label_selector(context, 
exclude_checked=exclude_checked)
             + ",spark-role=driver"
@@ -242,8 +251,25 @@ class SparkKubernetesOperator(KubernetesPodOperator):
         pod_list = self.client.list_namespaced_pod(self.namespace, 
label_selector=label_selector).items
 
         pod = None
-        if len(pod_list) > 1:  # and self.reattach_on_restart:
-            raise AirflowException(f"More than one pod running with labels: 
{label_selector}")
+        if len(pod_list) > 1:
+            # When multiple pods match the same labels, select one 
deterministically,
+            # preferring a Running pod, then creation time, with name as a 
tie-breaker.
+            pod = max(
+                pod_list,
+                key=lambda p: (
+                    p.status.phase == PodPhase.RUNNING,
+                    p.metadata.creation_timestamp or 
datetime.min.replace(tzinfo=timezone.utc),
+                    p.metadata.name or "",
+                ),
+            )
+            self.log.warning(
+                "Found %d Spark driver pods matching labels %s; "
+                "selecting pod %s for reattachment based on status and 
creation time.",
+                len(pod_list),
+                label_selector,
+                pod.metadata.name,
+            )
+
         if len(pod_list) == 1:
             pod = pod_list[0]
             self.log.info(
diff --git 
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py
 
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py
index bf573ba51b3..1e7f5c1da23 100644
--- 
a/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py
+++ 
b/providers/cncf/kubernetes/tests/unit/cncf/kubernetes/operators/test_spark_kubernetes.py
@@ -37,6 +37,7 @@ from airflow.models import Connection, DagRun, TaskInstance
 from airflow.providers.cncf.kubernetes.operators.spark_kubernetes import 
SparkKubernetesOperator
 from airflow.providers.cncf.kubernetes.pod_generator import MAX_LABEL_LEN
 from airflow.providers.cncf.kubernetes.triggers.pod import KubernetesPodTrigger
+from airflow.providers.cncf.kubernetes.utils.pod_manager import PodPhase
 from airflow.providers.common.compat.sdk import TaskDeferred
 from airflow.utils import timezone
 from airflow.utils.types import DagRunType
@@ -944,6 +945,170 @@ class TestSparkKubernetesOperator:
 
         mock_create_namespaced_crd.assert_not_called()
 
+    def test_find_spark_job_picks_running_pod(
+        self,
+        mock_is_in_cluster,
+        mock_parent_execute,
+        mock_create_namespaced_crd,
+        mock_get_namespaced_custom_object_status,
+        mock_cleanup,
+        mock_create_job_name,
+        mock_get_kube_client,
+        mock_create_pod,
+        mock_await_pod_completion,
+        mock_fetch_requested_container_logs,
+        data_file,
+    ):
+        """
+        Verifies that find_spark_job picks a Running Spark driver pod over a 
non-Running pod.
+        """
+
+        task_name = "test_find_spark_job_prefers_running_pod"
+        job_spec = 
yaml.safe_load(data_file("spark/application_template.yaml").read_text())
+
+        mock_create_job_name.return_value = task_name
+        op = SparkKubernetesOperator(
+            template_spec=job_spec,
+            kubernetes_conn_id="kubernetes_default_kube_config",
+            task_id=task_name,
+            get_logs=True,
+            reattach_on_restart=True,
+        )
+        context = create_context(op)
+
+        # Running pod should be selected.
+        running_pod = mock.MagicMock()
+        running_pod.metadata.creation_timestamp = timezone.datetime(2025, 1, 
1, tzinfo=timezone.utc)
+        running_pod.metadata.name = "spark-driver-running"
+        running_pod.metadata.labels = {"try_number": "1"}
+        running_pod.status.phase = "Running"
+
+        # Pending pod should not be selected.
+        pending_pod = mock.MagicMock()
+        pending_pod.metadata.creation_timestamp = timezone.datetime(2025, 1, 
1, tzinfo=timezone.utc)
+        pending_pod.metadata.name = "spark-driver-pending"
+        pending_pod.metadata.labels = {"try_number": "1"}
+        pending_pod.status.phase = "Pending"
+
+        mock_get_kube_client.list_namespaced_pod.return_value.items = [
+            running_pod,
+            pending_pod,
+        ]
+
+        returned_pod = op.find_spark_job(context)
+
+        assert returned_pod is running_pod
+
+    def test_find_spark_job_picks_latest_pod(
+        self,
+        mock_is_in_cluster,
+        mock_parent_execute,
+        mock_create_namespaced_crd,
+        mock_get_namespaced_custom_object_status,
+        mock_cleanup,
+        mock_create_job_name,
+        mock_get_kube_client,
+        mock_create_pod,
+        mock_await_pod_completion,
+        mock_fetch_requested_container_logs,
+        data_file,
+    ):
+        """
+        Verifies that find_spark_job selects the most recently created Spark 
driver pod
+        when multiple candidate driver pods are present and status does not 
disambiguate.
+        """
+
+        task_name = "test_find_spark_job_picks_latest_pod"
+        job_spec = 
yaml.safe_load(data_file("spark/application_template.yaml").read_text())
+
+        mock_create_job_name.return_value = task_name
+        op = SparkKubernetesOperator(
+            template_spec=job_spec,
+            kubernetes_conn_id="kubernetes_default_kube_config",
+            task_id=task_name,
+            get_logs=True,
+            reattach_on_restart=True,
+        )
+        context = create_context(op)
+
+        # Older pod that should be ignored.
+        old_mock_pod = mock.MagicMock()
+        old_mock_pod.metadata.creation_timestamp = timezone.datetime(2025, 1, 
1, tzinfo=timezone.utc)
+        old_mock_pod.metadata.name = "spark-driver-old"
+        old_mock_pod.status.phase = PodPhase.RUNNING
+
+        # Newer pod that should be picked up.
+        new_mock_pod = mock.MagicMock()
+        new_mock_pod.metadata.creation_timestamp = timezone.datetime(2025, 1, 
2, tzinfo=timezone.utc)
+        new_mock_pod.metadata.name = "spark-driver-new"
+        new_mock_pod.status.phase = PodPhase.RUNNING
+
+        # Same try_number to simulate abrupt failure scenarios (e.g. scheduler 
crash)
+        # where cleanup did not occur and multiple pods share identical labels.
+        old_mock_pod.metadata.labels = {"try_number": "1"}
+        new_mock_pod.metadata.labels = {"try_number": "1"}
+
+        mock_get_kube_client.list_namespaced_pod.return_value.items = 
[old_mock_pod, new_mock_pod]
+
+        returned_pod = op.find_spark_job(context)
+
+        assert returned_pod is new_mock_pod
+
+    def test_find_spark_job_tiebreaks_by_name(
+        self,
+        mock_is_in_cluster,
+        mock_parent_execute,
+        mock_create_namespaced_crd,
+        mock_get_namespaced_custom_object_status,
+        mock_cleanup,
+        mock_create_job_name,
+        mock_get_kube_client,
+        mock_create_pod,
+        mock_await_pod_completion,
+        mock_fetch_requested_container_logs,
+        data_file,
+    ):
+        """
+        Verifies that find_spark_job uses pod name as a deterministic 
tie-breaker
+        when multiple running Spark driver pods share the same 
creation_timestamp.
+        """
+
+        task_name = "test_find_spark_job_tiebreaks_by_name"
+        job_spec = 
yaml.safe_load(data_file("spark/application_template.yaml").read_text())
+
+        mock_create_job_name.return_value = task_name
+        op = SparkKubernetesOperator(
+            template_spec=job_spec,
+            kubernetes_conn_id="kubernetes_default_kube_config",
+            task_id=task_name,
+            get_logs=True,
+            reattach_on_restart=True,
+        )
+        context = create_context(op)
+
+        # Use identical creation timestamps to force name-based tie-breaking.
+        ts = timezone.datetime(2025, 1, 1, tzinfo=timezone.utc)
+
+        # Pod with lexicographically smaller name should not be selected.
+        invalid_mock_pod = mock.MagicMock()
+        invalid_mock_pod.metadata.creation_timestamp = ts
+        invalid_mock_pod.metadata.name = "spark-driver-abc"
+        invalid_mock_pod.metadata.labels = {"try_number": "1"}
+        invalid_mock_pod.status.phase = PodPhase.RUNNING
+
+        # Pod with lexicographically greater name should be selected.
+        valid_mock_pod = mock.MagicMock()
+        valid_mock_pod.metadata.creation_timestamp = ts
+        valid_mock_pod.metadata.name = "spark-driver-xyz"
+        valid_mock_pod.metadata.labels = {"try_number": "1"}
+        valid_mock_pod.status.phase = PodPhase.RUNNING
+
+        mock_get_kube_client.list_namespaced_pod.return_value.items = 
[invalid_mock_pod, valid_mock_pod]
+
+        returned_pod = op.find_spark_job(context)
+
+        assert returned_pod is valid_mock_pod
+
     @pytest.mark.asyncio
     def test_execute_deferrable(
         self,

Reply via email to