Repository: spark
Updated Branches:
  refs/heads/master bb152cdfb -> 45ea46b7b


[SPARK-18504][SQL] Scalar subquery with extra group by columns returning 
incorrect result

## What changes were proposed in this pull request?

This PR blocks an incorrect result scenario in scalar subquery where there are 
GROUP BY column(s)
that are not part of the correlated predicate(s).

Example:
// Incorrect result
Seq(1).toDF("c1").createOrReplaceTempView("t1")
Seq((1,1),(1,2)).toDF("c1","c2").createOrReplaceTempView("t2")
sql("select (select sum(-1) from t2 where t1.c1=t2.c1 group by t2.c2) from 
t1").show

// How can selecting a scalar subquery from a 1-row table return 2 rows?

## How was this patch tested?
sql/test, catalyst/test
new test case covering the reported problem is added to SubquerySuite.scala

Author: Nattavut Sutyanyong <nsy....@gmail.com>

Closes #15936 from nsyca/scalarSubqueryIncorrect-1.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/45ea46b7
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/45ea46b7
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/45ea46b7

Branch: refs/heads/master
Commit: 45ea46b7b397f023b4da878eb11e21b08d931115
Parents: bb152cd
Author: Nattavut Sutyanyong <nsy....@gmail.com>
Authored: Tue Nov 22 12:06:21 2016 -0800
Committer: Herman van Hovell <hvanhov...@databricks.com>
Committed: Tue Nov 22 12:06:21 2016 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  |  3 --
 .../sql/catalyst/analysis/CheckAnalysis.scala   | 30 ++++++++++++++++----
 .../org/apache/spark/sql/SubquerySuite.scala    | 12 ++++++++
 3 files changed, 36 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/45ea46b7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index ec5f710..0155741 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1241,9 +1241,6 @@ class Analyzer(
      */
     private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): 
LogicalPlan = {
       plan transformExpressions {
-        case s @ ScalarSubquery(sub, conditions, exprId)
-            if sub.resolved && conditions.isEmpty && sub.output.size != 1 =>
-          failAnalysis(s"Scalar subquery must return only one column, but got 
${sub.output.size}")
         case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved =>
           resolveSubQuery(s, plans, 1)(ScalarSubquery(_, _, exprId))
         case e @ Exists(sub, exprId) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/45ea46b7/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 80e577e..26d2638 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -117,19 +117,37 @@ trait CheckAnalysis extends PredicateHelper {
                 failAnalysis(s"Window specification $s is not valid because 
$m")
               case None => w
             }
+          case s @ ScalarSubquery(query, conditions, _)
+            // If no correlation, the output must be exactly one column
+            if (conditions.isEmpty && query.output.size != 1) =>
+              failAnalysis(
+                s"Scalar subquery must return only one column, but got 
${query.output.size}")
 
           case s @ ScalarSubquery(query, conditions, _) if conditions.nonEmpty 
=>
-            // Make sure correlated scalar subqueries contain one row for 
every outer row by
-            // enforcing that they are aggregates which contain exactly one 
aggregate expressions.
-            // The analyzer has already checked that subquery contained only 
one output column, and
-            // added all the grouping expressions to the aggregate.
-            def checkAggregate(a: Aggregate): Unit = {
-              val aggregates = a.expressions.flatMap(_.collect {
+            def checkAggregate(agg: Aggregate): Unit = {
+              // Make sure correlated scalar subqueries contain one row for 
every outer row by
+              // enforcing that they are aggregates which contain exactly one 
aggregate expressions.
+              // The analyzer has already checked that subquery contained only 
one output column,
+              // and added all the grouping expressions to the aggregate.
+              val aggregates = agg.expressions.flatMap(_.collect {
                 case a: AggregateExpression => a
               })
               if (aggregates.isEmpty) {
                 failAnalysis("The output of a correlated scalar subquery must 
be aggregated")
               }
+
+              // SPARK-18504: block cases where GROUP BY columns
+              // are not part of the correlated columns
+              val groupByCols = 
ExpressionSet.apply(agg.groupingExpressions.flatMap(_.references))
+              val predicateCols = 
ExpressionSet.apply(conditions.flatMap(_.references))
+              val invalidCols = groupByCols.diff(predicateCols)
+              // GROUP BY columns must be a subset of columns in the predicates
+              if (invalidCols.nonEmpty) {
+                failAnalysis(
+                  "a GROUP BY clause in a scalar correlated subquery " +
+                    "cannot contain non-correlated columns: " +
+                    invalidCols.mkString(","))
+              }
             }
 
             // Skip projects and subquery aliases added by the Analyzer and 
the SQLBuilder.

http://git-wip-us.apache.org/repos/asf/spark/blob/45ea46b7/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index c84a6f1..f1dd1c6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -483,6 +483,18 @@ class SubquerySuite extends QueryTest with 
SharedSQLContext {
       Row(1, null) :: Row(2, 6.0) :: Row(3, 2.0) :: Row(null, null) :: Row(6, 
null) :: Nil)
   }
 
+  test("SPARK-18504 extra GROUP BY column in correlated scalar subquery is not 
permitted") {
+    withTempView("t") {
+      Seq((1, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t")
+
+      val errMsg = intercept[AnalysisException] {
+        sql("select (select sum(-1) from t t2 where t1.c2 = t2.c1 group by 
t2.c2) sum from t t1")
+      }
+      assert(errMsg.getMessage.contains(
+        "a GROUP BY clause in a scalar correlated subquery cannot contain 
non-correlated columns:"))
+    }
+  }
+
   test("non-aggregated correlated scalar subquery") {
     val msg1 = intercept[AnalysisException] {
       sql("select a, (select b from l l2 where l2.a = l1.a) sum_b from l l1")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to