Repository: spark Updated Branches: refs/heads/branch-1.6 1585f559d -> b9adfdf9c
[SPARK-11564][SQL][FOLLOW-UP] improve java api for GroupedDataset created `MapGroupFunction`, `FlatMapGroupFunction`, `CoGroupFunction` Author: Wenchen Fan <wenc...@databricks.com> Closes #9564 from cloud-fan/map. (cherry picked from commit fcb57e9c7323e24b8563800deb035f94f616474e) Signed-off-by: Michael Armbrust <mich...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b9adfdf9 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b9adfdf9 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b9adfdf9 Branch: refs/heads/branch-1.6 Commit: b9adfdf9ca18292799e684c8510028c75fbf2808 Parents: 1585f55 Author: Wenchen Fan <wenc...@databricks.com> Authored: Mon Nov 9 15:16:47 2015 -0800 Committer: Michael Armbrust <mich...@databricks.com> Committed: Mon Nov 9 15:17:08 2015 -0800 ---------------------------------------------------------------------- .../api/java/function/CoGroupFunction.java | 29 ++++++++++++++++ .../api/java/function/FlatMapFunction.java | 2 +- .../api/java/function/FlatMapFunction2.java | 2 +- .../api/java/function/FlatMapGroupFunction.java | 28 +++++++++++++++ .../api/java/function/MapGroupFunction.java | 28 +++++++++++++++ .../catalyst/plans/logical/basicOperators.scala | 4 +-- .../org/apache/spark/sql/GroupedDataset.scala | 12 +++---- .../spark/sql/execution/basicOperators.scala | 2 +- .../org/apache/spark/sql/JavaDatasetSuite.java | 36 +++++++++++++------- 9 files changed, 118 insertions(+), 25 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/b9adfdf9/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java b/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java new file mode 100644 index 0000000..279639a --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/CoGroupFunction.java @@ -0,0 +1,29 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * A function that returns zero or more output records from each grouping key and its values from 2 + * Datasets. + */ +public interface CoGroupFunction<K, V1, V2, R> extends Serializable { + Iterable<R> call(K key, Iterator<V1> left, Iterator<V2> right) throws Exception; +} http://git-wip-us.apache.org/repos/asf/spark/blob/b9adfdf9/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java index 23f5fdd..ef0d182 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction.java @@ -23,5 +23,5 @@ import java.io.Serializable; * A function that returns zero or more output records from each input record. */ public interface FlatMapFunction<T, R> extends Serializable { - public Iterable<R> call(T t) throws Exception; + Iterable<R> call(T t) throws Exception; } http://git-wip-us.apache.org/repos/asf/spark/blob/b9adfdf9/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java index c48e92f..14a98a3 100644 --- a/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapFunction2.java @@ -23,5 +23,5 @@ import java.io.Serializable; * A function that takes two inputs and returns zero or more output records. */ public interface FlatMapFunction2<T1, T2, R> extends Serializable { - public Iterable<R> call(T1 t1, T2 t2) throws Exception; + Iterable<R> call(T1 t1, T2 t2) throws Exception; } http://git-wip-us.apache.org/repos/asf/spark/blob/b9adfdf9/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupFunction.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupFunction.java b/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupFunction.java new file mode 100644 index 0000000..18a2d73 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/FlatMapGroupFunction.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * A function that returns zero or more output records from each grouping key and its values. + */ +public interface FlatMapGroupFunction<K, V, R> extends Serializable { + Iterable<R> call(K key, Iterator<V> values) throws Exception; +} http://git-wip-us.apache.org/repos/asf/spark/blob/b9adfdf9/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java b/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java new file mode 100644 index 0000000..2935f99 --- /dev/null +++ b/core/src/main/java/org/apache/spark/api/java/function/MapGroupFunction.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.java.function; + +import java.io.Serializable; +import java.util.Iterator; + +/** + * Base interface for a map function used in GroupedDataset's map function. + */ +public interface MapGroupFunction<K, V, R> extends Serializable { + R call(K key, Iterator<V> values) throws Exception; +} http://git-wip-us.apache.org/repos/asf/spark/blob/b9adfdf9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index e151ac0..d771088 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -527,7 +527,7 @@ case class MapGroups[K, T, U]( /** Factory for constructing new `CoGroup` nodes. */ object CoGroup { def apply[K : Encoder, Left : Encoder, Right : Encoder, R : Encoder]( - func: (K, Iterator[Left], Iterator[Right]) => Iterator[R], + func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], left: LogicalPlan, @@ -551,7 +551,7 @@ object CoGroup { * right children. */ case class CoGroup[K, Left, Right, R]( - func: (K, Iterator[Left], Iterator[Right]) => Iterator[R], + func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R], kEncoder: ExpressionEncoder[K], leftEnc: ExpressionEncoder[Left], rightEnc: ExpressionEncoder[Right], http://git-wip-us.apache.org/repos/asf/spark/blob/b9adfdf9/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 5c3f626..850315e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -108,9 +108,7 @@ class GroupedDataset[K, T] private[sql]( MapGroups(f, groupingAttributes, logicalPlan)) } - def flatMap[U]( - f: JFunction2[K, JIterator[T], JIterator[U]], - encoder: Encoder[U]): Dataset[U] = { + def flatMap[U](f: FlatMapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = { flatMap((key, data) => f.call(key, data.asJava).asScala)(encoder) } @@ -131,9 +129,7 @@ class GroupedDataset[K, T] private[sql]( MapGroups(func, groupingAttributes, logicalPlan)) } - def map[U]( - f: JFunction2[K, JIterator[T], U], - encoder: Encoder[U]): Dataset[U] = { + def map[U](f: MapGroupFunction[K, T, U], encoder: Encoder[U]): Dataset[U] = { map((key, data) => f.call(key, data.asJava))(encoder) } @@ -218,7 +214,7 @@ class GroupedDataset[K, T] private[sql]( */ def cogroup[U, R : Encoder]( other: GroupedDataset[K, U])( - f: (K, Iterator[T], Iterator[U]) => Iterator[R]): Dataset[R] = { + f: (K, Iterator[T], Iterator[U]) => TraversableOnce[R]): Dataset[R] = { implicit def uEnc: Encoder[U] = other.tEncoder new Dataset[R]( sqlContext, @@ -232,7 +228,7 @@ class GroupedDataset[K, T] private[sql]( def cogroup[U, R]( other: GroupedDataset[K, U], - f: JFunction3[K, JIterator[T], JIterator[U], JIterator[R]], + f: CoGroupFunction[K, T, U, R], encoder: Encoder[R]): Dataset[R] = { cogroup(other)((key, left, right) => f.call(key, left.asJava, right.asJava).asScala)(encoder) } http://git-wip-us.apache.org/repos/asf/spark/blob/b9adfdf9/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 2593b16..145de0d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -391,7 +391,7 @@ case class MapGroups[K, T, U]( * The result of this function is encoded and flattened before being output. */ case class CoGroup[K, Left, Right, R]( - func: (K, Iterator[Left], Iterator[Right]) => Iterator[R], + func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R], kEncoder: ExpressionEncoder[K], leftEnc: ExpressionEncoder[Left], rightEnc: ExpressionEncoder[Right], http://git-wip-us.apache.org/repos/asf/spark/blob/b9adfdf9/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 0f90de7..312cf33 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -29,7 +29,6 @@ import org.junit.*; import org.apache.spark.Accumulator; import org.apache.spark.SparkContext; import org.apache.spark.api.java.function.*; -import org.apache.spark.api.java.function.Function; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.catalyst.encoders.Encoder; import org.apache.spark.sql.catalyst.encoders.Encoder$; @@ -170,20 +169,33 @@ public class JavaDatasetSuite implements Serializable { } }, e.INT()); - Dataset<String> mapped = grouped.map( - new Function2<Integer, Iterator<String>, String>() { + Dataset<String> mapped = grouped.map(new MapGroupFunction<Integer, String, String>() { + @Override + public String call(Integer key, Iterator<String> values) throws Exception { + StringBuilder sb = new StringBuilder(key.toString()); + while (values.hasNext()) { + sb.append(values.next()); + } + return sb.toString(); + } + }, e.STRING()); + + Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); + + Dataset<String> flatMapped = grouped.flatMap( + new FlatMapGroupFunction<Integer, String, String>() { @Override - public String call(Integer key, Iterator<String> data) throws Exception { + public Iterable<String> call(Integer key, Iterator<String> values) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); - while (data.hasNext()) { - sb.append(data.next()); + while (values.hasNext()) { + sb.append(values.next()); } - return sb.toString(); + return Collections.singletonList(sb.toString()); } }, e.STRING()); - Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList()); + Assert.assertEquals(Arrays.asList("1a", "3foobar"), flatMapped.collectAsList()); List<Integer> data2 = Arrays.asList(2, 6, 10); Dataset<Integer> ds2 = context.createDataset(data2, e.INT()); @@ -196,9 +208,9 @@ public class JavaDatasetSuite implements Serializable { Dataset<String> cogrouped = grouped.cogroup( grouped2, - new Function3<Integer, Iterator<String>, Iterator<Integer>, Iterator<String>>() { + new CoGroupFunction<Integer, String, Integer, String>() { @Override - public Iterator<String> call( + public Iterable<String> call( Integer key, Iterator<String> left, Iterator<Integer> right) throws Exception { @@ -210,7 +222,7 @@ public class JavaDatasetSuite implements Serializable { while (right.hasNext()) { sb.append(right.next()); } - return Collections.singletonList(sb.toString()).iterator(); + return Collections.singletonList(sb.toString()); } }, e.STRING()); @@ -225,7 +237,7 @@ public class JavaDatasetSuite implements Serializable { GroupedDataset<Integer, String> grouped = ds.groupBy(length(col("value"))).asKey(e.INT()); Dataset<String> mapped = grouped.map( - new Function2<Integer, Iterator<String>, String>() { + new MapGroupFunction<Integer, String, String>() { @Override public String call(Integer key, Iterator<String> data) throws Exception { StringBuilder sb = new StringBuilder(key.toString()); --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org