Repository: spark Updated Branches: refs/heads/master c964fc101 -> 9c57bc0ef
[SPARK-11656][SQL] support typed aggregate in project list insert `aEncoder` like we do in `agg` Author: Wenchen Fan <wenc...@databricks.com> Closes #9630 from cloud-fan/select. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9c57bc0e Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9c57bc0e Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9c57bc0e Branch: refs/heads/master Commit: 9c57bc0efce0ac37d8319666f5a8d3e8dce7651c Parents: c964fc1 Author: Wenchen Fan <wenc...@databricks.com> Authored: Wed Nov 11 10:21:53 2015 -0800 Committer: Michael Armbrust <mich...@databricks.com> Committed: Wed Nov 11 10:21:53 2015 -0800 ---------------------------------------------------------------------- .../scala/org/apache/spark/sql/Dataset.scala | 20 ++++++++++++++++---- .../spark/sql/DatasetAggregatorSuite.scala | 11 +++++++++++ 2 files changed, 27 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/9c57bc0e/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a7e5ab1..87dae6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -21,14 +21,15 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{Queryable, QueryExecution} +import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.types.StructType /** @@ -359,7 +360,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = { - new Dataset[U1](sqlContext, Project(Alias(c1.expr, "_1")() :: Nil, logicalPlan)) + new Dataset[U1](sqlContext, Project(Alias(withEncoder(c1).expr, "_1")() :: Nil, logicalPlan)) } /** @@ -368,11 +369,12 @@ class Dataset[T] private[sql]( * that cast appropriately for the user facing interface. */ protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { - val aliases = columns.zipWithIndex.map { case (c, i) => Alias(c.expr, s"_${i + 1}")() } + val withEncoders = columns.map(withEncoder) + val aliases = withEncoders.zipWithIndex.map { case (c, i) => Alias(c.expr, s"_${i + 1}")() } val unresolvedPlan = Project(aliases, logicalPlan) val execution = new QueryExecution(sqlContext, unresolvedPlan) // Rebind the encoders to the nested schema that will be produced by the select. - val encoders = columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]).zip(aliases).map { + val encoders = withEncoders.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]).zip(aliases).map { case (e: ExpressionEncoder[_], a) if !e.flat => e.nested(a.toAttribute).resolve(execution.analyzed.output) case (e, a) => @@ -381,6 +383,16 @@ class Dataset[T] private[sql]( new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) } + private def withEncoder(c: TypedColumn[_, _]): TypedColumn[_, _] = { + val e = c.expr transform { + case ta: TypedAggregateExpression if ta.aEncoder.isEmpty => + ta.copy( + aEncoder = Some(encoder.asInstanceOf[ExpressionEncoder[Any]]), + children = queryExecution.analyzed.output) + } + new TypedColumn(e, c.encoder) + } + /** * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. * @since 1.6.0 http://git-wip-us.apache.org/repos/asf/spark/blob/9c57bc0e/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 002d5c1..d4f0ab7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -114,4 +114,15 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { ComplexResultAgg.toColumn), ("a", 2.0, (2L, 4L)), ("b", 3.0, (1L, 3L))) } + + test("typed aggregation: in project list") { + val ds = Seq(1, 3, 2, 5).toDS() + + checkAnswer( + ds.select(sum((i: Int) => i)), + 11) + checkAnswer( + ds.select(sum((i: Int) => i), sum((i: Int) => i * 2)), + 11 -> 22) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org