This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 506803070e6 fix: spark operator label (#45353)
506803070e6 is described below

commit 506803070e61c00cef7d4d6ae8e02324d2b3c2ef
Author: Chongchen Chen <chenkov...@qq.com>
AuthorDate: Thu Feb 6 06:24:20 2025 +0800

    fix: spark operator label (#45353)
    
    * fix: spark operator label
    
    * update spark operator
    
    * update spark kube
    
    * make ci happy
    
    * update test
    
    * format
    
    * format
---
 .../cncf/kubernetes/operators/spark_kubernetes.py  | 26 +++++++++----------
 .../kubernetes/operators/test_spark_kubernetes.py  | 29 ++++++++++++++++++++++
 2 files changed, 41 insertions(+), 14 deletions(-)

diff --git 
a/providers/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py 
b/providers/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
index c0b90ebacb9..583388c6d14 100644
--- 
a/providers/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
+++ 
b/providers/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py
@@ -19,7 +19,7 @@ from __future__ import annotations
 
 from functools import cached_property
 from pathlib import Path
-from typing import TYPE_CHECKING, Any
+from typing import TYPE_CHECKING, Any, cast
 
 from kubernetes.client import CoreV1Api, CustomObjectsApi, models as k8s
 
@@ -177,12 +177,7 @@ class SparkKubernetesOperator(KubernetesPodOperator):
         return self._set_name(updated_name)
 
     @staticmethod
-    def _get_pod_identifying_label_string(labels) -> str:
-        filtered_labels = {label_id: label for label_id, label in 
labels.items() if label_id != "try_number"}
-        return ",".join([label_id + "=" + label for label_id, label in 
sorted(filtered_labels.items())])
-
-    @staticmethod
-    def create_labels_for_pod(context: dict | None = None, include_try_number: 
bool = True) -> dict:
+    def _get_ti_pod_labels(context: Context | None = None, include_try_number: 
bool = True) -> dict[str, str]:
         """
         Generate labels for the pod to track the pod in case of Operator crash.
 
@@ -193,8 +188,9 @@ class SparkKubernetesOperator(KubernetesPodOperator):
         if not context:
             return {}
 
-        ti = context["ti"]
-        run_id = context["run_id"]
+        context_dict = cast(dict, context)
+        ti = context_dict["ti"]
+        run_id = context_dict["run_id"]
 
         labels = {
             "dag_id": ti.dag_id,
@@ -213,8 +209,8 @@ class SparkKubernetesOperator(KubernetesPodOperator):
 
         # In the case of sub dags this is just useful
         # TODO: Remove this when the minimum version of Airflow is bumped to 
3.0
-        if getattr(context["dag"], "is_subdag", False):
-            labels["parent_dag_id"] = context["dag"].parent_dag.dag_id
+        if getattr(context_dict["dag"], "is_subdag", False):
+            labels["parent_dag_id"] = context_dict["dag"].parent_dag.dag_id
         # Ensure that label is valid for Kube,
         # and if not truncate/remove invalid chars and replace with short hash.
         for label_id, label in labels.items():
@@ -235,9 +231,11 @@ class SparkKubernetesOperator(KubernetesPodOperator):
         """Templated body for CustomObjectLauncher."""
         return self.manage_template_specs()
 
-    def find_spark_job(self, context):
-        labels = self.create_labels_for_pod(context, include_try_number=False)
-        label_selector = self._get_pod_identifying_label_string(labels) + 
",spark-role=driver"
+    def find_spark_job(self, context, exclude_checked: bool = True):
+        label_selector = (
+            self._build_find_pod_label_selector(context, 
exclude_checked=exclude_checked)
+            + ",spark-role=driver"
+        )
         pod_list = self.client.list_namespaced_pod(self.namespace, 
label_selector=label_selector).items
 
         pod = None
diff --git a/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py 
b/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py
index e1e7e85bcda..c6814222c47 100644
--- a/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py
+++ b/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py
@@ -701,6 +701,35 @@ class TestSparkKubernetesOperator:
             follow_logs=True,
         )
 
+    def test_find_custom_pod_labels(
+        self,
+        mock_create_namespaced_crd,
+        mock_get_namespaced_custom_object_status,
+        mock_cleanup,
+        mock_create_job_name,
+        mock_get_kube_client,
+        mock_create_pod,
+        mock_await_pod_start,
+        mock_await_pod_completion,
+        mock_fetch_requested_container_logs,
+        data_file,
+    ):
+        task_name = "test_find_custom_pod_labels"
+        job_spec = 
yaml.safe_load(data_file("spark/application_template.yaml").read_text())
+
+        mock_create_job_name.return_value = task_name
+        op = SparkKubernetesOperator(
+            template_spec=job_spec,
+            kubernetes_conn_id="kubernetes_default_kube_config",
+            task_id=task_name,
+            get_logs=True,
+        )
+        context = create_context(op)
+        op.execute(context)
+        label_selector = op._build_find_pod_label_selector(context) + 
",spark-role=driver"
+        op.find_spark_job(context)
+        mock_get_kube_client.list_namespaced_pod.assert_called_with("default", 
label_selector=label_selector)
+
 
 @pytest.mark.db_test
 def test_template_body_templating(create_task_instance_of_operator, session):

Reply via email to