This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 01c294b05f3a [SPARK-45760][SQL] Add With expression to avoid 
duplicating expressions
01c294b05f3a is described below

commit 01c294b05f3a9b7bd87cda0ee8b0160f5f58bb24
Author: Wenchen Fan <wenc...@databricks.com>
AuthorDate: Wed Nov 8 00:57:31 2023 +0800

    [SPARK-45760][SQL] Add With expression to avoid duplicating expressions
    
    ### What changes were proposed in this pull request?
    
    Sometimes we need to duplicate expressions when rewriting the plan. It's OK 
for small query, as codegen has common-subexpression-elimination (CSE) to avoid 
evaluating the same expression. However, when the query is big, duplicating 
expressions can lead to a very big expression tree and make catalyst rules very 
slow, or even OOM when updating a leaf node (need to copy all tree nodes).
    
    This PR introduces a new expression to do expression-level CTE: it adds a 
Project to pre-evaluate the common expressions, so that they appear only once 
on the query plan tree, and are evaluated only once. `NullIf` now uses this new 
expression to avoid duplicating the `left` child expression.
    
    ### Why are the changes needed?
    
    make catalyst more efficient.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    new test suite
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #43623 from cloud-fan/with.
    
    Lead-authored-by: Wenchen Fan <wenc...@databricks.com>
    Co-authored-by: Peter Toth <peter.t...@gmail.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../explain-results/function_count_if.explain      |   5 +-
 .../explain-results/function_regexp_substr.explain |   5 +-
 .../sql/connect/ProtoToParsedPlanTestSuite.scala   |  15 +-
 .../spark/sql/catalyst/expressions/With.scala      |  63 +++++++++
 .../sql/catalyst/expressions/nullExpressions.scala |   6 +-
 .../spark/sql/catalyst/optimizer/Optimizer.scala   |   3 +
 .../catalyst/optimizer/RewriteWithExpression.scala |  90 ++++++++++++
 .../spark/sql/catalyst/trees/TreePatterns.scala    |   2 +
 .../optimizer/RewriteWithExpressionSuite.scala     | 157 +++++++++++++++++++++
 9 files changed, 338 insertions(+), 8 deletions(-)

diff --git 
a/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_if.explain
 
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_if.explain
index 1c23bbf6bce5..f2ada15eccb7 100644
--- 
a/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_if.explain
+++ 
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_count_if.explain
@@ -1,2 +1,3 @@
-Aggregate [count(if (((a#0 > 0) = false)) null else (a#0 > 0)) AS count_if((a 
> 0))#0L]
-+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
+Aggregate [count(if ((_common_expr_0#0 = false)) null else _common_expr_0#0) 
AS count_if((a > 0))#0L]
++- Project [id#0L, a#0, b#0, d#0, e#0, f#0, g#0, (a#0 > 0) AS _common_expr_0#0]
+   +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
diff --git 
a/connector/connect/common/src/test/resources/query-tests/explain-results/function_regexp_substr.explain
 
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_regexp_substr.explain
index 69fc760c8291..1811f770f829 100644
--- 
a/connector/connect/common/src/test/resources/query-tests/explain-results/function_regexp_substr.explain
+++ 
b/connector/connect/common/src/test/resources/query-tests/explain-results/function_regexp_substr.explain
@@ -1,2 +1,3 @@
-Project [if ((regexp_extract(g#0, \d{2}(a|b|m), 0) = )) null else 
regexp_extract(g#0, \d{2}(a|b|m), 0) AS regexp_substr(g, \d{2}(a|b|m))#0]
-+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
+Project [if ((_common_expr_0#0 = )) null else _common_expr_0#0 AS 
regexp_substr(g, \d{2}(a|b|m))#0]
++- Project [id#0L, a#0, b#0, d#0, e#0, f#0, g#0, regexp_extract(g#0, 
\d{2}(a|b|m), 0) AS _common_expr_0#0]
+   +- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]
diff --git 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala
 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala
index 9fdaffcba670..e0c4e21503e9 100644
--- 
a/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala
+++ 
b/connector/connect/server/src/test/scala/org/apache/spark/sql/connect/ProtoToParsedPlanTestSuite.scala
@@ -29,7 +29,9 @@ import org.apache.spark.connect.proto
 import org.apache.spark.sql.catalyst.{catalog, QueryPlanningTracker}
 import org.apache.spark.sql.catalyst.analysis.{caseSensitiveResolution, 
Analyzer, FunctionRegistry, Resolver, TableFunctionRegistry}
 import org.apache.spark.sql.catalyst.catalog.SessionCatalog
-import org.apache.spark.sql.catalyst.optimizer.ReplaceExpressions
+import org.apache.spark.sql.catalyst.optimizer.{ReplaceExpressions, 
RewriteWithExpression}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
 import org.apache.spark.sql.connect.config.Connect
 import org.apache.spark.sql.connect.planner.SparkConnectPlanner
 import org.apache.spark.sql.connect.service.SessionHolder
@@ -181,8 +183,15 @@ class ProtoToParsedPlanTestSuite
       val planner = new SparkConnectPlanner(SessionHolder.forTesting(spark))
       val catalystPlan =
         analyzer.executeAndCheck(planner.transformRelation(relation), new 
QueryPlanningTracker)
-      val actual =
-        
removeMemoryAddress(normalizeExprIds(ReplaceExpressions(catalystPlan)).treeString)
+      val finalAnalyzedPlan = {
+        object Helper extends RuleExecutor[LogicalPlan] {
+          val batches =
+            Batch("Finish Analysis", Once, ReplaceExpressions) ::
+              Batch("Rewrite With expression", Once, RewriteWithExpression) :: 
Nil
+        }
+        Helper.execute(catalystPlan)
+      }
+      val actual = 
removeMemoryAddress(normalizeExprIds(finalAnalyzedPlan).treeString)
       val goldenFile = 
goldenFilePath.resolve(relativePath).getParent.resolve(name + ".explain")
       Try(readGoldenFile(goldenFile)) match {
         case Success(expected) if expected == actual => // Test passes.
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala
new file mode 100644
index 000000000000..bfed63af1740
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/With.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions
+
+import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, 
TreePattern, WITH_EXPRESSION}
+import org.apache.spark.sql.types.DataType
+
+/**
+ * An expression holder that keeps a list of common expressions and allow the 
actual expression to
+ * reference these common expressions. The common expressions are guaranteed 
to be evaluated only
+ * once even if it's referenced more than once. This is similar to CTE but is 
expression-level.
+ */
+case class With(child: Expression, defs: Seq[CommonExpressionDef])
+  extends Expression with Unevaluable {
+  override val nodePatterns: Seq[TreePattern] = Seq(WITH_EXPRESSION)
+  override def dataType: DataType = child.dataType
+  override def nullable: Boolean = child.nullable
+  override def children: Seq[Expression] = child +: defs
+  override protected def withNewChildrenInternal(
+      newChildren: IndexedSeq[Expression]): Expression = {
+    copy(child = newChildren.head, defs = 
newChildren.tail.map(_.asInstanceOf[CommonExpressionDef]))
+  }
+}
+
+/**
+ * A wrapper of common expression to carry the id.
+ */
+case class CommonExpressionDef(child: Expression, id: Long = 
CommonExpressionDef.newId)
+  extends UnaryExpression with Unevaluable {
+  override def dataType: DataType = child.dataType
+  override protected def withNewChildInternal(newChild: Expression): 
Expression =
+    copy(child = newChild)
+}
+
+/**
+ * A reference to the common expression by its id. Only resolved common 
expressions can be
+ * referenced, so that we can determine the data type and nullable of the 
reference node.
+ */
+case class CommonExpressionRef(id: Long, dataType: DataType, nullable: Boolean)
+  extends LeafExpression with Unevaluable {
+  def this(exprDef: CommonExpressionDef) = this(exprDef.id, exprDef.dataType, 
exprDef.nullable)
+  override val nodePatterns: Seq[TreePattern] = Seq(COMMON_EXPR_REF)
+}
+
+object CommonExpressionDef {
+  private[sql] val curId = new java.util.concurrent.atomic.AtomicLong()
+  def newId: Long = curId.getAndIncrement()
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
index 948cb6fbedd3..0e9e375b8acf 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala
@@ -154,7 +154,11 @@ case class NullIf(left: Expression, right: Expression, 
replacement: Expression)
   extends RuntimeReplaceable with InheritAnalysisRules {
 
   def this(left: Expression, right: Expression) = {
-    this(left, right, If(EqualTo(left, right), Literal.create(null, 
left.dataType), left))
+    this(left, right, {
+      val commonExpr = CommonExpressionDef(left)
+      val ref = new CommonExpressionRef(commonExpr)
+      With(If(EqualTo(ref, right), Literal.create(null, left.dataType), ref), 
Seq(commonExpr))
+    })
   }
 
   override def parameters: Seq[Expression] = Seq(left, right)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 48ecb9aee211..decef766ae97 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -147,6 +147,9 @@ abstract class Optimizer(catalogManager: CatalogManager)
 
     val batches = (
     Batch("Finish Analysis", Once, FinishAnalysis) ::
+    // We must run this batch after `ReplaceExpressions`, as 
`RuntimeReplaceable` expression
+    // may produce `With` expressions that need to be rewritten.
+    Batch("Rewrite With expression", Once, RewriteWithExpression) ::
     
//////////////////////////////////////////////////////////////////////////////////////////
     // Optimizer rules start here
     
//////////////////////////////////////////////////////////////////////////////////////////
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala
new file mode 100644
index 000000000000..c5bd71b4a7d1
--- /dev/null
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpression.scala
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import scala.collection.mutable
+
+import org.apache.spark.sql.catalyst.expressions.{Alias, CommonExpressionDef, 
CommonExpressionRef, Expression, With}
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.{COMMON_EXPR_REF, 
WITH_EXPRESSION}
+
+/**
+ * Rewrites the `With` expressions by adding a `Project` to pre-evaluate the 
common expressions, or
+ * just inline them if they are cheap.
+ *
+ * Note: For now we only use `With` in a few `RuntimeReplaceable` expressions. 
If we expand its
+ *       usage, we should support aggregate/window functions as well.
+ */
+object RewriteWithExpression extends Rule[LogicalPlan] {
+  override def apply(plan: LogicalPlan): LogicalPlan = {
+    plan.transformWithPruning(_.containsPattern(WITH_EXPRESSION)) {
+      case p if p.expressions.exists(_.containsPattern(WITH_EXPRESSION)) =>
+        var newChildren = p.children
+        var newPlan: LogicalPlan = p.transformExpressionsUp {
+          case With(child, defs) =>
+            val refToExpr = mutable.HashMap.empty[Long, Expression]
+            val childProjections = 
Array.fill(newChildren.size)(mutable.ArrayBuffer.empty[Alias])
+
+            defs.zipWithIndex.foreach { case (CommonExpressionDef(child, id), 
index) =>
+              if (CollapseProject.isCheap(child)) {
+                refToExpr(id) = child
+              } else {
+                val childProjectionIndex = newChildren.indexWhere(
+                  c => child.references.subsetOf(c.outputSet)
+                )
+                if (childProjectionIndex == -1) {
+                  // When we cannot rewrite the common expressions, force to 
inline them so that the
+                  // query can still run. This can happen if the join 
condition contains `With` and
+                  // the common expression references columns from both join 
sides.
+                  // TODO: things can go wrong if the common expression is 
nondeterministic. We
+                  //       don't fix it for now to match the old buggy 
behavior when certain
+                  //       `RuntimeReplaceable` did not use the `With` 
expression.
+                  // TODO: we should calculate the ref count and also inline 
the common expression
+                  //       if it's ref count is 1.
+                  refToExpr(id) = child
+                } else {
+                  val alias = Alias(child, s"_common_expr_$index")()
+                  childProjections(childProjectionIndex) += alias
+                  refToExpr(id) = alias.toAttribute
+                }
+              }
+            }
+
+            newChildren = newChildren.zip(childProjections).map { case (child, 
projections) =>
+              if (projections.nonEmpty) {
+                Project(child.output ++ projections, child)
+              } else {
+                child
+              }
+            }
+
+            child.transformWithPruning(_.containsPattern(COMMON_EXPR_REF)) {
+              case ref: CommonExpressionRef => refToExpr(ref.id)
+            }
+        }
+
+        newPlan = newPlan.withNewChildren(newChildren)
+        if (p.output == newPlan.output) {
+          newPlan
+        } else {
+          Project(p.output, newPlan)
+        }
+    }
+  }
+}
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index 8b714d5a5d28..9b3337d1a940 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -36,6 +36,7 @@ object TreePattern extends Enumeration  {
   val CASE_WHEN: Value = Value
   val CAST: Value = Value
   val COALESCE: Value = Value
+  val COMMON_EXPR_REF: Value = Value
   val CONCAT: Value = Value
   val COUNT: Value = Value
   val CREATE_NAMED_STRUCT: Value = Value
@@ -132,6 +133,7 @@ object TreePattern extends Enumeration  {
   val TYPED_FILTER: Value = Value
   val WINDOW: Value = Value
   val WINDOW_GROUP_LIMIT: Value = Value
+  val WITH_EXPRESSION: Value = Value
   val WITH_WINDOW_DEFINITION: Value = Value
 
   // Unresolved expression patterns (Alphabetically ordered)
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
new file mode 100644
index 000000000000..c625379eb5ff
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteWithExpressionSuite.scala
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, 
CommonExpressionDef, CommonExpressionRef, With}
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.types.IntegerType
+
+class RewriteWithExpressionSuite extends PlanTest {
+
+  object Optimizer extends RuleExecutor[LogicalPlan] {
+    val batches = Batch("Rewrite With expression", Once, 
RewriteWithExpression) :: Nil
+  }
+
+  private val testRelation = LocalRelation($"a".int, $"b".int)
+  private val testRelation2 = LocalRelation($"x".int, $"y".int)
+
+  test("simple common expression") {
+    val a = testRelation.output.head
+    val commonExprDef = CommonExpressionDef(a)
+    val ref = new CommonExpressionRef(commonExprDef)
+    val plan = testRelation.select(With(ref + ref, 
Seq(commonExprDef)).as("col"))
+    comparePlans(Optimizer.execute(plan), testRelation.select((a + 
a).as("col")))
+  }
+
+  test("non-cheap common expression") {
+    val a = testRelation.output.head
+    val commonExprDef = CommonExpressionDef(a + a)
+    val ref = new CommonExpressionRef(commonExprDef)
+    val plan = testRelation.select(With(ref * ref, 
Seq(commonExprDef)).as("col"))
+    val commonExprName = "_common_expr_0"
+    comparePlans(
+      Optimizer.execute(plan),
+      testRelation
+        .select((testRelation.output :+ (a + a).as(commonExprName)): _*)
+        .select(($"$commonExprName" * $"$commonExprName").as("col"))
+        .analyze
+    )
+  }
+
+  test("nested WITH expression") {
+    val a = testRelation.output.head
+    val commonExprDef = CommonExpressionDef(a + a)
+    val ref = new CommonExpressionRef(commonExprDef)
+    val innerExpr = With(ref + ref, Seq(commonExprDef))
+    val innerCommonExprName = "_common_expr_0"
+
+    val b = testRelation.output.last
+    val outerCommonExprDef = CommonExpressionDef(innerExpr + b)
+    val outerRef = new CommonExpressionRef(outerCommonExprDef)
+    val outerExpr = With(outerRef * outerRef, Seq(outerCommonExprDef))
+    val outerCommonExprName = "_common_expr_0"
+
+    val plan = testRelation.select(outerExpr.as("col"))
+    val rewrittenOuterExpr = ($"$innerCommonExprName" + 
$"$innerCommonExprName" + b)
+      .as(outerCommonExprName)
+    val outerExprAttr = AttributeReference(outerCommonExprName, IntegerType)(
+      exprId = rewrittenOuterExpr.exprId)
+    comparePlans(
+      Optimizer.execute(plan),
+      testRelation
+        .select((testRelation.output :+ (a + a).as(innerCommonExprName)): _*)
+        .select((testRelation.output :+ $"$innerCommonExprName" :+ 
rewrittenOuterExpr): _*)
+        .select((outerExprAttr * outerExprAttr).as("col"))
+        .analyze
+    )
+  }
+
+  test("WITH expression in filter") {
+    val a = testRelation.output.head
+    val commonExprDef = CommonExpressionDef(a + a)
+    val ref = new CommonExpressionRef(commonExprDef)
+    val plan = testRelation.where(With(ref < 10 && ref > 0, 
Seq(commonExprDef)))
+    val commonExprName = "_common_expr_0"
+    comparePlans(
+      Optimizer.execute(plan),
+      testRelation
+        .select((testRelation.output :+ (a + a).as(commonExprName)): _*)
+        .where($"$commonExprName" < 10 && $"$commonExprName" > 0)
+        .select(testRelation.output: _*)
+        .analyze
+    )
+  }
+
+  test("WITH expression in join condition: only reference left child") {
+    val a = testRelation.output.head
+    val commonExprDef = CommonExpressionDef(a + a)
+    val ref = new CommonExpressionRef(commonExprDef)
+    val condition = With(ref < 10 && ref > 0, Seq(commonExprDef))
+    val plan = testRelation.join(testRelation2, condition = Some(condition))
+    val commonExprName = "_common_expr_0"
+    comparePlans(
+      Optimizer.execute(plan),
+      testRelation
+        .select((testRelation.output :+ (a + a).as(commonExprName)): _*)
+        .join(testRelation2, condition = Some($"$commonExprName" < 10 && 
$"$commonExprName" > 0))
+        .select((testRelation.output ++ testRelation2.output): _*)
+        .analyze
+    )
+  }
+
+  test("WITH expression in join condition: only reference right child") {
+    val x = testRelation2.output.head
+    val commonExprDef = CommonExpressionDef(x + x)
+    val ref = new CommonExpressionRef(commonExprDef)
+    val condition = With(ref < 10 && ref > 0, Seq(commonExprDef))
+    val plan = testRelation.join(testRelation2, condition = Some(condition))
+    val commonExprName = "_common_expr_0"
+    comparePlans(
+      Optimizer.execute(plan),
+      testRelation
+        .join(
+          testRelation2.select((testRelation2.output :+ (x + 
x).as(commonExprName)): _*),
+          condition = Some($"$commonExprName" < 10 && $"$commonExprName" > 0)
+        )
+        .select((testRelation.output ++ testRelation2.output): _*)
+        .analyze
+    )
+  }
+
+  test("WITH expression in join condition: reference both children") {
+    val a = testRelation.output.head
+    val x = testRelation2.output.head
+    val commonExprDef = CommonExpressionDef(a + x)
+    val ref = new CommonExpressionRef(commonExprDef)
+    val condition = With(ref < 10 && ref > 0, Seq(commonExprDef))
+    val plan = testRelation.join(testRelation2, condition = Some(condition))
+    comparePlans(
+      Optimizer.execute(plan),
+      testRelation
+        .join(
+          testRelation2,
+          // Can't pre-evaluate, have to inline
+          condition = Some((a + x) < 10 && (a + x) > 0)
+        )
+    )
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to