Repository: spark Updated Branches: refs/heads/branch-1.6 9b99b2b46 -> 6e3e3c648
[SPARK-12068][SQL] use a single column in Dataset.groupBy and count will fail The reason is that, for a single culumn `RowEncoder`(or a single field product encoder), when we use it as the encoder for grouping key, we should also combine the grouping attributes, although there is only one grouping attribute. Author: Wenchen Fan <wenc...@databricks.com> Closes #10059 from cloud-fan/bug. (cherry picked from commit 8ddc55f1d582cccc3ca135510b2ea776e889e481) Signed-off-by: Michael Armbrust <mich...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6e3e3c64 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6e3e3c64 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6e3e3c64 Branch: refs/heads/branch-1.6 Commit: 6e3e3c648c4f74d9c1aabe767dbadfe47bd7e658 Parents: 9b99b2b Author: Wenchen Fan <wenc...@databricks.com> Authored: Tue Dec 1 10:22:55 2015 -0800 Committer: Michael Armbrust <mich...@databricks.com> Committed: Tue Dec 1 10:23:17 2015 -0800 ---------------------------------------------------------------------- .../scala/org/apache/spark/sql/Dataset.scala | 2 +- .../org/apache/spark/sql/GroupedDataset.scala | 7 ++++--- .../org/apache/spark/sql/DatasetSuite.scala | 19 +++++++++++++++++++ .../scala/org/apache/spark/sql/QueryTest.scala | 6 +++--- 4 files changed, 27 insertions(+), 7 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/6e3e3c64/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index da46001..c357f88 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -70,7 +70,7 @@ class Dataset[T] private[sql]( * implicit so that we can use it when constructing new [[Dataset]] objects that have the same * object type (that will be possibly resolved to a different schema). */ - private implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder) + private[sql] implicit val unresolvedTEncoder: ExpressionEncoder[T] = encoderFor(tEncoder) /** The encoder for this [[Dataset]] that has been resolved to its output schema. */ private[sql] val resolvedTEncoder: ExpressionEncoder[T] = http://git-wip-us.apache.org/repos/asf/spark/blob/6e3e3c64/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index a10a893..4bf0b25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -228,10 +228,11 @@ class GroupedDataset[K, V] private[sql]( val namedColumns = columns.map( _.withInputType(resolvedVEncoder, dataAttributes).named) - val keyColumn = if (groupingAttributes.length > 1) { - Alias(CreateStruct(groupingAttributes), "key")() - } else { + val keyColumn = if (resolvedKEncoder.flat) { + assert(groupingAttributes.length == 1) groupingAttributes.head + } else { + Alias(CreateStruct(groupingAttributes), "key")() } val aggregate = Aggregate(groupingAttributes, keyColumn +: namedColumns, logicalPlan) val execution = new QueryExecution(sqlContext, aggregate) http://git-wip-us.apache.org/repos/asf/spark/blob/6e3e3c64/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 7d53918..a2c8d20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -272,6 +272,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext { 3 -> "abcxyz", 5 -> "hello") } + test("groupBy single field class, count") { + val ds = Seq("abc", "xyz", "hello").toDS() + val count = ds.groupBy(s => Tuple1(s.length)).count() + + checkAnswer( + count, + (Tuple1(3), 2L), (Tuple1(5), 1L) + ) + } + test("groupBy columns, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1") @@ -282,6 +292,15 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ("a", 30), ("b", 3), ("c", 1)) } + test("groupBy columns, count") { + val ds = Seq("a" -> 1, "b" -> 1, "a" -> 2).toDS() + val count = ds.groupBy($"_1").count() + + checkAnswer( + count, + (Row("a"), 2L), (Row("b"), 1L)) + } + test("groupBy columns asKey, map") { val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() val grouped = ds.groupBy($"_1").keyAs[String] http://git-wip-us.apache.org/repos/asf/spark/blob/6e3e3c64/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 6ea1fe4..8f476dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -64,12 +64,12 @@ abstract class QueryTest extends PlanTest { * for cases where reordering is done on fields. For such tests, user `checkDecoding` instead * which performs a subset of the checks done by this function. */ - protected def checkAnswer[T : Encoder]( - ds: => Dataset[T], + protected def checkAnswer[T]( + ds: Dataset[T], expectedAnswer: T*): Unit = { checkAnswer( ds.toDF(), - sqlContext.createDataset(expectedAnswer).toDF().collect().toSeq) + sqlContext.createDataset(expectedAnswer)(ds.unresolvedTEncoder).toDF().collect().toSeq) checkDecoding(ds, expectedAnswer: _*) } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org