This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch branch-3.5
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.5 by this push:
     new b5ab247aa6e [SPARK-44644][PYTHON][3.5] Improve error messages for 
Python UDTFs with pickling errors
b5ab247aa6e is described below

commit b5ab247aa6e45180c2e826da74fcb615f3da3335
Author: allisonwang-db <allison.w...@databricks.com>
AuthorDate: Mon Aug 7 13:03:03 2023 +0800

    [SPARK-44644][PYTHON][3.5] Improve error messages for Python UDTFs with 
pickling errors
    
    ### What changes were proposed in this pull request?
    
    Cherry-pick 
https://github.com/apache/spark/commit/62415dc59627e1f7b4e3449ae728e93c1fc0b74f
    
    This PR improves the error messages when a Python UDTF failed to pickle.
    
    ### Why are the changes needed?
    
    To make the error message more user-friendly
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes, before this PR, when a UDTF fails to pickle, it throws this confusing 
exception:
    ```
    _pickle.PicklingError: Cannot pickle files that are not opened for reading: 
w
    ```
    After this PR, the error is more clear:
    `[UDTF_SERIALIZATION_ERROR] Cannot serialize the UDTF 'TestUDTF': Please 
check the stack trace and make sure that the function is serializable.`
    
    And for spark session access inside a UDTF:
    `[UDTF_SERIALIZATION_ERROR] it appears that you are attempting to reference 
SparkSession inside a UDTF. SparkSession can only be used on the driver, not in 
code that runs on workers. Please remove the reference and try again.`
    
    ### How was this patch tested?
    
    New UTs.
    
    Closes #42349 from allisonwang-db/spark-44644-3.5.
    
    Authored-by: allisonwang-db <allison.w...@databricks.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 python/pyspark/cloudpickle/cloudpickle_fast.py |  2 +-
 python/pyspark/errors/error_classes.py         |  5 +++++
 python/pyspark/sql/connect/plan.py             | 15 +++++++++++--
 python/pyspark/sql/tests/test_udtf.py          | 30 +++++++++++++++++++++++++-
 python/pyspark/sql/udtf.py                     | 25 ++++++++++++++++++++-
 5 files changed, 72 insertions(+), 5 deletions(-)

diff --git a/python/pyspark/cloudpickle/cloudpickle_fast.py 
b/python/pyspark/cloudpickle/cloudpickle_fast.py
index 63aaffa096b..ee1f4b8ee96 100644
--- a/python/pyspark/cloudpickle/cloudpickle_fast.py
+++ b/python/pyspark/cloudpickle/cloudpickle_fast.py
@@ -631,7 +631,7 @@ class CloudPickler(Pickler):
         try:
             return Pickler.dump(self, obj)
         except RuntimeError as e:
-            if "recursion" in e.args[0]:
+            if len(e.args) > 0 and "recursion" in e.args[0]:
                 msg = (
                     "Could not pickle object as excessively deep recursion "
                     "required."
diff --git a/python/pyspark/errors/error_classes.py 
b/python/pyspark/errors/error_classes.py
index 4ea3e678810..971dc59bbb2 100644
--- a/python/pyspark/errors/error_classes.py
+++ b/python/pyspark/errors/error_classes.py
@@ -743,6 +743,11 @@ ERROR_CLASSES_JSON = """
       "Mismatch in return type for the UDTF '<name>'. Expected a 'StructType', 
but got '<return_type>'. Please ensure the return type is a correctly formatted 
StructType."
     ]
   },
+  "UDTF_SERIALIZATION_ERROR" : {
+    "message" : [
+      "Cannot serialize the UDTF '<name>': <message>"
+    ]
+  },
   "UNEXPECTED_RESPONSE_FROM_SERVER" : {
     "message" : [
       "Unexpected response from iterator server."
diff --git a/python/pyspark/sql/connect/plan.py 
b/python/pyspark/sql/connect/plan.py
index 3390faa04de..2e918700848 100644
--- a/python/pyspark/sql/connect/plan.py
+++ b/python/pyspark/sql/connect/plan.py
@@ -21,6 +21,7 @@ check_dependencies(__name__)
 from typing import Any, List, Optional, Type, Sequence, Union, cast, 
TYPE_CHECKING, Mapping, Dict
 import functools
 import json
+import pickle
 from threading import Lock
 from inspect import signature, isclass
 
@@ -40,7 +41,7 @@ from pyspark.sql.connect.expressions import (
     LiteralExpression,
 )
 from pyspark.sql.connect.types import pyspark_types_to_proto_types, 
UnparsedDataType
-from pyspark.errors import PySparkTypeError, PySparkNotImplementedError
+from pyspark.errors import PySparkTypeError, PySparkNotImplementedError, 
PySparkRuntimeError
 
 if TYPE_CHECKING:
     from pyspark.sql.connect._typing import ColumnOrName
@@ -2200,7 +2201,17 @@ class PythonUDTF:
         assert self._return_type is not None
         
udtf.return_type.CopyFrom(pyspark_types_to_proto_types(self._return_type))
         udtf.eval_type = self._eval_type
-        udtf.command = CloudPickleSerializer().dumps(self._func)
+        try:
+            udtf.command = CloudPickleSerializer().dumps(self._func)
+        except pickle.PicklingError:
+            raise PySparkRuntimeError(
+                error_class="UDTF_SERIALIZATION_ERROR",
+                message_parameters={
+                    "name": self._name,
+                    "message": "Please check the stack trace and "
+                    "make sure the function is serializable.",
+                },
+            )
         udtf.python_ver = self._python_ver
         return udtf
 
diff --git a/python/pyspark/sql/tests/test_udtf.py 
b/python/pyspark/sql/tests/test_udtf.py
index 4a65a9bd2e4..9384a6bc011 100644
--- a/python/pyspark/sql/tests/test_udtf.py
+++ b/python/pyspark/sql/tests/test_udtf.py
@@ -14,7 +14,8 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 #
-
+import os
+import tempfile
 import unittest
 
 from typing import Iterator
@@ -24,6 +25,7 @@ from pyspark.errors import (
     PythonException,
     PySparkTypeError,
     AnalysisException,
+    PySparkRuntimeError,
 )
 from pyspark.rdd import PythonEvalType
 from pyspark.sql.functions import lit, udf, udtf
@@ -715,6 +717,32 @@ class BaseUDTFTestsMixin:
             },
         )
 
+    def test_udtf_pickle_error(self):
+        with tempfile.TemporaryDirectory() as d:
+            file = os.path.join(d, "file.txt")
+            file_obj = open(file, "w")
+
+            @udtf(returnType="x: int")
+            class TestUDTF:
+                def eval(self):
+                    file_obj
+                    yield 1,
+
+            with self.assertRaisesRegex(PySparkRuntimeError, 
"UDTF_SERIALIZATION_ERROR"):
+                TestUDTF().collect()
+
+    def test_udtf_access_spark_session(self):
+        df = self.spark.range(10)
+
+        @udtf(returnType="x: int")
+        class TestUDTF:
+            def eval(self):
+                df.collect()
+                yield 1,
+
+        with self.assertRaisesRegex(PySparkRuntimeError, 
"UDTF_SERIALIZATION_ERROR"):
+            TestUDTF().collect()
+
     def test_udtf_no_eval(self):
         with self.assertRaises(PySparkAttributeError) as e:
 
diff --git a/python/pyspark/sql/udtf.py b/python/pyspark/sql/udtf.py
index 7cbf4732ba9..bf85b55fea3 100644
--- a/python/pyspark/sql/udtf.py
+++ b/python/pyspark/sql/udtf.py
@@ -17,6 +17,7 @@
 """
 User-defined table function related classes and functions
 """
+import pickle
 import sys
 import warnings
 from functools import wraps
@@ -240,7 +241,29 @@ class UserDefinedTableFunction:
         spark = SparkSession._getActiveSessionOrCreate()
         sc = spark.sparkContext
 
-        wrapped_func = _wrap_function(sc, func)
+        try:
+            wrapped_func = _wrap_function(sc, func)
+        except pickle.PicklingError as e:
+            if "CONTEXT_ONLY_VALID_ON_DRIVER" in str(e):
+                raise PySparkRuntimeError(
+                    error_class="UDTF_SERIALIZATION_ERROR",
+                    message_parameters={
+                        "name": self._name,
+                        "message": "it appears that you are attempting to 
reference SparkSession "
+                        "inside a UDTF. SparkSession can only be used on the 
driver, "
+                        "not in code that runs on workers. Please remove the 
reference "
+                        "and try again.",
+                    },
+                ) from None
+            raise PySparkRuntimeError(
+                error_class="UDTF_SERIALIZATION_ERROR",
+                message_parameters={
+                    "name": self._name,
+                    "message": "Please check the stack trace and make sure the 
"
+                    "function is serializable.",
+                },
+            )
+
         jdt = spark._jsparkSession.parseDataType(self.returnType.json())
         assert sc._jvm is not None
         judtf = 
sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonTableFunction(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to