This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 5151b5ba963 [SPARK-44877][CONNECT][PYTHON] Support python protobuf 
functions for Spark Connect
5151b5ba963 is described below

commit 5151b5ba9630b042c659e98b9fd3f7bdb6fc19bd
Author: bogao007 <bo....@databricks.com>
AuthorDate: Mon Aug 21 09:12:14 2023 +0800

    [SPARK-44877][CONNECT][PYTHON] Support python protobuf functions for Spark 
Connect
    
    ### What changes were proposed in this pull request?
    
    Support python protobuf functions for Spark Connect
    
    ### Why are the changes needed?
    
    Support python protobuf functions for Spark Connect
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes
    
    ### How was this patch tested?
    
    added doctest and did manual test
    ```
    bo.gaoPF2WXGJ3KT spark % bin/pyspark --remote "local[*]" --jars 
connector/protobuf/target/scala-2.12/spark-protobuf_2.12-4.0.0-SNAPSHOT.jar
    Python 3.9.6 (default, May  7 2023, 23:32:44)
    [Clang 14.0.3 (clang-1403.0.22.14.1)] on darwin
    Type "help", "copyright", "credits" or "license" for more information.
    23/08/18 10:47:19 WARN NativeCodeLoader: Unable to load native-hadoop 
library for your platform... using builtin-java classes where applicable
    Setting default log level to "WARN".
    To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use 
setLogLevel(newLevel).
    /Users/bo.gao/workplace/spark/python/pyspark/pandas/__init__.py:50: 
UserWarning: 'PYARROW_IGNORE_TIMEZONE' environment variable was not set. It is 
required to set this environment variable to '1' in both driver and executor 
sides if you use pyarrow>=2.0.0. pandas-on-Spark will set it for you but it 
does not work if there is a Spark context already launched.
      warnings.warn(
    Welcome to
          ____              __
         / __/__  ___ _____/ /__
        _\ \/ _ \/ _ `/ __/  '_/
       /__ / .__/\_,_/_/ /_/\_\   version 4.0.0.dev0
          /_/
    
    Using Python version 3.9.6 (default, May  7 2023 23:32:44)
    Client connected to the Spark Connect server at localhost
    SparkSession available as 'spark'.
    >>> from pyspark.sql.protobuf.functions import from_protobuf, to_protobuf
    >>> import tempfile
    >>> data = [([(2, "Alice", 13093020)])]
    >>> ddl_schema = "value struct<age: INTEGER, name: STRING, score: LONG>"
    >>> df = spark.createDataFrame(data, ddl_schema)
    >>> desc_hex = 
str('0ACE010A41636F6E6E6563746F722F70726F746F6275662F7372632F746573742F726'
    ...     
'5736F75726365732F70726F746F6275662F7079737061726B5F746573742E70726F746F121D6F72672E61'
    ...     
'70616368652E737061726B2E73716C2E70726F746F627566224B0A0D53696D706C654D657373616765121'
    ...     
'00A03616765180120012805520361676512120A046E616D6518022001280952046E616D6512140A057363'
    ...     
'6F7265180320012803520573636F72654215421353696D706C654D65737361676550726F746F736206707'
    ...     '26F746F33')
    >>> with tempfile.TemporaryDirectory() as tmp_dir:
    ...      desc_file_path = "%s/pyspark_test.desc" % tmp_dir
    ...      with open(desc_file_path, "wb") as f:
    ...          _ = f.write(bytearray.fromhex(desc_hex))
    ...          f.flush()
    ...          message_name = 'SimpleMessage'
    ...          proto_df = df.select(
    ...              to_protobuf(df.value, message_name, 
desc_file_path).alias("value"))
    ...          proto_df.show(truncate=False)
    ...          proto_df_1 = proto_df.select( # With file name for descriptor
    ...              from_protobuf(proto_df.value, message_name, 
desc_file_path).alias("value"))
    ...          proto_df_1.show(truncate=False)
    ...          proto_df_2 = proto_df.select( # With binary for descriptor
    ...              from_protobuf(proto_df.value, message_name,
    ...                            binaryDescriptorSet = 
bytearray.fromhex(desc_hex))
    ...              .alias("value"))
    ...          proto_df_2.show(truncate=False)
    ...
    +-------------------------------------------+
    |value                                      |
    +-------------------------------------------+
    |[08 02 12 05 41 6C 69 63 65 18 9C 91 9F 06]|
    +-------------------------------------------+
    
    +--------------------+
    |value               |
    +--------------------+
    |{2, Alice, 13093020}|
    +--------------------+
    
    +--------------------+
    |value               |
    +--------------------+
    |{2, Alice, 13093020}|
    +--------------------+
    ```
    
    ```
    >>> data = [([(1668035962, 2020)])]
    >>> ddl_schema = "value struct<seconds: LONG, nanos: INT>"
    >>> df = spark.createDataFrame(data, ddl_schema)
    >>> message_class_name = 
"org.sparkproject.spark_protobuf.protobuf.Timestamp"
    >>> to_proto_df = df.select(to_protobuf(df.value, 
message_class_name).alias("value"))
    >>> from_proto_df = to_proto_df.select(
    ...      from_protobuf(to_proto_df.value, 
message_class_name).alias("value"))
    >>> from_proto_df.show(truncate=False)
    +------------------+
    |value             |
    +------------------+
    |{1668035962, 2020}|
    +------------------+
    ```
    
    ```
    >>> import tempfile
    >>> data = [([(2, "Alice", 13093020)])]
    >>> ddl_schema = "value struct<age: INTEGER, name: STRING, score: LONG>"
    >>> df = spark.createDataFrame(data, ddl_schema)
    >>> desc_hex = 
str('0ACE010A41636F6E6E6563746F722F70726F746F6275662F7372632F746573742F726'
    ...     
'5736F75726365732F70726F746F6275662F7079737061726B5F746573742E70726F746F121D6F72672E61'
    ...     
'70616368652E737061726B2E73716C2E70726F746F627566224B0A0D53696D706C654D657373616765121'
    ...     
'00A03616765180120012805520361676512120A046E616D6518022001280952046E616D6512140A057363'
    ...     
'6F7265180320012803520573636F72654215421353696D706C654D65737361676550726F746F736206707'
    ...     '26F746F33')
    >>> with tempfile.TemporaryDirectory() as tmp_dir:
    ...      desc_file_path = "%s/pyspark_test.desc" % tmp_dir
    ...      with open(desc_file_path, "wb") as f:
    ...          _ = f.write(bytearray.fromhex(desc_hex))
    ...          f.flush()
    ...          message_name = 'SimpleMessage'
    ...          proto_df = df.select( # With file name for descriptor
    ...              to_protobuf(df.value, message_name, 
desc_file_path).alias("suite"))
    ...          proto_df.show(truncate=False)
    ...          proto_df_2 = df.select( # With binary for descriptor
    ...              to_protobuf(df.value, message_name,
    ...                          
binaryDescriptorSet=bytearray.fromhex(desc_hex))
    ...              .alias("suite"))
    ...          proto_df_2.show(truncate=False)
    ...
    +-------------------------------------------+
    |suite                                      |
    +-------------------------------------------+
    |[08 02 12 05 41 6C 69 63 65 18 9C 91 9F 06]|
    +-------------------------------------------+
    
    +-------------------------------------------+
    |suite                                      |
    +-------------------------------------------+
    |[08 02 12 05 41 6C 69 63 65 18 9C 91 9F 06]|
    +-------------------------------------------+
    ```
    
    ```
    >>> data = [([(1668035962, 2020)])]
    >>> ddl_schema = "value struct<seconds: LONG, nanos: INT>"
    >>> df = spark.createDataFrame(data, ddl_schema)
    >>> message_class_name = 
"org.sparkproject.spark_protobuf.protobuf.Timestamp"
    >>> proto_df = df.select(to_protobuf(df.value, 
message_class_name).alias("suite"))
    >>> proto_df.show(truncate=False)
    +----------------------------+
    |suite                       |
    +----------------------------+
    |[08 FA EA B0 9B 06 10 E4 0F]|
    +----------------------------+
    ```
    
    Closes #42563 from bogao007/python-connect-protobuf.
    
    Authored-by: bogao007 <bo....@databricks.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 assembly/pom.xml                                 |   6 +
 dev/sparktestsupport/modules.py                  |  25 ++--
 python/pyspark/sql/connect/protobuf/__init__.py  |  18 +++
 python/pyspark/sql/connect/protobuf/functions.py | 166 +++++++++++++++++++++++
 python/pyspark/sql/protobuf/functions.py         |   6 +-
 python/pyspark/sql/utils.py                      |  16 +++
 6 files changed, 224 insertions(+), 13 deletions(-)

diff --git a/assembly/pom.xml b/assembly/pom.xml
index 4f80719afd4..69952a7ccfa 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -166,6 +166,12 @@
           <version>${project.version}</version>
           <scope>provided</scope>
         </dependency>
+        <dependency>
+          <groupId>org.apache.spark</groupId>
+          <artifactId>spark-protobuf_${scala.binary.version}</artifactId>
+          <version>${project.version}</version>
+          <scope>provided</scope>
+        </dependency>
       </dependencies>
     </profile>
     <profile>
diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py
index bca70018625..64ccf600ef0 100644
--- a/dev/sparktestsupport/modules.py
+++ b/dev/sparktestsupport/modules.py
@@ -271,27 +271,27 @@ sql_kafka = Module(
     ],
 )
 
-connect = Module(
-    name="connect",
-    dependencies=[hive, avro],
+protobuf = Module(
+    name="protobuf",
+    dependencies=[sql],
     source_file_regexes=[
-        "connector/connect",
+        "connector/protobuf",
     ],
-    build_profile_flags=["-Pconnect"],
     sbt_test_goals=[
-        "connect/test",
-        "connect-client-jvm/test",
+        "protobuf/test",
     ],
 )
 
-protobuf = Module(
-    name="protobuf",
-    dependencies=[sql],
+connect = Module(
+    name="connect",
+    dependencies=[hive, avro, protobuf],
     source_file_regexes=[
-        "connector/protobuf",
+        "connector/connect",
     ],
+    build_profile_flags=["-Pconnect"],
     sbt_test_goals=[
-        "protobuf/test",
+        "connect/test",
+        "connect-client-jvm/test",
     ],
 )
 
@@ -832,6 +832,7 @@ pyspark_connect = Module(
         "pyspark.sql.connect.dataframe",
         "pyspark.sql.connect.functions",
         "pyspark.sql.connect.avro.functions",
+        "pyspark.sql.connect.protobuf.functions",
         "pyspark.sql.connect.streaming.readwriter",
         "pyspark.sql.connect.streaming.query",
         # sql unittests
diff --git a/python/pyspark/sql/connect/protobuf/__init__.py 
b/python/pyspark/sql/connect/protobuf/__init__.py
new file mode 100644
index 00000000000..dc81e9f515e
--- /dev/null
+++ b/python/pyspark/sql/connect/protobuf/__init__.py
@@ -0,0 +1,18 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""Spark Connect Python Client - Protobuf Functions"""
diff --git a/python/pyspark/sql/connect/protobuf/functions.py 
b/python/pyspark/sql/connect/protobuf/functions.py
new file mode 100644
index 00000000000..56119f4bc4e
--- /dev/null
+++ b/python/pyspark/sql/connect/protobuf/functions.py
@@ -0,0 +1,166 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+A collections of builtin protobuf functions
+"""
+
+from pyspark.sql.connect.utils import check_dependencies
+
+check_dependencies(__name__)
+
+from typing import Dict, Optional, TYPE_CHECKING
+
+from pyspark.sql.protobuf import functions as PyProtobufFunctions
+
+from pyspark.sql.connect.column import Column
+from pyspark.sql.connect.functions import _invoke_function, _to_col, 
_options_to_col, lit
+
+if TYPE_CHECKING:
+    from pyspark.sql.connect._typing import ColumnOrName
+
+
+def from_protobuf(
+    data: "ColumnOrName",
+    messageName: str,
+    descFilePath: Optional[str] = None,
+    options: Optional[Dict[str, str]] = None,
+    binaryDescriptorSet: Optional[bytes] = None,
+) -> Column:
+    binary_proto = None
+    if binaryDescriptorSet is not None:
+        binary_proto = binaryDescriptorSet
+    elif descFilePath is not None:
+        binary_proto = _read_descriptor_set_file(descFilePath)
+
+    # TODO: simplify the code when _invoke_function() supports None as input.
+    if binary_proto is not None:
+        if options is None:
+            return _invoke_function(
+                "from_protobuf", _to_col(data), lit(messageName), 
lit(binary_proto)
+            )
+        else:
+            return _invoke_function(
+                "from_protobuf",
+                _to_col(data),
+                lit(messageName),
+                lit(binary_proto),
+                _options_to_col(options),
+            )
+    else:
+        if options is None:
+            return _invoke_function("from_protobuf", _to_col(data), 
lit(messageName))
+        else:
+            return _invoke_function(
+                "from_protobuf", _to_col(data), lit(messageName), 
_options_to_col(options)
+            )
+
+
+from_protobuf.__doc__ = PyProtobufFunctions.from_protobuf.__doc__
+
+
+def to_protobuf(
+    data: "ColumnOrName",
+    messageName: str,
+    descFilePath: Optional[str] = None,
+    options: Optional[Dict[str, str]] = None,
+    binaryDescriptorSet: Optional[bytes] = None,
+) -> Column:
+    binary_proto = None
+    if binaryDescriptorSet is not None:
+        binary_proto = binaryDescriptorSet
+    elif descFilePath is not None:
+        binary_proto = _read_descriptor_set_file(descFilePath)
+
+    # TODO: simplify the code when _invoke_function() supports None as input.
+    if binary_proto is not None:
+        if options is None:
+            return _invoke_function(
+                "to_protobuf", _to_col(data), lit(messageName), 
lit(binary_proto)
+            )
+        else:
+            return _invoke_function(
+                "to_protobuf",
+                _to_col(data),
+                lit(messageName),
+                lit(binary_proto),
+                _options_to_col(options),
+            )
+    else:
+        if options is None:
+            return _invoke_function("to_protobuf", _to_col(data), 
lit(messageName))
+        else:
+            return _invoke_function(
+                "to_protobuf", _to_col(data), lit(messageName), 
_options_to_col(options)
+            )
+
+
+to_protobuf.__doc__ = PyProtobufFunctions.to_protobuf.__doc__
+
+
+def _read_descriptor_set_file(filePath: str) -> bytes:
+    with open(filePath, "rb") as f:
+        return f.read()
+
+
+def _test() -> None:
+    import os
+    import sys
+    from pyspark.testing.utils import search_jar
+
+    protobuf_jar = search_jar("connector/protobuf", 
"spark-protobuf-assembly-", "spark-protobuf")
+    if protobuf_jar is None:
+        print(
+            "Skipping all Protobuf Python tests as the optional Protobuf 
project was "
+            "not compiled into a JAR. To run these tests, "
+            "you need to build Spark with 'build/sbt package' or "
+            "'build/mvn package' before running this test."
+        )
+        sys.exit(0)
+    else:
+        existing_args = os.environ.get("PYSPARK_SUBMIT_ARGS", "pyspark-shell")
+        jars_args = "--jars %s" % protobuf_jar
+        os.environ["PYSPARK_SUBMIT_ARGS"] = " ".join([jars_args, 
existing_args])
+
+    import doctest
+    from pyspark.sql import SparkSession as PySparkSession
+    import pyspark.sql.connect.protobuf.functions
+
+    globs = pyspark.sql.connect.protobuf.functions.__dict__.copy()
+
+    globs["spark"] = (
+        PySparkSession.builder.appName("sql.protobuf.functions tests")
+        .remote("local[2]")
+        .getOrCreate()
+    )
+
+    (failure_count, test_count) = doctest.testmod(
+        pyspark.sql.connect.protobuf.functions,
+        globs=globs,
+        optionflags=doctest.ELLIPSIS
+        | doctest.NORMALIZE_WHITESPACE
+        | doctest.IGNORE_EXCEPTION_DETAIL,
+    )
+
+    globs["spark"].stop()
+
+    if failure_count:
+        sys.exit(-1)
+
+
+if __name__ == "__main__":
+    _test()
diff --git a/python/pyspark/sql/protobuf/functions.py 
b/python/pyspark/sql/protobuf/functions.py
index 42165938eb7..acb1a17efbd 100644
--- a/python/pyspark/sql/protobuf/functions.py
+++ b/python/pyspark/sql/protobuf/functions.py
@@ -25,13 +25,14 @@ from typing import Dict, Optional, TYPE_CHECKING, cast
 from py4j.java_gateway import JVMView
 
 from pyspark.sql.column import Column, _to_java_column
-from pyspark.sql.utils import get_active_spark_context
+from pyspark.sql.utils import get_active_spark_context, 
try_remote_protobuf_functions
 from pyspark.util import _print_missing_jar
 
 if TYPE_CHECKING:
     from pyspark.sql._typing import ColumnOrName
 
 
+@try_remote_protobuf_functions
 def from_protobuf(
     data: "ColumnOrName",
     messageName: str,
@@ -58,6 +59,7 @@ def from_protobuf(
 
     .. versionchanged:: 3.5.0
         Supports `binaryDescriptorSet` arg to pass binary descriptor directly.
+        Supports Spark Connect.
 
     Parameters
     ----------
@@ -161,6 +163,7 @@ def from_protobuf(
     return Column(jc)
 
 
+@try_remote_protobuf_functions
 def to_protobuf(
     data: "ColumnOrName",
     messageName: str,
@@ -187,6 +190,7 @@ def to_protobuf(
 
     .. versionchanged:: 3.5.0
         Supports `binaryDescriptorSet` arg to pass binary descriptor directly.
+        Supports Spark Connect.
 
     Parameters
     ----------
diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py
index cb262a14cbe..45df4433916 100644
--- a/python/pyspark/sql/utils.py
+++ b/python/pyspark/sql/utils.py
@@ -211,6 +211,22 @@ def try_remote_avro_functions(f: FuncT) -> FuncT:
     return cast(FuncT, wrapped)
 
 
+def try_remote_protobuf_functions(f: FuncT) -> FuncT:
+    """Mark API supported from Spark Connect."""
+
+    @functools.wraps(f)
+    def wrapped(*args: Any, **kwargs: Any) -> Any:
+
+        if is_remote() and "PYSPARK_NO_NAMESPACE_SHARE" not in os.environ:
+            from pyspark.sql.connect.protobuf import functions
+
+            return getattr(functions, f.__name__)(*args, **kwargs)
+        else:
+            return f(*args, **kwargs)
+
+    return cast(FuncT, wrapped)
+
+
 def try_remote_window(f: FuncT) -> FuncT:
     """Mark API supported from Spark Connect."""
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to