This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 506803070e6 fix: spark operator label (#45353) 506803070e6 is described below commit 506803070e61c00cef7d4d6ae8e02324d2b3c2ef Author: Chongchen Chen <chenkov...@qq.com> AuthorDate: Thu Feb 6 06:24:20 2025 +0800 fix: spark operator label (#45353) * fix: spark operator label * update spark operator * update spark kube * make ci happy * update test * format * format --- .../cncf/kubernetes/operators/spark_kubernetes.py | 26 +++++++++---------- .../kubernetes/operators/test_spark_kubernetes.py | 29 ++++++++++++++++++++++ 2 files changed, 41 insertions(+), 14 deletions(-) diff --git a/providers/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py b/providers/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py index c0b90ebacb9..583388c6d14 100644 --- a/providers/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py +++ b/providers/src/airflow/providers/cncf/kubernetes/operators/spark_kubernetes.py @@ -19,7 +19,7 @@ from __future__ import annotations from functools import cached_property from pathlib import Path -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from kubernetes.client import CoreV1Api, CustomObjectsApi, models as k8s @@ -177,12 +177,7 @@ class SparkKubernetesOperator(KubernetesPodOperator): return self._set_name(updated_name) @staticmethod - def _get_pod_identifying_label_string(labels) -> str: - filtered_labels = {label_id: label for label_id, label in labels.items() if label_id != "try_number"} - return ",".join([label_id + "=" + label for label_id, label in sorted(filtered_labels.items())]) - - @staticmethod - def create_labels_for_pod(context: dict | None = None, include_try_number: bool = True) -> dict: + def _get_ti_pod_labels(context: Context | None = None, include_try_number: bool = True) -> dict[str, str]: """ Generate labels for the pod to track the pod in case of Operator crash. @@ -193,8 +188,9 @@ class SparkKubernetesOperator(KubernetesPodOperator): if not context: return {} - ti = context["ti"] - run_id = context["run_id"] + context_dict = cast(dict, context) + ti = context_dict["ti"] + run_id = context_dict["run_id"] labels = { "dag_id": ti.dag_id, @@ -213,8 +209,8 @@ class SparkKubernetesOperator(KubernetesPodOperator): # In the case of sub dags this is just useful # TODO: Remove this when the minimum version of Airflow is bumped to 3.0 - if getattr(context["dag"], "is_subdag", False): - labels["parent_dag_id"] = context["dag"].parent_dag.dag_id + if getattr(context_dict["dag"], "is_subdag", False): + labels["parent_dag_id"] = context_dict["dag"].parent_dag.dag_id # Ensure that label is valid for Kube, # and if not truncate/remove invalid chars and replace with short hash. for label_id, label in labels.items(): @@ -235,9 +231,11 @@ class SparkKubernetesOperator(KubernetesPodOperator): """Templated body for CustomObjectLauncher.""" return self.manage_template_specs() - def find_spark_job(self, context): - labels = self.create_labels_for_pod(context, include_try_number=False) - label_selector = self._get_pod_identifying_label_string(labels) + ",spark-role=driver" + def find_spark_job(self, context, exclude_checked: bool = True): + label_selector = ( + self._build_find_pod_label_selector(context, exclude_checked=exclude_checked) + + ",spark-role=driver" + ) pod_list = self.client.list_namespaced_pod(self.namespace, label_selector=label_selector).items pod = None diff --git a/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py b/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py index e1e7e85bcda..c6814222c47 100644 --- a/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py +++ b/providers/tests/cncf/kubernetes/operators/test_spark_kubernetes.py @@ -701,6 +701,35 @@ class TestSparkKubernetesOperator: follow_logs=True, ) + def test_find_custom_pod_labels( + self, + mock_create_namespaced_crd, + mock_get_namespaced_custom_object_status, + mock_cleanup, + mock_create_job_name, + mock_get_kube_client, + mock_create_pod, + mock_await_pod_start, + mock_await_pod_completion, + mock_fetch_requested_container_logs, + data_file, + ): + task_name = "test_find_custom_pod_labels" + job_spec = yaml.safe_load(data_file("spark/application_template.yaml").read_text()) + + mock_create_job_name.return_value = task_name + op = SparkKubernetesOperator( + template_spec=job_spec, + kubernetes_conn_id="kubernetes_default_kube_config", + task_id=task_name, + get_logs=True, + ) + context = create_context(op) + op.execute(context) + label_selector = op._build_find_pod_label_selector(context) + ",spark-role=driver" + op.find_spark_job(context) + mock_get_kube_client.list_namespaced_pod.assert_called_with("default", label_selector=label_selector) + @pytest.mark.db_test def test_template_body_templating(create_task_instance_of_operator, session):