Updated Branches: refs/heads/branch-0.8 f3cc3a7b8 -> c89b71ac7
Add collectPartition to JavaRDD interface. Also remove takePartition from PythonRDD and use collectPartition in rdd.py. Conflicts: core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala python/pyspark/context.py python/pyspark/rdd.py Project: http://git-wip-us.apache.org/repos/asf/incubator-spark/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-spark/commit/5092baed Tree: http://git-wip-us.apache.org/repos/asf/incubator-spark/tree/5092baed Diff: http://git-wip-us.apache.org/repos/asf/incubator-spark/diff/5092baed Branch: refs/heads/branch-0.8 Commit: 5092baedc20a7372b4238a7468470a9c5c60deeb Parents: 5c443ad Author: Shivaram Venkataraman <shiva...@eecs.berkeley.edu> Authored: Wed Dec 18 11:40:07 2013 -0800 Committer: Shivaram Venkataraman <shiva...@eecs.berkeley.edu> Committed: Thu Jan 16 19:20:57 2014 -0800 ---------------------------------------------------------------------- .../org/apache/spark/api/java/JavaRDDLike.scala | 11 +++++++- .../org/apache/spark/api/python/PythonRDD.scala | 4 --- .../scala/org/apache/spark/JavaAPISuite.java | 28 ++++++++++++++++++++ python/pyspark/context.py | 3 --- python/pyspark/rdd.py | 4 +-- 5 files changed, 40 insertions(+), 10 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/5092baed/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 7a3568c..0e46876 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -25,7 +25,7 @@ import com.google.common.base.Optional import org.apache.hadoop.io.compress.CompressionCodec import org.apache.spark.{SparkContext, Partition, TaskContext} -import org.apache.spark.rdd.RDD +import org.apache.spark.rdd.{RDD, PartitionPruningRDD} import org.apache.spark.api.java.JavaPairRDD._ import org.apache.spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _} import org.apache.spark.partial.{PartialResult, BoundedDouble} @@ -247,6 +247,15 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { } /** + * Return an array that contains all of the elements in a specific partition of this RDD. + */ + def collectPartition(partitionId: Int): JList[T] = { + import scala.collection.JavaConversions._ + val partition = new PartitionPruningRDD[T](rdd, _ == partitionId) + new java.util.ArrayList(partition.collect().toSeq) + } + + /** * Reduces the elements of this RDD using the specified commutative and associative binary operator. */ def reduce(f: JFunction2[T, T, T]): T = rdd.reduce(f) http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/5092baed/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 12b4d94..f7b38b4 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -283,10 +283,6 @@ private[spark] object PythonRDD { file.close() } - def takePartition[T](rdd: RDD[T], partition: Int): Iterator[T] = { - implicit val cm : ClassManifest[T] = rdd.elementClassManifest - rdd.context.runJob(rdd, ((x: Iterator[T]) => x.toArray), Seq(partition), true).head.iterator - } } private object Pickle { http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/5092baed/core/src/test/scala/org/apache/spark/JavaAPISuite.java ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java index 352036f..d7c673a 100644 --- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java @@ -883,4 +883,32 @@ public class JavaAPISuite implements Serializable { new Tuple2<Integer, Integer>(0, 4)), rdd3.collect()); } + + @Test + public void collectPartition() { + JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7), 3); + + JavaPairRDD<Integer, Integer> rdd2 = rdd1.map(new PairFunction<Integer, Integer, Integer>() { + @Override + public Tuple2<Integer, Integer> call(Integer i) throws Exception { + return new Tuple2<Integer, Integer>(i, i % 2); + } + }); + + Assert.assertEquals(Arrays.asList(1, 2), rdd1.collectPartition(0)); + Assert.assertEquals(Arrays.asList(3, 4), rdd1.collectPartition(1)); + Assert.assertEquals(Arrays.asList(5, 6, 7), rdd1.collectPartition(2)); + + Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(1, 1), + new Tuple2<Integer, Integer>(2, 0)), + rdd2.collectPartition(0)); + Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(3, 1), + new Tuple2<Integer, Integer>(4, 0)), + rdd2.collectPartition(1)); + Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(5, 1), + new Tuple2<Integer, Integer>(6, 0), + new Tuple2<Integer, Integer>(7, 1)), + rdd2.collectPartition(2)); + } + } http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/5092baed/python/pyspark/context.py ---------------------------------------------------------------------- diff --git a/python/pyspark/context.py b/python/pyspark/context.py index a7ca8bc..3d47589 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -43,7 +43,6 @@ class SparkContext(object): _gateway = None _jvm = None _writeIteratorToPickleFile = None - _takePartition = None _next_accum_id = 0 _active_spark_context = None _lock = Lock() @@ -127,8 +126,6 @@ class SparkContext(object): SparkContext._jvm = SparkContext._gateway.jvm SparkContext._writeIteratorToPickleFile = \ SparkContext._jvm.PythonRDD.writeIteratorToPickleFile - SparkContext._takePartition = \ - SparkContext._jvm.PythonRDD.takePartition if instance: if SparkContext._active_spark_context and SparkContext._active_spark_context != instance: http://git-wip-us.apache.org/repos/asf/incubator-spark/blob/5092baed/python/pyspark/rdd.py ---------------------------------------------------------------------- diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 0c599e0..22cddf0 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -570,8 +570,8 @@ class RDD(object): mapped = self.mapPartitions(takeUpToNum) items = [] for partition in range(mapped._jrdd.splits().size()): - iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition) - items.extend(self._collect_iterator_through_file(iterator)) + iterator = mapped._jrdd.collectPartition(partition).iterator() + items.extend(mapped._collect_iterator_through_file(iterator)) if len(items) >= num: break return items[:num]