This is an automated email from the ASF dual-hosted git repository. cutlerb pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new f9ca8ab [SPARK-27805][PYTHON] Propagate SparkExceptions during toPandas with arrow enabled f9ca8ab is described below commit f9ca8ab196b1967a7603ca36d62fc15d1391842e Author: David Vogelbacher <dvogelbac...@palantir.com> AuthorDate: Tue Jun 4 10:10:27 2019 -0700 [SPARK-27805][PYTHON] Propagate SparkExceptions during toPandas with arrow enabled ## What changes were proposed in this pull request? Similar to https://github.com/apache/spark/pull/24070, we now propagate SparkExceptions that are encountered during the collect in the java process to the python process. Fixes https://jira.apache.org/jira/browse/SPARK-27805 ## How was this patch tested? Added a new unit test Closes #24677 from dvogelbacher/dv/betterErrorMsgWhenUsingArrow. Authored-by: David Vogelbacher <dvogelbac...@palantir.com> Signed-off-by: Bryan Cutler <cutl...@gmail.com> --- python/pyspark/serializers.py | 6 +++- python/pyspark/sql/tests/test_arrow.py | 12 +++++++ .../main/scala/org/apache/spark/sql/Dataset.scala | 40 +++++++++++++++------- 3 files changed, 44 insertions(+), 14 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 6058e94..516ee7e 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -206,8 +206,12 @@ class ArrowCollectSerializer(Serializer): for batch in self.serializer.load_stream(stream): yield batch - # load the batch order indices + # load the batch order indices or propagate any error that occurred in the JVM num = read_int(stream) + if num == -1: + error_msg = UTF8Deserializer().loads(stream) + raise RuntimeError("An error occurred while calling " + "ArrowCollectSerializer.load_stream: {}".format(error_msg)) batch_order = [] for i in xrange(num): index = read_int(stream) diff --git a/python/pyspark/sql/tests/test_arrow.py b/python/pyspark/sql/tests/test_arrow.py index 7871af4..cb51241 100644 --- a/python/pyspark/sql/tests/test_arrow.py +++ b/python/pyspark/sql/tests/test_arrow.py @@ -23,6 +23,7 @@ import unittest import warnings from pyspark.sql import Row +from pyspark.sql.functions import udf from pyspark.sql.types import * from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, have_pyarrow, \ pandas_requirement_message, pyarrow_requirement_message @@ -205,6 +206,17 @@ class ArrowTests(ReusedSQLTestCase): self.assertEqual(pdf.columns[0], "field1") self.assertTrue(pdf.empty) + def test_propagates_spark_exception(self): + df = self.spark.range(3).toDF("i") + + def raise_exception(): + raise Exception("My error") + exception_udf = udf(raise_exception, IntegerType()) + df = df.withColumn("error", exception_udf()) + with QuietTest(self.sc): + with self.assertRaisesRegexp(RuntimeError, 'My error'): + df.toPandas() + def _createDataFrame_toggle(self, pdf, schema=None): with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}): df_no_arrow = self.spark.createDataFrame(pdf, schema=schema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index f3377f3..a80aade 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -26,7 +26,7 @@ import scala.util.control.NonFatal import org.apache.commons.lang3.StringUtils -import org.apache.spark.TaskContext +import org.apache.spark.{SparkException, TaskContext} import org.apache.spark.annotation.{DeveloperApi, Evolving, Experimental, Stable, Unstable} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.function._ @@ -3321,20 +3321,34 @@ class Dataset[T] private[sql]( } } - val arrowBatchRdd = toArrowBatchRdd(plan) - sparkSession.sparkContext.runJob( - arrowBatchRdd, - (it: Iterator[Array[Byte]]) => it.toArray, - handlePartitionBatches) + var sparkException: Option[SparkException] = None + try { + val arrowBatchRdd = toArrowBatchRdd(plan) + sparkSession.sparkContext.runJob( + arrowBatchRdd, + (it: Iterator[Array[Byte]]) => it.toArray, + handlePartitionBatches) + } catch { + case e: SparkException => + sparkException = Some(e) + } - // After processing all partitions, end the stream and write batch order indices + // After processing all partitions, end the batch stream batchWriter.end() - out.writeInt(batchOrder.length) - // Sort by (index of partition, batch index in that partition) tuple to get the - // overall_batch_index from 0 to N-1 batches, which can be used to put the - // transferred batches in the correct order - batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overallBatchIndex) => - out.writeInt(overallBatchIndex) + sparkException match { + case Some(exception) => + // Signal failure and write error message + out.writeInt(-1) + PythonRDD.writeUTF(exception.getMessage, out) + case None => + // Write batch order indices + out.writeInt(batchOrder.length) + // Sort by (index of partition, batch index in that partition) tuple to get the + // overall_batch_index from 0 to N-1 batches, which can be used to put the + // transferred batches in the correct order + batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, overallBatchIndex) => + out.writeInt(overallBatchIndex) + } } } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org