This is an automated email from the ASF dual-hosted git repository.
jasonliu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 33467ea8e3a spark-pipelines operator (#61681)
33467ea8e3a is described below
commit 33467ea8e3aae05ee1ea6c79ac6dc9d8b98ccca8
Author: Sandy Ryza <[email protected]>
AuthorDate: Wed Mar 4 00:34:43 2026 -0800
spark-pipelines operator (#61681)
* spark-pipelines operator
* fixups and feedback
* fix tests
* more
* static checks
---
providers/apache/spark/docs/operators.rst | 54 ++++++
providers/apache/spark/provider.yaml | 2 +
.../providers/apache/spark/get_provider_info.py | 2 +
.../apache/spark/hooks/spark_pipelines.py | 113 +++++++++++++
.../providers/apache/spark/hooks/spark_submit.py | 75 +++++----
.../apache/spark/operators/spark_pipelines.py | 148 +++++++++++++++++
.../apache/spark/hooks/test_spark_pipelines.py | 185 +++++++++++++++++++++
.../apache/spark/operators/test_spark_pipelines.py | 155 +++++++++++++++++
8 files changed, 702 insertions(+), 32 deletions(-)
diff --git a/providers/apache/spark/docs/operators.rst
b/providers/apache/spark/docs/operators.rst
index 2e1aad8fb38..125039ebdf3 100644
--- a/providers/apache/spark/docs/operators.rst
+++ b/providers/apache/spark/docs/operators.rst
@@ -31,6 +31,8 @@ Prerequisite
gets all the configurations from operator parameters.
* To use
:class:`~airflow.providers.apache.spark.operators.spark_pyspark.PySparkOperator`
you can configure :doc:`SparkConnect Connection <connections/spark-connect>`.
+* To use
:class:`~airflow.providers.apache.spark.operators.spark_pipelines.SparkPipelinesOperator`
+ you must configure :doc:`Spark Connection <connections/spark-submit>` and
have the ``spark-pipelines`` CLI available.
.. _howto/operator:SparkJDBCOperator:
@@ -81,6 +83,58 @@ Reference
For further information, look at `Running the Spark Connect Python
<https://spark.apache.org/docs/latest/api/python/getting_started/quickstart_connect.html>`_.
+.. _howto/operator:SparkPipelinesOperator:
+
+SparkPipelinesOperator
+----------------------
+
+Execute Spark Declarative Pipelines using the ``spark-pipelines`` CLI. This
operator wraps the spark-pipelines binary to execute declarative data
pipelines, supporting both pipeline execution and validation through dry-runs.
+
+For parameter definition take a look at
:class:`~airflow.providers.apache.spark.operators.spark_pipelines.SparkPipelinesOperator`.
+
+Using the operator
+""""""""""""""""""
+
+The operator can be used to run declarative pipelines:
+
+.. code-block:: python
+
+ from airflow.providers.apache.spark.operators.spark_pipelines import
SparkPipelinesOperator
+
+ # Execute the pipeline
+ run_pipeline = SparkPipelinesOperator(
+ task_id="run_pipeline",
+ pipeline_spec="/path/to/pipeline.yml",
+ pipeline_command="run",
+ conn_id="spark_default",
+ num_executors=2,
+ executor_cores=4,
+ executor_memory="2G",
+ driver_memory="1G",
+ )
+
+**Pipeline Specification**
+
+The ``pipeline_spec`` parameter should point to a YAML file defining your
declarative pipeline:
+
+.. code-block:: yaml
+
+ name: my_pipeline
+ storage: file:///path/to/pipeline-storage
+ libraries:
+ - glob:
+ include: transformations/**
+
+**Pipeline Commands**
+
+* ``run`` - Execute the pipeline (default)
+* ``dry-run`` - Validate the pipeline without execution
+
+Reference
+"""""""""
+
+For further information, look at `Spark Declarative Pipelines Programming
Guide
<https://spark.apache.org/docs/latest/declarative-pipelines-programming-guide.html>`_.
+
.. _howto/operator:SparkSqlOperator:
SparkSqlOperator
diff --git a/providers/apache/spark/provider.yaml
b/providers/apache/spark/provider.yaml
index e5309934d55..376cec10683 100644
--- a/providers/apache/spark/provider.yaml
+++ b/providers/apache/spark/provider.yaml
@@ -95,6 +95,7 @@ operators:
- integration-name: Apache Spark
python-modules:
- airflow.providers.apache.spark.operators.spark_jdbc
+ - airflow.providers.apache.spark.operators.spark_pipelines
- airflow.providers.apache.spark.operators.spark_sql
- airflow.providers.apache.spark.operators.spark_submit
- airflow.providers.apache.spark.operators.spark_pyspark
@@ -105,6 +106,7 @@ hooks:
- airflow.providers.apache.spark.hooks.spark_connect
- airflow.providers.apache.spark.hooks.spark_jdbc
- airflow.providers.apache.spark.hooks.spark_jdbc_script
+ - airflow.providers.apache.spark.hooks.spark_pipelines
- airflow.providers.apache.spark.hooks.spark_sql
- airflow.providers.apache.spark.hooks.spark_submit
diff --git
a/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
b/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
index bd4169f9f3a..bf9f4b2f8a7 100644
---
a/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
+++
b/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
@@ -40,6 +40,7 @@ def get_provider_info():
"integration-name": "Apache Spark",
"python-modules": [
"airflow.providers.apache.spark.operators.spark_jdbc",
+ "airflow.providers.apache.spark.operators.spark_pipelines",
"airflow.providers.apache.spark.operators.spark_sql",
"airflow.providers.apache.spark.operators.spark_submit",
"airflow.providers.apache.spark.operators.spark_pyspark",
@@ -53,6 +54,7 @@ def get_provider_info():
"airflow.providers.apache.spark.hooks.spark_connect",
"airflow.providers.apache.spark.hooks.spark_jdbc",
"airflow.providers.apache.spark.hooks.spark_jdbc_script",
+ "airflow.providers.apache.spark.hooks.spark_pipelines",
"airflow.providers.apache.spark.hooks.spark_sql",
"airflow.providers.apache.spark.hooks.spark_submit",
],
diff --git
a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_pipelines.py
b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_pipelines.py
new file mode 100644
index 00000000000..e721433b93b
--- /dev/null
+++
b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_pipelines.py
@@ -0,0 +1,113 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import subprocess
+from typing import Any
+
+from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook
+from airflow.providers.common.compat.sdk import AirflowException
+
+
+class SparkPipelinesException(AirflowException):
+ """Exception raised when spark-pipelines command fails."""
+
+
+class SparkPipelinesHook(SparkSubmitHook):
+ """
+ Hook for interacting with Spark Declarative Pipelines via the
spark-pipelines CLI.
+
+ Extends SparkSubmitHook to leverage existing connection management while
providing
+ pipeline-specific functionality.
+
+ :param pipeline_spec: Path to the pipeline specification file (YAML)
+ :param pipeline_command: The spark-pipelines command to run ('run',
'dry-run')
+ """
+
+ def __init__(
+ self,
+ pipeline_spec: str | None = None,
+ pipeline_command: str = "run",
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.pipeline_spec = pipeline_spec
+ self.pipeline_command = pipeline_command
+
+ if pipeline_command not in ["run", "dry-run"]:
+ raise ValueError(f"Invalid pipeline command: {pipeline_command}.
Must be 'run' or 'dry-run'")
+
+ def _get_spark_binary_path(self) -> list[str]:
+ return ["spark-pipelines"]
+
+ def _build_spark_pipelines_command(self) -> list[str]:
+ """
+ Construct the spark-pipelines command to execute.
+
+ :return: full command to be executed
+ """
+ # Start with spark-pipelines binary and command
+ connection_cmd = self._get_spark_binary_path()
+ connection_cmd.append(self.pipeline_command)
+
+ # Add pipeline spec if provided
+ if self.pipeline_spec:
+ connection_cmd.extend(["--spec", self.pipeline_spec])
+
+ # Reuse parent's common spark argument building logic
+ connection_cmd.extend(self._build_spark_common_args())
+
+ self.log.info("Spark-Pipelines cmd: %s",
self._mask_cmd(connection_cmd))
+ return connection_cmd
+
+ def submit_pipeline(self, **kwargs: Any) -> None:
+ """
+ Execute the spark-pipelines command.
+
+ :param kwargs: extra arguments to Popen (see subprocess.Popen)
+ """
+ pipelines_cmd = self._build_spark_pipelines_command()
+
+ if self._env:
+ import os
+
+ env = os.environ.copy()
+ env.update(self._env)
+ kwargs["env"] = env
+
+ self._submit_sp = subprocess.Popen(
+ pipelines_cmd,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.STDOUT,
+ bufsize=-1,
+ universal_newlines=True,
+ **kwargs,
+ )
+
+ if self._submit_sp.stdout:
+ self._process_spark_submit_log(iter(self._submit_sp.stdout))
+ returncode = self._submit_sp.wait()
+
+ if returncode:
+ raise SparkPipelinesException(
+ f"Cannot execute: {self._mask_cmd(pipelines_cmd)}. Error code
is: {returncode}."
+ )
+
+ def submit(self, application: str = "", **kwargs: Any) -> None:
+ """Override submit to use pipeline-specific logic."""
+ self.submit_pipeline(**kwargs)
diff --git
a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
index 0cb93afb861..7870b790535 100644
---
a/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
+++
b/providers/apache/spark/src/airflow/providers/apache/spark/hooks/spark_submit.py
@@ -389,20 +389,19 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
return connection_cmd_masked
- def _build_spark_submit_command(self, application: str) -> list[str]:
+ def _build_spark_common_args(self) -> list[str]:
"""
- Construct the spark-submit command to execute.
+ Build common Spark arguments that are shared between spark-submit and
spark-pipelines.
- :param application: command to append to the spark-submit command
- :return: full command to be executed
+ :return: list of common spark arguments
"""
- connection_cmd = self._get_spark_binary_path()
+ args = []
# The url of the spark master
- connection_cmd += ["--master", self._connection["master"]]
+ args += ["--master", self._connection["master"]]
for key in self._conf:
- connection_cmd += ["--conf", f"{key}={self._conf[key]}"]
+ args += ["--conf", f"{key}={self._conf[key]}"]
if self._env_vars and (self._is_kubernetes or self._is_yarn):
if self._is_yarn:
tmpl = "spark.yarn.appMasterEnv.{}={}"
@@ -411,66 +410,78 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
else:
tmpl = "spark.kubernetes.driverEnv.{}={}"
for key in self._env_vars:
- connection_cmd += ["--conf", tmpl.format(key,
str(self._env_vars[key]))]
+ args += ["--conf", tmpl.format(key, str(self._env_vars[key]))]
elif self._env_vars and self._connection["deploy_mode"] != "cluster":
self._env = self._env_vars # Do it on Popen of the process
elif self._env_vars and self._connection["deploy_mode"] == "cluster":
raise AirflowException("SparkSubmitHook env_vars is not supported
in standalone-cluster mode.")
if self._is_kubernetes and self._connection["namespace"]:
- connection_cmd += [
+ args += [
"--conf",
f"spark.kubernetes.namespace={self._connection['namespace']}",
]
if self._properties_file:
- connection_cmd += ["--properties-file", self._properties_file]
+ args += ["--properties-file", self._properties_file]
if self._files:
- connection_cmd += ["--files", self._files]
+ args += ["--files", self._files]
if self._py_files:
- connection_cmd += ["--py-files", self._py_files]
+ args += ["--py-files", self._py_files]
if self._archives:
- connection_cmd += ["--archives", self._archives]
+ args += ["--archives", self._archives]
if self._driver_class_path:
- connection_cmd += ["--driver-class-path", self._driver_class_path]
+ args += ["--driver-class-path", self._driver_class_path]
if self._jars:
- connection_cmd += ["--jars", self._jars]
+ args += ["--jars", self._jars]
if self._packages:
- connection_cmd += ["--packages", self._packages]
+ args += ["--packages", self._packages]
if self._exclude_packages:
- connection_cmd += ["--exclude-packages", self._exclude_packages]
+ args += ["--exclude-packages", self._exclude_packages]
if self._repositories:
- connection_cmd += ["--repositories", self._repositories]
+ args += ["--repositories", self._repositories]
if self._num_executors:
- connection_cmd += ["--num-executors", str(self._num_executors)]
+ args += ["--num-executors", str(self._num_executors)]
if self._total_executor_cores:
- connection_cmd += ["--total-executor-cores",
str(self._total_executor_cores)]
+ args += ["--total-executor-cores", str(self._total_executor_cores)]
if self._executor_cores:
- connection_cmd += ["--executor-cores", str(self._executor_cores)]
+ args += ["--executor-cores", str(self._executor_cores)]
if self._executor_memory:
- connection_cmd += ["--executor-memory", self._executor_memory]
+ args += ["--executor-memory", self._executor_memory]
if self._driver_memory:
- connection_cmd += ["--driver-memory", self._driver_memory]
+ args += ["--driver-memory", self._driver_memory]
if self._connection["keytab"]:
- connection_cmd += ["--keytab", self._connection["keytab"]]
+ args += ["--keytab", self._connection["keytab"]]
if self._connection["principal"]:
- connection_cmd += ["--principal", self._connection["principal"]]
+ args += ["--principal", self._connection["principal"]]
if self._use_krb5ccache:
if not os.getenv("KRB5CCNAME"):
raise AirflowException(
"KRB5CCNAME environment variable required to use ticket
ccache is missing."
)
- connection_cmd += ["--conf",
"spark.kerberos.renewal.credentials=ccache"]
+ args += ["--conf", "spark.kerberos.renewal.credentials=ccache"]
if self._proxy_user:
- connection_cmd += ["--proxy-user", self._proxy_user]
+ args += ["--proxy-user", self._proxy_user]
if self._name:
- connection_cmd += ["--name", self._name]
+ args += ["--name", self._name]
if self._java_class:
- connection_cmd += ["--class", self._java_class]
+ args += ["--class", self._java_class]
if self._verbose:
- connection_cmd += ["--verbose"]
+ args += ["--verbose"]
if self._connection["queue"]:
- connection_cmd += ["--queue", self._connection["queue"]]
+ args += ["--queue", self._connection["queue"]]
if self._connection["deploy_mode"]:
- connection_cmd += ["--deploy-mode",
self._connection["deploy_mode"]]
+ args += ["--deploy-mode", self._connection["deploy_mode"]]
+
+ return args
+
+ def _build_spark_submit_command(self, application: str) -> list[str]:
+ """
+ Construct the spark-submit command to execute.
+
+ :param application: command to append to the spark-submit command
+ :return: full command to be executed
+ """
+ connection_cmd = self._get_spark_binary_path()
+ connection_cmd.extend(self._build_spark_common_args())
# The actual script to execute
connection_cmd += [application]
diff --git
a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_pipelines.py
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_pipelines.py
new file mode 100644
index 00000000000..3e717271229
--- /dev/null
+++
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_pipelines.py
@@ -0,0 +1,148 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from collections.abc import Sequence
+from functools import cached_property
+from typing import TYPE_CHECKING, Any
+
+from airflow.providers.apache.spark.hooks.spark_pipelines import
SparkPipelinesHook
+from airflow.providers.common.compat.openlineage.utils.spark import (
+ inject_parent_job_information_into_spark_properties,
+ inject_transport_information_into_spark_properties,
+)
+from airflow.providers.common.compat.sdk import BaseOperator, conf
+
+if TYPE_CHECKING:
+ from airflow.providers.common.compat.sdk import Context
+
+
+class SparkPipelinesOperator(BaseOperator):
+ """
+ Execute Spark Declarative Pipelines using the spark-pipelines CLI.
+
+ This operator wraps the spark-pipelines binary to execute declarative data
pipelines.
+ It supports running pipelines, dry-runs for validation, and initializing
new pipeline projects.
+
+ .. seealso::
+ For more information on Spark Declarative Pipelines, see the guide:
+
https://spark.apache.org/docs/latest/declarative-pipelines-programming-guide.html
+
+ :param pipeline_spec: Path to the pipeline specification file (YAML).
(templated)
+ :param pipeline_command: The spark-pipelines command to execute ('run',
'dry-run'). Default is 'run'.
+ :param conf: Arbitrary Spark configuration properties (templated)
+ :param conn_id: The :ref:`spark connection id
<howto/connection:spark-submit>` as configured
+ in Airflow administration. When an invalid connection_id is supplied,
it will default to yarn.
+ :param num_executors: Number of executors to launch
+ :param executor_cores: Number of cores per executor (Default: 2)
+ :param executor_memory: Memory per executor (e.g. 1000M, 2G) (Default: 1G)
+ :param driver_memory: Memory allocated to the driver (e.g. 1000M, 2G)
(Default: 1G)
+ :param verbose: Whether to pass the verbose flag to spark-pipelines
process for debugging
+ :param env_vars: Environment variables for spark-pipelines. (templated)
+ :param deploy_mode: Whether to deploy your driver on the worker nodes
(cluster) or locally as a client.
+ :param yarn_queue: The name of the YARN queue to which the application is
submitted.
+ :param keytab: Full path to the file that contains the keytab (templated)
+ :param principal: The name of the kerberos principal used for keytab
(templated)
+ :param openlineage_inject_parent_job_info: Whether to inject OpenLineage
parent job information
+ :param openlineage_inject_transport_info: Whether to inject OpenLineage
transport information
+ """
+
+ template_fields: Sequence[str] = (
+ "pipeline_spec",
+ "conf",
+ "env_vars",
+ "keytab",
+ "principal",
+ )
+
+ def __init__(
+ self,
+ *,
+ pipeline_spec: str | None = None,
+ pipeline_command: str = "run",
+ conf: dict[Any, Any] | None = None,
+ conn_id: str = "spark_default",
+ num_executors: int | None = None,
+ executor_cores: int | None = None,
+ executor_memory: str | None = None,
+ driver_memory: str | None = None,
+ verbose: bool = False,
+ env_vars: dict[str, Any] | None = None,
+ deploy_mode: str | None = None,
+ yarn_queue: str | None = None,
+ keytab: str | None = None,
+ principal: str | None = None,
+ openlineage_inject_parent_job_info: bool = conf.getboolean(
+ "openlineage", "spark_inject_parent_job_info", fallback=False
+ ),
+ openlineage_inject_transport_info: bool = conf.getboolean(
+ "openlineage", "spark_inject_transport_info", fallback=False
+ ),
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(**kwargs)
+ self.pipeline_spec = pipeline_spec
+ self.pipeline_command = pipeline_command
+ self.conf = conf
+ self.num_executors = num_executors
+ self.executor_cores = executor_cores
+ self.executor_memory = executor_memory
+ self.driver_memory = driver_memory
+ self.verbose = verbose
+ self.env_vars = env_vars
+ self.deploy_mode = deploy_mode
+ self.yarn_queue = yarn_queue
+ self.keytab = keytab
+ self.principal = principal
+ self._conn_id = conn_id
+ self._openlineage_inject_parent_job_info =
openlineage_inject_parent_job_info
+ self._openlineage_inject_transport_info =
openlineage_inject_transport_info
+
+ def execute(self, context: Context) -> None:
+ """Execute the SparkPipelinesHook to run the specified pipeline
command."""
+ self.conf = self.conf or {}
+ if self._openlineage_inject_parent_job_info:
+ self.log.debug("Injecting OpenLineage parent job information into
Spark properties.")
+ self.conf =
inject_parent_job_information_into_spark_properties(self.conf, context)
+ if self._openlineage_inject_transport_info:
+ self.log.debug("Injecting OpenLineage transport information into
Spark properties.")
+ self.conf =
inject_transport_information_into_spark_properties(self.conf, context)
+
+ self.hook.submit_pipeline()
+
+ def on_kill(self) -> None:
+ self.hook.on_kill()
+
+ @cached_property
+ def hook(self) -> SparkPipelinesHook:
+ return SparkPipelinesHook(
+ pipeline_spec=self.pipeline_spec,
+ pipeline_command=self.pipeline_command,
+ conf=self.conf,
+ conn_id=self._conn_id,
+ num_executors=self.num_executors,
+ executor_cores=self.executor_cores,
+ executor_memory=self.executor_memory,
+ driver_memory=self.driver_memory,
+ verbose=self.verbose,
+ env_vars=self.env_vars,
+ deploy_mode=self.deploy_mode,
+ yarn_queue=self.yarn_queue,
+ keytab=self.keytab,
+ principal=self.principal,
+ )
diff --git
a/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_pipelines.py
b/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_pipelines.py
new file mode 100644
index 00000000000..abf1d5d94ec
--- /dev/null
+++
b/providers/apache/spark/tests/unit/apache/spark/hooks/test_spark_pipelines.py
@@ -0,0 +1,185 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from airflow.providers.apache.spark.hooks.spark_pipelines import
SparkPipelinesHook
+from airflow.providers.common.compat.sdk import AirflowException
+
+
+class TestSparkPipelinesHook:
+ def setup_method(self):
+ self.hook = SparkPipelinesHook(
+ pipeline_spec="test_pipeline.yml", pipeline_command="run",
conn_id="spark_default"
+ )
+
+ def test_init_with_valid_command(self):
+ hook = SparkPipelinesHook(pipeline_command="run")
+ assert hook.pipeline_command == "run"
+
+ def test_init_with_invalid_command(self):
+ with pytest.raises(ValueError, match="Invalid pipeline command"):
+ SparkPipelinesHook(pipeline_command="invalid")
+
+ def test_get_spark_binary_path(self):
+ binary_path = self.hook._get_spark_binary_path()
+ assert binary_path == ["spark-pipelines"]
+
+
@patch("airflow.providers.apache.spark.hooks.spark_pipelines.SparkPipelinesHook._resolve_connection")
+ def test_build_pipelines_command_run(self, mock_resolve_connection):
+ mock_resolve_connection.return_value = {
+ "master": "yarn",
+ "deploy_mode": "client",
+ "queue": "default",
+ "keytab": None,
+ "principal": None,
+ }
+
+ hook = SparkPipelinesHook(
+ pipeline_spec="test_pipeline.yml",
+ pipeline_command="run",
+ num_executors=2,
+ executor_cores=4,
+ executor_memory="2G",
+ driver_memory="1G",
+ verbose=True,
+ )
+ hook._connection = mock_resolve_connection.return_value
+
+ cmd = hook._build_spark_pipelines_command()
+
+ # Verify the command starts correctly and contains expected arguments
+ assert cmd[0] == "spark-pipelines"
+ assert cmd[1] == "run"
+ assert "--spec" in cmd
+ assert "test_pipeline.yml" in cmd
+ assert "--master" in cmd
+ assert "yarn" in cmd
+ assert "--deploy-mode" in cmd
+ assert "client" in cmd
+ assert "--queue" in cmd
+ assert "default" in cmd
+ assert "--num-executors" in cmd
+ assert "2" in cmd
+ assert "--executor-cores" in cmd
+ assert "4" in cmd
+ assert "--executor-memory" in cmd
+ assert "2G" in cmd
+ assert "--driver-memory" in cmd
+ assert "1G" in cmd
+ assert "--verbose" in cmd
+
+
@patch("airflow.providers.apache.spark.hooks.spark_pipelines.SparkPipelinesHook._resolve_connection")
+ def test_build_pipelines_command_dry_run(self, mock_resolve_connection):
+ mock_resolve_connection.return_value = {
+ "master": "local[*]",
+ "deploy_mode": None,
+ "queue": None,
+ "keytab": None,
+ "principal": None,
+ }
+
+ hook = SparkPipelinesHook(pipeline_spec="test_pipeline.yml",
pipeline_command="dry-run")
+ hook._connection = mock_resolve_connection.return_value
+
+ cmd = hook._build_spark_pipelines_command()
+
+ # Verify the command starts correctly and contains expected arguments
+ assert cmd[0] == "spark-pipelines"
+ assert cmd[1] == "dry-run"
+ assert "--spec" in cmd
+ assert "test_pipeline.yml" in cmd
+ assert "--master" in cmd
+ assert "local[*]" in cmd
+
+
@patch("airflow.providers.apache.spark.hooks.spark_pipelines.SparkPipelinesHook._resolve_connection")
+ def test_build_pipelines_command_with_conf(self, mock_resolve_connection):
+ mock_resolve_connection.return_value = {
+ "master": "yarn",
+ "deploy_mode": None,
+ "queue": None,
+ "keytab": None,
+ "principal": None,
+ }
+
+ hook = SparkPipelinesHook(
+ pipeline_spec="test_pipeline.yml",
+ pipeline_command="run",
+ conf={
+ "spark.sql.adaptive.enabled": "true",
+ "spark.sql.adaptive.coalescePartitions.enabled": "true",
+ },
+ )
+ hook._connection = mock_resolve_connection.return_value
+
+ cmd = hook._build_spark_pipelines_command()
+
+ assert "--conf" in cmd
+ assert "spark.sql.adaptive.enabled=true" in cmd
+ assert "spark.sql.adaptive.coalescePartitions.enabled=true" in cmd
+
+ @patch("subprocess.Popen")
+
@patch("airflow.providers.apache.spark.hooks.spark_pipelines.SparkPipelinesHook._resolve_connection")
+ def test_submit_pipeline_success(self, mock_resolve_connection,
mock_popen):
+ mock_resolve_connection.return_value = {
+ "master": "yarn",
+ "deploy_mode": None,
+ "queue": None,
+ "keytab": None,
+ "principal": None,
+ }
+
+ mock_process = MagicMock()
+ mock_process.wait.return_value = 0
+ mock_process.stdout = ["Pipeline completed successfully"]
+ mock_popen.return_value = mock_process
+
+ self.hook._connection = mock_resolve_connection.return_value
+ self.hook.submit_pipeline()
+
+ mock_popen.assert_called_once()
+ mock_process.wait.assert_called_once()
+
+ @patch("subprocess.Popen")
+
@patch("airflow.providers.apache.spark.hooks.spark_pipelines.SparkPipelinesHook._resolve_connection")
+ def test_submit_pipeline_failure(self, mock_resolve_connection,
mock_popen):
+ mock_resolve_connection.return_value = {
+ "master": "yarn",
+ "deploy_mode": None,
+ "queue": None,
+ "keytab": None,
+ "principal": None,
+ }
+
+ mock_process = MagicMock()
+ mock_process.wait.return_value = 1
+ mock_process.stdout = ["Pipeline failed"]
+ mock_popen.return_value = mock_process
+
+ self.hook._connection = mock_resolve_connection.return_value
+
+ with pytest.raises(AirflowException, match="Cannot execute"):
+ self.hook.submit_pipeline()
+
+ def test_submit_calls_submit_pipeline(self):
+ with patch.object(self.hook, "submit_pipeline") as
mock_submit_pipeline:
+ self.hook.submit("dummy_application")
+ mock_submit_pipeline.assert_called_once()
diff --git
a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_pipelines.py
b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_pipelines.py
new file mode 100644
index 00000000000..3348692560d
--- /dev/null
+++
b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_pipelines.py
@@ -0,0 +1,155 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock, patch
+
+from airflow.providers.apache.spark.operators.spark_pipelines import
SparkPipelinesOperator
+
+
+class TestSparkPipelinesOperator:
+ def test_init_with_run_command(self):
+ operator = SparkPipelinesOperator(
+ task_id="test_task", pipeline_spec="test_pipeline.yml",
pipeline_command="run"
+ )
+ assert operator.pipeline_spec == "test_pipeline.yml"
+ assert operator.pipeline_command == "run"
+
+ def test_init_with_dry_run_command(self):
+ operator = SparkPipelinesOperator(
+ task_id="test_task", pipeline_spec="test_pipeline.yml",
pipeline_command="dry-run"
+ )
+ assert operator.pipeline_command == "dry-run"
+
+ def test_template_fields(self):
+ expected_fields = (
+ "pipeline_spec",
+ "conf",
+ "env_vars",
+ "keytab",
+ "principal",
+ )
+ operator = SparkPipelinesOperator(task_id="test_task")
+ assert operator.template_fields == expected_fields
+
+ def test_execute(self):
+ mock_hook = MagicMock()
+
+ with patch.object(SparkPipelinesOperator, "hook", mock_hook):
+ operator = SparkPipelinesOperator(
+ task_id="test_task", pipeline_spec="test_pipeline.yml",
pipeline_command="run"
+ )
+
+ context = {}
+ operator.execute(context)
+
+ mock_hook.submit_pipeline.assert_called_once()
+
+ def test_on_kill(self):
+ mock_hook = MagicMock()
+
+ with patch.object(SparkPipelinesOperator, "hook", mock_hook):
+ operator = SparkPipelinesOperator(task_id="test_task",
pipeline_spec="test_pipeline.yml")
+
+ operator.on_kill()
+
+ mock_hook.on_kill.assert_called_once()
+
+ def test_get_hook(self):
+ operator = SparkPipelinesOperator(
+ task_id="test_task",
+ pipeline_spec="test_pipeline.yml",
+ pipeline_command="run",
+ conf={"spark.sql.adaptive.enabled": "true"},
+ num_executors=2,
+ executor_cores=4,
+ executor_memory="2G",
+ driver_memory="1G",
+ verbose=True,
+ env_vars={"SPARK_HOME": "/opt/spark"},
+ deploy_mode="client",
+ yarn_queue="default",
+ keytab="/path/to/keytab",
+ principal="[email protected]",
+ )
+
+ hook = operator.hook
+
+ assert hook.pipeline_spec == "test_pipeline.yml"
+ assert hook.pipeline_command == "run"
+ assert hook._conf == {"spark.sql.adaptive.enabled": "true"}
+ assert hook._num_executors == 2
+ assert hook._executor_cores == 4
+ assert hook._executor_memory == "2G"
+ assert hook._driver_memory == "1G"
+ assert hook._verbose is True
+ assert hook._env_vars == {"SPARK_HOME": "/opt/spark"}
+ assert hook._deploy_mode == "client"
+ assert hook._yarn_queue == "default"
+ assert hook._keytab == "/path/to/keytab"
+ assert hook._principal == "[email protected]"
+
+ @patch(
+
"airflow.providers.apache.spark.operators.spark_pipelines.inject_parent_job_information_into_spark_properties"
+ )
+ def test_execute_with_openlineage_parent_job_info(self,
mock_inject_parent):
+ mock_hook = MagicMock()
+
+ with patch.object(SparkPipelinesOperator, "hook", mock_hook):
+ original_conf = {"spark.sql.adaptive.enabled": "true"}
+ modified_conf = {**original_conf,
"spark.openlineage.parentJobName": "test_job"}
+ mock_inject_parent.return_value = modified_conf
+
+ operator = SparkPipelinesOperator(
+ task_id="test_task",
+ pipeline_spec="test_pipeline.yml",
+ conf=original_conf,
+ openlineage_inject_parent_job_info=True,
+ )
+
+ context = {"task_instance": MagicMock()}
+ operator.execute(context)
+
+ mock_inject_parent.assert_called_once_with(original_conf, context)
+ assert operator.conf == modified_conf
+ mock_hook.submit_pipeline.assert_called_once()
+
+ @patch(
+
"airflow.providers.apache.spark.operators.spark_pipelines.inject_transport_information_into_spark_properties"
+ )
+ def test_execute_with_openlineage_transport_info(self,
mock_inject_transport):
+ mock_hook = MagicMock()
+
+ with patch.object(SparkPipelinesOperator, "hook", mock_hook):
+ original_conf = {"spark.sql.adaptive.enabled": "true"}
+ modified_conf = {**original_conf,
"spark.openlineage.transport.type": "http"}
+ mock_inject_transport.return_value = modified_conf
+
+ operator = SparkPipelinesOperator(
+ task_id="test_task",
+ pipeline_spec="test_pipeline.yml",
+ conf=original_conf,
+ openlineage_inject_transport_info=True,
+ )
+
+ context = {"task_instance": MagicMock()}
+ operator.execute(context)
+
+ mock_inject_transport.assert_called_once_with(original_conf,
context)
+ assert operator.conf == modified_conf
+ mock_hook.submit_pipeline.assert_called_once()