This is an automated email from the ASF dual-hosted git repository. mridulm80 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new b5a1503 [SPARK-32920][SHUFFLE] Finalization of Shuffle push/merge with Push based shuffle and preparation step for the reduce stage b5a1503 is described below commit b5a15035851bfba12ef1c68d10103cec42cbac0c Author: Venkata krishnan Sowrirajan <vsowrira...@linkedin.com> AuthorDate: Thu Jun 10 13:06:15 2021 -0500 [SPARK-32920][SHUFFLE] Finalization of Shuffle push/merge with Push based shuffle and preparation step for the reduce stage ### What changes were proposed in this pull request? Summary of the changes made as part of this PR: 1. `DAGScheduler` changes to finalize a ShuffleMapStage which involves talking to all the shuffle mergers (`ExternalShuffleService`) and getting all the completed merge statuses. 2. Once the `ShuffleMapStage` finalization is complete, mark the `ShuffleMapStage` to be finalized which marks the stage as complete and subsequently letting the child stage start. 3. Also added the relevant tests to `DAGSchedulerSuite` for changes made as part of [SPARK-32919](https://issues.apache.org/jira/browse/SPARK-32919) Lead-authored-by: Min Shen mshenlinkedin.com Co-authored-by: Venkata krishnan Sowrirajan vsowrirajanlinkedin.com Co-authored-by: Chandni Singh chsinghlinkedin.com ### Why are the changes needed? Refer to [SPARK-30602](https://issues.apache.org/jira/browse/SPARK-30602) ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added unit tests to DAGSchedulerSuite Closes #30691 from venkata91/SPARK-32920. Lead-authored-by: Venkata krishnan Sowrirajan <vsowrira...@linkedin.com> Co-authored-by: Min Shen <ms...@linkedin.com> Co-authored-by: Chandni Singh <chsi...@linkedin.com> Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com> --- .../main/scala/org/apache/spark/Dependency.scala | 38 ++ .../scala/org/apache/spark/MapOutputTracker.scala | 44 +- .../org/apache/spark/internal/config/package.scala | 23 +- .../org/apache/spark/scheduler/DAGScheduler.scala | 257 +++++++++--- .../apache/spark/scheduler/DAGSchedulerEvent.scala | 6 + .../org/apache/spark/scheduler/StageInfo.scala | 2 +- ...g.apache.spark.scheduler.ExternalClusterManager | 1 + .../apache/spark/scheduler/DAGSchedulerSuite.scala | 448 ++++++++++++++++++++- 8 files changed, 747 insertions(+), 72 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala index d21b9d9..0a9acf4 100644 --- a/core/src/main/scala/org/apache/spark/Dependency.scala +++ b/core/src/main/scala/org/apache/spark/Dependency.scala @@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleHandle, ShuffleWriteProcessor} import org.apache.spark.storage.BlockManagerId +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -96,12 +97,31 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( val shuffleHandle: ShuffleHandle = _rdd.context.env.shuffleManager.registerShuffle( shuffleId, this) + // By default, shuffle merge is enabled for ShuffleDependency if push based shuffle + // is enabled + private[this] var _shuffleMergeEnabled = + Utils.isPushBasedShuffleEnabled(rdd.sparkContext.getConf) && + // TODO: SPARK-35547: Push based shuffle is currently unsupported for Barrier stages + !rdd.isBarrier() + + private[spark] def setShuffleMergeEnabled(shuffleMergeEnabled: Boolean): Unit = { + _shuffleMergeEnabled = shuffleMergeEnabled + } + + def shuffleMergeEnabled : Boolean = _shuffleMergeEnabled + /** * Stores the location of the list of chosen external shuffle services for handling the * shuffle merge requests from mappers in this shuffle map stage. */ private[spark] var mergerLocs: Seq[BlockManagerId] = Nil + /** + * Stores the information about whether the shuffle merge is finalized for the shuffle map stage + * associated with this shuffle dependency + */ + private[this] var _shuffleMergedFinalized: Boolean = false + def setMergerLocs(mergerLocs: Seq[BlockManagerId]): Unit = { if (mergerLocs != null) { this.mergerLocs = mergerLocs @@ -110,6 +130,24 @@ class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag]( def getMergerLocs: Seq[BlockManagerId] = mergerLocs + private[spark] def markShuffleMergeFinalized(): Unit = { + _shuffleMergedFinalized = true + } + + /** + * Returns true if push-based shuffle is disabled for this stage or empty RDD, + * or if the shuffle merge for this stage is finalized, i.e. the shuffle merge + * results for all partitions are available. + */ + def shuffleMergeFinalized: Boolean = { + // Empty RDD won't be computed therefore shuffle merge finalized should be true by default. + if (shuffleMergeEnabled && rdd.getNumPartitions > 0) { + _shuffleMergedFinalized + } else { + true + } + } + _rdd.sparkContext.cleaner.foreach(_.registerShuffleForCleanup(this)) _rdd.sparkContext.shuffleDriverComponents.registerShuffle(shuffleId) } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index ea9e641..9f2228b 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -214,6 +214,7 @@ private class ShuffleStatus( def removeOutputsOnHost(host: String): Unit = withWriteLock { logDebug(s"Removing outputs for host ${host}") removeOutputsByFilter(x => x.host == host) + removeMergeResultsByFilter(x => x.host == host) } /** @@ -238,6 +239,12 @@ private class ShuffleStatus( invalidateSerializedMapOutputStatusCache() } } + } + + /** + * Removes all shuffle merge result which satisfies the filter. + */ + def removeMergeResultsByFilter(f: BlockManagerId => Boolean): Unit = withWriteLock { for (reduceId <- mergeStatuses.indices) { if (mergeStatuses(reduceId) != null && f(mergeStatuses(reduceId).location)) { _numAvailableMergeResults -= 1 @@ -708,15 +715,16 @@ private[spark] class MapOutputTrackerMaster( } } - /** Unregister all map output information of the given shuffle. */ - def unregisterAllMapOutput(shuffleId: Int): Unit = { + /** Unregister all map and merge output information of the given shuffle. */ + def unregisterAllMapAndMergeOutput(shuffleId: Int): Unit = { shuffleStatuses.get(shuffleId) match { case Some(shuffleStatus) => shuffleStatus.removeOutputsByFilter(x => true) + shuffleStatus.removeMergeResultsByFilter(x => true) incrementEpoch() case None => throw new SparkException( - s"unregisterAllMapOutput called for nonexistent shuffle ID $shuffleId.") + s"unregisterAllMapAndMergeOutput called for nonexistent shuffle ID $shuffleId.") } } @@ -731,25 +739,26 @@ private[spark] class MapOutputTrackerMaster( } /** - * Unregisters a merge result corresponding to the reduceId if present. If the optional mapId - * is specified, it will only unregister the merge result if the mapId is part of that merge + * Unregisters a merge result corresponding to the reduceId if present. If the optional mapIndex + * is specified, it will only unregister the merge result if the mapIndex is part of that merge * result. * * @param shuffleId the shuffleId. * @param reduceId the reduceId. * @param bmAddress block manager address. - * @param mapId the optional mapId which should be checked to see it was part of the merge - * result. + * @param mapIndex the optional mapIndex which should be checked to see it was part of the + * merge result. */ def unregisterMergeResult( - shuffleId: Int, - reduceId: Int, - bmAddress: BlockManagerId, - mapId: Option[Int] = None): Unit = { + shuffleId: Int, + reduceId: Int, + bmAddress: BlockManagerId, + mapIndex: Option[Int] = None): Unit = { shuffleStatuses.get(shuffleId) match { case Some(shuffleStatus) => val mergeStatus = shuffleStatus.mergeStatuses(reduceId) - if (mergeStatus != null && (mapId.isEmpty || mergeStatus.tracker.contains(mapId.get))) { + if (mergeStatus != null && + (mapIndex.isEmpty || mergeStatus.tracker.contains(mapIndex.get))) { shuffleStatus.removeMergeResult(reduceId, bmAddress) incrementEpoch() } @@ -758,6 +767,17 @@ private[spark] class MapOutputTrackerMaster( } } + def unregisterAllMergeResult(shuffleId: Int): Unit = { + shuffleStatuses.get(shuffleId) match { + case Some(shuffleStatus) => + shuffleStatus.removeMergeResultsByFilter(x => true) + incrementEpoch() + case None => + throw new SparkException( + s"unregisterAllMergeResult called for nonexistent shuffle ID $shuffleId.") + } + } + /** Unregister shuffle data */ def unregisterShuffle(shuffleId: Int): Unit = { shuffleStatuses.remove(shuffleId).foreach { shuffleStatus => diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 9574416..84bd8cc 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -2084,6 +2084,27 @@ package object config { .booleanConf .createWithDefault(false) + private[spark] val PUSH_BASED_SHUFFLE_MERGE_RESULTS_TIMEOUT = + ConfigBuilder("spark.shuffle.push.merge.results.timeout") + .doc("Specify the max amount of time DAGScheduler waits for the merge results from " + + "all remote shuffle services for a given shuffle. DAGScheduler will start to submit " + + "following stages if not all results are received within the timeout.") + .version("3.2.0") + .timeConf(TimeUnit.SECONDS) + .checkValue(_ >= 0L, "Timeout must be >= 0.") + .createWithDefaultString("10s") + + private[spark] val PUSH_BASED_SHUFFLE_MERGE_FINALIZE_TIMEOUT = + ConfigBuilder("spark.shuffle.push.merge.finalize.timeout") + .doc("Specify the amount of time DAGScheduler waits after all mappers finish for " + + "a given shuffle map stage before it starts sending merge finalize requests to " + + "remote shuffle services. This allows the shuffle services some extra time to " + + "merge as many blocks as possible.") + .version("3.2.0") + .timeConf(TimeUnit.SECONDS) + .checkValue(_ >= 0L, "Timeout must be >= 0.") + .createWithDefaultString("10s") + private[spark] val SHUFFLE_MERGER_MAX_RETAINED_LOCATIONS = ConfigBuilder("spark.shuffle.push.maxRetainedMergerLocations") .doc("Maximum number of shuffle push merger locations cached for push based shuffle. " + @@ -2117,7 +2138,7 @@ package object config { s"${SHUFFLE_MERGER_LOCATIONS_MIN_THRESHOLD_RATIO.key} set to 0.05, we would need " + "at least 50 mergers to enable push based shuffle for that stage.") .version("3.1.0") - .doubleConf + .intConf .createWithDefault(5) private[spark] val SHUFFLE_NUM_PUSH_THREADS = diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index b359501..1f37638 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -19,7 +19,7 @@ package org.apache.spark.scheduler import java.io.NotSerializableException import java.util.Properties -import java.util.concurrent.{ConcurrentHashMap, TimeUnit} +import java.util.concurrent.{ConcurrentHashMap, TimeoutException, TimeUnit} import java.util.concurrent.atomic.AtomicInteger import scala.annotation.tailrec @@ -29,12 +29,16 @@ import scala.collection.mutable.{HashMap, HashSet, ListBuffer} import scala.concurrent.duration._ import scala.util.control.NonFatal +import com.google.common.util.concurrent.{Futures, SettableFuture} + import org.apache.spark._ import org.apache.spark.broadcast.Broadcast import org.apache.spark.executor.{ExecutorMetrics, TaskMetrics} import org.apache.spark.internal.Logging import org.apache.spark.internal.config import org.apache.spark.internal.config.Tests.TEST_NO_STAGE_RETRY +import org.apache.spark.network.shuffle.{BlockStoreClient, MergeFinalizerListener} +import org.apache.spark.network.shuffle.protocol.MergeStatuses import org.apache.spark.network.util.JavaUtils import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult} import org.apache.spark.rdd.{RDD, RDDCheckpointData} @@ -254,6 +258,24 @@ private[spark] class DAGScheduler( private val blockManagerMasterDriverHeartbeatTimeout = sc.getConf.get(config.STORAGE_BLOCKMANAGER_MASTER_DRIVER_HEARTBEAT_TIMEOUT).millis + private val shuffleMergeResultsTimeoutSec = + sc.getConf.get(config.PUSH_BASED_SHUFFLE_MERGE_RESULTS_TIMEOUT) + + private val shuffleMergeFinalizeWaitSec = + sc.getConf.get(config.PUSH_BASED_SHUFFLE_MERGE_FINALIZE_TIMEOUT) + + // Since SparkEnv gets initialized after DAGScheduler, externalShuffleClient needs to be + // initialized lazily + private lazy val externalShuffleClient: Option[BlockStoreClient] = + if (pushBasedShuffleEnabled) { + Some(env.blockManager.blockStoreClient) + } else { + None + } + + private val shuffleMergeFinalizeScheduler = + ThreadUtils.newDaemonThreadPoolScheduledExecutor("shuffle-merge-finalizer", 8) + /** * Called by the TaskSetManager to report task's starting. */ @@ -689,7 +711,10 @@ private[spark] class DAGScheduler( dep match { case shufDep: ShuffleDependency[_, _, _] => val mapStage = getOrCreateShuffleMapStage(shufDep, stage.firstJobId) - if (!mapStage.isAvailable) { + // Mark mapStage as available with shuffle outputs only after shuffle merge is + // finalized with push based shuffle. If not, subsequent ShuffleMapStage won't + // read from merged output as the MergeStatuses are not available. + if (!mapStage.isAvailable || !mapStage.shuffleDep.shuffleMergeFinalized) { missing += mapStage } case narrowDep: NarrowDependency[_] => @@ -1271,21 +1296,21 @@ private[spark] class DAGScheduler( * locations for block push/merge by getting the historical locations of past executors. */ private def prepareShuffleServicesForShuffleMapStage(stage: ShuffleMapStage): Unit = { - // TODO(SPARK-32920) Handle stage reuse/retry cases separately as without finalize - // TODO changes we cannot disable shuffle merge for the retry/reuse cases - val mergerLocs = sc.schedulerBackend.getShufflePushMergerLocations( - stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId) - - if (mergerLocs.nonEmpty) { - stage.shuffleDep.setMergerLocs(mergerLocs) - logInfo(s"Push-based shuffle enabled for $stage (${stage.name}) with" + - s" ${stage.shuffleDep.getMergerLocs.size} merger locations") - - logDebug("List of shuffle push merger locations " + - s"${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}") - } else { - logInfo("No available merger locations." + - s" Push-based shuffle disabled for $stage (${stage.name})") + assert(stage.shuffleDep.shuffleMergeEnabled && !stage.shuffleDep.shuffleMergeFinalized) + if (stage.shuffleDep.getMergerLocs.isEmpty) { + val mergerLocs = sc.schedulerBackend.getShufflePushMergerLocations( + stage.shuffleDep.partitioner.numPartitions, stage.resourceProfileId) + if (mergerLocs.nonEmpty) { + stage.shuffleDep.setMergerLocs(mergerLocs) + logInfo(s"Push-based shuffle enabled for $stage (${stage.name}) with" + + s" ${stage.shuffleDep.getMergerLocs.size} merger locations") + + logDebug("List of shuffle push merger locations " + + s"${stage.shuffleDep.getMergerLocs.map(_.host).mkString(", ")}") + } else { + stage.shuffleDep.setShuffleMergeEnabled(false) + logInfo("Push-based shuffle disabled for $stage (${stage.name})") + } } } @@ -1298,7 +1323,9 @@ private[spark] class DAGScheduler( // `findMissingPartitions()` returns all partitions every time. stage match { case sms: ShuffleMapStage if stage.isIndeterminate && !sms.isAvailable => - mapOutputTracker.unregisterAllMapOutput(sms.shuffleDep.shuffleId) + // TODO: SPARK-32923: Clean all push-based shuffle metadata like merge enabled and + // TODO: finalized as we are clearing all the merge results. + mapOutputTracker.unregisterAllMapAndMergeOutput(sms.shuffleDep.shuffleId) case _ => } @@ -1318,11 +1345,19 @@ private[spark] class DAGScheduler( stage match { case s: ShuffleMapStage => outputCommitCoordinator.stageStart(stage = s.id, maxPartitionId = s.numPartitions - 1) - // Only generate merger location for a given shuffle dependency once. This way, even if - // this stage gets retried, it would still be merging blocks using the same set of - // shuffle services. - if (pushBasedShuffleEnabled) { - prepareShuffleServicesForShuffleMapStage(s) + // Only generate merger location for a given shuffle dependency once. + if (s.shuffleDep.shuffleMergeEnabled) { + if (!s.shuffleDep.shuffleMergeFinalized) { + prepareShuffleServicesForShuffleMapStage(s) + } else { + // Disable Shuffle merge for the retry/reuse of the same shuffle dependency if it has + // already been merge finalized. If the shuffle dependency was previously assigned + // merger locations but the corresponding shuffle map stage did not complete + // successfully, we would still enable push for its retry. + s.shuffleDep.setShuffleMergeEnabled(false) + logInfo("Push-based shuffle disabled for $stage (${stage.name}) since it" + + " is already shuffle merge finalized") + } } case s: ResultStage => outputCommitCoordinator.stageStart( @@ -1678,38 +1713,16 @@ private[spark] class DAGScheduler( } if (runningStages.contains(shuffleStage) && shuffleStage.pendingPartitions.isEmpty) { - markStageAsFinished(shuffleStage) - logInfo("looking for newly runnable stages") - logInfo("running: " + runningStages) - logInfo("waiting: " + waitingStages) - logInfo("failed: " + failedStages) - - // This call to increment the epoch may not be strictly necessary, but it is retained - // for now in order to minimize the changes in behavior from an earlier version of the - // code. This existing behavior of always incrementing the epoch following any - // successful shuffle map stage completion may have benefits by causing unneeded - // cached map outputs to be cleaned up earlier on executors. In the future we can - // consider removing this call, but this will require some extra investigation. - // See https://github.com/apache/spark/pull/17955/files#r117385673 for more details. - mapOutputTracker.incrementEpoch() - - clearCacheLocs() - - if (!shuffleStage.isAvailable) { - // Some tasks had failed; let's resubmit this shuffleStage. - // TODO: Lower-level scheduler should also deal with this - logInfo("Resubmitting " + shuffleStage + " (" + shuffleStage.name + - ") because some of its tasks had failed: " + - shuffleStage.findMissingPartitions().mkString(", ")) - submitStage(shuffleStage) + if (!shuffleStage.shuffleDep.shuffleMergeFinalized && + shuffleStage.shuffleDep.getMergerLocs.nonEmpty) { + scheduleShuffleMergeFinalize(shuffleStage) } else { - markMapStageJobsAsFinished(shuffleStage) - submitWaitingChildStages(shuffleStage) + processShuffleMapStageCompletion(shuffleStage) } } } - case FetchFailed(bmAddress, shuffleId, _, mapIndex, _, failureMessage) => + case FetchFailed(bmAddress, shuffleId, _, mapIndex, reduceId, failureMessage) => val failedStage = stageIdToStage(task.stageId) val mapStage = shuffleIdToMapStage(shuffleId) @@ -1739,10 +1752,18 @@ private[spark] class DAGScheduler( if (mapStage.rdd.isBarrier()) { // Mark all the map as broken in the map stage, to ensure retry all the tasks on // resubmitted stage attempt. - mapOutputTracker.unregisterAllMapOutput(shuffleId) + // TODO: SPARK-35547: Clean all push-based shuffle metadata like merge enabled and + // TODO: finalized as we are clearing all the merge results. + mapOutputTracker.unregisterAllMapAndMergeOutput(shuffleId) } else if (mapIndex != -1) { // Mark the map whose fetch failed as broken in the map stage mapOutputTracker.unregisterMapOutput(shuffleId, mapIndex, bmAddress) + if (pushBasedShuffleEnabled) { + // Possibly unregister the merge result <shuffleId, reduceId>, if the FetchFailed + // mapIndex is part of the merge result of <shuffleId, reduceId> + mapOutputTracker. + unregisterMergeResult(shuffleId, reduceId, bmAddress, Option(mapIndex)) + } } if (failedStage.rdd.isBarrier()) { @@ -1750,7 +1771,7 @@ private[spark] class DAGScheduler( case failedMapStage: ShuffleMapStage => // Mark all the map as broken in the map stage, to ensure retry all the tasks on // resubmitted stage attempt. - mapOutputTracker.unregisterAllMapOutput(failedMapStage.shuffleDep.shuffleId) + mapOutputTracker.unregisterAllMapAndMergeOutput(failedMapStage.shuffleDep.shuffleId) case failedResultStage: ResultStage => // Abort the failed result stage since we may have committed output for some @@ -1959,7 +1980,7 @@ private[spark] class DAGScheduler( case failedMapStage: ShuffleMapStage => // Mark all the map as broken in the map stage, to ensure retry all the tasks on // resubmitted stage attempt. - mapOutputTracker.unregisterAllMapOutput(failedMapStage.shuffleDep.shuffleId) + mapOutputTracker.unregisterAllMapAndMergeOutput(failedMapStage.shuffleDep.shuffleId) case failedResultStage: ResultStage => // Abort the failed result stage since we may have committed output for some @@ -2000,6 +2021,130 @@ private[spark] class DAGScheduler( } } + /** + * Schedules shuffle merge finalize. + */ + private[scheduler] def scheduleShuffleMergeFinalize(stage: ShuffleMapStage): Unit = { + // TODO: SPARK-33701: Instead of waiting for a constant amount of time for finalization + // TODO: for all the stages, adaptively tune timeout for merge finalization + logInfo(("%s (%s) scheduled for finalizing" + + " shuffle merge in %s s").format(stage, stage.name, shuffleMergeFinalizeWaitSec)) + shuffleMergeFinalizeScheduler.schedule( + new Runnable { + override def run(): Unit = finalizeShuffleMerge(stage) + }, + shuffleMergeFinalizeWaitSec, + TimeUnit.SECONDS + ) + } + + /** + * DAGScheduler notifies all the remote shuffle services chosen to serve shuffle merge request for + * the given shuffle map stage to finalize the shuffle merge process for this shuffle. This is + * invoked in a separate thread to reduce the impact on the DAGScheduler main thread, as the + * scheduler might need to talk to 1000s of shuffle services to finalize shuffle merge. + */ + private[scheduler] def finalizeShuffleMerge(stage: ShuffleMapStage): Unit = { + logInfo("%s (%s) finalizing the shuffle merge".format(stage, stage.name)) + externalShuffleClient.foreach { shuffleClient => + val shuffleId = stage.shuffleDep.shuffleId + val numMergers = stage.shuffleDep.getMergerLocs.length + val results = (0 until numMergers).map(_ => SettableFuture.create[Boolean]()) + + stage.shuffleDep.getMergerLocs.zipWithIndex.foreach { + case (shuffleServiceLoc, index) => + // Sends async request to shuffle service to finalize shuffle merge on that host + // TODO: SPARK-35536: Cancel finalizeShuffleMerge if the stage is cancelled + // TODO: during shuffleMergeFinalizeWaitSec + shuffleClient.finalizeShuffleMerge(shuffleServiceLoc.host, + shuffleServiceLoc.port, shuffleId, + new MergeFinalizerListener { + override def onShuffleMergeSuccess(statuses: MergeStatuses): Unit = { + assert(shuffleId == statuses.shuffleId) + eventProcessLoop.post(RegisterMergeStatuses(stage, MergeStatus. + convertMergeStatusesToMergeStatusArr(statuses, shuffleServiceLoc))) + results(index).set(true) + } + + override def onShuffleMergeFailure(e: Throwable): Unit = { + logWarning(s"Exception encountered when trying to finalize shuffle " + + s"merge on ${shuffleServiceLoc.host} for shuffle $shuffleId", e) + // Do not fail the future as this would cause dag scheduler to prematurely + // give up on waiting for merge results from the remaining shuffle services + // if one fails + results(index).set(false) + } + }) + } + // DAGScheduler only waits for a limited amount of time for the merge results. + // It will attempt to submit the next stage(s) irrespective of whether merge results + // from all shuffle services are received or not. + try { + Futures.allAsList(results: _*).get(shuffleMergeResultsTimeoutSec, TimeUnit.SECONDS) + } catch { + case _: TimeoutException => + logInfo(s"Timed out on waiting for merge results from all " + + s"$numMergers mergers for shuffle $shuffleId") + } finally { + eventProcessLoop.post(ShuffleMergeFinalized(stage)) + } + } + } + + private def processShuffleMapStageCompletion(shuffleStage: ShuffleMapStage): Unit = { + markStageAsFinished(shuffleStage) + logInfo("looking for newly runnable stages") + logInfo("running: " + runningStages) + logInfo("waiting: " + waitingStages) + logInfo("failed: " + failedStages) + + // This call to increment the epoch may not be strictly necessary, but it is retained + // for now in order to minimize the changes in behavior from an earlier version of the + // code. This existing behavior of always incrementing the epoch following any + // successful shuffle map stage completion may have benefits by causing unneeded + // cached map outputs to be cleaned up earlier on executors. In the future we can + // consider removing this call, but this will require some extra investigation. + // See https://github.com/apache/spark/pull/17955/files#r117385673 for more details. + mapOutputTracker.incrementEpoch() + + clearCacheLocs() + + if (!shuffleStage.isAvailable) { + // Some tasks had failed; let's resubmit this shuffleStage. + // TODO: Lower-level scheduler should also deal with this + logInfo("Resubmitting " + shuffleStage + " (" + shuffleStage.name + + ") because some of its tasks had failed: " + + shuffleStage.findMissingPartitions().mkString(", ")) + submitStage(shuffleStage) + } else { + markMapStageJobsAsFinished(shuffleStage) + submitWaitingChildStages(shuffleStage) + } + } + + private[scheduler] def handleRegisterMergeStatuses( + stage: ShuffleMapStage, + mergeStatuses: Seq[(Int, MergeStatus)]): Unit = { + // Register merge statuses if the stage is still running and shuffle merge is not finalized yet. + // TODO: SPARK-35549: Currently merge statuses results which come after shuffle merge + // TODO: is finalized is not registered. + if (runningStages.contains(stage) && !stage.shuffleDep.shuffleMergeFinalized) { + mapOutputTracker.registerMergeResults(stage.shuffleDep.shuffleId, mergeStatuses) + } + } + + private[scheduler] def handleShuffleMergeFinalized(stage: ShuffleMapStage): Unit = { + // Only update MapOutputTracker metadata if the stage is still active. i.e not cancelled. + if (runningStages.contains(stage)) { + stage.shuffleDep.markShuffleMergeFinalized() + processShuffleMapStageCompletion(stage) + } else { + // Unregister all merge results if the stage is currently not + // active (i.e. the stage is cancelled) + mapOutputTracker.unregisterAllMergeResult(stage.shuffleDep.shuffleId) + } + } + private def handleResubmittedFailure(task: Task[_], stage: Stage): Unit = { logInfo(s"Resubmitted $task, so marking it as still running.") stage match { @@ -2447,6 +2592,12 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler case ResubmitFailedStages => dagScheduler.resubmitFailedStages() + + case RegisterMergeStatuses(stage, mergeStatuses) => + dagScheduler.handleRegisterMergeStatuses(stage, mergeStatuses) + + case ShuffleMergeFinalized(stage) => + dagScheduler.handleShuffleMergeFinalized(stage) } override def onError(e: Throwable): Unit = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index d226fe8..307844c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -105,3 +105,9 @@ private[scheduler] case class UnschedulableTaskSetRemoved(stageId: Int, stageAttemptId: Int) extends DAGSchedulerEvent +private[scheduler] case class RegisterMergeStatuses( + stage: ShuffleMapStage, mergeStatuses: Seq[(Int, MergeStatus)]) + extends DAGSchedulerEvent + +private[scheduler] case class ShuffleMergeFinalized(stage: ShuffleMapStage) + extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index 556478d..7b681bf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -42,7 +42,7 @@ class StageInfo( val resourceProfileId: Int) { /** When this stage was submitted from the DAGScheduler to a TaskScheduler. */ var submissionTime: Option[Long] = None - /** Time when all tasks in the stage completed or when the stage was cancelled. */ + /** Time when the stage completed or when the stage was cancelled. */ var completionTime: Option[Long] = None /** If the stage failed, the reason why. */ var failureReason: Option[String] = None diff --git a/core/src/test/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager b/core/src/test/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager index 60054c8..33b162e 100644 --- a/core/src/test/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager +++ b/core/src/test/resources/META-INF/services/org.apache.spark.scheduler.ExternalClusterManager @@ -1,3 +1,4 @@ org.apache.spark.scheduler.DummyExternalClusterManager org.apache.spark.scheduler.MockExternalClusterManager org.apache.spark.scheduler.CSMockExternalClusterManager +org.apache.spark.scheduler.PushBasedClusterManager diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 4c74e4f..f6e87ee 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -25,9 +25,8 @@ import scala.annotation.meta.param import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} import scala.util.control.NonFatal -import org.mockito.Mockito.spy -import org.mockito.Mockito.times -import org.mockito.Mockito.verify +import org.mockito.Mockito._ +import org.roaringbitmap.RoaringBitmap import org.scalatest.concurrent.{Signaler, ThreadSignaler, TimeLimits} import org.scalatest.exceptions.TestFailedException import org.scalatest.time.SpanSugar._ @@ -40,9 +39,10 @@ import org.apache.spark.rdd.{DeterministicLevel, RDD} import org.apache.spark.resource.{ExecutorResourceRequests, ResourceProfile, ResourceProfileBuilder, TaskResourceRequests} import org.apache.spark.resource.ResourceUtils.{FPGA, GPU} import org.apache.spark.scheduler.SchedulingMode.SchedulingMode +import org.apache.spark.scheduler.local.LocalSchedulerBackend import org.apache.spark.shuffle.{FetchFailedException, MetadataFetchFailedException} import org.apache.spark.storage.{BlockId, BlockManagerId, BlockManagerMaster} -import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, CallSite, LongAccumulator, Utils} +import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, CallSite, Clock, LongAccumulator, SystemClock, Utils} class DAGSchedulerEventProcessLoopTester(dagScheduler: DAGScheduler) extends DAGSchedulerEventProcessLoop(dagScheduler) { @@ -295,6 +295,35 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti } } + class MyDAGScheduler( + sc: SparkContext, + taskScheduler: TaskScheduler, + listenerBus: LiveListenerBus, + mapOutputTracker: MapOutputTrackerMaster, + blockManagerMaster: BlockManagerMaster, + env: SparkEnv, + clock: Clock = new SystemClock(), + shuffleMergeFinalize: Boolean = true, + shuffleMergeRegister: Boolean = true + ) extends DAGScheduler( + sc, taskScheduler, listenerBus, mapOutputTracker, blockManagerMaster, env, clock) { + /** + * Schedules shuffle merge finalize. + */ + override private[scheduler] def scheduleShuffleMergeFinalize( + shuffleMapStage: ShuffleMapStage): Unit = { + if (shuffleMergeRegister) { + for (part <- 0 until shuffleMapStage.shuffleDep.partitioner.numPartitions) { + val mergeStatuses = Seq((part, makeMergeStatus(""))) + handleRegisterMergeStatuses(shuffleMapStage, mergeStatuses) + } + if (shuffleMergeFinalize) { + handleShuffleMergeFinalized(shuffleMapStage) + } + } + } + } + override def beforeEach(): Unit = { super.beforeEach() firstInit = true @@ -322,13 +351,14 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti broadcastManager = new BroadcastManager(true, sc.getConf) mapOutputTracker = spy(new MyMapOutputTrackerMaster(sc.getConf, broadcastManager)) blockManagerMaster = spy(new MyBlockManagerMaster(sc.getConf)) - scheduler = new DAGScheduler( + scheduler = new MyDAGScheduler( sc, taskScheduler, sc.listenerBus, mapOutputTracker, blockManagerMaster, sc.env) + dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(scheduler) } @@ -3393,6 +3423,359 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti assert(rprofsE === Set()) } + private def initPushBasedShuffleConfs(conf: SparkConf) = { + conf.set(config.SHUFFLE_SERVICE_ENABLED, true) + conf.set(config.PUSH_BASED_SHUFFLE_ENABLED, true) + conf.set("spark.master", "pushbasedshuffleclustermanager") + } + + test("SPARK-32920: shuffle merge finalization") { + initPushBasedShuffleConfs(conf) + DAGSchedulerSuite.clearMergerLocs + DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", "host5")) + val parts = 2 + val shuffleMapRdd = new MyRDD(sc, parts, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(parts)) + val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = mapOutputTracker) + + // Submit a reduce job that depends which will create a map stage + submit(reduceRdd, (0 until parts).toArray) + completeShuffleMapStageSuccessfully(0, 0, parts) + assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep.shuffleId) == parts) + completeNextResultStageWithSuccess(1, 0) + assert(results === Map(0 -> 42, 1 -> 42)) + results.clear() + assertDataStructuresEmpty() + } + + test("SPARK-32920: merger locations not empty") { + initPushBasedShuffleConfs(conf) + conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 3) + DAGSchedulerSuite.clearMergerLocs + DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", "host5")) + val parts = 2 + + val shuffleMapRdd = new MyRDD(sc, parts, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(parts)) + val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = mapOutputTracker) + + // Submit a reduce job that depends which will create a map stage + submit(reduceRdd, (0 until parts).toArray) + completeShuffleMapStageSuccessfully(0, 0, parts) + val shuffleStage = scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage] + assert(shuffleStage.shuffleDep.getMergerLocs.nonEmpty) + + assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep.shuffleId) == parts) + completeNextResultStageWithSuccess(1, 0) + assert(results === Map(0 -> 42, 1 -> 42)) + + results.clear() + assertDataStructuresEmpty() + } + + test("SPARK-32920: merger locations reuse from shuffle dependency") { + initPushBasedShuffleConfs(conf) + conf.set(config.SHUFFLE_MERGER_MAX_RETAINED_LOCATIONS, 3) + DAGSchedulerSuite.clearMergerLocs + DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", "host5")) + val parts = 2 + + val shuffleMapRdd = new MyRDD(sc, parts, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(parts)) + val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = mapOutputTracker) + submit(reduceRdd, Array(0, 1)) + + completeShuffleMapStageSuccessfully(0, 0, parts) + assert(shuffleDep.getMergerLocs.nonEmpty) + val mergerLocs = shuffleDep.getMergerLocs + completeNextResultStageWithSuccess(1, 0 ) + + // submit another job w/ the shared dependency, and have a fetch failure + val reduce2 = new MyRDD(sc, 2, List(shuffleDep)) + submit(reduce2, Array(0, 1)) + // Note that the stage numbering here is only b/c the shared dependency produces a new, skipped + // stage. If instead it reused the existing stage, then this would be stage 2 + completeNextStageWithFetchFailure(3, 0, shuffleDep) + scheduler.resubmitFailedStages() + + assert(scheduler.runningStages.nonEmpty) + assert(scheduler.stageIdToStage(2) + .asInstanceOf[ShuffleMapStage].shuffleDep.getMergerLocs.nonEmpty) + val newMergerLocs = scheduler.stageIdToStage(2) + .asInstanceOf[ShuffleMapStage].shuffleDep.getMergerLocs + + // Check if same merger locs is reused for the new stage with shared shuffle dependency + assert(mergerLocs.zip(newMergerLocs).forall(x => x._1.host == x._2.host)) + completeShuffleMapStageSuccessfully(2, 0, 2) + completeNextResultStageWithSuccess(3, 1, idx => idx + 1234) + assert(results === Map(0 -> 1234, 1 -> 1235)) + + assertDataStructuresEmpty() + } + + test("SPARK-32920: Disable shuffle merge due to not enough mergers available") { + initPushBasedShuffleConfs(conf) + conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 6) + DAGSchedulerSuite.clearMergerLocs + DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", "host5")) + val parts = 7 + + val shuffleMapRdd = new MyRDD(sc, parts, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(parts)) + val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = mapOutputTracker) + + // Submit a reduce job that depends which will create a map stage + submit(reduceRdd, (0 until parts).toArray) + completeShuffleMapStageSuccessfully(0, 0, parts) + val shuffleStage = scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage] + assert(!shuffleStage.shuffleDep.shuffleMergeEnabled) + + completeNextResultStageWithSuccess(1, 0) + assert(results === Map(2 -> 42, 5 -> 42, 4 -> 42, 1 -> 42, 3 -> 42, 6 -> 42, 0 -> 42)) + + results.clear() + assertDataStructuresEmpty() + } + + test("SPARK-32920: Ensure child stage should not start before all the" + + " parent stages are completed with shuffle merge finalized for all the parent stages") { + initPushBasedShuffleConfs(conf) + DAGSchedulerSuite.clearMergerLocs + DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", "host5")) + val parts = 1 + val shuffleMapRdd1 = new MyRDD(sc, parts, Nil) + val shuffleDep1 = new ShuffleDependency(shuffleMapRdd1, new HashPartitioner(parts)) + + val shuffleMapRdd2 = new MyRDD(sc, parts, Nil) + val shuffleDep2 = new ShuffleDependency(shuffleMapRdd2, new HashPartitioner(parts)) + + val reduceRdd = new MyRDD(sc, parts, List(shuffleDep1, shuffleDep2), tracker = mapOutputTracker) + + // Submit a reduce job + submit(reduceRdd, (0 until parts).toArray) + + complete(taskSets(0), Seq((Success, makeMapStatus("hostA", 1)))) + val shuffleStage1 = scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage] + assert(shuffleStage1.shuffleDep.getMergerLocs.nonEmpty) + + complete(taskSets(1), Seq((Success, makeMapStatus("hostA", 1)))) + val shuffleStage2 = scheduler.stageIdToStage(1).asInstanceOf[ShuffleMapStage] + assert(shuffleStage2.shuffleDep.getMergerLocs.nonEmpty) + + assert(shuffleStage2.shuffleDep.shuffleMergeFinalized) + assert(shuffleStage1.shuffleDep.shuffleMergeFinalized) + assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep1.shuffleId) == parts) + assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep2.shuffleId) == parts) + + completeNextResultStageWithSuccess(2, 0) + assert(results === Map(0 -> 42)) + results.clear() + assertDataStructuresEmpty() + } + + test("SPARK-32920: Reused ShuffleDependency with Shuffle Merge disabled for the corresponding" + + " ShuffleDependency should not cause DAGScheduler to hang") { + initPushBasedShuffleConfs(conf) + conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 10) + DAGSchedulerSuite.clearMergerLocs + DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", "host5")) + val parts = 20 + + val shuffleMapRdd = new MyRDD(sc, parts, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(parts)) + val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = mapOutputTracker) + val partitions = (0 until parts).toArray + submit(reduceRdd, partitions) + + completeShuffleMapStageSuccessfully(0, 0, parts) + val shuffleStage = scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage] + assert(!shuffleStage.shuffleDep.shuffleMergeEnabled) + + completeNextResultStageWithSuccess(1, 0) + val reduce2 = new MyRDD(sc, parts, List(shuffleDep)) + submit(reduce2, partitions) + // Stage 2 should not be executed as it should reuse the already computed shuffle output + assert(scheduler.stageIdToStage(2).latestInfo.taskMetrics == null) + completeNextResultStageWithSuccess(3, 0, idx => idx + 1234) + + val expected = (0 until parts).map(idx => (idx, idx + 1234)) + assert(results === expected.toMap) + + assertDataStructuresEmpty() + } + + test("SPARK-32920: Reused ShuffleDependency with Shuffle Merge disabled for the corresponding" + + " ShuffleDependency with shuffle data loss should recompute missing partitions") { + initPushBasedShuffleConfs(conf) + conf.set(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD, 10) + DAGSchedulerSuite.clearMergerLocs + DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", "host5")) + val parts = 20 + + val shuffleMapRdd = new MyRDD(sc, parts, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(parts)) + val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = mapOutputTracker) + val partitions = (0 until parts).toArray + submit(reduceRdd, partitions) + + completeShuffleMapStageSuccessfully(0, 0, parts) + val shuffleStage = scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage] + assert(!shuffleStage.shuffleDep.shuffleMergeEnabled) + + completeNextResultStageWithSuccess(1, 0) + + DAGSchedulerSuite.clearMergerLocs + val hosts = (6 to parts).map {x => s"Host$x" } + DAGSchedulerSuite.addMergerLocs(hosts) + + val reduce2 = new MyRDD(sc, parts, List(shuffleDep)) + submit(reduce2, partitions) + // Note that the stage numbering here is only b/c the shared dependency produces a new, skipped + // stage. If instead it reused the existing stage, then this would be stage 2 + completeNextStageWithFetchFailure(3, 0, shuffleDep) + scheduler.resubmitFailedStages() + + // Make sure shuffle merge is disabled for the retry + val stage2 = scheduler.stageIdToStage(2).asInstanceOf[ShuffleMapStage] + assert(!stage2.shuffleDep.shuffleMergeEnabled) + + // the scheduler now creates a new task set to regenerate the missing map output, but this time + // using a different stage, the "skipped" one + assert(scheduler.stageIdToStage(2).latestInfo.taskMetrics != null) + completeShuffleMapStageSuccessfully(2, 0, 2) + completeNextResultStageWithSuccess(3, 1, idx => idx + 1234) + + val expected = (0 until parts).map(idx => (idx, idx + 1234)) + assert(results === expected.toMap) + assertDataStructuresEmpty() + } + + test("SPARK-32920: Empty RDD should not be computed") { + initPushBasedShuffleConfs(conf) + val data = sc.emptyRDD[Int] + data.sortBy(x => x).collect() + assert(taskSets.isEmpty) + assertDataStructuresEmpty() + } + + test("SPARK-32920: Merge results should be unregistered if the running stage is cancelled" + + " before shuffle merge is finalized") { + initPushBasedShuffleConfs(conf) + DAGSchedulerSuite.clearMergerLocs + DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", "host5")) + scheduler = new MyDAGScheduler( + sc, + taskScheduler, + sc.listenerBus, + mapOutputTracker, + blockManagerMaster, + sc.env, + shuffleMergeFinalize = false) + dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(scheduler) + + val parts = 2 + val shuffleMapRdd = new MyRDD(sc, parts, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(parts)) + val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = mapOutputTracker) + + // Submit a reduce job that depends which will create a map stage + submit(reduceRdd, (0 until parts).toArray) + // Complete shuffle map stage successfully on hostA + complete(taskSets(0), taskSets(0).tasks.zipWithIndex.map { + case (task, _) => + (Success, makeMapStatus("hostA", parts)) + }.toSeq) + + assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep.shuffleId) == parts) + val shuffleMapStageToCancel = scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage] + runEvent(StageCancelled(0, Option("Explicit cancel check"))) + scheduler.handleShuffleMergeFinalized(shuffleMapStageToCancel) + assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep.shuffleId) == 0) + } + + test("SPARK-32920: SPARK-35549: Merge results should not get registered" + + " after shuffle merge finalization") { + initPushBasedShuffleConfs(conf) + DAGSchedulerSuite.clearMergerLocs + DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", "host5")) + + scheduler = new MyDAGScheduler( + sc, + taskScheduler, + sc.listenerBus, + mapOutputTracker, + blockManagerMaster, + sc.env, + shuffleMergeFinalize = false, + shuffleMergeRegister = false) + dagEventProcessLoopTester = new DAGSchedulerEventProcessLoopTester(scheduler) + + val parts = 2 + val shuffleMapRdd = new MyRDD(sc, parts, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(parts)) + val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = mapOutputTracker) + + // Submit a reduce job that depends which will create a map stage + submit(reduceRdd, (0 until parts).toArray) + // Complete shuffle map stage successfully on hostA + complete(taskSets(0), taskSets(0).tasks.zipWithIndex.map { + case (task, _) => + (Success, makeMapStatus("hostA", parts)) + }.toSeq) + val shuffleMapStage = scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage] + scheduler.handleRegisterMergeStatuses(shuffleMapStage, Seq((0, makeMergeStatus("hostA")))) + scheduler.handleShuffleMergeFinalized(shuffleMapStage) + scheduler.handleRegisterMergeStatuses(shuffleMapStage, Seq((1, makeMergeStatus("hostA")))) + assert(mapOutputTracker.getNumAvailableMergeResults(shuffleDep.shuffleId) == 1) + } + + test("SPARK-32920: Disable push based shuffle in the case of a barrier stage") { + initPushBasedShuffleConfs(conf) + DAGSchedulerSuite.clearMergerLocs + DAGSchedulerSuite.addMergerLocs(Seq("host1", "host2", "host3", "host4", "host5")) + + val parts = 2 + val shuffleMapRdd = new MyRDD(sc, parts, Nil).barrier().mapPartitions(iter => iter) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(parts)) + val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = mapOutputTracker) + + // Submit a reduce job that depends which will create a map stage + submit(reduceRdd, (0 until parts).toArray) + completeShuffleMapStageSuccessfully(0, 0, reduceRdd.partitions.length) + val shuffleMapStage = scheduler.stageIdToStage(0).asInstanceOf[ShuffleMapStage] + assert(!shuffleMapStage.shuffleDep.shuffleMergeEnabled) + } + + test("SPARK-32920: metadata fetch failure should not unregister map status") { + initPushBasedShuffleConfs(conf) + val parts = 2 + val shuffleMapRdd = new MyRDD(sc, parts, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(parts)) + + val reduceRdd = new MyRDD(sc, parts, List(shuffleDep), tracker = mapOutputTracker) + submit(reduceRdd, (0 until parts).toArray) + assert(taskSets.length == 1) + + // Complete shuffle map stage successfully on hostA + complete(taskSets(0), taskSets(0).tasks.zipWithIndex.map { + case (task, _) => + (Success, makeMapStatus("hostA", parts)) + }.toSeq) + + assert(mapOutputTracker.getNumAvailableOutputs(shuffleDep.shuffleId) == parts) + + // Finish the first task + runEvent(makeCompletionEvent( + taskSets(1).tasks(0), Success, makeMapStatus("hostA", parts))) + + // The second task fails with Metadata Failed exception. + val metadataFetchFailedEx = new MetadataFetchFailedException( + shuffleDep.shuffleId, 1, "metadata failure"); + runEvent(makeCompletionEvent( + taskSets(1).tasks(1), metadataFetchFailedEx.toTaskFailedReason, null)) + assert(mapOutputTracker.getNumAvailableOutputs(shuffleDep.shuffleId) == parts) + } + /** * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Note that this checks only the host and not the executor ID. @@ -3448,14 +3831,69 @@ class DAGSchedulerSuite extends SparkFunSuite with TempLocalSparkContext with Ti } object DAGSchedulerSuite { + val mergerLocs = ArrayBuffer[BlockManagerId]() + def makeMapStatus(host: String, reduces: Int, sizes: Byte = 2, mapTaskId: Long = -1): MapStatus = MapStatus(makeBlockManagerId(host), Array.fill[Long](reduces)(sizes), mapTaskId) def makeBlockManagerId(host: String): BlockManagerId = { BlockManagerId(host + "-exec", host, 12345) } + + def makeMergeStatus(host: String, size: Long = 1000): MergeStatus = + MergeStatus(makeBlockManagerId(host), mock(classOf[RoaringBitmap]), size) + + def addMergerLocs(locs: Seq[String]): Unit = { + locs.foreach { loc => mergerLocs.append(makeBlockManagerId(loc)) } + } + + def clearMergerLocs: Unit = mergerLocs.clear() + } object FailThisAttempt { val _fail = new AtomicBoolean(true) } + +private class PushBasedSchedulerBackend( + conf: SparkConf, + scheduler: TaskSchedulerImpl, + cores: Int) extends LocalSchedulerBackend(conf, scheduler, cores) { + + override def getShufflePushMergerLocations( + numPartitions: Int, + resourceProfileId: Int): Seq[BlockManagerId] = { + val mergerLocations = Utils.randomize(DAGSchedulerSuite.mergerLocs).take(numPartitions) + if (mergerLocations.size < numPartitions && mergerLocations.size < + conf.getInt(config.SHUFFLE_MERGER_LOCATIONS_MIN_STATIC_THRESHOLD.key, 5)) { + Seq.empty[BlockManagerId] + } else { + mergerLocations + } + } + + // Currently this is only used in tests specifically for Push based shuffle + override def maxNumConcurrentTasks(rp: ResourceProfile): Int = { + 2 + } +} + +private class PushBasedClusterManager extends ExternalClusterManager { + def canCreate(masterURL: String): Boolean = masterURL == "pushbasedshuffleclustermanager" + + override def createSchedulerBackend( + sc: SparkContext, + masterURL: String, + scheduler: TaskScheduler): SchedulerBackend = { + new PushBasedSchedulerBackend(sc.conf, scheduler.asInstanceOf[TaskSchedulerImpl], 1) + } + + override def createTaskScheduler( + sc: SparkContext, + masterURL: String): TaskScheduler = new TaskSchedulerImpl(sc, 1, isLocal = true) + + override def initialize(scheduler: TaskScheduler, backend: SchedulerBackend): Unit = { + val sc = scheduler.asInstanceOf[TaskSchedulerImpl] + sc.initialize(backend) + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org