Repository: spark
Updated Branches:
  refs/heads/master 425ff03f5 -> b86f2cab6


[SPARK-11404] [SQL] Support for groupBy using column expressions

This PR adds a new method `groupBy(cols: Column*)` to `Dataset` that allows 
users to group using column expressions instead of a lambda function.  Since 
the return type of these expressions is not known at compile time, we just set 
the key type as a generic `Row`.  If the user would like to work the key in a 
type-safe way, they can call `grouped.asKey[Type]`, which is also added in this 
PR.

```scala
val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
val grouped = ds.groupBy($"_1").asKey[String]
val agged = grouped.mapGroups { case (g, iter) =>
  Iterator((g, iter.map(_._2).sum))
}

agged.collect()

res0: Array(("a", 30), ("b", 3), ("c", 1))
```

Author: Michael Armbrust <mich...@databricks.com>

Closes #9359 from marmbrus/columnGroupBy and squashes the following commits:

bbcb03b [Michael Armbrust] Update DatasetSuite.scala
8fd2908 [Michael Armbrust] Update DatasetSuite.scala
0b0e2f8 [Michael Armbrust] [SPARK-11404] [SQL] Support for groupBy using column 
expressions


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b86f2cab
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b86f2cab
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b86f2cab

Branch: refs/heads/master
Commit: b86f2cab67989f09ba1ba8604e52cd4b1e44e436
Parents: 425ff03
Author: Michael Armbrust <mich...@databricks.com>
Authored: Tue Nov 3 13:02:17 2015 +0100
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Tue Nov 3 13:02:17 2015 +0100

----------------------------------------------------------------------
 .../scala/org/apache/spark/sql/Dataset.scala    | 36 +++++++++++++--
 .../org/apache/spark/sql/GroupedDataset.scala   | 28 ++++++++++--
 .../org/apache/spark/sql/DatasetSuite.scala     | 48 ++++++++++++++++++++
 3 files changed, 106 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b86f2cab/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index ed98a25..7b75aee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
 
 import org.apache.spark.annotation.Experimental
 import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
 import org.apache.spark.sql.catalyst.encoders._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.Inner
@@ -78,9 +79,17 @@ class Dataset[T] private(
    * ************* */
 
   /**
-   * Returns a new `Dataset` where each record has been mapped on to the 
specified type.
-   * TODO: should bind here...
-   * TODO: document binding rules
+   * Returns a new `Dataset` where each record has been mapped on to the 
specified type.  The
+   * method used to map columns depend on the type of `U`:
+   *  - When `U` is a class, fields for the class will be mapped to columns of 
the same name
+   *    (case sensitivity is determined by `spark.sql.caseSensitive`)
+   *  - When `U` is a tuple, the columns will be be mapped by ordinal (i.e. 
the first column will
+   *    be assigned to `_1`).
+   *  - When `U` is a primitive type (i.e. String, Int, etc). then the first 
column of the
+   *    [[DataFrame]] will be used.
+   *
+   * If the schema of the [[DataFrame]] does not match the desired `U` type, 
you can use `select`
+   * along with `alias` or `as` to rearrange or rename as required.
    * @since 1.6.0
    */
   def as[U : Encoder]: Dataset[U] = {
@@ -225,6 +234,27 @@ class Dataset[T] private(
       withGroupingKey.newColumns)
   }
 
+  /**
+   * Returns a [[GroupedDataset]] where the data is grouped by the given 
[[Column]] expressions.
+   * @since 1.6.0
+   */
+  @scala.annotation.varargs
+  def groupBy(cols: Column*): GroupedDataset[Row, T] = {
+    val withKeyColumns = logicalPlan.output ++ 
cols.map(_.expr).map(UnresolvedAlias)
+    val withKey = Project(withKeyColumns, logicalPlan)
+    val executed = sqlContext.executePlan(withKey)
+
+    val dataAttributes = executed.analyzed.output.dropRight(cols.size)
+    val keyAttributes = executed.analyzed.output.takeRight(cols.size)
+
+    new GroupedDataset(
+      RowEncoder(keyAttributes.toStructType),
+      encoderFor[T],
+      executed,
+      dataAttributes,
+      keyAttributes)
+  }
+
   /* ****************** *
    *  Typed Relational  *
    * ****************** */

http://git-wip-us.apache.org/repos/asf/spark/blob/b86f2cab/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index 612f2b6..96d6e9d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -17,7 +17,7 @@
 
 package org.apache.spark.sql
 
-import org.apache.spark.sql.catalyst.encoders.Encoder
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, 
Encoder}
 import org.apache.spark.sql.catalyst.expressions.Attribute
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.execution.QueryExecution
@@ -34,12 +34,34 @@ class GroupedDataset[K, T] private[sql](
     private val dataAttributes: Seq[Attribute],
     private val groupingAttributes: Seq[Attribute]) extends Serializable {
 
-  private implicit def kEnc = kEncoder
-  private implicit def tEnc = tEncoder
+  private implicit val kEnc = kEncoder match {
+    case e: ExpressionEncoder[K] => e.resolve(groupingAttributes)
+    case other =>
+      throw new UnsupportedOperationException("Only expression encoders are 
currently supported")
+  }
+
+  private implicit val tEnc = tEncoder match {
+    case e: ExpressionEncoder[T] => e.resolve(dataAttributes)
+    case other =>
+      throw new UnsupportedOperationException("Only expression encoders are 
currently supported")
+  }
+
   private def logicalPlan = queryExecution.analyzed
   private def sqlContext = queryExecution.sqlContext
 
   /**
+   * Returns a new [[GroupedDataset]] where the type of the key has been 
mapped to the specified
+   * type. The mapping of key columns to the type follows the same rules as 
`as` on [[Dataset]].
+   */
+  def asKey[L : Encoder]: GroupedDataset[L, T] =
+    new GroupedDataset(
+      encoderFor[L],
+      tEncoder,
+      queryExecution,
+      dataAttributes,
+      groupingAttributes)
+
+  /**
    * Returns a [[Dataset]] that contains each unique key.
    */
   def keys: Dataset[K] = {

http://git-wip-us.apache.org/repos/asf/spark/blob/b86f2cab/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 95b8d05..5973fa7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -203,6 +203,54 @@ class DatasetSuite extends QueryTest with SharedSQLContext 
{
       ("a", 30), ("b", 3), ("c", 1))
   }
 
+  test("groupBy columns, mapGroups") {
+    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+    val grouped = ds.groupBy($"_1")
+    val agged = grouped.mapGroups { case (g, iter) =>
+      Iterator((g.getString(0), iter.map(_._2).sum))
+    }
+
+    checkAnswer(
+      agged,
+      ("a", 30), ("b", 3), ("c", 1))
+  }
+
+  test("groupBy columns asKey, mapGroups") {
+    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+    val grouped = ds.groupBy($"_1").asKey[String]
+    val agged = grouped.mapGroups { case (g, iter) =>
+      Iterator((g, iter.map(_._2).sum))
+    }
+
+    checkAnswer(
+      agged,
+      ("a", 30), ("b", 3), ("c", 1))
+  }
+
+  test("groupBy columns asKey tuple, mapGroups") {
+    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+    val grouped = ds.groupBy($"_1", lit(1)).asKey[(String, Int)]
+    val agged = grouped.mapGroups { case (g, iter) =>
+      Iterator((g, iter.map(_._2).sum))
+    }
+
+    checkAnswer(
+      agged,
+      (("a", 1), 30), (("b", 1), 3), (("c", 1), 1))
+  }
+
+  test("groupBy columns asKey class, mapGroups") {
+    val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+    val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).asKey[ClassData]
+    val agged = grouped.mapGroups { case (g, iter) =>
+      Iterator((g, iter.map(_._2).sum))
+    }
+
+    checkAnswer(
+      agged,
+      (ClassData("a", 1), 30), (ClassData("b", 1), 3), (ClassData("c", 1), 1))
+  }
+
   test("cogroup") {
     val ds1 = Seq(1 -> "a", 3 -> "abc", 5 -> "hello", 3 -> "foo").toDS()
     val ds2 = Seq(2 -> "q", 3 -> "w", 5 -> "e", 5 -> "r").toDS()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to