otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r649514612



##########
File path: 
core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
##########
@@ -436,24 +485,48 @@ final class ShuffleBlockFetcherIterator(
     val iterator = blockInfos.iterator
     var curRequestSize = 0L
     var curBlocks = Seq.empty[FetchBlockInfo]
-
     while (iterator.hasNext) {
       val (blockId, size, mapIndex) = iterator.next()
-      assertPositiveBlockSize(blockId, size)
       curBlocks = curBlocks ++ Seq(FetchBlockInfo(blockId, size, mapIndex))
       curRequestSize += size
-      // For batch fetch, the actual block in flight should count for merged 
block.
-      val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= 
maxBlocksInFlightPerAddress
-      if (curRequestSize >= targetRemoteRequestSize || mayExceedsMaxBlocks) {
-        curBlocks = createFetchRequests(curBlocks, address, isLast = false,
-          collectedRemoteRequests)
-        curRequestSize = curBlocks.map(_.size).sum
+      blockId match {
+        // Either all blocks are merged blocks, merged block chunks, or 
original non-merged blocks.
+        // Based on these types, we decide to do batch fetch and create 
FetchRequests with
+        // forMergedMetas set.
+        case ShuffleBlockChunkId(_, _, _) =>
+          if (curRequestSize >= targetRemoteRequestSize ||
+            curBlocks.size >= maxBlocksInFlightPerAddress) {
+            curBlocks = createFetchRequests(curBlocks, address, isLast = false,
+              collectedRemoteRequests, enableBatchFetch = false)
+            curRequestSize = curBlocks.map(_.size).sum
+          }
+        case ShuffleBlockId(_, SHUFFLE_PUSH_MAP_ID, _) =>
+          if (curBlocks.size >= maxBlocksInFlightPerAddress) {
+            curBlocks = createFetchRequests(curBlocks, address, isLast = false,
+              collectedRemoteRequests, enableBatchFetch = false, 
forMergedMetas = true)
+          }
+        case _ =>
+          // For batch fetch, the actual block in flight should count for 
merged block.
+          val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= 
maxBlocksInFlightPerAddress
+          if (curRequestSize >= targetRemoteRequestSize || 
mayExceedsMaxBlocks) {
+            curBlocks = createFetchRequests(curBlocks, address, isLast = false,
+              collectedRemoteRequests, enableBatchFetch = doBatchFetch)
+            curRequestSize = curBlocks.map(_.size).sum
+          }
       }
     }
     // Add in the final request
     if (curBlocks.nonEmpty) {
+      val (enableBatchFetch, areMergedBlocks) = {
+        curBlocks.head.blockId match {
+          case ShuffleBlockChunkId(_, _, _) => (false, false)
+          case ShuffleBlockId(_, SHUFFLE_PUSH_MAP_ID, _) => (false, true)
+          case _ => (doBatchFetch, false)
+        }
+      }
       curBlocks = createFetchRequests(curBlocks, address, isLast = true,
-        collectedRemoteRequests)
+        collectedRemoteRequests, enableBatchFetch = enableBatchFetch,
+        forMergedMetas = areMergedBlocks)
       curRequestSize = curBlocks.map(_.size).sum

Review comment:
       We do want the sum of the sizes of all the blocks in `curBlocks` so I 
think the `sum` is needed.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to