Repository: spark Updated Branches: refs/heads/master 1a644afba -> b070ded28
[SPARK-17756][PYTHON][STREAMING] Workaround to avoid return type mismatch in PythonTransformFunction ## What changes were proposed in this pull request? This PR proposes to wrap the transformed rdd within `TransformFunction`. `PythonTransformFunction` looks requiring to return `JavaRDD` in `_jrdd`. https://github.com/apache/spark/blob/39e2bad6a866d27c3ca594d15e574a1da3ee84cc/python/pyspark/streaming/util.py#L67 https://github.com/apache/spark/blob/6ee28423ad1b2e6089b82af64a31d77d3552bb38/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala#L43 However, this could be `JavaPairRDD` by some APIs, for example, `zip` in PySpark's RDD API. `_jrdd` could be checked as below: ```python >>> rdd.zip(rdd)._jrdd.getClass().toString() u'class org.apache.spark.api.java.JavaPairRDD' ``` So, here, I wrapped it with `map` so that it ensures returning `JavaRDD`. ```python >>> rdd.zip(rdd).map(lambda x: x)._jrdd.getClass().toString() u'class org.apache.spark.api.java.JavaRDD' ``` I tried to elaborate some failure cases as below: ```python from pyspark.streaming import StreamingContext ssc = StreamingContext(spark.sparkContext, 10) ssc.queueStream([sc.range(10)]) \ .transform(lambda rdd: rdd.cartesian(rdd)) \ .pprint() ssc.start() ``` ```python from pyspark.streaming import StreamingContext ssc = StreamingContext(spark.sparkContext, 10) ssc.queueStream([sc.range(10)]).foreachRDD(lambda rdd: rdd.cartesian(rdd)) ssc.start() ``` ```python from pyspark.streaming import StreamingContext ssc = StreamingContext(spark.sparkContext, 10) ssc.queueStream([sc.range(10)]).foreachRDD(lambda rdd: rdd.zip(rdd)) ssc.start() ``` ```python from pyspark.streaming import StreamingContext ssc = StreamingContext(spark.sparkContext, 10) ssc.queueStream([sc.range(10)]).foreachRDD(lambda rdd: rdd.zip(rdd).union(rdd.zip(rdd))) ssc.start() ``` ```python from pyspark.streaming import StreamingContext ssc = StreamingContext(spark.sparkContext, 10) ssc.queueStream([sc.range(10)]).foreachRDD(lambda rdd: rdd.zip(rdd).coalesce(1)) ssc.start() ``` ## How was this patch tested? Unit tests were added in `python/pyspark/streaming/tests.py` and manually tested. Author: hyukjinkwon <gurwls...@gmail.com> Closes #19498 from HyukjinKwon/SPARK-17756. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b070ded2 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b070ded2 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b070ded2 Branch: refs/heads/master Commit: b070ded2843e88131c90cb9ef1b4f8d533f8361d Parents: 1a644af Author: hyukjinkwon <gurwls...@gmail.com> Authored: Sat Jun 9 01:27:51 2018 +0700 Committer: hyukjinkwon <gurwls...@apache.org> Committed: Sat Jun 9 01:27:51 2018 +0700 ---------------------------------------------------------------------- python/pyspark/streaming/context.py | 2 +- python/pyspark/streaming/tests.py | 6 ++++++ python/pyspark/streaming/util.py | 11 ++++++++++- 3 files changed, 17 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/b070ded2/python/pyspark/streaming/context.py ---------------------------------------------------------------------- diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py index 17c34f8..dd924ef 100644 --- a/python/pyspark/streaming/context.py +++ b/python/pyspark/streaming/context.py @@ -338,7 +338,7 @@ class StreamingContext(object): jdstreams = [d._jdstream for d in dstreams] # change the final serializer to sc.serializer func = TransformFunction(self._sc, - lambda t, *rdds: transformFunc(rdds).map(lambda x: x), + lambda t, *rdds: transformFunc(rdds), *[d._jrdd_deserializer for d in dstreams]) jfunc = self._jvm.TransformFunction(func) jdstream = self._jssc.transform(jdstreams, jfunc) http://git-wip-us.apache.org/repos/asf/spark/blob/b070ded2/python/pyspark/streaming/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/streaming/tests.py b/python/pyspark/streaming/tests.py index e4a428a..373784f 100644 --- a/python/pyspark/streaming/tests.py +++ b/python/pyspark/streaming/tests.py @@ -779,6 +779,12 @@ class StreamingContextTests(PySparkStreamingTestCase): self.assertEqual([2, 3, 1], self._take(dstream, 3)) + def test_transform_pairrdd(self): + # This regression test case is for SPARK-17756. + dstream = self.ssc.queueStream( + [[1], [2], [3]]).transform(lambda rdd: rdd.cartesian(rdd)) + self.assertEqual([(1, 1), (2, 2), (3, 3)], self._take(dstream, 3)) + def test_get_active(self): self.assertEqual(StreamingContext.getActive(), None) http://git-wip-us.apache.org/repos/asf/spark/blob/b070ded2/python/pyspark/streaming/util.py ---------------------------------------------------------------------- diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py index df18447..b4b9f97 100644 --- a/python/pyspark/streaming/util.py +++ b/python/pyspark/streaming/util.py @@ -20,6 +20,8 @@ from datetime import datetime import traceback import sys +from py4j.java_gateway import is_instance_of + from pyspark import SparkContext, RDD @@ -65,7 +67,14 @@ class TransformFunction(object): t = datetime.fromtimestamp(milliseconds / 1000.0) r = self.func(t, *rdds) if r: - return r._jrdd + # Here, we work around to ensure `_jrdd` is `JavaRDD` by wrapping it by `map`. + # org.apache.spark.streaming.api.python.PythonTransformFunction requires to return + # `JavaRDD`; however, this could be `JavaPairRDD` by some APIs, for example, `zip`. + # See SPARK-17756. + if is_instance_of(self.ctx._gateway, r._jrdd, "org.apache.spark.api.java.JavaRDD"): + return r._jrdd + else: + return r.map(lambda x: x)._jrdd except: self.failure = traceback.format_exc() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org