This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new 30efbfdcc8c [SPARK-44587][SQL][CONNECT] Increase protobuf marshaller recursion limit 30efbfdcc8c is described below commit 30efbfdcc8c237c536a2320a688675f4e69bb075 Author: Yihong He <yihong...@databricks.com> AuthorDate: Mon Jul 31 09:39:13 2023 +0900 [SPARK-44587][SQL][CONNECT] Increase protobuf marshaller recursion limit - Use customized marshallers for spark connect grpc methods - Increase Protobuf marshaller recursion limit - Nested DFs fail easily No `build/sbt "connect-client-jvm/testOnly *ClientE2ETestSuite"` Closes #42212 from heyihong/SPARK-44587-2. Authored-by: Yihong He <yihong...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> (cherry picked from commit 55391f633a43113bdd36b5720f5a5f6d6a9daed8) Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../org/apache/spark/sql/ClientE2ETestSuite.scala | 8 +++ .../apache/spark/sql/connect/config/Connect.scala | 10 ++++ .../sql/connect/service/SparkConnectService.scala | 60 +++++++++++++++++++--- 3 files changed, 72 insertions(+), 6 deletions(-) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index 36f47cc1fba..1403d460b51 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -43,6 +43,14 @@ import org.apache.spark.sql.types._ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateMethodTester { + test("spark deep recursion") { + var df = spark.range(1) + for (a <- 1 to 500) { + df = df.union(spark.range(a, a + 1)) + } + assert(df.collect().length == 501) + } + test("many tables") { withSQLConf("spark.sql.execution.arrow.maxRecordsPerBatch" -> "10") { val numTables = 20 diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala index 2a805c45392..15288a65a45 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala @@ -56,6 +56,16 @@ object Connect { .bytesConf(ByteUnit.BYTE) .createWithDefault(ConnectCommon.CONNECT_GRPC_MAX_MESSAGE_SIZE) + val CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT = + ConfigBuilder("spark.connect.grpc.marshallerRecursionLimit") + .internal() + .doc(""" + |Sets the recursion limit to grpc protobuf messages. + |""".stripMargin) + .version("3.5.0") + .intConf + .createWithDefault(1024) + val CONNECT_EXTENSIONS_RELATION_CLASSES = ConfigBuilder("spark.connect.extensions.relation.classes") .doc(""" diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index 121d2accf6b..35a9df82d30 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -20,21 +20,27 @@ package org.apache.spark.sql.connect.service import java.util.UUID import java.util.concurrent.TimeUnit +import scala.jdk.CollectionConverters._ + import com.google.common.base.Ticker import com.google.common.cache.{CacheBuilder, RemovalListener, RemovalNotification} -import io.grpc.Server +import com.google.protobuf.MessageLite +import io.grpc.{BindableService, MethodDescriptor, Server, ServerMethodDefinition, ServerServiceDefinition} +import io.grpc.MethodDescriptor.PrototypeMarshaller import io.grpc.netty.NettyServerBuilder +import io.grpc.protobuf.lite.ProtoLiteUtils import io.grpc.protobuf.services.ProtoReflectionService import io.grpc.stub.StreamObserver import org.apache.commons.lang3.StringUtils import org.apache.spark.{SparkContext, SparkEnv, SparkSQLException} import org.apache.spark.connect.proto -import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse} +import org.apache.spark.connect.proto.{AddArtifactsRequest, AddArtifactsResponse, SparkConnectServiceGrpc} +import org.apache.spark.connect.proto.SparkConnectServiceGrpc.AsyncService import org.apache.spark.internal.Logging import org.apache.spark.internal.config.UI.UI_ENABLED import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE} +import org.apache.spark.sql.connect.config.Connect.{CONNECT_GRPC_BINDING_PORT, CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT, CONNECT_GRPC_MAX_INBOUND_MESSAGE_SIZE} import org.apache.spark.sql.connect.ui.{SparkConnectServerAppStatusStore, SparkConnectServerListener, SparkConnectServerTab} import org.apache.spark.sql.connect.utils.ErrorUtils import org.apache.spark.status.ElementTrackingStore @@ -47,9 +53,7 @@ import org.apache.spark.status.ElementTrackingStore * @param debug * delegates debug behavior to the handlers. */ -class SparkConnectService(debug: Boolean) - extends proto.SparkConnectServiceGrpc.SparkConnectServiceImplBase - with Logging { +class SparkConnectService(debug: Boolean) extends AsyncService with BindableService with Logging { /** * This is the main entry method for Spark Connect and all calls to execute a plan. @@ -163,6 +167,50 @@ class SparkConnectService(debug: Boolean) userId = request.getUserContext.getUserId, sessionId = request.getSessionId) } + + private def methodWithCustomMarshallers(methodDesc: MethodDescriptor[MessageLite, MessageLite]) + : MethodDescriptor[MessageLite, MessageLite] = { + val recursionLimit = + SparkEnv.get.conf.get(CONNECT_GRPC_MARSHALLER_RECURSION_LIMIT) + val requestMarshaller = + ProtoLiteUtils.marshallerWithRecursionLimit( + methodDesc.getRequestMarshaller + .asInstanceOf[PrototypeMarshaller[MessageLite]] + .getMessagePrototype, + recursionLimit) + val responseMarshaller = + ProtoLiteUtils.marshallerWithRecursionLimit( + methodDesc.getResponseMarshaller + .asInstanceOf[PrototypeMarshaller[MessageLite]] + .getMessagePrototype, + recursionLimit) + methodDesc.toBuilder + .setRequestMarshaller(requestMarshaller) + .setResponseMarshaller(responseMarshaller) + .build() + } + + override def bindService(): ServerServiceDefinition = { + // First, get the SparkConnectService ServerServiceDefinition. + val serviceDef = SparkConnectServiceGrpc.bindService(this) + + // Create a new ServerServiceDefinition builder + // using the name of the original service definition. + val builder = io.grpc.ServerServiceDefinition.builder(serviceDef.getServiceDescriptor.getName) + + // Iterate through all the methods of the original service definition. + // For each method, add a customized method descriptor (with updated marshallers) + // and the original server call handler to the builder. + serviceDef.getMethods.asScala + .asInstanceOf[Iterable[ServerMethodDefinition[MessageLite, MessageLite]]] + .foreach(method => + builder.addMethod( + methodWithCustomMarshallers(method.getMethodDescriptor), + method.getServerCallHandler)) + + // Build the final ServerServiceDefinition and return it. + builder.build() + } } /** --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org