This is an automated email from the ASF dual-hosted git repository.
ashrigondekar pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new acc80fd1b205 [SPARK-54984][SS] State Repartition execution and
integrate with State Rewriter
acc80fd1b205 is described below
commit acc80fd1b205792a43cc352a8e492e34bbe880da
Author: micheal-o <[email protected]>
AuthorDate: Fri Jan 9 14:55:29 2026 -0800
[SPARK-54984][SS] State Repartition execution and integrate with State
Rewriter
### What changes were proposed in this pull request?
Integrate Repartition runner with State rewriter that was introduced in my
previous [PR](https://github.com/apache/spark/pull/53703) and enable end to end
execution of repartitioning.
### Why are the changes needed?
For offline state repartitioning
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
New and existing tests
### Was this patch authored or co-authored using generative AI tooling?
Claude-4.5-opus
Closes #53747 from micheal-o/repart_rewrite_int.
Authored-by: micheal-o <[email protected]>
Signed-off-by: Anish Shrigondekar <[email protected]>
---
.../datasources/v2/state/StateDataSource.scala | 14 +-
.../datasources/v2/state/StateScanBuilder.scala | 11 +-
.../datasources/v2/state/StateTable.scala | 4 +-
.../state/OfflineStateRepartitionRunner.scala | 72 +++++++-
.../streaming/state/RocksDBFileManager.scala | 27 ++-
.../state/OfflineStateRepartitionSuite.scala | 201 ++++++++++++++++++---
6 files changed, 277 insertions(+), 52 deletions(-)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
index 40fba5a90cbd..9ccbb9a649f2 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala
@@ -67,7 +67,10 @@ class StateDataSource extends TableProvider with
DataSourceRegister with Logging
properties: util.Map[String, String]): Table = {
val sourceOptions = StateSourceOptions.modifySourceOptions(hadoopConf,
StateSourceOptions.apply(session, hadoopConf, properties))
- val stateConf = buildStateStoreConf(sourceOptions.resolvedCpLocation,
sourceOptions.batchId)
+ // Build the sql conf for the batch we are reading using confs in the
offsetlog
+ val batchSqlConf =
+ buildSqlConfForBatch(sourceOptions.resolvedCpLocation,
sourceOptions.batchId)
+ val stateConf = StateStoreConf(batchSqlConf)
// We only support RocksDB because the repartition work that this option
// is built for only supports RocksDB
if (sourceOptions.internalOnlyReadAllColumnFamilies
@@ -86,7 +89,8 @@ class StateDataSource extends TableProvider with
DataSourceRegister with Logging
NoPrefixKeyStateEncoderSpec(keySchema)
}
- new StateTable(session, schema, sourceOptions, stateConf,
keyStateEncoderSpec,
+ new StateTable(session, schema, sourceOptions, stateConf,
+ batchSqlConf.getConf(SQLConf.STATEFUL_SHUFFLE_PARTITIONS_INTERNAL).get,
keyStateEncoderSpec,
stateStoreReaderInfo.transformWithStateVariableInfoOpt,
stateStoreReaderInfo.stateStoreColFamilySchemaOpt,
stateStoreReaderInfo.stateSchemaProviderOpt,
@@ -171,7 +175,9 @@ class StateDataSource extends TableProvider with
DataSourceRegister with Logging
sourceOptions.operatorId)
}
- private def buildStateStoreConf(checkpointLocation: String, batchId: Long):
StateStoreConf = {
+ private def buildSqlConfForBatch(
+ checkpointLocation: String,
+ batchId: Long): SQLConf = {
val offsetLog = new StreamingQueryCheckpointMetadata(session,
checkpointLocation).offsetLog
offsetLog.get(batchId) match {
case Some(value) =>
@@ -181,7 +187,7 @@ class StateDataSource extends TableProvider with
DataSourceRegister with Logging
val clonedSqlConf = session.sessionState.conf.clone()
OffsetSeqMetadata.setSessionConf(metadata, clonedSqlConf)
- StateStoreConf(clonedSqlConf)
+ clonedSqlConf
case _ =>
throw StateDataSourceErrors.offsetLogUnavailable(batchId,
checkpointLocation)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala
index 1d7e7f709a6e..c3056767ee4b 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateScanBuilder.scala
@@ -42,6 +42,7 @@ class StateScanBuilder(
schema: StructType,
sourceOptions: StateSourceOptions,
stateStoreConf: StateStoreConf,
+ batchNumPartitions: Int,
keyStateEncoderSpec: KeyStateEncoderSpec,
stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
@@ -49,7 +50,8 @@ class StateScanBuilder(
joinColFamilyOpt: Option[String],
allColumnFamiliesReaderInfo: Option[AllColumnFamiliesReaderInfo]) extends
ScanBuilder {
override def build(): Scan = new StateScan(session, schema, sourceOptions,
stateStoreConf,
- keyStateEncoderSpec, stateVariableInfoOpt, stateStoreColFamilySchemaOpt,
stateSchemaProviderOpt,
+ batchNumPartitions, keyStateEncoderSpec,
+ stateVariableInfoOpt, stateStoreColFamilySchemaOpt, stateSchemaProviderOpt,
joinColFamilyOpt, allColumnFamiliesReaderInfo)
}
@@ -65,6 +67,7 @@ class StateScan(
schema: StructType,
sourceOptions: StateSourceOptions,
stateStoreConf: StateStoreConf,
+ batchNumPartitions: Int,
keyStateEncoderSpec: KeyStateEncoderSpec,
stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
@@ -85,7 +88,11 @@ class StateScan(
val partitions = fs.listStatus(stateCheckpointPartitionsLocation, new
PathFilter() {
override def accept(path: Path): Boolean = {
fs.getFileStatus(path).isDirectory &&
- Try(path.getName.toInt).isSuccess && path.getName.toInt >= 0
+ Try(path.getName.toInt).isSuccess && path.getName.toInt >= 0 &&
+ // Since we now support state repartitioning, it is possible that a
future batch has
+ // increased the number of partitions, hence increased the number of
partition directories.
+ // So we only want partition dirs for the number of partitions in this
batch.
+ path.getName.toInt < batchNumPartitions
}
})
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala
index e945b803d45b..43f2f28c6b95 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateTable.scala
@@ -41,6 +41,7 @@ class StateTable(
override val schema: StructType,
sourceOptions: StateSourceOptions,
stateConf: StateStoreConf,
+ batchNumPartitions: Int,
keyStateEncoderSpec: KeyStateEncoderSpec,
stateVariableInfoOpt: Option[TransformWithStateVariableInfo],
stateStoreColFamilySchemaOpt: Option[StateStoreColFamilySchema],
@@ -86,7 +87,8 @@ class StateTable(
override def capabilities(): util.Set[TableCapability] = CAPABILITY
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder =
- new StateScanBuilder(session, schema, sourceOptions, stateConf,
keyStateEncoderSpec,
+ new StateScanBuilder(session, schema, sourceOptions, stateConf,
+ batchNumPartitions, keyStateEncoderSpec,
stateVariableInfoOpt, stateStoreColFamilySchemaOpt,
stateSchemaProviderOpt,
joinColFamilyOpt, allColumnFamiliesReaderInfo)
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala
index 02c8e85986d0..19f75eb385f5 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionRunner.scala
@@ -17,15 +17,19 @@
package org.apache.spark.sql.execution.streaming.state
+import org.apache.hadoop.fs.Path
+
import org.apache.spark.internal.Logging
import org.apache.spark.internal.LogKeys._
-import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.sql.errors.QueryExecutionErrors
+import
org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader
import org.apache.spark.sql.execution.streaming.checkpointing.{CommitMetadata,
OffsetMap, OffsetSeq, OffsetSeqLog, OffsetSeqMetadata, OffsetSeqMetadataBase}
-import
org.apache.spark.sql.execution.streaming.runtime.StreamingQueryCheckpointMetadata
+import
org.apache.spark.sql.execution.streaming.runtime.{StreamingCheckpointConstants,
StreamingQueryCheckpointMetadata}
import org.apache.spark.sql.execution.streaming.utils.StreamingUtils
+import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.util.Utils
+import org.apache.spark.util.{SerializableConfiguration, Utils}
/**
* Runs repartitioning for the state stores used by a streaming query.
@@ -78,9 +82,22 @@ class OfflineStateRepartitionRunner(
val newBatchId = createNewBatchIfNeeded(lastBatchId,
lastCommittedBatchId)
- // todo(SPARK-54365): Do the repartitioning here, in subsequent PR
-
- // todo(SPARK-54365): update operator metadata in subsequent PR.
+ val stateRepartitionFunc = (stateDf: DataFrame) => {
+ // Repartition the state by the partition key
+ stateDf.repartition(numPartitions, col("partition_key"))
+ }
+ val rewriter = new StateRewriter(
+ sparkSession,
+ readBatchId = lastCommittedBatchId,
+ writeBatchId = newBatchId,
+ resolvedCpLocation,
+ hadoopConf,
+ transformFunc = Some(stateRepartitionFunc),
+ writeCheckpointMetadata = Some(checkpointMetadata)
+ )
+ rewriter.run()
+
+ updateNumPartitionsInOperatorMetadata(newBatchId, readBatchId =
lastCommittedBatchId)
// Commit the repartition batch
commitBatch(newBatchId, lastCommittedBatchId)
@@ -229,6 +246,49 @@ class OfflineStateRepartitionRunner(
newBatchId
}
+ private def updateNumPartitionsInOperatorMetadata(
+ newBatchId: Long,
+ readBatchId: Long): Unit = {
+ val stateMetadataReader = new StateMetadataPartitionReader(
+ resolvedCpLocation,
+ new SerializableConfiguration(hadoopConf),
+ readBatchId)
+
+ val allOperatorsMetadata = stateMetadataReader.allOperatorStateMetadata
+ assert(allOperatorsMetadata.nonEmpty, "Operator metadata shouldn't be
empty")
+
+ val stateRootLocation = new Path(
+ resolvedCpLocation, StreamingCheckpointConstants.DIR_NAME_STATE).toString
+
+ allOperatorsMetadata.foreach { opMetadata =>
+ opMetadata match {
+ // We would only update shuffle partitions for v2 op metadata since it
is versioned.
+ // For v1, we wouldn't update it since there is only one metadata file.
+ case v2: OperatorStateMetadataV2 =>
+ // update for each state store
+ val updatedStoreInfo = v2.stateStoreInfo.map { stateStore =>
+ stateStore.copy(numPartitions = numPartitions)
+ }
+ val updatedMetadata = v2.copy(stateStoreInfo = updatedStoreInfo)
+ // write the updated metadata
+ val metadataWriter = OperatorStateMetadataWriter.createWriter(
+ new Path(stateRootLocation,
updatedMetadata.operatorInfo.operatorId.toString),
+ hadoopConf,
+ updatedMetadata.version,
+ Some(newBatchId))
+ metadataWriter.write(updatedMetadata)
+
+ logInfo(log"Updated operator metadata for " +
+ log"operator=${MDC(OP_TYPE,
updatedMetadata.operatorInfo.operatorName)}, " +
+ log"numStateStores=${MDC(COUNT,
updatedMetadata.stateStoreInfo.length)}")
+ case v =>
+ logInfo(log"Skipping operator metadata update for " +
+ log"operator=${MDC(OP_TYPE, v.operatorInfo.operatorName)}, " +
+ log"since metadata version(${MDC(FILE_VERSION, v.version)}) is not
versioned")
+ }
+ }
+ }
+
private def commitBatch(newBatchId: Long, lastCommittedBatchId: Long): Unit
= {
val latestCommit =
checkpointMetadata.commitLog.get(lastCommittedBatchId).get
diff --git
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
index f67d80679d51..75e459b26511 100644
---
a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
+++
b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBFileManager.scala
@@ -201,6 +201,14 @@ class RocksDBFileManager(
}
}
+ private def createDfsRootDirIfNotExist(): Unit = {
+ if (!rootDirChecked) {
+ val rootDir = new Path(dfsRootDir)
+ if (!fm.exists(rootDir)) fm.mkdirs(rootDir)
+ rootDirChecked = true
+ }
+ }
+
def getChangeLogWriter(
version: Long,
useColumnFamilies: Boolean = false,
@@ -209,11 +217,7 @@ class RocksDBFileManager(
): StateStoreChangelogWriter = {
try {
val changelogFile = dfsChangelogFile(version, checkpointUniqueId)
- if (!rootDirChecked) {
- val rootDir = new Path(dfsRootDir)
- if (!fm.exists(rootDir)) fm.mkdirs(rootDir)
- rootDirChecked = true
- }
+ createDfsRootDirIfNotExist()
val enableStateStoreCheckpointIds = checkpointUniqueId.isDefined
val changelogVersion = getChangelogWriterVersion(
@@ -332,11 +336,7 @@ class RocksDBFileManager(
// CheckpointFileManager.createAtomic API which doesn't
auto-initialize parent directories.
// Moreover, once we disable to track the number of keys, in which the
numKeys is -1, we
// still need to create the initial dfs root directory anyway.
- if (!rootDirChecked) {
- val path = new Path(dfsRootDir)
- if (!fm.exists(path)) fm.mkdirs(path)
- rootDirChecked = true
- }
+ createDfsRootDirIfNotExist()
}
zipToDfsFile(localOtherFiles :+ metadataFile,
dfsBatchZipFile(version, checkpointUniqueId), verifyNonEmptyFilesInZip)
@@ -372,6 +372,7 @@ class RocksDBFileManager(
val metadata = if (version == 0) {
if (localDir.exists) Utils.deleteRecursively(localDir)
Utils.createDirectory(localDir)
+ createDfsRootDirIfNotExist()
// Since we cleared the local dir, we should also clear the local file
mapping
rocksDBFileMapping.clear()
RocksDBCheckpointMetadata(Seq.empty, 0)
@@ -404,11 +405,7 @@ class RocksDBFileManager(
// Return if there is a snapshot file at the corresponding version
// and optionally with checkpointunique id, e.g. version.zip or
version_uniqueId.zip
def existsSnapshotFile(version: Long, checkpointUniqueId: Option[String] =
None): Boolean = {
- if (!rootDirChecked) {
- val path = new Path(dfsRootDir)
- if (!fm.exists(path)) fm.mkdirs(path)
- rootDirChecked = true
- }
+ createDfsRootDirIfNotExist()
fm.exists(dfsBatchZipFile(version, checkpointUniqueId))
}
diff --git
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala
index 860e7a1ab2e4..1ab581d79437 100644
---
a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala
+++
b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/OfflineStateRepartitionSuite.scala
@@ -17,18 +17,26 @@
package org.apache.spark.sql.execution.streaming.state
+import scala.util.Try
+
+import org.apache.hadoop.conf.Configuration
+
+import
org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader
import org.apache.spark.sql.execution.streaming.checkpointing.{CommitLog,
CommitMetadata}
import org.apache.spark.sql.execution.streaming.runtime.{MemoryStream,
StreamingQueryCheckpointMetadata}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.streaming._
+import org.apache.spark.util.SerializableConfiguration
/**
* Test for offline state repartitioning. This tests that repartition behaves
as expected
* for different scenarios.
*/
-class OfflineStateRepartitionSuite extends StreamTest {
+class OfflineStateRepartitionSuite extends StreamTest
+ with AlsoTestWithRocksDBFeatures {
import testImplicits._
import OfflineStateRepartitionUtils._
+ import OfflineStateRepartitionTestUtils._
test("Fail if empty checkpoint directory") {
withTempDir { dir =>
@@ -101,7 +109,8 @@ class OfflineStateRepartitionSuite extends StreamTest {
test("Repartition: success, failure, retry") {
withTempDir { dir =>
val originalPartitions = 3
- val batchId = runSimpleStreamQuery(originalPartitions,
dir.getAbsolutePath)
+ val input = MemoryStream[Int]
+ val batchId = runSimpleStreamQuery(originalPartitions,
dir.getAbsolutePath, input)
val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark,
dir.getAbsolutePath)
// Shouldn't be seen as a repartition batch
assert(!isRepartitionBatch(batchId, checkpointMetadata.offsetLog,
dir.getAbsolutePath))
@@ -124,8 +133,9 @@ class OfflineStateRepartitionSuite extends StreamTest {
val newPartitions = originalPartitions + 1
spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
newPartitions)
val repartitionBatchId = batchId + 1
+ val hadoopConf = spark.sessionState.newHadoopConf()
verifyRepartitionBatch(
- repartitionBatchId, checkpointMetadata, dir.getAbsolutePath,
newPartitions)
+ repartitionBatchId, checkpointMetadata, hadoopConf,
dir.getAbsolutePath, newPartitions)
// Now delete the repartition commit to simulate a failed repartition
attempt.
// This will delete all the commits after the batchId.
@@ -150,7 +160,17 @@ class OfflineStateRepartitionSuite extends StreamTest {
// Retrying with the same numPartitions should work
spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
newPartitions)
verifyRepartitionBatch(
- repartitionBatchId, checkpointMetadata, dir.getAbsolutePath,
newPartitions)
+ repartitionBatchId, checkpointMetadata, hadoopConf,
dir.getAbsolutePath, newPartitions)
+
+ // Repartition with way more partitions, to verify that empty partitions
are properly created
+ val morePartitions = newPartitions * 3
+ spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
morePartitions)
+ verifyRepartitionBatch(
+ repartitionBatchId + 1, checkpointMetadata, hadoopConf,
+ dir.getAbsolutePath, morePartitions)
+
+ // Restart the query to make sure it can start after repartitioning
+ runSimpleStreamQuery(morePartitions, dir.getAbsolutePath, input)
}
}
@@ -188,6 +208,7 @@ class OfflineStateRepartitionSuite extends StreamTest {
verifyRepartitionBatch(
lastBatchId + 1,
checkpointMetadata,
+ spark.sessionState.newHadoopConf(),
dir.getAbsolutePath,
originalPartitions + 1,
// Repartition should be based on the first batch, since we skipped
the others
@@ -197,18 +218,21 @@ class OfflineStateRepartitionSuite extends StreamTest {
test("Consecutive repartition") {
withTempDir { dir =>
- val originalPartitions = 3
- val batchId = runSimpleStreamQuery(originalPartitions,
dir.getAbsolutePath)
+ val originalPartitions = 5
+ val input = MemoryStream[Int]
+ val batchId = runSimpleStreamQuery(originalPartitions,
dir.getAbsolutePath, input)
val checkpointMetadata = new StreamingQueryCheckpointMetadata(spark,
dir.getAbsolutePath)
+ val hadoopConf = spark.sessionState.newHadoopConf()
// decrease
- spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
originalPartitions - 1)
+ spark.streamingCheckpointManager.repartition(dir.getAbsolutePath,
originalPartitions - 3)
verifyRepartitionBatch(
batchId + 1,
checkpointMetadata,
+ hadoopConf,
dir.getAbsolutePath,
- originalPartitions - 1
+ originalPartitions - 3
)
// increase
@@ -216,9 +240,13 @@ class OfflineStateRepartitionSuite extends StreamTest {
verifyRepartitionBatch(
batchId + 2,
checkpointMetadata,
+ hadoopConf,
dir.getAbsolutePath,
originalPartitions + 1
)
+
+ // Restart the query to make sure it can start after repartitioning
+ runSimpleStreamQuery(originalPartitions + 1, dir.getAbsolutePath, input)
}
}
@@ -231,31 +259,56 @@ class OfflineStateRepartitionSuite extends StreamTest {
SQLConf.SHUFFLE_PARTITIONS.key -> numPartitions.toString)
var committedBatchId: Long = -1
- testStream(input.toDF().groupBy().count(), outputMode = OutputMode.Update)(
- StartStream(checkpointLocation = checkpointLocation, additionalConfs =
conf),
- AddData(input, 1, 2, 3),
- ProcessAllAvailable(),
- Execute { query =>
- committedBatchId =
Option(query.lastProgress).map(_.batchId).getOrElse(-1)
- }
- )
+ // Set the confs before starting the stream
+ withSQLConf(conf.toSeq: _*) {
+ testStream(input.toDF().groupBy("value").count(), outputMode =
OutputMode.Update)(
+ StartStream(checkpointLocation = checkpointLocation),
+ AddData(input, 1, 2, 3),
+ ProcessAllAvailable(),
+ Execute { query =>
+ committedBatchId =
Option(query.lastProgress).map(_.batchId).getOrElse(-1)
+ }
+ )
+ }
assert(committedBatchId >= 0, "No batch was committed in the streaming
query")
committedBatchId
}
+}
- private def verifyRepartitionBatch(
+object OfflineStateRepartitionTestUtils {
+ import OfflineStateRepartitionUtils._
+
+ def verifyRepartitionBatch(
batchId: Long,
checkpointMetadata: StreamingQueryCheckpointMetadata,
+ hadoopConf: Configuration,
checkpointLocation: String,
expectedShufflePartitions: Int,
baseBatchId: Option[Long] = None): Unit = {
// Should be seen as a repartition batch
assert(isRepartitionBatch(batchId, checkpointMetadata.offsetLog,
checkpointLocation))
+ // When failed batches are skipped, then repartition can be based
+ // on an older batch and not batchId - 1.
+ val previousBatchId = baseBatchId.getOrElse(batchId - 1)
+
+ verifyOffsetAndCommitLog(
+ batchId, previousBatchId, expectedShufflePartitions, checkpointMetadata)
+ verifyPartitionDirs(checkpointLocation, expectedShufflePartitions)
+ verifyOperatorMetadata(
+ batchId, previousBatchId, checkpointLocation, expectedShufflePartitions,
hadoopConf)
+ }
+
+ private def verifyOffsetAndCommitLog(
+ repartitionBatchId: Long,
+ previousBatchId: Long,
+ expectedShufflePartitions: Int,
+ checkpointMetadata: StreamingQueryCheckpointMetadata): Unit = {
// Verify the repartition batch
val lastBatchId = checkpointMetadata.offsetLog.getLatestBatchId().get
- assert(lastBatchId == batchId)
+ assert(lastBatchId == repartitionBatchId,
+ "The latest batch in offset log should be the repartition batch")
val lastBatch = checkpointMetadata.offsetLog.get(lastBatchId).get
val lastBatchShufflePartitions =
getShufflePartitions(lastBatch.metadataOpt.get).get
@@ -263,18 +316,16 @@ class OfflineStateRepartitionSuite extends StreamTest {
// Verify the commit log
val lastCommitId = checkpointMetadata.commitLog.getLatestBatchId().get
- assert(lastCommitId == batchId)
+ assert(lastCommitId == repartitionBatchId,
+ "The latest batch in commit log should be the repartition batch")
// verify that the offset seq is the same between repartition batch and
// the batch the repartition is based on except for the shuffle partitions.
- // When failed batches are skipped, then repartition can be based
- // on an older batch and not batchId - 1.
- val previousBatchId = baseBatchId.getOrElse(batchId - 1)
val previousBatch = checkpointMetadata.offsetLog.get(previousBatchId).get
// Verify offsets are identical
assert(lastBatch.offsets == previousBatch.offsets,
- s"Offsets should be identical between batch $previousBatchId and
$batchId")
+ s"Offsets should be identical between batch $previousBatchId and
$repartitionBatchId")
// Verify metadata is the same except for shuffle partitions config
(lastBatch.metadataOpt, previousBatch.metadataOpt) match {
@@ -298,7 +349,109 @@ class OfflineStateRepartitionSuite extends StreamTest {
getShufflePartitions(lastMetadata).get !=
getShufflePartitions(previousMetadata).get,
"Shuffle partitions should be different between batches")
case _ =>
- fail("Both batches should have metadata")
+ assert(false, "Both batches should have metadata")
+ }
+ }
+
+ // verify number of partition dirs in state dir
+ private def verifyPartitionDirs(
+ checkpointLocation: String,
+ expectedShufflePartitions: Int): Unit = {
+ val stateDir = new java.io.File(checkpointLocation, "state")
+
+ def numDirs(file: java.io.File): Int = {
+ file.listFiles()
+ .filter(d => d.isDirectory && Try(d.getName.toInt).isSuccess)
+ .length
+ }
+
+ val numOperators = numDirs(stateDir)
+ for (op <- 0 until numOperators) {
+ val partitionsDir = new java.io.File(stateDir, s"$op")
+ val numPartitions = numDirs(partitionsDir)
+ // Doing <= in case of reduced number of partitions
+ assert(expectedShufflePartitions <= numPartitions,
+ s"Expected atleast $expectedShufflePartitions partition dirs for
operator $op," +
+ s" but found $numPartitions")
+ }
+ }
+
+ private def verifyOperatorMetadata(
+ repartitionBatchId: Long,
+ baseBatchId: Long,
+ checkpointLocation: String,
+ expectedShufflePartitions: Int,
+ hadoopConf: Configuration): Unit = {
+ val serializableConf = new SerializableConfiguration(hadoopConf)
+
+ // Read operator metadata for both batches
+ val baseMetadataReader = new StateMetadataPartitionReader(
+ checkpointLocation, serializableConf, baseBatchId)
+ val repartitionMetadataReader = new StateMetadataPartitionReader(
+ checkpointLocation, serializableConf, repartitionBatchId)
+
+ val baseOperatorsMetadata = baseMetadataReader.allOperatorStateMetadata
+ val repartitionOperatorsMetadata =
repartitionMetadataReader.allOperatorStateMetadata
+
+ assert(baseOperatorsMetadata.nonEmpty, "Base batch should have operator
metadata")
+ assert(repartitionOperatorsMetadata.nonEmpty, "Repartition batch should
have operator metadata")
+ assert(baseOperatorsMetadata.length == repartitionOperatorsMetadata.length,
+ "Both batches should have the same number of operators")
+
+ // Verify each operator's metadata
+ baseOperatorsMetadata.zip(repartitionOperatorsMetadata).foreach {
+ case (baseOp, repartitionOp) =>
+ // Verify both are of the same type
+ assert(baseOp.getClass == repartitionOp.getClass,
+ s"Metadata types should match:
base=${baseOp.getClass.getSimpleName}, " +
+ s"repartition=${repartitionOp.getClass.getSimpleName}")
+
+ (baseOp, repartitionOp) match {
+ case (baseV2: OperatorStateMetadataV2, repartitionV2:
OperatorStateMetadataV2) =>
+ // Verify operator info is the same
+ assert(baseV2.operatorInfo == repartitionV2.operatorInfo,
+ s"Operator info should match: base=${baseV2.operatorInfo}, " +
+ s"repartition=${repartitionV2.operatorInfo}")
+
+ // Verify operator properties JSON is the same
+ assert(baseV2.operatorPropertiesJson ==
repartitionV2.operatorPropertiesJson,
+ "Operator properties JSON should match")
+
+ // Verify state store info (except numPartitions)
+ assert(baseV2.stateStoreInfo.length ==
repartitionV2.stateStoreInfo.length,
+ "Should have same number of state stores")
+
+ baseV2.stateStoreInfo.zip(repartitionV2.stateStoreInfo).foreach {
+ case (baseStore, repartitionStore) =>
+ assert(baseStore.storeName == repartitionStore.storeName,
+ s"Store name should match: ${baseStore.storeName} " +
+ s"vs ${repartitionStore.storeName}")
+ assert(baseStore.numColsPrefixKey ==
repartitionStore.numColsPrefixKey,
+ "numColsPrefixKey should match")
+ // Schema file paths should be the same (they reference the
same schema files)
+ assert(baseStore.stateSchemaFilePaths ==
repartitionStore.stateSchemaFilePaths,
+ "State schema file paths should match")
+ assert(baseStore.numPartitions !=
repartitionStore.numPartitions,
+ "numPartitions shouldn't be the same")
+ // Verify numPartitions is updated to expectedShufflePartitions
+ assert(repartitionStore.numPartitions ==
expectedShufflePartitions,
+ s"Repartition batch numPartitions should be
$expectedShufflePartitions, " +
+ s"but found ${repartitionStore.numPartitions}")
+ }
+
+ case (baseV1: OperatorStateMetadataV1, repartitionV1:
OperatorStateMetadataV1) =>
+ // For v1, since we didn't update it, then it should be the same.
+ // Can't use == directly because Array uses reference equality
+ assert(baseV1.operatorInfo == repartitionV1.operatorInfo,
+ "V1 operator info should be the same")
+
assert(baseV1.stateStoreInfo.sameElements(repartitionV1.stateStoreInfo),
+ "V1 state store info should be the same")
+
+ case _ =>
+ assert(false,
+ s"Unexpected metadata types:
base=${baseOp.getClass.getSimpleName}, " +
+ s"repartition=${repartitionOp.getClass.getSimpleName}")
+ }
}
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]