This is an automated email from the ASF dual-hosted git repository.

gengliang pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new d92018e  [SPARK-35298][SQL] Migrate to transformWithPruning for rules 
in Optimizer.scala
d92018e is described below

commit d92018ee358b0009dac626e2c5568db8363f53ee
Author: Yingyi Bu <yingyi...@databricks.com>
AuthorDate: Wed May 12 20:42:47 2021 +0800

    [SPARK-35298][SQL] Migrate to transformWithPruning for rules in 
Optimizer.scala
    
    ### What changes were proposed in this pull request?
    
    Added the following TreePattern enums:
    - ALIAS
    - AND_OR
    - AVERAGE
    - GENERATE
    - INTERSECT
    - SORT
    - SUM
    - DISTINCT_LIKE
    - PROJECT
    - REPARTITION_OPERATION
    - UNION
    
    Added tree traversal pruning to the following rules in Optimizer.scala:
    - EliminateAggregateFilter
    - RemoveRedundantAggregates
    - RemoveNoopOperators
    - RemoveNoopUnion
    - LimitPushDown
    - ColumnPruning
    - CollapseRepartition
    - OptimizeRepartition
    - OptimizeWindowFunctions
    - CollapseWindow
    - TransposeWindow
    - InferFiltersFromGenerate
    - InferFiltersFromConstraints
    - CombineUnions
    - CombineFilters
    - EliminateSorts
    - PruneFilters
    - EliminateLimits
    - DecimalAggregates
    - ConvertToLocalRelation
    - ReplaceDistinctWithAggregate
    - ReplaceIntersectWithSemiJoin
    - ReplaceExceptWithAntiJoin
    - RewriteExceptAll
    - RewriteIntersectAll
    - RemoveLiteralFromGroupExpressions
    - RemoveRepetitionFromGroupExpressions
    - OptimizeLimitZero
    
    ### Why are the changes needed?
    
    Reduce the number of tree traversals and hence improve the query 
compilation latency.
    
    perf diff:
    Rule name | Total Time (baseline) | Total Time (experiment) | 
experiment/baseline
    RemoveRedundantAggregates | 51290766 | 67070477 | 1.31
    RemoveNoopOperators | 192371141 | 196631275 | 1.02
    RemoveNoopUnion | 49222561 | 43266681 | 0.88
    LimitPushDown | 40885185 | 21672646 | 0.53
    ColumnPruning | 2003406120 | 1285562149 | 0.64
    CollapseRepartition | 40648048 | 72646515 | 1.79
    OptimizeRepartition | 37813850 | 20600803 | 0.54
    OptimizeWindowFunctions | 174426904 | 46741409 | 0.27
    CollapseWindow | 38959957 | 24542426 | 0.63
    TransposeWindow | 33533191 | 20414930 | 0.61
    InferFiltersFromGenerate | 21758688 | 15597344 | 0.72
    InferFiltersFromConstraints | 518009794 | 493282321 | 0.95
    CombineUnions | 67694022 | 70550382 | 1.04
    CombineFilters | 35265060 | 29005424 | 0.82
    EliminateSorts | 57025509 | 19795776 | 0.35
    PruneFilters | 433964815 | 465579200 | 1.07
    EliminateLimits | 44275393 | 24476859 | 0.55
    DecimalAggregates | 83143172 | 28816090 | 0.35
    ReplaceDistinctWithAggregate | 21783760 | 18287489 | 0.84
    ReplaceIntersectWithSemiJoin | 22311271 | 16566393 | 0.74
    ReplaceExceptWithAntiJoin | 23838520 | 16588808 | 0.70
    RewriteExceptAll | 32750296 | 29421957 | 0.90
    RewriteIntersectAll | 29760454 | 21243599 | 0.71
    RemoveLiteralFromGroupExpressions | 28151861 | 25270947 | 0.90
    RemoveRepetitionFromGroupExpressions | 29587030 | 23447041 | 0.79
    OptimizeLimitZero | 18081943 | 15597344 | 0.86
    **Accumulated | 4129959311 | 3112676285 | 0.75**
    
    ### How was this patch tested?
    
    Existing tests.
    
    Closes #32439 from sigmod/optimizer.
    
    Authored-by: Yingyi Bu <yingyi...@databricks.com>
    Signed-off-by: Gengliang Wang <ltn...@gmail.com>
---
 .../catalyst/expressions/aggregate/Average.scala   |   3 +
 .../sql/catalyst/expressions/aggregate/Sum.scala   |   3 +
 .../catalyst/expressions/namedExpressions.scala    |   2 +
 .../spark/sql/catalyst/optimizer/Optimizer.scala   | 113 ++++++++++++++-------
 .../plans/logical/basicLogicalOperators.scala      |  10 ++
 .../sql/catalyst/rules/RuleIdCollection.scala      |  24 +++++
 .../spark/sql/catalyst/trees/TreePatterns.scala    |  11 +-
 7 files changed, 128 insertions(+), 38 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
index 8ae24e5..82ad2df 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
 import org.apache.spark.sql.catalyst.analysis.{DecimalPrecision, 
FunctionRegistry, TypeCheckResult}
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.trees.TreePattern.{AVERAGE, TreePattern}
 import org.apache.spark.sql.catalyst.trees.UnaryLike
 import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.types._
@@ -51,6 +52,8 @@ case class Average(child: Expression) extends 
DeclarativeAggregate with Implicit
   // Return data type.
   override def dataType: DataType = resultType
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(AVERAGE)
+
   private lazy val resultType = child.dataType match {
     case DecimalType.Fixed(p, s) =>
       DecimalType.bounded(p + 4, s + 4)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
index 31150fc..16cd9d7 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
 import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
 import org.apache.spark.sql.catalyst.dsl.expressions._
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.trees.TreePattern.{SUM, TreePattern}
 import org.apache.spark.sql.catalyst.trees.UnaryLike
 import org.apache.spark.sql.catalyst.util.TypeUtils
 import org.apache.spark.sql.internal.SQLConf
@@ -52,6 +53,8 @@ case class Sum(child: Expression) extends 
DeclarativeAggregate with ImplicitCast
   override def checkInputDataTypes(): TypeCheckResult =
     TypeUtils.checkForAnsiIntervalOrNumericType(child.dataType, "sum")
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(SUM)
+
   private lazy val resultType = child.dataType match {
     case DecimalType.Fixed(precision, scale) =>
       DecimalType.bounded(precision + 10, scale)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 69f7d24..52487d4 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -158,6 +158,8 @@ case class Alias(child: Expression, name: String)(
     val nonInheritableMetadataKeys: Seq[String] = Seq.empty)
   extends UnaryExpression with NamedExpression {
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(ALIAS)
+
   // Alias(Generator, xx) need to be transformed into Generate(generator, ...)
   override lazy val resolved =
     childrenResolved && checkInputDataTypes().isSuccess && 
!child.isInstanceOf[Generator]
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 07c86a7..19e9312 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -27,7 +27,8 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical._
 import org.apache.spark.sql.catalyst.rules._
-import org.apache.spark.sql.catalyst.trees.TreePattern.PLAN_EXPRESSION
+import org.apache.spark.sql.catalyst.trees.AlwaysProcess
+import org.apache.spark.sql.catalyst.trees.TreePattern._
 import org.apache.spark.sql.connector.catalog.CatalogManager
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
@@ -377,7 +378,8 @@ object EliminateDistinct extends Rule[LogicalPlan] {
  * This rule should be applied before RewriteDistinctAggregates.
  */
 object EliminateAggregateFilter extends Rule[LogicalPlan] {
-  override def apply(plan: LogicalPlan): LogicalPlan = plan 
transformExpressions  {
+  override def apply(plan: LogicalPlan): LogicalPlan = 
plan.transformExpressionsWithPruning(
+    _.containsAllPatterns(TRUE_OR_FALSE_LITERAL), ruleId)  {
     case ae @ AggregateExpression(_, _, _, Some(Literal.TrueLiteral), _) =>
       ae.copy(filter = None)
     case AggregateExpression(af: DeclarativeAggregate, _, _, 
Some(Literal.FalseLiteral), _) =>
@@ -445,6 +447,9 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] {
    * (self) join or to prevent the removal of top-level subquery attributes.
    */
   private def removeRedundantAliases(plan: LogicalPlan, excluded: 
AttributeSet): LogicalPlan = {
+    if (!plan.containsPattern(ALIAS)) {
+      return plan
+    }
     plan match {
       // We want to keep the same output attributes for subqueries. This means 
we cannot remove
       // the aliases that produce these attributes
@@ -506,7 +511,8 @@ object RemoveRedundantAliases extends Rule[LogicalPlan] {
  * only goal is to keep distinct values, while its parent aggregate would 
ignore duplicate values.
  */
 object RemoveRedundantAggregates extends Rule[LogicalPlan] with AliasHelper {
-  def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
+    _.containsPattern(AGGREGATE), ruleId) {
     case upper @ Aggregate(_, _, lower: Aggregate) if lowerIsRedundant(upper, 
lower) =>
       val aliasMap = getAliasMap(lower)
 
@@ -545,7 +551,8 @@ object RemoveRedundantAggregates extends Rule[LogicalPlan] 
with AliasHelper {
  * Remove no-op operators from the query plan that do not make any 
modifications.
  */
 object RemoveNoopOperators extends Rule[LogicalPlan] {
-  def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
+    _.containsAnyPattern(PROJECT, WINDOW), ruleId) {
     // Eliminate no-op Projects
     case p @ Project(_, child) if child.sameOutput(p) => child
 
@@ -597,7 +604,8 @@ object RemoveNoopUnion extends Rule[LogicalPlan] {
     }
   }
 
-  def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
+    _.containsAllPatterns(DISTINCT_LIKE, UNION)) {
     case d @ Distinct(u: Union) =>
       d.withNewChildren(Seq(simplifyUnion(u)))
     case d @ Deduplicate(_, u: Union) =>
@@ -648,7 +656,8 @@ object LimitPushDown extends Rule[LogicalPlan] {
     }
   }
 
-  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+    _.containsPattern(LIMIT), ruleId) {
     // Adding extra Limits below UNION ALL for children which are not Limit or 
do not have Limit
     // descendants whose maxRow is larger. This heuristic is valid assuming 
there does not exist any
     // Limit push-down rule that is unable to infer the value of maxRows.
@@ -745,7 +754,8 @@ object PushProjectionThroughUnion extends Rule[LogicalPlan] 
with PredicateHelper
  */
 object ColumnPruning extends Rule[LogicalPlan] {
 
-  def apply(plan: LogicalPlan): LogicalPlan = removeProjectBeforeFilter(plan 
transform {
+  def apply(plan: LogicalPlan): LogicalPlan = removeProjectBeforeFilter(
+    plan.transformWithPruning(AlwaysProcess.fn, ruleId) {
     // Prunes the unused columns from project list of Project/Aggregate/Expand
     case p @ Project(_, p2: Project) if !p2.outputSet.subsetOf(p.references) =>
       p.copy(child = p2.copy(projectList = 
p2.projectList.filter(p.references.contains)))
@@ -863,7 +873,8 @@ object ColumnPruning extends Rule[LogicalPlan] {
  */
 object CollapseProject extends Rule[LogicalPlan] with AliasHelper {
 
-  def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
+    _.containsPattern(PROJECT), ruleId) {
     case p1 @ Project(_, p2: Project) =>
       if (haveCommonNonDeterministicOutput(p1.projectList, p2.projectList)) {
         p1
@@ -921,7 +932,8 @@ object CollapseProject extends Rule[LogicalPlan] with 
AliasHelper {
  * Combines adjacent [[RepartitionOperation]] operators
  */
 object CollapseRepartition extends Rule[LogicalPlan] {
-  def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
+    _.containsPattern(REPARTITION_OPERATION), ruleId) {
     // Case 1: When a Repartition has a child of Repartition or 
RepartitionByExpression,
     // 1) When the top node does not enable the shuffle (i.e., coalesce API), 
but the child
     //   enables the shuffle. Returns the child node if the last numPartitions 
is bigger;
@@ -943,7 +955,8 @@ object CollapseRepartition extends Rule[LogicalPlan] {
  * and user not specify.
  */
 object OptimizeRepartition extends Rule[LogicalPlan] {
-  override def apply(plan: LogicalPlan): LogicalPlan = plan.transform {
+  override def apply(plan: LogicalPlan): LogicalPlan = 
plan.transformWithPruning(
+    _.containsPattern(REPARTITION_OPERATION), ruleId) {
     case r @ RepartitionByExpression(partitionExpressions, _, numPartitions)
       if partitionExpressions.nonEmpty && 
partitionExpressions.forall(_.foldable) &&
         numPartitions.isEmpty =>
@@ -955,7 +968,8 @@ object OptimizeRepartition extends Rule[LogicalPlan] {
  * Replaces first(col) to nth_value(col, 1) for better performance.
  */
 object OptimizeWindowFunctions extends Rule[LogicalPlan] {
-  def apply(plan: LogicalPlan): LogicalPlan = plan resolveExpressions {
+  def apply(plan: LogicalPlan): LogicalPlan = 
plan.resolveExpressionsWithPruning(
+    _.containsPattern(WINDOW_EXPRESSION), ruleId) {
     case we @ WindowExpression(AggregateExpression(first: First, _, _, _, _),
         WindowSpecDefinition(_, orderSpec, frameSpecification: 
SpecifiedWindowFrame))
         if orderSpec.nonEmpty && frameSpecification.frameType == RowFrame &&
@@ -972,7 +986,8 @@ object OptimizeWindowFunctions extends Rule[LogicalPlan] {
  *   independent and are of the same window function type, collapse into the 
parent.
  */
 object CollapseWindow extends Rule[LogicalPlan] {
-  def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
+    _.containsPattern(WINDOW), ruleId) {
     case w1 @ Window(we1, ps1, os1, w2 @ Window(we2, ps2, os2, grandChild))
         if ps1 == ps2 && os1 == os2 && 
w1.references.intersect(w2.windowOutputSet).isEmpty &&
           we1.nonEmpty && we2.nonEmpty &&
@@ -995,7 +1010,8 @@ object TransposeWindow extends Rule[LogicalPlan] {
     })
   }
 
-  def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
+    _.containsPattern(WINDOW), ruleId) {
     case w1 @ Window(we1, ps1, os1, w2 @ Window(we2, ps2, os2, grandChild))
         if w1.references.intersect(w2.windowOutputSet).isEmpty &&
            w1.expressions.forall(_.deterministic) &&
@@ -1010,7 +1026,8 @@ object TransposeWindow extends Rule[LogicalPlan] {
  * by this [[Generate]] can be removed earlier - before joins and in data 
sources.
  */
 object InferFiltersFromGenerate extends Rule[LogicalPlan] {
-  def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
+    _.containsPattern(GENERATE)) {
     // This rule does not infer filters from foldable expressions to avoid 
constant filters
     // like 'size([1, 2, 3]) > 0'. These do not show up in child's constraints 
and
     // then the idempotence will break.
@@ -1060,7 +1077,8 @@ object InferFiltersFromConstraints extends 
Rule[LogicalPlan]
     }
   }
 
-  private def inferFilters(plan: LogicalPlan): LogicalPlan = plan transform {
+  private def inferFilters(plan: LogicalPlan): LogicalPlan = 
plan.transformWithPruning(
+    _.containsAnyPattern(FILTER, JOIN)) {
     case filter @ Filter(condition, child) =>
       val newFilters = filter.constraints --
         (child.constraints ++ splitConjunctivePredicates(condition))
@@ -1123,7 +1141,8 @@ object InferFiltersFromConstraints extends 
Rule[LogicalPlan]
  * Combines all adjacent [[Union]] operators into a single [[Union]].
  */
 object CombineUnions extends Rule[LogicalPlan] {
-  def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformDownWithPruning(
+    _.containsAnyPattern(UNION, DISTINCT_LIKE), ruleId) {
     case u: Union => flattenUnion(u, false)
     case Distinct(u: Union) => Distinct(flattenUnion(u, true))
     // Only handle distinct-like 'Deduplicate', where the keys == output
@@ -1167,7 +1186,8 @@ object CombineUnions extends Rule[LogicalPlan] {
  * one conjunctive predicate.
  */
 object CombineFilters extends Rule[LogicalPlan] with PredicateHelper {
-  def apply(plan: LogicalPlan): LogicalPlan = plan transform applyLocally
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+    _.containsPattern(FILTER), ruleId)(applyLocally)
 
   val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = {
     // The query execution/optimization does not guarantee the expressions are 
evaluated in order.
@@ -1202,7 +1222,8 @@ object CombineFilters extends Rule[LogicalPlan] with 
PredicateHelper {
  *    function is order irrelevant
  */
 object EliminateSorts extends Rule[LogicalPlan] {
-  def apply(plan: LogicalPlan): LogicalPlan = plan transform applyLocally
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+    _.containsPattern(SORT))(applyLocally)
 
   private val applyLocally: PartialFunction[LogicalPlan, LogicalPlan] = {
     case s @ Sort(orders, _, child) if orders.isEmpty || 
orders.exists(_.child.foldable) =>
@@ -1221,11 +1242,16 @@ object EliminateSorts extends Rule[LogicalPlan] {
       g.copy(child = recursiveRemoveSort(originChild))
   }
 
-  private def recursiveRemoveSort(plan: LogicalPlan): LogicalPlan = plan match 
{
-    case Sort(_, _, child) => recursiveRemoveSort(child)
-    case other if canEliminateSort(other) =>
-      other.withNewChildren(other.children.map(recursiveRemoveSort))
-    case _ => plan
+  private def recursiveRemoveSort(plan: LogicalPlan): LogicalPlan = {
+    if (!plan.containsPattern(SORT)) {
+      return plan
+    }
+    plan match {
+      case Sort(_, _, child) => recursiveRemoveSort(child)
+      case other if canEliminateSort(other) =>
+        other.withNewChildren(other.children.map(recursiveRemoveSort))
+      case _ => plan
+    }
   }
 
   private def canEliminateSort(plan: LogicalPlan): Boolean = plan match {
@@ -1264,7 +1290,8 @@ object EliminateSorts extends Rule[LogicalPlan] {
  * 3) by eliminating the always-true conditions given the constraints on the 
child's output.
  */
 object PruneFilters extends Rule[LogicalPlan] with PredicateHelper {
-  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+    _.containsPattern(FILTER), ruleId) {
     // If the filter condition always evaluate to true, remove the filter.
     case Filter(Literal(true, BooleanType), child) => child
     // If the filter condition always evaluate to null or false,
@@ -1620,7 +1647,8 @@ object EliminateLimits extends Rule[LogicalPlan] {
     limitExpr.foldable && child.maxRows.exists { _ <= 
limitExpr.eval().asInstanceOf[Int] }
   }
 
-  def apply(plan: LogicalPlan): LogicalPlan = plan transformDown {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformDownWithPruning(
+    _.containsPattern(LIMIT), ruleId) {
     case Limit(l, child) if canEliminate(l, child) =>
       child
     case GlobalLimit(l, child) if canEliminate(l, child) =>
@@ -1667,7 +1695,7 @@ object CheckCartesianProducts extends Rule[LogicalPlan] 
with PredicateHelper {
   def apply(plan: LogicalPlan): LogicalPlan =
     if (conf.crossJoinEnabled) {
       plan
-    } else plan transform {
+    } else plan.transformWithPruning(_.containsAnyPattern(INNER_LIKE_JOIN, 
OUTER_JOIN))  {
       case j @ Join(left, right, Inner | LeftOuter | RightOuter | FullOuter, 
_, _)
         if isCartesianProduct(j) =>
           throw new AnalysisException(
@@ -1695,8 +1723,10 @@ object DecimalAggregates extends Rule[LogicalPlan] {
   /** Maximum number of decimal digits representable precisely in a Double */
   private val MAX_DOUBLE_DIGITS = 15
 
-  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
-    case q: LogicalPlan => q transformExpressionsDown {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+    _.containsAnyPattern(SUM, AVERAGE), ruleId) {
+    case q: LogicalPlan => q.transformExpressionsDownWithPruning(
+      _.containsAnyPattern(SUM, AVERAGE), ruleId) {
       case we @ WindowExpression(ae @ AggregateExpression(af, _, _, _, _), _) 
=> af match {
         case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= 
MAX_LONG_DIGITS =>
           MakeDecimal(we.copy(windowFunction = ae.copy(aggregateFunction = 
Sum(UnscaledValue(e)))),
@@ -1732,7 +1762,8 @@ object DecimalAggregates extends Rule[LogicalPlan] {
  * another `LocalRelation`.
  */
 object ConvertToLocalRelation extends Rule[LogicalPlan] {
-  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+    _.containsPattern(LOCAL_RELATION), ruleId) {
     case Project(projectList, LocalRelation(output, data, isStreaming))
         if !projectList.exists(hasUnevaluableExpr) =>
       val projection = new InterpretedMutableProjection(projectList, output)
@@ -1761,7 +1792,8 @@ object ConvertToLocalRelation extends Rule[LogicalPlan] {
  * }}}
  */
 object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] {
-  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+    _.containsPattern(DISTINCT_LIKE), ruleId) {
     case Distinct(child) => Aggregate(child.output, child.output, child)
   }
 }
@@ -1805,7 +1837,8 @@ object ReplaceDeduplicateWithAggregate extends 
Rule[LogicalPlan] {
  *    join conditions will be incorrect.
  */
 object ReplaceIntersectWithSemiJoin extends Rule[LogicalPlan] {
-  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+    _.containsPattern(INTERSECT), ruleId) {
     case Intersect(left, right, false) =>
       assert(left.output.size == right.output.size)
       val joinCond = left.output.zip(right.output).map { case (l, r) => 
EqualNullSafe(l, r) }
@@ -1826,7 +1859,8 @@ object ReplaceIntersectWithSemiJoin extends 
Rule[LogicalPlan] {
  *    join conditions will be incorrect.
  */
 object ReplaceExceptWithAntiJoin extends Rule[LogicalPlan] {
-  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+    _.containsPattern(EXCEPT), ruleId) {
     case Except(left, right, false) =>
       assert(left.output.size == right.output.size)
       val joinCond = left.output.zip(right.output).map { case (l, r) => 
EqualNullSafe(l, r) }
@@ -1866,7 +1900,8 @@ object ReplaceExceptWithAntiJoin extends 
Rule[LogicalPlan] {
  */
 
 object RewriteExceptAll extends Rule[LogicalPlan] {
-  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+    _.containsPattern(EXCEPT), ruleId) {
     case Except(left, right, true) =>
       assert(left.output.size == right.output.size)
 
@@ -1923,7 +1958,8 @@ object RewriteExceptAll extends Rule[LogicalPlan] {
  * }}}
  */
 object RewriteIntersectAll extends Rule[LogicalPlan] {
-  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+    _.containsPattern(INTERSECT), ruleId) {
     case Intersect(left, right, true) =>
       assert(left.output.size == right.output.size)
 
@@ -1975,7 +2011,8 @@ object RewriteIntersectAll extends Rule[LogicalPlan] {
  * but only makes the grouping key bigger.
  */
 object RemoveLiteralFromGroupExpressions extends Rule[LogicalPlan] {
-  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+    _.containsPattern(AGGREGATE), ruleId) {
     case a @ Aggregate(grouping, _, _) if grouping.nonEmpty =>
       val newGrouping = grouping.filter(!_.foldable)
       if (newGrouping.nonEmpty) {
@@ -1994,7 +2031,8 @@ object RemoveLiteralFromGroupExpressions extends 
Rule[LogicalPlan] {
  * but only makes the grouping key bigger.
  */
 object RemoveRepetitionFromGroupExpressions extends Rule[LogicalPlan] {
-  def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
+    _.containsPattern(AGGREGATE), ruleId) {
     case a @ Aggregate(grouping, _, _) if grouping.size > 1 =>
       val newGrouping = ExpressionSet(grouping).toSeq
       if (newGrouping.size == grouping.size) {
@@ -2014,7 +2052,8 @@ object OptimizeLimitZero extends Rule[LogicalPlan] {
   private def empty(plan: LogicalPlan) =
     LocalRelation(plan.output, data = Seq.empty, isStreaming = 
plan.isStreaming)
 
-  def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+  def apply(plan: LogicalPlan): LogicalPlan = plan.transformUpWithPruning(
+    _.containsAllPatterns(LIMIT, LITERAL)) {
     // Nodes below GlobalLimit or LocalLimit can be pruned if the limit value 
is zero (0).
     // Any subtree in the logical plan that has GlobalLimit 0 or LocalLimit 0 
as its root is
     // semantically equivalent to an empty relation.
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index d3c5b51..88a58fd 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -127,6 +127,8 @@ case class Generate(
     child: LogicalPlan)
   extends UnaryNode {
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(GENERATE)
+
   lazy val requiredChildOutput: Seq[Attribute] = {
     val unrequiredSet = unrequiredChildIndex.toSet
     child.output.zipWithIndex.filterNot(t => 
unrequiredSet.contains(t._2)).map(_._1)
@@ -211,6 +213,8 @@ case class Intersect(
 
   override def nodeName: String = getClass.getSimpleName + ( if ( isAll ) 
"All" else "" )
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(INTERSECT)
+
   override def output: Seq[Attribute] =
     left.output.zip(right.output).map { case (leftAttr, rightAttr) =>
       leftAttr.withNullability(leftAttr.nullable && rightAttr.nullable)
@@ -280,6 +284,8 @@ case class Union(
     }
   }
 
+  final override val nodePatterns: Seq[TreePattern] = Seq(UNION)
+
   /**
    * Note the definition has assumption about how union is implemented 
physically.
    */
@@ -632,6 +638,7 @@ case class Sort(
   override def output: Seq[Attribute] = child.output
   override def maxRows: Option[Long] = child.maxRows
   override def outputOrdering: Seq[SortOrder] = order
+  final override val nodePatterns: Seq[TreePattern] = Seq(SORT)
   override protected def withNewChildInternal(newChild: LogicalPlan): Sort = 
copy(child = newChild)
 }
 
@@ -1203,6 +1210,7 @@ case class Sample(
 case class Distinct(child: LogicalPlan) extends UnaryNode {
   override def maxRows: Option[Long] = child.maxRows
   override def output: Seq[Attribute] = child.output
+  final override val nodePatterns: Seq[TreePattern] = Seq(DISTINCT_LIKE)
   override protected def withNewChildInternal(newChild: LogicalPlan): Distinct 
=
     copy(child = newChild)
 }
@@ -1215,6 +1223,7 @@ abstract class RepartitionOperation extends UnaryNode {
   def numPartitions: Int
   override final def maxRows: Option[Long] = child.maxRows
   override def output: Seq[Attribute] = child.output
+  final override val nodePatterns: Seq[TreePattern] = 
Seq(REPARTITION_OPERATION)
   def partitioning: Partitioning
 }
 
@@ -1314,6 +1323,7 @@ case class Deduplicate(
     child: LogicalPlan) extends UnaryNode {
   override def maxRows: Option[Long] = child.maxRows
   override def output: Seq[Attribute] = child.output
+  final override val nodePatterns: Seq[TreePattern] = Seq(DISTINCT_LIKE)
   override protected def withNewChildInternal(newChild: LogicalPlan): 
Deduplicate =
     copy(child = newChild)
 }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
index 62f09d0..605b57e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala
@@ -88,38 +88,62 @@ object RuleIdCollection {
       "org.apache.spark.sql.catalyst.analysis.UpdateOuterReferences" ::
       // Catalyst Optimizer rules
       "org.apache.spark.sql.catalyst.optimizer.BooleanSimplification" ::
+      "org.apache.spark.sql.catalyst.optimizer.CollapseProject" ::
+      "org.apache.spark.sql.catalyst.optimizer.CollapseRepartition" ::
+      "org.apache.spark.sql.catalyst.optimizer.CollapseWindow" ::
+      "org.apache.spark.sql.catalyst.optimizer.ColumnPruning" ::
       "org.apache.spark.sql.catalyst.optimizer.CombineConcats" ::
+      "org.apache.spark.sql.catalyst.optimizer.CombineFilters" ::
       "org.apache.spark.sql.catalyst.optimizer.CombineTypedFilters" ::
+      "org.apache.spark.sql.catalyst.optimizer.CombineUnions" ::
       "org.apache.spark.sql.catalyst.optimizer.ConstantFolding" ::
       "org.apache.spark.sql.catalyst.optimizer.ConstantPropagation" ::
+      "org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation" ::
       "org.apache.spark.sql.catalyst.optimizer.CostBasedJoinReorder" ::
+      "org.apache.spark.sql.catalyst.optimizer.DecimalAggregates" ::
+      "org.apache.spark.sql.catalyst.optimizer.EliminateAggregateFilter" ::
+      "org.apache.spark.sql.catalyst.optimizer.EliminateLimits" ::
       "org.apache.spark.sql.catalyst.optimizer.EliminateMapObjects" ::
       "org.apache.spark.sql.catalyst.optimizer.EliminateOuterJoin" ::
       "org.apache.spark.sql.catalyst.optimizer.EliminateSerialization" ::
       "org.apache.spark.sql.catalyst.optimizer.LikeSimplification" ::
+      "org.apache.spark.sql.catalyst.optimizer.LimitPushDown" ::
       "org.apache.spark.sql.catalyst.optimizer.LimitPushDownThroughWindow" ::
       "org.apache.spark.sql.catalyst.optimizer.NullPropagation" ::
       "org.apache.spark.sql.catalyst.optimizer.ObjectSerializerPruning" ::
       "org.apache.spark.sql.catalyst.optimizer.OptimizeCsvJsonExprs" ::
       "org.apache.spark.sql.catalyst.optimizer.OptimizeIn" ::
       "org.apache.spark.sql.catalyst.optimizer.Optimizer$OptimizeSubqueries" ::
+      "org.apache.spark.sql.catalyst.optimizer.OptimizeRepartition" ::
+      "org.apache.spark.sql.catalyst.optimizer.OptimizeWindowFunctions" ::
       "org.apache.spark.sql.catalyst.optimizer.OptimizeUpdateFields"::
       "org.apache.spark.sql.catalyst.optimizer.PropagateEmptyRelation" ::
+      "org.apache.spark.sql.catalyst.optimizer.PruneFilters" ::
       "org.apache.spark.sql.catalyst.optimizer.PushDownLeftSemiAntiJoin" ::
       "org.apache.spark.sql.catalyst.optimizer.PushExtraPredicateThroughJoin" 
::
       "org.apache.spark.sql.catalyst.optimizer.PushFoldableIntoBranches" ::
       
"org.apache.spark.sql.catalyst.optimizer.PushLeftSemiLeftAntiThroughJoin" ::
       "org.apache.spark.sql.catalyst.optimizer.ReassignLambdaVariableID" ::
       "org.apache.spark.sql.catalyst.optimizer.RemoveDispensableExpressions" ::
+      
"org.apache.spark.sql.catalyst.optimizer.RemoveLiteralFromGroupExpressions" ::
+      "org.apache.spark.sql.catalyst.optimizer.RemoveNoopOperators" ::
+      "org.apache.spark.sql.catalyst.optimizer.RemoveRedundantAggregates" ::
+      
"org.apache.spark.sql.catalyst.optimizer.RemoveRepetitionFromGroupExpressions" 
::
       "org.apache.spark.sql.catalyst.optimizer.ReorderAssociativeOperator" ::
       "org.apache.spark.sql.catalyst.optimizer.ReorderJoin" ::
+      "org.apache.spark.sql.catalyst.optimizer.ReplaceExceptWithAntiJoin" ::
       "org.apache.spark.sql.catalyst.optimizer.ReplaceExceptWithFilter" ::
+      "org.apache.spark.sql.catalyst.optimizer.ReplaceDistinctWithAggregate" ::
       
"org.apache.spark.sql.catalyst.optimizer.ReplaceNullWithFalseInPredicate" ::
+      "org.apache.spark.sql.catalyst.optimizer.ReplaceIntersectWithSemiJoin" ::
+      "org.apache.spark.sql.catalyst.optimizer.RewriteExceptAll" ::
+      "org.apache.spark.sql.catalyst.optimizer.RewriteIntersectAll" ::
       "org.apache.spark.sql.catalyst.optimizer.SimplifyBinaryComparison" ::
       
"org.apache.spark.sql.catalyst.optimizer.SimplifyCaseConversionExpressions" ::
       "org.apache.spark.sql.catalyst.optimizer.SimplifyCasts" ::
       "org.apache.spark.sql.catalyst.optimizer.SimplifyConditionals" ::
       
"org.apache.spark.sql.catalyst.optimizer.SimplifyConditionalsInPredicate" ::
+      "org.apache.spark.sql.catalyst.optimizer.TransposeWindow" ::
       "org.apache.spark.sql.catalyst.optimizer.UnwrapCastInBinaryComparison" 
::  Nil
   }
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
index d1ba832..40ef7cb 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala
@@ -23,9 +23,11 @@ object TreePattern extends Enumeration  {
 
   // Enum Ids start from 0.
   // Expression patterns (alphabetically ordered)
-  val AND_OR: Value = Value(0)
+  val ALIAS: Value = Value(0)
+  val AND_OR: Value = Value
   val ATTRIBUTE_REFERENCE: Value = Value
   val APPEND_COLUMNS: Value = Value
+  val AVERAGE: Value = Value
   val BINARY_ARITHMETIC: Value = Value
   val BINARY_COMPARISON: Value = Value
   val BOOL_AGG: Value = Value
@@ -41,10 +43,12 @@ object TreePattern extends Enumeration  {
   val EXISTS_SUBQUERY = Value
   val EXPRESSION_WITH_RANDOM_SEED: Value = Value
   val EXTRACT_VALUE: Value = Value
+  val GENERATE: Value = Value
   val IF: Value = Value
   val IN: Value = Value
   val IN_SUBQUERY: Value = Value
   val INSET: Value = Value
+  val INTERSECT: Value = Value
   val JSON_TO_STRUCT: Value = Value
   val LAMBDA_VARIABLE: Value = Value
   val LIKE_FAMLIY: Value = Value
@@ -59,6 +63,8 @@ object TreePattern extends Enumeration  {
   val PLAN_EXPRESSION: Value = Value
   val RUNTIME_REPLACEABLE: Value = Value
   val SCALAR_SUBQUERY: Value = Value
+  val SORT: Value = Value
+  val SUM: Value = Value
   val TRUE_OR_FALSE_LITERAL: Value = Value
   val WINDOW_EXPRESSION: Value = Value
   val UNARY_POSITIVE: Value = Value
@@ -66,6 +72,7 @@ object TreePattern extends Enumeration  {
 
   // Logical plan patterns (alphabetically ordered)
   val AGGREGATE: Value = Value
+  val DISTINCT_LIKE: Value = Value
   val EXCEPT: Value = Value
   val FILTER: Value = Value
   val INNER_LIKE_JOIN: Value = Value
@@ -76,6 +83,8 @@ object TreePattern extends Enumeration  {
   val NATURAL_LIKE_JOIN: Value = Value
   val OUTER_JOIN: Value = Value
   val PROJECT: Value = Value
+  val REPARTITION_OPERATION: Value = Value
+  val UNION: Value = Value
   val TYPED_FILTER: Value = Value
   val WINDOW: Value = Value
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to