siying commented on code in PR #47895:
URL: https://github.com/apache/spark/pull/47895#discussion_r1755481725


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StreamStreamJoinStatePartitionReader.scala:
##########
@@ -105,7 +105,7 @@ class StreamStreamJoinStatePartitionReader(
       val stateInfo = StatefulOperatorStateInfo(
         partition.sourceOptions.stateCheckpointLocation.toString,
         partition.queryId, partition.sourceOptions.operatorId,
-        partition.sourceOptions.batchId + 1, -1)
+        partition.sourceOptions.batchId + 1, -1, None)

Review Comment:
   Good point. Will check.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala:
##########
@@ -900,12 +907,54 @@ class MicroBatchExecution(
    */
   protected def markMicroBatchExecutionStart(execCtx: 
MicroBatchExecutionContext): Unit = {}
 
+  private def updateCheckpointIdForOperator(
+      execCtx: MicroBatchExecutionContext,
+      opId: Long,
+      checkpointInfo: Array[StateStoreCheckpointInfo]): Unit = {
+    // TODO validate baseCheckpointId
+    checkpointInfo.map(_.batchVersion).foreach { v =>
+      assert(
+        execCtx.batchId == -1 || v == execCtx.batchId + 1,
+        s"version $v doesn't match current Batch ID ${execCtx.batchId}")
+    }
+    currentCheckpointUniqueId.put(opId, checkpointInfo.map { c =>
+      assert(c.checkpointId.isDefined)
+      c.checkpointId.get
+    })
+  }
+
+  private def updateCheckpointId(

Review Comment:
   That's right. I should write a comment somewhere.



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala:
##########
@@ -222,6 +375,456 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest
     }
   }
 
+  testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2") {
+    withSQLConf((SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key, "2")) {
+      val checkpointDir = Utils.createTempDir().getCanonicalFile
+      checkpointDir.delete()
+
+      val inputData = MemoryStream[Int]
+      val aggregated =
+        inputData
+          .toDF()
+          .groupBy($"value")
+          .agg(count("*"))
+          .as[(Int, Long)]
+
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 3),
+        CheckLastBatch((3, 1)),
+        AddData(inputData, 3, 2),
+        CheckLastBatch((3, 2), (2, 1)),
+        StopStream
+      )
+
+      // Run the stream with changelog checkpointing enabled.
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 3, 2, 1),
+        CheckLastBatch((3, 3), (2, 2), (1, 1)),
+        // By default we run in new tuple mode.
+        AddData(inputData, 4, 4, 4, 4),
+        CheckLastBatch((4, 4)),
+        AddData(inputData, 5, 5),
+        CheckLastBatch((5, 2))
+      )
+
+      // Run the stream with changelog checkpointing disabled.
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 4),
+        CheckLastBatch((4, 5))
+      )
+    }
+  }
+
+  testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2 validate 
ID") {
+    val providerClassName = 
classOf[TestStateStoreProviderWrapper].getCanonicalName
+    TestStateStoreWrapper.clear()
+    withSQLConf(
+      (SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName),
+      (SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2"),
+      (SQLConf.SHUFFLE_PARTITIONS.key, "2")) {
+      val checkpointDir = Utils.createTempDir().getCanonicalFile
+      checkpointDir.delete()
+
+      val inputData = MemoryStream[Int]
+      val aggregated =
+        inputData
+          .toDF()
+          .groupBy($"value")
+          .agg(count("*"))
+          .as[(Int, Long)]
+
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 3),
+        CheckLastBatch((3, 1)),
+        AddData(inputData, 3, 2),
+        CheckLastBatch((3, 2), (2, 1)),
+        StopStream
+      )
+
+      // Test recovery
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 3, 2, 1),
+        CheckLastBatch((3, 3), (2, 2), (1, 1)),
+        // By default we run in new tuple mode.
+        AddData(inputData, 4, 4, 4, 4),
+        CheckLastBatch((4, 4)),
+        AddData(inputData, 5, 5),
+        CheckLastBatch((5, 2)),
+        StopStream
+      )
+
+      // crash recovery again

Review Comment:
   I'll fix the comment.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala:
##########
@@ -72,7 +72,8 @@ class RocksDB(
     localRootDir: File = Utils.createTempDir(),
     hadoopConf: Configuration = new Configuration,
     loggingId: String = "",
-    useColumnFamilies: Boolean = false) extends Logging {
+    useColumnFamilies: Boolean = false,
+    ifEnableCheckpointId: Boolean = false) extends Logging {

Review Comment:
   Do you have a suggestion for a better name? I can definitely change it.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala:
##########
@@ -57,7 +59,9 @@ class IncrementalExecution(
     val prevOffsetSeqMetadata: Option[OffsetSeqMetadata],
     val offsetSeqMetadata: OffsetSeqMetadata,
     val watermarkPropagator: WatermarkPropagator,
-    val isFirstBatch: Boolean)
+    val isFirstBatch: Boolean,
+    val currentCheckpointUniqueId:
+      MutableMap[Long, Array[String]] = MutableMap[Long, Array[String]]())

Review Comment:
   Yes it is true.



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala:
##########
@@ -222,6 +375,456 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest
     }
   }
 
+  testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2") {
+    withSQLConf((SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key, "2")) {
+      val checkpointDir = Utils.createTempDir().getCanonicalFile
+      checkpointDir.delete()
+
+      val inputData = MemoryStream[Int]
+      val aggregated =
+        inputData
+          .toDF()
+          .groupBy($"value")
+          .agg(count("*"))
+          .as[(Int, Long)]
+
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 3),
+        CheckLastBatch((3, 1)),
+        AddData(inputData, 3, 2),
+        CheckLastBatch((3, 2), (2, 1)),
+        StopStream
+      )
+
+      // Run the stream with changelog checkpointing enabled.

Review Comment:
   It's a bad comment. Let me remove it.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala:
##########
@@ -900,12 +907,42 @@ class MicroBatchExecution(
    */
   protected def markMicroBatchExecutionStart(execCtx: 
MicroBatchExecutionContext): Unit = {}
 
+  private def updateCheckpointIdForOperator(
+      execCtx: MicroBatchExecutionContext,
+      opId: Long,
+      checkpointInfo: Array[StateStoreCheckpointInfo]): Unit = {
+    // TODO validate baseCheckpointId
+    checkpointInfo.map(_.batchVersion).foreach { v =>
+      assert(
+        execCtx.batchId == -1 || v == execCtx.batchId + 1,
+        s"version $v doesn't match current Batch ID ${execCtx.batchId}")

Review Comment:
   I can rephrase it, but batch n commits to state store version n+1. 



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreIntegrationSuite.scala:
##########
@@ -222,6 +375,456 @@ class RocksDBStateStoreIntegrationSuite extends StreamTest
     }
   }
 
+  testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2") {
+    withSQLConf((SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key, "2")) {
+      val checkpointDir = Utils.createTempDir().getCanonicalFile
+      checkpointDir.delete()
+
+      val inputData = MemoryStream[Int]
+      val aggregated =
+        inputData
+          .toDF()
+          .groupBy($"value")
+          .agg(count("*"))
+          .as[(Int, Long)]
+
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 3),
+        CheckLastBatch((3, 1)),
+        AddData(inputData, 3, 2),
+        CheckLastBatch((3, 2), (2, 1)),
+        StopStream
+      )
+
+      // Run the stream with changelog checkpointing enabled.
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 3, 2, 1),
+        CheckLastBatch((3, 3), (2, 2), (1, 1)),
+        // By default we run in new tuple mode.
+        AddData(inputData, 4, 4, 4, 4),
+        CheckLastBatch((4, 4)),
+        AddData(inputData, 5, 5),
+        CheckLastBatch((5, 2))
+      )
+
+      // Run the stream with changelog checkpointing disabled.
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 4),
+        CheckLastBatch((4, 5))
+      )
+    }
+  }
+
+  testWithChangelogCheckpointingEnabled(s"checkpointFormatVersion2 validate 
ID") {
+    val providerClassName = 
classOf[TestStateStoreProviderWrapper].getCanonicalName
+    TestStateStoreWrapper.clear()
+    withSQLConf(
+      (SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName),
+      (SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2"),
+      (SQLConf.SHUFFLE_PARTITIONS.key, "2")) {
+      val checkpointDir = Utils.createTempDir().getCanonicalFile
+      checkpointDir.delete()
+
+      val inputData = MemoryStream[Int]
+      val aggregated =
+        inputData
+          .toDF()
+          .groupBy($"value")
+          .agg(count("*"))
+          .as[(Int, Long)]
+
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 3),
+        CheckLastBatch((3, 1)),
+        AddData(inputData, 3, 2),
+        CheckLastBatch((3, 2), (2, 1)),
+        StopStream
+      )
+
+      // Test recovery
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 3, 2, 1),
+        CheckLastBatch((3, 3), (2, 2), (1, 1)),
+        // By default we run in new tuple mode.
+        AddData(inputData, 4, 4, 4, 4),
+        CheckLastBatch((4, 4)),
+        AddData(inputData, 5, 5),
+        CheckLastBatch((5, 2)),
+        StopStream
+      )
+
+      // crash recovery again
+      testStream(aggregated, Update)(
+        StartStream(checkpointLocation = checkpointDir.getAbsolutePath),
+        AddData(inputData, 4),
+        CheckLastBatch((4, 5))
+      )
+    }
+    val checkpointInfoList = TestStateStoreWrapper.getCheckpointInfos
+    // We have 6 batches, 2 partitions, and 1 state store per batch
+    assert(checkpointInfoList.size == 12)
+    checkpointInfoList.foreach { l =>
+      assert(l.checkpointId.isDefined)
+      if (l.batchVersion == 2 || l.batchVersion == 4 || l.batchVersion == 5) {
+        assert(l.baseCheckpointId.isDefined)
+      }
+    }
+    assert(checkpointInfoList.count(_.partitionId == 0) == 6)
+    assert(checkpointInfoList.count(_.partitionId == 1) == 6)
+    for (i <- 1 to 6) {
+      assert(checkpointInfoList.count(_.batchVersion == i) == 2)
+    }
+    for {
+      a <- checkpointInfoList
+      b <- checkpointInfoList
+      if a.partitionId == b.partitionId && a.batchVersion == b.batchVersion + 1
+    } {
+      // if batch version exists, it should be the same as the checkpoint ID 
of the previous batch
+      assert(!a.baseCheckpointId.isDefined || b.checkpointId == 
a.baseCheckpointId)
+    }
+  }
+
+  testWithChangelogCheckpointingEnabled(
+    s"checkpointFormatVersion2 validate ID with dedup and groupBy") {
+    val providerClassName = 
classOf[TestStateStoreProviderWrapper].getCanonicalName
+    TestStateStoreWrapper.clear()
+    withSQLConf(
+      (SQLConf.STATE_STORE_PROVIDER_CLASS.key -> providerClassName),
+      (SQLConf.STATE_STORE_CHECKPOINT_FORMAT_VERSION.key -> "2"),
+      (SQLConf.SHUFFLE_PARTITIONS.key, "2")) {
+      val checkpointDir = Utils.createTempDir().getCanonicalFile
+      checkpointDir.delete()

Review Comment:
   That's a good point. I'll do it. I copy&pasted from a previous test without 
thinking.



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDB.scala:
##########
@@ -803,6 +843,14 @@ class RocksDB(
   /** Get the write buffer manager and cache */
   def getWriteBufferManagerAndCache(): (WriteBufferManager, Cache) = 
(writeBufferManager, lruCache)
 
+  def getLatestCheckpointInfo(partitionId: Int): StateStoreCheckpointInfo = {

Review Comment:
   This will always be called. The caller has no knowledge on what's going on 
there.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to