This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 61e25e1cdcbb [SPARK-47071][SQL] Inline With expression if it contains 
special expression
61e25e1cdcbb is described below

commit 61e25e1cdcbb867fca264fa444d30b20e27c5a00
Author: Wenchen Fan <wenc...@databricks.com>
AuthorDate: Fri Feb 16 00:28:29 2024 -0800

    [SPARK-47071][SQL] Inline With expression if it contains special expression
    
    ### What changes were proposed in this pull request?
    
    This is a bug fix for the With expression. If the common expression 
contains special expression like aggregate expresson, we cannot pull it out and 
put it in Project. We have to inline it.
    
    ### Why are the changes needed?
    
    bug fix.
    
    ### Does this PR introduce _any_ user-facing change?
    
    a failed can run after this fix, but this bug is not released yet.
    
    ### How was this patch tested?
    
    new tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    no
    
    Closes #45134 from cloud-fan/with.
    
    Authored-by: Wenchen Fan <wenc...@databricks.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../catalyst/optimizer/RewriteWithExpression.scala | 23 ++++++++++++++++------
 .../sql-compatibility-functions.sql.out            |  7 +++++++
 .../inputs/sql-compatibility-functions.sql         |  3 +++
 .../results/sql-compatibility-functions.sql.out    |  8 ++++++++
 4 files changed, 35 insertions(+), 6 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala
index cf2c77069a19..342c7ad09574 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala
@@ -21,7 +21,7 @@ import scala.collection.mutable
 
 import org.apache.spark.SparkException
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, PlanHelper, 
Project}
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, 
WITH_EXPRESSION}
 
@@ -41,10 +41,15 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
           rewriteWithExprAndInputPlans(expr, inputPlans)
         }
         newPlan = newPlan.withNewChildren(inputPlans.toIndexedSeq)
-        if (p.output == newPlan.output) {
-          newPlan
-        } else {
+        // Since we add extra Projects with extra columns to pre-evaluate the 
common expressions,
+        // the current operator may have extra columns if it inherits the 
output columns from its
+        // child, and we need to project away the extra columns to keep the 
plan schema unchanged.
+        assert(p.output.length <= newPlan.output.length)
+        if (p.output.length < newPlan.output.length) {
+          assert(p.outputSet.subsetOf(newPlan.outputSet))
           Project(p.output, newPlan)
+        } else {
+          newPlan
         }
     }
   }
@@ -85,8 +90,14 @@ object RewriteWithExpression extends Rule[LogicalPlan] {
               refToExpr(id) = child
             } else {
               val alias = Alias(child, s"_common_expr_$index")()
-              childProjections(childProjectionIndex) += alias
-              refToExpr(id) = alias.toAttribute
+              val fakeProj = Project(Seq(alias), 
inputPlans(childProjectionIndex))
+              if 
(PlanHelper.specialExpressionsInUnsupportedOperator(fakeProj).nonEmpty) {
+                // We have to inline the common expression if it cannot be put 
in a Project.
+                refToExpr(id) = child
+              } else {
+                childProjections(childProjectionIndex) += alias
+                refToExpr(id) = alias.toAttribute
+              }
             }
           }
         }
diff --git 
a/sql/core/src/test/resources/sql-tests/analyzer-results/sql-compatibility-functions.sql.out
 
b/sql/core/src/test/resources/sql-tests/analyzer-results/sql-compatibility-functions.sql.out
index b713f4d50917..f80290c5ab34 100644
--- 
a/sql/core/src/test/resources/sql-tests/analyzer-results/sql-compatibility-functions.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/analyzer-results/sql-compatibility-functions.sql.out
@@ -116,3 +116,10 @@ Aggregate [nvl(st#x.col1, value)], [nvl(st#x.col1, value) 
AS nvl(st.col1, value)
       +- Project [cast(id#x as int) AS id#x, cast(st#x as 
struct<col1:string,col2:string>) AS st#x]
          +- SubqueryAlias T
             +- LocalRelation [id#x, st#x]
+
+
+-- !query
+SELECT nullif(SUM(id), 0) from range(5)
+-- !query analysis
+Aggregate [nullif(sum(id#xL), 0) AS nullif(sum(id), 0)#xL]
++- Range (0, 5, step=1, splits=None)
diff --git 
a/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql 
b/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql
index 1ae49c8bfc76..6c840154c618 100644
--- 
a/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql
+++ 
b/sql/core/src/test/resources/sql-tests/inputs/sql-compatibility-functions.sql
@@ -22,3 +22,6 @@ SELECT string(1, 2);
 -- SPARK-21555: RuntimeReplaceable used in group by
 CREATE TEMPORARY VIEW tempView1 AS VALUES (1, NAMED_STRUCT('col1', 'gamma', 
'col2', 'delta')) AS T(id, st);
 SELECT nvl(st.col1, "value"), count(*) FROM from tempView1 GROUP BY 
nvl(st.col1, "value");
+
+-- aggregate function inside NULLIF
+SELECT nullif(SUM(id), 0) from range(5);
diff --git 
a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out
 
b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out
index 1d3257fdaae3..0dd8c738d212 100644
--- 
a/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out
+++ 
b/sql/core/src/test/resources/sql-tests/results/sql-compatibility-functions.sql.out
@@ -126,3 +126,11 @@ SELECT nvl(st.col1, "value"), count(*) FROM from tempView1 
GROUP BY nvl(st.col1,
 struct<nvl(st.col1, value):string,FROM:bigint>
 -- !query output
 gamma  1
+
+
+-- !query
+SELECT nullif(SUM(id), 0) from range(5)
+-- !query schema
+struct<nullif(sum(id), 0):bigint>
+-- !query output
+10


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to