This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 1e94415739c [SPARK-45586][SQL] Reduce compiler latency for plans with 
large expression trees
1e94415739c is described below

commit 1e94415739ccfc4222a067459d3cb8be480530b4
Author: Kelvin Jiang <kelvin.ji...@databricks.com>
AuthorDate: Thu Oct 19 10:24:58 2023 +0800

    [SPARK-45586][SQL] Reduce compiler latency for plans with large expression 
trees
    
    ### What changes were proposed in this pull request?
    
    * Included rule ID pruning when traversing the expression trees in 
`TypeCoercionRule` (this avoids us from traversing the expression tree over and 
over again in future iterations of the rule)
    * Improved `EquivalentExpressions`:
      * Since `supportedExpression()` is checking for the existence of a 
pattern in the tree, changed to check the `TreePatternBits` instead of 
recursing using `.exists()`
      * When creating an `ExpressionEquals` object, calculating the height 
requires recursing through all of its children, which is O(n^2) when called 
upon each expression in the expression tree. This changes it so that this 
height is cached in the `TreeNode`, so that it is now O(n) when called upon 
each expression in the tree
    * More targeted TreePatternBits pruning in `ResolveTimeZone` and 
`ConstantPropagation`
    
    ### Why are the changes needed?
    
    This PR improves some analyzer and optimizer rules to address 
inefficiencies when handling extremely large expression trees.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    There should be no plan changes, so no unit tests were modified.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #43420 from kelvinjian-db/SPARK-45586-large-expr-trees.
    
    Authored-by: Kelvin Jiang <kelvin.ji...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../spark/sql/catalyst/analysis/TypeCoercion.scala |  4 +++-
 .../sql/catalyst/analysis/timeZoneAnalysis.scala   |  8 ++++----
 .../expressions/EquivalentExpressions.scala        | 23 +++++++---------------
 .../spark/sql/catalyst/optimizer/expressions.scala |  6 ++++--
 .../apache/spark/sql/catalyst/trees/TreeNode.scala |  2 ++
 5 files changed, 20 insertions(+), 23 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
index c26569866e5..b34fd873621 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.AlwaysProcess
 import org.apache.spark.sql.catalyst.types.DataTypeUtils
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.internal.SQLConf
@@ -1215,7 +1216,8 @@ trait TypeCoercionRule extends Rule[LogicalPlan] with 
Logging {
           } else {
             beforeMapChildren
           }
-          withPropagatedTypes.transformExpressionsUp(typeCoercionFn)
+          withPropagatedTypes.transformExpressionsUpWithPruning(
+            AlwaysProcess.fn, ruleId)(typeCoercionFn)
         }
     }
   }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala
index 11a5bc99b6c..01d88f050ca 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/timeZoneAnalysis.scala
@@ -20,7 +20,7 @@ import org.apache.spark.sql.catalyst.SQLConfHelper
 import org.apache.spark.sql.catalyst.expressions.{Cast, Expression, ListQuery, 
TimeZoneAwareExpression}
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.catalyst.trees.TreePattern.{LIST_SUBQUERY, 
TIME_ZONE_AWARE_EXPRESSION}
+import 
org.apache.spark.sql.catalyst.trees.TreePattern.TIME_ZONE_AWARE_EXPRESSION
 import org.apache.spark.sql.types.DataType
 
 /**
@@ -40,10 +40,10 @@ object ResolveTimeZone extends Rule[LogicalPlan] {
 
   override def apply(plan: LogicalPlan): LogicalPlan =
     plan.resolveExpressionsWithPruning(
-      _.containsAnyPattern(LIST_SUBQUERY, TIME_ZONE_AWARE_EXPRESSION), ruleId
-    )(transformTimeZoneExprs)
+      _.containsPattern(TIME_ZONE_AWARE_EXPRESSION), 
ruleId)(transformTimeZoneExprs)
 
-  def resolveTimeZones(e: Expression): Expression = 
e.transform(transformTimeZoneExprs)
+  def resolveTimeZones(e: Expression): Expression = e.transformWithPruning(
+    _.containsPattern(TIME_ZONE_AWARE_EXPRESSION))(transformTimeZoneExprs)
 }
 
 /**
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
index 1a84859cc3a..8738015ce91 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
@@ -22,7 +22,7 @@ import java.util.Objects
 import scala.collection.mutable
 
 import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
-import org.apache.spark.sql.catalyst.expressions.objects.LambdaVariable
+import org.apache.spark.sql.catalyst.trees.TreePattern.{LAMBDA_VARIABLE, 
PLAN_EXPRESSION}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.util.Utils
 
@@ -163,18 +163,13 @@ class EquivalentExpressions(
     case _ => Nil
   }
 
-  private def supportedExpression(e: Expression) = {
-    !e.exists {
-      // `LambdaVariable` is usually used as a loop variable, which can't be 
evaluated ahead of the
-      // loop. So we can't evaluate sub-expressions containing 
`LambdaVariable` at the beginning.
-      case _: LambdaVariable => true
-
+  private def supportedExpression(e: Expression): Boolean = {
+    // `LambdaVariable` is usually used as a loop variable, which can't be 
evaluated ahead of the
+    // loop. So we can't evaluate sub-expressions containing `LambdaVariable` 
at the beginning.
+    !(e.containsPattern(LAMBDA_VARIABLE) ||
       // `PlanExpression` wraps query plan. To compare query plans of 
`PlanExpression` on executor,
       // can cause error like NPE.
-      case _: PlanExpression[_] => Utils.isInRunningSparkTask
-
-      case _ => false
-    }
+      (e.containsPattern(PLAN_EXPRESSION) && Utils.isInRunningSparkTask))
   }
 
   /**
@@ -244,13 +239,9 @@ class EquivalentExpressions(
  * Wrapper around an Expression that provides semantic equality.
  */
 case class ExpressionEquals(e: Expression) {
-  private def getHeight(tree: Expression): Int = {
-    tree.children.map(getHeight).reduceOption(_ max _).getOrElse(0) + 1
-  }
-
   // This is used to do a fast pre-check for child-parent relationship. For 
example, expr1 can
   // only be a parent of expr2 if expr1.height is larger than expr2.height.
-  lazy val height = getHeight(e)
+  def height: Int = e.height
 
   override def equals(o: Any): Boolean = o match {
     case other: ExpressionEquals => e.semanticEquals(other.e) && height == 
other.height
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index cc14789f6f5..91d5e180c59 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -116,7 +116,7 @@ object ConstantFolding extends Rule[LogicalPlan] {
  */
 object ConstantPropagation extends Rule[LogicalPlan] {
   def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
-    _.containsAllPatterns(LITERAL, FILTER), ruleId) {
+    _.containsAllPatterns(LITERAL, FILTER, BINARY_COMPARISON), ruleId) {
     case f: Filter =>
       val (newCondition, _) = traverse(f.condition, replaceChildren = true, 
nullIsFalse = true)
       if (newCondition.isDefined) {
@@ -147,6 +147,8 @@ object ConstantPropagation extends Rule[LogicalPlan] {
   private def traverse(condition: Expression, replaceChildren: Boolean, 
nullIsFalse: Boolean)
     : (Option[Expression], AttributeMap[(Literal, BinaryComparison)]) =
     condition match {
+      case _ if !condition.containsAllPatterns(LITERAL, BINARY_COMPARISON) =>
+        (None, AttributeMap.empty)
       case e @ EqualTo(left: AttributeReference, right: Literal)
         if safeToReplace(left, nullIsFalse) =>
         (None, AttributeMap(Map(left -> (right, e))))
@@ -206,7 +208,7 @@ object ConstantPropagation extends Rule[LogicalPlan] {
       equalityPredicates: AttributeMap[(Literal, BinaryComparison)]): 
Expression = {
     val constantsMap = AttributeMap(equalityPredicates.map { case (attr, (lit, 
_)) => attr -> lit })
     val predicates = equalityPredicates.values.map(_._2).toSet
-    condition transform {
+    condition.transformWithPruning(_.containsPattern(BINARY_COMPARISON)) {
       case b: BinaryComparison if !predicates.contains(b) => b transform {
         case a: AttributeReference => constantsMap.getOrElse(a, a)
       }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index a34ad10f36a..cc470d0de6f 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -173,6 +173,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]]
 
   lazy val containsChild: Set[TreeNode[_]] = children.toSet
 
+  lazy val height: Int = children.map(_.height).reduceOption(_ max 
_).getOrElse(0) + 1
+
   // Copied from Scala 2.13.1
   // 
github.com/scala/scala/blob/v2.13.1/src/library/scala/util/hashing/MurmurHash3.scala#L56-L73
   // to prevent the issue https://github.com/scala/bug/issues/10495


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to