Repository: spark
Updated Branches:
  refs/heads/master f07c5064a -> 3e5b4ae63


[SPARK-23754][PYTHON][FOLLOWUP] Move UDF stop iteration wrapping from driver to 
executor

## What changes were proposed in this pull request?
SPARK-23754 was fixed in #21383 by changing the UDF code to wrap the user 
function, but this required a hack to save its argspec. This PR reverts this 
change and fixes the `StopIteration` bug in the worker

## How does this work?

The root of the problem is that when an user-supplied function raises a 
`StopIteration`, pyspark might stop processing data, if this function is used 
in a for-loop. The solution is to catch `StopIteration`s exceptions and 
re-raise them as `RuntimeError`s, so that the execution fails and the error is 
reported to the user. This is done using the `fail_on_stopiteration` wrapper, 
in different ways depending on where the function is used:
 - In RDDs, the user function is wrapped in the driver, because this function 
is also called in the driver itself.
 - In SQL UDFs, the function is wrapped in the worker, since all processing 
happens there. Moreover, the worker needs the signature of the user function, 
which is lost when wrapping it, but passing this signature to the worker 
requires a not so nice hack.

## How was this patch tested?

Same tests, plus tests for pandas UDFs

Author: edorigatti <emilio.doriga...@gmail.com>

Closes #21467 from e-dorigatti/fix_udf_hack.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3e5b4ae6
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3e5b4ae6
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3e5b4ae6

Branch: refs/heads/master
Commit: 3e5b4ae63a468858ff8b9f7f3231cc877846a0af
Parents: f07c506
Author: edorigatti <emilio.doriga...@gmail.com>
Authored: Mon Jun 11 10:15:42 2018 +0800
Committer: hyukjinkwon <gurwls...@apache.org>
Committed: Mon Jun 11 10:15:42 2018 +0800

----------------------------------------------------------------------
 python/pyspark/sql/tests.py | 71 +++++++++++++++++++++++++++++++---------
 python/pyspark/sql/udf.py   | 14 ++------
 python/pyspark/tests.py     | 37 ++++++++++++---------
 python/pyspark/util.py      |  9 ++---
 python/pyspark/worker.py    | 18 ++++++----
 5 files changed, 92 insertions(+), 57 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3e5b4ae6/python/pyspark/sql/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 487eb19..4a3941d 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -900,22 +900,6 @@ class SQLTests(ReusedSQLTestCase):
         self.assertEqual(f, f_.func)
         self.assertEqual(return_type, f_.returnType)
 
-    def test_stopiteration_in_udf(self):
-        # test for SPARK-23754
-        from pyspark.sql.functions import udf
-        from py4j.protocol import Py4JJavaError
-
-        def foo(x):
-            raise StopIteration()
-
-        with self.assertRaises(Py4JJavaError) as cm:
-            self.spark.range(0, 1000).withColumn('v', udf(foo)('id')).show()
-
-        self.assertIn(
-            "Caught StopIteration thrown from user's code; failing the task",
-            cm.exception.java_exception.toString()
-        )
-
     def test_validate_column_types(self):
         from pyspark.sql.functions import udf, to_json
         from pyspark.sql.column import _to_java_column
@@ -4144,6 +4128,61 @@ class PandasUDFTests(ReusedSQLTestCase):
                 def foo(k, v, w):
                     return k
 
+    def test_stopiteration_in_udf(self):
+        from pyspark.sql.functions import udf, pandas_udf, PandasUDFType
+        from py4j.protocol import Py4JJavaError
+
+        def foo(x):
+            raise StopIteration()
+
+        def foofoo(x, y):
+            raise StopIteration()
+
+        exc_message = "Caught StopIteration thrown from user's code; failing 
the task"
+        df = self.spark.range(0, 100)
+
+        # plain udf (test for SPARK-23754)
+        self.assertRaisesRegexp(
+            Py4JJavaError,
+            exc_message,
+            df.withColumn('v', udf(foo)('id')).collect
+        )
+
+        # pandas scalar udf
+        self.assertRaisesRegexp(
+            Py4JJavaError,
+            exc_message,
+            df.withColumn(
+                'v', pandas_udf(foo, 'double', PandasUDFType.SCALAR)('id')
+            ).collect
+        )
+
+        # pandas grouped map
+        self.assertRaisesRegexp(
+            Py4JJavaError,
+            exc_message,
+            df.groupBy('id').apply(
+                pandas_udf(foo, df.schema, PandasUDFType.GROUPED_MAP)
+            ).collect
+        )
+
+        self.assertRaisesRegexp(
+            Py4JJavaError,
+            exc_message,
+            df.groupBy('id').apply(
+                pandas_udf(foofoo, df.schema, PandasUDFType.GROUPED_MAP)
+            ).collect
+        )
+
+        # pandas grouped agg
+        self.assertRaisesRegexp(
+            Py4JJavaError,
+            exc_message,
+            df.groupBy('id').agg(
+                pandas_udf(foo, 'double', PandasUDFType.GROUPED_AGG)('id')
+            ).collect
+        )
+
 
 @unittest.skipIf(
     not _have_pandas or not _have_pyarrow,

http://git-wip-us.apache.org/repos/asf/spark/blob/3e5b4ae6/python/pyspark/sql/udf.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py
index c8fb49d..9dbe49b 100644
--- a/python/pyspark/sql/udf.py
+++ b/python/pyspark/sql/udf.py
@@ -25,7 +25,7 @@ from pyspark.rdd import _prepare_for_python_RDD, 
PythonEvalType, ignore_unicode_
 from pyspark.sql.column import Column, _to_java_column, _to_seq
 from pyspark.sql.types import StringType, DataType, StructType, 
_parse_datatype_string,\
     to_arrow_type, to_arrow_schema
-from pyspark.util import _get_argspec, fail_on_stopiteration
+from pyspark.util import _get_argspec
 
 __all__ = ["UDFRegistration"]
 
@@ -157,17 +157,7 @@ class UserDefinedFunction(object):
         spark = SparkSession.builder.getOrCreate()
         sc = spark.sparkContext
 
-        func = fail_on_stopiteration(self.func)
-
-        # for pandas UDFs the worker needs to know if the function takes
-        # one or two arguments, but the signature is lost when wrapping with
-        # fail_on_stopiteration, so we store it here
-        if self.evalType in (PythonEvalType.SQL_SCALAR_PANDAS_UDF,
-                             PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF,
-                             PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF):
-            func._argspec = _get_argspec(self.func)
-
-        wrapped_func = _wrap_function(sc, func, self.returnType)
+        wrapped_func = _wrap_function(sc, self.func, self.returnType)
         jdt = spark._jsparkSession.parseDataType(self.returnType.json())
         judf = 
sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction(
             self._name, wrapped_func, jdt, self.evalType, self.deterministic)

http://git-wip-us.apache.org/repos/asf/spark/blob/3e5b4ae6/python/pyspark/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 30723b8..18b2f25 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -1291,27 +1291,34 @@ class RDDTests(ReusedPySparkTestCase):
         result = rdd.pipe('cat').collect()
         self.assertEqual(data, result)
 
-    def test_stopiteration_in_client_code(self):
+    def test_stopiteration_in_user_code(self):
 
         def stopit(*x):
             raise StopIteration()
 
         seq_rdd = self.sc.parallelize(range(10))
         keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10))
-
-        self.assertRaises(Py4JJavaError, seq_rdd.map(stopit).collect)
-        self.assertRaises(Py4JJavaError, seq_rdd.filter(stopit).collect)
-        self.assertRaises(Py4JJavaError, 
seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect)
-        self.assertRaises(Py4JJavaError, seq_rdd.foreach, stopit)
-        self.assertRaises(Py4JJavaError, keyed_rdd.reduceByKeyLocally, stopit)
-        self.assertRaises(Py4JJavaError, seq_rdd.reduce, stopit)
-        self.assertRaises(Py4JJavaError, seq_rdd.fold, 0, stopit)
-
-        # the exception raised is non-deterministic
-        self.assertRaises((Py4JJavaError, RuntimeError),
-                          seq_rdd.aggregate, 0, stopit, lambda *x: 1)
-        self.assertRaises((Py4JJavaError, RuntimeError),
-                          seq_rdd.aggregate, 0, lambda *x: 1, stopit)
+        msg = "Caught StopIteration thrown from user's code; failing the task"
+
+        self.assertRaisesRegexp(Py4JJavaError, msg, 
seq_rdd.map(stopit).collect)
+        self.assertRaisesRegexp(Py4JJavaError, msg, 
seq_rdd.filter(stopit).collect)
+        self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit)
+        self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.reduce, stopit)
+        self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.fold, 0, stopit)
+        self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit)
+        self.assertRaisesRegexp(Py4JJavaError, msg,
+                                
seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect)
+
+        # these methods call the user function both in the driver and in the 
executor
+        # the exception raised is different according to where the 
StopIteration happens
+        # RuntimeError is raised if in the driver
+        # Py4JJavaError is raised if in the executor (wraps the RuntimeError 
raised in the worker)
+        self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
+                                keyed_rdd.reduceByKeyLocally, stopit)
+        self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
+                                seq_rdd.aggregate, 0, stopit, lambda *x: 1)
+        self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg,
+                                seq_rdd.aggregate, 0, lambda *x: 1, stopit)
 
 
 class ProfilerTests(PySparkTestCase):

http://git-wip-us.apache.org/repos/asf/spark/blob/3e5b4ae6/python/pyspark/util.py
----------------------------------------------------------------------
diff --git a/python/pyspark/util.py b/python/pyspark/util.py
index e95a9b5..f015542 100644
--- a/python/pyspark/util.py
+++ b/python/pyspark/util.py
@@ -53,12 +53,7 @@ def _get_argspec(f):
     """
     Get argspec of a function. Supports both Python 2 and Python 3.
     """
-
-    if hasattr(f, '_argspec'):
-        # only used for pandas UDF: they wrap the user function, losing its 
signature
-        # workers need this signature, so UDF saves it here
-        argspec = f._argspec
-    elif sys.version_info[0] < 3:
+    if sys.version_info[0] < 3:
         argspec = inspect.getargspec(f)
     else:
         # `getargspec` is deprecated since python3.0 (incompatible with 
function annotations).
@@ -97,7 +92,7 @@ class VersionUtils(object):
 def fail_on_stopiteration(f):
     """
     Wraps the input function to fail on 'StopIteration' by raising a 
'RuntimeError'
-    prevents silent loss of data when 'f' is used in a for loop
+    prevents silent loss of data when 'f' is used in a for loop in Spark code
     """
     def wrapper(*args, **kwargs):
         try:

http://git-wip-us.apache.org/repos/asf/spark/blob/3e5b4ae6/python/pyspark/worker.py
----------------------------------------------------------------------
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index fbcb8af..a30d6bf 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -35,7 +35,7 @@ from pyspark.serializers import write_with_length, write_int, 
read_long, \
     write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
     BatchedSerializer, ArrowStreamPandasSerializer
 from pyspark.sql.types import to_arrow_type
-from pyspark.util import _get_argspec
+from pyspark.util import _get_argspec, fail_on_stopiteration
 from pyspark import shuffle
 
 pickleSer = PickleSerializer()
@@ -92,10 +92,9 @@ def wrap_scalar_pandas_udf(f, return_type):
     return lambda *a: (verify_result_length(*a), arrow_return_type)
 
 
-def wrap_grouped_map_pandas_udf(f, return_type):
+def wrap_grouped_map_pandas_udf(f, return_type, argspec):
     def wrapped(key_series, value_series):
         import pandas as pd
-        argspec = _get_argspec(f)
 
         if len(argspec.args) == 1:
             result = f(pd.concat(value_series, axis=1))
@@ -140,15 +139,20 @@ def read_single_udf(pickleSer, infile, eval_type):
         else:
             row_func = chain(row_func, f)
 
+    # make sure StopIteration's raised in the user code are not ignored
+    # when they are processed in a for loop, raise them as RuntimeError's 
instead
+    func = fail_on_stopiteration(row_func)
+
     # the last returnType will be the return type of UDF
     if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
-        return arg_offsets, wrap_scalar_pandas_udf(row_func, return_type)
+        return arg_offsets, wrap_scalar_pandas_udf(func, return_type)
     elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
-        return arg_offsets, wrap_grouped_map_pandas_udf(row_func, return_type)
+        argspec = _get_argspec(row_func)  # signature was lost when wrapping it
+        return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type, 
argspec)
     elif eval_type == PythonEvalType.SQL_GROUPED_AGG_PANDAS_UDF:
-        return arg_offsets, wrap_grouped_agg_pandas_udf(row_func, return_type)
+        return arg_offsets, wrap_grouped_agg_pandas_udf(func, return_type)
     elif eval_type == PythonEvalType.SQL_BATCHED_UDF:
-        return arg_offsets, wrap_udf(row_func, return_type)
+        return arg_offsets, wrap_udf(func, return_type)
     else:
         raise ValueError("Unknown eval type: {}".format(eval_type))
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to