Repository: spark
Updated Branches:
  refs/heads/master d4f0b1d2c -> 224375c55


[SPARK-22892][SQL] Simplify some estimation logic by using double instead of 
decimal

## What changes were proposed in this pull request?

Simplify some estimation logic by using double instead of decimal.

## How was this patch tested?

Existing tests.

Author: Zhenhua Wang <wangzhen...@huawei.com>

Closes #20062 from wzhfy/simplify_by_double.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/224375c5
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/224375c5
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/224375c5

Branch: refs/heads/master
Commit: 224375c55ff4c832dafbb87c55f2971e6d8994f2
Parents: d4f0b1d
Author: Zhenhua Wang <wangzhen...@huawei.com>
Authored: Fri Dec 29 15:39:56 2017 +0800
Committer: Wenchen Fan <wenc...@databricks.com>
Committed: Fri Dec 29 15:39:56 2017 +0800

----------------------------------------------------------------------
 .../statsEstimation/EstimationUtils.scala       | 30 +++----
 .../statsEstimation/FilterEstimation.scala      | 84 +++++++++-----------
 .../logical/statsEstimation/ValueInterval.scala | 14 ++--
 3 files changed, 59 insertions(+), 69 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/224375c5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
index 71e852a..d793f77 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
@@ -89,29 +89,29 @@ object EstimationUtils {
   }
 
   /**
-   * For simplicity we use Decimal to unify operations for data types whose 
min/max values can be
+   * For simplicity we use Double to unify operations for data types whose 
min/max values can be
    * represented as numbers, e.g. Boolean can be represented as 0 (false) or 1 
(true).
    * The two methods below are the contract of conversion.
    */
-  def toDecimal(value: Any, dataType: DataType): Decimal = {
+  def toDouble(value: Any, dataType: DataType): Double = {
     dataType match {
-      case _: NumericType | DateType | TimestampType => Decimal(value.toString)
-      case BooleanType => if (value.asInstanceOf[Boolean]) Decimal(1) else 
Decimal(0)
+      case _: NumericType | DateType | TimestampType => value.toString.toDouble
+      case BooleanType => if (value.asInstanceOf[Boolean]) 1 else 0
     }
   }
 
-  def fromDecimal(dec: Decimal, dataType: DataType): Any = {
+  def fromDouble(double: Double, dataType: DataType): Any = {
     dataType match {
-      case BooleanType => dec.toLong == 1
-      case DateType => dec.toInt
-      case TimestampType => dec.toLong
-      case ByteType => dec.toByte
-      case ShortType => dec.toShort
-      case IntegerType => dec.toInt
-      case LongType => dec.toLong
-      case FloatType => dec.toFloat
-      case DoubleType => dec.toDouble
-      case _: DecimalType => dec
+      case BooleanType => double.toInt == 1
+      case DateType => double.toInt
+      case TimestampType => double.toLong
+      case ByteType => double.toByte
+      case ShortType => double.toShort
+      case IntegerType => double.toInt
+      case LongType => double.toLong
+      case FloatType => double.toFloat
+      case DoubleType => double
+      case _: DecimalType => Decimal(double)
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/224375c5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
index 850dd1b..4cc32de 100755
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
@@ -31,7 +31,7 @@ case class FilterEstimation(plan: Filter) extends Logging {
 
   private val childStats = plan.child.stats
 
-  private val colStatsMap = new ColumnStatsMap(childStats.attributeStats)
+  private val colStatsMap = ColumnStatsMap(childStats.attributeStats)
 
   /**
    * Returns an option of Statistics for a Filter logical plan node.
@@ -47,7 +47,7 @@ case class FilterEstimation(plan: Filter) extends Logging {
 
     // Estimate selectivity of this filter predicate, and update column stats 
if needed.
     // For not-supported condition, set filter selectivity to a conservative 
estimate 100%
-    val filterSelectivity = 
calculateFilterSelectivity(plan.condition).getOrElse(BigDecimal(1))
+    val filterSelectivity = 
calculateFilterSelectivity(plan.condition).getOrElse(1.0)
 
     val filteredRowCount: BigInt = ceil(BigDecimal(childStats.rowCount.get) * 
filterSelectivity)
     val newColStats = if (filteredRowCount == 0) {
@@ -79,17 +79,16 @@ case class FilterEstimation(plan: Filter) extends Logging {
    * @return an optional double value to show the percentage of rows meeting a 
given condition.
    *         It returns None if the condition is not supported.
    */
-  def calculateFilterSelectivity(condition: Expression, update: Boolean = true)
-    : Option[BigDecimal] = {
+  def calculateFilterSelectivity(condition: Expression, update: Boolean = 
true): Option[Double] = {
     condition match {
       case And(cond1, cond2) =>
-        val percent1 = calculateFilterSelectivity(cond1, 
update).getOrElse(BigDecimal(1))
-        val percent2 = calculateFilterSelectivity(cond2, 
update).getOrElse(BigDecimal(1))
+        val percent1 = calculateFilterSelectivity(cond1, update).getOrElse(1.0)
+        val percent2 = calculateFilterSelectivity(cond2, update).getOrElse(1.0)
         Some(percent1 * percent2)
 
       case Or(cond1, cond2) =>
-        val percent1 = calculateFilterSelectivity(cond1, update = 
false).getOrElse(BigDecimal(1))
-        val percent2 = calculateFilterSelectivity(cond2, update = 
false).getOrElse(BigDecimal(1))
+        val percent1 = calculateFilterSelectivity(cond1, update = 
false).getOrElse(1.0)
+        val percent2 = calculateFilterSelectivity(cond2, update = 
false).getOrElse(1.0)
         Some(percent1 + percent2 - (percent1 * percent2))
 
       // Not-operator pushdown
@@ -131,7 +130,7 @@ case class FilterEstimation(plan: Filter) extends Logging {
    * @return an optional double value to show the percentage of rows meeting a 
given condition.
    *         It returns None if the condition is not supported.
    */
-  def calculateSingleCondition(condition: Expression, update: Boolean): 
Option[BigDecimal] = {
+  def calculateSingleCondition(condition: Expression, update: Boolean): 
Option[Double] = {
     condition match {
       case l: Literal =>
         evaluateLiteral(l)
@@ -225,17 +224,17 @@ case class FilterEstimation(plan: Filter) extends Logging 
{
   def evaluateNullCheck(
       attr: Attribute,
       isNull: Boolean,
-      update: Boolean): Option[BigDecimal] = {
+      update: Boolean): Option[Double] = {
     if (!colStatsMap.contains(attr)) {
       logDebug("[CBO] No statistics for " + attr)
       return None
     }
     val colStat = colStatsMap(attr)
     val rowCountValue = childStats.rowCount.get
-    val nullPercent: BigDecimal = if (rowCountValue == 0) {
+    val nullPercent: Double = if (rowCountValue == 0) {
       0
     } else {
-      BigDecimal(colStat.nullCount) / BigDecimal(rowCountValue)
+      (BigDecimal(colStat.nullCount) / BigDecimal(rowCountValue)).toDouble
     }
 
     if (update) {
@@ -271,7 +270,7 @@ case class FilterEstimation(plan: Filter) extends Logging {
       op: BinaryComparison,
       attr: Attribute,
       literal: Literal,
-      update: Boolean): Option[BigDecimal] = {
+      update: Boolean): Option[Double] = {
     if (!colStatsMap.contains(attr)) {
       logDebug("[CBO] No statistics for " + attr)
       return None
@@ -305,13 +304,12 @@ case class FilterEstimation(plan: Filter) extends Logging 
{
   def evaluateEquality(
       attr: Attribute,
       literal: Literal,
-      update: Boolean): Option[BigDecimal] = {
+      update: Boolean): Option[Double] = {
     if (!colStatsMap.contains(attr)) {
       logDebug("[CBO] No statistics for " + attr)
       return None
     }
     val colStat = colStatsMap(attr)
-    val ndv = colStat.distinctCount
 
     // decide if the value is in [min, max] of the column.
     // We currently don't store min/max for binary/string type.
@@ -334,7 +332,7 @@ case class FilterEstimation(plan: Filter) extends Logging {
 
       if (colStat.histogram.isEmpty) {
         // returns 1/ndv if there is no histogram
-        Some(1.0 / BigDecimal(ndv))
+        Some(1.0 / colStat.distinctCount.toDouble)
       } else {
         Some(computeEqualityPossibilityByHistogram(literal, colStat))
       }
@@ -354,7 +352,7 @@ case class FilterEstimation(plan: Filter) extends Logging {
    * @param literal a literal value (or constant)
    * @return an optional double value to show the percentage of rows meeting a 
given condition
    */
-  def evaluateLiteral(literal: Literal): Option[BigDecimal] = {
+  def evaluateLiteral(literal: Literal): Option[Double] = {
     literal match {
       case Literal(null, _) => Some(0.0)
       case FalseLiteral => Some(0.0)
@@ -379,7 +377,7 @@ case class FilterEstimation(plan: Filter) extends Logging {
   def evaluateInSet(
       attr: Attribute,
       hSet: Set[Any],
-      update: Boolean): Option[BigDecimal] = {
+      update: Boolean): Option[Double] = {
     if (!colStatsMap.contains(attr)) {
       logDebug("[CBO] No statistics for " + attr)
       return None
@@ -403,8 +401,8 @@ case class FilterEstimation(plan: Filter) extends Logging {
           return Some(0.0)
         }
 
-        val newMax = validQuerySet.maxBy(EstimationUtils.toDecimal(_, 
dataType))
-        val newMin = validQuerySet.minBy(EstimationUtils.toDecimal(_, 
dataType))
+        val newMax = validQuerySet.maxBy(EstimationUtils.toDouble(_, dataType))
+        val newMin = validQuerySet.minBy(EstimationUtils.toDouble(_, dataType))
         // newNdv should not be greater than the old ndv.  For example, column 
has only 2 values
         // 1 and 6. The predicate column IN (1, 2, 3, 4, 5). 
validQuerySet.size is 5.
         newNdv = ndv.min(BigInt(validQuerySet.size))
@@ -425,7 +423,7 @@ case class FilterEstimation(plan: Filter) extends Logging {
 
     // return the filter selectivity.  Without advanced statistics such as 
histograms,
     // we have to assume uniform distribution.
-    Some((BigDecimal(newNdv) / BigDecimal(ndv)).min(1.0))
+    Some(math.min(newNdv.toDouble / ndv.toDouble, 1.0))
   }
 
   /**
@@ -443,21 +441,17 @@ case class FilterEstimation(plan: Filter) extends Logging 
{
       op: BinaryComparison,
       attr: Attribute,
       literal: Literal,
-      update: Boolean): Option[BigDecimal] = {
+      update: Boolean): Option[Double] = {
 
     val colStat = colStatsMap(attr)
     val statsInterval =
       ValueInterval(colStat.min, colStat.max, 
attr.dataType).asInstanceOf[NumericValueInterval]
-    val max = statsInterval.max.toBigDecimal
-    val min = statsInterval.min.toBigDecimal
-    val ndv = BigDecimal(colStat.distinctCount)
+    val max = statsInterval.max
+    val min = statsInterval.min
+    val ndv = colStat.distinctCount.toDouble
 
     // determine the overlapping degree between predicate interval and 
column's interval
-    val numericLiteral = if (literal.dataType == BooleanType) {
-      if (literal.value.asInstanceOf[Boolean]) BigDecimal(1) else BigDecimal(0)
-    } else {
-      BigDecimal(literal.value.toString)
-    }
+    val numericLiteral = EstimationUtils.toDouble(literal.value, 
literal.dataType)
     val (noOverlap: Boolean, completeOverlap: Boolean) = op match {
       case _: LessThan =>
         (numericLiteral <= min, numericLiteral > max)
@@ -469,7 +463,7 @@ case class FilterEstimation(plan: Filter) extends Logging {
         (numericLiteral > max, numericLiteral <= min)
     }
 
-    var percent = BigDecimal(1)
+    var percent = 1.0
     if (noOverlap) {
       percent = 0.0
     } else if (completeOverlap) {
@@ -518,8 +512,6 @@ case class FilterEstimation(plan: Filter) extends Logging {
         val newValue = Some(literal.value)
         var newMax = colStat.max
         var newMin = colStat.min
-        var newNdv = ceil(ndv * percent)
-        if (newNdv < 1) newNdv = 1
 
         op match {
           case _: GreaterThan | _: GreaterThanOrEqual =>
@@ -528,8 +520,8 @@ case class FilterEstimation(plan: Filter) extends Logging {
             newMax = newValue
         }
 
-        val newStats =
-          colStat.copy(distinctCount = newNdv, min = newMin, max = newMax, 
nullCount = 0)
+        val newStats = colStat.copy(distinctCount = ceil(ndv * percent),
+          min = newMin, max = newMax, nullCount = 0)
 
         colStatsMap.update(attr, newStats)
       }
@@ -543,13 +535,13 @@ case class FilterEstimation(plan: Filter) extends Logging 
{
    */
   private def computeEqualityPossibilityByHistogram(
       literal: Literal, colStat: ColumnStat): Double = {
-    val datum = EstimationUtils.toDecimal(literal.value, 
literal.dataType).toDouble
+    val datum = EstimationUtils.toDouble(literal.value, literal.dataType)
     val histogram = colStat.histogram.get
 
     // find bins where column's current min and max locate.  Note that a 
column's [min, max]
     // range may change due to another condition applied earlier.
-    val min = EstimationUtils.toDecimal(colStat.min.get, 
literal.dataType).toDouble
-    val max = EstimationUtils.toDecimal(colStat.max.get, 
literal.dataType).toDouble
+    val min = EstimationUtils.toDouble(colStat.min.get, literal.dataType)
+    val max = EstimationUtils.toDouble(colStat.max.get, literal.dataType)
 
     // compute how many bins the column's current valid range [min, max] 
occupies.
     val numBinsHoldingEntireRange = EstimationUtils.numBinsHoldingRange(
@@ -574,13 +566,13 @@ case class FilterEstimation(plan: Filter) extends Logging 
{
    */
   private def computeComparisonPossibilityByHistogram(
       op: BinaryComparison, literal: Literal, colStat: ColumnStat): Double = {
-    val datum = EstimationUtils.toDecimal(literal.value, 
literal.dataType).toDouble
+    val datum = EstimationUtils.toDouble(literal.value, literal.dataType)
     val histogram = colStat.histogram.get
 
     // find bins where column's current min and max locate.  Note that a 
column's [min, max]
     // range may change due to another condition applied earlier.
-    val min = EstimationUtils.toDecimal(colStat.min.get, 
literal.dataType).toDouble
-    val max = EstimationUtils.toDecimal(colStat.max.get, 
literal.dataType).toDouble
+    val min = EstimationUtils.toDouble(colStat.min.get, literal.dataType)
+    val max = EstimationUtils.toDouble(colStat.max.get, literal.dataType)
 
     // compute how many bins the column's current valid range [min, max] 
occupies.
     val numBinsHoldingEntireRange = EstimationUtils.numBinsHoldingRange(
@@ -643,7 +635,7 @@ case class FilterEstimation(plan: Filter) extends Logging {
       op: BinaryComparison,
       attrLeft: Attribute,
       attrRight: Attribute,
-      update: Boolean): Option[BigDecimal] = {
+      update: Boolean): Option[Double] = {
 
     if (!colStatsMap.contains(attrLeft)) {
       logDebug("[CBO] No statistics for " + attrLeft)
@@ -726,7 +718,7 @@ case class FilterEstimation(plan: Filter) extends Logging {
         )
     }
 
-    var percent = BigDecimal(1)
+    var percent = 1.0
     if (noOverlap) {
       percent = 0.0
     } else if (completeOverlap) {
@@ -740,11 +732,9 @@ case class FilterEstimation(plan: Filter) extends Logging {
         // Need to adjust new min/max after the filter condition is applied
 
         val ndvLeft = BigDecimal(colStatLeft.distinctCount)
-        var newNdvLeft = ceil(ndvLeft * percent)
-        if (newNdvLeft < 1) newNdvLeft = 1
+        val newNdvLeft = ceil(ndvLeft * percent)
         val ndvRight = BigDecimal(colStatRight.distinctCount)
-        var newNdvRight = ceil(ndvRight * percent)
-        if (newNdvRight < 1) newNdvRight = 1
+        val newNdvRight = ceil(ndvRight * percent)
 
         var newMaxLeft = colStatLeft.max
         var newMinLeft = colStatLeft.min

http://git-wip-us.apache.org/repos/asf/spark/blob/224375c5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala
index 0caaf79..f46b4ed 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala
@@ -26,10 +26,10 @@ trait ValueInterval {
   def contains(l: Literal): Boolean
 }
 
-/** For simplicity we use decimal to unify operations of numeric intervals. */
-case class NumericValueInterval(min: Decimal, max: Decimal) extends 
ValueInterval {
+/** For simplicity we use double to unify operations of numeric intervals. */
+case class NumericValueInterval(min: Double, max: Double) extends 
ValueInterval {
   override def contains(l: Literal): Boolean = {
-    val lit = EstimationUtils.toDecimal(l.value, l.dataType)
+    val lit = EstimationUtils.toDouble(l.value, l.dataType)
     min <= lit && max >= lit
   }
 }
@@ -56,8 +56,8 @@ object ValueInterval {
     case _ if min.isEmpty || max.isEmpty => new NullValueInterval()
     case _ =>
       NumericValueInterval(
-        min = EstimationUtils.toDecimal(min.get, dataType),
-        max = EstimationUtils.toDecimal(max.get, dataType))
+        min = EstimationUtils.toDouble(min.get, dataType),
+        max = EstimationUtils.toDouble(max.get, dataType))
   }
 
   def isIntersected(r1: ValueInterval, r2: ValueInterval): Boolean = (r1, r2) 
match {
@@ -84,8 +84,8 @@ object ValueInterval {
         // Choose the maximum of two min values, and the minimum of two max 
values.
         val newMin = if (n1.min <= n2.min) n2.min else n1.min
         val newMax = if (n1.max <= n2.max) n1.max else n2.max
-        (Some(EstimationUtils.fromDecimal(newMin, dt)),
-          Some(EstimationUtils.fromDecimal(newMax, dt)))
+        (Some(EstimationUtils.fromDouble(newMin, dt)),
+          Some(EstimationUtils.fromDouble(newMax, dt)))
     }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to