This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 637cf4f4ab8 [SPARK-41165][CONNECT] Avoid hangs in the arrow collect code path 637cf4f4ab8 is described below commit 637cf4f4ab84708e58f7265b8dea928e1964a95f Author: Herman van Hovell <her...@databricks.com> AuthorDate: Fri Nov 18 09:43:48 2022 +0900 [SPARK-41165][CONNECT] Avoid hangs in the arrow collect code path ### What changes were proposed in this pull request? Two changes: 1. Make sure connect's arrow result path properly deals with errors, and avoids hangs. 2. Fix a common source of non-serializable exceptions in `SparkConnectStreamHandler`. ### Why are the changes needed? The current Arrow result code path for connect assumes no error can happen during execution. As a result it will hang when an error occurs. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added a UT. Closes #38681 from hvanhovell/SPARK-41165. Authored-by: Herman van Hovell <her...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../service/SparkConnectStreamHandler.scala | 55 ++++++++++++++++++---- .../connect/planner/SparkConnectServiceSuite.scala | 48 +++++++++++++++++++ 2 files changed, 94 insertions(+), 9 deletions(-) diff --git a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala index ec2db3efa96..a780858d55c 100644 --- a/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala +++ b/connector/connect/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectStreamHandler.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.connect.service import scala.collection.JavaConverters._ +import scala.util.control.NonFatal import com.google.protobuf.ByteString import io.grpc.stub.StreamObserver @@ -27,13 +28,15 @@ import org.apache.spark.connect.proto import org.apache.spark.connect.proto.{Request, Response} import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connect.planner.SparkConnectPlanner import org.apache.spark.sql.execution.{SparkPlan, SQLExecution} import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, AdaptiveSparkPlanHelper, QueryStageExec} import org.apache.spark.sql.execution.arrow.ArrowConverters +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.ThreadUtils class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) extends Logging { - // The maximum batch size in bytes for a single batch of data to be returned via proto. private val MAX_BATCH_SIZE: Long = 4 * 1024 * 1024 @@ -139,14 +142,13 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte if (numPartitions > 0) { type Batch = (Array[Byte], Long) - val batches = rows.mapPartitionsInternal { iter => - val newIter = ArrowConverters - .toBatchWithSchemaIterator(iter, schema, maxRecordsPerBatch, maxBatchSize, timeZoneId) - newIter.map { batch: Array[Byte] => (batch, newIter.rowCountInLastBatch) } - } + val batches = rows.mapPartitionsInternal( + SparkConnectStreamHandler + .rowToArrowConverter(schema, maxRecordsPerBatch, maxBatchSize, timeZoneId)) val signal = new Object val partitions = collection.mutable.Map.empty[Int, Array[Batch]] + var error: Throwable = null val processPartition = (iter: Iterator[Batch]) => iter.toArray @@ -161,13 +163,23 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte () } - spark.sparkContext.submitJob( + val future = spark.sparkContext.submitJob( rdd = batches, processPartition = processPartition, partitions = Seq.range(0, numPartitions), resultHandler = resultHandler, resultFunc = () => ()) + // Collect errors and propagate them to the main thread. + future.onComplete { result => + result.failed.foreach { throwable => + signal.synchronized { + error = throwable + signal.notify() + } + } + }(ThreadUtils.sameThread) + // The main thread will wait until 0-th partition is available, // then send it to client and wait for the next partition. // Different from the implementation of [[Dataset#collectAsArrowToPython]], it sends @@ -178,11 +190,18 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte while (currentPartitionId < numPartitions) { val partition = signal.synchronized { var result = partitions.remove(currentPartitionId) - while (result.isEmpty) { + while (result.isEmpty && error == null) { signal.wait() result = partitions.remove(currentPartitionId) } - result.get + error match { + case NonFatal(e) => + responseObserver.onError(error) + logError("Error while processing query.", e) + return + case fatal: Throwable => throw fatal + case null => result.get + } } partition.foreach { case (bytes, count) => @@ -236,6 +255,24 @@ class SparkConnectStreamHandler(responseObserver: StreamObserver[Response]) exte } } +object SparkConnectStreamHandler { + type Batch = (Array[Byte], Long) + + private[service] def rowToArrowConverter( + schema: StructType, + maxRecordsPerBatch: Int, + maxBatchSize: Long, + timeZoneId: String): Iterator[InternalRow] => Iterator[Batch] = { rows => + val batches = ArrowConverters.toBatchWithSchemaIterator( + rows, + schema, + maxRecordsPerBatch, + maxBatchSize, + timeZoneId) + batches.map(b => b -> batches.rowCountInLastBatch) + } +} + object MetricGenerator extends AdaptiveSparkPlanHelper { def buildMetrics(p: SparkPlan): Response.Metrics = { val b = Response.Metrics.newBuilder diff --git a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala index 4be8d1705b9..7ff3a823fa1 100644 --- a/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala +++ b/connector/connect/src/test/scala/org/apache/spark/sql/connect/planner/SparkConnectServiceSuite.scala @@ -16,9 +16,18 @@ */ package org.apache.spark.sql.connect.planner +import scala.concurrent.Promise +import scala.concurrent.duration._ + +import io.grpc.stub.StreamObserver + +import org.apache.spark.SparkException import org.apache.spark.connect.proto +import org.apache.spark.sql.connect.dsl.MockRemoteSession +import org.apache.spark.sql.connect.dsl.plans._ import org.apache.spark.sql.connect.service.SparkConnectService import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.util.ThreadUtils /** * Testing Connect Service implementation. @@ -55,4 +64,43 @@ class SparkConnectServiceSuite extends SharedSparkSession { && schema.getFields(1).getType.getKindCase == proto.DataType.KindCase.STRING) } } + + test("SPARK-41165: failures in the arrow collect path should not cause hangs") { + val instance = new SparkConnectService(false) + + // Add an always crashing UDF + val session = SparkConnectService.getOrCreateIsolatedSession("c1").session + val instaKill: Long => Long = { _ => + throw new Exception("Kaboom") + } + session.udf.register("insta_kill", instaKill) + + val connect = new MockRemoteSession() + val context = proto.Request.UserContext + .newBuilder() + .setUserId("c1") + .build() + val plan = proto.Plan + .newBuilder() + .setRoot(connect.sql("select insta_kill(id) from range(10)")) + .build() + val request = proto.Request + .newBuilder() + .setPlan(plan) + .setUserContext(context) + .build() + + val promise = Promise[Seq[proto.Response]] + instance.executePlan( + request, + new StreamObserver[proto.Response] { + private val responses = Seq.newBuilder[proto.Response] + override def onNext(v: proto.Response): Unit = responses += v + override def onError(throwable: Throwable): Unit = promise.failure(throwable) + override def onCompleted(): Unit = promise.success(responses.result()) + }) + intercept[SparkException] { + ThreadUtils.awaitResult(promise.future, 2.seconds) + } + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org