This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch branch-3.5 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.5 by this push: new b5ab247aa6e [SPARK-44644][PYTHON][3.5] Improve error messages for Python UDTFs with pickling errors b5ab247aa6e is described below commit b5ab247aa6e45180c2e826da74fcb615f3da3335 Author: allisonwang-db <allison.w...@databricks.com> AuthorDate: Mon Aug 7 13:03:03 2023 +0800 [SPARK-44644][PYTHON][3.5] Improve error messages for Python UDTFs with pickling errors ### What changes were proposed in this pull request? Cherry-pick https://github.com/apache/spark/commit/62415dc59627e1f7b4e3449ae728e93c1fc0b74f This PR improves the error messages when a Python UDTF failed to pickle. ### Why are the changes needed? To make the error message more user-friendly ### Does this PR introduce _any_ user-facing change? Yes, before this PR, when a UDTF fails to pickle, it throws this confusing exception: ``` _pickle.PicklingError: Cannot pickle files that are not opened for reading: w ``` After this PR, the error is more clear: `[UDTF_SERIALIZATION_ERROR] Cannot serialize the UDTF 'TestUDTF': Please check the stack trace and make sure that the function is serializable.` And for spark session access inside a UDTF: `[UDTF_SERIALIZATION_ERROR] it appears that you are attempting to reference SparkSession inside a UDTF. SparkSession can only be used on the driver, not in code that runs on workers. Please remove the reference and try again.` ### How was this patch tested? New UTs. Closes #42349 from allisonwang-db/spark-44644-3.5. Authored-by: allisonwang-db <allison.w...@databricks.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- python/pyspark/cloudpickle/cloudpickle_fast.py | 2 +- python/pyspark/errors/error_classes.py | 5 +++++ python/pyspark/sql/connect/plan.py | 15 +++++++++++-- python/pyspark/sql/tests/test_udtf.py | 30 +++++++++++++++++++++++++- python/pyspark/sql/udtf.py | 25 ++++++++++++++++++++- 5 files changed, 72 insertions(+), 5 deletions(-) diff --git a/python/pyspark/cloudpickle/cloudpickle_fast.py b/python/pyspark/cloudpickle/cloudpickle_fast.py index 63aaffa096b..ee1f4b8ee96 100644 --- a/python/pyspark/cloudpickle/cloudpickle_fast.py +++ b/python/pyspark/cloudpickle/cloudpickle_fast.py @@ -631,7 +631,7 @@ class CloudPickler(Pickler): try: return Pickler.dump(self, obj) except RuntimeError as e: - if "recursion" in e.args[0]: + if len(e.args) > 0 and "recursion" in e.args[0]: msg = ( "Could not pickle object as excessively deep recursion " "required." diff --git a/python/pyspark/errors/error_classes.py b/python/pyspark/errors/error_classes.py index 4ea3e678810..971dc59bbb2 100644 --- a/python/pyspark/errors/error_classes.py +++ b/python/pyspark/errors/error_classes.py @@ -743,6 +743,11 @@ ERROR_CLASSES_JSON = """ "Mismatch in return type for the UDTF '<name>'. Expected a 'StructType', but got '<return_type>'. Please ensure the return type is a correctly formatted StructType." ] }, + "UDTF_SERIALIZATION_ERROR" : { + "message" : [ + "Cannot serialize the UDTF '<name>': <message>" + ] + }, "UNEXPECTED_RESPONSE_FROM_SERVER" : { "message" : [ "Unexpected response from iterator server." diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 3390faa04de..2e918700848 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -21,6 +21,7 @@ check_dependencies(__name__) from typing import Any, List, Optional, Type, Sequence, Union, cast, TYPE_CHECKING, Mapping, Dict import functools import json +import pickle from threading import Lock from inspect import signature, isclass @@ -40,7 +41,7 @@ from pyspark.sql.connect.expressions import ( LiteralExpression, ) from pyspark.sql.connect.types import pyspark_types_to_proto_types, UnparsedDataType -from pyspark.errors import PySparkTypeError, PySparkNotImplementedError +from pyspark.errors import PySparkTypeError, PySparkNotImplementedError, PySparkRuntimeError if TYPE_CHECKING: from pyspark.sql.connect._typing import ColumnOrName @@ -2200,7 +2201,17 @@ class PythonUDTF: assert self._return_type is not None udtf.return_type.CopyFrom(pyspark_types_to_proto_types(self._return_type)) udtf.eval_type = self._eval_type - udtf.command = CloudPickleSerializer().dumps(self._func) + try: + udtf.command = CloudPickleSerializer().dumps(self._func) + except pickle.PicklingError: + raise PySparkRuntimeError( + error_class="UDTF_SERIALIZATION_ERROR", + message_parameters={ + "name": self._name, + "message": "Please check the stack trace and " + "make sure the function is serializable.", + }, + ) udtf.python_ver = self._python_ver return udtf diff --git a/python/pyspark/sql/tests/test_udtf.py b/python/pyspark/sql/tests/test_udtf.py index 4a65a9bd2e4..9384a6bc011 100644 --- a/python/pyspark/sql/tests/test_udtf.py +++ b/python/pyspark/sql/tests/test_udtf.py @@ -14,7 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # - +import os +import tempfile import unittest from typing import Iterator @@ -24,6 +25,7 @@ from pyspark.errors import ( PythonException, PySparkTypeError, AnalysisException, + PySparkRuntimeError, ) from pyspark.rdd import PythonEvalType from pyspark.sql.functions import lit, udf, udtf @@ -715,6 +717,32 @@ class BaseUDTFTestsMixin: }, ) + def test_udtf_pickle_error(self): + with tempfile.TemporaryDirectory() as d: + file = os.path.join(d, "file.txt") + file_obj = open(file, "w") + + @udtf(returnType="x: int") + class TestUDTF: + def eval(self): + file_obj + yield 1, + + with self.assertRaisesRegex(PySparkRuntimeError, "UDTF_SERIALIZATION_ERROR"): + TestUDTF().collect() + + def test_udtf_access_spark_session(self): + df = self.spark.range(10) + + @udtf(returnType="x: int") + class TestUDTF: + def eval(self): + df.collect() + yield 1, + + with self.assertRaisesRegex(PySparkRuntimeError, "UDTF_SERIALIZATION_ERROR"): + TestUDTF().collect() + def test_udtf_no_eval(self): with self.assertRaises(PySparkAttributeError) as e: diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py index 7cbf4732ba9..bf85b55fea3 100644 --- a/python/pyspark/sql/udtf.py +++ b/python/pyspark/sql/udtf.py @@ -17,6 +17,7 @@ """ User-defined table function related classes and functions """ +import pickle import sys import warnings from functools import wraps @@ -240,7 +241,29 @@ class UserDefinedTableFunction: spark = SparkSession._getActiveSessionOrCreate() sc = spark.sparkContext - wrapped_func = _wrap_function(sc, func) + try: + wrapped_func = _wrap_function(sc, func) + except pickle.PicklingError as e: + if "CONTEXT_ONLY_VALID_ON_DRIVER" in str(e): + raise PySparkRuntimeError( + error_class="UDTF_SERIALIZATION_ERROR", + message_parameters={ + "name": self._name, + "message": "it appears that you are attempting to reference SparkSession " + "inside a UDTF. SparkSession can only be used on the driver, " + "not in code that runs on workers. Please remove the reference " + "and try again.", + }, + ) from None + raise PySparkRuntimeError( + error_class="UDTF_SERIALIZATION_ERROR", + message_parameters={ + "name": self._name, + "message": "Please check the stack trace and make sure the " + "function is serializable.", + }, + ) + jdt = spark._jsparkSession.parseDataType(self.returnType.json()) assert sc._jvm is not None judtf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonTableFunction( --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org