This is an automated email from the ASF dual-hosted git repository. kabhwan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 9dc5599b01b1 [SPARK-48105][SS] Fix the race condition between state store unloading and snapshotting 9dc5599b01b1 is described below commit 9dc5599b01b197b6f703d93486ff960a67e4e25c Author: Huanli Wang <huanli.w...@databricks.com> AuthorDate: Tue May 7 09:39:13 2024 +0900 [SPARK-48105][SS] Fix the race condition between state store unloading and snapshotting ### What changes were proposed in this pull request? * When we close the hdfs state store, we should only remove the entry from `loadedMaps` rather than doing the active data cleanup. JVM GC should be able to help us GC those objects. * we should wait for the maintenance thread to stop before unloading the providers. ### Why are the changes needed? There are two race conditions between state store snapshotting and state store unloading which could result in query failure and potential data corruption. Case 1: 1. the maintenance thread pool encounters some issues and call the [stopMaintenanceTask,](https://github.com/apache/spark/blob/d9d79a54a3cd487380039c88ebe9fa708e0dcf23/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala#L774) this function further calls [threadPool.stop.](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala#L587) However, this function doesn't wait for th [...] 2. the provider unload will [close the state store](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala#L719-L721) which [clear the values of loadedMaps](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala#L353-L355) for HDFS backed state store. 3. if the not-yet-stop maintenance thread is still running and trying to do the snapshot, but the data in the underlying `HDFSBackedStateStoreMap` has been removed. if this snapshot process completes successfully, then we will write corrupted data and the following batches will consume this corrupted data. Case 2: 1. In executor_1, the maintenance thread is going to do the snapshot for state_store_1, it retrieves the `HDFSBackedStateStoreMap` object from the loadedMaps, after this, the maintenance thread [releases the lock of the loadedMaps](https://github.com/apache/spark/blob/c6696cdcd611a682ebf5b7a183e2970ecea3b58c/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala#L750-L751). 2. state_store_1 is loaded in another executor, e.g. executor_2. 3. another state store, state_store_2, is loaded on executor_1 and [reportActiveStoreInstance](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala#L854-L871) to driver. 4. executor_1 does the [unload](https://github.com/apache/spark/blob/c6696cdcd611a682ebf5b7a183e2970ecea3b58c/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala#L713) for those no longer active state store which clears the data entries in the `HDFSBackedStateStoreMap` 5. the snapshotting thread is terminated and uploads the incomplete snapshot to cloud because the [iterator doesn't have next element](https://github.com/apache/spark/blob/c6696cdcd611a682ebf5b7a183e2970ecea3b58c/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala#L634) after doing the clear. 6. future batches are consuming the corrupted data. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? ``` [info] Run completed in 2 minutes, 55 seconds. [info] Total number of tests run: 153 [info] Suites: completed 1, aborted 0 [info] Tests: succeeded 153, failed 0, canceled 0, ignored 0, pending 0 [info] All tests passed. [success] Total time: 271 s (04:31), completed May 2, 2024, 6:26:33 PM ``` before this change ``` [info] - state store unload/close happens during the maintenance *** FAILED *** (648 milliseconds) [info] Vector("a1", "a10", "a11", "a12", "a13", "a14", "a15", "a16", "a17", "a18", "a19", "a2", "a20", "a3", "a4", "a5", "a6", "a7", "a8", "a9") did not equal ArrayBuffer("a8") (StateStoreSuite.scala:414) [info] Analysis: [info] Vector1(0: "a1" -> "a8", 1: "a10" -> , 2: "a11" -> , 3: "a12" -> , 4: "a13" -> , 5: "a14" -> , 6: "a15" -> , 7: "a16" -> , 8: "a17" -> , 9: "a18" -> , 10: "a19" -> , 11: "a2" -> , 12: "a20" -> , 13: "a3" -> , 14: "a4" -> , 15: "a5" -> , 16: "a6" -> , 17: "a7" -> , 18: "a8" -> , 19: "a9" -> ) [info] org.scalatest.exceptions.TestFailedException: [info] at org.scalatest.Assertions.newAssertionFailedException(Assertions.scala:472) [info] at org.scalatest.Assertions.newAssertionFailedException$(Assertions.scala:471) [info] at org.scalatest.Assertions$.newAssertionFailedException(Assertions.scala:1231) [info] at org.scalatest.Assertions$AssertionsHelper.macroAssert(Assertions.scala:1295) [info] at org.apache.spark.sql.execution.streaming.state.StateStoreSuite.$anonfun$new$39(StateStoreSuite.scala:414) [info] at org.apache.spark.sql.execution.streaming.state.StateStoreSuiteBase.tryWithProviderResource(StateStoreSuite.scala:1663) [info] at org.apache.spark.sql.execution.streaming.state.StateStoreSuite.$anonfun$new$38(StateStoreSuite.scala:394) 18:32:09.694 WARN org.apache.spark.sql.execution.streaming.state.StateStoreSuite: ===== POSSIBLE THREAD LEAK IN SUITE o.a.s.sql.execution.streaming.state.StateStoreSuite, threads: ForkJoinPool.commonPool-worker-1 (daemon=true) ===== [info] at org.scalatest.enablers.Timed$$anon$1.timeoutAfter(Timed.scala:127) [info] at org.scalatest.concurrent.TimeLimits$.failAfterImpl(TimeLimits.scala:282) [info] at org.scalatest.concurrent.TimeLimits.failAfter(TimeLimits.scala:231) [info] at org.scalatest.concurrent.TimeLimits.failAfter$(TimeLimits.scala:230) [info] at org.apache.spark.SparkFunSuite.failAfter(SparkFunSuite.scala:69) [info] at org.apache.spark.SparkFunSuite.$anonfun$test$2(SparkFunSuite.scala:155) [info] at org.scalatest.OutcomeOf.outcomeOf(OutcomeOf.scala:85) [info] at org.scalatest.OutcomeOf.outcomeOf$(OutcomeOf.scala:83) [info] at org.scalatest.OutcomeOf$.outcomeOf(OutcomeOf.scala:104) [info] at org.scalatest.Transformer.apply(Transformer.scala:22) [info] at org.scalatest.Transformer.apply(Transformer.scala:20) [info] at org.scalatest.funsuite.AnyFunSuiteLike$$anon$1.apply(AnyFunSuiteLike.scala:226) [info] at org.apache.spark.SparkFunSuite.withFixture(SparkFunSuite.scala:227) [info] at org.scalatest.funsuite.AnyFunSuiteLike.invokeWithFixture$1(AnyFunSuiteLike.scala:224) [info] at org.scalatest.funsuite.AnyFunSuiteLike.$anonfun$runTest$1(AnyFunSuiteLike.scala:236) [info] at org.scalatest.SuperEngine.runTestImpl(Engine.scala:306) [info] at org.scalatest.funsuite.AnyFunSuiteLike.runTest(AnyFunSuiteLike.scala:236) [info] at org.scalatest.funsuite.AnyFunSuiteLike.runTest$(AnyFunSuiteLike.scala:218) [info] at org.apache.spark.SparkFunSuite.org$scalatest$BeforeAndAfterEach$$super$runTest(SparkFunSuite.scala:69) [info] at org.scalatest.BeforeAndAfterEach.runTest(BeforeAndAfterEach.scala:234) [info] at org.scalatest.BeforeAndAfterEach.runTest$(BeforeAndAfterEach.scala:227) [info] at org.apache.spark.sql.execution.streaming.state.StateStoreSuite.org$scalatest$BeforeAndAfter$$super$runTest(StateStoreSuite.scala:90) [info] at org.scalatest.BeforeAndAfter.runTest(BeforeAndAfter.scala:213) [info] at org.scalatest.BeforeAndAfter.runTest$(BeforeAndAfter.scala:203) [info] at org.apache.spark.sql.execution.streaming.state.StateStoreSuite.runTest(StateStoreSuite.scala:90) [info] at org.scalatest.funsuite.AnyFunSuiteLike.$anonfun$runTests$1(AnyFunSuiteLike.scala:269) [info] at org.scalatest.SuperEngine.$anonfun$runTestsInBranch$1(Engine.scala:413) [info] at scala.collection.immutable.List.foreach(List.scala:334) [info] at org.scalatest.SuperEngine.traverseSubNodes$1(Engine.scala:401) [info] at org.scalatest.SuperEngine.runTestsInBranch(Engine.scala:396) [info] at org.scalatest.SuperEngine.runTestsImpl(Engine.scala:475) [info] at org.scalatest.funsuite.AnyFunSuiteLike.runTests(AnyFunSuiteLike.scala:269) [info] at org.scalatest.funsuite.AnyFunSuiteLike.runTests$(AnyFunSuiteLike.scala:268) [info] at org.scalatest.funsuite.AnyFunSuite.runTests(AnyFunSuite.scala:1564) [info] at org.scalatest.Suite.run(Suite.scala:1114) [info] at org.scalatest.Suite.run$(Suite.scala:1096) [info] at org.scalatest.funsuite.AnyFunSuite.org$scalatest$funsuite$AnyFunSuiteLike$$super$run(AnyFunSuite.scala:1564) [info] at org.scalatest.funsuite.AnyFunSuiteLike.$anonfun$run$1(AnyFunSuiteLike.scala:273) [info] at org.scalatest.SuperEngine.runImpl(Engine.scala:535) [info] at org.scalatest.funsuite.AnyFunSuiteLike.run(AnyFunSuiteLike.scala:273) [info] at org.scalatest.funsuite.AnyFunSuiteLike.run$(AnyFunSuiteLike.scala:272) [info] at org.apache.spark.SparkFunSuite.org$scalatest$BeforeAndAfterAll$$super$run(SparkFunSuite.scala:69) [info] at org.scalatest.BeforeAndAfterAll.liftedTree1$1(BeforeAndAfterAll.scala:213) [info] at org.scalatest.BeforeAndAfterAll.run(BeforeAndAfterAll.scala:210) [info] at org.scalatest.BeforeAndAfterAll.run$(BeforeAndAfterAll.scala:208) [info] at org.apache.spark.sql.execution.streaming.state.StateStoreSuite.org$scalatest$BeforeAndAfter$$super$run(StateStoreSuite.scala:90) [info] at org.scalatest.BeforeAndAfter.run(BeforeAndAfter.scala:273) [info] at org.scalatest.BeforeAndAfter.run$(BeforeAndAfter.scala:271) [info] at org.apache.spark.sql.execution.streaming.state.StateStoreSuite.run(StateStoreSuite.scala:90) [info] at org.scalatest.tools.Framework.org$scalatest$tools$Framework$$runSuite(Framework.scala:321) [info] at org.scalatest.tools.Framework$ScalaTestTask.execute(Framework.scala:517) [info] at sbt.ForkMain$Run.lambda$runTest$1(ForkMain.java:414) [info] at java.base/java.util.concurrent.FutureTask.run(FutureTask.java:264) [info] at java.base/java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1136) [info] at java.base/java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:635) [info] at java.base/java.lang.Thread.run(Thread.java:840) [info] Run completed in 2 seconds, 4 milliseconds. [info] Total number of tests run: 1 [info] Suites: completed 1, aborted 0 [info] Tests: succeeded 0, failed 1, canceled 0, ignored 0, pending 0 [info] *** 1 TEST FAILED *** ``` ### Was this patch authored or co-authored using generative AI tooling? No Closes #46351 from huanliwang-db/race. Authored-by: Huanli Wang <huanli.w...@databricks.com> Signed-off-by: Jungtaek Lim <kabhwan.opensou...@gmail.com> --- .../streaming/state/HDFSBackedStateStoreMap.scala | 8 ----- .../state/HDFSBackedStateStoreProvider.scala | 5 ++- .../sql/execution/streaming/state/StateStore.scala | 16 ++++++++- .../streaming/state/StateStoreSuite.scala | 38 ++++++++++++++++++++++ 4 files changed, 57 insertions(+), 10 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreMap.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreMap.scala index 32ff87f754d7..fe59703a1f45 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreMap.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreMap.scala @@ -32,7 +32,6 @@ trait HDFSBackedStateStoreMap { def remove(key: UnsafeRow): UnsafeRow def iterator(): Iterator[UnsafeRowPair] def prefixScan(prefixKey: UnsafeRow): Iterator[UnsafeRowPair] - def clear(): Unit } object HDFSBackedStateStoreMap { @@ -80,8 +79,6 @@ class NoPrefixHDFSBackedStateStoreMap extends HDFSBackedStateStoreMap { override def prefixScan(prefixKey: UnsafeRow): Iterator[UnsafeRowPair] = { throw SparkUnsupportedOperationException() } - - override def clear(): Unit = map.clear() } class PrefixScannableHDFSBackedStateStoreMap( @@ -170,9 +167,4 @@ class PrefixScannableHDFSBackedStateStoreMap( .iterator .map { key => unsafeRowPair.withRows(key, map.get(key)) } } - - override def clear(): Unit = { - map.clear() - prefixKeyToKeysMap.clear() - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 64a9eaad8805..543cd74c489d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -351,7 +351,10 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with } override def close(): Unit = { - synchronized { loadedMaps.values.asScala.foreach(_.clear()) } + // Clearing the map resets the TreeMap.root to null, and therefore entries inside the + // `loadedMaps` will be de-referenced and GCed automatically when their reference + // counts become 0. + synchronized { loadedMaps.clear() } } override def supportedCustomMetrics: Seq[StateStoreCustomMetric] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 2a1a0f7b01d9..8c2170abe311 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -584,7 +584,21 @@ object StateStore extends Logging { } def stop(): Unit = { - threadPool.shutdown() + logInfo("Shutting down MaintenanceThreadPool") + threadPool.shutdown() // Disable new tasks from being submitted + + // Wait a while for existing tasks to terminate + if (!threadPool.awaitTermination(5 * 60, TimeUnit.SECONDS)) { + logWarning( + s"MaintenanceThreadPool is not able to be terminated within 300 seconds," + + " forcefully shutting down now.") + threadPool.shutdownNow() // Cancel currently executing tasks + + // Wait a while for tasks to respond to being cancelled + if (!threadPool.awaitTermination(60, TimeUnit.SECONDS)) { + logError("MaintenanceThreadPool did not terminate") + } + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 4523a14ca1cc..6a6867fbb552 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -388,6 +388,44 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] } } + test("SPARK-48105: state store unload/close happens during the maintenance") { + tryWithProviderResource( + newStoreProvider(opId = Random.nextInt(), partition = 0, minDeltasForSnapshot = 1)) { + provider => + val store = provider.getStore(0).asInstanceOf[provider.HDFSBackedStateStore] + val values = (1 to 20) + val keys = values.map(i => ("a" + i)) + keys.zip(values).map{case (k, v) => put(store, k, 0, v)} + // commit state store with 20 keys. + store.commit() + // get the state store iterator: mimic the case which the iterator is hold in the + // maintenance thread. + val storeIterator = store.iterator() + + // the store iterator should still be valid as the maintenance thread may have already + // hold it and is doing snapshotting even though the state store is unloaded. + val outputKeys = new mutable.ArrayBuffer[String] + val outputValues = new mutable.ArrayBuffer[Int] + var cnt = 0 + while (storeIterator.hasNext) { + if (cnt == 10) { + // Mimic the case where the provider is loaded in another executor in the middle of + // iteration. When this happens, the provider will be unloaded and closed in + // current executor. + provider.close() + } + val unsafeRowPair = storeIterator.next() + val (key, _) = keyRowToData(unsafeRowPair.key) + outputKeys.append(key) + outputValues.append(valueRowToData(unsafeRowPair.value)) + + cnt = cnt + 1 + } + assert(keys.sorted === outputKeys.sorted) + assert(values.sorted === outputValues.sorted) + } + } + test("maintenance") { val conf = new SparkConf() .setMaster("local") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org