Repository: spark
Updated Branches:
  refs/heads/branch-1.6 7437a7f5b -> 7e06d51d5


[SPARK-11889][SQL] Fix type inference for GroupedDataset.agg in REPL

In this PR I delete a method that breaks type inference for aggregators (only 
in the REPL)

The error when this method is present is:
```
<console>:38: error: missing parameter type for expanded function ((x$2) => 
x$2._2)
              ds.groupBy(_._1).agg(sum(_._2), sum(_._3)).collect()
```

Author: Michael Armbrust <mich...@databricks.com>

Closes #9870 from marmbrus/dataset-repl-agg.

(cherry picked from commit 968acf3bd9a502fcad15df3e53e359695ae702cc)
Signed-off-by: Michael Armbrust <mich...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7e06d51d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7e06d51d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7e06d51d

Branch: refs/heads/branch-1.6
Commit: 7e06d51d5637d4f8e042a1a230ee48591d08236f
Parents: 7437a7f
Author: Michael Armbrust <mich...@databricks.com>
Authored: Fri Nov 20 15:36:30 2015 -0800
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Fri Nov 20 15:36:39 2015 -0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/repl/ReplSuite.scala | 24 +++++++++++++++++
 .../org/apache/spark/sql/GroupedDataset.scala   | 27 +++-----------------
 .../org/apache/spark/sql/JavaDatasetSuite.java  |  8 +++---
 3 files changed, 30 insertions(+), 29 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7e06d51d/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
----------------------------------------------------------------------
diff --git 
a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala 
b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index 081aa03..cbcccb1 100644
--- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -339,6 +339,30 @@ class ReplSuite extends SparkFunSuite {
     }
   }
 
+  test("Datasets agg type-inference") {
+    val output = runInterpreter("local",
+      """
+        |import org.apache.spark.sql.functions._
+        |import org.apache.spark.sql.Encoder
+        |import org.apache.spark.sql.expressions.Aggregator
+        |import org.apache.spark.sql.TypedColumn
+        |/** An `Aggregator` that adds up any numeric type returned by the 
given function. */
+        |class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] 
with Serializable {
+        |  val numeric = implicitly[Numeric[N]]
+        |  override def zero: N = numeric.zero
+        |  override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
+        |  override def merge(b1: N,b2: N): N = numeric.plus(b1, b2)
+        |  override def finish(reduction: N): N = reduction
+        |}
+        |
+        |def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new 
SumOf(f).toColumn
+        |val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS()
+        |ds.groupBy(_._1).agg(sum(_._2), sum(_._3)).collect()
+      """.stripMargin)
+    assertDoesNotContain("error:", output)
+    assertDoesNotContain("Exception", output)
+  }
+
   test("collecting objects of class defined in repl") {
     val output = runInterpreter("local[2]",
       """

http://git-wip-us.apache.org/repos/asf/spark/blob/7e06d51d/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index 6de3dd6..263f049 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -146,31 +146,10 @@ class GroupedDataset[K, T] private[sql](
     reduce(f.call _)
   }
 
-  /**
-   * Compute aggregates by specifying a series of aggregate columns, and 
return a [[DataFrame]].
-   * We can call `as[T : Encoder]` to turn the returned [[DataFrame]] to 
[[Dataset]] again.
-   *
-   * The available aggregate methods are defined in 
[[org.apache.spark.sql.functions]].
-   *
-   * {{{
-   *   // Selects the age of the oldest employee and the aggregate expense for 
each department
-   *
-   *   // Scala:
-   *   import org.apache.spark.sql.functions._
-   *   df.groupBy("department").agg(max("age"), sum("expense"))
-   *
-   *   // Java:
-   *   import static org.apache.spark.sql.functions.*;
-   *   df.groupBy("department").agg(max("age"), sum("expense"));
-   * }}}
-   *
-   * We can also use `Aggregator.toColumn` to pass in typed aggregate 
functions.
-   *
-   * @since 1.6.0
-   */
+  // This is here to prevent us from adding overloads that would be ambiguous.
   @scala.annotation.varargs
-  def agg(expr: Column, exprs: Column*): DataFrame =
-    groupedData.agg(withEncoder(expr), exprs.map(withEncoder): _*)
+  private def agg(exprs: Column*): DataFrame =
+    groupedData.agg(withEncoder(exprs.head), exprs.tail.map(withEncoder): _*)
 
   private def withEncoder(c: Column): Column = c match {
     case tc: TypedColumn[_, _] =>

http://git-wip-us.apache.org/repos/asf/spark/blob/7e06d51d/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java 
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index ce40dd8..f7249b8 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -404,11 +404,9 @@ public class JavaDatasetSuite implements Serializable {
       grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()));
     Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), 
agged.collectAsList());
 
-    Dataset<Tuple4<String, Integer, Long, Long>> agged2 = grouped.agg(
-      new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()),
-      expr("sum(_2)"),
-      count("*"))
-      .as(Encoders.tuple(Encoders.STRING(), Encoders.INT(), Encoders.LONG(), 
Encoders.LONG()));
+    Dataset<Tuple2<String, Integer>> agged2 = grouped.agg(
+      new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()))
+      .as(Encoders.tuple(Encoders.STRING(), Encoders.INT()));
     Assert.assertEquals(
       Arrays.asList(
         new Tuple4<>("a", 3, 3L, 2L),


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to