Repository: spark Updated Branches: refs/heads/master 425ff03f5 -> b86f2cab6
[SPARK-11404] [SQL] Support for groupBy using column expressions This PR adds a new method `groupBy(cols: Column*)` to `Dataset` that allows users to group using column expressions instead of a lambda function. Since the return type of these expressions is not known at compile time, we just set the key type as a generic `Row`. If the user would like to work the key in a type-safe way, they can call `grouped.asKey[Type]`, which is also added in this PR. ```scala val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1").asKey[String] val agged = grouped.mapGroups { case (g, iter) => Iterator((g, iter.map(_._2).sum)) } agged.collect() res0: Array(("a", 30), ("b", 3), ("c", 1)) ``` Author: Michael Armbrust <mich...@databricks.com> Closes #9359 from marmbrus/columnGroupBy and squashes the following commits: bbcb03b [Michael Armbrust] Update DatasetSuite.scala 8fd2908 [Michael Armbrust] Update DatasetSuite.scala 0b0e2f8 [Michael Armbrust] [SPARK-11404] [SQL] Support for groupBy using column expressions Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b86f2cab Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b86f2cab Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b86f2cab Branch: refs/heads/master Commit: b86f2cab67989f09ba1ba8604e52cd4b1e44e436 Parents: 425ff03 Author: Michael Armbrust <mich...@databricks.com> Authored: Tue Nov 3 13:02:17 2015 +0100 Committer: Michael Armbrust <mich...@databricks.com> Committed: Tue Nov 3 13:02:17 2015 +0100 ---------------------------------------------------------------------- .../scala/org/apache/spark/sql/Dataset.scala | 36 +++++++++++++-- .../org/apache/spark/sql/GroupedDataset.scala | 28 ++++++++++-- .../org/apache/spark/sql/DatasetSuite.scala | 48 ++++++++++++++++++++ 3 files changed, 106 insertions(+), 6 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/b86f2cab/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index ed98a25..7b75aee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.Inner @@ -78,9 +79,17 @@ class Dataset[T] private( * ************* */ /** - * Returns a new `Dataset` where each record has been mapped on to the specified type. - * TODO: should bind here... - * TODO: document binding rules + * Returns a new `Dataset` where each record has been mapped on to the specified type. The + * method used to map columns depend on the type of `U`: + * - When `U` is a class, fields for the class will be mapped to columns of the same name + * (case sensitivity is determined by `spark.sql.caseSensitive`) + * - When `U` is a tuple, the columns will be be mapped by ordinal (i.e. the first column will + * be assigned to `_1`). + * - When `U` is a primitive type (i.e. String, Int, etc). then the first column of the + * [[DataFrame]] will be used. + * + * If the schema of the [[DataFrame]] does not match the desired `U` type, you can use `select` + * along with `alias` or `as` to rearrange or rename as required. * @since 1.6.0 */ def as[U : Encoder]: Dataset[U] = { @@ -225,6 +234,27 @@ class Dataset[T] private( withGroupingKey.newColumns) } + /** + * Returns a [[GroupedDataset]] where the data is grouped by the given [[Column]] expressions. + * @since 1.6.0 + */ + @scala.annotation.varargs + def groupBy(cols: Column*): GroupedDataset[Row, T] = { + val withKeyColumns = logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias) + val withKey = Project(withKeyColumns, logicalPlan) + val executed = sqlContext.executePlan(withKey) + + val dataAttributes = executed.analyzed.output.dropRight(cols.size) + val keyAttributes = executed.analyzed.output.takeRight(cols.size) + + new GroupedDataset( + RowEncoder(keyAttributes.toStructType), + encoderFor[T], + executed, + dataAttributes, + keyAttributes) + } + /* ****************** * * Typed Relational * * ****************** */ http://git-wip-us.apache.org/repos/asf/spark/blob/b86f2cab/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 612f2b6..96d6e9d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import org.apache.spark.sql.catalyst.encoders.Encoder +import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Encoder} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.QueryExecution @@ -34,12 +34,34 @@ class GroupedDataset[K, T] private[sql]( private val dataAttributes: Seq[Attribute], private val groupingAttributes: Seq[Attribute]) extends Serializable { - private implicit def kEnc = kEncoder - private implicit def tEnc = tEncoder + private implicit val kEnc = kEncoder match { + case e: ExpressionEncoder[K] => e.resolve(groupingAttributes) + case other => + throw new UnsupportedOperationException("Only expression encoders are currently supported") + } + + private implicit val tEnc = tEncoder match { + case e: ExpressionEncoder[T] => e.resolve(dataAttributes) + case other => + throw new UnsupportedOperationException("Only expression encoders are currently supported") + } + private def logicalPlan = queryExecution.analyzed private def sqlContext = queryExecution.sqlContext /** + * Returns a new [[GroupedDataset]] where the type of the key has been mapped to the specified + * type. The mapping of key columns to the type follows the same rules as `as` on [[Dataset]]. + */ + def asKey[L : Encoder]: GroupedDataset[L, T] = + new GroupedDataset( + encoderFor[L], + tEncoder, + queryExecution, + dataAttributes, + groupingAttributes) + + /** * Returns a [[Dataset]] that contains each unique key. */ def keys: Dataset[K] = { http://git-wip-us.apache.org/repos/asf/spark/blob/b86f2cab/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 95b8d05..5973fa7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -203,6 +203,54 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ("a", 30), ("b", 3), ("c", 1)) } + test("groupBy columns, mapGroups") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + val grouped = ds.groupBy($"_1") + val agged = grouped.mapGroups { case (g, iter) => + Iterator((g.getString(0), iter.map(_._2).sum)) + } + + checkAnswer( + agged, + ("a", 30), ("b", 3), ("c", 1)) + } + + test("groupBy columns asKey, mapGroups") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + val grouped = ds.groupBy($"_1").asKey[String] + val agged = grouped.mapGroups { case (g, iter) => + Iterator((g, iter.map(_._2).sum)) + } + + checkAnswer( + agged, + ("a", 30), ("b", 3), ("c", 1)) + } + + test("groupBy columns asKey tuple, mapGroups") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + val grouped = ds.groupBy($"_1", lit(1)).asKey[(String, Int)] + val agged = grouped.mapGroups { case (g, iter) => + Iterator((g, iter.map(_._2).sum)) + } + + checkAnswer( + agged, + (("a", 1), 30), (("b", 1), 3), (("c", 1), 1)) + } + + test("groupBy columns asKey class, mapGroups") { + val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() + val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).asKey[ClassData] + val agged = grouped.mapGroups { case (g, iter) => + Iterator((g, iter.map(_._2).sum)) + } + + checkAnswer( + agged, + (ClassData("a", 1), 30), (ClassData("b", 1), 3), (ClassData("c", 1), 1)) + } + test("cogroup") { val ds1 = Seq(1 -> "a", 3 -> "abc", 5 -> "hello", 3 -> "foo").toDS() val ds2 = Seq(2 -> "q", 3 -> "w", 5 -> "e", 5 -> "r").toDS() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org