This is an automated email from the ASF dual-hosted git repository.

mridulm80 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new f6128a6  [SPARK-33701][SHUFFLE] Adaptive shuffle merge finalization 
for push-based shuffle
f6128a6 is described below

commit f6128a6f4215dc45a19209d799dd9bf98fab6d8a
Author: Venkata krishnan Sowrirajan <vsowrira...@linkedin.com>
AuthorDate: Wed Jan 5 01:47:01 2022 -0600

    [SPARK-33701][SHUFFLE] Adaptive shuffle merge finalization for push-based 
shuffle
    
    ### What changes were proposed in this pull request?
    
    As part of SPARK-32920 implemented a simple approach to finalization for 
push-based shuffle. Shuffle merge finalization is the final operation happens 
at the end of the stage when all the tasks are completed asking all the 
external shuffle services to complete the shuffle merge for the stage. Once 
this request is completed no more shuffle pushes will be accepted. With this 
approach, `DAGScheduler` waits for a fixed time of 10s 
(`spark.shuffle.push.finalize.timeout`) to allow some time [...]
    
    In this PR, instead of waiting for fixed amount of time before shuffle 
merge finalization now this is controlled adaptively if min threshold number of 
map tasks shuffle push (`spark.shuffle.push.minPushRatio`) completed then 
shuffle merge finalization will be scheduled. Also additionally if the total 
shuffle generated is lesser than min threshold shuffle size 
(`spark.shuffle.push.minShuffleSizeToWait`) then immediately shuffle merge 
finalization is scheduled.
    ### Why are the changes needed?
    
    This is a performance improvement to the existing functionality
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes additional user facing configs `spark.shuffle.push.minPushRatio` and 
`spark.shuffle.push.minShuffleSizeToWait`
    
    ### How was this patch tested?
    
    Added unit tests in `DAGSchedulerSuite`, `ShuffleBlockPusherSuite`
    
    Lead-authored-by: Min Shen <mshenlinkedin.com>
    Co-authored-by: Venkata krishnan Sowrirajan <vsowrirajanlinkedin.com>
    
    Closes #33896 from venkata91/SPARK-33701.
    
    Lead-authored-by: Venkata krishnan Sowrirajan <vsowrira...@linkedin.com>
    Co-authored-by: Min Shen <ms...@linkedin.com>
    Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
---
 .../main/scala/org/apache/spark/Dependency.scala   |  35 ++-
 .../scala/org/apache/spark/MapOutputTracker.scala  |   6 +-
 .../src/main/scala/org/apache/spark/SparkEnv.scala |   3 +
 .../executor/CoarseGrainedExecutorBackend.scala    |   6 +
 .../org/apache/spark/internal/config/package.scala |  27 ++
 .../org/apache/spark/scheduler/DAGScheduler.scala  | 278 +++++++++++++----
 .../apache/spark/scheduler/DAGSchedulerEvent.scala |   4 +
 .../cluster/CoarseGrainedClusterMessage.scala      |   3 +
 .../cluster/CoarseGrainedSchedulerBackend.scala    |   3 +
 .../apache/spark/shuffle/ShuffleBlockPusher.scala  |  39 ++-
 .../apache/spark/scheduler/DAGSchedulerSuite.scala | 340 +++++++++++++++++++--
 .../spark/shuffle/ShuffleBlockPusherSuite.scala    | 101 +++++-
 docs/configuration.md                              |  16 +
 13 files changed, 772 insertions(+), 89 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala 
b/core/src/main/scala/org/apache/spark/Dependency.scala
index 1b4e7ba..8e348ee 100644
--- a/core/src/main/scala/org/apache/spark/Dependency.scala
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -17,8 +17,12 @@
 
 package org.apache.spark
 
+import java.util.concurrent.ScheduledFuture
+
 import scala.reflect.ClassTag
 
+import org.roaringbitmap.RoaringBitmap
+
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.internal.Logging
 import org.apache.spark.rdd.RDD
@@ -131,9 +135,7 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: 
ClassTag](
   def shuffleMergeId: Int = _shuffleMergeId
 
   def setMergerLocs(mergerLocs: Seq[BlockManagerId]): Unit = {
-    if (mergerLocs != null) {
-      this.mergerLocs = mergerLocs
-    }
+    this.mergerLocs = mergerLocs
   }
 
   def getMergerLocs: Seq[BlockManagerId] = mergerLocs
@@ -160,6 +162,8 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: 
ClassTag](
     _shuffleMergedFinalized = false
     mergerLocs = Nil
     _shuffleMergeId += 1
+    finalizeTask = None
+    shufflePushCompleted.clear()
   }
 
   private def canShuffleMergeBeEnabled(): Boolean = {
@@ -169,11 +173,34 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: 
ClassTag](
     if (isPushShuffleEnabled && rdd.isBarrier()) {
       logWarning("Push-based shuffle is currently not supported for barrier 
stages")
     }
-    isPushShuffleEnabled &&
+    isPushShuffleEnabled && numPartitions > 0 &&
       // TODO: SPARK-35547: Push based shuffle is currently unsupported for 
Barrier stages
       !rdd.isBarrier()
   }
 
+  @transient private[this] val shufflePushCompleted = new RoaringBitmap()
+
+  /**
+   * Mark a given map task as push completed in the tracking bitmap.
+   * Using the bitmap ensures that the same map task launched multiple times 
due to
+   * either speculation or stage retry is only counted once.
+   * @param mapIndex Map task index
+   * @return number of map tasks with block push completed
+   */
+  def incPushCompleted(mapIndex: Int): Int = {
+    shufflePushCompleted.add(mapIndex)
+    shufflePushCompleted.getCardinality
+  }
+
+  // Only used by DAGScheduler to coordinate shuffle merge finalization
+  @transient private[this] var finalizeTask: Option[ScheduledFuture[_]] = None
+
+  def getFinalizeTask: Option[ScheduledFuture[_]] = finalizeTask
+
+  def setFinalizeTask(task: ScheduledFuture[_]): Unit = {
+    finalizeTask = Option(task)
+  }
+
   _rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this))
   _rdd.sparkContext.shuffleDriverComponents.registerShuffle(shuffleId)
 }
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala 
b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index af26abc..d71fb09 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -917,7 +917,7 @@ private[spark] class MapOutputTrackerMaster(
         Runtime.getRuntime.availableProcessors(),
         statuses.length.toLong * totalSizes.length / parallelAggThreshold + 
1).toInt
       if (parallelism <= 1) {
-        for (s <- statuses) {
+        statuses.filter(_ != null).foreach { s =>
           for (i <- 0 until totalSizes.length) {
             totalSizes(i) += s.getSizeForBlock(i)
           }
@@ -928,8 +928,8 @@ private[spark] class MapOutputTrackerMaster(
           implicit val executionContext = 
ExecutionContext.fromExecutor(threadPool)
           val mapStatusSubmitTasks = equallyDivide(totalSizes.length, 
parallelism).map {
             reduceIds => Future {
-              for (s <- statuses; i <- reduceIds) {
-                totalSizes(i) += s.getSizeForBlock(i)
+              statuses.filter(_ != null).foreach { s =>
+                reduceIds.foreach(i => totalSizes(i) += s.getSizeForBlock(i))
               }
             }
           }
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala 
b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 0388c7b..d07614a 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -32,6 +32,7 @@ import org.apache.hadoop.conf.Configuration
 import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.api.python.PythonWorkerFactory
 import org.apache.spark.broadcast.BroadcastManager
+import org.apache.spark.executor.ExecutorBackend
 import org.apache.spark.internal.{config, Logging}
 import org.apache.spark.internal.config._
 import org.apache.spark.memory.{MemoryManager, UnifiedMemoryManager}
@@ -81,6 +82,8 @@ class SparkEnv (
 
   private[spark] var driverTmpDir: Option[String] = None
 
+  private[spark] var executorBackend: Option[ExecutorBackend] = None
+
   private[spark] def stop(): Unit = {
 
     if (!isStopped) {
diff --git 
a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
 
b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 43887a7..fb7b4e6 100644
--- 
a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ 
b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -106,6 +106,7 @@ private[spark] class CoarseGrainedExecutorBackend(
     rpcEnv.asyncSetupEndpointRefByURI(driverUrl).flatMap { ref =>
       // This is a very fast action so we can use "ThreadUtils.sameThread"
       driver = Some(ref)
+      env.executorBackend = Option(this)
       ref.ask[Boolean](RegisterExecutor(executorId, self, hostname, cores, 
extractLogUrls,
         extractAttributes, _resources, resourceProfile.id))
     }(ThreadUtils.sameThread).onComplete {
@@ -162,6 +163,11 @@ private[spark] class CoarseGrainedExecutorBackend(
       .map(e => (e._1.substring(prefix.length).toUpperCase(Locale.ROOT), 
e._2)).toMap
   }
 
+  def notifyDriverAboutPushCompletion(shuffleId: Int, shuffleMergeId: Int, 
mapIndex: Int): Unit = {
+    val msg = ShufflePushCompletion(shuffleId, shuffleMergeId, mapIndex)
+    driver.foreach(_.send(msg))
+  }
+
   override def receive: PartialFunction[Any, Unit] = {
     case RegisteredExecutor =>
       logInfo("Successfully registered with driver")
diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala 
b/core/src/main/scala/org/apache/spark/internal/config/package.scala
index 71a11f6..a942ba5 100644
--- a/core/src/main/scala/org/apache/spark/internal/config/package.scala
+++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala
@@ -2193,6 +2193,33 @@ package object config {
       // with small MB sized chunk of data.
       .createWithDefaultString("3m")
 
+  private[spark] val PUSH_BASED_SHUFFLE_MERGE_FINALIZE_THREADS =
+    ConfigBuilder("spark.shuffle.push.merge.finalizeThreads")
+      .doc("Number of threads used by driver to finalize shuffle merge. Since 
it could" +
+        " potentially take seconds for a large shuffle to finalize, having 
multiple threads helps" +
+        " driver to handle concurrent shuffle merge finalize requests when 
push-based" +
+        " shuffle is enabled.")
+      .version("3.3.0")
+      .intConf
+      .createWithDefault(3)
+
+  private[spark] val PUSH_BASED_SHUFFLE_SIZE_MIN_SHUFFLE_SIZE_TO_WAIT =
+    ConfigBuilder("spark.shuffle.push.minShuffleSizeToWait")
+      .doc("Driver will wait for merge finalization to complete only if total 
shuffle size is" +
+        " more than this threshold. If total shuffle size is less, driver will 
immediately" +
+        " finalize the shuffle output")
+      .version("3.3.0")
+      .bytesConf(ByteUnit.BYTE)
+      .createWithDefaultString("500m")
+
+  private[spark] val PUSH_BASED_SHUFFLE_MIN_PUSH_RATIO =
+    ConfigBuilder("spark.shuffle.push.minCompletedPushRatio")
+      .doc("Fraction of map partitions that should be push complete before 
driver starts" +
+        " shuffle merge finalization during push based shuffle")
+      .version("3.3.0")
+      .doubleConf
+      .createWithDefault(1.0)
+
   private[spark] val JAR_IVY_REPO_PATH =
     ConfigBuilder("spark.jars.ivy")
       .doc("Path to specify the Ivy user directory, used for the local Ivy 
cache and " +
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala 
b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 4ed734c..eed71038 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -19,7 +19,7 @@ package org.apache.spark.scheduler
 
 import java.io.NotSerializableException
 import java.util.Properties
-import java.util.concurrent.{ConcurrentHashMap, TimeoutException, TimeUnit}
+import java.util.concurrent.{ConcurrentHashMap, ScheduledFuture, 
TimeoutException, TimeUnit }
 import java.util.concurrent.atomic.AtomicInteger
 
 import scala.annotation.tailrec
@@ -265,6 +265,14 @@ private[spark] class DAGScheduler(
   private val shuffleMergeFinalizeWaitSec =
     sc.getConf.get(config.PUSH_BASED_SHUFFLE_MERGE_FINALIZE_TIMEOUT)
 
+  private val shuffleMergeWaitMinSizeThreshold =
+    sc.getConf.get(config.PUSH_BASED_SHUFFLE_SIZE_MIN_SHUFFLE_SIZE_TO_WAIT)
+
+  private val shufflePushMinRatio = 
sc.getConf.get(config.PUSH_BASED_SHUFFLE_MIN_PUSH_RATIO)
+
+  private val shuffleMergeFinalizeNumThreads =
+    sc.getConf.get(config.PUSH_BASED_SHUFFLE_MERGE_FINALIZE_THREADS)
+
   // Since SparkEnv gets initialized after DAGScheduler, externalShuffleClient 
needs to be
   // initialized lazily
   private lazy val externalShuffleClient: Option[BlockStoreClient] =
@@ -274,8 +282,12 @@ private[spark] class DAGScheduler(
       None
     }
 
+  // Use multi-threaded scheduled executor. The merge finalization task could 
take some time,
+  // depending on the time to establish connections to mergers, and the time 
to get MergeStatuses
+  // from all the mergers.
   private val shuffleMergeFinalizeScheduler =
-    
ThreadUtils.newDaemonThreadPoolScheduledExecutor("shuffle-merge-finalizer", 8)
+    ThreadUtils.newDaemonThreadPoolScheduledExecutor("shuffle-merge-finalizer",
+      shuffleMergeFinalizeNumThreads)
 
   /**
    * Called by the TaskSetManager to report task's starting.
@@ -1065,6 +1077,14 @@ private[spark] class DAGScheduler(
   }
 
   /**
+   * Receives notification about shuffle push for a given shuffle from one map
+   * task has completed
+   */
+  def shufflePushCompleted(shuffleId: Int, shuffleMergeId: Int, mapIndex: 
Int): Unit = {
+    eventProcessLoop.post(ShufflePushCompleted(shuffleId, shuffleMergeId, 
mapIndex))
+  }
+
+  /**
    * Kill a given task. It will be retried.
    *
    * @return Whether the task was successfully killed.
@@ -1407,7 +1427,7 @@ private[spark] class DAGScheduler(
             // merger locations but the corresponding shuffle map stage did 
not complete
             // successfully, we would still enable push for its retry.
             s.shuffleDep.setShuffleMergeEnabled(false)
-            logInfo("Push-based shuffle disabled for $stage (${stage.name}) 
since it" +
+            logInfo(s"Push-based shuffle disabled for $stage (${stage.name}) 
since it" +
               " is already shuffle merge finalized")
           }
         }
@@ -1636,6 +1656,42 @@ private[spark] class DAGScheduler(
     }
   }
 
+  private[scheduler] def checkAndScheduleShuffleMergeFinalize(
+      shuffleStage: ShuffleMapStage): Unit = {
+    // Check if a finalize task has already been scheduled. This is to prevent 
scenarios
+    // where we don't schedule multiple shuffle merge finalization which can 
happen due to
+    // stage retry or shufflePushMinRatio is already hit etc.
+    if (shuffleStage.shuffleDep.getFinalizeTask.isEmpty) {
+      // 1. Stage indeterminate and some map outputs are not available - 
finalize
+      // immediately without registering shuffle merge results.
+      // 2. Stage determinate and some map outputs are not available - decide 
to
+      // register merge results based on map outputs size available and
+      // shuffleMergeWaitMinSizeThreshold.
+      // 3. All shuffle outputs available - decide to register merge results 
based
+      // on map outputs size available and shuffleMergeWaitMinSizeThreshold.
+      val totalSize = {
+        lazy val computedTotalSize =
+          mapOutputTracker.getStatistics(shuffleStage.shuffleDep).
+            bytesByPartitionId.filter(_ > 0).sum
+        if (shuffleStage.isAvailable) {
+          computedTotalSize
+        } else {
+          if (shuffleStage.isIndeterminate) {
+            0L
+          } else {
+            computedTotalSize
+          }
+        }
+      }
+
+      if (totalSize < shuffleMergeWaitMinSizeThreshold) {
+        scheduleShuffleMergeFinalize(shuffleStage, delay = 0, 
registerMergeResults = false)
+      } else {
+        scheduleShuffleMergeFinalize(shuffleStage, shuffleMergeFinalizeWaitSec)
+      }
+    }
+  }
+
   /**
    * Responds to a task finishing. This is called inside the event loop so it 
assumes that it can
    * modify the scheduler's internal state. Use taskEnded() to post a task end 
event from outside.
@@ -1767,7 +1823,7 @@ private[spark] class DAGScheduler(
             if (runningStages.contains(shuffleStage) && 
shuffleStage.pendingPartitions.isEmpty) {
               if (!shuffleStage.shuffleDep.shuffleMergeFinalized &&
                 shuffleStage.shuffleDep.getMergerLocs.nonEmpty) {
-                scheduleShuffleMergeFinalize(shuffleStage)
+                checkAndScheduleShuffleMergeFinalize(shuffleStage)
               } else {
                 processShuffleMapStageCompletion(shuffleStage)
               }
@@ -2074,20 +2130,63 @@ private[spark] class DAGScheduler(
   }
 
   /**
-   * Schedules shuffle merge finalize.
+   *
+   * Schedules shuffle merge finalization.
+   *
+   * @param stage the stage to finalize shuffle merge
+   * @param delay how long to wait before finalizing shuffle merge
+   * @param registerMergeResults indicate whether DAGScheduler would register 
the received
+   *                             MergeStatus with MapOutputTracker and wait to 
schedule the reduce
+   *                             stage until MergeStatus have been received 
from all mergers or
+   *                             reaches timeout. For very small shuffle, this 
could be set to
+   *                             false to avoid impact to job runtime.
    */
-  private[scheduler] def scheduleShuffleMergeFinalize(stage: ShuffleMapStage): 
Unit = {
-    // TODO: SPARK-33701: Instead of waiting for a constant amount of time for 
finalization
-    // TODO: for all the stages, adaptively tune timeout for merge finalization
-    logInfo(("%s (%s) scheduled for finalizing" +
-      " shuffle merge in %s s").format(stage, stage.name, 
shuffleMergeFinalizeWaitSec))
-    shuffleMergeFinalizeScheduler.schedule(
-      new Runnable {
-        override def run(): Unit = finalizeShuffleMerge(stage)
-      },
-      shuffleMergeFinalizeWaitSec,
-      TimeUnit.SECONDS
-    )
+  private[scheduler] def scheduleShuffleMergeFinalize(
+      stage: ShuffleMapStage,
+      delay: Long,
+      registerMergeResults: Boolean = true): Unit = {
+    val shuffleDep = stage.shuffleDep
+    val scheduledTask: Option[ScheduledFuture[_]] = shuffleDep.getFinalizeTask
+    scheduledTask match {
+      case Some(task) =>
+        // If we find an already scheduled task, check if the task has been 
triggered yet.
+        // If it's already triggered, do nothing. Otherwise, cancel it and 
schedule a new
+        // one for immediate execution. Note that we should get here only when
+        // handleShufflePushCompleted schedules a finalize task after the 
shuffle map stage
+        // completed earlier and scheduled a task with default delay.
+        // The current task should be coming from handleShufflePushCompleted, 
thus the
+        // delay should be 0 and registerMergeResults should be true.
+        assert(delay == 0 && registerMergeResults)
+        if (task.getDelay(TimeUnit.NANOSECONDS) > 0 && task.cancel(false)) {
+          logInfo(s"$stage (${stage.name}) scheduled for finalizing shuffle 
merge immediately " +
+            s"after cancelling previously scheduled task.")
+          shuffleDep.setFinalizeTask(
+            shuffleMergeFinalizeScheduler.schedule(
+              new Runnable {
+                override def run(): Unit = finalizeShuffleMerge(stage, 
registerMergeResults)
+              },
+              0,
+              TimeUnit.SECONDS
+            )
+          )
+        } else {
+          logInfo(s"$stage (${stage.name}) existing scheduled task for 
finalizing shuffle merge" +
+            s"would either be in-progress or finished. No need to schedule 
shuffle merge" +
+            s" finalization again.")
+        }
+      case None =>
+        // If no previous finalization task is scheduled, schedule the 
finalization task.
+        logInfo(s"$stage (${stage.name}) scheduled for finalizing shuffle 
merge in $delay s")
+        shuffleDep.setFinalizeTask(
+          shuffleMergeFinalizeScheduler.schedule(
+            new Runnable {
+              override def run(): Unit = finalizeShuffleMerge(stage, 
registerMergeResults)
+            },
+            delay,
+            TimeUnit.SECONDS
+          )
+        )
+    }
   }
 
   /**
@@ -2095,38 +2194,72 @@ private[spark] class DAGScheduler(
    * the given shuffle map stage to finalize the shuffle merge process for 
this shuffle. This is
    * invoked in a separate thread to reduce the impact on the DAGScheduler 
main thread, as the
    * scheduler might need to talk to 1000s of shuffle services to finalize 
shuffle merge.
+   *
+   * @param stage ShuffleMapStage to finalize shuffle merge for
+   * @param registerMergeResults indicate whether DAGScheduler would register 
the received
+   *                             MergeStatus with MapOutputTracker and wait to 
schedule the reduce
+   *                             stage until MergeStatus have been received 
from all mergers or
+   *                             reaches timeout. For very small shuffle, this 
could be set to
+   *                             false to avoid impact to job runtime.
    */
-  private[scheduler] def finalizeShuffleMerge(stage: ShuffleMapStage): Unit = {
-    logInfo("%s (%s) finalizing the shuffle merge".format(stage, stage.name))
+  private[scheduler] def finalizeShuffleMerge(
+      stage: ShuffleMapStage,
+      registerMergeResults: Boolean = true): Unit = {
+    logInfo(s"$stage (${stage.name}) finalizing the shuffle merge with 
registering merge " +
+      s"results set to $registerMergeResults")
+    val shuffleId = stage.shuffleDep.shuffleId
+    val shuffleMergeId = stage.shuffleDep.shuffleMergeId
+    val numMergers = stage.shuffleDep.getMergerLocs.length
+    val results = (0 until numMergers).map(_ => 
SettableFuture.create[Boolean]())
     externalShuffleClient.foreach { shuffleClient =>
-      val shuffleId = stage.shuffleDep.shuffleId
-      val numMergers = stage.shuffleDep.getMergerLocs.length
-      val results = (0 until numMergers).map(_ => 
SettableFuture.create[Boolean]())
-
-      stage.shuffleDep.getMergerLocs.zipWithIndex.foreach {
-        case (shuffleServiceLoc, index) =>
-          // Sends async request to shuffle service to finalize shuffle merge 
on that host
-          // TODO: SPARK-35536: Cancel finalizeShuffleMerge if the stage is 
cancelled
-          // TODO: during shuffleMergeFinalizeWaitSec
-          shuffleClient.finalizeShuffleMerge(shuffleServiceLoc.host,
-            shuffleServiceLoc.port, shuffleId, stage.shuffleDep.shuffleMergeId,
-            new MergeFinalizerListener {
-              override def onShuffleMergeSuccess(statuses: MergeStatuses): 
Unit = {
-                assert(shuffleId == statuses.shuffleId)
-                eventProcessLoop.post(RegisterMergeStatuses(stage, MergeStatus.
-                  convertMergeStatusesToMergeStatusArr(statuses, 
shuffleServiceLoc)))
-                results(index).set(true)
-              }
+      if (!registerMergeResults) {
+        results.foreach(_.set(true))
+        // Finalize in separate thread as shuffle merge is a no-op in this case
+        shuffleMergeFinalizeScheduler.schedule(new Runnable {
+          override def run(): Unit = {
+            stage.shuffleDep.getMergerLocs.foreach {
+              case shuffleServiceLoc =>
+                // Sends async request to shuffle service to finalize shuffle 
merge on that host.
+                // Since merge statuses will not be registered in this case,
+                // we pass a no-op listener.
+                shuffleClient.finalizeShuffleMerge(shuffleServiceLoc.host,
+                  shuffleServiceLoc.port, shuffleId, shuffleMergeId,
+                  new MergeFinalizerListener {
+                    override def onShuffleMergeSuccess(statuses: 
MergeStatuses): Unit = {
+                    }
 
-              override def onShuffleMergeFailure(e: Throwable): Unit = {
-                logWarning(s"Exception encountered when trying to finalize 
shuffle " +
-                  s"merge on ${shuffleServiceLoc.host} for shuffle 
$shuffleId", e)
-                // Do not fail the future as this would cause dag scheduler to 
prematurely
-                // give up on waiting for merge results from the remaining 
shuffle services
-                // if one fails
-                results(index).set(false)
-              }
-            })
+                    override def onShuffleMergeFailure(e: Throwable): Unit = {
+                    }
+                  })
+            }
+          }
+        }, 0, TimeUnit.SECONDS)
+      } else {
+        stage.shuffleDep.getMergerLocs.zipWithIndex.foreach {
+          case (shuffleServiceLoc, index) =>
+            // Sends async request to shuffle service to finalize shuffle 
merge on that host
+            // TODO: SPARK-35536: Cancel finalizeShuffleMerge if the stage is 
cancelled
+            // TODO: during shuffleMergeFinalizeWaitSec
+            shuffleClient.finalizeShuffleMerge(shuffleServiceLoc.host,
+              shuffleServiceLoc.port, shuffleId, shuffleMergeId,
+              new MergeFinalizerListener {
+                override def onShuffleMergeSuccess(statuses: MergeStatuses): 
Unit = {
+                  assert(shuffleId == statuses.shuffleId)
+                  eventProcessLoop.post(RegisterMergeStatuses(stage, 
MergeStatus.
+                    convertMergeStatusesToMergeStatusArr(statuses, 
shuffleServiceLoc)))
+                  results(index).set(true)
+                }
+
+                override def onShuffleMergeFailure(e: Throwable): Unit = {
+                  logWarning(s"Exception encountered when trying to finalize 
shuffle " +
+                    s"merge on ${shuffleServiceLoc.host} for shuffle 
$shuffleId", e)
+                  // Do not fail the future as this would cause dag scheduler 
to prematurely
+                  // give up on waiting for merge results from the remaining 
shuffle services
+                  // if one fails
+                  results(index).set(false)
+                }
+              })
+        }
       }
       // DAGScheduler only waits for a limited amount of time for the merge 
results.
       // It will attempt to submit the next stage(s) irrespective of whether 
merge results
@@ -2185,15 +2318,45 @@ private[spark] class DAGScheduler(
     }
   }
 
-  private[scheduler] def handleShuffleMergeFinalized(stage: ShuffleMapStage): 
Unit = {
-    // Only update MapOutputTracker metadata if the stage is still active. i.e 
not cancelled.
-    if (runningStages.contains(stage)) {
-      stage.shuffleDep.markShuffleMergeFinalized()
-      processShuffleMapStageCompletion(stage)
-    } else {
-      // Unregister all merge results if the stage is currently not
-      // active (i.e. the stage is cancelled)
-      mapOutputTracker.unregisterAllMergeResult(stage.shuffleDep.shuffleId)
+  private[scheduler] def handleShuffleMergeFinalized(stage: ShuffleMapStage,
+        shuffleMergeId: Int): Unit = {
+    // Check if update is for the same merge id - finalization might have 
completed for an earlier
+    // adaptive attempt while the stage might have failed/killed and shuffle 
id is getting
+    // re-executing now.
+    if (stage.shuffleDep.shuffleMergeId == shuffleMergeId) {
+      if (stage.pendingPartitions.isEmpty) {
+        if (runningStages.contains(stage)) {
+          stage.shuffleDep.markShuffleMergeFinalized()
+          processShuffleMapStageCompletion(stage)
+        } else {
+          // Unregister all merge results if the stage is currently not
+          // active (i.e. the stage is cancelled)
+          mapOutputTracker.unregisterAllMergeResult(stage.shuffleDep.shuffleId)
+        }
+      } else {
+        // stage still running, mark merge finalized. Stage completion will 
invoke
+        // processShuffleMapStageCompletion
+        stage.shuffleDep.markShuffleMergeFinalized()
+      }
+    }
+  }
+
+  private[scheduler] def handleShufflePushCompleted(
+      shuffleId: Int, shuffleMergeId: Int, mapIndex: Int): Unit = {
+    shuffleIdToMapStage.get(shuffleId) match {
+      case Some(mapStage) =>
+        val shuffleDep = mapStage.shuffleDep
+        // Only update shufflePushCompleted events for the current active 
stage map tasks.
+        // This is required to prevent shuffle merge finalization by dangling 
tasks of a
+        // previous attempt in the case of indeterminate stage.
+        if (shuffleDep.shuffleMergeId == shuffleMergeId) {
+          if (!shuffleDep.shuffleMergeFinalized &&
+            shuffleDep.incPushCompleted(mapIndex).toDouble / 
shuffleDep.rdd.partitions.length
+              >= shufflePushMinRatio) {
+            scheduleShuffleMergeFinalize(mapStage, delay = 0)
+          }
+        }
+      case None =>
     }
   }
 
@@ -2649,7 +2812,10 @@ private[scheduler] class 
DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
       dagScheduler.handleRegisterMergeStatuses(stage, mergeStatuses)
 
     case ShuffleMergeFinalized(stage) =>
-      dagScheduler.handleShuffleMergeFinalized(stage)
+      dagScheduler.handleShuffleMergeFinalized(stage, 
stage.shuffleDep.shuffleMergeId)
+
+    case ShufflePushCompleted(shuffleId, shuffleMergeId, mapIndex) =>
+      dagScheduler.handleShufflePushCompleted(shuffleId, shuffleMergeId, 
mapIndex)
   }
 
   override def onError(e: Throwable): Unit = {
diff --git 
a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala 
b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index 307844c..f3798da 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -111,3 +111,7 @@ private[scheduler] case class RegisterMergeStatuses(
 
 private[scheduler] case class ShuffleMergeFinalized(stage: ShuffleMapStage)
   extends DAGSchedulerEvent
+
+private[scheduler] case class ShufflePushCompleted(
+    shuffleId: Int, shuffleMergeId: Int, mapIndex: Int)
+  extends DAGSchedulerEvent
diff --git 
a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
 
b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
index 66ac40f..61ee865 100644
--- 
a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
+++ 
b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala
@@ -85,6 +85,9 @@ private[spark] object CoarseGrainedClusterMessages {
     }
   }
 
+  case class ShufflePushCompletion(shuffleId: Int, shuffleMergeId: Int, 
mapIndex: Int)
+    extends CoarseGrainedClusterMessage
+
   // Internal messages in driver
   case object ReviveOffers extends CoarseGrainedClusterMessage
 
diff --git 
a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
 
b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 326ea83..13a7183 100644
--- 
a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ 
b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -168,6 +168,9 @@ class CoarseGrainedSchedulerBackend(scheduler: 
TaskSchedulerImpl, val rpcEnv: Rp
           }
         }
 
+      case ShufflePushCompletion(shuffleId, shuffleMergeId, mapIndex) =>
+        scheduler.dagScheduler.shufflePushCompleted(shuffleId, shuffleMergeId, 
mapIndex)
+
       case ReviveOffers =>
         makeOffers()
 
diff --git 
a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala 
b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala
index 8790371..d6972cd 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala
@@ -26,6 +26,7 @@ import scala.collection.mutable.{ArrayBuffer, HashMap, 
HashSet, Queue}
 
 import org.apache.spark.{ShuffleDependency, SparkConf, SparkContext, SparkEnv}
 import org.apache.spark.annotation.Since
+import org.apache.spark.executor.{CoarseGrainedExecutorBackend, 
ExecutorBackend}
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config._
 import org.apache.spark.launcher.SparkLauncher
@@ -53,7 +54,7 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) 
extends Logging {
   private[this] val maxBytesInFlight = conf.get(REDUCER_MAX_SIZE_IN_FLIGHT) * 
1024 * 1024
   private[this] val maxReqsInFlight = conf.get(REDUCER_MAX_REQS_IN_FLIGHT)
   private[this] val maxBlocksInFlightPerAddress = 
conf.get(REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS)
-  private[this] var bytesInFlight = 0L
+  private[shuffle] var bytesInFlight = 0L
   private[this] var reqsInFlight = 0
   private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId, 
Int]()
   private[this] val deferredPushRequests = new HashMap[BlockManagerId, 
Queue[PushRequest]]()
@@ -61,6 +62,10 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) 
extends Logging {
   private[this] val errorHandler = createErrorHandler()
   // VisibleForTesting
   private[shuffle] val unreachableBlockMgrs = new HashSet[BlockManagerId]()
+  private[this] var shuffleId = -1
+  private[this] var mapIndex = -1
+  private[this] var shuffleMergeId = -1
+  private[this] var pushCompletionNotified = false
 
   // VisibleForTesting
   private[shuffle] def createErrorHandler(): BlockPushErrorHandler = {
@@ -84,6 +89,8 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) 
extends Logging {
       }
     }
   }
+  // VisibleForTesting
+  private[shuffle] def isPushCompletionNotified = pushCompletionNotified
 
   /**
    * Initiates the block push.
@@ -101,11 +108,17 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) 
extends Logging {
       mapIndex: Int): Unit = {
     val numPartitions = dep.partitioner.numPartitions
     val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle")
+    this.shuffleId = dep.shuffleId
+    this.shuffleMergeId = dep.shuffleMergeId
+    this.mapIndex = mapIndex
     val requests = prepareBlockPushRequests(numPartitions, mapIndex, 
dep.shuffleId,
       dep.shuffleMergeId, dataFile, partitionLengths, dep.getMergerLocs, 
transportConf)
     // Randomize the orders of the PushRequest, so different mappers pushing 
blocks at the same
     // time won't be pushing the same ranges of shuffle partitions.
     pushRequests ++= Utils.randomize(requests)
+    if (pushRequests.isEmpty) {
+      notifyDriverAboutPushCompletion()
+    }
 
     submitTask(() => {
       tryPushUpToMax()
@@ -327,11 +340,35 @@ private[spark] class ShuffleBlockPusher(conf: SparkConf) 
extends Logging {
         s"stop.")
       return false
     } else {
+      if (reqsInFlight <= 0 && pushRequests.isEmpty && 
deferredPushRequests.isEmpty) {
+        notifyDriverAboutPushCompletion()
+      }
       remainingBlocks.isEmpty && (pushRequests.nonEmpty || 
deferredPushRequests.nonEmpty)
     }
   }
 
   /**
+   * Notify the driver about all the blocks generated by the current map task 
having been pushed.
+   * This enables the DAGScheduler to finalize shuffle merge as soon as 
sufficient map tasks have
+   * completed push instead of always waiting for a fixed amount of time.
+   *
+   * VisibleForTesting
+   */
+  protected def notifyDriverAboutPushCompletion(): Unit = {
+    assert(shuffleId >= 0 && mapIndex >= 0)
+    if (!pushCompletionNotified) {
+      SparkEnv.get.executorBackend match {
+        case Some(cb: CoarseGrainedExecutorBackend) =>
+          cb.notifyDriverAboutPushCompletion(shuffleId, shuffleMergeId, 
mapIndex)
+        case Some(eb: ExecutorBackend) =>
+          logWarning(s"Currently $eb doesn't support push-based shuffle")
+        case None =>
+      }
+      pushCompletionNotified = true
+    }
+  }
+
+  /**
    * Convert the shuffle data file of the current mapper into a list of 
PushRequest. Basically,
    * continuous blocks in the shuffle file are grouped into a single request 
to allow more
    * efficient read of the block data. Each mapper for a given shuffle will 
receive the same
diff --git 
a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala 
b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index afea912..76612cb 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.scheduler
 
 import java.util.Properties
-import java.util.concurrent.{CountDownLatch, TimeUnit}
+import java.util.concurrent.{CountDownLatch, Delayed, ScheduledFuture, 
TimeUnit}
 import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong, AtomicReference}
 
 import scala.annotation.meta.param
@@ -124,6 +124,31 @@ class MyRDD(
   override def toString: String = "DAGSchedulerSuiteRDD " + id
 }
 
+class DummyScheduledFuture(
+    val delay: Long,
+    val registerMergeResults: Boolean)
+  extends ScheduledFuture[Int] {
+
+  override def get(timeout: Long, unit: TimeUnit): Int =
+    throw new IllegalStateException("should not be reached")
+
+  override def getDelay(unit: TimeUnit): Long = delay
+
+  override def compareTo(o: Delayed): Int =
+    throw new IllegalStateException("should not be reached")
+
+  override def cancel(mayInterruptIfRunning: Boolean): Boolean = true
+
+  override def isCancelled: Boolean =
+    throw new IllegalStateException("should not be reached")
+
+  override def isDone: Boolean =
+    throw new IllegalStateException("should not be reached")
+
+  override def get(): Int =
+    throw new IllegalStateException("should not be reached")
+}
+
 class DAGSchedulerSuiteDummyException extends Exception
 
 class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with 
TimeLimits {
@@ -312,16 +337,27 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
      * Schedules shuffle merge finalize.
      */
     override private[scheduler] def scheduleShuffleMergeFinalize(
-        shuffleMapStage: ShuffleMapStage): Unit = {
-      if (shuffleMergeRegister) {
+        shuffleMapStage: ShuffleMapStage,
+        delay: Long,
+        registerMergeResults: Boolean = true): Unit = {
+      if (shuffleMergeRegister && registerMergeResults) {
         for (part <- 0 until 
shuffleMapStage.shuffleDep.partitioner.numPartitions) {
           val mergeStatuses = Seq((part, makeMergeStatus("",
             shuffleMapStage.shuffleDep.shuffleMergeId)))
           handleRegisterMergeStatuses(shuffleMapStage, mergeStatuses)
         }
-        if (shuffleMergeFinalize) {
-          handleShuffleMergeFinalized(shuffleMapStage)
-        }
+      }
+
+      shuffleMapStage.shuffleDep.getFinalizeTask match {
+        case Some(_) =>
+          assert(delay == 0 && registerMergeResults)
+        case None =>
+      }
+
+      shuffleMapStage.shuffleDep.setFinalizeTask(
+          new DummyScheduledFuture(delay, registerMergeResults))
+      if (shuffleMergeFinalize) {
+        handleShuffleMergeFinalized(shuffleMapStage, 
shuffleMapStage.shuffleDep.shuffleMergeId)
       }
     }
   }
@@ -472,6 +508,12 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     assert(this.results === expected)
   }
 
+  /** Sends ShufflePushCompleted to the DAG scheduler. */
+  private def pushComplete(
+      shuffleId: Int, shuffleMergeId: Int, mapIndex: Int): Unit = {
+    runEvent(ShufflePushCompleted(shuffleId, shuffleMergeId, mapIndex))
+  }
+
   test("[SPARK-3353] parent stage should have lower stage id") {
     sc.parallelize(1 to 10).map(x => (x, x)).reduceByKey(_ + _, 4).count()
     val stageByOrderOfExecution = sparkListener.stageByOrderOfExecution
@@ -3428,6 +3470,7 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
   private def initPushBasedShuffleConfs(conf: SparkConf) = {
     conf.set(config.SHUFFLE_SERVICE_ENABLED, true)
     conf.set(config.PUSH_BASED_SHUFFLE_ENABLED, true)
+    conf.set(config.PUSH_BASED_SHUFFLE_SIZE_MIN_SHUFFLE_SIZE_TO_WAIT, 1L)
     conf.set("spark.master", "pushbasedshuffleclustermanager")
     // Needed to run push-based shuffle tests in ad-hoc manner through IDE
     conf.set(Tests.IS_TESTING, true)
@@ -3439,7 +3482,7 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
 
   test("SPARK-32920: shuffle merge finalization") {
     initPushBasedShuffleConfs(conf)
-    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.clearMergerLocs()
     DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
     val parts = 2
     val shuffleMapRdd = new MyRDD(sc, parts, Nil)
@@ -3459,7 +3502,7 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
   test("SPARK-32920: merger locations not empty") {
     initPushBasedShuffleConfs(conf)
     conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 3)
-    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.clearMergerLocs()
     DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
     val parts = 2
 
@@ -3484,7 +3527,7 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
   test("SPARK-32920: merger locations reuse from shuffle dependency") {
     initPushBasedShuffleConfs(conf)
     conf.set(config.SHUFFLE_MERGER_MAX_RETAINED_LOCATIONS, 3)
-    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.clearMergerLocs()
     DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
     val parts = 2
 
@@ -3524,7 +3567,7 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
   test("SPARK-32920: Disable shuffle merge due to not enough mergers 
available") {
     initPushBasedShuffleConfs(conf)
     conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 6)
-    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.clearMergerLocs()
     DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
     val parts = 7
 
@@ -3548,7 +3591,7 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
   test("SPARK-32920: Ensure child stage should not start before all the" +
       " parent stages are completed with shuffle merge finalized for all the 
parent stages") {
     initPushBasedShuffleConfs(conf)
-    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.clearMergerLocs()
     DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
     val parts = 1
     val shuffleMapRdd1 = new MyRDD(sc, parts, Nil)
@@ -3585,7 +3628,7 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
       " ShuffleDependency should not cause DAGScheduler to hang") {
     initPushBasedShuffleConfs(conf)
     conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 10)
-    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.clearMergerLocs()
     DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
     val parts = 20
 
@@ -3616,7 +3659,7 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
       " ShuffleDependency with shuffle data loss should recompute missing 
partitions") {
     initPushBasedShuffleConfs(conf)
     conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 10)
-    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.clearMergerLocs()
     DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
     val parts = 20
 
@@ -3632,7 +3675,7 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
 
     completeNextResultStageWithSuccess(1, 0)
 
-    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.clearMergerLocs()
     val hosts = (6 to parts).map {x => s"Host$x" }
     DAGSchedulerSuite.addMergerLocs(hosts)
 
@@ -3669,7 +3712,7 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
   test("SPARK-32920: Merge results should be unregistered if the running stage 
is cancelled" +
     " before shuffle merge is finalized") {
     initPushBasedShuffleConfs(conf)
-    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.clearMergerLocs()
     DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
     scheduler = new MyDAGScheduler(
       sc,
@@ -3697,14 +3740,15 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep.shuffleId) 
== parts)
     val shuffleMapStageToCancel = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
     runEvent(StageCancelled(0, Option("Explicit cancel check")))
-    scheduler.handleShuffleMergeFinalized(shuffleMapStageToCancel)
+    scheduler.handleShuffleMergeFinalized(shuffleMapStageToCancel,
+      shuffleMapStageToCancel.shuffleDep.shuffleMergeId)
     assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep.shuffleId) 
== 0)
   }
 
   test("SPARK-32920: SPARK-35549: Merge results should not get registered" +
     " after shuffle merge finalization") {
     initPushBasedShuffleConfs(conf)
-    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.clearMergerLocs()
     DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
 
     scheduler = new MyDAGScheduler(
@@ -3733,7 +3777,8 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     val shuffleMapStage = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
     scheduler.handleRegisterMergeStatuses(shuffleMapStage, Seq((0, 
makeMergeStatus("hostA",
       shuffleDep.shuffleMergeId))))
-    scheduler.handleShuffleMergeFinalized(shuffleMapStage)
+    scheduler.handleShuffleMergeFinalized(shuffleMapStage,
+      shuffleMapStage.shuffleDep.shuffleMergeId)
     scheduler.handleRegisterMergeStatuses(shuffleMapStage, Seq((1, 
makeMergeStatus("hostA",
       shuffleDep.shuffleMergeId))))
     assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep.shuffleId) 
== 1)
@@ -3741,7 +3786,7 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
 
   test("SPARK-32920: Disable push based shuffle in the case of a barrier 
stage") {
     initPushBasedShuffleConfs(conf)
-    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.clearMergerLocs()
     DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
 
     val parts = 2
@@ -3788,7 +3833,7 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
 
   test("SPARK-32923: handle stage failure for indeterminate map stage with 
push-based shuffle") {
     initPushBasedShuffleConfs(conf)
-    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.clearMergerLocs()
     DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
     val (shuffleId1, shuffleId2) = constructIndeterminateStageFetchFailed()
 
@@ -3847,11 +3892,262 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
 
     // Job successful ended.
     assert(results === Map(0 -> 11, 1 -> 12))
+  }
+
+  test("SPARK-33701: check adaptive shuffle merge finalization triggered 
after" +
+    " stage completion") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 3)
+    DAGSchedulerSuite.clearMergerLocs()
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
+    val parts = 2
+
+    val shuffleMapRdd1 = new MyRDD(sc, parts, Nil)
+    val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new 
HashPartitioner(parts))
+    val shuffleMapRdd2 = new MyRDD(sc, parts, Nil)
+    val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep1, shuffleDep2),
+      tracker = mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+
+    val taskResults = taskSets(0).tasks.zipWithIndex.map {
+      case (_, idx) =>
+        (Success, makeMapStatus("host" + ('A' + idx).toChar, parts))
+    }.toSeq
+    for ((result, i) <- taskResults.zipWithIndex) {
+      runEvent(makeCompletionEvent(taskSets(0).tasks(i), result._1, result._2))
+    }
+    // Verify finalize task is set with default delay of 10s and merge results 
are marked
+    // for registration
+    val shuffleStage1 = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    assert(shuffleStage1.shuffleDep.shuffleMergeEnabled)
+    val finalizeTask1 = shuffleStage1.shuffleDep.getFinalizeTask.get
+      .asInstanceOf[DummyScheduledFuture]
+    assert(finalizeTask1.delay == 10 && finalizeTask1.registerMergeResults)
+    assert(shuffleStage1.shuffleDep.shuffleMergeFinalized)
+
+    complete(taskSets(1), taskSets(1).tasks.zipWithIndex.map {
+      case (_, idx) =>
+        (Success, makeMapStatus("host" + ('A' + idx).toChar, parts, 10))
+    }.toSeq)
+    val shuffleStage2 = 
scheduler.stageIdToStage(1).asInstanceOf[ShuffleMapStage]
+    assert(shuffleStage2.shuffleDep.getFinalizeTask.nonEmpty)
+    val finalizeTask2 = shuffleStage2.shuffleDep.getFinalizeTask.get
+      .asInstanceOf[DummyScheduledFuture]
+    assert(finalizeTask2.delay == 10 && finalizeTask2.registerMergeResults)
+
+    assert(mapOutputTracker.
+      getNumAvailableMergeResults(shuffleStage1.shuffleDep.shuffleId) == parts)
+    assert(mapOutputTracker.
+      getNumAvailableMergeResults(shuffleStage2.shuffleDep.shuffleId) == parts)
+    completeNextResultStageWithSuccess(2, 0)
+    assert(results === Map(0 -> 42, 1 -> 42))
+
     results.clear()
     assertDataStructuresEmpty()
   }
 
-  /**
+  test("SPARK-33701: check adaptive shuffle merge finalization triggered after 
minimum" +
+    " threshold push complete") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.PUSH_BASED_SHUFFLE_SIZE_MIN_SHUFFLE_SIZE_TO_WAIT, 10L)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 5)
+    conf.set(config.PUSH_BASED_SHUFFLE_MIN_PUSH_RATIO, 0.5)
+    DAGSchedulerSuite.clearMergerLocs()
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
+    val parts = 4
+
+    val shuffleMapRdd1 = new MyRDD(sc, parts, Nil)
+    val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new 
HashPartitioner(parts))
+    val shuffleMapRdd2 = new MyRDD(sc, parts, Nil)
+    val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep1, shuffleDep2),
+      tracker = mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+
+    val taskResults = taskSets(0).tasks.zipWithIndex.map {
+      case (_, idx) =>
+        (Success, makeMapStatus("host" + ('A' + idx).toChar, parts))
+    }.toSeq
+
+    runEvent(makeCompletionEvent(taskSets(0).tasks(0), taskResults(0)._1, 
taskResults(0)._2))
+    runEvent(makeCompletionEvent(taskSets(0).tasks(1), taskResults(0)._1, 
taskResults(0)._2))
+
+    val shuffleStage1 = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    assert(shuffleStage1.shuffleDep.shuffleMergeEnabled)
+
+    pushComplete(shuffleStage1.shuffleDep.shuffleId, 0, 0)
+    pushComplete(shuffleStage1.shuffleDep.shuffleId, 0, 1)
+
+    // Minimum push complete for 2 tasks, should have scheduled merge 
finalization
+    val finalizeTask = shuffleStage1.shuffleDep.getFinalizeTask.get
+      .asInstanceOf[DummyScheduledFuture]
+    assert(finalizeTask.registerMergeResults && finalizeTask.delay == 0)
+
+    runEvent(makeCompletionEvent(taskSets(0).tasks(2), taskResults(0)._1, 
taskResults(0)._2))
+    runEvent(makeCompletionEvent(taskSets(0).tasks(3), taskResults(0)._1, 
taskResults(0)._2))
+
+    completeShuffleMapStageSuccessfully(1, 0, parts)
+
+    completeNextResultStageWithSuccess(2, 0)
+    assert(results === Map(0 -> 42, 1 -> 42, 2 -> 42, 3 -> 42))
+
+    results.clear()
+    assertDataStructuresEmpty()
+  }
+
+  // Test the behavior of stage cancellation during the 
spark.shuffle.push.finalize.timeout
+  // wait for shuffle merge finalization, there are 2 different cases:
+  // 1. Deterministic stage - With deterministic stage, the shuffleMergeId = 0 
for multiple
+  // stage attempts, so if the stage is cancelled before shuffle is merge 
finalized then
+  // the merge results are unregistered from MapOutputTracker
+  // 2. Indeterminate stage - Different attempt of the same stage can trigger 
shuffle merge
+  // finalization but it is validated by the shuffleMergeId (unique across 
stages and stage
+  // attempts for indeterminate stages) and only the shuffle merge is finalized
+  test("SPARK-33701: check adaptive shuffle merge finalization behavior with 
stage" +
+    " cancellation during spark.shuffle.push.finalize.timeout wait") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.PUSH_BASED_SHUFFLE_SIZE_MIN_SHUFFLE_SIZE_TO_WAIT, 10L)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 5)
+    conf.set(config.PUSH_BASED_SHUFFLE_MIN_PUSH_RATIO, 0.5)
+    DAGSchedulerSuite.clearMergerLocs()
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
+    val parts = 4
+
+    scheduler = new MyDAGScheduler(
+      sc,
+      taskScheduler,
+      sc.listenerBus,
+      mapOutputTracker,
+      blockManagerMaster,
+      sc.env,
+      shuffleMergeFinalize = false)
+    dagEventProcessLoopTester = new 
DAGSchedulerEventProcessLoopTester(scheduler)
+
+    // Determinate stage
+    val shuffleMapRdd1 = new MyRDD(sc, parts, Nil)
+    val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new 
HashPartitioner(parts))
+    val shuffleMapRdd2 = new MyRDD(sc, parts, Nil)
+    val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep1, shuffleDep2),
+      tracker = mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+
+    val taskResults = taskSets(0).tasks.zipWithIndex.map {
+      case (_, idx) =>
+        (Success, makeMapStatus("host" + ('A' + idx).toChar, parts))
+    }.toSeq
+
+    for ((result, i) <- taskResults.zipWithIndex) {
+      runEvent(makeCompletionEvent(taskSets(0).tasks(i), result._1, result._2))
+    }
+    val shuffleStage1 = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    runEvent(StageCancelled(0, Option("Explicit cancel check")))
+    scheduler.handleShuffleMergeFinalized(shuffleStage1, 
shuffleStage1.shuffleDep.shuffleMergeId)
+
+    assert(shuffleStage1.shuffleDep.shuffleMergeEnabled)
+    assert(!shuffleStage1.shuffleDep.shuffleMergeFinalized)
+    assert(mapOutputTracker.
+      getNumAvailableMergeResults(shuffleStage1.shuffleDep.shuffleId) == 0)
+
+    // Indeterminate stage
+    val shuffleMapIndeterminateRdd1 = new MyRDD(sc, parts, Nil, indeterminate 
= true)
+    val shuffleIndeterminateDep1 = new ShuffleDependency(
+      shuffleMapIndeterminateRdd1, new HashPartitioner(parts))
+    val shuffleMapIndeterminateRdd2 = new MyRDD(sc, parts, Nil, indeterminate 
= true)
+    val shuffleIndeterminateDep2 = new ShuffleDependency(
+      shuffleMapIndeterminateRdd2, new HashPartitioner(parts))
+    val reduceIndeterminateRdd = new MyRDD(sc, parts, List(
+      shuffleIndeterminateDep1, shuffleIndeterminateDep2), tracker = 
mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceIndeterminateRdd, (0 until parts).toArray)
+
+    val indeterminateResults = taskSets(0).tasks.zipWithIndex.map {
+      case (_, idx) =>
+        (Success, makeMapStatus("host" + ('A' + idx).toChar, parts))
+    }.toSeq
+
+    for ((result, i) <- indeterminateResults.zipWithIndex) {
+      runEvent(makeCompletionEvent(taskSets(0).tasks(i), result._1, result._2))
+    }
+
+    val shuffleIndeterminateStage = 
scheduler.stageIdToStage(3).asInstanceOf[ShuffleMapStage]
+    assert(shuffleIndeterminateStage.isIndeterminate)
+    scheduler.handleShuffleMergeFinalized(shuffleIndeterminateStage, 2)
+    assert(shuffleIndeterminateStage.shuffleDep.shuffleMergeEnabled)
+    assert(!shuffleIndeterminateStage.shuffleDep.shuffleMergeFinalized)
+  }
+
+  // With Adaptive shuffle merge finalization, once minimum shuffle pushes 
complete after stage
+  // completion, the existing shuffle merge finalization task with
+  // delay = spark.shuffle.push.finalize.timeout should be replaced with a new 
shuffle merge
+  // finalization task with delay = 0
+  test("SPARK-33701: check adaptive shuffle merge finalization with minimum 
pushes complete" +
+    " after the stage completion replacing the finalize task with delay = 0") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.PUSH_BASED_SHUFFLE_SIZE_MIN_SHUFFLE_SIZE_TO_WAIT, 10L)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 5)
+    conf.set(config.PUSH_BASED_SHUFFLE_MIN_PUSH_RATIO, 0.5)
+    DAGSchedulerSuite.clearMergerLocs()
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
+    val parts = 4
+
+    scheduler = new MyDAGScheduler(
+      sc,
+      taskScheduler,
+      sc.listenerBus,
+      mapOutputTracker,
+      blockManagerMaster,
+      sc.env,
+      shuffleMergeFinalize = false)
+    dagEventProcessLoopTester = new 
DAGSchedulerEventProcessLoopTester(scheduler)
+
+    // Determinate stage
+    val shuffleMapRdd1 = new MyRDD(sc, parts, Nil)
+    val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new 
HashPartitioner(parts))
+    val shuffleMapRdd2 = new MyRDD(sc, parts, Nil)
+    val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep1, shuffleDep2),
+      tracker = mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+
+    val taskResults = taskSets(0).tasks.zipWithIndex.map {
+      case (_, idx) =>
+        (Success, makeMapStatus("host" + ('A' + idx).toChar, parts))
+    }.toSeq
+
+    for ((result, i) <- taskResults.zipWithIndex) {
+      runEvent(makeCompletionEvent(taskSets(0).tasks(i), result._1, result._2))
+    }
+    val shuffleStage1 = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    assert(shuffleStage1.shuffleDep.shuffleMergeEnabled)
+    assert(!shuffleStage1.shuffleDep.shuffleMergeFinalized)
+    val finalizeTask1 = shuffleStage1.shuffleDep.getFinalizeTask.get.
+      asInstanceOf[DummyScheduledFuture]
+    assert(finalizeTask1.delay == 10 && finalizeTask1.registerMergeResults)
+
+    // Minimum shuffle pushes complete, replace the finalizeTask with delay = 
10
+    // with a finalizeTask with delay = 0
+    pushComplete(shuffleStage1.shuffleDep.shuffleId, 0, 0)
+    pushComplete(shuffleStage1.shuffleDep.shuffleId, 0, 1)
+
+    // Existing finalizeTask with delay = 10 should be replaced with 
finalizeTask
+    // with delay = 0
+    val finalizeTask2 = shuffleStage1.shuffleDep.getFinalizeTask.get.
+      asInstanceOf[DummyScheduledFuture]
+    assert(finalizeTask2.delay == 0 && finalizeTask2.registerMergeResults)
+  }
+
+    /**
    * Assert that the supplied TaskSet has exactly the given hosts as its 
preferred locations.
    * Note that this checks only the host and not the executor ID.
    */
@@ -3922,7 +4218,7 @@ object DAGSchedulerSuite {
     locs.foreach { loc => mergerLocs.append(makeBlockManagerId(loc)) }
   }
 
-  def clearMergerLocs: Unit = mergerLocs.clear()
+  def clearMergerLocs(): Unit = mergerLocs.clear()
 
 }
 
diff --git 
a/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala 
b/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala
index 298ba50..94c0417 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala
@@ -20,11 +20,11 @@ package org.apache.spark.shuffle
 import java.io.{File, FileNotFoundException, IOException}
 import java.net.ConnectException
 import java.nio.ByteBuffer
-import java.util.concurrent.LinkedBlockingQueue
+import java.util.concurrent.{CountDownLatch, LinkedBlockingQueue, Semaphore}
 
 import scala.collection.mutable.ArrayBuffer
 
-import org.mockito.{Mock, MockitoAnnotations}
+import org.mockito.{ArgumentMatchers, Mock, MockitoAnnotations}
 import org.mockito.Answers.RETURNS_SMART_NULLS
 import org.mockito.ArgumentMatchers.any
 import org.mockito.Mockito._
@@ -32,6 +32,8 @@ import org.mockito.invocation.InvocationOnMock
 import org.scalatest.BeforeAndAfterEach
 
 import org.apache.spark._
+import org.apache.spark.executor.CoarseGrainedExecutorBackend
+import 
org.apache.spark.internal.config.REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS
 import org.apache.spark.network.buffer.ManagedBuffer
 import org.apache.spark.network.server.BlockPushNonFatalFailure
 import org.apache.spark.network.server.BlockPushNonFatalFailure.ReturnCode
@@ -40,12 +42,14 @@ import org.apache.spark.network.util.TransportConf
 import org.apache.spark.serializer.JavaSerializer
 import org.apache.spark.shuffle.ShuffleBlockPusher.PushRequest
 import org.apache.spark.storage._
+import org.apache.spark.util.ThreadUtils
 
 class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach {
 
   @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = 
_
   @Mock(answer = RETURNS_SMART_NULLS) private var dependency: 
ShuffleDependency[Int, Int, Int] = _
   @Mock(answer = RETURNS_SMART_NULLS) private var shuffleClient: 
BlockStoreClient = _
+  @Mock(answer = RETURNS_SMART_NULLS) private var executorBackend: 
CoarseGrainedExecutorBackend = _
 
   private var conf: SparkConf = _
   private var pushedBlocks = new ArrayBuffer[String]
@@ -54,6 +58,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with 
BeforeAndAfterEach {
     super.beforeEach()
     conf = new SparkConf(loadDefaults = false)
     MockitoAnnotations.openMocks(this).close()
+    when(dependency.shuffleId).thenReturn(0)
     when(dependency.partitioner).thenReturn(new HashPartitioner(8))
     when(dependency.serializer).thenReturn(new JavaSerializer(conf))
     
when(dependency.getMergerLocs).thenReturn(Seq(BlockManagerId("test-client", 
"test-client", 1)))
@@ -62,6 +67,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with 
BeforeAndAfterEach {
     when(mockEnv.conf).thenReturn(conf)
     when(mockEnv.blockManager).thenReturn(blockManager)
     SparkEnv.set(mockEnv)
+    when(SparkEnv.get.executorBackend).thenReturn(Some(executorBackend))
     when(blockManager.blockStoreClient).thenReturn(shuffleClient)
   }
 
@@ -91,37 +97,104 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with 
BeforeAndAfterEach {
     })
   }
 
+  private def verifyBlockPushCompleted(
+      blockPusher: ShuffleBlockPusher): Unit = {
+    verify(executorBackend, times(1))
+      .notifyDriverAboutPushCompletion(dependency.shuffleId, 0, 0)
+    assert(blockPusher.isPushCompletionNotified)
+  }
+
   test("A batch of blocks is limited by maxBlocksBatchSize") {
+    interceptPushedBlocksForSuccess()
     conf.set("spark.shuffle.push.maxBlockBatchSize", "1m")
     conf.set("spark.shuffle.push.maxBlockSizeToPush", "2048k")
     val blockPusher = new TestShuffleBlockPusher(conf)
     val mergerLocs = dependency.getMergerLocs.map(loc => BlockManagerId("", 
loc.host, loc.port))
     val largeBlockSize = 2 * 1024 * 1024
+    blockPusher.initiateBlockPush(mock(classOf[File]),
+      Array.fill(dependency.partitioner.numPartitions) { 5 }, dependency, 0)
     val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0, 0,
       mock(classOf[File]), Array(2, 2, 2, largeBlockSize, largeBlockSize), 
mergerLocs,
       mock(classOf[TransportConf]))
+    blockPusher.runPendingTasks()
     assert(pushRequests.length == 3)
+    verifyBlockPushCompleted(blockPusher)
     verifyPushRequests(pushRequests, Seq(6, largeBlockSize, largeBlockSize))
   }
 
   test("Large blocks are excluded in the preparation") {
+    interceptPushedBlocksForSuccess()
     conf.set("spark.shuffle.push.maxBlockSizeToPush", "1k")
     val blockPusher = new TestShuffleBlockPusher(conf)
     val mergerLocs = dependency.getMergerLocs.map(loc => BlockManagerId("", 
loc.host, loc.port))
+    blockPusher.initiateBlockPush(mock(classOf[File]),
+      Array.fill(dependency.partitioner.numPartitions) { 5 }, dependency, 0)
     val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0, 0,
       mock(classOf[File]), Array(2, 2, 2, 1028, 1024), mergerLocs, 
mock(classOf[TransportConf]))
+    blockPusher.runPendingTasks()
     assert(pushRequests.length == 2)
     verifyPushRequests(pushRequests, Seq(6, 1024))
+    verifyBlockPushCompleted(blockPusher)
   }
 
   test("Number of blocks in a push request are limited by 
maxBlocksInFlightPerAddress ") {
+    interceptPushedBlocksForSuccess()
     conf.set("spark.reducer.maxBlocksInFlightPerAddress", "1")
     val blockPusher = new TestShuffleBlockPusher(conf)
     val mergerLocs = dependency.getMergerLocs.map(loc => BlockManagerId("", 
loc.host, loc.port))
+    blockPusher.initiateBlockPush(mock(classOf[File]),
+      Array.fill(dependency.partitioner.numPartitions) { 5 }, dependency, 0)
     val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0, 0,
       mock(classOf[File]), Array(2, 2, 2, 2, 2), mergerLocs, 
mock(classOf[TransportConf]))
+    blockPusher.runPendingTasks()
     assert(pushRequests.length == 5)
     verifyPushRequests(pushRequests, Seq(2, 2, 2, 2, 2))
+    verifyBlockPushCompleted(blockPusher)
+  }
+
+  test("SPARK-33701: Ensure all the blocks are pushed before notifying driver" 
+
+    " about push completion") {
+    conf.set(REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS, 12)
+    conf.set("spark.shuffle.push.maxBlockBatchSize", "20b")
+    val latch = new CountDownLatch(1)
+    // Different remote servers to send 2 different requests to ensure that 
all the blocks
+    // are pushed before notifying driver about push completion
+    
when(dependency.getMergerLocs).thenReturn(Seq(BlockManagerId("test-client", 
"test-client", 1),
+      BlockManagerId("slow-client", "slow-client", 1)))
+    when(shuffleClient.pushBlocks(ArgumentMatchers.eq("slow-client"), any(), 
any(), any(), any()))
+      .thenAnswer((invocation: InvocationOnMock) => {
+        val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]]
+        val blockPushListener = 
invocation.getArguments()(4).asInstanceOf[BlockPushingListener]
+        latch.await()
+        // Add a small wait here to delay the "onBlockPushSuccess" to mimic 
the real world
+        Thread.sleep(500)
+        blocks.foreach { blockId =>
+          blockPushListener.onBlockPushSuccess(blockId, 
mock(classOf[ManagedBuffer]))
+        }
+      })
+    when(shuffleClient.pushBlocks(ArgumentMatchers.eq("test-client"), any(), 
any(), any(), any()))
+      .thenAnswer((invocation: InvocationOnMock) => {
+        val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]]
+        val blockPushListener = 
invocation.getArguments()(4).asInstanceOf[BlockPushingListener]
+        latch.await()
+        blocks.foreach { blockId =>
+          blockPushListener.onBlockPushSuccess(blockId, 
mock(classOf[ManagedBuffer]))
+        }
+      })
+    val semaphore = new Semaphore(0)
+    val blockPusher = new ConcurrentTestBlockPusher(conf, semaphore)
+    val mergerLocs = dependency.getMergerLocs.map(loc => BlockManagerId("", 
loc.host, loc.port))
+    blockPusher.initiateBlockPush(mock(classOf[File]),
+      Array.fill(dependency.partitioner.numPartitions) { 5 }, dependency, 0)
+    val pushRequests = blockPusher.prepareBlockPushRequests(5, 0, 0, 0,
+      mock(classOf[File]), Array(2, 2, 2, 2, 2), mergerLocs, 
mock(classOf[TransportConf]))
+    latch.countDown()
+    latch.countDown()
+    semaphore.acquire()
+    assert(blockPusher.bytesInFlight <= 0)
+    assert(pushRequests.length == 2)
+    verifyPushRequests(pushRequests, Seq(6, 4))
+    verifyBlockPushCompleted(blockPusher)
   }
 
   test("Basic block push") {
@@ -133,6 +206,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with 
BeforeAndAfterEach {
     verify(shuffleClient, times(1))
       .pushBlocks(any(), any(), any(), any(), any())
     assert(pushedBlocks.length == dependency.partitioner.numPartitions)
+    verifyBlockPushCompleted(blockPusher)
     ShuffleBlockPusher.stop()
   }
 
@@ -146,6 +220,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with 
BeforeAndAfterEach {
     verify(shuffleClient, times(1))
       .pushBlocks(any(), any(), any(), any(), any())
     assert(pushedBlocks.length == dependency.partitioner.numPartitions - 1)
+    verifyBlockPushCompleted(pusher)
     ShuffleBlockPusher.stop()
   }
 
@@ -159,6 +234,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with 
BeforeAndAfterEach {
     verify(shuffleClient, times(8))
       .pushBlocks(any(), any(), any(), any(), any())
     assert(pushedBlocks.length == dependency.partitioner.numPartitions)
+    verifyBlockPushCompleted(pusher)
     ShuffleBlockPusher.stop()
   }
 
@@ -199,6 +275,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with 
BeforeAndAfterEach {
     verify(shuffleClient, times(4))
       .pushBlocks(any(), any(), any(), any(), any())
     assert(pushedBlocks.length == 8)
+    verifyBlockPushCompleted(pusher)
     ShuffleBlockPusher.stop()
   }
 
@@ -213,6 +290,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with 
BeforeAndAfterEach {
     verify(shuffleClient, times(4))
       .pushBlocks(any(), any(), any(), any(), any())
     assert(pushedBlocks.length == dependency.partitioner.numPartitions)
+    verifyBlockPushCompleted(pusher)
     ShuffleBlockPusher.stop()
   }
 
@@ -279,6 +357,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with 
BeforeAndAfterEach {
     verify(shuffleClient, times(8))
       .pushBlocks(any(), any(), any(), any(), any())
     assert(pushedBlocks.length == 7)
+    verifyBlockPushCompleted(pusher)
   }
 
   test("More blocks are not pushed when a block push fails with too late " +
@@ -333,6 +412,7 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with 
BeforeAndAfterEach {
     // 2 blocks for each merger locations
     assert(pushedBlocks.length == 4)
     assert(pusher.unreachableBlockMgrs.size == 2)
+    verifyBlockPushCompleted(pusher)
   }
 
   test("SPARK-36255: FileNotFoundException stops the push") {
@@ -359,7 +439,8 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with 
BeforeAndAfterEach {
     ShuffleBlockPusher.stop()
   }
 
-  private class TestShuffleBlockPusher(conf: SparkConf) extends 
ShuffleBlockPusher(conf) {
+  private class TestShuffleBlockPusher(
+      conf: SparkConf) extends ShuffleBlockPusher(conf) {
     val tasks = new LinkedBlockingQueue[Runnable]
 
     override protected def submitTask(task: Runnable): Unit = {
@@ -385,4 +466,18 @@ class ShuffleBlockPusherSuite extends SparkFunSuite with 
BeforeAndAfterEach {
       managedBuffer
     }
   }
+
+  private class ConcurrentTestBlockPusher(conf: SparkConf, semaphore: 
Semaphore)
+      extends TestShuffleBlockPusher(conf) {
+    val blockPusher = ThreadUtils.newDaemonFixedThreadPool(1, 
"test-block-pusher")
+
+    override protected def submitTask(task: Runnable): Unit = {
+      blockPusher.execute(task)
+    }
+
+    override def notifyDriverAboutPushCompletion(): Unit = {
+      super.notifyDriverAboutPushCompletion()
+      semaphore.release()
+    }
+  }
 }
diff --git a/docs/configuration.md b/docs/configuration.md
index 2d4164f..80f17a8 100644
--- a/docs/configuration.md
+++ b/docs/configuration.md
@@ -3268,4 +3268,20 @@ Push-based shuffle helps improve the reliability and 
performance of spark shuffl
   </td>
   <td>3.2.0</td>
 </tr>
+<tr>
+  <td><code>spark.shuffle.push.minShuffleSizeToWait</code></td>
+  <td><code>500m</code></td>
+  <td>
+    Driver will wait for merge finalization to complete only if total shuffle 
data size is more than this threshold. If total shuffle size is less, driver 
will immediately finalize the shuffle output.
+  </td>
+  <td>3.3.0</td>
+</tr>
+<tr>
+  <td><code>spark.shuffle.push.minCompletedPushRatio</code></td>
+  <td><code>1.0</code></td>
+  <td>
+    Fraction of minimum map partitions that should be push complete before 
driver starts shuffle merge finalization during push based shuffle.
+  </td>
+  <td>3.3.0</td>
+</tr>
 </table>

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to