This is an automated email from the ASF dual-hosted git repository. mobuchowski pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 27a3944da67 add OpenLineage configuration injection to SparkSubmitOperator (#47508) 27a3944da67 is described below commit 27a3944da6781d8564c5f1d9da7c97ae7173b633 Author: Maciej Obuchowski <obuchowski.mac...@gmail.com> AuthorDate: Mon Mar 24 19:35:37 2025 +0100 add OpenLineage configuration injection to SparkSubmitOperator (#47508) Signed-off-by: Maciej Obuchowski <maciej.obuchow...@datadoghq.com> --- generated/provider_dependencies.json | 1 + providers/apache/spark/README.rst | 15 +- providers/apache/spark/pyproject.toml | 4 +- .../providers/apache/spark/get_provider_info.py | 12 +- .../apache/spark/operators/spark_submit.py | 19 +++ .../apache/spark/operators/test_spark_submit.py | 179 +++++++++++++++++++++ .../airflow/providers/openlineage/utils/spark.py | 2 +- 7 files changed, 216 insertions(+), 16 deletions(-) diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 5f765c9f10e..004201c55c6 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -284,6 +284,7 @@ }, "apache.spark": { "deps": [ + "apache-airflow-providers-common-compat>=1.5.0", "apache-airflow>=2.9.0", "grpcio-status>=1.59.0", "pyspark>=3.1.3" diff --git a/providers/apache/spark/README.rst b/providers/apache/spark/README.rst index 7a7c81415ac..74786005344 100644 --- a/providers/apache/spark/README.rst +++ b/providers/apache/spark/README.rst @@ -50,13 +50,14 @@ The package supports the following python versions: 3.9,3.10,3.11,3.12 Requirements ------------ -================== ================== -PIP package Version required -================== ================== -``apache-airflow`` ``>=2.9.0`` -``pyspark`` ``>=3.1.3`` -``grpcio-status`` ``>=1.59.0`` -================== ================== +========================================== ================== +PIP package Version required +========================================== ================== +``apache-airflow`` ``>=2.9.0`` +``apache-airflow-providers-common-compat`` ``>=1.5.0`` +``pyspark`` ``>=3.1.3`` +``grpcio-status`` ``>=1.59.0`` +========================================== ================== Cross provider package dependencies ----------------------------------- diff --git a/providers/apache/spark/pyproject.toml b/providers/apache/spark/pyproject.toml index 00158c07778..a3f8ddf803e 100644 --- a/providers/apache/spark/pyproject.toml +++ b/providers/apache/spark/pyproject.toml @@ -58,6 +58,7 @@ requires-python = "~=3.9" # After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build`` dependencies = [ "apache-airflow>=2.9.0", + "apache-airflow-providers-common-compat>=1.5.0", "pyspark>=3.1.3", "grpcio-status>=1.59.0", ] @@ -68,9 +69,6 @@ dependencies = [ "cncf.kubernetes" = [ "apache-airflow-providers-cncf-kubernetes>=7.4.0", ] -"common.compat" = [ - "apache-airflow-providers-common-compat" -] [dependency-groups] dev = [ diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py b/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py index a30a51ca493..0074188bd66 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/get_provider_info.py @@ -125,10 +125,12 @@ def get_provider_info(): "name": "pyspark", } ], - "dependencies": ["apache-airflow>=2.9.0", "pyspark>=3.1.3", "grpcio-status>=1.59.0"], - "optional-dependencies": { - "cncf.kubernetes": ["apache-airflow-providers-cncf-kubernetes>=7.4.0"], - "common.compat": ["apache-airflow-providers-common-compat"], - }, + "dependencies": [ + "apache-airflow>=2.9.0", + "apache-airflow-providers-common-compat>=1.5.0", + "pyspark>=3.1.3", + "grpcio-status>=1.59.0", + ], + "optional-dependencies": {"cncf.kubernetes": ["apache-airflow-providers-cncf-kubernetes>=7.4.0"]}, "devel-dependencies": [], } diff --git a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py index 3ad4ff0fe6e..0ba57eb857e 100644 --- a/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py +++ b/providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py @@ -20,8 +20,13 @@ from __future__ import annotations from collections.abc import Sequence from typing import TYPE_CHECKING, Any +from airflow.configuration import conf from airflow.models import BaseOperator from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook +from airflow.providers.common.compat.openlineage.utils.spark import ( + inject_parent_job_information_into_spark_properties, + inject_transport_information_into_spark_properties, +) from airflow.settings import WEB_COLORS if TYPE_CHECKING: @@ -135,6 +140,12 @@ class SparkSubmitOperator(BaseOperator): yarn_queue: str | None = None, deploy_mode: str | None = None, use_krb5ccache: bool = False, + openlineage_inject_parent_job_info: bool = conf.getboolean( + "openlineage", "spark_inject_parent_job_info", fallback=False + ), + openlineage_inject_transport_info: bool = conf.getboolean( + "openlineage", "spark_inject_transport_info", fallback=False + ), **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -169,9 +180,17 @@ class SparkSubmitOperator(BaseOperator): self._hook: SparkSubmitHook | None = None self._conn_id = conn_id self._use_krb5ccache = use_krb5ccache + self._openlineage_inject_parent_job_info = openlineage_inject_parent_job_info + self._openlineage_inject_transport_info = openlineage_inject_transport_info def execute(self, context: Context) -> None: """Call the SparkSubmitHook to run the provided spark job.""" + if self._openlineage_inject_parent_job_info: + self.log.debug("Injecting OpenLineage parent job information into Spark properties.") + self.conf = inject_parent_job_information_into_spark_properties(self.conf, context) + if self._openlineage_inject_transport_info: + self.log.debug("Injecting OpenLineage transport information into Spark properties.") + self.conf = inject_transport_information_into_spark_properties(self.conf, context) if self._hook is None: self._hook = self._get_hook() self._hook.submit(self.application) diff --git a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py index 94344d54068..b339093bc54 100644 --- a/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py +++ b/providers/apache/spark/tests/unit/apache/spark/operators/test_spark_submit.py @@ -17,7 +17,10 @@ # under the License. from __future__ import annotations +import logging from datetime import timedelta +from unittest import mock +from unittest.mock import MagicMock import pytest @@ -281,3 +284,179 @@ class TestSparkSubmitOperator: assert task.application_args == "application_args" assert task.env_vars == "env_vars" assert task.properties_file == "properties_file" + + @mock.patch.object(SparkSubmitOperator, "_get_hook") + @mock.patch("airflow.providers.openlineage.utils.spark.get_openlineage_listener") + def test_inject_simple_openlineage_config_to_spark(self, mock_get_openlineage_listener, mock_get_hook): + # Given / When + from openlineage.client.transport.http import ( + ApiKeyTokenProvider, + HttpCompression, + HttpConfig, + HttpTransport, + ) + + mock_get_openlineage_listener.return_value.adapter.get_or_create_openlineage_client.return_value.transport = HttpTransport( + config=HttpConfig( + url="http://localhost:5000", + endpoint="api/v2/lineage", + timeout=5050, + auth=ApiKeyTokenProvider({"api_key": "12345"}), + compression=HttpCompression.GZIP, + custom_headers={"X-OpenLineage-Custom-Header": "airflow"}, + ) + ) + operator = SparkSubmitOperator( + task_id="spark_submit_job", + spark_binary="sparky", + dag=self.dag, + openlineage_inject_parent_job_info=False, + openlineage_inject_transport_info=True, + **self._config, + ) + operator.execute(MagicMock()) + + assert operator.conf == { + "parquet.compression": "SNAPPY", + "spark.openlineage.transport.type": "http", + "spark.openlineage.transport.url": "http://localhost:5000", + "spark.openlineage.transport.endpoint": "api/v2/lineage", + "spark.openlineage.transport.timeoutInMillis": "5050000", + "spark.openlineage.transport.compression": "gzip", + "spark.openlineage.transport.auth.type": "api_key", + "spark.openlineage.transport.auth.apiKey": "Bearer 12345", + "spark.openlineage.transport.headers.X-OpenLineage-Custom-Header": "airflow", + } + + @mock.patch.object(SparkSubmitOperator, "_get_hook") + @mock.patch("airflow.providers.openlineage.utils.spark.get_openlineage_listener") + def test_inject_composite_openlineage_config_to_spark(self, mock_get_openlineage_listener, mock_get_hook): + # Given / When + from openlineage.client.transport.composite import CompositeConfig, CompositeTransport + + mock_get_openlineage_listener.return_value.adapter.get_or_create_openlineage_client.return_value.transport = CompositeTransport( + CompositeConfig.from_dict( + { + "transports": { + "test1": { + "type": "http", + "url": "http://localhost:5000", + "endpoint": "api/v2/lineage", + "timeout": "5050", + "auth": { + "type": "api_key", + "api_key": "12345", + }, + "compression": "gzip", + "custom_headers": { + "X-OpenLineage-Custom-Header": "airflow", + }, + }, + "test2": { + "type": "http", + "url": "https://example.com:1234", + }, + "test3": {"type": "console"}, + } + } + ) + ) + + mock_ti = MagicMock() + mock_ti.dag_id = "test_dag_id" + mock_ti.task_id = "spark_submit_job" + mock_ti.try_number = 1 + mock_ti.dag_run.logical_date = DEFAULT_DATE + mock_ti.dag_run.run_after = DEFAULT_DATE + mock_ti.logical_date = DEFAULT_DATE + mock_ti.map_index = -1 + + operator = SparkSubmitOperator( + task_id="spark_submit_job", + spark_binary="sparky", + dag=self.dag, + openlineage_inject_parent_job_info=True, + openlineage_inject_transport_info=True, + **self._config, + ) + operator.execute({"ti": mock_ti}) + + assert operator.conf == { + "parquet.compression": "SNAPPY", + "spark.openlineage.parentJobName": "test_dag_id.spark_submit_job", + "spark.openlineage.parentJobNamespace": "default", + "spark.openlineage.parentRunId": "01595753-6400-710b-8a12-9e978335a56d", + "spark.openlineage.transport.type": "composite", + "spark.openlineage.transport.continueOnFailure": "True", + "spark.openlineage.transport.transports.test1.type": "http", + "spark.openlineage.transport.transports.test1.url": "http://localhost:5000", + "spark.openlineage.transport.transports.test1.endpoint": "api/v2/lineage", + "spark.openlineage.transport.transports.test1.timeoutInMillis": "5050000", + "spark.openlineage.transport.transports.test1.auth.type": "api_key", + "spark.openlineage.transport.transports.test1.auth.apiKey": "Bearer 12345", + "spark.openlineage.transport.transports.test1.compression": "gzip", + "spark.openlineage.transport.transports.test1.headers.X-OpenLineage-Custom-Header": "airflow", + "spark.openlineage.transport.transports.test2.type": "http", + "spark.openlineage.transport.transports.test2.url": "https://example.com:1234", + "spark.openlineage.transport.transports.test2.endpoint": "api/v1/lineage", + "spark.openlineage.transport.transports.test2.timeoutInMillis": "5000", + } + + @mock.patch.object(SparkSubmitOperator, "_get_hook") + @mock.patch("airflow.providers.openlineage.utils.spark.get_openlineage_listener") + def test_inject_openlineage_composite_config_wrong_transport_to_spark( + self, mock_get_openlineage_listener, mock_get_hook, caplog + ): + # Given / When + from openlineage.client.transport.composite import CompositeConfig, CompositeTransport + + mock_get_openlineage_listener.return_value.adapter.get_or_create_openlineage_client.return_value.transport = CompositeTransport( + CompositeConfig.from_dict({"transports": {"test1": {"type": "console"}}}) + ) + + with caplog.at_level(logging.INFO): + operator = SparkSubmitOperator( + task_id="spark_submit_job", + spark_binary="sparky", + dag=self.dag, + openlineage_inject_parent_job_info=False, + openlineage_inject_transport_info=True, + **self._config, + ) + operator.execute(MagicMock()) + + assert ( + "OpenLineage transport type `composite` does not contain http transport. Skipping injection of OpenLineage transport information into Spark properties." + in caplog.text + ) + assert operator.conf == { + "parquet.compression": "SNAPPY", + } + + @mock.patch.object(SparkSubmitOperator, "_get_hook") + @mock.patch("airflow.providers.openlineage.utils.spark.get_openlineage_listener") + def test_inject_openlineage_simple_config_wrong_transport_to_spark( + self, mock_get_openlineage_listener, mock_get_hook, caplog + ): + # Given / When + from openlineage.client.transport.console import ConsoleConfig, ConsoleTransport + + mock_get_openlineage_listener.return_value.adapter.get_or_create_openlineage_client.return_value.transport = ConsoleTransport( + config=ConsoleConfig() + ) + + with caplog.at_level(logging.INFO): + operator = SparkSubmitOperator( + task_id="spark_submit_job", + spark_binary="sparky", + dag=self.dag, + openlineage_inject_parent_job_info=False, + openlineage_inject_transport_info=True, + **self._config, + ) + operator.execute(MagicMock()) + + assert "OpenLineage transport type `console` does not support automatic injection of OpenLineage transport information into Spark properties." + assert operator.conf == { + "parquet.compression": "SNAPPY", + } diff --git a/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py b/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py index 9f0fef84be0..becb4bd7670 100644 --- a/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py +++ b/providers/openlineage/src/airflow/providers/openlineage/utils/spark.py @@ -60,7 +60,7 @@ def _get_transport_information_as_spark_properties() -> dict: "url": tp.url, "endpoint": tp.endpoint, "timeoutInMillis": str( - int(tp.timeout * 1000) # convert to milliseconds, as required by Spark integration + int(tp.timeout) * 1000 # convert to milliseconds, as required by Spark integration ), } if hasattr(tp, "compression") and tp.compression: