This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 4863be5632f [SPARK-45207][SQL][CONNECT] Implement Error Enrichment for Scala Client 4863be5632f is described below commit 4863be5632f3165a5699a525235ea118c1e1f7eb Author: Yihong He <yihong...@databricks.com> AuthorDate: Mon Sep 25 09:35:33 2023 +0900 [SPARK-45207][SQL][CONNECT] Implement Error Enrichment for Scala Client ### What changes were proposed in this pull request? - Implemented the reconstruction of the complete exception (un-truncated error messages, cause exceptions, server-side stacktrace) based on the responses of FetchErrorDetails RPC. ### Why are the changes needed? - Cause exceptions play an important role in the current control flow, such as in StreamingQueryException. They are also valuable for debugging. - Un-truncated error message is useful for debugging - Providing server-side stack traces aids in effectively diagnosing server-related issues. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - `build/sbt "connect-client-jvm/testOnly *ClientE2ETestSuite"` - `build/sbt "connect-client-jvm/testOnly *ClientStreamingQuerySuite"` ### Was this patch authored or co-authored using generative AI tooling? No Closes #42987 from heyihong/SPARK-45207. Authored-by: Yihong He <yihong...@databricks.com> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../org/apache/spark/sql/ClientE2ETestSuite.scala | 59 ++++++- .../sql/streaming/ClientStreamingQuerySuite.scala | 41 ++++- .../client/CustomSparkConnectBlockingStub.scala | 44 ++++- .../connect/client/GrpcExceptionConverter.scala | 192 +++++++++++++++++---- 4 files changed, 292 insertions(+), 44 deletions(-) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala index 21892542eab..ec9b1698a4e 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import java.io.{ByteArrayOutputStream, PrintStream} import java.nio.file.Files +import java.time.DateTimeException import java.util.Properties import scala.collection.JavaConverters._ @@ -29,7 +30,7 @@ import org.apache.commons.lang3.{JavaVersion, SystemUtils} import org.scalactic.TolerantNumerics import org.scalatest.PrivateMethodTester -import org.apache.spark.{SparkArithmeticException, SparkException} +import org.apache.spark.{SparkArithmeticException, SparkException, SparkUpgradeException} import org.apache.spark.SparkBuildInfo.{spark_version => SPARK_VERSION} import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchDatabaseException, NoSuchTableException, TableAlreadyExistsException, TempTableAlreadyExistsException} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.StringEncoder @@ -44,6 +45,62 @@ import org.apache.spark.sql.types._ class ClientE2ETestSuite extends RemoteSparkSession with SQLHelper with PrivateMethodTester { + for (enrichErrorEnabled <- Seq(false, true)) { + test(s"cause exception - ${enrichErrorEnabled}") { + withSQLConf("spark.sql.connect.enrichError.enabled" -> enrichErrorEnabled.toString) { + val ex = intercept[SparkUpgradeException] { + spark + .sql(""" + |select from_json( + | '{"d": "02-29"}', + | 'd date', + | map('dateFormat', 'MM-dd')) + |""".stripMargin) + .collect() + } + if (enrichErrorEnabled) { + assert(ex.getCause.isInstanceOf[DateTimeException]) + } else { + assert(ex.getCause == null) + } + } + } + } + + test(s"throw SparkException with large cause exception") { + withSQLConf("spark.sql.connect.enrichError.enabled" -> "true") { + val session = spark + import session.implicits._ + + val throwException = + udf((_: String) => throw new SparkException("test" * 10000)) + + val ex = intercept[SparkException] { + Seq("1").toDS.withColumn("udf_val", throwException($"value")).collect() + } + + assert(ex.getCause.isInstanceOf[SparkException]) + assert(ex.getCause.getMessage.contains("test" * 10000)) + } + } + + for (isServerStackTraceEnabled <- Seq(false, true)) { + test(s"server-side stack trace is set in exceptions - ${isServerStackTraceEnabled}") { + withSQLConf( + "spark.sql.connect.serverStacktrace.enabled" -> isServerStackTraceEnabled.toString, + "spark.sql.pyspark.jvmStacktrace.enabled" -> "false") { + val ex = intercept[AnalysisException] { + spark.sql("select x").collect() + } + assert( + ex.getStackTrace + .find(_.getClassName.contains("org.apache.spark.sql.catalyst.analysis.CheckAnalysis")) + .isDefined + == isServerStackTraceEnabled) + } + } + } + test("throw SparkArithmeticException") { withSQLConf("spark.sql.ansi.enabled" -> "true") { intercept[SparkArithmeticException] { diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala index dc4d441ec30..5d281cfbfeb 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/streaming/ClientStreamingQuerySuite.scala @@ -27,11 +27,11 @@ import org.scalatest.concurrent.Eventually.eventually import org.scalatest.concurrent.Futures.timeout import org.scalatest.time.SpanSugar._ +import org.apache.spark.SparkException import org.apache.spark.api.java.function.VoidFunction2 import org.apache.spark.internal.Logging import org.apache.spark.sql.{DataFrame, ForeachWriter, Row, SparkSession} -import org.apache.spark.sql.functions.col -import org.apache.spark.sql.functions.window +import org.apache.spark.sql.functions.{col, udf, window} import org.apache.spark.sql.streaming.StreamingQueryListener.{QueryIdleEvent, QueryStartedEvent, QueryTerminatedEvent} import org.apache.spark.sql.test.{QueryTest, SQLHelper} import org.apache.spark.util.SparkFileUtils @@ -175,6 +175,43 @@ class ClientStreamingQuerySuite extends QueryTest with SQLHelper with Logging { } } + test("throw exception in streaming") { + // Disable spark.sql.pyspark.jvmStacktrace.enabled to avoid hitting the + // netty header limit. + withSQLConf("spark.sql.pyspark.jvmStacktrace.enabled" -> "false") { + val session = spark + import session.implicits._ + + val checkForTwo = udf((value: Int) => { + if (value == 2) { + throw new RuntimeException("Number 2 encountered!") + } + value + }) + + val query = spark.readStream + .format("rate") + .option("rowsPerSecond", "1") + .load() + .select(checkForTwo($"value").as("checkedValue")) + .writeStream + .outputMode("append") + .format("console") + .start() + + val exception = intercept[SparkException] { + query.awaitTermination() + } + + assert(exception.getCause.isInstanceOf[SparkException]) + assert(exception.getCause.getCause.isInstanceOf[SparkException]) + assert(exception.getCause.getCause.getCause.isInstanceOf[SparkException]) + assert( + exception.getCause.getCause.getCause.getMessage + .contains("java.lang.RuntimeException: Number 2 encountered!")) + } + } + test("foreach Row") { val writer = new TestForeachWriter[Row] diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala index 80edcfa8be1..f02704b2a02 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/CustomSparkConnectBlockingStub.scala @@ -27,11 +27,21 @@ private[connect] class CustomSparkConnectBlockingStub( retryPolicy: GrpcRetryHandler.RetryPolicy) { private val stub = SparkConnectServiceGrpc.newBlockingStub(channel) + private val retryHandler = new GrpcRetryHandler(retryPolicy) + // GrpcExceptionConverter with a GRPC stub for fetching error details from server. + private val grpcExceptionConverter = new GrpcExceptionConverter(stub) + def executePlan(request: ExecutePlanRequest): CloseableIterator[ExecutePlanResponse] = { - GrpcExceptionConverter.convert { - GrpcExceptionConverter.convertIterator[ExecutePlanResponse]( + grpcExceptionConverter.convert( + request.getSessionId, + request.getUserContext, + request.getClientType) { + grpcExceptionConverter.convertIterator[ExecutePlanResponse]( + request.getSessionId, + request.getUserContext, + request.getClientType, retryHandler.RetryIterator[ExecutePlanRequest, ExecutePlanResponse]( request, r => CloseableIterator(stub.executePlan(r).asScala))) @@ -40,15 +50,24 @@ private[connect] class CustomSparkConnectBlockingStub( def executePlanReattachable( request: ExecutePlanRequest): CloseableIterator[ExecutePlanResponse] = { - GrpcExceptionConverter.convert { - GrpcExceptionConverter.convertIterator[ExecutePlanResponse]( + grpcExceptionConverter.convert( + request.getSessionId, + request.getUserContext, + request.getClientType) { + grpcExceptionConverter.convertIterator[ExecutePlanResponse]( + request.getSessionId, + request.getUserContext, + request.getClientType, // Don't use retryHandler - own retry handling is inside. new ExecutePlanResponseReattachableIterator(request, channel, retryPolicy)) } } def analyzePlan(request: AnalyzePlanRequest): AnalyzePlanResponse = { - GrpcExceptionConverter.convert { + grpcExceptionConverter.convert( + request.getSessionId, + request.getUserContext, + request.getClientType) { retryHandler.retry { stub.analyzePlan(request) } @@ -56,7 +75,10 @@ private[connect] class CustomSparkConnectBlockingStub( } def config(request: ConfigRequest): ConfigResponse = { - GrpcExceptionConverter.convert { + grpcExceptionConverter.convert( + request.getSessionId, + request.getUserContext, + request.getClientType) { retryHandler.retry { stub.config(request) } @@ -64,7 +86,10 @@ private[connect] class CustomSparkConnectBlockingStub( } def interrupt(request: InterruptRequest): InterruptResponse = { - GrpcExceptionConverter.convert { + grpcExceptionConverter.convert( + request.getSessionId, + request.getUserContext, + request.getClientType) { retryHandler.retry { stub.interrupt(request) } @@ -72,7 +97,10 @@ private[connect] class CustomSparkConnectBlockingStub( } def artifactStatus(request: ArtifactStatusesRequest): ArtifactStatusesResponse = { - GrpcExceptionConverter.convert { + grpcExceptionConverter.convert( + request.getSessionId, + request.getUserContext, + request.getClientType) { retryHandler.retry { stub.artifactStatus(request) } diff --git a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala index fe9f6dc2b4a..edbc434ef96 100644 --- a/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala +++ b/connector/connect/common/src/main/scala/org/apache/spark/sql/connect/client/GrpcExceptionConverter.scala @@ -24,49 +24,145 @@ import scala.reflect.ClassTag import com.google.rpc.ErrorInfo import io.grpc.StatusRuntimeException import io.grpc.protobuf.StatusProto +import org.json4s.DefaultFormats +import org.json4s.jackson.JsonMethods import org.apache.spark.{SparkArithmeticException, SparkArrayIndexOutOfBoundsException, SparkDateTimeException, SparkException, SparkIllegalArgumentException, SparkNumberFormatException, SparkRuntimeException, SparkUnsupportedOperationException, SparkUpgradeException} +import org.apache.spark.connect.proto.{FetchErrorDetailsRequest, FetchErrorDetailsResponse, UserContext} +import org.apache.spark.connect.proto.SparkConnectServiceGrpc.SparkConnectServiceBlockingStub +import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, NoSuchDatabaseException, NoSuchTableException, TableAlreadyExistsException, TempTableAlreadyExistsException} import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.trees.Origin -import org.apache.spark.util.JsonUtils -private[client] object GrpcExceptionConverter extends JsonUtils { - def convert[T](f: => T): T = { +/** + * GrpcExceptionConverter handles the conversion of StatusRuntimeExceptions into Spark exceptions. + * It does so by utilizing the ErrorInfo defined in error_details.proto and making an additional + * FetchErrorDetails RPC call to retrieve the full error message and optionally the server-side + * stacktrace. + * + * If the FetchErrorDetails RPC call succeeds, the exceptions will be constructed based on the + * response. If the RPC call fails, the exception will be constructed based on the ErrorInfo. If + * the ErrorInfo is missing, the exception will be constructed based on the StatusRuntimeException + * itself. + */ +private[client] class GrpcExceptionConverter(grpcStub: SparkConnectServiceBlockingStub) + extends Logging { + import GrpcExceptionConverter._ + + def convert[T](sessionId: String, userContext: UserContext, clientType: String)(f: => T): T = { try { f } catch { case e: StatusRuntimeException => - throw toThrowable(e) + throw toThrowable(e, sessionId, userContext, clientType) } } - def convertIterator[T](iter: CloseableIterator[T]): CloseableIterator[T] = { + def convertIterator[T]( + sessionId: String, + userContext: UserContext, + clientType: String, + iter: CloseableIterator[T]): CloseableIterator[T] = { new WrappedCloseableIterator[T] { override def innerIterator: Iterator[T] = iter override def hasNext: Boolean = { - convert { + convert(sessionId, userContext, clientType) { iter.hasNext } } override def next(): T = { - convert { + convert(sessionId, userContext, clientType) { iter.next() } } override def close(): Unit = { - convert { + convert(sessionId, userContext, clientType) { iter.close() } } } } + /** + * Fetches enriched errors with full exception message and optionally stacktrace by issuing an + * additional RPC call to fetch error details. The RPC call is best-effort at-most-once. + */ + private def fetchEnrichedError( + info: ErrorInfo, + sessionId: String, + userContext: UserContext, + clientType: String): Option[Throwable] = { + val errorId = info.getMetadataOrDefault("errorId", null) + if (errorId == null) { + logWarning("Unable to fetch enriched error since errorId is missing") + return None + } + + try { + val errorDetailsResponse = grpcStub.fetchErrorDetails( + FetchErrorDetailsRequest + .newBuilder() + .setSessionId(sessionId) + .setErrorId(errorId) + .setUserContext(userContext) + .setClientType(clientType) + .build()) + + if (!errorDetailsResponse.hasRootErrorIdx) { + logWarning("Unable to fetch enriched error since error is not found") + return None + } + + Some( + errorsToThrowable( + errorDetailsResponse.getRootErrorIdx, + errorDetailsResponse.getErrorsList.asScala.toSeq)) + } catch { + case e: StatusRuntimeException => + logWarning("Unable to fetch enriched error", e) + None + } + } + + private def toThrowable( + ex: StatusRuntimeException, + sessionId: String, + userContext: UserContext, + clientType: String): Throwable = { + val status = StatusProto.fromThrowable(ex) + + // Extract the ErrorInfo from the StatusProto, if present. + val errorInfoOpt = status.getDetailsList.asScala + .find(_.is(classOf[ErrorInfo])) + .map(_.unpack(classOf[ErrorInfo])) + + if (errorInfoOpt.isDefined) { + // If ErrorInfo is found, try to fetch enriched error details by an additional RPC. + val enrichedErrorOpt = + fetchEnrichedError(errorInfoOpt.get, sessionId, userContext, clientType) + if (enrichedErrorOpt.isDefined) { + return enrichedErrorOpt.get + } + + // If fetching enriched error details fails, convert ErrorInfo to a Throwable. + // Unlike enriched errors above, the message from status may be truncated, + // and no cause exceptions or server-side stack traces will be reconstructed. + return errorInfoToThrowable(errorInfoOpt.get, status.getMessage) + } + + // If no ErrorInfo is found, create a SparkException based on the StatusRuntimeException. + new SparkException(ex.toString, ex.getCause) + } +} + +private object GrpcExceptionConverter { + private def errorConstructor[T <: Throwable: ClassTag]( throwableCtr: (String, Option[Throwable]) => T) : (String, (String, Option[Throwable]) => Throwable) = { @@ -93,33 +189,63 @@ private[client] object GrpcExceptionConverter extends JsonUtils { new SparkArrayIndexOutOfBoundsException(message)), errorConstructor[DateTimeException]((message, _) => new SparkDateTimeException(message)), errorConstructor((message, cause) => new SparkRuntimeException(message, cause)), - errorConstructor((message, cause) => new SparkUpgradeException(message, cause))) - - private def errorInfoToThrowable(info: ErrorInfo, message: String): Option[Throwable] = { - val classes = - mapper.readValue(info.getMetadataOrDefault("classes", "[]"), classOf[Array[String]]) + errorConstructor((message, cause) => new SparkUpgradeException(message, cause)), + errorConstructor((message, cause) => new SparkException(message, cause.orNull))) + + /** + * errorsToThrowable reconstructs the exception based on a list of protobuf messages + * FetchErrorDetailsResponse.Error with un-truncated error messages and server-side stacktrace + * (if set). + */ + private def errorsToThrowable( + errorIdx: Int, + errors: Seq[FetchErrorDetailsResponse.Error]): Throwable = { + + val error = errors(errorIdx) + + val classHierarchy = error.getErrorTypeHierarchyList.asScala + + val constructor = + classHierarchy + .flatMap(errorFactory.get) + .headOption + .getOrElse((message: String, cause: Option[Throwable]) => + new SparkException(s"${classHierarchy.head}: ${message}", cause.orNull)) + + val causeOpt = + if (error.hasCauseIdx) Some(errorsToThrowable(error.getCauseIdx, errors)) else None + + val exception = constructor(error.getMessage, causeOpt) + + if (!error.getStackTraceList.isEmpty) { + exception.setStackTrace(error.getStackTraceList.asScala.toArray.map { stackTraceElement => + new StackTraceElement( + stackTraceElement.getDeclaringClass, + stackTraceElement.getMethodName, + stackTraceElement.getFileName, + stackTraceElement.getLineNumber) + }) + } - classes - .find(errorFactory.contains) - .map { cls => - val constructor = errorFactory.get(cls).get - constructor(message, None) - } + exception } - private def toThrowable(ex: StatusRuntimeException): Throwable = { - val status = StatusProto.fromThrowable(ex) - - val fallbackEx = new SparkException(ex.toString, ex.getCause) - - val errorInfoOpt = status.getDetailsList.asScala - .find(_.is(classOf[ErrorInfo])) - - if (errorInfoOpt.isEmpty) { - return fallbackEx - } - - errorInfoToThrowable(errorInfoOpt.get.unpack(classOf[ErrorInfo]), status.getMessage) - .getOrElse(fallbackEx) + /** + * errorInfoToThrowable reconstructs the exception based on the error classes hierarchy and the + * truncated error message. + */ + private def errorInfoToThrowable(info: ErrorInfo, message: String): Throwable = { + implicit val formats = DefaultFormats + val classes = + JsonMethods.parse(info.getMetadataOrDefault("classes", "[]")).extract[Array[String]] + + errorsToThrowable( + 0, + Seq( + FetchErrorDetailsResponse.Error + .newBuilder() + .setMessage(message) + .addAllErrorTypeHierarchy(classes.toIterable.asJava) + .build())) } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org