This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3feddec3d9c [SPARK-40862][SQL] Support non-aggregated subqueries in 
RewriteCorrelatedScalarSubquery
3feddec3d9c is described below

commit 3feddec3d9c0b2bd44610b20c9448445a6d761d3
Author: allisonwang-db <allison.w...@databricks.com>
AuthorDate: Fri Oct 28 12:25:28 2022 +0800

    [SPARK-40862][SQL] Support non-aggregated subqueries in 
RewriteCorrelatedScalarSubquery
    
    ### What changes were proposed in this pull request?
    This PR updates the `splitSubquery` in `RewriteCorrelatedScalarSubquery` to 
support non-aggregated one-row subquery.
    
    In CheckAnalysis, we allow three types of correlated scalar subquery 
patterns:
    1. SubqueryAlias/Project + Aggregate
    2. SubqueryAlias/Project + Filter + Aggregate
    3. SubqueryAlias/Project + LogicalPlan (maxRows <= 1)
    
    
https://github.com/apache/spark/blob/748fa2792e488a6b923b32e2898d9bb6e16fb4ca/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala#L851-L856
    
    We should support the thrid case in `splitSubquery` to avoid `Unexpected 
operator` exceptions.
    
    ### Why are the changes needed?
    To fix an issue with correlated subquery rewrite.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    New unit tests.
    
    Closes #38336 from allisonwang-db/spark-40862-split-subquery.
    
    Authored-by: allisonwang-db <allison.w...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/catalyst/optimizer/subquery.scala    | 142 +++++++++++----------
 .../scala/org/apache/spark/sql/SubquerySuite.scala |  17 +++
 2 files changed, 95 insertions(+), 64 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
index 6665d885554..3c995573d53 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala
@@ -509,19 +509,21 @@ object RewriteCorrelatedScalarSubquery extends 
Rule[LogicalPlan] with AliasHelpe
   /**
    * Split the plan for a scalar subquery into the parts above the innermost 
query block
    * (first part of returned value), the HAVING clause of the innermost query 
block
-   * (optional second part) and the parts below the HAVING CLAUSE (third part).
+   * (optional second part) and the Aggregate below the HAVING CLAUSE 
(optional third part).
+   * When the third part is empty, it means the subquery is a non-aggregated 
single-row subquery.
    */
-  private def splitSubquery(plan: LogicalPlan) : (Seq[LogicalPlan], 
Option[Filter], Aggregate) = {
+  private def splitSubquery(
+      plan: LogicalPlan): (Seq[LogicalPlan], Option[Filter], 
Option[Aggregate]) = {
     val topPart = ArrayBuffer.empty[LogicalPlan]
     var bottomPart: LogicalPlan = plan
     while (true) {
       bottomPart match {
         case havingPart @ Filter(_, aggPart: Aggregate) =>
-          return (topPart.toSeq, Option(havingPart), aggPart)
+          return (topPart.toSeq, Option(havingPart), Some(aggPart))
 
         case aggPart: Aggregate =>
           // No HAVING clause
-          return (topPart.toSeq, None, aggPart)
+          return (topPart.toSeq, None, Some(aggPart))
 
         case p @ Project(_, child) =>
           topPart += p
@@ -531,6 +533,10 @@ object RewriteCorrelatedScalarSubquery extends 
Rule[LogicalPlan] with AliasHelpe
           topPart += s
           bottomPart = child
 
+        case p: LogicalPlan if p.maxRows.exists(_ <= 1) =>
+          // Non-aggregated one row subquery.
+          return (topPart.toSeq, None, None)
+
         case Filter(_, op) =>
           throw 
QueryExecutionErrors.unexpectedOperatorInCorrelatedSubquery(op, " below filter")
 
@@ -561,72 +567,80 @@ object RewriteCorrelatedScalarSubquery extends 
Rule[LogicalPlan] with AliasHelpe
         val origOutput = query.output.head
 
         val resultWithZeroTups = evalSubqueryOnZeroTups(query)
+        lazy val planWithoutCountBug = Project(
+          currentChild.output :+ origOutput,
+          Join(currentChild, query, LeftOuter, conditions.reduceOption(And), 
JoinHint.NONE))
+
         if (resultWithZeroTups.isEmpty) {
           // CASE 1: Subquery guaranteed not to have the COUNT bug
-          Project(
-            currentChild.output :+ origOutput,
-            Join(currentChild, query, LeftOuter, conditions.reduceOption(And), 
JoinHint.NONE))
+          planWithoutCountBug
         } else {
-          // Subquery might have the COUNT bug. Add appropriate corrections.
           val (topPart, havingNode, aggNode) = splitSubquery(query)
-
-          // The next two cases add a leading column to the outer join input 
to make it
-          // possible to distinguish between the case when no tuples join and 
the case
-          // when the tuple that joins contains null values.
-          // The leading column always has the value TRUE.
-          val alwaysTrueExprId = NamedExpression.newExprId
-          val alwaysTrueExpr = Alias(Literal.TrueLiteral,
-            ALWAYS_TRUE_COLNAME)(exprId = alwaysTrueExprId)
-          val alwaysTrueRef = AttributeReference(ALWAYS_TRUE_COLNAME,
-            BooleanType)(exprId = alwaysTrueExprId)
-
-          val aggValRef = query.output.head
-
-          if (havingNode.isEmpty) {
-            // CASE 2: Subquery with no HAVING clause
-            val subqueryResultExpr =
-              Alias(If(IsNull(alwaysTrueRef),
-                resultWithZeroTups.get,
-                aggValRef), origOutput.name)()
-            subqueryAttrMapping += ((origOutput, 
subqueryResultExpr.toAttribute))
-            Project(
-              currentChild.output :+ subqueryResultExpr,
-              Join(currentChild,
-                Project(query.output :+ alwaysTrueExpr, query),
-                LeftOuter, conditions.reduceOption(And), JoinHint.NONE))
-
+          if (aggNode.isEmpty) {
+            // SPARK-40862: When the aggregate node is empty, it means the 
subquery produces
+            // at most one row and it is not subject to the COUNT bug.
+            planWithoutCountBug
           } else {
-            // CASE 3: Subquery with HAVING clause. Pull the HAVING clause 
above the join.
-            // Need to modify any operators below the join to pass through all 
columns
-            // referenced in the HAVING clause.
-            var subqueryRoot: UnaryNode = aggNode
-            val havingInputs: Seq[NamedExpression] = aggNode.output
-
-            topPart.reverse.foreach {
-              case Project(projList, _) =>
-                subqueryRoot = Project(projList ++ havingInputs, subqueryRoot)
-              case s @ SubqueryAlias(alias, _) =>
-                subqueryRoot = SubqueryAlias(alias, subqueryRoot)
-              case op => throw 
QueryExecutionErrors.unexpectedOperatorInCorrelatedSubquery(op)
+            // Subquery might have the COUNT bug. Add appropriate corrections.
+            val aggregate = aggNode.get
+
+            // The next two cases add a leading column to the outer join input 
to make it
+            // possible to distinguish between the case when no tuples join 
and the case
+            // when the tuple that joins contains null values.
+            // The leading column always has the value TRUE.
+            val alwaysTrueExprId = NamedExpression.newExprId
+            val alwaysTrueExpr = Alias(Literal.TrueLiteral,
+              ALWAYS_TRUE_COLNAME)(exprId = alwaysTrueExprId)
+            val alwaysTrueRef = AttributeReference(ALWAYS_TRUE_COLNAME,
+              BooleanType)(exprId = alwaysTrueExprId)
+
+            val aggValRef = query.output.head
+
+            if (havingNode.isEmpty) {
+              // CASE 2: Subquery with no HAVING clause
+              val subqueryResultExpr =
+                Alias(If(IsNull(alwaysTrueRef),
+                  resultWithZeroTups.get,
+                  aggValRef), origOutput.name)()
+              subqueryAttrMapping += ((origOutput, 
subqueryResultExpr.toAttribute))
+              Project(
+                currentChild.output :+ subqueryResultExpr,
+                Join(currentChild,
+                  Project(query.output :+ alwaysTrueExpr, query),
+                  LeftOuter, conditions.reduceOption(And), JoinHint.NONE))
+
+            } else {
+              // CASE 3: Subquery with HAVING clause. Pull the HAVING clause 
above the join.
+              // Need to modify any operators below the join to pass through 
all columns
+              // referenced in the HAVING clause.
+              var subqueryRoot: UnaryNode = aggregate
+              val havingInputs: Seq[NamedExpression] = aggregate.output
+
+              topPart.reverse.foreach {
+                case Project(projList, _) =>
+                  subqueryRoot = Project(projList ++ havingInputs, 
subqueryRoot)
+                case s@SubqueryAlias(alias, _) =>
+                  subqueryRoot = SubqueryAlias(alias, subqueryRoot)
+                case op => throw 
QueryExecutionErrors.unexpectedOperatorInCorrelatedSubquery(op)
+              }
+
+              // CASE WHEN alwaysTrue IS NULL THEN resultOnZeroTups
+              //      WHEN NOT (original HAVING clause expr) THEN CAST(null AS 
<type of aggVal>)
+              //      ELSE (aggregate value) END AS (original column name)
+              val caseExpr = Alias(CaseWhen(Seq(
+                (IsNull(alwaysTrueRef), resultWithZeroTups.get),
+                (Not(havingNode.get.condition), Literal.create(null, 
aggValRef.dataType))),
+                aggValRef),
+                origOutput.name)()
+
+              subqueryAttrMapping += ((origOutput, caseExpr.toAttribute))
+
+              Project(
+                currentChild.output :+ caseExpr,
+                Join(currentChild,
+                  Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot),
+                  LeftOuter, conditions.reduceOption(And), JoinHint.NONE))
             }
-
-            // CASE WHEN alwaysTrue IS NULL THEN resultOnZeroTups
-            //      WHEN NOT (original HAVING clause expr) THEN CAST(null AS 
<type of aggVal>)
-            //      ELSE (aggregate value) END AS (original column name)
-            val caseExpr = Alias(CaseWhen(Seq(
-              (IsNull(alwaysTrueRef), resultWithZeroTups.get),
-              (Not(havingNode.get.condition), Literal.create(null, 
aggValRef.dataType))),
-              aggValRef),
-              origOutput.name)()
-
-            subqueryAttrMapping += ((origOutput, caseExpr.toAttribute))
-
-            Project(
-              currentChild.output :+ caseExpr,
-              Join(currentChild,
-                Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot),
-                LeftOuter, conditions.reduceOption(And), JoinHint.NONE))
-
           }
         }
     }
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
index 4b586356367..7b67648d475 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala
@@ -2491,4 +2491,21 @@ class SubquerySuite extends QueryTest
         Row("a"))
     }
   }
+
+  test("SPARK-40862: correlated one-row subquery with non-deterministic 
expressions") {
+    import org.apache.spark.sql.functions.udf
+    withTempView("t1") {
+      sql("CREATE TEMP VIEW t1 AS SELECT ARRAY('a', 'b') a")
+      val func = udf(() => "a")
+      spark.udf.register("func", func.asNondeterministic())
+      checkAnswer(sql(
+        """
+          |SELECT (
+          |  SELECT array_sort(a, (i, j) -> rank[i] - rank[j])[0] || str AS 
sorted
+          |  FROM (SELECT MAP('a', 1, 'b', 2) rank, func() AS str)
+          |) FROM t1
+          |""".stripMargin),
+        Row("aa"))
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to