Repository: spark
Updated Branches:
  refs/heads/master 1a644afba -> b070ded28


[SPARK-17756][PYTHON][STREAMING] Workaround to avoid return type mismatch in 
PythonTransformFunction

## What changes were proposed in this pull request?

This PR proposes to wrap the transformed rdd within `TransformFunction`. 
`PythonTransformFunction` looks requiring to return `JavaRDD` in `_jrdd`.

https://github.com/apache/spark/blob/39e2bad6a866d27c3ca594d15e574a1da3ee84cc/python/pyspark/streaming/util.py#L67

https://github.com/apache/spark/blob/6ee28423ad1b2e6089b82af64a31d77d3552bb38/streaming/src/main/scala/org/apache/spark/streaming/api/python/PythonDStream.scala#L43

However, this could be `JavaPairRDD` by some APIs, for example, `zip` in 
PySpark's RDD API.
`_jrdd` could be checked as below:

```python
>>> rdd.zip(rdd)._jrdd.getClass().toString()
u'class org.apache.spark.api.java.JavaPairRDD'
```

So, here, I wrapped it with `map` so that it ensures returning `JavaRDD`.

```python
>>> rdd.zip(rdd).map(lambda x: x)._jrdd.getClass().toString()
u'class org.apache.spark.api.java.JavaRDD'
```

I tried to elaborate some failure cases as below:

```python
from pyspark.streaming import StreamingContext
ssc = StreamingContext(spark.sparkContext, 10)
ssc.queueStream([sc.range(10)]) \
    .transform(lambda rdd: rdd.cartesian(rdd)) \
    .pprint()
ssc.start()
```

```python
from pyspark.streaming import StreamingContext
ssc = StreamingContext(spark.sparkContext, 10)
ssc.queueStream([sc.range(10)]).foreachRDD(lambda rdd: rdd.cartesian(rdd))
ssc.start()
```

```python
from pyspark.streaming import StreamingContext
ssc = StreamingContext(spark.sparkContext, 10)
ssc.queueStream([sc.range(10)]).foreachRDD(lambda rdd: rdd.zip(rdd))
ssc.start()
```

```python
from pyspark.streaming import StreamingContext
ssc = StreamingContext(spark.sparkContext, 10)
ssc.queueStream([sc.range(10)]).foreachRDD(lambda rdd: 
rdd.zip(rdd).union(rdd.zip(rdd)))
ssc.start()
```

```python
from pyspark.streaming import StreamingContext
ssc = StreamingContext(spark.sparkContext, 10)
ssc.queueStream([sc.range(10)]).foreachRDD(lambda rdd: rdd.zip(rdd).coalesce(1))
ssc.start()
```

## How was this patch tested?

Unit tests were added in `python/pyspark/streaming/tests.py` and manually 
tested.

Author: hyukjinkwon <gurwls...@gmail.com>

Closes #19498 from HyukjinKwon/SPARK-17756.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b070ded2
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b070ded2
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b070ded2

Branch: refs/heads/master
Commit: b070ded2843e88131c90cb9ef1b4f8d533f8361d
Parents: 1a644af
Author: hyukjinkwon <gurwls...@gmail.com>
Authored: Sat Jun 9 01:27:51 2018 +0700
Committer: hyukjinkwon <gurwls...@apache.org>
Committed: Sat Jun 9 01:27:51 2018 +0700

----------------------------------------------------------------------
 python/pyspark/streaming/context.py |  2 +-
 python/pyspark/streaming/tests.py   |  6 ++++++
 python/pyspark/streaming/util.py    | 11 ++++++++++-
 3 files changed, 17 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b070ded2/python/pyspark/streaming/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/streaming/context.py 
b/python/pyspark/streaming/context.py
index 17c34f8..dd924ef 100644
--- a/python/pyspark/streaming/context.py
+++ b/python/pyspark/streaming/context.py
@@ -338,7 +338,7 @@ class StreamingContext(object):
         jdstreams = [d._jdstream for d in dstreams]
         # change the final serializer to sc.serializer
         func = TransformFunction(self._sc,
-                                 lambda t, *rdds: 
transformFunc(rdds).map(lambda x: x),
+                                 lambda t, *rdds: transformFunc(rdds),
                                  *[d._jrdd_deserializer for d in dstreams])
         jfunc = self._jvm.TransformFunction(func)
         jdstream = self._jssc.transform(jdstreams, jfunc)

http://git-wip-us.apache.org/repos/asf/spark/blob/b070ded2/python/pyspark/streaming/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/streaming/tests.py 
b/python/pyspark/streaming/tests.py
index e4a428a..373784f 100644
--- a/python/pyspark/streaming/tests.py
+++ b/python/pyspark/streaming/tests.py
@@ -779,6 +779,12 @@ class StreamingContextTests(PySparkStreamingTestCase):
 
         self.assertEqual([2, 3, 1], self._take(dstream, 3))
 
+    def test_transform_pairrdd(self):
+        # This regression test case is for SPARK-17756.
+        dstream = self.ssc.queueStream(
+            [[1], [2], [3]]).transform(lambda rdd: rdd.cartesian(rdd))
+        self.assertEqual([(1, 1), (2, 2), (3, 3)], self._take(dstream, 3))
+
     def test_get_active(self):
         self.assertEqual(StreamingContext.getActive(), None)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b070ded2/python/pyspark/streaming/util.py
----------------------------------------------------------------------
diff --git a/python/pyspark/streaming/util.py b/python/pyspark/streaming/util.py
index df18447..b4b9f97 100644
--- a/python/pyspark/streaming/util.py
+++ b/python/pyspark/streaming/util.py
@@ -20,6 +20,8 @@ from datetime import datetime
 import traceback
 import sys
 
+from py4j.java_gateway import is_instance_of
+
 from pyspark import SparkContext, RDD
 
 
@@ -65,7 +67,14 @@ class TransformFunction(object):
             t = datetime.fromtimestamp(milliseconds / 1000.0)
             r = self.func(t, *rdds)
             if r:
-                return r._jrdd
+                # Here, we work around to ensure `_jrdd` is `JavaRDD` by 
wrapping it by `map`.
+                # 
org.apache.spark.streaming.api.python.PythonTransformFunction requires to return
+                # `JavaRDD`; however, this could be `JavaPairRDD` by some 
APIs, for example, `zip`.
+                # See SPARK-17756.
+                if is_instance_of(self.ctx._gateway, r._jrdd, 
"org.apache.spark.api.java.JavaRDD"):
+                    return r._jrdd
+                else:
+                    return r.map(lambda x: x)._jrdd
         except:
             self.failure = traceback.format_exc()
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to