Repository: spark
Updated Branches:
  refs/heads/master 88875d941 -> d29e2ef4c


[SPARK-11935][PYSPARK] Send the Python exceptions in TransformFunction and 
TransformFunctionSerializer to Java

The Python exception track in TransformFunction and TransformFunctionSerializer 
is not sent back to Java. Py4j just throws a very general exception, which is 
hard to debug.

This PRs adds `getFailure` method to get the failure message in Java side.

Author: Shixiong Zhu <shixi...@databricks.com>

Closes #9922 from zsxwing/SPARK-11935.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d29e2ef4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d29e2ef4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d29e2ef4

Branch: refs/heads/master
Commit: d29e2ef4cf43c7f7c5aa40d305cf02be44ce19e0
Parents: 88875d9
Author: Shixiong Zhu <shixi...@databricks.com>
Authored: Wed Nov 25 11:47:21 2015 -0800
Committer: Tathagata Das <tathagata.das1...@gmail.com>
Committed: Wed Nov 25 11:47:21 2015 -0800

----------------------------------------------------------------------
 python/pyspark/streaming/tests.py               | 82 +++++++++++++++++++-
 python/pyspark/streaming/util.py                | 29 ++++---
 .../streaming/api/python/PythonDStream.scala    | 52 ++++++++++---
 3 files changed, 144 insertions(+), 19 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d29e2ef4/python/pyspark/streaming/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/streaming/tests.py 
b/python/pyspark/streaming/tests.py
index a0e0267..d380d69 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -404,17 +404,69 @@ class BasicOperationTests(PySparkStreamingTestCase):
         self._test_func(input, func, expected)
 
     def test_failed_func(self):
+        # Test failure in
+        # TransformFunction.apply(rdd: Option[RDD[_]], time: Time)
         input = [self.sc.parallelize([d], 1) for d in range(4)]
         input_stream = self.ssc.queueStream(input)
 
         def failed_func(i):
-            raise ValueError("failed")
+            raise ValueError("This is a special error")
 
         input_stream.map(failed_func).pprint()
         self.ssc.start()
         try:
             self.ssc.awaitTerminationOrTimeout(10)
         except:
+            import traceback
+            failure = traceback.format_exc()
+            self.assertTrue("This is a special error" in failure)
+            return
+
+        self.fail("a failed func should throw an error")
+
+    def test_failed_func2(self):
+        # Test failure in
+        # TransformFunction.apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], 
time: Time)
+        input = [self.sc.parallelize([d], 1) for d in range(4)]
+        input_stream1 = self.ssc.queueStream(input)
+        input_stream2 = self.ssc.queueStream(input)
+
+        def failed_func(rdd1, rdd2):
+            raise ValueError("This is a special error")
+
+        input_stream1.transformWith(failed_func, input_stream2, True).pprint()
+        self.ssc.start()
+        try:
+            self.ssc.awaitTerminationOrTimeout(10)
+        except:
+            import traceback
+            failure = traceback.format_exc()
+            self.assertTrue("This is a special error" in failure)
+            return
+
+        self.fail("a failed func should throw an error")
+
+    def test_failed_func_with_reseting_failure(self):
+        input = [self.sc.parallelize([d], 1) for d in range(4)]
+        input_stream = self.ssc.queueStream(input)
+
+        def failed_func(i):
+            if i == 1:
+                # Make it fail in the second batch
+                raise ValueError("This is a special error")
+            else:
+                return i
+
+        # We should be able to see the results of the 3rd and 4th batches even 
if the second batch
+        # fails
+        expected = [[0], [2], [3]]
+        self.assertEqual(expected, 
self._collect(input_stream.map(failed_func), 3))
+        try:
+            self.ssc.awaitTerminationOrTimeout(10)
+        except:
+            import traceback
+            failure = traceback.format_exc()
+            self.assertTrue("This is a special error" in failure)
             return
 
         self.fail("a failed func should throw an error")
@@ -780,6 +832,34 @@ class CheckpointTests(unittest.TestCase):
         if self.cpd is not None:
             shutil.rmtree(self.cpd)
 
+    def test_transform_function_serializer_failure(self):
+        inputd = tempfile.mkdtemp()
+        self.cpd = 
tempfile.mkdtemp("test_transform_function_serializer_failure")
+
+        def setup():
+            conf = SparkConf().set("spark.default.parallelism", 1)
+            sc = SparkContext(conf=conf)
+            ssc = StreamingContext(sc, 0.5)
+
+            # A function that cannot be serialized
+            def process(time, rdd):
+                sc.parallelize(range(1, 10))
+
+            ssc.textFileStream(inputd).foreachRDD(process)
+            return ssc
+
+        self.ssc = StreamingContext.getOrCreate(self.cpd, setup)
+        try:
+            self.ssc.start()
+        except:
+            import traceback
+            failure = traceback.format_exc()
+            self.assertTrue(
+                "It appears that you are attempting to reference SparkContext" 
in failure)
+            return
+
+        self.fail("using SparkContext in process should fail because it's not 
Serializable")
+
     def test_get_or_create_and_get_active_or_create(self):
         inputd = tempfile.mkdtemp()
         outputd = tempfile.mkdtemp() + "/"

http://git-wip-us.apache.org/repos/asf/spark/blob/d29e2ef4/python/pyspark/streaming/util.py
----------------------------------------------------------------------
diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py
index 767c732..c7f02bc 100644
--- a/python/pyspark/streaming/util.py
+++ b/python/pyspark/streaming/util.py
@@ -38,12 +38,15 @@ class TransformFunction(object):
         self.func = func
         self.deserializers = deserializers
         self._rdd_wrapper = lambda jrdd, ctx, ser: RDD(jrdd, ctx, ser)
+        self.failure = None
 
     def rdd_wrapper(self, func):
         self._rdd_wrapper = func
         return self
 
     def call(self, milliseconds, jrdds):
+        # Clear the failure
+        self.failure = None
         try:
             if self.ctx is None:
                 self.ctx = SparkContext._active_spark_context
@@ -62,9 +65,11 @@ class TransformFunction(object):
             r = self.func(t, *rdds)
             if r:
                 return r._jrdd
-        except Exception:
-            traceback.print_exc()
-            raise
+        except:
+            self.failure = traceback.format_exc()
+
+    def getLastFailure(self):
+        return self.failure
 
     def __repr__(self):
         return "TransformFunction(%s)" % self.func
@@ -89,22 +94,28 @@ class TransformFunctionSerializer(object):
         self.serializer = serializer
         self.gateway = gateway or self.ctx._gateway
         self.gateway.jvm.PythonDStream.registerSerializer(self)
+        self.failure = None
 
     def dumps(self, id):
+        # Clear the failure
+        self.failure = None
         try:
             func = self.gateway.gateway_property.pool[id]
             return bytearray(self.serializer.dumps((func.func, 
func.deserializers)))
-        except Exception:
-            traceback.print_exc()
-            raise
+        except:
+            self.failure = traceback.format_exc()
 
     def loads(self, data):
+        # Clear the failure
+        self.failure = None
         try:
             f, deserializers = self.serializer.loads(bytes(data))
             return TransformFunction(self.ctx, f, *deserializers)
-        except Exception:
-            traceback.print_exc()
-            raise
+        except:
+            self.failure = traceback.format_exc()
+
+    def getLastFailure(self):
+        return self.failure
 
     def __repr__(self):
         return "TransformFunctionSerializer(%s)" % self.serializer

http://git-wip-us.apache.org/repos/asf/spark/blob/d29e2ef4/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
----------------------------------------------------------------------
diff --git 
a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
 
b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
index dfc5694..994309d 100644
--- 
a/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
+++ 
b/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala
@@ -26,6 +26,7 @@ import scala.language.existentials
 
 import py4j.GatewayServer
 
+import org.apache.spark.SparkException
 import org.apache.spark.api.java._
 import org.apache.spark.rdd.RDD
 import org.apache.spark.storage.StorageLevel
@@ -40,6 +41,13 @@ import org.apache.spark.util.Utils
  */
 private[python] trait PythonTransformFunction {
   def call(time: Long, rdds: JList[_]): JavaRDD[Array[Byte]]
+
+  /**
+   * Get the failure, if any, in the last call to `call`.
+   *
+   * @return the failure message if there was a failure, or `null` if there 
was no failure.
+   */
+  def getLastFailure: String
 }
 
 /**
@@ -48,6 +56,13 @@ private[python] trait PythonTransformFunction {
 private[python] trait PythonTransformFunctionSerializer {
   def dumps(id: String): Array[Byte]
   def loads(bytes: Array[Byte]): PythonTransformFunction
+
+  /**
+   * Get the failure, if any, in the last call to `dumps` or `loads`.
+   *
+   * @return the failure message if there was a failure, or `null` if there 
was no failure.
+   */
+  def getLastFailure: String
 }
 
 /**
@@ -59,18 +74,27 @@ private[python] class TransformFunction(@transient var 
pfunc: PythonTransformFun
   extends function.Function2[JList[JavaRDD[_]], Time, JavaRDD[Array[Byte]]] {
 
   def apply(rdd: Option[RDD[_]], time: Time): Option[RDD[Array[Byte]]] = {
-    Option(pfunc.call(time.milliseconds, 
List(rdd.map(JavaRDD.fromRDD(_)).orNull).asJava))
-      .map(_.rdd)
+    val rdds = List(rdd.map(JavaRDD.fromRDD(_)).orNull).asJava
+    Option(callPythonTransformFunction(time.milliseconds, rdds)).map(_.rdd)
   }
 
   def apply(rdd: Option[RDD[_]], rdd2: Option[RDD[_]], time: Time): 
Option[RDD[Array[Byte]]] = {
     val rdds = List(rdd.map(JavaRDD.fromRDD(_)).orNull, 
rdd2.map(JavaRDD.fromRDD(_)).orNull).asJava
-    Option(pfunc.call(time.milliseconds, rdds)).map(_.rdd)
+    Option(callPythonTransformFunction(time.milliseconds, rdds)).map(_.rdd)
   }
 
   // for function.Function2
   def call(rdds: JList[JavaRDD[_]], time: Time): JavaRDD[Array[Byte]] = {
-    pfunc.call(time.milliseconds, rdds)
+    callPythonTransformFunction(time.milliseconds, rdds)
+  }
+
+  private def callPythonTransformFunction(time: Long, rdds: JList[_]): 
JavaRDD[Array[Byte]] = {
+    val resultRDD = pfunc.call(time, rdds)
+    val failure = pfunc.getLastFailure
+    if (failure != null) {
+      throw new SparkException("An exception was raised by Python:\n" + 
failure)
+    }
+    resultRDD
   }
 
   private def writeObject(out: ObjectOutputStream): Unit = 
Utils.tryOrIOException {
@@ -103,23 +127,33 @@ private[python] object PythonTransformFunctionSerializer {
   /*
    * Register a serializer from Python, should be called during initialization
    */
-  def register(ser: PythonTransformFunctionSerializer): Unit = {
+  def register(ser: PythonTransformFunctionSerializer): Unit = synchronized {
     serializer = ser
   }
 
-  def serialize(func: PythonTransformFunction): Array[Byte] = {
+  def serialize(func: PythonTransformFunction): Array[Byte] = synchronized {
     require(serializer != null, "Serializer has not been registered!")
     // get the id of PythonTransformFunction in py4j
     val h = Proxy.getInvocationHandler(func.asInstanceOf[Proxy])
     val f = h.getClass().getDeclaredField("id")
     f.setAccessible(true)
     val id = f.get(h).asInstanceOf[String]
-    serializer.dumps(id)
+    val results = serializer.dumps(id)
+    val failure = serializer.getLastFailure
+    if (failure != null) {
+      throw new SparkException("An exception was raised by Python:\n" + 
failure)
+    }
+    results
   }
 
-  def deserialize(bytes: Array[Byte]): PythonTransformFunction = {
+  def deserialize(bytes: Array[Byte]): PythonTransformFunction = synchronized {
     require(serializer != null, "Serializer has not been registered!")
-    serializer.loads(bytes)
+    val pfunc = serializer.loads(bytes)
+    val failure = serializer.getLastFailure
+    if (failure != null) {
+      throw new SparkException("An exception was raised by Python:\n" + 
failure)
+    }
+    pfunc
   }
 }
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to