This is an automated email from the ASF dual-hosted git repository.

mobuchowski pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 27a3944da67 add OpenLineage configuration injection to 
SparkSubmitOperator (#47508)
27a3944da67 is described below

commit 27a3944da6781d8564c5f1d9da7c97ae7173b633
Author: Maciej Obuchowski <obuchowski.mac...@gmail.com>
AuthorDate: Mon Mar 24 19:35:37 2025 +0100

    add OpenLineage configuration injection to SparkSubmitOperator (#47508)
    
    Signed-off-by: Maciej Obuchowski <maciej.obuchow...@datadoghq.com>
---
 generated/provider_dependencies.json               |   1 +
 providers/apache/spark/README.rst                  |  15 +-
 providers/apache/spark/pyproject.toml              |   4 +-
 .../providers/apache/spark/get_provider_info.py    |  12 +-
 .../apache/spark/operators/spark_submit.py         |  19 +++
 .../apache/spark/operators/test_spark_submit.py    | 179 +++++++++++++++++++++
 .../airflow/providers/openlineage/utils/spark.py   |   2 +-
 7 files changed, 216 insertions(+), 16 deletions(-)

diff --git a/generated/provider_dependencies.json 
b/generated/provider_dependencies.json
index 5f765c9f10e..004201c55c6 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -284,6 +284,7 @@
   },
   "apache.spark": {
     "deps": [
+      "apache-airflow-providers-common-compat>=1.5.0",
       "apache-airflow>=2.9.0",
       "grpcio-status>=1.59.0",
       "pyspark>=3.1.3"
diff --git a/providers/apache/spark/README.rst 
b/providers/apache/spark/README.rst
index 7a7c81415ac..74786005344 100644
--- a/providers/apache/spark/README.rst
+++ b/providers/apache/spark/README.rst
@@ -50,13 +50,14 @@ The package supports the following python versions: 
3.9,3.10,3.11,3.12
 Requirements
 ------------
 
-==================  ==================
-PIP package         Version required
-==================  ==================
-``apache-airflow``  ``>=2.9.0``
-``pyspark``         ``>=3.1.3``
-``grpcio-status``   ``>=1.59.0``
-==================  ==================
+==========================================  ==================
+PIP package                                 Version required
+==========================================  ==================
+``apache-airflow``                          ``>=2.9.0``
+``apache-airflow-providers-common-compat``  ``>=1.5.0``
+``pyspark``                                 ``>=3.1.3``
+``grpcio-status``                           ``>=1.59.0``
+==========================================  ==================
 
 Cross provider package dependencies
 -----------------------------------
diff --git a/providers/apache/spark/pyproject.toml 
b/providers/apache/spark/pyproject.toml
index 00158c07778..a3f8ddf803e 100644
--- a/providers/apache/spark/pyproject.toml
+++ b/providers/apache/spark/pyproject.toml
@@ -58,6 +58,7 @@ requires-python = "~=3.9"
 # After you modify the dependencies, and rebuild your Breeze CI image with 
``breeze ci-image build``
 dependencies = [
     "apache-airflow>=2.9.0",
+    "apache-airflow-providers-common-compat>=1.5.0",
     "pyspark>=3.1.3",
     "grpcio-status>=1.59.0",
 ]
@@ -68,9 +69,6 @@ dependencies = [
 "cncf.kubernetes" = [
     "apache-airflow-providers-cncf-kubernetes>=7.4.0",
 ]
-"common.compat" = [
-    "apache-airflow-providers-common-compat"
-]
 
 [dependency-groups]
 dev = [
diff --git 
a/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
 
b/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
index a30a51ca493..0074188bd66 100644
--- 
a/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
+++ 
b/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py
@@ -125,10 +125,12 @@ def get_provider_info():
                 "name": "pyspark",
             }
         ],
-        "dependencies": ["apache-airflow>=2.9.0", "pyspark>=3.1.3", 
"grpcio-status>=1.59.0"],
-        "optional-dependencies": {
-            "cncf.kubernetes": 
["apache-airflow-providers-cncf-kubernetes>=7.4.0"],
-            "common.compat": ["apache-airflow-providers-common-compat"],
-        },
+        "dependencies": [
+            "apache-airflow>=2.9.0",
+            "apache-airflow-providers-common-compat>=1.5.0",
+            "pyspark>=3.1.3",
+            "grpcio-status>=1.59.0",
+        ],
+        "optional-dependencies": {"cncf.kubernetes": 
["apache-airflow-providers-cncf-kubernetes>=7.4.0"]},
         "devel-dependencies": [],
     }
diff --git 
a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
 
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
index 3ad4ff0fe6e..0ba57eb857e 100644
--- 
a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
+++ 
b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py
@@ -20,8 +20,13 @@ from __future__ import annotations
 from collections.abc import Sequence
 from typing import TYPE_CHECKING, Any
 
+from airflow.configuration import conf
 from airflow.models import BaseOperator
 from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook
+from airflow.providers.common.compat.openlineage.utils.spark import (
+    inject_parent_job_information_into_spark_properties,
+    inject_transport_information_into_spark_properties,
+)
 from airflow.settings import WEB_COLORS
 
 if TYPE_CHECKING:
@@ -135,6 +140,12 @@ class SparkSubmitOperator(BaseOperator):
         yarn_queue: str | None = None,
         deploy_mode: str | None = None,
         use_krb5ccache: bool = False,
+        openlineage_inject_parent_job_info: bool = conf.getboolean(
+            "openlineage", "spark_inject_parent_job_info", fallback=False
+        ),
+        openlineage_inject_transport_info: bool = conf.getboolean(
+            "openlineage", "spark_inject_transport_info", fallback=False
+        ),
         **kwargs: Any,
     ) -> None:
         super().__init__(**kwargs)
@@ -169,9 +180,17 @@ class SparkSubmitOperator(BaseOperator):
         self._hook: SparkSubmitHook | None = None
         self._conn_id = conn_id
         self._use_krb5ccache = use_krb5ccache
+        self._openlineage_inject_parent_job_info = 
openlineage_inject_parent_job_info
+        self._openlineage_inject_transport_info = 
openlineage_inject_transport_info
 
     def execute(self, context: Context) -> None:
         """Call the SparkSubmitHook to run the provided spark job."""
+        if self._openlineage_inject_parent_job_info:
+            self.log.debug("Injecting OpenLineage parent job information into 
Spark properties.")
+            self.conf = 
inject_parent_job_information_into_spark_properties(self.conf, context)
+        if self._openlineage_inject_transport_info:
+            self.log.debug("Injecting OpenLineage transport information into 
Spark properties.")
+            self.conf = 
inject_transport_information_into_spark_properties(self.conf, context)
         if self._hook is None:
             self._hook = self._get_hook()
         self._hook.submit(self.application)
diff --git 
a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py 
b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
index 94344d54068..b339093bc54 100644
--- 
a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
+++ 
b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py
@@ -17,7 +17,10 @@
 # under the License.
 from __future__ import annotations
 
+import logging
 from datetime import timedelta
+from unittest import mock
+from unittest.mock import MagicMock
 
 import pytest
 
@@ -281,3 +284,179 @@ class TestSparkSubmitOperator:
         assert task.application_args == "application_args"
         assert task.env_vars == "env_vars"
         assert task.properties_file == "properties_file"
+
+    @mock.patch.object(SparkSubmitOperator, "_get_hook")
+    
@mock.patch("airflow.providers.openlineage.utils.spark.get_openlineage_listener")
+    def test_inject_simple_openlineage_config_to_spark(self, 
mock_get_openlineage_listener, mock_get_hook):
+        # Given / When
+        from openlineage.client.transport.http import (
+            ApiKeyTokenProvider,
+            HttpCompression,
+            HttpConfig,
+            HttpTransport,
+        )
+
+        
mock_get_openlineage_listener.return_value.adapter.get_or_create_openlineage_client.return_value.transport
 = HttpTransport(
+            config=HttpConfig(
+                url="http://localhost:5000";,
+                endpoint="api/v2/lineage",
+                timeout=5050,
+                auth=ApiKeyTokenProvider({"api_key": "12345"}),
+                compression=HttpCompression.GZIP,
+                custom_headers={"X-OpenLineage-Custom-Header": "airflow"},
+            )
+        )
+        operator = SparkSubmitOperator(
+            task_id="spark_submit_job",
+            spark_binary="sparky",
+            dag=self.dag,
+            openlineage_inject_parent_job_info=False,
+            openlineage_inject_transport_info=True,
+            **self._config,
+        )
+        operator.execute(MagicMock())
+
+        assert operator.conf == {
+            "parquet.compression": "SNAPPY",
+            "spark.openlineage.transport.type": "http",
+            "spark.openlineage.transport.url": "http://localhost:5000";,
+            "spark.openlineage.transport.endpoint": "api/v2/lineage",
+            "spark.openlineage.transport.timeoutInMillis": "5050000",
+            "spark.openlineage.transport.compression": "gzip",
+            "spark.openlineage.transport.auth.type": "api_key",
+            "spark.openlineage.transport.auth.apiKey": "Bearer 12345",
+            "spark.openlineage.transport.headers.X-OpenLineage-Custom-Header": 
"airflow",
+        }
+
+    @mock.patch.object(SparkSubmitOperator, "_get_hook")
+    
@mock.patch("airflow.providers.openlineage.utils.spark.get_openlineage_listener")
+    def test_inject_composite_openlineage_config_to_spark(self, 
mock_get_openlineage_listener, mock_get_hook):
+        # Given / When
+        from openlineage.client.transport.composite import CompositeConfig, 
CompositeTransport
+
+        
mock_get_openlineage_listener.return_value.adapter.get_or_create_openlineage_client.return_value.transport
 = CompositeTransport(
+            CompositeConfig.from_dict(
+                {
+                    "transports": {
+                        "test1": {
+                            "type": "http",
+                            "url": "http://localhost:5000";,
+                            "endpoint": "api/v2/lineage",
+                            "timeout": "5050",
+                            "auth": {
+                                "type": "api_key",
+                                "api_key": "12345",
+                            },
+                            "compression": "gzip",
+                            "custom_headers": {
+                                "X-OpenLineage-Custom-Header": "airflow",
+                            },
+                        },
+                        "test2": {
+                            "type": "http",
+                            "url": "https://example.com:1234";,
+                        },
+                        "test3": {"type": "console"},
+                    }
+                }
+            )
+        )
+
+        mock_ti = MagicMock()
+        mock_ti.dag_id = "test_dag_id"
+        mock_ti.task_id = "spark_submit_job"
+        mock_ti.try_number = 1
+        mock_ti.dag_run.logical_date = DEFAULT_DATE
+        mock_ti.dag_run.run_after = DEFAULT_DATE
+        mock_ti.logical_date = DEFAULT_DATE
+        mock_ti.map_index = -1
+
+        operator = SparkSubmitOperator(
+            task_id="spark_submit_job",
+            spark_binary="sparky",
+            dag=self.dag,
+            openlineage_inject_parent_job_info=True,
+            openlineage_inject_transport_info=True,
+            **self._config,
+        )
+        operator.execute({"ti": mock_ti})
+
+        assert operator.conf == {
+            "parquet.compression": "SNAPPY",
+            "spark.openlineage.parentJobName": "test_dag_id.spark_submit_job",
+            "spark.openlineage.parentJobNamespace": "default",
+            "spark.openlineage.parentRunId": 
"01595753-6400-710b-8a12-9e978335a56d",
+            "spark.openlineage.transport.type": "composite",
+            "spark.openlineage.transport.continueOnFailure": "True",
+            "spark.openlineage.transport.transports.test1.type": "http",
+            "spark.openlineage.transport.transports.test1.url": 
"http://localhost:5000";,
+            "spark.openlineage.transport.transports.test1.endpoint": 
"api/v2/lineage",
+            "spark.openlineage.transport.transports.test1.timeoutInMillis": 
"5050000",
+            "spark.openlineage.transport.transports.test1.auth.type": 
"api_key",
+            "spark.openlineage.transport.transports.test1.auth.apiKey": 
"Bearer 12345",
+            "spark.openlineage.transport.transports.test1.compression": "gzip",
+            
"spark.openlineage.transport.transports.test1.headers.X-OpenLineage-Custom-Header":
 "airflow",
+            "spark.openlineage.transport.transports.test2.type": "http",
+            "spark.openlineage.transport.transports.test2.url": 
"https://example.com:1234";,
+            "spark.openlineage.transport.transports.test2.endpoint": 
"api/v1/lineage",
+            "spark.openlineage.transport.transports.test2.timeoutInMillis": 
"5000",
+        }
+
+    @mock.patch.object(SparkSubmitOperator, "_get_hook")
+    
@mock.patch("airflow.providers.openlineage.utils.spark.get_openlineage_listener")
+    def test_inject_openlineage_composite_config_wrong_transport_to_spark(
+        self, mock_get_openlineage_listener, mock_get_hook, caplog
+    ):
+        # Given / When
+        from openlineage.client.transport.composite import CompositeConfig, 
CompositeTransport
+
+        
mock_get_openlineage_listener.return_value.adapter.get_or_create_openlineage_client.return_value.transport
 = CompositeTransport(
+            CompositeConfig.from_dict({"transports": {"test1": {"type": 
"console"}}})
+        )
+
+        with caplog.at_level(logging.INFO):
+            operator = SparkSubmitOperator(
+                task_id="spark_submit_job",
+                spark_binary="sparky",
+                dag=self.dag,
+                openlineage_inject_parent_job_info=False,
+                openlineage_inject_transport_info=True,
+                **self._config,
+            )
+            operator.execute(MagicMock())
+
+            assert (
+                "OpenLineage transport type `composite` does not contain http 
transport. Skipping injection of OpenLineage transport information into Spark 
properties."
+                in caplog.text
+            )
+        assert operator.conf == {
+            "parquet.compression": "SNAPPY",
+        }
+
+    @mock.patch.object(SparkSubmitOperator, "_get_hook")
+    
@mock.patch("airflow.providers.openlineage.utils.spark.get_openlineage_listener")
+    def test_inject_openlineage_simple_config_wrong_transport_to_spark(
+        self, mock_get_openlineage_listener, mock_get_hook, caplog
+    ):
+        # Given / When
+        from openlineage.client.transport.console import ConsoleConfig, 
ConsoleTransport
+
+        
mock_get_openlineage_listener.return_value.adapter.get_or_create_openlineage_client.return_value.transport
 = ConsoleTransport(
+            config=ConsoleConfig()
+        )
+
+        with caplog.at_level(logging.INFO):
+            operator = SparkSubmitOperator(
+                task_id="spark_submit_job",
+                spark_binary="sparky",
+                dag=self.dag,
+                openlineage_inject_parent_job_info=False,
+                openlineage_inject_transport_info=True,
+                **self._config,
+            )
+            operator.execute(MagicMock())
+
+            assert "OpenLineage transport type `console` does not support 
automatic injection of OpenLineage transport information into Spark properties."
+        assert operator.conf == {
+            "parquet.compression": "SNAPPY",
+        }
diff --git 
a/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py 
b/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py
index 9f0fef84be0..becb4bd7670 100644
--- a/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py
+++ b/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py
@@ -60,7 +60,7 @@ def _get_transport_information_as_spark_properties() -> dict:
             "url": tp.url,
             "endpoint": tp.endpoint,
             "timeoutInMillis": str(
-                int(tp.timeout * 1000)  # convert to milliseconds, as required 
by Spark integration
+                int(tp.timeout) * 1000  # convert to milliseconds, as required 
by Spark integration
             ),
         }
         if hasattr(tp, "compression") and tp.compression:

Reply via email to