This is an automated email from the ASF dual-hosted git repository. hvanhovell pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 48b1a283a2eb [SPARK-44622][SQL][CONNECT] Implement FetchErrorDetails RPC 48b1a283a2eb is described below commit 48b1a283a2eba9f70149d5980d074fad2743c4ff Author: Yihong He <yihong...@databricks.com> AuthorDate: Wed Sep 20 00:14:44 2023 -0400 [SPARK-44622][SQL][CONNECT] Implement FetchErrorDetails RPC ### What changes were proposed in this pull request? - Introduced the FetchErrorDetails RPC to retrieve comprehensive error details. FetchErrorDetails is used for enriching the error by issuing a separate RPC call based on the `errorId` field in the ErrorInfo. - Introduced error enrichment that utilizes an additional RPC to fetch untruncated exception messages and server-side stack traces. This enrichment can be enabled or disabled using the flag `spark.sql.connect.enrichError.enabled`, and it's true by default. - Implemented setting server-side stack traces for exceptions on the client side via FetchErrorDetails RPC for debugging. The feature is enabled or disabled using the flag `spark.sql.connect.serverStacktrace.enabled` and it's true by default ### Why are the changes needed? - Attaching full exception messages to the error details protobuf can quickly hit the 8K GRPC Netty header limit. Utilizing a separate RPC to fetch comprehensive error information is more dependable. - Providing server-side stack traces aids in effectively diagnosing server-related issues. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? - `build/sbt "connect/testOnly *FetchErrorDetailsHandlerSuite"` ### Was this patch authored or co-authored using generative AI tooling? No Closes #42377 from heyihong/SPARK-44622. Authored-by: Yihong He <yihong...@databricks.com> Signed-off-by: Herman van Hovell <her...@databricks.com> --- .../src/main/protobuf/spark/connect/base.proto | 57 +++++++ .../apache/spark/sql/connect/config/Connect.scala | 18 +++ .../spark/sql/connect/service/SessionHolder.scala | 21 ++- .../SparkConnectFetchErrorDetailsHandler.scala | 59 +++++++ .../sql/connect/service/SparkConnectService.scala | 14 ++ .../spark/sql/connect/utils/ErrorUtils.scala | 103 ++++++++++-- .../service/FetchErrorDetailsHandlerSuite.scala | 166 +++++++++++++++++++ python/pyspark/sql/connect/proto/base_pb2.py | 14 +- python/pyspark/sql/connect/proto/base_pb2.pyi | 180 +++++++++++++++++++++ python/pyspark/sql/connect/proto/base_pb2_grpc.py | 45 ++++++ 10 files changed, 659 insertions(+), 18 deletions(-) diff --git a/connector/connect/common/src/main/protobuf/spark/connect/base.proto b/connector/connect/common/src/main/protobuf/spark/connect/base.proto index 65e2493f8368..cf1355f7ebc1 100644 --- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto +++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto @@ -778,6 +778,60 @@ message ReleaseExecuteResponse { optional string operation_id = 2; } +message FetchErrorDetailsRequest { + + // (Required) + // The session_id specifies a Spark session for a user identified by user_context.user_id. + // The id should be a UUID string of the format `00112233-4455-6677-8899-aabbccddeeff`. + string session_id = 1; + + // User context + UserContext user_context = 2; + + // (Required) + // The id of the error. + string error_id = 3; +} + +message FetchErrorDetailsResponse { + + message StackTraceElement { + // The fully qualified name of the class containing the execution point. + string declaring_class = 1; + + // The name of the method containing the execution point. + string method_name = 2; + + // The name of the file containing the execution point. + string file_name = 3; + + // The line number of the source line containing the execution point. + int32 line_number = 4; + } + + // Error defines the schema for the representing exception. + message Error { + // The fully qualified names of the exception class and its parent classes. + repeated string error_type_hierarchy = 1; + + // The detailed message of the exception. + string message = 2; + + // The stackTrace of the exception. It will be set + // if the SQLConf spark.sql.connect.serverStacktrace.enabled is true. + repeated StackTraceElement stack_trace = 3; + + // The index of the cause error in errors. + optional int32 cause_idx = 4; + } + + // The index of the root error in errors. The field will not be set if the error is not found. + optional int32 root_error_idx = 1; + + // A list of errors. + repeated Error errors = 2; +} + // Main interface for the SparkConnect service. service SparkConnectService { @@ -813,5 +867,8 @@ service SparkConnectService { // Non reattachable executions are released automatically and immediately after the ExecutePlan // RPC and ReleaseExecute may not be used. rpc ReleaseExecute(ReleaseExecuteRequest) returns (ReleaseExecuteResponse) {} + + // FetchErrorDetails retrieves the matched exception with details based on a provided error id. + rpc FetchErrorDetails(FetchErrorDetailsRequest) returns (FetchErrorDetailsResponse) {} } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala index dfd6008ac09a..248444e710d2 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala @@ -20,6 +20,7 @@ import java.util.concurrent.TimeUnit import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.connect.common.config.ConnectCommon +import org.apache.spark.sql.internal.SQLConf.buildConf object Connect { import org.apache.spark.sql.internal.SQLConf.buildStaticConf @@ -213,4 +214,21 @@ object Connect { .version("3.5.0") .intConf .createWithDefault(200) + + val CONNECT_ENRICH_ERROR_ENABLED = + buildConf("spark.sql.connect.enrichError.enabled") + .doc(""" + |When true, it enriches errors with full exception messages and optionally server-side + |stacktrace on the client side via an additional RPC. + |""".stripMargin) + .version("4.0.0") + .booleanConf + .createWithDefault(true) + + val CONNECT_SERVER_STACKTRACE_ENABLED = + buildConf("spark.sql.connect.serverStacktrace.enabled") + .doc("When true, it sets the server-side stacktrace in the user-facing Spark exception.") + .version("4.0.0") + .booleanConf + .createWithDefault(true) } diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index 1cef02d7e346..0748cd237bf0 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -19,11 +19,14 @@ package org.apache.spark.sql.connect.service import java.nio.file.Path import java.util.UUID -import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap} +import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, TimeUnit} import scala.collection.JavaConverters._ import scala.collection.mutable +import com.google.common.base.Ticker +import com.google.common.cache.CacheBuilder + import org.apache.spark.{JobArtifactSet, SparkException} import org.apache.spark.internal.Logging import org.apache.spark.sql.DataFrame @@ -32,6 +35,7 @@ import org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager import org.apache.spark.sql.connect.common.InvalidPlanInput import org.apache.spark.sql.connect.planner.PythonStreamingQueryListener import org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper +import org.apache.spark.sql.connect.service.SessionHolder.{ERROR_CACHE_SIZE, ERROR_CACHE_TIMEOUT_SEC} import org.apache.spark.sql.streaming.StreamingQueryListener import org.apache.spark.util.SystemClock import org.apache.spark.util.Utils @@ -45,6 +49,15 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio private val executions: ConcurrentMap[String, ExecuteHolder] = new ConcurrentHashMap[String, ExecuteHolder]() + // The cache that maps an error id to a throwable. The throwable in cache is independent to + // each other. + private[connect] val errorIdToError = CacheBuilder + .newBuilder() + .ticker(Ticker.systemTicker()) + .maximumSize(ERROR_CACHE_SIZE) + .expireAfterAccess(ERROR_CACHE_TIMEOUT_SEC, TimeUnit.SECONDS) + .build[String, Throwable]() + val eventManager: SessionEventsManager = SessionEventsManager(this, new SystemClock()) // Mapping from relation ID (passed to client) to runtime dataframe. Used for callbacks like @@ -265,6 +278,12 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio object SessionHolder { + // The maximum number of distinct errors in the cache. + private val ERROR_CACHE_SIZE = 20 + + // The maximum time for an error to stay in the cache. + private val ERROR_CACHE_TIMEOUT_SEC = 60 + /** Creates a dummy session holder for use in tests. */ def forTesting(session: SparkSession): SessionHolder = { val ret = diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectFetchErrorDetailsHandler.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectFetchErrorDetailsHandler.scala new file mode 100644 index 000000000000..17a6e9e434f3 --- /dev/null +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectFetchErrorDetailsHandler.scala @@ -0,0 +1,59 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.service + +import io.grpc.stub.StreamObserver + +import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.FetchErrorDetailsResponse +import org.apache.spark.sql.connect.config.Connect +import org.apache.spark.sql.connect.utils.ErrorUtils +import org.apache.spark.sql.internal.SQLConf + +/** + * Handles [[proto.FetchErrorDetailsRequest]]s for the [[SparkConnectService]]. The handler + * retrieves the matched error with details from the cache based on a provided error id. + * + * @param responseObserver + */ +class SparkConnectFetchErrorDetailsHandler( + responseObserver: StreamObserver[proto.FetchErrorDetailsResponse]) { + + def handle(v: proto.FetchErrorDetailsRequest): Unit = { + val sessionHolder = + SparkConnectService + .getOrCreateIsolatedSession(v.getUserContext.getUserId, v.getSessionId) + + val response = Option(sessionHolder.errorIdToError.getIfPresent(v.getErrorId)) + .map { error => + // This error can only be fetched once, + // if a connection dies in the middle you cannot repeat. + sessionHolder.errorIdToError.invalidate(v.getErrorId) + + ErrorUtils.throwableToFetchErrorDetailsResponse( + st = error, + serverStackTraceEnabled = sessionHolder.session.conf.get( + Connect.CONNECT_SERVER_STACKTRACE_ENABLED) || sessionHolder.session.conf.get( + SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED)) + } + .getOrElse(FetchErrorDetailsResponse.newBuilder().build()) + + responseObserver.onNext(response) + + responseObserver.onCompleted() + } +} diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala index 269e47609dbf..e82c9cba5626 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala @@ -201,6 +201,20 @@ class SparkConnectService(debug: Boolean) extends AsyncService with BindableServ sessionId = request.getSessionId) } + override def fetchErrorDetails( + request: proto.FetchErrorDetailsRequest, + responseObserver: StreamObserver[proto.FetchErrorDetailsResponse]): Unit = { + try { + new SparkConnectFetchErrorDetailsHandler(responseObserver).handle(request) + } catch { + ErrorUtils.handleError( + "getErrorInfo", + observer = responseObserver, + userId = request.getUserContext.getUserId, + sessionId = request.getSessionId) + } + } + private def methodWithCustomMarshallers(methodDesc: MethodDescriptor[MessageLite, MessageLite]) : MethodDescriptor[MessageLite, MessageLite] = { val recursionLimit = diff --git a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala index 2050ebc01aa0..1abd44608cd0 100644 --- a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala +++ b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala @@ -17,8 +17,12 @@ package org.apache.spark.sql.connect.utils +import java.util.UUID + import scala.annotation.tailrec +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer +import scala.jdk.CollectionConverters._ import scala.util.control.NonFatal import com.google.protobuf.{Any => ProtoAny} @@ -33,13 +37,14 @@ import org.json4s.jackson.JsonMethods import org.apache.spark.{SparkEnv, SparkException, SparkThrowable} import org.apache.spark.api.python.PythonException +import org.apache.spark.connect.proto.FetchErrorDetailsResponse import org.apache.spark.internal.Logging import org.apache.spark.sql.connect.config.Connect -import org.apache.spark.sql.connect.service.ExecuteEventsManager -import org.apache.spark.sql.connect.service.SparkConnectService +import org.apache.spark.sql.connect.service.{ExecuteEventsManager, SessionHolder, SparkConnectService} import org.apache.spark.sql.internal.SQLConf private[connect] object ErrorUtils extends Logging { + private def allClasses(cl: Class[_]): Seq[Class[_]] = { val classes = ArrayBuffer.empty[Class[_]] if (cl != null && !cl.equals(classOf[java.lang.Object])) { @@ -57,7 +62,67 @@ private[connect] object ErrorUtils extends Logging { classes.toSeq } - private def buildStatusFromThrowable(st: Throwable, stackTraceEnabled: Boolean): RPCStatus = { + // The maximum length of the error chain. + private[connect] val MAX_ERROR_CHAIN_LENGTH = 5 + + /** + * Convert Throwable to a protobuf message FetchErrorDetailsResponse. + * @param st + * the Throwable to be converted + * @param serverStackTraceEnabled + * whether to return the server stack trace. + * @return + * FetchErrorDetailsResponse + */ + private[connect] def throwableToFetchErrorDetailsResponse( + st: Throwable, + serverStackTraceEnabled: Boolean = false): FetchErrorDetailsResponse = { + + var currentError = st + val buffer = mutable.Buffer.empty[FetchErrorDetailsResponse.Error] + + while (buffer.size < MAX_ERROR_CHAIN_LENGTH && currentError != null) { + val builder = FetchErrorDetailsResponse.Error + .newBuilder() + .setMessage(currentError.getMessage) + .addAllErrorTypeHierarchy( + ErrorUtils.allClasses(currentError.getClass).map(_.getName).asJava) + + if (serverStackTraceEnabled) { + builder.addAllStackTrace( + currentError.getStackTrace + .map { stackTraceElement => + FetchErrorDetailsResponse.StackTraceElement + .newBuilder() + .setDeclaringClass(stackTraceElement.getClassName) + .setMethodName(stackTraceElement.getMethodName) + .setFileName(stackTraceElement.getFileName) + .setLineNumber(stackTraceElement.getLineNumber) + .build() + } + .toIterable + .asJava) + } + + val causeIdx = buffer.size + 1 + + if (causeIdx < MAX_ERROR_CHAIN_LENGTH && currentError.getCause != null) { + builder.setCauseIdx(causeIdx) + } + + buffer.append(builder.build()) + + currentError = currentError.getCause + } + + FetchErrorDetailsResponse + .newBuilder() + .setRootErrorIdx(0) + .addAllErrors(buffer.asJava) + .build() + } + + private def buildStatusFromThrowable(st: Throwable, sessionHolder: SessionHolder): RPCStatus = { val errorInfo = ErrorInfo .newBuilder() .setReason(st.getClass.getName) @@ -66,14 +131,26 @@ private[connect] object ErrorUtils extends Logging { "classes", JsonMethods.compact(JsonMethods.render(allClasses(st.getClass).map(_.getName)))) - lazy val stackTrace = Option(ExceptionUtils.getStackTrace(st)) - val withStackTrace = if (stackTraceEnabled && stackTrace.nonEmpty) { - val maxSize = SparkEnv.get.conf.get(Connect.CONNECT_JVM_STACK_TRACE_MAX_SIZE) - errorInfo.putMetadata("stackTrace", StringUtils.abbreviate(stackTrace.get, maxSize)) - } else { - errorInfo + if (sessionHolder.session.conf.get(Connect.CONNECT_ENRICH_ERROR_ENABLED)) { + // Generate a new unique key for this exception. + val errorId = UUID.randomUUID().toString + + errorInfo.putMetadata("errorId", errorId) + + sessionHolder.errorIdToError + .put(errorId, st) } + lazy val stackTrace = Option(ExceptionUtils.getStackTrace(st)) + val withStackTrace = + if (sessionHolder.session.conf.get( + SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED) && stackTrace.nonEmpty) { + val maxSize = SparkEnv.get.conf.get(Connect.CONNECT_JVM_STACK_TRACE_MAX_SIZE) + errorInfo.putMetadata("stackTrace", StringUtils.abbreviate(stackTrace.get, maxSize)) + } else { + errorInfo + } + RPCStatus .newBuilder() .setCode(RPCCode.INTERNAL_VALUE) @@ -107,21 +184,19 @@ private[connect] object ErrorUtils extends Logging { sessionId: String, events: Option[ExecuteEventsManager] = None, isInterrupted: Boolean = false): PartialFunction[Throwable, Unit] = { - val session = + val sessionHolder = SparkConnectService .getOrCreateIsolatedSession(userId, sessionId) - .session - val stackTraceEnabled = session.conf.get(SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED) val partial: PartialFunction[Throwable, (Throwable, Throwable)] = { case se: SparkException if isPythonExecutionException(se) => ( se, StatusProto.toStatusRuntimeException( - buildStatusFromThrowable(se.getCause, stackTraceEnabled))) + buildStatusFromThrowable(se.getCause, sessionHolder))) case e: Throwable if e.isInstanceOf[SparkThrowable] || NonFatal.apply(e) => - (e, StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e, stackTraceEnabled))) + (e, StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e, sessionHolder))) case e: Throwable => ( diff --git a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/FetchErrorDetailsHandlerSuite.scala b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/FetchErrorDetailsHandlerSuite.scala new file mode 100644 index 000000000000..c0591dcc9c7b --- /dev/null +++ b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/FetchErrorDetailsHandlerSuite.scala @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connect.service + +import java.util.UUID + +import scala.concurrent.Promise +import scala.concurrent.duration._ + +import io.grpc.stub.StreamObserver + +import org.apache.spark.connect.proto +import org.apache.spark.connect.proto.FetchErrorDetailsResponse +import org.apache.spark.sql.connect.ResourceHelper +import org.apache.spark.sql.connect.config.Connect +import org.apache.spark.sql.connect.utils.ErrorUtils +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.util.ThreadUtils + +private class FetchErrorDetailsResponseObserver(p: Promise[FetchErrorDetailsResponse]) + extends StreamObserver[FetchErrorDetailsResponse] { + override def onNext(v: FetchErrorDetailsResponse): Unit = p.success(v) + override def onError(throwable: Throwable): Unit = throw throwable + override def onCompleted(): Unit = {} +} + +class FetchErrorDetailsHandlerSuite extends SharedSparkSession with ResourceHelper { + + private val userId = "user1" + + private val sessionId = UUID.randomUUID().toString + + private def fetchErrorDetails( + userId: String, + sessionId: String, + errorId: String): FetchErrorDetailsResponse = { + val promise = Promise[FetchErrorDetailsResponse] + val handler = + new SparkConnectFetchErrorDetailsHandler(new FetchErrorDetailsResponseObserver(promise)) + val context = proto.UserContext + .newBuilder() + .setUserId(userId) + .build() + val request = proto.FetchErrorDetailsRequest + .newBuilder() + .setUserContext(context) + .setSessionId(sessionId) + .setErrorId(errorId) + .build() + handler.handle(request) + ThreadUtils.awaitResult(promise.future, 5.seconds) + } + + for (serverStacktraceEnabled <- Seq(false, true)) { + test(s"error chain is properly constructed - $serverStacktraceEnabled") { + val testError = + new Exception("test1", new Exception("test2")) + val errorId = UUID.randomUUID().toString() + + val sessionHolder = SparkConnectService + .getOrCreateIsolatedSession(userId, sessionId) + + sessionHolder.errorIdToError.put(errorId, testError) + + sessionHolder.session.conf + .set(Connect.CONNECT_SERVER_STACKTRACE_ENABLED.key, serverStacktraceEnabled) + sessionHolder.session.conf + .set(SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED.key, false) + try { + val response = fetchErrorDetails(userId, sessionId, errorId) + assert(response.hasRootErrorIdx) + assert(response.getRootErrorIdx == 0) + + assert(response.getErrorsCount == 2) + assert(response.getErrors(0).getMessage == "test1") + assert(response.getErrors(0).getErrorTypeHierarchyCount == 3) + assert(response.getErrors(0).getErrorTypeHierarchy(0) == classOf[Exception].getName) + assert(response.getErrors(0).getErrorTypeHierarchy(1) == classOf[Throwable].getName) + assert(response.getErrors(0).getErrorTypeHierarchy(2) == classOf[Object].getName) + assert(response.getErrors(0).hasCauseIdx) + assert(response.getErrors(0).getCauseIdx == 1) + + assert(response.getErrors(1).getMessage == "test2") + assert(response.getErrors(1).getErrorTypeHierarchyCount == 3) + assert(response.getErrors(1).getErrorTypeHierarchy(0) == classOf[Exception].getName) + assert(response.getErrors(1).getErrorTypeHierarchy(1) == classOf[Throwable].getName) + assert(response.getErrors(1).getErrorTypeHierarchy(2) == classOf[Object].getName) + assert(!response.getErrors(1).hasCauseIdx) + if (serverStacktraceEnabled) { + assert(response.getErrors(0).getStackTraceCount == testError.getStackTrace.length) + assert( + response.getErrors(1).getStackTraceCount == + testError.getCause.getStackTrace.length) + } else { + assert(response.getErrors(0).getStackTraceCount == 0) + assert(response.getErrors(1).getStackTraceCount == 0) + } + } finally { + sessionHolder.session.conf.unset(Connect.CONNECT_SERVER_STACKTRACE_ENABLED.key) + sessionHolder.session.conf.unset(SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED.key) + } + } + } + + test("error not found") { + val response = fetchErrorDetails(userId, sessionId, UUID.randomUUID().toString()) + assert(!response.hasRootErrorIdx) + } + + test("invalidate cached exceptions after first request") { + val testError = new Exception("test1") + val errorId = UUID.randomUUID().toString() + + SparkConnectService + .getOrCreateIsolatedSession(userId, sessionId) + .errorIdToError + .put(errorId, testError) + + val response = fetchErrorDetails(userId, sessionId, errorId) + assert(response.hasRootErrorIdx) + assert(response.getRootErrorIdx == 0) + + assert(response.getErrorsCount == 1) + assert(response.getErrors(0).getMessage == "test1") + + assert( + SparkConnectService + .getOrCreateIsolatedSession(userId, sessionId) + .errorIdToError + .size() == 0) + } + + test("error chain is truncated after reaching max depth") { + var testError = new Exception("test") + for (i <- 0 until 2 * ErrorUtils.MAX_ERROR_CHAIN_LENGTH) { + val errorId = UUID.randomUUID().toString() + + SparkConnectService + .getOrCreateIsolatedSession(userId, sessionId) + .errorIdToError + .put(errorId, testError) + + val response = fetchErrorDetails(userId, sessionId, errorId) + val expectedErrorCount = Math.min(i + 1, ErrorUtils.MAX_ERROR_CHAIN_LENGTH) + assert(response.getErrorsCount == expectedErrorCount) + assert(response.getErrors(expectedErrorCount - 1).hasCauseIdx == false) + + testError = new Exception(s"test$i", testError) + } + } +} diff --git a/python/pyspark/sql/connect/proto/base_pb2.py b/python/pyspark/sql/connect/proto/base_pb2.py index 731f4445e150..2bde0677e4b7 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.py +++ b/python/pyspark/sql/connect/proto/base_pb2.py @@ -37,7 +37,7 @@ from pyspark.sql.connect.proto import types_pb2 as spark_dot_connect_dot_types__ DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( - b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17 [...] + b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01 \x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02 \x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17 [...] ) _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals()) @@ -197,6 +197,14 @@ if _descriptor._USE_C_DESCRIPTORS == False: _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_end = 11157 _RELEASEEXECUTERESPONSE._serialized_start = 11186 _RELEASEEXECUTERESPONSE._serialized_end = 11298 - _SPARKCONNECTSERVICE._serialized_start = 11301 - _SPARKCONNECTSERVICE._serialized_end = 12044 + _FETCHERRORDETAILSREQUEST._serialized_start = 11301 + _FETCHERRORDETAILSREQUEST._serialized_end = 11448 + _FETCHERRORDETAILSRESPONSE._serialized_start = 11451 + _FETCHERRORDETAILSRESPONSE._serialized_end = 11997 + _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_start = 11596 + _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_end = 11751 + _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 11754 + _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 11978 + _SPARKCONNECTSERVICE._serialized_start = 12000 + _SPARKCONNECTSERVICE._serialized_end = 12849 # @@protoc_insertion_point(module_scope) diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi b/python/pyspark/sql/connect/proto/base_pb2.pyi index 3dca29230ef2..43254ceb2560 100644 --- a/python/pyspark/sql/connect/proto/base_pb2.pyi +++ b/python/pyspark/sql/connect/proto/base_pb2.pyi @@ -2728,3 +2728,183 @@ class ReleaseExecuteResponse(google.protobuf.message.Message): ) -> typing_extensions.Literal["operation_id"] | None: ... global___ReleaseExecuteResponse = ReleaseExecuteResponse + +class FetchErrorDetailsRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + SESSION_ID_FIELD_NUMBER: builtins.int + USER_CONTEXT_FIELD_NUMBER: builtins.int + ERROR_ID_FIELD_NUMBER: builtins.int + session_id: builtins.str + """(Required) + The session_id specifies a Spark session for a user identified by user_context.user_id. + The id should be a UUID string of the format `00112233-4455-6677-8899-aabbccddeeff`. + """ + @property + def user_context(self) -> global___UserContext: + """User context""" + error_id: builtins.str + """(Required) + The id of the error. + """ + def __init__( + self, + *, + session_id: builtins.str = ..., + user_context: global___UserContext | None = ..., + error_id: builtins.str = ..., + ) -> None: ... + def HasField( + self, field_name: typing_extensions.Literal["user_context", b"user_context"] + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "error_id", b"error_id", "session_id", b"session_id", "user_context", b"user_context" + ], + ) -> None: ... + +global___FetchErrorDetailsRequest = FetchErrorDetailsRequest + +class FetchErrorDetailsResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + class StackTraceElement(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + DECLARING_CLASS_FIELD_NUMBER: builtins.int + METHOD_NAME_FIELD_NUMBER: builtins.int + FILE_NAME_FIELD_NUMBER: builtins.int + LINE_NUMBER_FIELD_NUMBER: builtins.int + declaring_class: builtins.str + """The fully qualified name of the class containing the execution point.""" + method_name: builtins.str + """The name of the method containing the execution point.""" + file_name: builtins.str + """The name of the file containing the execution point.""" + line_number: builtins.int + """The line number of the source line containing the execution point.""" + def __init__( + self, + *, + declaring_class: builtins.str = ..., + method_name: builtins.str = ..., + file_name: builtins.str = ..., + line_number: builtins.int = ..., + ) -> None: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "declaring_class", + b"declaring_class", + "file_name", + b"file_name", + "line_number", + b"line_number", + "method_name", + b"method_name", + ], + ) -> None: ... + + class Error(google.protobuf.message.Message): + """Error defines the schema for the representing exception.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + ERROR_TYPE_HIERARCHY_FIELD_NUMBER: builtins.int + MESSAGE_FIELD_NUMBER: builtins.int + STACK_TRACE_FIELD_NUMBER: builtins.int + CAUSE_IDX_FIELD_NUMBER: builtins.int + @property + def error_type_hierarchy( + self, + ) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """The fully qualified names of the exception class and its parent classes.""" + message: builtins.str + """The detailed message of the exception.""" + @property + def stack_trace( + self, + ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ + global___FetchErrorDetailsResponse.StackTraceElement + ]: + """The stackTrace of the exception. It will be set + if the SQLConf spark.sql.connect.serverStacktrace.enabled is true. + """ + cause_idx: builtins.int + """The index of the cause error in errors.""" + def __init__( + self, + *, + error_type_hierarchy: collections.abc.Iterable[builtins.str] | None = ..., + message: builtins.str = ..., + stack_trace: collections.abc.Iterable[ + global___FetchErrorDetailsResponse.StackTraceElement + ] + | None = ..., + cause_idx: builtins.int | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "_cause_idx", b"_cause_idx", "cause_idx", b"cause_idx" + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "_cause_idx", + b"_cause_idx", + "cause_idx", + b"cause_idx", + "error_type_hierarchy", + b"error_type_hierarchy", + "message", + b"message", + "stack_trace", + b"stack_trace", + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_cause_idx", b"_cause_idx"] + ) -> typing_extensions.Literal["cause_idx"] | None: ... + + ROOT_ERROR_IDX_FIELD_NUMBER: builtins.int + ERRORS_FIELD_NUMBER: builtins.int + root_error_idx: builtins.int + """The index of the root error in errors. The field will not be set if the error is not found.""" + @property + def errors( + self, + ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[ + global___FetchErrorDetailsResponse.Error + ]: + """A list of errors.""" + def __init__( + self, + *, + root_error_idx: builtins.int | None = ..., + errors: collections.abc.Iterable[global___FetchErrorDetailsResponse.Error] | None = ..., + ) -> None: ... + def HasField( + self, + field_name: typing_extensions.Literal[ + "_root_error_idx", b"_root_error_idx", "root_error_idx", b"root_error_idx" + ], + ) -> builtins.bool: ... + def ClearField( + self, + field_name: typing_extensions.Literal[ + "_root_error_idx", + b"_root_error_idx", + "errors", + b"errors", + "root_error_idx", + b"root_error_idx", + ], + ) -> None: ... + def WhichOneof( + self, oneof_group: typing_extensions.Literal["_root_error_idx", b"_root_error_idx"] + ) -> typing_extensions.Literal["root_error_idx"] | None: ... + +global___FetchErrorDetailsResponse = FetchErrorDetailsResponse diff --git a/python/pyspark/sql/connect/proto/base_pb2_grpc.py b/python/pyspark/sql/connect/proto/base_pb2_grpc.py index e6bfda8a40a8..f6c5573ded6b 100644 --- a/python/pyspark/sql/connect/proto/base_pb2_grpc.py +++ b/python/pyspark/sql/connect/proto/base_pb2_grpc.py @@ -70,6 +70,11 @@ class SparkConnectServiceStub(object): request_serializer=spark_dot_connect_dot_base__pb2.ReleaseExecuteRequest.SerializeToString, response_deserializer=spark_dot_connect_dot_base__pb2.ReleaseExecuteResponse.FromString, ) + self.FetchErrorDetails = channel.unary_unary( + "/spark.connect.SparkConnectService/FetchErrorDetails", + request_serializer=spark_dot_connect_dot_base__pb2.FetchErrorDetailsRequest.SerializeToString, + response_deserializer=spark_dot_connect_dot_base__pb2.FetchErrorDetailsResponse.FromString, + ) class SparkConnectServiceServicer(object): @@ -136,6 +141,12 @@ class SparkConnectServiceServicer(object): context.set_details("Method not implemented!") raise NotImplementedError("Method not implemented!") + def FetchErrorDetails(self, request, context): + """FetchErrorDetails retrieves the matched exception with details based on a provided error id.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details("Method not implemented!") + raise NotImplementedError("Method not implemented!") + def add_SparkConnectServiceServicer_to_server(servicer, server): rpc_method_handlers = { @@ -179,6 +190,11 @@ def add_SparkConnectServiceServicer_to_server(servicer, server): request_deserializer=spark_dot_connect_dot_base__pb2.ReleaseExecuteRequest.FromString, response_serializer=spark_dot_connect_dot_base__pb2.ReleaseExecuteResponse.SerializeToString, ), + "FetchErrorDetails": grpc.unary_unary_rpc_method_handler( + servicer.FetchErrorDetails, + request_deserializer=spark_dot_connect_dot_base__pb2.FetchErrorDetailsRequest.FromString, + response_serializer=spark_dot_connect_dot_base__pb2.FetchErrorDetailsResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( "spark.connect.SparkConnectService", rpc_method_handlers @@ -421,3 +437,32 @@ class SparkConnectService(object): timeout, metadata, ) + + @staticmethod + def FetchErrorDetails( + request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None, + ): + return grpc.experimental.unary_unary( + request, + target, + "/spark.connect.SparkConnectService/FetchErrorDetails", + spark_dot_connect_dot_base__pb2.FetchErrorDetailsRequest.SerializeToString, + spark_dot_connect_dot_base__pb2.FetchErrorDetailsResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + ) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org