Repository: spark
Updated Branches:
  refs/heads/master c4008480b -> 5c8ef376e


[SPARK-17075][SQL][FOLLOWUP] Add Estimation of Constant Literal

### What changes were proposed in this pull request?
`FalseLiteral` and `TrueLiteral` should have been eliminated by optimizer rule 
`BooleanSimplification`, but null literals might be added by optimizer rule 
`NullPropagation`. For safety, our filter estimation should handle all the 
eligible literal cases.

Our optimizer rule BooleanSimplification is unable to remove the null literal 
in many cases. For example, `a < 0 or null`. Thus, we need to handle null 
literal in filter estimation.

`Not` can be pushed down below `And` and `Or`. Then, we could see two 
consecutive `Not`, which need to be collapsed into one. Because of the limited 
expression support for filter estimation, we just need to handle the case 
`Not(null)` for avoiding incorrect error due to the boolean operation on null. 
For details, see below matrix.

```
not NULL = NULL
NULL or false = NULL
NULL or true = true
NULL or NULL = NULL
NULL and false = false
NULL and true = NULL
NULL and NULL = NULL
```
### How was this patch tested?
Added the test cases.

Author: Xiao Li <gatorsm...@gmail.com>

Closes #17446 from gatorsmile/constantFilterEstimation.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5c8ef376
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5c8ef376
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5c8ef376

Branch: refs/heads/master
Commit: 5c8ef376e874497766ba0cc4d97429e33a3d9c61
Parents: c400848
Author: Xiao Li <gatorsm...@gmail.com>
Authored: Wed Mar 29 12:43:22 2017 -0700
Committer: Xiao Li <gatorsm...@gmail.com>
Committed: Wed Mar 29 12:43:22 2017 -0700

----------------------------------------------------------------------
 .../statsEstimation/FilterEstimation.scala      | 39 ++++++++-
 .../statsEstimation/FilterEstimationSuite.scala | 87 ++++++++++++++++++++
 2 files changed, 124 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5c8ef376/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
index f14df93..b32374c 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
@@ -24,6 +24,7 @@ import scala.math.BigDecimal.RoundingMode
 import org.apache.spark.internal.Logging
 import org.apache.spark.sql.catalyst.CatalystConf
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, 
TrueLiteral}
 import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, 
LeafNode, Statistics}
 import org.apache.spark.sql.catalyst.util.DateTimeUtils
 import org.apache.spark.sql.types._
@@ -104,12 +105,23 @@ case class FilterEstimation(plan: Filter, catalystConf: 
CatalystConf) extends Lo
         val percent2 = calculateFilterSelectivity(cond2, update = 
false).getOrElse(1.0)
         Some(percent1 + percent2 - (percent1 * percent2))
 
+      // Not-operator pushdown
       case Not(And(cond1, cond2)) =>
         calculateFilterSelectivity(Or(Not(cond1), Not(cond2)), update = false)
 
+      // Not-operator pushdown
       case Not(Or(cond1, cond2)) =>
         calculateFilterSelectivity(And(Not(cond1), Not(cond2)), update = false)
 
+      // Collapse two consecutive Not operators which could be generated after 
Not-operator pushdown
+      case Not(Not(cond)) =>
+        calculateFilterSelectivity(cond, update = false)
+
+      // The foldable Not has been processed in the ConstantFolding rule
+      // This is a top-down traversal. The Not could be pushed down by the 
above two cases.
+      case Not(l @ Literal(null, _)) =>
+        calculateSingleCondition(l, update = false)
+
       case Not(cond) =>
         calculateFilterSelectivity(cond, update = false) match {
           case Some(percent) => Some(1.0 - percent)
@@ -134,13 +146,16 @@ case class FilterEstimation(plan: Filter, catalystConf: 
CatalystConf) extends Lo
    */
   def calculateSingleCondition(condition: Expression, update: Boolean): 
Option[Double] = {
     condition match {
+      case l: Literal =>
+        evaluateLiteral(l)
+
       // For evaluateBinary method, we assume the literal on the right side of 
an operator.
       // So we will change the order if not.
 
       // EqualTo/EqualNullSafe does not care about the order
-      case op @ Equality(ar: Attribute, l: Literal) =>
+      case Equality(ar: Attribute, l: Literal) =>
         evaluateEquality(ar, l, update)
-      case op @ Equality(l: Literal, ar: Attribute) =>
+      case Equality(l: Literal, ar: Attribute) =>
         evaluateEquality(ar, l, update)
 
       case op @ LessThan(ar: Attribute, l: Literal) =>
@@ -343,6 +358,26 @@ case class FilterEstimation(plan: Filter, catalystConf: 
CatalystConf) extends Lo
   }
 
   /**
+   * Returns a percentage of rows meeting a Literal expression.
+   * This method evaluates all the possible literal cases in Filter.
+   *
+   * FalseLiteral and TrueLiteral should be eliminated by optimizer, but null 
literal might be added
+   * by optimizer rule NullPropagation. For safety, we handle all the cases 
here.
+   *
+   * @param literal a literal value (or constant)
+   * @return an optional double value to show the percentage of rows meeting a 
given condition
+   */
+  def evaluateLiteral(literal: Literal): Option[Double] = {
+    literal match {
+      case Literal(null, _) => Some(0.0)
+      case FalseLiteral => Some(0.0)
+      case TrueLiteral => Some(1.0)
+      // Ideally, we should not hit the following branch
+      case _ => None
+    }
+  }
+
+  /**
    * Returns a percentage of rows meeting "IN" operator expression.
    * This method evaluates the equality predicate for all data types.
    *

http://git-wip-us.apache.org/repos/asf/spark/blob/5c8ef376/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
index 07abe1e..1966c96 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.statsEstimation
 import java.sql.Date
 
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, 
TrueLiteral}
 import org.apache.spark.sql.catalyst.plans.LeftOuter
 import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Join, 
Statistics}
 import 
org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._
@@ -76,6 +77,82 @@ class FilterEstimationSuite extends StatsEstimationTestBase {
     attrDouble -> colStatDouble,
     attrString -> colStatString))
 
+  test("true") {
+    validateEstimatedStats(
+      Filter(TrueLiteral, childStatsTestPlan(Seq(attrInt), 10L)),
+      Seq(attrInt -> colStatInt),
+      expectedRowCount = 10)
+  }
+
+  test("false") {
+    validateEstimatedStats(
+      Filter(FalseLiteral, childStatsTestPlan(Seq(attrInt), 10L)),
+      Nil,
+      expectedRowCount = 0)
+  }
+
+  test("null") {
+    validateEstimatedStats(
+      Filter(Literal(null, IntegerType), childStatsTestPlan(Seq(attrInt), 
10L)),
+      Nil,
+      expectedRowCount = 0)
+  }
+
+  test("Not(null)") {
+    validateEstimatedStats(
+      Filter(Not(Literal(null, IntegerType)), childStatsTestPlan(Seq(attrInt), 
10L)),
+      Nil,
+      expectedRowCount = 0)
+  }
+
+  test("Not(Not(null))") {
+    validateEstimatedStats(
+      Filter(Not(Not(Literal(null, IntegerType))), 
childStatsTestPlan(Seq(attrInt), 10L)),
+      Nil,
+      expectedRowCount = 0)
+  }
+
+  test("cint < 3 AND null") {
+    val condition = And(LessThan(attrInt, Literal(3)), Literal(null, 
IntegerType))
+    validateEstimatedStats(
+      Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)),
+      Nil,
+      expectedRowCount = 0)
+  }
+
+  test("cint < 3 OR null") {
+    val condition = Or(LessThan(attrInt, Literal(3)), Literal(null, 
IntegerType))
+    val m = Filter(condition, childStatsTestPlan(Seq(attrInt), 
10L)).stats(conf)
+    validateEstimatedStats(
+      Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)),
+      Seq(attrInt -> colStatInt),
+      expectedRowCount = 3)
+  }
+
+  test("Not(cint < 3 AND null)") {
+    val condition = Not(And(LessThan(attrInt, Literal(3)), Literal(null, 
IntegerType)))
+    validateEstimatedStats(
+      Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)),
+      Seq(attrInt -> colStatInt),
+      expectedRowCount = 8)
+  }
+
+  test("Not(cint < 3 OR null)") {
+    val condition = Not(Or(LessThan(attrInt, Literal(3)), Literal(null, 
IntegerType)))
+    validateEstimatedStats(
+      Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)),
+      Nil,
+      expectedRowCount = 0)
+  }
+
+  test("Not(cint < 3 AND Not(null))") {
+    val condition = Not(And(LessThan(attrInt, Literal(3)), Not(Literal(null, 
IntegerType))))
+    validateEstimatedStats(
+      Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)),
+      Seq(attrInt -> colStatInt),
+      expectedRowCount = 8)
+  }
+
   test("cint = 2") {
     validateEstimatedStats(
       Filter(EqualTo(attrInt, Literal(2)), childStatsTestPlan(Seq(attrInt), 
10L)),
@@ -163,6 +240,16 @@ class FilterEstimationSuite extends 
StatsEstimationTestBase {
       expectedRowCount = 10)
   }
 
+  test("cint IS NOT NULL && null") {
+    // 'cint < null' will be optimized to 'cint IS NOT NULL && null'.
+    // More similar cases can be found in the Optimizer NullPropagation.
+    val condition = And(IsNotNull(attrInt), Literal(null, IntegerType))
+    validateEstimatedStats(
+      Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)),
+      Nil,
+      expectedRowCount = 0)
+  }
+
   test("cint > 3 AND cint <= 6") {
     val condition = And(GreaterThan(attrInt, Literal(3)), 
LessThanOrEqual(attrInt, Literal(6)))
     validateEstimatedStats(


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to