mridulm commented on code in PR #37922: URL: https://github.com/apache/spark/pull/37922#discussion_r1067941354
########## core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala: ########## @@ -321,6 +321,12 @@ class BlockManagerMasterEndpoint( } private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = { + val mergerLocations = + if (Utils.isPushBasedShuffleEnabled(conf, isDriver)) { + mapOutputTracker.getShufflePushMergerLocations(shuffleId) + } else { + Seq.empty[BlockManagerId] + } Review Comment: `RemoveShuffle` is sent to the storage endpoint - not to the block manager. You can test it with this diff: ``` diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index 9f03228bb4f..e04ef7f11db 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -321,12 +321,6 @@ class BlockManagerMasterEndpoint( } private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = { - val mergerLocations = - if (Utils.isPushBasedShuffleEnabled(conf, isDriver)) { - mapOutputTracker.getShufflePushMergerLocations(shuffleId) - } else { - Seq.empty[BlockManagerId] - } val removeMsg = RemoveShuffle(shuffleId) val removeShuffleFromExecutorsFutures = blockManagerInfo.values.map { bm => bm.storageEndpoint.ask[Boolean](removeMsg).recover { @@ -374,6 +368,13 @@ class BlockManagerMasterEndpoint( val removeShuffleMergeFromShuffleServicesFutures = externalBlockStoreClient.map { shuffleClient => + val mergerLocations = { + if (Utils.isPushBasedShuffleEnabled(conf, isDriver)) { + mapOutputTracker.getShufflePushMergerLocations(shuffleId) + } else { + Seq.empty[BlockManagerId] + } + } mergerLocations.map { bmId => Future[Boolean] { shuffleClient.removeShuffleMerge(bmId.host, bmId.port, shuffleId, diff --git a/core/src/test/resources/log4j2.properties b/core/src/test/resources/log4j2.properties index ab02104c696..b906523789d 100644 --- a/core/src/test/resources/log4j2.properties +++ b/core/src/test/resources/log4j2.properties @@ -16,7 +16,7 @@ # # Set everything to be logged to the file target/unit-tests.log -rootLogger.level = info +rootLogger.level = debug rootLogger.appenderRef.file.ref = ${sys:test.appender:-File} appender.file.type = File diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index a13527f4b74..fb9dc8ff29b 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -17,9 +17,12 @@ package org.apache.spark +import java.util.{Collections => JCollections, HashSet => JHashSet} import java.util.concurrent.atomic.LongAdder +import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer +import scala.collection.concurrent import org.mockito.ArgumentMatchers.any import org.mockito.Mockito._ @@ -30,10 +33,12 @@ import org.apache.spark.broadcast.BroadcastManager import org.apache.spark.internal.config._ import org.apache.spark.internal.config.Network.{RPC_ASK_TIMEOUT, RPC_MESSAGE_MAX_SIZE} import org.apache.spark.internal.config.Tests.IS_TESTING +import org.apache.spark.network.shuffle.ExternalBlockStoreClient import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv} import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus, MapStatus, MergeStatus} import org.apache.spark.shuffle.FetchFailedException -import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId, ShuffleMergedBlockId} +import org.apache.spark.storage.{BlockManagerId, BlockManagerInfo, BlockManagerMaster, BlockManagerMasterEndpoint, ShuffleBlockId, ShuffleMergedBlockId} +import org.mockito.invocation.InvocationOnMock class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { private val conf = new SparkConf @@ -913,9 +918,64 @@ class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext { slaveRpcEnv.shutdown() } + test("SPARK-40480: shuffle remove should cleanup merged files as well") { + + val newConf = new SparkConf + newConf.set("spark.shuffle.push.enabled", "true") + newConf.set("spark.shuffle.service.enabled", "true") + newConf.set(SERIALIZER, "org.apache.spark.serializer.KryoSerializer") + newConf.set(IS_TESTING, true) + + val SHUFFLE_ID = 10 + + // needs TorrentBroadcast so need a SparkContext + withSpark(new SparkContext("local", "MapOutputTrackerSuite", newConf)) { sc => + val rpcEnv = sc.env.rpcEnv + val masterTracker = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster] + + val blockStoreClient = mock(classOf[ExternalBlockStoreClient]) + val blockManagerMasterEndpoint = new BlockManagerMasterEndpoint( + rpcEnv, + sc.isLocal, + sc.conf, + sc.listenerBus, + Some(blockStoreClient), + // We dont care about this ... + new concurrent.TrieMap[BlockManagerId, BlockManagerInfo](), + masterTracker, + sc.env.shuffleManager, + true + ) + rpcEnv.stop(sc.env.blockManager.master.driverEndpoint) + sc.env.blockManager.master.driverEndpoint = + rpcEnv.setupEndpoint(BlockManagerMaster.DRIVER_ENDPOINT_NAME, + blockManagerMasterEndpoint) + + masterTracker.registerShuffle(SHUFFLE_ID, 10, 10) + val mergerLocs = (1 to 10).map(x => BlockManagerId(s"exec-$x", s"host-$x", x)) + masterTracker.registerShufflePushMergerLocations(SHUFFLE_ID, mergerLocs) + + assert(masterTracker.getShufflePushMergerLocations(SHUFFLE_ID).map(_.host).toSet == + mergerLocs.map(_.host).toSet) + + val foundHosts = JCollections.synchronizedSet(new JHashSet[String]()) + when(blockStoreClient.removeShuffleMerge(any(), any(), any(), any())).thenAnswer( + (m: InvocationOnMock) => { + val host = m.getArgument(0).asInstanceOf[String] + val shuffleId = m.getArgument(2).asInstanceOf[Int] + assert(shuffleId == SHUFFLE_ID) + foundHosts.add(host) + true + }) + + sc.cleaner.get.doCleanupShuffle(SHUFFLE_ID, blocking = true) + assert(foundHosts.asScala == mergerLocs.map(_.host).toSet) + } + } + test("SPARK-34826: Adaptive shuffle mergers") { val newConf = new SparkConf - newConf.set("spark.shuffle.push.based.enabled", "true") + newConf.set("spark.shuffle.push.enabled", "true") newConf.set("spark.shuffle.service.enabled", "true") // needs TorrentBroadcast so need a SparkContext ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org