This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.1 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.1 by this push: new 04ada05 [SPARK-36079][SQL] Null-based filter estimate should always be in the range [0, 1] 04ada05 is described below commit 04ada0598d9c78253bde8378cac0a322c0ed1031 Author: Karen Feng <karen.f...@databricks.com> AuthorDate: Tue Jul 20 21:32:13 2021 +0800 [SPARK-36079][SQL] Null-based filter estimate should always be in the range [0, 1] Forces the selectivity estimate for null-based filters to be in the range `[0,1]`. I noticed in a few TPC-DS query tests that the column statistic null count can be higher than the table statistic row count. In the current implementation, the selectivity estimate for `IsNotNull` is negative. No Unit test Closes #33286 from karenfeng/bound-selectivity-est. Authored-by: Karen Feng <karen.f...@databricks.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit ddc61e62b9af5deff1b93e22f466f2a13f281155) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../sql/catalyst/plans/logical/Statistics.scala | 13 +++++++ .../logical/statsEstimation/EstimationUtils.scala | 18 ++++++---- .../logical/statsEstimation/FilterEstimation.scala | 30 +++++++--------- .../logical/statsEstimation/JoinEstimation.scala | 13 +++---- .../statsEstimation/FilterEstimationSuite.scala | 40 +++++++++++++++++++++- 5 files changed, 80 insertions(+), 34 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala index 1346f80..e80eae6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala @@ -24,6 +24,7 @@ import net.jpountz.lz4.{LZ4BlockInputStream, LZ4BlockOutputStream} import org.apache.spark.sql.catalyst.catalog.CatalogColumnStat import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -116,6 +117,18 @@ case class ColumnStat( maxLen = maxLen, histogram = histogram, version = version) + + def updateCountStats( + oldNumRows: BigInt, + newNumRows: BigInt, + updatedColumnStatOpt: Option[ColumnStat] = None): ColumnStat = { + val updatedColumnStat = updatedColumnStatOpt.getOrElse(this) + val newDistinctCount = EstimationUtils.updateStat(oldNumRows, newNumRows, + distinctCount, updatedColumnStat.distinctCount) + val newNullCount = EstimationUtils.updateStat(oldNumRows, newNumRows, + nullCount, updatedColumnStat.nullCount) + updatedColumnStat.copy(distinctCount = newDistinctCount, nullCount = newNullCount) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index 967cced..dafb979 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -52,14 +52,20 @@ object EstimationUtils { } /** - * Updates (scales down) the number of distinct values if the number of rows decreases after - * some operation (such as filter, join). Otherwise keep it unchanged. + * Updates (scales down) a statistic (eg. number of distinct values) if the number of rows + * decreases after some operation (such as filter, join). Otherwise keep it unchanged. */ - def updateNdv(oldNumRows: BigInt, newNumRows: BigInt, oldNdv: BigInt): BigInt = { - if (newNumRows < oldNumRows) { - ceil(BigDecimal(oldNdv) * BigDecimal(newNumRows) / BigDecimal(oldNumRows)) + def updateStat( + oldNumRows: BigInt, + newNumRows: BigInt, + oldStatOpt: Option[BigInt], + updatedStatOpt: Option[BigInt]): Option[BigInt] = { + if (oldStatOpt.isDefined && updatedStatOpt.isDefined && updatedStatOpt.get > 1 && + newNumRows < oldNumRows) { + // no need to scale down since it is already down to 1 + Some(ceil(BigDecimal(oldStatOpt.get) * BigDecimal(newNumRows) / BigDecimal(oldNumRows))) } else { - oldNdv + updatedStatOpt } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 2c5beef..bc341b9 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -106,7 +106,7 @@ case class FilterEstimation(plan: Filter) extends Logging { // The foldable Not has been processed in the ConstantFolding rule // This is a top-down traversal. The Not could be pushed down by the above two cases. case Not(l @ Literal(null, _)) => - calculateSingleCondition(l, update = false) + calculateSingleCondition(l, update = false).map(boundProbability(_)) case Not(cond) => calculateFilterSelectivity(cond, update = false) match { @@ -115,7 +115,7 @@ case class FilterEstimation(plan: Filter) extends Logging { } case _ => - calculateSingleCondition(condition, update) + calculateSingleCondition(condition, update).map(boundProbability(_)) } } @@ -233,6 +233,8 @@ case class FilterEstimation(plan: Filter) extends Logging { val rowCountValue = childStats.rowCount.get val nullPercent: Double = if (rowCountValue == 0) { 0 + } else if (colStat.nullCount.get > rowCountValue) { + 1 } else { (BigDecimal(colStat.nullCount.get) / BigDecimal(rowCountValue)).toDouble } @@ -854,6 +856,10 @@ case class FilterEstimation(plan: Filter) extends Logging { Some(percent) } + // Bound result in [0, 1] + private def boundProbability(p: Double): Double = { + Math.max(0.0, Math.min(1.0, p)) + } } /** @@ -907,26 +913,14 @@ case class ColumnStatsMap(originalMap: AttributeMap[ColumnStat]) { def update(a: Attribute, stats: ColumnStat): Unit = updatedMap.update(a.exprId, a -> stats) /** - * Collects updated column stats, and scales down ndv for other column stats if the number of rows - * decreases after this Filter operator. + * Collects updated column stats; scales down column count stats if the + * number of rows decreases after this Filter operator. */ def outputColumnStats(rowsBeforeFilter: BigInt, rowsAfterFilter: BigInt) : AttributeMap[ColumnStat] = { val newColumnStats = originalMap.map { case (attr, oriColStat) => - val colStat = updatedMap.get(attr.exprId).map(_._2).getOrElse(oriColStat) - val newNdv = if (colStat.distinctCount.isEmpty) { - // No NDV in the original stats. - None - } else if (colStat.distinctCount.get > 1) { - // Update ndv based on the overall filter selectivity: scale down ndv if the number of rows - // decreases; otherwise keep it unchanged. - Some(EstimationUtils.updateNdv(oldNumRows = rowsBeforeFilter, - newNumRows = rowsAfterFilter, oldNdv = oriColStat.distinctCount.get)) - } else { - // no need to scale down since it is already down to 1 (for skewed distribution case) - colStat.distinctCount - } - attr -> colStat.copy(distinctCount = newNdv) + attr -> oriColStat.updateCountStats( + rowsBeforeFilter, rowsAfterFilter, updatedMap.get(attr.exprId).map(_._2)) } AttributeMap(newColumnStats.toSeq) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala index 777a4c8..c966117 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala @@ -308,17 +308,12 @@ case class JoinEstimation(join: Join) extends Logging { outputAttrStats += a -> keyStatsAfterJoin(a) } else { val oldColStat = oldAttrStats(a) - val oldNdv = oldColStat.distinctCount - val newNdv = if (oldNdv.isDefined) { - Some(if (join.left.outputSet.contains(a)) { - updateNdv(oldNumRows = leftRows, newNumRows = outputRows, oldNdv = oldNdv.get) - } else { - updateNdv(oldNumRows = rightRows, newNumRows = outputRows, oldNdv = oldNdv.get) - }) + val oldNumRows = if (join.left.outputSet.contains(a)) { + leftRows } else { - None + rightRows } - val newColStat = oldColStat.copy(distinctCount = newNdv) + val newColStat = oldColStat.updateCountStats(oldNumRows, outputRows) // TODO: support nullCount updates for specific outer joins outputAttrStats += a -> newColStat } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 878fae4..2ec2475 100755 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -822,6 +822,41 @@ class FilterEstimationSuite extends StatsEstimationTestBase { expectedRowCount = 3) } + test("SPARK-36079: Null count should be no higher than row count after filter") { + val colStatNullableString = colStatString.copy(nullCount = Some(10)) + val condition = Filter(EqualTo(attrBool, Literal(true)), + childStatsTestPlan(Seq(attrBool, attrString), tableRowCount = 10L, + attributeMap = AttributeMap(Seq( + attrBool -> colStatBool, attrString -> colStatNullableString)))) + validateEstimatedStats( + condition, + Seq(attrBool -> colStatBool.copy(distinctCount = Some(1), min = Some(true)), + attrString -> colStatNullableString.copy(distinctCount = Some(5), nullCount = Some(5))), + expectedRowCount = 5) + } + + test("SPARK-36079: Null count higher than row count") { + val colStatNullableString = colStatString.copy(nullCount = Some(15)) + val condition = Filter(IsNotNull(attrString), + childStatsTestPlan(Seq(attrString), tableRowCount = 10L, + attributeMap = AttributeMap(Seq(attrString -> colStatNullableString)))) + validateEstimatedStats( + condition, + Seq(attrString -> colStatNullableString), + expectedRowCount = 0) + } + + test("SPARK-36079: Bound selectivity >= 0") { + val colStatNullableString = colStatString.copy(nullCount = Some(-1)) + val condition = Filter(IsNotNull(attrString), + childStatsTestPlan(Seq(attrString), tableRowCount = 10L, + attributeMap = AttributeMap(Seq(attrString -> colStatNullableString)))) + validateEstimatedStats( + condition, + Seq(attrString -> colStatString), + expectedRowCount = 10) + } + test("ColumnStatsMap tests") { val attrNoDistinct = AttributeReference("att_without_distinct", IntegerType)() val attrNoCount = AttributeReference("att_without_count", BooleanType)() @@ -848,7 +883,10 @@ class FilterEstimationSuite extends StatsEstimationTestBase { assert(!columnStatsMap.hasMinMaxStats(attrNoMinMax)) } - private def childStatsTestPlan(outList: Seq[Attribute], tableRowCount: BigInt): StatsTestPlan = { + private def childStatsTestPlan( + outList: Seq[Attribute], + tableRowCount: BigInt, + attributeMap: AttributeMap[ColumnStat] = attributeMap): StatsTestPlan = { StatsTestPlan( outputList = outList, rowCount = tableRowCount, --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org