Repository: spark Updated Branches: refs/heads/master 7536e2849 -> fb036c441
[SPARK-20318][SQL] Use Catalyst type for min/max in ColumnStat for ease of estimation ## What changes were proposed in this pull request? Currently when estimating predicates like col > literal or col = literal, we will update min or max in column stats based on literal value. However, literal value is of Catalyst type (internal type), while min/max is of external type. Then for the next predicate, we again need to do type conversion to compare and update column stats. This is awkward and causes many unnecessary conversions in estimation. To solve this, we use Catalyst type for min/max in `ColumnStat`. Note that the persistent format in metastore is still of external type, so there's no inconsistency for statistics in metastore. This pr also fixes a bug for boolean type in `IN` condition. ## How was this patch tested? The changes for ColumnStat are covered by existing tests. For bug fix, a new test for boolean type in IN condition is added Author: wangzhenhua <wangzhen...@huawei.com> Closes #17630 from wzhfy/refactorColumnStat. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fb036c44 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fb036c44 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fb036c44 Branch: refs/heads/master Commit: fb036c4413c2cd4d90880d080f418ec468d6c0fc Parents: 7536e28 Author: wangzhenhua <wangzhen...@huawei.com> Authored: Fri Apr 14 19:16:47 2017 +0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Fri Apr 14 19:16:47 2017 +0800 ---------------------------------------------------------------------- .../sql/catalyst/plans/logical/Statistics.scala | 95 +++++++++++++------- .../statsEstimation/EstimationUtils.scala | 30 ++++++- .../statsEstimation/FilterEstimation.scala | 68 +++++--------- .../plans/logical/statsEstimation/Range.scala | 70 +++------------ .../statsEstimation/FilterEstimationSuite.scala | 41 +++++---- .../statsEstimation/JoinEstimationSuite.scala | 15 ++-- .../ProjectEstimationSuite.scala | 21 ++--- .../command/AnalyzeColumnCommand.scala | 8 +- .../spark/sql/StatisticsCollectionSuite.scala | 19 ++-- .../spark/sql/hive/HiveExternalCatalog.scala | 4 +- 10 files changed, 189 insertions(+), 182 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/fb036c44/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index f24b240..3d4efef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -25,6 +25,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -74,11 +75,10 @@ case class Statistics( * Statistics collected for a column. * * 1. Supported data types are defined in `ColumnStat.supportsType`. - * 2. The JVM data type stored in min/max is the external data type (used in Row) for the - * corresponding Catalyst data type. For example, for DateType we store java.sql.Date, and for - * TimestampType we store java.sql.Timestamp. - * 3. For integral types, they are all upcasted to longs, i.e. shorts are stored as longs. - * 4. There is no guarantee that the statistics collected are accurate. Approximation algorithms + * 2. The JVM data type stored in min/max is the internal data type for the corresponding + * Catalyst data type. For example, the internal type of DateType is Int, and that the internal + * type of TimestampType is Long. + * 3. There is no guarantee that the statistics collected are accurate. Approximation algorithms * (sketches) might have been used, and the data collected can also be stale. * * @param distinctCount number of distinct values @@ -104,22 +104,43 @@ case class ColumnStat( /** * Returns a map from string to string that can be used to serialize the column stats. * The key is the name of the field (e.g. "distinctCount" or "min"), and the value is the string - * representation for the value. The deserialization side is defined in [[ColumnStat.fromMap]]. + * representation for the value. min/max values are converted to the external data type. For + * example, for DateType we store java.sql.Date, and for TimestampType we store + * java.sql.Timestamp. The deserialization side is defined in [[ColumnStat.fromMap]]. * * As part of the protocol, the returned map always contains a key called "version". * In the case min/max values are null (None), they won't appear in the map. */ - def toMap: Map[String, String] = { + def toMap(colName: String, dataType: DataType): Map[String, String] = { val map = new scala.collection.mutable.HashMap[String, String] map.put(ColumnStat.KEY_VERSION, "1") map.put(ColumnStat.KEY_DISTINCT_COUNT, distinctCount.toString) map.put(ColumnStat.KEY_NULL_COUNT, nullCount.toString) map.put(ColumnStat.KEY_AVG_LEN, avgLen.toString) map.put(ColumnStat.KEY_MAX_LEN, maxLen.toString) - min.foreach { v => map.put(ColumnStat.KEY_MIN_VALUE, v.toString) } - max.foreach { v => map.put(ColumnStat.KEY_MAX_VALUE, v.toString) } + min.foreach { v => map.put(ColumnStat.KEY_MIN_VALUE, toExternalString(v, colName, dataType)) } + max.foreach { v => map.put(ColumnStat.KEY_MAX_VALUE, toExternalString(v, colName, dataType)) } map.toMap } + + /** + * Converts the given value from Catalyst data type to string representation of external + * data type. + */ + private def toExternalString(v: Any, colName: String, dataType: DataType): String = { + val externalValue = dataType match { + case DateType => DateTimeUtils.toJavaDate(v.asInstanceOf[Int]) + case TimestampType => DateTimeUtils.toJavaTimestamp(v.asInstanceOf[Long]) + case BooleanType | _: IntegralType | FloatType | DoubleType => v + case _: DecimalType => v.asInstanceOf[Decimal].toJavaBigDecimal + // This version of Spark does not use min/max for binary/string types so we ignore it. + case _ => + throw new AnalysisException("Column statistics deserialization is not supported for " + + s"column $colName of data type: $dataType.") + } + externalValue.toString + } + } @@ -150,28 +171,15 @@ object ColumnStat extends Logging { * Creates a [[ColumnStat]] object from the given map. This is used to deserialize column stats * from some external storage. The serialization side is defined in [[ColumnStat.toMap]]. */ - def fromMap(table: String, field: StructField, map: Map[String, String]) - : Option[ColumnStat] = { - val str2val: (String => Any) = field.dataType match { - case _: IntegralType => _.toLong - case _: DecimalType => new java.math.BigDecimal(_) - case DoubleType | FloatType => _.toDouble - case BooleanType => _.toBoolean - case DateType => java.sql.Date.valueOf - case TimestampType => java.sql.Timestamp.valueOf - // This version of Spark does not use min/max for binary/string types so we ignore it. - case BinaryType | StringType => _ => null - case _ => - throw new AnalysisException("Column statistics deserialization is not supported for " + - s"column ${field.name} of data type: ${field.dataType}.") - } - + def fromMap(table: String, field: StructField, map: Map[String, String]): Option[ColumnStat] = { try { Some(ColumnStat( distinctCount = BigInt(map(KEY_DISTINCT_COUNT).toLong), // Note that flatMap(Option.apply) turns Option(null) into None. - min = map.get(KEY_MIN_VALUE).map(str2val).flatMap(Option.apply), - max = map.get(KEY_MAX_VALUE).map(str2val).flatMap(Option.apply), + min = map.get(KEY_MIN_VALUE) + .map(fromExternalString(_, field.name, field.dataType)).flatMap(Option.apply), + max = map.get(KEY_MAX_VALUE) + .map(fromExternalString(_, field.name, field.dataType)).flatMap(Option.apply), nullCount = BigInt(map(KEY_NULL_COUNT).toLong), avgLen = map.getOrElse(KEY_AVG_LEN, field.dataType.defaultSize.toString).toLong, maxLen = map.getOrElse(KEY_MAX_LEN, field.dataType.defaultSize.toString).toLong @@ -184,6 +192,30 @@ object ColumnStat extends Logging { } /** + * Converts from string representation of external data type to the corresponding Catalyst data + * type. + */ + private def fromExternalString(s: String, name: String, dataType: DataType): Any = { + dataType match { + case BooleanType => s.toBoolean + case DateType => DateTimeUtils.fromJavaDate(java.sql.Date.valueOf(s)) + case TimestampType => DateTimeUtils.fromJavaTimestamp(java.sql.Timestamp.valueOf(s)) + case ByteType => s.toByte + case ShortType => s.toShort + case IntegerType => s.toInt + case LongType => s.toLong + case FloatType => s.toFloat + case DoubleType => s.toDouble + case _: DecimalType => Decimal(s) + // This version of Spark does not use min/max for binary/string types so we ignore it. + case BinaryType | StringType => null + case _ => + throw new AnalysisException("Column statistics deserialization is not supported for " + + s"column $name of data type: $dataType.") + } + } + + /** * Constructs an expression to compute column statistics for a given column. * * The expression should create a single struct column with the following schema: @@ -232,11 +264,14 @@ object ColumnStat extends Logging { } /** Convert a struct for column stats (defined in statExprs) into [[ColumnStat]]. */ - def rowToColumnStat(row: Row): ColumnStat = { + def rowToColumnStat(row: Row, attr: Attribute): ColumnStat = { ColumnStat( distinctCount = BigInt(row.getLong(0)), - min = Option(row.get(1)), // for string/binary min/max, get should return null - max = Option(row.get(2)), + // for string/binary min/max, get should return null + min = Option(row.get(1)) + .map(v => fromExternalString(v.toString, attr.name, attr.dataType)).flatMap(Option.apply), + max = Option(row.get(2)) + .map(v => fromExternalString(v.toString, attr.name, attr.dataType)).flatMap(Option.apply), nullCount = BigInt(row.getLong(3)), avgLen = row.getLong(4), maxLen = row.getLong(5) http://git-wip-us.apache.org/repos/asf/spark/blob/fb036c44/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index 5577233..f1aff62 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -22,7 +22,7 @@ import scala.math.BigDecimal.RoundingMode import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DataType, StringType} +import org.apache.spark.sql.types.{DecimalType, _} object EstimationUtils { @@ -75,4 +75,32 @@ object EstimationUtils { // (simple computation of statistics returns product of children). if (outputRowCount > 0) outputRowCount * sizePerRow else 1 } + + /** + * For simplicity we use Decimal to unify operations for data types whose min/max values can be + * represented as numbers, e.g. Boolean can be represented as 0 (false) or 1 (true). + * The two methods below are the contract of conversion. + */ + def toDecimal(value: Any, dataType: DataType): Decimal = { + dataType match { + case _: NumericType | DateType | TimestampType => Decimal(value.toString) + case BooleanType => if (value.asInstanceOf[Boolean]) Decimal(1) else Decimal(0) + } + } + + def fromDecimal(dec: Decimal, dataType: DataType): Any = { + dataType match { + case BooleanType => dec.toLong == 1 + case DateType => dec.toInt + case TimestampType => dec.toLong + case ByteType => dec.toByte + case ShortType => dec.toShort + case IntegerType => dec.toInt + case LongType => dec.toLong + case FloatType => dec.toFloat + case DoubleType => dec.toDouble + case _: DecimalType => dec + } + } + } http://git-wip-us.apache.org/repos/asf/spark/blob/fb036c44/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 7bd8e65..4b6b3b1 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -25,7 +25,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, LeafNode, Statistics} -import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -302,30 +301,6 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging } /** - * For a SQL data type, its internal data type may be different from its external type. - * For DateType, its internal type is Int, and its external data type is Java Date type. - * The min/max values in ColumnStat are saved in their corresponding external type. - * - * @param attrDataType the column data type - * @param litValue the literal value - * @return a BigDecimal value - */ - def convertBoundValue(attrDataType: DataType, litValue: Any): Option[Any] = { - attrDataType match { - case DateType => - Some(DateTimeUtils.toJavaDate(litValue.toString.toInt)) - case TimestampType => - Some(DateTimeUtils.toJavaTimestamp(litValue.toString.toLong)) - case _: DecimalType => - Some(litValue.asInstanceOf[Decimal].toJavaBigDecimal) - case StringType | BinaryType => - None - case _ => - Some(litValue) - } - } - - /** * Returns a percentage of rows meeting an equality (=) expression. * This method evaluates the equality predicate for all data types. * @@ -356,12 +331,16 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging val statsRange = Range(colStat.min, colStat.max, attr.dataType) if (statsRange.contains(literal)) { if (update) { - // We update ColumnStat structure after apply this equality predicate. - // Set distinctCount to 1. Set nullCount to 0. - // Need to save new min/max using the external type value of the literal - val newValue = convertBoundValue(attr.dataType, literal.value) - val newStats = colStat.copy(distinctCount = 1, min = newValue, - max = newValue, nullCount = 0) + // We update ColumnStat structure after apply this equality predicate: + // Set distinctCount to 1, nullCount to 0, and min/max values (if exist) to the literal + // value. + val newStats = attr.dataType match { + case StringType | BinaryType => + colStat.copy(distinctCount = 1, nullCount = 0) + case _ => + colStat.copy(distinctCount = 1, min = Some(literal.value), + max = Some(literal.value), nullCount = 0) + } colStatsMap(attr) = newStats } @@ -430,18 +409,14 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging return Some(0.0) } - // Need to save new min/max using the external type value of the literal - val newMax = convertBoundValue( - attr.dataType, validQuerySet.maxBy(v => BigDecimal(v.toString))) - val newMin = convertBoundValue( - attr.dataType, validQuerySet.minBy(v => BigDecimal(v.toString))) - + val newMax = validQuerySet.maxBy(EstimationUtils.toDecimal(_, dataType)) + val newMin = validQuerySet.minBy(EstimationUtils.toDecimal(_, dataType)) // newNdv should not be greater than the old ndv. For example, column has only 2 values // 1 and 6. The predicate column IN (1, 2, 3, 4, 5). validQuerySet.size is 5. newNdv = ndv.min(BigInt(validQuerySet.size)) if (update) { - val newStats = colStat.copy(distinctCount = newNdv, min = newMin, - max = newMax, nullCount = 0) + val newStats = colStat.copy(distinctCount = newNdv, min = Some(newMin), + max = Some(newMax), nullCount = 0) colStatsMap(attr) = newStats } @@ -478,8 +453,8 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging val colStat = colStatsMap(attr) val statsRange = Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange] - val max = BigDecimal(statsRange.max) - val min = BigDecimal(statsRange.min) + val max = statsRange.max.toBigDecimal + val min = statsRange.min.toBigDecimal val ndv = BigDecimal(colStat.distinctCount) // determine the overlapping degree between predicate range and column's range @@ -540,8 +515,7 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging } if (update) { - // Need to save new min/max using the external type value of the literal - val newValue = convertBoundValue(attr.dataType, literal.value) + val newValue = Some(literal.value) var newMax = colStat.max var newMin = colStat.min var newNdv = (ndv * percent).setScale(0, RoundingMode.HALF_UP).toBigInt() @@ -606,14 +580,14 @@ case class FilterEstimation(plan: Filter, catalystConf: SQLConf) extends Logging val colStatLeft = colStatsMap(attrLeft) val statsRangeLeft = Range(colStatLeft.min, colStatLeft.max, attrLeft.dataType) .asInstanceOf[NumericRange] - val maxLeft = BigDecimal(statsRangeLeft.max) - val minLeft = BigDecimal(statsRangeLeft.min) + val maxLeft = statsRangeLeft.max + val minLeft = statsRangeLeft.min val colStatRight = colStatsMap(attrRight) val statsRangeRight = Range(colStatRight.min, colStatRight.max, attrRight.dataType) .asInstanceOf[NumericRange] - val maxRight = BigDecimal(statsRangeRight.max) - val minRight = BigDecimal(statsRangeRight.min) + val maxRight = statsRangeRight.max + val minRight = statsRangeRight.min // determine the overlapping degree between predicate range and column's range val allNotNull = (colStatLeft.nullCount == 0) && (colStatRight.nullCount == 0) http://git-wip-us.apache.org/repos/asf/spark/blob/fb036c44/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala index 3d13967..4ac5ba5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/Range.scala @@ -17,12 +17,8 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation -import java.math.{BigDecimal => JDecimal} -import java.sql.{Date, Timestamp} - import org.apache.spark.sql.catalyst.expressions.Literal -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.types.{BooleanType, DateType, TimestampType, _} +import org.apache.spark.sql.types._ /** Value range of a column. */ @@ -31,13 +27,10 @@ trait Range { } /** For simplicity we use decimal to unify operations of numeric ranges. */ -case class NumericRange(min: JDecimal, max: JDecimal) extends Range { +case class NumericRange(min: Decimal, max: Decimal) extends Range { override def contains(l: Literal): Boolean = { - val decimal = l.dataType match { - case BooleanType => if (l.value.asInstanceOf[Boolean]) new JDecimal(1) else new JDecimal(0) - case _ => new JDecimal(l.value.toString) - } - min.compareTo(decimal) <= 0 && max.compareTo(decimal) >= 0 + val lit = EstimationUtils.toDecimal(l.value, l.dataType) + min <= lit && max >= lit } } @@ -58,7 +51,10 @@ object Range { def apply(min: Option[Any], max: Option[Any], dataType: DataType): Range = dataType match { case StringType | BinaryType => new DefaultRange() case _ if min.isEmpty || max.isEmpty => new NullRange() - case _ => toNumericRange(min.get, max.get, dataType) + case _ => + NumericRange( + min = EstimationUtils.toDecimal(min.get, dataType), + max = EstimationUtils.toDecimal(max.get, dataType)) } def isIntersected(r1: Range, r2: Range): Boolean = (r1, r2) match { @@ -82,51 +78,11 @@ object Range { // binary/string types don't support intersecting. (None, None) case (n1: NumericRange, n2: NumericRange) => - val newRange = NumericRange(n1.min.max(n2.min), n1.max.min(n2.max)) - val (newMin, newMax) = fromNumericRange(newRange, dt) - (Some(newMin), Some(newMax)) + // Choose the maximum of two min values, and the minimum of two max values. + val newMin = if (n1.min <= n2.min) n2.min else n1.min + val newMax = if (n1.max <= n2.max) n1.max else n2.max + (Some(EstimationUtils.fromDecimal(newMin, dt)), + Some(EstimationUtils.fromDecimal(newMax, dt))) } } - - /** - * For simplicity we use decimal to unify operations of numeric types, the two methods below - * are the contract of conversion. - */ - private def toNumericRange(min: Any, max: Any, dataType: DataType): NumericRange = { - dataType match { - case _: NumericType => - NumericRange(new JDecimal(min.toString), new JDecimal(max.toString)) - case BooleanType => - val min1 = if (min.asInstanceOf[Boolean]) 1 else 0 - val max1 = if (max.asInstanceOf[Boolean]) 1 else 0 - NumericRange(new JDecimal(min1), new JDecimal(max1)) - case DateType => - val min1 = DateTimeUtils.fromJavaDate(min.asInstanceOf[Date]) - val max1 = DateTimeUtils.fromJavaDate(max.asInstanceOf[Date]) - NumericRange(new JDecimal(min1), new JDecimal(max1)) - case TimestampType => - val min1 = DateTimeUtils.fromJavaTimestamp(min.asInstanceOf[Timestamp]) - val max1 = DateTimeUtils.fromJavaTimestamp(max.asInstanceOf[Timestamp]) - NumericRange(new JDecimal(min1), new JDecimal(max1)) - } - } - - private def fromNumericRange(n: NumericRange, dataType: DataType): (Any, Any) = { - dataType match { - case _: IntegralType => - (n.min.longValue(), n.max.longValue()) - case FloatType | DoubleType => - (n.min.doubleValue(), n.max.doubleValue()) - case _: DecimalType => - (n.min, n.max) - case BooleanType => - (n.min.longValue() == 1, n.max.longValue() == 1) - case DateType => - (DateTimeUtils.toJavaDate(n.min.intValue()), DateTimeUtils.toJavaDate(n.max.intValue())) - case TimestampType => - (DateTimeUtils.toJavaTimestamp(n.min.longValue()), - DateTimeUtils.toJavaTimestamp(n.max.longValue())) - } - } - } http://git-wip-us.apache.org/repos/asf/spark/blob/fb036c44/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index cffb0d8..a284478 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite import org.apache.spark.sql.catalyst.plans.LeftOuter import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Join, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ /** @@ -45,15 +46,15 @@ class FilterEstimationSuite extends StatsEstimationTestBase { nullCount = 0, avgLen = 1, maxLen = 1) // column cdate has 10 values from 2017-01-01 through 2017-01-10. - val dMin = Date.valueOf("2017-01-01") - val dMax = Date.valueOf("2017-01-10") + val dMin = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-01")) + val dMax = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-10")) val attrDate = AttributeReference("cdate", DateType)() val colStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax), nullCount = 0, avgLen = 4, maxLen = 4) // column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20. - val decMin = new java.math.BigDecimal("0.200000000000000000") - val decMax = new java.math.BigDecimal("0.800000000000000000") + val decMin = Decimal("0.200000000000000000") + val decMax = Decimal("0.800000000000000000") val attrDecimal = AttributeReference("cdecimal", DecimalType(18, 18))() val colStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax), nullCount = 0, avgLen = 8, maxLen = 8) @@ -147,7 +148,6 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint < 3 OR null") { val condition = Or(LessThan(attrInt, Literal(3)), Literal(null, IntegerType)) - val m = Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)).stats(conf) validateEstimatedStats( Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), Seq(attrInt -> colStatInt), @@ -341,6 +341,14 @@ class FilterEstimationSuite extends StatsEstimationTestBase { expectedRowCount = 7) } + test("cbool IN (true)") { + validateEstimatedStats( + Filter(InSet(attrBool, Set(true)), childStatsTestPlan(Seq(attrBool), 10L)), + Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), + nullCount = 0, avgLen = 1, maxLen = 1)), + expectedRowCount = 5) + } + test("cbool = true") { validateEstimatedStats( Filter(EqualTo(attrBool, Literal(true)), childStatsTestPlan(Seq(attrBool), 10L)), @@ -358,9 +366,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { } test("cdate = cast('2017-01-02' AS DATE)") { - val d20170102 = Date.valueOf("2017-01-02") + val d20170102 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-02")) validateEstimatedStats( - Filter(EqualTo(attrDate, Literal(d20170102)), + Filter(EqualTo(attrDate, Literal(d20170102, DateType)), childStatsTestPlan(Seq(attrDate), 10L)), Seq(attrDate -> ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), nullCount = 0, avgLen = 4, maxLen = 4)), @@ -368,9 +376,9 @@ class FilterEstimationSuite extends StatsEstimationTestBase { } test("cdate < cast('2017-01-03' AS DATE)") { - val d20170103 = Date.valueOf("2017-01-03") + val d20170103 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-03")) validateEstimatedStats( - Filter(LessThan(attrDate, Literal(d20170103)), + Filter(LessThan(attrDate, Literal(d20170103, DateType)), childStatsTestPlan(Seq(attrDate), 10L)), Seq(attrDate -> ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103), nullCount = 0, avgLen = 4, maxLen = 4)), @@ -379,19 +387,19 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("""cdate IN ( cast('2017-01-03' AS DATE), cast('2017-01-04' AS DATE), cast('2017-01-05' AS DATE) )""") { - val d20170103 = Date.valueOf("2017-01-03") - val d20170104 = Date.valueOf("2017-01-04") - val d20170105 = Date.valueOf("2017-01-05") + val d20170103 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-03")) + val d20170104 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-04")) + val d20170105 = DateTimeUtils.fromJavaDate(Date.valueOf("2017-01-05")) validateEstimatedStats( - Filter(In(attrDate, Seq(Literal(d20170103), Literal(d20170104), Literal(d20170105))), - childStatsTestPlan(Seq(attrDate), 10L)), + Filter(In(attrDate, Seq(Literal(d20170103, DateType), Literal(d20170104, DateType), + Literal(d20170105, DateType))), childStatsTestPlan(Seq(attrDate), 10L)), Seq(attrDate -> ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105), nullCount = 0, avgLen = 4, maxLen = 4)), expectedRowCount = 3) } test("cdecimal = 0.400000000000000000") { - val dec_0_40 = new java.math.BigDecimal("0.400000000000000000") + val dec_0_40 = Decimal("0.400000000000000000") validateEstimatedStats( Filter(EqualTo(attrDecimal, Literal(dec_0_40)), childStatsTestPlan(Seq(attrDecimal), 4L)), @@ -401,7 +409,7 @@ class FilterEstimationSuite extends StatsEstimationTestBase { } test("cdecimal < 0.60 ") { - val dec_0_60 = new java.math.BigDecimal("0.600000000000000000") + val dec_0_60 = Decimal("0.600000000000000000") validateEstimatedStats( Filter(LessThan(attrDecimal, Literal(dec_0_60)), childStatsTestPlan(Seq(attrDecimal), 4L)), @@ -532,7 +540,6 @@ class FilterEstimationSuite extends StatsEstimationTestBase { test("cint = cint3") { // no records qualify due to no overlap - val emptyColStats = Seq[(Attribute, ColumnStat)]() validateEstimatedStats( Filter(EqualTo(attrInt, attrInt3), childStatsTestPlan(Seq(attrInt, attrInt3), 10L)), Nil, // set to empty http://git-wip-us.apache.org/repos/asf/spark/blob/fb036c44/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala index f62df84..2d6b6e8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/JoinEstimationSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeMap, import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, Project, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types.{DateType, TimestampType, _} @@ -254,24 +255,24 @@ class JoinEstimationSuite extends StatsEstimationTestBase { test("test join keys of different types") { /** Columns in a table with only one row */ def genColumnData: mutable.LinkedHashMap[Attribute, ColumnStat] = { - val dec = new java.math.BigDecimal("1.000000000000000000") - val date = Date.valueOf("2016-05-08") - val timestamp = Timestamp.valueOf("2016-05-08 00:00:01") + val dec = Decimal("1.000000000000000000") + val date = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-08")) + val timestamp = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-08 00:00:01")) mutable.LinkedHashMap[Attribute, ColumnStat]( AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 1, min = Some(false), max = Some(false), nullCount = 0, avgLen = 1, maxLen = 1), AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 1, - min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 1, maxLen = 1), + min = Some(1.toByte), max = Some(1.toByte), nullCount = 0, avgLen = 1, maxLen = 1), AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 1, - min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 2, maxLen = 2), + min = Some(1.toShort), max = Some(1.toShort), nullCount = 0, avgLen = 2, maxLen = 2), AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 1, - min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 4, maxLen = 4), + min = Some(1), max = Some(1), nullCount = 0, avgLen = 4, maxLen = 4), AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 1, min = Some(1L), max = Some(1L), nullCount = 0, avgLen = 8, maxLen = 8), AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 1, min = Some(1.0), max = Some(1.0), nullCount = 0, avgLen = 8, maxLen = 8), AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 1, - min = Some(1.0), max = Some(1.0), nullCount = 0, avgLen = 4, maxLen = 4), + min = Some(1.0f), max = Some(1.0f), nullCount = 0, avgLen = 4, maxLen = 4), AttributeReference("cdec", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 1, min = Some(dec), max = Some(dec), nullCount = 0, avgLen = 16, maxLen = 16), AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 1, http://git-wip-us.apache.org/repos/asf/spark/blob/fb036c44/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala index f408dc4..a5c4d22 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/ProjectEstimationSuite.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp} import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -62,28 +63,28 @@ class ProjectEstimationSuite extends StatsEstimationTestBase { } test("test row size estimation") { - val dec1 = new java.math.BigDecimal("1.000000000000000000") - val dec2 = new java.math.BigDecimal("8.000000000000000000") - val d1 = Date.valueOf("2016-05-08") - val d2 = Date.valueOf("2016-05-09") - val t1 = Timestamp.valueOf("2016-05-08 00:00:01") - val t2 = Timestamp.valueOf("2016-05-09 00:00:02") + val dec1 = Decimal("1.000000000000000000") + val dec2 = Decimal("8.000000000000000000") + val d1 = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-08")) + val d2 = DateTimeUtils.fromJavaDate(Date.valueOf("2016-05-09")) + val t1 = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-08 00:00:01")) + val t2 = DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2016-05-09 00:00:02")) val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( AttributeReference("cbool", BooleanType)() -> ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), nullCount = 0, avgLen = 1, maxLen = 1), AttributeReference("cbyte", ByteType)() -> ColumnStat(distinctCount = 2, - min = Some(1L), max = Some(2L), nullCount = 0, avgLen = 1, maxLen = 1), + min = Some(1.toByte), max = Some(2.toByte), nullCount = 0, avgLen = 1, maxLen = 1), AttributeReference("cshort", ShortType)() -> ColumnStat(distinctCount = 2, - min = Some(1L), max = Some(3L), nullCount = 0, avgLen = 2, maxLen = 2), + min = Some(1.toShort), max = Some(3.toShort), nullCount = 0, avgLen = 2, maxLen = 2), AttributeReference("cint", IntegerType)() -> ColumnStat(distinctCount = 2, - min = Some(1L), max = Some(4L), nullCount = 0, avgLen = 4, maxLen = 4), + min = Some(1), max = Some(4), nullCount = 0, avgLen = 4, maxLen = 4), AttributeReference("clong", LongType)() -> ColumnStat(distinctCount = 2, min = Some(1L), max = Some(5L), nullCount = 0, avgLen = 8, maxLen = 8), AttributeReference("cdouble", DoubleType)() -> ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(6.0), nullCount = 0, avgLen = 8, maxLen = 8), AttributeReference("cfloat", FloatType)() -> ColumnStat(distinctCount = 2, - min = Some(1.0), max = Some(7.0), nullCount = 0, avgLen = 4, maxLen = 4), + min = Some(1.0f), max = Some(7.0f), nullCount = 0, avgLen = 4, maxLen = 4), AttributeReference("cdecimal", DecimalType.SYSTEM_DEFAULT)() -> ColumnStat(distinctCount = 2, min = Some(dec1), max = Some(dec2), nullCount = 0, avgLen = 16, maxLen = 16), AttributeReference("cstring", StringType)() -> ColumnStat(distinctCount = 2, http://git-wip-us.apache.org/repos/asf/spark/blob/fb036c44/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala index b89014e..0d8db2f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzeColumnCommand.scala @@ -73,10 +73,10 @@ case class AnalyzeColumnCommand( val relation = sparkSession.table(tableIdent).logicalPlan // Resolve the column names and dedup using AttributeSet val resolver = sparkSession.sessionState.conf.resolver - val attributesToAnalyze = AttributeSet(columnNames.map { col => + val attributesToAnalyze = columnNames.map { col => val exprOption = relation.output.find(attr => resolver(attr.name, col)) exprOption.getOrElse(throw new AnalysisException(s"Column $col does not exist.")) - }).toSeq + } // Make sure the column types are supported for stats gathering. attributesToAnalyze.foreach { attr => @@ -99,8 +99,8 @@ case class AnalyzeColumnCommand( val statsRow = Dataset.ofRows(sparkSession, Aggregate(Nil, namedExpressions, relation)).head() val rowCount = statsRow.getLong(0) - val columnStats = attributesToAnalyze.zipWithIndex.map { case (expr, i) => - (expr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1))) + val columnStats = attributesToAnalyze.zipWithIndex.map { case (attr, i) => + (attr.name, ColumnStat.rowToColumnStat(statsRow.getStruct(i + 1), attr)) }.toMap (rowCount, columnStats) } http://git-wip-us.apache.org/repos/asf/spark/blob/fb036c44/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index 1f547c5..ddc393c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -26,6 +26,7 @@ import scala.util.Random import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.{CatalogRelation, CatalogStatistics} import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.internal.StaticSQLConf import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} @@ -117,7 +118,7 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) stats.zip(df.schema).foreach { case ((k, v), field) => withClue(s"column $k with type ${field.dataType}") { - val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap) + val roundtrip = ColumnStat.fromMap("table_is_foo", field, v.toMap(k, field.dataType)) assert(roundtrip == Some(v)) } } @@ -201,17 +202,19 @@ abstract class StatisticsCollectionTestBase extends QueryTest with SQLTestUtils /** A mapping from column to the stats collected. */ protected val stats = mutable.LinkedHashMap( "cbool" -> ColumnStat(2, Some(false), Some(true), 1, 1, 1), - "cbyte" -> ColumnStat(2, Some(1L), Some(2L), 1, 1, 1), - "cshort" -> ColumnStat(2, Some(1L), Some(3L), 1, 2, 2), - "cint" -> ColumnStat(2, Some(1L), Some(4L), 1, 4, 4), + "cbyte" -> ColumnStat(2, Some(1.toByte), Some(2.toByte), 1, 1, 1), + "cshort" -> ColumnStat(2, Some(1.toShort), Some(3.toShort), 1, 2, 2), + "cint" -> ColumnStat(2, Some(1), Some(4), 1, 4, 4), "clong" -> ColumnStat(2, Some(1L), Some(5L), 1, 8, 8), "cdouble" -> ColumnStat(2, Some(1.0), Some(6.0), 1, 8, 8), - "cfloat" -> ColumnStat(2, Some(1.0), Some(7.0), 1, 4, 4), - "cdecimal" -> ColumnStat(2, Some(dec1), Some(dec2), 1, 16, 16), + "cfloat" -> ColumnStat(2, Some(1.0f), Some(7.0f), 1, 4, 4), + "cdecimal" -> ColumnStat(2, Some(Decimal(dec1)), Some(Decimal(dec2)), 1, 16, 16), "cstring" -> ColumnStat(2, None, None, 1, 3, 3), "cbinary" -> ColumnStat(2, None, None, 1, 3, 3), - "cdate" -> ColumnStat(2, Some(d1), Some(d2), 1, 4, 4), - "ctimestamp" -> ColumnStat(2, Some(t1), Some(t2), 1, 8, 8) + "cdate" -> ColumnStat(2, Some(DateTimeUtils.fromJavaDate(d1)), + Some(DateTimeUtils.fromJavaDate(d2)), 1, 4, 4), + "ctimestamp" -> ColumnStat(2, Some(DateTimeUtils.fromJavaTimestamp(t1)), + Some(DateTimeUtils.fromJavaTimestamp(t2)), 1, 8, 8) ) private val randomName = new Random(31) http://git-wip-us.apache.org/repos/asf/spark/blob/fb036c44/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 806f2be..8b0fdf4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -526,8 +526,10 @@ private[spark] class HiveExternalCatalog(conf: SparkConf, hadoopConf: Configurat if (stats.rowCount.isDefined) { statsProperties += STATISTICS_NUM_ROWS -> stats.rowCount.get.toString() } + val colNameTypeMap: Map[String, DataType] = + tableDefinition.schema.fields.map(f => (f.name, f.dataType)).toMap stats.colStats.foreach { case (colName, colStat) => - colStat.toMap.foreach { case (k, v) => + colStat.toMap(colName, colNameTypeMap(colName)).foreach { case (k, v) => statsProperties += (columnStatKeyPropName(colName, k) -> v) } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org