Repository: spark
Updated Branches:
  refs/heads/master 797f8a000 -> 3c4e486b9


[SPARK-5843] [API] Allowing map-side combine to be specified in Java.

Specifically, when calling JavaPairRDD.combineByKey(), there is a new
six-parameter method that exposes the map-side-combine boolean as the
fifth parameter and the serializer as the sixth parameter.

Author: mcheah <mch...@palantir.com>

Closes #4634 from mccheah/pair-rdd-map-side-combine and squashes the following 
commits:

5c58319 [mcheah] Fixing compiler errors.
3ce7deb [mcheah] Addressing style and documentation comments.
7455c7a [mcheah] Allowing Java combineByKey to specify Serializer as well.
6ddd729 [mcheah] [SPARK-5843] Allowing map-side combine to be specified in Java.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3c4e486b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3c4e486b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3c4e486b

Branch: refs/heads/master
Commit: 3c4e486b9c8d3f256e801db7c176ab650c976135
Parents: 797f8a0
Author: mcheah <mch...@palantir.com>
Authored: Thu Mar 19 08:51:49 2015 -0400
Committer: Sean Owen <so...@cloudera.com>
Committed: Thu Mar 19 08:51:49 2015 -0400

----------------------------------------------------------------------
 .../org/apache/spark/api/java/JavaPairRDD.scala | 46 +++++++++++++----
 .../java/org/apache/spark/JavaAPISuite.java     | 53 ++++++++++++++++++--
 2 files changed, 87 insertions(+), 12 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3c4e486b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala 
b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index 4eadc9a..a023712 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -39,6 +39,7 @@ import org.apache.spark.api.java.function.{Function => 
JFunction, Function2 => J
 import org.apache.spark.partial.{BoundedDouble, PartialResult}
 import org.apache.spark.rdd.{OrderedRDDFunctions, RDD}
 import org.apache.spark.rdd.RDD.rddToPairRDDFunctions
+import org.apache.spark.serializer.Serializer
 import org.apache.spark.storage.StorageLevel
 import org.apache.spark.util.Utils
 
@@ -227,24 +228,51 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
    * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list)
    * - `mergeCombiners`, to combine two C's into a single one.
    *
-   * In addition, users can control the partitioning of the output RDD, and 
whether to perform
-   * map-side aggregation (if a mapper can produce multiple items with the 
same key).
+   * In addition, users can control the partitioning of the output RDD, the 
serializer that is use
+   * for the shuffle, and whether to perform map-side aggregation (if a mapper 
can produce multiple
+   * items with the same key).
    */
   def combineByKey[C](createCombiner: JFunction[V, C],
-    mergeValue: JFunction2[C, V, C],
-    mergeCombiners: JFunction2[C, C, C],
-    partitioner: Partitioner): JavaPairRDD[K, C] = {
-    implicit val ctag: ClassTag[C] = fakeClassTag
+      mergeValue: JFunction2[C, V, C],
+      mergeCombiners: JFunction2[C, C, C],
+      partitioner: Partitioner,
+      mapSideCombine: Boolean,
+      serializer: Serializer): JavaPairRDD[K, C] = {
+      implicit val ctag: ClassTag[C] = fakeClassTag
     fromRDD(rdd.combineByKey(
       createCombiner,
       mergeValue,
       mergeCombiners,
-      partitioner
+      partitioner,
+      mapSideCombine,
+      serializer
     ))
   }
 
   /**
-   * Simplified version of combineByKey that hash-partitions the output RDD.
+   * Generic function to combine the elements for each key using a custom set 
of aggregation
+   * functions. Turns a JavaPairRDD[(K, V)] into a result of type 
JavaPairRDD[(K, C)], for a
+   * "combined type" C * Note that V and C can be different -- for example, 
one might group an
+   * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users 
provide three
+   * functions:
+   *
+   * - `createCombiner`, which turns a V into a C (e.g., creates a one-element 
list)
+   * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list)
+   * - `mergeCombiners`, to combine two C's into a single one.
+   *
+   * In addition, users can control the partitioning of the output RDD. This 
method automatically
+   * uses map-side aggregation in shuffling the RDD.
+   */
+  def combineByKey[C](createCombiner: JFunction[V, C],
+      mergeValue: JFunction2[C, V, C],
+      mergeCombiners: JFunction2[C, C, C],
+      partitioner: Partitioner): JavaPairRDD[K, C] = {
+    combineByKey(createCombiner, mergeValue, mergeCombiners, partitioner, 
true, null)
+  }
+
+  /**
+   * Simplified version of combineByKey that hash-partitions the output RDD 
and uses map-side
+   * aggregation.
    */
   def combineByKey[C](createCombiner: JFunction[V, C],
       mergeValue: JFunction2[C, V, C],
@@ -488,7 +516,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
 
   /**
    * Simplified version of combineByKey that hash-partitions the resulting RDD 
using the existing
-   * partitioner/parallelism level.
+   * partitioner/parallelism level and using map-side aggregation.
    */
   def combineByKey[C](createCombiner: JFunction[V, C],
     mergeValue: JFunction2[C, V, C],

http://git-wip-us.apache.org/repos/asf/spark/blob/3c4e486b/core/src/test/java/org/apache/spark/JavaAPISuite.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java 
b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index 8ec5436..d4b5bb5 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -24,11 +24,12 @@ import java.net.URI;
 import java.util.*;
 import java.util.concurrent.*;
 
-import org.apache.spark.input.PortableDataStream;
+import scala.collection.JavaConversions;
 import scala.Tuple2;
 import scala.Tuple3;
 import scala.Tuple4;
 
+import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.Iterables;
 import com.google.common.collect.Iterators;
 import com.google.common.collect.Lists;
@@ -51,8 +52,11 @@ import org.junit.Test;
 import org.apache.spark.api.java.*;
 import org.apache.spark.api.java.function.*;
 import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.input.PortableDataStream;
 import org.apache.spark.partial.BoundedDouble;
 import org.apache.spark.partial.PartialResult;
+import org.apache.spark.rdd.RDD;
+import org.apache.spark.serializer.KryoSerializer;
 import org.apache.spark.storage.StorageLevel;
 import org.apache.spark.util.StatCounter;
 
@@ -726,8 +730,8 @@ public class JavaAPISuite implements Serializable {
     Tuple2<double[], long[]> results = rdd.histogram(2);
     double[] expected_buckets = {1.0, 2.5, 4.0};
     long[] expected_counts = {2, 2};
-    Assert.assertArrayEquals(expected_buckets, results._1, 0.1);
-    Assert.assertArrayEquals(expected_counts, results._2);
+    Assert.assertArrayEquals(expected_buckets, results._1(), 0.1);
+    Assert.assertArrayEquals(expected_counts, results._2());
     // Test with provided buckets
     long[] histogram = rdd.histogram(expected_buckets);
     Assert.assertArrayEquals(expected_counts, histogram);
@@ -1424,6 +1428,49 @@ public class JavaAPISuite implements Serializable {
     Assert.assertEquals(Arrays.asList(1, 2, 3, 4, 5), recovered.collect());
   }
 
+  @Test
+  public void combineByKey() {
+    JavaRDD<Integer> originalRDD = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 
6));
+    Function<Integer, Integer> keyFunction = new Function<Integer, Integer>() {
+      @Override
+      public Integer call(Integer v1) throws Exception {
+        return v1 % 3;
+      }
+    };
+    Function<Integer, Integer> createCombinerFunction = new Function<Integer, 
Integer>() {
+      @Override
+      public Integer call(Integer v1) throws Exception {
+        return v1;
+      }
+    };
+
+    Function2<Integer, Integer, Integer> mergeValueFunction = new 
Function2<Integer, Integer, Integer>() {
+      @Override
+      public Integer call(Integer v1, Integer v2) throws Exception {
+        return v1 + v2;
+      }
+    };
+
+    JavaPairRDD<Integer, Integer> combinedRDD = originalRDD.keyBy(keyFunction)
+        .combineByKey(createCombinerFunction, mergeValueFunction, 
mergeValueFunction);
+    Map<Integer, Integer> results = combinedRDD.collectAsMap();
+    ImmutableMap<Integer, Integer> expected = ImmutableMap.of(0, 9, 1, 5, 2, 
7);
+    Assert.assertEquals(expected, results);
+
+    Partitioner defaultPartitioner = Partitioner.defaultPartitioner(
+        combinedRDD.rdd(), 
JavaConversions.asScalaBuffer(Lists.<RDD<?>>newArrayList()));
+    combinedRDD = originalRDD.keyBy(keyFunction)
+        .combineByKey(
+             createCombinerFunction,
+             mergeValueFunction,
+             mergeValueFunction,
+             defaultPartitioner,
+             false,
+             new KryoSerializer(new SparkConf()));
+    results = combinedRDD.collectAsMap();
+    Assert.assertEquals(expected, results);
+  }
+
   @SuppressWarnings("unchecked")
   @Test
   public void mapOnPairRDD() {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to