This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.1 by this push:
     new 04ada05  [SPARK-36079][SQL] Null-based filter estimate should always 
be in the range [0, 1]
04ada05 is described below

commit 04ada0598d9c78253bde8378cac0a322c0ed1031
Author: Karen Feng <karen.f...@databricks.com>
AuthorDate: Tue Jul 20 21:32:13 2021 +0800

    [SPARK-36079][SQL] Null-based filter estimate should always be in the range 
[0, 1]
    
    Forces the selectivity estimate for null-based filters to be in the range 
`[0,1]`.
    
    I noticed in a few TPC-DS query tests that the column statistic null count 
can be higher than the table statistic row count. In the current 
implementation, the selectivity estimate for `IsNotNull` is negative.
    
    No
    
    Unit test
    
    Closes #33286 from karenfeng/bound-selectivity-est.
    
    Authored-by: Karen Feng <karen.f...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit ddc61e62b9af5deff1b93e22f466f2a13f281155)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/catalyst/plans/logical/Statistics.scala    | 13 +++++++
 .../logical/statsEstimation/EstimationUtils.scala  | 18 ++++++----
 .../logical/statsEstimation/FilterEstimation.scala | 30 +++++++---------
 .../logical/statsEstimation/JoinEstimation.scala   | 13 +++----
 .../statsEstimation/FilterEstimationSuite.scala    | 40 +++++++++++++++++++++-
 5 files changed, 80 insertions(+), 34 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
index 1346f80..e80eae6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/Statistics.scala
@@ -24,6 +24,7 @@ import net.jpountz.lz4.{LZ4BlockInputStream, 
LZ4BlockOutputStream}
 
 import org.apache.spark.sql.catalyst.catalog.CatalogColumnStat
 import org.apache.spark.sql.catalyst.expressions._
+import 
org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
 
@@ -116,6 +117,18 @@ case class ColumnStat(
       maxLen = maxLen,
       histogram = histogram,
       version = version)
+
+  def updateCountStats(
+      oldNumRows: BigInt,
+      newNumRows: BigInt,
+      updatedColumnStatOpt: Option[ColumnStat] = None): ColumnStat = {
+    val updatedColumnStat = updatedColumnStatOpt.getOrElse(this)
+    val newDistinctCount = EstimationUtils.updateStat(oldNumRows, newNumRows,
+      distinctCount, updatedColumnStat.distinctCount)
+    val newNullCount = EstimationUtils.updateStat(oldNumRows, newNumRows,
+      nullCount, updatedColumnStat.nullCount)
+    updatedColumnStat.copy(distinctCount = newDistinctCount, nullCount = 
newNullCount)
+  }
 }
 
 /**
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
index 967cced..dafb979 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala
@@ -52,14 +52,20 @@ object EstimationUtils {
   }
 
   /**
-   * Updates (scales down) the number of distinct values if the number of rows 
decreases after
-   * some operation (such as filter, join). Otherwise keep it unchanged.
+   * Updates (scales down) a statistic (eg. number of distinct values) if the 
number of rows
+   * decreases after some operation (such as filter, join). Otherwise keep it 
unchanged.
    */
-  def updateNdv(oldNumRows: BigInt, newNumRows: BigInt, oldNdv: BigInt): 
BigInt = {
-    if (newNumRows < oldNumRows) {
-      ceil(BigDecimal(oldNdv) * BigDecimal(newNumRows) / 
BigDecimal(oldNumRows))
+  def updateStat(
+      oldNumRows: BigInt,
+      newNumRows: BigInt,
+      oldStatOpt: Option[BigInt],
+      updatedStatOpt: Option[BigInt]): Option[BigInt] = {
+    if (oldStatOpt.isDefined && updatedStatOpt.isDefined && updatedStatOpt.get 
> 1 &&
+      newNumRows < oldNumRows) {
+        // no need to scale down since it is already down to 1
+        Some(ceil(BigDecimal(oldStatOpt.get) * BigDecimal(newNumRows) / 
BigDecimal(oldNumRows)))
     } else {
-      oldNdv
+      updatedStatOpt
     }
   }
 
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
index 2c5beef..bc341b9 100755
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala
@@ -106,7 +106,7 @@ case class FilterEstimation(plan: Filter) extends Logging {
       // The foldable Not has been processed in the ConstantFolding rule
       // This is a top-down traversal. The Not could be pushed down by the 
above two cases.
       case Not(l @ Literal(null, _)) =>
-        calculateSingleCondition(l, update = false)
+        calculateSingleCondition(l, update = false).map(boundProbability(_))
 
       case Not(cond) =>
         calculateFilterSelectivity(cond, update = false) match {
@@ -115,7 +115,7 @@ case class FilterEstimation(plan: Filter) extends Logging {
         }
 
       case _ =>
-        calculateSingleCondition(condition, update)
+        calculateSingleCondition(condition, update).map(boundProbability(_))
     }
   }
 
@@ -233,6 +233,8 @@ case class FilterEstimation(plan: Filter) extends Logging {
     val rowCountValue = childStats.rowCount.get
     val nullPercent: Double = if (rowCountValue == 0) {
       0
+    } else if (colStat.nullCount.get > rowCountValue) {
+      1
     } else {
       (BigDecimal(colStat.nullCount.get) / BigDecimal(rowCountValue)).toDouble
     }
@@ -854,6 +856,10 @@ case class FilterEstimation(plan: Filter) extends Logging {
     Some(percent)
   }
 
+  // Bound result in [0, 1]
+  private def boundProbability(p: Double): Double = {
+    Math.max(0.0, Math.min(1.0, p))
+  }
 }
 
 /**
@@ -907,26 +913,14 @@ case class ColumnStatsMap(originalMap: 
AttributeMap[ColumnStat]) {
   def update(a: Attribute, stats: ColumnStat): Unit = 
updatedMap.update(a.exprId, a -> stats)
 
   /**
-   * Collects updated column stats, and scales down ndv for other column stats 
if the number of rows
-   * decreases after this Filter operator.
+   * Collects updated column stats; scales down column count stats if the
+   * number of rows decreases after this Filter operator.
    */
   def outputColumnStats(rowsBeforeFilter: BigInt, rowsAfterFilter: BigInt)
     : AttributeMap[ColumnStat] = {
     val newColumnStats = originalMap.map { case (attr, oriColStat) =>
-      val colStat = updatedMap.get(attr.exprId).map(_._2).getOrElse(oriColStat)
-      val newNdv = if (colStat.distinctCount.isEmpty) {
-        // No NDV in the original stats.
-        None
-      } else if (colStat.distinctCount.get > 1) {
-        // Update ndv based on the overall filter selectivity: scale down ndv 
if the number of rows
-        // decreases; otherwise keep it unchanged.
-        Some(EstimationUtils.updateNdv(oldNumRows = rowsBeforeFilter,
-          newNumRows = rowsAfterFilter, oldNdv = oriColStat.distinctCount.get))
-      } else {
-        // no need to scale down since it is already down to 1 (for skewed 
distribution case)
-        colStat.distinctCount
-      }
-      attr -> colStat.copy(distinctCount = newNdv)
+      attr -> oriColStat.updateCountStats(
+        rowsBeforeFilter, rowsAfterFilter, 
updatedMap.get(attr.exprId).map(_._2))
     }
     AttributeMap(newColumnStats.toSeq)
   }
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
index 777a4c8..c966117 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/JoinEstimation.scala
@@ -308,17 +308,12 @@ case class JoinEstimation(join: Join) extends Logging {
         outputAttrStats += a -> keyStatsAfterJoin(a)
       } else {
         val oldColStat = oldAttrStats(a)
-        val oldNdv = oldColStat.distinctCount
-        val newNdv = if (oldNdv.isDefined) {
-          Some(if (join.left.outputSet.contains(a)) {
-            updateNdv(oldNumRows = leftRows, newNumRows = outputRows, oldNdv = 
oldNdv.get)
-          } else {
-            updateNdv(oldNumRows = rightRows, newNumRows = outputRows, oldNdv 
= oldNdv.get)
-          })
+        val oldNumRows = if (join.left.outputSet.contains(a)) {
+          leftRows
         } else {
-          None
+          rightRows
         }
-        val newColStat = oldColStat.copy(distinctCount = newNdv)
+        val newColStat = oldColStat.updateCountStats(oldNumRows, outputRows)
         // TODO: support nullCount updates for specific outer joins
         outputAttrStats += a -> newColStat
       }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
index 878fae4..2ec2475 100755
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala
@@ -822,6 +822,41 @@ class FilterEstimationSuite extends 
StatsEstimationTestBase {
       expectedRowCount = 3)
   }
 
+  test("SPARK-36079: Null count should be no higher than row count after 
filter") {
+    val colStatNullableString = colStatString.copy(nullCount = Some(10))
+    val condition = Filter(EqualTo(attrBool, Literal(true)),
+      childStatsTestPlan(Seq(attrBool, attrString), tableRowCount = 10L,
+        attributeMap = AttributeMap(Seq(
+          attrBool -> colStatBool, attrString -> colStatNullableString))))
+    validateEstimatedStats(
+      condition,
+      Seq(attrBool -> colStatBool.copy(distinctCount = Some(1), min = 
Some(true)),
+        attrString -> colStatNullableString.copy(distinctCount = Some(5), 
nullCount = Some(5))),
+      expectedRowCount = 5)
+  }
+
+  test("SPARK-36079: Null count higher than row count") {
+    val colStatNullableString = colStatString.copy(nullCount = Some(15))
+    val condition = Filter(IsNotNull(attrString),
+      childStatsTestPlan(Seq(attrString), tableRowCount = 10L,
+        attributeMap = AttributeMap(Seq(attrString -> colStatNullableString))))
+    validateEstimatedStats(
+      condition,
+      Seq(attrString -> colStatNullableString),
+      expectedRowCount = 0)
+  }
+
+  test("SPARK-36079: Bound selectivity >= 0") {
+    val colStatNullableString = colStatString.copy(nullCount = Some(-1))
+    val condition = Filter(IsNotNull(attrString),
+      childStatsTestPlan(Seq(attrString), tableRowCount = 10L,
+        attributeMap = AttributeMap(Seq(attrString -> colStatNullableString))))
+    validateEstimatedStats(
+      condition,
+      Seq(attrString -> colStatString),
+      expectedRowCount = 10)
+  }
+
   test("ColumnStatsMap tests") {
     val attrNoDistinct = AttributeReference("att_without_distinct", 
IntegerType)()
     val attrNoCount = AttributeReference("att_without_count", BooleanType)()
@@ -848,7 +883,10 @@ class FilterEstimationSuite extends 
StatsEstimationTestBase {
     assert(!columnStatsMap.hasMinMaxStats(attrNoMinMax))
   }
 
-  private def childStatsTestPlan(outList: Seq[Attribute], tableRowCount: 
BigInt): StatsTestPlan = {
+  private def childStatsTestPlan(
+      outList: Seq[Attribute],
+      tableRowCount: BigInt,
+      attributeMap: AttributeMap[ColumnStat] = attributeMap): StatsTestPlan = {
     StatsTestPlan(
       outputList = outList,
       rowCount = tableRowCount,

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to