Repository: spark Updated Branches: refs/heads/master d4f0b1d2c -> 224375c55
[SPARK-22892][SQL] Simplify some estimation logic by using double instead of decimal ## What changes were proposed in this pull request? Simplify some estimation logic by using double instead of decimal. ## How was this patch tested? Existing tests. Author: Zhenhua Wang <wangzhen...@huawei.com> Closes #20062 from wzhfy/simplify_by_double. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/224375c5 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/224375c5 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/224375c5 Branch: refs/heads/master Commit: 224375c55ff4c832dafbb87c55f2971e6d8994f2 Parents: d4f0b1d Author: Zhenhua Wang <wangzhen...@huawei.com> Authored: Fri Dec 29 15:39:56 2017 +0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Fri Dec 29 15:39:56 2017 +0800 ---------------------------------------------------------------------- .../statsEstimation/EstimationUtils.scala | 30 +++---- .../statsEstimation/FilterEstimation.scala | 84 +++++++++----------- .../logical/statsEstimation/ValueInterval.scala | 14 ++-- 3 files changed, 59 insertions(+), 69 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/224375c5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index 71e852a..d793f77 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -89,29 +89,29 @@ object EstimationUtils { } /** - * For simplicity we use Decimal to unify operations for data types whose min/max values can be + * For simplicity we use Double to unify operations for data types whose min/max values can be * represented as numbers, e.g. Boolean can be represented as 0 (false) or 1 (true). * The two methods below are the contract of conversion. */ - def toDecimal(value: Any, dataType: DataType): Decimal = { + def toDouble(value: Any, dataType: DataType): Double = { dataType match { - case _: NumericType | DateType | TimestampType => Decimal(value.toString) - case BooleanType => if (value.asInstanceOf[Boolean]) Decimal(1) else Decimal(0) + case _: NumericType | DateType | TimestampType => value.toString.toDouble + case BooleanType => if (value.asInstanceOf[Boolean]) 1 else 0 } } - def fromDecimal(dec: Decimal, dataType: DataType): Any = { + def fromDouble(double: Double, dataType: DataType): Any = { dataType match { - case BooleanType => dec.toLong == 1 - case DateType => dec.toInt - case TimestampType => dec.toLong - case ByteType => dec.toByte - case ShortType => dec.toShort - case IntegerType => dec.toInt - case LongType => dec.toLong - case FloatType => dec.toFloat - case DoubleType => dec.toDouble - case _: DecimalType => dec + case BooleanType => double.toInt == 1 + case DateType => double.toInt + case TimestampType => double.toLong + case ByteType => double.toByte + case ShortType => double.toShort + case IntegerType => double.toInt + case LongType => double.toLong + case FloatType => double.toFloat + case DoubleType => double + case _: DecimalType => Decimal(double) } } http://git-wip-us.apache.org/repos/asf/spark/blob/224375c5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 850dd1b..4cc32de 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -31,7 +31,7 @@ case class FilterEstimation(plan: Filter) extends Logging { private val childStats = plan.child.stats - private val colStatsMap = new ColumnStatsMap(childStats.attributeStats) + private val colStatsMap = ColumnStatsMap(childStats.attributeStats) /** * Returns an option of Statistics for a Filter logical plan node. @@ -47,7 +47,7 @@ case class FilterEstimation(plan: Filter) extends Logging { // Estimate selectivity of this filter predicate, and update column stats if needed. // For not-supported condition, set filter selectivity to a conservative estimate 100% - val filterSelectivity = calculateFilterSelectivity(plan.condition).getOrElse(BigDecimal(1)) + val filterSelectivity = calculateFilterSelectivity(plan.condition).getOrElse(1.0) val filteredRowCount: BigInt = ceil(BigDecimal(childStats.rowCount.get) * filterSelectivity) val newColStats = if (filteredRowCount == 0) { @@ -79,17 +79,16 @@ case class FilterEstimation(plan: Filter) extends Logging { * @return an optional double value to show the percentage of rows meeting a given condition. * It returns None if the condition is not supported. */ - def calculateFilterSelectivity(condition: Expression, update: Boolean = true) - : Option[BigDecimal] = { + def calculateFilterSelectivity(condition: Expression, update: Boolean = true): Option[Double] = { condition match { case And(cond1, cond2) => - val percent1 = calculateFilterSelectivity(cond1, update).getOrElse(BigDecimal(1)) - val percent2 = calculateFilterSelectivity(cond2, update).getOrElse(BigDecimal(1)) + val percent1 = calculateFilterSelectivity(cond1, update).getOrElse(1.0) + val percent2 = calculateFilterSelectivity(cond2, update).getOrElse(1.0) Some(percent1 * percent2) case Or(cond1, cond2) => - val percent1 = calculateFilterSelectivity(cond1, update = false).getOrElse(BigDecimal(1)) - val percent2 = calculateFilterSelectivity(cond2, update = false).getOrElse(BigDecimal(1)) + val percent1 = calculateFilterSelectivity(cond1, update = false).getOrElse(1.0) + val percent2 = calculateFilterSelectivity(cond2, update = false).getOrElse(1.0) Some(percent1 + percent2 - (percent1 * percent2)) // Not-operator pushdown @@ -131,7 +130,7 @@ case class FilterEstimation(plan: Filter) extends Logging { * @return an optional double value to show the percentage of rows meeting a given condition. * It returns None if the condition is not supported. */ - def calculateSingleCondition(condition: Expression, update: Boolean): Option[BigDecimal] = { + def calculateSingleCondition(condition: Expression, update: Boolean): Option[Double] = { condition match { case l: Literal => evaluateLiteral(l) @@ -225,17 +224,17 @@ case class FilterEstimation(plan: Filter) extends Logging { def evaluateNullCheck( attr: Attribute, isNull: Boolean, - update: Boolean): Option[BigDecimal] = { + update: Boolean): Option[Double] = { if (!colStatsMap.contains(attr)) { logDebug("[CBO] No statistics for " + attr) return None } val colStat = colStatsMap(attr) val rowCountValue = childStats.rowCount.get - val nullPercent: BigDecimal = if (rowCountValue == 0) { + val nullPercent: Double = if (rowCountValue == 0) { 0 } else { - BigDecimal(colStat.nullCount) / BigDecimal(rowCountValue) + (BigDecimal(colStat.nullCount) / BigDecimal(rowCountValue)).toDouble } if (update) { @@ -271,7 +270,7 @@ case class FilterEstimation(plan: Filter) extends Logging { op: BinaryComparison, attr: Attribute, literal: Literal, - update: Boolean): Option[BigDecimal] = { + update: Boolean): Option[Double] = { if (!colStatsMap.contains(attr)) { logDebug("[CBO] No statistics for " + attr) return None @@ -305,13 +304,12 @@ case class FilterEstimation(plan: Filter) extends Logging { def evaluateEquality( attr: Attribute, literal: Literal, - update: Boolean): Option[BigDecimal] = { + update: Boolean): Option[Double] = { if (!colStatsMap.contains(attr)) { logDebug("[CBO] No statistics for " + attr) return None } val colStat = colStatsMap(attr) - val ndv = colStat.distinctCount // decide if the value is in [min, max] of the column. // We currently don't store min/max for binary/string type. @@ -334,7 +332,7 @@ case class FilterEstimation(plan: Filter) extends Logging { if (colStat.histogram.isEmpty) { // returns 1/ndv if there is no histogram - Some(1.0 / BigDecimal(ndv)) + Some(1.0 / colStat.distinctCount.toDouble) } else { Some(computeEqualityPossibilityByHistogram(literal, colStat)) } @@ -354,7 +352,7 @@ case class FilterEstimation(plan: Filter) extends Logging { * @param literal a literal value (or constant) * @return an optional double value to show the percentage of rows meeting a given condition */ - def evaluateLiteral(literal: Literal): Option[BigDecimal] = { + def evaluateLiteral(literal: Literal): Option[Double] = { literal match { case Literal(null, _) => Some(0.0) case FalseLiteral => Some(0.0) @@ -379,7 +377,7 @@ case class FilterEstimation(plan: Filter) extends Logging { def evaluateInSet( attr: Attribute, hSet: Set[Any], - update: Boolean): Option[BigDecimal] = { + update: Boolean): Option[Double] = { if (!colStatsMap.contains(attr)) { logDebug("[CBO] No statistics for " + attr) return None @@ -403,8 +401,8 @@ case class FilterEstimation(plan: Filter) extends Logging { return Some(0.0) } - val newMax = validQuerySet.maxBy(EstimationUtils.toDecimal(_, dataType)) - val newMin = validQuerySet.minBy(EstimationUtils.toDecimal(_, dataType)) + val newMax = validQuerySet.maxBy(EstimationUtils.toDouble(_, dataType)) + val newMin = validQuerySet.minBy(EstimationUtils.toDouble(_, dataType)) // newNdv should not be greater than the old ndv. For example, column has only 2 values // 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5. newNdv = ndv.min(BigInt(validQuerySet.size)) @@ -425,7 +423,7 @@ case class FilterEstimation(plan: Filter) extends Logging { // return the filter selectivity. Without advanced statistics such as histograms, // we have to assume uniform distribution. - Some((BigDecimal(newNdv) / BigDecimal(ndv)).min(1.0)) + Some(math.min(newNdv.toDouble / ndv.toDouble, 1.0)) } /** @@ -443,21 +441,17 @@ case class FilterEstimation(plan: Filter) extends Logging { op: BinaryComparison, attr: Attribute, literal: Literal, - update: Boolean): Option[BigDecimal] = { + update: Boolean): Option[Double] = { val colStat = colStatsMap(attr) val statsInterval = ValueInterval(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericValueInterval] - val max = statsInterval.max.toBigDecimal - val min = statsInterval.min.toBigDecimal - val ndv = BigDecimal(colStat.distinctCount) + val max = statsInterval.max + val min = statsInterval.min + val ndv = colStat.distinctCount.toDouble // determine the overlapping degree between predicate interval and column's interval - val numericLiteral = if (literal.dataType == BooleanType) { - if (literal.value.asInstanceOf[Boolean]) BigDecimal(1) else BigDecimal(0) - } else { - BigDecimal(literal.value.toString) - } + val numericLiteral = EstimationUtils.toDouble(literal.value, literal.dataType) val (noOverlap: Boolean, completeOverlap: Boolean) = op match { case _: LessThan => (numericLiteral <= min, numericLiteral > max) @@ -469,7 +463,7 @@ case class FilterEstimation(plan: Filter) extends Logging { (numericLiteral > max, numericLiteral <= min) } - var percent = BigDecimal(1) + var percent = 1.0 if (noOverlap) { percent = 0.0 } else if (completeOverlap) { @@ -518,8 +512,6 @@ case class FilterEstimation(plan: Filter) extends Logging { val newValue = Some(literal.value) var newMax = colStat.max var newMin = colStat.min - var newNdv = ceil(ndv * percent) - if (newNdv < 1) newNdv = 1 op match { case _: GreaterThan | _: GreaterThanOrEqual => @@ -528,8 +520,8 @@ case class FilterEstimation(plan: Filter) extends Logging { newMax = newValue } - val newStats = - colStat.copy(distinctCount = newNdv, min = newMin, max = newMax, nullCount = 0) + val newStats = colStat.copy(distinctCount = ceil(ndv * percent), + min = newMin, max = newMax, nullCount = 0) colStatsMap.update(attr, newStats) } @@ -543,13 +535,13 @@ case class FilterEstimation(plan: Filter) extends Logging { */ private def computeEqualityPossibilityByHistogram( literal: Literal, colStat: ColumnStat): Double = { - val datum = EstimationUtils.toDecimal(literal.value, literal.dataType).toDouble + val datum = EstimationUtils.toDouble(literal.value, literal.dataType) val histogram = colStat.histogram.get // find bins where column's current min and max locate. Note that a column's [min, max] // range may change due to another condition applied earlier. - val min = EstimationUtils.toDecimal(colStat.min.get, literal.dataType).toDouble - val max = EstimationUtils.toDecimal(colStat.max.get, literal.dataType).toDouble + val min = EstimationUtils.toDouble(colStat.min.get, literal.dataType) + val max = EstimationUtils.toDouble(colStat.max.get, literal.dataType) // compute how many bins the column's current valid range [min, max] occupies. val numBinsHoldingEntireRange = EstimationUtils.numBinsHoldingRange( @@ -574,13 +566,13 @@ case class FilterEstimation(plan: Filter) extends Logging { */ private def computeComparisonPossibilityByHistogram( op: BinaryComparison, literal: Literal, colStat: ColumnStat): Double = { - val datum = EstimationUtils.toDecimal(literal.value, literal.dataType).toDouble + val datum = EstimationUtils.toDouble(literal.value, literal.dataType) val histogram = colStat.histogram.get // find bins where column's current min and max locate. Note that a column's [min, max] // range may change due to another condition applied earlier. - val min = EstimationUtils.toDecimal(colStat.min.get, literal.dataType).toDouble - val max = EstimationUtils.toDecimal(colStat.max.get, literal.dataType).toDouble + val min = EstimationUtils.toDouble(colStat.min.get, literal.dataType) + val max = EstimationUtils.toDouble(colStat.max.get, literal.dataType) // compute how many bins the column's current valid range [min, max] occupies. val numBinsHoldingEntireRange = EstimationUtils.numBinsHoldingRange( @@ -643,7 +635,7 @@ case class FilterEstimation(plan: Filter) extends Logging { op: BinaryComparison, attrLeft: Attribute, attrRight: Attribute, - update: Boolean): Option[BigDecimal] = { + update: Boolean): Option[Double] = { if (!colStatsMap.contains(attrLeft)) { logDebug("[CBO] No statistics for " + attrLeft) @@ -726,7 +718,7 @@ case class FilterEstimation(plan: Filter) extends Logging { ) } - var percent = BigDecimal(1) + var percent = 1.0 if (noOverlap) { percent = 0.0 } else if (completeOverlap) { @@ -740,11 +732,9 @@ case class FilterEstimation(plan: Filter) extends Logging { // Need to adjust new min/max after the filter condition is applied val ndvLeft = BigDecimal(colStatLeft.distinctCount) - var newNdvLeft = ceil(ndvLeft * percent) - if (newNdvLeft < 1) newNdvLeft = 1 + val newNdvLeft = ceil(ndvLeft * percent) val ndvRight = BigDecimal(colStatRight.distinctCount) - var newNdvRight = ceil(ndvRight * percent) - if (newNdvRight < 1) newNdvRight = 1 + val newNdvRight = ceil(ndvRight * percent) var newMaxLeft = colStatLeft.max var newMinLeft = colStatLeft.min http://git-wip-us.apache.org/repos/asf/spark/blob/224375c5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala index 0caaf79..f46b4ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ValueInterval.scala @@ -26,10 +26,10 @@ trait ValueInterval { def contains(l: Literal): Boolean } -/** For simplicity we use decimal to unify operations of numeric intervals. */ -case class NumericValueInterval(min: Decimal, max: Decimal) extends ValueInterval { +/** For simplicity we use double to unify operations of numeric intervals. */ +case class NumericValueInterval(min: Double, max: Double) extends ValueInterval { override def contains(l: Literal): Boolean = { - val lit = EstimationUtils.toDecimal(l.value, l.dataType) + val lit = EstimationUtils.toDouble(l.value, l.dataType) min <= lit && max >= lit } } @@ -56,8 +56,8 @@ object ValueInterval { case _ if min.isEmpty || max.isEmpty => new NullValueInterval() case _ => NumericValueInterval( - min = EstimationUtils.toDecimal(min.get, dataType), - max = EstimationUtils.toDecimal(max.get, dataType)) + min = EstimationUtils.toDouble(min.get, dataType), + max = EstimationUtils.toDouble(max.get, dataType)) } def isIntersected(r1: ValueInterval, r2: ValueInterval): Boolean = (r1, r2) match { @@ -84,8 +84,8 @@ object ValueInterval { // Choose the maximum of two min values, and the minimum of two max values. val newMin = if (n1.min <= n2.min) n2.min else n1.min val newMax = if (n1.max <= n2.max) n1.max else n2.max - (Some(EstimationUtils.fromDecimal(newMin, dt)), - Some(EstimationUtils.fromDecimal(newMax, dt))) + (Some(EstimationUtils.fromDouble(newMin, dt)), + Some(EstimationUtils.fromDouble(newMax, dt))) } } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org