Repository: spark
Updated Branches:
  refs/heads/master eec74ba8b -> 363a476c3


[SPARK-11528] [SQL] Typed aggregations for Datasets

This PR adds the ability to do typed SQL aggregations.  We will likely also 
want to provide an interface to allow users to do aggregations on objects, but 
this is deferred to another PR.

```scala
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
ds.groupBy(_._1).agg(sum("_2").as[Int]).collect()

res0: Array(("a", 30), ("b", 3), ("c", 1))
```

Author: Michael Armbrust <mich...@databricks.com>

Closes #9499 from marmbrus/dataset-agg.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/363a476c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/363a476c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/363a476c

Branch: refs/heads/master
Commit: 363a476c3fefb0263e63fd24df0b2779a64f79ec
Parents: eec74ba
Author: Michael Armbrust <mich...@databricks.com>
Authored: Thu Nov 5 21:42:32 2015 -0800
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Thu Nov 5 21:42:32 2015 -0800

----------------------------------------------------------------------
 .../catalyst/expressions/namedExpressions.scala |  4 +
 .../scala/org/apache/spark/sql/Dataset.scala    |  2 +-
 .../org/apache/spark/sql/GroupedDataset.scala   | 93 +++++++++++++++++++-
 .../org/apache/spark/sql/DatasetSuite.scala     | 36 ++++++++
 4 files changed, 132 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/363a476c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 8957df0..9ab5c29 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -254,6 +254,10 @@ case class AttributeReference(
   }
 
   override def toString: String = s"$name#${exprId.id}$typeSuffix"
+
+  // Since the expression id is not in the first constructor it is missing 
from the default
+  // tree string.
+  override def simpleString: String = s"$name#${exprId.id}: 
${dataType.simpleString}"
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/363a476c/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 500227e..4bca9c3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -55,7 +55,7 @@ import org.apache.spark.sql.types.StructType
  * @since 1.6.0
  */
 @Experimental
-class Dataset[T] private(
+class Dataset[T] private[sql](
     @transient val sqlContext: SQLContext,
     @transient val queryExecution: QueryExecution,
     unresolvedEncoder: Encoder[T]) extends Serializable {

http://git-wip-us.apache.org/repos/asf/spark/blob/363a476c/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index 96d6e9d..b8fc373 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -17,16 +17,25 @@
 
 package org.apache.spark.sql
 
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, 
UnresolvedAttribute}
 import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, 
Encoder}
-import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, 
Alias, Attribute}
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.execution.QueryExecution
 
 /**
+ * :: Experimental ::
  * A [[Dataset]] has been logically grouped by a user specified grouping key.  
Users should not
  * construct a [[GroupedDataset]] directly, but should instead call `groupBy` 
on an existing
  * [[Dataset]].
+ *
+ * COMPATIBILITY NOTE: Long term we plan to make [[GroupedDataset)]] extend 
`GroupedData`.  However,
+ * making this change to the class hierarchy would break some function 
signatures. As such, this
+ * class should be considered a preview of the final API.  Changes will be 
made to the interface
+ * after Spark 1.6.
  */
+@Experimental
 class GroupedDataset[K, T] private[sql](
     private val kEncoder: Encoder[K],
     private val tEncoder: Encoder[T],
@@ -35,7 +44,7 @@ class GroupedDataset[K, T] private[sql](
     private val groupingAttributes: Seq[Attribute]) extends Serializable {
 
   private implicit val kEnc = kEncoder match {
-    case e: ExpressionEncoder[K] => e.resolve(groupingAttributes)
+    case e: ExpressionEncoder[K] => 
e.unbind(groupingAttributes).resolve(groupingAttributes)
     case other =>
       throw new UnsupportedOperationException("Only expression encoders are 
currently supported")
   }
@@ -46,9 +55,16 @@ class GroupedDataset[K, T] private[sql](
       throw new UnsupportedOperationException("Only expression encoders are 
currently supported")
   }
 
+  /** Encoders for built in aggregations. */
+  private implicit def newLongEncoder: Encoder[Long] = 
ExpressionEncoder[Long](flat = true)
+
   private def logicalPlan = queryExecution.analyzed
   private def sqlContext = queryExecution.sqlContext
 
+  private def groupedData =
+    new GroupedData(
+      new DataFrame(sqlContext, logicalPlan), groupingAttributes, 
GroupedData.GroupByType)
+
   /**
    * Returns a new [[GroupedDataset]] where the type of the key has been 
mapped to the specified
    * type. The mapping of key columns to the type follows the same rules as 
`as` on [[Dataset]].
@@ -88,6 +104,79 @@ class GroupedDataset[K, T] private[sql](
       MapGroups(f, groupingAttributes, logicalPlan))
   }
 
+  // To ensure valid overloading.
+  protected def agg(expr: Column, exprs: Column*): DataFrame =
+    groupedData.agg(expr, exprs: _*)
+
+  /**
+   * Internal helper function for building typed aggregations that return 
tuples.  For simplicity
+   * and code reuse, we do this without the help of the type system and then 
use helper functions
+   * that cast appropriately for the user facing interface.
+   * TODO: does not handle aggrecations that return nonflat results,
+   */
+  protected def aggUntyped(columns: TypedColumn[_]*): Dataset[_] = {
+    val aliases = (groupingAttributes ++ columns.map(_.expr)).map {
+      case u: UnresolvedAttribute => UnresolvedAlias(u)
+      case expr: NamedExpression => expr
+      case expr: Expression => Alias(expr, expr.prettyString)()
+    }
+
+    val unresolvedPlan = Aggregate(groupingAttributes, aliases, logicalPlan)
+    val execution = new QueryExecution(sqlContext, unresolvedPlan)
+
+    val columnEncoders = 
columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]])
+
+    // Rebind the encoders to the nested schema that will be produced by the 
aggregation.
+    val encoders = (kEnc +: columnEncoders).zip(execution.analyzed.output).map 
{
+      case (e: ExpressionEncoder[_], a) if !e.flat =>
+        e.nested(a).resolve(execution.analyzed.output)
+      case (e, a) =>
+        e.unbind(a :: Nil).resolve(execution.analyzed.output)
+    }
+    new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
+  }
+
+  /**
+   * Computes the given aggregation, returning a [[Dataset]] of tuples for 
each unique key
+   * and the result of computing this aggregation over all elements in the 
group.
+   */
+  def agg[A1](col1: TypedColumn[A1]): Dataset[(K, A1)] =
+    aggUntyped(col1).asInstanceOf[Dataset[(K, A1)]]
+
+  /**
+   * Computes the given aggregations, returning a [[Dataset]] of tuples for 
each unique key
+   * and the result of computing these aggregations over all elements in the 
group.
+   */
+  def agg[A1, A2](col1: TypedColumn[A1], col2: TypedColumn[A2]): Dataset[(K, 
A1, A2)] =
+    aggUntyped(col1, col2).asInstanceOf[Dataset[(K, A1, A2)]]
+
+  /**
+   * Computes the given aggregations, returning a [[Dataset]] of tuples for 
each unique key
+   * and the result of computing these aggregations over all elements in the 
group.
+   */
+  def agg[A1, A2, A3](
+      col1: TypedColumn[A1],
+      col2: TypedColumn[A2],
+      col3: TypedColumn[A3]): Dataset[(K, A1, A2, A3)] =
+    aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, A1, A2, A3)]]
+
+  /**
+   * Computes the given aggregations, returning a [[Dataset]] of tuples for 
each unique key
+   * and the result of computing these aggregations over all elements in the 
group.
+   */
+  def agg[A1, A2, A3, A4](
+      col1: TypedColumn[A1],
+      col2: TypedColumn[A2],
+      col3: TypedColumn[A3],
+      col4: TypedColumn[A4]): Dataset[(K, A1, A2, A3, A4)] =
+    aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, A1, A2, A3, 
A4)]]
+
+  /**
+   * Returns a [[Dataset]] that contains a tuple with each key and the number 
of items present
+   * for that key.
+   */
+  def count(): Dataset[(K, Long)] = agg(functions.count("*").as[Long])
+
   /**
    * Applies the given function to each cogrouped data.  For each unique 
group, the function will
    * be passed the grouping key and 2 iterators containing all elements in the 
group from

http://git-wip-us.apache.org/repos/asf/spark/blob/363a476c/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 3e9b621..d61e17e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -258,6 +258,42 @@ class DatasetSuite extends QueryTest with SharedSQLContext 
{
       (ClassData("a", 1), 30), (ClassData("b", 1), 3), (ClassData("c", 1), 1))
   }
 
+  test("typed aggregation: expr") {
+    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+
+    checkAnswer(
+      ds.groupBy(_._1).agg(sum("_2").as[Int]),
+      ("a", 30), ("b", 3), ("c", 1))
+  }
+
+  test("typed aggregation: expr, expr") {
+    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+
+    checkAnswer(
+      ds.groupBy(_._1).agg(sum("_2").as[Int], sum($"_2" + 1).as[Long]),
+      ("a", 30, 32L), ("b", 3, 5L), ("c", 1, 2L))
+  }
+
+  test("typed aggregation: expr, expr, expr") {
+    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+
+    checkAnswer(
+      ds.groupBy(_._1).agg(sum("_2").as[Int], sum($"_2" + 1).as[Long], 
count("*").as[Long]),
+      ("a", 30, 32L, 2L), ("b", 3, 5L, 2L), ("c", 1, 2L, 1L))
+  }
+
+  test("typed aggregation: expr, expr, expr, expr") {
+    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+
+    checkAnswer(
+      ds.groupBy(_._1).agg(
+        sum("_2").as[Int],
+        sum($"_2" + 1).as[Long],
+        count("*").as[Long],
+        avg("_2").as[Double]),
+      ("a", 30, 32L, 2L, 15.0), ("b", 3, 5L, 2L, 1.5), ("c", 1, 2L, 1L, 1.0))
+  }
+
   test("cogroup") {
     val ds1 = Seq(1 -> "a", 3 -> "abc", 5 -> "hello", 3 -> "foo").toDS()
     val ds2 = Seq(2 -> "q", 3 -> "w", 5 -> "e", 5 -> "r").toDS()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to