Repository: spark
Updated Branches:
  refs/heads/master b0a5cd890 -> 9909f6d36


[SPARK-19350][SQL] Cardinality estimation of Limit and Sample

## What changes were proposed in this pull request?

Before this pr, LocalLimit/GlobalLimit/Sample propagates the same row count and 
column stats from its child, which is incorrect.
We can get the correct rowCount in Statistics for GlobalLimit/Sample whether 
cbo is enabled or not.
We don't know the rowCount for LocalLimit because we don't know the partition 
number at that time. Column stats should not be propagated because we don't 
know the distribution of columns after Limit or Sample.

## How was this patch tested?

Added test cases.

Author: wangzhenhua <wangzhen...@huawei.com>

Closes #16696 from wzhfy/limitEstimation.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/9909f6d3
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/9909f6d3
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/9909f6d3

Branch: refs/heads/master
Commit: 9909f6d361fdf2b7ef30fa7fbbc91e00f2999794
Parents: b0a5cd8
Author: wangzhenhua <wangzhen...@huawei.com>
Authored: Mon Mar 6 21:45:36 2017 -0800
Committer: Xiao Li <gatorsm...@gmail.com>
Committed: Mon Mar 6 21:45:36 2017 -0800

----------------------------------------------------------------------
 .../plans/logical/basicLogicalOperators.scala   |  38 +++---
 .../BasicStatsEstimationSuite.scala             | 122 +++++++++++++++++++
 .../statsEstimation/StatsConfSuite.scala        |  64 ----------
 .../spark/sql/StatisticsCollectionSuite.scala   |  24 ----
 4 files changed, 145 insertions(+), 103 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/9909f6d3/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index ccebae3..4d27ff2 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -752,14 +752,13 @@ case class GlobalLimit(limitExpr: Expression, child: 
LogicalPlan) extends UnaryN
   }
   override def computeStats(conf: CatalystConf): Statistics = {
     val limit = limitExpr.eval().asInstanceOf[Int]
-    val sizeInBytes = if (limit == 0) {
-      // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be 
zero
-      // (product of children).
-      1
-    } else {
-      (limit: Long) * output.map(a => a.dataType.defaultSize).sum
-    }
-    child.stats(conf).copy(sizeInBytes = sizeInBytes)
+    val childStats = child.stats(conf)
+    val rowCount: BigInt = 
childStats.rowCount.map(_.min(limit)).getOrElse(limit)
+    // Don't propagate column stats, because we don't know the distribution 
after a limit operation
+    Statistics(
+      sizeInBytes = EstimationUtils.getOutputSize(output, rowCount, 
childStats.attributeStats),
+      rowCount = Some(rowCount),
+      isBroadcastable = childStats.isBroadcastable)
   }
 }
 
@@ -773,14 +772,21 @@ case class LocalLimit(limitExpr: Expression, child: 
LogicalPlan) extends UnaryNo
   }
   override def computeStats(conf: CatalystConf): Statistics = {
     val limit = limitExpr.eval().asInstanceOf[Int]
-    val sizeInBytes = if (limit == 0) {
+    val childStats = child.stats(conf)
+    if (limit == 0) {
       // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be 
zero
       // (product of children).
-      1
+      Statistics(
+        sizeInBytes = 1,
+        rowCount = Some(0),
+        isBroadcastable = childStats.isBroadcastable)
     } else {
-      (limit: Long) * output.map(a => a.dataType.defaultSize).sum
+      // The output row count of LocalLimit should be the sum of row counts 
from each partition.
+      // However, since the number of partitions is not available here, we 
just use statistics of
+      // the child. Because the distribution after a limit operation is 
unknown, we do not propagate
+      // the column stats.
+      childStats.copy(attributeStats = AttributeMap(Nil))
     }
-    child.stats(conf).copy(sizeInBytes = sizeInBytes)
   }
 }
 
@@ -816,12 +822,14 @@ case class Sample(
 
   override def computeStats(conf: CatalystConf): Statistics = {
     val ratio = upperBound - lowerBound
-    // BigInt can't multiply with Double
-    var sizeInBytes = child.stats(conf).sizeInBytes * (ratio * 100).toInt / 100
+    val childStats = child.stats(conf)
+    var sizeInBytes = EstimationUtils.ceil(BigDecimal(childStats.sizeInBytes) 
* ratio)
     if (sizeInBytes == 0) {
       sizeInBytes = 1
     }
-    child.stats(conf).copy(sizeInBytes = sizeInBytes)
+    val sampledRowCount = childStats.rowCount.map(c => 
EstimationUtils.ceil(BigDecimal(c) * ratio))
+    // Don't propagate column stats, because we don't know the distribution 
after a sample operation
+    Statistics(sizeInBytes, sampledRowCount, isBroadcastable = 
childStats.isBroadcastable)
   }
 
   override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/9909f6d3/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
new file mode 100644
index 0000000..e5dc811
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.statsEstimation
+
+import org.apache.spark.sql.catalyst.CatalystConf
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, 
AttributeReference, Literal}
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.types.IntegerType
+
+
+class BasicStatsEstimationSuite extends StatsEstimationTestBase {
+  val attribute = attr("key")
+  val colStat = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10),
+    nullCount = 0, avgLen = 4, maxLen = 4)
+
+  val plan = StatsTestPlan(
+    outputList = Seq(attribute),
+    attributeStats = AttributeMap(Seq(attribute -> colStat)),
+    rowCount = 10,
+    // row count * (overhead + column size)
+    size = Some(10 * (8 + 4)))
+
+  test("limit estimation: limit < child's rowCount") {
+    val localLimit = LocalLimit(Literal(2), plan)
+    val globalLimit = GlobalLimit(Literal(2), plan)
+    // LocalLimit's stats is just its child's stats except column stats
+    checkStats(localLimit, plan.stats(conf).copy(attributeStats = 
AttributeMap(Nil)))
+    checkStats(globalLimit, Statistics(sizeInBytes = 24, rowCount = Some(2)))
+  }
+
+  test("limit estimation: limit > child's rowCount") {
+    val localLimit = LocalLimit(Literal(20), plan)
+    val globalLimit = GlobalLimit(Literal(20), plan)
+    checkStats(localLimit, plan.stats(conf).copy(attributeStats = 
AttributeMap(Nil)))
+    // Limit is larger than child's rowCount, so GlobalLimit's stats is equal 
to its child's stats.
+    checkStats(globalLimit, plan.stats(conf).copy(attributeStats = 
AttributeMap(Nil)))
+  }
+
+  test("limit estimation: limit = 0") {
+    val localLimit = LocalLimit(Literal(0), plan)
+    val globalLimit = GlobalLimit(Literal(0), plan)
+    val stats = Statistics(sizeInBytes = 1, rowCount = Some(0))
+    checkStats(localLimit, stats)
+    checkStats(globalLimit, stats)
+  }
+
+  test("sample estimation") {
+    val sample = Sample(0.0, 0.5, withReplacement = false, (math.random * 
1000).toLong, plan)()
+    checkStats(sample, Statistics(sizeInBytes = 60, rowCount = Some(5)))
+
+    // Child doesn't have rowCount in stats
+    val childStats = Statistics(sizeInBytes = 120)
+    val childPlan = DummyLogicalPlan(childStats, childStats)
+    val sample2 =
+      Sample(0.0, 0.11, withReplacement = false, (math.random * 1000).toLong, 
childPlan)()
+    checkStats(sample2, Statistics(sizeInBytes = 14))
+  }
+
+  test("estimate statistics when the conf changes") {
+    val expectedDefaultStats =
+      Statistics(
+        sizeInBytes = 40,
+        rowCount = Some(10),
+        attributeStats = AttributeMap(Seq(
+          AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), 
Some(10), 0, 4, 4))),
+        isBroadcastable = false)
+    val expectedCboStats =
+      Statistics(
+        sizeInBytes = 4,
+        rowCount = Some(1),
+        attributeStats = AttributeMap(Seq(
+          AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), 
Some(5), 0, 4, 4))),
+        isBroadcastable = false)
+
+    val plan = DummyLogicalPlan(defaultStats = expectedDefaultStats, cboStats 
= expectedCboStats)
+    checkStats(
+      plan, expectedStatsCboOn = expectedCboStats, expectedStatsCboOff = 
expectedDefaultStats)
+  }
+
+  /** Check estimated stats when cbo is turned on/off. */
+  private def checkStats(
+      plan: LogicalPlan,
+      expectedStatsCboOn: Statistics,
+      expectedStatsCboOff: Statistics): Unit = {
+    assert(plan.stats(conf.copy(cboEnabled = true)) == expectedStatsCboOn)
+    // Invalidate statistics
+    plan.invalidateStatsCache()
+    assert(plan.stats(conf.copy(cboEnabled = false)) == expectedStatsCboOff)
+  }
+
+  /** Check estimated stats when it's the same whether cbo is turned on or 
off. */
+  private def checkStats(plan: LogicalPlan, expectedStats: Statistics): Unit =
+    checkStats(plan, expectedStats, expectedStats)
+}
+
+/**
+ * This class is used for unit-testing the cbo switch, it mimics a logical 
plan which computes
+ * a simple statistics or a cbo estimated statistics based on the conf.
+ */
+private case class DummyLogicalPlan(
+    defaultStats: Statistics,
+    cboStats: Statistics) extends LogicalPlan {
+  override def output: Seq[Attribute] = Nil
+  override def children: Seq[LogicalPlan] = Nil
+  override def computeStats(conf: CatalystConf): Statistics =
+    if (conf.cboEnabled) cboStats else defaultStats
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/9909f6d3/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsConfSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsConfSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsConfSuite.scala
deleted file mode 100644
index 212d57a..0000000
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsConfSuite.scala
+++ /dev/null
@@ -1,64 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst.statsEstimation
-
-import org.apache.spark.sql.catalyst.CatalystConf
-import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, 
AttributeReference}
-import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, 
Statistics}
-import org.apache.spark.sql.types.IntegerType
-
-
-class StatsConfSuite extends StatsEstimationTestBase {
-  test("estimate statistics when the conf changes") {
-    val expectedDefaultStats =
-      Statistics(
-        sizeInBytes = 40,
-        rowCount = Some(10),
-        attributeStats = AttributeMap(Seq(
-          AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), 
Some(10), 0, 4, 4))),
-        isBroadcastable = false)
-    val expectedCboStats =
-      Statistics(
-        sizeInBytes = 4,
-        rowCount = Some(1),
-        attributeStats = AttributeMap(Seq(
-          AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), 
Some(5), 0, 4, 4))),
-        isBroadcastable = false)
-
-    val plan = DummyLogicalPlan(defaultStats = expectedDefaultStats, cboStats 
= expectedCboStats)
-    // Return the statistics estimated by cbo
-    assert(plan.stats(conf.copy(cboEnabled = true)) == expectedCboStats)
-    // Invalidate statistics
-    plan.invalidateStatsCache()
-    // Return the simple statistics
-    assert(plan.stats(conf.copy(cboEnabled = false)) == expectedDefaultStats)
-  }
-}
-
-/**
- * This class is used for unit-testing the cbo switch, it mimics a logical 
plan which computes
- * a simple statistics or a cbo estimated statistics based on the conf.
- */
-private case class DummyLogicalPlan(
-    defaultStats: Statistics,
-    cboStats: Statistics) extends LogicalPlan {
-  override def output: Seq[Attribute] = Nil
-  override def children: Seq[LogicalPlan] = Nil
-  override def computeStats(conf: CatalystConf): Statistics =
-    if (conf.cboEnabled) cboStats else defaultStats
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/9909f6d3/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
index bbb31db..1f547c5 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala
@@ -112,30 +112,6 @@ class StatisticsCollectionSuite extends 
StatisticsCollectionTestBase with Shared
       spark.sessionState.conf.autoBroadcastJoinThreshold)
   }
 
-  test("estimates the size of limit") {
-    withTempView("test") {
-      Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v")
-        .createOrReplaceTempView("test")
-      Seq((0, 1), (1, 24), (2, 48)).foreach { case (limit, expected) =>
-        val df = sql(s"""SELECT * FROM test limit $limit""")
-
-        val sizesGlobalLimit = df.queryExecution.analyzed.collect { case g: 
GlobalLimit =>
-          g.stats(conf).sizeInBytes
-        }
-        assert(sizesGlobalLimit.size === 1, s"Size wrong for:\n 
${df.queryExecution}")
-        assert(sizesGlobalLimit.head === BigInt(expected),
-          s"expected exact size $expected for table 'test', got: 
${sizesGlobalLimit.head}")
-
-        val sizesLocalLimit = df.queryExecution.analyzed.collect { case l: 
LocalLimit =>
-          l.stats(conf).sizeInBytes
-        }
-        assert(sizesLocalLimit.size === 1, s"Size wrong for:\n 
${df.queryExecution}")
-        assert(sizesLocalLimit.head === BigInt(expected),
-          s"expected exact size $expected for table 'test', got: 
${sizesLocalLimit.head}")
-      }
-    }
-  }
-
   test("column stats round trip serialization") {
     // Make sure we serialize and then deserialize and we will get the result 
data
     val df = data.toDF(stats.keys.toSeq :+ "carray" : _*)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to