This is an automated email from the ASF dual-hosted git repository.

kabhwan pushed a commit to branch branch-4.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-4.0 by this push:
     new 55fc6f5f3028 [SPARK-51097][SS] Adding state store instance metrics for 
last uploaded snapshot version in RocksDB
55fc6f5f3028 is described below

commit 55fc6f5f3028c35fa7564e45daa7c0e3d4d456f9
Author: Zeyu Chen <[email protected]>
AuthorDate: Fri Feb 21 06:57:52 2025 +0900

    [SPARK-51097][SS] Adding state store instance metrics for last uploaded 
snapshot version in RocksDB
    
    ### What changes were proposed in this pull request?
    
    SPARK-51097
    
    This PR sets up instance-specific metrics 
(`SnapshotLastUploaded.partition_<partition id>_<state store name>` to be 
precise) in the executor side and publishes them through StreamingQueryProgress.
    
    ### Why are the changes needed?
    
    There's currently a lack of observability into state store specific 
maintenance information, notably metrics of the last snapshot version uploaded. 
This affects the ability to identify performance degradation issues behind 
maintenance tasks and more as described in SPARK-51097.
    
    ### Does this PR introduce _any_ user-facing change?
    
    There will be some new metrics displayed from StreamingQueryProgress:
    ```
    Streaming query made progress: {
      ...
      "stateOperators" : [ {
        ...
        "customMetrics" : {
          ...
          "SnapshotLastUploaded.partition_0_default" : 2,
          "SnapshotLastUploaded.partition_12_default" : 10,
          "SnapshotLastUploaded.partition_8_default" : 10,
          ...
        }
      } ],
      "sources" : ...,
      "sink" : ...
    }
    ```
    Not all state store instance's metrics will be published to remove noise in 
query progress messages. The upper threshold is configured using 
`STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT`, and will by default report 5 
instance metrics.
    
    ### How was this patch tested?
    
    Four new tests are added in RocksDBStateStoreIntegrationSuite.
    
    The first two tests execute a dedup streaming query and verifies metrics 
are properly filtered + updated through the StreamingQueryProgress logs, but 
with different StateStore providers that skip maintenance tasks for specific 
partitions.
    
    The other two tests execute a join streaming query, which contains four 
state stores per partition instead of one. These two tests verifies metrics are 
properly collected and filtered as well.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #49816 from zecookiez/SPARK-51097.
    
    Lead-authored-by: Zeyu Chen <[email protected]>
    Co-authored-by: Zeyu Chen <[email protected]>
    Signed-off-by: Jungtaek Lim <[email protected]>
    (cherry picked from commit da1854e0cb38681b950ff39a2cfb99d303e192c8)
    Signed-off-by: Jungtaek Lim <[email protected]>
---
 .../org/apache/spark/sql/internal/SQLConf.scala    |  16 ++
 .../streaming/StreamingSymmetricHashJoinExec.scala |   7 +-
 .../sql/execution/streaming/state/RocksDB.scala    |  11 +-
 .../state/RocksDBStateStoreProvider.scala          |  16 +-
 .../sql/execution/streaming/state/StateStore.scala | 101 +++++++-
 .../state/SymmetricHashJoinStateManager.scala      |   4 +-
 .../execution/streaming/statefulOperators.scala    |  82 ++++++-
 .../RocksDBStateStoreCheckpointFormatV2Suite.scala |   3 +
 .../state/RocksDBStateStoreIntegrationSuite.scala  | 253 ++++++++++++++++++++-
 9 files changed, 473 insertions(+), 20 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 628ed928e85e..5efa3e8a4ca6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -2218,6 +2218,19 @@ object SQLConf {
       .booleanConf
       .createWithDefault(true)
 
+  val STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT =
+    
buildConf("spark.sql.streaming.stateStore.numStateStoreInstanceMetricsToReport")
+      .internal()
+      .doc(
+        "Number of state store instance metrics included in streaming query 
progress messages " +
+        "per stateful operator. Instance metrics are selected based on 
metric-specific ordering " +
+        "to minimize noise in the progress report."
+      )
+      .version("4.0.0")
+      .intConf
+      .checkValue(k => k >= 0, "Must be greater than or equal to 0")
+      .createWithDefault(5)
+
   val STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT =
     buildConf("spark.sql.streaming.stateStore.minDeltasForSnapshot")
       .internal()
@@ -5720,6 +5733,9 @@ class SQLConf extends Serializable with Logging with 
SqlApiConf {
 
   def numStateStoreMaintenanceThreads: Int = 
getConf(NUM_STATE_STORE_MAINTENANCE_THREADS)
 
+  def numStateStoreInstanceMetricsToReport: Int =
+    getConf(STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT)
+
   def stateStoreMinDeltasForSnapshot: Int = 
getConf(STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT)
 
   def stateStoreFormatValidationEnabled: Boolean = 
getConf(STATE_STORE_FORMAT_VALIDATION_ENABLED)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
index 5eab57f7372c..7c8ba260b88a 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingSymmetricHashJoinExec.scala
@@ -224,7 +224,7 @@ case class StreamingSymmetricHashJoinExec(
 
   override def shortName: String = "symmetricHashJoin"
 
-  private val stateStoreNames =
+  override val stateStoreNames: Seq[String] =
     SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide)
 
   override def operatorStateMetadata(
@@ -527,9 +527,8 @@ case class StreamingSymmetricHashJoinExec(
           (leftSideJoiner.numUpdatedStateRows + 
rightSideJoiner.numUpdatedStateRows)
         numTotalStateRows += combinedMetrics.numKeys
         stateMemory += combinedMetrics.memoryUsedBytes
-        combinedMetrics.customMetrics.foreach { case (metric, value) =>
-          longMetric(metric.name) += value
-        }
+        setStoreCustomMetrics(combinedMetrics.customMetrics)
+        setStoreInstanceMetrics(combinedMetrics.instanceMetrics)
       }
 
       val stateStoreNames = 
SymmetricHashJoinStateManager.allStateStoreNames(LeftSide, RightSide);
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
index bc7ff53d9af3..820322d1e0ee 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala
@@ -22,7 +22,7 @@ import java.util.Locale
 import java.util.Set
 import java.util.UUID
 import java.util.concurrent.{ConcurrentHashMap, ConcurrentLinkedQueue, 
TimeUnit}
-import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger}
+import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger, AtomicLong}
 import javax.annotation.concurrent.GuardedBy
 
 import scala.collection.{mutable, Map}
@@ -147,6 +147,10 @@ class RocksDB(
   private val enableChangelogCheckpointing: Boolean = 
conf.enableChangelogCheckpointing
   @volatile protected var loadedVersion: Long = -1L   // -1 = nothing valid is 
loaded
 
+  // Can be updated by whichever thread uploaded a snapshot, which could be 
either task,
+  // maintenance, or both. -1 represents no version has ever been uploaded.
+  protected val lastUploadedSnapshotVersion: AtomicLong = new AtomicLong(-1L)
+
   // variables to manage checkpoint ID. Once a checkpointing finishes, it 
needs to return
   // `lastCommittedStateStoreCkptId` as the committed checkpointID, as well as
   // `lastCommitBasedStateStoreCkptId` as the checkpontID of the previous 
version that is based on.
@@ -1293,6 +1297,7 @@ class RocksDB(
       bytesCopied = fileManagerMetrics.bytesCopied,
       filesCopied = fileManagerMetrics.filesCopied,
       filesReused = fileManagerMetrics.filesReused,
+      lastUploadedSnapshotVersion = lastUploadedSnapshotVersion.get(),
       zipFileBytesUncompressed = fileManagerMetrics.zipFileBytesUncompressed,
       nativeOpsMetrics = nativeOpsMetrics)
   }
@@ -1461,6 +1466,7 @@ class RocksDB(
         log"with uniqueId: ${MDC(LogKeys.UUID, snapshot.uniqueId)} " +
         log"time taken: ${MDC(LogKeys.TIME_UNITS, uploadTime)} ms. " +
         log"Current lineage: ${MDC(LogKeys.LINEAGE, lineageManager)}")
+      lastUploadedSnapshotVersion.set(snapshot.version)
     } finally {
       snapshot.close()
     }
@@ -1912,7 +1918,8 @@ case class RocksDBMetrics(
     bytesCopied: Long,
     filesReused: Long,
     zipFileBytesUncompressed: Option[Long],
-    nativeOpsMetrics: Map[String, Long]) {
+    nativeOpsMetrics: Map[String, Long],
+    lastUploadedSnapshotVersion: Long) {
   def json: String = Serialization.write(this)(RocksDBMetrics.format)
 }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
index cd9fdb9469d6..47721cea4359 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala
@@ -316,14 +316,20 @@ private[sql] class RocksDBStateStoreProvider
         ) ++ rocksDBMetrics.zipFileBytesUncompressed.map(bytes =>
           Map(CUSTOM_METRIC_ZIP_FILE_BYTES_UNCOMPRESSED -> 
bytes)).getOrElse(Map())
 
+        val stateStoreInstanceMetrics = Map[StateStoreInstanceMetric, Long](
+          CUSTOM_INSTANCE_METRIC_SNAPSHOT_LAST_UPLOADED
+            .withNewId(id.partitionId, id.storeName) -> 
rocksDBMetrics.lastUploadedSnapshotVersion
+        )
+
         StateStoreMetrics(
           rocksDBMetrics.numUncommittedKeys,
           rocksDBMetrics.totalMemUsageBytes,
-          stateStoreCustomMetrics)
+          stateStoreCustomMetrics,
+          stateStoreInstanceMetrics)
       } else {
         logInfo(log"Failed to collect metrics for 
store_id=${MDC(STATE_STORE_ID, id)} " +
           log"and version=${MDC(VERSION_NUM, version)}")
-        StateStoreMetrics(0, 0, Map.empty)
+        StateStoreMetrics(0, 0, Map.empty, Map.empty)
       }
     }
 
@@ -497,6 +503,8 @@ private[sql] class RocksDBStateStoreProvider
 
   override def supportedCustomMetrics: Seq[StateStoreCustomMetric] = 
ALL_CUSTOM_METRICS
 
+  override def supportedInstanceMetrics: Seq[StateStoreInstanceMetric] = 
ALL_INSTANCE_METRICS
+
   private[state] def latestVersion: Long = rocksDB.getLatestVersion()
 
   /** Internal fields and methods */
@@ -888,6 +896,10 @@ object RocksDBStateStoreProvider {
     CUSTOM_METRIC_COMPACT_WRITTEN_BYTES, CUSTOM_METRIC_FLUSH_WRITTEN_BYTES,
     CUSTOM_METRIC_PINNED_BLOCKS_MEM_USAGE, 
CUSTOM_METRIC_NUM_INTERNAL_COL_FAMILIES_KEYS,
     CUSTOM_METRIC_NUM_EXTERNAL_COL_FAMILIES, 
CUSTOM_METRIC_NUM_INTERNAL_COL_FAMILIES)
+
+  val CUSTOM_INSTANCE_METRIC_SNAPSHOT_LAST_UPLOADED = 
StateStoreSnapshotLastUploadInstanceMetric()
+
+  val ALL_INSTANCE_METRICS = Seq(CUSTOM_INSTANCE_METRIC_SNAPSHOT_LAST_UPLOADED)
 }
 
 /** [[StateStoreChangeDataReader]] implementation for 
[[RocksDBStateStoreProvider]] */
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
index 8ba3fc37162c..09acc24aff98 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala
@@ -249,12 +249,17 @@ class WrappedReadStateStore(store: StateStore) extends 
ReadStateStore {
  * @param memoryUsedBytes Memory used by the state store
  * @param customMetrics   Custom implementation-specific metrics
  *                        The metrics reported through this must have the same 
`name` as those
- *                        reported by `StateStoreProvider.customMetrics`.
+ *                        reported by 
`StateStoreProvider.supportedCustomMetrics`.
+ * @param instanceMetrics Custom implementation-specific metrics that are 
specific to state stores
+ *                        The metrics reported through this must have the same 
`name` as those
+ *                        reported by 
`StateStoreProvider.supportedInstanceMetrics`,
+ *                        including partition id and store name.
  */
 case class StateStoreMetrics(
     numKeys: Long,
     memoryUsedBytes: Long,
-    customMetrics: Map[StateStoreCustomMetric, Long])
+    customMetrics: Map[StateStoreCustomMetric, Long],
+    instanceMetrics: Map[StateStoreInstanceMetric, Long] = Map.empty)
 
 /**
  * State store checkpoint information, used to pass checkpointing information 
from executors
@@ -284,7 +289,8 @@ object StateStoreMetrics {
     StateStoreMetrics(
       allMetrics.map(_.numKeys).sum,
       allMetrics.map(_.memoryUsedBytes).sum,
-      combinedCustomMetrics)
+      combinedCustomMetrics,
+      allMetrics.flatMap(_.instanceMetrics).toMap)
   }
 }
 
@@ -321,6 +327,86 @@ case class StateStoreCustomTimingMetric(name: String, 
desc: String) extends Stat
     SQLMetrics.createTimingMetric(sparkContext, desc)
 }
 
+trait StateStoreInstanceMetric {
+  def metricPrefix: String
+  def descPrefix: String
+  def partitionId: Option[Int]
+  def storeName: String
+  def initValue: Long
+
+  def createSQLMetric(sparkContext: SparkContext): SQLMetric
+
+  /**
+   * Defines how instance metrics are selected for progress reporting.
+   * Metrics are sorted by value using this ordering, and only the first N 
metrics are displayed.
+   * For example, the highest N metrics by value should use 
Ordering.Long.reverse.
+   */
+  def ordering: Ordering[Long]
+
+  /** Should this instance metric be reported if it is unchanged from its 
initial value */
+  def ignoreIfUnchanged: Boolean
+
+  /**
+   * Defines how to merge metric values from different executors for the same 
state store
+   * instance in situations like speculative execution or provider unloading. 
In most cases,
+   * the original metric value is at its initial value.
+   */
+  def combine(originalMetric: SQLMetric, value: Long): Long
+
+  def name: String = {
+    assert(partitionId.isDefined, "Partition ID must be defined for instance 
metric name")
+    s"$metricPrefix.partition_${partitionId.get}_$storeName"
+  }
+
+  def desc: String = {
+    assert(partitionId.isDefined, "Partition ID must be defined for instance 
metric description")
+    s"$descPrefix (partitionId = ${partitionId.get}, storeName = $storeName)"
+  }
+
+  def withNewId(partitionId: Int, storeName: String): StateStoreInstanceMetric
+}
+
+case class StateStoreSnapshotLastUploadInstanceMetric(
+    partitionId: Option[Int] = None,
+    storeName: String = StateStoreId.DEFAULT_STORE_NAME)
+  extends StateStoreInstanceMetric {
+
+  override def metricPrefix: String = "SnapshotLastUploaded"
+
+  override def descPrefix: String = {
+    "The last uploaded version of the snapshot for a specific state store 
instance"
+  }
+
+  override def initValue: Long = -1L
+
+  override def createSQLMetric(sparkContext: SparkContext): SQLMetric = {
+    SQLMetrics.createSizeMetric(sparkContext, desc, initValue)
+  }
+
+  override def ordering: Ordering[Long] = Ordering.Long
+
+  override def ignoreIfUnchanged: Boolean = false
+
+  override def combine(originalMetric: SQLMetric, value: Long): Long = {
+    // Check for cases where the initial value is less than 0, forcing 
metric.value to
+    // convert it to 0. Since the last uploaded snapshot version can have an 
initial
+    // value of -1, we need special handling to avoid turning the -1 into a 0.
+    if (originalMetric.isZero) {
+      value
+    } else {
+      // Use max to grab the most recent snapshot version across all executors
+      // of the same store instance
+      Math.max(originalMetric.value, value)
+    }
+  }
+
+  override def withNewId(
+      partitionId: Int,
+      storeName: String): StateStoreSnapshotLastUploadInstanceMetric = {
+    copy(partitionId = Some(partitionId), storeName = storeName)
+  }
+}
+
 sealed trait KeyStateEncoderSpec {
   def keySchema: StructType
   def jsonValue: JValue
@@ -495,9 +581,16 @@ trait StateStoreProvider {
   /**
    * Optional custom metrics that the implementation may want to report.
    * @note The StateStore objects created by this provider must report the 
same custom metrics
-   * (specifically, same names) through `StateStore.metrics`.
+   * (specifically, same names) through `StateStore.metrics.customMetrics`.
    */
   def supportedCustomMetrics: Seq[StateStoreCustomMetric] = Nil
+
+  /**
+   * Optional custom state store instance metrics that the implementation may 
want to report.
+   * @note The StateStore objects created by this provider must report the 
same instance metrics
+   * (specifically, same names) through `StateStore.metrics.instanceMetrics`.
+   */
+  def supportedInstanceMetrics: Seq[StateStoreInstanceMetric] = Seq.empty
 }
 
 object StateStoreProvider {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
index f487ddf4252c..66ab0006c498 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SymmetricHashJoinStateManager.scala
@@ -447,7 +447,9 @@ class SymmetricHashJoinStateManager(
       keyToNumValuesMetrics.memoryUsedBytes + 
keyWithIndexToValueMetrics.memoryUsedBytes,
       keyWithIndexToValueMetrics.customMetrics.map {
         case (metric, value) => (metric.withNewDesc(desc = 
newDesc(metric.desc)), value)
-      }
+      },
+      // We want to collect instance metrics from both state stores
+      keyWithIndexToValueMetrics.instanceMetrics ++ 
keyToNumValuesMetrics.instanceMetrics
     )
   }
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
index fc269897edd6..af47229dfa88 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala
@@ -216,7 +216,14 @@ trait StateStoreWriter
     "stateMemory" -> SQLMetrics.createSizeMetric(sparkContext, "memory used by 
state"),
     "numStateStoreInstances" -> SQLMetrics.createMetric(sparkContext,
       "number of state store instances")
-  ) ++ stateStoreCustomMetrics ++ pythonMetrics
+  ) ++ stateStoreCustomMetrics ++ pythonMetrics ++ stateStoreInstanceMetrics
+
+  val stateStoreNames: Seq[String] = Seq(StateStoreId.DEFAULT_STORE_NAME)
+
+  // This is used to relate metric names back to their original metric object,
+  // which holds information on how to report the metric during getProgress.
+  lazy val instanceMetricConfiguration: Map[String, StateStoreInstanceMetric] =
+    stateStoreInstanceMetricObjects
 
   // This method is only used to fetch the state schema directory path for
   // operators that use StateSchemaV3, as prior versions only use a single
@@ -320,11 +327,40 @@ trait StateStoreWriter
    * the driver after this SparkPlan has been executed and metrics have been 
updated.
    */
   def getProgress(): StateOperatorProgress = {
+    val instanceMetricsToReport = instanceMetricConfiguration
+      .filter {
+        case (name, metricConfig) =>
+          // Keep instance metrics that are updated or aren't marked to be 
ignored,
+          // as their initial value could still be important.
+          !metricConfig.ignoreIfUnchanged || !longMetric(name).isZero
+      }
+      .groupBy {
+        // Group all instance metrics underneath their common metric prefix
+        // to ignore partition and store names.
+        case (name, metricConfig) => metricConfig.metricPrefix
+      }
+      .flatMap {
+        case (_, metrics) =>
+          // Select at most N metrics based on the metric's defined ordering
+          // to report to the driver. For example, ascending order would be 
taking the N smallest.
+          val metricConf = metrics.head._2
+          metrics
+            .map {
+              case (_, metric) =>
+                metric.name -> (if (longMetric(metric.name).isZero) 
metricConf.initValue
+                                else longMetric(metric.name).value)
+            }
+            .toSeq
+            .sortBy(_._2)(metricConf.ordering)
+            .take(conf.numStateStoreInstanceMetricsToReport)
+            .toMap
+      }
     val customMetrics = (stateStoreCustomMetrics ++ 
statefulOperatorCustomMetrics)
       .map(entry => entry._1 -> longMetric(entry._1).value)
+    val allCustomMetrics = customMetrics ++ instanceMetricsToReport
 
     val javaConvertedCustomMetrics: java.util.HashMap[String, java.lang.Long] =
-      new java.util.HashMap(customMetrics.transform((_, v) => 
long2Long(v)).asJava)
+      new java.util.HashMap(allCustomMetrics.transform((_, v) => 
long2Long(v)).asJava)
 
     // We now don't report number of shuffle partitions inside the state 
operator. Instead,
     // it will be filled when the stream query progress is reported
@@ -373,9 +409,8 @@ trait StateStoreWriter
     val storeMetrics = store.metrics
     longMetric("numTotalStateRows") += storeMetrics.numKeys
     longMetric("stateMemory") += storeMetrics.memoryUsedBytes
-    storeMetrics.customMetrics.foreach { case (metric, value) =>
-      longMetric(metric.name) += value
-    }
+    setStoreCustomMetrics(storeMetrics.customMetrics)
+    setStoreInstanceMetrics(storeMetrics.instanceMetrics)
 
     if (StatefulOperatorStateInfo.enableStateStoreCheckpointIds(conf)) {
       // Set the state store checkpoint information for the driver to collect
@@ -391,6 +426,23 @@ trait StateStoreWriter
     }
   }
 
+  protected def setStoreCustomMetrics(customMetrics: 
Map[StateStoreCustomMetric, Long]): Unit = {
+    customMetrics.foreach {
+      case (metric, value) =>
+        longMetric(metric.name) += value
+    }
+  }
+
+  protected def setStoreInstanceMetrics(
+      instanceMetrics: Map[StateStoreInstanceMetric, Long]): Unit = {
+    instanceMetrics.foreach {
+      case (metric, value) =>
+        val metricConfig = instanceMetricConfiguration(metric.name)
+        // Update the metric's value based on the defined combine method
+        
longMetric(metric.name).set(metricConfig.combine(longMetric(metric.name), 
value))
+    }
+  }
+
   private def stateStoreCustomMetrics: Map[String, SQLMetric] = {
     val provider = StateStoreProvider.create(conf.stateStoreProviderClass)
     provider.supportedCustomMetrics.map {
@@ -398,6 +450,26 @@ trait StateStoreWriter
     }.toMap
   }
 
+  private def stateStoreInstanceMetrics: Map[String, SQLMetric] = {
+    instanceMetricConfiguration.map {
+      case (name, metric) => (name, metric.createSQLMetric(sparkContext))
+    }
+  }
+
+  private def stateStoreInstanceMetricObjects: Map[String, 
StateStoreInstanceMetric] = {
+    val provider = StateStoreProvider.create(conf.stateStoreProviderClass)
+    val maxPartitions = 
stateInfo.map(_.numPartitions).getOrElse(conf.defaultNumShufflePartitions)
+
+    (0 until maxPartitions).flatMap { partitionId =>
+      provider.supportedInstanceMetrics.flatMap { metric =>
+        stateStoreNames.map { storeName =>
+          val metricWithPartition = metric.withNewId(partitionId, storeName)
+          (metricWithPartition.name, metricWithPartition)
+        }
+      }
+    }.toMap
+  }
+
   /**
    * Set of stateful operator custom metrics. These are captured as part of 
the generic
    * key-value map [[StateOperatorProgress.customMetrics]].
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala
index d35bbd49de0d..da4f685aaff8 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreCheckpointFormatV2Suite.scala
@@ -183,6 +183,9 @@ class CkptIdCollectingStateStoreProviderWrapper extends 
StateStoreProvider {
 
   override def supportedCustomMetrics: Seq[StateStoreCustomMetric] =
     innerProvider.supportedCustomMetrics
+
+  override def supportedInstanceMetrics: Seq[StateStoreInstanceMetric] =
+    innerProvider.supportedInstanceMetrics
 }
 
 class RocksDBStateStoreCheckpointFormatV2Suite extends StreamTest
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala
index 1f4fd7f79571..b3807653d8d2 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala
@@ -19,21 +19,36 @@ package org.apache.spark.sql.execution.streaming.state
 
 import java.io.File
 
-import scala.jdk.CollectionConverters.SetHasAsScala
+import scala.concurrent.duration.DurationInt
+import scala.jdk.CollectionConverters.{MapHasAsScala, SetHasAsScala}
 
 import org.scalatest.time.{Minute, Span}
 
 import org.apache.spark.sql.execution.streaming.{MemoryStream, 
StreamingQueryWrapper}
-import org.apache.spark.sql.functions.count
+import org.apache.spark.sql.functions.{count, expr}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming._
 import org.apache.spark.sql.streaming.OutputMode.Update
 import org.apache.spark.util.Utils
 
+// SkipMaintenanceOnCertainPartitionsProvider is a test-only provider that 
skips running
+// maintenance for partitions 0 and 1 (these are arbitrary choices). This is 
used to test
+// snapshot upload lag can be observed through StreamingQueryProgress metrics.
+class SkipMaintenanceOnCertainPartitionsProvider extends 
RocksDBStateStoreProvider {
+  override def doMaintenance(): Unit = {
+    if (stateStoreId.partitionId == 0 || stateStoreId.partitionId == 1) {
+      return
+    }
+    super.doMaintenance()
+  }
+}
+
 class RocksDBStateStoreIntegrationSuite extends StreamTest
   with AlsoTestWithRocksDBFeatures {
   import testImplicits._
 
+  private val SNAPSHOT_LAG_METRIC_PREFIX = "SnapshotLastUploaded.partition_"
+
   testWithColumnFamilies("RocksDBStateStore",
     TestWithBothChangelogCheckpointingEnabledAndDisabled) { colFamiliesEnabled 
=>
     withTempDir { dir =>
@@ -71,6 +86,7 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest
         (SQLConf.STREAMING_NO_DATA_PROGRESS_EVENT_INTERVAL.key -> "10"),
         (SQLConf.STATE_STORE_PROVIDER_CLASS.key -> 
classOf[RocksDBStateStoreProvider].getName),
         (SQLConf.CHECKPOINT_LOCATION.key -> dir.getCanonicalPath),
+        (SQLConf.STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT.key -> "0"),
         (SQLConf.SHUFFLE_PARTITIONS.key, "1")) {
         val inputData = MemoryStream[Int]
 
@@ -270,4 +286,237 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest
     assert(changelogVersionsPresent(dirForPartition0) == List(3L, 4L))
     assert(snapshotVersionsPresent(dirForPartition0).contains(5L))
   }
+
+  private def snapshotLagMetricName(
+      partitionId: Long,
+      storeName: String = StateStoreId.DEFAULT_STORE_NAME): String = {
+    s"$SNAPSHOT_LAG_METRIC_PREFIX${partitionId}_$storeName"
+  }
+
+  testWithChangelogCheckpointingEnabled(
+    "SPARK-51097: Verify snapshot lag metrics are updated correctly with 
RocksDBStateStoreProvider"
+  ) {
+    withSQLConf(
+      SQLConf.STATE_STORE_PROVIDER_CLASS.key -> 
classOf[RocksDBStateStoreProvider].getName,
+      SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
+      SQLConf.STREAMING_NO_DATA_PROGRESS_EVENT_INTERVAL.key -> "10",
+      SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
+      SQLConf.STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT.key -> "3"
+    ) {
+      withTempDir { checkpointDir =>
+        val inputData = MemoryStream[String]
+        val result = inputData.toDS().dropDuplicates()
+
+        testStream(result, outputMode = OutputMode.Update)(
+          StartStream(checkpointLocation = checkpointDir.getCanonicalPath),
+          AddData(inputData, "a"),
+          ProcessAllAvailable(),
+          AddData(inputData, "b"),
+          ProcessAllAvailable(),
+          CheckNewAnswer("a", "b"),
+          Execute { q =>
+            // Make sure only smallest K active metrics are published
+            eventually(timeout(10.seconds)) {
+              val instanceMetrics = q.lastProgress
+                .stateOperators(0)
+                .customMetrics
+                .asScala
+                .view
+                .filterKeys(_.startsWith(SNAPSHOT_LAG_METRIC_PREFIX))
+              // Determined by STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT
+              assert(
+                instanceMetrics.size == q.sparkSession.conf
+                  .get(SQLConf.STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT)
+              )
+              assert(instanceMetrics.forall(_._2 == 1))
+            }
+          },
+          StopStream
+        )
+      }
+    }
+  }
+
+  testWithChangelogCheckpointingEnabled(
+    "SPARK-51097: Verify snapshot lag metrics are updated correctly with " +
+    "SkipMaintenanceOnCertainPartitionsProvider"
+  ) {
+    withSQLConf(
+      SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+        classOf[SkipMaintenanceOnCertainPartitionsProvider].getName,
+      SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
+      SQLConf.STREAMING_NO_DATA_PROGRESS_EVENT_INTERVAL.key -> "10",
+      SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
+      SQLConf.STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT.key -> "3"
+    ) {
+      withTempDir { checkpointDir =>
+        val inputData = MemoryStream[String]
+        val result = inputData.toDS().dropDuplicates()
+
+        testStream(result, outputMode = OutputMode.Update)(
+          StartStream(checkpointLocation = checkpointDir.getCanonicalPath),
+          AddData(inputData, "a"),
+          ProcessAllAvailable(),
+          AddData(inputData, "b"),
+          ProcessAllAvailable(),
+          CheckNewAnswer("a", "b"),
+          Execute { q =>
+            // Partitions getting skipped (id 0 and 1) do not have an uploaded 
version, leaving
+            // those instance metrics as -1.
+            eventually(timeout(10.seconds)) {
+              assert(
+                q.lastProgress
+                  .stateOperators(0)
+                  .customMetrics
+                  .get(snapshotLagMetricName(0)) === -1
+              )
+              assert(
+                q.lastProgress
+                  .stateOperators(0)
+                  .customMetrics
+                  .get(snapshotLagMetricName(1)) === -1
+              )
+              // Make sure only smallest K active metrics are published
+              val instanceMetrics = q.lastProgress
+                .stateOperators(0)
+                .customMetrics
+                .asScala
+                .view
+                .filterKeys(_.startsWith(SNAPSHOT_LAG_METRIC_PREFIX))
+              // Determined by STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT
+              assert(
+                instanceMetrics.size == q.sparkSession.conf
+                  .get(SQLConf.STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT)
+              )
+              // Two metrics published are -1, the remainder should all be 1 
as they
+              // uploaded properly.
+              assert(
+                instanceMetrics.count(_._2 == 1) == q.sparkSession.conf
+                  .get(SQLConf.STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT) - 2
+              )
+            }
+          },
+          StopStream
+        )
+      }
+    }
+  }
+
+  testWithChangelogCheckpointingEnabled(
+    "SPARK-51097: Verify snapshot lag metrics are updated correctly for join 
queries with " +
+    "RocksDBStateStoreProvider"
+  ) {
+    withSQLConf(
+      SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+        classOf[RocksDBStateStoreProvider].getName,
+      SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
+      SQLConf.STREAMING_NO_DATA_PROGRESS_EVENT_INTERVAL.key -> "10",
+      SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
+      SQLConf.STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT.key -> "10"
+    ) {
+      withTempDir { checkpointDir =>
+        val input1 = MemoryStream[Int]
+        val input2 = MemoryStream[Int]
+
+        val df1 = input1.toDF().select($"value" as "leftKey", ($"value" * 2) 
as "leftValue")
+        val df2 = input2
+          .toDF()
+          .select($"value" as "rightKey", ($"value" * 3) as "rightValue")
+        val joined = df1.join(df2, expr("leftKey = rightKey"))
+
+        testStream(joined)(
+          StartStream(checkpointLocation = checkpointDir.getCanonicalPath),
+          AddData(input1, 1, 5),
+          ProcessAllAvailable(),
+          AddData(input2, 1, 5, 10),
+          ProcessAllAvailable(),
+          CheckNewAnswer((1, 2, 1, 3), (5, 10, 5, 15)),
+          Execute { q =>
+            eventually(timeout(10.seconds)) {
+              // Make sure only smallest K active metrics are published.
+              // There are 5 * 4 = 20 metrics in total because of join, but 
only 10 are published.
+              val instanceMetrics = q.lastProgress
+                .stateOperators(0)
+                .customMetrics
+                .asScala
+                .view
+                .filterKeys(_.startsWith(SNAPSHOT_LAG_METRIC_PREFIX))
+              // Determined by STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT
+              assert(
+                instanceMetrics.size == q.sparkSession.conf
+                  .get(SQLConf.STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT)
+              )
+              // All state store instances should have uploaded a version
+              assert(instanceMetrics.forall(_._2 == 1))
+            }
+          },
+          StopStream
+        )
+      }
+    }
+  }
+
+  testWithChangelogCheckpointingEnabled(
+    "SPARK-51097: Verify snapshot lag metrics are updated correctly for join 
queries with " +
+    "SkipMaintenanceOnCertainPartitionsProvider"
+  ) {
+    withSQLConf(
+      SQLConf.STATE_STORE_PROVIDER_CLASS.key ->
+        classOf[SkipMaintenanceOnCertainPartitionsProvider].getName,
+      SQLConf.STREAMING_MAINTENANCE_INTERVAL.key -> "100",
+      SQLConf.STREAMING_NO_DATA_PROGRESS_EVENT_INTERVAL.key -> "10",
+      SQLConf.STATE_STORE_MIN_DELTAS_FOR_SNAPSHOT.key -> "1",
+      SQLConf.STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT.key -> "10"
+    ) {
+      withTempDir { checkpointDir =>
+        val input1 = MemoryStream[Int]
+        val input2 = MemoryStream[Int]
+
+        val df1 = input1.toDF().select($"value" as "leftKey", ($"value" * 2) 
as "leftValue")
+        val df2 = input2
+          .toDF()
+          .select($"value" as "rightKey", ($"value" * 3) as "rightValue")
+        val joined = df1.join(df2, expr("leftKey = rightKey"))
+
+        testStream(joined)(
+          StartStream(checkpointLocation = checkpointDir.getCanonicalPath),
+          AddData(input1, 1, 5),
+          ProcessAllAvailable(),
+          AddData(input2, 1, 5, 10),
+          ProcessAllAvailable(),
+          CheckNewAnswer((1, 2, 1, 3), (5, 10, 5, 15)),
+          Execute { q =>
+            eventually(timeout(10.seconds)) {
+              // Make sure only smallest K active metrics are published.
+              // There are 5 * 4 = 20 metrics in total because of join, but 
only 10 are published.
+              val allInstanceMetrics = q.lastProgress
+                .stateOperators(0)
+                .customMetrics
+                .asScala
+                .view
+                .filterKeys(_.startsWith(SNAPSHOT_LAG_METRIC_PREFIX))
+              val badInstanceMetrics = allInstanceMetrics.filterKeys(
+                k =>
+                  k.startsWith(snapshotLagMetricName(0, "")) ||
+                  k.startsWith(snapshotLagMetricName(1, ""))
+              )
+              // Determined by STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT
+              assert(
+                allInstanceMetrics.size == q.sparkSession.conf
+                  .get(SQLConf.STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT)
+              )
+              // Two ids are blocked, each with four state stores
+              assert(badInstanceMetrics.count(_._2 == -1) == 2 * 4)
+              // The rest should have uploaded a version
+              assert(
+                allInstanceMetrics.count(_._2 == 1) == q.sparkSession.conf
+                  .get(SQLConf.STATE_STORE_INSTANCE_METRICS_REPORT_LIMIT) - 2 
* 4
+              )
+            }
+          },
+          StopStream
+        )
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to