This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 48b1a283a2eb [SPARK-44622][SQL][CONNECT] Implement FetchErrorDetails 
RPC
48b1a283a2eb is described below

commit 48b1a283a2eba9f70149d5980d074fad2743c4ff
Author: Yihong He <yihong...@databricks.com>
AuthorDate: Wed Sep 20 00:14:44 2023 -0400

    [SPARK-44622][SQL][CONNECT] Implement FetchErrorDetails RPC
    
    ### What changes were proposed in this pull request?
    
    - Introduced the FetchErrorDetails RPC to retrieve comprehensive error 
details. FetchErrorDetails is used for enriching the error by issuing a 
separate RPC call based on the `errorId` field in the ErrorInfo.
    - Introduced error enrichment that utilizes an additional RPC to fetch 
untruncated exception messages and server-side stack traces. This enrichment 
can be enabled or disabled using the flag 
`spark.sql.connect.enrichError.enabled`, and it's true by default.
    - Implemented setting server-side stack traces for exceptions on the client 
side via FetchErrorDetails RPC for debugging. The feature is enabled or 
disabled using the flag `spark.sql.connect.serverStacktrace.enabled` and it's 
true by default
    
    ### Why are the changes needed?
    
    - Attaching full exception messages to the error details protobuf can 
quickly hit the 8K GRPC Netty header limit. Utilizing a separate RPC to fetch 
comprehensive error information is more dependable.
    - Providing server-side stack traces aids in effectively diagnosing 
server-related issues.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    - `build/sbt "connect/testOnly *FetchErrorDetailsHandlerSuite"`
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #42377 from heyihong/SPARK-44622.
    
    Authored-by: Yihong He <yihong...@databricks.com>
    Signed-off-by: Herman van Hovell <her...@databricks.com>
---
 .../src/main/protobuf/spark/connect/base.proto     |  57 +++++++
 .../apache/spark/sql/connect/config/Connect.scala  |  18 +++
 .../spark/sql/connect/service/SessionHolder.scala  |  21 ++-
 .../SparkConnectFetchErrorDetailsHandler.scala     |  59 +++++++
 .../sql/connect/service/SparkConnectService.scala  |  14 ++
 .../spark/sql/connect/utils/ErrorUtils.scala       | 103 ++++++++++--
 .../service/FetchErrorDetailsHandlerSuite.scala    | 166 +++++++++++++++++++
 python/pyspark/sql/connect/proto/base_pb2.py       |  14 +-
 python/pyspark/sql/connect/proto/base_pb2.pyi      | 180 +++++++++++++++++++++
 python/pyspark/sql/connect/proto/base_pb2_grpc.py  |  45 ++++++
 10 files changed, 659 insertions(+), 18 deletions(-)

diff --git 
a/connector/connect/common/src/main/protobuf/spark/connect/base.proto 
b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
index 65e2493f8368..cf1355f7ebc1 100644
--- a/connector/connect/common/src/main/protobuf/spark/connect/base.proto
+++ b/connector/connect/common/src/main/protobuf/spark/connect/base.proto
@@ -778,6 +778,60 @@ message ReleaseExecuteResponse {
   optional string operation_id = 2;
 }
 
+message FetchErrorDetailsRequest {
+
+  // (Required)
+  // The session_id specifies a Spark session for a user identified by 
user_context.user_id.
+  // The id should be a UUID string of the format 
`00112233-4455-6677-8899-aabbccddeeff`.
+  string session_id = 1;
+
+  // User context
+  UserContext user_context = 2;
+
+  // (Required)
+  // The id of the error.
+  string error_id = 3;
+}
+
+message FetchErrorDetailsResponse {
+
+  message StackTraceElement {
+    // The fully qualified name of the class containing the execution point.
+    string declaring_class = 1;
+
+    // The name of the method containing the execution point.
+    string method_name = 2;
+
+    // The name of the file containing the execution point.
+    string file_name = 3;
+
+    // The line number of the source line containing the execution point.
+    int32 line_number = 4;
+  }
+
+  // Error defines the schema for the representing exception.
+  message Error {
+    // The fully qualified names of the exception class and its parent classes.
+    repeated string error_type_hierarchy = 1;
+
+    // The detailed message of the exception.
+    string message = 2;
+
+    // The stackTrace of the exception. It will be set
+    // if the SQLConf spark.sql.connect.serverStacktrace.enabled is true.
+    repeated StackTraceElement stack_trace = 3;
+
+    // The index of the cause error in errors.
+    optional int32 cause_idx = 4;
+  }
+
+  // The index of the root error in errors. The field will not be set if the 
error is not found.
+  optional int32 root_error_idx = 1;
+
+  // A list of errors.
+  repeated Error errors = 2;
+}
+
 // Main interface for the SparkConnect service.
 service SparkConnectService {
 
@@ -813,5 +867,8 @@ service SparkConnectService {
   // Non reattachable executions are released automatically and immediately 
after the ExecutePlan
   // RPC and ReleaseExecute may not be used.
   rpc ReleaseExecute(ReleaseExecuteRequest) returns (ReleaseExecuteResponse) {}
+
+  // FetchErrorDetails retrieves the matched exception with details based on a 
provided error id.
+  rpc FetchErrorDetails(FetchErrorDetailsRequest) returns 
(FetchErrorDetailsResponse) {}
 }
 
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
index dfd6008ac09a..248444e710d2 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/config/Connect.scala
@@ -20,6 +20,7 @@ import java.util.concurrent.TimeUnit
 
 import org.apache.spark.network.util.ByteUnit
 import org.apache.spark.sql.connect.common.config.ConnectCommon
+import org.apache.spark.sql.internal.SQLConf.buildConf
 
 object Connect {
   import org.apache.spark.sql.internal.SQLConf.buildStaticConf
@@ -213,4 +214,21 @@ object Connect {
     .version("3.5.0")
     .intConf
     .createWithDefault(200)
+
+  val CONNECT_ENRICH_ERROR_ENABLED =
+    buildConf("spark.sql.connect.enrichError.enabled")
+      .doc("""
+          |When true, it enriches errors with full exception messages and 
optionally server-side
+          |stacktrace on the client side via an additional RPC.
+          |""".stripMargin)
+      .version("4.0.0")
+      .booleanConf
+      .createWithDefault(true)
+
+  val CONNECT_SERVER_STACKTRACE_ENABLED =
+    buildConf("spark.sql.connect.serverStacktrace.enabled")
+      .doc("When true, it sets the server-side stacktrace in the user-facing 
Spark exception.")
+      .version("4.0.0")
+      .booleanConf
+      .createWithDefault(true)
 }
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
index 1cef02d7e346..0748cd237bf0 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala
@@ -19,11 +19,14 @@ package org.apache.spark.sql.connect.service
 
 import java.nio.file.Path
 import java.util.UUID
-import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap}
+import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, TimeUnit}
 
 import scala.collection.JavaConverters._
 import scala.collection.mutable
 
+import com.google.common.base.Ticker
+import com.google.common.cache.CacheBuilder
+
 import org.apache.spark.{JobArtifactSet, SparkException}
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.DataFrame
@@ -32,6 +35,7 @@ import 
org.apache.spark.sql.connect.artifact.SparkConnectArtifactManager
 import org.apache.spark.sql.connect.common.InvalidPlanInput
 import org.apache.spark.sql.connect.planner.PythonStreamingQueryListener
 import org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper
+import org.apache.spark.sql.connect.service.SessionHolder.{ERROR_CACHE_SIZE, 
ERROR_CACHE_TIMEOUT_SEC}
 import org.apache.spark.sql.streaming.StreamingQueryListener
 import org.apache.spark.util.SystemClock
 import org.apache.spark.util.Utils
@@ -45,6 +49,15 @@ case class SessionHolder(userId: String, sessionId: String, 
session: SparkSessio
   private val executions: ConcurrentMap[String, ExecuteHolder] =
     new ConcurrentHashMap[String, ExecuteHolder]()
 
+  // The cache that maps an error id to a throwable. The throwable in cache is 
independent to
+  // each other.
+  private[connect] val errorIdToError = CacheBuilder
+    .newBuilder()
+    .ticker(Ticker.systemTicker())
+    .maximumSize(ERROR_CACHE_SIZE)
+    .expireAfterAccess(ERROR_CACHE_TIMEOUT_SEC, TimeUnit.SECONDS)
+    .build[String, Throwable]()
+
   val eventManager: SessionEventsManager = SessionEventsManager(this, new 
SystemClock())
 
   // Mapping from relation ID (passed to client) to runtime dataframe. Used 
for callbacks like
@@ -265,6 +278,12 @@ case class SessionHolder(userId: String, sessionId: 
String, session: SparkSessio
 
 object SessionHolder {
 
+  // The maximum number of distinct errors in the cache.
+  private val ERROR_CACHE_SIZE = 20
+
+  // The maximum time for an error to stay in the cache.
+  private val ERROR_CACHE_TIMEOUT_SEC = 60
+
   /** Creates a dummy session holder for use in tests. */
   def forTesting(session: SparkSession): SessionHolder = {
     val ret =
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectFetchErrorDetailsHandler.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectFetchErrorDetailsHandler.scala
new file mode 100644
index 000000000000..17a6e9e434f3
--- /dev/null
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectFetchErrorDetailsHandler.scala
@@ -0,0 +1,59 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.service
+
+import io.grpc.stub.StreamObserver
+
+import org.apache.spark.connect.proto
+import org.apache.spark.connect.proto.FetchErrorDetailsResponse
+import org.apache.spark.sql.connect.config.Connect
+import org.apache.spark.sql.connect.utils.ErrorUtils
+import org.apache.spark.sql.internal.SQLConf
+
+/**
+ * Handles [[proto.FetchErrorDetailsRequest]]s for the 
[[SparkConnectService]]. The handler
+ * retrieves the matched error with details from the cache based on a provided 
error id.
+ *
+ * @param responseObserver
+ */
+class SparkConnectFetchErrorDetailsHandler(
+    responseObserver: StreamObserver[proto.FetchErrorDetailsResponse]) {
+
+  def handle(v: proto.FetchErrorDetailsRequest): Unit = {
+    val sessionHolder =
+      SparkConnectService
+        .getOrCreateIsolatedSession(v.getUserContext.getUserId, v.getSessionId)
+
+    val response = 
Option(sessionHolder.errorIdToError.getIfPresent(v.getErrorId))
+      .map { error =>
+        // This error can only be fetched once,
+        // if a connection dies in the middle you cannot repeat.
+        sessionHolder.errorIdToError.invalidate(v.getErrorId)
+
+        ErrorUtils.throwableToFetchErrorDetailsResponse(
+          st = error,
+          serverStackTraceEnabled = sessionHolder.session.conf.get(
+            Connect.CONNECT_SERVER_STACKTRACE_ENABLED) || 
sessionHolder.session.conf.get(
+            SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED))
+      }
+      .getOrElse(FetchErrorDetailsResponse.newBuilder().build())
+
+    responseObserver.onNext(response)
+
+    responseObserver.onCompleted()
+  }
+}
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
index 269e47609dbf..e82c9cba5626 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectService.scala
@@ -201,6 +201,20 @@ class SparkConnectService(debug: Boolean) extends 
AsyncService with BindableServ
         sessionId = request.getSessionId)
   }
 
+  override def fetchErrorDetails(
+      request: proto.FetchErrorDetailsRequest,
+      responseObserver: StreamObserver[proto.FetchErrorDetailsResponse]): Unit 
= {
+    try {
+      new 
SparkConnectFetchErrorDetailsHandler(responseObserver).handle(request)
+    } catch {
+      ErrorUtils.handleError(
+        "getErrorInfo",
+        observer = responseObserver,
+        userId = request.getUserContext.getUserId,
+        sessionId = request.getSessionId)
+    }
+  }
+
   private def methodWithCustomMarshallers(methodDesc: 
MethodDescriptor[MessageLite, MessageLite])
       : MethodDescriptor[MessageLite, MessageLite] = {
     val recursionLimit =
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
index 2050ebc01aa0..1abd44608cd0 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
@@ -17,8 +17,12 @@
 
 package org.apache.spark.sql.connect.utils
 
+import java.util.UUID
+
 import scala.annotation.tailrec
+import scala.collection.mutable
 import scala.collection.mutable.ArrayBuffer
+import scala.jdk.CollectionConverters._
 import scala.util.control.NonFatal
 
 import com.google.protobuf.{Any => ProtoAny}
@@ -33,13 +37,14 @@ import org.json4s.jackson.JsonMethods
 
 import org.apache.spark.{SparkEnv, SparkException, SparkThrowable}
 import org.apache.spark.api.python.PythonException
+import org.apache.spark.connect.proto.FetchErrorDetailsResponse
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.connect.config.Connect
-import org.apache.spark.sql.connect.service.ExecuteEventsManager
-import org.apache.spark.sql.connect.service.SparkConnectService
+import org.apache.spark.sql.connect.service.{ExecuteEventsManager, 
SessionHolder, SparkConnectService}
 import org.apache.spark.sql.internal.SQLConf
 
 private[connect] object ErrorUtils extends Logging {
+
   private def allClasses(cl: Class[_]): Seq[Class[_]] = {
     val classes = ArrayBuffer.empty[Class[_]]
     if (cl != null && !cl.equals(classOf[java.lang.Object])) {
@@ -57,7 +62,67 @@ private[connect] object ErrorUtils extends Logging {
     classes.toSeq
   }
 
-  private def buildStatusFromThrowable(st: Throwable, stackTraceEnabled: 
Boolean): RPCStatus = {
+  // The maximum length of the error chain.
+  private[connect] val MAX_ERROR_CHAIN_LENGTH = 5
+
+  /**
+   * Convert Throwable to a protobuf message FetchErrorDetailsResponse.
+   * @param st
+   *   the Throwable to be converted
+   * @param serverStackTraceEnabled
+   *   whether to return the server stack trace.
+   * @return
+   *   FetchErrorDetailsResponse
+   */
+  private[connect] def throwableToFetchErrorDetailsResponse(
+      st: Throwable,
+      serverStackTraceEnabled: Boolean = false): FetchErrorDetailsResponse = {
+
+    var currentError = st
+    val buffer = mutable.Buffer.empty[FetchErrorDetailsResponse.Error]
+
+    while (buffer.size < MAX_ERROR_CHAIN_LENGTH && currentError != null) {
+      val builder = FetchErrorDetailsResponse.Error
+        .newBuilder()
+        .setMessage(currentError.getMessage)
+        .addAllErrorTypeHierarchy(
+          ErrorUtils.allClasses(currentError.getClass).map(_.getName).asJava)
+
+      if (serverStackTraceEnabled) {
+        builder.addAllStackTrace(
+          currentError.getStackTrace
+            .map { stackTraceElement =>
+              FetchErrorDetailsResponse.StackTraceElement
+                .newBuilder()
+                .setDeclaringClass(stackTraceElement.getClassName)
+                .setMethodName(stackTraceElement.getMethodName)
+                .setFileName(stackTraceElement.getFileName)
+                .setLineNumber(stackTraceElement.getLineNumber)
+                .build()
+            }
+            .toIterable
+            .asJava)
+      }
+
+      val causeIdx = buffer.size + 1
+
+      if (causeIdx < MAX_ERROR_CHAIN_LENGTH && currentError.getCause != null) {
+        builder.setCauseIdx(causeIdx)
+      }
+
+      buffer.append(builder.build())
+
+      currentError = currentError.getCause
+    }
+
+    FetchErrorDetailsResponse
+      .newBuilder()
+      .setRootErrorIdx(0)
+      .addAllErrors(buffer.asJava)
+      .build()
+  }
+
+  private def buildStatusFromThrowable(st: Throwable, sessionHolder: 
SessionHolder): RPCStatus = {
     val errorInfo = ErrorInfo
       .newBuilder()
       .setReason(st.getClass.getName)
@@ -66,14 +131,26 @@ private[connect] object ErrorUtils extends Logging {
         "classes",
         
JsonMethods.compact(JsonMethods.render(allClasses(st.getClass).map(_.getName))))
 
-    lazy val stackTrace = Option(ExceptionUtils.getStackTrace(st))
-    val withStackTrace = if (stackTraceEnabled && stackTrace.nonEmpty) {
-      val maxSize = 
SparkEnv.get.conf.get(Connect.CONNECT_JVM_STACK_TRACE_MAX_SIZE)
-      errorInfo.putMetadata("stackTrace", 
StringUtils.abbreviate(stackTrace.get, maxSize))
-    } else {
-      errorInfo
+    if (sessionHolder.session.conf.get(Connect.CONNECT_ENRICH_ERROR_ENABLED)) {
+      // Generate a new unique key for this exception.
+      val errorId = UUID.randomUUID().toString
+
+      errorInfo.putMetadata("errorId", errorId)
+
+      sessionHolder.errorIdToError
+        .put(errorId, st)
     }
 
+    lazy val stackTrace = Option(ExceptionUtils.getStackTrace(st))
+    val withStackTrace =
+      if (sessionHolder.session.conf.get(
+          SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED) && stackTrace.nonEmpty) {
+        val maxSize = 
SparkEnv.get.conf.get(Connect.CONNECT_JVM_STACK_TRACE_MAX_SIZE)
+        errorInfo.putMetadata("stackTrace", 
StringUtils.abbreviate(stackTrace.get, maxSize))
+      } else {
+        errorInfo
+      }
+
     RPCStatus
       .newBuilder()
       .setCode(RPCCode.INTERNAL_VALUE)
@@ -107,21 +184,19 @@ private[connect] object ErrorUtils extends Logging {
       sessionId: String,
       events: Option[ExecuteEventsManager] = None,
       isInterrupted: Boolean = false): PartialFunction[Throwable, Unit] = {
-    val session =
+    val sessionHolder =
       SparkConnectService
         .getOrCreateIsolatedSession(userId, sessionId)
-        .session
-    val stackTraceEnabled = 
session.conf.get(SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED)
 
     val partial: PartialFunction[Throwable, (Throwable, Throwable)] = {
       case se: SparkException if isPythonExecutionException(se) =>
         (
           se,
           StatusProto.toStatusRuntimeException(
-            buildStatusFromThrowable(se.getCause, stackTraceEnabled)))
+            buildStatusFromThrowable(se.getCause, sessionHolder)))
 
       case e: Throwable if e.isInstanceOf[SparkThrowable] || NonFatal.apply(e) 
=>
-        (e, StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e, 
stackTraceEnabled)))
+        (e, StatusProto.toStatusRuntimeException(buildStatusFromThrowable(e, 
sessionHolder)))
 
       case e: Throwable =>
         (
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/FetchErrorDetailsHandlerSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/FetchErrorDetailsHandlerSuite.scala
new file mode 100644
index 000000000000..c0591dcc9c7b
--- /dev/null
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/FetchErrorDetailsHandlerSuite.scala
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.connect.service
+
+import java.util.UUID
+
+import scala.concurrent.Promise
+import scala.concurrent.duration._
+
+import io.grpc.stub.StreamObserver
+
+import org.apache.spark.connect.proto
+import org.apache.spark.connect.proto.FetchErrorDetailsResponse
+import org.apache.spark.sql.connect.ResourceHelper
+import org.apache.spark.sql.connect.config.Connect
+import org.apache.spark.sql.connect.utils.ErrorUtils
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.util.ThreadUtils
+
+private class FetchErrorDetailsResponseObserver(p: 
Promise[FetchErrorDetailsResponse])
+    extends StreamObserver[FetchErrorDetailsResponse] {
+  override def onNext(v: FetchErrorDetailsResponse): Unit = p.success(v)
+  override def onError(throwable: Throwable): Unit = throw throwable
+  override def onCompleted(): Unit = {}
+}
+
+class FetchErrorDetailsHandlerSuite extends SharedSparkSession with 
ResourceHelper {
+
+  private val userId = "user1"
+
+  private val sessionId = UUID.randomUUID().toString
+
+  private def fetchErrorDetails(
+      userId: String,
+      sessionId: String,
+      errorId: String): FetchErrorDetailsResponse = {
+    val promise = Promise[FetchErrorDetailsResponse]
+    val handler =
+      new SparkConnectFetchErrorDetailsHandler(new 
FetchErrorDetailsResponseObserver(promise))
+    val context = proto.UserContext
+      .newBuilder()
+      .setUserId(userId)
+      .build()
+    val request = proto.FetchErrorDetailsRequest
+      .newBuilder()
+      .setUserContext(context)
+      .setSessionId(sessionId)
+      .setErrorId(errorId)
+      .build()
+    handler.handle(request)
+    ThreadUtils.awaitResult(promise.future, 5.seconds)
+  }
+
+  for (serverStacktraceEnabled <- Seq(false, true)) {
+    test(s"error chain is properly constructed - $serverStacktraceEnabled") {
+      val testError =
+        new Exception("test1", new Exception("test2"))
+      val errorId = UUID.randomUUID().toString()
+
+      val sessionHolder = SparkConnectService
+        .getOrCreateIsolatedSession(userId, sessionId)
+
+      sessionHolder.errorIdToError.put(errorId, testError)
+
+      sessionHolder.session.conf
+        .set(Connect.CONNECT_SERVER_STACKTRACE_ENABLED.key, 
serverStacktraceEnabled)
+      sessionHolder.session.conf
+        .set(SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED.key, false)
+      try {
+        val response = fetchErrorDetails(userId, sessionId, errorId)
+        assert(response.hasRootErrorIdx)
+        assert(response.getRootErrorIdx == 0)
+
+        assert(response.getErrorsCount == 2)
+        assert(response.getErrors(0).getMessage == "test1")
+        assert(response.getErrors(0).getErrorTypeHierarchyCount == 3)
+        assert(response.getErrors(0).getErrorTypeHierarchy(0) == 
classOf[Exception].getName)
+        assert(response.getErrors(0).getErrorTypeHierarchy(1) == 
classOf[Throwable].getName)
+        assert(response.getErrors(0).getErrorTypeHierarchy(2) == 
classOf[Object].getName)
+        assert(response.getErrors(0).hasCauseIdx)
+        assert(response.getErrors(0).getCauseIdx == 1)
+
+        assert(response.getErrors(1).getMessage == "test2")
+        assert(response.getErrors(1).getErrorTypeHierarchyCount == 3)
+        assert(response.getErrors(1).getErrorTypeHierarchy(0) == 
classOf[Exception].getName)
+        assert(response.getErrors(1).getErrorTypeHierarchy(1) == 
classOf[Throwable].getName)
+        assert(response.getErrors(1).getErrorTypeHierarchy(2) == 
classOf[Object].getName)
+        assert(!response.getErrors(1).hasCauseIdx)
+        if (serverStacktraceEnabled) {
+          assert(response.getErrors(0).getStackTraceCount == 
testError.getStackTrace.length)
+          assert(
+            response.getErrors(1).getStackTraceCount ==
+              testError.getCause.getStackTrace.length)
+        } else {
+          assert(response.getErrors(0).getStackTraceCount == 0)
+          assert(response.getErrors(1).getStackTraceCount == 0)
+        }
+      } finally {
+        
sessionHolder.session.conf.unset(Connect.CONNECT_SERVER_STACKTRACE_ENABLED.key)
+        
sessionHolder.session.conf.unset(SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED.key)
+      }
+    }
+  }
+
+  test("error not found") {
+    val response = fetchErrorDetails(userId, sessionId, 
UUID.randomUUID().toString())
+    assert(!response.hasRootErrorIdx)
+  }
+
+  test("invalidate cached exceptions after first request") {
+    val testError = new Exception("test1")
+    val errorId = UUID.randomUUID().toString()
+
+    SparkConnectService
+      .getOrCreateIsolatedSession(userId, sessionId)
+      .errorIdToError
+      .put(errorId, testError)
+
+    val response = fetchErrorDetails(userId, sessionId, errorId)
+    assert(response.hasRootErrorIdx)
+    assert(response.getRootErrorIdx == 0)
+
+    assert(response.getErrorsCount == 1)
+    assert(response.getErrors(0).getMessage == "test1")
+
+    assert(
+      SparkConnectService
+        .getOrCreateIsolatedSession(userId, sessionId)
+        .errorIdToError
+        .size() == 0)
+  }
+
+  test("error chain is truncated after reaching max depth") {
+    var testError = new Exception("test")
+    for (i <- 0 until 2 * ErrorUtils.MAX_ERROR_CHAIN_LENGTH) {
+      val errorId = UUID.randomUUID().toString()
+
+      SparkConnectService
+        .getOrCreateIsolatedSession(userId, sessionId)
+        .errorIdToError
+        .put(errorId, testError)
+
+      val response = fetchErrorDetails(userId, sessionId, errorId)
+      val expectedErrorCount = Math.min(i + 1, 
ErrorUtils.MAX_ERROR_CHAIN_LENGTH)
+      assert(response.getErrorsCount == expectedErrorCount)
+      assert(response.getErrors(expectedErrorCount - 1).hasCauseIdx == false)
+
+      testError = new Exception(s"test$i", testError)
+    }
+  }
+}
diff --git a/python/pyspark/sql/connect/proto/base_pb2.py 
b/python/pyspark/sql/connect/proto/base_pb2.py
index 731f4445e150..2bde0677e4b7 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.py
+++ b/python/pyspark/sql/connect/proto/base_pb2.py
@@ -37,7 +37,7 @@ from pyspark.sql.connect.proto import types_pb2 as 
spark_dot_connect_dot_types__
 
 
 DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
-    
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
 
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
 
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17
 [...]
+    
b'\n\x18spark/connect/base.proto\x12\rspark.connect\x1a\x19google/protobuf/any.proto\x1a\x1cspark/connect/commands.proto\x1a\x1aspark/connect/common.proto\x1a\x1fspark/connect/expressions.proto\x1a\x1dspark/connect/relations.proto\x1a\x19spark/connect/types.proto"t\n\x04Plan\x12-\n\x04root\x18\x01
 
\x01(\x0b\x32\x17.spark.connect.RelationH\x00R\x04root\x12\x32\n\x07\x63ommand\x18\x02
 
\x01(\x0b\x32\x16.spark.connect.CommandH\x00R\x07\x63ommandB\t\n\x07op_type"z\n\x0bUserContext\x12\x17
 [...]
 )
 
 _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
@@ -197,6 +197,14 @@ if _descriptor._USE_C_DESCRIPTORS == False:
     _RELEASEEXECUTEREQUEST_RELEASEUNTIL._serialized_end = 11157
     _RELEASEEXECUTERESPONSE._serialized_start = 11186
     _RELEASEEXECUTERESPONSE._serialized_end = 11298
-    _SPARKCONNECTSERVICE._serialized_start = 11301
-    _SPARKCONNECTSERVICE._serialized_end = 12044
+    _FETCHERRORDETAILSREQUEST._serialized_start = 11301
+    _FETCHERRORDETAILSREQUEST._serialized_end = 11448
+    _FETCHERRORDETAILSRESPONSE._serialized_start = 11451
+    _FETCHERRORDETAILSRESPONSE._serialized_end = 11997
+    _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_start = 11596
+    _FETCHERRORDETAILSRESPONSE_STACKTRACEELEMENT._serialized_end = 11751
+    _FETCHERRORDETAILSRESPONSE_ERROR._serialized_start = 11754
+    _FETCHERRORDETAILSRESPONSE_ERROR._serialized_end = 11978
+    _SPARKCONNECTSERVICE._serialized_start = 12000
+    _SPARKCONNECTSERVICE._serialized_end = 12849
 # @@protoc_insertion_point(module_scope)
diff --git a/python/pyspark/sql/connect/proto/base_pb2.pyi 
b/python/pyspark/sql/connect/proto/base_pb2.pyi
index 3dca29230ef2..43254ceb2560 100644
--- a/python/pyspark/sql/connect/proto/base_pb2.pyi
+++ b/python/pyspark/sql/connect/proto/base_pb2.pyi
@@ -2728,3 +2728,183 @@ class 
ReleaseExecuteResponse(google.protobuf.message.Message):
     ) -> typing_extensions.Literal["operation_id"] | None: ...
 
 global___ReleaseExecuteResponse = ReleaseExecuteResponse
+
+class FetchErrorDetailsRequest(google.protobuf.message.Message):
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    SESSION_ID_FIELD_NUMBER: builtins.int
+    USER_CONTEXT_FIELD_NUMBER: builtins.int
+    ERROR_ID_FIELD_NUMBER: builtins.int
+    session_id: builtins.str
+    """(Required)
+    The session_id specifies a Spark session for a user identified by 
user_context.user_id.
+    The id should be a UUID string of the format 
`00112233-4455-6677-8899-aabbccddeeff`.
+    """
+    @property
+    def user_context(self) -> global___UserContext:
+        """User context"""
+    error_id: builtins.str
+    """(Required)
+    The id of the error.
+    """
+    def __init__(
+        self,
+        *,
+        session_id: builtins.str = ...,
+        user_context: global___UserContext | None = ...,
+        error_id: builtins.str = ...,
+    ) -> None: ...
+    def HasField(
+        self, field_name: typing_extensions.Literal["user_context", 
b"user_context"]
+    ) -> builtins.bool: ...
+    def ClearField(
+        self,
+        field_name: typing_extensions.Literal[
+            "error_id", b"error_id", "session_id", b"session_id", 
"user_context", b"user_context"
+        ],
+    ) -> None: ...
+
+global___FetchErrorDetailsRequest = FetchErrorDetailsRequest
+
+class FetchErrorDetailsResponse(google.protobuf.message.Message):
+    DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+    class StackTraceElement(google.protobuf.message.Message):
+        DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+        DECLARING_CLASS_FIELD_NUMBER: builtins.int
+        METHOD_NAME_FIELD_NUMBER: builtins.int
+        FILE_NAME_FIELD_NUMBER: builtins.int
+        LINE_NUMBER_FIELD_NUMBER: builtins.int
+        declaring_class: builtins.str
+        """The fully qualified name of the class containing the execution 
point."""
+        method_name: builtins.str
+        """The name of the method containing the execution point."""
+        file_name: builtins.str
+        """The name of the file containing the execution point."""
+        line_number: builtins.int
+        """The line number of the source line containing the execution 
point."""
+        def __init__(
+            self,
+            *,
+            declaring_class: builtins.str = ...,
+            method_name: builtins.str = ...,
+            file_name: builtins.str = ...,
+            line_number: builtins.int = ...,
+        ) -> None: ...
+        def ClearField(
+            self,
+            field_name: typing_extensions.Literal[
+                "declaring_class",
+                b"declaring_class",
+                "file_name",
+                b"file_name",
+                "line_number",
+                b"line_number",
+                "method_name",
+                b"method_name",
+            ],
+        ) -> None: ...
+
+    class Error(google.protobuf.message.Message):
+        """Error defines the schema for the representing exception."""
+
+        DESCRIPTOR: google.protobuf.descriptor.Descriptor
+
+        ERROR_TYPE_HIERARCHY_FIELD_NUMBER: builtins.int
+        MESSAGE_FIELD_NUMBER: builtins.int
+        STACK_TRACE_FIELD_NUMBER: builtins.int
+        CAUSE_IDX_FIELD_NUMBER: builtins.int
+        @property
+        def error_type_hierarchy(
+            self,
+        ) -> 
google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]:
+            """The fully qualified names of the exception class and its parent 
classes."""
+        message: builtins.str
+        """The detailed message of the exception."""
+        @property
+        def stack_trace(
+            self,
+        ) -> 
google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+            global___FetchErrorDetailsResponse.StackTraceElement
+        ]:
+            """The stackTrace of the exception. It will be set
+            if the SQLConf spark.sql.connect.serverStacktrace.enabled is true.
+            """
+        cause_idx: builtins.int
+        """The index of the cause error in errors."""
+        def __init__(
+            self,
+            *,
+            error_type_hierarchy: collections.abc.Iterable[builtins.str] | 
None = ...,
+            message: builtins.str = ...,
+            stack_trace: collections.abc.Iterable[
+                global___FetchErrorDetailsResponse.StackTraceElement
+            ]
+            | None = ...,
+            cause_idx: builtins.int | None = ...,
+        ) -> None: ...
+        def HasField(
+            self,
+            field_name: typing_extensions.Literal[
+                "_cause_idx", b"_cause_idx", "cause_idx", b"cause_idx"
+            ],
+        ) -> builtins.bool: ...
+        def ClearField(
+            self,
+            field_name: typing_extensions.Literal[
+                "_cause_idx",
+                b"_cause_idx",
+                "cause_idx",
+                b"cause_idx",
+                "error_type_hierarchy",
+                b"error_type_hierarchy",
+                "message",
+                b"message",
+                "stack_trace",
+                b"stack_trace",
+            ],
+        ) -> None: ...
+        def WhichOneof(
+            self, oneof_group: typing_extensions.Literal["_cause_idx", 
b"_cause_idx"]
+        ) -> typing_extensions.Literal["cause_idx"] | None: ...
+
+    ROOT_ERROR_IDX_FIELD_NUMBER: builtins.int
+    ERRORS_FIELD_NUMBER: builtins.int
+    root_error_idx: builtins.int
+    """The index of the root error in errors. The field will not be set if the 
error is not found."""
+    @property
+    def errors(
+        self,
+    ) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[
+        global___FetchErrorDetailsResponse.Error
+    ]:
+        """A list of errors."""
+    def __init__(
+        self,
+        *,
+        root_error_idx: builtins.int | None = ...,
+        errors: 
collections.abc.Iterable[global___FetchErrorDetailsResponse.Error] | None = ...,
+    ) -> None: ...
+    def HasField(
+        self,
+        field_name: typing_extensions.Literal[
+            "_root_error_idx", b"_root_error_idx", "root_error_idx", 
b"root_error_idx"
+        ],
+    ) -> builtins.bool: ...
+    def ClearField(
+        self,
+        field_name: typing_extensions.Literal[
+            "_root_error_idx",
+            b"_root_error_idx",
+            "errors",
+            b"errors",
+            "root_error_idx",
+            b"root_error_idx",
+        ],
+    ) -> None: ...
+    def WhichOneof(
+        self, oneof_group: typing_extensions.Literal["_root_error_idx", 
b"_root_error_idx"]
+    ) -> typing_extensions.Literal["root_error_idx"] | None: ...
+
+global___FetchErrorDetailsResponse = FetchErrorDetailsResponse
diff --git a/python/pyspark/sql/connect/proto/base_pb2_grpc.py 
b/python/pyspark/sql/connect/proto/base_pb2_grpc.py
index e6bfda8a40a8..f6c5573ded6b 100644
--- a/python/pyspark/sql/connect/proto/base_pb2_grpc.py
+++ b/python/pyspark/sql/connect/proto/base_pb2_grpc.py
@@ -70,6 +70,11 @@ class SparkConnectServiceStub(object):
             
request_serializer=spark_dot_connect_dot_base__pb2.ReleaseExecuteRequest.SerializeToString,
             
response_deserializer=spark_dot_connect_dot_base__pb2.ReleaseExecuteResponse.FromString,
         )
+        self.FetchErrorDetails = channel.unary_unary(
+            "/spark.connect.SparkConnectService/FetchErrorDetails",
+            
request_serializer=spark_dot_connect_dot_base__pb2.FetchErrorDetailsRequest.SerializeToString,
+            
response_deserializer=spark_dot_connect_dot_base__pb2.FetchErrorDetailsResponse.FromString,
+        )
 
 
 class SparkConnectServiceServicer(object):
@@ -136,6 +141,12 @@ class SparkConnectServiceServicer(object):
         context.set_details("Method not implemented!")
         raise NotImplementedError("Method not implemented!")
 
+    def FetchErrorDetails(self, request, context):
+        """FetchErrorDetails retrieves the matched exception with details 
based on a provided error id."""
+        context.set_code(grpc.StatusCode.UNIMPLEMENTED)
+        context.set_details("Method not implemented!")
+        raise NotImplementedError("Method not implemented!")
+
 
 def add_SparkConnectServiceServicer_to_server(servicer, server):
     rpc_method_handlers = {
@@ -179,6 +190,11 @@ def add_SparkConnectServiceServicer_to_server(servicer, 
server):
             
request_deserializer=spark_dot_connect_dot_base__pb2.ReleaseExecuteRequest.FromString,
             
response_serializer=spark_dot_connect_dot_base__pb2.ReleaseExecuteResponse.SerializeToString,
         ),
+        "FetchErrorDetails": grpc.unary_unary_rpc_method_handler(
+            servicer.FetchErrorDetails,
+            
request_deserializer=spark_dot_connect_dot_base__pb2.FetchErrorDetailsRequest.FromString,
+            
response_serializer=spark_dot_connect_dot_base__pb2.FetchErrorDetailsResponse.SerializeToString,
+        ),
     }
     generic_handler = grpc.method_handlers_generic_handler(
         "spark.connect.SparkConnectService", rpc_method_handlers
@@ -421,3 +437,32 @@ class SparkConnectService(object):
             timeout,
             metadata,
         )
+
+    @staticmethod
+    def FetchErrorDetails(
+        request,
+        target,
+        options=(),
+        channel_credentials=None,
+        call_credentials=None,
+        insecure=False,
+        compression=None,
+        wait_for_ready=None,
+        timeout=None,
+        metadata=None,
+    ):
+        return grpc.experimental.unary_unary(
+            request,
+            target,
+            "/spark.connect.SparkConnectService/FetchErrorDetails",
+            
spark_dot_connect_dot_base__pb2.FetchErrorDetailsRequest.SerializeToString,
+            
spark_dot_connect_dot_base__pb2.FetchErrorDetailsResponse.FromString,
+            options,
+            channel_credentials,
+            insecure,
+            call_credentials,
+            compression,
+            wait_for_ready,
+            timeout,
+            metadata,
+        )


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org


Reply via email to