wankunde commented on code in PR #37533:
URL: https://github.com/apache/spark/pull/37533#discussion_r973869591


##########
core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala:
##########
@@ -4443,36 +4443,115 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     assert(mapStatuses.count(s => s != null && s.location.executorId == 
"hostB-exec") === 1)
   }
 
-  test("SPARK-40096: Send finalize events even if shuffle merger blocks 
indefinitely") {
+  test("SPARK-40096: Send finalize events even if shuffle merger blocks 
indefinitely " +
+    "with registerMergeResults is true") {
     initPushBasedShuffleConfs(conf)
 
+    sc.conf.set("spark.shuffle.push.results.timeout", "1s")
+    val myScheduler = new MyDAGScheduler(
+      sc,
+      taskScheduler,
+      sc.listenerBus,
+      mapOutputTracker,
+      blockManagerMaster,
+      sc.env,
+      shuffleMergeFinalize = false)
+
+    val mergerLocs = Seq(makeBlockManagerId("hostA"), 
makeBlockManagerId("hostB"))
+    val timeoutSecs = 1
+    val sendRequestsLatch = new CountDownLatch(mergerLocs.size)
+    val completeLatch = new CountDownLatch(mergerLocs.size)
+    val canSendRequestLatch = new CountDownLatch(1)
+
     val blockStoreClient = mock(classOf[ExternalBlockStoreClient])
     val blockStoreClientField = 
classOf[BlockManager].getDeclaredField("blockStoreClient")
     blockStoreClientField.setAccessible(true)
     blockStoreClientField.set(sc.env.blockManager, blockStoreClient)
+
     val sentHosts = ArrayBuffer[String]()
+    var hostAInterrupted = false
     doAnswer { (invoke: InvocationOnMock) =>
       val host = invoke.getArgument[String](0)
-      sentHosts += host
-      // Block FinalizeShuffleMerge rpc for 2 seconds
-      if (invoke.getArgument[String](0) == "hostA") {
-        Thread.sleep(2000)
+      sendRequestsLatch.countDown()
+      try {
+        if (host == "hostA") {
+          canSendRequestLatch.await(timeoutSecs * 2, TimeUnit.SECONDS)
+        }
+        sentHosts += host
+      } catch {
+        case _: InterruptedException => hostAInterrupted = true
+      } finally {
+        completeLatch.countDown()
       }
     }.when(blockStoreClient).finalizeShuffleMerge(any(), any(), any(), any(), 
any())
 
     val shuffleMapRdd = new MyRDD(sc, 1, Nil)
     val shuffleDep = new ShuffleDependency(shuffleMapRdd, new 
HashPartitioner(2))
-    shuffleDep.setMergerLocs(Seq(makeBlockManagerId("hostA"), 
makeBlockManagerId("hostB")))
-    val shuffleStage = scheduler.createShuffleMapStage(shuffleDep, 0)
-
-    Seq(true, false).foreach { registerMergeResults =>
-      sentHosts.clear()
-      scheduler.finalizeShuffleMerge(shuffleStage, registerMergeResults)
-      verify(blockStoreClient, times(2))
-        .finalizeShuffleMerge(any(), any(), any(), any(), any())
-      assert((sentHosts diff Seq("hostA", "hostB")).isEmpty)
-      reset(blockStoreClient)
-    }
+    shuffleDep.setMergerLocs(mergerLocs)
+    val shuffleStage = myScheduler.createShuffleMapStage(shuffleDep, 0)
+
+    myScheduler.finalizeShuffleMerge(shuffleStage, true)
+    sendRequestsLatch.await()
+    verify(blockStoreClient, times(2))
+      .finalizeShuffleMerge(any(), any(), any(), any(), any())
+    assert(sentHosts === Seq("hostB"))
+    completeLatch.await()
+    assert(hostAInterrupted)
+  }
+
+  test("SPARK-40096: Send finalize events even if shuffle merger blocks 
indefinitely " +
+    "with registerMergeResults is false") {

Review Comment:
   Thanks @mridulm I have merged UTs



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to