siying commented on code in PR #47895: URL: https://github.com/apache/spark/pull/47895#discussion_r1755481725
########## sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala: ########## @@ -105,7 +105,7 @@ class StreamStreamJoinStatePartitionReader( val stateInfo = StatefulOperatorStateInfo( partition.sourceOptions.stateCheckpointLocation.toString, partition.queryId, partition.sourceOptions.operatorId, - partition.sourceOptions.batchId + 1, -1) + partition.sourceOptions.batchId + 1, -1, None) Review Comment: Good point. Will check. ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala: ########## @@ -900,12 +907,54 @@ class MicroBatchExecution( */ protected def markMicroBatchExecutionStart(execCtx: MicroBatchExecutionContext): Unit = {} + private def updateCheckpointIdForOperator( + execCtx: MicroBatchExecutionContext, + opId: Long, + checkpointInfo: Array[StateStoreCheckpointInfo]): Unit = { + // TODO validate baseCheckpointId + checkpointInfo.map(_.batchVersion).foreach { v => + assert( + execCtx.batchId == -1 || v == execCtx.batchId + 1, + s"version $v doesn't match current Batch ID ${execCtx.batchId}") + } + currentCheckpointUniqueId.put(opId, checkpointInfo.map { c => + assert(c.checkpointId.isDefined) + c.checkpointId.get + }) + } + + private def updateCheckpointId( Review Comment: That's right. I should write a comment somewhere. ########## sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala: ########## @@ -222,6 +375,456 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest } } + testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2") { + withSQLConf((SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key, "2")) { + val checkpointDir = Utils.createTempDir().getCanonicalFile + checkpointDir.delete() + + val inputData = MemoryStream[Int] + val aggregated = + inputData + .toDF() + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3), + CheckLastBatch((3, 1)), + AddData(inputData, 3, 2), + CheckLastBatch((3, 2), (2, 1)), + StopStream + ) + + // Run the stream with changelog checkpointing enabled. + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3, 2, 1), + CheckLastBatch((3, 3), (2, 2), (1, 1)), + // By default we run in new tuple mode. + AddData(inputData, 4, 4, 4, 4), + CheckLastBatch((4, 4)), + AddData(inputData, 5, 5), + CheckLastBatch((5, 2)) + ) + + // Run the stream with changelog checkpointing disabled. + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 4), + CheckLastBatch((4, 5)) + ) + } + } + + testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2 validate ID") { + val providerClassName = classOf[TestStateStoreProviderWrapper].getCanonicalName + TestStateStoreWrapper.clear() + withSQLConf( + (SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName), + (SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2"), + (SQLConf.SHUFFLE_PARTITIONS.key, "2")) { + val checkpointDir = Utils.createTempDir().getCanonicalFile + checkpointDir.delete() + + val inputData = MemoryStream[Int] + val aggregated = + inputData + .toDF() + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3), + CheckLastBatch((3, 1)), + AddData(inputData, 3, 2), + CheckLastBatch((3, 2), (2, 1)), + StopStream + ) + + // Test recovery + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3, 2, 1), + CheckLastBatch((3, 3), (2, 2), (1, 1)), + // By default we run in new tuple mode. + AddData(inputData, 4, 4, 4, 4), + CheckLastBatch((4, 4)), + AddData(inputData, 5, 5), + CheckLastBatch((5, 2)), + StopStream + ) + + // crash recovery again Review Comment: I'll fix the comment. ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala: ########## @@ -72,7 +72,8 @@ class RocksDB( localRootDir: File = Utils.createTempDir(), hadoopConf: Configuration = new Configuration, loggingId: String = "", - useColumnFamilies: Boolean = false) extends Logging { + useColumnFamilies: Boolean = false, + ifEnableCheckpointId: Boolean = false) extends Logging { Review Comment: Do you have a suggestion for a better name? I can definitely change it. ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala: ########## @@ -57,7 +59,9 @@ class IncrementalExecution( val prevOffsetSeqMetadata: Option[OffsetSeqMetadata], val offsetSeqMetadata: OffsetSeqMetadata, val watermarkPropagator: WatermarkPropagator, - val isFirstBatch: Boolean) + val isFirstBatch: Boolean, + val currentCheckpointUniqueId: + MutableMap[Long, Array[String]] = MutableMap[Long, Array[String]]()) Review Comment: Yes it is true. ########## sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala: ########## @@ -222,6 +375,456 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest } } + testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2") { + withSQLConf((SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key, "2")) { + val checkpointDir = Utils.createTempDir().getCanonicalFile + checkpointDir.delete() + + val inputData = MemoryStream[Int] + val aggregated = + inputData + .toDF() + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3), + CheckLastBatch((3, 1)), + AddData(inputData, 3, 2), + CheckLastBatch((3, 2), (2, 1)), + StopStream + ) + + // Run the stream with changelog checkpointing enabled. Review Comment: It's a bad comment. Let me remove it. ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala: ########## @@ -900,12 +907,42 @@ class MicroBatchExecution( */ protected def markMicroBatchExecutionStart(execCtx: MicroBatchExecutionContext): Unit = {} + private def updateCheckpointIdForOperator( + execCtx: MicroBatchExecutionContext, + opId: Long, + checkpointInfo: Array[StateStoreCheckpointInfo]): Unit = { + // TODO validate baseCheckpointId + checkpointInfo.map(_.batchVersion).foreach { v => + assert( + execCtx.batchId == -1 || v == execCtx.batchId + 1, + s"version $v doesn't match current Batch ID ${execCtx.batchId}") Review Comment: I can rephrase it, but batch n commits to state store version n+1. ########## sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala: ########## @@ -222,6 +375,456 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest } } + testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2") { + withSQLConf((SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key, "2")) { + val checkpointDir = Utils.createTempDir().getCanonicalFile + checkpointDir.delete() + + val inputData = MemoryStream[Int] + val aggregated = + inputData + .toDF() + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3), + CheckLastBatch((3, 1)), + AddData(inputData, 3, 2), + CheckLastBatch((3, 2), (2, 1)), + StopStream + ) + + // Run the stream with changelog checkpointing enabled. + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3, 2, 1), + CheckLastBatch((3, 3), (2, 2), (1, 1)), + // By default we run in new tuple mode. + AddData(inputData, 4, 4, 4, 4), + CheckLastBatch((4, 4)), + AddData(inputData, 5, 5), + CheckLastBatch((5, 2)) + ) + + // Run the stream with changelog checkpointing disabled. + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 4), + CheckLastBatch((4, 5)) + ) + } + } + + testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2 validate ID") { + val providerClassName = classOf[TestStateStoreProviderWrapper].getCanonicalName + TestStateStoreWrapper.clear() + withSQLConf( + (SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName), + (SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2"), + (SQLConf.SHUFFLE_PARTITIONS.key, "2")) { + val checkpointDir = Utils.createTempDir().getCanonicalFile + checkpointDir.delete() + + val inputData = MemoryStream[Int] + val aggregated = + inputData + .toDF() + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3), + CheckLastBatch((3, 1)), + AddData(inputData, 3, 2), + CheckLastBatch((3, 2), (2, 1)), + StopStream + ) + + // Test recovery + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 3, 2, 1), + CheckLastBatch((3, 3), (2, 2), (1, 1)), + // By default we run in new tuple mode. + AddData(inputData, 4, 4, 4, 4), + CheckLastBatch((4, 4)), + AddData(inputData, 5, 5), + CheckLastBatch((5, 2)), + StopStream + ) + + // crash recovery again + testStream(aggregated, Update)( + StartStream(checkpointLocation = checkpointDir.getAbsolutePath), + AddData(inputData, 4), + CheckLastBatch((4, 5)) + ) + } + val checkpointInfoList = TestStateStoreWrapper.getCheckpointInfos + // We have 6 batches, 2 partitions, and 1 state store per batch + assert(checkpointInfoList.size == 12) + checkpointInfoList.foreach { l => + assert(l.checkpointId.isDefined) + if (l.batchVersion == 2 || l.batchVersion == 4 || l.batchVersion == 5) { + assert(l.baseCheckpointId.isDefined) + } + } + assert(checkpointInfoList.count(_.partitionId == 0) == 6) + assert(checkpointInfoList.count(_.partitionId == 1) == 6) + for (i <- 1 to 6) { + assert(checkpointInfoList.count(_.batchVersion == i) == 2) + } + for { + a <- checkpointInfoList + b <- checkpointInfoList + if a.partitionId == b.partitionId && a.batchVersion == b.batchVersion + 1 + } { + // if batch version exists, it should be the same as the checkpoint ID of the previous batch + assert(!a.baseCheckpointId.isDefined || b.checkpointId == a.baseCheckpointId) + } + } + + testWithChangelogCheckpointingEnabled( + s"checkpointFormatVersion2 validate ID with dedup and groupBy") { + val providerClassName = classOf[TestStateStoreProviderWrapper].getCanonicalName + TestStateStoreWrapper.clear() + withSQLConf( + (SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName), + (SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2"), + (SQLConf.SHUFFLE_PARTITIONS.key, "2")) { + val checkpointDir = Utils.createTempDir().getCanonicalFile + checkpointDir.delete() Review Comment: That's a good point. I'll do it. I copy&pasted from a previous test without thinking. ########## sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala: ########## @@ -803,6 +843,14 @@ class RocksDB( /** Get the write buffer manager and cache */ def getWriteBufferManagerAndCache(): (WriteBufferManager, Cache) = (writeBufferManager, lruCache) + def getLatestCheckpointInfo(partitionId: Int): StateStoreCheckpointInfo = { Review Comment: This will always be called. The caller has no knowledge on what's going on there. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org