This is an automated email from the ASF dual-hosted git repository. sunchao pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new ce818ba96953 [SPARK-45731][SQL] Also update partition statistics with `ANALYZE TABLE` command ce818ba96953 is described below commit ce818ba969537cf9eb16865a88148407a5992e98 Author: Chao Sun <sunc...@apple.com> AuthorDate: Thu Nov 9 15:56:47 2023 -0800 [SPARK-45731][SQL] Also update partition statistics with `ANALYZE TABLE` command ### What changes were proposed in this pull request? Also update partition statistics (e.g., total size in bytes, row count) with `ANALYZE TABLE` command. ### Why are the changes needed? Currently when a `ANALYZE TABLE <tableName>` command is triggered against a partition table, only table stats are updated, but not partition stats. For Spark users who want to update the latter, they have to use a different syntax: `ANALYZE TABLE <tableName> PARTITION(<partitionColumns>)` which is more verbose. Given `ANALYZE TABLE` internally already calculates total size for all the partitions, it makes sense to also update partition stats using the result. In this way, Spark users do not need to remember two different syntaxes. In addition, when using `ANALYZE TABLE` with the "scan node", i.e., `NOSCAN` is NOT specified, we can also calculate row count for all the partitions and update the stats accordingly. The above behavior is controlled via a new flag `spark.sql.statistics.updatePartitionStatsInAnalyzeTable.enabled`, which by default is turned off. ### Does this PR introduce _any_ user-facing change? Not by default. When `spark.sql.statistics.updatePartitionStatsInAnalyzeTable.enabled`, Spark will now update partition stats as well with `ANALYZE TABLE` command, on a partitioned table. ### How was this patch tested? Added a unit test for this feature. ### Was this patch authored or co-authored using generative AI tooling? No Closes #43629 from sunchao/SPARK-45731. Authored-by: Chao Sun <sunc...@apple.com> Signed-off-by: Chao Sun <sunc...@apple.com> --- .../org/apache/spark/sql/internal/SQLConf.scala | 13 ++++ .../command/AnalyzePartitionCommand.scala | 50 ++----------- .../spark/sql/execution/command/CommandUtils.scala | 87 ++++++++++++++++++---- .../apache/spark/sql/hive/StatisticsSuite.scala | 78 +++++++++++++++++++ 4 files changed, 170 insertions(+), 58 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index ecc3e6e101fc..ff6ab7b541a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2664,6 +2664,16 @@ object SQLConf { .booleanConf .createWithDefault(false) + val UPDATE_PART_STATS_IN_ANALYZE_TABLE_ENABLED = + buildConf("spark.sql.statistics.updatePartitionStatsInAnalyzeTable.enabled") + .doc("When this config is enabled, Spark will also update partition statistics in analyze " + + "table command (i.e., ANALYZE TABLE .. COMPUTE STATISTICS [NOSCAN]). Note the command " + + "will also become more expensive. When this config is disabled, Spark will only " + + "update table level statistics.") + .version("4.0.0") + .booleanConf + .createWithDefault(false) + val CBO_ENABLED = buildConf("spark.sql.cbo.enabled") .doc("Enables CBO for estimation of plan statistics when set true.") @@ -5113,6 +5123,9 @@ class SQLConf extends Serializable with Logging with SqlApiConf { def autoSizeUpdateEnabled: Boolean = getConf(SQLConf.AUTO_SIZE_UPDATE_ENABLED) + def updatePartStatsInAnalyzeTableEnabled: Boolean = + getConf(SQLConf.UPDATE_PART_STATS_IN_ANALYZE_TABLE_ENABLED) + def joinReorderEnabled: Boolean = getConf(SQLConf.JOIN_REORDER_ENABLED) def joinReorderDPThreshold: Int = getConf(SQLConf.JOIN_REORDER_DP_THRESHOLD) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala index c2b227d6cad7..7fe4c73abf90 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala @@ -17,15 +17,12 @@ package org.apache.spark.sql.execution.command -import org.apache.spark.sql.{Column, Row, SparkSession} +import org.apache.spark.sql.{Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType, ExternalCatalogUtils} +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec -import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Literal} import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.util.PartitioningUtils -import org.apache.spark.util.collection.Utils /** * Analyzes a given set of partitions to generate per-partition statistics, which will be used in @@ -101,20 +98,13 @@ case class AnalyzePartitionCommand( if (noscan) { Map.empty } else { - calculateRowCountsPerPartition(sparkSession, tableMeta, partitionValueSpec) + CommandUtils.calculateRowCountsPerPartition(sparkSession, tableMeta, partitionValueSpec) } // Update the metastore if newly computed statistics are different from those // recorded in the metastore. - - val sizes = CommandUtils.calculateMultipleLocationSizes(sparkSession, tableMeta.identifier, - partitions.map(_.storage.locationUri)) - val newPartitions = partitions.zipWithIndex.flatMap { case (p, idx) => - val newRowCount = rowCounts.get(p.spec) - val newStats = CommandUtils.compareAndGetNewStats(p.stats, sizes(idx), newRowCount) - newStats.map(_ => p.copy(stats = newStats)) - } - + val (_, newPartitions) = CommandUtils.calculatePartitionStats( + sparkSession, tableMeta, partitions, Some(rowCounts)) if (newPartitions.nonEmpty) { sessionState.catalog.alterPartitions(tableMeta.identifier, newPartitions) } @@ -122,35 +112,5 @@ case class AnalyzePartitionCommand( Seq.empty[Row] } - private def calculateRowCountsPerPartition( - sparkSession: SparkSession, - tableMeta: CatalogTable, - partitionValueSpec: Option[TablePartitionSpec]): Map[TablePartitionSpec, BigInt] = { - val filter = if (partitionValueSpec.isDefined) { - val filters = partitionValueSpec.get.map { - case (columnName, value) => EqualTo(UnresolvedAttribute(columnName), Literal(value)) - } - filters.reduce(And) - } else { - Literal.TrueLiteral - } - - val tableDf = sparkSession.table(tableMeta.identifier) - val partitionColumns = tableMeta.partitionColumnNames.map(Column(_)) - - val df = tableDf.filter(Column(filter)).groupBy(partitionColumns: _*).count() - df.collect().map { r => - val partitionColumnValues = partitionColumns.indices.map { i => - if (r.isNullAt(i)) { - ExternalCatalogUtils.DEFAULT_PARTITION_NAME - } else { - r.get(i).toString - } - } - val spec = Utils.toMap(tableMeta.partitionColumnNames, partitionColumnValues) - val count = BigInt(r.getLong(partitionColumns.size)) - (spec, count) - }.toMap - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala index c656bdbafa0c..73478272a684 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala @@ -25,9 +25,11 @@ import scala.util.control.NonFatal import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter} import org.apache.spark.internal.Logging -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{Column, SparkSession} import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} -import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTablePartition, CatalogTableType} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, CatalogTablePartition, CatalogTableType, ExternalCatalogUtils} +import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -37,6 +39,7 @@ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.datasources.{DataSourceUtils, InMemoryFileIndex} import org.apache.spark.sql.internal.{SessionState, SQLConf} import org.apache.spark.sql.types._ +import org.apache.spark.util.collection.Utils /** * For the purpose of calculating total directory sizes, use this filter to @@ -76,29 +79,42 @@ object CommandUtils extends Logging { def calculateTotalSize( spark: SparkSession, - catalogTable: CatalogTable): (BigInt, Seq[CatalogTablePartition]) = { + catalogTable: CatalogTable, + partitionRowCount: Option[Map[TablePartitionSpec, BigInt]] = None): + (BigInt, Seq[CatalogTablePartition]) = { val sessionState = spark.sessionState val startTime = System.nanoTime() val (totalSize, newPartitions) = if (catalogTable.partitionColumnNames.isEmpty) { - (calculateSingleLocationSize(sessionState, catalogTable.identifier, - catalogTable.storage.locationUri), Seq()) + val size = calculateSingleLocationSize(sessionState, catalogTable.identifier, + catalogTable.storage.locationUri) + (BigInt(size), Seq()) } else { // Calculate table size as a sum of the visible partitions. See SPARK-21079 val partitions = sessionState.catalog.listPartitions(catalogTable.identifier) logInfo(s"Starting to calculate sizes for ${partitions.length} partitions.") - val paths = partitions.map(_.storage.locationUri) - val sizes = calculateMultipleLocationSizes(spark, catalogTable.identifier, paths) - val newPartitions = partitions.zipWithIndex.flatMap { case (p, idx) => - val newStats = CommandUtils.compareAndGetNewStats(p.stats, sizes(idx), None) - newStats.map(_ => p.copy(stats = newStats)) - } - (sizes.sum, newPartitions) + calculatePartitionStats(spark, catalogTable, partitions, partitionRowCount) } logInfo(s"It took ${(System.nanoTime() - startTime) / (1000 * 1000)} ms to calculate" + s" the total size for table ${catalogTable.identifier}.") (totalSize, newPartitions) } + def calculatePartitionStats( + spark: SparkSession, + catalogTable: CatalogTable, + partitions: Seq[CatalogTablePartition], + partitionRowCount: Option[Map[TablePartitionSpec, BigInt]] = None): + (BigInt, Seq[CatalogTablePartition]) = { + val paths = partitions.map(_.storage.locationUri) + val sizes = calculateMultipleLocationSizes(spark, catalogTable.identifier, paths) + val newPartitions = partitions.zipWithIndex.flatMap { case (p, idx) => + val newRowCount = partitionRowCount.flatMap(_.get(p.spec)) + val newStats = CommandUtils.compareAndGetNewStats(p.stats, sizes(idx), newRowCount) + newStats.map(_ => p.copy(stats = newStats)) + } + (sizes.sum, newPartitions) + } + def calculateSingleLocationSize( sessionState: SessionState, identifier: TableIdentifier, @@ -214,6 +230,7 @@ object CommandUtils extends Logging { tableIdent: TableIdentifier, noScan: Boolean): Unit = { val sessionState = sparkSession.sessionState + val partitionStatsEnabled = sessionState.conf.updatePartStatsInAnalyzeTableEnabled val db = tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase) val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db)) val tableMeta = sessionState.catalog.getTableMetadata(tableIdentWithDB) @@ -231,7 +248,15 @@ object CommandUtils extends Logging { } } else { // Compute stats for the whole table - val (newTotalSize, _) = CommandUtils.calculateTotalSize(sparkSession, tableMeta) + val rowCounts: Map[TablePartitionSpec, BigInt] = + if (noScan || !partitionStatsEnabled) { + Map.empty + } else { + calculateRowCountsPerPartition(sparkSession, tableMeta, None) + } + val (newTotalSize, newPartitions) = CommandUtils.calculateTotalSize( + sparkSession, tableMeta, Some(rowCounts)) + val newRowCount = if (noScan) None else Some(BigInt(sparkSession.table(tableIdentWithDB).count())) @@ -241,6 +266,10 @@ object CommandUtils extends Logging { if (newStats.isDefined) { sessionState.catalog.alterTableStats(tableIdentWithDB, newStats) } + // Also update partition stats when the config is enabled + if (newPartitions.nonEmpty && partitionStatsEnabled) { + sessionState.catalog.alterPartitions(tableIdentWithDB, newPartitions) + } } } @@ -440,4 +469,36 @@ object CommandUtils extends Logging { case NonFatal(e) => logWarning(s"Exception when attempting to uncache $name", e) } } + + def calculateRowCountsPerPartition( + sparkSession: SparkSession, + tableMeta: CatalogTable, + partitionValueSpec: Option[TablePartitionSpec]): Map[TablePartitionSpec, BigInt] = { + val filter = if (partitionValueSpec.isDefined) { + val filters = partitionValueSpec.get.map { + case (columnName, value) => EqualTo(UnresolvedAttribute(columnName), Literal(value)) + } + filters.reduce(And) + } else { + Literal.TrueLiteral + } + + val tableDf = sparkSession.table(tableMeta.identifier) + val partitionColumns = tableMeta.partitionColumnNames.map(Column(_)) + + val df = tableDf.filter(Column(filter)).groupBy(partitionColumns: _*).count() + + df.collect().map { r => + val partitionColumnValues = partitionColumns.indices.map { i => + if (r.isNullAt(i)) { + ExternalCatalogUtils.DEFAULT_PARTITION_NAME + } else { + r.get(i).toString + } + } + val spec = Utils.toMap(tableMeta.partitionColumnNames, partitionColumnValues) + val count = BigInt(r.getLong(partitionColumns.size)) + (spec, count) + }.toMap + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index 11134a891960..21a115486298 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -363,6 +363,84 @@ class StatisticsSuite extends StatisticsCollectionTestBase with TestHiveSingleto } } + test("SPARK-45731: update partition stats with ANALYZE TABLE") { + val tableName = "analyzeTable_part" + + def queryStats(ds: String): Option[CatalogStatistics] = { + val partition = + spark.sessionState.catalog.getPartition(TableIdentifier(tableName), Map("ds" -> ds)) + partition.stats + } + + val partitionDates = List("2010-01-01", "2010-01-02", "2010-01-03") + val expectedRowCount = 500 + + Seq(true, false).foreach { partitionStatsEnabled => + withSQLConf(SQLConf.UPDATE_PART_STATS_IN_ANALYZE_TABLE_ENABLED.key -> + partitionStatsEnabled.toString) { + withTable(tableName) { + withTempPath { path => + // Create a table with 3 partitions all located under a directory 'path' + sql( + s""" + |CREATE TABLE $tableName (key INT, value STRING) + |USING parquet + |PARTITIONED BY (ds STRING) + |LOCATION '${path.toURI}' + """.stripMargin) + + partitionDates.foreach { ds => + sql(s"ALTER TABLE $tableName ADD PARTITION (ds='$ds') LOCATION '$path/ds=$ds'") + sql("SELECT * FROM src").write.mode(SaveMode.Overwrite) + .format("parquet").save(s"$path/ds=$ds") + } + + assert(getCatalogTable(tableName).stats.isEmpty) + partitionDates.foreach { ds => + assert(queryStats(ds).isEmpty) + } + + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS NOSCAN") + + // Table size should also have been updated + assert(getTableStats(tableName).sizeInBytes > 0) + // Row count should NOT be updated with the `NOSCAN` option + assert(getTableStats(tableName).rowCount.isEmpty) + + partitionDates.foreach { ds => + val partStats = queryStats(ds) + if (partitionStatsEnabled) { + assert(partStats.nonEmpty) + assert(partStats.get.sizeInBytes > 0) + assert(partStats.get.rowCount.isEmpty) + } else { + assert(partStats.isEmpty) + } + } + + sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS") + + assert(getTableStats(tableName).sizeInBytes > 0) + // Table row count should be updated + assert(getTableStats(tableName).rowCount.get == 3 * expectedRowCount) + + partitionDates.foreach { ds => + val partStats = queryStats(ds) + if (partitionStatsEnabled) { + assert(partStats.nonEmpty) + // The scan option should update partition row count + assert(partStats.get.sizeInBytes > 0) + assert(partStats.get.rowCount.get == expectedRowCount) + } else { + assert(partStats.isEmpty) + } + } + } + } + } + } + } + test("analyze single partition") { val tableName = "analyzeTable_part" --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org