This is an automated email from the ASF dual-hosted git repository.

jasonliu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 33467ea8e3a spark-pipelines operator (#61681)
33467ea8e3a is described below

commit 33467ea8e3aae05ee1ea6c79ac6dc9d8b98ccca8
Author: Sandy Ryza <[email protected]>
AuthorDate: Wed Mar 4 00:34:43 2026 -0800

    spark-pipelines operator (#61681)
    
    * spark-pipelines operator
    
    * fixups and feedback
    
    * fix tests
    
    * more
    
    * static checks
---
 providers/apache/spark/docs/operators.rst          |  54 ++++++
 providers/apache/spark/provider.yaml               |   2 +
 .../providers/apache/spark/get_provider_info.py    |   2 +
 .../apache/spark/hooks/spark_pipelines.py          | 113 +++++++++++++
 .../providers/apache/spark/hooks/spark_submit.py   |  75 +++++----
 .../apache/spark/operators/spark_pipelines.py      | 148 +++++++++++++++++
 .../apache/spark/hooks/test_spark_pipelines.py     | 185 +++++++++++++++++++++
 .../apache/spark/operators/test_spark_pipelines.py | 155 +++++++++++++++++
 8 files changed, 702 insertions(+), 32 deletions(-)

diff --git a/providers/apache/spark/docs/operators.rst 
b/providers/apache/spark/docs/operators.rst
index 2e1aad8fb38..125039ebdf3 100644
--- a/providers/apache/spark/docs/operators.rst
+++ b/providers/apache/spark/docs/operators.rst
@@ -31,6 +31,8 @@ Prerequisite
   gets all the configurations from operator parameters.
 * To use 
:class:`~airflow.providers.apache.spark.operators.spark_pyspark.PySparkOperator`
   you can configure :doc:`SparkConnect Connection <connections/spark-connect>`.
+* To use 
:class:`~airflow.providers.apache.spark.operators.spark_pipelines.SparkPipelinesOperator`
+  you must configure :doc:`Spark Connection <connections/spark-submit>` and 
have the ``spark-pipelines`` CLI available.
 
 .. _howto/operator:SparkJDBCOperator:
 
@@ -81,6 +83,58 @@ Reference
 
 For further information, look at `Running the Spark Connect Python 
<https://spark.apache.org/docs/latest/api/python/getting_started/quickstart_connect.html>`_.
 
+.. _howto/operator:SparkPipelinesOperator:
+
+SparkPipelinesOperator
+----------------------
+
+Execute Spark Declarative Pipelines using the ``spark-pipelines`` CLI. This 
operator wraps the spark-pipelines binary to execute declarative data 
pipelines, supporting both pipeline execution and validation through dry-runs.
+
+For parameter definition take a look at 
:class:`~airflow.providers.apache.spark.operators.spark_pipelines.SparkPipelinesOperator`.
+
+Using the operator
+""""""""""""""""""
+
+The operator can be used to run declarative pipelines:
+
+.. code-block:: python
+
+   from airflow.providers.apache.spark.operators.spark_pipelines import 
SparkPipelinesOperator
+
+   # Execute the pipeline
+   run_pipeline = SparkPipelinesOperator(
+       task_id="run_pipeline",
+       pipeline_spec="/path/to/pipeline.yml",
+       pipeline_command="run",
+       conn_id="spark_default",
+       num_executors=2,
+       executor_cores=4,
+       executor_memory="2G",
+       driver_memory="1G",
+   )
+
+**Pipeline Specification**
+
+The ``pipeline_spec`` parameter should point to a YAML file defining your 
declarative pipeline:
+
+.. code-block:: yaml
+
+   name: my_pipeline
+   storage: file:///path/to/pipeline-storage
+   libraries:
+     - glob:
+         include: transformations/**
+
+**Pipeline Commands**
+
+* ``run`` - Execute the pipeline (default)
+* ``dry-run`` - Validate the pipeline without execution
+
+Reference
+"""""""""
+
+For further information, look at `Spark Declarative Pipelines Programming 
Guide 
<https://spark.apache.org/docs/latest/declarative-pipelines-programming-guide.html>`_.
+
 .. _howto/operator:SparkSqlOperator:
 
 SparkSqlOperator
diff --git a/providers/apache/spark/provider.yaml 
b/providers/apache/spark/provider.yaml
index e5309934d55..376cec10683 100644
--- a/providers/apache/spark/provider.yaml
+++ b/providers/apache/spark/provider.yaml
@@ -95,6 +95,7 @@ operators:
   - integration-name: Apache Spark
     python-modules:
       - airflow.providers.apache.spark.operators.spark_jdbc
+      - airflow.providers.apache.spark.operators.spark_pipelines
       - airflow.providers.apache.spark.operators.spark_sql
       - airflow.providers.apache.spark.operators.spark_submit
       - airflow.providers.apache.spark.operators.spark_pyspark
@@ -105,6 +106,7 @@ hooks:
       - airflow.providers.apache.spark.hooks.spark_connect
       - airflow.providers.apache.spark.hooks.spark_jdbc
       - airflow.providers.apache.spark.hooks.spark_jdbc_script
+      - airflow.providers.apache.spark.hooks.spark_pipelines
       - airflow.providers.apache.spark.hooks.spark_sql
       - airflow.providers.apache.spark.hooks.spark_submit
 
diff --git 
a/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
 
b/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
index bd4169f9f3a..bf9f4b2f8a7 100644
--- 
a/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
+++ 
b/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
@@ -40,6 +40,7 @@ def get_provider_info():
                 "integration-name": "Apache Spark",
                 "python-modules": [
                     "airflow.providers.apache.spark.operators.spark_jdbc",
+                    "airflow.providers.apache.spark.operators.spark_pipelines",
                     "airflow.providers.apache.spark.operators.spark_sql",
                     "airflow.providers.apache.spark.operators.spark_submit",
                     "airflow.providers.apache.spark.operators.spark_pyspark",
@@ -53,6 +54,7 @@ def get_provider_info():
                     "airflow.providers.apache.spark.hooks.spark_connect",
                     "airflow.providers.apache.spark.hooks.spark_jdbc",
                     "airflow.providers.apache.spark.hooks.spark_jdbc_script",
+                    "airflow.providers.apache.spark.hooks.spark_pipelines",
                     "airflow.providers.apache.spark.hooks.spark_sql",
                     "airflow.providers.apache.spark.hooks.spark_submit",
                 ],
diff --git 
a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_pipelines.py
 
b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_pipelines.py
new file mode 100644
index 00000000000..e721433b93b
--- /dev/null
+++ 
b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_pipelines.py
@@ -0,0 +1,113 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import subprocess
+from typing import Any
+
+from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook
+from airflow.providers.common.compat.sdk import AirflowException
+
+
+class SparkPipelinesException(AirflowException):
+    """Exception raised when spark-pipelines command fails."""
+
+
+class SparkPipelinesHook(SparkSubmitHook):
+    """
+    Hook for interacting with Spark Declarative Pipelines via the 
spark-pipelines CLI.
+
+    Extends SparkSubmitHook to leverage existing connection management while 
providing
+    pipeline-specific functionality.
+
+    :param pipeline_spec: Path to the pipeline specification file (YAML)
+    :param pipeline_command: The spark-pipelines command to run ('run', 
'dry-run')
+    """
+
+    def __init__(
+        self,
+        pipeline_spec: str | None = None,
+        pipeline_command: str = "run",
+        **kwargs: Any,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.pipeline_spec = pipeline_spec
+        self.pipeline_command = pipeline_command
+
+        if pipeline_command not in ["run", "dry-run"]:
+            raise ValueError(f"Invalid pipeline command: {pipeline_command}. 
Must be 'run' or 'dry-run'")
+
+    def _get_spark_binary_path(self) -> list[str]:
+        return ["spark-pipelines"]
+
+    def _build_spark_pipelines_command(self) -> list[str]:
+        """
+        Construct the spark-pipelines command to execute.
+
+        :return: full command to be executed
+        """
+        # Start with spark-pipelines binary and command
+        connection_cmd = self._get_spark_binary_path()
+        connection_cmd.append(self.pipeline_command)
+
+        # Add pipeline spec if provided
+        if self.pipeline_spec:
+            connection_cmd.extend(["--spec", self.pipeline_spec])
+
+        # Reuse parent's common spark argument building logic
+        connection_cmd.extend(self._build_spark_common_args())
+
+        self.log.info("Spark-Pipelines cmd: %s", 
self._mask_cmd(connection_cmd))
+        return connection_cmd
+
+    def submit_pipeline(self, **kwargs: Any) -> None:
+        """
+        Execute the spark-pipelines command.
+
+        :param kwargs: extra arguments to Popen (see subprocess.Popen)
+        """
+        pipelines_cmd = self._build_spark_pipelines_command()
+
+        if self._env:
+            import os
+
+            env = os.environ.copy()
+            env.update(self._env)
+            kwargs["env"] = env
+
+        self._submit_sp = subprocess.Popen(
+            pipelines_cmd,
+            stdout=subprocess.PIPE,
+            stderr=subprocess.STDOUT,
+            bufsize=-1,
+            universal_newlines=True,
+            **kwargs,
+        )
+
+        if self._submit_sp.stdout:
+            self._process_spark_submit_log(iter(self._submit_sp.stdout))
+        returncode = self._submit_sp.wait()
+
+        if returncode:
+            raise SparkPipelinesException(
+                f"Cannot execute: {self._mask_cmd(pipelines_cmd)}. Error code 
is: {returncode}."
+            )
+
+    def submit(self, application: str = "", **kwargs: Any) -> None:
+        """Override submit to use pipeline-specific logic."""
+        self.submit_pipeline(**kwargs)
diff --git 
a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
 
b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
index 0cb93afb861..7870b790535 100644
--- 
a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
+++ 
b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
@@ -389,20 +389,19 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
 
         return connection_cmd_masked
 
-    def _build_spark_submit_command(self, application: str) -> list[str]:
+    def _build_spark_common_args(self) -> list[str]:
         """
-        Construct the spark-submit command to execute.
+        Build common Spark arguments that are shared between spark-submit and 
spark-pipelines.
 
-        :param application: command to append to the spark-submit command
-        :return: full command to be executed
+        :return: list of common spark arguments
         """
-        connection_cmd = self._get_spark_binary_path()
+        args = []
 
         # The url of the spark master
-        connection_cmd += ["--master", self._connection["master"]]
+        args += ["--master", self._connection["master"]]
 
         for key in self._conf:
-            connection_cmd += ["--conf", f"{key}={self._conf[key]}"]
+            args += ["--conf", f"{key}={self._conf[key]}"]
         if self._env_vars and (self._is_kubernetes or self._is_yarn):
             if self._is_yarn:
                 tmpl = "spark.yarn.appMasterEnv.{}={}"
@@ -411,66 +410,78 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
             else:
                 tmpl = "spark.kubernetes.driverEnv.{}={}"
             for key in self._env_vars:
-                connection_cmd += ["--conf", tmpl.format(key, 
str(self._env_vars[key]))]
+                args += ["--conf", tmpl.format(key, str(self._env_vars[key]))]
         elif self._env_vars and self._connection["deploy_mode"] != "cluster":
             self._env = self._env_vars  # Do it on Popen of the process
         elif self._env_vars and self._connection["deploy_mode"] == "cluster":
             raise AirflowException("SparkSubmitHook env_vars is not supported 
in standalone-cluster mode.")
         if self._is_kubernetes and self._connection["namespace"]:
-            connection_cmd += [
+            args += [
                 "--conf",
                 f"spark.kubernetes.namespace={self._connection['namespace']}",
             ]
         if self._properties_file:
-            connection_cmd += ["--properties-file", self._properties_file]
+            args += ["--properties-file", self._properties_file]
         if self._files:
-            connection_cmd += ["--files", self._files]
+            args += ["--files", self._files]
         if self._py_files:
-            connection_cmd += ["--py-files", self._py_files]
+            args += ["--py-files", self._py_files]
         if self._archives:
-            connection_cmd += ["--archives", self._archives]
+            args += ["--archives", self._archives]
         if self._driver_class_path:
-            connection_cmd += ["--driver-class-path", self._driver_class_path]
+            args += ["--driver-class-path", self._driver_class_path]
         if self._jars:
-            connection_cmd += ["--jars", self._jars]
+            args += ["--jars", self._jars]
         if self._packages:
-            connection_cmd += ["--packages", self._packages]
+            args += ["--packages", self._packages]
         if self._exclude_packages:
-            connection_cmd += ["--exclude-packages", self._exclude_packages]
+            args += ["--exclude-packages", self._exclude_packages]
         if self._repositories:
-            connection_cmd += ["--repositories", self._repositories]
+            args += ["--repositories", self._repositories]
         if self._num_executors:
-            connection_cmd += ["--num-executors", str(self._num_executors)]
+            args += ["--num-executors", str(self._num_executors)]
         if self._total_executor_cores:
-            connection_cmd += ["--total-executor-cores", 
str(self._total_executor_cores)]
+            args += ["--total-executor-cores", str(self._total_executor_cores)]
         if self._executor_cores:
-            connection_cmd += ["--executor-cores", str(self._executor_cores)]
+            args += ["--executor-cores", str(self._executor_cores)]
         if self._executor_memory:
-            connection_cmd += ["--executor-memory", self._executor_memory]
+            args += ["--executor-memory", self._executor_memory]
         if self._driver_memory:
-            connection_cmd += ["--driver-memory", self._driver_memory]
+            args += ["--driver-memory", self._driver_memory]
         if self._connection["keytab"]:
-            connection_cmd += ["--keytab", self._connection["keytab"]]
+            args += ["--keytab", self._connection["keytab"]]
         if self._connection["principal"]:
-            connection_cmd += ["--principal", self._connection["principal"]]
+            args += ["--principal", self._connection["principal"]]
         if self._use_krb5ccache:
             if not os.getenv("KRB5CCNAME"):
                 raise AirflowException(
                     "KRB5CCNAME environment variable required to use ticket 
ccache is missing."
                 )
-            connection_cmd += ["--conf", 
"spark.kerberos.renewal.credentials=ccache"]
+            args += ["--conf", "spark.kerberos.renewal.credentials=ccache"]
         if self._proxy_user:
-            connection_cmd += ["--proxy-user", self._proxy_user]
+            args += ["--proxy-user", self._proxy_user]
         if self._name:
-            connection_cmd += ["--name", self._name]
+            args += ["--name", self._name]
         if self._java_class:
-            connection_cmd += ["--class", self._java_class]
+            args += ["--class", self._java_class]
         if self._verbose:
-            connection_cmd += ["--verbose"]
+            args += ["--verbose"]
         if self._connection["queue"]:
-            connection_cmd += ["--queue", self._connection["queue"]]
+            args += ["--queue", self._connection["queue"]]
         if self._connection["deploy_mode"]:
-            connection_cmd += ["--deploy-mode", 
self._connection["deploy_mode"]]
+            args += ["--deploy-mode", self._connection["deploy_mode"]]
+
+        return args
+
+    def _build_spark_submit_command(self, application: str) -> list[str]:
+        """
+        Construct the spark-submit command to execute.
+
+        :param application: command to append to the spark-submit command
+        :return: full command to be executed
+        """
+        connection_cmd = self._get_spark_binary_path()
+        connection_cmd.extend(self._build_spark_common_args())
 
         # The actual script to execute
         connection_cmd += [application]
diff --git 
a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_pipelines.py
 
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_pipelines.py
new file mode 100644
index 00000000000..3e717271229
--- /dev/null
+++ 
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_pipelines.py
@@ -0,0 +1,148 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from collections.abc import Sequence
+from functools import cached_property
+from typing import TYPE_CHECKING, Any
+
+from airflow.providers.apache.spark.hooks.spark_pipelines import 
SparkPipelinesHook
+from airflow.providers.common.compat.openlineage.utils.spark import (
+    inject_parent_job_information_into_spark_properties,
+    inject_transport_information_into_spark_properties,
+)
+from airflow.providers.common.compat.sdk import BaseOperator, conf
+
+if TYPE_CHECKING:
+    from airflow.providers.common.compat.sdk import Context
+
+
+class SparkPipelinesOperator(BaseOperator):
+    """
+    Execute Spark Declarative Pipelines using the spark-pipelines CLI.
+
+    This operator wraps the spark-pipelines binary to execute declarative data 
pipelines.
+    It supports running pipelines, dry-runs for validation, and initializing 
new pipeline projects.
+
+    .. seealso::
+        For more information on Spark Declarative Pipelines, see the guide:
+        
https://spark.apache.org/docs/latest/declarative-pipelines-programming-guide.html
+
+    :param pipeline_spec: Path to the pipeline specification file (YAML). 
(templated)
+    :param pipeline_command: The spark-pipelines command to execute ('run', 
'dry-run'). Default is 'run'.
+    :param conf: Arbitrary Spark configuration properties (templated)
+    :param conn_id: The :ref:`spark connection id 
<howto/connection:spark-submit>` as configured
+        in Airflow administration. When an invalid connection_id is supplied, 
it will default to yarn.
+    :param num_executors: Number of executors to launch
+    :param executor_cores: Number of cores per executor (Default: 2)
+    :param executor_memory: Memory per executor (e.g. 1000M, 2G) (Default: 1G)
+    :param driver_memory: Memory allocated to the driver (e.g. 1000M, 2G) 
(Default: 1G)
+    :param verbose: Whether to pass the verbose flag to spark-pipelines 
process for debugging
+    :param env_vars: Environment variables for spark-pipelines. (templated)
+    :param deploy_mode: Whether to deploy your driver on the worker nodes 
(cluster) or locally as a client.
+    :param yarn_queue: The name of the YARN queue to which the application is 
submitted.
+    :param keytab: Full path to the file that contains the keytab (templated)
+    :param principal: The name of the kerberos principal used for keytab 
(templated)
+    :param openlineage_inject_parent_job_info: Whether to inject OpenLineage 
parent job information
+    :param openlineage_inject_transport_info: Whether to inject OpenLineage 
transport information
+    """
+
+    template_fields: Sequence[str] = (
+        "pipeline_spec",
+        "conf",
+        "env_vars",
+        "keytab",
+        "principal",
+    )
+
+    def __init__(
+        self,
+        *,
+        pipeline_spec: str | None = None,
+        pipeline_command: str = "run",
+        conf: dict[Any, Any] | None = None,
+        conn_id: str = "spark_default",
+        num_executors: int | None = None,
+        executor_cores: int | None = None,
+        executor_memory: str | None = None,
+        driver_memory: str | None = None,
+        verbose: bool = False,
+        env_vars: dict[str, Any] | None = None,
+        deploy_mode: str | None = None,
+        yarn_queue: str | None = None,
+        keytab: str | None = None,
+        principal: str | None = None,
+        openlineage_inject_parent_job_info: bool = conf.getboolean(
+            "openlineage", "spark_inject_parent_job_info", fallback=False
+        ),
+        openlineage_inject_transport_info: bool = conf.getboolean(
+            "openlineage", "spark_inject_transport_info", fallback=False
+        ),
+        **kwargs: Any,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.pipeline_spec = pipeline_spec
+        self.pipeline_command = pipeline_command
+        self.conf = conf
+        self.num_executors = num_executors
+        self.executor_cores = executor_cores
+        self.executor_memory = executor_memory
+        self.driver_memory = driver_memory
+        self.verbose = verbose
+        self.env_vars = env_vars
+        self.deploy_mode = deploy_mode
+        self.yarn_queue = yarn_queue
+        self.keytab = keytab
+        self.principal = principal
+        self._conn_id = conn_id
+        self._openlineage_inject_parent_job_info = 
openlineage_inject_parent_job_info
+        self._openlineage_inject_transport_info = 
openlineage_inject_transport_info
+
+    def execute(self, context: Context) -> None:
+        """Execute the SparkPipelinesHook to run the specified pipeline 
command."""
+        self.conf = self.conf or {}
+        if self._openlineage_inject_parent_job_info:
+            self.log.debug("Injecting OpenLineage parent job information into 
Spark properties.")
+            self.conf = 
inject_parent_job_information_into_spark_properties(self.conf, context)
+        if self._openlineage_inject_transport_info:
+            self.log.debug("Injecting OpenLineage transport information into 
Spark properties.")
+            self.conf = 
inject_transport_information_into_spark_properties(self.conf, context)
+
+        self.hook.submit_pipeline()
+
+    def on_kill(self) -> None:
+        self.hook.on_kill()
+
+    @cached_property
+    def hook(self) -> SparkPipelinesHook:
+        return SparkPipelinesHook(
+            pipeline_spec=self.pipeline_spec,
+            pipeline_command=self.pipeline_command,
+            conf=self.conf,
+            conn_id=self._conn_id,
+            num_executors=self.num_executors,
+            executor_cores=self.executor_cores,
+            executor_memory=self.executor_memory,
+            driver_memory=self.driver_memory,
+            verbose=self.verbose,
+            env_vars=self.env_vars,
+            deploy_mode=self.deploy_mode,
+            yarn_queue=self.yarn_queue,
+            keytab=self.keytab,
+            principal=self.principal,
+        )
diff --git 
a/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_pipelines.py 
b/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_pipelines.py
new file mode 100644
index 00000000000..abf1d5d94ec
--- /dev/null
+++ 
b/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_pipelines.py
@@ -0,0 +1,185 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from airflow.providers.apache.spark.hooks.spark_pipelines import 
SparkPipelinesHook
+from airflow.providers.common.compat.sdk import AirflowException
+
+
+class TestSparkPipelinesHook:
+    def setup_method(self):
+        self.hook = SparkPipelinesHook(
+            pipeline_spec="test_pipeline.yml", pipeline_command="run", 
conn_id="spark_default"
+        )
+
+    def test_init_with_valid_command(self):
+        hook = SparkPipelinesHook(pipeline_command="run")
+        assert hook.pipeline_command == "run"
+
+    def test_init_with_invalid_command(self):
+        with pytest.raises(ValueError, match="Invalid pipeline command"):
+            SparkPipelinesHook(pipeline_command="invalid")
+
+    def test_get_spark_binary_path(self):
+        binary_path = self.hook._get_spark_binary_path()
+        assert binary_path == ["spark-pipelines"]
+
+    
@patch("airflow.providers.apache.spark.hooks.spark_pipelines.SparkPipelinesHook._resolve_connection")
+    def test_build_pipelines_command_run(self, mock_resolve_connection):
+        mock_resolve_connection.return_value = {
+            "master": "yarn",
+            "deploy_mode": "client",
+            "queue": "default",
+            "keytab": None,
+            "principal": None,
+        }
+
+        hook = SparkPipelinesHook(
+            pipeline_spec="test_pipeline.yml",
+            pipeline_command="run",
+            num_executors=2,
+            executor_cores=4,
+            executor_memory="2G",
+            driver_memory="1G",
+            verbose=True,
+        )
+        hook._connection = mock_resolve_connection.return_value
+
+        cmd = hook._build_spark_pipelines_command()
+
+        # Verify the command starts correctly and contains expected arguments
+        assert cmd[0] == "spark-pipelines"
+        assert cmd[1] == "run"
+        assert "--spec" in cmd
+        assert "test_pipeline.yml" in cmd
+        assert "--master" in cmd
+        assert "yarn" in cmd
+        assert "--deploy-mode" in cmd
+        assert "client" in cmd
+        assert "--queue" in cmd
+        assert "default" in cmd
+        assert "--num-executors" in cmd
+        assert "2" in cmd
+        assert "--executor-cores" in cmd
+        assert "4" in cmd
+        assert "--executor-memory" in cmd
+        assert "2G" in cmd
+        assert "--driver-memory" in cmd
+        assert "1G" in cmd
+        assert "--verbose" in cmd
+
+    
@patch("airflow.providers.apache.spark.hooks.spark_pipelines.SparkPipelinesHook._resolve_connection")
+    def test_build_pipelines_command_dry_run(self, mock_resolve_connection):
+        mock_resolve_connection.return_value = {
+            "master": "local[*]",
+            "deploy_mode": None,
+            "queue": None,
+            "keytab": None,
+            "principal": None,
+        }
+
+        hook = SparkPipelinesHook(pipeline_spec="test_pipeline.yml", 
pipeline_command="dry-run")
+        hook._connection = mock_resolve_connection.return_value
+
+        cmd = hook._build_spark_pipelines_command()
+
+        # Verify the command starts correctly and contains expected arguments
+        assert cmd[0] == "spark-pipelines"
+        assert cmd[1] == "dry-run"
+        assert "--spec" in cmd
+        assert "test_pipeline.yml" in cmd
+        assert "--master" in cmd
+        assert "local[*]" in cmd
+
+    
@patch("airflow.providers.apache.spark.hooks.spark_pipelines.SparkPipelinesHook._resolve_connection")
+    def test_build_pipelines_command_with_conf(self, mock_resolve_connection):
+        mock_resolve_connection.return_value = {
+            "master": "yarn",
+            "deploy_mode": None,
+            "queue": None,
+            "keytab": None,
+            "principal": None,
+        }
+
+        hook = SparkPipelinesHook(
+            pipeline_spec="test_pipeline.yml",
+            pipeline_command="run",
+            conf={
+                "spark.sql.adaptive.enabled": "true",
+                "spark.sql.adaptive.coalescePartitions.enabled": "true",
+            },
+        )
+        hook._connection = mock_resolve_connection.return_value
+
+        cmd = hook._build_spark_pipelines_command()
+
+        assert "--conf" in cmd
+        assert "spark.sql.adaptive.enabled=true" in cmd
+        assert "spark.sql.adaptive.coalescePartitions.enabled=true" in cmd
+
+    @patch("subprocess.Popen")
+    
@patch("airflow.providers.apache.spark.hooks.spark_pipelines.SparkPipelinesHook._resolve_connection")
+    def test_submit_pipeline_success(self, mock_resolve_connection, 
mock_popen):
+        mock_resolve_connection.return_value = {
+            "master": "yarn",
+            "deploy_mode": None,
+            "queue": None,
+            "keytab": None,
+            "principal": None,
+        }
+
+        mock_process = MagicMock()
+        mock_process.wait.return_value = 0
+        mock_process.stdout = ["Pipeline completed successfully"]
+        mock_popen.return_value = mock_process
+
+        self.hook._connection = mock_resolve_connection.return_value
+        self.hook.submit_pipeline()
+
+        mock_popen.assert_called_once()
+        mock_process.wait.assert_called_once()
+
+    @patch("subprocess.Popen")
+    
@patch("airflow.providers.apache.spark.hooks.spark_pipelines.SparkPipelinesHook._resolve_connection")
+    def test_submit_pipeline_failure(self, mock_resolve_connection, 
mock_popen):
+        mock_resolve_connection.return_value = {
+            "master": "yarn",
+            "deploy_mode": None,
+            "queue": None,
+            "keytab": None,
+            "principal": None,
+        }
+
+        mock_process = MagicMock()
+        mock_process.wait.return_value = 1
+        mock_process.stdout = ["Pipeline failed"]
+        mock_popen.return_value = mock_process
+
+        self.hook._connection = mock_resolve_connection.return_value
+
+        with pytest.raises(AirflowException, match="Cannot execute"):
+            self.hook.submit_pipeline()
+
+    def test_submit_calls_submit_pipeline(self):
+        with patch.object(self.hook, "submit_pipeline") as 
mock_submit_pipeline:
+            self.hook.submit("dummy_application")
+            mock_submit_pipeline.assert_called_once()
diff --git 
a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_pipelines.py
 
b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_pipelines.py
new file mode 100644
index 00000000000..3348692560d
--- /dev/null
+++ 
b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_pipelines.py
@@ -0,0 +1,155 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock, patch
+
+from airflow.providers.apache.spark.operators.spark_pipelines import 
SparkPipelinesOperator
+
+
+class TestSparkPipelinesOperator:
+    def test_init_with_run_command(self):
+        operator = SparkPipelinesOperator(
+            task_id="test_task", pipeline_spec="test_pipeline.yml", 
pipeline_command="run"
+        )
+        assert operator.pipeline_spec == "test_pipeline.yml"
+        assert operator.pipeline_command == "run"
+
+    def test_init_with_dry_run_command(self):
+        operator = SparkPipelinesOperator(
+            task_id="test_task", pipeline_spec="test_pipeline.yml", 
pipeline_command="dry-run"
+        )
+        assert operator.pipeline_command == "dry-run"
+
+    def test_template_fields(self):
+        expected_fields = (
+            "pipeline_spec",
+            "conf",
+            "env_vars",
+            "keytab",
+            "principal",
+        )
+        operator = SparkPipelinesOperator(task_id="test_task")
+        assert operator.template_fields == expected_fields
+
+    def test_execute(self):
+        mock_hook = MagicMock()
+
+        with patch.object(SparkPipelinesOperator, "hook", mock_hook):
+            operator = SparkPipelinesOperator(
+                task_id="test_task", pipeline_spec="test_pipeline.yml", 
pipeline_command="run"
+            )
+
+            context = {}
+            operator.execute(context)
+
+            mock_hook.submit_pipeline.assert_called_once()
+
+    def test_on_kill(self):
+        mock_hook = MagicMock()
+
+        with patch.object(SparkPipelinesOperator, "hook", mock_hook):
+            operator = SparkPipelinesOperator(task_id="test_task", 
pipeline_spec="test_pipeline.yml")
+
+            operator.on_kill()
+
+            mock_hook.on_kill.assert_called_once()
+
+    def test_get_hook(self):
+        operator = SparkPipelinesOperator(
+            task_id="test_task",
+            pipeline_spec="test_pipeline.yml",
+            pipeline_command="run",
+            conf={"spark.sql.adaptive.enabled": "true"},
+            num_executors=2,
+            executor_cores=4,
+            executor_memory="2G",
+            driver_memory="1G",
+            verbose=True,
+            env_vars={"SPARK_HOME": "/opt/spark"},
+            deploy_mode="client",
+            yarn_queue="default",
+            keytab="/path/to/keytab",
+            principal="[email protected]",
+        )
+
+        hook = operator.hook
+
+        assert hook.pipeline_spec == "test_pipeline.yml"
+        assert hook.pipeline_command == "run"
+        assert hook._conf == {"spark.sql.adaptive.enabled": "true"}
+        assert hook._num_executors == 2
+        assert hook._executor_cores == 4
+        assert hook._executor_memory == "2G"
+        assert hook._driver_memory == "1G"
+        assert hook._verbose is True
+        assert hook._env_vars == {"SPARK_HOME": "/opt/spark"}
+        assert hook._deploy_mode == "client"
+        assert hook._yarn_queue == "default"
+        assert hook._keytab == "/path/to/keytab"
+        assert hook._principal == "[email protected]"
+
+    @patch(
+        
"airflow.providers.apache.spark.operators.spark_pipelines.inject_parent_job_information_into_spark_properties"
+    )
+    def test_execute_with_openlineage_parent_job_info(self, 
mock_inject_parent):
+        mock_hook = MagicMock()
+
+        with patch.object(SparkPipelinesOperator, "hook", mock_hook):
+            original_conf = {"spark.sql.adaptive.enabled": "true"}
+            modified_conf = {**original_conf, 
"spark.openlineage.parentJobName": "test_job"}
+            mock_inject_parent.return_value = modified_conf
+
+            operator = SparkPipelinesOperator(
+                task_id="test_task",
+                pipeline_spec="test_pipeline.yml",
+                conf=original_conf,
+                openlineage_inject_parent_job_info=True,
+            )
+
+            context = {"task_instance": MagicMock()}
+            operator.execute(context)
+
+            mock_inject_parent.assert_called_once_with(original_conf, context)
+            assert operator.conf == modified_conf
+            mock_hook.submit_pipeline.assert_called_once()
+
+    @patch(
+        
"airflow.providers.apache.spark.operators.spark_pipelines.inject_transport_information_into_spark_properties"
+    )
+    def test_execute_with_openlineage_transport_info(self, 
mock_inject_transport):
+        mock_hook = MagicMock()
+
+        with patch.object(SparkPipelinesOperator, "hook", mock_hook):
+            original_conf = {"spark.sql.adaptive.enabled": "true"}
+            modified_conf = {**original_conf, 
"spark.openlineage.transport.type": "http"}
+            mock_inject_transport.return_value = modified_conf
+
+            operator = SparkPipelinesOperator(
+                task_id="test_task",
+                pipeline_spec="test_pipeline.yml",
+                conf=original_conf,
+                openlineage_inject_transport_info=True,
+            )
+
+            context = {"task_instance": MagicMock()}
+            operator.execute(context)
+
+            mock_inject_transport.assert_called_once_with(original_conf, 
context)
+            assert operator.conf == modified_conf
+            mock_hook.submit_pipeline.assert_called_once()


Reply via email to