This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 83fe9b16ab5a [SPARK-47694][CONNECT] Make max message size configurable 
on the client side
83fe9b16ab5a is described below

commit 83fe9b16ab5a2eec5f844d1e30488fe48223e29b
Author: Robert Dillitz <robert.dill...@databricks.com>
AuthorDate: Mon Apr 15 14:52:52 2024 -0400

    [SPARK-47694][CONNECT] Make max message size configurable on the client side
    
    ### What changes were proposed in this pull request?
    Follow up to https://github.com/apache/spark/pull/40447.
    Allows to configure the currently hardcoded max message of 128MB on the 
client side for both the Scala and Python clients.
    Adds the option to the Scala client and improves the way we handle 
`channelOptions` in Python's `ChannelBuiler`.
    
    ### Why are the changes needed?
    Usability - I am aware of two different cases where these limits are hit:
    
    1. The user is trying to create a large dataframe from local data. We 
either hit the` grpc.max_send_message_length` in the Python client ([currently 
hardcoded](https://github.com/apache/spark/pull/40447/files)) or the 
`maxInboundMessageSize`  on the cluster side ([now 
configurable](https://github.com/apache/spark/pull/40447/files)).
    2. The result from the cluster has a single row that is larger than 128MB, 
causing an `ExecutePlanResponse` that is larger than the client's 
`grpc.max_receive_message_length` (Python) or `channel.maxInboundMessageSize`  
(Scala) ([both hardcoded](https://github.com/apache/spark/pull/40447/files)).
    
    This gives the option to increase these limits on the client side.
    
    ### Does this PR introduce _any_ user-facing change?
    Scala: Adds option to set `grpcMaxMessageSize` to 
`SparkConnectClient.Builder`
    Python: No.
    
    ### How was this patch tested?
    Tests added.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No.
    
    Closes #45842 from dillitz/SPARK-47694.
    
    Authored-by: Robert Dillitz <robert.dill...@databricks.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../connect/client/SparkConnectClientSuite.scala   |  9 +++++++-
 .../sql/connect/client/SparkConnectClient.scala    | 14 +++++++++++--
 .../connect/client/SparkConnectClientParser.scala  | 24 +++++++++++++---------
 python/pyspark/sql/connect/client/core.py          | 20 +++++++++++++-----
 .../sql/tests/connect/test_connect_session.py      | 20 ++++++++++++++++++
 5 files changed, 69 insertions(+), 18 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
index 5a43cf014bdc..55f962b2a52c 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala
@@ -310,7 +310,14 @@ class SparkConnectClientSuite extends ConnectFunSuite with 
BeforeAndAfterEach {
         assert(client.userAgent.contains("scala/"))
         assert(client.userAgent.contains("jvm/"))
         assert(client.userAgent.contains("os/"))
-      }))
+      }),
+    TestPackURI(
+      "sc://SPARK-47694:123/;grpc_max_message_size=1860",
+      isCorrect = true,
+      client => {
+        assert(client.configuration.grpcMaxMessageSize == 1860)
+      }),
+    TestPackURI("sc://SPARK-47694:123/;grpc_max_message_size=abc", isCorrect = 
false))
 
   private def checkTestPack(testPack: TestPackURI): Unit = {
     val client = 
SparkConnectClient.builder().connectionString(testPack.connectionString).build()
diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
index 746aaca6f559..d9d51c15a880 100644
--- 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala
@@ -510,6 +510,7 @@ object SparkConnectClient {
       val PARAM_TOKEN = "token"
       val PARAM_USER_AGENT = "user_agent"
       val PARAM_SESSION_ID = "session_id"
+      val PARAM_GRPC_MAX_MESSAGE_SIZE = "grpc_max_message_size"
     }
 
     private def verifyURI(uri: URI): Unit = {
@@ -558,6 +559,13 @@ object SparkConnectClient {
 
     def userAgent: String = _configuration.userAgent
 
+    def grpcMaxMessageSize(messageSize: Int): Builder = {
+      _configuration = _configuration.copy(grpcMaxMessageSize = messageSize)
+      this
+    }
+
+    def grpcMaxMessageSize: Int = _configuration.grpcMaxMessageSize
+
     def option(key: String, value: String): Builder = {
       _configuration = _configuration.copy(metadata = _configuration.metadata 
+ ((key, value)))
       this
@@ -584,6 +592,7 @@ object SparkConnectClient {
           case URIParams.PARAM_USE_SSL =>
             if (java.lang.Boolean.valueOf(value)) enableSsl() else disableSsl()
           case URIParams.PARAM_SESSION_ID => sessionId(value)
+          case URIParams.PARAM_GRPC_MAX_MESSAGE_SIZE => 
grpcMaxMessageSize(value.toInt)
           case _ => option(key, value)
         }
       }
@@ -693,7 +702,8 @@ object SparkConnectClient {
       retryPolicies: Seq[RetryPolicy] = RetryPolicy.defaultPolicies(),
       useReattachableExecute: Boolean = true,
       interceptors: List[ClientInterceptor] = List.empty,
-      sessionId: Option[String] = None) {
+      sessionId: Option[String] = None,
+      grpcMaxMessageSize: Int = ConnectCommon.CONNECT_GRPC_MAX_MESSAGE_SIZE) {
 
     def userContext: proto.UserContext = {
       val builder = proto.UserContext.newBuilder()
@@ -731,7 +741,7 @@ object SparkConnectClient {
 
       interceptors.foreach(channelBuilder.intercept(_))
 
-      
channelBuilder.maxInboundMessageSize(ConnectCommon.CONNECT_GRPC_MAX_MESSAGE_SIZE)
+      channelBuilder.maxInboundMessageSize(grpcMaxMessageSize)
       channelBuilder.build()
     }
 
diff --git 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClientParser.scala
 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClientParser.scala
index cfb5823ee43e..7e137a6a3e05 100644
--- 
a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClientParser.scala
+++ 
b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClientParser.scala
@@ -32,16 +32,17 @@ private[sql] object SparkConnectClientParser {
   def usage(): String =
     s"""
        |Options:
-       |   --remote REMOTE          URI of the Spark Connect Server to connect 
to.
-       |   --host HOST              Host where the Spark Connect Server is 
running.
-       |   --port PORT              Port where the Spark Connect Server is 
running.
-       |   --use_ssl                Connect to the server using SSL.
-       |   --token TOKEN            Token to use for authentication.
-       |   --user_id USER_ID        Id of the user connecting.
-       |   --user_name USER_NAME    Name of the user connecting.
-       |   --user_agent USER_AGENT  The User-Agent Client information (only 
intended for logging purposes by the server).
-       |   --session_id SESSION_ID  Session Id of the user connecting.
-       |   --option KEY=VALUE       Key-value pair that is used to further 
configure the session.
+       |   --remote REMOTE              URI of the Spark Connect Server to 
connect to.
+       |   --host HOST                  Host where the Spark Connect Server is 
running.
+       |   --port PORT                  Port where the Spark Connect Server is 
running.
+       |   --use_ssl                    Connect to the server using SSL.
+       |   --token TOKEN                Token to use for authentication.
+       |   --user_id USER_ID            Id of the user connecting.
+       |   --user_name USER_NAME        Name of the user connecting.
+       |   --user_agent USER_AGENT      The User-Agent Client information 
(only intended for logging purposes by the server).
+       |   --session_id SESSION_ID      Session Id of the user connecting.
+       |   --grpc_max_message_size SIZE Maximum message size allowed for gRPC 
messages in bytes.
+       |   --option KEY=VALUE           Key-value pair that is used to further 
configure the session.
      """.stripMargin
   // scalastyle:on line.size.limit
 
@@ -88,6 +89,9 @@ private[sql] object SparkConnectClientParser {
             s"--option should contain key=value, found ${tail.head} instead")
         }
         parse(tail.tail, builder.option(key, value))
+      case "--grpc_max_message_size" :: tail =>
+        val (value, remainder) = extract("--grpc_max_message_size", tail)
+        parse(remainder, builder.grpcMaxMessageSize(value.toInt))
       case unsupported :: _ =>
         throw new IllegalArgumentException(s"$unsupported is an unsupported 
argument.")
     }
diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index 532d490d925e..667b93596c5f 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -115,11 +115,12 @@ class ChannelBuilder:
     PARAM_USER_ID = "user_id"
     PARAM_USER_AGENT = "user_agent"
     PARAM_SESSION_ID = "session_id"
-    MAX_MESSAGE_LENGTH = 128 * 1024 * 1024
+
+    GRPC_MAX_MESSAGE_LENGTH_DEFAULT = 128 * 1024 * 1024
 
     GRPC_DEFAULT_OPTIONS = [
-        ("grpc.max_send_message_length", MAX_MESSAGE_LENGTH),
-        ("grpc.max_receive_message_length", MAX_MESSAGE_LENGTH),
+        ("grpc.max_send_message_length", GRPC_MAX_MESSAGE_LENGTH_DEFAULT),
+        ("grpc.max_receive_message_length", GRPC_MAX_MESSAGE_LENGTH_DEFAULT),
     ]
 
     def __init__(
@@ -129,10 +130,11 @@ class ChannelBuilder:
     ):
         self._interceptors: List[grpc.UnaryStreamClientInterceptor] = []
         self._params: Dict[str, str] = params or dict()
-        self._channel_options: List[Tuple[str, Any]] = 
ChannelBuilder.GRPC_DEFAULT_OPTIONS
+        self._channel_options: List[Tuple[str, Any]] = 
ChannelBuilder.GRPC_DEFAULT_OPTIONS.copy()
 
         if channelOptions is not None:
-            self._channel_options = self._channel_options + channelOptions
+            for key, value in channelOptions:
+                self.setChannelOption(key, value)
 
     def get(self, key: str) -> Any:
         """
@@ -152,6 +154,14 @@ class ChannelBuilder:
     def set(self, key: str, value: Any) -> None:
         self._params[key] = value
 
+    def setChannelOption(self, key: str, value: Any) -> None:
+        # overwrite option if it exists already else append it
+        for i, option in enumerate(self._channel_options):
+            if option[0] == key:
+                self._channel_options[i] = (key, value)
+                return
+        self._channel_options.append((key, value))
+
     def add_interceptor(self, interceptor: grpc.UnaryStreamClientInterceptor) 
-> None:
         self._interceptors.append(interceptor)
 
diff --git a/python/pyspark/sql/tests/connect/test_connect_session.py 
b/python/pyspark/sql/tests/connect/test_connect_session.py
index 4d6127b5be8b..1caf3525cfbb 100644
--- a/python/pyspark/sql/tests/connect/test_connect_session.py
+++ b/python/pyspark/sql/tests/connect/test_connect_session.py
@@ -498,6 +498,26 @@ class ChannelBuilderTests(unittest.TestCase):
         chan = DefaultChannelBuilder("sc://host/")
         self.assertIsNone(chan.session_id)
 
+    def test_channel_options(self):
+        # SPARK-47694
+        chan = DefaultChannelBuilder(
+            "sc://host", [("grpc.max_send_message_length", 1860), ("test", 
"robert")]
+        )
+        options = chan._channel_options
+        self.assertEqual(
+            [k for k, _ in options].count("grpc.max_send_message_length"),
+            1,
+            "only one occurrence for defaults",
+        )
+        self.assertEqual(
+            next(v for k, v in options if k == "grpc.max_send_message_length"),
+            1860,
+            "overwrites defaults",
+        )
+        self.assertEqual(
+            next(v for k, v in options if k == "test"), "robert", "new values 
are picked up"
+        )
+
 
 if __name__ == "__main__":
     from pyspark.sql.tests.connect.test_connect_session import *  # noqa: F401


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to