Repository: spark
Updated Branches:
  refs/heads/master 8141d5592 -> b4fd75fb9


[SPARK-24972][SQL] PivotFirst could not handle pivot columns of complex types

## What changes were proposed in this pull request?

When the pivot column is of a complex type, the eval() result will be an 
UnsafeRow, while the keys of the HashMap for column value matching is a 
GenericInternalRow. As a result, there will be no match and the result will 
always be empty.
So for a pivot column of complex-types, we should:
1) If the complex-type is not comparable (orderable), throw an Exception. It 
cannot be a pivot column.
2) Otherwise, if it goes through the `PivotFirst` code path, `PivotFirst` 
should use a TreeMap instead of HashMap for such columns.

This PR has also reverted the walk-around in Analyzer that had been introduced 
to avoid this `PivotFirst` issue.

## How was this patch tested?

Added UT.

Author: maryannxue <maryann...@apache.org>

Closes #21926 from maryannxue/pivot_followup.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/b4fd75fb
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/b4fd75fb
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/b4fd75fb

Branch: refs/heads/master
Commit: b4fd75fb9b615cfe592ad269cf20d02b483a0d33
Parents: 8141d55
Author: maryannxue <maryann...@apache.org>
Authored: Mon Jul 30 23:43:53 2018 -0700
Committer: Xiao Li <gatorsm...@gmail.com>
Committed: Mon Jul 30 23:43:53 2018 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  |  12 +-
 .../expressions/aggregate/PivotFirst.scala      |  11 +-
 .../test/resources/sql-tests/inputs/pivot.sql   |  78 ++++++++++++-
 .../resources/sql-tests/results/pivot.sql.out   | 116 +++++++++++++++++--
 4 files changed, 199 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/b4fd75fb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 1488ede..76dc867 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -529,6 +529,10 @@ class Analyzer(
         || (p.groupByExprsOpt.isDefined && 
!p.groupByExprsOpt.get.forall(_.resolved))
         || !p.pivotColumn.resolved || !p.pivotValues.forall(_.resolved) => p
       case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) 
=>
+        if (!RowOrdering.isOrderable(pivotColumn.dataType)) {
+          throw new AnalysisException(
+            s"Invalid pivot column '${pivotColumn}'. Pivot columns must be 
comparable.")
+        }
         // Check all aggregate expressions.
         aggregates.foreach(checkValidAggregateExpression)
         // Check all pivot values are literal and match pivot column data type.
@@ -574,10 +578,14 @@ class Analyzer(
           // Since evaluating |pivotValues| if statements for each input row 
can get slow this is an
           // alternate plan that instead uses two steps of aggregation.
           val namedAggExps: Seq[NamedExpression] = aggregates.map(a => 
Alias(a, a.sql)())
-          val bigGroup = groupByExprs ++ pivotColumn.references
+          val namedPivotCol = pivotColumn match {
+            case n: NamedExpression => n
+            case _ => Alias(pivotColumn, "__pivot_col")()
+          }
+          val bigGroup = groupByExprs :+ namedPivotCol
           val firstAgg = Aggregate(bigGroup, bigGroup ++ namedAggExps, child)
           val pivotAggs = namedAggExps.map { a =>
-            Alias(PivotFirst(pivotColumn, a.toAttribute, evalPivotValues)
+            Alias(PivotFirst(namedPivotCol.toAttribute, a.toAttribute, 
evalPivotValues)
               .toAggregateExpression()
             , "__pivot_" + a.sql)()
           }

http://git-wip-us.apache.org/repos/asf/spark/blob/b4fd75fb/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala
index 5237148..33bc5b5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PivotFirst.scala
@@ -17,11 +17,11 @@
 
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
-import scala.collection.immutable.HashMap
+import scala.collection.immutable.{HashMap, TreeMap}
 
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.catalyst.util.{GenericArrayData, TypeUtils}
 import org.apache.spark.sql.types._
 
 object PivotFirst {
@@ -83,7 +83,12 @@ case class PivotFirst(
 
   override val dataType: DataType = ArrayType(valueDataType)
 
-  val pivotIndex = HashMap(pivotColumnValues.zipWithIndex: _*)
+  val pivotIndex = if (pivotColumn.dataType.isInstanceOf[AtomicType]) {
+    HashMap(pivotColumnValues.zipWithIndex: _*)
+  } else {
+    TreeMap(pivotColumnValues.zipWithIndex: _*)(
+      TypeUtils.getInterpretedOrdering(pivotColumn.dataType))
+  }
 
   val indexSize = pivotIndex.size
 

http://git-wip-us.apache.org/repos/asf/spark/blob/b4fd75fb/sql/core/src/test/resources/sql-tests/inputs/pivot.sql
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql 
b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql
index a6c8d48..1f607b3 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/pivot.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/pivot.sql
@@ -11,10 +11,10 @@ create temporary view years as select * from values
   (2013, 2)
   as years(y, s);
 
-create temporary view yearsWithArray as select * from values
-  (2012, array(1, 1)),
-  (2013, array(2, 2))
-  as yearsWithArray(y, a);
+create temporary view yearsWithComplexTypes as select * from values
+  (2012, array(1, 1), map('1', 1), struct(1, 'a')),
+  (2013, array(2, 2), map('2', 2), struct(2, 'b'))
+  as yearsWithComplexTypes(y, a, m, s);
 
 -- pivot courses
 SELECT * FROM (
@@ -204,7 +204,7 @@ PIVOT (
 SELECT * FROM (
   SELECT course, year, a
   FROM courseSales
-  JOIN yearsWithArray ON year = y
+  JOIN yearsWithComplexTypes ON year = y
 )
 PIVOT (
   min(a)
@@ -215,9 +215,75 @@ PIVOT (
 SELECT * FROM (
   SELECT course, year, y, a
   FROM courseSales
-  JOIN yearsWithArray ON year = y
+  JOIN yearsWithComplexTypes ON year = y
 )
 PIVOT (
   max(a)
   FOR (y, course) IN ((2012, 'dotNET'), (2013, 'Java'))
 );
+
+-- pivot on pivot column of array type
+SELECT * FROM (
+  SELECT earnings, year, a
+  FROM courseSales
+  JOIN yearsWithComplexTypes ON year = y
+)
+PIVOT (
+  sum(earnings)
+  FOR a IN (array(1, 1), array(2, 2))
+);
+
+-- pivot on multiple pivot columns containing array type
+SELECT * FROM (
+  SELECT course, earnings, year, a
+  FROM courseSales
+  JOIN yearsWithComplexTypes ON year = y
+)
+PIVOT (
+  sum(earnings)
+  FOR (course, a) IN (('dotNET', array(1, 1)), ('Java', array(2, 2)))
+);
+
+-- pivot on pivot column of struct type
+SELECT * FROM (
+  SELECT earnings, year, s
+  FROM courseSales
+  JOIN yearsWithComplexTypes ON year = y
+)
+PIVOT (
+  sum(earnings)
+  FOR s IN ((1, 'a'), (2, 'b'))
+);
+
+-- pivot on multiple pivot columns containing struct type
+SELECT * FROM (
+  SELECT course, earnings, year, s
+  FROM courseSales
+  JOIN yearsWithComplexTypes ON year = y
+)
+PIVOT (
+  sum(earnings)
+  FOR (course, s) IN (('dotNET', (1, 'a')), ('Java', (2, 'b')))
+);
+
+-- pivot on pivot column of map type
+SELECT * FROM (
+  SELECT earnings, year, m
+  FROM courseSales
+  JOIN yearsWithComplexTypes ON year = y
+)
+PIVOT (
+  sum(earnings)
+  FOR m IN (map('1', 1), map('2', 2))
+);
+
+-- pivot on multiple pivot columns containing map type
+SELECT * FROM (
+  SELECT course, earnings, year, m
+  FROM courseSales
+  JOIN yearsWithComplexTypes ON year = y
+)
+PIVOT (
+  sum(earnings)
+  FOR (course, m) IN (('dotNET', map('1', 1)), ('Java', map('2', 2)))
+);

http://git-wip-us.apache.org/repos/asf/spark/blob/b4fd75fb/sql/core/src/test/resources/sql-tests/results/pivot.sql.out
----------------------------------------------------------------------
diff --git a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out 
b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out
index 6bb51b9..2dd9293 100644
--- a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out
@@ -1,5 +1,5 @@
 -- Automatically generated by SQLQueryTestSuite
--- Number of queries: 25
+-- Number of queries: 31
 
 
 -- !query 0
@@ -28,10 +28,10 @@ struct<>
 
 
 -- !query 2
-create temporary view yearsWithArray as select * from values
-  (2012, array(1, 1)),
-  (2013, array(2, 2))
-  as yearsWithArray(y, a)
+create temporary view yearsWithComplexTypes as select * from values
+  (2012, array(1, 1), map('1', 1), struct(1, 'a')),
+  (2013, array(2, 2), map('2', 2), struct(2, 'b'))
+  as yearsWithComplexTypes(y, a, m, s)
 -- !query 2 schema
 struct<>
 -- !query 2 output
@@ -346,7 +346,7 @@ Literal expressions required for pivot values, found 
'course#x';
 SELECT * FROM (
   SELECT course, year, a
   FROM courseSales
-  JOIN yearsWithArray ON year = y
+  JOIN yearsWithComplexTypes ON year = y
 )
 PIVOT (
   min(a)
@@ -363,7 +363,7 @@ struct<year:int,dotNET:array<int>,Java:array<int>>
 SELECT * FROM (
   SELECT course, year, y, a
   FROM courseSales
-  JOIN yearsWithArray ON year = y
+  JOIN yearsWithComplexTypes ON year = y
 )
 PIVOT (
   max(a)
@@ -374,3 +374,105 @@ struct<year:int,[2012, dotNET]:array<int>,[2013, 
Java]:array<int>>
 -- !query 24 output
 2012   [1,1]   NULL
 2013   NULL    [2,2]
+
+
+-- !query 25
+SELECT * FROM (
+  SELECT earnings, year, a
+  FROM courseSales
+  JOIN yearsWithComplexTypes ON year = y
+)
+PIVOT (
+  sum(earnings)
+  FOR a IN (array(1, 1), array(2, 2))
+)
+-- !query 25 schema
+struct<year:int,[1, 1]:bigint,[2, 2]:bigint>
+-- !query 25 output
+2012   35000   NULL
+2013   NULL    78000
+
+
+-- !query 26
+SELECT * FROM (
+  SELECT course, earnings, year, a
+  FROM courseSales
+  JOIN yearsWithComplexTypes ON year = y
+)
+PIVOT (
+  sum(earnings)
+  FOR (course, a) IN (('dotNET', array(1, 1)), ('Java', array(2, 2)))
+)
+-- !query 26 schema
+struct<year:int,[dotNET, [1, 1]]:bigint,[Java, [2, 2]]:bigint>
+-- !query 26 output
+2012   15000   NULL
+2013   NULL    30000
+
+
+-- !query 27
+SELECT * FROM (
+  SELECT earnings, year, s
+  FROM courseSales
+  JOIN yearsWithComplexTypes ON year = y
+)
+PIVOT (
+  sum(earnings)
+  FOR s IN ((1, 'a'), (2, 'b'))
+)
+-- !query 27 schema
+struct<year:int,[1, a]:bigint,[2, b]:bigint>
+-- !query 27 output
+2012   35000   NULL
+2013   NULL    78000
+
+
+-- !query 28
+SELECT * FROM (
+  SELECT course, earnings, year, s
+  FROM courseSales
+  JOIN yearsWithComplexTypes ON year = y
+)
+PIVOT (
+  sum(earnings)
+  FOR (course, s) IN (('dotNET', (1, 'a')), ('Java', (2, 'b')))
+)
+-- !query 28 schema
+struct<year:int,[dotNET, [1, a]]:bigint,[Java, [2, b]]:bigint>
+-- !query 28 output
+2012   15000   NULL
+2013   NULL    30000
+
+
+-- !query 29
+SELECT * FROM (
+  SELECT earnings, year, m
+  FROM courseSales
+  JOIN yearsWithComplexTypes ON year = y
+)
+PIVOT (
+  sum(earnings)
+  FOR m IN (map('1', 1), map('2', 2))
+)
+-- !query 29 schema
+struct<>
+-- !query 29 output
+org.apache.spark.sql.AnalysisException
+Invalid pivot column 'm#x'. Pivot columns must be comparable.;
+
+
+-- !query 30
+SELECT * FROM (
+  SELECT course, earnings, year, m
+  FROM courseSales
+  JOIN yearsWithComplexTypes ON year = y
+)
+PIVOT (
+  sum(earnings)
+  FOR (course, m) IN (('dotNET', map('1', 1)), ('Java', map('2', 2)))
+)
+-- !query 30 schema
+struct<>
+-- !query 30 output
+org.apache.spark.sql.AnalysisException
+Invalid pivot column 'named_struct(course, course#x, m, m#x)'. Pivot columns 
must be comparable.;


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to