Repository: spark Updated Branches: refs/heads/branch-2.3 a55de387d -> 470cacd49
[SPARK-23754][PYTHON][FOLLOWUP][BACKPORT-2.3] Move UDF stop iteration wrapping from driver to executor SPARK-23754 was fixed in #21383 by changing the UDF code to wrap the user function, but this required a hack to save its argspec. This PR reverts this change and fixes the `StopIteration` bug in the worker. The root of the problem is that when an user-supplied function raises a `StopIteration`, pyspark might stop processing data, if this function is used in a for-loop. The solution is to catch `StopIteration`s exceptions and re-raise them as `RuntimeError`s, so that the execution fails and the error is reported to the user. This is done using the `fail_on_stopiteration` wrapper, in different ways depending on where the function is used: - In RDDs, the user function is wrapped in the driver, because this function is also called in the driver itself. - In SQL UDFs, the function is wrapped in the worker, since all processing happens there. Moreover, the worker needs the signature of the user function, which is lost when wrapping it, but passing this signature to the worker requires a not so nice hack. HyukjinKwon Author: edorigatti <emilio.doriga...@gmail.com> Author: e-dorigatti <emilio.doriga...@gmail.com> Closes #21538 from e-dorigatti/branch-2.3. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/470cacd4 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/470cacd4 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/470cacd4 Branch: refs/heads/branch-2.3 Commit: 470cacd4982ca369ffd294ee37abfa1864d39967 Parents: a55de38 Author: edorigatti <emilio.doriga...@gmail.com> Authored: Wed Jun 13 09:06:06 2018 +0800 Committer: hyukjinkwon <gurwls...@apache.org> Committed: Wed Jun 13 09:06:06 2018 +0800 ---------------------------------------------------------------------- python/pyspark/sql/tests.py | 54 ++++++++++++++++++++++++++++------------ python/pyspark/sql/udf.py | 4 +-- python/pyspark/tests.py | 37 ++++++++++++++++----------- python/pyspark/util.py | 2 +- python/pyspark/worker.py | 11 +++++--- 5 files changed, 70 insertions(+), 38 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/470cacd4/python/pyspark/sql/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 818ba83..aa7d8eb 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -853,22 +853,6 @@ class SQLTests(ReusedSQLTestCase): self.assertEqual(f, f_.func) self.assertEqual(return_type, f_.returnType) - def test_stopiteration_in_udf(self): - # test for SPARK-23754 - from pyspark.sql.functions import udf - from py4j.protocol import Py4JJavaError - - def foo(x): - raise StopIteration() - - with self.assertRaises(Py4JJavaError) as cm: - self.spark.range(0, 1000).withColumn('v', udf(foo)('id')).show() - - self.assertIn( - "Caught StopIteration thrown from user's code; failing the task", - cm.exception.java_exception.toString() - ) - def test_validate_column_types(self): from pyspark.sql.functions import udf, to_json from pyspark.sql.column import _to_java_column @@ -3917,6 +3901,44 @@ class PandasUDFTests(ReusedSQLTestCase): def foo(k, v): return k + def test_stopiteration_in_udf(self): + from pyspark.sql.functions import udf, pandas_udf, PandasUDFType + from py4j.protocol import Py4JJavaError + + def foo(x): + raise StopIteration() + + def foofoo(x, y): + raise StopIteration() + + exc_message = "Caught StopIteration thrown from user's code; failing the task" + df = self.spark.range(0, 100) + + # plain udf (test for SPARK-23754) + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.withColumn('v', udf(foo)('id')).collect + ) + + # pandas scalar udf + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.withColumn( + 'v', pandas_udf(foo, 'double', PandasUDFType.SCALAR)('id') + ).collect + ) + + # pandas grouped map + self.assertRaisesRegexp( + Py4JJavaError, + exc_message, + df.groupBy('id').apply( + pandas_udf(foo, df.schema, PandasUDFType.GROUPED_MAP) + ).collect + ) + @unittest.skipIf( not _have_pandas or not _have_pyarrow, http://git-wip-us.apache.org/repos/asf/spark/blob/470cacd4/python/pyspark/sql/udf.py ---------------------------------------------------------------------- diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 7d813af..671e568 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -24,7 +24,6 @@ from pyspark.rdd import _prepare_for_python_RDD, PythonEvalType, ignore_unicode_ from pyspark.sql.column import Column, _to_java_column, _to_seq from pyspark.sql.types import StringType, DataType, StructType, _parse_datatype_string, \ to_arrow_type, to_arrow_schema -from pyspark.util import fail_on_stopiteration __all__ = ["UDFRegistration"] @@ -155,8 +154,7 @@ class UserDefinedFunction(object): spark = SparkSession.builder.getOrCreate() sc = spark.sparkContext - func = fail_on_stopiteration(self.func) - wrapped_func = _wrap_function(sc, func, self.returnType) + wrapped_func = _wrap_function(sc, self.func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( self._name, wrapped_func, jdt, self.evalType, self.deterministic) http://git-wip-us.apache.org/repos/asf/spark/blob/470cacd4/python/pyspark/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index af39450..81bff4b 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -1270,27 +1270,34 @@ class RDDTests(ReusedPySparkTestCase): self.assertRaises(Py4JJavaError, rdd.pipe('grep 4', checkCode=True).collect) self.assertEqual([], rdd.pipe('grep 4').collect()) - def test_stopiteration_in_client_code(self): + def test_stopiteration_in_user_code(self): def stopit(*x): raise StopIteration() seq_rdd = self.sc.parallelize(range(10)) keyed_rdd = self.sc.parallelize((x % 2, x) for x in range(10)) - - self.assertRaises(Py4JJavaError, seq_rdd.map(stopit).collect) - self.assertRaises(Py4JJavaError, seq_rdd.filter(stopit).collect) - self.assertRaises(Py4JJavaError, seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) - self.assertRaises(Py4JJavaError, seq_rdd.foreach, stopit) - self.assertRaises(Py4JJavaError, keyed_rdd.reduceByKeyLocally, stopit) - self.assertRaises(Py4JJavaError, seq_rdd.reduce, stopit) - self.assertRaises(Py4JJavaError, seq_rdd.fold, 0, stopit) - - # the exception raised is non-deterministic - self.assertRaises((Py4JJavaError, RuntimeError), - seq_rdd.aggregate, 0, stopit, lambda *x: 1) - self.assertRaises((Py4JJavaError, RuntimeError), - seq_rdd.aggregate, 0, lambda *x: 1, stopit) + msg = "Caught StopIteration thrown from user's code; failing the task" + + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.map(stopit).collect) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.filter(stopit).collect) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.reduce, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.fold, 0, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, seq_rdd.foreach, stopit) + self.assertRaisesRegexp(Py4JJavaError, msg, + seq_rdd.cartesian(seq_rdd).flatMap(stopit).collect) + + # these methods call the user function both in the driver and in the executor + # the exception raised is different according to where the StopIteration happens + # RuntimeError is raised if in the driver + # Py4JJavaError is raised if in the executor (wraps the RuntimeError raised in the worker) + self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, + keyed_rdd.reduceByKeyLocally, stopit) + self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, + seq_rdd.aggregate, 0, stopit, lambda *x: 1) + self.assertRaisesRegexp((Py4JJavaError, RuntimeError), msg, + seq_rdd.aggregate, 0, lambda *x: 1, stopit) class ProfilerTests(PySparkTestCase): http://git-wip-us.apache.org/repos/asf/spark/blob/470cacd4/python/pyspark/util.py ---------------------------------------------------------------------- diff --git a/python/pyspark/util.py b/python/pyspark/util.py index 83d528f..94f51ee 100644 --- a/python/pyspark/util.py +++ b/python/pyspark/util.py @@ -48,7 +48,7 @@ def _exception_message(excp): def fail_on_stopiteration(f): """ Wraps the input function to fail on 'StopIteration' by raising a 'RuntimeError' - prevents silent loss of data when 'f' is used in a for loop + prevents silent loss of data when 'f' is used in a for loop in Spark code """ def wrapper(*args, **kwargs): try: http://git-wip-us.apache.org/repos/asf/spark/blob/470cacd4/python/pyspark/worker.py ---------------------------------------------------------------------- diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 44e9106..788b323 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -35,6 +35,7 @@ from pyspark.serializers import write_with_length, write_int, read_long, \ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \ BatchedSerializer, ArrowStreamPandasSerializer from pyspark.sql.types import to_arrow_type +from pyspark.util import fail_on_stopiteration from pyspark import shuffle pickleSer = PickleSerializer() @@ -122,13 +123,17 @@ def read_single_udf(pickleSer, infile, eval_type): else: row_func = chain(row_func, f) + # make sure StopIteration's raised in the user code are not ignored + # when they are processed in a for loop, raise them as RuntimeError's instead + func = fail_on_stopiteration(row_func) + # the last returnType will be the return type of UDF if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF: - return arg_offsets, wrap_scalar_pandas_udf(row_func, return_type) + return arg_offsets, wrap_scalar_pandas_udf(func, return_type) elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF: - return arg_offsets, wrap_grouped_map_pandas_udf(row_func, return_type) + return arg_offsets, wrap_grouped_map_pandas_udf(func, return_type) else: - return arg_offsets, wrap_udf(row_func, return_type) + return arg_offsets, wrap_udf(func, return_type) def read_udfs(pickleSer, infile, eval_type): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org