Repository: spark Updated Branches: refs/heads/master 0bf605c2c -> 039ed9fe8
[SPARK-19271][SQL] Change non-cbo estimation of aggregate ## What changes were proposed in this pull request? Change non-cbo estimation behavior of aggregate: - If groupExpression is empty, we can know row count (=1) and the corresponding size; - otherwise, estimation falls back to UnaryNode's computeStats method, which should not propagate rowCount and attributeStats in Statistics because they are not estimated in that method. ## How was this patch tested? Added test case Author: wangzhenhua <wangzhen...@huawei.com> Closes #16631 from wzhfy/aggNoCbo. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/039ed9fe Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/039ed9fe Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/039ed9fe Branch: refs/heads/master Commit: 039ed9fe8a2fdcd99e0561af64cda8fe3406bc12 Parents: 0bf605c Author: wangzhenhua <wangzhen...@huawei.com> Authored: Thu Jan 19 22:18:47 2017 -0800 Committer: gatorsmile <gatorsm...@gmail.com> Committed: Thu Jan 19 22:18:47 2017 -0800 ---------------------------------------------------------------------- .../catalyst/plans/logical/LogicalPlan.scala | 3 ++- .../plans/logical/basicLogicalOperators.scala | 7 ++++-- .../statsEstimation/AggregateEstimation.scala | 2 +- .../statsEstimation/EstimationUtils.scala | 4 ++-- .../statsEstimation/ProjectEstimation.scala | 2 +- .../AggregateEstimationSuite.scala | 24 +++++++++++++++++++- .../StatsEstimationTestBase.scala | 7 +++--- 7 files changed, 38 insertions(+), 11 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/039ed9fe/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 0587a59..93550e1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -344,7 +344,8 @@ abstract class UnaryNode extends LogicalPlan { sizeInBytes = 1 } - child.stats(conf).copy(sizeInBytes = sizeInBytes) + // Don't propagate rowCount and attributeStats, since they are not estimated here. + Statistics(sizeInBytes = sizeInBytes, isBroadcastable = child.stats(conf).isBroadcastable) } } http://git-wip-us.apache.org/repos/asf/spark/blob/039ed9fe/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 3bd3143..432097d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTypes} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{AggregateEstimation, ProjectEstimation} +import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.{AggregateEstimation, EstimationUtils, ProjectEstimation} import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -541,7 +541,10 @@ case class Aggregate( override def computeStats(conf: CatalystConf): Statistics = { def simpleEstimation: Statistics = { if (groupingExpressions.isEmpty) { - super.computeStats(conf).copy(sizeInBytes = 1) + Statistics( + sizeInBytes = EstimationUtils.getOutputSize(output, outputRowCount = 1), + rowCount = Some(1), + isBroadcastable = child.stats(conf).isBroadcastable) } else { super.computeStats(conf) } http://git-wip-us.apache.org/repos/asf/spark/blob/039ed9fe/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala index 21e94fc..ce74554 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/AggregateEstimation.scala @@ -53,7 +53,7 @@ object AggregateEstimation { val outputAttrStats = getOutputMap(childStats.attributeStats, agg.output) Some(Statistics( - sizeInBytes = getOutputSize(agg.output, outputAttrStats, outputRows), + sizeInBytes = getOutputSize(agg.output, outputRows, outputAttrStats), rowCount = Some(outputRows), attributeStats = outputAttrStats, isBroadcastable = childStats.isBroadcastable)) http://git-wip-us.apache.org/repos/asf/spark/blob/039ed9fe/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala index cf4452d..e8b7942 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/EstimationUtils.scala @@ -37,8 +37,8 @@ object EstimationUtils { def getOutputSize( attributes: Seq[Attribute], - attrStats: AttributeMap[ColumnStat], - outputRowCount: BigInt): BigInt = { + outputRowCount: BigInt, + attrStats: AttributeMap[ColumnStat] = AttributeMap(Nil)): BigInt = { // We assign a generic overhead for a Row object, the actual overhead is different for different // Row format. val sizePerRow = 8 + attributes.map { attr => http://git-wip-us.apache.org/repos/asf/spark/blob/039ed9fe/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala index 50b869a..e9084ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/ProjectEstimation.scala @@ -36,7 +36,7 @@ object ProjectEstimation { val outputAttrStats = getOutputMap(AttributeMap(inputAttrStats.toSeq ++ aliasStats), project.output) Some(childStats.copy( - sizeInBytes = getOutputSize(project.output, outputAttrStats, childStats.rowCount.get), + sizeInBytes = getOutputSize(project.output, childStats.rowCount.get, outputAttrStats), attributeStats = outputAttrStats)) } else { None http://git-wip-us.apache.org/repos/asf/spark/blob/039ed9fe/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala index 41a4bc3..c0b9515 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/AggregateEstimationSuite.scala @@ -90,6 +90,28 @@ class AggregateEstimationSuite extends StatsEstimationTestBase { expectedOutputRowCount = 0) } + test("non-cbo estimation") { + val attributes = Seq("key12").map(nameToAttr) + val child = StatsTestPlan( + outputList = attributes, + rowCount = 4, + // rowCount * (overhead + column size) + size = Some(4 * (8 + 4)), + attributeStats = AttributeMap(Seq("key12").map(nameToColInfo))) + + val noGroupAgg = Aggregate(groupingExpressions = Nil, + aggregateExpressions = Seq(Alias(Count(Literal(1)), "cnt")()), child) + assert(noGroupAgg.stats(conf.copy(cboEnabled = false)) == + // overhead + count result size + Statistics(sizeInBytes = 8 + 8, rowCount = Some(1))) + + val hasGroupAgg = Aggregate(groupingExpressions = attributes, + aggregateExpressions = attributes :+ Alias(Count(Literal(1)), "cnt")(), child) + assert(hasGroupAgg.stats(conf.copy(cboEnabled = false)) == + // From UnaryNode.computeStats, childSize * outputRowSize / childRowSize + Statistics(sizeInBytes = 48 * (8 + 4 + 8) / (8 + 4))) + } + private def checkAggStats( tableColumns: Seq[String], tableRowCount: BigInt, @@ -107,7 +129,7 @@ class AggregateEstimationSuite extends StatsEstimationTestBase { val expectedAttrStats = AttributeMap(groupByColumns.map(nameToColInfo)) val expectedStats = Statistics( - sizeInBytes = getOutputSize(testAgg.output, expectedAttrStats, expectedOutputRowCount), + sizeInBytes = getOutputSize(testAgg.output, expectedOutputRowCount, expectedAttrStats), rowCount = Some(expectedOutputRowCount), attributeStats = expectedAttrStats) http://git-wip-us.apache.org/repos/asf/spark/blob/039ed9fe/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala index e6adb67..a5fac4b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala @@ -45,11 +45,12 @@ class StatsEstimationTestBase extends SparkFunSuite { protected case class StatsTestPlan( outputList: Seq[Attribute], rowCount: BigInt, - attributeStats: AttributeMap[ColumnStat]) extends LeafNode { + attributeStats: AttributeMap[ColumnStat], + size: Option[BigInt] = None) extends LeafNode { override def output: Seq[Attribute] = outputList override def computeStats(conf: CatalystConf): Statistics = Statistics( - // sizeInBytes in stats of StatsTestPlan is useless in cbo estimation, we just use a fake value - sizeInBytes = Int.MaxValue, + // If sizeInBytes is useless in testing, we just use a fake value + sizeInBytes = size.getOrElse(Int.MaxValue), rowCount = Some(rowCount), attributeStats = attributeStats) } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org