This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1ca543b5595e [SPARK-45808][CONNECT][PYTHON] Better error handling for 
SQL Exceptions
1ca543b5595e is described below

commit 1ca543b5595ebfff4c46500df0ef7715c440c050
Author: Martin Grund <martin.gr...@databricks.com>
AuthorDate: Tue Nov 7 10:12:16 2023 -0800

    [SPARK-45808][CONNECT][PYTHON] Better error handling for SQL Exceptions
    
    ### What changes were proposed in this pull request?
    This patch optimizes the handling of errors reported back to Python. First, 
it properly allows the extraction of the `ERROR_CLASS` and the `SQL_STATE` and 
gives simpler accces to the stack trace.
    
    It therefore makes sure that the display of the stack trace is no longer 
only server-side decided but becomes a local usability property.
    
    In addition the following methods on the `SparkConnectGrpcException` become 
actually useful:
    
    * `getSqlState()`
    * `getErrorClass()`
    * `getStackTrace()`
    
    ### Why are the changes needed?
    Compatibility
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Updated the existing tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #43667 from grundprinzip/SPARK-XXXX-ex.
    
    Authored-by: Martin Grund <martin.gr...@databricks.com>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../org/apache/spark/sql/ClientE2ETestSuite.scala  |   3 +-
 .../SparkConnectFetchErrorDetailsHandler.scala     |   6 +-
 .../spark/sql/connect/utils/ErrorUtils.scala       |  14 ++
 .../service/FetchErrorDetailsHandlerSuite.scala    |  14 +-
 .../service/SparkConnectSessionHolderSuite.scala   | 102 ++++++------
 python/pyspark/errors/exceptions/base.py           |   2 +-
 python/pyspark/errors/exceptions/captured.py       |   2 +-
 python/pyspark/errors/exceptions/connect.py        | 178 ++++++++++++++++++---
 python/pyspark/sql/connect/client/core.py          |  13 +-
 .../sql/tests/connect/test_connect_basic.py        |  25 +--
 10 files changed, 258 insertions(+), 101 deletions(-)

diff --git 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
index b9fa415034c3..10c928f13041 100644
--- 
a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
+++ 
b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/ClientE2ETestSuite.scala
@@ -136,8 +136,7 @@ class ClientE2ETestSuite extends RemoteSparkSession with 
SQLHelper with PrivateM
         assert(
           ex.getStackTrace
             
.find(_.getClassName.contains("org.apache.spark.sql.catalyst.analysis.CheckAnalysis"))
-            .isDefined
-            == isServerStackTraceEnabled)
+            .isDefined)
       }
     }
   }
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectFetchErrorDetailsHandler.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectFetchErrorDetailsHandler.scala
index 17a6e9e434f3..b5a3c986d169 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectFetchErrorDetailsHandler.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectFetchErrorDetailsHandler.scala
@@ -20,9 +20,7 @@ import io.grpc.stub.StreamObserver
 
 import org.apache.spark.connect.proto
 import org.apache.spark.connect.proto.FetchErrorDetailsResponse
-import org.apache.spark.sql.connect.config.Connect
 import org.apache.spark.sql.connect.utils.ErrorUtils
-import org.apache.spark.sql.internal.SQLConf
 
 /**
  * Handles [[proto.FetchErrorDetailsRequest]]s for the 
[[SparkConnectService]]. The handler
@@ -46,9 +44,7 @@ class SparkConnectFetchErrorDetailsHandler(
 
         ErrorUtils.throwableToFetchErrorDetailsResponse(
           st = error,
-          serverStackTraceEnabled = sessionHolder.session.conf.get(
-            Connect.CONNECT_SERVER_STACKTRACE_ENABLED) || 
sessionHolder.session.conf.get(
-            SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED))
+          serverStackTraceEnabled = true)
       }
       .getOrElse(FetchErrorDetailsResponse.newBuilder().build())
 
diff --git 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
index 744fa3c8aa1a..7cb555ca47ec 100644
--- 
a/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
+++ 
b/connector/connect/server/src/main/scala/org/apache/spark/sql/connect/utils/ErrorUtils.scala
@@ -164,6 +164,20 @@ private[connect] object ErrorUtils extends Logging {
         "classes",
         
JsonMethods.compact(JsonMethods.render(allClasses(st.getClass).map(_.getName))))
 
+    // Add the SQL State and Error Class to the response metadata of the 
ErrorInfoObject.
+    st match {
+      case e: SparkThrowable =>
+        val state = e.getSqlState
+        if (state != null && state.nonEmpty) {
+          errorInfo.putMetadata("sqlState", state)
+        }
+        val errorClass = e.getErrorClass
+        if (errorClass != null && errorClass.nonEmpty) {
+          errorInfo.putMetadata("errorClass", errorClass)
+        }
+      case _ =>
+    }
+
     if 
(sessionHolderOpt.exists(_.session.conf.get(Connect.CONNECT_ENRICH_ERROR_ENABLED)))
 {
       // Generate a new unique key for this exception.
       val errorId = UUID.randomUUID().toString
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/FetchErrorDetailsHandlerSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/FetchErrorDetailsHandlerSuite.scala
index 40439a217230..ebcd1de60057 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/FetchErrorDetailsHandlerSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/FetchErrorDetailsHandlerSuite.scala
@@ -103,15 +103,11 @@ class FetchErrorDetailsHandlerSuite extends 
SharedSparkSession with ResourceHelp
         assert(response.getErrors(1).getErrorTypeHierarchy(1) == 
classOf[Throwable].getName)
         assert(response.getErrors(1).getErrorTypeHierarchy(2) == 
classOf[Object].getName)
         assert(!response.getErrors(1).hasCauseIdx)
-        if (serverStacktraceEnabled) {
-          assert(response.getErrors(0).getStackTraceCount == 
testError.getStackTrace.length)
-          assert(
-            response.getErrors(1).getStackTraceCount ==
-              testError.getCause.getStackTrace.length)
-        } else {
-          assert(response.getErrors(0).getStackTraceCount == 0)
-          assert(response.getErrors(1).getStackTraceCount == 0)
-        }
+        assert(response.getErrors(0).getStackTraceCount == 
testError.getStackTrace.length)
+        assert(
+          response.getErrors(1).getStackTraceCount ==
+            testError.getCause.getStackTrace.length)
+
       } finally {
         
sessionHolder.session.conf.unset(Connect.CONNECT_SERVER_STACKTRACE_ENABLED.key)
         
sessionHolder.session.conf.unset(SQLConf.PYSPARK_JVM_STACKTRACE_ENABLED.key)
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
index 910c2a2650c6..9845cee31037 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala
@@ -169,6 +169,56 @@ class SparkConnectSessionHolderSuite extends 
SharedSparkSession {
       accumulator = null)
   }
 
+  test("python listener process: process terminates after listener is 
removed") {
+    // scalastyle:off assume
+    assume(IntegratedUDFTestUtils.shouldTestPandasUDFs)
+    // scalastyle:on assume
+
+    val sessionHolder = SessionHolder.forTesting(spark)
+    try {
+      SparkConnectService.start(spark.sparkContext)
+
+      val pythonFn = 
dummyPythonFunction(sessionHolder)(streamingQueryListenerFunction)
+
+      val id1 = "listener_removeListener_test_1"
+      val id2 = "listener_removeListener_test_2"
+      val listener1 = new PythonStreamingQueryListener(pythonFn, sessionHolder)
+      val listener2 = new PythonStreamingQueryListener(pythonFn, sessionHolder)
+
+      sessionHolder.cacheListenerById(id1, listener1)
+      spark.streams.addListener(listener1)
+      sessionHolder.cacheListenerById(id2, listener2)
+      spark.streams.addListener(listener2)
+
+      val (runner1, runner2) = (listener1.runner, listener2.runner)
+
+      // assert both python processes are running
+      assert(!runner1.isWorkerStopped().get)
+      assert(!runner2.isWorkerStopped().get)
+
+      // remove listener1
+      spark.streams.removeListener(listener1)
+      sessionHolder.removeCachedListener(id1)
+      // assert listener1's python process is not running
+      eventually(timeout(30.seconds)) {
+        assert(runner1.isWorkerStopped().get)
+        assert(!runner2.isWorkerStopped().get)
+      }
+
+      // remove listener2
+      spark.streams.removeListener(listener2)
+      sessionHolder.removeCachedListener(id2)
+      eventually(timeout(30.seconds)) {
+        // assert listener2's python process is not running
+        assert(runner2.isWorkerStopped().get)
+        // all listeners are removed
+        assert(spark.streams.listListeners().isEmpty)
+      }
+    } finally {
+      SparkConnectService.stop()
+    }
+  }
+
   test("python foreachBatch process: process terminates after query is 
stopped") {
     // scalastyle:off assume
     assume(IntegratedUDFTestUtils.shouldTestPandasUDFs)
@@ -232,58 +282,10 @@ class SparkConnectSessionHolderSuite extends 
SharedSparkSession {
       assert(spark.streams.listListeners().length == 1) // only process 
termination listener
     } finally {
       SparkConnectService.stop()
+      // Wait for things to calm down.
+      Thread.sleep(4.seconds.toMillis)
       // remove process termination listener
       spark.streams.listListeners().foreach(spark.streams.removeListener)
     }
   }
-
-  test("python listener process: process terminates after listener is 
removed") {
-    // scalastyle:off assume
-    assume(IntegratedUDFTestUtils.shouldTestPandasUDFs)
-    // scalastyle:on assume
-
-    val sessionHolder = SessionHolder.forTesting(spark)
-    try {
-      SparkConnectService.start(spark.sparkContext)
-
-      val pythonFn = 
dummyPythonFunction(sessionHolder)(streamingQueryListenerFunction)
-
-      val id1 = "listener_removeListener_test_1"
-      val id2 = "listener_removeListener_test_2"
-      val listener1 = new PythonStreamingQueryListener(pythonFn, sessionHolder)
-      val listener2 = new PythonStreamingQueryListener(pythonFn, sessionHolder)
-
-      sessionHolder.cacheListenerById(id1, listener1)
-      spark.streams.addListener(listener1)
-      sessionHolder.cacheListenerById(id2, listener2)
-      spark.streams.addListener(listener2)
-
-      val (runner1, runner2) = (listener1.runner, listener2.runner)
-
-      // assert both python processes are running
-      assert(!runner1.isWorkerStopped().get)
-      assert(!runner2.isWorkerStopped().get)
-
-      // remove listener1
-      spark.streams.removeListener(listener1)
-      sessionHolder.removeCachedListener(id1)
-      // assert listener1's python process is not running
-      eventually(timeout(30.seconds)) {
-        assert(runner1.isWorkerStopped().get)
-        assert(!runner2.isWorkerStopped().get)
-      }
-
-      // remove listener2
-      spark.streams.removeListener(listener2)
-      sessionHolder.removeCachedListener(id2)
-      eventually(timeout(30.seconds)) {
-        // assert listener2's python process is not running
-        assert(runner2.isWorkerStopped().get)
-        // all listeners are removed
-        assert(spark.streams.listListeners().isEmpty)
-      }
-    } finally {
-      SparkConnectService.stop()
-    }
-  }
 }
diff --git a/python/pyspark/errors/exceptions/base.py 
b/python/pyspark/errors/exceptions/base.py
index 1d09a68dffbf..518a2d99ce88 100644
--- a/python/pyspark/errors/exceptions/base.py
+++ b/python/pyspark/errors/exceptions/base.py
@@ -75,7 +75,7 @@ class PySparkException(Exception):
         """
         return self.message_parameters
 
-    def getSqlState(self) -> None:
+    def getSqlState(self) -> Optional[str]:
         """
         Returns an SQLSTATE as a string.
 
diff --git a/python/pyspark/errors/exceptions/captured.py 
b/python/pyspark/errors/exceptions/captured.py
index d62b7d24347e..55ed7ab3a6d5 100644
--- a/python/pyspark/errors/exceptions/captured.py
+++ b/python/pyspark/errors/exceptions/captured.py
@@ -107,7 +107,7 @@ class CapturedException(PySparkException):
         else:
             return None
 
-    def getSqlState(self) -> Optional[str]:  # type: ignore[override]
+    def getSqlState(self) -> Optional[str]:
         assert SparkContext._gateway is not None
         gw = SparkContext._gateway
         if self._origin is not None and is_instance_of(
diff --git a/python/pyspark/errors/exceptions/connect.py 
b/python/pyspark/errors/exceptions/connect.py
index 423fb2c6f0ac..2558c425469a 100644
--- a/python/pyspark/errors/exceptions/connect.py
+++ b/python/pyspark/errors/exceptions/connect.py
@@ -46,55 +46,155 @@ class SparkConnectException(PySparkException):
 
 
 def convert_exception(
-    info: "ErrorInfo", truncated_message: str, resp: 
Optional[pb2.FetchErrorDetailsResponse]
+    info: "ErrorInfo",
+    truncated_message: str,
+    resp: Optional[pb2.FetchErrorDetailsResponse],
+    display_server_stacktrace: bool = False,
 ) -> SparkConnectException:
     classes = []
+    sql_state = None
+    error_class = None
+
+    stacktrace: Optional[str] = None
+
     if "classes" in info.metadata:
         classes = json.loads(info.metadata["classes"])
 
+    if "sqlState" in info.metadata:
+        sql_state = info.metadata["sqlState"]
+
+    if "errorClass" in info.metadata:
+        error_class = info.metadata["errorClass"]
+
     if resp is not None and resp.HasField("root_error_idx"):
         message = resp.errors[resp.root_error_idx].message
         stacktrace = _extract_jvm_stacktrace(resp)
     else:
         message = truncated_message
-        stacktrace = info.metadata["stackTrace"] if "stackTrace" in 
info.metadata else ""
-
-    if len(stacktrace) > 0:
-        message += f"\n\nJVM stacktrace:\n{stacktrace}"
+        stacktrace = info.metadata["stackTrace"] if "stackTrace" in 
info.metadata else None
+        display_server_stacktrace = display_server_stacktrace if stacktrace is 
not None else False
 
     if "org.apache.spark.sql.catalyst.parser.ParseException" in classes:
-        return ParseException(message)
+        return ParseException(
+            message,
+            error_class=error_class,
+            sql_state=sql_state,
+            server_stacktrace=stacktrace,
+            display_server_stacktrace=display_server_stacktrace,
+        )
     # Order matters. ParseException inherits AnalysisException.
     elif "org.apache.spark.sql.AnalysisException" in classes:
-        return AnalysisException(message)
+        return AnalysisException(
+            message,
+            error_class=error_class,
+            sql_state=sql_state,
+            server_stacktrace=stacktrace,
+            display_server_stacktrace=display_server_stacktrace,
+        )
     elif "org.apache.spark.sql.streaming.StreamingQueryException" in classes:
-        return StreamingQueryException(message)
+        return StreamingQueryException(
+            message,
+            error_class=error_class,
+            sql_state=sql_state,
+            server_stacktrace=stacktrace,
+            display_server_stacktrace=display_server_stacktrace,
+        )
     elif "org.apache.spark.sql.execution.QueryExecutionException" in classes:
-        return QueryExecutionException(message)
+        return QueryExecutionException(
+            message,
+            error_class=error_class,
+            sql_state=sql_state,
+            server_stacktrace=stacktrace,
+            display_server_stacktrace=display_server_stacktrace,
+        )
     # Order matters. NumberFormatException inherits IllegalArgumentException.
     elif "java.lang.NumberFormatException" in classes:
-        return NumberFormatException(message)
+        return NumberFormatException(
+            message,
+            error_class=error_class,
+            sql_state=sql_state,
+            server_stacktrace=stacktrace,
+            display_server_stacktrace=display_server_stacktrace,
+        )
     elif "java.lang.IllegalArgumentException" in classes:
-        return IllegalArgumentException(message)
+        return IllegalArgumentException(
+            message,
+            error_class=error_class,
+            sql_state=sql_state,
+            server_stacktrace=stacktrace,
+            display_server_stacktrace=display_server_stacktrace,
+        )
     elif "java.lang.ArithmeticException" in classes:
-        return ArithmeticException(message)
+        return ArithmeticException(
+            message,
+            error_class=error_class,
+            sql_state=sql_state,
+            server_stacktrace=stacktrace,
+            display_server_stacktrace=display_server_stacktrace,
+        )
     elif "java.lang.UnsupportedOperationException" in classes:
-        return UnsupportedOperationException(message)
+        return UnsupportedOperationException(
+            message,
+            error_class=error_class,
+            sql_state=sql_state,
+            server_stacktrace=stacktrace,
+            display_server_stacktrace=display_server_stacktrace,
+        )
     elif "java.lang.ArrayIndexOutOfBoundsException" in classes:
-        return ArrayIndexOutOfBoundsException(message)
+        return ArrayIndexOutOfBoundsException(
+            message,
+            error_class=error_class,
+            sql_state=sql_state,
+            server_stacktrace=stacktrace,
+            display_server_stacktrace=display_server_stacktrace,
+        )
     elif "java.time.DateTimeException" in classes:
-        return DateTimeException(message)
+        return DateTimeException(
+            message,
+            error_class=error_class,
+            sql_state=sql_state,
+            server_stacktrace=stacktrace,
+            display_server_stacktrace=display_server_stacktrace,
+        )
     elif "org.apache.spark.SparkRuntimeException" in classes:
-        return SparkRuntimeException(message)
+        return SparkRuntimeException(
+            message,
+            error_class=error_class,
+            sql_state=sql_state,
+            server_stacktrace=stacktrace,
+            display_server_stacktrace=display_server_stacktrace,
+        )
     elif "org.apache.spark.SparkUpgradeException" in classes:
-        return SparkUpgradeException(message)
+        return SparkUpgradeException(
+            message,
+            error_class=error_class,
+            sql_state=sql_state,
+            server_stacktrace=stacktrace,
+            display_server_stacktrace=display_server_stacktrace,
+        )
     elif "org.apache.spark.api.python.PythonException" in classes:
         return PythonException(
             "\n  An exception was thrown from the Python worker. "
             "Please see the stack trace below.\n%s" % message
         )
+    # Make sure that the generic SparkException is handled last.
+    elif "org.apache.spark.SparkException" in classes:
+        return SparkException(
+            message,
+            error_class=error_class,
+            sql_state=sql_state,
+            server_stacktrace=stacktrace,
+            display_server_stacktrace=display_server_stacktrace,
+        )
     else:
-        return SparkConnectGrpcException(message, reason=info.reason)
+        return SparkConnectGrpcException(
+            message,
+            reason=info.reason,
+            error_class=error_class,
+            sql_state=sql_state,
+            server_stacktrace=stacktrace,
+            display_server_stacktrace=display_server_stacktrace,
+        )
 
 
 def _extract_jvm_stacktrace(resp: pb2.FetchErrorDetailsResponse) -> str:
@@ -106,7 +206,7 @@ def _extract_jvm_stacktrace(resp: 
pb2.FetchErrorDetailsResponse) -> str:
     def format_stacktrace(error: pb2.FetchErrorDetailsResponse.Error) -> None:
         message = f"{error.error_type_hierarchy[0]}: {error.message}"
         if len(lines) == 0:
-            lines.append(message)
+            lines.append(error.error_type_hierarchy[0])
         else:
             lines.append(f"Caused by: {message}")
         for elem in error.stack_trace:
@@ -135,16 +235,48 @@ class SparkConnectGrpcException(SparkConnectException):
         error_class: Optional[str] = None,
         message_parameters: Optional[Dict[str, str]] = None,
         reason: Optional[str] = None,
+        sql_state: Optional[str] = None,
+        server_stacktrace: Optional[str] = None,
+        display_server_stacktrace: bool = False,
     ) -> None:
         self.message = message  # type: ignore[assignment]
         if reason is not None:
             self.message = f"({reason}) {self.message}"
 
+        # PySparkException has the assumption that error_class and 
message_parameters are
+        # only occurring together. If only one is set, we assume the message 
to be fully
+        # parsed.
+        tmp_error_class = error_class
+        tmp_message_parameters = message_parameters
+        if error_class is not None and message_parameters is None:
+            tmp_error_class = None
+        elif error_class is None and message_parameters is not None:
+            tmp_message_parameters = None
+
         super().__init__(
             message=self.message,
-            error_class=error_class,
-            message_parameters=message_parameters,
+            error_class=tmp_error_class,
+            message_parameters=tmp_message_parameters,
         )
+        self.error_class = error_class
+        self._sql_state: Optional[str] = sql_state
+        self._stacktrace: Optional[str] = server_stacktrace
+        self._display_stacktrace: bool = display_server_stacktrace
+
+    def getSqlState(self) -> Optional[str]:
+        if self._sql_state is not None:
+            return self._sql_state
+        else:
+            return super().getSqlState()
+
+    def getStackTrace(self) -> Optional[str]:
+        return self._stacktrace
+
+    def __str__(self) -> str:
+        desc = self.message
+        if self._display_stacktrace:
+            desc += "\n\nJVM stacktrace:\n%s" % self._stacktrace
+        return desc
 
 
 class AnalysisException(SparkConnectGrpcException, BaseAnalysisException):
@@ -223,3 +355,7 @@ class SparkUpgradeException(SparkConnectGrpcException, 
BaseSparkUpgradeException
     """
     Exception thrown because of Spark upgrade from Spark Connect.
     """
+
+
+class SparkException(SparkConnectGrpcException):
+    """ """
diff --git a/python/pyspark/sql/connect/client/core.py 
b/python/pyspark/sql/connect/client/core.py
index 11a1112ad1fe..cef0ea4f305d 100644
--- a/python/pyspark/sql/connect/client/core.py
+++ b/python/pyspark/sql/connect/client/core.py
@@ -1564,6 +1564,14 @@ class SparkConnectClient(object):
         except grpc.RpcError:
             return None
 
+    def _display_server_stack_trace(self) -> bool:
+        from pyspark.sql.connect.conf import RuntimeConf
+
+        conf = RuntimeConf(self)
+        if conf.get("spark.sql.connect.serverStacktrace.enabled") == "true":
+            return True
+        return conf.get("spark.sql.pyspark.jvmStacktrace.enabled") == "true"
+
     def _handle_rpc_error(self, rpc_error: grpc.RpcError) -> NoReturn:
         """
         Error handling helper for dealing with GRPC Errors. On the server 
side, certain
@@ -1594,7 +1602,10 @@ class SparkConnectClient(object):
                     d.Unpack(info)
 
                     raise convert_exception(
-                        info, status.message, self._fetch_enriched_error(info)
+                        info,
+                        status.message,
+                        self._fetch_enriched_error(info),
+                        self._display_server_stack_trace(),
                     ) from None
 
             raise SparkConnectGrpcException(status.message) from None
diff --git a/python/pyspark/sql/tests/connect/test_connect_basic.py 
b/python/pyspark/sql/tests/connect/test_connect_basic.py
index f024a03c2686..daf6772e52bf 100755
--- a/python/pyspark/sql/tests/connect/test_connect_basic.py
+++ b/python/pyspark/sql/tests/connect/test_connect_basic.py
@@ -3378,35 +3378,37 @@ class SparkConnectSessionTests(ReusedConnectTestCase):
                         """select from_json(
                             '{"d": "02-29"}', 'd date', map('dateFormat', 
'MM-dd'))"""
                     ).collect()
-                self.assertTrue("JVM stacktrace" in e.exception.message)
-                self.assertTrue("org.apache.spark.SparkUpgradeException:" in 
e.exception.message)
+                self.assertTrue("JVM stacktrace" in str(e.exception))
+                self.assertTrue("org.apache.spark.SparkUpgradeException" in 
str(e.exception))
                 self.assertTrue(
                     "at org.apache.spark.sql.errors.ExecutionErrors"
-                    ".failToParseDateTimeInNewParserError" in 
e.exception.message
+                    ".failToParseDateTimeInNewParserError" in str(e.exception)
                 )
-                self.assertTrue("Caused by: java.time.DateTimeException:" in 
e.exception.message)
+                self.assertTrue("Caused by: java.time.DateTimeException:" in 
str(e.exception))
 
     def test_not_hitting_netty_header_limit(self):
         with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": True}):
             with self.assertRaises(AnalysisException):
-                self.spark.sql("select " + "test" * 10000).collect()
+                self.spark.sql("select " + "test" * 1).collect()
 
     def test_error_stack_trace(self):
         with self.sql_conf({"spark.sql.connect.enrichError.enabled": False}):
             with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": 
True}):
                 with self.assertRaises(AnalysisException) as e:
                     self.spark.sql("select x").collect()
-                self.assertTrue("JVM stacktrace" in e.exception.message)
+                self.assertTrue("JVM stacktrace" in str(e.exception))
+                self.assertIsNotNone(e.exception.getStackTrace())
                 self.assertTrue(
-                    "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" 
in e.exception.message
+                    "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" 
in str(e.exception)
                 )
 
             with self.sql_conf({"spark.sql.pyspark.jvmStacktrace.enabled": 
False}):
                 with self.assertRaises(AnalysisException) as e:
                     self.spark.sql("select x").collect()
-                self.assertFalse("JVM stacktrace" in e.exception.message)
+                self.assertFalse("JVM stacktrace" in str(e.exception))
+                self.assertIsNone(e.exception.getStackTrace())
                 self.assertFalse(
-                    "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" 
in e.exception.message
+                    "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" 
in str(e.exception)
                 )
 
         # Create a new session with a different stack trace size.
@@ -3421,9 +3423,10 @@ class SparkConnectSessionTests(ReusedConnectTestCase):
         spark.conf.set("spark.sql.pyspark.jvmStacktrace.enabled", True)
         with self.assertRaises(AnalysisException) as e:
             spark.sql("select x").collect()
-        self.assertTrue("JVM stacktrace" in e.exception.message)
+        self.assertTrue("JVM stacktrace" in str(e.exception))
+        self.assertIsNotNone(e.exception.getStackTrace())
         self.assertFalse(
-            "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in 
e.exception.message
+            "at org.apache.spark.sql.catalyst.analysis.CheckAnalysis" in 
str(e.exception)
         )
         spark.stop()
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to