Updated Branches:
  refs/heads/branch-0.8 f3cc3a7b8 -> c89b71ac7

Add collectPartition to JavaRDD interface.
Also remove takePartition from PythonRDD and use collectPartition in rdd.py.

Conflicts:
        core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
        python/pyspark/context.py
        python/pyspark/rdd.py


Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/5092baed
Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/5092baed
Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/5092baed

Branch: refs/heads/branch-0.8
Commit: 5092baedc20a7372b4238a7468470a9c5c60deeb
Parents: 5c443ad
Author: Shivaram Venkataraman <shiva...@eecs.berkeley.edu>
Authored: Wed Dec 18 11:40:07 2013 -0800
Committer: Shivaram Venkataraman <shiva...@eecs.berkeley.edu>
Committed: Thu Jan 16 19:20:57 2014 -0800

----------------------------------------------------------------------
 .../org/apache/spark/api/java/JavaRDDLike.scala | 11 +++++++-
 .../org/apache/spark/api/python/PythonRDD.scala |  4 ---
 .../scala/org/apache/spark/JavaAPISuite.java    | 28 ++++++++++++++++++++
 python/pyspark/context.py                       |  3 ---
 python/pyspark/rdd.py                           |  4 +--
 5 files changed, 40 insertions(+), 10 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/5092baed/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala 
b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index 7a3568c..0e46876 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -25,7 +25,7 @@ import com.google.common.base.Optional
 import org.apache.hadoop.io.compress.CompressionCodec
 
 import org.apache.spark.{SparkContext, Partition, TaskContext}
-import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.{RDD, PartitionPruningRDD}
 import org.apache.spark.api.java.JavaPairRDD._
 import org.apache.spark.api.java.function.{Function2 => JFunction2, Function 
=> JFunction, _}
 import org.apache.spark.partial.{PartialResult, BoundedDouble}
@@ -247,6 +247,15 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends 
Serializable {
   }
 
   /**
+   * Return an array that contains all of the elements in a specific partition 
of this RDD.
+   */
+  def collectPartition(partitionId: Int): JList[T] = {
+    import scala.collection.JavaConversions._
+    val partition = new PartitionPruningRDD[T](rdd, _ == partitionId)
+    new java.util.ArrayList(partition.collect().toSeq)
+  }
+
+  /**
    * Reduces the elements of this RDD using the specified commutative and 
associative binary operator.
    */
   def reduce(f: JFunction2[T, T, T]): T = rdd.reduce(f)

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/5092baed/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala 
b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 12b4d94..f7b38b4 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -283,10 +283,6 @@ private[spark] object PythonRDD {
     file.close()
   }
 
-  def takePartition[T](rdd: RDD[T], partition: Int): Iterator[T] = {
-    implicit val cm : ClassManifest[T] = rdd.elementClassManifest
-    rdd.context.runJob(rdd, ((x: Iterator[T]) => x.toArray), Seq(partition), 
true).head.iterator
-  }
 }
 
 private object Pickle {

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/5092baed/core/src/test/scala/org/apache/spark/JavaAPISuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java 
b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
index 352036f..d7c673a 100644
--- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
@@ -883,4 +883,32 @@ public class JavaAPISuite implements Serializable {
         new Tuple2<Integer, Integer>(0, 4)), rdd3.collect());
 
   }
+
+  @Test
+  public void collectPartition() {
+    JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7), 
3);
+
+    JavaPairRDD<Integer, Integer> rdd2 = rdd1.map(new PairFunction<Integer, 
Integer, Integer>() {
+      @Override
+      public Tuple2<Integer, Integer> call(Integer i) throws Exception {
+        return new Tuple2<Integer, Integer>(i, i % 2);
+      }
+    });
+
+    Assert.assertEquals(Arrays.asList(1, 2), rdd1.collectPartition(0));
+    Assert.assertEquals(Arrays.asList(3, 4), rdd1.collectPartition(1));
+    Assert.assertEquals(Arrays.asList(5, 6, 7), rdd1.collectPartition(2));
+
+    Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(1, 1),
+                                      new Tuple2<Integer, Integer>(2, 0)),
+                        rdd2.collectPartition(0));
+    Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(3, 1),
+                                      new Tuple2<Integer, Integer>(4, 0)),
+                        rdd2.collectPartition(1));
+    Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(5, 1),
+                                      new Tuple2<Integer, Integer>(6, 0),
+                                      new Tuple2<Integer, Integer>(7, 1)),
+                        rdd2.collectPartition(2));
+  }
+
 }

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/5092baed/python/pyspark/context.py
----------------------------------------------------------------------
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index a7ca8bc..3d47589 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -43,7 +43,6 @@ class SparkContext(object):
     _gateway = None
     _jvm = None
     _writeIteratorToPickleFile = None
-    _takePartition = None
     _next_accum_id = 0
     _active_spark_context = None
     _lock = Lock()
@@ -127,8 +126,6 @@ class SparkContext(object):
                 SparkContext._jvm = SparkContext._gateway.jvm
                 SparkContext._writeIteratorToPickleFile = \
                     SparkContext._jvm.PythonRDD.writeIteratorToPickleFile
-                SparkContext._takePartition = \
-                    SparkContext._jvm.PythonRDD.takePartition
 
             if instance:
                 if SparkContext._active_spark_context and 
SparkContext._active_spark_context != instance:

http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/5092baed/python/pyspark/rdd.py
----------------------------------------------------------------------
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 0c599e0..22cddf0 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -570,8 +570,8 @@ class RDD(object):
         mapped = self.mapPartitions(takeUpToNum)
         items = []
         for partition in range(mapped._jrdd.splits().size()):
-            iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition)
-            items.extend(self._collect_iterator_through_file(iterator))
+            iterator = mapped._jrdd.collectPartition(partition).iterator()
+            items.extend(mapped._collect_iterator_through_file(iterator))
             if len(items) >= num:
                 break
         return items[:num]

Reply via email to