This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 2a10c8d93aa [SPARK-45069][SQL] SQL variable should always be resolved 
after outer reference
2a10c8d93aa is described below

commit 2a10c8d93aa9033842471e4f676fddb3b3f90940
Author: Wenchen Fan <wenc...@databricks.com>
AuthorDate: Mon Sep 11 22:57:47 2023 +0800

    [SPARK-45069][SQL] SQL variable should always be resolved after outer 
reference
    
    ### What changes were proposed in this pull request?
    
    This is a bug fix for the recently added SQL variable feature. It's 
designed to resolve columns to SQL variable as the last resort, but for columns 
in Aggregate, we may resolve columns to outer reference first.
    
    ### Why are the changes needed?
    
    bug fix
    
    ### Does this PR introduce _any_ user-facing change?
    
    yes, the query result can be wrong before this fix
    
    ### How was this patch tested?
    
    new tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    no
    
    Closes #42803 from cloud-fan/meta-col.
    
    Authored-by: Wenchen Fan <wenc...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/catalyst/analysis/Analyzer.scala     | 50 +++++++++++-----------
 .../catalyst/analysis/ColumnResolutionHelper.scala | 26 ++++++++---
 .../analysis/ResolveReferencesInAggregate.scala    | 24 +++++------
 .../analysis/ResolveReferencesInSort.scala         | 13 +++---
 .../analyzer-results/sql-session-variables.sql.out | 25 +++++++++--
 .../sql-tests/inputs/sql-session-variables.sql     |  3 ++
 .../results/sql-session-variables.sql.out          | 19 +++++++-
 7 files changed, 105 insertions(+), 55 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index a8c99075cdb..da983ff0c7c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -683,7 +683,7 @@ class Analyzer(override val catalogManager: CatalogManager) 
extends RuleExecutor
       //       of the analysis phase.
       val colResolved = h.mapExpressions { e =>
         resolveExpressionByPlanOutput(
-          resolveColWithAgg(e, aggForResolving), aggForResolving, allowOuter = 
true)
+          resolveColWithAgg(e, aggForResolving), aggForResolving, 
includeLastResort = true)
       }
       val cond = if 
(SubqueryExpression.hasSubquery(colResolved.havingCondition)) {
         val fake = Project(Alias(colResolved.havingCondition, "fake")() :: 
Nil, aggregate.child)
@@ -1450,6 +1450,7 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
    *    e.g. `SELECT col, current_date FROM t`.
    * 4. Resolves the columns to outer references with the outer plan if we are 
resolving subquery
    *    expressions.
+   * 5. Resolves the columns to SQL variables.
    *
    * Some plan nodes have special column reference resolution logic, please 
read these sub-rules for
    * details:
@@ -1568,7 +1569,8 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
       case g @ Generate(generator, _, _, _, _, _) if generator.resolved => g
 
       case g @ Generate(generator, join, outer, qualifier, output, child) =>
-        val newG = resolveExpressionByPlanOutput(generator, child, throws = 
true, allowOuter = true)
+        val newG = resolveExpressionByPlanOutput(
+          generator, child, throws = true, includeLastResort = true)
         if (newG.fastEquals(generator)) {
           g
         } else {
@@ -1584,7 +1586,7 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
           case plan => plan
         }
         val resolvedOrder = mg.dataOrder
-            .map(resolveExpressionByPlanOutput(_, 
planForResolve).asInstanceOf[SortOrder])
+          .map(resolveExpressionByPlanOutput(_, 
planForResolve).asInstanceOf[SortOrder])
         mg.copy(dataOrder = resolvedOrder)
 
       // Left and right sort expression have to be resolved against the 
respective child plan only
@@ -1614,13 +1616,11 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
 
       // Special case for Project as it supports lateral column alias.
       case p: Project =>
-        val resolvedNoOuter = p.projectList
-          .map(resolveExpressionByPlanChildren(_, p, allowOuter = false))
+        val resolvedBasic = 
p.projectList.map(resolveExpressionByPlanChildren(_, p))
         // Lateral column alias has higher priority than outer reference.
-        val resolvedWithLCA = resolveLateralColumnAlias(resolvedNoOuter)
-        val resolvedWithOuter = resolvedWithLCA.map(resolveOuterRef)
-        val resolvedWithVariables = resolvedWithOuter.map(p => 
resolveVariables(p))
-        p.copy(projectList = 
resolvedWithVariables.map(_.asInstanceOf[NamedExpression]))
+        val resolvedWithLCA = resolveLateralColumnAlias(resolvedBasic)
+        val resolvedFinal = resolvedWithLCA.map(resolveColsLastResort)
+        p.copy(projectList = 
resolvedFinal.map(_.asInstanceOf[NamedExpression]))
 
       case o: OverwriteByExpression if o.table.resolved =>
         // The delete condition of `OverwriteByExpression` will be passed to 
the table
@@ -1714,7 +1714,7 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
           // Columns in HAVING should be resolved with `agg.child.output` 
first, to follow the SQL
           // standard. See more details in SPARK-31519.
           val resolvedWithAgg = resolveColWithAgg(e, agg)
-          resolveExpressionByPlanChildren(resolvedWithAgg, u, allowOuter = 
true)
+          resolveExpressionByPlanChildren(resolvedWithAgg, u, 
includeLastResort = true)
         }
 
       // RepartitionByExpression can host missing attributes that are from a 
descendant node.
@@ -1724,32 +1724,32 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
       // node, and project them way at the end via an extra Project.
       case r @ RepartitionByExpression(partitionExprs, child, _, _)
         if !r.resolved || r.missingInput.nonEmpty =>
-        val resolvedNoOuter = 
partitionExprs.map(resolveExpressionByPlanChildren(_, r))
-        val (newPartitionExprs, newChild) = 
resolveExprsAndAddMissingAttrs(resolvedNoOuter, child)
-        // Outer reference has lower priority than this. See the doc of 
`ResolveReferences`.
-        val resolvedWithOuter = newPartitionExprs.map(resolveOuterRef)
-        val finalPartitionExprs = resolvedWithOuter.map(e => 
resolveVariables(e))
+        val resolvedBasic = 
partitionExprs.map(resolveExpressionByPlanChildren(_, r))
+        val (newPartitionExprs, newChild) = 
resolveExprsAndAddMissingAttrs(resolvedBasic, child)
+        // Missing columns should be resolved right after basic column 
resolution.
+        // See the doc of `ResolveReferences`.
+        val resolvedFinal = newPartitionExprs.map(resolveColsLastResort)
         if (child.output == newChild.output) {
-          r.copy(finalPartitionExprs, newChild)
+          r.copy(resolvedFinal, newChild)
         } else {
-          Project(child.output, r.copy(finalPartitionExprs, newChild))
+          Project(child.output, r.copy(resolvedFinal, newChild))
         }
 
       // Filter can host both grouping expressions/aggregate functions and 
missing attributes.
       // The grouping expressions/aggregate functions resolution takes 
precedence over missing
       // attributes. See the classdoc of `ResolveReferences` for details.
       case f @ Filter(cond, child) if !cond.resolved || 
f.missingInput.nonEmpty =>
-        val resolvedNoOuter = resolveExpressionByPlanChildren(cond, f)
-        val resolvedWithAgg = resolveColWithAgg(resolvedNoOuter, child)
+        val resolvedBasic = resolveExpressionByPlanChildren(cond, f)
+        val resolvedWithAgg = resolveColWithAgg(resolvedBasic, child)
         val (newCond, newChild) = 
resolveExprsAndAddMissingAttrs(Seq(resolvedWithAgg), child)
-        // Outer reference has lowermost priority. See the doc of 
`ResolveReferences`.
-        val resolvedWithOuter = resolveOuterRef(newCond.head)
-        val finalCond = resolveVariables(resolvedWithOuter)
+        // Missing columns should be resolved right after basic column 
resolution.
+        // See the doc of `ResolveReferences`.
+        val resolvedFinal = resolveColsLastResort(newCond.head)
         if (child.output == newChild.output) {
-          f.copy(condition = finalCond)
+          f.copy(condition = resolvedFinal)
         } else {
           // Add missing attributes and then project them away.
-          val newFilter = Filter(finalCond, newChild)
+          val newFilter = Filter(resolvedFinal, newChild)
           Project(child.output, newFilter)
         }
 
@@ -1758,7 +1758,7 @@ class Analyzer(override val catalogManager: 
CatalogManager) extends RuleExecutor
 
       case q: LogicalPlan =>
         logTrace(s"Attempting to resolve 
${q.simpleString(conf.maxToStringFields)}")
-        q.mapExpressions(resolveExpressionByPlanChildren(_, q, allowOuter = 
true))
+        q.mapExpressions(resolveExpressionByPlanChildren(_, q, 
includeLastResort = true))
     }
 
     private object MergeResolvePolicy extends Enumeration {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
index d7b1f99f1ed..54a9c6ca018 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ColumnResolutionHelper.scala
@@ -135,7 +135,7 @@ trait ColumnResolutionHelper extends Logging {
       resolveColumnByName: Seq[String] => Option[Expression],
       getAttrCandidates: () => Seq[Attribute],
       throws: Boolean,
-      allowOuter: Boolean): Expression = {
+      includeLastResort: Boolean): Expression = {
     def innerResolve(e: Expression, isTopLevel: Boolean): Expression = 
withOrigin(e.origin) {
       if (e.resolved) return e
       val resolved = e match {
@@ -196,8 +196,11 @@ trait ColumnResolutionHelper extends Logging {
 
     try {
       val resolved = innerResolve(expr, isTopLevel = true)
-      val withOuterResolved = if (allowOuter) resolveOuterRef(resolved) else 
resolved
-      resolveVariables(withOuterResolved)
+      if (includeLastResort) {
+        resolveColsLastResort(resolved)
+      } else {
+        resolved
+      }
     } catch {
       case ae: AnalysisException if !throws =>
         logDebug(ae.getMessage)
@@ -421,7 +424,7 @@ trait ColumnResolutionHelper extends Logging {
       expr: Expression,
       plan: LogicalPlan,
       throws: Boolean = false,
-      allowOuter: Boolean = false): Expression = {
+      includeLastResort: Boolean = false): Expression = {
     resolveExpression(
       expr,
       resolveColumnByName = nameParts => {
@@ -429,7 +432,7 @@ trait ColumnResolutionHelper extends Logging {
       },
       getAttrCandidates = () => plan.output,
       throws = throws,
-      allowOuter = allowOuter)
+      includeLastResort = includeLastResort)
   }
 
   /**
@@ -443,7 +446,7 @@ trait ColumnResolutionHelper extends Logging {
   def resolveExpressionByPlanChildren(
       e: Expression,
       q: LogicalPlan,
-      allowOuter: Boolean = false): Expression = {
+      includeLastResort: Boolean = false): Expression = {
     val newE = if (e.exists(_.getTagValue(LogicalPlan.PLAN_ID_TAG).nonEmpty)) {
       // If the TreeNodeTag 'LogicalPlan.PLAN_ID_TAG' is attached, it means 
that the plan and
       // expression are from Spark Connect, and need to be resolved in this 
way:
@@ -467,7 +470,16 @@ trait ColumnResolutionHelper extends Logging {
         q.children.head.output
       },
       throws = true,
-      allowOuter = allowOuter)
+      includeLastResort = includeLastResort)
+  }
+
+  /**
+   * The last resort to resolve columns. Currently it does two things:
+   *  - Try to resolve column names as outer references
+   *  - Try to resolve column names as SQL variable
+   */
+  protected def resolveColsLastResort(e: Expression): Expression = {
+    resolveVariables(resolveOuterRef(e))
   }
 
   def resolveExprInAssignment(expr: Expression, hostPlan: LogicalPlan): 
Expression = {
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala
index 6bc1949a4e0..4f5a11835c3 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInAggregate.scala
@@ -59,23 +59,23 @@ class ResolveReferencesInAggregate(val catalogManager: 
CatalogManager) extends S
       case _ => a
     }
 
-    val resolvedGroupExprsNoOuter = a.groupingExpressions
-      .map(resolveExpressionByPlanChildren(_, planForResolve, allowOuter = 
false))
-    val resolvedAggExprsNoOuter = a.aggregateExpressions.map(
-      resolveExpressionByPlanChildren(_, planForResolve, allowOuter = false))
-    val resolvedAggExprsWithLCA = 
resolveLateralColumnAlias(resolvedAggExprsNoOuter)
-    val resolvedAggExprsWithOuter = 
resolvedAggExprsWithLCA.map(resolveOuterRef)
+    val resolvedGroupExprsBasic = a.groupingExpressions
+      .map(resolveExpressionByPlanChildren(_, planForResolve))
+    val resolvedAggExprsBasic = a.aggregateExpressions.map(
+      resolveExpressionByPlanChildren(_, planForResolve))
+    val resolvedAggExprsWithLCA = 
resolveLateralColumnAlias(resolvedAggExprsBasic)
+    val resolvedAggExprsFinal = 
resolvedAggExprsWithLCA.map(resolveColsLastResort)
       .map(_.asInstanceOf[NamedExpression])
     // `groupingExpressions` may rely on `aggregateExpressions`, due to 
features like GROUP BY alias
     // and GROUP BY ALL. We only do basic resolution for 
`groupingExpressions`, and will further
     // resolve it after `aggregateExpressions` are all resolved. Note: the 
basic resolution is
     // needed as `aggregateExpressions` may rely on `groupingExpressions` as 
well, for the session
     // window feature. See the rule `SessionWindowing` for more details.
-    val resolvedGroupExprs = if (resolvedAggExprsWithOuter.forall(_.resolved)) 
{
+    val resolvedGroupExprs = if (resolvedAggExprsFinal.forall(_.resolved)) {
       val resolved = resolveGroupByAll(
-        resolvedAggExprsWithOuter,
-        resolveGroupByAlias(resolvedAggExprsWithOuter, 
resolvedGroupExprsNoOuter)
-      ).map(resolveOuterRef)
+        resolvedAggExprsFinal,
+        resolveGroupByAlias(resolvedAggExprsFinal, resolvedGroupExprsBasic)
+      ).map(resolveColsLastResort)
       // TODO: currently we don't support LCA in `groupingExpressions` yet.
       if (resolved.exists(_.containsPattern(LATERAL_COLUMN_ALIAS_REFERENCE))) {
         throw new AnalysisException(
@@ -89,7 +89,7 @@ class ResolveReferencesInAggregate(val catalogManager: 
CatalogManager) extends S
       // alias/ALL in the next iteration. If aggregate expressions end up as 
unresolved, we don't
       // need to resolve grouping expressions at all, as `CheckAnalysis` will 
report error for
       // aggregate expressions first.
-      resolvedGroupExprsNoOuter
+      resolvedGroupExprsBasic
     }
     a.copy(
       // The aliases in grouping expressions are useless and will be removed 
at the end of analysis
@@ -105,7 +105,7 @@ class ResolveReferencesInAggregate(val catalogManager: 
CatalogManager) extends S
         //       GROUP BY will be removed eventually, by following iterations.
         if (e.resolved) trimAliases(e) else e
       },
-      aggregateExpressions = resolvedAggExprsWithOuter)
+      aggregateExpressions = resolvedAggExprsFinal)
   }
 
   private def resolveGroupByAlias(
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInSort.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInSort.scala
index e4e9188662a..02583ebb8f6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInSort.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveReferencesInSort.scala
@@ -50,19 +50,18 @@ class ResolveReferencesInSort(val catalogManager: 
CatalogManager)
   extends SQLConfHelper with ColumnResolutionHelper {
 
   def apply(s: Sort): LogicalPlan = {
-    val resolvedNoOuter = s.order.map(resolveExpressionByPlanOutput(_, 
s.child))
-    val resolvedWithAgg = resolvedNoOuter.map(resolveColWithAgg(_, s.child))
+    val resolvedBasic = s.order.map(resolveExpressionByPlanOutput(_, s.child))
+    val resolvedWithAgg = resolvedBasic.map(resolveColWithAgg(_, s.child))
     val (missingAttrResolved, newChild) = 
resolveExprsAndAddMissingAttrs(resolvedWithAgg, s.child)
     val orderByAllResolved = resolveOrderByAll(
       s.global, newChild, missingAttrResolved.map(_.asInstanceOf[SortOrder]))
-    val resolvedWithOuter = orderByAllResolved.map(e => resolveOuterRef(e))
-    val finalOrdering = resolvedWithOuter.map(e => resolveVariables(e)
-      .asInstanceOf[SortOrder])
+    val resolvedFinal = orderByAllResolved
+      .map(e => resolveColsLastResort(e).asInstanceOf[SortOrder])
     if (s.child.output == newChild.output) {
-      s.copy(order = finalOrdering)
+      s.copy(order = resolvedFinal)
     } else {
       // Add missing attributes and then project them away.
-      val newSort = s.copy(order = finalOrdering, child = newChild)
+      val newSort = s.copy(order = resolvedFinal, child = newChild)
       Project(s.child.output, newSort)
     }
   }
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out
 
b/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out
index 45bfbf69db3..ff645867415 100644
--- 
a/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/analyzer-results/sql-session-variables.sql.out
@@ -485,9 +485,28 @@ org.apache.spark.sql.AnalysisException
 
 
 -- !query
-SET VARIABLE title = 'Test qualifiers - fail'
+SET VARIABLE title = 'Test variable in aggregate'
 -- !query analysis
 SetVariable [variablereference(system.session.title='Test qualifiers - 
success')]
++- Project [Test variable in aggregate AS title#x]
+   +- OneRowRelation
+
+
+-- !query
+SELECT (SELECT MAX(id) FROM RANGE(10) WHERE id < title) FROM VALUES 1, 2 AS 
t(title)
+-- !query analysis
+Project [scalar-subquery#x [title#x] AS scalarsubquery(title)#xL]
+:  +- Aggregate [max(id#xL) AS max(id)#xL]
+:     +- Filter (id#xL < cast(outer(title#x) as bigint))
+:        +- Range (0, 10, step=1, splits=None)
++- SubqueryAlias t
+   +- LocalRelation [title#x]
+
+
+-- !query
+SET VARIABLE title = 'Test qualifiers - fail'
+-- !query analysis
+SetVariable [variablereference(system.session.title='Test variable in 
aggregate')]
 +- Project [Test qualifiers - fail AS title#x]
    +- OneRowRelation
 
@@ -1881,10 +1900,10 @@ Project [var1#x AS 2#x]
 SELECT c1 AS `2` FROM VALUES(2) AS T(var1), LATERAL(SELECT var1) AS TT(c1)
 -- !query analysis
 Project [c1#x AS 2#x]
-+- LateralJoin lateral-subquery#x [], Inner
++- LateralJoin lateral-subquery#x [var1#x], Inner
    :  +- SubqueryAlias TT
    :     +- Project [var1#x AS c1#x]
-   :        +- Project [variablereference(system.session.var1=1) AS var1#x]
+   :        +- Project [outer(var1#x)]
    :           +- OneRowRelation
    +- SubqueryAlias T
       +- LocalRelation [var1#x]
diff --git 
a/sql/core/src/test/resources/sql-tests/inputs/sql-session-variables.sql 
b/sql/core/src/test/resources/sql-tests/inputs/sql-session-variables.sql
index 4992453603c..53149a5e37b 100644
--- a/sql/core/src/test/resources/sql-tests/inputs/sql-session-variables.sql
+++ b/sql/core/src/test/resources/sql-tests/inputs/sql-session-variables.sql
@@ -80,6 +80,9 @@ DECLARE OR REPLACE VARIABLE var1 INT;
 DROP TEMPORARY VARIABLE sysTem.sesSion.vAr1;
 DROP TEMPORARY VARIABLE var1;
 
+SET VARIABLE title = 'Test variable in aggregate';
+SELECT (SELECT MAX(id) FROM RANGE(10) WHERE id < title) FROM VALUES 1, 2 AS 
t(title);
+
 SET VARIABLE title = 'Test qualifiers - fail';
 DECLARE OR REPLACE VARIABLE builtin.var1 INT;
 DECLARE OR REPLACE VARIABLE system.sesion.var1 INT;
diff --git 
a/sql/core/src/test/resources/sql-tests/results/sql-session-variables.sql.out 
b/sql/core/src/test/resources/sql-tests/results/sql-session-variables.sql.out
index b3146e645c5..0297a8a11a9 100644
--- 
a/sql/core/src/test/resources/sql-tests/results/sql-session-variables.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/results/sql-session-variables.sql.out
@@ -544,6 +544,23 @@ org.apache.spark.sql.AnalysisException
 }
 
 
+-- !query
+SET VARIABLE title = 'Test variable in aggregate'
+-- !query schema
+struct<>
+-- !query output
+
+
+
+-- !query
+SELECT (SELECT MAX(id) FROM RANGE(10) WHERE id < title) FROM VALUES 1, 2 AS 
t(title)
+-- !query schema
+struct<scalarsubquery(title):bigint>
+-- !query output
+0
+1
+
+
 -- !query
 SET VARIABLE title = 'Test qualifiers - fail'
 -- !query schema
@@ -2058,7 +2075,7 @@ SELECT c1 AS `2` FROM VALUES(2) AS T(var1), 
LATERAL(SELECT var1) AS TT(c1)
 -- !query schema
 struct<2:int>
 -- !query output
-1
+2
 
 
 -- !query


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to