otterc commented on a change in pull request #30691:
URL: https://github.com/apache/spark/pull/30691#discussion_r626178921



##########
File path: core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
##########
@@ -254,6 +259,28 @@ private[spark] class DAGScheduler(
   private val blockManagerMasterDriverHeartbeatTimeout =
     
sc.getConf.get(config.STORAGE_BLOCKMANAGER_MASTER_DRIVER_HEARTBEAT_TIMEOUT).millis
 
+  private val shuffleMergeResultsTimeoutSec =
+    
JavaUtils.timeStringAsSec(sc.getConf.get(config.PUSH_BASED_SHUFFLE_MERGE_RESULTS_TIMEOUT))
+
+  private val shuffleMergeFinalizeWaitSec =
+    
JavaUtils.timeStringAsSec(sc.getConf.get(config.PUSH_BASED_SHUFFLE_MERGE_FINALIZE_TIMEOUT))
+
+  // lazy initialized so that the shuffle client can be properly initialized

Review comment:
       Nit: this comment is not clear. What is needed for "proper 
initialization"?

##########
File path: 
core/src/test/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager
##########
@@ -1,3 +1,4 @@
 org.apache.spark.scheduler.DummyExternalClusterManager
 org.apache.spark.scheduler.MockExternalClusterManager
 org.apache.spark.scheduler.CSMockExternalClusterManager
+org.apache.spark.scheduler.PushBasedClusterManager

Review comment:
       Nit: is this needed?

##########
File path: core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
##########
@@ -2004,6 +2006,142 @@ private[spark] class DAGScheduler(
     }
   }
 
+  /**
+   * Schedules shuffle merge finalize.
+   */
+  private[scheduler] def scheduleShuffleMergeFinalize(stage: ShuffleMapStage): 
Unit = {
+    // TODO Use the default single threaded scheduler or extend ThreadUtils to
+    // TODO support the multi-threaded scheduler?
+    logInfo(("%s (%s) scheduled for finalizing" +
+      " shuffle merge in %s s").format(stage, stage.name, 
shuffleMergeFinalizeWaitSec))
+    shuffleMergeFinalizeScheduler.schedule(
+      new Runnable {
+        override def run(): Unit = finalizeShuffleMerge(stage)
+      },
+      shuffleMergeFinalizeWaitSec,
+      TimeUnit.SECONDS
+    )
+  }
+
+  /**
+   * DAGScheduler notifies all the remote shuffle services chosen to serve 
shuffle merge request for
+   * the given shuffle map stage to finalize the shuffle merge process for 
this shuffle. This is
+   * invoked in a separate thread to reduce the impact on the DAGScheduler 
main thread, as the
+   * scheduler might need to talk to 1000s of shuffle services to finalize 
shuffle merge.
+   */
+  private[scheduler] def finalizeShuffleMerge(stage: ShuffleMapStage): Unit = {
+    logInfo("%s (%s) finalizing the shuffle merge".format(stage, stage.name))
+    externalShuffleClient.foreach { shuffleClient =>
+      val shuffleId = stage.shuffleDep.shuffleId
+      val numMergers = stage.shuffleDep.getMergerLocs.length
+      val numResponses = new AtomicInteger()
+      val results = (0 until numMergers).map(_ => 
SettableFuture.create[Boolean]())
+      val timedOut = new AtomicBoolean()
+
+      // NOTE: This is a defensive check to post finalize event if numMergers 
is 0 (i.e. no shuffle
+      // service available).
+      if (numMergers == 0) {

Review comment:
       Will this be true here if we don't schedule the finalize at line 1706 
when the stage has no mergers?

##########
File path: 
core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
##########
@@ -3448,14 +3726,63 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
 }
 
 object DAGSchedulerSuite {
+  val mergerLocs = ArrayBuffer[BlockManagerId]()
+
   def makeMapStatus(host: String, reduces: Int, sizes: Byte = 2, mapTaskId: 
Long = -1): MapStatus =
     MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes), 
mapTaskId)
 
   def makeBlockManagerId(host: String): BlockManagerId = {
     BlockManagerId(host + "-exec", host, 12345)
   }
+
+  def makeMergeStatus(host: String, size: Long = 1000): MergeStatus =
+    MergeStatus(makeBlockManagerId(host), mock(classOf[RoaringBitmap]), size)
+
+  def addMergerLocs(locs: Seq[String]): Unit = {
+    locs.foreach { loc => mergerLocs.append(makeBlockManagerId(loc)) }
+  }
+
+  def clearMergerLocs: Unit = mergerLocs.clear()
+
 }
 
 object FailThisAttempt {
   val _fail = new AtomicBoolean(true)
 }
+
+private class PushBasedSchedulerBackend(
+  conf: SparkConf,
+  scheduler: TaskSchedulerImpl, cores: Int) extends 
LocalSchedulerBackend(conf, scheduler, cores) {
+
+  override def getShufflePushMergerLocations(
+    numPartitions: Int,

Review comment:
       nit: indentation

##########
File path: 
core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
##########
@@ -3448,14 +3726,63 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
 }
 
 object DAGSchedulerSuite {
+  val mergerLocs = ArrayBuffer[BlockManagerId]()
+
   def makeMapStatus(host: String, reduces: Int, sizes: Byte = 2, mapTaskId: 
Long = -1): MapStatus =
     MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes), 
mapTaskId)
 
   def makeBlockManagerId(host: String): BlockManagerId = {
     BlockManagerId(host + "-exec", host, 12345)
   }
+
+  def makeMergeStatus(host: String, size: Long = 1000): MergeStatus =
+    MergeStatus(makeBlockManagerId(host), mock(classOf[RoaringBitmap]), size)
+
+  def addMergerLocs(locs: Seq[String]): Unit = {
+    locs.foreach { loc => mergerLocs.append(makeBlockManagerId(loc)) }
+  }
+
+  def clearMergerLocs: Unit = mergerLocs.clear()
+
 }
 
 object FailThisAttempt {
   val _fail = new AtomicBoolean(true)
 }
+
+private class PushBasedSchedulerBackend(
+  conf: SparkConf,
+  scheduler: TaskSchedulerImpl, cores: Int) extends 
LocalSchedulerBackend(conf, scheduler, cores) {
+
+  override def getShufflePushMergerLocations(
+    numPartitions: Int,
+    resourceProfileId: Int): Seq[BlockManagerId] = {
+    val mergerLocations = 
Utils.randomize(DAGSchedulerSuite.mergerLocs).take(numPartitions)
+    if (mergerLocations.size < numPartitions && mergerLocations.size <
+      conf.getInt(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD.key, 
5)) {
+      Seq.empty[BlockManagerId]
+    } else {
+      mergerLocations
+    }
+  }
+}
+
+private class PushBasedClusterManager extends ExternalClusterManager {
+  def canCreate(masterURL: String): Boolean = masterURL == 
"pushbasedshuffleclustermanager"
+
+  override def createSchedulerBackend(
+    sc: SparkContext,
+    masterURL: String,
+    scheduler: TaskScheduler): SchedulerBackend = {
+    new PushBasedSchedulerBackend(sc.conf, 
scheduler.asInstanceOf[TaskSchedulerImpl], 1)
+  }
+
+  override def createTaskScheduler(
+    sc: SparkContext,

Review comment:
       nit: indentation

##########
File path: 
core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
##########
@@ -3448,14 +3726,63 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
 }
 
 object DAGSchedulerSuite {
+  val mergerLocs = ArrayBuffer[BlockManagerId]()
+
   def makeMapStatus(host: String, reduces: Int, sizes: Byte = 2, mapTaskId: 
Long = -1): MapStatus =
     MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes), 
mapTaskId)
 
   def makeBlockManagerId(host: String): BlockManagerId = {
     BlockManagerId(host + "-exec", host, 12345)
   }
+
+  def makeMergeStatus(host: String, size: Long = 1000): MergeStatus =
+    MergeStatus(makeBlockManagerId(host), mock(classOf[RoaringBitmap]), size)
+
+  def addMergerLocs(locs: Seq[String]): Unit = {
+    locs.foreach { loc => mergerLocs.append(makeBlockManagerId(loc)) }
+  }
+
+  def clearMergerLocs: Unit = mergerLocs.clear()
+
 }
 
 object FailThisAttempt {
   val _fail = new AtomicBoolean(true)
 }
+
+private class PushBasedSchedulerBackend(
+  conf: SparkConf,

Review comment:
       Nit: indentation

##########
File path: core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
##########
@@ -689,7 +716,7 @@ private[spark] class DAGScheduler(
             dep match {
               case shufDep: ShuffleDependency[_, _, _] =>
                 val mapStage = getOrCreateShuffleMapStage(shufDep, 
stage.firstJobId)
-                if (!mapStage.isAvailable) {
+                if (!mapStage.isAvailable || !mapStage.isMergeFinalized) {

Review comment:
       Nit: Can we add a comment here that why is this needed?

##########
File path: core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
##########
@@ -1678,33 +1703,10 @@ private[spark] class DAGScheduler(
             }
 
             if (runningStages.contains(shuffleStage) && 
shuffleStage.pendingPartitions.isEmpty) {
-              markStageAsFinished(shuffleStage)
-              logInfo("looking for newly runnable stages")
-              logInfo("running: " + runningStages)
-              logInfo("waiting: " + waitingStages)
-              logInfo("failed: " + failedStages)
-
-              // This call to increment the epoch may not be strictly 
necessary, but it is retained
-              // for now in order to minimize the changes in behavior from an 
earlier version of the
-              // code. This existing behavior of always incrementing the epoch 
following any
-              // successful shuffle map stage completion may have benefits by 
causing unneeded
-              // cached map outputs to be cleaned up earlier on executors. In 
the future we can
-              // consider removing this call, but this will require some extra 
investigation.
-              // See 
https://github.com/apache/spark/pull/17955/files#r117385673 for more details.
-              mapOutputTracker.incrementEpoch()
-
-              clearCacheLocs()
-
-              if (!shuffleStage.isAvailable) {
-                // Some tasks had failed; let's resubmit this shuffleStage.
-                // TODO: Lower-level scheduler should also deal with this
-                logInfo("Resubmitting " + shuffleStage + " (" + 
shuffleStage.name +
-                  ") because some of its tasks had failed: " +
-                  shuffleStage.findMissingPartitions().mkString(", "))
-                submitStage(shuffleStage)
+              if (pushBasedShuffleEnabled) {

Review comment:
       Should this not check 
`shuffleStage.dependency.mergerLocations.nonEmpty`? That would be the case when 
push-based shuffle would be disabled so there is no point in scheduling a 
`scheduleShuffleMergeFinalize`

##########
File path: core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
##########
@@ -2004,6 +2006,142 @@ private[spark] class DAGScheduler(
     }
   }
 
+  /**
+   * Schedules shuffle merge finalize.
+   */
+  private[scheduler] def scheduleShuffleMergeFinalize(stage: ShuffleMapStage): 
Unit = {
+    // TODO Use the default single threaded scheduler or extend ThreadUtils to
+    // TODO support the multi-threaded scheduler?
+    logInfo(("%s (%s) scheduled for finalizing" +
+      " shuffle merge in %s s").format(stage, stage.name, 
shuffleMergeFinalizeWaitSec))
+    shuffleMergeFinalizeScheduler.schedule(
+      new Runnable {
+        override def run(): Unit = finalizeShuffleMerge(stage)
+      },
+      shuffleMergeFinalizeWaitSec,
+      TimeUnit.SECONDS
+    )
+  }
+
+  /**
+   * DAGScheduler notifies all the remote shuffle services chosen to serve 
shuffle merge request for
+   * the given shuffle map stage to finalize the shuffle merge process for 
this shuffle. This is
+   * invoked in a separate thread to reduce the impact on the DAGScheduler 
main thread, as the
+   * scheduler might need to talk to 1000s of shuffle services to finalize 
shuffle merge.
+   */
+  private[scheduler] def finalizeShuffleMerge(stage: ShuffleMapStage): Unit = {
+    logInfo("%s (%s) finalizing the shuffle merge".format(stage, stage.name))
+    externalShuffleClient.foreach { shuffleClient =>
+      val shuffleId = stage.shuffleDep.shuffleId
+      val numMergers = stage.shuffleDep.getMergerLocs.length
+      val numResponses = new AtomicInteger()
+      val results = (0 until numMergers).map(_ => 
SettableFuture.create[Boolean]())
+      val timedOut = new AtomicBoolean()
+
+      // NOTE: This is a defensive check to post finalize event if numMergers 
is 0 (i.e. no shuffle
+      // service available).
+      if (numMergers == 0) {
+        eventProcessLoop.post(ShuffleMergeFinalized(stage))
+        return
+      }
+
+      def increaseAndCheckResponseCount: Unit = {
+        if (numResponses.incrementAndGet() == numMergers) {
+          // Since this runs in the netty client thread and is outside of 
DAGScheduler
+          // event loop, we only post ShuffleMergeFinalized event into the 
event queue.
+          // The processing of this event should be done inside the event 
loop, so it
+          // can safely modify scheduler's internal state.
+          logInfo("%s (%s) shuffle merge finalized".format(stage, stage.name))
+          eventProcessLoop.post(ShuffleMergeFinalized(stage))
+        }
+      }
+
+      stage.shuffleDep.getMergerLocs.zipWithIndex.foreach {
+        case (shuffleServiceLoc, index) =>
+          // Sends async request to shuffle service to finalize shuffle merge 
on that host
+          shuffleClient.finalizeShuffleMerge(shuffleServiceLoc.host,
+            shuffleServiceLoc.port, shuffleId,
+            new MergeFinalizerListener {
+              override def onShuffleMergeSuccess(statuses: MergeStatuses): 
Unit = {
+                assert(shuffleId == statuses.shuffleId)
+                // Register the merge results even if already timed out, in 
case the reducer
+                // needing this merged block starts after dag scheduler 
receives this response.
+                mapOutputTracker.registerMergeResults(statuses.shuffleId,
+                  MergeStatus.convertMergeStatusesToMergeStatusArr(statuses, 
shuffleServiceLoc))
+                if (!timedOut.get()) {
+                  increaseAndCheckResponseCount
+                  results(index).set(true)
+                }
+              }
+
+              override def onShuffleMergeFailure(e: Throwable): Unit = {
+                if (!timedOut.get()) {
+                  logWarning(s"Exception encountered when trying to finalize 
shuffle " +
+                    s"merge on ${shuffleServiceLoc.host} for shuffle 
$shuffleId", e)
+                  increaseAndCheckResponseCount
+                  // Do not fail the future as this would cause dag scheduler 
to prematurely
+                  // give up on waiting for merge results from the remaining 
shuffle services
+                  // if one fails
+                  results(index).set(false)
+                }
+              }
+            })
+      }
+      // DAGScheduler only waits for a limited amount of time for the merge 
results.
+      // It will attempt to submit the next stage(s) irrespective of whether 
merge results
+      // from all shuffle services are received or not.
+      // TODO what are the reasonable configurations for the 2 timeouts? When 
# mappers

Review comment:
       Nit: Please remove these todos with the Spark jira created for Adaptive 
merge finalization timeout

##########
File path: 
core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
##########
@@ -3393,6 +3406,271 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     assert(rprofsE === Set())
   }
 
+  private def initPushBasedShuffleConfs(conf: SparkConf) = {
+    conf.set(config.SHUFFLE_SERVICE_ENABLED, true)
+    conf.set(config.PUSH_BASED_SHUFFLE_ENABLED, true)
+    conf.set("spark.master", "pushbasedshuffleclustermanager")
+  }
+
+  test("SPARK-32920: shuffle merge finalization") {
+    initPushBasedShuffleConfs(conf)
+    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
+    val parts = 2
+    val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = 
mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+    completeShuffleMapStageSuccessfully(0, 0, parts)
+    assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep.shuffleId) 
== parts)
+    completeNextResultStageWithSuccess(1, 0)
+    assert(results === Map(0 -> 42, 1 -> 42))
+    results.clear()
+    assertDataStructuresEmpty()
+  }
+
+  test("SPARK-32920: merger locations not empty") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 3)
+    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
+    val parts = 2
+
+    val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = 
mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+    completeShuffleMapStageSuccessfully(0, 0, parts)
+    val shuffleStage = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    assert(shuffleStage.shuffleDep.getMergerLocs.nonEmpty)
+
+    assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep.shuffleId) 
== parts)
+    completeNextResultStageWithSuccess(1, 0)
+    assert(results === Map(0 -> 42, 1 -> 42))
+
+    results.clear()
+    assertDataStructuresEmpty()
+  }
+
+  test("SPARK-32920: merger locations reuse from shuffle dependency") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.SHUFFLE_MERGER_MAX_RETAINED_LOCATIONS, 3)
+    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
+    val parts = 2
+
+    val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = 
mapOutputTracker)
+    submit(reduceRdd, Array(0, 1))
+
+    completeShuffleMapStageSuccessfully(0, 0, parts)
+    assert(shuffleDep.getMergerLocs.nonEmpty)
+    val mergerLocs = shuffleDep.getMergerLocs
+    completeNextResultStageWithSuccess(1, 0 )
+
+    // submit another job w/ the shared dependency, and have a fetch failure
+    val reduce2 = new MyRDD(sc, 2, List(shuffleDep))
+    submit(reduce2, Array(0, 1))
+    // Note that the stage numbering here is only b/c the shared dependency 
produces a new, skipped
+    // stage.  If instead it reused the existing stage, then this would be 
stage 2
+    completeNextStageWithFetchFailure(3, 0, shuffleDep)
+    scheduler.resubmitFailedStages()
+
+    assert(scheduler.runningStages.nonEmpty)
+    assert(scheduler.stageIdToStage(2)
+      .asInstanceOf[ShuffleMapStage].shuffleDep.getMergerLocs.nonEmpty)
+    val newMergerLocs = scheduler.stageIdToStage(2)
+      .asInstanceOf[ShuffleMapStage].shuffleDep.getMergerLocs
+
+    // Check if same merger locs is reused for the new stage with shared 
shuffle dependency
+    assert(mergerLocs.zip(newMergerLocs).forall(x => x._1.host == x._2.host))
+    completeShuffleMapStageSuccessfully(2, 0, 2)
+    completeNextResultStageWithSuccess(3, 1, idx => idx + 1234)
+    assert(results === Map(0 -> 1234, 1 -> 1235))
+
+    assertDataStructuresEmpty()
+  }
+
+  test("SPARK-32920: Disable shuffle merge due to not enough mergers 
available") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 6)
+    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
+    val parts = 7
+
+    val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = 
mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+    completeShuffleMapStageSuccessfully(0, 0, parts)
+    val shuffleStage = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    assert(shuffleStage.shuffleDep.mergerLocs.isEmpty)
+
+    completeNextResultStageWithSuccess(1, 0)
+    assert(results === Map(2 -> 42, 5 -> 42, 4 -> 42, 1 -> 42, 3 -> 42, 6 -> 
42, 0 -> 42))
+
+    results.clear()
+    assertDataStructuresEmpty()
+  }
+
+  test("SPARK-32920: Ensure child stage should not start before all the" +
+    " parent stages are completed with shuffle merge finalized for all the 
parent stages") {

Review comment:
       Nit: I think here the indentation should be 4 as well. 

##########
File path: 
core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
##########
@@ -3448,14 +3726,63 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
 }
 
 object DAGSchedulerSuite {
+  val mergerLocs = ArrayBuffer[BlockManagerId]()
+
   def makeMapStatus(host: String, reduces: Int, sizes: Byte = 2, mapTaskId: 
Long = -1): MapStatus =
     MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes), 
mapTaskId)
 
   def makeBlockManagerId(host: String): BlockManagerId = {
     BlockManagerId(host + "-exec", host, 12345)
   }
+
+  def makeMergeStatus(host: String, size: Long = 1000): MergeStatus =
+    MergeStatus(makeBlockManagerId(host), mock(classOf[RoaringBitmap]), size)
+
+  def addMergerLocs(locs: Seq[String]): Unit = {
+    locs.foreach { loc => mergerLocs.append(makeBlockManagerId(loc)) }
+  }
+
+  def clearMergerLocs: Unit = mergerLocs.clear()
+
 }
 
 object FailThisAttempt {
   val _fail = new AtomicBoolean(true)
 }
+
+private class PushBasedSchedulerBackend(
+  conf: SparkConf,
+  scheduler: TaskSchedulerImpl, cores: Int) extends 
LocalSchedulerBackend(conf, scheduler, cores) {
+
+  override def getShufflePushMergerLocations(
+    numPartitions: Int,
+    resourceProfileId: Int): Seq[BlockManagerId] = {
+    val mergerLocations = 
Utils.randomize(DAGSchedulerSuite.mergerLocs).take(numPartitions)
+    if (mergerLocations.size < numPartitions && mergerLocations.size <
+      conf.getInt(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD.key, 
5)) {
+      Seq.empty[BlockManagerId]
+    } else {
+      mergerLocations
+    }
+  }
+}
+
+private class PushBasedClusterManager extends ExternalClusterManager {
+  def canCreate(masterURL: String): Boolean = masterURL == 
"pushbasedshuffleclustermanager"
+
+  override def createSchedulerBackend(
+    sc: SparkContext,

Review comment:
       nit: indentation

##########
File path: 
core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
##########
@@ -3393,6 +3406,271 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
     assert(rprofsE === Set())
   }
 
+  private def initPushBasedShuffleConfs(conf: SparkConf) = {
+    conf.set(config.SHUFFLE_SERVICE_ENABLED, true)
+    conf.set(config.PUSH_BASED_SHUFFLE_ENABLED, true)
+    conf.set("spark.master", "pushbasedshuffleclustermanager")
+  }
+
+  test("SPARK-32920: shuffle merge finalization") {
+    initPushBasedShuffleConfs(conf)
+    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
+    val parts = 2
+    val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = 
mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+    completeShuffleMapStageSuccessfully(0, 0, parts)
+    assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep.shuffleId) 
== parts)
+    completeNextResultStageWithSuccess(1, 0)
+    assert(results === Map(0 -> 42, 1 -> 42))
+    results.clear()
+    assertDataStructuresEmpty()
+  }
+
+  test("SPARK-32920: merger locations not empty") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 3)
+    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
+    val parts = 2
+
+    val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = 
mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+    completeShuffleMapStageSuccessfully(0, 0, parts)
+    val shuffleStage = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    assert(shuffleStage.shuffleDep.getMergerLocs.nonEmpty)
+
+    assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep.shuffleId) 
== parts)
+    completeNextResultStageWithSuccess(1, 0)
+    assert(results === Map(0 -> 42, 1 -> 42))
+
+    results.clear()
+    assertDataStructuresEmpty()
+  }
+
+  test("SPARK-32920: merger locations reuse from shuffle dependency") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.SHUFFLE_MERGER_MAX_RETAINED_LOCATIONS, 3)
+    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
+    val parts = 2
+
+    val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = 
mapOutputTracker)
+    submit(reduceRdd, Array(0, 1))
+
+    completeShuffleMapStageSuccessfully(0, 0, parts)
+    assert(shuffleDep.getMergerLocs.nonEmpty)
+    val mergerLocs = shuffleDep.getMergerLocs
+    completeNextResultStageWithSuccess(1, 0 )
+
+    // submit another job w/ the shared dependency, and have a fetch failure
+    val reduce2 = new MyRDD(sc, 2, List(shuffleDep))
+    submit(reduce2, Array(0, 1))
+    // Note that the stage numbering here is only b/c the shared dependency 
produces a new, skipped
+    // stage.  If instead it reused the existing stage, then this would be 
stage 2
+    completeNextStageWithFetchFailure(3, 0, shuffleDep)
+    scheduler.resubmitFailedStages()
+
+    assert(scheduler.runningStages.nonEmpty)
+    assert(scheduler.stageIdToStage(2)
+      .asInstanceOf[ShuffleMapStage].shuffleDep.getMergerLocs.nonEmpty)
+    val newMergerLocs = scheduler.stageIdToStage(2)
+      .asInstanceOf[ShuffleMapStage].shuffleDep.getMergerLocs
+
+    // Check if same merger locs is reused for the new stage with shared 
shuffle dependency
+    assert(mergerLocs.zip(newMergerLocs).forall(x => x._1.host == x._2.host))
+    completeShuffleMapStageSuccessfully(2, 0, 2)
+    completeNextResultStageWithSuccess(3, 1, idx => idx + 1234)
+    assert(results === Map(0 -> 1234, 1 -> 1235))
+
+    assertDataStructuresEmpty()
+  }
+
+  test("SPARK-32920: Disable shuffle merge due to not enough mergers 
available") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 6)
+    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
+    val parts = 7
+
+    val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = 
mapOutputTracker)
+
+    // Submit a reduce job that depends which will create a map stage
+    submit(reduceRdd, (0 until parts).toArray)
+    completeShuffleMapStageSuccessfully(0, 0, parts)
+    val shuffleStage = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    assert(shuffleStage.shuffleDep.mergerLocs.isEmpty)
+
+    completeNextResultStageWithSuccess(1, 0)
+    assert(results === Map(2 -> 42, 5 -> 42, 4 -> 42, 1 -> 42, 3 -> 42, 6 -> 
42, 0 -> 42))
+
+    results.clear()
+    assertDataStructuresEmpty()
+  }
+
+  test("SPARK-32920: Ensure child stage should not start before all the" +
+    " parent stages are completed with shuffle merge finalized for all the 
parent stages") {
+    initPushBasedShuffleConfs(conf)
+    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
+    val parts = 1
+    val shuffleMapRdd1 = new MyRDD(sc, parts, Nil)
+    val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new 
HashPartitioner(parts))
+
+    val shuffleMapRdd2 = new MyRDD(sc, parts, Nil)
+    val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new 
HashPartitioner(parts))
+
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep1, shuffleDep2), 
tracker = mapOutputTracker)
+
+    // Submit a reduce job
+    submit(reduceRdd, (0 until parts).toArray)
+
+    complete(taskSets(0), Seq((Success, makeMapStatus("hostA", 1))))
+    val shuffleStage1 = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    assert(shuffleStage1.shuffleDep.getMergerLocs.nonEmpty)
+
+    complete(taskSets(1), Seq((Success, makeMapStatus("hostA", 1))))
+    val shuffleStage2 = 
scheduler.stageIdToStage(1).asInstanceOf[ShuffleMapStage]
+    assert(shuffleStage2.shuffleDep.getMergerLocs.nonEmpty)
+
+    assert(shuffleStage2.isMergeFinalized)
+    assert(shuffleStage1.isMergeFinalized)
+    assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep1.shuffleId) 
== parts)
+    assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep2.shuffleId) 
== parts)
+
+    completeNextResultStageWithSuccess(2, 0)
+    assert(results === Map(0 -> 42))
+    results.clear()
+    assertDataStructuresEmpty()
+  }
+
+  test("SPARK-32920: Reused ShuffleDependency with Shuffle Merge disabled for 
the corresponding" +
+    " ShuffleDependency should not cause DAGScheduler to hang") {
+    initPushBasedShuffleConfs(conf)
+    conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 10)
+    DAGSchedulerSuite.clearMergerLocs
+    DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", 
"host5"))
+    val parts = 20
+
+    val shuffleMapRdd = new MyRDD(sc, parts, Nil)
+    val shuffleDep = new ShuffleDependency(shuffleMapRdd, new 
HashPartitioner(parts))
+    val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = 
mapOutputTracker)
+    val partitions = (0 until parts).toArray
+    submit(reduceRdd, partitions)
+
+    completeShuffleMapStageSuccessfully(0, 0, parts)
+    val shuffleStage = 
scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage]
+    assert(shuffleStage.shuffleDep.mergerLocs.isEmpty)
+
+    completeNextResultStageWithSuccess(1, 0)
+    val reduce2 = new MyRDD(sc, parts, List(shuffleDep))
+    submit(reduce2, partitions)
+    // Stage 2 should not be executed as it should reuse the already computed 
shuffle output
+    assert(scheduler.stageIdToStage(2).latestInfo.taskMetrics == null)
+    completeNextResultStageWithSuccess(3, 0, idx => idx + 1234)
+
+    val expected = (0 until parts).map(idx => (idx, idx + 1234))
+    assert(results === expected.toMap)
+
+    assertDataStructuresEmpty()
+  }
+
+  test("SPARK-32920: Reused ShuffleDependency with Shuffle Merge disabled for 
the corresponding" +
+    " ShuffleDependency with shuffle data loss should recompute missing 
partitions") {

Review comment:
       Nit: please check what the indentation should be.

##########
File path: 
core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
##########
@@ -328,7 +328,20 @@ class DAGSchedulerSuite extends SparkFunSuite with 
TempLocalSparkContext with Ti
       sc.listenerBus,
       mapOutputTracker,
       blockManagerMaster,
-      sc.env)
+      sc.env) {
+      /**
+       * Schedules shuffle merge finalize.
+       */
+      override private[scheduler] def scheduleShuffleMergeFinalize(
+        shuffleMapStage: ShuffleMapStage): Unit = {

Review comment:
       Nit: indentation should be 4




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to