This is an automated email from the ASF dual-hosted git repository.
ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new 78d9eae49a82 [SPARK-54599][PYTHON] Refactor PythonException so it can
take errorClass with sqlstate
78d9eae49a82 is described below
commit 78d9eae49a8228bf9af36de1d59e3ade3295d88e
Author: Tian Gao <[email protected]>
AuthorDate: Wed Feb 4 11:11:05 2026 +0800
[SPARK-54599][PYTHON] Refactor PythonException so it can take errorClass
with sqlstate
### What changes were proposed in this pull request?
In general we want an `sqlstate` with `PythonException` to be better
categorized. This PR basically rewrite `PythonException` to
`SparkPythonException` style. For now we report `38000` for all existing
exceptions, but we leave the flexibility to future changes.
I also want to take this chance to make the user side exception print a bit
more pythonic.
### Why are the changes needed?
Many reported PythonExceptions can't be categorized by `sqlstate` which is
not helpful to triage and build metrics.
### Does this PR introduce _any_ user-facing change?
Yes. Some `Py4JJavaError` will become `PythonException` now - those raised
by python workers. I think that's an improvement.
Also, the exception report would be different.
code:
```python
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
spark = SparkSession.builder.getOrCreate()
F.udf
def add_one(x):
raise ValueError()
df = spark.createDataFrame([(1,), (2,), (3,)], ["value"])
df = df.withColumn("value_plus_one", add_one("value"))
df.show()
```
Current traceback:
```
Traceback (most recent call last):
File "/Users/gaotian/programs/spark/python/example.py", line 15, in
<module>
df.show()
File
"/Users/gaotian/programs/spark/python/pyspark/sql/classic/dataframe.py", line
285, in show
print(self._show_string(n, truncate, vertical))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File
"/Users/gaotian/programs/spark/python/pyspark/sql/classic/dataframe.py", line
303, in _show_string
return self._jdf.showString(n, 20, vertical)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File
"/Users/gaotian/programs/spark/venv/lib/python3.12/site-packages/py4j/java_gateway.py",
line 1362, in __call__
return_value = get_return_value(
^^^^^^^^^^^^^^^^^
File
"/Users/gaotian/programs/spark/python/pyspark/errors/exceptions/captured.py",
line 269, in deco
raise converted from None
File "/Users/gaotian/programs/spark/python/example.py", line 9, in add_one
raise ValueError()
^^^^^^^^^^^^^^^^^
pyspark.errors.exceptions.captured.PythonException:
An exception was thrown from the Python worker. Please see the stack
trace below.
Traceback (most recent call last):
File "/Users/gaotian/programs/spark/python/example.py", line 8, in add_one
raise ValueError()
ValueError
```
Without any change, the worker part traceback would be like
```
pyspark.errors.exceptions.captured.PythonException: [PYTHON_EXCEPTION]
Python worker failed with the following error:
Traceback (most recent call last):
File "/Users/gaotian/programs/spark/python/example.py", line 8, in add_one
raise ValueError()
ValueError
SQLSTATE: 38000
```
I think moving the exception explanation right after the Exception type is
good - that's intuitive to Python users compared to having a new line. But that
line is too long. I think we should have a special `__str__` method for
`PythonException` instead of using the default `getMessage()`
Option 1: Remove the errorClass and SQLSTATE because end users are not that
interested
```
pyspark.errors.exceptions.captured.PythonException: Python worker failed
with the following error:
Traceback (most recent call last):
File "/Users/gaotian/programs/spark/python/example.py", line 8, in add_one
raise ValueError()
ValueError
```
Option 2: Do some magic to print the exception that user is supposed to use
to catch - instead of exposing our implementation details.
```
pyspark.errors.PythonException: Python worker failed with the following
error:
Traceback (most recent call last):
File "/Users/gaotian/programs/spark/python/example.py", line 8, in add_one
raise ValueError()
ValueError
```
Option 3: We actually have two traceback from worker printed, which is
weird. I think we should just keep one. We should not artificially combine the
tracebacks from the driver and the worker
```
File
"/Users/gaotian/programs/spark/python/pyspark/errors/exceptions/captured.py",
line 269, in deco
raise converted from None
File "/Users/gaotian/programs/spark/python/example.py", line 9, in
add_one. -> THIS SHOULD BE REMOVED!
raise ValueError()
^^^^^^^^^^^^^^^^^
pyspark.errors.PythonException: Python worker failed with the following
error:
Traceback (most recent call last):
File "/Users/gaotian/programs/spark/python/example.py", line 8, in add_one
raise ValueError()
ValueError
```
New print:
```
Traceback (most recent call last):
File "/Users/gaotian/programs/spark/python/example.py", line 15, in
<module>
df.show()
File
"/Users/gaotian/programs/spark/python/pyspark/sql/classic/dataframe.py", line
285, in show
print(self._show_string(n, truncate, vertical))
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File
"/Users/gaotian/programs/spark/python/pyspark/sql/classic/dataframe.py", line
303, in _show_string
return self._jdf.showString(n, 20, vertical)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File
"/Users/gaotian/programs/spark/venv/lib/python3.12/site-packages/py4j/java_gateway.py",
line 1362, in __call__
return_value = get_return_value(
^^^^^^^^^^^^^^^^^
File
"/Users/gaotian/programs/spark/python/pyspark/errors/exceptions/captured.py",
line 269, in deco
raise converted from None
pyspark.errors.PythonException: Python worker failed with the following
error:
Traceback (most recent call last):
File "/Users/gaotian/programs/spark/python/example.py", line 8, in add_one
raise ValueError()
ValueError
```
### How was this patch tested?
CI
### Was this patch authored or co-authored using generative AI tooling?
No.
Closes #53099 from gaogaotiantian/add-sqlstate-python-exception.
Authored-by: Tian Gao <[email protected]>
Signed-off-by: Ruifeng Zheng <[email protected]>
---
.../src/main/resources/error/error-conditions.json | 6 ++++
.../org/apache/spark/api/python/PythonRDD.scala | 30 ++++++++++++++++--
.../org/apache/spark/api/python/PythonRunner.scala | 8 +++--
python/pyspark/errors/exceptions/captured.py | 37 +++++++++++-----------
python/pyspark/sql/context.py | 2 +-
.../sql/tests/arrow/test_arrow_cogrouped_map.py | 4 +--
.../sql/tests/arrow/test_arrow_grouped_map.py | 4 +--
.../sql/tests/pandas/test_pandas_cogrouped_map.py | 6 ++--
.../sql/tests/pandas/test_pandas_grouped_map.py | 10 +++---
python/pyspark/sql/tests/pandas/test_pandas_map.py | 16 +++++-----
.../sql/tests/streaming/test_streaming_foreach.py | 1 -
.../streaming/test_streaming_foreach_batch.py | 2 --
.../sql/tests/streaming/test_streaming_listener.py | 4 ++-
.../planner/StreamingForeachBatchHelper.scala | 10 +++---
.../planner/StreamingQueryListenerHelper.scala | 9 +++---
.../org/apache/spark/sql/SQLQueryTestHelper.scala | 4 +++
.../execution/python/PythonDataSourceSuite.scala | 8 ++---
17 files changed, 101 insertions(+), 60 deletions(-)
diff --git a/common/utils/src/main/resources/error/error-conditions.json
b/common/utils/src/main/resources/error/error-conditions.json
index 8e3acfdb2c0a..258fcfeda15e 100644
--- a/common/utils/src/main/resources/error/error-conditions.json
+++ b/common/utils/src/main/resources/error/error-conditions.json
@@ -5262,6 +5262,12 @@
],
"sqlState" : "38000"
},
+ "PYTHON_EXCEPTION" : {
+ "message" : [
+ "<msg>: <traceback>"
+ ],
+ "sqlState" : "38000"
+ },
"PYTHON_STREAMING_DATA_SOURCE_RUNTIME_ERROR" : {
"message" : [
"Failed when Python streaming data source perform <action>: <msg>"
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index cf0169fed60c..45bf30675148 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -126,8 +126,34 @@ private[spark] case class SimplePythonFunction(
private[spark] case class ChainedPythonFunctions(funcs: Seq[PythonFunction])
/** Thrown for exceptions in user Python code. */
-private[spark] class PythonException(msg: String, cause: Throwable)
- extends RuntimeException(msg, cause)
+private[spark] class PythonException(
+ msg: String,
+ cause: Throwable,
+ errorClass: Option[String],
+ messageParameters: Map[String, String],
+ context: Array[QueryContext])
+ extends RuntimeException(msg, cause) with SparkThrowable {
+
+ def this(
+ errorClass: String,
+ messageParameters: Map[String, String],
+ cause: Throwable = null,
+ context: Array[QueryContext] = Array.empty,
+ summary: String = "") = {
+ this(
+ SparkThrowableHelper.getMessage(errorClass, messageParameters, summary),
+ cause,
+ Option(errorClass),
+ messageParameters,
+ context
+ )
+ }
+
+ override def getMessageParameters: java.util.Map[String, String] =
messageParameters.asJava
+
+ override def getCondition: String = errorClass.orNull
+ override def getQueryContext: Array[QueryContext] = context
+}
/**
* Form an RDD[(Array[Byte], Array[Byte])] from key-value pairs returned from
Python.
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
index c3ee3853ce0f..6bc64f54d3a2 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
@@ -659,8 +659,12 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
protected def handlePythonException(): PythonException = {
// Signals that an exception has been thrown in python
- val msg = PythonWorkerUtils.readUTF(stream)
- new PythonException(msg, writer.exception.orNull)
+ val traceback = PythonWorkerUtils.readUTF(stream)
+ val msg = "An exception was thrown from the Python worker"
+ new PythonException(
+ errorClass = "PYTHON_EXCEPTION",
+ messageParameters = Map("msg" -> msg, "traceback" -> traceback),
+ cause = writer.exception.orNull)
}
protected def handleEndOfDataSection(): Unit = {
diff --git a/python/pyspark/errors/exceptions/captured.py
b/python/pyspark/errors/exceptions/captured.py
index 0f76e3b5f6a0..adde27bf8a5d 100644
--- a/python/pyspark/errors/exceptions/captured.py
+++ b/python/pyspark/errors/exceptions/captured.py
@@ -234,25 +234,13 @@ def _convert_exception(e: "Py4JJavaError") ->
CapturedException:
return SparkUpgradeException(origin=e)
elif is_instance_of(gw, e, "org.apache.spark.SparkNoSuchElementException"):
return SparkNoSuchElementException(origin=e)
-
- c: "Py4JJavaError" = e.getCause()
- stacktrace: str = getattr(jvm,
"org.apache.spark.util.Utils").exceptionString(e)
- if c is not None and (
- is_instance_of(gw, c, "org.apache.spark.api.python.PythonException")
- # To make sure this only catches Python UDFs.
- and any(
- map(
- lambda v: "org.apache.spark.sql.execution.python" in
v.toString(), c.getStackTrace()
- )
- )
- ):
- msg = (
- "\n An exception was thrown from the Python worker. "
- "Please see the stack trace below.\n%s" % c.getMessage()
- )
- return PythonException(msg, stacktrace)
-
- return UnknownException(desc=e.toString(), stackTrace=stacktrace, cause=c)
+ elif is_instance_of(gw, e, "org.apache.spark.api.python.PythonException"):
+ return PythonException(origin=e)
+ return UnknownException(
+ desc=e.toString(),
+ stackTrace=getattr(jvm,
"org.apache.spark.util.Utils").exceptionString(e),
+ cause=e.getCause(),
+ )
def capture_sql_exception(f: Callable[..., Any]) -> Callable[..., Any]:
@@ -348,6 +336,17 @@ class PythonException(CapturedException,
BasePythonException):
Exceptions thrown from Python workers.
"""
+ def __str__(self) -> str:
+ messageParameters = self.getMessageParameters()
+
+ if (
+ messageParameters is None
+ or "msg" not in messageParameters
+ or "traceback" not in messageParameters
+ ):
+ return super().__str__()
+ return
f"{messageParameters['msg']}:\n{messageParameters['traceback'].strip()}"
+
class ArithmeticException(CapturedException, BaseArithmeticException):
"""
diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py
index c51b9e063a77..34f488244d78 100644
--- a/python/pyspark/sql/context.py
+++ b/python/pyspark/sql/context.py
@@ -472,7 +472,7 @@ class SQLContext:
>>> sqlContext.createDataFrame(rdd, "boolean").collect() # doctest:
+IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
- Py4JJavaError: ...
+ pyspark.errors.exceptions.captured.PythonException: ...
"""
return self.sparkSession.createDataFrame( # type:
ignore[call-overload]
data, schema, samplingRatio, verifySchema
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py
b/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py
index 98362a44d3eb..45c8a455a44a 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_cogrouped_map.py
@@ -196,7 +196,7 @@ class CogroupedMapInArrowTestsMixin:
with self.assertRaisesRegex(
PythonException,
"Column names of the returned pyarrow.Table do not match
specified schema. "
- "Missing: m. Unexpected: v, v2.\n",
+ "Missing: m. Unexpected: v, v2.",
):
# stats returns three columns while here we set schema with
two columns
self.cogrouped.applyInArrow(stats, schema="id long, m
double").collect()
@@ -232,7 +232,7 @@ class CogroupedMapInArrowTestsMixin:
with self.assertRaisesRegex(
PythonException,
"Column names of the returned pyarrow.Table do not match
specified schema. "
- "Missing: m.\n",
+ "Missing: m.",
):
# stats returns one column for even keys while here we set
schema with two columns
self.cogrouped.applyInArrow(odd_means, schema="id long, m
double").collect()
diff --git a/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py
b/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py
index 9c1b14676ecc..d34c31220be6 100644
--- a/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py
+++ b/python/pyspark/sql/tests/arrow/test_arrow_grouped_map.py
@@ -224,7 +224,7 @@ class ApplyInArrowTestsMixin:
with self.assertRaisesRegex(
PythonException,
"Column names of the returned pyarrow.Table do not match
specified schema. "
- "Missing: m. Unexpected: v, v2.\n",
+ "Missing: m. Unexpected: v, v2.",
):
# stats returns three columns while here we set schema
with two columns
df.groupby("id").applyInArrow(
@@ -265,7 +265,7 @@ class ApplyInArrowTestsMixin:
with self.assertRaisesRegex(
PythonException,
"Column names of the returned pyarrow.Table do not match
specified schema. "
- "Missing: m.\n",
+ "Missing: m.",
):
# stats returns one column for even keys while here we set
schema with two columns
df.groupby("id").applyInArrow(odd_means, schema="id long, m
double").collect()
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
index aa0a8479b3cd..88b5b9138f71 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_cogrouped_map.py
@@ -209,7 +209,7 @@ class CogroupedApplyInPandasTestsMixin:
fn=merge_pandas,
errorClass=PythonException,
error_message_regex="Column names of the returned pandas.DataFrame
"
- "do not match specified schema. Unexpected: add, more.\n",
+ "do not match specified schema. Unexpected: add, more.",
)
def test_apply_in_pandas_returning_no_column_names_and_wrong_amount(self):
@@ -230,7 +230,7 @@ class CogroupedApplyInPandasTestsMixin:
fn=merge_pandas,
errorClass=PythonException,
error_message_regex="Number of columns of the returned
pandas.DataFrame "
- "doesn't match specified schema. Expected: 4 Actual: 6\n",
+ "doesn't match specified schema. Expected: 4 Actual: 6",
)
def test_apply_in_pandas_returning_empty_dataframe(self):
@@ -277,7 +277,7 @@ class CogroupedApplyInPandasTestsMixin:
with self.subTest(convert="double to string"):
expected = (
r"TypeError: Exception thrown when converting
pandas.Series \(float64\) "
- r"with name 'k' to Arrow Array \(string\).\n"
+ r"with name 'k' to Arrow Array \(string\)."
)
self._test_merge_error(
fn=lambda lft, rgt: pd.DataFrame({"id": [1], "k":
[2.0]}),
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
index efbf1c42b77f..76def82729b9 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_grouped_map.py
@@ -305,7 +305,7 @@ class ApplyInPandasTestsMixin:
with self.assertRaisesRegex(
PythonException,
"Column names of the returned pandas.DataFrame do not match
specified schema. "
- "Missing: mean. Unexpected: median, std.\n",
+ "Missing: mean. Unexpected: median, std.",
):
self._test_apply_in_pandas(
lambda key, pdf: pd.DataFrame(
@@ -321,7 +321,7 @@ class ApplyInPandasTestsMixin:
with self.assertRaisesRegex(
PythonException,
"Number of columns of the returned pandas.DataFrame doesn't match "
- "specified schema. Expected: 2 Actual: 3\n",
+ "specified schema. Expected: 2 Actual: 3",
):
self._test_apply_in_pandas(
lambda key, pdf: pd.DataFrame([key + (pdf.v.mean(),
pdf.v.std())])
@@ -359,7 +359,7 @@ class ApplyInPandasTestsMixin:
"can be disabled by using SQL config "
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
)
- with self.assertRaisesRegex(PythonException, expected +
"\n"):
+ with self.assertRaisesRegex(PythonException, expected):
self._test_apply_in_pandas(
lambda key, pdf: pd.DataFrame([key +
("test_string",)]),
output_schema="id long, mean double",
@@ -370,7 +370,7 @@ class ApplyInPandasTestsMixin:
with self.assertRaisesRegex(
PythonException,
r"TypeError: Exception thrown when converting
pandas.Series \(float64\) "
- r"with name 'mean' to Arrow Array \(string\).\n",
+ r"with name 'mean' to Arrow Array \(string\).",
):
self._test_apply_in_pandas(
lambda key, pdf: pd.DataFrame([key +
(pdf.v.mean(),)]),
@@ -667,7 +667,7 @@ class ApplyInPandasTestsMixin:
with self.assertRaisesRegex(
PythonException,
"Column names of the returned pandas.DataFrame do not match "
- "specified schema. Missing: id. Unexpected: iid.\n",
+ "specified schema. Missing: id. Unexpected: iid.",
):
grouped_df.apply(column_name_typo).collect()
with self.assertRaisesRegex(Exception, "[D|d]ecimal.*got.*date"):
diff --git a/python/pyspark/sql/tests/pandas/test_pandas_map.py
b/python/pyspark/sql/tests/pandas/test_pandas_map.py
index 1b07ec5630ec..43c41ed54ebd 100644
--- a/python/pyspark/sql/tests/pandas/test_pandas_map.py
+++ b/python/pyspark/sql/tests/pandas/test_pandas_map.py
@@ -212,7 +212,7 @@ class MapInPandasTestsMixin:
PythonException,
"PySparkRuntimeError: \\[RESULT_COLUMNS_MISMATCH_FOR_PANDAS_UDF\\]
"
"Column names of the returned pandas.DataFrame do not match "
- "specified schema. Missing: id. Unexpected: iid.\n",
+ "specified schema. Missing: id. Unexpected: iid.",
):
(
self.spark.range(10, numPartitions=3)
@@ -234,7 +234,7 @@ class MapInPandasTestsMixin:
PythonException,
"PySparkRuntimeError: \\[RESULT_COLUMNS_MISMATCH_FOR_PANDAS_UDF\\]
"
"Column names of the returned pandas.DataFrame do not match "
- "specified schema. Missing: id2.\n",
+ "specified schema. Missing: id2.",
):
(
self.spark.range(10, numPartitions=3)
@@ -255,7 +255,7 @@ class MapInPandasTestsMixin:
PythonException,
"PySparkRuntimeError: \\[RESULT_COLUMNS_MISMATCH_FOR_PANDAS_UDF\\]
"
"Column names of the returned pandas.DataFrame do not match "
- "specified schema. Missing: id2.\n",
+ "specified schema. Missing: id2.",
):
f = self.identity_dataframes_iter("id", "value")
(df.mapInPandas(f, "id int, id2 long, value int").collect())
@@ -264,7 +264,7 @@ class MapInPandasTestsMixin:
PythonException,
"PySparkRuntimeError: \\[RESULT_LENGTH_MISMATCH_FOR_PANDAS_UDF\\] "
"Number of columns of the returned pandas.DataFrame doesn't match "
- "specified schema. Expected: 3 Actual: 2\n",
+ "specified schema. Expected: 3 Actual: 2",
):
f = self.identity_dataframes_wo_column_names_iter("id", "value")
(df.mapInPandas(f, "id int, id2 long, value int").collect())
@@ -312,7 +312,7 @@ class MapInPandasTestsMixin:
"can be disabled by using SQL config "
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
)
- with self.assertRaisesRegex(PythonException, expected +
"\n"):
+ with self.assertRaisesRegex(PythonException, expected):
(
self.spark.range(10, numPartitions=3)
.mapInPandas(func, "id double")
@@ -339,7 +339,7 @@ class MapInPandasTestsMixin:
"can be disabled by using SQL config "
"`spark.sql.execution.pandas.convertToArrowArraySafely`."
)
- with self.assertRaisesRegex(PythonException, expected
+ "\n"):
+ with self.assertRaisesRegex(PythonException, expected):
df.collect()
else:
self.assertEqual(
@@ -375,7 +375,7 @@ class MapInPandasTestsMixin:
PythonException,
"PySparkRuntimeError: \\[RESULT_COLUMNS_MISMATCH_FOR_PANDAS_UDF\\]
"
"Column names of the returned pandas.DataFrame do not match "
- "specified schema. Missing: value.\n",
+ "specified schema. Missing: value.",
):
f = self.dataframes_and_empty_dataframe_iter("id")
(
@@ -404,7 +404,7 @@ class MapInPandasTestsMixin:
PythonException,
"PySparkRuntimeError: \\[RESULT_COLUMNS_MISMATCH_FOR_PANDAS_UDF\\]
"
"Column names of the returned pandas.DataFrame do not match "
- "specified schema. Missing: id. Unexpected: iid.\n",
+ "specified schema. Missing: id. Unexpected: iid.",
):
(
self.spark.range(10, numPartitions=3)
diff --git a/python/pyspark/sql/tests/streaming/test_streaming_foreach.py
b/python/pyspark/sql/tests/streaming/test_streaming_foreach.py
index 9a13c9c8ad0c..39e2bfcc8d5c 100644
--- a/python/pyspark/sql/tests/streaming/test_streaming_foreach.py
+++ b/python/pyspark/sql/tests/streaming/test_streaming_foreach.py
@@ -223,7 +223,6 @@ class StreamingTestsForeachMixin:
except StreamingQueryException as e:
err_msg = str(e)
self.assertTrue("test error" in err_msg)
- self.assertTrue("FOREACH_USER_FUNCTION_ERROR" in err_msg)
self.assertEqual(len(tester.process_events()), 0) # no row was
processed
close_events = tester.close_events()
diff --git a/python/pyspark/sql/tests/streaming/test_streaming_foreach_batch.py
b/python/pyspark/sql/tests/streaming/test_streaming_foreach_batch.py
index cbfc0e0123da..818d8361537d 100644
--- a/python/pyspark/sql/tests/streaming/test_streaming_foreach_batch.py
+++ b/python/pyspark/sql/tests/streaming/test_streaming_foreach_batch.py
@@ -81,8 +81,6 @@ class StreamingTestsForeachBatchMixin:
except StreamingQueryException as e:
err_msg = str(e)
self.assertTrue("this should fail" in err_msg)
- # check for foreachBatch error class
- self.assertTrue("FOREACH_BATCH_USER_FUNCTION_ERROR" in err_msg)
finally:
if q:
q.stop()
diff --git a/python/pyspark/sql/tests/streaming/test_streaming_listener.py
b/python/pyspark/sql/tests/streaming/test_streaming_listener.py
index 6771df97eb4d..620186f70b66 100644
--- a/python/pyspark/sql/tests/streaming/test_streaming_listener.py
+++ b/python/pyspark/sql/tests/streaming/test_streaming_listener.py
@@ -413,7 +413,9 @@ class StreamingListenerTests(StreamingListenerTestsMixin,
ReusedSQLTestCase):
self.fail("Not getting terminated event after 50
seconds")
q.stop()
self.spark.sparkContext._jsc.sc().listenerBus().waitUntilEmpty()
- self.check_terminated_event(terminated_event,
"ZeroDivisionError")
+ self.check_terminated_event(
+ terminated_event, "ZeroDivisionError",
errorClass="PYTHON_EXCEPTION"
+ )
finally:
self.spark.streams.removeListener(test_listener)
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
index a4da5ea99838..72662d2cb048 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingForeachBatchHelper.scala
@@ -176,11 +176,13 @@ object StreamingForeachBatchHelper extends Logging {
log"Python foreach batch for dfId ${MDC(DATAFRAME_ID,
args.dfId)} " +
log"completed (ret: 0)")
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
- val msg = PythonWorkerUtils.readUTF(dataIn)
- throw new PythonException(
+ val traceback = PythonWorkerUtils.readUTF(dataIn)
+ val msg =
s"[session: ${sessionHolder.sessionId}] [userId:
${sessionHolder.userId}] " +
- s"Found error inside foreachBatch Python process: $msg",
- null)
+ s"Found error inside foreachBatch Python process"
+ throw new PythonException(
+ errorClass = "PYTHON_EXCEPTION",
+ messageParameters = Map("msg" -> msg, "traceback" -> traceback))
case otherValue =>
throw new IllegalStateException(
s"[session: ${sessionHolder.sessionId}] [userId:
${sessionHolder.userId}] " +
diff --git
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
index f994ada920ec..faab81778482 100644
---
a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
+++
b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/StreamingQueryListenerHelper.scala
@@ -91,11 +91,12 @@ class PythonStreamingQueryListener(listener:
SimplePythonFunction, sessionHolder
log"Streaming query listener function ${MDC(FUNCTION_NAME,
functionName)} " +
log"completed (ret: 0)")
case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
- val msg = PythonWorkerUtils.readUTF(dataIn)
+ val traceback = PythonWorkerUtils.readUTF(dataIn)
+ val msg = s"Found error inside Streaming query listener Python " +
+ s"process for function $functionName"
throw new PythonException(
- s"Found error inside Streaming query listener Python " +
- s"process for function $functionName: $msg",
- null)
+ errorClass = "PYTHON_EXCEPTION",
+ messageParameters = Map("msg" -> msg, "traceback" -> traceback))
case otherValue =>
throw new IllegalStateException(
s"Unexpected return value $otherValue from the " +
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala
b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala
index 23a43dbd641d..899fd85a8bb8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestHelper.scala
@@ -172,6 +172,10 @@ trait SQLQueryTestHelper extends SQLConfHelper with
Logging {
try {
result
} catch {
+ case e: SparkThrowable with Throwable
+ if Option(e.getCondition).contains("PYTHON_EXCEPTION") =>
+ val msg = Option(e.getMessageParameters.get("traceback")).getOrElse("")
+ (emptySchema, Seq(e.getClass.getName, msg))
case e: SparkThrowable with Throwable if e.getCondition != null =>
(emptySchema, Seq(e.getClass.getName, getMessage(e, format)))
case a: AnalysisException =>
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
index d201f1890dbd..bac9849381a3 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.python
import java.io.{File, FileWriter}
-import org.apache.spark.SparkException
+import org.apache.spark.api.python.PythonException
import org.apache.spark.api.python.PythonUtils
import org.apache.spark.sql.{AnalysisException, IntegratedUDFTestUtils,
QueryTest, Row}
import org.apache.spark.sql.execution.FilterExec
@@ -755,14 +755,14 @@ class PythonDataSourceSuite extends
PythonDataSourceSuiteBase {
createUserDefinedPythonDataSource(dataSourceName, dataSourceScript))
withClue("user error") {
- val error = intercept[SparkException] {
+ val error = intercept[PythonException] {
spark.range(10).write.format(dataSourceName).mode("append").save()
}
assert(error.getMessage.contains("something is wrong"))
}
withClue("no commit message") {
- val error = intercept[SparkException] {
+ val error = intercept[PythonException] {
spark.range(1).write.format(dataSourceName).mode("append").save()
}
assert(error.getMessage.contains("DATA_SOURCE_TYPE_MISMATCH"))
@@ -926,7 +926,7 @@ class PythonDataSourceSuite extends
PythonDataSourceSuiteBase {
}
withClue("abort") {
- intercept[SparkException] {
+ intercept[PythonException] {
sql("SELECT * FROM range(8, 12, 1, 4)")
.write.format(dataSourceName)
.mode("append")
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]