This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 78d9eae49a82 [SPARK-54599][PYTHON] Refactor PythonException so it can 
take errorClass with sqlstate
78d9eae49a82 is described below

commit 78d9eae49a8228bf9af36de1d59e3ade3295d88e
Author: Tian Gao <[email protected]>
AuthorDate: Wed Feb 4 11:11:05 2026 +0800

    [SPARK-54599][PYTHON] Refactor PythonException so it can take errorClass 
with sqlstate
    
    ### What changes were proposed in this pull request?
    
    In general we want an `sqlstate` with `PythonException` to be better 
categorized. This PR basically rewrite `PythonException` to 
`SparkPythonException` style. For now we report `38000` for all existing 
exceptions, but we leave the flexibility to future changes.
    
    I also want to take this chance to make the user side exception print a bit 
more pythonic.
    
    ### Why are the changes needed?
    
    Many reported PythonExceptions can't be categorized by `sqlstate` which is 
not helpful to triage and build metrics.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. Some `Py4JJavaError` will become `PythonException` now - those raised 
by python workers. I think that's an improvement.
    
    Also, the exception report would be different.
    
    code:
    ```python
    from pyspark.sql import SparkSession
    from pyspark.sql import functions as F
    
    spark = SparkSession.builder.getOrCreate()
    
    F.udf
    def add_one(x):
        raise ValueError()
    
    df = spark.createDataFrame([(1,), (2,), (3,)], ["value"])
    df = df.withColumn("value_plus_one", add_one("value"))
    
    df.show()
    ```
    
    Current traceback:
    
    ```
    Traceback (most recent call last):
      File "/Users/gaotian/programs/spark/python/example.py", line 15, in 
<module>
        df.show()
      File 
"/Users/gaotian/programs/spark/python/pyspark/sql/classic/dataframe.py", line 
285, in show
        print(self._show_string(n, truncate, vertical))
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File 
"/Users/gaotian/programs/spark/python/pyspark/sql/classic/dataframe.py", line 
303, in _show_string
        return self._jdf.showString(n, 20, vertical)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File 
"/Users/gaotian/programs/spark/venv/lib/python3.12/site-packages/py4j/java_gateway.py",
 line 1362, in __call__
        return_value = get_return_value(
                       ^^^^^^^^^^^^^^^^^
      File 
"/Users/gaotian/programs/spark/python/pyspark/errors/exceptions/captured.py", 
line 269, in deco
        raise converted from None
      File "/Users/gaotian/programs/spark/python/example.py", line 9, in add_one
        raise ValueError()
          ^^^^^^^^^^^^^^^^^
    pyspark.errors.exceptions.captured.PythonException:
      An exception was thrown from the Python worker. Please see the stack 
trace below.
    Traceback (most recent call last):
      File "/Users/gaotian/programs/spark/python/example.py", line 8, in add_one
        raise ValueError()
    ValueError
    ```
    
    Without any change, the worker part traceback would be like
    
    ```
    pyspark.errors.exceptions.captured.PythonException: [PYTHON_EXCEPTION] 
Python worker failed with the following error:
    Traceback (most recent call last):
      File "/Users/gaotian/programs/spark/python/example.py", line 8, in add_one
        raise ValueError()
    ValueError
     SQLSTATE: 38000
    ```
    
    I think moving the exception explanation right after the Exception type is 
good - that's intuitive to Python users compared to having a new line. But that 
line is too long. I think we should have a special `__str__` method for 
`PythonException` instead of using the default `getMessage()`
    
    Option 1: Remove the errorClass and SQLSTATE because end users are not that 
interested
    
    ```
    pyspark.errors.exceptions.captured.PythonException: Python worker failed 
with the following error:
    Traceback (most recent call last):
      File "/Users/gaotian/programs/spark/python/example.py", line 8, in add_one
        raise ValueError()
    ValueError
    ```
    
    Option 2: Do some magic to print the exception that user is supposed to use 
to catch - instead of exposing our implementation details.
    
    ```
    pyspark.errors.PythonException: Python worker failed with the following 
error:
    Traceback (most recent call last):
      File "/Users/gaotian/programs/spark/python/example.py", line 8, in add_one
        raise ValueError()
    ValueError
    ```
    
    Option 3: We actually have two traceback from worker printed, which is 
weird. I think we should just keep one. We should not artificially combine the 
tracebacks from the driver and the worker
    
    ```
      File 
"/Users/gaotian/programs/spark/python/pyspark/errors/exceptions/captured.py", 
line 269, in deco
        raise converted from None
      File "/Users/gaotian/programs/spark/python/example.py", line 9, in 
add_one. -> THIS SHOULD BE REMOVED!
        raise ValueError()
          ^^^^^^^^^^^^^^^^^
    pyspark.errors.PythonException: Python worker failed with the following 
error:
    Traceback (most recent call last):
      File "/Users/gaotian/programs/spark/python/example.py", line 8, in add_one
        raise ValueError()
    ValueError
    ```
    
    New print:
    
    ```
    Traceback (most recent call last):
      File "/Users/gaotian/programs/spark/python/example.py", line 15, in 
<module>
        df.show()
      File 
"/Users/gaotian/programs/spark/python/pyspark/sql/classic/dataframe.py", line 
285, in show
        print(self._show_string(n, truncate, vertical))
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File 
"/Users/gaotian/programs/spark/python/pyspark/sql/classic/dataframe.py", line 
303, in _show_string
        return self._jdf.showString(n, 20, vertical)
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
      File 
"/Users/gaotian/programs/spark/venv/lib/python3.12/site-packages/py4j/java_gateway.py",
 line 1362, in __call__
        return_value = get_return_value(
                       ^^^^^^^^^^^^^^^^^
      File 
"/Users/gaotian/programs/spark/python/pyspark/errors/exceptions/captured.py", 
line 269, in deco
        raise converted from None
    pyspark.errors.PythonException: Python worker failed with the following 
error:
    Traceback (most recent call last):
      File "/Users/gaotian/programs/spark/python/example.py", line 8, in add_one
        raise ValueError()
    ValueError
    ```
    
    ### How was this patch tested?
    
    CI
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #53099 from gaogaotiantian/add-sqlstate-python-exception.
    
    Authored-by: Tian Gao <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../src/main/resources/error/error-conditions.json |  6 ++++
 .../org/apache/spark/api/python/PythonRDD.scala    | 30 ++++++++++++++++--
 .../org/apache/spark/api/python/PythonRunner.scala |  8 +++--
 python/pyspark/errors/exceptions/captured.py       | 37 +++++++++++-----------
 python/pyspark/sql/context.py                      |  2 +-
 .../sql/tests/arrow/test_arrow_cogrouped_map.py    |  4 +--
 .../sql/tests/arrow/test_arrow_grouped_map.py      |  4 +--
 .../sql/tests/pandas/test_pandas_cogrouped_map.py  |  6 ++--
 .../sql/tests/pandas/test_pandas_grouped_map.py    | 10 +++---
 python/pyspark/sql/tests/pandas/test_pandas_map.py | 16 +++++-----
 .../sql/tests/streaming/test_streaming_foreach.py  |  1 -
 .../streaming/test_streaming_foreach_batch.py      |  2 --
 .../sql/tests/streaming/test_streaming_listener.py |  4 ++-
 .../planner/StreamingForeachBatchHelper.scala      | 10 +++---
 .../planner/StreamingQueryListenerHelper.scala     |  9 +++---
 .../org/apache/spark/sql/SQLQueryTestHelper.scala  |  4 +++
 .../execution/python/PythonDataSourceSuite.scala   |  8 ++---
 17 files changed, 101 insertions(+), 60 deletions(-)

diff --git a/common/utils/src/main/resources/error/error-conditions.json 
b/common/utils/src/main/resources/error/error-conditions.json
index 8e3acfdb2c0a..258fcfeda15e 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -5262,6 +5262,12 @@
     ],
     "sqlState" : "38000"
   },
+  "PYTHON_EXCEPTION" : {
+    "message" : [
+      "<msg>: <traceback>"
+    ],
+    "sqlState" : "38000"
+  },
   "PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR" : {
     "message" : [
       "Failed when Python streaming data source perform <action>: <msg>"
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index cf0169fed60c..45bf30675148 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -126,8 +126,34 @@ private[spark] case class SimplePythonFunction(
 private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction])
 
 /** Thrown for exceptions in user Python code. */
-private[spark] class PythonException(msg: String, cause: Throwable)
-  extends RuntimeException(msg, cause)
+private[spark] class PythonException(
+    msg: String,
+    cause: Throwable,
+    errorClass: Option[String],
+    messageParameters: Map[String, String],
+    context: Array[QueryContext])
+  extends RuntimeException(msg, cause) with SparkThrowable {
+
+  def this(
+      errorClass: String,
+      messageParameters: Map[String, String],
+      cause: Throwable = null,
+      context: Array[QueryContext] = Array.empty,
+      summary: String = "") = {
+    this(
+      SparkThrowableHelper.getMessage(errorClass, messageParameters, summary),
+      cause,
+      Option(errorClass),
+      messageParameters,
+      context
+    )
+  }
+
+  override def getMessageParameters: java.util.Map[String, String] = 
messageParameters.asJava
+
+  override def getCondition: String = errorClass.orNull
+  override def getQueryContext: Array[QueryContext] = context
+}
 
 /**
  * Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from 
Python.
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index c3ee3853ce0f..6bc64f54d3a2 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -659,8 +659,12 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
 
     protected def handlePythonException(): PythonException = {
       // Signals that an exception has been thrown in python
-      val msg = PythonWorkerUtils.readUTF(stream)
-      new PythonException(msg, writer.exception.orNull)
+      val traceback = PythonWorkerUtils.readUTF(stream)
+      val msg = "An exception was thrown from the Python worker"
+      new PythonException(
+        errorClass = "PYTHON_EXCEPTION",
+        messageParameters = Map("msg" -> msg, "traceback" -> traceback),
+        cause = writer.exception.orNull)
     }
 
     protected def handleEndOfDataSection(): Unit = {
diff --git a/python/pyspark/errors/exceptions/captured.py 
b/python/pyspark/errors/exceptions/captured.py
index 0f76e3b5f6a0..adde27bf8a5d 100644
--- a/python/pyspark/errors/exceptions/captured.py
+++ b/python/pyspark/errors/exceptions/captured.py
@@ -234,25 +234,13 @@ def _convert_exception(e: "Py4JJavaError") -> 
CapturedException:
         return SparkUpgradeException(origin=e)
     elif is_instance_of(gw, e, "org.apache.spark.SparkNoSuchElementException"):
         return SparkNoSuchElementException(origin=e)
-
-    c: "Py4JJavaError" = e.getCause()
-    stacktrace: str = getattr(jvm, 
"org.apache.spark.util.Utils").exceptionString(e)
-    if c is not None and (
-        is_instance_of(gw, c, "org.apache.spark.api.python.PythonException")
-        # To make sure this only catches Python UDFs.
-        and any(
-            map(
-                lambda v: "org.apache.spark.sql.execution.python" in 
v.toString(), c.getStackTrace()
-            )
-        )
-    ):
-        msg = (
-            "\n  An exception was thrown from the Python worker. "
-            "Please see the stack trace below.\n%s" % c.getMessage()
-        )
-        return PythonException(msg, stacktrace)
-
-    return UnknownException(desc=e.toString(), stackTrace=stacktrace, cause=c)
+    elif is_instance_of(gw, e, "org.apache.spark.api.python.PythonException"):
+        return PythonException(origin=e)
+    return UnknownException(
+        desc=e.toString(),
+        stackTrace=getattr(jvm, 
"org.apache.spark.util.Utils").exceptionString(e),
+        cause=e.getCause(),
+    )
 
 
 def capture_sql_exception(f: Callable[..., Any]) -> Callable[..., Any]:
@@ -348,6 +336,17 @@ class PythonException(CapturedException, 
BasePythonException):
     Exceptions thrown from Python workers.
     """
 
+    def __str__(self) -> str:
+        messageParameters = self.getMessageParameters()
+
+        if (
+            messageParameters is None
+            or "msg" not in messageParameters
+            or "traceback" not in messageParameters
+        ):
+            return super().__str__()
+        return 
f"{messageParameters['msg']}:\n{messageParameters['traceback'].strip()}"
+
 
 class ArithmeticException(CapturedException, BaseArithmeticException):
     """
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index c51b9e063a77..34f488244d78 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -472,7 +472,7 @@ class SQLContext:
         >>> sqlContext.createDataFrame(rdd, "boolean").collect() # doctest: 
+IGNORE_EXCEPTION_DETAIL
         Traceback (most recent call last):
             ...
-        Py4JJavaError: ...
+        pyspark.errors.exceptions.captured.PythonException: ...
         """
         return self.sparkSession.createDataFrame(  # type: 
ignore[call-overload]
             data, schema, samplingRatio, verifySchema
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py 
b/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py
index 98362a44d3eb..45c8a455a44a 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py
@@ -196,7 +196,7 @@ class CogroupedMapInArrowTestsMixin:
             with self.assertRaisesRegex(
                 PythonException,
                 "Column names of the returned pyarrow.Table do not match 
specified schema. "
-                "Missing: m. Unexpected: v, v2.\n",
+                "Missing: m. Unexpected: v, v2.",
             ):
                 # stats returns three columns while here we set schema with 
two columns
                 self.cogrouped.applyInArrow(stats, schema="id long, m 
double").collect()
@@ -232,7 +232,7 @@ class CogroupedMapInArrowTestsMixin:
             with self.assertRaisesRegex(
                 PythonException,
                 "Column names of the returned pyarrow.Table do not match 
specified schema. "
-                "Missing: m.\n",
+                "Missing: m.",
             ):
                 # stats returns one column for even keys while here we set 
schema with two columns
                 self.cogrouped.applyInArrow(odd_means, schema="id long, m 
double").collect()
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py 
b/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py
index 9c1b14676ecc..d34c31220be6 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py
@@ -224,7 +224,7 @@ class ApplyInArrowTestsMixin:
                 with self.assertRaisesRegex(
                     PythonException,
                     "Column names of the returned pyarrow.Table do not match 
specified schema. "
-                    "Missing: m. Unexpected: v, v2.\n",
+                    "Missing: m. Unexpected: v, v2.",
                 ):
                     # stats returns three columns while here we set schema 
with two columns
                     df.groupby("id").applyInArrow(
@@ -265,7 +265,7 @@ class ApplyInArrowTestsMixin:
             with self.assertRaisesRegex(
                 PythonException,
                 "Column names of the returned pyarrow.Table do not match 
specified schema. "
-                "Missing: m.\n",
+                "Missing: m.",
             ):
                 # stats returns one column for even keys while here we set 
schema with two columns
                 df.groupby("id").applyInArrow(odd_means, schema="id long, m 
double").collect()
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py 
b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
index aa0a8479b3cd..88b5b9138f71 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
@@ -209,7 +209,7 @@ class CogroupedApplyInPandasTestsMixin:
             fn=merge_pandas,
             errorClass=PythonException,
             error_message_regex="Column names of the returned pandas.DataFrame 
"
-            "do not match specified schema. Unexpected: add, more.\n",
+            "do not match specified schema. Unexpected: add, more.",
         )
 
     def test_apply_in_pandas_returning_no_column_names_and_wrong_amount(self):
@@ -230,7 +230,7 @@ class CogroupedApplyInPandasTestsMixin:
             fn=merge_pandas,
             errorClass=PythonException,
             error_message_regex="Number of columns of the returned 
pandas.DataFrame "
-            "doesn't match specified schema. Expected: 4 Actual: 6\n",
+            "doesn't match specified schema. Expected: 4 Actual: 6",
         )
 
     def test_apply_in_pandas_returning_empty_dataframe(self):
@@ -277,7 +277,7 @@ class CogroupedApplyInPandasTestsMixin:
                 with self.subTest(convert="double to string"):
                     expected = (
                         r"TypeError: Exception thrown when converting 
pandas.Series \(float64\) "
-                        r"with name 'k' to Arrow Array \(string\).\n"
+                        r"with name 'k' to Arrow Array \(string\)."
                     )
                     self._test_merge_error(
                         fn=lambda lft, rgt: pd.DataFrame({"id": [1], "k": 
[2.0]}),
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py 
b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
index efbf1c42b77f..76def82729b9 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
@@ -305,7 +305,7 @@ class ApplyInPandasTestsMixin:
         with self.assertRaisesRegex(
             PythonException,
             "Column names of the returned pandas.DataFrame do not match 
specified schema. "
-            "Missing: mean. Unexpected: median, std.\n",
+            "Missing: mean. Unexpected: median, std.",
         ):
             self._test_apply_in_pandas(
                 lambda key, pdf: pd.DataFrame(
@@ -321,7 +321,7 @@ class ApplyInPandasTestsMixin:
         with self.assertRaisesRegex(
             PythonException,
             "Number of columns of the returned pandas.DataFrame doesn't match "
-            "specified schema. Expected: 2 Actual: 3\n",
+            "specified schema. Expected: 2 Actual: 3",
         ):
             self._test_apply_in_pandas(
                 lambda key, pdf: pd.DataFrame([key + (pdf.v.mean(), 
pdf.v.std())])
@@ -359,7 +359,7 @@ class ApplyInPandasTestsMixin:
                             "can be disabled by using SQL config "
                             
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
                         )
-                    with self.assertRaisesRegex(PythonException, expected + 
"\n"):
+                    with self.assertRaisesRegex(PythonException, expected):
                         self._test_apply_in_pandas(
                             lambda key, pdf: pd.DataFrame([key + 
("test_string",)]),
                             output_schema="id long, mean double",
@@ -370,7 +370,7 @@ class ApplyInPandasTestsMixin:
                     with self.assertRaisesRegex(
                         PythonException,
                         r"TypeError: Exception thrown when converting 
pandas.Series \(float64\) "
-                        r"with name 'mean' to Arrow Array \(string\).\n",
+                        r"with name 'mean' to Arrow Array \(string\).",
                     ):
                         self._test_apply_in_pandas(
                             lambda key, pdf: pd.DataFrame([key + 
(pdf.v.mean(),)]),
@@ -667,7 +667,7 @@ class ApplyInPandasTestsMixin:
             with self.assertRaisesRegex(
                 PythonException,
                 "Column names of the returned pandas.DataFrame do not match "
-                "specified schema. Missing: id. Unexpected: iid.\n",
+                "specified schema. Missing: id. Unexpected: iid.",
             ):
                 grouped_df.apply(column_name_typo).collect()
             with self.assertRaisesRegex(Exception, "[D|d]ecimal.*got.*date"):
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_map.py 
b/python/pyspark/sql/tests/pandas/test_pandas_map.py
index 1b07ec5630ec..43c41ed54ebd 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_map.py
@@ -212,7 +212,7 @@ class MapInPandasTestsMixin:
             PythonException,
             "PySparkRuntimeError: \\[RESULT_COLUMNS_MISMATCH_FOR_PANDAS_UDF\\] 
"
             "Column names of the returned pandas.DataFrame do not match "
-            "specified schema. Missing: id. Unexpected: iid.\n",
+            "specified schema. Missing: id. Unexpected: iid.",
         ):
             (
                 self.spark.range(10, numPartitions=3)
@@ -234,7 +234,7 @@ class MapInPandasTestsMixin:
             PythonException,
             "PySparkRuntimeError: \\[RESULT_COLUMNS_MISMATCH_FOR_PANDAS_UDF\\] 
"
             "Column names of the returned pandas.DataFrame do not match "
-            "specified schema. Missing: id2.\n",
+            "specified schema. Missing: id2.",
         ):
             (
                 self.spark.range(10, numPartitions=3)
@@ -255,7 +255,7 @@ class MapInPandasTestsMixin:
             PythonException,
             "PySparkRuntimeError: \\[RESULT_COLUMNS_MISMATCH_FOR_PANDAS_UDF\\] 
"
             "Column names of the returned pandas.DataFrame do not match "
-            "specified schema. Missing: id2.\n",
+            "specified schema. Missing: id2.",
         ):
             f = self.identity_dataframes_iter("id", "value")
             (df.mapInPandas(f, "id int, id2 long, value int").collect())
@@ -264,7 +264,7 @@ class MapInPandasTestsMixin:
             PythonException,
             "PySparkRuntimeError: \\[RESULT_LENGTH_MISMATCH_FOR_PANDAS_UDF\\] "
             "Number of columns of the returned pandas.DataFrame doesn't match "
-            "specified schema. Expected: 3 Actual: 2\n",
+            "specified schema. Expected: 3 Actual: 2",
         ):
             f = self.identity_dataframes_wo_column_names_iter("id", "value")
             (df.mapInPandas(f, "id int, id2 long, value int").collect())
@@ -312,7 +312,7 @@ class MapInPandasTestsMixin:
                             "can be disabled by using SQL config "
                             
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
                         )
-                    with self.assertRaisesRegex(PythonException, expected + 
"\n"):
+                    with self.assertRaisesRegex(PythonException, expected):
                         (
                             self.spark.range(10, numPartitions=3)
                             .mapInPandas(func, "id double")
@@ -339,7 +339,7 @@ class MapInPandasTestsMixin:
                             "can be disabled by using SQL config "
                             
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
                         )
-                        with self.assertRaisesRegex(PythonException, expected 
+ "\n"):
+                        with self.assertRaisesRegex(PythonException, expected):
                             df.collect()
                     else:
                         self.assertEqual(
@@ -375,7 +375,7 @@ class MapInPandasTestsMixin:
             PythonException,
             "PySparkRuntimeError: \\[RESULT_COLUMNS_MISMATCH_FOR_PANDAS_UDF\\] 
"
             "Column names of the returned pandas.DataFrame do not match "
-            "specified schema. Missing: value.\n",
+            "specified schema. Missing: value.",
         ):
             f = self.dataframes_and_empty_dataframe_iter("id")
             (
@@ -404,7 +404,7 @@ class MapInPandasTestsMixin:
             PythonException,
             "PySparkRuntimeError: \\[RESULT_COLUMNS_MISMATCH_FOR_PANDAS_UDF\\] 
"
             "Column names of the returned pandas.DataFrame do not match "
-            "specified schema. Missing: id. Unexpected: iid.\n",
+            "specified schema. Missing: id. Unexpected: iid.",
         ):
             (
                 self.spark.range(10, numPartitions=3)
diff --git a/python/pyspark/sql/tests/streaming/test_streaming_foreach.py 
b/python/pyspark/sql/tests/streaming/test_streaming_foreach.py
index 9a13c9c8ad0c..39e2bfcc8d5c 100644
--- a/python/pyspark/sql/tests/streaming/test_streaming_foreach.py
+++ b/python/pyspark/sql/tests/streaming/test_streaming_foreach.py
@@ -223,7 +223,6 @@ class StreamingTestsForeachMixin:
         except StreamingQueryException as e:
             err_msg = str(e)
             self.assertTrue("test error" in err_msg)
-            self.assertTrue("FOREACH_USER_FUNCTION_ERROR" in err_msg)
 
         self.assertEqual(len(tester.process_events()), 0)  # no row was 
processed
         close_events = tester.close_events()
diff --git a/python/pyspark/sql/tests/streaming/test_streaming_foreach_batch.py 
b/python/pyspark/sql/tests/streaming/test_streaming_foreach_batch.py
index cbfc0e0123da..818d8361537d 100644
--- a/python/pyspark/sql/tests/streaming/test_streaming_foreach_batch.py
+++ b/python/pyspark/sql/tests/streaming/test_streaming_foreach_batch.py
@@ -81,8 +81,6 @@ class StreamingTestsForeachBatchMixin:
         except StreamingQueryException as e:
             err_msg = str(e)
             self.assertTrue("this should fail" in err_msg)
-            # check for foreachBatch error class
-            self.assertTrue("FOREACH_BATCH_USER_FUNCTION_ERROR" in err_msg)
         finally:
             if q:
                 q.stop()
diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py 
b/python/pyspark/sql/tests/streaming/test_streaming_listener.py
index 6771df97eb4d..620186f70b66 100644
--- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py
+++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py
@@ -413,7 +413,9 @@ class StreamingListenerTests(StreamingListenerTestsMixin, 
ReusedSQLTestCase):
                         self.fail("Not getting terminated event after 50 
seconds")
                 q.stop()
                 
self.spark.sparkContext._jsc.sc().listenerBus().waitUntilEmpty()
-                self.check_terminated_event(terminated_event, 
"ZeroDivisionError")
+                self.check_terminated_event(
+                    terminated_event, "ZeroDivisionError", 
errorClass="PYTHON_EXCEPTION"
+                )
 
             finally:
                 self.spark.streams.removeListener(test_listener)
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
index a4da5ea99838..72662d2cb048 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
@@ -176,11 +176,13 @@ object StreamingForeachBatchHelper extends Logging {
                 log"Python foreach batch for dfId ${MDC(DATAFRAME_ID, 
args.dfId)} " +
                 log"completed (ret: 0)")
           case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
-            val msg = PythonWorkerUtils.readUTF(dataIn)
-            throw new PythonException(
+            val traceback = PythonWorkerUtils.readUTF(dataIn)
+            val msg =
               s"[session: ${sessionHolder.sessionId}] [userId: 
${sessionHolder.userId}] " +
-                s"Found error inside foreachBatch Python process: $msg",
-              null)
+                s"Found error inside foreachBatch Python process"
+            throw new PythonException(
+              errorClass = "PYTHON_EXCEPTION",
+              messageParameters = Map("msg" -> msg, "traceback" -> traceback))
           case otherValue =>
             throw new IllegalStateException(
               s"[session: ${sessionHolder.sessionId}] [userId: 
${sessionHolder.userId}] " +
diff --git 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
index f994ada920ec..faab81778482 100644
--- 
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
+++ 
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
@@ -91,11 +91,12 @@ class PythonStreamingQueryListener(listener: 
SimplePythonFunction, sessionHolder
             log"Streaming query listener function ${MDC(FUNCTION_NAME, 
functionName)} " +
               log"completed (ret: 0)")
         case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
-          val msg = PythonWorkerUtils.readUTF(dataIn)
+          val traceback = PythonWorkerUtils.readUTF(dataIn)
+          val msg = s"Found error inside Streaming query listener Python " +
+            s"process for function $functionName"
           throw new PythonException(
-            s"Found error inside Streaming query listener Python " +
-              s"process for function $functionName: $msg",
-            null)
+            errorClass = "PYTHON_EXCEPTION",
+            messageParameters = Map("msg" -> msg, "traceback" -> traceback))
         case otherValue =>
           throw new IllegalStateException(
             s"Unexpected return value $otherValue from the " +
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala
index 23a43dbd641d..899fd85a8bb8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala
@@ -172,6 +172,10 @@ trait SQLQueryTestHelper extends SQLConfHelper with 
Logging {
     try {
       result
     } catch {
+      case e: SparkThrowable with Throwable
+          if Option(e.getCondition).contains("PYTHON_EXCEPTION") =>
+        val msg = Option(e.getMessageParameters.get("traceback")).getOrElse("")
+        (emptySchema, Seq(e.getClass.getName, msg))
       case e: SparkThrowable with Throwable if e.getCondition != null =>
         (emptySchema, Seq(e.getClass.getName, getMessage(e, format)))
       case a: AnalysisException =>
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
index d201f1890dbd..bac9849381a3 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.python
 
 import java.io.{File, FileWriter}
 
-import org.apache.spark.SparkException
+import org.apache.spark.api.python.PythonException
 import org.apache.spark.api.python.PythonUtils
 import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils, 
QueryTest, Row}
 import org.apache.spark.sql.execution.FilterExec
@@ -755,14 +755,14 @@ class PythonDataSourceSuite extends 
PythonDataSourceSuiteBase {
       createUserDefinedPythonDataSource(dataSourceName, dataSourceScript))
 
     withClue("user error") {
-      val error = intercept[SparkException] {
+      val error = intercept[PythonException] {
         spark.range(10).write.format(dataSourceName).mode("append").save()
       }
       assert(error.getMessage.contains("something is wrong"))
     }
 
     withClue("no commit message") {
-      val error = intercept[SparkException] {
+      val error = intercept[PythonException] {
         spark.range(1).write.format(dataSourceName).mode("append").save()
       }
       assert(error.getMessage.contains("DATA_SOURCE_TYPE_MISMATCH"))
@@ -926,7 +926,7 @@ class PythonDataSourceSuite extends 
PythonDataSourceSuiteBase {
       }
 
       withClue("abort") {
-        intercept[SparkException] {
+        intercept[PythonException] {
           sql("SELECT * FROM range(8, 12, 1, 4)")
             .write.format(dataSourceName)
             .mode("append")


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to