This is an automated email from the ASF dual-hosted git repository.
mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 27a3944da67 add OpenLineage configuration injection to
SparkSubmitOperator (#47508)
27a3944da67 is described below
commit 27a3944da6781d8564c5f1d9da7c97ae7173b633
Author: Maciej Obuchowski <[email protected]>
AuthorDate: Mon Mar 24 19:35:37 2025 +0100
add OpenLineage configuration injection to SparkSubmitOperator (#47508)
Signed-off-by: Maciej Obuchowski <[email protected]>
---
generated/provider_dependencies.json | 1 +
providers/apache/spark/README.rst | 15 +-
providers/apache/spark/pyproject.toml | 4 +-
.../providers/apache/spark/get_provider_info.py | 12 +-
.../apache/spark/operators/spark_submit.py | 19 +++
.../apache/spark/operators/test_spark_submit.py | 179 +++++++++++++++++++++
.../airflow/providers/openlineage/utils/spark.py | 2 +-
7 files changed, 216 insertions(+), 16 deletions(-)
diff --git a/generated/provider_dependencies.json
b/generated/provider_dependencies.json
index 5f765c9f10e..004201c55c6 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -284,6 +284,7 @@
},
"apache.spark": {
"deps": [
+ "apache-airflow-providers-common-compat>=1.5.0",
"apache-airflow>=2.9.0",
"grpcio-status>=1.59.0",
"pyspark>=3.1.3"
diff --git a/providers/apache/spark/README.rst
b/providers/apache/spark/README.rst
index 7a7c81415ac..74786005344 100644
--- a/providers/apache/spark/README.rst
+++ b/providers/apache/spark/README.rst
@@ -50,13 +50,14 @@ The package supports the following python versions:
3.9,3.10,3.11,3.12
Requirements
------------
-================== ==================
-PIP package Version required
-================== ==================
-``apache-airflow`` ``>=2.9.0``
-``pyspark`` ``>=3.1.3``
-``grpcio-status`` ``>=1.59.0``
-================== ==================
+========================================== ==================
+PIP package Version required
+========================================== ==================
+``apache-airflow`` ``>=2.9.0``
+``apache-airflow-providers-common-compat`` ``>=1.5.0``
+``pyspark`` ``>=3.1.3``
+``grpcio-status`` ``>=1.59.0``
+========================================== ==================
Cross provider package dependencies
-----------------------------------
diff --git a/providers/apache/spark/pyproject.toml
b/providers/apache/spark/pyproject.toml
index 00158c07778..a3f8ddf803e 100644
--- a/providers/apache/spark/pyproject.toml
+++ b/providers/apache/spark/pyproject.toml
@@ -58,6 +58,7 @@ requires-python = "~=3.9"
# After you modify the dependencies, and rebuild your Breeze CI image with
``breeze ci-image build``
dependencies = [
"apache-airflow>=2.9.0",
+ "apache-airflow-providers-common-compat>=1.5.0",
"pyspark>=3.1.3",
"grpcio-status>=1.59.0",
]
@@ -68,9 +69,6 @@ dependencies = [
"cncf.kubernetes" = [
"apache-airflow-providers-cncf-kubernetes>=7.4.0",
]
-"common.compat" = [
- "apache-airflow-providers-common-compat"
-]
[dependency-groups]
dev = [
diff --git
a/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
b/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
index a30a51ca493..0074188bd66 100644
---
a/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
+++
b/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
@@ -125,10 +125,12 @@ def get_provider_info():
"name": "pyspark",
}
],
- "dependencies": ["apache-airflow>=2.9.0", "pyspark>=3.1.3",
"grpcio-status>=1.59.0"],
- "optional-dependencies": {
- "cncf.kubernetes":
["apache-airflow-providers-cncf-kubernetes>=7.4.0"],
- "common.compat": ["apache-airflow-providers-common-compat"],
- },
+ "dependencies": [
+ "apache-airflow>=2.9.0",
+ "apache-airflow-providers-common-compat>=1.5.0",
+ "pyspark>=3.1.3",
+ "grpcio-status>=1.59.0",
+ ],
+ "optional-dependencies": {"cncf.kubernetes":
["apache-airflow-providers-cncf-kubernetes>=7.4.0"]},
"devel-dependencies": [],
}
diff --git
a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
index 3ad4ff0fe6e..0ba57eb857e 100644
---
a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
+++
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
@@ -20,8 +20,13 @@ from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any
+from airflow.configuration import conf
from airflow.models import BaseOperator
from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook
+from airflow.providers.common.compat.openlineage.utils.spark import (
+ inject_parent_job_information_into_spark_properties,
+ inject_transport_information_into_spark_properties,
+)
from airflow.settings import WEB_COLORS
if TYPE_CHECKING:
@@ -135,6 +140,12 @@ class SparkSubmitOperator(BaseOperator):
yarn_queue: str | None = None,
deploy_mode: str | None = None,
use_krb5ccache: bool = False,
+ openlineage_inject_parent_job_info: bool = conf.getboolean(
+ "openlineage", "spark_inject_parent_job_info", fallback=False
+ ),
+ openlineage_inject_transport_info: bool = conf.getboolean(
+ "openlineage", "spark_inject_transport_info", fallback=False
+ ),
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
@@ -169,9 +180,17 @@ class SparkSubmitOperator(BaseOperator):
self._hook: SparkSubmitHook | None = None
self._conn_id = conn_id
self._use_krb5ccache = use_krb5ccache
+ self._openlineage_inject_parent_job_info =
openlineage_inject_parent_job_info
+ self._openlineage_inject_transport_info =
openlineage_inject_transport_info
def execute(self, context: Context) -> None:
"""Call the SparkSubmitHook to run the provided spark job."""
+ if self._openlineage_inject_parent_job_info:
+ self.log.debug("Injecting OpenLineage parent job information into
Spark properties.")
+ self.conf =
inject_parent_job_information_into_spark_properties(self.conf, context)
+ if self._openlineage_inject_transport_info:
+ self.log.debug("Injecting OpenLineage transport information into
Spark properties.")
+ self.conf =
inject_transport_information_into_spark_properties(self.conf, context)
if self._hook is None:
self._hook = self._get_hook()
self._hook.submit(self.application)
diff --git
a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
index 94344d54068..b339093bc54 100644
---
a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
+++
b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
@@ -17,7 +17,10 @@
# under the License.
from __future__ import annotations
+import logging
from datetime import timedelta
+from unittest import mock
+from unittest.mock import MagicMock
import pytest
@@ -281,3 +284,179 @@ class TestSparkSubmitOperator:
assert task.application_args == "application_args"
assert task.env_vars == "env_vars"
assert task.properties_file == "properties_file"
+
+ @mock.patch.object(SparkSubmitOperator, "_get_hook")
+
@mock.patch("airflow.providers.openlineage.utils.spark.get_openlineage_listener")
+ def test_inject_simple_openlineage_config_to_spark(self,
mock_get_openlineage_listener, mock_get_hook):
+ # Given / When
+ from openlineage.client.transport.http import (
+ ApiKeyTokenProvider,
+ HttpCompression,
+ HttpConfig,
+ HttpTransport,
+ )
+
+
mock_get_openlineage_listener.return_value.adapter.get_or_create_openlineage_client.return_value.transport
= HttpTransport(
+ config=HttpConfig(
+ url="http://localhost:5000",
+ endpoint="api/v2/lineage",
+ timeout=5050,
+ auth=ApiKeyTokenProvider({"api_key": "12345"}),
+ compression=HttpCompression.GZIP,
+ custom_headers={"X-OpenLineage-Custom-Header": "airflow"},
+ )
+ )
+ operator = SparkSubmitOperator(
+ task_id="spark_submit_job",
+ spark_binary="sparky",
+ dag=self.dag,
+ openlineage_inject_parent_job_info=False,
+ openlineage_inject_transport_info=True,
+ **self._config,
+ )
+ operator.execute(MagicMock())
+
+ assert operator.conf == {
+ "parquet.compression": "SNAPPY",
+ "spark.openlineage.transport.type": "http",
+ "spark.openlineage.transport.url": "http://localhost:5000",
+ "spark.openlineage.transport.endpoint": "api/v2/lineage",
+ "spark.openlineage.transport.timeoutInMillis": "5050000",
+ "spark.openlineage.transport.compression": "gzip",
+ "spark.openlineage.transport.auth.type": "api_key",
+ "spark.openlineage.transport.auth.apiKey": "Bearer 12345",
+ "spark.openlineage.transport.headers.X-OpenLineage-Custom-Header":
"airflow",
+ }
+
+ @mock.patch.object(SparkSubmitOperator, "_get_hook")
+
@mock.patch("airflow.providers.openlineage.utils.spark.get_openlineage_listener")
+ def test_inject_composite_openlineage_config_to_spark(self,
mock_get_openlineage_listener, mock_get_hook):
+ # Given / When
+ from openlineage.client.transport.composite import CompositeConfig,
CompositeTransport
+
+
mock_get_openlineage_listener.return_value.adapter.get_or_create_openlineage_client.return_value.transport
= CompositeTransport(
+ CompositeConfig.from_dict(
+ {
+ "transports": {
+ "test1": {
+ "type": "http",
+ "url": "http://localhost:5000",
+ "endpoint": "api/v2/lineage",
+ "timeout": "5050",
+ "auth": {
+ "type": "api_key",
+ "api_key": "12345",
+ },
+ "compression": "gzip",
+ "custom_headers": {
+ "X-OpenLineage-Custom-Header": "airflow",
+ },
+ },
+ "test2": {
+ "type": "http",
+ "url": "https://example.com:1234",
+ },
+ "test3": {"type": "console"},
+ }
+ }
+ )
+ )
+
+ mock_ti = MagicMock()
+ mock_ti.dag_id = "test_dag_id"
+ mock_ti.task_id = "spark_submit_job"
+ mock_ti.try_number = 1
+ mock_ti.dag_run.logical_date = DEFAULT_DATE
+ mock_ti.dag_run.run_after = DEFAULT_DATE
+ mock_ti.logical_date = DEFAULT_DATE
+ mock_ti.map_index = -1
+
+ operator = SparkSubmitOperator(
+ task_id="spark_submit_job",
+ spark_binary="sparky",
+ dag=self.dag,
+ openlineage_inject_parent_job_info=True,
+ openlineage_inject_transport_info=True,
+ **self._config,
+ )
+ operator.execute({"ti": mock_ti})
+
+ assert operator.conf == {
+ "parquet.compression": "SNAPPY",
+ "spark.openlineage.parentJobName": "test_dag_id.spark_submit_job",
+ "spark.openlineage.parentJobNamespace": "default",
+ "spark.openlineage.parentRunId":
"01595753-6400-710b-8a12-9e978335a56d",
+ "spark.openlineage.transport.type": "composite",
+ "spark.openlineage.transport.continueOnFailure": "True",
+ "spark.openlineage.transport.transports.test1.type": "http",
+ "spark.openlineage.transport.transports.test1.url":
"http://localhost:5000",
+ "spark.openlineage.transport.transports.test1.endpoint":
"api/v2/lineage",
+ "spark.openlineage.transport.transports.test1.timeoutInMillis":
"5050000",
+ "spark.openlineage.transport.transports.test1.auth.type":
"api_key",
+ "spark.openlineage.transport.transports.test1.auth.apiKey":
"Bearer 12345",
+ "spark.openlineage.transport.transports.test1.compression": "gzip",
+
"spark.openlineage.transport.transports.test1.headers.X-OpenLineage-Custom-Header":
"airflow",
+ "spark.openlineage.transport.transports.test2.type": "http",
+ "spark.openlineage.transport.transports.test2.url":
"https://example.com:1234",
+ "spark.openlineage.transport.transports.test2.endpoint":
"api/v1/lineage",
+ "spark.openlineage.transport.transports.test2.timeoutInMillis":
"5000",
+ }
+
+ @mock.patch.object(SparkSubmitOperator, "_get_hook")
+
@mock.patch("airflow.providers.openlineage.utils.spark.get_openlineage_listener")
+ def test_inject_openlineage_composite_config_wrong_transport_to_spark(
+ self, mock_get_openlineage_listener, mock_get_hook, caplog
+ ):
+ # Given / When
+ from openlineage.client.transport.composite import CompositeConfig,
CompositeTransport
+
+
mock_get_openlineage_listener.return_value.adapter.get_or_create_openlineage_client.return_value.transport
= CompositeTransport(
+ CompositeConfig.from_dict({"transports": {"test1": {"type":
"console"}}})
+ )
+
+ with caplog.at_level(logging.INFO):
+ operator = SparkSubmitOperator(
+ task_id="spark_submit_job",
+ spark_binary="sparky",
+ dag=self.dag,
+ openlineage_inject_parent_job_info=False,
+ openlineage_inject_transport_info=True,
+ **self._config,
+ )
+ operator.execute(MagicMock())
+
+ assert (
+ "OpenLineage transport type `composite` does not contain http
transport. Skipping injection of OpenLineage transport information into Spark
properties."
+ in caplog.text
+ )
+ assert operator.conf == {
+ "parquet.compression": "SNAPPY",
+ }
+
+ @mock.patch.object(SparkSubmitOperator, "_get_hook")
+
@mock.patch("airflow.providers.openlineage.utils.spark.get_openlineage_listener")
+ def test_inject_openlineage_simple_config_wrong_transport_to_spark(
+ self, mock_get_openlineage_listener, mock_get_hook, caplog
+ ):
+ # Given / When
+ from openlineage.client.transport.console import ConsoleConfig,
ConsoleTransport
+
+
mock_get_openlineage_listener.return_value.adapter.get_or_create_openlineage_client.return_value.transport
= ConsoleTransport(
+ config=ConsoleConfig()
+ )
+
+ with caplog.at_level(logging.INFO):
+ operator = SparkSubmitOperator(
+ task_id="spark_submit_job",
+ spark_binary="sparky",
+ dag=self.dag,
+ openlineage_inject_parent_job_info=False,
+ openlineage_inject_transport_info=True,
+ **self._config,
+ )
+ operator.execute(MagicMock())
+
+ assert "OpenLineage transport type `console` does not support
automatic injection of OpenLineage transport information into Spark properties."
+ assert operator.conf == {
+ "parquet.compression": "SNAPPY",
+ }
diff --git
a/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py
b/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py
index 9f0fef84be0..becb4bd7670 100644
--- a/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py
+++ b/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py
@@ -60,7 +60,7 @@ def _get_transport_information_as_spark_properties() -> dict:
"url": tp.url,
"endpoint": tp.endpoint,
"timeoutInMillis": str(
- int(tp.timeout * 1000) # convert to milliseconds, as required
by Spark integration
+ int(tp.timeout) * 1000 # convert to milliseconds, as required
by Spark integration
),
}
if hasattr(tp, "compression") and tp.compression: