Repository: spark Updated Branches: refs/heads/master 8141d5592 -> b4fd75fb9
[SPARK-24972][SQL] PivotFirst could not handle pivot columns of complex types ## What changes were proposed in this pull request? When the pivot column is of a complex type, the eval() result will be an UnsafeRow, while the keys of the HashMap for column value matching is a GenericInternalRow. As a result, there will be no match and the result will always be empty. So for a pivot column of complex-types, we should: 1) If the complex-type is not comparable (orderable), throw an Exception. It cannot be a pivot column. 2) Otherwise, if it goes through the `PivotFirst` code path, `PivotFirst` should use a TreeMap instead of HashMap for such columns. This PR has also reverted the walk-around in Analyzer that had been introduced to avoid this `PivotFirst` issue. ## How was this patch tested? Added UT. Author: maryannxue <maryann...@apache.org> Closes #21926 from maryannxue/pivot_followup. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b4fd75fb Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b4fd75fb Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b4fd75fb Branch: refs/heads/master Commit: b4fd75fb9b615cfe592ad269cf20d02b483a0d33 Parents: 8141d55 Author: maryannxue <maryann...@apache.org> Authored: Mon Jul 30 23:43:53 2018 -0700 Committer: Xiao Li <gatorsm...@gmail.com> Committed: Mon Jul 30 23:43:53 2018 -0700 ---------------------------------------------------------------------- .../spark/sql/catalyst/analysis/Analyzer.scala | 12 +- .../expressions/aggregate/PivotFirst.scala | 11 +- .../test/resources/sql-tests/inputs/pivot.sql | 78 ++++++++++++- .../resources/sql-tests/results/pivot.sql.out | 116 +++++++++++++++++-- 4 files changed, 199 insertions(+), 18 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/b4fd75fb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 1488ede..76dc867 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -529,6 +529,10 @@ class Analyzer( || (p.groupByExprsOpt.isDefined && !p.groupByExprsOpt.get.forall(_.resolved)) || !p.pivotColumn.resolved || !p.pivotValues.forall(_.resolved) => p case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) => + if (!RowOrdering.isOrderable(pivotColumn.dataType)) { + throw new AnalysisException( + s"Invalid pivot column '${pivotColumn}'. Pivot columns must be comparable.") + } // Check all aggregate expressions. aggregates.foreach(checkValidAggregateExpression) // Check all pivot values are literal and match pivot column data type. @@ -574,10 +578,14 @@ class Analyzer( // Since evaluating |pivotValues| if statements for each input row can get slow this is an // alternate plan that instead uses two steps of aggregation. val namedAggExps: Seq[NamedExpression] = aggregates.map(a => Alias(a, a.sql)()) - val bigGroup = groupByExprs ++ pivotColumn.references + val namedPivotCol = pivotColumn match { + case n: NamedExpression => n + case _ => Alias(pivotColumn, "__pivot_col")() + } + val bigGroup = groupByExprs :+ namedPivotCol val firstAgg = Aggregate(bigGroup, bigGroup ++ namedAggExps, child) val pivotAggs = namedAggExps.map { a => - Alias(PivotFirst(pivotColumn, a.toAttribute, evalPivotValues) + Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute, evalPivotValues) .toAggregateExpression() , "__pivot_" + a.sql)() } http://git-wip-us.apache.org/repos/asf/spark/blob/b4fd75fb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala index 5237148..33bc5b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala @@ -17,11 +17,11 @@ package org.apache.spark.sql.catalyst.expressions.aggregate -import scala.collection.immutable.HashMap +import scala.collection.immutable.{HashMap, TreeMap} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.util.{GenericArrayData, TypeUtils} import org.apache.spark.sql.types._ object PivotFirst { @@ -83,7 +83,12 @@ case class PivotFirst( override val dataType: DataType = ArrayType(valueDataType) - val pivotIndex = HashMap(pivotColumnValues.zipWithIndex: _*) + val pivotIndex = if (pivotColumn.dataType.isInstanceOf[AtomicType]) { + HashMap(pivotColumnValues.zipWithIndex: _*) + } else { + TreeMap(pivotColumnValues.zipWithIndex: _*)( + TypeUtils.getInterpretedOrdering(pivotColumn.dataType)) + } val indexSize = pivotIndex.size http://git-wip-us.apache.org/repos/asf/spark/blob/b4fd75fb/sql/core/src/test/resources/sql-tests/inputs/pivot.sql ---------------------------------------------------------------------- diff --git a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql index a6c8d48..1f607b3 100644 --- a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql +++ b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql @@ -11,10 +11,10 @@ create temporary view years as select * from values (2013, 2) as years(y, s); -create temporary view yearsWithArray as select * from values - (2012, array(1, 1)), - (2013, array(2, 2)) - as yearsWithArray(y, a); +create temporary view yearsWithComplexTypes as select * from values + (2012, array(1, 1), map('1', 1), struct(1, 'a')), + (2013, array(2, 2), map('2', 2), struct(2, 'b')) + as yearsWithComplexTypes(y, a, m, s); -- pivot courses SELECT * FROM ( @@ -204,7 +204,7 @@ PIVOT ( SELECT * FROM ( SELECT course, year, a FROM courseSales - JOIN yearsWithArray ON year = y + JOIN yearsWithComplexTypes ON year = y ) PIVOT ( min(a) @@ -215,9 +215,75 @@ PIVOT ( SELECT * FROM ( SELECT course, year, y, a FROM courseSales - JOIN yearsWithArray ON year = y + JOIN yearsWithComplexTypes ON year = y ) PIVOT ( max(a) FOR (y, course) IN ((2012, 'dotNET'), (2013, 'Java')) ); + +-- pivot on pivot column of array type +SELECT * FROM ( + SELECT earnings, year, a + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR a IN (array(1, 1), array(2, 2)) +); + +-- pivot on multiple pivot columns containing array type +SELECT * FROM ( + SELECT course, earnings, year, a + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, a) IN (('dotNET', array(1, 1)), ('Java', array(2, 2))) +); + +-- pivot on pivot column of struct type +SELECT * FROM ( + SELECT earnings, year, s + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR s IN ((1, 'a'), (2, 'b')) +); + +-- pivot on multiple pivot columns containing struct type +SELECT * FROM ( + SELECT course, earnings, year, s + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, s) IN (('dotNET', (1, 'a')), ('Java', (2, 'b'))) +); + +-- pivot on pivot column of map type +SELECT * FROM ( + SELECT earnings, year, m + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR m IN (map('1', 1), map('2', 2)) +); + +-- pivot on multiple pivot columns containing map type +SELECT * FROM ( + SELECT course, earnings, year, m + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, m) IN (('dotNET', map('1', 1)), ('Java', map('2', 2))) +); http://git-wip-us.apache.org/repos/asf/spark/blob/b4fd75fb/sql/core/src/test/resources/sql-tests/results/pivot.sql.out ---------------------------------------------------------------------- diff --git a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out index 6bb51b9..2dd9293 100644 --- a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out @@ -1,5 +1,5 @@ -- Automatically generated by SQLQueryTestSuite --- Number of queries: 25 +-- Number of queries: 31 -- !query 0 @@ -28,10 +28,10 @@ struct<> -- !query 2 -create temporary view yearsWithArray as select * from values - (2012, array(1, 1)), - (2013, array(2, 2)) - as yearsWithArray(y, a) +create temporary view yearsWithComplexTypes as select * from values + (2012, array(1, 1), map('1', 1), struct(1, 'a')), + (2013, array(2, 2), map('2', 2), struct(2, 'b')) + as yearsWithComplexTypes(y, a, m, s) -- !query 2 schema struct<> -- !query 2 output @@ -346,7 +346,7 @@ Literal expressions required for pivot values, found 'course#x'; SELECT * FROM ( SELECT course, year, a FROM courseSales - JOIN yearsWithArray ON year = y + JOIN yearsWithComplexTypes ON year = y ) PIVOT ( min(a) @@ -363,7 +363,7 @@ struct<year:int,dotNET:array<int>,Java:array<int>> SELECT * FROM ( SELECT course, year, y, a FROM courseSales - JOIN yearsWithArray ON year = y + JOIN yearsWithComplexTypes ON year = y ) PIVOT ( max(a) @@ -374,3 +374,105 @@ struct<year:int,[2012, dotNET]:array<int>,[2013, Java]:array<int>> -- !query 24 output 2012 [1,1] NULL 2013 NULL [2,2] + + +-- !query 25 +SELECT * FROM ( + SELECT earnings, year, a + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR a IN (array(1, 1), array(2, 2)) +) +-- !query 25 schema +struct<year:int,[1, 1]:bigint,[2, 2]:bigint> +-- !query 25 output +2012 35000 NULL +2013 NULL 78000 + + +-- !query 26 +SELECT * FROM ( + SELECT course, earnings, year, a + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, a) IN (('dotNET', array(1, 1)), ('Java', array(2, 2))) +) +-- !query 26 schema +struct<year:int,[dotNET, [1, 1]]:bigint,[Java, [2, 2]]:bigint> +-- !query 26 output +2012 15000 NULL +2013 NULL 30000 + + +-- !query 27 +SELECT * FROM ( + SELECT earnings, year, s + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR s IN ((1, 'a'), (2, 'b')) +) +-- !query 27 schema +struct<year:int,[1, a]:bigint,[2, b]:bigint> +-- !query 27 output +2012 35000 NULL +2013 NULL 78000 + + +-- !query 28 +SELECT * FROM ( + SELECT course, earnings, year, s + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, s) IN (('dotNET', (1, 'a')), ('Java', (2, 'b'))) +) +-- !query 28 schema +struct<year:int,[dotNET, [1, a]]:bigint,[Java, [2, b]]:bigint> +-- !query 28 output +2012 15000 NULL +2013 NULL 30000 + + +-- !query 29 +SELECT * FROM ( + SELECT earnings, year, m + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR m IN (map('1', 1), map('2', 2)) +) +-- !query 29 schema +struct<> +-- !query 29 output +org.apache.spark.sql.AnalysisException +Invalid pivot column 'm#x'. Pivot columns must be comparable.; + + +-- !query 30 +SELECT * FROM ( + SELECT course, earnings, year, m + FROM courseSales + JOIN yearsWithComplexTypes ON year = y +) +PIVOT ( + sum(earnings) + FOR (course, m) IN (('dotNET', map('1', 1)), ('Java', map('2', 2))) +) +-- !query 30 schema +struct<> +-- !query 30 output +org.apache.spark.sql.AnalysisException +Invalid pivot column 'named_struct(course, course#x, m, m#x)'. Pivot columns must be comparable.; --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org