Repository: spark
Updated Branches:
  refs/heads/branch-1.6 e042780cd -> 4f8c7e18f


[SPARK-11625][SQL] add java test for typed aggregate

Author: Wenchen Fan <wenc...@databricks.com>

Closes #9591 from cloud-fan/agg-test.

(cherry picked from commit fd14936be7beff543dbbcf270f2f9749f7a803c4)
Signed-off-by: Michael Armbrust <mich...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4f8c7e18
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4f8c7e18
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4f8c7e18

Branch: refs/heads/branch-1.6
Commit: 4f8c7e18f3103ee1fcf5a79c1d39cf5a81e78c87
Parents: e042780
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Mon Nov 16 15:32:49 2015 -0800
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Mon Nov 16 15:35:21 2015 -0800

----------------------------------------------------------------------
 .../spark/api/java/function/Function.java       |  2 +-
 .../org/apache/spark/sql/GroupedDataset.scala   | 34 ++++++++++--
 .../spark/sql/expressions/Aggregator.scala      |  2 +-
 .../org/apache/spark/sql/JavaDatasetSuite.java  | 56 ++++++++++++++++++++
 .../spark/sql/DatasetAggregatorSuite.scala      |  7 ++-
 5 files changed, 92 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/4f8c7e18/core/src/main/java/org/apache/spark/api/java/function/Function.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/api/java/function/Function.java 
b/core/src/main/java/org/apache/spark/api/java/function/Function.java
index d00551b..b9d9777 100644
--- a/core/src/main/java/org/apache/spark/api/java/function/Function.java
+++ b/core/src/main/java/org/apache/spark/api/java/function/Function.java
@@ -25,5 +25,5 @@ import java.io.Serializable;
  * when mapping RDDs of other types.
  */
 public interface Function<T1, R> extends Serializable {
-  public R call(T1 v1) throws Exception;
+  R call(T1 v1) throws Exception;
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/4f8c7e18/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index ebcf4c8..467cd42 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -145,9 +145,37 @@ class GroupedDataset[K, T] private[sql](
     reduce(f.call _)
   }
 
-  // To ensure valid overloading.
-  protected def agg(expr: Column, exprs: Column*): DataFrame =
-    groupedData.agg(expr, exprs: _*)
+  /**
+   * Compute aggregates by specifying a series of aggregate columns, and 
return a [[DataFrame]].
+   * We can call `as[T : Encoder]` to turn the returned [[DataFrame]] to 
[[Dataset]] again.
+   *
+   * The available aggregate methods are defined in 
[[org.apache.spark.sql.functions]].
+   *
+   * {{{
+   *   // Selects the age of the oldest employee and the aggregate expense for 
each department
+   *
+   *   // Scala:
+   *   import org.apache.spark.sql.functions._
+   *   df.groupBy("department").agg(max("age"), sum("expense"))
+   *
+   *   // Java:
+   *   import static org.apache.spark.sql.functions.*;
+   *   df.groupBy("department").agg(max("age"), sum("expense"));
+   * }}}
+   *
+   * We can also use `Aggregator.toColumn` to pass in typed aggregate 
functions.
+   *
+   * @since 1.6.0
+   */
+  @scala.annotation.varargs
+  def agg(expr: Column, exprs: Column*): DataFrame =
+    groupedData.agg(withEncoder(expr), exprs.map(withEncoder): _*)
+
+  private def withEncoder(c: Column): Column = c match {
+    case tc: TypedColumn[_, _] =>
+      tc.withInputType(resolvedTEncoder.bind(dataAttributes), dataAttributes)
+    case _ => c
+  }
 
   /**
    * Internal helper function for building typed aggregations that return 
tuples.  For simplicity

http://git-wip-us.apache.org/repos/asf/spark/blob/4f8c7e18/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
index 360c9a5..72610e7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala
@@ -47,7 +47,7 @@ import org.apache.spark.sql.{Dataset, DataFrame, TypedColumn}
  * @tparam B The type of the intermediate value of the reduction.
  * @tparam C The type of the final result.
  */
-abstract class Aggregator[-A, B, C] {
+abstract class Aggregator[-A, B, C] extends Serializable {
 
   /** A zero value for this aggregation. Should satisfy the property that any 
b + zero = b */
   def zero: B

http://git-wip-us.apache.org/repos/asf/spark/blob/4f8c7e18/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java 
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index eb6fa1e..d9b2250 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -34,6 +34,7 @@ import org.apache.spark.sql.Encoder;
 import org.apache.spark.sql.Encoders;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.GroupedDataset;
+import org.apache.spark.sql.expressions.Aggregator;
 import org.apache.spark.sql.test.TestSQLContext;
 
 import static org.apache.spark.sql.functions.*;
@@ -381,4 +382,59 @@ public class JavaDatasetSuite implements Serializable {
       context.createDataset(data3, encoder3);
     Assert.assertEquals(data3, ds3.collectAsList());
   }
+
+  @Test
+  public void testTypedAggregation() {
+    Encoder<Tuple2<String, Integer>> encoder = 
Encoders.tuple(Encoders.STRING(), Encoders.INT());
+    List<Tuple2<String, Integer>> data =
+      Arrays.asList(tuple2("a", 1), tuple2("a", 2), tuple2("b", 3));
+    Dataset<Tuple2<String, Integer>> ds = context.createDataset(data, encoder);
+
+    GroupedDataset<String, Tuple2<String, Integer>> grouped = ds.groupBy(
+      new MapFunction<Tuple2<String, Integer>, String>() {
+        @Override
+        public String call(Tuple2<String, Integer> value) throws Exception {
+          return value._1();
+        }
+      },
+      Encoders.STRING());
+
+    Dataset<Tuple2<String, Integer>> agged =
+      grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()));
+    Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), 
agged.collectAsList());
+
+    Dataset<Tuple4<String, Integer, Long, Long>> agged2 = grouped.agg(
+      new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()),
+      expr("sum(_2)"),
+      count("*"))
+      .as(Encoders.tuple(Encoders.STRING(), Encoders.INT(), Encoders.LONG(), 
Encoders.LONG()));
+    Assert.assertEquals(
+      Arrays.asList(
+        new Tuple4<String, Integer, Long, Long>("a", 3, 3L, 2L),
+        new Tuple4<String, Integer, Long, Long>("b", 3, 3L, 1L)),
+      agged2.collectAsList());
+  }
+
+  static class IntSumOf extends Aggregator<Tuple2<String, Integer>, Integer, 
Integer> {
+
+    @Override
+    public Integer zero() {
+      return 0;
+    }
+
+    @Override
+    public Integer reduce(Integer l, Tuple2<String, Integer> t) {
+      return l + t._2();
+    }
+
+    @Override
+    public Integer merge(Integer b1, Integer b2) {
+      return b1 + b2;
+    }
+
+    @Override
+    public Integer finish(Integer reduction) {
+      return reduction;
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/4f8c7e18/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index 46f9f07..9377589 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.functions._
 import org.apache.spark.sql.expressions.Aggregator
 
 /** An `Aggregator` that adds up any numeric type returned by the given 
function. */
-class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with 
Serializable {
+class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] {
   val numeric = implicitly[Numeric[N]]
 
   override def zero: N = numeric.zero
@@ -37,7 +37,7 @@ class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, 
N, N] with Serializ
   override def finish(reduction: N): N = reduction
 }
 
-object TypedAverage extends Aggregator[(String, Int), (Long, Long), Double] 
with Serializable {
+object TypedAverage extends Aggregator[(String, Int), (Long, Long), Double] {
   override def zero: (Long, Long) = (0, 0)
 
   override def reduce(countAndSum: (Long, Long), input: (String, Int)): (Long, 
Long) = {
@@ -51,8 +51,7 @@ object TypedAverage extends Aggregator[(String, Int), (Long, 
Long), Double] with
   override def finish(countAndSum: (Long, Long)): Double = countAndSum._2 / 
countAndSum._1
 }
 
-object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, 
Long)]
-  with Serializable {
+object ComplexResultAgg extends Aggregator[(String, Int), (Long, Long), (Long, 
Long)] {
 
   override def zero: (Long, Long) = (0, 0)
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to