Repository: spark
Updated Branches:
  refs/heads/branch-2.1 0e60e4b88 -> 0e624e990


[SPARK-18504][SQL] Scalar subquery with extra group by columns returning 
incorrect result

## What changes were proposed in this pull request?

This PR blocks an incorrect result scenario in scalar subquery where there are 
GROUP BY column(s)
that are not part of the correlated predicate(s).

Example:
// Incorrect result
Seq(1).toDF("c1").createOrReplaceTempView("t1")
Seq((1,1),(1,2)).toDF("c1","c2").createOrReplaceTempView("t2")
sql("select (select sum(-1) from t2 where t1.c1=t2.c1 group by t2.c2) from 
t1").show

// How can selecting a scalar subquery from a 1-row table return 2 rows?

## How was this patch tested?
sql/test, catalyst/test
new test case covering the reported problem is added to SubquerySuite.scala

Author: Nattavut Sutyanyong <nsy....@gmail.com>

Closes #15936 from nsyca/scalarSubqueryIncorrect-1.

(cherry picked from commit 45ea46b7b397f023b4da878eb11e21b08d931115)
Signed-off-by: Herman van Hovell <hvanhov...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0e624e99
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0e624e99
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0e624e99

Branch: refs/heads/branch-2.1
Commit: 0e624e990b3b426dba0a6149ad6340f85d214a58
Parents: 0e60e4b
Author: Nattavut Sutyanyong <nsy....@gmail.com>
Authored: Tue Nov 22 12:06:21 2016 -0800
Committer: Herman van Hovell <hvanhov...@databricks.com>
Committed: Tue Nov 22 12:06:32 2016 -0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/analysis/Analyzer.scala  |  3 --
 .../sql/catalyst/analysis/CheckAnalysis.scala   | 30 ++++++++++++++++----
 .../org/apache/spark/sql/SubquerySuite.scala    | 12 ++++++++
 3 files changed, 36 insertions(+), 9 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0e624e99/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index b7e1675..2918e9d 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1182,9 +1182,6 @@ class Analyzer(
      */
     private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): 
LogicalPlan = {
       plan transformExpressions {
-        case s @ ScalarSubquery(sub, conditions, exprId)
-            if sub.resolved && conditions.isEmpty && sub.output.size != 1 =>
-          failAnalysis(s"Scalar subquery must return only one column, but got 
${sub.output.size}")
         case s @ ScalarSubquery(sub, _, exprId) if !sub.resolved =>
           resolveSubQuery(s, plans, 1)(ScalarSubquery(_, _, exprId))
         case e @ Exists(sub, exprId) =>

http://git-wip-us.apache.org/repos/asf/spark/blob/0e624e99/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 80e577e..26d2638 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -117,19 +117,37 @@ trait CheckAnalysis extends PredicateHelper {
                 failAnalysis(s"Window specification $s is not valid because 
$m")
               case None => w
             }
+          case s @ ScalarSubquery(query, conditions, _)
+            // If no correlation, the output must be exactly one column
+            if (conditions.isEmpty && query.output.size != 1) =>
+              failAnalysis(
+                s"Scalar subquery must return only one column, but got 
${query.output.size}")
 
           case s @ ScalarSubquery(query, conditions, _) if conditions.nonEmpty 
=>
-            // Make sure correlated scalar subqueries contain one row for 
every outer row by
-            // enforcing that they are aggregates which contain exactly one 
aggregate expressions.
-            // The analyzer has already checked that subquery contained only 
one output column, and
-            // added all the grouping expressions to the aggregate.
-            def checkAggregate(a: Aggregate): Unit = {
-              val aggregates = a.expressions.flatMap(_.collect {
+            def checkAggregate(agg: Aggregate): Unit = {
+              // Make sure correlated scalar subqueries contain one row for 
every outer row by
+              // enforcing that they are aggregates which contain exactly one 
aggregate expressions.
+              // The analyzer has already checked that subquery contained only 
one output column,
+              // and added all the grouping expressions to the aggregate.
+              val aggregates = agg.expressions.flatMap(_.collect {
                 case a: AggregateExpression => a
               })
               if (aggregates.isEmpty) {
                 failAnalysis("The output of a correlated scalar subquery must 
be aggregated")
               }
+
+              // SPARK-18504: block cases where GROUP BY columns
+              // are not part of the correlated columns
+              val groupByCols = 
ExpressionSet.apply(agg.groupingExpressions.flatMap(_.references))
+              val predicateCols = 
ExpressionSet.apply(conditions.flatMap(_.references))
+              val invalidCols = groupByCols.diff(predicateCols)
+              // GROUP BY columns must be a subset of columns in the predicates
+              if (invalidCols.nonEmpty) {
+                failAnalysis(
+                  "a GROUP BY clause in a scalar correlated subquery " +
+                    "cannot contain non-correlated columns: " +
+                    invalidCols.mkString(","))
+              }
             }
 
             // Skip projects and subquery aliases added by the Analyzer and 
the SQLBuilder.

http://git-wip-us.apache.org/repos/asf/spark/blob/0e624e99/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index c84a6f1..f1dd1c6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -483,6 +483,18 @@ class SubquerySuite extends QueryTest with 
SharedSQLContext {
       Row(1, null) :: Row(2, 6.0) :: Row(3, 2.0) :: Row(null, null) :: Row(6, 
null) :: Nil)
   }
 
+  test("SPARK-18504 extra GROUP BY column in correlated scalar subquery is not 
permitted") {
+    withTempView("t") {
+      Seq((1, 1), (1, 2)).toDF("c1", "c2").createOrReplaceTempView("t")
+
+      val errMsg = intercept[AnalysisException] {
+        sql("select (select sum(-1) from t t2 where t1.c2 = t2.c1 group by 
t2.c2) sum from t t1")
+      }
+      assert(errMsg.getMessage.contains(
+        "a GROUP BY clause in a scalar correlated subquery cannot contain 
non-correlated columns:"))
+    }
+  }
+
   test("non-aggregated correlated scalar subquery") {
     val msg1 = intercept[AnalysisException] {
       sql("select a, (select b from l l2 where l2.a = l1.a) sum_b from l l1")


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to