Repository: spark Updated Branches: refs/heads/master dcb9a97f3 -> 623c2ec4e
[SPARK-25048][SQL] Pivoting by multiple columns in Scala/Java ## What changes were proposed in this pull request? In the PR, I propose to extend implementation of existing method: ``` def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset ``` to support values of the struct type. This allows pivoting by multiple columns combined by `struct`: ``` trainingSales .groupBy($"sales.year") .pivot( pivotColumn = struct(lower($"sales.course"), $"training"), values = Seq( struct(lit("dotnet"), lit("Experts")), struct(lit("java"), lit("Dummies"))) ).agg(sum($"sales.earnings")) ``` ## How was this patch tested? Added a test for values specified via `struct` in Java and Scala. Closes #22316 from MaxGekk/pivoting-by-multiple-columns2. Lead-authored-by: Maxim Gekk <maxim.g...@databricks.com> Co-authored-by: Maxim Gekk <max.g...@gmail.com> Signed-off-by: hyukjinkwon <gurwls...@apache.org> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/623c2ec4 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/623c2ec4 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/623c2ec4 Branch: refs/heads/master Commit: 623c2ec4ef3776bc5e2cac2c66300ddc6264db54 Parents: dcb9a97 Author: Maxim Gekk <maxim.g...@databricks.com> Authored: Sat Sep 29 21:50:35 2018 +0800 Committer: hyukjinkwon <gurwls...@apache.org> Committed: Sat Sep 29 21:50:35 2018 +0800 ---------------------------------------------------------------------- .../spark/sql/RelationalGroupedDataset.scala | 17 +++++++++++++-- .../apache/spark/sql/JavaDataFrameSuite.java | 16 ++++++++++++++ .../apache/spark/sql/DataFramePivotSuite.scala | 23 ++++++++++++++++++++ 3 files changed, 54 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/623c2ec4/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index d700fb8..dbacdbf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -330,6 +330,15 @@ class RelationalGroupedDataset protected[sql]( * df.groupBy("year").pivot("course").sum("earnings") * }}} * + * From Spark 2.5.0, values can be literal columns, for instance, struct. For pivoting by + * multiple columns, use the `struct` function to combine the columns and values: + * + * {{{ + * df.groupBy("year") + * .pivot("trainingCourse", Seq(struct(lit("java"), lit("Experts")))) + * .agg(sum($"earnings")) + * }}} + * * @param pivotColumn Name of the column to pivot. * @param values List of values that will be translated to columns in the output DataFrame. * @since 1.6.0 @@ -413,10 +422,14 @@ class RelationalGroupedDataset protected[sql]( def pivot(pivotColumn: Column, values: Seq[Any]): RelationalGroupedDataset = { groupType match { case RelationalGroupedDataset.GroupByType => + val valueExprs = values.map(_ match { + case c: Column => c.expr + case v => Literal.apply(v) + }) new RelationalGroupedDataset( df, groupingExprs, - RelationalGroupedDataset.PivotType(pivotColumn.expr, values.map(Literal.apply))) + RelationalGroupedDataset.PivotType(pivotColumn.expr, valueExprs)) case _: RelationalGroupedDataset.PivotType => throw new UnsupportedOperationException("repeated pivots are not supported") case _ => @@ -561,5 +574,5 @@ private[sql] object RelationalGroupedDataset { /** * To indicate it's the PIVOT */ - private[sql] case class PivotType(pivotCol: Expression, values: Seq[Literal]) extends GroupType + private[sql] case class PivotType(pivotCol: Expression, values: Seq[Expression]) extends GroupType } http://git-wip-us.apache.org/repos/asf/spark/blob/623c2ec4/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java ---------------------------------------------------------------------- diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 3f37e58..00f41d6 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -317,6 +317,22 @@ public class JavaDataFrameSuite { Assert.assertEquals(30000.0, actual.get(1).getDouble(2), 0.01); } + @Test + public void pivotColumnValues() { + Dataset<Row> df = spark.table("courseSales"); + List<Row> actual = df.groupBy("year") + .pivot(col("course"), Arrays.asList(lit("dotNET"), lit("Java"))) + .agg(sum("earnings")).orderBy("year").collectAsList(); + + Assert.assertEquals(2012, actual.get(0).getInt(0)); + Assert.assertEquals(15000.0, actual.get(0).getDouble(1), 0.01); + Assert.assertEquals(20000.0, actual.get(0).getDouble(2), 0.01); + + Assert.assertEquals(2013, actual.get(1).getInt(0)); + Assert.assertEquals(48000.0, actual.get(1).getDouble(1), 0.01); + Assert.assertEquals(30000.0, actual.get(1).getDouble(2), 0.01); + } + private String getResource(String resource) { try { // The following "getResource" has different behaviors in SBT and Maven. http://git-wip-us.apache.org/repos/asf/spark/blob/623c2ec4/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala index b972b9e..02ab197 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala @@ -308,4 +308,27 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext { assert(exception.getMessage.contains("aggregate functions are not allowed")) } + + test("pivoting column list with values") { + val expected = Row(2012, 10000.0, null) :: Row(2013, 48000.0, 30000.0) :: Nil + val df = trainingSales + .groupBy($"sales.year") + .pivot(struct(lower($"sales.course"), $"training"), Seq( + struct(lit("dotnet"), lit("Experts")), + struct(lit("java"), lit("Dummies"))) + ).agg(sum($"sales.earnings")) + + checkAnswer(df, expected) + } + + test("pivoting column list") { + val exception = intercept[RuntimeException] { + trainingSales + .groupBy($"sales.year") + .pivot(struct(lower($"sales.course"), $"training")) + .agg(sum($"sales.earnings")) + .collect() + } + assert(exception.getMessage.contains("Unsupported literal type")) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org