This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 20edfdd  [SPARK-35845][SQL] OuterReference resolution should reject 
ambiguous column names
20edfdd is described below

commit 20edfdd39a83c52813f91e4028f816d06a6be99e
Author: Wenchen Fan <wenc...@databricks.com>
AuthorDate: Wed Jun 23 14:32:34 2021 +0800

    [SPARK-35845][SQL] OuterReference resolution should reject ambiguous column 
names
    
    ### What changes were proposed in this pull request?
    
    The current OuterReference resolution is a bit weird: when the outer plan 
has more than one child, it resolves OuterReference from the output of each 
child, one by one, left to right.
    
    This is incorrect in the case of join, as the column name can be ambiguous 
if both left and right sides output this column.
    
    This PR fixes this bug by resolving OuterReference with 
`outerPlan.resolveChildren`, instead of something like 
`outerPlan.children.foreach(_.resolve(...))`
    
    ### Why are the changes needed?
    
    bug fix
    
    ### Does this PR introduce _any_ user-facing change?
    
    The problem only occurs in join, and join condition doesn't support 
correlated subquery yet. So this PR only improves the error message. Before 
this PR, people see
    ```
    java.lang.UnsupportedOperationException
    Cannot generate code for expression: outer(t1a#291)
    ```
    
    ### How was this patch tested?
    
    a new test
    
    Closes #33004 from cloud-fan/outer-ref.
    
    Authored-by: Wenchen Fan <wenc...@databricks.com>
    Signed-off-by: Gengliang Wang <gengli...@apache.org>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     | 35 +++++++++++-----------
 .../catalyst/optimizer/DecorrelateInnerQuery.scala | 10 ++-----
 .../spark/sql/catalyst/optimizer/subquery.scala    | 26 ++++++++--------
 .../optimizer/DecorrelateInnerQuerySuite.scala     |  6 ++--
 .../negative-cases/invalid-correlation.sql         |  9 ++++++
 .../negative-cases/invalid-correlation.sql.out     | 24 ++++++++++++++-
 6 files changed, 68 insertions(+), 42 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 555be01..ba680ba 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -2285,8 +2285,8 @@ class Analyzer(override val catalogManager: 
CatalogManager)
     }
 
     /**
-     * Resolve the correlated expressions in a subquery by using the an outer 
plans' references. All
-     * resolved outer references are wrapped in an [[OuterReference]]
+     * Resolve the correlated expressions in a subquery, as if the expressions 
live in the outer
+     * plan. All resolved outer references are wrapped in an [[OuterReference]]
      */
     private def resolveOuterReferences(plan: LogicalPlan, outer: LogicalPlan): 
LogicalPlan = {
       
plan.resolveOperatorsDownWithPruning(_.containsPattern(UNRESOLVED_ATTRIBUTE)) {
@@ -2295,7 +2295,7 @@ class Analyzer(override val catalogManager: 
CatalogManager)
             case u @ UnresolvedAttribute(nameParts) =>
               withPosition(u) {
                 try {
-                  outer.resolve(nameParts, resolver) match {
+                  outer.resolveChildren(nameParts, resolver) match {
                     case Some(outerAttr) => wrapOuterReference(outerAttr)
                     case None => u
                   }
@@ -2317,7 +2317,7 @@ class Analyzer(override val catalogManager: 
CatalogManager)
      */
     private def resolveSubQuery(
         e: SubqueryExpression,
-        plans: Seq[LogicalPlan])(
+        outer: LogicalPlan)(
         f: (LogicalPlan, Seq[Expression]) => SubqueryExpression): 
SubqueryExpression = {
       // Step 1: Resolve the outer expressions.
       var previous: LogicalPlan = null
@@ -2328,10 +2328,8 @@ class Analyzer(override val catalogManager: 
CatalogManager)
         current = executeSameContext(current)
 
         // Use the outer references to resolve the subquery plan if it isn't 
resolved yet.
-        val i = plans.iterator
-        val afterResolve = current
-        while (!current.resolved && current.fastEquals(afterResolve) && 
i.hasNext) {
-          current = resolveOuterReferences(current, i.next())
+        if (!current.resolved) {
+          current = resolveOuterReferences(current, outer)
         }
       } while (!current.resolved && !current.fastEquals(previous))
 
@@ -2354,20 +2352,20 @@ class Analyzer(override val catalogManager: 
CatalogManager)
      * (2) Any aggregate expression(s) that reference outer attributes are 
pushed down to
      *     outer plan to get evaluated.
      */
-    private def resolveSubQueries(plan: LogicalPlan, plans: Seq[LogicalPlan]): 
LogicalPlan = {
+    private def resolveSubQueries(plan: LogicalPlan, outer: LogicalPlan): 
LogicalPlan = {
       
plan.transformAllExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION), 
ruleId) {
         case s @ ScalarSubquery(sub, _, exprId, _) if !sub.resolved =>
-          resolveSubQuery(s, plans)(ScalarSubquery(_, _, exprId))
+          resolveSubQuery(s, outer)(ScalarSubquery(_, _, exprId))
         case e @ Exists(sub, _, exprId, _) if !sub.resolved =>
-          resolveSubQuery(e, plans)(Exists(_, _, exprId))
+          resolveSubQuery(e, outer)(Exists(_, _, exprId))
         case InSubquery(values, l @ ListQuery(_, _, exprId, _, _))
             if values.forall(_.resolved) && !l.resolved =>
-          val expr = resolveSubQuery(l, plans)((plan, exprs) => {
+          val expr = resolveSubQuery(l, outer)((plan, exprs) => {
             ListQuery(plan, exprs, exprId, plan.output)
           })
           InSubquery(values, expr.asInstanceOf[ListQuery])
         case s @ LateralSubquery(sub, _, exprId, _) if !sub.resolved =>
-          resolveSubQuery(s, plans)(LateralSubquery(_, _, exprId))
+          resolveSubQuery(s, outer)(LateralSubquery(_, _, exprId))
       }
     }
 
@@ -2377,14 +2375,17 @@ class Analyzer(override val catalogManager: 
CatalogManager)
     def apply(plan: LogicalPlan): LogicalPlan = 
plan.resolveOperatorsUpWithPruning(
       _.containsPattern(PLAN_EXPRESSION), ruleId) {
       case j: LateralJoin if j.left.resolved =>
-        resolveSubQueries(j, j.children)
+        // We can't pass `LateralJoin` as the outer plan, as its right child 
is not resolved yet
+        // and we can't call `LateralJoin.resolveChildren` to resolve outer 
references. Here we
+        // create a fake Project node as the outer plan.
+        resolveSubQueries(j, Project(Nil, j.left))
       // Only a few unary nodes (Project/Filter/Aggregate) can contain 
subqueries.
       case q: UnaryNode if q.childrenResolved =>
-        resolveSubQueries(q, q.children)
+        resolveSubQueries(q, q)
       case j: Join if j.childrenResolved && j.duplicateResolved =>
-        resolveSubQueries(j, j.children)
+        resolveSubQueries(j, j)
       case s: SupportsSubquery if s.childrenResolved =>
-        resolveSubQueries(s, s.children)
+        resolveSubQueries(s, s)
     }
   }
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala
index a2404f0..f0441e3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala
@@ -258,13 +258,7 @@ object DecorrelateInnerQuery extends PredicateHelper {
   def apply(
       innerPlan: LogicalPlan,
       outerPlan: LogicalPlan): (LogicalPlan, Seq[Expression]) = {
-    apply(innerPlan, Seq(outerPlan))
-  }
-
-  def apply(
-      innerPlan: LogicalPlan,
-      outerPlans: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = {
-    val outputSet = AttributeSet(outerPlans.flatMap(_.outputSet))
+    val outputPlanInputAttrs = outerPlan.inputSet
 
     // The return type of the recursion.
     // The first parameter is a new logical plan with correlation eliminated.
@@ -486,7 +480,7 @@ object DecorrelateInnerQuery extends PredicateHelper {
       }
     }
     val (newChild, joinCond, _) = 
decorrelate(BooleanSimplification(innerPlan), AttributeSet.empty)
-    val (plan, conditions) = deduplicate(newChild, joinCond, outputSet)
+    val (plan, conditions) = deduplicate(newChild, joinCond, 
outputPlanInputAttrs)
     (plan, stripOuterReferences(conditions))
   }
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
index ab9c827..7914d14 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
@@ -220,7 +220,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] 
with PredicateHelper
     */
   private def pullOutCorrelatedPredicates(
       sub: LogicalPlan,
-      outer: Seq[LogicalPlan]): (LogicalPlan, Seq[Expression]) = {
+      outer: LogicalPlan): (LogicalPlan, Seq[Expression]) = {
     val predicateMap = scala.collection.mutable.Map.empty[LogicalPlan, 
Seq[Expression]]
 
     /** Determine which correlated predicate references are missing from this 
plan. */
@@ -272,10 +272,10 @@ object PullupCorrelatedPredicates extends 
Rule[LogicalPlan] with PredicateHelper
     // In case of a collision, change the subquery plan's output to use
     // different attribute by creating alias(s).
     val baseConditions = predicateMap.values.flatten.toSeq
-    val (newPlan, newCond) = if (outer.nonEmpty) {
-      val outputSet = outer.map(_.outputSet).reduce(_ ++ _)
+    val outerPlanInputAttrs = outer.inputSet
+    val (newPlan, newCond) = if (outerPlanInputAttrs.nonEmpty) {
       val (plan, deDuplicatedConditions) =
-        DecorrelateInnerQuery.deduplicate(transformed, baseConditions, 
outputSet)
+        DecorrelateInnerQuery.deduplicate(transformed, baseConditions, 
outerPlanInputAttrs)
       (plan, stripOuterReferences(deDuplicatedConditions))
     } else {
       (transformed, stripOuterReferences(baseConditions))
@@ -283,7 +283,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] 
with PredicateHelper
     (newPlan, newCond)
   }
 
-  private def rewriteSubQueries(plan: LogicalPlan, outerPlans: 
Seq[LogicalPlan]): LogicalPlan = {
+  private def rewriteSubQueries(plan: LogicalPlan): LogicalPlan = {
     /**
      * This function is used as a aid to enforce idempotency of 
pullUpCorrelatedPredicate rule.
      * In the first call to rewriteSubqueries, all the outer references from 
the subplan are
@@ -296,7 +296,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] 
with PredicateHelper
       if (newCond.isEmpty) oldCond else newCond
     }
 
-    def decorrelate(sub: LogicalPlan, outer: Seq[LogicalPlan]): (LogicalPlan, 
Seq[Expression]) = {
+    def decorrelate(sub: LogicalPlan, outer: LogicalPlan): (LogicalPlan, 
Seq[Expression]) = {
       if (SQLConf.get.decorrelateInnerQueryEnabled) {
         DecorrelateInnerQuery(sub, outer)
       } else {
@@ -306,16 +306,16 @@ object PullupCorrelatedPredicates extends 
Rule[LogicalPlan] with PredicateHelper
 
     plan.transformExpressionsWithPruning(_.containsPattern(PLAN_EXPRESSION)) {
       case ScalarSubquery(sub, children, exprId, conditions) if 
children.nonEmpty =>
-        val (newPlan, newCond) = decorrelate(sub, outerPlans)
+        val (newPlan, newCond) = decorrelate(sub, plan)
         ScalarSubquery(newPlan, children, exprId, getJoinCondition(newCond, 
conditions))
       case Exists(sub, children, exprId, conditions) if children.nonEmpty =>
-        val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
+        val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, plan)
         Exists(newPlan, children, exprId, getJoinCondition(newCond, 
conditions))
       case ListQuery(sub, children, exprId, childOutputs, conditions) if 
children.nonEmpty =>
-        val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, outerPlans)
+        val (newPlan, newCond) = pullOutCorrelatedPredicates(sub, plan)
         ListQuery(newPlan, children, exprId, childOutputs, 
getJoinCondition(newCond, conditions))
       case LateralSubquery(sub, children, exprId, conditions) if 
children.nonEmpty =>
-        val (newPlan, newCond) = decorrelate(sub, outerPlans)
+        val (newPlan, newCond) = decorrelate(sub, plan)
         LateralSubquery(newPlan, children, exprId, getJoinCondition(newCond, 
conditions))
     }
   }
@@ -326,7 +326,7 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] 
with PredicateHelper
   def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
     _.containsPattern(PLAN_EXPRESSION)) {
     case j: LateralJoin =>
-      val newPlan = rewriteSubQueries(j, j.children)
+      val newPlan = rewriteSubQueries(j)
       // Since a lateral join's output depends on its left child output and 
its lateral subquery's
       // plan output, we need to trim the domain attributes added to the 
subquery's plan output
       // to preserve the original output of the join.
@@ -337,9 +337,9 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] 
with PredicateHelper
       }
     // Only a few unary nodes (Project/Filter/Aggregate) can contain 
subqueries.
     case q: UnaryNode =>
-      rewriteSubQueries(q, q.children)
+      rewriteSubQueries(q)
     case s: SupportsSubquery =>
-      rewriteSubQueries(s, s.children)
+      rewriteSubQueries(s)
   }
 }
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala
index 93b2703..92995c2 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala
@@ -44,7 +44,7 @@ class DecorrelateInnerQuerySuite extends PlanTest {
       outerPlan: LogicalPlan,
       correctAnswer: LogicalPlan,
       conditions: Seq[Expression]): Unit = {
-    val (outputPlan, joinCond) = DecorrelateInnerQuery(innerPlan, outerPlan)
+    val (outputPlan, joinCond) = DecorrelateInnerQuery(innerPlan, 
outerPlan.select())
     assert(!hasOuterReferences(outputPlan))
     comparePlans(outputPlan, correctAnswer)
     assert(joinCond.length == conditions.length)
@@ -90,7 +90,7 @@ class DecorrelateInnerQuerySuite extends PlanTest {
       Project(Seq(a),
         Filter(OuterReference(a) === a,
           testRelation))
-    val (outputPlan, joinCond) = DecorrelateInnerQuery(innerPlan, outerPlan)
+    val (outputPlan, joinCond) = DecorrelateInnerQuery(innerPlan, 
outerPlan.select())
     val a1 = outputPlan.output.head
     val correctAnswer =
       Project(Seq(Alias(a, a1.name)(a1.exprId)),
@@ -197,7 +197,7 @@ class DecorrelateInnerQuerySuite extends PlanTest {
         Inner,
         Some(OuterReference(x) === a),
         JoinHint.NONE)
-    val error = intercept[AssertionError] { DecorrelateInnerQuery(innerPlan, 
outerPlan) }
+    val error = intercept[AssertionError] { DecorrelateInnerQuery(innerPlan, 
outerPlan.select()) }
     assert(error.getMessage.contains("Correlated column is not allowed in 
join"))
   }
 
diff --git 
a/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/invalid-correlation.sql
 
b/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/invalid-correlation.sql
index 109ffa7..1260fb7 100644
--- 
a/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/invalid-correlation.sql
+++ 
b/sql/core/src/test/resources/sql-tests/inputs/subquery/negative-cases/invalid-correlation.sql
@@ -71,3 +71,12 @@ WHERE  t1a IN (SELECT t2a
                WHERE  EXISTS (SELECT min(t2a) 
                               FROM   t3));
 
+CREATE TEMPORARY VIEW t1_copy AS SELECT * FROM VALUES
+  (1, 2, 3)
+AS t1(t1a, t1b, t1c);
+
+-- invalid because column name `t1a` is ambiguous in the subquery.
+SELECT t1.t1a
+FROM   t1
+JOIN   t1_copy
+ON     EXISTS (SELECT 1 FROM t2 WHERE t2a > t1a)
diff --git 
a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out
 
b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out
index ea01d76..8734511 100644
--- 
a/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/results/subquery/negative-cases/invalid-correlation.sql.out
@@ -1,5 +1,5 @@
 -- Automatically generated by SQLQueryTestSuite
--- Number of queries: 8
+-- Number of queries: 10
 
 
 -- !query
@@ -116,3 +116,25 @@ Aggregate [min(outer(t2a#x)) AS min(outer(t2.t2a))#x]
          +- Project [t3a#x, t3b#x, t3c#x]
             +- SubqueryAlias t3
                +- LocalRelation [t3a#x, t3b#x, t3c#x]
+
+
+-- !query
+CREATE TEMPORARY VIEW t1_copy AS SELECT * FROM VALUES
+  (1, 2, 3)
+AS t1(t1a, t1b, t1c)
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT t1.t1a
+FROM   t1
+JOIN   t1_copy
+ON     EXISTS (SELECT 1 FROM t2 WHERE t2a > t1a)
+-- !query schema
+struct<>
+-- !query output
+org.apache.spark.sql.AnalysisException
+cannot resolve 't1a' given input columns: [t2.t2a, t2.t2b, t2.t2c]; line 4 pos 
44

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to