Repository: spark Updated Branches: refs/heads/master 7de06a646 -> 2f1d0320c
[SPARK-13363][SQL] support Aggregator in RelationalGroupedDataset ## What changes were proposed in this pull request? set the input encoder for `TypedColumn` in `RelationalGroupedDataset.agg`. ## How was this patch tested? new tests in `DatasetAggregatorSuite` close https://github.com/apache/spark/pull/11269 This PR brings https://github.com/apache/spark/pull/12359 up to date and fix the compile. Author: Wenchen Fan <wenc...@databricks.com> Closes #12451 from cloud-fan/agg. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/2f1d0320 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/2f1d0320 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/2f1d0320 Branch: refs/heads/master Commit: 2f1d0320c97f064556fa1cf98d4e30d2ab2fe661 Parents: 7de06a6 Author: Wenchen Fan <wenc...@databricks.com> Authored: Mon Apr 18 14:27:26 2016 +0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Mon Apr 18 14:27:26 2016 +0800 ---------------------------------------------------------------------- .../apache/spark/sql/RelationalGroupedDataset.scala | 6 +++++- .../org/apache/spark/sql/DatasetAggregatorSuite.scala | 14 +++++++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/2f1d0320/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 7dbf2e6..0ffb136 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -208,7 +208,11 @@ class RelationalGroupedDataset protected[sql]( */ @scala.annotation.varargs def agg(expr: Column, exprs: Column*): DataFrame = { - toDF((expr +: exprs).map(_.expr)) + toDF((expr +: exprs).map { + case typed: TypedColumn[_, _] => + typed.withInputType(df.unresolvedTEncoder.deserializer, df.logicalPlan.output).expr + case c => c.expr + }) } /** http://git-wip-us.apache.org/repos/asf/spark/blob/2f1d0320/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 3a7215e..0d84a59 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql import scala.language.postfixOps -import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.expressions.Aggregator import org.apache.spark.sql.expressions.scala.typed import org.apache.spark.sql.functions._ @@ -85,6 +84,15 @@ class ParameterizedTypeSum[IN, OUT : Numeric : Encoder](f: IN => OUT) override def outputEncoder: Encoder[OUT] = implicitly[Encoder[OUT]] } +object RowAgg extends Aggregator[Row, Int, Int] { + def zero: Int = 0 + def reduce(b: Int, a: Row): Int = a.getInt(0) + b + def merge(b1: Int, b2: Int): Int = b1 + b2 + def finish(r: Int): Int = r + override def bufferEncoder: Encoder[Int] = Encoders.scalaInt + override def outputEncoder: Encoder[Int] = Encoders.scalaInt +} + class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { @@ -200,4 +208,8 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { (1279869254, "Some String")) } + test("aggregator in DataFrame/Dataset[Row]") { + val df = Seq(1 -> "a", 2 -> "b", 3 -> "b").toDF("i", "j") + checkAnswer(df.groupBy($"j").agg(RowAgg.toColumn), Row("a", 1) :: Row("b", 5) :: Nil) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org