Repository: spark
Updated Branches:
  refs/heads/master c964fc101 -> 9c57bc0ef


[SPARK-11656][SQL] support typed aggregate in project list

insert `aEncoder` like we do in `agg`

Author: Wenchen Fan <wenc...@databricks.com>

Closes #9630 from cloud-fan/select.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9c57bc0e
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9c57bc0e
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9c57bc0e

Branch: refs/heads/master
Commit: 9c57bc0efce0ac37d8319666f5a8d3e8dce7651c
Parents: c964fc1
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Wed Nov 11 10:21:53 2015 -0800
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Wed Nov 11 10:21:53 2015 -0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/sql/Dataset.scala    | 20 ++++++++++++++++----
 .../spark/sql/DatasetAggregatorSuite.scala      | 11 +++++++++++
 2 files changed, 27 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9c57bc0e/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index a7e5ab1..87dae6b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -21,14 +21,15 @@ import scala.collection.JavaConverters._
 
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
 import org.apache.spark.api.java.function._
 
 import org.apache.spark.sql.catalyst.encoders._
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
 import org.apache.spark.sql.catalyst.plans.Inner
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.execution.{Queryable, QueryExecution}
+import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
 import org.apache.spark.sql.types.StructType
 
 /**
@@ -359,7 +360,7 @@ class Dataset[T] private[sql](
    * @since 1.6.0
    */
   def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = {
-    new Dataset[U1](sqlContext, Project(Alias(c1.expr, "_1")() :: Nil, 
logicalPlan))
+    new Dataset[U1](sqlContext, Project(Alias(withEncoder(c1).expr, "_1")() :: 
Nil, logicalPlan))
   }
 
   /**
@@ -368,11 +369,12 @@ class Dataset[T] private[sql](
    * that cast appropriately for the user facing interface.
    */
   protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
-    val aliases = columns.zipWithIndex.map { case (c, i) => Alias(c.expr, 
s"_${i + 1}")() }
+    val withEncoders = columns.map(withEncoder)
+    val aliases = withEncoders.zipWithIndex.map { case (c, i) => Alias(c.expr, 
s"_${i + 1}")() }
     val unresolvedPlan = Project(aliases, logicalPlan)
     val execution = new QueryExecution(sqlContext, unresolvedPlan)
     // Rebind the encoders to the nested schema that will be produced by the 
select.
-    val encoders = 
columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]).zip(aliases).map {
+    val encoders = 
withEncoders.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]).zip(aliases).map 
{
       case (e: ExpressionEncoder[_], a) if !e.flat =>
         e.nested(a.toAttribute).resolve(execution.analyzed.output)
       case (e, a) =>
@@ -381,6 +383,16 @@ class Dataset[T] private[sql](
     new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
   }
 
+  private def withEncoder(c: TypedColumn[_, _]): TypedColumn[_, _] = {
+    val e = c.expr transform {
+      case ta: TypedAggregateExpression if ta.aEncoder.isEmpty =>
+        ta.copy(
+          aEncoder = Some(encoder.asInstanceOf[ExpressionEncoder[Any]]),
+          children = queryExecution.analyzed.output)
+    }
+    new TypedColumn(e, c.encoder)
+  }
+
   /**
    * Returns a new [[Dataset]] by computing the given [[Column]] expressions 
for each element.
    * @since 1.6.0

http://git-wip-us.apache.org/repos/asf/spark/blob/9c57bc0e/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index 002d5c1..d4f0ab7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -114,4 +114,15 @@ class DatasetAggregatorSuite extends QueryTest with 
SharedSQLContext {
         ComplexResultAgg.toColumn),
       ("a", 2.0, (2L, 4L)), ("b", 3.0, (1L, 3L)))
   }
+
+  test("typed aggregation: in project list") {
+    val ds = Seq(1, 3, 2, 5).toDS()
+
+    checkAnswer(
+      ds.select(sum((i: Int) => i)),
+      11)
+    checkAnswer(
+      ds.select(sum((i: Int) => i), sum((i: Int) => i * 2)),
+      11 -> 22)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to