Copilot commented on code in PR #12370:
URL: https://github.com/apache/gluten/pull/12370#discussion_r3505163705


##########
cpp/velox/utils/CachedBatchQueue.h:
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <glog/logging.h>
+#include "velox/common/base/Exceptions.h"
+
+#include <condition_variable>
+#include <memory>
+#include <mutex>
+#include <queue>
+
+namespace gluten {
+
+template <typename T>
+class CachedBatchQueue {
+ public:
+  explicit CachedBatchQueue(const int64_t capacity) : capacity_(capacity) {}
+
+  void put(std::shared_ptr<T> batch) {
+    std::unique_lock<std::mutex> lock(mtx_);
+    VELOX_CHECK(!noMoreBatches_, "Cannot put batch after noMoreBatches() is 
called");
+
+    const auto batchSize = batch->numBytes();
+    VELOX_CHECK_LE(batchSize, capacity_, "Batch size exceeds queue capacity");
+
+    notFull_.wait(lock, [&]() { return totalSize_ + batchSize <= capacity_; });
+
+    queue_.push(std::move(batch));
+    totalSize_ += batchSize;
+
+    notEmpty_.notify_one();
+  }
+
+  std::shared_ptr<T> get() {
+    std::unique_lock<std::mutex> lock(mtx_);
+    notEmpty_.wait(lock, [&]() { return noMoreBatches_ || !queue_.empty(); });
+
+    if (queue_.empty()) {
+      return nullptr;
+    }
+    auto batch = std::move(queue_.front());
+    LOG(INFO) << "Trying to get from cached buffer queue. Queue length: " << 
queue_.size()
+              << ", total size in queue: " << totalSize_ << ", current batch 
size: " << batch->numBytes() << std::endl;
+

Review Comment:
   `CachedBatchQueue.get()` logs at INFO for every dequeued batch (and also 
flushes via `std::endl`). This is likely to be extremely noisy and expensive in 
the shuffle fast path. Consider removing this log or downgrading to `VLOG` 
without `std::endl`.



##########
cpp/velox/utils/CachedBatchQueue.h:
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include <glog/logging.h>
+#include "velox/common/base/Exceptions.h"
+
+#include <condition_variable>
+#include <memory>
+#include <mutex>
+#include <queue>
+
+namespace gluten {
+
+template <typename T>
+class CachedBatchQueue {
+ public:
+  explicit CachedBatchQueue(const int64_t capacity) : capacity_(capacity) {}
+
+  void put(std::shared_ptr<T> batch) {
+    std::unique_lock<std::mutex> lock(mtx_);
+    VELOX_CHECK(!noMoreBatches_, "Cannot put batch after noMoreBatches() is 
called");
+
+    const auto batchSize = batch->numBytes();
+    VELOX_CHECK_LE(batchSize, capacity_, "Batch size exceeds queue capacity");
+
+    notFull_.wait(lock, [&]() { return totalSize_ + batchSize <= capacity_; });
+

Review Comment:
   `CachedBatchQueue.put()` can deadlock if `noMoreBatches()` is called while a 
producer is blocked in `notFull_.wait(...)`: `noMoreBatches()` notifies 
`notFull_`, but the wait predicate never becomes true, so the producer can wait 
forever. Include `noMoreBatches_` in the predicate and re-check after waking to 
fail fast.



##########
gluten-arrow/src/main/scala/org/apache/spark/storage/SparkInputStreamUtil.scala:
##########
@@ -19,7 +19,7 @@ package org.apache.spark.storage
 import java.io.InputStream
 
 object SparkInputStreamUtil {
-  def unwrapBufferReleasingInputStream(in: BufferReleasingInputStream): 
InputStream = {
+  def unwrapBufferReleasingInputStream(in: GlutenBufferReleasingInputStream): 
InputStream = {
     in.delegate
   }

Review Comment:
   `unwrapBufferReleasingInputStream` takes `GlutenBufferReleasingInputStream`, 
which forces other modules (like Java `JniByteInputStreams`) to reference that 
shim-internal type directly. Since `GlutenBufferReleasingInputStream` is 
currently declared `private` in the shims, this makes cross-module compilation 
fragile. Consider changing this utility to accept a plain `InputStream` and do 
the unwrapping internally (pattern match) to avoid leaking the shim wrapper 
type.



##########
cpp/velox/compute/VeloxBackend.cc:
##########
@@ -294,12 +294,19 @@ void VeloxBackend::init(
   registerShuffleDictionaryWriterFactory([](MemoryManager* memoryManager, 
arrow::util::Codec* codec) {
     return std::make_unique<ArrowShuffleDictionaryWriter>(memoryManager, 
codec);
   });
+
+  readerThreadPool_ = std::make_unique<ReaderThreadPool>(
+      backendConf_->get<int32_t>(kShuffleReaderThreads, 
std::thread::hardware_concurrency()));

Review Comment:
   `std::thread::hardware_concurrency()` is permitted to return 0, and the 
config could also be set to 0. Creating `ReaderThreadPool(0)` would result in 
zero reader tasks and the GPU shuffle reader would block forever on 
`CachedBatchQueue::get()` (since `noMoreBatches()` is only called by reader 
threads). Clamp the thread count to at least 1 (and ideally validate the conf 
value).



##########
shims/spark33/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala:
##########
@@ -0,0 +1,1506 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.storage
+
+import org.apache.spark.{MapOutputTracker, TaskContext}
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.errors.SparkCoreErrors
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, 
ManagedBuffer}
+import org.apache.spark.network.shuffle._
+import org.apache.spark.network.shuffle.checksum.{Cause, ShuffleChecksumHelper}
+import org.apache.spark.network.util.TransportConf
+import org.apache.spark.shuffle.ShuffleReadMetricsReporter
+import org.apache.spark.util.{TaskCompletionListener, Utils}
+
+import io.netty.util.internal.OutOfDirectMemoryError
+import org.apache.commons.io.IOUtils
+
+import javax.annotation.concurrent.GuardedBy
+
+import java.io.{InputStream, IOException}
+import java.nio.channels.ClosedByInterruptException
+import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, TimeUnit}
+import java.util.zip.CheckedInputStream
+
+import scala.collection.mutable
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue}
+import scala.util.{Failure, Success}
+
+/**
+ * An iterator that fetches multiple blocks. For local blocks, it fetches from 
the local block
+ * manager. For remote blocks, it fetches them using the provided 
BlockTransferService.
+ *
+ * This creates an iterator of (BlockID, InputStream) tuples so the caller can 
handle blocks in a
+ * pipelined fashion as they are received.
+ *
+ * The implementation throttles the remote fetches so they don't exceed 
maxBytesInFlight to avoid
+ * using too much memory.
+ *
+ * @param context
+ *   [[TaskContext]], used for metrics update
+ * @param shuffleClient
+ *   [[BlockStoreClient]] for fetching remote blocks
+ * @param blockManager
+ *   [[BlockManager]] for reading local blocks
+ * @param blocksByAddress
+ *   list of blocks to fetch grouped by the [[BlockManagerId]]. For each block 
we also require two
+ *   info: 1. the size (in bytes as a long field) in order to throttle the 
memory usage; 2. the
+ *   mapIndex for this block, which indicate the index in the map stage. Note 
that zero-sized blocks
+ *   are already excluded, which happened in
+ *   [[org.apache.spark.MapOutputTracker.convertMapStatuses]].
+ * @param mapOutputTracker
+ *   [[MapOutputTracker]] for falling back to fetching the original blocks if 
we fail to fetch
+ *   shuffle chunks when push based shuffle is enabled.
+ * @param streamWrapper
+ *   A function to wrap the returned input stream.
+ * @param maxBytesInFlight
+ *   max size (in bytes) of remote blocks to fetch at any given point.
+ * @param maxReqsInFlight
+ *   max number of remote requests to fetch blocks at any given point.
+ * @param maxBlocksInFlightPerAddress
+ *   max number of shuffle blocks being fetched at any given point for a given 
remote host:port.
+ * @param maxReqSizeShuffleToMem
+ *   max size (in bytes) of a request that can be shuffled to memory.
+ * @param maxAttemptsOnNettyOOM
+ *   The max number of a block could retry due to Netty OOM before throwing 
the fetch failure.
+ * @param detectCorrupt
+ *   whether to detect any corruption in fetched blocks.
+ * @param checksumEnabled
+ *   whether the shuffle checksum is enabled. When enabled, Spark will try to 
diagnose the cause of
+ *   the block corruption.
+ * @param checksumAlgorithm
+ *   the checksum algorithm that is used when calculating the checksum value 
for the block data.
+ * @param shuffleMetrics
+ *   used to report shuffle metrics.
+ * @param doBatchFetch
+ *   fetch continuous shuffle blocks from same executor in batch if the server 
side supports.
+ */
+final class GlutenShuffleBlockFetcherIterator(
+    context: TaskContext,
+    shuffleClient: BlockStoreClient,
+    blockManager: BlockManager,
+    mapOutputTracker: MapOutputTracker,
+    blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+    streamWrapper: (BlockId, InputStream) => InputStream,
+    maxBytesInFlight: Long,
+    maxReqsInFlight: Int,
+    maxBlocksInFlightPerAddress: Int,
+    val maxReqSizeShuffleToMem: Long,
+    maxAttemptsOnNettyOOM: Int,
+    detectCorrupt: Boolean,
+    detectCorruptUseExtraMemory: Boolean,
+    checksumEnabled: Boolean,
+    checksumAlgorithm: String,
+    shuffleMetrics: ShuffleReadMetricsReporter,
+    doBatchFetch: Boolean)
+  extends GlutenShuffleBlockFetcherIteratorBase
+  with DownloadFileManager
+  with Logging {
+
+  import ShuffleBlockFetcherIterator._
+
+  // Make remote requests at most maxBytesInFlight / 5 in length; the reason 
to keep them
+  // smaller than maxBytesInFlight is to allow multiple, parallel fetches from 
up to 5
+  // nodes, rather than blocking on reading output from one node.
+  private val targetRemoteRequestSize = math.max(maxBytesInFlight / 5, 1L)
+
+  /**
+   * Total number of blocks to fetch.
+   */
+  private[this] var numBlocksToFetch = 0
+
+  /**
+   * The number of blocks processed by the caller. The iterator is exhausted 
when
+   * [[numBlocksProcessed]] == [[numBlocksToFetch]].
+   */
+  private[this] var numBlocksProcessed = 0
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  /** Host local blocks to fetch, excluding zero-sized blocks. */
+  private[this] val hostLocalBlocks = 
scala.collection.mutable.LinkedHashSet[(BlockId, Int)]()
+
+  /**
+   * A queue to hold our results. This turns the asynchronous model provided by
+   * [[org.apache.spark.network.BlockTransferService]] into a synchronous 
model (iterator).
+   */
+  private[this] val results = new LinkedBlockingQueue[FetchResult]
+
+  /**
+   * Current [[FetchResult]] being processed per thread. We track this so we 
can release the current
+   * buffer in case of a runtime exception when processing the current buffer. 
Using
+   * ConcurrentHashMap to support concurrent access from multiple threads 
while allowing cleanup
+   * from any thread.
+   */
+  private[this] val currentResults: ConcurrentHashMap[Long, 
SuccessFetchResult] =
+    new ConcurrentHashMap[Long, SuccessFetchResult]()
+
+  /**
+   * Queue of fetch requests to issue; we'll pull requests off this gradually 
to make sure that the
+   * number of bytes in flight is limited to maxBytesInFlight.
+   */
+  private[this] val fetchRequests = new Queue[FetchRequest]
+
+  /**
+   * Queue of fetch requests which could not be issued the first time they 
were dequeued. These
+   * requests are tried again when the fetch constraints are satisfied.
+   */
+  private[this] val deferredFetchRequests = new HashMap[BlockManagerId, 
Queue[FetchRequest]]()
+
+  /** Current bytes in flight from our requests */
+  private[this] var bytesInFlight = 0L
+
+  /** Current number of requests in flight */
+  private[this] var reqsInFlight = 0
+
+  /** Current number of blocks in flight per host:port */
+  private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId, 
Int]()
+
+  /**
+   * Count the retry times for the blocks due to Netty OOM. The block will 
stop retry if retry times
+   * has exceeded the [[maxAttemptsOnNettyOOM]].
+   */
+  private[this] val blockOOMRetryCounts = new HashMap[String, Int]
+
+  /**
+   * The blocks that can't be decompressed successfully, it is used to 
guarantee that we retry at
+   * most once for those corrupted blocks.
+   */
+  private[this] val corruptedBlocks = mutable.HashSet[BlockId]()
+
+  /**
+   * Whether the iterator is still active. If isZombie is true, the callback 
interface will no
+   * longer place fetched blocks into [[results]].
+   */
+  @GuardedBy("this")
+  private[this] var isZombie = false
+
+  /**
+   * A set to store the files used for shuffling remote huge blocks. Files in 
this set will be
+   * deleted when cleanup. This is a layer of defensiveness against disk file 
leaks.
+   */
+  @GuardedBy("this")
+  private[this] val shuffleFilesSet = mutable.HashSet[DownloadFile]()
+
+  private[this] val onCompleteCallback = new 
GlutenShuffleFetchCompletionListener(this)
+
+  private[this] val pushBasedFetchHelper = new GlutenPushBasedFetchHelper(
+    this,
+    shuffleClient,
+    blockManager,
+    mapOutputTracker)
+
+  initialize()
+
+  // Decrements the buffer reference count.
+  // The currentResult is removed from the map to prevent releasing the buffer 
again on cleanup()
+  private[storage] def releaseCurrentResultBuffer(): Unit = {
+    val threadId = Thread.currentThread().getId
+    // Release the current buffer if necessary
+    val result = currentResults.remove(threadId)
+    if (result != null) {
+      result.buf.release()
+    }
+  }
+
+  override def createTempFile(transportConf: TransportConf): DownloadFile = {
+    // we never need to do any encryption or decryption here, regardless of 
configs, because that
+    // is handled at another layer in the code.  When encryption is enabled, 
shuffle data is written
+    // to disk encrypted in the first place, and sent over the network still 
encrypted.
+    new SimpleDownloadFile(
+      blockManager.diskBlockManager.createTempLocalBlock()._2,
+      transportConf)
+  }
+
+  override def registerTempFileToClean(file: DownloadFile): Boolean = 
synchronized {
+    if (isZombie) {
+      false
+    } else {
+      shuffleFilesSet += file
+      true
+    }
+  }
+
+  /**
+   * Mark the iterator as zombie, and release all buffers that haven't been 
deserialized yet.
+   */
+  private[storage] def cleanup(): Unit = {
+    synchronized {
+      isZombie = true
+    }
+    releaseCurrentResultBuffer()
+    // Release buffers in the results queue
+    val iter = results.iterator()
+    while (iter.hasNext) {
+      val result = iter.next()
+      result match {
+        case SuccessFetchResult(blockId, mapIndex, address, _, buf, _) =>
+          if (address != blockManager.blockManagerId) {
+            if (hostLocalBlocks.contains(blockId -> mapIndex)) {
+              shuffleMetrics.incLocalBlocksFetched(1)
+              shuffleMetrics.incLocalBytesRead(buf.size)
+            } else {
+              shuffleMetrics.incRemoteBytesRead(buf.size)
+              if (buf.isInstanceOf[FileSegmentManagedBuffer]) {
+                shuffleMetrics.incRemoteBytesReadToDisk(buf.size)
+              }
+              shuffleMetrics.incRemoteBlocksFetched(1)
+            }
+          }
+          buf.release()
+        case _ =>
+      }
+    }
+    shuffleFilesSet.foreach {
+      file =>
+        if (!file.delete()) {
+          logWarning("Failed to cleanup shuffle fetch temp file " + 
file.path())
+        }
+    }
+  }
+
+  private[this] def sendRequest(req: FetchRequest): Unit = {
+    logDebug("Sending request for %d blocks (%s) from %s".format(
+      req.blocks.size,
+      Utils.bytesToString(req.size),
+      req.address.hostPort))
+    bytesInFlight += req.size
+    reqsInFlight += 1
+
+    // so we can look up the block info of each blockID
+    val infoMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, mapIndex) => (blockId.toString, 
(size, mapIndex))
+    }.toMap
+    val remainingBlocks = new HashSet[String]() ++= infoMap.keys
+    val deferredBlocks = new ArrayBuffer[String]()
+    val blockIds = req.blocks.map(_.blockId.toString)
+    val address = req.address
+
+    @inline def enqueueDeferredFetchRequestIfNecessary(): Unit = {
+      if (remainingBlocks.isEmpty && deferredBlocks.nonEmpty) {
+        val blocks = deferredBlocks.map {
+          blockId =>
+            val (size, mapIndex) = infoMap(blockId)
+            FetchBlockInfo(BlockId(blockId), size, mapIndex)
+        }
+        results.put(DeferFetchRequestResult(FetchRequest(address, 
blocks.toSeq)))
+        deferredBlocks.clear()
+      }
+    }
+
+    val blockFetchingListener = new BlockFetchingListener {
+      override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): 
Unit = {
+        // Only add the buffer to results queue if the iterator is not zombie,
+        // i.e. cleanup() has not been called yet.
+        GlutenShuffleBlockFetcherIterator.this.synchronized {
+          if (!isZombie) {
+            // Increment the ref count because we need to pass this to a 
different thread.
+            // This needs to be released after use.
+            buf.retain()
+            remainingBlocks -= blockId
+            blockOOMRetryCounts.remove(blockId)
+            results.put(new SuccessFetchResult(
+              BlockId(blockId),
+              infoMap(blockId)._2,
+              address,
+              infoMap(blockId)._1,
+              buf,
+              remainingBlocks.isEmpty))
+            logDebug("remainingBlocks: " + remainingBlocks)
+            enqueueDeferredFetchRequestIfNecessary()
+          }
+        }
+        logTrace(s"Got remote block $blockId after 
${Utils.getUsedTimeNs(startTimeNs)}")
+      }
+
+      override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
+        GlutenShuffleBlockFetcherIterator.this.synchronized {
+          logError(s"Failed to get block(s) from 
${req.address.host}:${req.address.port}", e)
+          e match {
+            // SPARK-27991: Catch the Netty OOM and set the flag 
`isNettyOOMOnShuffle` (shared among
+            // tasks) to true as early as possible. The pending fetch requests 
won't be sent
+            // afterwards until the flag is set to false on:
+            // 1) the Netty free memory >= maxReqSizeShuffleToMem
+            //    - we'll check this whenever there's a fetch request succeeds.
+            // 2) the number of in-flight requests becomes 0
+            //    - we'll check this in `fetchUpToMaxBytes` whenever it's 
invoked.
+            // Although Netty memory is shared across multiple modules, e.g., 
shuffle, rpc, the flag
+            // only takes effect for the shuffle due to the implementation 
simplicity concern.
+            // And we'll buffer the consecutive block failures caused by the 
OOM error until there's
+            // no remaining blocks in the current request. Then, we'll package 
these blocks into
+            // a same fetch request for the retry later. In this way, instead 
of creating the fetch
+            // request per block, it would help reduce the concurrent 
connections and data loads
+            // pressure at remote server.
+            // Note that catching OOM and do something based on it is only a 
workaround for
+            // handling the Netty OOM issue, which is not the best way towards 
memory management.
+            // We can get rid of it when we find a way to manage Netty's 
memory precisely.
+            case _: OutOfDirectMemoryError
+                if blockOOMRetryCounts.getOrElseUpdate(blockId, 0) < 
maxAttemptsOnNettyOOM =>
+              if (!isZombie) {
+                val failureTimes = blockOOMRetryCounts(blockId)
+                blockOOMRetryCounts(blockId) += 1
+                if (isNettyOOMOnShuffle.compareAndSet(false, true)) {
+                  // The fetcher can fail remaining blocks in batch for the 
same error. So we only
+                  // log the warning once to avoid flooding the logs.
+                  logInfo(s"Block $blockId has failed $failureTimes times " +
+                    s"due to Netty OOM, will retry")
+                }
+                remainingBlocks -= blockId
+                deferredBlocks += blockId
+                enqueueDeferredFetchRequestIfNecessary()
+              }
+
+            case _ =>
+              val block = BlockId(blockId)
+              if (block.isShuffleChunk) {
+                remainingBlocks -= blockId
+                results.put(FallbackOnPushMergedFailureResult(
+                  block,
+                  address,
+                  infoMap(blockId)._1,
+                  remainingBlocks.isEmpty))
+              } else {
+                results.put(FailureFetchResult(block, infoMap(blockId)._2, 
address, e))
+              }
+          }
+        }
+      }
+    }
+
+    // Fetch remote shuffle blocks to disk when the request is too large. 
Since the shuffle data is
+    // already encrypted and compressed over the wire(w.r.t. the related 
configs), we can just fetch
+    // the data and write it to file directly.
+    if (req.size > maxReqSizeShuffleToMem) {
+      shuffleClient.fetchBlocks(
+        address.host,
+        address.port,
+        address.executorId,
+        blockIds.toArray,
+        blockFetchingListener,
+        this)
+    } else {
+      shuffleClient.fetchBlocks(
+        address.host,
+        address.port,
+        address.executorId,
+        blockIds.toArray,
+        blockFetchingListener,
+        null)
+    }
+  }
+
+  /**
+   * This is called from initialize and also from the fallback which is 
triggered from
+   * [[PushBasedFetchHelper]].
+   */
+  private[this] def partitionBlocksByFetchMode(
+      blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+      localBlocks: mutable.LinkedHashSet[(BlockId, Int)],
+      hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, 
Seq[(BlockId, Long, Int)]],
+      pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): 
ArrayBuffer[FetchRequest] = {
+    logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
+      + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: 
$maxBlocksInFlightPerAddress")
+
+    // Partition to local, host-local, push-merged-local, remote (includes 
push-merged-remote)
+    // blocks.Remote blocks are further split into FetchRequests of size at 
most maxBytesInFlight
+    // in order to limit the amount of data in flight
+    val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
+    var localBlockBytes = 0L
+    var hostLocalBlockBytes = 0L
+    var numHostLocalBlocks = 0
+    var pushMergedLocalBlockBytes = 0L
+    val prevNumBlocksToFetch = numBlocksToFetch
+
+    val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
+    val localExecIds = Set(blockManager.blockManagerId.executorId, fallback)
+    for ((address, blockInfos) <- blocksByAddress) {
+      checkBlockSizes(blockInfos)
+      if (pushBasedFetchHelper.isPushMergedShuffleBlockAddress(address)) {
+        // These are push-merged blocks or shuffle chunks of these blocks.
+        if (address.host == blockManager.blockManagerId.host) {
+          numBlocksToFetch += blockInfos.size
+          pushMergedLocalBlocks ++= blockInfos.map(_._1)
+          pushMergedLocalBlockBytes += blockInfos.map(_._2).sum
+        } else {
+          collectFetchRequests(address, blockInfos, collectedRemoteRequests)
+        }
+      } else if (localExecIds.contains(address.executorId)) {
+        val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
+          blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)),
+          doBatchFetch)
+        numBlocksToFetch += mergedBlockInfos.size
+        localBlocks ++= mergedBlockInfos.map(info => (info.blockId, 
info.mapIndex))
+        localBlockBytes += mergedBlockInfos.map(_.size).sum
+      } else if (
+        blockManager.hostLocalDirManager.isDefined &&
+        address.host == blockManager.blockManagerId.host
+      ) {
+        val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
+          blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)),
+          doBatchFetch)
+        numBlocksToFetch += mergedBlockInfos.size
+        val blocksForAddress =
+          mergedBlockInfos.map(info => (info.blockId, info.size, 
info.mapIndex))
+        hostLocalBlocksByExecutor += address -> blocksForAddress
+        numHostLocalBlocks += blocksForAddress.size
+        hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum
+      } else {
+        val (_, timeCost) = Utils.timeTakenMs[Unit] {
+          collectFetchRequests(address, blockInfos, collectedRemoteRequests)
+        }
+        logDebug(s"Collected remote fetch requests for $address in $timeCost 
ms")
+      }
+    }
+    val (remoteBlockBytes, numRemoteBlocks) =
+      collectedRemoteRequests.foldLeft((0L, 0))((x, y) => (x._1 + y.size, x._2 
+ y.blocks.size))
+    val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes +
+      pushMergedLocalBlockBytes
+    val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch
+    assert(
+      blocksToFetchCurrentIteration == localBlocks.size +
+        numHostLocalBlocks + numRemoteBlocks + pushMergedLocalBlocks.size,
+      s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't 
equal to the sum " +
+        s"of the number of local blocks ${localBlocks.size} + " +
+        s"the number of host-local blocks $numHostLocalBlocks " +
+        s"the number of push-merged-local blocks ${pushMergedLocalBlocks.size} 
" +
+        s"+ the number of remote blocks $numRemoteBlocks "
+    )
+    logInfo(s"Getting $blocksToFetchCurrentIteration " +
+      s"(${Utils.bytesToString(totalBytes)}) non-empty blocks including " +
+      s"${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local 
and " +
+      s"$numHostLocalBlocks (${Utils.bytesToString(hostLocalBlockBytes)}) " +
+      s"host-local and ${pushMergedLocalBlocks.size} " +
+      s"(${Utils.bytesToString(pushMergedLocalBlockBytes)}) " +
+      s"push-merged-local and $numRemoteBlocks 
(${Utils.bytesToString(remoteBlockBytes)}) " +
+      s"remote blocks")
+    this.hostLocalBlocks ++= hostLocalBlocksByExecutor.values
+      .flatMap(infos => infos.map(info => (info._1, info._3)))
+    collectedRemoteRequests
+  }
+
+  private def createFetchRequest(
+      blocks: Seq[FetchBlockInfo],
+      address: BlockManagerId,
+      forMergedMetas: Boolean): FetchRequest = {
+    logDebug(s"Creating fetch request of ${blocks.map(_.size).sum} at $address 
"
+      + s"with ${blocks.size} blocks")
+    FetchRequest(address, blocks, forMergedMetas)
+  }
+
+  private def createFetchRequests(
+      curBlocks: Seq[FetchBlockInfo],
+      address: BlockManagerId,
+      isLast: Boolean,
+      collectedRemoteRequests: ArrayBuffer[FetchRequest],
+      enableBatchFetch: Boolean,
+      forMergedMetas: Boolean = false): ArrayBuffer[FetchBlockInfo] = {
+    val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks, 
enableBatchFetch)
+    numBlocksToFetch += mergedBlocks.size
+    val retBlocks = new ArrayBuffer[FetchBlockInfo]
+    if (mergedBlocks.length <= maxBlocksInFlightPerAddress) {
+      collectedRemoteRequests += createFetchRequest(mergedBlocks, address, 
forMergedMetas)
+    } else {
+      mergedBlocks.grouped(maxBlocksInFlightPerAddress).foreach {
+        blocks =>
+          if (blocks.length == maxBlocksInFlightPerAddress || isLast) {
+            collectedRemoteRequests += createFetchRequest(blocks, address, 
forMergedMetas)
+          } else {
+            // The last group does not exceed `maxBlocksInFlightPerAddress`. 
Put it back
+            // to `curBlocks`.
+            retBlocks ++= blocks
+            numBlocksToFetch -= blocks.size
+          }
+      }
+    }
+    retBlocks
+  }
+
+  private def collectFetchRequests(
+      address: BlockManagerId,
+      blockInfos: Seq[(BlockId, Long, Int)],
+      collectedRemoteRequests: ArrayBuffer[FetchRequest]): Unit = {
+    val iterator = blockInfos.iterator
+    var curRequestSize = 0L
+    var curBlocks = new ArrayBuffer[FetchBlockInfo]()
+
+    while (iterator.hasNext) {
+      val (blockId, size, mapIndex) = iterator.next()
+      curBlocks += FetchBlockInfo(blockId, size, mapIndex)
+      curRequestSize += size
+      blockId match {
+        // Either all blocks are push-merged blocks, shuffle chunks, or 
original blocks.
+        // Based on these types, we decide to do batch fetch and create 
FetchRequests with
+        // forMergedMetas set.
+        case ShuffleBlockChunkId(_, _, _, _) =>
+          if (
+            curRequestSize >= targetRemoteRequestSize ||
+            curBlocks.size >= maxBlocksInFlightPerAddress
+          ) {
+            curBlocks = createFetchRequests(
+              curBlocks.toSeq,
+              address,
+              isLast = false,
+              collectedRemoteRequests,
+              enableBatchFetch = false)
+            curRequestSize = curBlocks.map(_.size).sum
+          }
+        case ShuffleMergedBlockId(_, _, _) =>
+          if (curBlocks.size >= maxBlocksInFlightPerAddress) {
+            curBlocks = createFetchRequests(
+              curBlocks.toSeq,
+              address,
+              isLast = false,
+              collectedRemoteRequests,
+              enableBatchFetch = false,
+              forMergedMetas = true)
+          }
+        case _ =>
+          // For batch fetch, the actual block in flight should count for 
merged block.
+          val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= 
maxBlocksInFlightPerAddress
+          if (curRequestSize >= targetRemoteRequestSize || 
mayExceedsMaxBlocks) {
+            curBlocks = createFetchRequests(
+              curBlocks.toSeq,
+              address,
+              isLast = false,
+              collectedRemoteRequests,
+              doBatchFetch)
+            curRequestSize = curBlocks.map(_.size).sum
+          }
+      }
+    }
+    // Add in the final request
+    if (curBlocks.nonEmpty) {
+      val (enableBatchFetch, forMergedMetas) = {
+        curBlocks.head.blockId match {
+          case ShuffleBlockChunkId(_, _, _, _) => (false, false)
+          case ShuffleMergedBlockId(_, _, _) => (false, true)
+          case _ => (doBatchFetch, false)
+        }
+      }
+      createFetchRequests(
+        curBlocks.toSeq,
+        address,
+        isLast = true,
+        collectedRemoteRequests,
+        enableBatchFetch = enableBatchFetch,
+        forMergedMetas = forMergedMetas)
+    }
+  }
+
+  private def assertPositiveBlockSize(blockId: BlockId, blockSize: Long): Unit 
= {
+    if (blockSize < 0) {
+      throw BlockException(blockId, "Negative block size " + size)

Review Comment:
   `assertPositiveBlockSize` throws using an undefined identifier `size`, which 
will not compile. It should reference the `blockSize` parameter when building 
the error message.



##########
gluten-arrow/src/main/java/org/apache/gluten/vectorized/JniByteInputStreams.java:
##########
@@ -58,8 +58,8 @@ public static JniByteInputStream create(InputStream in) {
 
   static InputStream unwrapSparkInputStream(InputStream in) {
     InputStream unwrapped = in;
-    if (unwrapped instanceof BufferReleasingInputStream) {
-      final BufferReleasingInputStream brin = (BufferReleasingInputStream) 
unwrapped;
+    if (unwrapped instanceof GlutenBufferReleasingInputStream) {
+      final GlutenBufferReleasingInputStream brin = 
(GlutenBufferReleasingInputStream) unwrapped;
       unwrapped =
           
org.apache.spark.storage.SparkInputStreamUtil.unwrapBufferReleasingInputStream(brin);
     }

Review Comment:
   This `instanceof GlutenBufferReleasingInputStream` cast introduces a 
compile-time dependency on a shim-internal type that is currently declared as a 
`private class` in the Scala shims. That makes compilation fragile / likely to 
fail. Prefer moving the unwrapping logic behind a public helper that accepts 
`InputStream` (so this code doesn’t need to reference the shim class), or make 
the wrapper type public across shims.



##########
shims/spark33/src/main/scala/org/apache/spark/storage/GlutenShuffleBlockFetcherIterator.scala:
##########
@@ -0,0 +1,1506 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.storage
+
+import org.apache.spark.{MapOutputTracker, TaskContext}
+import org.apache.spark.MapOutputTracker.SHUFFLE_PUSH_MAP_ID
+import org.apache.spark.errors.SparkCoreErrors
+import org.apache.spark.internal.Logging
+import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, 
ManagedBuffer}
+import org.apache.spark.network.shuffle._
+import org.apache.spark.network.shuffle.checksum.{Cause, ShuffleChecksumHelper}
+import org.apache.spark.network.util.TransportConf
+import org.apache.spark.shuffle.ShuffleReadMetricsReporter
+import org.apache.spark.util.{TaskCompletionListener, Utils}
+
+import io.netty.util.internal.OutOfDirectMemoryError
+import org.apache.commons.io.IOUtils
+
+import javax.annotation.concurrent.GuardedBy
+
+import java.io.{InputStream, IOException}
+import java.nio.channels.ClosedByInterruptException
+import java.util.concurrent.{ConcurrentHashMap, LinkedBlockingQueue, TimeUnit}
+import java.util.zip.CheckedInputStream
+
+import scala.collection.mutable
+import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Queue}
+import scala.util.{Failure, Success}
+
+/**
+ * An iterator that fetches multiple blocks. For local blocks, it fetches from 
the local block
+ * manager. For remote blocks, it fetches them using the provided 
BlockTransferService.
+ *
+ * This creates an iterator of (BlockID, InputStream) tuples so the caller can 
handle blocks in a
+ * pipelined fashion as they are received.
+ *
+ * The implementation throttles the remote fetches so they don't exceed 
maxBytesInFlight to avoid
+ * using too much memory.
+ *
+ * @param context
+ *   [[TaskContext]], used for metrics update
+ * @param shuffleClient
+ *   [[BlockStoreClient]] for fetching remote blocks
+ * @param blockManager
+ *   [[BlockManager]] for reading local blocks
+ * @param blocksByAddress
+ *   list of blocks to fetch grouped by the [[BlockManagerId]]. For each block 
we also require two
+ *   info: 1. the size (in bytes as a long field) in order to throttle the 
memory usage; 2. the
+ *   mapIndex for this block, which indicate the index in the map stage. Note 
that zero-sized blocks
+ *   are already excluded, which happened in
+ *   [[org.apache.spark.MapOutputTracker.convertMapStatuses]].
+ * @param mapOutputTracker
+ *   [[MapOutputTracker]] for falling back to fetching the original blocks if 
we fail to fetch
+ *   shuffle chunks when push based shuffle is enabled.
+ * @param streamWrapper
+ *   A function to wrap the returned input stream.
+ * @param maxBytesInFlight
+ *   max size (in bytes) of remote blocks to fetch at any given point.
+ * @param maxReqsInFlight
+ *   max number of remote requests to fetch blocks at any given point.
+ * @param maxBlocksInFlightPerAddress
+ *   max number of shuffle blocks being fetched at any given point for a given 
remote host:port.
+ * @param maxReqSizeShuffleToMem
+ *   max size (in bytes) of a request that can be shuffled to memory.
+ * @param maxAttemptsOnNettyOOM
+ *   The max number of a block could retry due to Netty OOM before throwing 
the fetch failure.
+ * @param detectCorrupt
+ *   whether to detect any corruption in fetched blocks.
+ * @param checksumEnabled
+ *   whether the shuffle checksum is enabled. When enabled, Spark will try to 
diagnose the cause of
+ *   the block corruption.
+ * @param checksumAlgorithm
+ *   the checksum algorithm that is used when calculating the checksum value 
for the block data.
+ * @param shuffleMetrics
+ *   used to report shuffle metrics.
+ * @param doBatchFetch
+ *   fetch continuous shuffle blocks from same executor in batch if the server 
side supports.
+ */
+final class GlutenShuffleBlockFetcherIterator(
+    context: TaskContext,
+    shuffleClient: BlockStoreClient,
+    blockManager: BlockManager,
+    mapOutputTracker: MapOutputTracker,
+    blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+    streamWrapper: (BlockId, InputStream) => InputStream,
+    maxBytesInFlight: Long,
+    maxReqsInFlight: Int,
+    maxBlocksInFlightPerAddress: Int,
+    val maxReqSizeShuffleToMem: Long,
+    maxAttemptsOnNettyOOM: Int,
+    detectCorrupt: Boolean,
+    detectCorruptUseExtraMemory: Boolean,
+    checksumEnabled: Boolean,
+    checksumAlgorithm: String,
+    shuffleMetrics: ShuffleReadMetricsReporter,
+    doBatchFetch: Boolean)
+  extends GlutenShuffleBlockFetcherIteratorBase
+  with DownloadFileManager
+  with Logging {
+
+  import ShuffleBlockFetcherIterator._
+
+  // Make remote requests at most maxBytesInFlight / 5 in length; the reason 
to keep them
+  // smaller than maxBytesInFlight is to allow multiple, parallel fetches from 
up to 5
+  // nodes, rather than blocking on reading output from one node.
+  private val targetRemoteRequestSize = math.max(maxBytesInFlight / 5, 1L)
+
+  /**
+   * Total number of blocks to fetch.
+   */
+  private[this] var numBlocksToFetch = 0
+
+  /**
+   * The number of blocks processed by the caller. The iterator is exhausted 
when
+   * [[numBlocksProcessed]] == [[numBlocksToFetch]].
+   */
+  private[this] var numBlocksProcessed = 0
+
+  private[this] val startTimeNs = System.nanoTime()
+
+  /** Host local blocks to fetch, excluding zero-sized blocks. */
+  private[this] val hostLocalBlocks = 
scala.collection.mutable.LinkedHashSet[(BlockId, Int)]()
+
+  /**
+   * A queue to hold our results. This turns the asynchronous model provided by
+   * [[org.apache.spark.network.BlockTransferService]] into a synchronous 
model (iterator).
+   */
+  private[this] val results = new LinkedBlockingQueue[FetchResult]
+
+  /**
+   * Current [[FetchResult]] being processed per thread. We track this so we 
can release the current
+   * buffer in case of a runtime exception when processing the current buffer. 
Using
+   * ConcurrentHashMap to support concurrent access from multiple threads 
while allowing cleanup
+   * from any thread.
+   */
+  private[this] val currentResults: ConcurrentHashMap[Long, 
SuccessFetchResult] =
+    new ConcurrentHashMap[Long, SuccessFetchResult]()
+
+  /**
+   * Queue of fetch requests to issue; we'll pull requests off this gradually 
to make sure that the
+   * number of bytes in flight is limited to maxBytesInFlight.
+   */
+  private[this] val fetchRequests = new Queue[FetchRequest]
+
+  /**
+   * Queue of fetch requests which could not be issued the first time they 
were dequeued. These
+   * requests are tried again when the fetch constraints are satisfied.
+   */
+  private[this] val deferredFetchRequests = new HashMap[BlockManagerId, 
Queue[FetchRequest]]()
+
+  /** Current bytes in flight from our requests */
+  private[this] var bytesInFlight = 0L
+
+  /** Current number of requests in flight */
+  private[this] var reqsInFlight = 0
+
+  /** Current number of blocks in flight per host:port */
+  private[this] val numBlocksInFlightPerAddress = new HashMap[BlockManagerId, 
Int]()
+
+  /**
+   * Count the retry times for the blocks due to Netty OOM. The block will 
stop retry if retry times
+   * has exceeded the [[maxAttemptsOnNettyOOM]].
+   */
+  private[this] val blockOOMRetryCounts = new HashMap[String, Int]
+
+  /**
+   * The blocks that can't be decompressed successfully, it is used to 
guarantee that we retry at
+   * most once for those corrupted blocks.
+   */
+  private[this] val corruptedBlocks = mutable.HashSet[BlockId]()
+
+  /**
+   * Whether the iterator is still active. If isZombie is true, the callback 
interface will no
+   * longer place fetched blocks into [[results]].
+   */
+  @GuardedBy("this")
+  private[this] var isZombie = false
+
+  /**
+   * A set to store the files used for shuffling remote huge blocks. Files in 
this set will be
+   * deleted when cleanup. This is a layer of defensiveness against disk file 
leaks.
+   */
+  @GuardedBy("this")
+  private[this] val shuffleFilesSet = mutable.HashSet[DownloadFile]()
+
+  private[this] val onCompleteCallback = new 
GlutenShuffleFetchCompletionListener(this)
+
+  private[this] val pushBasedFetchHelper = new GlutenPushBasedFetchHelper(
+    this,
+    shuffleClient,
+    blockManager,
+    mapOutputTracker)
+
+  initialize()
+
+  // Decrements the buffer reference count.
+  // The currentResult is removed from the map to prevent releasing the buffer 
again on cleanup()
+  private[storage] def releaseCurrentResultBuffer(): Unit = {
+    val threadId = Thread.currentThread().getId
+    // Release the current buffer if necessary
+    val result = currentResults.remove(threadId)
+    if (result != null) {
+      result.buf.release()
+    }
+  }
+
+  override def createTempFile(transportConf: TransportConf): DownloadFile = {
+    // we never need to do any encryption or decryption here, regardless of 
configs, because that
+    // is handled at another layer in the code.  When encryption is enabled, 
shuffle data is written
+    // to disk encrypted in the first place, and sent over the network still 
encrypted.
+    new SimpleDownloadFile(
+      blockManager.diskBlockManager.createTempLocalBlock()._2,
+      transportConf)
+  }
+
+  override def registerTempFileToClean(file: DownloadFile): Boolean = 
synchronized {
+    if (isZombie) {
+      false
+    } else {
+      shuffleFilesSet += file
+      true
+    }
+  }
+
+  /**
+   * Mark the iterator as zombie, and release all buffers that haven't been 
deserialized yet.
+   */
+  private[storage] def cleanup(): Unit = {
+    synchronized {
+      isZombie = true
+    }
+    releaseCurrentResultBuffer()
+    // Release buffers in the results queue
+    val iter = results.iterator()
+    while (iter.hasNext) {
+      val result = iter.next()
+      result match {
+        case SuccessFetchResult(blockId, mapIndex, address, _, buf, _) =>
+          if (address != blockManager.blockManagerId) {
+            if (hostLocalBlocks.contains(blockId -> mapIndex)) {
+              shuffleMetrics.incLocalBlocksFetched(1)
+              shuffleMetrics.incLocalBytesRead(buf.size)
+            } else {
+              shuffleMetrics.incRemoteBytesRead(buf.size)
+              if (buf.isInstanceOf[FileSegmentManagedBuffer]) {
+                shuffleMetrics.incRemoteBytesReadToDisk(buf.size)
+              }
+              shuffleMetrics.incRemoteBlocksFetched(1)
+            }
+          }
+          buf.release()
+        case _ =>
+      }
+    }
+    shuffleFilesSet.foreach {
+      file =>
+        if (!file.delete()) {
+          logWarning("Failed to cleanup shuffle fetch temp file " + 
file.path())
+        }
+    }
+  }
+
+  private[this] def sendRequest(req: FetchRequest): Unit = {
+    logDebug("Sending request for %d blocks (%s) from %s".format(
+      req.blocks.size,
+      Utils.bytesToString(req.size),
+      req.address.hostPort))
+    bytesInFlight += req.size
+    reqsInFlight += 1
+
+    // so we can look up the block info of each blockID
+    val infoMap = req.blocks.map {
+      case FetchBlockInfo(blockId, size, mapIndex) => (blockId.toString, 
(size, mapIndex))
+    }.toMap
+    val remainingBlocks = new HashSet[String]() ++= infoMap.keys
+    val deferredBlocks = new ArrayBuffer[String]()
+    val blockIds = req.blocks.map(_.blockId.toString)
+    val address = req.address
+
+    @inline def enqueueDeferredFetchRequestIfNecessary(): Unit = {
+      if (remainingBlocks.isEmpty && deferredBlocks.nonEmpty) {
+        val blocks = deferredBlocks.map {
+          blockId =>
+            val (size, mapIndex) = infoMap(blockId)
+            FetchBlockInfo(BlockId(blockId), size, mapIndex)
+        }
+        results.put(DeferFetchRequestResult(FetchRequest(address, 
blocks.toSeq)))
+        deferredBlocks.clear()
+      }
+    }
+
+    val blockFetchingListener = new BlockFetchingListener {
+      override def onBlockFetchSuccess(blockId: String, buf: ManagedBuffer): 
Unit = {
+        // Only add the buffer to results queue if the iterator is not zombie,
+        // i.e. cleanup() has not been called yet.
+        GlutenShuffleBlockFetcherIterator.this.synchronized {
+          if (!isZombie) {
+            // Increment the ref count because we need to pass this to a 
different thread.
+            // This needs to be released after use.
+            buf.retain()
+            remainingBlocks -= blockId
+            blockOOMRetryCounts.remove(blockId)
+            results.put(new SuccessFetchResult(
+              BlockId(blockId),
+              infoMap(blockId)._2,
+              address,
+              infoMap(blockId)._1,
+              buf,
+              remainingBlocks.isEmpty))
+            logDebug("remainingBlocks: " + remainingBlocks)
+            enqueueDeferredFetchRequestIfNecessary()
+          }
+        }
+        logTrace(s"Got remote block $blockId after 
${Utils.getUsedTimeNs(startTimeNs)}")
+      }
+
+      override def onBlockFetchFailure(blockId: String, e: Throwable): Unit = {
+        GlutenShuffleBlockFetcherIterator.this.synchronized {
+          logError(s"Failed to get block(s) from 
${req.address.host}:${req.address.port}", e)
+          e match {
+            // SPARK-27991: Catch the Netty OOM and set the flag 
`isNettyOOMOnShuffle` (shared among
+            // tasks) to true as early as possible. The pending fetch requests 
won't be sent
+            // afterwards until the flag is set to false on:
+            // 1) the Netty free memory >= maxReqSizeShuffleToMem
+            //    - we'll check this whenever there's a fetch request succeeds.
+            // 2) the number of in-flight requests becomes 0
+            //    - we'll check this in `fetchUpToMaxBytes` whenever it's 
invoked.
+            // Although Netty memory is shared across multiple modules, e.g., 
shuffle, rpc, the flag
+            // only takes effect for the shuffle due to the implementation 
simplicity concern.
+            // And we'll buffer the consecutive block failures caused by the 
OOM error until there's
+            // no remaining blocks in the current request. Then, we'll package 
these blocks into
+            // a same fetch request for the retry later. In this way, instead 
of creating the fetch
+            // request per block, it would help reduce the concurrent 
connections and data loads
+            // pressure at remote server.
+            // Note that catching OOM and do something based on it is only a 
workaround for
+            // handling the Netty OOM issue, which is not the best way towards 
memory management.
+            // We can get rid of it when we find a way to manage Netty's 
memory precisely.
+            case _: OutOfDirectMemoryError
+                if blockOOMRetryCounts.getOrElseUpdate(blockId, 0) < 
maxAttemptsOnNettyOOM =>
+              if (!isZombie) {
+                val failureTimes = blockOOMRetryCounts(blockId)
+                blockOOMRetryCounts(blockId) += 1
+                if (isNettyOOMOnShuffle.compareAndSet(false, true)) {
+                  // The fetcher can fail remaining blocks in batch for the 
same error. So we only
+                  // log the warning once to avoid flooding the logs.
+                  logInfo(s"Block $blockId has failed $failureTimes times " +
+                    s"due to Netty OOM, will retry")
+                }
+                remainingBlocks -= blockId
+                deferredBlocks += blockId
+                enqueueDeferredFetchRequestIfNecessary()
+              }
+
+            case _ =>
+              val block = BlockId(blockId)
+              if (block.isShuffleChunk) {
+                remainingBlocks -= blockId
+                results.put(FallbackOnPushMergedFailureResult(
+                  block,
+                  address,
+                  infoMap(blockId)._1,
+                  remainingBlocks.isEmpty))
+              } else {
+                results.put(FailureFetchResult(block, infoMap(blockId)._2, 
address, e))
+              }
+          }
+        }
+      }
+    }
+
+    // Fetch remote shuffle blocks to disk when the request is too large. 
Since the shuffle data is
+    // already encrypted and compressed over the wire(w.r.t. the related 
configs), we can just fetch
+    // the data and write it to file directly.
+    if (req.size > maxReqSizeShuffleToMem) {
+      shuffleClient.fetchBlocks(
+        address.host,
+        address.port,
+        address.executorId,
+        blockIds.toArray,
+        blockFetchingListener,
+        this)
+    } else {
+      shuffleClient.fetchBlocks(
+        address.host,
+        address.port,
+        address.executorId,
+        blockIds.toArray,
+        blockFetchingListener,
+        null)
+    }
+  }
+
+  /**
+   * This is called from initialize and also from the fallback which is 
triggered from
+   * [[PushBasedFetchHelper]].
+   */
+  private[this] def partitionBlocksByFetchMode(
+      blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])],
+      localBlocks: mutable.LinkedHashSet[(BlockId, Int)],
+      hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, 
Seq[(BlockId, Long, Int)]],
+      pushMergedLocalBlocks: mutable.LinkedHashSet[BlockId]): 
ArrayBuffer[FetchRequest] = {
+    logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
+      + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: 
$maxBlocksInFlightPerAddress")
+
+    // Partition to local, host-local, push-merged-local, remote (includes 
push-merged-remote)
+    // blocks.Remote blocks are further split into FetchRequests of size at 
most maxBytesInFlight
+    // in order to limit the amount of data in flight
+    val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
+    var localBlockBytes = 0L
+    var hostLocalBlockBytes = 0L
+    var numHostLocalBlocks = 0
+    var pushMergedLocalBlockBytes = 0L
+    val prevNumBlocksToFetch = numBlocksToFetch
+
+    val fallback = FallbackStorage.FALLBACK_BLOCK_MANAGER_ID.executorId
+    val localExecIds = Set(blockManager.blockManagerId.executorId, fallback)
+    for ((address, blockInfos) <- blocksByAddress) {
+      checkBlockSizes(blockInfos)
+      if (pushBasedFetchHelper.isPushMergedShuffleBlockAddress(address)) {
+        // These are push-merged blocks or shuffle chunks of these blocks.
+        if (address.host == blockManager.blockManagerId.host) {
+          numBlocksToFetch += blockInfos.size
+          pushMergedLocalBlocks ++= blockInfos.map(_._1)
+          pushMergedLocalBlockBytes += blockInfos.map(_._2).sum
+        } else {
+          collectFetchRequests(address, blockInfos, collectedRemoteRequests)
+        }
+      } else if (localExecIds.contains(address.executorId)) {
+        val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
+          blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)),
+          doBatchFetch)
+        numBlocksToFetch += mergedBlockInfos.size
+        localBlocks ++= mergedBlockInfos.map(info => (info.blockId, 
info.mapIndex))
+        localBlockBytes += mergedBlockInfos.map(_.size).sum
+      } else if (
+        blockManager.hostLocalDirManager.isDefined &&
+        address.host == blockManager.blockManagerId.host
+      ) {
+        val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
+          blockInfos.map(info => FetchBlockInfo(info._1, info._2, info._3)),
+          doBatchFetch)
+        numBlocksToFetch += mergedBlockInfos.size
+        val blocksForAddress =
+          mergedBlockInfos.map(info => (info.blockId, info.size, 
info.mapIndex))
+        hostLocalBlocksByExecutor += address -> blocksForAddress
+        numHostLocalBlocks += blocksForAddress.size
+        hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum
+      } else {
+        val (_, timeCost) = Utils.timeTakenMs[Unit] {
+          collectFetchRequests(address, blockInfos, collectedRemoteRequests)
+        }
+        logDebug(s"Collected remote fetch requests for $address in $timeCost 
ms")
+      }
+    }
+    val (remoteBlockBytes, numRemoteBlocks) =
+      collectedRemoteRequests.foldLeft((0L, 0))((x, y) => (x._1 + y.size, x._2 
+ y.blocks.size))
+    val totalBytes = localBlockBytes + remoteBlockBytes + hostLocalBlockBytes +
+      pushMergedLocalBlockBytes
+    val blocksToFetchCurrentIteration = numBlocksToFetch - prevNumBlocksToFetch
+    assert(
+      blocksToFetchCurrentIteration == localBlocks.size +
+        numHostLocalBlocks + numRemoteBlocks + pushMergedLocalBlocks.size,
+      s"The number of non-empty blocks $blocksToFetchCurrentIteration doesn't 
equal to the sum " +
+        s"of the number of local blocks ${localBlocks.size} + " +
+        s"the number of host-local blocks $numHostLocalBlocks " +
+        s"the number of push-merged-local blocks ${pushMergedLocalBlocks.size} 
" +
+        s"+ the number of remote blocks $numRemoteBlocks "
+    )
+    logInfo(s"Getting $blocksToFetchCurrentIteration " +
+      s"(${Utils.bytesToString(totalBytes)}) non-empty blocks including " +
+      s"${localBlocks.size} (${Utils.bytesToString(localBlockBytes)}) local 
and " +
+      s"$numHostLocalBlocks (${Utils.bytesToString(hostLocalBlockBytes)}) " +
+      s"host-local and ${pushMergedLocalBlocks.size} " +
+      s"(${Utils.bytesToString(pushMergedLocalBlockBytes)}) " +
+      s"push-merged-local and $numRemoteBlocks 
(${Utils.bytesToString(remoteBlockBytes)}) " +
+      s"remote blocks")
+    this.hostLocalBlocks ++= hostLocalBlocksByExecutor.values
+      .flatMap(infos => infos.map(info => (info._1, info._3)))
+    collectedRemoteRequests
+  }
+
+  private def createFetchRequest(
+      blocks: Seq[FetchBlockInfo],
+      address: BlockManagerId,
+      forMergedMetas: Boolean): FetchRequest = {
+    logDebug(s"Creating fetch request of ${blocks.map(_.size).sum} at $address 
"
+      + s"with ${blocks.size} blocks")
+    FetchRequest(address, blocks, forMergedMetas)
+  }
+
+  private def createFetchRequests(
+      curBlocks: Seq[FetchBlockInfo],
+      address: BlockManagerId,
+      isLast: Boolean,
+      collectedRemoteRequests: ArrayBuffer[FetchRequest],
+      enableBatchFetch: Boolean,
+      forMergedMetas: Boolean = false): ArrayBuffer[FetchBlockInfo] = {
+    val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks, 
enableBatchFetch)
+    numBlocksToFetch += mergedBlocks.size
+    val retBlocks = new ArrayBuffer[FetchBlockInfo]
+    if (mergedBlocks.length <= maxBlocksInFlightPerAddress) {
+      collectedRemoteRequests += createFetchRequest(mergedBlocks, address, 
forMergedMetas)
+    } else {
+      mergedBlocks.grouped(maxBlocksInFlightPerAddress).foreach {
+        blocks =>
+          if (blocks.length == maxBlocksInFlightPerAddress || isLast) {
+            collectedRemoteRequests += createFetchRequest(blocks, address, 
forMergedMetas)
+          } else {
+            // The last group does not exceed `maxBlocksInFlightPerAddress`. 
Put it back
+            // to `curBlocks`.
+            retBlocks ++= blocks
+            numBlocksToFetch -= blocks.size
+          }
+      }
+    }
+    retBlocks
+  }
+
+  private def collectFetchRequests(
+      address: BlockManagerId,
+      blockInfos: Seq[(BlockId, Long, Int)],
+      collectedRemoteRequests: ArrayBuffer[FetchRequest]): Unit = {
+    val iterator = blockInfos.iterator
+    var curRequestSize = 0L
+    var curBlocks = new ArrayBuffer[FetchBlockInfo]()
+
+    while (iterator.hasNext) {
+      val (blockId, size, mapIndex) = iterator.next()
+      curBlocks += FetchBlockInfo(blockId, size, mapIndex)
+      curRequestSize += size
+      blockId match {
+        // Either all blocks are push-merged blocks, shuffle chunks, or 
original blocks.
+        // Based on these types, we decide to do batch fetch and create 
FetchRequests with
+        // forMergedMetas set.
+        case ShuffleBlockChunkId(_, _, _, _) =>
+          if (
+            curRequestSize >= targetRemoteRequestSize ||
+            curBlocks.size >= maxBlocksInFlightPerAddress
+          ) {
+            curBlocks = createFetchRequests(
+              curBlocks.toSeq,
+              address,
+              isLast = false,
+              collectedRemoteRequests,
+              enableBatchFetch = false)
+            curRequestSize = curBlocks.map(_.size).sum
+          }
+        case ShuffleMergedBlockId(_, _, _) =>
+          if (curBlocks.size >= maxBlocksInFlightPerAddress) {
+            curBlocks = createFetchRequests(
+              curBlocks.toSeq,
+              address,
+              isLast = false,
+              collectedRemoteRequests,
+              enableBatchFetch = false,
+              forMergedMetas = true)
+          }
+        case _ =>
+          // For batch fetch, the actual block in flight should count for 
merged block.
+          val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= 
maxBlocksInFlightPerAddress
+          if (curRequestSize >= targetRemoteRequestSize || 
mayExceedsMaxBlocks) {
+            curBlocks = createFetchRequests(
+              curBlocks.toSeq,
+              address,
+              isLast = false,
+              collectedRemoteRequests,
+              doBatchFetch)
+            curRequestSize = curBlocks.map(_.size).sum
+          }
+      }
+    }
+    // Add in the final request
+    if (curBlocks.nonEmpty) {
+      val (enableBatchFetch, forMergedMetas) = {
+        curBlocks.head.blockId match {
+          case ShuffleBlockChunkId(_, _, _, _) => (false, false)
+          case ShuffleMergedBlockId(_, _, _) => (false, true)
+          case _ => (doBatchFetch, false)
+        }
+      }
+      createFetchRequests(
+        curBlocks.toSeq,
+        address,
+        isLast = true,
+        collectedRemoteRequests,
+        enableBatchFetch = enableBatchFetch,
+        forMergedMetas = forMergedMetas)
+    }
+  }
+
+  private def assertPositiveBlockSize(blockId: BlockId, blockSize: Long): Unit 
= {
+    if (blockSize < 0) {
+      throw BlockException(blockId, "Negative block size " + size)
+    } else if (blockSize == 0) {
+      throw BlockException(blockId, "Zero-sized blocks should be excluded.")
+    }
+  }
+
+  private def checkBlockSizes(blockInfos: Seq[(BlockId, Long, Int)]): Unit = {
+    blockInfos.foreach { case (blockId, size, _) => 
assertPositiveBlockSize(blockId, size) }
+  }
+
+  /**
+   * Fetch the local blocks while we are fetching remote blocks. This is ok 
because
+   * `ManagedBuffer`'s memory is allocated lazily when we create the input 
stream, so all we track
+   * in-memory are the ManagedBuffer references themselves.
+   */
+  private[this] def fetchLocalBlocks(
+      localBlocks: mutable.LinkedHashSet[(BlockId, Int)]): Unit = {
+    logDebug(s"Start fetching local blocks: ${localBlocks.mkString(", ")}")
+    val iter = localBlocks.iterator
+    while (iter.hasNext) {
+      val (blockId, mapIndex) = iter.next()
+      try {
+        val buf = blockManager.getLocalBlockData(blockId)
+        shuffleMetrics.incLocalBlocksFetched(1)
+        shuffleMetrics.incLocalBytesRead(buf.size)
+        buf.retain()
+        results.put(new SuccessFetchResult(
+          blockId,
+          mapIndex,
+          blockManager.blockManagerId,
+          buf.size(),
+          buf,
+          false))
+      } catch {
+        // If we see an exception, stop immediately.
+        case e: Exception =>
+          e match {
+            // ClosedByInterruptException is an excepted exception when kill 
task,
+            // don't log the exception stack trace to avoid confusing users.
+            // See: SPARK-28340
+            case ce: ClosedByInterruptException =>
+              logError("Error occurred while fetching local blocks, " + 
ce.getMessage)
+            case ex: Exception => logError("Error occurred while fetching 
local blocks", ex)
+          }
+          results.put(new FailureFetchResult(blockId, mapIndex, 
blockManager.blockManagerId, e))
+          return
+      }
+    }
+  }
+
+  private[this] def fetchHostLocalBlock(
+      blockId: BlockId,
+      mapIndex: Int,
+      localDirs: Array[String],
+      blockManagerId: BlockManagerId): Boolean = {
+    try {
+      val buf = blockManager.getHostLocalShuffleData(blockId, localDirs)
+      buf.retain()
+      results.put(SuccessFetchResult(
+        blockId,
+        mapIndex,
+        blockManagerId,
+        buf.size(),
+        buf,
+        isNetworkReqDone = false))
+      true
+    } catch {
+      case e: Exception =>
+        // If we see an exception, stop immediately.
+        logError(s"Error occurred while fetching local blocks", e)
+        results.put(FailureFetchResult(blockId, mapIndex, blockManagerId, e))
+        false
+    }
+  }
+
+  /**
+   * Fetch the host-local blocks while we are fetching remote blocks. This is 
ok because
+   * `ManagedBuffer`'s memory is allocated lazily when we create the input 
stream, so all we track
+   * in-memory are the ManagedBuffer references themselves.
+   */
+  private[this] def fetchHostLocalBlocks(
+      hostLocalDirManager: HostLocalDirManager,
+      hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, 
Seq[(BlockId, Long, Int)]])
+      : Unit = {
+    val cachedDirsByExec = hostLocalDirManager.getCachedHostLocalDirs
+    val (hostLocalBlocksWithCachedDirs, hostLocalBlocksWithMissingDirs) = {
+      val (hasCache, noCache) = hostLocalBlocksByExecutor.partition {
+        case (hostLocalBmId, _) =>
+          cachedDirsByExec.contains(hostLocalBmId.executorId)
+      }
+      (hasCache.toMap, noCache.toMap)
+    }
+
+    if (hostLocalBlocksWithMissingDirs.nonEmpty) {
+      logDebug(s"Asynchronous fetching host-local blocks without cached 
executors' dir: " +
+        s"${hostLocalBlocksWithMissingDirs.mkString(", ")}")
+
+      // If the external shuffle service is enabled, we'll fetch the local 
directories for
+      // multiple executors from the external shuffle service, which located 
at the same host
+      // with the executors, in once. Otherwise, we'll fetch the local 
directories from those
+      // executors directly one by one. The fetch requests won't be too much 
since one host is
+      // almost impossible to have many executors at the same time practically.
+      val dirFetchRequests = if (blockManager.externalShuffleServiceEnabled) {
+        val host = blockManager.blockManagerId.host
+        val port = blockManager.externalShuffleServicePort
+        Seq((host, port, hostLocalBlocksWithMissingDirs.keys.toArray))
+      } else {
+        hostLocalBlocksWithMissingDirs.keys.map(bmId => (bmId.host, bmId.port, 
Array(bmId))).toSeq
+      }
+
+      dirFetchRequests.foreach {
+        case (host, port, bmIds) =>
+          hostLocalDirManager.getHostLocalDirs(host, port, 
bmIds.map(_.executorId)) {
+            case Success(dirsByExecId) =>
+              fetchMultipleHostLocalBlocks(
+                
hostLocalBlocksWithMissingDirs.filterKeys(bmIds.contains).toMap,
+                dirsByExecId,
+                cached = false)
+
+            case Failure(throwable) =>
+              logError("Error occurred while fetching host local blocks", 
throwable)
+              val bmId = bmIds.head
+              val blockInfoSeq = hostLocalBlocksWithMissingDirs(bmId)
+              val (blockId, _, mapIndex) = blockInfoSeq.head
+              results.put(FailureFetchResult(blockId, mapIndex, bmId, 
throwable))
+          }
+      }
+    }
+
+    if (hostLocalBlocksWithCachedDirs.nonEmpty) {
+      logDebug(s"Synchronous fetching host-local blocks with cached executors' 
dir: " +
+        s"${hostLocalBlocksWithCachedDirs.mkString(", ")}")
+      fetchMultipleHostLocalBlocks(hostLocalBlocksWithCachedDirs, 
cachedDirsByExec, cached = true)
+    }
+  }
+
+  private def fetchMultipleHostLocalBlocks(
+      bmIdToBlocks: Map[BlockManagerId, Seq[(BlockId, Long, Int)]],
+      localDirsByExecId: Map[String, Array[String]],
+      cached: Boolean): Unit = {
+    // We use `forall` because once there's a failed block fetch, 
`fetchHostLocalBlock` will put
+    // a `FailureFetchResult` immediately to the `results`. So there's no 
reason to fetch the
+    // remaining blocks.
+    val allFetchSucceeded = bmIdToBlocks.forall {
+      case (bmId, blockInfos) =>
+        blockInfos.forall {
+          case (blockId, _, mapIndex) =>
+            fetchHostLocalBlock(blockId, mapIndex, 
localDirsByExecId(bmId.executorId), bmId)
+        }
+    }
+    if (allFetchSucceeded) {
+      logDebug(s"Got host-local blocks from ${bmIdToBlocks.keys.mkString(", 
")} " +
+        s"(${if (cached) "with" else "without"} cached executors' dir) " +
+        s"in ${Utils.getUsedTimeNs(startTimeNs)}")
+    }
+  }
+
+  private[this] def initialize(): Unit = {
+    // Add a task completion callback (called in both success case and failure 
case) to cleanup.
+    context.addTaskCompletionListener(onCompleteCallback)
+    // Local blocks to fetch, excluding zero-sized blocks.
+    val localBlocks = mutable.LinkedHashSet[(BlockId, Int)]()
+    val hostLocalBlocksByExecutor =
+      mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]()
+    val pushMergedLocalBlocks = mutable.LinkedHashSet[BlockId]()
+    // Partition blocks by the different fetch modes: local, host-local, 
push-merged-local and
+    // remote blocks.
+    val remoteRequests = partitionBlocksByFetchMode(
+      blocksByAddress,
+      localBlocks,
+      hostLocalBlocksByExecutor,
+      pushMergedLocalBlocks)
+    // Add the remote requests into our queue in a random order
+    fetchRequests ++= Utils.randomize(remoteRequests)
+    assert(
+      (0 == reqsInFlight) == (0 == bytesInFlight),
+      "expected reqsInFlight = 0 but found reqsInFlight = " + reqsInFlight +
+        ", expected bytesInFlight = 0 but found bytesInFlight = " + 
bytesInFlight
+    )
+
+    // Send out initial requests for blocks, up to our maxBytesInFlight
+    fetchUpToMaxBytes()
+
+    val numDeferredRequest = deferredFetchRequests.values.map(_.size).sum
+    val numFetches = remoteRequests.size - fetchRequests.size - 
numDeferredRequest
+    logInfo(s"Started $numFetches remote fetches in 
${Utils.getUsedTimeNs(startTimeNs)}" +
+      (if (numDeferredRequest > 0) s", deferred $numDeferredRequest requests" 
else ""))
+
+    // Get Local Blocks
+    fetchLocalBlocks(localBlocks)
+    logDebug(s"Got local blocks in ${Utils.getUsedTimeNs(startTimeNs)}")
+    // Get host local blocks if any
+    fetchAllHostLocalBlocks(hostLocalBlocksByExecutor)
+    pushBasedFetchHelper.fetchAllPushMergedLocalBlocks(pushMergedLocalBlocks)
+  }
+
+  private def fetchAllHostLocalBlocks(
+      hostLocalBlocksByExecutor: mutable.LinkedHashMap[BlockManagerId, 
Seq[(BlockId, Long, Int)]])
+      : Unit = {
+    if (hostLocalBlocksByExecutor.nonEmpty) {
+      blockManager.hostLocalDirManager.foreach(fetchHostLocalBlocks(_, 
hostLocalBlocksByExecutor))
+    }
+  }
+
+  override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch
+
+  /**
+   * Fetches the next (BlockId, InputStream). If a task fails, the 
ManagedBuffers underlying each
+   * InputStream will be freed by the cleanup() method registered with the 
TaskCompletionListener.
+   * However, callers should close() these InputStreams as soon as they are no 
longer needed, in
+   * order to release memory as early as possible.
+   *
+   * Throws a FetchFailedException if the next block could not be fetched.
+   */
+  override def next(): (BlockId, InputStream) = {
+    if (!hasNext) {
+      throw SparkCoreErrors.noSuchElementError()
+    }
+
+    numBlocksProcessed += 1
+
+    var result: FetchResult = null
+    var input: InputStream = null
+    // This's only initialized when shuffle checksum is enabled.
+    var checkedIn: CheckedInputStream = null
+    var streamCompressedOrEncrypted: Boolean = false
+    // Take the next fetched result and try to decompress it to detect data 
corruption,
+    // then fetch it one more time if it's corrupt, throw FailureFetchResult 
if the second fetch
+    // is also corrupt, so the previous stage could be retried.
+    // For local shuffle block, throw FailureFetchResult for the first 
IOException.
+    while (result == null) {
+      val startFetchWait = System.nanoTime()
+      result = results.take()
+      val fetchWaitTime = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - 
startFetchWait)
+      shuffleMetrics.incFetchWaitTime(fetchWaitTime)
+
+      result match {
+        case r @ SuccessFetchResult(blockId, mapIndex, address, size, buf, 
isNetworkReqDone) =>
+          if (address != blockManager.blockManagerId) {
+            if (
+              hostLocalBlocks.contains(blockId -> mapIndex) ||
+              pushBasedFetchHelper.isLocalPushMergedBlockAddress(address)
+            ) {
+              // It is a host local block or a local shuffle chunk
+              shuffleMetrics.incLocalBlocksFetched(1)
+              shuffleMetrics.incLocalBytesRead(buf.size)
+            } else {
+              numBlocksInFlightPerAddress(address) = 
numBlocksInFlightPerAddress(address) - 1
+              shuffleMetrics.incRemoteBytesRead(buf.size)
+              if (buf.isInstanceOf[FileSegmentManagedBuffer]) {
+                shuffleMetrics.incRemoteBytesReadToDisk(buf.size)
+              }
+              shuffleMetrics.incRemoteBlocksFetched(1)
+              bytesInFlight -= size
+            }
+          }
+          if (isNetworkReqDone) {
+            reqsInFlight -= 1
+            resetNettyOOMFlagIfPossible(maxReqSizeShuffleToMem)
+            logDebug("Number of requests in flight " + reqsInFlight)
+          }
+
+          val in = if (buf.size == 0) {
+            // We will never legitimately receive a zero-size block. All 
blocks with zero records
+            // have zero size and all zero-size blocks have no records (and 
hence should never
+            // have been requested in the first place). This statement relies 
on behaviors of the
+            // shuffle writers, which are guaranteed by the following test 
cases:
+            //
+            // - BypassMergeSortShuffleWriterSuite: "write with some empty 
partitions"
+            // - UnsafeShuffleWriterSuite: "writeEmptyIterator"
+            // - DiskBlockObjectWriterSuite: "commit() and close() without 
ever opening or writing"
+            //
+            // There is not an explicit test for SortShuffleWriter but the 
underlying APIs that
+            // uses are shared by the UnsafeShuffleWriter (both writers use 
DiskBlockObjectWriter
+            // which returns a zero-size from commitAndGet() in case no 
records were written
+            // since the last call.
+            val msg = s"Received a zero-size buffer for block $blockId from 
$address " +
+              s"(expectedApproxSize = $size, 
isNetworkReqDone=$isNetworkReqDone)"
+            if (blockId.isShuffleChunk) {
+              // Zero-size block may come from nodes with hardware failures, 
For shuffle chunks,
+              // the original shuffle blocks that belong to that zero-size 
shuffle chunk is
+              // available and we can opt to fallback immediately.
+              logWarning(msg)
+              
pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address)
+              // Set result to null to trigger another iteration of the while 
loop to get either.
+              result = null
+              null
+            } else {
+              throwFetchFailedException(blockId, mapIndex, address, new 
IOException(msg))
+            }
+          } else {
+            try {
+              val bufIn = buf.createInputStream()
+              if (checksumEnabled) {
+                val checksum = 
ShuffleChecksumHelper.getChecksumByAlgorithm(checksumAlgorithm)
+                checkedIn = new CheckedInputStream(bufIn, checksum)
+                checkedIn
+              } else {
+                bufIn
+              }
+            } catch {
+              // The exception could only be throwed by local shuffle block
+              case e: IOException =>
+                assert(buf.isInstanceOf[FileSegmentManagedBuffer])
+                e match {
+                  case ce: ClosedByInterruptException =>
+                    logError("Failed to create input stream from local block, 
" +
+                      ce.getMessage)
+                  case e: IOException =>
+                    logError("Failed to create input stream from local block", 
e)
+                }
+                buf.release()
+                if (blockId.isShuffleChunk) {
+                  
pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address)
+                  // Set result to null to trigger another iteration of the 
while loop to get
+                  // either.
+                  result = null
+                  null
+                } else {
+                  throwFetchFailedException(blockId, mapIndex, address, e)
+                }
+            }
+          }
+
+          if (in != null) {
+            try {
+              input = streamWrapper(blockId, in)
+              // If the stream is compressed or wrapped, then we optionally 
decompress/unwrap the
+              // first maxBytesInFlight/3 bytes into memory, to check for 
corruption in that portion
+              // of the data. But even if 'detectCorruptUseExtraMemory' 
configuration is off, or if
+              // the corruption is later, we'll still detect the corruption 
later in the stream.
+              streamCompressedOrEncrypted = !input.eq(in)
+              if (streamCompressedOrEncrypted && detectCorruptUseExtraMemory) {
+                // TODO: manage the memory used here, and spill it into disk 
in case of OOM.
+                input = Utils.copyStreamUpTo(input, maxBytesInFlight / 3)
+              }
+            } catch {
+              case e: IOException =>
+                // When shuffle checksum is enabled, for a block that is 
corrupted twice,
+                // we'd calculate the checksum of the block by consuming the 
remaining data
+                // in the buf. So, we should release the buf later.
+                if (!(checksumEnabled && corruptedBlocks.contains(blockId))) {
+                  buf.release()
+                }
+
+                if (blockId.isShuffleChunk) {
+                  // TODO (SPARK-36284): Add shuffle checksum support for 
push-based shuffle
+                  // Retrying a corrupt block may result again in a corrupt 
block. For shuffle
+                  // chunks, we opt to fallback on the original shuffle blocks 
that belong to that
+                  // corrupt shuffle chunk immediately instead of retrying to 
fetch the corrupt
+                  // chunk. This also makes the code simpler because the 
chunkMeta corresponding to
+                  // a shuffle chunk is always removed from chunksMetaMap 
whenever a shuffle chunk
+                  // gets processed. If we try to re-fetch a corrupt shuffle 
chunk, then it has to
+                  // be added back to the chunksMetaMap.
+                  
pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address)
+                  // Set result to null to trigger another iteration of the 
while loop.
+                  result = null
+                } else if (buf.isInstanceOf[FileSegmentManagedBuffer]) {
+                  throwFetchFailedException(blockId, mapIndex, address, e)
+                } else if (corruptedBlocks.contains(blockId)) {
+                  // It's the second time this block is detected corrupted
+                  if (checksumEnabled) {
+                    // Diagnose the cause of data corruption if shuffle 
checksum is enabled
+                    val diagnosisResponse = diagnoseCorruption(checkedIn, 
address, blockId)
+                    buf.release()
+                    logError(diagnosisResponse)
+                    throwFetchFailedException(
+                      blockId,
+                      mapIndex,
+                      address,
+                      e,
+                      Some(diagnosisResponse))
+                  } else {
+                    throwFetchFailedException(blockId, mapIndex, address, e)
+                  }
+                } else {
+                  // It's the first time this block is detected corrupted
+                  logWarning(s"got an corrupted block $blockId from $address, 
fetch again", e)
+                  corruptedBlocks += blockId
+                  fetchRequests += FetchRequest(
+                    address,
+                    Array(FetchBlockInfo(blockId, size, mapIndex)))
+                  result = null
+                }
+            } finally {
+              if (blockId.isShuffleChunk) {
+                
pushBasedFetchHelper.removeChunk(blockId.asInstanceOf[ShuffleBlockChunkId])
+              }
+              // TODO: release the buf here to free memory earlier
+              if (input == null) {
+                // Close the underlying stream if there was an issue in 
wrapping the stream using
+                // streamWrapper
+                in.close()
+              }
+            }
+          }
+
+        case FailureFetchResult(blockId, mapIndex, address, e) =>
+          var errorMsg: String = null
+          if (e.isInstanceOf[OutOfDirectMemoryError]) {
+            errorMsg = s"Block $blockId fetch failed after 
$maxAttemptsOnNettyOOM " +
+              s"retries due to Netty OOM"
+            logError(errorMsg)
+          }
+          throwFetchFailedException(blockId, mapIndex, address, e, 
Some(errorMsg))
+
+        case DeferFetchRequestResult(request) =>
+          val address = request.address
+          numBlocksInFlightPerAddress(address) =
+            numBlocksInFlightPerAddress(address) - request.blocks.size
+          bytesInFlight -= request.size
+          reqsInFlight -= 1
+          logDebug("Number of requests in flight " + reqsInFlight)
+          val defReqQueue =
+            deferredFetchRequests.getOrElseUpdate(address, new 
Queue[FetchRequest]())
+          defReqQueue.enqueue(request)
+          result = null
+
+        case FallbackOnPushMergedFailureResult(blockId, address, size, 
isNetworkReqDone) =>
+          // We get this result in 3 cases:
+          // 1. Failure to fetch the data of a remote shuffle chunk. In this 
case, the
+          //    blockId is a ShuffleBlockChunkId.
+          // 2. Failure to read the push-merged-local meta. In this case, the 
blockId is
+          //    ShuffleBlockId.
+          // 3. Failure to get the push-merged-local directories from the 
external shuffle service.
+          //    In this case, the blockId is ShuffleBlockId.
+          if (pushBasedFetchHelper.isRemotePushMergedBlockAddress(address)) {
+            numBlocksInFlightPerAddress(address) = 
numBlocksInFlightPerAddress(address) - 1
+            bytesInFlight -= size
+          }
+          if (isNetworkReqDone) {
+            reqsInFlight -= 1
+            logDebug("Number of requests in flight " + reqsInFlight)
+          }
+          
pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(blockId, address)
+          // Set result to null to trigger another iteration of the while loop 
to get either
+          // a SuccessFetchResult or a FailureFetchResult.
+          result = null
+
+        case PushMergedLocalMetaFetchResult(
+              shuffleId,
+              shuffleMergeId,
+              reduceId,
+              bitmaps,
+              localDirs) =>
+          // Fetch push-merged-local shuffle block data as multiple shuffle 
chunks
+          val shuffleBlockId = ShuffleMergedBlockId(shuffleId, shuffleMergeId, 
reduceId)
+          try {
+            val bufs: Seq[ManagedBuffer] = 
blockManager.getLocalMergedBlockData(
+              shuffleBlockId,
+              localDirs)
+            // Since the request for local block meta completed successfully, 
numBlocksToFetch
+            // is decremented.
+            numBlocksToFetch -= 1
+            // Update total number of blocks to fetch, reflecting the multiple 
local shuffle
+            // chunks.
+            numBlocksToFetch += bufs.size
+            bufs.zipWithIndex.foreach {
+              case (buf, chunkId) =>
+                buf.retain()
+                val shuffleChunkId = ShuffleBlockChunkId(
+                  shuffleId,
+                  shuffleMergeId,
+                  reduceId,
+                  chunkId)
+                pushBasedFetchHelper.addChunk(shuffleChunkId, bitmaps(chunkId))
+                results.put(SuccessFetchResult(
+                  shuffleChunkId,
+                  SHUFFLE_PUSH_MAP_ID,
+                  pushBasedFetchHelper.localShuffleMergerBlockMgrId,
+                  buf.size(),
+                  buf,
+                  isNetworkReqDone = false))
+            }
+          } catch {
+            case e: Exception =>
+              // If we see an exception with reading push-merged-local index 
file, we fallback
+              // to fetch the original blocks. We do not report block fetch 
failure
+              // and will continue with the remaining local block read.
+              logWarning(
+                s"Error occurred while reading push-merged-local index, " +
+                  s"prepare to fetch the original blocks",
+                e)
+              pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(
+                shuffleBlockId,
+                pushBasedFetchHelper.localShuffleMergerBlockMgrId)
+          }
+          result = null
+
+        case PushMergedRemoteMetaFetchResult(
+              shuffleId,
+              shuffleMergeId,
+              reduceId,
+              blockSize,
+              bitmaps,
+              address) =>
+          // The original meta request is processed so we decrease 
numBlocksToFetch and
+          // numBlocksInFlightPerAddress by 1. We will collect new shuffle 
chunks request and the
+          // count of this is added to numBlocksToFetch in 
collectFetchReqsFromMergedBlocks.
+          numBlocksInFlightPerAddress(address) = 
numBlocksInFlightPerAddress(address) - 1
+          numBlocksToFetch -= 1
+          val blocksToFetch = 
pushBasedFetchHelper.createChunkBlockInfosFromMetaResponse(
+            shuffleId,
+            shuffleMergeId,
+            reduceId,
+            blockSize,
+            bitmaps)
+          val additionalRemoteReqs = new ArrayBuffer[FetchRequest]
+          collectFetchRequests(address, blocksToFetch.toSeq, 
additionalRemoteReqs)
+          fetchRequests ++= additionalRemoteReqs
+          // Set result to null to force another iteration.
+          result = null
+
+        case PushMergedRemoteMetaFailedFetchResult(
+              shuffleId,
+              shuffleMergeId,
+              reduceId,
+              address) =>
+          // The original meta request failed so we decrease 
numBlocksInFlightPerAddress by 1.
+          numBlocksInFlightPerAddress(address) = 
numBlocksInFlightPerAddress(address) - 1
+          // If we fail to fetch the meta of a push-merged block, we fall back 
to fetching the
+          // original blocks.
+          pushBasedFetchHelper.initiateFallbackFetchForPushMergedBlock(
+            ShuffleMergedBlockId(shuffleId, shuffleMergeId, reduceId),
+            address)
+          // Set result to null to force another iteration.
+          result = null
+      }
+
+      // Send fetch requests up to maxBytesInFlight
+      fetchUpToMaxBytes()
+    }
+
+    val successResult = result.asInstanceOf[SuccessFetchResult]
+    val threadId = Thread.currentThread().getId
+    currentResults.put(threadId, successResult)
+    (
+      successResult.blockId,
+      new GlutenBufferReleasingInputStream(
+        input,
+        this,
+        successResult.blockId,
+        successResult.mapIndex,
+        successResult.address,
+        detectCorrupt && streamCompressedOrEncrypted,
+        successResult.isNetworkReqDone,
+        Option(checkedIn)
+      ))
+  }
+
+  /**
+   * Get the suspect corruption cause for the corrupted block. It should be 
only invoked when
+   * checksum is enabled and corruption was detected at least once.
+   *
+   * This will firstly consume the rest of stream of the corrupted block to 
calculate the checksum
+   * of the block. Then, it will raise a synchronized RPC call along with the 
checksum to ask the
+   * server(where the corrupted block is fetched from) to diagnose the cause 
of corruption and
+   * return it.
+   *
+   * Any exception raised during the process will result in the 
[[Cause.UNKNOWN_ISSUE]] of the
+   * corruption cause since corruption diagnosis is only a best effort.
+   *
+   * @param checkedIn
+   *   the [[CheckedInputStream]] which is used to calculate the checksum.
+   * @param address
+   *   the address where the corrupted block is fetched from.
+   * @param blockId
+   *   the blockId of the corrupted block.
+   * @return
+   *   The corruption diagnosis response for different causes.
+   */
+  private[storage] def diagnoseCorruption(
+      checkedIn: CheckedInputStream,
+      address: BlockManagerId,
+      blockId: BlockId): String = {
+    logInfo("Start corruption diagnosis.")
+    blockId match {
+      case shuffleBlock: ShuffleBlockId =>
+        val startTimeNs = System.nanoTime()
+        val buffer = new 
Array[Byte](ShuffleChecksumHelper.CHECKSUM_CALCULATION_BUFFER)
+        // consume the remaining data to calculate the checksum
+        var cause: Cause = null
+        try {
+          while (checkedIn.read(buffer) != -1) {}
+          val checksum = checkedIn.getChecksum.getValue
+          cause = shuffleClient.diagnoseCorruption(
+            address.host,
+            address.port,
+            address.executorId,
+            shuffleBlock.shuffleId,
+            shuffleBlock.mapId,
+            shuffleBlock.reduceId,
+            checksum,
+            checksumAlgorithm)
+        } catch {
+          case e: Exception =>
+            logWarning("Unable to diagnose the corruption cause of the 
corrupted block", e)
+            cause = Cause.UNKNOWN_ISSUE
+        }
+        val duration = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - 
startTimeNs)
+        val diagnosisResponse = cause match {
+          case Cause.UNSUPPORTED_CHECKSUM_ALGORITHM =>
+            s"Block $blockId is corrupted but corruption diagnosis failed due 
to " +
+              s"unsupported checksum algorithm: $checksumAlgorithm"
+
+          case Cause.CHECKSUM_VERIFY_PASS =>
+            s"Block $blockId is corrupted but checksum verification passed"
+
+          case Cause.UNKNOWN_ISSUE =>
+            s"Block $blockId is corrupted but the cause is unknown"
+
+          case otherCause =>
+            s"Block $blockId is corrupted due to $otherCause"
+        }
+        logInfo(s"Finished corruption diagnosis in $duration ms. 
$diagnosisResponse")
+        diagnosisResponse
+      case shuffleBlockChunk: ShuffleBlockChunkId =>
+        // TODO SPARK-36284 Add shuffle checksum support for push-based shuffle
+        val diagnosisResponse = s"BlockChunk $shuffleBlockChunk is corrupted 
but corruption " +
+          s"diagnosis is skipped due to lack of shuffle checksum support for 
push-based shuffle."
+        logWarning(diagnosisResponse)
+        diagnosisResponse
+      case unexpected: BlockId =>
+        throw new IllegalArgumentException(s"Unexpected type of BlockId, 
$unexpected")
+    }
+  }
+
+  override def onComplete(): Unit = {
+    onCompleteCallback.onComplete(context)
+  }
+
+  private def fetchUpToMaxBytes(): Unit = {
+    if (isNettyOOMOnShuffle.get()) {
+      if (reqsInFlight > 0) {
+        // Return immediately if Netty is still OOMed and there're ongoing 
fetch requests
+        return
+      } else {
+        resetNettyOOMFlagIfPossible(0)
+      }
+    }
+
+    // Send fetch requests up to maxBytesInFlight. If you cannot fetch from a 
remote host
+    // immediately, defer the request until the next time it can be processed.
+
+    // Process any outstanding deferred fetch requests if possible.
+    if (deferredFetchRequests.nonEmpty) {
+      for ((remoteAddress, defReqQueue) <- deferredFetchRequests) {
+        while (
+          isRemoteBlockFetchable(defReqQueue) &&
+          !isRemoteAddressMaxedOut(remoteAddress, defReqQueue.front)
+        ) {
+          val request = defReqQueue.dequeue()
+          logDebug(s"Processing deferred fetch request for $remoteAddress with 
"
+            + s"${request.blocks.length} blocks")
+          send(remoteAddress, request)
+          if (defReqQueue.isEmpty) {
+            deferredFetchRequests -= remoteAddress
+          }
+        }
+      }
+    }
+
+    // Process any regular fetch requests if possible.
+    while (isRemoteBlockFetchable(fetchRequests)) {
+      val request = fetchRequests.dequeue()
+      val remoteAddress = request.address
+      if (isRemoteAddressMaxedOut(remoteAddress, request)) {
+        logDebug(s"Deferring fetch request for $remoteAddress with 
${request.blocks.size} blocks")
+        val defReqQueue = deferredFetchRequests.getOrElse(remoteAddress, new 
Queue[FetchRequest]())
+        defReqQueue.enqueue(request)
+        deferredFetchRequests(remoteAddress) = defReqQueue
+      } else {
+        send(remoteAddress, request)
+      }
+    }
+
+    def send(remoteAddress: BlockManagerId, request: FetchRequest): Unit = {
+      if (request.forMergedMetas) {
+        pushBasedFetchHelper.sendFetchMergedStatusRequest(request)
+      } else {
+        sendRequest(request)
+      }
+      numBlocksInFlightPerAddress(remoteAddress) =
+        numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + 
request.blocks.size
+    }
+
+    def isRemoteBlockFetchable(fetchReqQueue: Queue[FetchRequest]): Boolean = {
+      fetchReqQueue.nonEmpty &&
+      (bytesInFlight == 0 ||
+        (reqsInFlight + 1 <= maxReqsInFlight &&
+          bytesInFlight + fetchReqQueue.front.size <= maxBytesInFlight))
+    }
+
+    // Checks if sending a new fetch request will exceed the max no. of blocks 
being fetched from a
+    // given remote address.
+    def isRemoteAddressMaxedOut(remoteAddress: BlockManagerId, request: 
FetchRequest): Boolean = {
+      numBlocksInFlightPerAddress.getOrElse(remoteAddress, 0) + 
request.blocks.size >
+        maxBlocksInFlightPerAddress
+    }
+  }
+
+  private[storage] def throwFetchFailedException(
+      blockId: BlockId,
+      mapIndex: Int,
+      address: BlockManagerId,
+      e: Throwable,
+      message: Option[String] = None) = {
+    val msg = message.getOrElse(e.getMessage)
+    blockId match {
+      case ShuffleBlockId(shufId, mapId, reduceId) =>
+        throw SparkCoreErrors.fetchFailedError(address, shufId, mapId, 
mapIndex, reduceId, msg, e)
+      case ShuffleBlockBatchId(shuffleId, mapId, startReduceId, _) =>
+        throw SparkCoreErrors.fetchFailedError(
+          address,
+          shuffleId,
+          mapId,
+          mapIndex,
+          startReduceId,
+          msg,
+          e)
+      case _ => throw SparkCoreErrors.failToGetNonShuffleBlockError(blockId, e)
+    }
+  }
+
+  /**
+   * All the below methods are used by [[PushBasedFetchHelper]] to communicate 
with the iterator
+   */
+  private[storage] def addToResultsQueue(result: FetchResult): Unit = {
+    results.put(result)
+  }
+
+  private[storage] def decreaseNumBlocksToFetch(blocksFetched: Int): Unit = {
+    numBlocksToFetch -= blocksFetched
+  }
+
+  /**
+   * Currently used by [[PushBasedFetchHelper]] to fetch fallback blocks when 
there is a fetch
+   * failure related to a push-merged block or shuffle chunk. This is executed 
by the task thread
+   * when the `iterator.next()` is invoked and if that initiates fallback.
+   */
+  private[storage] def fallbackFetch(
+      originalBlocksByAddr: Iterator[(BlockManagerId, Seq[(BlockId, Long, 
Int)])]): Unit = {
+    val originalLocalBlocks = mutable.LinkedHashSet[(BlockId, Int)]()
+    val originalHostLocalBlocksByExecutor =
+      mutable.LinkedHashMap[BlockManagerId, Seq[(BlockId, Long, Int)]]()
+    val originalMergedLocalBlocks = mutable.LinkedHashSet[BlockId]()
+    val originalRemoteReqs = partitionBlocksByFetchMode(
+      originalBlocksByAddr,
+      originalLocalBlocks,
+      originalHostLocalBlocksByExecutor,
+      originalMergedLocalBlocks)
+    // Add the remote requests into our queue in a random order
+    fetchRequests ++= Utils.randomize(originalRemoteReqs)
+    logInfo(s"Created ${originalRemoteReqs.size} fallback remote requests for 
push-merged")
+    // fetch all the fallback blocks that are local.
+    fetchLocalBlocks(originalLocalBlocks)
+    // Merged local blocks should be empty during fallback
+    assert(
+      originalMergedLocalBlocks.isEmpty,
+      "There should be zero push-merged blocks during fallback")
+    // Some of the fallback local blocks could be host local blocks
+    fetchAllHostLocalBlocks(originalHostLocalBlocksByExecutor)
+  }
+
+  /**
+   * Removes all the pending shuffle chunks that are on the same host and have 
the same reduceId as
+   * the current chunk that had a fetch failure. This is executed by the task 
thread when the
+   * `iterator.next()` is invoked and if that initiates fallback.
+   *
+   * @return
+   *   set of all the removed shuffle chunk Ids.
+   */
+  private[storage] def removePendingChunks(
+      failedBlockId: ShuffleBlockChunkId,
+      address: BlockManagerId): mutable.HashSet[ShuffleBlockChunkId] = {
+    val removedChunkIds = new mutable.HashSet[ShuffleBlockChunkId]()
+
+    def sameShuffleReducePartition(block: BlockId): Boolean = {
+      val chunkId = block.asInstanceOf[ShuffleBlockChunkId]
+      chunkId.shuffleId == failedBlockId.shuffleId && chunkId.reduceId == 
failedBlockId.reduceId
+    }
+
+    def filterRequests(queue: mutable.Queue[FetchRequest]): Unit = {
+      val fetchRequestsToRemove = new mutable.Queue[FetchRequest]()
+      fetchRequestsToRemove ++= queue.dequeueAll {
+        req =>
+          val firstBlock = req.blocks.head
+          firstBlock.blockId.isShuffleChunk && req.address.equals(address) &&
+          sameShuffleReducePartition(firstBlock.blockId)
+      }
+      fetchRequestsToRemove.foreach {
+        _ =>
+          removedChunkIds ++=
+            
fetchRequestsToRemove.flatMap(_.blocks.map(_.blockId.asInstanceOf[ShuffleBlockChunkId]))
+      }

Review Comment:
   `removePendingChunks` builds `removedChunkIds` inside a `foreach`, but each 
iteration appends the full `fetchRequestsToRemove` list again. This is 
unnecessarily O(n²) work in a failure path and can inflate fallback latency.



##########
cpp/velox/shuffle/VeloxGpuShuffleReader.cc:
##########
@@ -62,62 +78,126 @@ 
VeloxGpuHashShuffleReaderDeserializer::VeloxGpuHashShuffleReaderDeserializer(
       rowType_(rowType),
       readerBufferSize_(readerBufferSize),
       memoryManager_(memoryManager),
+      threadPool_(threadPool),
       deserializeTime_(deserializeTime),
       decompressTime_(decompressTime) {}
 
-bool VeloxGpuHashShuffleReaderDeserializer::resolveNextBlockType() {
-  GLUTEN_ASSIGN_OR_THROW(auto blockType, readBlockType(in_.get()));
-  switch (blockType) {
-    case BlockType::kEndOfStream:
-      return false;
-    case BlockType::kPlainPayload:
-      return true;
-    default:
-      throw GlutenException(fmt::format("Unsupported block type: {}", 
static_cast<int32_t>(blockType)));
+VeloxGpuHashShuffleReaderDeserializer::~VeloxGpuHashShuffleReaderDeserializer()
 {
+  // Wait for all reader threads to complete before destroying
+  if (!isStopped()) {
+    stop();
   }
+
+  decompressTime_ += decompressTimeCounter_.load(std::memory_order_relaxed);
+  deserializeTime_ += deserializeTimeCounter_.load(std::memory_order_relaxed);
 }
 
-void VeloxGpuHashShuffleReaderDeserializer::loadNextStream() {
-  if (reachedEos_) {
-    return;
+std::unique_ptr<ColumnarBatchIterator> 
VeloxGpuHashShuffleReaderDeserializer::deserializeStreams(int32_t priority) {
+  batchQueue_ = std::make_unique<CachedBatchQueue<GpuBufferColumnarBatch>>(1L 
<< 30);
+
+  if (!threadPool_) {
+    throw GlutenException("Thread pool must be provided to 
VeloxGpuHashShuffleReaderDeserializer");
+  }
+
+  const size_t numThreads = threadPool_->getNumThreads();
+  activeReaders_.store(numThreads);
+
+  // Submit reader tasks to the thread pool.
+  std::vector<ReaderThreadPool::Task> tasks;
+  tasks.reserve(numThreads);
+  for (size_t i = 0; i < numThreads; ++i) {
+    tasks.emplace_back([this]() { read(); });
   }
+  threadPool_->submitBatch(std::move(tasks), priority);
 
-  auto in = 
streamReader_->readNextStream(memoryManager_->defaultArrowMemoryPool());
-  if (in == nullptr) {
-    reachedEos_ = true;
-    return;
+  if (priority == 0) {
+    threadPool_->start();
   }
 
-  GLUTEN_ASSIGN_OR_THROW(
-      in_,
-      arrow::io::BufferedInputStream::Create(
-          readerBufferSize_, memoryManager_->defaultArrowMemoryPool(), 
std::move(in)));
+  return 
std::make_unique<AsyncShuffleReaderIterator<GpuBufferColumnarBatch>>(batchQueue_.get());
 }
 
-std::shared_ptr<ColumnarBatch> VeloxGpuHashShuffleReaderDeserializer::next() {
-  if (in_ == nullptr) {
-    loadNextStream();
+void VeloxGpuHashShuffleReaderDeserializer::stop() {
+  // Signal threads to stop if not already stopped.
+  stop_.store(true, std::memory_order_release);
+  // Wait for all reader threads to complete.
+  std::unique_lock<std::mutex> lock(completionMtx_);

Review Comment:
   `stop()` waits for reader threads to exit, but reader threads may be blocked 
inside `batchQueue_->put(batch)` when the queue is full. In that case `stop()` 
can deadlock because blocked producers never reach the loop top to observe 
`stop_`. A robust fix needs a way to unblock/abort `put()` when stopping (e.g., 
close the queue and have `put()` wake and fail).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to