Repository: spark
Updated Branches:
  refs/heads/branch-1.6 1585f559d -> b9adfdf9c


[SPARK-11564][SQL][FOLLOW-UP] improve java api for GroupedDataset

created `MapGroupFunction`, `FlatMapGroupFunction`, `CoGroupFunction`

Author: Wenchen Fan <wenc...@databricks.com>

Closes #9564 from cloud-fan/map.

(cherry picked from commit fcb57e9c7323e24b8563800deb035f94f616474e)
Signed-off-by: Michael Armbrust <mich...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b9adfdf9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b9adfdf9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b9adfdf9

Branch: refs/heads/branch-1.6
Commit: b9adfdf9ca18292799e684c8510028c75fbf2808
Parents: 1585f55
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Mon Nov 9 15:16:47 2015 -0800
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Mon Nov 9 15:17:08 2015 -0800

----------------------------------------------------------------------
 .../api/java/function/CoGroupFunction.java      | 29 ++++++++++++++++
 .../api/java/function/FlatMapFunction.java      |  2 +-
 .../api/java/function/FlatMapFunction2.java     |  2 +-
 .../api/java/function/FlatMapGroupFunction.java | 28 +++++++++++++++
 .../api/java/function/MapGroupFunction.java     | 28 +++++++++++++++
 .../catalyst/plans/logical/basicOperators.scala |  4 +--
 .../org/apache/spark/sql/GroupedDataset.scala   | 12 +++----
 .../spark/sql/execution/basicOperators.scala    |  2 +-
 .../org/apache/spark/sql/JavaDatasetSuite.java  | 36 +++++++++++++-------
 9 files changed, 118 insertions(+), 25 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b9adfdf9/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java 
b/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java
new file mode 100644
index 0000000..279639a
--- /dev/null
+++ b/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+import java.util.Iterator;
+
+/**
+ * A function that returns zero or more output records from each grouping key 
and its values from 2
+ * Datasets.
+ */
+public interface CoGroupFunction<K, V1, V2, R> extends Serializable {
+  Iterable<R> call(K key, Iterator<V1> left, Iterator<V2> right) throws 
Exception;
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/b9adfdf9/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java 
b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java
index 23f5fdd..ef0d182 100644
--- a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java
+++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java
@@ -23,5 +23,5 @@ import java.io.Serializable;
  * A function that returns zero or more output records from each input record.
  */
 public interface FlatMapFunction<T, R> extends Serializable {
-  public Iterable<R> call(T t) throws Exception;
+  Iterable<R> call(T t) throws Exception;
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/b9adfdf9/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java 
b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java
index c48e92f..14a98a3 100644
--- 
a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java
+++ 
b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java
@@ -23,5 +23,5 @@ import java.io.Serializable;
  * A function that takes two inputs and returns zero or more output records.
  */
 public interface FlatMapFunction2<T1, T2, R> extends Serializable {
-  public Iterable<R> call(T1 t1, T2 t2) throws Exception;
+  Iterable<R> call(T1 t1, T2 t2) throws Exception;
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/b9adfdf9/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupFunction.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupFunction.java
 
b/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupFunction.java
new file mode 100644
index 0000000..18a2d73
--- /dev/null
+++ 
b/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupFunction.java
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+import java.util.Iterator;
+
+/**
+ * A function that returns zero or more output records from each grouping key 
and its values.
+ */
+public interface FlatMapGroupFunction<K, V, R> extends Serializable {
+  Iterable<R> call(K key, Iterator<V> values) throws Exception;
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/b9adfdf9/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java 
b/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java
new file mode 100644
index 0000000..2935f99
--- /dev/null
+++ 
b/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java.function;
+
+import java.io.Serializable;
+import java.util.Iterator;
+
+/**
+ * Base interface for a map function used in GroupedDataset's map function.
+ */
+public interface MapGroupFunction<K, V, R> extends Serializable {
+  R call(K key, Iterator<V> values) throws Exception;
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/b9adfdf9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index e151ac0..d771088 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -527,7 +527,7 @@ case class MapGroups[K, T, U](
 /** Factory for constructing new `CoGroup` nodes. */
 object CoGroup {
   def apply[K : Encoder, Left : Encoder, Right : Encoder, R : Encoder](
-      func: (K, Iterator[Left], Iterator[Right]) => Iterator[R],
+      func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R],
       leftGroup: Seq[Attribute],
       rightGroup: Seq[Attribute],
       left: LogicalPlan,
@@ -551,7 +551,7 @@ object CoGroup {
  * right children.
  */
 case class CoGroup[K, Left, Right, R](
-    func: (K, Iterator[Left], Iterator[Right]) => Iterator[R],
+    func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R],
     kEncoder: ExpressionEncoder[K],
     leftEnc: ExpressionEncoder[Left],
     rightEnc: ExpressionEncoder[Right],

http://git-wip-us.apache.org/repos/asf/spark/blob/b9adfdf9/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index 5c3f626..850315e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -108,9 +108,7 @@ class GroupedDataset[K, T] private[sql](
       MapGroups(f, groupingAttributes, logicalPlan))
   }
 
-  def flatMap[U](
-      f: JFunction2[K, JIterator[T], JIterator[U]],
-      encoder: Encoder[U]): Dataset[U] = {
+  def flatMap[U](f: FlatMapGroupFunction[K, T, U], encoder: Encoder[U]): 
Dataset[U] = {
     flatMap((key, data) => f.call(key, data.asJava).asScala)(encoder)
   }
 
@@ -131,9 +129,7 @@ class GroupedDataset[K, T] private[sql](
       MapGroups(func, groupingAttributes, logicalPlan))
   }
 
-  def map[U](
-      f: JFunction2[K, JIterator[T], U],
-      encoder: Encoder[U]): Dataset[U] = {
+  def map[U](f: MapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = {
     map((key, data) => f.call(key, data.asJava))(encoder)
   }
 
@@ -218,7 +214,7 @@ class GroupedDataset[K, T] private[sql](
    */
   def cogroup[U, R : Encoder](
       other: GroupedDataset[K, U])(
-      f: (K, Iterator[T], Iterator[U]) => Iterator[R]): Dataset[R] = {
+      f: (K, Iterator[T], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
     implicit def uEnc: Encoder[U] = other.tEncoder
     new Dataset[R](
       sqlContext,
@@ -232,7 +228,7 @@ class GroupedDataset[K, T] private[sql](
 
   def cogroup[U, R](
       other: GroupedDataset[K, U],
-      f: JFunction3[K, JIterator[T], JIterator[U], JIterator[R]],
+      f: CoGroupFunction[K, T, U, R],
       encoder: Encoder[R]): Dataset[R] = {
     cogroup(other)((key, left, right) => f.call(key, left.asJava, 
right.asJava).asScala)(encoder)
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/b9adfdf9/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 2593b16..145de0d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -391,7 +391,7 @@ case class MapGroups[K, T, U](
  * The result of this function is encoded and flattened before being output.
  */
 case class CoGroup[K, Left, Right, R](
-    func: (K, Iterator[Left], Iterator[Right]) => Iterator[R],
+    func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R],
     kEncoder: ExpressionEncoder[K],
     leftEnc: ExpressionEncoder[Left],
     rightEnc: ExpressionEncoder[Right],

http://git-wip-us.apache.org/repos/asf/spark/blob/b9adfdf9/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java 
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 0f90de7..312cf33 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -29,7 +29,6 @@ import org.junit.*;
 import org.apache.spark.Accumulator;
 import org.apache.spark.SparkContext;
 import org.apache.spark.api.java.function.*;
-import org.apache.spark.api.java.function.Function;
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.sql.catalyst.encoders.Encoder;
 import org.apache.spark.sql.catalyst.encoders.Encoder$;
@@ -170,20 +169,33 @@ public class JavaDatasetSuite implements Serializable {
       }
     }, e.INT());
 
-    Dataset<String> mapped = grouped.map(
-      new Function2<Integer, Iterator<String>, String>() {
+    Dataset<String> mapped = grouped.map(new MapGroupFunction<Integer, String, 
String>() {
+      @Override
+      public String call(Integer key, Iterator<String> values) throws 
Exception {
+        StringBuilder sb = new StringBuilder(key.toString());
+        while (values.hasNext()) {
+          sb.append(values.next());
+        }
+        return sb.toString();
+      }
+    }, e.STRING());
+
+    Assert.assertEquals(Arrays.asList("1a", "3foobar"), 
mapped.collectAsList());
+
+    Dataset<String> flatMapped = grouped.flatMap(
+      new FlatMapGroupFunction<Integer, String, String>() {
         @Override
-        public String call(Integer key, Iterator<String> data) throws 
Exception {
+        public Iterable<String> call(Integer key, Iterator<String> values) 
throws Exception {
           StringBuilder sb = new StringBuilder(key.toString());
-          while (data.hasNext()) {
-            sb.append(data.next());
+          while (values.hasNext()) {
+            sb.append(values.next());
           }
-          return sb.toString();
+          return Collections.singletonList(sb.toString());
         }
       },
       e.STRING());
 
-    Assert.assertEquals(Arrays.asList("1a", "3foobar"), 
mapped.collectAsList());
+    Assert.assertEquals(Arrays.asList("1a", "3foobar"), 
flatMapped.collectAsList());
 
     List<Integer> data2 = Arrays.asList(2, 6, 10);
     Dataset<Integer> ds2 = context.createDataset(data2, e.INT());
@@ -196,9 +208,9 @@ public class JavaDatasetSuite implements Serializable {
 
     Dataset<String> cogrouped = grouped.cogroup(
       grouped2,
-      new Function3<Integer, Iterator<String>, Iterator<Integer>, 
Iterator<String>>() {
+      new CoGroupFunction<Integer, String, Integer, String>() {
         @Override
-        public Iterator<String> call(
+        public Iterable<String> call(
             Integer key,
             Iterator<String> left,
             Iterator<Integer> right) throws Exception {
@@ -210,7 +222,7 @@ public class JavaDatasetSuite implements Serializable {
           while (right.hasNext()) {
             sb.append(right.next());
           }
-          return Collections.singletonList(sb.toString()).iterator();
+          return Collections.singletonList(sb.toString());
         }
       },
       e.STRING());
@@ -225,7 +237,7 @@ public class JavaDatasetSuite implements Serializable {
     GroupedDataset<Integer, String> grouped = 
ds.groupBy(length(col("value"))).asKey(e.INT());
 
     Dataset<String> mapped = grouped.map(
-      new Function2<Integer, Iterator<String>, String>() {
+      new MapGroupFunction<Integer, String, String>() {
         @Override
         public String call(Integer key, Iterator<String> data) throws 
Exception {
           StringBuilder sb = new StringBuilder(key.toString());


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to