This is an automated email from the ASF dual-hosted git repository.

mridulm80 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 9a5cd15  [SPARK-32922][SHUFFLE][CORE] Adds support for executors to 
fetch local and remote merged shuffle data
9a5cd15 is described below

commit 9a5cd15e8726ccd93a550f90e8113b80fc6d0122
Author: Chandni Singh <singh.chan...@gmail.com>
AuthorDate: Tue Jun 29 17:44:15 2021 -0500

    [SPARK-32922][SHUFFLE][CORE] Adds support for executors to fetch local and 
remote merged shuffle data
    
    ### What changes were proposed in this pull request?
    This is the shuffle fetch side change where executors can fetch 
local/remote push-merged shuffle data from shuffle services. This is needed for 
push-based shuffle - SPIP 
[SPARK-30602](https://issues.apache.org/jira/browse/SPARK-30602).
    The change adds support to the `ShuffleBlockFetchIterator` to fetch 
push-merged block meta and shuffle chunks from local and remote ESS. If the 
fetch of any of these fails, then the iterator fallsback to fetch the original 
shuffle blocks that belonged to the push-merged block.
    
    ### Why are the changes needed?
    These changes are needed for push-based shuffle. Refer to the SPIP in 
[SPARK-30602](https://issues.apache.org/jira/browse/SPARK-30602).
    
    ### Does this PR introduce _any_ user-facing change?
    When push-based shuffle is turned on then that will fetch push-merged 
blocks from the remote shuffle service. The client logs will indicate this.
    
    ### How was this patch tested?
    Added unit tests.
    The reference PR with the consolidated changes covering the complete 
implementation is also provided in 
[SPARK-30602](https://issues.apache.org/jira/browse/SPARK-30602).
    We have already verified the functionality and the improved performance as 
documented in the SPIP doc.
    
    Lead-authored-by: Chandni Singh chsinghlinkedin.com
    Co-authored-by: Min Shen mshenlinkedin.com
    Co-authored-by: Ye Zhou yezhoulinkedin.com
    
    Closes #32140 from otterc/SPARK-32922.
    
    Lead-authored-by: Chandni Singh <singh.chan...@gmail.com>
    Co-authored-by: Chandni Singh <chsi...@linkedin.com>
    Co-authored-by: Min Shen <ms...@linkedin.com>
    Co-authored-by: otterc <singh.chan...@gmail.com>
    Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
---
 .../scala/org/apache/spark/MapOutputTracker.scala  |   3 +-
 .../spark/serializer/SerializerManager.scala       |   1 +
 .../spark/shuffle/BlockStoreShuffleReader.scala    |   2 +
 .../scala/org/apache/spark/storage/BlockId.scala   |  15 +-
 .../org/apache/spark/storage/BlockManager.scala    |   5 +
 .../spark/storage/PushBasedFetchHelper.scala       | 320 ++++++++++
 .../storage/ShuffleBlockFetcherIterator.scala      | 496 ++++++++++++---
 .../org/apache/spark/storage/BlockIdSuite.scala    |  25 +
 .../storage/ShuffleBlockFetcherIteratorSuite.scala | 684 ++++++++++++++++++++-
 9 files changed, 1464 insertions(+), 87 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala 
b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 2b06c49..e605eea 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -1453,7 +1453,8 @@ private[spark] object MapOutputTracker extends Logging {
             // ShuffleBlockId with mapId being SHUFFLE_PUSH_MAP_ID to indicate 
this is
             // a merged shuffle block.
             splitsByAddress.getOrElseUpdate(mergeStatus.location, 
ListBuffer()) +=
-              ((ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, partId), 
mergeStatus.totalSize, -1))
+              ((ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, partId), 
mergeStatus.totalSize,
+                SHUFFLE_PUSH_MAP_ID))
             // For the "holes" in this pre-merged shuffle partition, i.e., 
unmerged mapper
             // shuffle partition blocks, fetch the original map produced 
shuffle partition blocks
             val mapStatusesWithIndex = mapStatuses.zipWithIndex
diff --git 
a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala 
b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
index 623db9d..640396a 100644
--- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
@@ -110,6 +110,7 @@ private[spark] class SerializerManager(
   private def shouldCompress(blockId: BlockId): Boolean = {
     blockId match {
       case _: ShuffleBlockId => compressShuffle
+      case _: ShuffleBlockChunkId => compressShuffle
       case _: BroadcastBlockId => compressBroadcast
       case _: RDDBlockId => compressRdds
       case _: TempLocalBlockId => compressShuffleSpill
diff --git 
a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala 
b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
index 6782c74..818aa2e 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
@@ -35,6 +35,7 @@ private[spark] class BlockStoreShuffleReader[K, C](
     readMetrics: ShuffleReadMetricsReporter,
     serializerManager: SerializerManager = SparkEnv.get.serializerManager,
     blockManager: BlockManager = SparkEnv.get.blockManager,
+    mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker,
     shouldBatchFetch: Boolean = false)
   extends ShuffleReader[K, C] with Logging {
 
@@ -71,6 +72,7 @@ private[spark] class BlockStoreShuffleReader[K, C](
       context,
       blockManager.blockStoreClient,
       blockManager,
+      mapOutputTracker,
       blocksByAddress,
       serializerManager.wrapStream,
       // Note: we use getSizeAsMb when no suffix is provided for backwards 
compatibility
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockId.scala 
b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
index 47c1b96..dc70a9a 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockId.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockId.scala
@@ -43,6 +43,7 @@ sealed abstract class BlockId {
     (isInstanceOf[ShuffleBlockId] || isInstanceOf[ShuffleBlockBatchId] ||
      isInstanceOf[ShuffleDataBlockId] || isInstanceOf[ShuffleIndexBlockId])
   }
+  def isShuffleChunk: Boolean = isInstanceOf[ShuffleBlockChunkId]
   def isBroadcast: Boolean = isInstanceOf[BroadcastBlockId]
 
   override def toString: String = name
@@ -72,6 +73,15 @@ case class ShuffleBlockBatchId(
   }
 }
 
+@Since("3.2.0")
+@DeveloperApi
+case class ShuffleBlockChunkId(
+    shuffleId: Int,
+    reduceId: Int,
+    chunkId: Int) extends BlockId {
+  override def name: String = "shuffleChunk_" + shuffleId  + "_" + reduceId + 
"_" + chunkId
+}
+
 @DeveloperApi
 case class ShuffleDataBlockId(shuffleId: Int, mapId: Long, reduceId: Int) 
extends BlockId {
   override def name: String = "shuffle_" + shuffleId + "_" + mapId + "_" + 
reduceId + ".data"
@@ -152,7 +162,7 @@ class UnrecognizedBlockId(name: String)
 @DeveloperApi
 object BlockId {
   val RDD = "rdd_([0-9]+)_([0-9]+)".r
-  val SHUFFLE = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)".r
+  val SHUFFLE = "shuffle_([0-9]+)_(-?[0-9]+)_([0-9]+)".r
   val SHUFFLE_BATCH = "shuffle_([0-9]+)_([0-9]+)_([0-9]+)_([0-9]+)".r
   val SHUFFLE_DATA = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).data".r
   val SHUFFLE_INDEX = "shuffle_([0-9]+)_([0-9]+)_([0-9]+).index".r
@@ -160,6 +170,7 @@ object BlockId {
   val SHUFFLE_MERGED_DATA = 
"shuffleMerged_([_A-Za-z0-9]*)_([0-9]+)_([0-9]+).data".r
   val SHUFFLE_MERGED_INDEX = 
"shuffleMerged_([_A-Za-z0-9]*)_([0-9]+)_([0-9]+).index".r
   val SHUFFLE_MERGED_META = 
"shuffleMerged_([_A-Za-z0-9]*)_([0-9]+)_([0-9]+).meta".r
+  val SHUFFLE_CHUNK = "shuffleChunk_([0-9]+)_([0-9]+)_([0-9]+)".r
   val BROADCAST = "broadcast_([0-9]+)([_A-Za-z0-9]*)".r
   val TASKRESULT = "taskresult_([0-9]+)".r
   val STREAM = "input-([0-9]+)-([0-9]+)".r
@@ -186,6 +197,8 @@ object BlockId {
       ShuffleMergedIndexBlockId(appId, shuffleId.toInt, reduceId.toInt)
     case SHUFFLE_MERGED_META(appId, shuffleId, reduceId) =>
       ShuffleMergedMetaBlockId(appId, shuffleId.toInt, reduceId.toInt)
+    case SHUFFLE_CHUNK(shuffleId, reduceId, chunkId) =>
+      ShuffleBlockChunkId(shuffleId.toInt, reduceId.toInt, chunkId.toInt)
     case BROADCAST(broadcastId, field) =>
       BroadcastBlockId(broadcastId.toLong, field.stripPrefix("_"))
     case TASKRESULT(taskId) =>
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala 
b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index df449fb..98d0949 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -132,6 +132,11 @@ private[spark] class HostLocalDirManager(
       executorIdToLocalDirsCache.asMap().asScala.toMap
     }
 
+  private[spark] def getCachedHostLocalDirsFor(executorId: String): 
Option[Array[String]] =
+    executorIdToLocalDirsCache.synchronized {
+      Option(executorIdToLocalDirsCache.getIfPresent(executorId))
+    }
+
   private[spark] def getHostLocalDirs(
       host: String,
       port: Int,
diff --git 
a/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala 
b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
new file mode 100644
index 0000000..63f42a0
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/storage/PushBasedFetchHelper.scala
@@ -0,0 +1,320 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import java.util.concurrent.TimeUnit
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.{Failure, Success}
+
+import org.roaringbitmap.RoaringBitmap
+
+import org.apache.spark.MapOutputTracker
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.shuffle.{BlockStoreClient, MergedBlockMeta, 
MergedBlocksMetaListener}
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
+
+/**
+ * Helper class for [[ShuffleBlockFetcherIterator]] that encapsulates all the 
push-based
+ * functionality to fetch push-merged block meta and shuffle chunks.
+ * A push-merged block contains multiple shuffle chunks where each shuffle 
chunk contains multiple
+ * shuffle blocks that belong to the common reduce partition and were merged 
by the ESS to that
+ * chunk.
+ */
+private class PushBasedFetchHelper(
+   private val iterator: ShuffleBlockFetcherIterator,
+   private val shuffleClient: BlockStoreClient,
+   private val blockManager: BlockManager,
+   private val mapOutputTracker: MapOutputTracker) extends Logging {
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  private[storage] val localShuffleMergerBlockMgrId = BlockManagerId(
+    SHUFFLE_MERGER_IDENTIFIER, blockManager.blockManagerId.host,
+    blockManager.blockManagerId.port, blockManager.blockManagerId.topologyInfo)
+
+  /**
+   * A map for storing shuffle chunk bitmap.
+   */
+  private[this] val chunksMetaMap = new mutable.HashMap[ShuffleBlockChunkId, 
RoaringBitmap]()
+
+  /**
+   * Returns true if the address is for a push-merged block.
+   */
+  def isPushMergedShuffleBlockAddress(address: BlockManagerId): Boolean = {
+    SHUFFLE_MERGER_IDENTIFIER == address.executorId
+  }
+
+  /**
+   * Returns true if the address is of a remote push-merged block. false 
otherwise.
+   */
+  def isRemotePushMergedBlockAddress(address: BlockManagerId): Boolean = {
+    isPushMergedShuffleBlockAddress(address) && address.host != 
blockManager.blockManagerId.host
+  }
+
+  /**
+   * Returns true if the address is of a push-merged-local block. false 
otherwise.
+   */
+  def isLocalPushMergedBlockAddress(address: BlockManagerId): Boolean = {
+    isPushMergedShuffleBlockAddress(address) && address.host == 
blockManager.blockManagerId.host
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked 
and the iterator
+   * processes a response of type 
[[ShuffleBlockFetcherIterator.SuccessFetchResult]].
+   *
+   * @param blockId shuffle chunk id.
+   */
+  def removeChunk(blockId: ShuffleBlockChunkId): Unit = {
+    chunksMetaMap.remove(blockId)
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked 
and the iterator
+   * processes a response of type 
[[ShuffleBlockFetcherIterator.PushMergedLocalMetaFetchResult]].
+   *
+   * @param blockId shuffle chunk id.
+   */
+  def addChunk(blockId: ShuffleBlockChunkId, chunkMeta: RoaringBitmap): Unit = 
{
+    chunksMetaMap(blockId) = chunkMeta
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked 
and the iterator
+   * processes a response of type 
[[ShuffleBlockFetcherIterator.PushMergedRemoteMetaFetchResult]].
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId  reduce id.
+   * @param blockSize size of the push-merged block.
+   * @param bitmaps   chunk bitmaps, where each bitmap contains all the mapIds 
that were merged
+   *                  to that chunk.
+   * @return  shuffle chunks to fetch.
+   */
+  def createChunkBlockInfosFromMetaResponse(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      bitmaps: Array[RoaringBitmap]): ArrayBuffer[(BlockId, Long, Int)] = {
+    val approxChunkSize = blockSize / bitmaps.length
+    val blocksToFetch = new ArrayBuffer[(BlockId, Long, Int)]()
+    for (i <- bitmaps.indices) {
+      val blockChunkId = ShuffleBlockChunkId(shuffleId, reduceId, i)
+      chunksMetaMap.put(blockChunkId, bitmaps(i))
+      logDebug(s"adding block chunk $blockChunkId of size $approxChunkSize")
+      blocksToFetch += ((blockChunkId, approxChunkSize, SHUFFLE_PUSH_MAP_ID))
+    }
+    blocksToFetch
+  }
+
+  /**
+   * This is executed by the task thread when the iterator is initialized and 
only if it has
+   * push-merged blocks for which it needs to fetch the metadata.
+   *
+   * @param req [[ShuffleBlockFetcherIterator.FetchRequest]] that only 
contains requests to fetch
+   *            metadata of push-merged blocks.
+   */
+  def sendFetchMergedStatusRequest(req: FetchRequest): Unit = {
+    val sizeMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, _) =>
+        val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+        ((shuffleBlockId.shuffleId, shuffleBlockId.reduceId), size)
+    }.toMap
+    val address = req.address
+    val mergedBlocksMetaListener = new MergedBlocksMetaListener {
+      override def onSuccess(shuffleId: Int, reduceId: Int, meta: 
MergedBlockMeta): Unit = {
+        logInfo(s"Received the meta of push-merged block for ($shuffleId, 
$reduceId)  " +
+          s"from ${req.address.host}:${req.address.port}")
+        try {
+          
iterator.addToResultsQueue(PushMergedRemoteMetaFetchResult(shuffleId, reduceId,
+            sizeMap((shuffleId, reduceId)), meta.readChunkBitmaps(), address))
+        } catch {
+          case exception: Exception =>
+            logError(s"Failed to parse the meta of push-merged block for 
($shuffleId, " +
+              s"$reduceId) from ${req.address.host}:${req.address.port}", 
exception)
+            iterator.addToResultsQueue(
+              PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, 
address))
+        }
+      }
+
+      override def onFailure(shuffleId: Int, reduceId: Int, exception: 
Throwable): Unit = {
+        logError(s"Failed to get the meta of push-merged block for 
($shuffleId, $reduceId) " +
+          s"from ${req.address.host}:${req.address.port}", exception)
+        iterator.addToResultsQueue(
+          PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, address))
+      }
+    }
+    req.blocks.foreach { block =>
+      val shuffleBlockId = block.blockId.asInstanceOf[ShuffleBlockId]
+      shuffleClient.getMergedBlockMeta(address.host, address.port, 
shuffleBlockId.shuffleId,
+        shuffleBlockId.reduceId, mergedBlocksMetaListener)
+    }
+  }
+
+  /**
+   * This is executed by the task thread when the iterator is initialized. It 
fetches all the
+   * outstanding push-merged local blocks.
+   * @param pushMergedLocalBlocks set of identified merged local blocks and 
their sizes.
+   */
+  def fetchAllPushMergedLocalBlocks(
+      pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    if (pushMergedLocalBlocks.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchPushMergedLocalBlocks(_, 
pushMergedLocalBlocks))
+    }
+  }
+
+  /**
+   * Fetch the push-merged blocks dirs if they are not in the cache and 
eventually fetch push-merged
+   * local blocks.
+   */
+  private def fetchPushMergedLocalBlocks(
+      hostLocalDirManager: HostLocalDirManager,
+      pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): Unit = {
+    val cachedPushedMergedDirs = hostLocalDirManager.getCachedHostLocalDirsFor(
+      SHUFFLE_MERGER_IDENTIFIER)
+    if (cachedPushedMergedDirs.isDefined) {
+      logDebug(s"Fetch the push-merged-local blocks with cached merged dirs: " 
+
+        s"${cachedPushedMergedDirs.get.mkString(", ")}")
+      pushMergedLocalBlocks.foreach { blockId =>
+        fetchPushMergedLocalBlock(blockId, cachedPushedMergedDirs.get,
+          localShuffleMergerBlockMgrId)
+      }
+    } else {
+      logDebug(s"Asynchronous fetch the push-merged-local blocks without 
cached merged dirs")
+      hostLocalDirManager.getHostLocalDirs(localShuffleMergerBlockMgrId.host,
+        localShuffleMergerBlockMgrId.port, Array(SHUFFLE_MERGER_IDENTIFIER)) {
+        case Success(dirs) =>
+          logDebug(s"Fetched merged dirs in " +
+            s"${TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - 
startTimeNs)} ms")
+          pushMergedLocalBlocks.foreach {
+            blockId =>
+              logDebug(s"Successfully fetched local dirs: " +
+                s"${dirs.get(SHUFFLE_MERGER_IDENTIFIER).mkString(", ")}")
+              fetchPushMergedLocalBlock(blockId, 
dirs(SHUFFLE_MERGER_IDENTIFIER),
+                localShuffleMergerBlockMgrId)
+          }
+        case Failure(throwable) =>
+          // If we see an exception with getting the local dirs for 
push-merged-local blocks,
+          // we fallback to fetch the original blocks. We do not report block 
fetch failure.
+          logWarning(s"Error while fetching the merged dirs for 
push-merged-local " +
+            s"blocks: ${pushMergedLocalBlocks.mkString(", ")}. Fetch the 
original blocks instead",
+            throwable)
+          pushMergedLocalBlocks.foreach {
+            blockId =>
+              iterator.addToResultsQueue(FallbackOnPushMergedFailureResult(
+                blockId, localShuffleMergerBlockMgrId, 0, isNetworkReqDone = 
false))
+          }
+      }
+    }
+  }
+
+  /**
+   * Fetch a single push-merged-local block generated. This can also be 
executed by the task thread
+   * as well as the netty thread.
+   * @param blockId ShuffleBlockId to be fetched
+   * @param localDirs Local directories where the push-merged shuffle files 
are stored
+   * @param blockManagerId BlockManagerId
+   */
+  private[this] def fetchPushMergedLocalBlock(
+      blockId: BlockId,
+      localDirs: Array[String],
+      blockManagerId: BlockManagerId): Unit = {
+    try {
+      val shuffleBlockId = blockId.asInstanceOf[ShuffleBlockId]
+      val chunksMeta = blockManager.getLocalMergedBlockMeta(shuffleBlockId, 
localDirs)
+      iterator.addToResultsQueue(PushMergedLocalMetaFetchResult(
+        shuffleBlockId.shuffleId, shuffleBlockId.reduceId, 
chunksMeta.readChunkBitmaps(),
+        localDirs))
+    } catch {
+      case e: Exception =>
+        // If we see an exception with reading a push-merged-local meta, we 
fallback to
+        // fetch the original blocks. We do not report block fetch failure
+        // and will continue with the remaining local block read.
+        logWarning(s"Error occurred while fetching push-merged-local meta, " +
+          s"prepare to fetch the original blocks", e)
+        iterator.addToResultsQueue(
+          FallbackOnPushMergedFailureResult(blockId, blockManagerId, 0, 
isNetworkReqDone = false))
+    }
+  }
+
+  /**
+   * This is executed by the task thread when the `iterator.next()` is invoked 
and the iterator
+   * processes a response of type:
+   * 1) [[ShuffleBlockFetcherIterator.SuccessFetchResult]]
+   * 2) [[ShuffleBlockFetcherIterator.FallbackOnPushMergedFailureResult]]
+   * 3) [[ShuffleBlockFetcherIterator.PushMergedRemoteMetaFailedFetchResult]]
+   *
+   * This initiates fetching fallback blocks for a push-merged block or a 
shuffle chunk that
+   * failed to fetch.
+   * It makes a call to the map output tracker to get the list of original 
blocks for the
+   * given push-merged block/shuffle chunk, split them into remote and local 
blocks, and process
+   * them accordingly.
+   * It also updates the numberOfBlocksToFetch in the iterator as it processes 
failed response and
+   * finds more push-merged requests to remote and again updates it with 
additional requests for
+   * original blocks.
+   * The fallback happens when:
+   * 1. There is an exception while creating shuffle chunks from 
push-merged-local shuffle block.
+   *    See fetchLocalBlock.
+   * 2. There is a failure when fetching remote shuffle chunks.
+   * 3. There is a failure when processing SuccessFetchResult which is for a 
shuffle chunk
+   *    (local or remote).
+   */
+  def initiateFallbackFetchForPushMergedBlock(
+      blockId: BlockId,
+      address: BlockManagerId): Unit = {
+    assert(blockId.isInstanceOf[ShuffleBlockId] || 
blockId.isInstanceOf[ShuffleBlockChunkId])
+    logWarning(s"Falling back to fetch the original blocks for push-merged 
block $blockId")
+    // Increase the blocks processed since we will process another block in 
the next iteration of
+    // the while loop in ShuffleBlockFetcherIterator.next().
+    val fallbackBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, 
Int)])] =
+      blockId match {
+        case shuffleBlockId: ShuffleBlockId =>
+          iterator.decreaseNumBlocksToFetch(1)
+          mapOutputTracker.getMapSizesForMergeResult(
+            shuffleBlockId.shuffleId, shuffleBlockId.reduceId)
+        case _ =>
+          val shuffleChunkId = blockId.asInstanceOf[ShuffleBlockChunkId]
+          val chunkBitmap: RoaringBitmap = 
chunksMetaMap.remove(shuffleChunkId).get
+          var blocksProcessed = 1
+          // When there is a failure to fetch a remote shuffle chunk, then we 
try to
+          // fallback not only for that particular remote shuffle chunk but 
also for all the
+          // pending chunks that belong to the same host. The reason for doing 
so is that it
+          // is very likely that the subsequent requests for shuffle chunks 
from this host will
+          // fail as well. Since, push-based shuffle is best effort and we try 
not to increase the
+          // delay of the fetches, we immediately fallback for all the pending 
shuffle chunks in the
+          // fetchRequests queue.
+          if (isRemotePushMergedBlockAddress(address)) {
+            // Fallback for all the pending fetch requests
+            val pendingShuffleChunks = 
iterator.removePendingChunks(shuffleChunkId, address)
+            pendingShuffleChunks.foreach { pendingBlockId =>
+              logInfo(s"Falling back immediately for shuffle chunk 
$pendingBlockId")
+              val bitmapOfPendingChunk: RoaringBitmap = 
chunksMetaMap.remove(pendingBlockId).get
+              chunkBitmap.or(bitmapOfPendingChunk)
+            }
+            // These blocks were added to numBlocksToFetch so we increment 
numBlocksProcessed
+            blocksProcessed += pendingShuffleChunks.size
+          }
+          iterator.decreaseNumBlocksToFetch(blocksProcessed)
+          mapOutputTracker.getMapSizesForMergeResult(
+            shuffleChunkId.shuffleId, shuffleChunkId.reduceId, chunkBitmap)
+      }
+    iterator.fallbackFetch(fallbackBlocksByAddr)
+  }
+}
diff --git 
a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
 
b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index 4465d76..094c3b5 100644
--- 
a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ 
b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -24,13 +24,15 @@ import java.util.concurrent.atomic.AtomicBoolean
 import javax.annotation.concurrent.GuardedBy
 
 import scala.collection.mutable
-import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, LinkedHashMap, 
Queue}
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue}
 import scala.util.{Failure, Success}
 
 import io.netty.util.internal.OutOfDirectMemoryError
 import org.apache.commons.io.IOUtils
+import org.roaringbitmap.RoaringBitmap
 
-import org.apache.spark.{SparkException, TaskContext}
+import org.apache.spark.{MapOutputTracker, SparkException, TaskContext}
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
 import org.apache.spark.internal.Logging
 import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, 
ManagedBuffer}
 import org.apache.spark.network.shuffle._
@@ -57,6 +59,8 @@ import org.apache.spark.util.{CompletionIterator, 
TaskCompletionListener, Utils}
  *                        block, which indicate the index in the map stage.
  *                        Note that zero-sized blocks are already excluded, 
which happened in
  *                        
[[org.apache.spark.MapOutputTracker.convertMapStatuses]].
+ * @param mapOutputTracker [[MapOutputTracker]] for falling back to fetching 
the original blocks if
+ *                         we fail to fetch shuffle chunks when push based 
shuffle is enabled.
  * @param streamWrapper A function to wrap the returned input stream.
  * @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at 
any given point.
  * @param maxReqsInFlight max number of remote requests to fetch blocks at any 
given point.
@@ -75,6 +79,7 @@ final class ShuffleBlockFetcherIterator(
     context: TaskContext,
     shuffleClient: BlockStoreClient,
     blockManager: BlockManager,
+    mapOutputTracker: MapOutputTracker,
     blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
     streamWrapper: (BlockId, InputStream) => InputStream,
     maxBytesInFlight: Long,
@@ -108,13 +113,6 @@ final class ShuffleBlockFetcherIterator(
 
   private[this] val startTimeNs = System.nanoTime()
 
-  /** Local blocks to fetch, excluding zero-sized blocks. */
-  private[this] val localBlocks = 
scala.collection.mutable.LinkedHashSet[(BlockId, Int)]()
-
-  /** Host local blockIds to fetch by executors, excluding zero-sized blocks. 
*/
-  private[this] val hostLocalBlocksByExecutor =
-    LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]()
-
   /** Host local blocks to fetch, excluding zero-sized blocks. */
   private[this] val hostLocalBlocks = 
scala.collection.mutable.LinkedHashSet[(BlockId, Int)]()
 
@@ -179,6 +177,9 @@ final class ShuffleBlockFetcherIterator(
 
   private[this] val onCompleteCallback = new 
ShuffleFetchCompletionListener(this)
 
+  private[this] val pushBasedFetchHelper = new PushBasedFetchHelper(
+    this, shuffleClient, blockManager, mapOutputTracker)
+
   initialize()
 
   // Decrements the buffer reference count.
@@ -329,7 +330,14 @@ final class ShuffleBlockFetcherIterator(
               }
 
             case _ =>
-              results.put(FailureFetchResult(BlockId(blockId), 
infoMap(blockId)._2, address, e))
+              val block = BlockId(blockId)
+              if (block.isShuffleChunk) {
+                remainingBlocks -= blockId
+                results.put(FallbackOnPushMergedFailureResult(
+                  block, address, infoMap(blockId)._1, 
remainingBlocks.isEmpty))
+              } else {
+                results.put(FailureFetchResult(block, infoMap(blockId)._2, 
address, e))
+              }
           }
         }
       }
@@ -347,20 +355,42 @@ final class ShuffleBlockFetcherIterator(
     }
   }
 
-  private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+  /**
+   * This is called from initialize and also from the fallback which is 
triggered from
+   * [[PushBasedFetchHelper]].
+   */
+  private[this] def partitionBlocksByFetchMode(
+      blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+      localBlocks: mutable.LinkedHashSet[(BlockId, Int)],
+      hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, 
Seq[(BlockId, Long, Int)]],
+      pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): 
ArrayBuffer[FetchRequest] = {
     logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
       + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: 
$maxBlocksInFlightPerAddress")
 
-    // Partition to local, host-local and remote blocks. Remote blocks are 
further split into
-    // FetchRequests of size at most maxBytesInFlight in order to limit the 
amount of data in flight
+    // Partition to local, host-local, push-merged-local, remote (includes 
push-merged-remote)
+    // blocks.Remote blocks are further split into FetchRequests of size at 
most maxBytesInFlight
+    // in order to limit the amount of data in flight
     val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
     var localBlockBytes = 0L
     var hostLocalBlockBytes = 0L
+    var numHostLocalBlocks = 0
+    var pushMergedLocalBlockBytes = 0L
+    val prevNumBlocksToFetch = numBlocksToFetch
 
     val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
+    val localExecIds = Set(blockManager.blockManagerId.executorId, fallback)
     for ((address, blockInfos) <- blocksByAddress) {
-      if (Seq(blockManager.blockManagerId.executorId, 
fallback).contains(address.executorId)) {
-        checkBlockSizes(blockInfos)
+      checkBlockSizes(blockInfos)
+      if (pushBasedFetchHelper.isPushMergedShuffleBlockAddress(address)) {
+        // These are push-merged blocks or shuffle chunks of these blocks.
+        if (address.host == blockManager.blockManagerId.host) {
+          numBlocksToFetch += blockInfos.size
+          pushMergedLocalBlocks ++= blockInfos.map(_._1)
+          pushMergedLocalBlockBytes += blockInfos.map(_._3).sum
+        } else {
+          collectFetchRequests(address, blockInfos, collectedRemoteRequests)
+        }
+      } else if (localExecIds.contains(address.executorId)) {
         val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
           blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), 
doBatchFetch)
         numBlocksToFetch += mergedBlockInfos.size
@@ -368,14 +398,13 @@ final class ShuffleBlockFetcherIterator(
         localBlockBytes += mergedBlockInfos.map(_.size).sum
       } else if (blockManager.hostLocalDirManager.isDefined &&
         address.host == blockManager.blockManagerId.host) {
-        checkBlockSizes(blockInfos)
         val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
           blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)), 
doBatchFetch)
         numBlocksToFetch += mergedBlockInfos.size
         val blocksForAddress =
           mergedBlockInfos.map(info => (info.blockId, info.size, 
info.mapIndex))
         hostLocalBlocksByExecutor += address -> blocksForAddress
-        hostLocalBlocks ++= blocksForAddress.map(info => (info._1, info._3))
+        numHostLocalBlocks += blocksForAddress.size
         hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum
       } else {
         val (_, timeCost) = Utils.timeTakenMs[Unit] {
@@ -386,40 +415,54 @@ final class ShuffleBlockFetcherIterator(
     }
     val (remoteBlockBytes, numRemoteBlocks) =
       collectedRemoteRequests.foldLeft((0L, 0))((x, y) => (x._1 + y.size, x._2 
+ y.blocks.size))
-    val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes
-    assert(numBlocksToFetch == localBlocks.size + hostLocalBlocks.size + 
numRemoteBlocks,
-      s"The number of non-empty blocks $numBlocksToFetch doesn't equal to the 
number of local " +
-        s"blocks ${localBlocks.size} + the number of host-local blocks 
${hostLocalBlocks.size} " +
-        s"+ the number of remote blocks ${numRemoteBlocks}.")
-    logInfo(s"Getting $numBlocksToFetch (${Utils.bytesToString(totalBytes)}) 
non-empty blocks " +
-      s"including ${localBlocks.size} 
(${Utils.bytesToString(localBlockBytes)}) local and " +
-      s"${hostLocalBlocks.size} (${Utils.bytesToString(hostLocalBlockBytes)}) 
" +
-      s"host-local and $numRemoteBlocks 
(${Utils.bytesToString(remoteBlockBytes)}) remote blocks")
+    val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes +
+      pushMergedLocalBlockBytes
+    val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch
+    assert(blocksToFetchCurrentIteration == localBlocks.size +
+      numHostLocalBlocks + numRemoteBlocks + pushMergedLocalBlocks.size,
+        s"The number of non-empty blocks $blocksToFetchCurrentIteration 
doesn't equal to the sum " +
+        s"of the number of local blocks ${localBlocks.size} + " +
+        s"the number of host-local blocks ${numHostLocalBlocks} " +
+        s"the number of push-merged-local blocks ${pushMergedLocalBlocks.size} 
" +
+        s"+ the number of remote blocks ${numRemoteBlocks} ")
+    logInfo(s"Getting $blocksToFetchCurrentIteration " +
+      s"(${Utils.bytesToString(totalBytes)}) non-empty blocks including " +
+      s"${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local 
and " +
+      s"${numHostLocalBlocks} (${Utils.bytesToString(hostLocalBlockBytes)}) " +
+      s"host-local and ${pushMergedLocalBlocks.size} " +
+      s"(${Utils.bytesToString(pushMergedLocalBlockBytes)}) " +
+      s"push-merged-local and $numRemoteBlocks 
(${Utils.bytesToString(remoteBlockBytes)}) " +
+      s"remote blocks")
+    this.hostLocalBlocks ++= hostLocalBlocksByExecutor.values
+      .flatMap { infos => infos.map(info => (info._1, info._3)) }
     collectedRemoteRequests
   }
 
   private def createFetchRequest(
       blocks: Seq[FetchBlockInfo],
-      address: BlockManagerId): FetchRequest = {
+      address: BlockManagerId,
+      forMergedMetas: Boolean): FetchRequest = {
     logDebug(s"Creating fetch request of ${blocks.map(_.size).sum} at $address 
"
       + s"with ${blocks.size} blocks")
-    FetchRequest(address, blocks)
+    FetchRequest(address, blocks, forMergedMetas)
   }
 
   private def createFetchRequests(
       curBlocks: Seq[FetchBlockInfo],
       address: BlockManagerId,
       isLast: Boolean,
-      collectedRemoteRequests: ArrayBuffer[FetchRequest]): 
ArrayBuffer[FetchBlockInfo] = {
-    val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks, 
doBatchFetch)
+      collectedRemoteRequests: ArrayBuffer[FetchRequest],
+      enableBatchFetch: Boolean,
+      forMergedMetas: Boolean = false): ArrayBuffer[FetchBlockInfo] = {
+    val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks, 
enableBatchFetch)
     numBlocksToFetch += mergedBlocks.size
     val retBlocks = new ArrayBuffer[FetchBlockInfo]
     if (mergedBlocks.length <= maxBlocksInFlightPerAddress) {
-      collectedRemoteRequests += createFetchRequest(mergedBlocks, address)
+      collectedRemoteRequests += createFetchRequest(mergedBlocks, address, 
forMergedMetas)
     } else {
       mergedBlocks.grouped(maxBlocksInFlightPerAddress).foreach { blocks =>
         if (blocks.length == maxBlocksInFlightPerAddress || isLast) {
-          collectedRemoteRequests += createFetchRequest(blocks, address)
+          collectedRemoteRequests += createFetchRequest(blocks, address, 
forMergedMetas)
         } else {
           // The last group does not exceed `maxBlocksInFlightPerAddress`. Put 
it back
           // to `curBlocks`.
@@ -441,20 +484,45 @@ final class ShuffleBlockFetcherIterator(
 
     while (iterator.hasNext) {
       val (blockId, size, mapIndex) = iterator.next()
-      assertPositiveBlockSize(blockId, size)
       curBlocks += FetchBlockInfo(blockId, size, mapIndex)
       curRequestSize += size
-      // For batch fetch, the actual block in flight should count for merged 
block.
-      val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= 
maxBlocksInFlightPerAddress
-      if (curRequestSize >= targetRemoteRequestSize || mayExceedsMaxBlocks) {
-        curBlocks = createFetchRequests(curBlocks.toSeq, address, isLast = 
false,
-          collectedRemoteRequests)
-        curRequestSize = curBlocks.map(_.size).sum
+      blockId match {
+        // Either all blocks are push-merged blocks, shuffle chunks, or 
original blocks.
+        // Based on these types, we decide to do batch fetch and create 
FetchRequests with
+        // forMergedMetas set.
+        case ShuffleBlockChunkId(_, _, _) =>
+          if (curRequestSize >= targetRemoteRequestSize ||
+            curBlocks.size >= maxBlocksInFlightPerAddress) {
+            curBlocks = createFetchRequests(curBlocks.toSeq, address, isLast = 
false,
+              collectedRemoteRequests, enableBatchFetch = false)
+            curRequestSize = curBlocks.map(_.size).sum
+          }
+        case ShuffleBlockId(_, SHUFFLE_PUSH_MAP_ID, _) =>
+          if (curBlocks.size >= maxBlocksInFlightPerAddress) {
+            curBlocks = createFetchRequests(curBlocks.toSeq, address, isLast = 
false,
+              collectedRemoteRequests, enableBatchFetch = false, 
forMergedMetas = true)
+          }
+        case _ =>
+          // For batch fetch, the actual block in flight should count for 
merged block.
+          val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= 
maxBlocksInFlightPerAddress
+          if (curRequestSize >= targetRemoteRequestSize || 
mayExceedsMaxBlocks) {
+            curBlocks = createFetchRequests(curBlocks.toSeq, address, isLast = 
false,
+              collectedRemoteRequests, doBatchFetch)
+            curRequestSize = curBlocks.map(_.size).sum
+          }
       }
     }
     // Add in the final request
     if (curBlocks.nonEmpty) {
-      createFetchRequests(curBlocks.toSeq, address, isLast = true, 
collectedRemoteRequests)
+      val (enableBatchFetch, forMergedMetas) = {
+        curBlocks.head.blockId match {
+          case ShuffleBlockChunkId(_, _, _) => (false, false)
+          case ShuffleBlockId(_, SHUFFLE_PUSH_MAP_ID, _) => (false, true)
+          case _ => (doBatchFetch, false)
+        }
+      }
+      createFetchRequests(curBlocks.toSeq, address, isLast = true, 
collectedRemoteRequests,
+        enableBatchFetch = enableBatchFetch, forMergedMetas = forMergedMetas)
     }
   }
 
@@ -475,7 +543,8 @@ final class ShuffleBlockFetcherIterator(
    * `ManagedBuffer`'s memory is allocated lazily when we create the input 
stream, so all we
    * track in-memory are the ManagedBuffer references themselves.
    */
-  private[this] def fetchLocalBlocks(): Unit = {
+  private[this] def fetchLocalBlocks(
+      localBlocks: mutable.LinkedHashSet[(BlockId, Int)]): Unit = {
     logDebug(s"Start fetching local blocks: ${localBlocks.mkString(", ")}")
     val iter = localBlocks.iterator
     while (iter.hasNext) {
@@ -529,7 +598,10 @@ final class ShuffleBlockFetcherIterator(
    * `ManagedBuffer`'s memory is allocated lazily when we create the input 
stream, so all we
    * track in-memory are the ManagedBuffer references themselves.
    */
-  private[this] def fetchHostLocalBlocks(hostLocalDirManager: 
HostLocalDirManager): Unit = {
+  private[this] def fetchHostLocalBlocks(
+      hostLocalDirManager: HostLocalDirManager,
+      hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, 
Seq[(BlockId, Long, Int)]]):
+    Unit = {
     val cachedDirsByExec = hostLocalDirManager.getCachedHostLocalDirs
     val (hostLocalBlocksWithCachedDirs, hostLocalBlocksWithMissingDirs) = {
       val (hasCache, noCache) = hostLocalBlocksByExecutor.partition { case 
(hostLocalBmId, _) =>
@@ -602,9 +674,15 @@ final class ShuffleBlockFetcherIterator(
   private[this] def initialize(): Unit = {
     // Add a task completion callback (called in both success case and failure 
case) to cleanup.
     context.addTaskCompletionListener(onCompleteCallback)
-
-    // Partition blocks by the different fetch modes: local, host-local and 
remote blocks.
-    val remoteRequests = partitionBlocksByFetchMode()
+    // Local blocks to fetch, excluding zero-sized blocks.
+    val localBlocks = mutable.LinkedHashSet[(BlockId, Int)]()
+    val hostLocalBlocksByExecutor =
+      mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]()
+    val pushMergedLocalBlocks = mutable.LinkedHashSet[BlockId]()
+    // Partition blocks by the different fetch modes: local, host-local, 
push-merged-local and
+    // remote blocks.
+    val remoteRequests = partitionBlocksByFetchMode(
+      blocksByAddress, localBlocks, hostLocalBlocksByExecutor, 
pushMergedLocalBlocks)
     // Add the remote requests into our queue in a random order
     fetchRequests ++= Utils.randomize(remoteRequests)
     assert ((0 == reqsInFlight) == (0 == bytesInFlight),
@@ -620,11 +698,18 @@ final class ShuffleBlockFetcherIterator(
       (if (numDeferredRequest > 0 ) s", deferred $numDeferredRequest requests" 
else ""))
 
     // Get Local Blocks
-    fetchLocalBlocks()
+    fetchLocalBlocks(localBlocks)
     logDebug(s"Got local blocks in ${Utils.getUsedTimeNs(startTimeNs)}")
+    // Get host local blocks if any
+    fetchAllHostLocalBlocks(hostLocalBlocksByExecutor)
+    pushBasedFetchHelper.fetchAllPushMergedLocalBlocks(pushMergedLocalBlocks)
+  }
 
-    if (hostLocalBlocks.nonEmpty) {
-      blockManager.hostLocalDirManager.foreach(fetchHostLocalBlocks)
+  private def fetchAllHostLocalBlocks(
+      hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, 
Seq[(BlockId, Long, Int)]]):
+    Unit = {
+    if (hostLocalBlocksByExecutor.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchHostLocalBlocks(_, 
hostLocalBlocksByExecutor))
     }
   }
 
@@ -661,7 +746,9 @@ final class ShuffleBlockFetcherIterator(
       result match {
         case r @ SuccessFetchResult(blockId, mapIndex, address, size, buf, 
isNetworkReqDone) =>
           if (address != blockManager.blockManagerId) {
-            if (hostLocalBlocks.contains(blockId -> mapIndex)) {
+            if (hostLocalBlocks.contains(blockId -> mapIndex) ||
+              pushBasedFetchHelper.isLocalPushMergedBlockAddress(address)) {
+              // It is a host local block or a local shuffle chunk
               shuffleMetrics.incLocalBlocksFetched(1)
               shuffleMetrics.incLocalBytesRead(buf.size)
             } else {
@@ -712,38 +799,63 @@ final class ShuffleBlockFetcherIterator(
                 case e: IOException => logError("Failed to create input stream 
from local block", e)
               }
               buf.release()
-              throwFetchFailedException(blockId, mapIndex, address, e)
-          }
-          try {
-            input = streamWrapper(blockId, in)
-            // If the stream is compressed or wrapped, then we optionally 
decompress/unwrap the
-            // first maxBytesInFlight/3 bytes into memory, to check for 
corruption in that portion
-            // of the data. But even if 'detectCorruptUseExtraMemory' 
configuration is off, or if
-            // the corruption is later, we'll still detect the corruption 
later in the stream.
-            streamCompressedOrEncrypted = !input.eq(in)
-            if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) {
-              // TODO: manage the memory used here, and spill it into disk in 
case of OOM.
-              input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3)
-            }
-          } catch {
-            case e: IOException =>
-              buf.release()
-              if (buf.isInstanceOf[FileSegmentManagedBuffer]
-                  || corruptedBlocks.contains(blockId)) {
-                throwFetchFailedException(blockId, mapIndex, address, e)
-              } else {
-                logWarning(s"got an corrupted block $blockId from $address, 
fetch again", e)
-                corruptedBlocks += blockId
-                fetchRequests += FetchRequest(
-                  address, Array(FetchBlockInfo(blockId, size, mapIndex)))
+              if (blockId.isShuffleChunk) {
+                
pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address)
+                // Set result to null to trigger another iteration of the 
while loop to get either.
                 result = null
+                null
+              } else {
+                throwFetchFailedException(blockId, mapIndex, address, e)
+              }
+          }
+          if (in != null) {
+            try {
+              input = streamWrapper(blockId, in)
+              // If the stream is compressed or wrapped, then we optionally 
decompress/unwrap the
+              // first maxBytesInFlight/3 bytes into memory, to check for 
corruption in that portion
+              // of the data. But even if 'detectCorruptUseExtraMemory' 
configuration is off, or if
+              // the corruption is later, we'll still detect the corruption 
later in the stream.
+              streamCompressedOrEncrypted = !input.eq(in)
+              if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) {
+                // TODO: manage the memory used here, and spill it into disk 
in case of OOM.
+                input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3)
+              }
+            } catch {
+              case e: IOException =>
+                buf.release()
+                if (blockId.isShuffleChunk) {
+                  // Retrying a corrupt block may result again in a corrupt 
block. For shuffle
+                  // chunks, we opt to fallback on the original shuffle blocks 
that belong to that
+                  // corrupt shuffle chunk immediately instead of retrying to 
fetch the corrupt
+                  // chunk. This also makes the code simpler because the 
chunkMeta corresponding to
+                  // a shuffle chunk is always removed from chunksMetaMap 
whenever a shuffle chunk
+                  // gets processed. If we try to re-fetch a corrupt shuffle 
chunk, then it has to
+                  // be added back to the chunksMetaMap.
+                  
pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address)
+                  // Set result to null to trigger another iteration of the 
while loop.
+                  result = null
+                } else {
+                  if (buf.isInstanceOf[FileSegmentManagedBuffer]
+                    || corruptedBlocks.contains(blockId)) {
+                    throwFetchFailedException(blockId, mapIndex, address, e)
+                  } else {
+                    logWarning(s"got an corrupted block $blockId from 
$address, fetch again", e)
+                    corruptedBlocks += blockId
+                    fetchRequests += FetchRequest(
+                      address, Array(FetchBlockInfo(blockId, size, mapIndex)))
+                    result = null
+                  }
+                }
+            } finally {
+              if (blockId.isShuffleChunk) {
+                
pushBasedFetchHelper.removeChunk(blockId.asInstanceOf[ShuffleBlockChunkId])
+              }
+              // TODO: release the buf here to free memory earlier
+              if (input == null) {
+                // Close the underlying stream if there was an issue in 
wrapping the stream using
+                // streamWrapper
+                in.close()
               }
-          } finally {
-            // TODO: release the buf here to free memory earlier
-            if (input == null) {
-              // Close the underlying stream if there was an issue in wrapping 
the stream using
-              // streamWrapper
-              in.close()
             }
           }
 
@@ -767,6 +879,83 @@ final class ShuffleBlockFetcherIterator(
             deferredFetchRequests.getOrElseUpdate(address, new 
Queue[FetchRequest]())
           defReqQueue.enqueue(request)
           result = null
+
+        case FallbackOnPushMergedFailureResult(blockId, address, size, 
isNetworkReqDone) =>
+          // We get this result in 3 cases:
+          // 1. Failure to fetch the data of a remote shuffle chunk. In this 
case, the
+          //    blockId is a ShuffleBlockChunkId.
+          // 2. Failure to read the push-merged-local meta. In this case, the 
blockId is
+          //    ShuffleBlockId.
+          // 3. Failure to get the push-merged-local directories from the ESS. 
In this case, the
+          //    blockId is ShuffleBlockId.
+          if (pushBasedFetchHelper.isRemotePushMergedBlockAddress(address)) {
+            numBlocksInFlightPerAddress(address) = 
numBlocksInFlightPerAddress(address) - 1
+            bytesInFlight -= size
+          }
+          if (isNetworkReqDone) {
+            reqsInFlight -= 1
+            logDebug("Number of requests in flight " + reqsInFlight)
+          }
+          
pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address)
+          // Set result to null to trigger another iteration of the while loop 
to get either
+          // a SuccessFetchResult or a FailureFetchResult.
+          result = null
+
+          case PushMergedLocalMetaFetchResult(shuffleId, reduceId, bitmaps, 
localDirs) =>
+            // Fetch push-merged-local shuffle block data as multiple shuffle 
chunks
+            val shuffleBlockId = ShuffleBlockId(shuffleId, 
SHUFFLE_PUSH_MAP_ID, reduceId)
+            try {
+              val bufs: Seq[ManagedBuffer] = 
blockManager.getLocalMergedBlockData(shuffleBlockId,
+                localDirs)
+              // Since the request for local block meta completed 
successfully, numBlocksToFetch
+              // is decremented.
+              numBlocksToFetch -= 1
+              // Update total number of blocks to fetch, reflecting the 
multiple local shuffle
+              // chunks.
+              numBlocksToFetch += bufs.size
+              bufs.zipWithIndex.foreach { case (buf, chunkId) =>
+                buf.retain()
+                val shuffleChunkId = ShuffleBlockChunkId(shuffleId, reduceId, 
chunkId)
+                pushBasedFetchHelper.addChunk(shuffleChunkId, bitmaps(chunkId))
+                results.put(SuccessFetchResult(shuffleChunkId, 
SHUFFLE_PUSH_MAP_ID,
+                  pushBasedFetchHelper.localShuffleMergerBlockMgrId, 
buf.size(), buf,
+                  isNetworkReqDone = false))
+              }
+            } catch {
+              case e: Exception =>
+                // If we see an exception with reading push-merged-local index 
file, we fallback
+                // to fetch the original blocks. We do not report block fetch 
failure
+                // and will continue with the remaining local block read.
+                logWarning(s"Error occurred while reading push-merged-local 
index, " +
+                  s"prepare to fetch the original blocks", e)
+                pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(
+                  shuffleBlockId, 
pushBasedFetchHelper.localShuffleMergerBlockMgrId)
+            }
+            result = null
+
+        case PushMergedRemoteMetaFetchResult(shuffleId, reduceId, blockSize, 
bitmaps, address) =>
+          // The original meta request is processed so we decrease 
numBlocksToFetch and
+          // numBlocksInFlightPerAddress by 1. We will collect new shuffle 
chunks request and the
+          // count of this is added to numBlocksToFetch in 
collectFetchReqsFromMergedBlocks.
+          numBlocksInFlightPerAddress(address) = 
numBlocksInFlightPerAddress(address) - 1
+          numBlocksToFetch -= 1
+          val blocksToFetch = 
pushBasedFetchHelper.createChunkBlockInfosFromMetaResponse(
+            shuffleId, reduceId, blockSize, bitmaps)
+          val additionalRemoteReqs = new ArrayBuffer[FetchRequest]
+          collectFetchRequests(address, blocksToFetch.toSeq, 
additionalRemoteReqs)
+          fetchRequests ++= additionalRemoteReqs
+          // Set result to null to force another iteration.
+          result = null
+
+        case PushMergedRemoteMetaFailedFetchResult(shuffleId, reduceId, 
address) =>
+          // The original meta request failed so we decrease 
numBlocksInFlightPerAddress by 1.
+          numBlocksInFlightPerAddress(address) = 
numBlocksInFlightPerAddress(address) - 1
+          // If we fail to fetch the meta of a push-merged block, we fall back 
to fetching the
+          // original blocks.
+          pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(
+            ShuffleBlockId(shuffleId, SHUFFLE_PUSH_MAP_ID, reduceId), address)
+          // Set result to null to force another iteration.
+          result = null
       }
 
       // Send fetch requests up to maxBytesInFlight
@@ -834,7 +1023,11 @@ final class ShuffleBlockFetcherIterator(
     }
 
     def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = {
-      sendRequest(request)
+      if (request.forMergedMetas) {
+        pushBasedFetchHelper.sendFetchMergedStatusRequest(request)
+      } else {
+        sendRequest(request)
+      }
       numBlocksInFlightPerAddress(remoteAddress) =
         numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + 
request.blocks.size
     }
@@ -871,6 +1064,82 @@ final class ShuffleBlockFetcherIterator(
           "Failed to get block " + blockId + ", which is not a shuffle block", 
e)
     }
   }
+
+  /**
+   * All the below methods are used by [[PushBasedFetchHelper]] to communicate 
with the iterator
+   */
+  private[storage] def addToResultsQueue(result: FetchResult): Unit = {
+    results.put(result)
+  }
+
+  private[storage] def decreaseNumBlocksToFetch(blocksFetched: Int): Unit = {
+    numBlocksToFetch -= blocksFetched
+  }
+
+  /**
+   * Currently used by [[PushBasedFetchHelper]] to fetch fallback blocks when 
there is a fetch
+   * failure related to a push-merged block or shuffle chunk.
+   * This is executed by the task thread when the `iterator.next()` is invoked 
and if that initiates
+   * fallback.
+   */
+  private[storage] def fallbackFetch(
+      originalBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, 
Int)])]): Unit = {
+    val originalLocalBlocks = mutable.LinkedHashSet[(BlockId, Int)]()
+    val originalHostLocalBlocksByExecutor =
+      mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]()
+    val originalMergedLocalBlocks = mutable.LinkedHashSet[BlockId]()
+    val originalRemoteReqs = partitionBlocksByFetchMode(originalBlocksByAddr,
+      originalLocalBlocks, originalHostLocalBlocksByExecutor, 
originalMergedLocalBlocks)
+    // Add the remote requests into our queue in a random order
+    fetchRequests ++= Utils.randomize(originalRemoteReqs)
+    logInfo(s"Created ${originalRemoteReqs.size} fallback remote requests for 
push-merged")
+    // fetch all the fallback blocks that are local.
+    fetchLocalBlocks(originalLocalBlocks)
+    // Merged local blocks should be empty during fallback
+    assert(originalMergedLocalBlocks.isEmpty,
+      "There should be zero push-merged blocks during fallback")
+    // Some of the fallback local blocks could be host local blocks
+    fetchAllHostLocalBlocks(originalHostLocalBlocksByExecutor)
+  }
+
+  /**
+   * Removes all the pending shuffle chunks that are on the same host and have 
the same reduceId as
+   * the current chunk that had a fetch failure.
+   * This is executed by the task thread when the `iterator.next()` is invoked 
and if that initiates
+   * fallback.
+   *
+   * @return set of all the removed shuffle chunk Ids.
+   */
+  private[storage] def removePendingChunks(
+      failedBlockId: ShuffleBlockChunkId,
+      address: BlockManagerId): mutable.HashSet[ShuffleBlockChunkId] = {
+    val removedChunkIds = new mutable.HashSet[ShuffleBlockChunkId]()
+
+    def sameShuffleReducePartition(block: BlockId): Boolean = {
+      val chunkId = block.asInstanceOf[ShuffleBlockChunkId]
+      chunkId.shuffleId == failedBlockId.shuffleId && chunkId.reduceId == 
failedBlockId.reduceId
+    }
+
+    def filterRequests(queue: mutable.Queue[FetchRequest]): Unit = {
+      val fetchRequestsToRemove = new mutable.Queue[FetchRequest]()
+      fetchRequestsToRemove ++= queue.dequeueAll { req =>
+        val firstBlock = req.blocks.head
+        firstBlock.blockId.isShuffleChunk && req.address.equals(address) &&
+          sameShuffleReducePartition(firstBlock.blockId)
+      }
+      fetchRequestsToRemove.foreach { _ =>
+        removedChunkIds ++=
+          
fetchRequestsToRemove.flatMap(_.blocks.map(_.blockId.asInstanceOf[ShuffleBlockChunkId]))
+      }
+    }
+
+    filterRequests(fetchRequests)
+    deferredFetchRequests.get(address).foreach { defRequests =>
+      filterRequests(defRequests)
+      if (defRequests.isEmpty) deferredFetchRequests.remove(address)
+    }
+    removedChunkIds
+  }
 }
 
 /**
@@ -1074,8 +1343,13 @@ object ShuffleBlockFetcherIterator {
    * A request to fetch blocks from a remote BlockManager.
    * @param address remote BlockManager to fetch from.
    * @param blocks Sequence of the information for blocks to fetch from the 
same address.
+   * @param forMergedMetas true if this request is for requesting push-merged 
meta information;
+   *                       false if it is for regular or shuffle chunks.
    */
-  case class FetchRequest(address: BlockManagerId, blocks: 
Seq[FetchBlockInfo]) {
+  case class FetchRequest(
+      address: BlockManagerId,
+      blocks: Seq[FetchBlockInfo],
+      forMergedMetas: Boolean = false) {
     val size = blocks.map(_.size).sum
   }
 
@@ -1124,4 +1398,64 @@ object ShuffleBlockFetcherIterator {
    */
   private[storage]
   case class DeferFetchRequestResult(fetchRequest: FetchRequest) extends 
FetchResult
+
+  /**
+   * Result of an un-successful fetch of either of these:
+   * 1) Remote shuffle chunk.
+   * 2) Local push-merged block.
+   *
+   * Instead of treating this as a [[FailureFetchResult]], we fallback to 
fetch the original blocks.
+   *
+   * @param blockId block id
+   * @param address BlockManager that the push-merged block was attempted to 
be fetched from
+   * @param size size of the block, used to update bytesInFlight.
+   * @param isNetworkReqDone Is this the last network request for this host in 
this fetch
+   *                         request. Used to update reqsInFlight.
+   */
+  private[storage] case class FallbackOnPushMergedFailureResult(blockId: 
BlockId,
+      address: BlockManagerId,
+      size: Long,
+      isNetworkReqDone: Boolean) extends FetchResult
+
+  /**
+   * Result of a successful fetch of meta information for a remote push-merged 
block.
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId reduce id.
+   * @param blockSize size of each push-merged block.
+   * @param bitmaps bitmaps for every chunk.
+   * @param address BlockManager that the meta was fetched from.
+   */
+  private[storage] case class PushMergedRemoteMetaFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      blockSize: Long,
+      bitmaps: Array[RoaringBitmap],
+      address: BlockManagerId) extends FetchResult
+
+  /**
+   * Result of a failure while fetching the meta information for a remote 
push-merged block.
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId reduce id.
+   * @param address BlockManager that the meta was fetched from.
+   */
+  private[storage] case class PushMergedRemoteMetaFailedFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      address: BlockManagerId) extends FetchResult
+
+  /**
+   * Result of a successful fetch of meta information for a push-merged-local 
block.
+   *
+   * @param shuffleId shuffle id.
+   * @param reduceId reduce id.
+   * @param bitmaps bitmaps for every chunk.
+   * @param localDirs local directories where the push-merged shuffle files 
are storedl
+   */
+  private[storage] case class PushMergedLocalMetaFetchResult(
+      shuffleId: Int,
+      reduceId: Int,
+      bitmaps: Array[RoaringBitmap],
+      localDirs: Array[String]) extends FetchResult
 }
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala 
b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala
index b3138d7..e8c3c2d 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala
@@ -210,4 +210,29 @@ class BlockIdSuite extends SparkFunSuite {
     assert(!id.isShuffle)
     assertSame(id, BlockId(id.toString))
   }
+
+  test("merged shuffle id") {
+    val id = ShuffleBlockId(1, -1, 0)
+    assertSame(id, ShuffleBlockId(1, -1, 0))
+    assertDifferent(id, ShuffleBlockId(1, 1, 1))
+    assert(id.name === "shuffle_1_-1_0")
+    assert(id.asRDDId === None)
+    assert(id.shuffleId === 1)
+    assert(id.mapId === -1)
+    assert(id.reduceId === 0)
+    assertSame(id, BlockId(id.toString))
+  }
+
+  test("shuffle chunk") {
+    val id = ShuffleBlockChunkId(1, 1, 0)
+    assertSame(id, ShuffleBlockChunkId(1, 1, 0))
+    assertDifferent(id, ShuffleBlockChunkId(1, 1, 1))
+    assert(id.name === "shuffleChunk_1_1_0")
+    assert(id.asRDDId === None)
+    assert(id.shuffleId === 1)
+    assert(id.reduceId === 1)
+    assert(id.chunkId === 0)
+    assertSame(id, BlockId(id.toString))
+  }
+
 }
diff --git 
a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
 
b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index 9b63347..a5143cd 100644
--- 
a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -22,32 +22,41 @@ import java.nio.ByteBuffer
 import java.util.UUID
 import java.util.concurrent.{CompletableFuture, Semaphore}
 
+import scala.collection.mutable
 import scala.concurrent.ExecutionContext.Implicits.global
 import scala.concurrent.Future
 
 import io.netty.util.internal.OutOfDirectMemoryError
 import org.apache.log4j.Level
 import org.mockito.ArgumentMatchers.{any, eq => meq}
-import org.mockito.Mockito.{mock, times, verify, when}
+import org.mockito.Mockito.{doThrow, mock, times, verify, when}
+import org.mockito.invocation.InvocationOnMock
 import org.mockito.stubbing.Answer
+import org.roaringbitmap.RoaringBitmap
 import org.scalatest.PrivateMethodTester
 
-import org.apache.spark.{SparkFunSuite, TaskContext}
+import org.apache.spark.{MapOutputTracker, SparkFunSuite, TaskContext}
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
 import org.apache.spark.network._
 import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, 
ManagedBuffer}
-import org.apache.spark.network.shuffle.{BlockFetchingListener, 
DownloadFileManager, ExternalBlockStoreClient}
+import org.apache.spark.network.shuffle.{BlockFetchingListener, 
DownloadFileManager, ExternalBlockStoreClient, MergedBlockMeta, 
MergedBlocksMetaListener}
 import org.apache.spark.network.util.LimitedInputStream
 import org.apache.spark.shuffle.{FetchFailedException, 
ShuffleReadMetricsReporter}
-import org.apache.spark.storage.ShuffleBlockFetcherIterator.FetchBlockInfo
+import org.apache.spark.storage.BlockManagerId.SHUFFLE_MERGER_IDENTIFIER
+import org.apache.spark.storage.ShuffleBlockFetcherIterator._
 import org.apache.spark.util.Utils
 
 
 class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with 
PrivateMethodTester {
 
   private var transfer: BlockTransferService = _
+  private var mapOutputTracker: MapOutputTracker = _
 
   override def beforeEach(): Unit = {
     transfer = mock(classOf[BlockTransferService])
+    mapOutputTracker = mock(classOf[MapOutputTracker])
+    when(mapOutputTracker.getMapSizesForMergeResult(any(), any(), any()))
+      .thenReturn(Seq.empty.iterator)
   }
 
   private def doReturn(value: Any) = org.mockito.Mockito.doReturn(value, 
Seq.empty: _*)
@@ -178,6 +187,7 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
       tContext,
       transfer,
       blockManager.getOrElse(createMockBlockManager()),
+      mapOutputTracker,
       blocksByAddress.toIterator,
       (_, in) => streamWrapperLimitSize.map(new LimitedInputStream(in, 
_)).getOrElse(in),
       maxBytesInFlight,
@@ -1017,4 +1027,670 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
     }
     assert(e.getMessage.contains("fetch failed after 10 retries due to Netty 
OOM"))
   }
+
+  /**
+   * Prepares the transfer to trigger success for all the blocks present in 
blockChunks. It will
+   * trigger failure of block which is not part of blockChunks.
+   */
+  private def configureMockTransferForPushShuffle(
+     blocksSem: Semaphore,
+     blockChunks: Map[BlockId, ManagedBuffer]): Unit = {
+    when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
+      .thenAnswer((invocation: InvocationOnMock) => {
+        val regularBlocks = 
invocation.getArguments()(3).asInstanceOf[Array[String]]
+        val blockFetchListener =
+          invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
+        Future {
+          regularBlocks.foreach(blockId => {
+            val shuffleBlock = BlockId(blockId)
+            if (!blockChunks.contains(shuffleBlock)) {
+              // force failure
+              blockFetchListener.onBlockFetchFailure(
+                blockId, new RuntimeException("failed to fetch"))
+            } else {
+              blockFetchListener.onBlockFetchSuccess(blockId, 
blockChunks(shuffleBlock))
+            }
+            blocksSem.release()
+          })
+        }
+      })
+  }
+
+  test("SPARK-32922: fetch remote push-merged block meta") {
+    val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]](
+      (BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "push-merged-host", 1),
+        toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, 
SHUFFLE_PUSH_MAP_ID)),
+      (BlockManagerId("remote-client-1", "remote-host-1", 1),
+        toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 3, 2)), 1L, 
1))
+    )
+    val blockChunks = Map[BlockId, ManagedBuffer](
+      ShuffleBlockId(0, 0, 2) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 3, 2) -> createMockManagedBuffer(),
+      ShuffleBlockChunkId(0, 2, 0) -> createMockManagedBuffer(),
+      ShuffleBlockChunkId(0, 2, 1) -> createMockManagedBuffer()
+    )
+    val blocksSem = new Semaphore(0)
+    configureMockTransferForPushShuffle(blocksSem, blockChunks)
+
+    val metaSem = new Semaphore(0)
+    val pushMergedBlockMeta = mock(classOf[MergedBlockMeta])
+    when(pushMergedBlockMeta.getNumChunks).thenReturn(2)
+    
when(pushMergedBlockMeta.getChunksBitmapBuffer).thenReturn(mock(classOf[ManagedBuffer]))
+    val roaringBitmaps = Array(new RoaringBitmap, new RoaringBitmap)
+    when(pushMergedBlockMeta.readChunkBitmaps()).thenReturn(roaringBitmaps)
+    when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any()))
+      .thenAnswer((invocation: InvocationOnMock) => {
+        val metaListener = 
invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener]
+        Future {
+          val shuffleId = invocation.getArguments()(2).asInstanceOf[Int]
+          val reduceId = invocation.getArguments()(3).asInstanceOf[Int]
+          logInfo(s"acquiring semaphore for host = 
${invocation.getArguments()(0)}, " +
+            s"port = ${invocation.getArguments()(1)}, " +
+            s"shuffleId = $shuffleId, reduceId = $reduceId")
+          metaSem.acquire()
+          metaListener.onSuccess(shuffleId, reduceId, pushMergedBlockMeta)
+        }
+      })
+    val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress)
+    blocksSem.acquire(2)
+    // The first block should be returned without an exception
+    val (id1, _) = iterator.next()
+    assert(id1 === ShuffleBlockId(0, 0, 2))
+    val (id2, _) = iterator.next()
+    assert(id2 === ShuffleBlockId(0, 3, 2))
+    metaSem.release()
+    val (id3, _) = iterator.next()
+    blocksSem.acquire()
+    assert(id3 === ShuffleBlockChunkId(0, 2, 0))
+    val (id4, _) = iterator.next()
+    blocksSem.acquire()
+    assert(id4 === ShuffleBlockChunkId(0, 2, 1))
+    assert(!iterator.hasNext)
+  }
+
+  test("SPARK-32922: failed to fetch remote push-merged block meta so fallback 
to " +
+    "original blocks.") {
+    val remoteBmId = BlockManagerId("remote-client", "remote-host-1", 1)
+    val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]](
+      (BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "push-merged-host", 1),
+        toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, 
SHUFFLE_PUSH_MAP_ID)),
+      (remoteBmId, toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 
3, 2)), 1L, 1)))
+
+    val blockChunks = Map[BlockId, ManagedBuffer](
+      ShuffleBlockId(0, 0, 2) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 1, 2) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 2, 2) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 3, 2) -> createMockManagedBuffer()
+    )
+    when(mapOutputTracker.getMapSizesForMergeResult(0, 2)).thenReturn(
+      Seq((remoteBmId, toBlockList(
+        Seq(ShuffleBlockId(0, 1, 2), ShuffleBlockId(0, 2, 2)), 1L, 
1))).iterator)
+    val blocksSem = new Semaphore(0)
+    configureMockTransferForPushShuffle(blocksSem, blockChunks)
+    when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any()))
+      .thenAnswer((invocation: InvocationOnMock) => {
+        val metaListener = 
invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener]
+        val shuffleId = invocation.getArguments()(2).asInstanceOf[Int]
+        val reduceId = invocation.getArguments()(3).asInstanceOf[Int]
+        Future {
+          metaListener.onFailure(shuffleId, reduceId, new 
RuntimeException("forced error"))
+        }
+      })
+    val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress)
+    blocksSem.acquire(2)
+    val (id1, _) = iterator.next()
+    assert(id1 === ShuffleBlockId(0, 0, 2))
+    val (id2, _) = iterator.next()
+    assert(id2 === ShuffleBlockId(0, 3, 2))
+    val (id3, _) = iterator.next()
+    blocksSem.acquire(2)
+    assert(id3 === ShuffleBlockId(0, 1, 2))
+    val (id4, _) = iterator.next()
+    assert(id4 === ShuffleBlockId(0, 2, 2))
+    assert(!iterator.hasNext)
+  }
+
+  test("SPARK-32922: iterator has just 1 push-merged block and fails to fetch 
the meta") {
+    val remoteBmId = BlockManagerId("remote-client", "remote-host-1", 1)
+    val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]](
+      (BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "push-merged-host", 1),
+        toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, 
SHUFFLE_PUSH_MAP_ID)))
+
+    val blockChunks = Map[BlockId, ManagedBuffer](
+      ShuffleBlockId(0, 0, 2) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 1, 2) -> createMockManagedBuffer()
+    )
+    when(mapOutputTracker.getMapSizesForMergeResult(0, 2)).thenReturn(
+      Seq((remoteBmId, toBlockList(
+        Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 1, 2)), 1L, 
1))).iterator)
+    val blocksSem = new Semaphore(0)
+    configureMockTransferForPushShuffle(blocksSem, blockChunks)
+    when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any()))
+      .thenAnswer((invocation: InvocationOnMock) => {
+        val metaListener = 
invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener]
+        val shuffleId = invocation.getArguments()(2).asInstanceOf[Int]
+        val reduceId = invocation.getArguments()(3).asInstanceOf[Int]
+        Future {
+          metaListener.onFailure(shuffleId, reduceId, new 
RuntimeException("forced error"))
+        }
+      })
+    val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress)
+    val (id1, _) = iterator.next()
+    blocksSem.acquire(2)
+    assert(id1 === ShuffleBlockId(0, 0, 2))
+    val (id2, _) = iterator.next()
+    assert(id2 === ShuffleBlockId(0, 1, 2))
+    assert(!iterator.hasNext)
+  }
+
+  private def createMockPushMergedBlockMeta(
+      numChunks: Int,
+      bitmaps: Array[RoaringBitmap]): MergedBlockMeta = {
+    val pushMergedBlockMeta = mock(classOf[MergedBlockMeta])
+    when(pushMergedBlockMeta.getNumChunks).thenReturn(numChunks)
+    if (bitmaps == null) {
+      when(pushMergedBlockMeta.readChunkBitmaps()).thenThrow(new 
IOException("forced error"))
+    } else {
+      when(pushMergedBlockMeta.readChunkBitmaps()).thenReturn(bitmaps)
+    }
+    
doReturn(createMockManagedBuffer()).when(pushMergedBlockMeta).getChunksBitmapBuffer
+    pushMergedBlockMeta
+  }
+
+  private def prepareForFallbackToLocalBlocks(
+      blockManager: BlockManager,
+      localDirsMap : Map[String, Array[String]],
+      failReadingLocalChunksMeta: Boolean = false):
+    Map[BlockManagerId, Seq[(BlockId, Long, Int)]] = {
+    val localHost = "test-local-host"
+    val localBmId = BlockManagerId("test-client", localHost, 1)
+    doReturn(localBmId).when(blockManager).blockManagerId
+    initHostLocalDirManager(blockManager, localDirsMap)
+
+    val blockBuffers = Map[BlockId, ManagedBuffer](
+      ShuffleBlockId(0, 0, 2) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 1, 2) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 2, 2) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 3, 2) -> createMockManagedBuffer()
+    )
+
+    doReturn(blockBuffers(ShuffleBlockId(0, 0, 2))).when(blockManager)
+      .getLocalBlockData(ShuffleBlockId(0, 0, 2))
+    doReturn(blockBuffers(ShuffleBlockId(0, 1, 2))).when(blockManager)
+      .getLocalBlockData(ShuffleBlockId(0, 1, 2))
+    doReturn(blockBuffers(ShuffleBlockId(0, 2, 2))).when(blockManager)
+      .getLocalBlockData(ShuffleBlockId(0, 2, 2))
+    doReturn(blockBuffers(ShuffleBlockId(0, 3, 2))).when(blockManager)
+      .getLocalBlockData(ShuffleBlockId(0, 3, 2))
+
+    val dirsForMergedData = localDirsMap(SHUFFLE_MERGER_IDENTIFIER)
+    doReturn(Seq(createMockManagedBuffer(2))).when(blockManager)
+      .getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), 
dirsForMergedData)
+
+    // Get a valid chunk meta for this test
+    val bitmaps = Array(new RoaringBitmap)
+    bitmaps(0).add(1) // chunk 0 has mapId 1
+    bitmaps(0).add(2) // chunk 0 has mapId 2
+    val pushMergedBlockMeta: MergedBlockMeta = if (failReadingLocalChunksMeta) 
{
+      createMockPushMergedBlockMeta(bitmaps.length, null)
+    } else {
+      createMockPushMergedBlockMeta(bitmaps.length, bitmaps)
+    }
+    when(blockManager.getLocalMergedBlockMeta(ShuffleBlockId(0, 
SHUFFLE_PUSH_MAP_ID, 2),
+      dirsForMergedData)).thenReturn(pushMergedBlockMeta)
+    when(mapOutputTracker.getMapSizesForMergeResult(0, 2)).thenReturn(
+      Seq((localBmId,
+        toBlockList(Seq(ShuffleBlockId(0, 1, 2), ShuffleBlockId(0, 2, 2)), 1L, 
1))).iterator)
+    when(mapOutputTracker.getMapSizesForMergeResult(0, 2, bitmaps(0)))
+      .thenReturn(Seq((localBmId,
+        toBlockList(Seq(ShuffleBlockId(0, 1, 2), ShuffleBlockId(0, 2, 2)), 1L, 
1))).iterator)
+    val pushMergedBmId = BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, localHost, 
1)
+    Map[BlockManagerId, Seq[(BlockId, Long, Int)]](
+      (localBmId, toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 
3, 2)), 1L, 1)),
+      (pushMergedBmId, toBlockList(
+        Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, 
SHUFFLE_PUSH_MAP_ID)))
+  }
+
+  private def verifyLocalBlocksFromFallback(iterator: 
ShuffleBlockFetcherIterator): Unit = {
+    val (id1, _) = iterator.next()
+    assert(id1 === ShuffleBlockId(0, 0, 2))
+    val (id2, _) = iterator.next()
+    assert(id2 === ShuffleBlockId(0, 3, 2))
+    val (id3, _) = iterator.next()
+    assert(id3 === ShuffleBlockId(0, 1, 2))
+    val (id4, _) = iterator.next()
+    assert(id4 === ShuffleBlockId(0, 2, 2))
+    assert(!iterator.hasNext)
+  }
+
+  test("SPARK-32922: failure to fetch push-merged-local meta should fallback 
to fetch " +
+    "original shuffle blocks") {
+    val blockManager = mock(classOf[BlockManager])
+    val localDirs = Array("testPath1", "testPath2")
+    val blocksByAddress = prepareForFallbackToLocalBlocks(
+      blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs))
+    doThrow(new RuntimeException("Forced error")).when(blockManager)
+      .getLocalMergedBlockMeta(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), 
localDirs)
+    val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress,
+      blockManager = Some(blockManager))
+    verifyLocalBlocksFromFallback(iterator)
+  }
+
+  test("SPARK-32922: failure to reading chunkBitmaps of push-merged-local meta 
should " +
+    "fallback to original shuffle blocks") {
+    val blockManager = mock(classOf[BlockManager])
+    val localDirs = Array("local-dir")
+    val blocksByAddress = prepareForFallbackToLocalBlocks(
+      blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs),
+      failReadingLocalChunksMeta = true)
+    val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress,
+      blockManager = Some(blockManager), streamWrapperLimitSize = Some(100))
+    verifyLocalBlocksFromFallback(iterator)
+  }
+
+  test("SPARK-32922: failure to fetch push-merged-local data should fallback 
to fetch " +
+    "original shuffle blocks") {
+    val blockManager = mock(classOf[BlockManager])
+    val localDirs = Array("testPath1", "testPath2")
+    val blocksByAddress = prepareForFallbackToLocalBlocks(
+      blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs))
+    doThrow(new RuntimeException("Forced error")).when(blockManager)
+      .getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), 
localDirs)
+    val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress,
+      blockManager = Some(blockManager))
+    verifyLocalBlocksFromFallback(iterator)
+  }
+
+  test("SPARK-32922: failure to fetch push-merged-local meta of a single 
merged block " +
+    "should not drop the fetch of other push-merged-local blocks") {
+    val blockManager = mock(classOf[BlockManager])
+    val localDirs = Array("testPath1", "testPath2")
+    prepareForFallbackToLocalBlocks(
+      blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs))
+    val localHost = "test-local-host"
+    val localBmId = BlockManagerId("test-client", localHost, 1)
+    val pushMergedBmId = BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, localHost, 
1)
+    val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]](
+      (localBmId, toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 
3, 2)), 1L, 1)),
+      (pushMergedBmId, toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 
2),
+        ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 3)), 2L, SHUFFLE_PUSH_MAP_ID)))
+    doThrow(new RuntimeException("Forced error")).when(blockManager)
+      .getLocalMergedBlockMeta(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), 
localDirs)
+    // Create a valid chunk meta for partition 3
+    val bitmaps = Array(new RoaringBitmap)
+    bitmaps(0).add(1) // chunk 0 has mapId 1
+    doReturn(createMockPushMergedBlockMeta(bitmaps.length, 
bitmaps)).when(blockManager)
+      .getLocalMergedBlockMeta(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 3), 
localDirs)
+    // Return valid buffer for chunk in partition 3
+    doReturn(Seq(createMockManagedBuffer(2))).when(blockManager)
+      .getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 3), 
localDirs)
+    val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress,
+      blockManager = Some(blockManager))
+    val (id1, _) = iterator.next()
+    assert(id1 === ShuffleBlockId(0, 0, 2))
+    val (id2, _) = iterator.next()
+    assert(id2 === ShuffleBlockId(0, 3, 2))
+    val (id3, _) = iterator.next()
+    assert(id3 === ShuffleBlockId(0, 1, 2))
+    val (id4, _) = iterator.next()
+    assert(id4 === ShuffleBlockId(0, 2, 2))
+    val (id5, _) = iterator.next()
+    assert(id5 === ShuffleBlockChunkId(0, 3, 0))
+    assert(!iterator.hasNext)
+  }
+
+  test("SPARK-32922: failure to fetch push-merged block as well as fallback 
block should throw " +
+    "a FetchFailedException") {
+    val blockManager = mock(classOf[BlockManager])
+    val localDirs = Array("testPath1", "testPath2")
+    val localBmId = BlockManagerId("test-client", "test-local-host", 1)
+    doReturn(localBmId).when(blockManager).blockManagerId
+    val localDirsMap = Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs)
+    initHostLocalDirManager(blockManager, localDirsMap)
+
+    doReturn(createMockManagedBuffer()).when(blockManager)
+      .getLocalBlockData(ShuffleBlockId(0, 0, 2))
+    // Force to fail reading of original block (0, 1, 2) that will throw a 
FetchFailed exception.
+    doThrow(new RuntimeException("Forced error")).when(blockManager)
+      .getLocalBlockData(ShuffleBlockId(0, 1, 2))
+
+    val dirsForMergedData = localDirsMap(SHUFFLE_MERGER_IDENTIFIER)
+    // Since bitmaps are null, this will fail reading the push-merged block 
meta causing fallback to
+    // initiate.
+    val pushMergedBlockMeta: MergedBlockMeta = 
createMockPushMergedBlockMeta(2, null)
+    when(blockManager.getLocalMergedBlockMeta(ShuffleBlockId(0, 
SHUFFLE_PUSH_MAP_ID, 2),
+      dirsForMergedData)).thenReturn(pushMergedBlockMeta)
+    when(mapOutputTracker.getMapSizesForMergeResult(0, 2)).thenReturn(
+      Seq((localBmId,
+        toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 1, 2)), 1L, 
1))).iterator)
+
+    val blocksByAddress = Map[BlockManagerId, Seq[(BlockId, Long, Int)]](
+      (BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "test-local-host", 1), 
toBlockList(
+        Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, 
SHUFFLE_PUSH_MAP_ID)))
+    val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress,
+      blockManager = Some(blockManager))
+    // 1st instance of iterator.next() returns the original shuffle block (0, 
0, 2)
+    assert(iterator.next()._1 === ShuffleBlockId(0, 0, 2))
+    // 2nd instance of iterator.next() throws FetchFailedException
+    intercept[FetchFailedException] { iterator.next() }
+  }
+
+  test("SPARK-32922: failure to fetch push-merged-local block should fallback 
to fetch " +
+    "original shuffle blocks which contain host-local blocks") {
+    val blockManager = mock(classOf[BlockManager])
+    // BlockManagerId from another executor on the same host
+    val hostLocalBmId = BlockManagerId("test-client-1", "test-local-host", 1)
+    val hostLocalDirs = Map("test-client-1" -> Array("local-dir"),
+      SHUFFLE_MERGER_IDENTIFIER -> Array("local-dir"))
+    val blocksByAddress = prepareForFallbackToLocalBlocks(blockManager, 
hostLocalDirs)
+
+    doThrow(new RuntimeException("Forced error")).when(blockManager)
+      .getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), 
Array("local-dir"))
+    // host local read for a shuffle block
+    doReturn(createMockManagedBuffer()).when(blockManager)
+      .getHostLocalShuffleData(ShuffleBlockId(0, 2, 2), Array("local-dir"))
+    when(mapOutputTracker.getMapSizesForMergeResult(0, 2)).thenAnswer(
+      (_: InvocationOnMock) => {
+        Seq((blockManager.blockManagerId, toBlockList(Seq(ShuffleBlockId(0, 1, 
2)), 1L, 1)),
+          (hostLocalBmId, toBlockList(Seq(ShuffleBlockId(0, 2, 2)), 1L, 
1))).iterator
+      })
+    val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress,
+      blockManager = Some(blockManager))
+    verifyLocalBlocksFromFallback(iterator)
+  }
+
+  test("SPARK-32922: fetch host local blocks with push-merged block during 
initialization " +
+    "and fallback to host locals blocks") {
+    val blockManager = mock(classOf[BlockManager])
+    // BlockManagerId of another executor on the same host
+    val hostLocalBmId = BlockManagerId("test-client-1", "test-local-host", 1)
+    val originalHostLocalBmId = BlockManagerId("test-client-2", 
"test-local-host", 1)
+    val hostLocalDirs = Map(hostLocalBmId.executorId -> Array("local-dir"),
+      SHUFFLE_MERGER_IDENTIFIER -> Array("local-dir"),
+      originalHostLocalBmId.executorId -> Array("local-dir"))
+
+    val hostLocalBlocks = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
+      (hostLocalBmId, Seq((ShuffleBlockId(0, 5, 2), 1L, 1))))
+
+    val blocksByAddress = prepareForFallbackToLocalBlocks(
+      blockManager, hostLocalDirs) ++ hostLocalBlocks
+
+    doThrow(new RuntimeException("Forced error")).when(blockManager)
+      .getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), 
Array("local-dir"))
+    // host Local read for this original shuffle block
+    doReturn(createMockManagedBuffer()).when(blockManager)
+      .getHostLocalShuffleData(ShuffleBlockId(0, 1, 2), Array("local-dir"))
+    doReturn(createMockManagedBuffer()).when(blockManager)
+      .getHostLocalShuffleData(ShuffleBlockId(0, 5, 2), Array("local-dir"))
+    when(mapOutputTracker.getMapSizesForMergeResult(0, 2)).thenAnswer(
+      (_: InvocationOnMock) => {
+        Seq((blockManager.blockManagerId, toBlockList(Seq(ShuffleBlockId(0, 2, 
2)), 1L, 1)),
+          (originalHostLocalBmId, toBlockList(Seq(ShuffleBlockId(0, 1, 2)), 
1L, 1))).iterator
+      })
+    val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress,
+      blockManager = Some(blockManager))
+    val (id1, _) = iterator.next()
+    assert(id1 === ShuffleBlockId(0, 0, 2))
+    val (id2, _) = iterator.next()
+    assert(id2 === ShuffleBlockId(0, 3, 2))
+    val (id3, _) = iterator.next()
+    assert(id3 === ShuffleBlockId(0, 5, 2))
+    val (id4, _) = iterator.next()
+    assert(id4 === ShuffleBlockId(0, 2, 2))
+    val (id5, _) = iterator.next()
+    assert(id5 === ShuffleBlockId(0, 1, 2))
+    assert(!iterator.hasNext)
+  }
+
+  test("SPARK-32922: failure while reading local shuffle chunks should 
fallback to original " +
+    "shuffle blocks") {
+    val blockManager = mock(classOf[BlockManager])
+    val localDirs = Array("local-dir")
+    val blocksByAddress = prepareForFallbackToLocalBlocks(
+      blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs))
+    // This will throw an IOException when input stream is created from the 
ManagedBuffer
+    doReturn(Seq({
+      new FileSegmentManagedBuffer(null, new File("non-existent"), 0, 100)
+      })).when(blockManager).getLocalMergedBlockData(
+        ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), localDirs)
+    val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress,
+      blockManager = Some(blockManager))
+    verifyLocalBlocksFromFallback(iterator)
+  }
+
+  test("SPARK-32922: fallback to original shuffle block when a push-merged 
shuffle chunk " +
+    "is corrupt") {
+    val blockManager = mock(classOf[BlockManager])
+    val localDirs = Array("local-dir")
+    val blocksByAddress = prepareForFallbackToLocalBlocks(
+      blockManager, Map(SHUFFLE_MERGER_IDENTIFIER -> localDirs))
+    val corruptBuffer = createMockManagedBuffer(2)
+    doReturn(Seq({corruptBuffer})).when(blockManager)
+      .getLocalMergedBlockData(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2), 
localDirs)
+    val corruptStream = mock(classOf[InputStream])
+    when(corruptStream.read(any(), any(), any())).thenThrow(new 
IOException("corrupt"))
+    doReturn(corruptStream).when(corruptBuffer).createInputStream()
+    val iterator = createShuffleBlockIteratorWithDefaults(blocksByAddress,
+      blockManager = Some(blockManager), streamWrapperLimitSize = Some(100))
+    verifyLocalBlocksFromFallback(iterator)
+  }
+
+  test("SPARK-32922: fallback to original blocks when failed to fetch remote 
shuffle chunk") {
+    val blockChunks = Map[BlockId, ManagedBuffer](
+      ShuffleBlockChunkId(0, 2, 0) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 3, 2) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 4, 2) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 5, 2) -> createMockManagedBuffer()
+    )
+    val blocksSem = new Semaphore(0)
+    configureMockTransferForPushShuffle(blocksSem, blockChunks)
+    val bitmaps = Array(new RoaringBitmap, new RoaringBitmap)
+    bitmaps(1).add(3)
+    bitmaps(1).add(4)
+    bitmaps(1).add(5)
+    val pushMergedBlockMeta = createMockPushMergedBlockMeta(2, bitmaps)
+    when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any()))
+      .thenAnswer((invocation: InvocationOnMock) => {
+        val metaListener = 
invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener]
+        val shuffleId = invocation.getArguments()(2).asInstanceOf[Int]
+        val reduceId = invocation.getArguments()(3).asInstanceOf[Int]
+        Future {
+          metaListener.onSuccess(shuffleId, reduceId, pushMergedBlockMeta)
+        }
+      })
+    val fallbackBlocksByAddr = Seq[(BlockManagerId, Seq[(BlockId, Long, 
Int)])](
+      (BlockManagerId("remote-client", "remote-host-2", 1),
+        toBlockList(Seq(ShuffleBlockId(0, 3, 2), ShuffleBlockId(0, 4, 2),
+          ShuffleBlockId(0, 5, 2)), 4L, 1)))
+    when(mapOutputTracker.getMapSizesForMergeResult(any(), any(), any()))
+      .thenReturn(fallbackBlocksByAddr.iterator)
+    val iterator = createShuffleBlockIteratorWithDefaults(Map(
+      BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "remote-client-1", 1) ->
+        toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 12L, 
SHUFFLE_PUSH_MAP_ID)))
+    val (id1, _) = iterator.next()
+    blocksSem.acquire(1)
+    assert(id1 === ShuffleBlockChunkId(0, 2, 0))
+    val (id3, _) = iterator.next()
+    blocksSem.acquire(3)
+    assert(id3 === ShuffleBlockId(0, 3, 2))
+    val (id4, _) = iterator.next()
+    assert(id4 === ShuffleBlockId(0, 4, 2))
+    val (id5, _) = iterator.next()
+    assert(id5 === ShuffleBlockId(0, 5, 2))
+    assert(!iterator.hasNext)
+  }
+
+  test("SPARK-32922: fallback to original blocks when failed to parse remote 
merged block meta") {
+    val blockChunks = Map[BlockId, ManagedBuffer](
+      ShuffleBlockId(0, 0, 2) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 1, 2) -> createMockManagedBuffer()
+    )
+    when(mapOutputTracker.getMapSizesForMergeResult(0, 2)).thenReturn(
+      Seq((BlockManagerId("remote-client-1", "remote-host-1", 1),
+        toBlockList(Seq(ShuffleBlockId(0, 0, 2), ShuffleBlockId(0, 1, 2)), 1L, 
1))).iterator)
+    val blocksSem = new Semaphore(0)
+    configureMockTransferForPushShuffle(blocksSem, blockChunks)
+    val pushMergedBlockMeta = createMockPushMergedBlockMeta(2, null)
+    when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any()))
+      .thenAnswer((invocation: InvocationOnMock) => {
+        val metaListener = 
invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener]
+        val shuffleId = invocation.getArguments()(2).asInstanceOf[Int]
+        val reduceId = invocation.getArguments()(3).asInstanceOf[Int]
+        Future {
+          metaListener.onSuccess(shuffleId, reduceId, pushMergedBlockMeta)
+        }
+      })
+    val remoteMergedBlockMgrId = BlockManagerId(
+      SHUFFLE_MERGER_IDENTIFIER, "remote-host-2", 1)
+    val iterator = createShuffleBlockIteratorWithDefaults(
+      Map(remoteMergedBlockMgrId -> toBlockList(
+        Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 2L, 
SHUFFLE_PUSH_MAP_ID)))
+    val (id1, _) = iterator.next()
+    blocksSem.acquire(2)
+    assert(id1 === ShuffleBlockId(0, 0, 2))
+    val (id2, _) = iterator.next()
+    assert(id2 === ShuffleBlockId(0, 1, 2))
+    assert(!iterator.hasNext)
+  }
+
+  test("SPARK-32922: failure to fetch a remote shuffle chunk initiates the 
fallback of " +
+    "pending shuffle chunks immediately") {
+    val blockChunks = Map[BlockId, ManagedBuffer](
+      ShuffleBlockChunkId(0, 2, 0) -> createMockManagedBuffer(),
+      // ShuffleBlockChunk(0, 2, 1) will cause a failure as it is not in 
block-chunks.
+      ShuffleBlockChunkId(0, 2, 2) -> createMockManagedBuffer(),
+      ShuffleBlockChunkId(0, 2, 3) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 3, 2) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 4, 2) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 5, 2) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 6, 2) -> createMockManagedBuffer()
+    )
+    val blocksSem = new Semaphore(0)
+    configureMockTransferForPushShuffle(blocksSem, blockChunks)
+
+    val metaSem = new Semaphore(0)
+    val pushMergedBlockMeta = mock(classOf[MergedBlockMeta])
+    when(pushMergedBlockMeta.getNumChunks).thenReturn(4)
+    
when(pushMergedBlockMeta.getChunksBitmapBuffer).thenReturn(mock(classOf[ManagedBuffer]))
+    val roaringBitmaps = Array.fill[RoaringBitmap](4)(new RoaringBitmap)
+    when(pushMergedBlockMeta.readChunkBitmaps()).thenReturn(roaringBitmaps)
+    when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any()))
+      .thenAnswer((invocation: InvocationOnMock) => {
+        val metaListener = 
invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener]
+        val shuffleId = invocation.getArguments()(2).asInstanceOf[Int]
+        val reduceId = invocation.getArguments()(3).asInstanceOf[Int]
+        Future {
+          logInfo(s"acquiring semaphore for host = 
${invocation.getArguments()(0)}, " +
+            s"port = ${invocation.getArguments()(1)}, " +
+            s"shuffleId = $shuffleId, reduceId = $reduceId")
+          metaSem.release()
+          metaListener.onSuccess(shuffleId, reduceId, pushMergedBlockMeta)
+        }
+      })
+    val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
+    val fallbackBlocksByAddr = Seq[(BlockManagerId, Seq[(BlockId, Long, 
Int)])](
+      (remoteBmId, toBlockList(Seq(ShuffleBlockId(0, 3, 2), ShuffleBlockId(0, 
4, 2),
+        ShuffleBlockId(0, 5, 2), ShuffleBlockId(0, 6, 2)), 1L, 1)))
+    when(mapOutputTracker.getMapSizesForMergeResult(any(), any(), any()))
+      .thenReturn(fallbackBlocksByAddr.iterator)
+
+    val iterator = createShuffleBlockIteratorWithDefaults(Map(
+      BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "test-client-1", 2) ->
+        toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 16L, 
SHUFFLE_PUSH_MAP_ID)),
+      maxBytesInFlight = 4)
+    metaSem.acquire(1)
+    val (id1, _) = iterator.next()
+    blocksSem.acquire(1)
+    assert(id1 === ShuffleBlockChunkId(0, 2, 0))
+    val regularBlocks = new mutable.HashSet[BlockId]()
+    val (id2, _) = iterator.next()
+    blocksSem.acquire(1)
+    regularBlocks.add(id2)
+    val (id3, _) = iterator.next()
+    blocksSem.acquire(1)
+    regularBlocks.add(id3)
+    val (id4, _) = iterator.next()
+    blocksSem.acquire(1)
+    regularBlocks.add(id4)
+    val (id5, _) = iterator.next()
+    blocksSem.acquire(1)
+    regularBlocks.add(id5)
+    assert(!iterator.hasNext)
+    assert(regularBlocks === Set(ShuffleBlockId(0, 3, 2), ShuffleBlockId(0, 4, 
2),
+      ShuffleBlockId(0, 5, 2), ShuffleBlockId(0, 6, 2)))
+  }
+
+  test("SPARK-32922: failure to fetch a remote shuffle chunk initiates the 
fallback of " +
+    "pending shuffle chunks immediately which got deferred") {
+    val blockChunks = Map[BlockId, ManagedBuffer](
+      ShuffleBlockChunkId(0, 2, 0) -> createMockManagedBuffer(),
+      ShuffleBlockChunkId(0, 2, 1) -> createMockManagedBuffer(),
+      ShuffleBlockChunkId(0, 2, 2) -> createMockManagedBuffer(),
+      // ShuffleBlockChunkId(0, 2, 3) will cause failure as it is not in bock 
chunks
+      ShuffleBlockChunkId(0, 2, 4) -> createMockManagedBuffer(),
+      ShuffleBlockChunkId(0, 2, 5) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 3, 2) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 4, 2) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 5, 2) -> createMockManagedBuffer(),
+      ShuffleBlockId(0, 6, 2) -> createMockManagedBuffer()
+    )
+    val blocksSem = new Semaphore(0)
+    configureMockTransferForPushShuffle(blocksSem, blockChunks)
+    val metaSem = new Semaphore(0)
+    val pushMergedBlockMeta = mock(classOf[MergedBlockMeta])
+    when(pushMergedBlockMeta.getNumChunks).thenReturn(6)
+    
when(pushMergedBlockMeta.getChunksBitmapBuffer).thenReturn(mock(classOf[ManagedBuffer]))
+    val roaringBitmaps = Array.fill[RoaringBitmap](6)(new RoaringBitmap)
+    when(pushMergedBlockMeta.readChunkBitmaps()).thenReturn(roaringBitmaps)
+    when(transfer.getMergedBlockMeta(any(), any(), any(), any(), any()))
+      .thenAnswer((invocation: InvocationOnMock) => {
+        val metaListener = 
invocation.getArguments()(4).asInstanceOf[MergedBlocksMetaListener]
+        val shuffleId = invocation.getArguments()(2).asInstanceOf[Int]
+        val reduceId = invocation.getArguments()(3).asInstanceOf[Int]
+        Future {
+          logInfo(s"acquiring semaphore for host = 
${invocation.getArguments()(0)}, " +
+            s"port = ${invocation.getArguments()(1)}, " +
+            s"shuffleId = $shuffleId, reduceId = $reduceId")
+          metaSem.release()
+          metaListener.onSuccess(shuffleId, reduceId, pushMergedBlockMeta)
+        }
+      })
+    val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
+    val fallbackBlocksByAddr = Seq[(BlockManagerId, Seq[(BlockId, Long, 
Int)])](
+      (remoteBmId, toBlockList(Seq(ShuffleBlockId(0, 3, 2), ShuffleBlockId(0, 
4, 2),
+      ShuffleBlockId(0, 5, 2), ShuffleBlockId(0, 6, 2)), 1L, 1)))
+    when(mapOutputTracker.getMapSizesForMergeResult(any(), any(), any()))
+      .thenReturn(fallbackBlocksByAddr.iterator)
+
+    val iterator = createShuffleBlockIteratorWithDefaults(Map(
+      BlockManagerId(SHUFFLE_MERGER_IDENTIFIER, "test-client-1", 2) ->
+        toBlockList(Seq(ShuffleBlockId(0, SHUFFLE_PUSH_MAP_ID, 2)), 24L, 
SHUFFLE_PUSH_MAP_ID)),
+      maxBytesInFlight = 8, maxBlocksInFlightPerAddress = 1)
+    metaSem.acquire(1)
+    val (id1, _) = iterator.next()
+    blocksSem.acquire(2)
+    assert(id1 === ShuffleBlockChunkId(0, 2, 0))
+    val (id2, _) = iterator.next()
+    assert(id2 === ShuffleBlockChunkId(0, 2, 1))
+    val (id3, _) = iterator.next()
+    blocksSem.acquire(1)
+    assert(id3 === ShuffleBlockChunkId(0, 2, 2))
+    val regularBlocks = new mutable.HashSet[BlockId]()
+    val (id4, _) = iterator.next()
+    blocksSem.acquire(1)
+    regularBlocks.add(id4)
+    val (id5, _) = iterator.next()
+    blocksSem.acquire(1)
+    regularBlocks.add(id5)
+    val (id6, _) = iterator.next()
+    blocksSem.acquire(1)
+    regularBlocks.add(id6)
+    val (id7, _) = iterator.next()
+    blocksSem.acquire(1)
+    regularBlocks.add(id7)
+    assert(!iterator.hasNext)
+    assert(regularBlocks === Set[ShuffleBlockId](ShuffleBlockId(0, 3, 2), 
ShuffleBlockId(0, 4, 2),
+      ShuffleBlockId(0, 5, 2), ShuffleBlockId(0, 6, 2)))
+  }
+
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to