Repository: spark
Updated Branches:
  refs/heads/master f41c0a93f -> 18b75d465


[SPARK-22719][SQL] Refactor ConstantPropagation

## What changes were proposed in this pull request?

The current time complexity of ConstantPropagation is O(n^2), which can be slow 
when the query is complex.
Refactor the implementation with O( n ) time complexity, and some pruning to 
avoid traversing the whole `Condition`

## How was this patch tested?

Unit test.

Also simple benchmark test in ConstantPropagationSuite
```
  val condition = (1 to 500).map{_ => Rand(0) === Rand(0)}.reduce(And)
  val query = testRelation
    .select(columnA)
    .where(condition)
  val start = System.currentTimeMillis()
  (1 to 40).foreach { _ =>
    Optimize.execute(query.analyze)
  }
  val end = System.currentTimeMillis()
  println(end - start)
```
Run time before changes: 18989ms (474ms per loop)
Run time after changes: 1275 ms (32ms per loop)

Author: Wang Gengliang <ltn...@gmail.com>

Closes #19912 from gengliangwang/ConstantPropagation.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/18b75d46
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/18b75d46
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/18b75d46

Branch: refs/heads/master
Commit: 18b75d465b7563de926c5690094086a72a75c09f
Parents: f41c0a9
Author: Wang Gengliang <ltn...@gmail.com>
Authored: Thu Dec 7 10:24:49 2017 -0800
Committer: gatorsmile <gatorsm...@gmail.com>
Committed: Thu Dec 7 10:24:49 2017 -0800

----------------------------------------------------------------------
 .../sql/catalyst/optimizer/expressions.scala    | 106 +++++++++++++------
 1 file changed, 73 insertions(+), 33 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/18b75d46/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
index 785e815..6305b6c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala
@@ -64,49 +64,89 @@ object ConstantFolding extends Rule[LogicalPlan] {
  * }}}
  *
  * Approach used:
- * - Start from AND operator as the root
- * - Get all the children conjunctive predicates which are EqualTo / 
EqualNullSafe such that they
- *   don't have a `NOT` or `OR` operator in them
  * - Populate a mapping of attribute => constant value by looking at all the 
equals predicates
  * - Using this mapping, replace occurrence of the attributes with the 
corresponding constant values
  *   in the AND node.
  */
 object ConstantPropagation extends Rule[LogicalPlan] with PredicateHelper {
-  private def containsNonConjunctionPredicates(expression: Expression): 
Boolean = expression.find {
-    case _: Not | _: Or => true
-    case _ => false
-  }.isDefined
-
   def apply(plan: LogicalPlan): LogicalPlan = plan transform {
-    case f: Filter => f transformExpressionsUp {
-      case and: And =>
-        val conjunctivePredicates =
-          splitConjunctivePredicates(and)
-            .filter(expr => expr.isInstanceOf[EqualTo] || 
expr.isInstanceOf[EqualNullSafe])
-            .filterNot(expr => containsNonConjunctionPredicates(expr))
-
-        val equalityPredicates = conjunctivePredicates.collect {
-          case e @ EqualTo(left: AttributeReference, right: Literal) => 
((left, right), e)
-          case e @ EqualTo(left: Literal, right: AttributeReference) => 
((right, left), e)
-          case e @ EqualNullSafe(left: AttributeReference, right: Literal) => 
((left, right), e)
-          case e @ EqualNullSafe(left: Literal, right: AttributeReference) => 
((right, left), e)
-        }
+    case f: Filter =>
+      val (newCondition, _) = traverse(f.condition, replaceChildren = true)
+      if (newCondition.isDefined) {
+        f.copy(condition = newCondition.get)
+      } else {
+        f
+      }
+  }
 
-        val constantsMap = AttributeMap(equalityPredicates.map(_._1))
-        val predicates = equalityPredicates.map(_._2).toSet
+  type EqualityPredicates = Seq[((AttributeReference, Literal), 
BinaryComparison)]
 
-        def replaceConstants(expression: Expression) = expression transform {
-          case a: AttributeReference =>
-            constantsMap.get(a) match {
-              case Some(literal) => literal
-              case None => a
-            }
+  /**
+   * Traverse a condition as a tree and replace attributes with constant 
values.
+   * - On matching [[And]], recursively traverse each children and get 
propagated mappings.
+   *   If the current node is not child of another [[And]], replace all 
occurrences of the
+   *   attributes with the corresponding constant values.
+   * - If a child of [[And]] is [[EqualTo]] or [[EqualNullSafe]], propagate 
the mapping
+   *   of attribute => constant.
+   * - On matching [[Or]] or [[Not]], recursively traverse each children, 
propagate empty mapping.
+   * - Otherwise, stop traversal and propagate empty mapping.
+   * @param condition condition to be traversed
+   * @param replaceChildren whether to replace attributes with constant values 
in children
+   * @return A tuple including:
+   *         1. Option[Expression]: optional changed condition after traversal
+   *         2. EqualityPredicates: propagated mapping of attribute => constant
+   */
+  private def traverse(condition: Expression, replaceChildren: Boolean)
+    : (Option[Expression], EqualityPredicates) =
+    condition match {
+      case e @ EqualTo(left: AttributeReference, right: Literal) => (None, 
Seq(((left, right), e)))
+      case e @ EqualTo(left: Literal, right: AttributeReference) => (None, 
Seq(((right, left), e)))
+      case e @ EqualNullSafe(left: AttributeReference, right: Literal) =>
+        (None, Seq(((left, right), e)))
+      case e @ EqualNullSafe(left: Literal, right: AttributeReference) =>
+        (None, Seq(((right, left), e)))
+      case a: And =>
+        val (newLeft, equalityPredicatesLeft) = traverse(a.left, 
replaceChildren = false)
+        val (newRight, equalityPredicatesRight) = traverse(a.right, 
replaceChildren = false)
+        val equalityPredicates = equalityPredicatesLeft ++ 
equalityPredicatesRight
+        val newSelf = if (equalityPredicates.nonEmpty && replaceChildren) {
+          Some(And(replaceConstants(newLeft.getOrElse(a.left), 
equalityPredicates),
+            replaceConstants(newRight.getOrElse(a.right), equalityPredicates)))
+        } else {
+          if (newLeft.isDefined || newRight.isDefined) {
+            Some(And(newLeft.getOrElse(a.left), newRight.getOrElse(a.right)))
+          } else {
+            None
+          }
         }
-
-        and transform {
-          case e @ EqualTo(_, _) if !predicates.contains(e) => 
replaceConstants(e)
-          case e @ EqualNullSafe(_, _) if !predicates.contains(e) => 
replaceConstants(e)
+        (newSelf, equalityPredicates)
+      case o: Or =>
+        // Ignore the EqualityPredicates from children since they are only 
propagated through And.
+        val (newLeft, _) = traverse(o.left, replaceChildren = true)
+        val (newRight, _) = traverse(o.right, replaceChildren = true)
+        val newSelf = if (newLeft.isDefined || newRight.isDefined) {
+          Some(Or(left = newLeft.getOrElse(o.left), right = 
newRight.getOrElse((o.right))))
+        } else {
+          None
         }
+        (newSelf, Seq.empty)
+      case n: Not =>
+        // Ignore the EqualityPredicates from children since they are only 
propagated through And.
+        val (newChild, _) = traverse(n.child, replaceChildren = true)
+        (newChild.map(Not), Seq.empty)
+      case _ => (None, Seq.empty)
+    }
+
+  private def replaceConstants(condition: Expression, equalityPredicates: 
EqualityPredicates)
+    : Expression = {
+    val constantsMap = AttributeMap(equalityPredicates.map(_._1))
+    val predicates = equalityPredicates.map(_._2).toSet
+    def replaceConstants0(expression: Expression) = expression transform {
+      case a: AttributeReference => constantsMap.getOrElse(a, a)
+    }
+    condition transform {
+      case e @ EqualTo(_, _) if !predicates.contains(e) => replaceConstants0(e)
+      case e @ EqualNullSafe(_, _) if !predicates.contains(e) => 
replaceConstants0(e)
     }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to