This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 5151b5ba963 [SPARK-44877][CONNECT][PYTHON] Support python protobuf functions for Spark Connect 5151b5ba963 is described below commit 5151b5ba9630b042c659e98b9fd3f7bdb6fc19bd Author: bogao007 <bo....@databricks.com> AuthorDate: Mon Aug 21 09:12:14 2023 +0800 [SPARK-44877][CONNECT][PYTHON] Support python protobuf functions for Spark Connect ### What changes were proposed in this pull request? Support python protobuf functions for Spark Connect ### Why are the changes needed? Support python protobuf functions for Spark Connect ### Does this PR introduce _any_ user-facing change? Yes ### How was this patch tested? added doctest and did manual test ``` bo.gaoPF2WXGJ3KT spark % bin/pyspark --remote "local[*]" --jars connector/protobuf/target/scala-2.12/spark-protobuf_2.12-4.0.0-SNAPSHOT.jar Python 3.9.6 (default, May 7 2023, 23:32:44) [Clang 14.0.3 (clang-1403.0.22.14.1)] on darwin Type "help", "copyright", "credits" or "license" for more information. 23/08/18 10:47:19 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable Setting default log level to "WARN". To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel). /Users/bo.gao/workplace/spark/python/pyspark/pandas/__init__.py:50: UserWarning: 'PYARROW_IGNORE_TIMEZONE' environment variable was not set. It is required to set this environment variable to '1' in both driver and executor sides if you use pyarrow>=2.0.0. pandas-on-Spark will set it for you but it does not work if there is a Spark context already launched. warnings.warn( Welcome to ____ __ / __/__ ___ _____/ /__ _\ \/ _ \/ _ `/ __/ '_/ /__ / .__/\_,_/_/ /_/\_\ version 4.0.0.dev0 /_/ Using Python version 3.9.6 (default, May 7 2023 23:32:44) Client connected to the Spark Connect server at localhost SparkSession available as 'spark'. >>> from pyspark.sql.protobuf.functions import from_protobuf, to_protobuf >>> import tempfile >>> data = [([(2, "Alice", 13093020)])] >>> ddl_schema = "value struct<age: INTEGER, name: STRING, score: LONG>" >>> df = spark.createDataFrame(data, ddl_schema) >>> desc_hex = str('0ACE010A41636F6E6E6563746F722F70726F746F6275662F7372632F746573742F726' ... '5736F75726365732F70726F746F6275662F7079737061726B5F746573742E70726F746F121D6F72672E61' ... '70616368652E737061726B2E73716C2E70726F746F627566224B0A0D53696D706C654D657373616765121' ... '00A03616765180120012805520361676512120A046E616D6518022001280952046E616D6512140A057363' ... '6F7265180320012803520573636F72654215421353696D706C654D65737361676550726F746F736206707' ... '26F746F33') >>> with tempfile.TemporaryDirectory() as tmp_dir: ... desc_file_path = "%s/pyspark_test.desc" % tmp_dir ... with open(desc_file_path, "wb") as f: ... _ = f.write(bytearray.fromhex(desc_hex)) ... f.flush() ... message_name = 'SimpleMessage' ... proto_df = df.select( ... to_protobuf(df.value, message_name, desc_file_path).alias("value")) ... proto_df.show(truncate=False) ... proto_df_1 = proto_df.select( # With file name for descriptor ... from_protobuf(proto_df.value, message_name, desc_file_path).alias("value")) ... proto_df_1.show(truncate=False) ... proto_df_2 = proto_df.select( # With binary for descriptor ... from_protobuf(proto_df.value, message_name, ... binaryDescriptorSet = bytearray.fromhex(desc_hex)) ... .alias("value")) ... proto_df_2.show(truncate=False) ... +-------------------------------------------+ |value | +-------------------------------------------+ |[08 02 12 05 41 6C 69 63 65 18 9C 91 9F 06]| +-------------------------------------------+ +--------------------+ |value | +--------------------+ |{2, Alice, 13093020}| +--------------------+ +--------------------+ |value | +--------------------+ |{2, Alice, 13093020}| +--------------------+ ``` ``` >>> data = [([(1668035962, 2020)])] >>> ddl_schema = "value struct<seconds: LONG, nanos: INT>" >>> df = spark.createDataFrame(data, ddl_schema) >>> message_class_name = "org.sparkproject.spark_protobuf.protobuf.Timestamp" >>> to_proto_df = df.select(to_protobuf(df.value, message_class_name).alias("value")) >>> from_proto_df = to_proto_df.select( ... from_protobuf(to_proto_df.value, message_class_name).alias("value")) >>> from_proto_df.show(truncate=False) +------------------+ |value | +------------------+ |{1668035962, 2020}| +------------------+ ``` ``` >>> import tempfile >>> data = [([(2, "Alice", 13093020)])] >>> ddl_schema = "value struct<age: INTEGER, name: STRING, score: LONG>" >>> df = spark.createDataFrame(data, ddl_schema) >>> desc_hex = str('0ACE010A41636F6E6E6563746F722F70726F746F6275662F7372632F746573742F726' ... '5736F75726365732F70726F746F6275662F7079737061726B5F746573742E70726F746F121D6F72672E61' ... '70616368652E737061726B2E73716C2E70726F746F627566224B0A0D53696D706C654D657373616765121' ... '00A03616765180120012805520361676512120A046E616D6518022001280952046E616D6512140A057363' ... '6F7265180320012803520573636F72654215421353696D706C654D65737361676550726F746F736206707' ... '26F746F33') >>> with tempfile.TemporaryDirectory() as tmp_dir: ... desc_file_path = "%s/pyspark_test.desc" % tmp_dir ... with open(desc_file_path, "wb") as f: ... _ = f.write(bytearray.fromhex(desc_hex)) ... f.flush() ... message_name = 'SimpleMessage' ... proto_df = df.select( # With file name for descriptor ... to_protobuf(df.value, message_name, desc_file_path).alias("suite")) ... proto_df.show(truncate=False) ... proto_df_2 = df.select( # With binary for descriptor ... to_protobuf(df.value, message_name, ... binaryDescriptorSet=bytearray.fromhex(desc_hex)) ... .alias("suite")) ... proto_df_2.show(truncate=False) ... +-------------------------------------------+ |suite | +-------------------------------------------+ |[08 02 12 05 41 6C 69 63 65 18 9C 91 9F 06]| +-------------------------------------------+ +-------------------------------------------+ |suite | +-------------------------------------------+ |[08 02 12 05 41 6C 69 63 65 18 9C 91 9F 06]| +-------------------------------------------+ ``` ``` >>> data = [([(1668035962, 2020)])] >>> ddl_schema = "value struct<seconds: LONG, nanos: INT>" >>> df = spark.createDataFrame(data, ddl_schema) >>> message_class_name = "org.sparkproject.spark_protobuf.protobuf.Timestamp" >>> proto_df = df.select(to_protobuf(df.value, message_class_name).alias("suite")) >>> proto_df.show(truncate=False) +----------------------------+ |suite | +----------------------------+ |[08 FA EA B0 9B 06 10 E4 0F]| +----------------------------+ ``` Closes #42563 from bogao007/python-connect-protobuf. Authored-by: bogao007 <bo....@databricks.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- assembly/pom.xml | 6 + dev/sparktestsupport/modules.py | 25 ++-- python/pyspark/sql/connect/protobuf/__init__.py | 18 +++ python/pyspark/sql/connect/protobuf/functions.py | 166 +++++++++++++++++++++++ python/pyspark/sql/protobuf/functions.py | 6 +- python/pyspark/sql/utils.py | 16 +++ 6 files changed, 224 insertions(+), 13 deletions(-) diff --git a/assembly/pom.xml b/assembly/pom.xml index 4f80719afd4..69952a7ccfa 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -166,6 +166,12 @@ <version>${project.version}</version> <scope>provided</scope> </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-protobuf_${scala.binary.version}</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> </dependencies> </profile> <profile> diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index bca70018625..64ccf600ef0 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -271,27 +271,27 @@ sql_kafka = Module( ], ) -connect = Module( - name="connect", - dependencies=[hive, avro], +protobuf = Module( + name="protobuf", + dependencies=[sql], source_file_regexes=[ - "connector/connect", + "connector/protobuf", ], - build_profile_flags=["-Pconnect"], sbt_test_goals=[ - "connect/test", - "connect-client-jvm/test", + "protobuf/test", ], ) -protobuf = Module( - name="protobuf", - dependencies=[sql], +connect = Module( + name="connect", + dependencies=[hive, avro, protobuf], source_file_regexes=[ - "connector/protobuf", + "connector/connect", ], + build_profile_flags=["-Pconnect"], sbt_test_goals=[ - "protobuf/test", + "connect/test", + "connect-client-jvm/test", ], ) @@ -832,6 +832,7 @@ pyspark_connect = Module( "pyspark.sql.connect.dataframe", "pyspark.sql.connect.functions", "pyspark.sql.connect.avro.functions", + "pyspark.sql.connect.protobuf.functions", "pyspark.sql.connect.streaming.readwriter", "pyspark.sql.connect.streaming.query", # sql unittests diff --git a/python/pyspark/sql/connect/protobuf/__init__.py b/python/pyspark/sql/connect/protobuf/__init__.py new file mode 100644 index 00000000000..dc81e9f515e --- /dev/null +++ b/python/pyspark/sql/connect/protobuf/__init__.py @@ -0,0 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Spark Connect Python Client - Protobuf Functions""" diff --git a/python/pyspark/sql/connect/protobuf/functions.py b/python/pyspark/sql/connect/protobuf/functions.py new file mode 100644 index 00000000000..56119f4bc4e --- /dev/null +++ b/python/pyspark/sql/connect/protobuf/functions.py @@ -0,0 +1,166 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +A collections of builtin protobuf functions +""" + +from pyspark.sql.connect.utils import check_dependencies + +check_dependencies(__name__) + +from typing import Dict, Optional, TYPE_CHECKING + +from pyspark.sql.protobuf import functions as PyProtobufFunctions + +from pyspark.sql.connect.column import Column +from pyspark.sql.connect.functions import _invoke_function, _to_col, _options_to_col, lit + +if TYPE_CHECKING: + from pyspark.sql.connect._typing import ColumnOrName + + +def from_protobuf( + data: "ColumnOrName", + messageName: str, + descFilePath: Optional[str] = None, + options: Optional[Dict[str, str]] = None, + binaryDescriptorSet: Optional[bytes] = None, +) -> Column: + binary_proto = None + if binaryDescriptorSet is not None: + binary_proto = binaryDescriptorSet + elif descFilePath is not None: + binary_proto = _read_descriptor_set_file(descFilePath) + + # TODO: simplify the code when _invoke_function() supports None as input. + if binary_proto is not None: + if options is None: + return _invoke_function( + "from_protobuf", _to_col(data), lit(messageName), lit(binary_proto) + ) + else: + return _invoke_function( + "from_protobuf", + _to_col(data), + lit(messageName), + lit(binary_proto), + _options_to_col(options), + ) + else: + if options is None: + return _invoke_function("from_protobuf", _to_col(data), lit(messageName)) + else: + return _invoke_function( + "from_protobuf", _to_col(data), lit(messageName), _options_to_col(options) + ) + + +from_protobuf.__doc__ = PyProtobufFunctions.from_protobuf.__doc__ + + +def to_protobuf( + data: "ColumnOrName", + messageName: str, + descFilePath: Optional[str] = None, + options: Optional[Dict[str, str]] = None, + binaryDescriptorSet: Optional[bytes] = None, +) -> Column: + binary_proto = None + if binaryDescriptorSet is not None: + binary_proto = binaryDescriptorSet + elif descFilePath is not None: + binary_proto = _read_descriptor_set_file(descFilePath) + + # TODO: simplify the code when _invoke_function() supports None as input. + if binary_proto is not None: + if options is None: + return _invoke_function( + "to_protobuf", _to_col(data), lit(messageName), lit(binary_proto) + ) + else: + return _invoke_function( + "to_protobuf", + _to_col(data), + lit(messageName), + lit(binary_proto), + _options_to_col(options), + ) + else: + if options is None: + return _invoke_function("to_protobuf", _to_col(data), lit(messageName)) + else: + return _invoke_function( + "to_protobuf", _to_col(data), lit(messageName), _options_to_col(options) + ) + + +to_protobuf.__doc__ = PyProtobufFunctions.to_protobuf.__doc__ + + +def _read_descriptor_set_file(filePath: str) -> bytes: + with open(filePath, "rb") as f: + return f.read() + + +def _test() -> None: + import os + import sys + from pyspark.testing.utils import search_jar + + protobuf_jar = search_jar("connector/protobuf", "spark-protobuf-assembly-", "spark-protobuf") + if protobuf_jar is None: + print( + "Skipping all Protobuf Python tests as the optional Protobuf project was " + "not compiled into a JAR. To run these tests, " + "you need to build Spark with 'build/sbt package' or " + "'build/mvn package' before running this test." + ) + sys.exit(0) + else: + existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell") + jars_args = "--jars %s" % protobuf_jar + os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, existing_args]) + + import doctest + from pyspark.sql import SparkSession as PySparkSession + import pyspark.sql.connect.protobuf.functions + + globs = pyspark.sql.connect.protobuf.functions.__dict__.copy() + + globs["spark"] = ( + PySparkSession.builder.appName("sql.protobuf.functions tests") + .remote("local[2]") + .getOrCreate() + ) + + (failure_count, test_count) = doctest.testmod( + pyspark.sql.connect.protobuf.functions, + globs=globs, + optionflags=doctest.ELLIPSIS + | doctest.NORMALIZE_WHITESPACE + | doctest.IGNORE_EXCEPTION_DETAIL, + ) + + globs["spark"].stop() + + if failure_count: + sys.exit(-1) + + +if __name__ == "__main__": + _test() diff --git a/python/pyspark/sql/protobuf/functions.py b/python/pyspark/sql/protobuf/functions.py index 42165938eb7..acb1a17efbd 100644 --- a/python/pyspark/sql/protobuf/functions.py +++ b/python/pyspark/sql/protobuf/functions.py @@ -25,13 +25,14 @@ from typing import Dict, Optional, TYPE_CHECKING, cast from py4j.java_gateway import JVMView from pyspark.sql.column import Column, _to_java_column -from pyspark.sql.utils import get_active_spark_context +from pyspark.sql.utils import get_active_spark_context, try_remote_protobuf_functions from pyspark.util import _print_missing_jar if TYPE_CHECKING: from pyspark.sql._typing import ColumnOrName +@try_remote_protobuf_functions def from_protobuf( data: "ColumnOrName", messageName: str, @@ -58,6 +59,7 @@ def from_protobuf( .. versionchanged:: 3.5.0 Supports `binaryDescriptorSet` arg to pass binary descriptor directly. + Supports Spark Connect. Parameters ---------- @@ -161,6 +163,7 @@ def from_protobuf( return Column(jc) +@try_remote_protobuf_functions def to_protobuf( data: "ColumnOrName", messageName: str, @@ -187,6 +190,7 @@ def to_protobuf( .. versionchanged:: 3.5.0 Supports `binaryDescriptorSet` arg to pass binary descriptor directly. + Supports Spark Connect. Parameters ---------- diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index cb262a14cbe..45df4433916 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -211,6 +211,22 @@ def try_remote_avro_functions(f: FuncT) -> FuncT: return cast(FuncT, wrapped) +def try_remote_protobuf_functions(f: FuncT) -> FuncT: + """Mark API supported from Spark Connect.""" + + @functools.wraps(f) + def wrapped(*args: Any, **kwargs: Any) -> Any: + + if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ: + from pyspark.sql.connect.protobuf import functions + + return getattr(functions, f.__name__)(*args, **kwargs) + else: + return f(*args, **kwargs) + + return cast(FuncT, wrapped) + + def try_remote_window(f: FuncT) -> FuncT: """Mark API supported from Spark Connect.""" --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org