This is an automated email from the ASF dual-hosted git repository.

sunchao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new ce818ba96953 [SPARK-45731][SQL] Also update partition statistics with 
`ANALYZE TABLE` command
ce818ba96953 is described below

commit ce818ba969537cf9eb16865a88148407a5992e98
Author: Chao Sun <sunc...@apple.com>
AuthorDate: Thu Nov 9 15:56:47 2023 -0800

    [SPARK-45731][SQL] Also update partition statistics with `ANALYZE TABLE` 
command
    
    ### What changes were proposed in this pull request?
    
    Also update partition statistics (e.g., total size in bytes, row count) 
with `ANALYZE TABLE` command.
    
    ### Why are the changes needed?
    
    Currently when a `ANALYZE TABLE <tableName>` command is triggered against a 
partition table, only table stats are updated, but not partition stats. For 
Spark users who want to update the latter, they have to use a different syntax: 
`ANALYZE TABLE <tableName> PARTITION(<partitionColumns>)` which is more verbose.
    
    Given `ANALYZE TABLE` internally already calculates total size for all the 
partitions, it makes sense to also update partition stats using the result. In 
this way, Spark users do not need to remember two different syntaxes.
    
    In addition, when using `ANALYZE TABLE` with the "scan node", i.e., 
`NOSCAN` is NOT specified, we can also calculate row count for all the 
partitions and update the stats accordingly.
    
    The above behavior is controlled via a new flag 
`spark.sql.statistics.updatePartitionStatsInAnalyzeTable.enabled`, which by 
default is turned off.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Not by default. When 
`spark.sql.statistics.updatePartitionStatsInAnalyzeTable.enabled`, Spark will 
now update partition stats as well with `ANALYZE TABLE` command, on a 
partitioned table.
    
    ### How was this patch tested?
    
    Added a unit test for this feature.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #43629 from sunchao/SPARK-45731.
    
    Authored-by: Chao Sun <sunc...@apple.com>
    Signed-off-by: Chao Sun <sunc...@apple.com>
---
 .../org/apache/spark/sql/internal/SQLConf.scala    | 13 ++++
 .../command/AnalyzePartitionCommand.scala          | 50 ++-----------
 .../spark/sql/execution/command/CommandUtils.scala | 87 ++++++++++++++++++----
 .../apache/spark/sql/hive/StatisticsSuite.scala    | 78 +++++++++++++++++++
 4 files changed, 170 insertions(+), 58 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index ecc3e6e101fc..ff6ab7b541a3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -2664,6 +2664,16 @@ object SQLConf {
       .booleanConf
       .createWithDefault(false)
 
+  val UPDATE_PART_STATS_IN_ANALYZE_TABLE_ENABLED =
+    
buildConf("spark.sql.statistics.updatePartitionStatsInAnalyzeTable.enabled")
+      .doc("When this config is enabled, Spark will also update partition 
statistics in analyze " +
+        "table command (i.e., ANALYZE TABLE .. COMPUTE STATISTICS [NOSCAN]). 
Note the command " +
+        "will also become more expensive. When this config is disabled, Spark 
will only " +
+        "update table level statistics.")
+      .version("4.0.0")
+      .booleanConf
+      .createWithDefault(false)
+
   val CBO_ENABLED =
     buildConf("spark.sql.cbo.enabled")
       .doc("Enables CBO for estimation of plan statistics when set true.")
@@ -5113,6 +5123,9 @@ class SQLConf extends Serializable with Logging with 
SqlApiConf {
 
   def autoSizeUpdateEnabled: Boolean = 
getConf(SQLConf.AUTO_SIZE_UPDATE_ENABLED)
 
+  def updatePartStatsInAnalyzeTableEnabled: Boolean =
+    getConf(SQLConf.UPDATE_PART_STATS_IN_ANALYZE_TABLE_ENABLED)
+
   def joinReorderEnabled: Boolean = getConf(SQLConf.JOIN_REORDER_ENABLED)
 
   def joinReorderDPThreshold: Int = getConf(SQLConf.JOIN_REORDER_DP_THRESHOLD)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala
index c2b227d6cad7..7fe4c73abf90 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AnalyzePartitionCommand.scala
@@ -17,15 +17,12 @@
 
 package org.apache.spark.sql.execution.command
 
-import org.apache.spark.sql.{Column, Row, SparkSession}
+import org.apache.spark.sql.{Row, SparkSession}
 import org.apache.spark.sql.catalyst.TableIdentifier
-import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
-import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType, 
ExternalCatalogUtils}
+import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType}
 import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
-import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Literal}
 import org.apache.spark.sql.errors.QueryCompilationErrors
 import org.apache.spark.sql.util.PartitioningUtils
-import org.apache.spark.util.collection.Utils
 
 /**
  * Analyzes a given set of partitions to generate per-partition statistics, 
which will be used in
@@ -101,20 +98,13 @@ case class AnalyzePartitionCommand(
       if (noscan) {
         Map.empty
       } else {
-        calculateRowCountsPerPartition(sparkSession, tableMeta, 
partitionValueSpec)
+        CommandUtils.calculateRowCountsPerPartition(sparkSession, tableMeta, 
partitionValueSpec)
       }
 
     // Update the metastore if newly computed statistics are different from 
those
     // recorded in the metastore.
-
-    val sizes = CommandUtils.calculateMultipleLocationSizes(sparkSession, 
tableMeta.identifier,
-      partitions.map(_.storage.locationUri))
-    val newPartitions = partitions.zipWithIndex.flatMap { case (p, idx) =>
-      val newRowCount = rowCounts.get(p.spec)
-      val newStats = CommandUtils.compareAndGetNewStats(p.stats, sizes(idx), 
newRowCount)
-      newStats.map(_ => p.copy(stats = newStats))
-    }
-
+    val (_, newPartitions) = CommandUtils.calculatePartitionStats(
+      sparkSession, tableMeta, partitions, Some(rowCounts))
     if (newPartitions.nonEmpty) {
       sessionState.catalog.alterPartitions(tableMeta.identifier, newPartitions)
     }
@@ -122,35 +112,5 @@ case class AnalyzePartitionCommand(
     Seq.empty[Row]
   }
 
-  private def calculateRowCountsPerPartition(
-      sparkSession: SparkSession,
-      tableMeta: CatalogTable,
-      partitionValueSpec: Option[TablePartitionSpec]): Map[TablePartitionSpec, 
BigInt] = {
-    val filter = if (partitionValueSpec.isDefined) {
-      val filters = partitionValueSpec.get.map {
-        case (columnName, value) => EqualTo(UnresolvedAttribute(columnName), 
Literal(value))
-      }
-      filters.reduce(And)
-    } else {
-      Literal.TrueLiteral
-    }
-
-    val tableDf = sparkSession.table(tableMeta.identifier)
-    val partitionColumns = tableMeta.partitionColumnNames.map(Column(_))
-
-    val df = tableDf.filter(Column(filter)).groupBy(partitionColumns: 
_*).count()
 
-    df.collect().map { r =>
-      val partitionColumnValues = partitionColumns.indices.map { i =>
-        if (r.isNullAt(i)) {
-          ExternalCatalogUtils.DEFAULT_PARTITION_NAME
-        } else {
-          r.get(i).toString
-        }
-      }
-      val spec = Utils.toMap(tableMeta.partitionColumnNames, 
partitionColumnValues)
-      val count = BigInt(r.getLong(partitionColumns.size))
-      (spec, count)
-    }.toMap
-  }
 }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala
index c656bdbafa0c..73478272a684 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/CommandUtils.scala
@@ -25,9 +25,11 @@ import scala.util.control.NonFatal
 import org.apache.hadoop.fs.{FileStatus, FileSystem, Path, PathFilter}
 
 import org.apache.spark.internal.Logging
-import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.{Column, SparkSession}
 import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier}
-import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, 
CatalogTablePartition, CatalogTableType}
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.catalog.{CatalogStatistics, CatalogTable, 
CatalogTablePartition, CatalogTableType, ExternalCatalogUtils}
+import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate._
 import org.apache.spark.sql.catalyst.plans.logical._
@@ -37,6 +39,7 @@ import org.apache.spark.sql.execution.QueryExecution
 import org.apache.spark.sql.execution.datasources.{DataSourceUtils, 
InMemoryFileIndex}
 import org.apache.spark.sql.internal.{SessionState, SQLConf}
 import org.apache.spark.sql.types._
+import org.apache.spark.util.collection.Utils
 
 /**
  * For the purpose of calculating total directory sizes, use this filter to
@@ -76,29 +79,42 @@ object CommandUtils extends Logging {
 
   def calculateTotalSize(
       spark: SparkSession,
-      catalogTable: CatalogTable): (BigInt, Seq[CatalogTablePartition]) = {
+      catalogTable: CatalogTable,
+      partitionRowCount: Option[Map[TablePartitionSpec, BigInt]] = None):
+  (BigInt, Seq[CatalogTablePartition]) = {
     val sessionState = spark.sessionState
     val startTime = System.nanoTime()
     val (totalSize, newPartitions) = if 
(catalogTable.partitionColumnNames.isEmpty) {
-      (calculateSingleLocationSize(sessionState, catalogTable.identifier,
-        catalogTable.storage.locationUri), Seq())
+      val size = calculateSingleLocationSize(sessionState, 
catalogTable.identifier,
+        catalogTable.storage.locationUri)
+      (BigInt(size), Seq())
     } else {
       // Calculate table size as a sum of the visible partitions. See 
SPARK-21079
       val partitions = 
sessionState.catalog.listPartitions(catalogTable.identifier)
       logInfo(s"Starting to calculate sizes for ${partitions.length} 
partitions.")
-      val paths = partitions.map(_.storage.locationUri)
-      val sizes = calculateMultipleLocationSizes(spark, 
catalogTable.identifier, paths)
-      val newPartitions = partitions.zipWithIndex.flatMap { case (p, idx) =>
-        val newStats = CommandUtils.compareAndGetNewStats(p.stats, sizes(idx), 
None)
-        newStats.map(_ => p.copy(stats = newStats))
-      }
-      (sizes.sum, newPartitions)
+      calculatePartitionStats(spark, catalogTable, partitions, 
partitionRowCount)
     }
     logInfo(s"It took ${(System.nanoTime() - startTime) / (1000 * 1000)} ms to 
calculate" +
       s" the total size for table ${catalogTable.identifier}.")
     (totalSize, newPartitions)
   }
 
+  def calculatePartitionStats(
+      spark: SparkSession,
+      catalogTable: CatalogTable,
+      partitions: Seq[CatalogTablePartition],
+      partitionRowCount: Option[Map[TablePartitionSpec, BigInt]] = None):
+  (BigInt, Seq[CatalogTablePartition]) = {
+    val paths = partitions.map(_.storage.locationUri)
+    val sizes = calculateMultipleLocationSizes(spark, catalogTable.identifier, 
paths)
+    val newPartitions = partitions.zipWithIndex.flatMap { case (p, idx) =>
+      val newRowCount = partitionRowCount.flatMap(_.get(p.spec))
+      val newStats = CommandUtils.compareAndGetNewStats(p.stats, sizes(idx), 
newRowCount)
+      newStats.map(_ => p.copy(stats = newStats))
+    }
+    (sizes.sum, newPartitions)
+  }
+
   def calculateSingleLocationSize(
       sessionState: SessionState,
       identifier: TableIdentifier,
@@ -214,6 +230,7 @@ object CommandUtils extends Logging {
       tableIdent: TableIdentifier,
       noScan: Boolean): Unit = {
     val sessionState = sparkSession.sessionState
+    val partitionStatsEnabled = 
sessionState.conf.updatePartStatsInAnalyzeTableEnabled
     val db = 
tableIdent.database.getOrElse(sessionState.catalog.getCurrentDatabase)
     val tableIdentWithDB = TableIdentifier(tableIdent.table, Some(db))
     val tableMeta = sessionState.catalog.getTableMetadata(tableIdentWithDB)
@@ -231,7 +248,15 @@ object CommandUtils extends Logging {
       }
     } else {
       // Compute stats for the whole table
-      val (newTotalSize, _) = CommandUtils.calculateTotalSize(sparkSession, 
tableMeta)
+      val rowCounts: Map[TablePartitionSpec, BigInt] =
+        if (noScan || !partitionStatsEnabled) {
+          Map.empty
+        } else {
+          calculateRowCountsPerPartition(sparkSession, tableMeta, None)
+        }
+      val (newTotalSize, newPartitions) = CommandUtils.calculateTotalSize(
+        sparkSession, tableMeta, Some(rowCounts))
+
       val newRowCount =
         if (noScan) None else 
Some(BigInt(sparkSession.table(tableIdentWithDB).count()))
 
@@ -241,6 +266,10 @@ object CommandUtils extends Logging {
       if (newStats.isDefined) {
         sessionState.catalog.alterTableStats(tableIdentWithDB, newStats)
       }
+      // Also update partition stats when the config is enabled
+      if (newPartitions.nonEmpty && partitionStatsEnabled) {
+        sessionState.catalog.alterPartitions(tableIdentWithDB, newPartitions)
+      }
     }
   }
 
@@ -440,4 +469,36 @@ object CommandUtils extends Logging {
       case NonFatal(e) => logWarning(s"Exception when attempting to uncache 
$name", e)
     }
   }
+
+  def calculateRowCountsPerPartition(
+      sparkSession: SparkSession,
+      tableMeta: CatalogTable,
+      partitionValueSpec: Option[TablePartitionSpec]): Map[TablePartitionSpec, 
BigInt] = {
+    val filter = if (partitionValueSpec.isDefined) {
+      val filters = partitionValueSpec.get.map {
+        case (columnName, value) => EqualTo(UnresolvedAttribute(columnName), 
Literal(value))
+      }
+      filters.reduce(And)
+    } else {
+      Literal.TrueLiteral
+    }
+
+    val tableDf = sparkSession.table(tableMeta.identifier)
+    val partitionColumns = tableMeta.partitionColumnNames.map(Column(_))
+
+    val df = tableDf.filter(Column(filter)).groupBy(partitionColumns: 
_*).count()
+
+    df.collect().map { r =>
+      val partitionColumnValues = partitionColumns.indices.map { i =>
+        if (r.isNullAt(i)) {
+          ExternalCatalogUtils.DEFAULT_PARTITION_NAME
+        } else {
+          r.get(i).toString
+        }
+      }
+      val spec = Utils.toMap(tableMeta.partitionColumnNames, 
partitionColumnValues)
+      val count = BigInt(r.getLong(partitionColumns.size))
+      (spec, count)
+    }.toMap
+  }
 }
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
index 11134a891960..21a115486298 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala
@@ -363,6 +363,84 @@ class StatisticsSuite extends StatisticsCollectionTestBase 
with TestHiveSingleto
     }
   }
 
+  test("SPARK-45731: update partition stats with ANALYZE TABLE") {
+    val tableName = "analyzeTable_part"
+
+    def queryStats(ds: String): Option[CatalogStatistics] = {
+      val partition =
+        spark.sessionState.catalog.getPartition(TableIdentifier(tableName), 
Map("ds" -> ds))
+      partition.stats
+    }
+
+    val partitionDates = List("2010-01-01", "2010-01-02", "2010-01-03")
+    val expectedRowCount = 500
+
+    Seq(true, false).foreach { partitionStatsEnabled =>
+      withSQLConf(SQLConf.UPDATE_PART_STATS_IN_ANALYZE_TABLE_ENABLED.key ->
+        partitionStatsEnabled.toString) {
+        withTable(tableName) {
+          withTempPath { path =>
+            // Create a table with 3 partitions all located under a directory 
'path'
+            sql(
+              s"""
+                 |CREATE TABLE $tableName (key INT, value STRING)
+                 |USING parquet
+                 |PARTITIONED BY (ds STRING)
+                 |LOCATION '${path.toURI}'
+               """.stripMargin)
+
+            partitionDates.foreach { ds =>
+              sql(s"ALTER TABLE $tableName ADD PARTITION (ds='$ds') LOCATION 
'$path/ds=$ds'")
+              sql("SELECT * FROM src").write.mode(SaveMode.Overwrite)
+                .format("parquet").save(s"$path/ds=$ds")
+            }
+
+            assert(getCatalogTable(tableName).stats.isEmpty)
+            partitionDates.foreach { ds =>
+              assert(queryStats(ds).isEmpty)
+            }
+
+            sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS NOSCAN")
+
+            // Table size should also have been updated
+            assert(getTableStats(tableName).sizeInBytes > 0)
+            // Row count should NOT be updated with the `NOSCAN` option
+            assert(getTableStats(tableName).rowCount.isEmpty)
+
+            partitionDates.foreach { ds =>
+              val partStats = queryStats(ds)
+              if (partitionStatsEnabled) {
+                assert(partStats.nonEmpty)
+                assert(partStats.get.sizeInBytes > 0)
+                assert(partStats.get.rowCount.isEmpty)
+              } else {
+                assert(partStats.isEmpty)
+              }
+            }
+
+            sql(s"ANALYZE TABLE $tableName COMPUTE STATISTICS")
+
+            assert(getTableStats(tableName).sizeInBytes > 0)
+            // Table row count should be updated
+            assert(getTableStats(tableName).rowCount.get == 3 * 
expectedRowCount)
+
+            partitionDates.foreach { ds =>
+              val partStats = queryStats(ds)
+              if (partitionStatsEnabled) {
+                assert(partStats.nonEmpty)
+                // The scan option should update partition row count
+                assert(partStats.get.sizeInBytes > 0)
+                assert(partStats.get.rowCount.get == expectedRowCount)
+              } else {
+                assert(partStats.isEmpty)
+              }
+            }
+          }
+        }
+      }
+    }
+  }
+
   test("analyze single partition") {
     val tableName = "analyzeTable_part"
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to