This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new be18718  [SPARK-31788][CORE][PYTHON] Fix UnionRDD of PairRDDs
be18718 is described below

commit be18718380fc501fe2a780debd089a1df91c1699
Author: schintap <schin...@verizonmedia.com>
AuthorDate: Mon May 25 10:29:08 2020 +0900

    [SPARK-31788][CORE][PYTHON] Fix UnionRDD of PairRDDs
    
    ### What changes were proposed in this pull request?
    UnionRDD of PairRDDs causing a bug. The fix is to check for instance type 
before proceeding
    
    ### Why are the changes needed?
    Changes are needed to avoid users running into issues with union rdd 
operation with any other type other than JavaRDD.
    
    ### Does this PR introduce _any_ user-facing change?
    Yes
    
    Before:
    SparkSession available as 'spark'.
    >>> rdd1 = sc.parallelize([1,2,3,4,5])
    >>> rdd2 = sc.parallelize([6,7,8,9,10])
    >>> pairRDD1 = rdd1.zip(rdd2)
    >>> unionRDD1 = sc.union([pairRDD1, pairRDD1])
    Traceback (most recent call last): File "<stdin>", line 1, in <module> File 
"/home/gs/spark/latest/python/pyspark/context.py", line 870,
    in union jrdds[i] = rdds[i]._jrdd
    File 
"/home/gs/spark/latest/python/lib/py4j-0.10.9-src.zip/py4j/java_collections.py",
 line 238, in setitem File 
"/home/gs/spark/latest/python/lib/py4j-0.10.9-src.zip/py4j/java_collections.py",
 line 221,
    in __set_item File 
"/home/gs/spark/latest/python/lib/py4j-0.10.9-src.zip/py4j/protocol.py", line 
332, in get_return_value py4j.protocol.Py4JError: An error occurred while 
calling None.None. Trace: py4j.Py4JException: Cannot convert 
org.apache.spark.api.java.JavaPairRDD to org.apache.spark.api.java.JavaRDD at 
py4j.commands.ArrayCommand.convertArgument(ArrayCommand.java:166) at 
py4j.commands.ArrayCommand.setArray(ArrayCommand.java:144) at 
py4j.commands.ArrayCommand.execute(ArrayCommand. [...]
    
    After:
    >>> rdd2 = sc.parallelize([6,7,8,9,10])
    >>> pairRDD1 = rdd1.zip(rdd2)
    >>> unionRDD1 = sc.union([pairRDD1, pairRDD1])
    >>> unionRDD1.collect()
    [(1, 6), (2, 7), (3, 8), (4, 9), (5, 10), (1, 6), (2, 7), (3, 8), (4, 9), 
(5, 10)]
    
    ### How was this patch tested?
    Tested with the reproduced piece of code above manually
    
    Closes #28603 from redsanket/SPARK-31788.
    
    Authored-by: schintap <schin...@verizonmedia.com>
    Signed-off-by: HyukjinKwon <gurwls...@apache.org>
    (cherry picked from commit a61911c50c391e61038cf01611629d2186d17a76)
    Signed-off-by: HyukjinKwon <gurwls...@apache.org>
---
 python/pyspark/context.py        | 12 ++++++++++--
 python/pyspark/tests/test_rdd.py |  9 +++++++++
 2 files changed, 19 insertions(+), 2 deletions(-)

diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index d5f1506..3199aa7 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -25,6 +25,7 @@ from threading import RLock
 from tempfile import NamedTemporaryFile
 
 from py4j.protocol import Py4JError
+from py4j.java_gateway import is_instance_of
 
 from pyspark import accumulators
 from pyspark.accumulators import Accumulator
@@ -864,10 +865,17 @@ class SparkContext(object):
         first_jrdd_deserializer = rdds[0]._jrdd_deserializer
         if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds):
             rdds = [x._reserialize() for x in rdds]
+        gw = SparkContext._gateway
         cls = SparkContext._jvm.org.apache.spark.api.java.JavaRDD
-        jrdds = SparkContext._gateway.new_array(cls, len(rdds))
+        is_jrdd = is_instance_of(gw, rdds[0]._jrdd, cls)
+        jrdds = gw.new_array(cls, len(rdds))
         for i in range(0, len(rdds)):
-            jrdds[i] = rdds[i]._jrdd
+            if is_jrdd:
+                jrdds[i] = rdds[i]._jrdd
+            else:
+                # zip could return JavaPairRDD hence we ensure `_jrdd`
+                # to be `JavaRDD` by wrapping it in a `map`
+                jrdds[i] = rdds[i].map(lambda x: x)._jrdd
         return RDD(self._jsc.union(jrdds), self, rdds[0]._jrdd_deserializer)
 
     def broadcast(self, value):
diff --git a/python/pyspark/tests/test_rdd.py b/python/pyspark/tests/test_rdd.py
index e2d910c..0f1ee5b 100644
--- a/python/pyspark/tests/test_rdd.py
+++ b/python/pyspark/tests/test_rdd.py
@@ -166,6 +166,15 @@ class RDDTests(ReusedPySparkTestCase):
             set([(x, (x, x)) for x in 'abc'])
         )
 
+    def test_union_pair_rdd(self):
+        # Regression test for SPARK-31788
+        rdd = self.sc.parallelize([1, 2])
+        pair_rdd = rdd.zip(rdd)
+        self.assertEqual(
+            self.sc.union([pair_rdd, pair_rdd]).collect(),
+            [((1, 1), (2, 2)), ((1, 1), (2, 2))]
+        )
+
     def test_deleting_input_files(self):
         # Regression test for SPARK-1025
         tempFile = tempfile.NamedTemporaryFile(delete=False)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to