mridulm commented on code in PR #37922:
URL: https://github.com/apache/spark/pull/37922#discussion_r1067941354


##########
core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala:
##########
@@ -321,6 +321,12 @@ class BlockManagerMasterEndpoint(
   }
 
   private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = {
+    val mergerLocations =
+      if (Utils.isPushBasedShuffleEnabled(conf, isDriver)) {
+        mapOutputTracker.getShufflePushMergerLocations(shuffleId)
+      } else {
+        Seq.empty[BlockManagerId]
+      }

Review Comment:
   `RemoveShuffle` is sent to the storage endpoint - not to the block manager.
   
   You can test it with this diff:
   
   ```
   diff --git 
a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala 
b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
   index 9f03228bb4f..e04ef7f11db 100644
   --- 
a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
   +++ 
b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala
   @@ -321,12 +321,6 @@ class BlockManagerMasterEndpoint(
      }
    
      private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = {
   -    val mergerLocations =
   -      if (Utils.isPushBasedShuffleEnabled(conf, isDriver)) {
   -        mapOutputTracker.getShufflePushMergerLocations(shuffleId)
   -      } else {
   -        Seq.empty[BlockManagerId]
   -      }
        val removeMsg = RemoveShuffle(shuffleId)
        val removeShuffleFromExecutorsFutures = blockManagerInfo.values.map { 
bm =>
          bm.storageEndpoint.ask[Boolean](removeMsg).recover {
   @@ -374,6 +368,13 @@ class BlockManagerMasterEndpoint(
    
        val removeShuffleMergeFromShuffleServicesFutures =
          externalBlockStoreClient.map { shuffleClient =>
   +        val mergerLocations = {
   +          if (Utils.isPushBasedShuffleEnabled(conf, isDriver)) {
   +            mapOutputTracker.getShufflePushMergerLocations(shuffleId)
   +          } else {
   +            Seq.empty[BlockManagerId]
   +          }
   +        }
            mergerLocations.map { bmId =>
              Future[Boolean] {
                shuffleClient.removeShuffleMerge(bmId.host, bmId.port, 
shuffleId,
   diff --git a/core/src/test/resources/log4j2.properties 
b/core/src/test/resources/log4j2.properties
   index ab02104c696..b906523789d 100644
   --- a/core/src/test/resources/log4j2.properties
   +++ b/core/src/test/resources/log4j2.properties
   @@ -16,7 +16,7 @@
    #
    
    # Set everything to be logged to the file target/unit-tests.log
   -rootLogger.level = info
   +rootLogger.level = debug
    rootLogger.appenderRef.file.ref = ${sys:test.appender:-File}
    
    appender.file.type = File
   diff --git 
a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala 
b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
   index a13527f4b74..fb9dc8ff29b 100644
   --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
   +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
   @@ -17,9 +17,12 @@
    
    package org.apache.spark
    
   +import java.util.{Collections => JCollections, HashSet => JHashSet}
    import java.util.concurrent.atomic.LongAdder
    
   +import scala.collection.JavaConverters._
    import scala.collection.mutable.ArrayBuffer
   +import scala.collection.concurrent
    
    import org.mockito.ArgumentMatchers.any
    import org.mockito.Mockito._
   @@ -30,10 +33,12 @@ import org.apache.spark.broadcast.BroadcastManager
    import org.apache.spark.internal.config._
    import org.apache.spark.internal.config.Network.{RPC_ASK_TIMEOUT, 
RPC_MESSAGE_MAX_SIZE}
    import org.apache.spark.internal.config.Tests.IS_TESTING
   +import org.apache.spark.network.shuffle.ExternalBlockStoreClient
    import org.apache.spark.rpc.{RpcAddress, RpcCallContext, RpcEnv}
    import org.apache.spark.scheduler.{CompressedMapStatus, 
HighlyCompressedMapStatus, MapStatus, MergeStatus}
    import org.apache.spark.shuffle.FetchFailedException
   -import org.apache.spark.storage.{BlockManagerId, ShuffleBlockId, 
ShuffleMergedBlockId}
   +import org.apache.spark.storage.{BlockManagerId, BlockManagerInfo, 
BlockManagerMaster, BlockManagerMasterEndpoint, ShuffleBlockId, 
ShuffleMergedBlockId}
   +import org.mockito.invocation.InvocationOnMock
    
    class MapOutputTrackerSuite extends SparkFunSuite with LocalSparkContext {
      private val conf = new SparkConf
   @@ -913,9 +918,64 @@ class MapOutputTrackerSuite extends SparkFunSuite with 
LocalSparkContext {
        slaveRpcEnv.shutdown()
      }
    
   +  test("SPARK-40480: shuffle remove should cleanup merged files as well") {
   +
   +    val newConf = new SparkConf
   +    newConf.set("spark.shuffle.push.enabled", "true")
   +    newConf.set("spark.shuffle.service.enabled", "true")
   +    newConf.set(SERIALIZER, "org.apache.spark.serializer.KryoSerializer")
   +    newConf.set(IS_TESTING, true)
   +
   +    val SHUFFLE_ID = 10
   +
   +    // needs TorrentBroadcast so need a SparkContext
   +    withSpark(new SparkContext("local", "MapOutputTrackerSuite", newConf)) 
{ sc =>
   +      val rpcEnv = sc.env.rpcEnv
   +      val masterTracker = 
sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
   +
   +      val blockStoreClient = mock(classOf[ExternalBlockStoreClient])
   +      val blockManagerMasterEndpoint = new BlockManagerMasterEndpoint(
   +        rpcEnv,
   +        sc.isLocal,
   +        sc.conf,
   +        sc.listenerBus,
   +        Some(blockStoreClient),
   +        // We dont care about this ...
   +        new concurrent.TrieMap[BlockManagerId, BlockManagerInfo](),
   +        masterTracker,
   +        sc.env.shuffleManager,
   +        true
   +      )
   +      rpcEnv.stop(sc.env.blockManager.master.driverEndpoint)
   +      sc.env.blockManager.master.driverEndpoint =
   +        rpcEnv.setupEndpoint(BlockManagerMaster.DRIVER_ENDPOINT_NAME,
   +          blockManagerMasterEndpoint)
   +
   +      masterTracker.registerShuffle(SHUFFLE_ID, 10, 10)
   +      val mergerLocs = (1 to 10).map(x => BlockManagerId(s"exec-$x", 
s"host-$x", x))
   +      masterTracker.registerShufflePushMergerLocations(SHUFFLE_ID, 
mergerLocs)
   +
   +      
assert(masterTracker.getShufflePushMergerLocations(SHUFFLE_ID).map(_.host).toSet
 ==
   +        mergerLocs.map(_.host).toSet)
   +
   +      val foundHosts = JCollections.synchronizedSet(new JHashSet[String]())
   +      when(blockStoreClient.removeShuffleMerge(any(), any(), any(), 
any())).thenAnswer(
   +        (m: InvocationOnMock) => {
   +          val host = m.getArgument(0).asInstanceOf[String]
   +          val shuffleId = m.getArgument(2).asInstanceOf[Int]
   +          assert(shuffleId == SHUFFLE_ID)
   +          foundHosts.add(host)
   +          true
   +      })
   +
   +      sc.cleaner.get.doCleanupShuffle(SHUFFLE_ID, blocking = true)
   +      assert(foundHosts.asScala == mergerLocs.map(_.host).toSet)
   +    }
   +  }
   +
      test("SPARK-34826: Adaptive shuffle mergers") {
        val newConf = new SparkConf
   -    newConf.set("spark.shuffle.push.based.enabled", "true")
   +    newConf.set("spark.shuffle.push.enabled", "true")
        newConf.set("spark.shuffle.service.enabled", "true")
    
        // needs TorrentBroadcast so need a SparkContext
   
   
   
   
   ```
   
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to