Ngone51 commented on a change in pull request #30312:
URL: https://github.com/apache/spark/pull/30312#discussion_r530141471



##########
File path: core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala
##########
@@ -0,0 +1,462 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import java.io.File
+import java.net.ConnectException
+import java.nio.ByteBuffer
+import java.util.concurrent.ExecutorService
+
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue}
+
+import com.google.common.base.Throwables
+
+import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv}
+import org.apache.spark.annotation.Since
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config._
+import org.apache.spark.launcher.SparkLauncher
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, 
ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.netty.SparkTransportConf
+import org.apache.spark.network.shuffle.BlockFetchingListener
+import org.apache.spark.network.shuffle.ErrorHandler.BlockPushErrorHandler
+import org.apache.spark.network.util.TransportConf
+import org.apache.spark.shuffle.ShuffleBlockPusher._
+import org.apache.spark.storage.{BlockId, BlockManagerId, ShufflePushBlockId}
+import org.apache.spark.util.{ThreadUtils, Utils}
+
+/**
+ * Used for pushing shuffle blocks to remote shuffle services when push 
shuffle is enabled.
+ * When push shuffle is enabled, it is created after the shuffle writer 
finishes writing the shuffle
+ * file and initiates the block push process.
+ *
+ * @param dataFile         mapper generated shuffle data file
+ * @param partitionLengths array of shuffle block size so we can tell shuffle 
block
+ *                         boundaries within the shuffle file
+ * @param dep              shuffle dependency to get shuffle ID and the 
location of remote shuffle
+ *                         services to push local shuffle blocks
+ * @param partitionId      map index of the shuffle map task
+ * @param conf             spark configuration
+ */
+@Since("3.1.0")
+private[spark] class ShuffleBlockPusher(
+    dataFile: File,
+    partitionLengths: Array[Long],
+    dep: ShuffleDependency[_, _, _],
+    partitionId: Int,
+    conf: SparkConf) extends Logging {

Review comment:
       Pass these fields to `initiateBlockPush()` should be enough?

##########
File path: 
core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala
##########
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import java.io.File
+import java.net.ConnectException
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.mockito.{Mock, MockitoAnnotations}
+import org.mockito.Answers.RETURNS_SMART_NULLS
+import org.mockito.ArgumentMatchers.any
+import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark._
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockFetchingListener, 
BlockStoreClient}
+import org.apache.spark.network.shuffle.ErrorHandler.BlockPushErrorHandler
+import org.apache.spark.network.util.TransportConf
+import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.storage._
+
+class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach {
+
+  @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = 
_
+  @Mock(answer = RETURNS_SMART_NULLS) private var dependency: 
ShuffleDependency[Int, Int, Int] = _
+  @Mock(answer = RETURNS_SMART_NULLS) private var shuffleClient: 
BlockStoreClient = _
+
+  private var conf: SparkConf = _
+  private var pushedBlocks = new ArrayBuffer[String]
+
+  override def beforeEach(): Unit = {
+    super.beforeEach()
+    conf = new SparkConf(loadDefaults = false)
+    MockitoAnnotations.initMocks(this)
+    when(dependency.partitioner).thenReturn(new HashPartitioner(8))
+    when(dependency.serializer).thenReturn(new JavaSerializer(conf))
+    
when(dependency.getMergerLocs).thenReturn(Seq(BlockManagerId("test-client", 
"test-client", 1)))
+    conf.set("spark.shuffle.push.based.enabled", "true")
+    conf.set("spark.shuffle.service.enabled", "true")
+    // Set the env because the shuffler writer gets the shuffle client 
instance from the env.
+    val mockEnv = mock(classOf[SparkEnv])
+    when(mockEnv.conf).thenReturn(conf)
+    when(mockEnv.blockManager).thenReturn(blockManager)
+    SparkEnv.set(mockEnv)
+    when(blockManager.blockStoreClient).thenReturn(shuffleClient)
+  }
+
+  override def afterEach(): Unit = {
+    pushedBlocks.clear()
+    super.afterEach()
+  }
+
+  private def interceptPushedBlocksForSuccess(): Unit = {
+    when(shuffleClient.pushBlocks(any(), any(), any(), any(), any()))
+      .thenAnswer((invocation: InvocationOnMock) => {
+        val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]]
+        pushedBlocks ++= blocks
+        val managedBuffers = 
invocation.getArguments()(3).asInstanceOf[Array[ManagedBuffer]]
+        val blockFetchListener = 
invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
+        (blocks, managedBuffers).zipped.foreach((blockId, buffer) => {
+          blockFetchListener.onBlockFetchSuccess(blockId, buffer)
+        })
+      })
+  }
+
+  test("Basic block push") {
+    interceptPushedBlocksForSuccess()
+    new TestShuffleBlockPusher(mock(classOf[File]),
+      Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0, 
conf)
+        .initiateBlockPush()
+    verify(shuffleClient, times(1))
+      .pushBlocks(any(), any(), any(), any(), any())

Review comment:
       Could you also add tests for `prepareBlockPushRequests()`? So we can 
test the generated requests directly in addition to the verify the invocation 
number of `pushBlocks()`.

##########
File path: core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala
##########
@@ -0,0 +1,462 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import java.io.File
+import java.net.ConnectException
+import java.nio.ByteBuffer
+import java.util.concurrent.ExecutorService
+
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue}
+
+import com.google.common.base.Throwables
+
+import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv}
+import org.apache.spark.annotation.Since
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config._
+import org.apache.spark.launcher.SparkLauncher
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, 
ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.netty.SparkTransportConf
+import org.apache.spark.network.shuffle.BlockFetchingListener
+import org.apache.spark.network.shuffle.ErrorHandler.BlockPushErrorHandler
+import org.apache.spark.network.util.TransportConf
+import org.apache.spark.shuffle.ShuffleBlockPusher._
+import org.apache.spark.storage.{BlockId, BlockManagerId, ShufflePushBlockId}
+import org.apache.spark.util.{ThreadUtils, Utils}
+
+/**
+ * Used for pushing shuffle blocks to remote shuffle services when push 
shuffle is enabled.
+ * When push shuffle is enabled, it is created after the shuffle writer 
finishes writing the shuffle
+ * file and initiates the block push process.
+ *
+ * @param dataFile         mapper generated shuffle data file
+ * @param partitionLengths array of shuffle block size so we can tell shuffle 
block
+ *                         boundaries within the shuffle file
+ * @param dep              shuffle dependency to get shuffle ID and the 
location of remote shuffle
+ *                         services to push local shuffle blocks
+ * @param partitionId      map index of the shuffle map task
+ * @param conf             spark configuration
+ */
+@Since("3.1.0")
+private[spark] class ShuffleBlockPusher(
+    dataFile: File,
+    partitionLengths: Array[Long],
+    dep: ShuffleDependency[_, _, _],
+    partitionId: Int,
+    conf: SparkConf) extends Logging {
+  private[this] var maxBytesInFlight = 0L
+  private[this] var maxReqsInFlight = 0
+  private[this] var maxBlocksInFlightPerAddress = 0
+  private[this] var bytesInFlight = 0L
+  private[this] var reqsInFlight = 0
+  private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId, 
Int]()
+  private[this] val deferredPushRequests = new HashMap[BlockManagerId, 
Queue[PushRequest]]()
+  private[this] val pushRequests = new Queue[PushRequest]
+  private[this] val errorHandler = createErrorHandler()
+  private[this] val unreachableBlockMgrs = new HashSet[BlockManagerId]()
+
+  // VisibleForTesting
+  private[shuffle] def createErrorHandler(): BlockPushErrorHandler = {
+    new BlockPushErrorHandler() {
+      // For a connection exception against a particular host, we will stop 
pushing any
+      // blocks to just that host and continue push blocks to other hosts. So, 
here push of
+      // all blocks will only stop when it is "Too Late". Also see 
updateStateAndCheckIfPushMore.
+      override def shouldRetryError(t: Throwable): Boolean = {
+        // If the block is too late, there is no need to retry it
+        
!Throwables.getStackTraceAsString(t).contains(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX)
+      }
+    }
+  }
+
+  private[shuffle] def initiateBlockPush(): Unit = {
+    val numPartitions = dep.partitioner.numPartitions
+    val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle")
+
+    val maxBlockSizeToPush = conf.get(SHUFFLE_MAX_BLOCK_SIZE_TO_PUSH) * 1024
+    val maxBlockBatchSize = conf.get(SHUFFLE_MAX_BLOCK_BATCH_SIZE_FOR_PUSH) * 
1024 * 1024
+    val mergerLocs = dep.getMergerLocs.map(loc => BlockManagerId("", loc.host, 
loc.port))
+
+    maxBytesInFlight = conf.getSizeAsMb("spark.reducer.maxSizeInFlight", 
"48m") * 1024 * 1024
+    maxReqsInFlight = conf.getInt("spark.reducer.maxReqsInFlight", 
Int.MaxValue)
+    maxBlocksInFlightPerAddress = 
conf.get(REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS)
+
+    val requests = prepareBlockPushRequests(numPartitions, partitionId, 
dep.shuffleId, dataFile,
+      partitionLengths, mergerLocs, transportConf, maxBlockSizeToPush, 
maxBlockBatchSize)
+    // Randomize the orders of the PushRequest, so different mappers pushing 
blocks at the same
+    // time won't be pushing the same ranges of shuffle partitions.
+    pushRequests ++= Utils.randomize(requests)
+
+    submitTask(() => {
+      pushUpToMax()
+    })
+  }
+
+  /**
+   * Triggers the push. It's a separate method for testing.
+   * VisibleForTesting
+   */
+  protected def submitTask(task: Runnable): Unit = {
+    if (BLOCK_PUSHER_POOL != null) {
+      BLOCK_PUSHER_POOL.execute(task)
+    }
+  }
+
+  /**
+   * Since multiple netty client threads could potentially be calling 
pushUpToMax for the same
+   * mapper, we synchronize access to this method so that only one thread can 
push blocks for
+   * a given mapper. This helps to simplify access to the shared states. The 
down side of this
+   * is that we could unnecessarily block other mappers' block pushes if all 
netty client threads
+   * are occupied by block pushes from the same mapper.
+   *
+   * This code is similar to ShuffleBlockFetcherIterator#fetchUpToMaxBytes in 
how it throttles
+   * the data transfer between shuffle client/server.
+   */
+  private def pushUpToMax(): Unit = synchronized {
+    // Process any outstanding deferred push requests if possible.
+    if (deferredPushRequests.nonEmpty) {
+      for ((remoteAddress, defReqQueue) <- deferredPushRequests) {
+        while (isRemoteBlockPushable(defReqQueue) &&
+          !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) {
+          val request = defReqQueue.dequeue()
+          logDebug(s"Processing deferred push request for $remoteAddress with "
+            + s"${request.blocks.length} blocks")
+          sendRequest(request)
+          if (defReqQueue.isEmpty) {
+            deferredPushRequests -= remoteAddress
+          }
+        }
+      }
+    }
+
+    // Process any regular push requests if possible.
+    while (isRemoteBlockPushable(pushRequests)) {
+      val request = pushRequests.dequeue()
+      val remoteAddress = request.address
+      if (isRemoteAddressMaxedOut(remoteAddress, request)) {
+        logDebug(s"Deferring push request for $remoteAddress with 
${request.blocks.size} blocks")
+        val defReqQueue = deferredPushRequests.getOrElse(remoteAddress, new 
Queue[PushRequest]())
+        defReqQueue.enqueue(request)
+        deferredPushRequests(remoteAddress) = defReqQueue
+      } else {
+        sendRequest(request)
+      }
+    }
+
+    def isRemoteBlockPushable(pushReqQueue: Queue[PushRequest]): Boolean = {
+      pushReqQueue.nonEmpty &&
+        (bytesInFlight == 0 ||
+          (reqsInFlight + 1 <= maxReqsInFlight &&
+            bytesInFlight + pushReqQueue.front.size <= maxBytesInFlight))
+    }
+
+    // Checks if sending a new push request will exceed the max no. of blocks 
being pushed to a
+    // given remote address.
+    def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: 
PushRequest): Boolean = {
+      (numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0)
+        + request.blocks.size) > maxBlocksInFlightPerAddress
+    }
+  }
+
+  /**
+   * Push blocks to remote shuffle server. The callback listener will invoke 
#pushUpToMax again
+   * to trigger pushing the next batch of blocks once some block transfer is 
done in the current
+   * batch. This way, we decouple the map task from the block push process, 
since it is netty
+   * client thread instead of task execution thread which takes care of 
majority of the block
+   * pushes.
+   */
+  private def sendRequest(request: PushRequest): Unit = {
+    bytesInFlight = bytesInFlight + request.size
+    reqsInFlight = reqsInFlight + 1
+    numBlocksInFlightPerAddress(request.address) = 
numBlocksInFlightPerAddress.getOrElseUpdate(
+      request.address, 0) + request.blocks.length
+
+    val sizeMap = request.blocks.map { case (blockId, size) => 
(blockId.toString, size) }.toMap
+    val address = request.address
+    val blockIds = request.blocks.map(_._1.toString)
+    val remainingBlocks = new HashSet[String]() ++= blockIds
+
+    val blockPushListener = new BlockFetchingListener {
+      // Initiating a connection and pushing blocks to a remote shuffle 
service is always handled by
+      // the block-push-threads. We should not initiate the connection 
creation in the
+      // blockPushListener callbacks which are invoked by the netty eventloop 
because:
+      // 1. TrasportClient.createConnection(...) blocks for connection to be 
established and it's
+      // recommended to avoid any blocking operations in the eventloop;
+      // 2. The actual connection creation is a task that gets added to the 
task queue of another
+      // eventloop which could have eventloops eventually blocking each other.
+      // Once the blockPushListener is notified of the block push success or 
failure, we
+      // just delegate it to block-push-threads.
+      def handleResult(result: PushResult): Unit = {
+        submitTask(() => {
+          if (updateStateAndCheckIfPushMore(
+            sizeMap(result.blockId), address, remainingBlocks, result)) {
+            pushUpToMax()
+          }
+        })
+      }
+
+      override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): 
Unit = {
+        logTrace(s"Push for block $blockId to $address successful.")
+        handleResult(PushResult(blockId, null))
+      }
+
+      override def onBlockFetchFailure(blockId: String, exception: Throwable): 
Unit = {
+        // check the message or it's cause to see it needs to be logged.
+        if (!errorHandler.shouldLogError(exception)) {
+          logTrace(s"Pushing block $blockId to $address failed.", exception)
+        } else {
+          logWarning(s"Pushing block $blockId to $address failed.", exception)
+        }
+        handleResult(PushResult(blockId, exception))
+      }
+    }
+    SparkEnv.get.blockManager.blockStoreClient.pushBlocks(
+      address.host, address.port, blockIds.toArray,
+      sliceReqBufferIntoBlockBuffers(request.reqBuffer, 
request.blocks.map(_._2)),
+      blockPushListener)
+  }
+
+  /**
+   * Given the ManagedBuffer representing all the continuous blocks inside the 
shuffle data file
+   * for a PushRequest and an array of individual block sizes, load the buffer 
from disk into
+   * memory and slice it into multiple smaller buffers representing each block.
+   *
+   * With nio ByteBuffer, the individual block buffers share data with the 
initial in memory
+   * buffer loaded from disk. Thus only one copy of the block data is kept in 
memory.
+   * @param reqBuffer A {{FileSegmentManagedBuffer}} representing all the 
continuous blocks in
+   *                  the shuffle data file for a PushRequest
+   * @param blockSizes Array of block sizes
+   * @return Array of in memory buffer for each individual block
+   */
+  private def sliceReqBufferIntoBlockBuffers(
+      reqBuffer: ManagedBuffer,
+      blockSizes: Seq[Long]): Array[ManagedBuffer] = {
+    if (blockSizes.size == 1) {
+      Array(reqBuffer)
+    } else {
+      val inMemoryBuffer = reqBuffer.nioByteBuffer()
+      val blockOffsets = new Array[Long](blockSizes.size)
+      var offset = 0L
+      for (index <- blockSizes.indices) {
+        blockOffsets(index) = offset
+        offset += blockSizes(index)
+      }
+      blockOffsets.zip(blockSizes).map {
+        case (offset, size) =>
+          new NioManagedBuffer(inMemoryBuffer.duplicate()

Review comment:
       Why we need `duplicate()` here? I thought we could apply the `slice()` 
to `inMemoryBuffer` directly.

##########
File path: core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala
##########
@@ -0,0 +1,462 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import java.io.File
+import java.net.ConnectException
+import java.nio.ByteBuffer
+import java.util.concurrent.ExecutorService
+
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue}
+
+import com.google.common.base.Throwables
+
+import org.apache.spark.{ShuffleDependency, SparkConf, SparkEnv}
+import org.apache.spark.annotation.Since
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config._
+import org.apache.spark.launcher.SparkLauncher
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, 
ManagedBuffer, NioManagedBuffer}
+import org.apache.spark.network.netty.SparkTransportConf
+import org.apache.spark.network.shuffle.BlockFetchingListener
+import org.apache.spark.network.shuffle.ErrorHandler.BlockPushErrorHandler
+import org.apache.spark.network.util.TransportConf
+import org.apache.spark.shuffle.ShuffleBlockPusher._
+import org.apache.spark.storage.{BlockId, BlockManagerId, ShufflePushBlockId}
+import org.apache.spark.util.{ThreadUtils, Utils}
+
+/**
+ * Used for pushing shuffle blocks to remote shuffle services when push 
shuffle is enabled.
+ * When push shuffle is enabled, it is created after the shuffle writer 
finishes writing the shuffle
+ * file and initiates the block push process.
+ *
+ * @param dataFile         mapper generated shuffle data file
+ * @param partitionLengths array of shuffle block size so we can tell shuffle 
block
+ *                         boundaries within the shuffle file
+ * @param dep              shuffle dependency to get shuffle ID and the 
location of remote shuffle
+ *                         services to push local shuffle blocks
+ * @param partitionId      map index of the shuffle map task
+ * @param conf             spark configuration
+ */
+@Since("3.1.0")
+private[spark] class ShuffleBlockPusher(
+    dataFile: File,
+    partitionLengths: Array[Long],
+    dep: ShuffleDependency[_, _, _],
+    partitionId: Int,
+    conf: SparkConf) extends Logging {
+  private[this] var maxBytesInFlight = 0L
+  private[this] var maxReqsInFlight = 0
+  private[this] var maxBlocksInFlightPerAddress = 0
+  private[this] var bytesInFlight = 0L
+  private[this] var reqsInFlight = 0
+  private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId, 
Int]()
+  private[this] val deferredPushRequests = new HashMap[BlockManagerId, 
Queue[PushRequest]]()
+  private[this] val pushRequests = new Queue[PushRequest]
+  private[this] val errorHandler = createErrorHandler()
+  private[this] val unreachableBlockMgrs = new HashSet[BlockManagerId]()
+
+  // VisibleForTesting
+  private[shuffle] def createErrorHandler(): BlockPushErrorHandler = {
+    new BlockPushErrorHandler() {
+      // For a connection exception against a particular host, we will stop 
pushing any
+      // blocks to just that host and continue push blocks to other hosts. So, 
here push of
+      // all blocks will only stop when it is "Too Late". Also see 
updateStateAndCheckIfPushMore.
+      override def shouldRetryError(t: Throwable): Boolean = {
+        // If the block is too late, there is no need to retry it
+        
!Throwables.getStackTraceAsString(t).contains(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX)
+      }
+    }
+  }
+
+  private[shuffle] def initiateBlockPush(): Unit = {
+    val numPartitions = dep.partitioner.numPartitions
+    val transportConf = SparkTransportConf.fromSparkConf(conf, "shuffle")
+
+    val maxBlockSizeToPush = conf.get(SHUFFLE_MAX_BLOCK_SIZE_TO_PUSH) * 1024
+    val maxBlockBatchSize = conf.get(SHUFFLE_MAX_BLOCK_BATCH_SIZE_FOR_PUSH) * 
1024 * 1024
+    val mergerLocs = dep.getMergerLocs.map(loc => BlockManagerId("", loc.host, 
loc.port))
+
+    maxBytesInFlight = conf.getSizeAsMb("spark.reducer.maxSizeInFlight", 
"48m") * 1024 * 1024
+    maxReqsInFlight = conf.getInt("spark.reducer.maxReqsInFlight", 
Int.MaxValue)
+    maxBlocksInFlightPerAddress = 
conf.get(REDUCER_MAX_BLOCKS_IN_FLIGHT_PER_ADDRESS)
+
+    val requests = prepareBlockPushRequests(numPartitions, partitionId, 
dep.shuffleId, dataFile,
+      partitionLengths, mergerLocs, transportConf, maxBlockSizeToPush, 
maxBlockBatchSize)
+    // Randomize the orders of the PushRequest, so different mappers pushing 
blocks at the same
+    // time won't be pushing the same ranges of shuffle partitions.
+    pushRequests ++= Utils.randomize(requests)
+
+    submitTask(() => {
+      pushUpToMax()
+    })
+  }
+
+  /**
+   * Triggers the push. It's a separate method for testing.
+   * VisibleForTesting
+   */
+  protected def submitTask(task: Runnable): Unit = {
+    if (BLOCK_PUSHER_POOL != null) {
+      BLOCK_PUSHER_POOL.execute(task)
+    }
+  }
+
+  /**
+   * Since multiple netty client threads could potentially be calling 
pushUpToMax for the same
+   * mapper, we synchronize access to this method so that only one thread can 
push blocks for
+   * a given mapper. This helps to simplify access to the shared states. The 
down side of this
+   * is that we could unnecessarily block other mappers' block pushes if all 
netty client threads
+   * are occupied by block pushes from the same mapper.
+   *
+   * This code is similar to ShuffleBlockFetcherIterator#fetchUpToMaxBytes in 
how it throttles
+   * the data transfer between shuffle client/server.
+   */
+  private def pushUpToMax(): Unit = synchronized {
+    // Process any outstanding deferred push requests if possible.
+    if (deferredPushRequests.nonEmpty) {
+      for ((remoteAddress, defReqQueue) <- deferredPushRequests) {
+        while (isRemoteBlockPushable(defReqQueue) &&
+          !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)) {
+          val request = defReqQueue.dequeue()
+          logDebug(s"Processing deferred push request for $remoteAddress with "
+            + s"${request.blocks.length} blocks")
+          sendRequest(request)
+          if (defReqQueue.isEmpty) {
+            deferredPushRequests -= remoteAddress
+          }
+        }
+      }
+    }
+
+    // Process any regular push requests if possible.
+    while (isRemoteBlockPushable(pushRequests)) {
+      val request = pushRequests.dequeue()
+      val remoteAddress = request.address
+      if (isRemoteAddressMaxedOut(remoteAddress, request)) {
+        logDebug(s"Deferring push request for $remoteAddress with 
${request.blocks.size} blocks")
+        val defReqQueue = deferredPushRequests.getOrElse(remoteAddress, new 
Queue[PushRequest]())
+        defReqQueue.enqueue(request)
+        deferredPushRequests(remoteAddress) = defReqQueue
+      } else {
+        sendRequest(request)
+      }
+    }
+
+    def isRemoteBlockPushable(pushReqQueue: Queue[PushRequest]): Boolean = {
+      pushReqQueue.nonEmpty &&
+        (bytesInFlight == 0 ||
+          (reqsInFlight + 1 <= maxReqsInFlight &&
+            bytesInFlight + pushReqQueue.front.size <= maxBytesInFlight))
+    }
+
+    // Checks if sending a new push request will exceed the max no. of blocks 
being pushed to a
+    // given remote address.
+    def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: 
PushRequest): Boolean = {
+      (numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0)
+        + request.blocks.size) > maxBlocksInFlightPerAddress
+    }
+  }
+
+  /**
+   * Push blocks to remote shuffle server. The callback listener will invoke 
#pushUpToMax again
+   * to trigger pushing the next batch of blocks once some block transfer is 
done in the current
+   * batch. This way, we decouple the map task from the block push process, 
since it is netty
+   * client thread instead of task execution thread which takes care of 
majority of the block
+   * pushes.
+   */
+  private def sendRequest(request: PushRequest): Unit = {
+    bytesInFlight = bytesInFlight + request.size
+    reqsInFlight = reqsInFlight + 1
+    numBlocksInFlightPerAddress(request.address) = 
numBlocksInFlightPerAddress.getOrElseUpdate(
+      request.address, 0) + request.blocks.length
+
+    val sizeMap = request.blocks.map { case (blockId, size) => 
(blockId.toString, size) }.toMap
+    val address = request.address
+    val blockIds = request.blocks.map(_._1.toString)
+    val remainingBlocks = new HashSet[String]() ++= blockIds
+
+    val blockPushListener = new BlockFetchingListener {
+      // Initiating a connection and pushing blocks to a remote shuffle 
service is always handled by
+      // the block-push-threads. We should not initiate the connection 
creation in the
+      // blockPushListener callbacks which are invoked by the netty eventloop 
because:
+      // 1. TrasportClient.createConnection(...) blocks for connection to be 
established and it's
+      // recommended to avoid any blocking operations in the eventloop;
+      // 2. The actual connection creation is a task that gets added to the 
task queue of another
+      // eventloop which could have eventloops eventually blocking each other.
+      // Once the blockPushListener is notified of the block push success or 
failure, we
+      // just delegate it to block-push-threads.
+      def handleResult(result: PushResult): Unit = {
+        submitTask(() => {
+          if (updateStateAndCheckIfPushMore(
+            sizeMap(result.blockId), address, remainingBlocks, result)) {
+            pushUpToMax()
+          }
+        })
+      }
+
+      override def onBlockFetchSuccess(blockId: String, data: ManagedBuffer): 
Unit = {
+        logTrace(s"Push for block $blockId to $address successful.")
+        handleResult(PushResult(blockId, null))
+      }
+
+      override def onBlockFetchFailure(blockId: String, exception: Throwable): 
Unit = {
+        // check the message or it's cause to see it needs to be logged.
+        if (!errorHandler.shouldLogError(exception)) {
+          logTrace(s"Pushing block $blockId to $address failed.", exception)
+        } else {
+          logWarning(s"Pushing block $blockId to $address failed.", exception)
+        }
+        handleResult(PushResult(blockId, exception))
+      }
+    }
+    SparkEnv.get.blockManager.blockStoreClient.pushBlocks(
+      address.host, address.port, blockIds.toArray,
+      sliceReqBufferIntoBlockBuffers(request.reqBuffer, 
request.blocks.map(_._2)),
+      blockPushListener)
+  }
+
+  /**
+   * Given the ManagedBuffer representing all the continuous blocks inside the 
shuffle data file
+   * for a PushRequest and an array of individual block sizes, load the buffer 
from disk into
+   * memory and slice it into multiple smaller buffers representing each block.
+   *
+   * With nio ByteBuffer, the individual block buffers share data with the 
initial in memory
+   * buffer loaded from disk. Thus only one copy of the block data is kept in 
memory.
+   * @param reqBuffer A {{FileSegmentManagedBuffer}} representing all the 
continuous blocks in
+   *                  the shuffle data file for a PushRequest
+   * @param blockSizes Array of block sizes
+   * @return Array of in memory buffer for each individual block
+   */
+  private def sliceReqBufferIntoBlockBuffers(
+      reqBuffer: ManagedBuffer,
+      blockSizes: Seq[Long]): Array[ManagedBuffer] = {
+    if (blockSizes.size == 1) {
+      Array(reqBuffer)
+    } else {
+      val inMemoryBuffer = reqBuffer.nioByteBuffer()
+      val blockOffsets = new Array[Long](blockSizes.size)
+      var offset = 0L
+      for (index <- blockSizes.indices) {
+        blockOffsets(index) = offset
+        offset += blockSizes(index)
+      }
+      blockOffsets.zip(blockSizes).map {
+        case (offset, size) =>
+          new NioManagedBuffer(inMemoryBuffer.duplicate()
+            .position(offset.toInt)
+            .limit((offset + size).toInt).asInstanceOf[ByteBuffer].slice())
+      }.toArray
+    }
+  }
+
+  /**
+   * Updates the stats and based on the previous push result decides whether 
to push more blocks
+   * or stop.
+   *
+   * @param bytesPushed     number of bytes pushed.
+   * @param address         address of the remote service
+   * @param remainingBlocks remaining blocks
+   * @param pushResult      result of the last push
+   * @return true if more blocks should be pushed; false otherwise.
+   */
+  private def updateStateAndCheckIfPushMore(
+      bytesPushed: Long,
+      address: BlockManagerId,
+      remainingBlocks: HashSet[String],
+      pushResult: PushResult): Boolean = synchronized {
+    remainingBlocks -= pushResult.blockId
+    bytesInFlight = bytesInFlight - bytesPushed
+    numBlocksInFlightPerAddress(address) = 
numBlocksInFlightPerAddress(address) - 1
+    if (remainingBlocks.isEmpty) {
+      reqsInFlight = reqsInFlight - 1
+    }
+    if (pushResult.failure != null && pushResult.failure.getCause != null &&
+      pushResult.failure.getCause.isInstanceOf[ConnectException]) {
+      // Remove all the blocks for this address just once because removing 
from pushRequests
+      // is expensive. If there is a ConnectException for the first block, all 
the subsequent
+      // blocks to that address will fail, so should avoid removing multiple 
times.
+      if (!unreachableBlockMgrs.contains(address)) {
+        var removed = 0
+        unreachableBlockMgrs.add(address)
+        removed += pushRequests.dequeueAll(req => req.address == 
address).length
+        val droppedReq = deferredPushRequests.remove(address)
+        if (droppedReq.isDefined) {
+          removed += droppedReq.get.length
+        }
+        logWarning(s"Received a ConnectException from $address. " +
+          s"Dropping push of $removed blocks and " +

Review comment:
       nit: should be `requests` rather than `blocks`?

##########
File path: core/src/main/scala/org/apache/spark/storage/BlockId.scala
##########
@@ -81,6 +81,12 @@ case class ShuffleIndexBlockId(shuffleId: Int, mapId: Long, 
reduceId: Int) exten
   override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + 
reduceId + ".index"
 }
 
+@Since("3.1.0")
+@DeveloperApi
+case class ShufflePushBlockId(shuffleId: Int, mapIndex: Int, reduceId: Int) 
extends BlockId {

Review comment:
       Couldn't we just reuse `ShuffleBlockId`? 

##########
File path: 
core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala
##########
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import java.io.File
+import java.net.ConnectException
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.mockito.{Mock, MockitoAnnotations}
+import org.mockito.Answers.RETURNS_SMART_NULLS
+import org.mockito.ArgumentMatchers.any
+import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark._
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockFetchingListener, 
BlockStoreClient}
+import org.apache.spark.network.shuffle.ErrorHandler.BlockPushErrorHandler
+import org.apache.spark.network.util.TransportConf
+import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.storage._
+
+class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach {
+
+  @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = 
_
+  @Mock(answer = RETURNS_SMART_NULLS) private var dependency: 
ShuffleDependency[Int, Int, Int] = _
+  @Mock(answer = RETURNS_SMART_NULLS) private var shuffleClient: 
BlockStoreClient = _
+
+  private var conf: SparkConf = _
+  private var pushedBlocks = new ArrayBuffer[String]
+
+  override def beforeEach(): Unit = {
+    super.beforeEach()
+    conf = new SparkConf(loadDefaults = false)
+    MockitoAnnotations.initMocks(this)
+    when(dependency.partitioner).thenReturn(new HashPartitioner(8))
+    when(dependency.serializer).thenReturn(new JavaSerializer(conf))
+    
when(dependency.getMergerLocs).thenReturn(Seq(BlockManagerId("test-client", 
"test-client", 1)))
+    conf.set("spark.shuffle.push.based.enabled", "true")
+    conf.set("spark.shuffle.service.enabled", "true")
+    // Set the env because the shuffler writer gets the shuffle client 
instance from the env.
+    val mockEnv = mock(classOf[SparkEnv])
+    when(mockEnv.conf).thenReturn(conf)
+    when(mockEnv.blockManager).thenReturn(blockManager)
+    SparkEnv.set(mockEnv)
+    when(blockManager.blockStoreClient).thenReturn(shuffleClient)
+  }
+
+  override def afterEach(): Unit = {
+    pushedBlocks.clear()
+    super.afterEach()
+  }
+
+  private def interceptPushedBlocksForSuccess(): Unit = {
+    when(shuffleClient.pushBlocks(any(), any(), any(), any(), any()))
+      .thenAnswer((invocation: InvocationOnMock) => {
+        val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]]
+        pushedBlocks ++= blocks
+        val managedBuffers = 
invocation.getArguments()(3).asInstanceOf[Array[ManagedBuffer]]
+        val blockFetchListener = 
invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
+        (blocks, managedBuffers).zipped.foreach((blockId, buffer) => {
+          blockFetchListener.onBlockFetchSuccess(blockId, buffer)
+        })
+      })
+  }
+
+  test("Basic block push") {
+    interceptPushedBlocksForSuccess()
+    new TestShuffleBlockPusher(mock(classOf[File]),
+      Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0, 
conf)
+        .initiateBlockPush()
+    verify(shuffleClient, times(1))
+      .pushBlocks(any(), any(), any(), any(), any())
+    assert(pushedBlocks.length == dependency.partitioner.numPartitions)
+    ShuffleBlockPusher.stop()
+  }
+
+  test("Large blocks are skipped for push") {
+    conf.set("spark.shuffle.push.maxBlockSizeToPush", "1k")
+    interceptPushedBlocksForSuccess()
+    new TestShuffleBlockPusher(mock(classOf[File]), Array(2, 2, 2, 2, 2, 2, 2, 
1100),
+      dependency, 0, conf).initiateBlockPush()
+    verify(shuffleClient, times(1))
+      .pushBlocks(any(), any(), any(), any(), any())
+    assert(pushedBlocks.length == dependency.partitioner.numPartitions - 1)
+    ShuffleBlockPusher.stop()
+  }
+
+  test("Number of blocks in flight per address are limited by 
maxBlocksInFlightPerAddress") {
+    conf.set("spark.reducer.maxBlocksInFlightPerAddress", "1")
+    interceptPushedBlocksForSuccess()
+    new TestShuffleBlockPusher(mock(classOf[File]),
+      Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0, 
conf)
+        .initiateBlockPush()
+    verify(shuffleClient, times(8))
+      .pushBlocks(any(), any(), any(), any(), any())
+    assert(pushedBlocks.length == dependency.partitioner.numPartitions)
+    ShuffleBlockPusher.stop()
+  }
+
+  test("Hit maxBlocksInFlightPerAddress limit so that the blocks are 
deferred") {
+    conf.set("spark.reducer.maxBlocksInFlightPerAddress", "2")
+    var blockPendingResponse : String = null
+    var listener : BlockFetchingListener = null
+    when(shuffleClient.pushBlocks(any(), any(), any(), any(), any()))
+      .thenAnswer((invocation: InvocationOnMock) => {
+        val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]]
+        pushedBlocks ++= blocks
+        val managedBuffers = 
invocation.getArguments()(3).asInstanceOf[Array[ManagedBuffer]]
+        val blockFetchListener = 
invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
+        // Expecting 2 blocks
+        assert(blocks.length == 2)
+        if (blockPendingResponse == null) {
+          blockPendingResponse = blocks(1)
+          listener = blockFetchListener
+          // Respond with success only for the first block which will cause 
all the rest of the
+          // blocks to be deferred
+          blockFetchListener.onBlockFetchSuccess(blocks(0), managedBuffers(0))
+        } else {
+          (blocks, managedBuffers).zipped.foreach((blockId, buffer) => {
+            blockFetchListener.onBlockFetchSuccess(blockId, buffer)
+          })
+        }
+      })
+    new TestShuffleBlockPusher(mock(classOf[File]),
+      Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0, 
conf)
+        .initiateBlockPush()
+    verify(shuffleClient, times(1))
+      .pushBlocks(any(), any(), any(), any(), any())
+    assert(pushedBlocks.length == 2)
+    // this will trigger push of deferred blocks
+    listener.onBlockFetchSuccess(blockPendingResponse, 
mock(classOf[ManagedBuffer]))
+    verify(shuffleClient, times(4))
+      .pushBlocks(any(), any(), any(), any(), any())
+    assert(pushedBlocks.length == 8)
+    ShuffleBlockPusher.stop()
+  }
+
+  test("Number of shuffle blocks grouped in a single push request is limited 
by " +
+      "maxBlockBatchSize") {
+    conf.set("spark.shuffle.push.maxBlockBatchSize", "1m")
+    interceptPushedBlocksForSuccess()
+    new TestShuffleBlockPusher(mock(classOf[File]),
+      Array.fill(dependency.partitioner.numPartitions) { 512 * 1024 }, 
dependency, 0, conf)
+        .initiateBlockPush()
+    verify(shuffleClient, times(4))
+      .pushBlocks(any(), any(), any(), any(), any())
+    assert(pushedBlocks.length == dependency.partitioner.numPartitions)
+    ShuffleBlockPusher.stop()
+  }
+
+  test("Error retries") {
+    val pushShuffleSupport = new ShuffleBlockPusher(mock(classOf[File]),
+      Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0, 
conf)
+    val errorHandler = pushShuffleSupport.createErrorHandler()
+    assert(
+      !errorHandler.shouldRetryError(new RuntimeException(
+        new 
IllegalArgumentException(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX))))
+    assert(errorHandler.shouldRetryError(new RuntimeException(new 
ConnectException())))
+    assert(
+      errorHandler.shouldRetryError(new RuntimeException(new 
IllegalArgumentException(
+        BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX))))
+    assert (errorHandler.shouldRetryError(new Throwable()))
+  }
+
+  test("Error logging") {
+    val pushShuffleSupport = new ShuffleBlockPusher(mock(classOf[File]),
+      Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0, 
conf)
+    val errorHandler = pushShuffleSupport.createErrorHandler()
+    assert(
+      !errorHandler.shouldLogError(new RuntimeException(
+        new 
IllegalArgumentException(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX))))
+    assert(!errorHandler.shouldLogError(new RuntimeException(
+      new IllegalArgumentException(
+        BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX))))
+    assert(errorHandler.shouldLogError(new Throwable()))
+  }
+
+  test("Connect exceptions remove all the push requests for that host") {
+    when(dependency.getMergerLocs).thenReturn(
+      Seq(BlockManagerId("client1", "client1", 1), BlockManagerId("client2", 
"client2", 2)))
+    conf.set("spark.reducer.maxBlocksInFlightPerAddress", "2")
+    when(shuffleClient.pushBlocks(any(), any(), any(), any(), any()))
+      .thenAnswer((invocation: InvocationOnMock) => {
+        val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]]
+        pushedBlocks ++= blocks
+        val blockFetchListener = 
invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
+        blocks.foreach(blockId => {
+          blockFetchListener.onBlockFetchFailure(
+            blockId, new RuntimeException(new ConnectException()))
+        })
+      })
+    new TestShuffleBlockPusher(mock(classOf[File]),
+      Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0, 
conf)
+        .initiateBlockPush()
+    verify(shuffleClient, times(2))
+      .pushBlocks(any(), any(), any(), any(), any())
+    // 2 blocks for each merger locations
+    assert(pushedBlocks.length == 4)

Review comment:
       Maybe, also add the assertion for `unreachableBlockMgrs`?

##########
File path: 
core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala
##########
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import java.io.File
+import java.net.ConnectException
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.mockito.{Mock, MockitoAnnotations}
+import org.mockito.Answers.RETURNS_SMART_NULLS
+import org.mockito.ArgumentMatchers.any
+import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark._
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockFetchingListener, 
BlockStoreClient}
+import org.apache.spark.network.shuffle.ErrorHandler.BlockPushErrorHandler
+import org.apache.spark.network.util.TransportConf
+import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.storage._
+
+class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach {
+
+  @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = 
_
+  @Mock(answer = RETURNS_SMART_NULLS) private var dependency: 
ShuffleDependency[Int, Int, Int] = _
+  @Mock(answer = RETURNS_SMART_NULLS) private var shuffleClient: 
BlockStoreClient = _
+
+  private var conf: SparkConf = _
+  private var pushedBlocks = new ArrayBuffer[String]
+
+  override def beforeEach(): Unit = {
+    super.beforeEach()
+    conf = new SparkConf(loadDefaults = false)
+    MockitoAnnotations.initMocks(this)
+    when(dependency.partitioner).thenReturn(new HashPartitioner(8))
+    when(dependency.serializer).thenReturn(new JavaSerializer(conf))
+    
when(dependency.getMergerLocs).thenReturn(Seq(BlockManagerId("test-client", 
"test-client", 1)))
+    conf.set("spark.shuffle.push.based.enabled", "true")
+    conf.set("spark.shuffle.service.enabled", "true")
+    // Set the env because the shuffler writer gets the shuffle client 
instance from the env.
+    val mockEnv = mock(classOf[SparkEnv])
+    when(mockEnv.conf).thenReturn(conf)
+    when(mockEnv.blockManager).thenReturn(blockManager)
+    SparkEnv.set(mockEnv)
+    when(blockManager.blockStoreClient).thenReturn(shuffleClient)
+  }
+
+  override def afterEach(): Unit = {
+    pushedBlocks.clear()
+    super.afterEach()
+  }
+
+  private def interceptPushedBlocksForSuccess(): Unit = {
+    when(shuffleClient.pushBlocks(any(), any(), any(), any(), any()))
+      .thenAnswer((invocation: InvocationOnMock) => {
+        val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]]
+        pushedBlocks ++= blocks
+        val managedBuffers = 
invocation.getArguments()(3).asInstanceOf[Array[ManagedBuffer]]
+        val blockFetchListener = 
invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
+        (blocks, managedBuffers).zipped.foreach((blockId, buffer) => {
+          blockFetchListener.onBlockFetchSuccess(blockId, buffer)
+        })
+      })
+  }
+
+  test("Basic block push") {
+    interceptPushedBlocksForSuccess()
+    new TestShuffleBlockPusher(mock(classOf[File]),
+      Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0, 
conf)
+        .initiateBlockPush()
+    verify(shuffleClient, times(1))
+      .pushBlocks(any(), any(), any(), any(), any())
+    assert(pushedBlocks.length == dependency.partitioner.numPartitions)
+    ShuffleBlockPusher.stop()
+  }
+
+  test("Large blocks are skipped for push") {
+    conf.set("spark.shuffle.push.maxBlockSizeToPush", "1k")
+    interceptPushedBlocksForSuccess()
+    new TestShuffleBlockPusher(mock(classOf[File]), Array(2, 2, 2, 2, 2, 2, 2, 
1100),
+      dependency, 0, conf).initiateBlockPush()
+    verify(shuffleClient, times(1))
+      .pushBlocks(any(), any(), any(), any(), any())
+    assert(pushedBlocks.length == dependency.partitioner.numPartitions - 1)
+    ShuffleBlockPusher.stop()
+  }
+
+  test("Number of blocks in flight per address are limited by 
maxBlocksInFlightPerAddress") {
+    conf.set("spark.reducer.maxBlocksInFlightPerAddress", "1")
+    interceptPushedBlocksForSuccess()
+    new TestShuffleBlockPusher(mock(classOf[File]),
+      Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0, 
conf)
+        .initiateBlockPush()
+    verify(shuffleClient, times(8))
+      .pushBlocks(any(), any(), any(), any(), any())
+    assert(pushedBlocks.length == dependency.partitioner.numPartitions)
+    ShuffleBlockPusher.stop()
+  }
+
+  test("Hit maxBlocksInFlightPerAddress limit so that the blocks are 
deferred") {
+    conf.set("spark.reducer.maxBlocksInFlightPerAddress", "2")
+    var blockPendingResponse : String = null
+    var listener : BlockFetchingListener = null
+    when(shuffleClient.pushBlocks(any(), any(), any(), any(), any()))
+      .thenAnswer((invocation: InvocationOnMock) => {
+        val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]]
+        pushedBlocks ++= blocks
+        val managedBuffers = 
invocation.getArguments()(3).asInstanceOf[Array[ManagedBuffer]]
+        val blockFetchListener = 
invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
+        // Expecting 2 blocks
+        assert(blocks.length == 2)
+        if (blockPendingResponse == null) {
+          blockPendingResponse = blocks(1)
+          listener = blockFetchListener
+          // Respond with success only for the first block which will cause 
all the rest of the
+          // blocks to be deferred
+          blockFetchListener.onBlockFetchSuccess(blocks(0), managedBuffers(0))
+        } else {
+          (blocks, managedBuffers).zipped.foreach((blockId, buffer) => {
+            blockFetchListener.onBlockFetchSuccess(blockId, buffer)
+          })
+        }
+      })
+    new TestShuffleBlockPusher(mock(classOf[File]),
+      Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0, 
conf)
+        .initiateBlockPush()
+    verify(shuffleClient, times(1))
+      .pushBlocks(any(), any(), any(), any(), any())
+    assert(pushedBlocks.length == 2)
+    // this will trigger push of deferred blocks
+    listener.onBlockFetchSuccess(blockPendingResponse, 
mock(classOf[ManagedBuffer]))
+    verify(shuffleClient, times(4))
+      .pushBlocks(any(), any(), any(), any(), any())
+    assert(pushedBlocks.length == 8)
+    ShuffleBlockPusher.stop()
+  }
+
+  test("Number of shuffle blocks grouped in a single push request is limited 
by " +
+      "maxBlockBatchSize") {
+    conf.set("spark.shuffle.push.maxBlockBatchSize", "1m")
+    interceptPushedBlocksForSuccess()
+    new TestShuffleBlockPusher(mock(classOf[File]),
+      Array.fill(dependency.partitioner.numPartitions) { 512 * 1024 }, 
dependency, 0, conf)
+        .initiateBlockPush()
+    verify(shuffleClient, times(4))
+      .pushBlocks(any(), any(), any(), any(), any())
+    assert(pushedBlocks.length == dependency.partitioner.numPartitions)
+    ShuffleBlockPusher.stop()
+  }
+
+  test("Error retries") {

Review comment:
       Could you add a test to cover the block retry process? (Not only test 
whether the error is retriable)

##########
File path: 
core/src/test/scala/org/apache/spark/shuffle/ShuffleBlockPusherSuite.scala
##########
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle
+
+import java.io.File
+import java.net.ConnectException
+import java.nio.ByteBuffer
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.mockito.{Mock, MockitoAnnotations}
+import org.mockito.Answers.RETURNS_SMART_NULLS
+import org.mockito.ArgumentMatchers.any
+import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark._
+import org.apache.spark.network.buffer.ManagedBuffer
+import org.apache.spark.network.shuffle.{BlockFetchingListener, 
BlockStoreClient}
+import org.apache.spark.network.shuffle.ErrorHandler.BlockPushErrorHandler
+import org.apache.spark.network.util.TransportConf
+import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.storage._
+
+class ShuffleBlockPusherSuite extends SparkFunSuite with BeforeAndAfterEach {
+
+  @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = 
_
+  @Mock(answer = RETURNS_SMART_NULLS) private var dependency: 
ShuffleDependency[Int, Int, Int] = _
+  @Mock(answer = RETURNS_SMART_NULLS) private var shuffleClient: 
BlockStoreClient = _
+
+  private var conf: SparkConf = _
+  private var pushedBlocks = new ArrayBuffer[String]
+
+  override def beforeEach(): Unit = {
+    super.beforeEach()
+    conf = new SparkConf(loadDefaults = false)
+    MockitoAnnotations.initMocks(this)
+    when(dependency.partitioner).thenReturn(new HashPartitioner(8))
+    when(dependency.serializer).thenReturn(new JavaSerializer(conf))
+    
when(dependency.getMergerLocs).thenReturn(Seq(BlockManagerId("test-client", 
"test-client", 1)))
+    conf.set("spark.shuffle.push.based.enabled", "true")
+    conf.set("spark.shuffle.service.enabled", "true")
+    // Set the env because the shuffler writer gets the shuffle client 
instance from the env.
+    val mockEnv = mock(classOf[SparkEnv])
+    when(mockEnv.conf).thenReturn(conf)
+    when(mockEnv.blockManager).thenReturn(blockManager)
+    SparkEnv.set(mockEnv)
+    when(blockManager.blockStoreClient).thenReturn(shuffleClient)
+  }
+
+  override def afterEach(): Unit = {
+    pushedBlocks.clear()
+    super.afterEach()
+  }
+
+  private def interceptPushedBlocksForSuccess(): Unit = {
+    when(shuffleClient.pushBlocks(any(), any(), any(), any(), any()))
+      .thenAnswer((invocation: InvocationOnMock) => {
+        val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]]
+        pushedBlocks ++= blocks
+        val managedBuffers = 
invocation.getArguments()(3).asInstanceOf[Array[ManagedBuffer]]
+        val blockFetchListener = 
invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
+        (blocks, managedBuffers).zipped.foreach((blockId, buffer) => {
+          blockFetchListener.onBlockFetchSuccess(blockId, buffer)
+        })
+      })
+  }
+
+  test("Basic block push") {
+    interceptPushedBlocksForSuccess()
+    new TestShuffleBlockPusher(mock(classOf[File]),
+      Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0, 
conf)
+        .initiateBlockPush()
+    verify(shuffleClient, times(1))
+      .pushBlocks(any(), any(), any(), any(), any())
+    assert(pushedBlocks.length == dependency.partitioner.numPartitions)
+    ShuffleBlockPusher.stop()
+  }
+
+  test("Large blocks are skipped for push") {
+    conf.set("spark.shuffle.push.maxBlockSizeToPush", "1k")
+    interceptPushedBlocksForSuccess()
+    new TestShuffleBlockPusher(mock(classOf[File]), Array(2, 2, 2, 2, 2, 2, 2, 
1100),
+      dependency, 0, conf).initiateBlockPush()
+    verify(shuffleClient, times(1))
+      .pushBlocks(any(), any(), any(), any(), any())
+    assert(pushedBlocks.length == dependency.partitioner.numPartitions - 1)
+    ShuffleBlockPusher.stop()
+  }
+
+  test("Number of blocks in flight per address are limited by 
maxBlocksInFlightPerAddress") {
+    conf.set("spark.reducer.maxBlocksInFlightPerAddress", "1")
+    interceptPushedBlocksForSuccess()
+    new TestShuffleBlockPusher(mock(classOf[File]),
+      Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0, 
conf)
+        .initiateBlockPush()
+    verify(shuffleClient, times(8))
+      .pushBlocks(any(), any(), any(), any(), any())
+    assert(pushedBlocks.length == dependency.partitioner.numPartitions)
+    ShuffleBlockPusher.stop()
+  }
+
+  test("Hit maxBlocksInFlightPerAddress limit so that the blocks are 
deferred") {
+    conf.set("spark.reducer.maxBlocksInFlightPerAddress", "2")
+    var blockPendingResponse : String = null
+    var listener : BlockFetchingListener = null
+    when(shuffleClient.pushBlocks(any(), any(), any(), any(), any()))
+      .thenAnswer((invocation: InvocationOnMock) => {
+        val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]]
+        pushedBlocks ++= blocks
+        val managedBuffers = 
invocation.getArguments()(3).asInstanceOf[Array[ManagedBuffer]]
+        val blockFetchListener = 
invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
+        // Expecting 2 blocks
+        assert(blocks.length == 2)
+        if (blockPendingResponse == null) {
+          blockPendingResponse = blocks(1)
+          listener = blockFetchListener
+          // Respond with success only for the first block which will cause 
all the rest of the
+          // blocks to be deferred
+          blockFetchListener.onBlockFetchSuccess(blocks(0), managedBuffers(0))
+        } else {
+          (blocks, managedBuffers).zipped.foreach((blockId, buffer) => {
+            blockFetchListener.onBlockFetchSuccess(blockId, buffer)
+          })
+        }
+      })
+    new TestShuffleBlockPusher(mock(classOf[File]),
+      Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0, 
conf)
+        .initiateBlockPush()
+    verify(shuffleClient, times(1))
+      .pushBlocks(any(), any(), any(), any(), any())
+    assert(pushedBlocks.length == 2)
+    // this will trigger push of deferred blocks
+    listener.onBlockFetchSuccess(blockPendingResponse, 
mock(classOf[ManagedBuffer]))
+    verify(shuffleClient, times(4))
+      .pushBlocks(any(), any(), any(), any(), any())
+    assert(pushedBlocks.length == 8)
+    ShuffleBlockPusher.stop()
+  }
+
+  test("Number of shuffle blocks grouped in a single push request is limited 
by " +
+      "maxBlockBatchSize") {
+    conf.set("spark.shuffle.push.maxBlockBatchSize", "1m")
+    interceptPushedBlocksForSuccess()
+    new TestShuffleBlockPusher(mock(classOf[File]),
+      Array.fill(dependency.partitioner.numPartitions) { 512 * 1024 }, 
dependency, 0, conf)
+        .initiateBlockPush()
+    verify(shuffleClient, times(4))
+      .pushBlocks(any(), any(), any(), any(), any())
+    assert(pushedBlocks.length == dependency.partitioner.numPartitions)
+    ShuffleBlockPusher.stop()
+  }
+
+  test("Error retries") {
+    val pushShuffleSupport = new ShuffleBlockPusher(mock(classOf[File]),
+      Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0, 
conf)
+    val errorHandler = pushShuffleSupport.createErrorHandler()
+    assert(
+      !errorHandler.shouldRetryError(new RuntimeException(
+        new 
IllegalArgumentException(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX))))
+    assert(errorHandler.shouldRetryError(new RuntimeException(new 
ConnectException())))
+    assert(
+      errorHandler.shouldRetryError(new RuntimeException(new 
IllegalArgumentException(
+        BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX))))
+    assert (errorHandler.shouldRetryError(new Throwable()))
+  }
+
+  test("Error logging") {
+    val pushShuffleSupport = new ShuffleBlockPusher(mock(classOf[File]),
+      Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0, 
conf)
+    val errorHandler = pushShuffleSupport.createErrorHandler()
+    assert(
+      !errorHandler.shouldLogError(new RuntimeException(
+        new 
IllegalArgumentException(BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX))))
+    assert(!errorHandler.shouldLogError(new RuntimeException(
+      new IllegalArgumentException(
+        BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX))))
+    assert(errorHandler.shouldLogError(new Throwable()))
+  }
+
+  test("Connect exceptions remove all the push requests for that host") {
+    when(dependency.getMergerLocs).thenReturn(
+      Seq(BlockManagerId("client1", "client1", 1), BlockManagerId("client2", 
"client2", 2)))
+    conf.set("spark.reducer.maxBlocksInFlightPerAddress", "2")
+    when(shuffleClient.pushBlocks(any(), any(), any(), any(), any()))
+      .thenAnswer((invocation: InvocationOnMock) => {
+        val blocks = invocation.getArguments()(2).asInstanceOf[Array[String]]
+        pushedBlocks ++= blocks
+        val blockFetchListener = 
invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
+        blocks.foreach(blockId => {
+          blockFetchListener.onBlockFetchFailure(
+            blockId, new RuntimeException(new ConnectException()))
+        })
+      })
+    new TestShuffleBlockPusher(mock(classOf[File]),
+      Array.fill(dependency.partitioner.numPartitions) { 2 }, dependency, 0, 
conf)
+        .initiateBlockPush()
+    verify(shuffleClient, times(2))
+      .pushBlocks(any(), any(), any(), any(), any())
+    // 2 blocks for each merger locations
+    assert(pushedBlocks.length == 4)
+  }
+
+  private class TestShuffleBlockPusher(
+      dataFile: File,
+      partitionLengths: Array[Long],
+      dep: ShuffleDependency[_, _, _],
+      partitionId: Int,
+      conf: SparkConf)
+    extends ShuffleBlockPusher(dataFile, partitionLengths, dep, partitionId, 
conf) {
+
+    override protected def submitTask(task: Runnable): Unit = {
+     // Making this synchronous for testing
+      task.run()
+    }
+
+    def getPartitionLengths(): Array[Long] = {

Review comment:
       Unused?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to