This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 55391f633a4 [SPARK-44587][SQL][CONNECT] Increase protobuf marshaller 
recursion limit
55391f633a4 is described below

commit 55391f633a43113bdd36b5720f5a5f6d6a9daed8
Author: Yihong He <yihong...@databricks.com>
AuthorDate: Mon Jul 31 09:39:13 2023 +0900

    [SPARK-44587][SQL][CONNECT] Increase protobuf marshaller recursion limit
    
    ### What changes were proposed in this pull request?
    
    - Use customized marshallers for spark connect grpc methods
    - Increase Protobuf marshaller recursion limit
    
    ### Why are the changes needed?
    
    - Nested DFs fail easily
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    `build/sbt "connect-client-jvm/testOnly *ClientE2ETestSuite"`
    
    Closes #42212 from heyihong/SPARK-44587-2.
    
    Authored-by: Yihong He <yihong...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../org/apache/spark/sql/ClientE2ETestSuite.scala  |  8 +++
 .../apache/spark/sql/connect/config/Connect.scala  | 10 ++++
 .../sql/connect/service/SparkConnectService.scala  | 60 +++++++++++++++++++---
 3 files changed, 72 insertions(+), 6 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index 36f47cc1fba..1403d460b51 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -43,6 +43,14 @@ import org.apache.spark.sql.types._
 
 class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with 
PrivateMethodTester {
 
+  test("spark deep recursion") {
+    var df = spark.range(1)
+    for (a <- 1 to 500) {
+      df = df.union(spark.range(a, a + 1))
+    }
+    assert(df.collect().length == 501)
+  }
+
   test("many tables") {
     withSQLConf("spark.sql.execution.arrow.maxRecordsPerBatch" -> "10") {
       val numTables = 20
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
index 23aa42bad30..142b206fbf4 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
@@ -62,6 +62,16 @@ object Connect {
       .bytesConf(ByteUnit.BYTE)
       .createWithDefault(ConnectCommon.CONNECT_GRPC_MAX_MESSAGE_SIZE)
 
+  val CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT =
+    ConfigBuilder("spark.connect.grpc.marshallerRecursionLimit")
+      .internal()
+      .doc("""
+          |Sets the recursion limit to grpc protobuf messages.
+          |""".stripMargin)
+      .version("3.5.0")
+      .intConf
+      .createWithDefault(1024)
+
   val CONNECT_EXTENSIONS_RELATION_CLASSES =
     ConfigBuilder("spark.connect.extensions.relation.classes")
       .doc("""
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
index 206e24714fe..8f93d5083f4 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -21,21 +21,27 @@ import java.net.InetSocketAddress
 import java.util.UUID
 import java.util.concurrent.TimeUnit
 
+import scala.jdk.CollectionConverters._
+
 import com.google.common.base.Ticker
 import com.google.common.cache.{CacheBuilder, RemovalListener, 
RemovalNotification}
-import io.grpc.Server
+import com.google.protobuf.MessageLite
+import io.grpc.{BindableService, MethodDescriptor, Server, 
ServerMethodDefinition, ServerServiceDefinition}
+import io.grpc.MethodDescriptor.PrototypeMarshaller
 import io.grpc.netty.NettyServerBuilder
+import io.grpc.protobuf.lite.ProtoLiteUtils
 import io.grpc.protobuf.services.ProtoReflectionService
 import io.grpc.stub.StreamObserver
 import org.apache.commons.lang3.StringUtils
 
 import org.apache.spark.{SparkContext, SparkEnv, SparkSQLException}
 import org.apache.spark.connect.proto
-import org.apache.spark.connect.proto.{AddArtifactsRequest, 
AddArtifactsResponse}
+import org.apache.spark.connect.proto.{AddArtifactsRequest, 
AddArtifactsResponse, SparkConnectServiceGrpc}
+import org.apache.spark.connect.proto.SparkConnectServiceGrpc.AsyncService
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config.UI.UI_ENABLED
 import org.apache.spark.sql.SparkSession
-import 
org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_ADDRESS, 
CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE}
+import 
org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_ADDRESS, 
CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT, 
CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE}
 import org.apache.spark.sql.connect.ui.{SparkConnectServerAppStatusStore, 
SparkConnectServerListener, SparkConnectServerTab}
 import org.apache.spark.sql.connect.utils.ErrorUtils
 import org.apache.spark.status.ElementTrackingStore
@@ -48,9 +54,7 @@ import org.apache.spark.status.ElementTrackingStore
  * @param debug
  *   delegates debug behavior to the handlers.
  */
-class SparkConnectService(debug: Boolean)
-    extends proto.SparkConnectServiceGrpc.SparkConnectServiceImplBase
-    with Logging {
+class SparkConnectService(debug: Boolean) extends AsyncService with 
BindableService with Logging {
 
   /**
    * This is the main entry method for Spark Connect and all calls to execute 
a plan.
@@ -164,6 +168,50 @@ class SparkConnectService(debug: Boolean)
         userId = request.getUserContext.getUserId,
         sessionId = request.getSessionId)
   }
+
+  private def methodWithCustomMarshallers(methodDesc: 
MethodDescriptor[MessageLite, MessageLite])
+      : MethodDescriptor[MessageLite, MessageLite] = {
+    val recursionLimit =
+      SparkEnv.get.conf.get(CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT)
+    val requestMarshaller =
+      ProtoLiteUtils.marshallerWithRecursionLimit(
+        methodDesc.getRequestMarshaller
+          .asInstanceOf[PrototypeMarshaller[MessageLite]]
+          .getMessagePrototype,
+        recursionLimit)
+    val responseMarshaller =
+      ProtoLiteUtils.marshallerWithRecursionLimit(
+        methodDesc.getResponseMarshaller
+          .asInstanceOf[PrototypeMarshaller[MessageLite]]
+          .getMessagePrototype,
+        recursionLimit)
+    methodDesc.toBuilder
+      .setRequestMarshaller(requestMarshaller)
+      .setResponseMarshaller(responseMarshaller)
+      .build()
+  }
+
+  override def bindService(): ServerServiceDefinition = {
+    // First, get the SparkConnectService ServerServiceDefinition.
+    val serviceDef = SparkConnectServiceGrpc.bindService(this)
+
+    // Create a new ServerServiceDefinition builder
+    // using the name of the original service definition.
+    val builder = 
io.grpc.ServerServiceDefinition.builder(serviceDef.getServiceDescriptor.getName)
+
+    // Iterate through all the methods of the original service definition.
+    // For each method, add a customized method descriptor (with updated 
marshallers)
+    // and the original server call handler to the builder.
+    serviceDef.getMethods.asScala
+      .asInstanceOf[Iterable[ServerMethodDefinition[MessageLite, MessageLite]]]
+      .foreach(method =>
+        builder.addMethod(
+          methodWithCustomMarshallers(method.getMethodDescriptor),
+          method.getServerCallHandler))
+
+    // Build the final ServerServiceDefinition and return it.
+    builder.build()
+  }
 }
 
 /**


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to