This is an automated email from the ASF dual-hosted git repository.

ashrigondekar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new acc80fd1b205 [SPARK-54984][SS] State Repartition execution and 
integrate with State Rewriter
acc80fd1b205 is described below

commit acc80fd1b205792a43cc352a8e492e34bbe880da
Author: micheal-o <[email protected]>
AuthorDate: Fri Jan 9 14:55:29 2026 -0800

    [SPARK-54984][SS] State Repartition execution and integrate with State 
Rewriter
    
    ### What changes were proposed in this pull request?
    
    Integrate Repartition runner with State rewriter that was introduced in my 
previous [PR](https://github.com/apache/spark/pull/53703) and enable end to end 
execution of repartitioning.
    
    ### Why are the changes needed?
    
    For offline state repartitioning
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    New and existing tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    Claude-4.5-opus
    
    Closes #53747 from micheal-o/repart_rewrite_int.
    
    Authored-by: micheal-o <[email protected]>
    Signed-off-by: Anish Shrigondekar <[email protected]>
---
 .../datasources/v2/state/StateDataSource.scala     |  14 +-
 .../datasources/v2/state/StateScanBuilder.scala    |  11 +-
 .../datasources/v2/state/StateTable.scala          |   4 +-
 .../state/OfflineStateRepartitionRunner.scala      |  72 +++++++-
 .../streaming/state/RocksDBFileManager.scala       |  27 ++-
 .../state/OfflineStateRepartitionSuite.scala       | 201 ++++++++++++++++++---
 6 files changed, 277 insertions(+), 52 deletions(-)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
index 40fba5a90cbd..9ccbb9a649f2 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
@@ -67,7 +67,10 @@ class StateDataSource extends TableProvider with 
DataSourceRegister with Logging
       properties: util.Map[String, String]): Table = {
     val sourceOptions = StateSourceOptions.modifySourceOptions(hadoopConf,
       StateSourceOptions.apply(session, hadoopConf, properties))
-    val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation, 
sourceOptions.batchId)
+    // Build the sql conf for the batch we are reading using confs in the 
offsetlog
+    val batchSqlConf =
+      buildSqlConfForBatch(sourceOptions.resolvedCpLocation, 
sourceOptions.batchId)
+    val stateConf = StateStoreConf(batchSqlConf)
     // We only support RocksDB because the repartition work that this option
     // is built for only supports RocksDB
     if (sourceOptions.internalOnlyReadAllColumnFamilies
@@ -86,7 +89,8 @@ class StateDataSource extends TableProvider with 
DataSourceRegister with Logging
       NoPrefixKeyStateEncoderSpec(keySchema)
     }
 
-    new StateTable(session, schema, sourceOptions, stateConf, 
keyStateEncoderSpec,
+    new StateTable(session, schema, sourceOptions, stateConf,
+      batchSqlConf.getConf(SQLConf.STATEFUL_SHUFFLE_PARTITIONS_INTERNAL).get, 
keyStateEncoderSpec,
       stateStoreReaderInfo.transformWithStateVariableInfoOpt,
       stateStoreReaderInfo.stateStoreColFamilySchemaOpt,
       stateStoreReaderInfo.stateSchemaProviderOpt,
@@ -171,7 +175,9 @@ class StateDataSource extends TableProvider with 
DataSourceRegister with Logging
         sourceOptions.operatorId)
   }
 
-  private def buildStateStoreConf(checkpointLocation: String, batchId: Long): 
StateStoreConf = {
+  private def buildSqlConfForBatch(
+      checkpointLocation: String,
+      batchId: Long): SQLConf = {
     val offsetLog = new StreamingQueryCheckpointMetadata(session, 
checkpointLocation).offsetLog
     offsetLog.get(batchId) match {
       case Some(value) =>
@@ -181,7 +187,7 @@ class StateDataSource extends TableProvider with 
DataSourceRegister with Logging
 
         val clonedSqlConf = session.sessionState.conf.clone()
         OffsetSeqMetadata.setSessionConf(metadata, clonedSqlConf)
-        StateStoreConf(clonedSqlConf)
+        clonedSqlConf
 
       case _ =>
         throw StateDataSourceErrors.offsetLogUnavailable(batchId, 
checkpointLocation)
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala
index 1d7e7f709a6e..c3056767ee4b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala
@@ -42,6 +42,7 @@ class StateScanBuilder(
     schema: StructType,
     sourceOptions: StateSourceOptions,
     stateStoreConf: StateStoreConf,
+    batchNumPartitions: Int,
     keyStateEncoderSpec: KeyStateEncoderSpec,
     stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
     stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
@@ -49,7 +50,8 @@ class StateScanBuilder(
     joinColFamilyOpt: Option[String],
     allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo]) extends 
ScanBuilder {
   override def build(): Scan = new StateScan(session, schema, sourceOptions, 
stateStoreConf,
-    keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt, 
stateSchemaProviderOpt,
+    batchNumPartitions, keyStateEncoderSpec,
+    stateVariableInfoOpt, stateStoreColFamilySchemaOpt, stateSchemaProviderOpt,
     joinColFamilyOpt, allColumnFamiliesReaderInfo)
 }
 
@@ -65,6 +67,7 @@ class StateScan(
     schema: StructType,
     sourceOptions: StateSourceOptions,
     stateStoreConf: StateStoreConf,
+    batchNumPartitions: Int,
     keyStateEncoderSpec: KeyStateEncoderSpec,
     stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
     stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
@@ -85,7 +88,11 @@ class StateScan(
     val partitions = fs.listStatus(stateCheckpointPartitionsLocation, new 
PathFilter() {
       override def accept(path: Path): Boolean = {
         fs.getFileStatus(path).isDirectory &&
-        Try(path.getName.toInt).isSuccess && path.getName.toInt >= 0
+        Try(path.getName.toInt).isSuccess && path.getName.toInt >= 0 &&
+        // Since we now support state repartitioning, it is possible that a 
future batch has
+        // increased the number of partitions, hence increased the number of 
partition directories.
+        // So we only want partition dirs for the number of partitions in this 
batch.
+        path.getName.toInt < batchNumPartitions
       }
     })
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala
index e945b803d45b..43f2f28c6b95 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala
@@ -41,6 +41,7 @@ class StateTable(
     override val schema: StructType,
     sourceOptions: StateSourceOptions,
     stateConf: StateStoreConf,
+    batchNumPartitions: Int,
     keyStateEncoderSpec: KeyStateEncoderSpec,
     stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
     stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
@@ -86,7 +87,8 @@ class StateTable(
   override def capabilities(): util.Set[TableCapability] = CAPABILITY
 
   override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder =
-    new StateScanBuilder(session, schema, sourceOptions, stateConf, 
keyStateEncoderSpec,
+    new StateScanBuilder(session, schema, sourceOptions, stateConf,
+      batchNumPartitions, keyStateEncoderSpec,
       stateVariableInfoOpt, stateStoreColFamilySchemaOpt, 
stateSchemaProviderOpt,
       joinColFamilyOpt, allColumnFamiliesReaderInfo)
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala
index 02c8e85986d0..19f75eb385f5 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala
@@ -17,15 +17,19 @@
 
 package org.apache.spark.sql.execution.streaming.state
 
+import org.apache.hadoop.fs.Path
+
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.LogKeys._
-import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.{DataFrame, SparkSession}
 import org.apache.spark.sql.errors.QueryExecutionErrors
+import 
org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader
 import org.apache.spark.sql.execution.streaming.checkpointing.{CommitMetadata, 
OffsetMap, OffsetSeq, OffsetSeqLog, OffsetSeqMetadata, OffsetSeqMetadataBase}
-import 
org.apache.spark.sql.execution.streaming.runtime.StreamingQueryCheckpointMetadata
+import 
org.apache.spark.sql.execution.streaming.runtime.{StreamingCheckpointConstants, 
StreamingQueryCheckpointMetadata}
 import org.apache.spark.sql.execution.streaming.utils.StreamingUtils
+import org.apache.spark.sql.functions.col
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{SerializableConfiguration, Utils}
 
 /**
  * Runs repartitioning for the state stores used by a streaming query.
@@ -78,9 +82,22 @@ class OfflineStateRepartitionRunner(
 
         val newBatchId = createNewBatchIfNeeded(lastBatchId, 
lastCommittedBatchId)
 
-        // todo(SPARK-54365): Do the repartitioning here, in subsequent PR
-
-        // todo(SPARK-54365): update operator metadata in subsequent PR.
+        val stateRepartitionFunc = (stateDf: DataFrame) => {
+          // Repartition the state by the partition key
+          stateDf.repartition(numPartitions, col("partition_key"))
+        }
+        val rewriter = new StateRewriter(
+          sparkSession,
+          readBatchId = lastCommittedBatchId,
+          writeBatchId = newBatchId,
+          resolvedCpLocation,
+          hadoopConf,
+          transformFunc = Some(stateRepartitionFunc),
+          writeCheckpointMetadata = Some(checkpointMetadata)
+        )
+        rewriter.run()
+
+        updateNumPartitionsInOperatorMetadata(newBatchId, readBatchId = 
lastCommittedBatchId)
 
         // Commit the repartition batch
         commitBatch(newBatchId, lastCommittedBatchId)
@@ -229,6 +246,49 @@ class OfflineStateRepartitionRunner(
     newBatchId
   }
 
+  private def updateNumPartitionsInOperatorMetadata(
+      newBatchId: Long,
+      readBatchId: Long): Unit = {
+    val stateMetadataReader = new StateMetadataPartitionReader(
+      resolvedCpLocation,
+      new SerializableConfiguration(hadoopConf),
+      readBatchId)
+
+    val allOperatorsMetadata = stateMetadataReader.allOperatorStateMetadata
+    assert(allOperatorsMetadata.nonEmpty, "Operator metadata shouldn't be 
empty")
+
+    val stateRootLocation = new Path(
+      resolvedCpLocation, StreamingCheckpointConstants.DIR_NAME_STATE).toString
+
+    allOperatorsMetadata.foreach { opMetadata =>
+      opMetadata match {
+        // We would only update shuffle partitions for v2 op metadata since it 
is versioned.
+        // For v1, we wouldn't update it since there is only one metadata file.
+        case v2: OperatorStateMetadataV2 =>
+          // update for each state store
+          val updatedStoreInfo = v2.stateStoreInfo.map { stateStore =>
+            stateStore.copy(numPartitions = numPartitions)
+          }
+          val updatedMetadata = v2.copy(stateStoreInfo = updatedStoreInfo)
+          // write the updated metadata
+          val metadataWriter = OperatorStateMetadataWriter.createWriter(
+            new Path(stateRootLocation, 
updatedMetadata.operatorInfo.operatorId.toString),
+            hadoopConf,
+            updatedMetadata.version,
+            Some(newBatchId))
+          metadataWriter.write(updatedMetadata)
+
+          logInfo(log"Updated operator metadata for " +
+            log"operator=${MDC(OP_TYPE, 
updatedMetadata.operatorInfo.operatorName)}, " +
+            log"numStateStores=${MDC(COUNT, 
updatedMetadata.stateStoreInfo.length)}")
+        case v =>
+          logInfo(log"Skipping operator metadata update for " +
+            log"operator=${MDC(OP_TYPE, v.operatorInfo.operatorName)}, " +
+            log"since metadata version(${MDC(FILE_VERSION, v.version)}) is not 
versioned")
+      }
+    }
+  }
+
   private def commitBatch(newBatchId: Long, lastCommittedBatchId: Long): Unit 
= {
     val latestCommit = 
checkpointMetadata.commitLog.get(lastCommittedBatchId).get
 
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
index f67d80679d51..75e459b26511 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
@@ -201,6 +201,14 @@ class RocksDBFileManager(
     }
   }
 
+  private def createDfsRootDirIfNotExist(): Unit = {
+    if (!rootDirChecked) {
+      val rootDir = new Path(dfsRootDir)
+      if (!fm.exists(rootDir)) fm.mkdirs(rootDir)
+      rootDirChecked = true
+    }
+  }
+
   def getChangeLogWriter(
       version: Long,
       useColumnFamilies: Boolean = false,
@@ -209,11 +217,7 @@ class RocksDBFileManager(
     ): StateStoreChangelogWriter = {
     try {
       val changelogFile = dfsChangelogFile(version, checkpointUniqueId)
-      if (!rootDirChecked) {
-        val rootDir = new Path(dfsRootDir)
-        if (!fm.exists(rootDir)) fm.mkdirs(rootDir)
-        rootDirChecked = true
-      }
+      createDfsRootDirIfNotExist()
 
       val enableStateStoreCheckpointIds = checkpointUniqueId.isDefined
       val changelogVersion = getChangelogWriterVersion(
@@ -332,11 +336,7 @@ class RocksDBFileManager(
         // CheckpointFileManager.createAtomic API which doesn't 
auto-initialize parent directories.
         // Moreover, once we disable to track the number of keys, in which the 
numKeys is -1, we
         // still need to create the initial dfs root directory anyway.
-        if (!rootDirChecked) {
-          val path = new Path(dfsRootDir)
-          if (!fm.exists(path)) fm.mkdirs(path)
-          rootDirChecked = true
-        }
+        createDfsRootDirIfNotExist()
       }
       zipToDfsFile(localOtherFiles :+ metadataFile,
         dfsBatchZipFile(version, checkpointUniqueId), verifyNonEmptyFilesInZip)
@@ -372,6 +372,7 @@ class RocksDBFileManager(
     val metadata = if (version == 0) {
       if (localDir.exists) Utils.deleteRecursively(localDir)
       Utils.createDirectory(localDir)
+      createDfsRootDirIfNotExist()
       // Since we cleared the local dir, we should also clear the local file 
mapping
       rocksDBFileMapping.clear()
       RocksDBCheckpointMetadata(Seq.empty, 0)
@@ -404,11 +405,7 @@ class RocksDBFileManager(
   // Return if there is a snapshot file at the corresponding version
   // and optionally with checkpointunique id, e.g. version.zip or 
version_uniqueId.zip
   def existsSnapshotFile(version: Long, checkpointUniqueId: Option[String] = 
None): Boolean = {
-    if (!rootDirChecked) {
-      val path = new Path(dfsRootDir)
-      if (!fm.exists(path)) fm.mkdirs(path)
-      rootDirChecked = true
-    }
+    createDfsRootDirIfNotExist()
     fm.exists(dfsBatchZipFile(version, checkpointUniqueId))
   }
 
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala
index 860e7a1ab2e4..1ab581d79437 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala
@@ -17,18 +17,26 @@
 
 package org.apache.spark.sql.execution.streaming.state
 
+import scala.util.Try
+
+import org.apache.hadoop.conf.Configuration
+
+import 
org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader
 import org.apache.spark.sql.execution.streaming.checkpointing.{CommitLog, 
CommitMetadata}
 import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream, 
StreamingQueryCheckpointMetadata}
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.streaming._
+import org.apache.spark.util.SerializableConfiguration
 
 /**
  * Test for offline state repartitioning. This tests that repartition behaves 
as expected
  * for different scenarios.
  */
-class OfflineStateRepartitionSuite extends StreamTest {
+class OfflineStateRepartitionSuite extends StreamTest
+  with AlsoTestWithRocksDBFeatures {
   import testImplicits._
   import OfflineStateRepartitionUtils._
+  import OfflineStateRepartitionTestUtils._
 
   test("Fail if empty checkpoint directory") {
     withTempDir { dir =>
@@ -101,7 +109,8 @@ class OfflineStateRepartitionSuite extends StreamTest {
   test("Repartition: success, failure, retry") {
     withTempDir { dir =>
       val originalPartitions = 3
-      val batchId = runSimpleStreamQuery(originalPartitions, 
dir.getAbsolutePath)
+      val input = MemoryStream[Int]
+      val batchId = runSimpleStreamQuery(originalPartitions, 
dir.getAbsolutePath, input)
       val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark, 
dir.getAbsolutePath)
       // Shouldn't be seen as a repartition batch
       assert(!isRepartitionBatch(batchId, checkpointMetadata.offsetLog, 
dir.getAbsolutePath))
@@ -124,8 +133,9 @@ class OfflineStateRepartitionSuite extends StreamTest {
       val newPartitions = originalPartitions + 1
       spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
newPartitions)
       val repartitionBatchId = batchId + 1
+      val hadoopConf = spark.sessionState.newHadoopConf()
       verifyRepartitionBatch(
-        repartitionBatchId, checkpointMetadata, dir.getAbsolutePath, 
newPartitions)
+        repartitionBatchId, checkpointMetadata, hadoopConf, 
dir.getAbsolutePath, newPartitions)
 
       // Now delete the repartition commit to simulate a failed repartition 
attempt.
       // This will delete all the commits after the batchId.
@@ -150,7 +160,17 @@ class OfflineStateRepartitionSuite extends StreamTest {
       // Retrying with the same numPartitions should work
       spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
newPartitions)
       verifyRepartitionBatch(
-        repartitionBatchId, checkpointMetadata, dir.getAbsolutePath, 
newPartitions)
+        repartitionBatchId, checkpointMetadata, hadoopConf, 
dir.getAbsolutePath, newPartitions)
+
+      // Repartition with way more partitions, to verify that empty partitions 
are properly created
+      val morePartitions = newPartitions * 3
+      spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
morePartitions)
+      verifyRepartitionBatch(
+        repartitionBatchId + 1, checkpointMetadata, hadoopConf,
+        dir.getAbsolutePath, morePartitions)
+
+      // Restart the query to make sure it can start after repartitioning
+      runSimpleStreamQuery(morePartitions, dir.getAbsolutePath, input)
     }
   }
 
@@ -188,6 +208,7 @@ class OfflineStateRepartitionSuite extends StreamTest {
       verifyRepartitionBatch(
         lastBatchId + 1,
         checkpointMetadata,
+        spark.sessionState.newHadoopConf(),
         dir.getAbsolutePath,
         originalPartitions + 1,
         // Repartition should be based on the first batch, since we skipped 
the others
@@ -197,18 +218,21 @@ class OfflineStateRepartitionSuite extends StreamTest {
 
   test("Consecutive repartition") {
     withTempDir { dir =>
-      val originalPartitions = 3
-      val batchId = runSimpleStreamQuery(originalPartitions, 
dir.getAbsolutePath)
+      val originalPartitions = 5
+      val input = MemoryStream[Int]
+      val batchId = runSimpleStreamQuery(originalPartitions, 
dir.getAbsolutePath, input)
 
       val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark, 
dir.getAbsolutePath)
+      val hadoopConf = spark.sessionState.newHadoopConf()
 
       // decrease
-      spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
originalPartitions - 1)
+      spark.streamingCheckpointManager.repartition(dir.getAbsolutePath, 
originalPartitions - 3)
       verifyRepartitionBatch(
         batchId + 1,
         checkpointMetadata,
+        hadoopConf,
         dir.getAbsolutePath,
-        originalPartitions - 1
+        originalPartitions - 3
       )
 
       // increase
@@ -216,9 +240,13 @@ class OfflineStateRepartitionSuite extends StreamTest {
       verifyRepartitionBatch(
         batchId + 2,
         checkpointMetadata,
+        hadoopConf,
         dir.getAbsolutePath,
         originalPartitions + 1
       )
+
+      // Restart the query to make sure it can start after repartitioning
+      runSimpleStreamQuery(originalPartitions + 1, dir.getAbsolutePath, input)
     }
   }
 
@@ -231,31 +259,56 @@ class OfflineStateRepartitionSuite extends StreamTest {
       SQLConf.SHUFFLE_PARTITIONS.key -> numPartitions.toString)
 
     var committedBatchId: Long = -1
-    testStream(input.toDF().groupBy().count(), outputMode = OutputMode.Update)(
-      StartStream(checkpointLocation = checkpointLocation, additionalConfs = 
conf),
-      AddData(input, 1, 2, 3),
-      ProcessAllAvailable(),
-      Execute { query =>
-        committedBatchId = 
Option(query.lastProgress).map(_.batchId).getOrElse(-1)
-      }
-    )
+    // Set the confs before starting the stream
+    withSQLConf(conf.toSeq: _*) {
+      testStream(input.toDF().groupBy("value").count(), outputMode = 
OutputMode.Update)(
+        StartStream(checkpointLocation = checkpointLocation),
+        AddData(input, 1, 2, 3),
+        ProcessAllAvailable(),
+        Execute { query =>
+          committedBatchId = 
Option(query.lastProgress).map(_.batchId).getOrElse(-1)
+        }
+      )
+    }
 
     assert(committedBatchId >= 0, "No batch was committed in the streaming 
query")
     committedBatchId
   }
+}
 
-  private def verifyRepartitionBatch(
+object OfflineStateRepartitionTestUtils {
+  import OfflineStateRepartitionUtils._
+
+  def verifyRepartitionBatch(
       batchId: Long,
       checkpointMetadata: StreamingQueryCheckpointMetadata,
+      hadoopConf: Configuration,
       checkpointLocation: String,
       expectedShufflePartitions: Int,
       baseBatchId: Option[Long] = None): Unit = {
     // Should be seen as a repartition batch
     assert(isRepartitionBatch(batchId, checkpointMetadata.offsetLog, 
checkpointLocation))
 
+    // When failed batches are skipped, then repartition can be based
+    // on an older batch and not batchId - 1.
+    val previousBatchId = baseBatchId.getOrElse(batchId - 1)
+
+    verifyOffsetAndCommitLog(
+      batchId, previousBatchId, expectedShufflePartitions, checkpointMetadata)
+    verifyPartitionDirs(checkpointLocation, expectedShufflePartitions)
+    verifyOperatorMetadata(
+      batchId, previousBatchId, checkpointLocation, expectedShufflePartitions, 
hadoopConf)
+  }
+
+  private def verifyOffsetAndCommitLog(
+      repartitionBatchId: Long,
+      previousBatchId: Long,
+      expectedShufflePartitions: Int,
+      checkpointMetadata: StreamingQueryCheckpointMetadata): Unit = {
     // Verify the repartition batch
     val lastBatchId = checkpointMetadata.offsetLog.getLatestBatchId().get
-    assert(lastBatchId == batchId)
+    assert(lastBatchId == repartitionBatchId,
+      "The latest batch in offset log should be the repartition batch")
 
     val lastBatch = checkpointMetadata.offsetLog.get(lastBatchId).get
     val lastBatchShufflePartitions = 
getShufflePartitions(lastBatch.metadataOpt.get).get
@@ -263,18 +316,16 @@ class OfflineStateRepartitionSuite extends StreamTest {
 
     // Verify the commit log
     val lastCommitId = checkpointMetadata.commitLog.getLatestBatchId().get
-    assert(lastCommitId == batchId)
+    assert(lastCommitId == repartitionBatchId,
+      "The latest batch in commit log should be the repartition batch")
 
     // verify that the offset seq is the same between repartition batch and
     // the batch the repartition is based on except for the shuffle partitions.
-    // When failed batches are skipped, then repartition can be based
-    // on an older batch and not batchId - 1.
-    val previousBatchId = baseBatchId.getOrElse(batchId - 1)
     val previousBatch = checkpointMetadata.offsetLog.get(previousBatchId).get
 
     // Verify offsets are identical
     assert(lastBatch.offsets == previousBatch.offsets,
-      s"Offsets should be identical between batch $previousBatchId and 
$batchId")
+      s"Offsets should be identical between batch $previousBatchId and 
$repartitionBatchId")
 
     // Verify metadata is the same except for shuffle partitions config
     (lastBatch.metadataOpt, previousBatch.metadataOpt) match {
@@ -298,7 +349,109 @@ class OfflineStateRepartitionSuite extends StreamTest {
           getShufflePartitions(lastMetadata).get != 
getShufflePartitions(previousMetadata).get,
           "Shuffle partitions should be different between batches")
       case _ =>
-        fail("Both batches should have metadata")
+        assert(false, "Both batches should have metadata")
+    }
+  }
+
+  // verify number of partition dirs in state dir
+  private def verifyPartitionDirs(
+      checkpointLocation: String,
+      expectedShufflePartitions: Int): Unit = {
+    val stateDir = new java.io.File(checkpointLocation, "state")
+
+    def numDirs(file: java.io.File): Int = {
+      file.listFiles()
+        .filter(d => d.isDirectory && Try(d.getName.toInt).isSuccess)
+        .length
+    }
+
+    val numOperators = numDirs(stateDir)
+    for (op <- 0 until numOperators) {
+      val partitionsDir = new java.io.File(stateDir, s"$op")
+      val numPartitions = numDirs(partitionsDir)
+      // Doing <= in case of reduced number of partitions
+      assert(expectedShufflePartitions <= numPartitions,
+        s"Expected atleast $expectedShufflePartitions partition dirs for 
operator $op," +
+          s" but found $numPartitions")
+    }
+  }
+
+  private def verifyOperatorMetadata(
+      repartitionBatchId: Long,
+      baseBatchId: Long,
+      checkpointLocation: String,
+      expectedShufflePartitions: Int,
+      hadoopConf: Configuration): Unit = {
+    val serializableConf = new SerializableConfiguration(hadoopConf)
+
+    // Read operator metadata for both batches
+    val baseMetadataReader = new StateMetadataPartitionReader(
+      checkpointLocation, serializableConf, baseBatchId)
+    val repartitionMetadataReader = new StateMetadataPartitionReader(
+      checkpointLocation, serializableConf, repartitionBatchId)
+
+    val baseOperatorsMetadata = baseMetadataReader.allOperatorStateMetadata
+    val repartitionOperatorsMetadata = 
repartitionMetadataReader.allOperatorStateMetadata
+
+    assert(baseOperatorsMetadata.nonEmpty, "Base batch should have operator 
metadata")
+    assert(repartitionOperatorsMetadata.nonEmpty, "Repartition batch should 
have operator metadata")
+    assert(baseOperatorsMetadata.length == repartitionOperatorsMetadata.length,
+      "Both batches should have the same number of operators")
+
+    // Verify each operator's metadata
+    baseOperatorsMetadata.zip(repartitionOperatorsMetadata).foreach {
+      case (baseOp, repartitionOp) =>
+        // Verify both are of the same type
+        assert(baseOp.getClass == repartitionOp.getClass,
+          s"Metadata types should match: 
base=${baseOp.getClass.getSimpleName}, " +
+            s"repartition=${repartitionOp.getClass.getSimpleName}")
+
+        (baseOp, repartitionOp) match {
+          case (baseV2: OperatorStateMetadataV2, repartitionV2: 
OperatorStateMetadataV2) =>
+            // Verify operator info is the same
+            assert(baseV2.operatorInfo == repartitionV2.operatorInfo,
+              s"Operator info should match: base=${baseV2.operatorInfo}, " +
+                s"repartition=${repartitionV2.operatorInfo}")
+
+            // Verify operator properties JSON is the same
+            assert(baseV2.operatorPropertiesJson == 
repartitionV2.operatorPropertiesJson,
+              "Operator properties JSON should match")
+
+            // Verify state store info (except numPartitions)
+            assert(baseV2.stateStoreInfo.length == 
repartitionV2.stateStoreInfo.length,
+              "Should have same number of state stores")
+
+            baseV2.stateStoreInfo.zip(repartitionV2.stateStoreInfo).foreach {
+              case (baseStore, repartitionStore) =>
+                assert(baseStore.storeName == repartitionStore.storeName,
+                  s"Store name should match: ${baseStore.storeName} " +
+                    s"vs ${repartitionStore.storeName}")
+                assert(baseStore.numColsPrefixKey == 
repartitionStore.numColsPrefixKey,
+                  "numColsPrefixKey should match")
+                // Schema file paths should be the same (they reference the 
same schema files)
+                assert(baseStore.stateSchemaFilePaths == 
repartitionStore.stateSchemaFilePaths,
+                  "State schema file paths should match")
+                assert(baseStore.numPartitions != 
repartitionStore.numPartitions,
+                  "numPartitions shouldn't be the same")
+                // Verify numPartitions is updated to expectedShufflePartitions
+                assert(repartitionStore.numPartitions == 
expectedShufflePartitions,
+                  s"Repartition batch numPartitions should be 
$expectedShufflePartitions, " +
+                    s"but found ${repartitionStore.numPartitions}")
+            }
+
+          case (baseV1: OperatorStateMetadataV1, repartitionV1: 
OperatorStateMetadataV1) =>
+            // For v1, since we didn't update it, then it should be the same.
+            // Can't use == directly because Array uses reference equality
+            assert(baseV1.operatorInfo == repartitionV1.operatorInfo,
+              "V1 operator info should be the same")
+            
assert(baseV1.stateStoreInfo.sameElements(repartitionV1.stateStoreInfo),
+              "V1 state store info should be the same")
+
+          case _ =>
+            assert(false,
+              s"Unexpected metadata types: 
base=${baseOp.getClass.getSimpleName}, " +
+                s"repartition=${repartitionOp.getClass.getSimpleName}")
+        }
     }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to