This is an automated email from the ASF dual-hosted git repository.

cutlerb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f9ca8ab  [SPARK-27805][PYTHON] Propagate SparkExceptions during 
toPandas with arrow enabled
f9ca8ab is described below

commit f9ca8ab196b1967a7603ca36d62fc15d1391842e
Author: David Vogelbacher <dvogelbac...@palantir.com>
AuthorDate: Tue Jun 4 10:10:27 2019 -0700

    [SPARK-27805][PYTHON] Propagate SparkExceptions during toPandas with arrow 
enabled
    
    ## What changes were proposed in this pull request?
    Similar to https://github.com/apache/spark/pull/24070, we now propagate 
SparkExceptions that are encountered during the collect in the java process to 
the python process.
    
    Fixes https://jira.apache.org/jira/browse/SPARK-27805
    
    ## How was this patch tested?
    Added a new unit test
    
    Closes #24677 from dvogelbacher/dv/betterErrorMsgWhenUsingArrow.
    
    Authored-by: David Vogelbacher <dvogelbac...@palantir.com>
    Signed-off-by: Bryan Cutler <cutl...@gmail.com>
---
 python/pyspark/serializers.py                      |  6 +++-
 python/pyspark/sql/tests/test_arrow.py             | 12 +++++++
 .../main/scala/org/apache/spark/sql/Dataset.scala  | 40 +++++++++++++++-------
 3 files changed, 44 insertions(+), 14 deletions(-)

diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 6058e94..516ee7e 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -206,8 +206,12 @@ class ArrowCollectSerializer(Serializer):
         for batch in self.serializer.load_stream(stream):
             yield batch
 
-        # load the batch order indices
+        # load the batch order indices or propagate any error that occurred in 
the JVM
         num = read_int(stream)
+        if num == -1:
+            error_msg = UTF8Deserializer().loads(stream)
+            raise RuntimeError("An error occurred while calling "
+                               "ArrowCollectSerializer.load_stream: 
{}".format(error_msg))
         batch_order = []
         for i in xrange(num):
             index = read_int(stream)
diff --git a/python/pyspark/sql/tests/test_arrow.py 
b/python/pyspark/sql/tests/test_arrow.py
index 7871af4..cb51241 100644
--- a/python/pyspark/sql/tests/test_arrow.py
+++ b/python/pyspark/sql/tests/test_arrow.py
@@ -23,6 +23,7 @@ import unittest
 import warnings
 
 from pyspark.sql import Row
+from pyspark.sql.functions import udf
 from pyspark.sql.types import *
 from pyspark.testing.sqlutils import ReusedSQLTestCase, have_pandas, 
have_pyarrow, \
     pandas_requirement_message, pyarrow_requirement_message
@@ -205,6 +206,17 @@ class ArrowTests(ReusedSQLTestCase):
         self.assertEqual(pdf.columns[0], "field1")
         self.assertTrue(pdf.empty)
 
+    def test_propagates_spark_exception(self):
+        df = self.spark.range(3).toDF("i")
+
+        def raise_exception():
+            raise Exception("My error")
+        exception_udf = udf(raise_exception, IntegerType())
+        df = df.withColumn("error", exception_udf())
+        with QuietTest(self.sc):
+            with self.assertRaisesRegexp(RuntimeError, 'My error'):
+                df.toPandas()
+
     def _createDataFrame_toggle(self, pdf, schema=None):
         with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": 
False}):
             df_no_arrow = self.spark.createDataFrame(pdf, schema=schema)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index f3377f3..a80aade 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -26,7 +26,7 @@ import scala.util.control.NonFatal
 
 import org.apache.commons.lang3.StringUtils
 
-import org.apache.spark.TaskContext
+import org.apache.spark.{SparkException, TaskContext}
 import org.apache.spark.annotation.{DeveloperApi, Evolving, Experimental, 
Stable, Unstable}
 import org.apache.spark.api.java.JavaRDD
 import org.apache.spark.api.java.function._
@@ -3321,20 +3321,34 @@ class Dataset[T] private[sql](
             }
           }
 
-        val arrowBatchRdd = toArrowBatchRdd(plan)
-        sparkSession.sparkContext.runJob(
-          arrowBatchRdd,
-          (it: Iterator[Array[Byte]]) => it.toArray,
-          handlePartitionBatches)
+        var sparkException: Option[SparkException] = None
+        try {
+          val arrowBatchRdd = toArrowBatchRdd(plan)
+          sparkSession.sparkContext.runJob(
+            arrowBatchRdd,
+            (it: Iterator[Array[Byte]]) => it.toArray,
+            handlePartitionBatches)
+        } catch {
+          case e: SparkException =>
+            sparkException = Some(e)
+        }
 
-        // After processing all partitions, end the stream and write batch 
order indices
+        // After processing all partitions, end the batch stream
         batchWriter.end()
-        out.writeInt(batchOrder.length)
-        // Sort by (index of partition, batch index in that partition) tuple 
to get the
-        // overall_batch_index from 0 to N-1 batches, which can be used to put 
the
-        // transferred batches in the correct order
-        batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, 
overallBatchIndex) =>
-          out.writeInt(overallBatchIndex)
+        sparkException match {
+          case Some(exception) =>
+            // Signal failure and write error message
+            out.writeInt(-1)
+            PythonRDD.writeUTF(exception.getMessage, out)
+          case None =>
+            // Write batch order indices
+            out.writeInt(batchOrder.length)
+            // Sort by (index of partition, batch index in that partition) 
tuple to get the
+            // overall_batch_index from 0 to N-1 batches, which can be used to 
put the
+            // transferred batches in the correct order
+            batchOrder.zipWithIndex.sortBy(_._1).foreach { case (_, 
overallBatchIndex) =>
+              out.writeInt(overallBatchIndex)
+            }
         }
       }
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to