vanzin commented on a change in pull request #25299: [SPARK-27651][Core] Avoid 
the network when shuffle blocks are fetched from the same host
URL: https://github.com/apache/spark/pull/25299#discussion_r346976078
 
 

 ##########
 File path: 
core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
 ##########
 @@ -272,73 +280,92 @@ final class ShuffleBlockFetcherIterator(
     }
   }
 
-  private[this] def splitLocalRemoteBlocks(): ArrayBuffer[FetchRequest] = {
-    // Make remote requests at most maxBytesInFlight / 5 in length; the reason 
to keep them
-    // smaller than maxBytesInFlight is to allow multiple, parallel fetches 
from up to 5
-    // nodes, rather than blocking on reading output from one node.
-    val targetRequestSize = math.max(maxBytesInFlight / 5, 1L)
-    logDebug("maxBytesInFlight: " + maxBytesInFlight + ", targetRequestSize: " 
+ targetRequestSize
-      + ", maxBlocksInFlightPerAddress: " + maxBlocksInFlightPerAddress)
-
-    // Split local and remote blocks. Remote blocks are further split into 
FetchRequests of size
-    // at most maxBytesInFlight in order to limit the amount of data in flight.
-    val remoteRequests = new ArrayBuffer[FetchRequest]
+  private[this] def partitionBlocksByFetchMode(): ArrayBuffer[FetchRequest] = {
+    logDebug(s"maxBytesInFlight: $maxBytesInFlight, targetRemoteRequestSize: "
+      + s"$targetRemoteRequestSize, maxBlocksInFlightPerAddress: 
$maxBlocksInFlightPerAddress")
+
+    // Partition to local, host-local and remote blocks. Remote blocks are 
further split into
+    // FetchRequests of size at most maxBytesInFlight in order to limit the 
amount of data in flight
+    val collectedRemoteRequests = new ArrayBuffer[FetchRequest]
     var localBlockBytes = 0L
+    var hostLocalBlockBytes = 0L
     var remoteBlockBytes = 0L
+    var numRemoteBlocks = 0
+
+    val hostLocalDirReadingEnabled =
+      blockManager.hostLocalDirManager != null && 
blockManager.hostLocalDirManager.isDefined
 
     for ((address, blockInfos) <- blocksByAddress) {
       if (address.executorId == blockManager.blockManagerId.executorId) {
-        blockInfos.find(_._2 <= 0) match {
-          case Some((blockId, size, _)) if size < 0 =>
-            throw new BlockException(blockId, "Negative block size " + size)
-          case Some((blockId, size, _)) if size == 0 =>
-            throw new BlockException(blockId, "Zero-sized blocks should be 
excluded.")
-          case None => // do nothing.
-        }
+        checkBlockSizes(blockInfos)
         val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
           blockInfos.map(info => FetchBlockInfo(info._1, info._2, 
info._3)).to[ArrayBuffer])
         localBlocks ++= mergedBlockInfos.map(info => (info.blockId, 
info.mapIndex))
         localBlockBytes += mergedBlockInfos.map(_.size).sum
+      } else if (hostLocalDirReadingEnabled && address.host == 
blockManager.blockManagerId.host) {
+        checkBlockSizes(blockInfos)
+        val mergedBlockInfos = mergeContinuousShuffleBlockIdsIfNeeded(
+          blockInfos.map(info => FetchBlockInfo(info._1, info._2, 
info._3)).to[ArrayBuffer])
+        val blocksForAddress =
+          mergedBlockInfos.map(info => (info.blockId, info.size, 
info.mapIndex))
+        hostLocalBlocksByExecutor += address -> blocksForAddress
+        hostLocalBlocks ++= blocksForAddress.map(info => (info._1, info._3))
+        hostLocalBlockBytes += mergedBlockInfos.map(_.size).sum
       } else {
-        val iterator = blockInfos.iterator
-        var curRequestSize = 0L
-        var curBlocks = new ArrayBuffer[FetchBlockInfo]
-        while (iterator.hasNext) {
-          val (blockId, size, mapIndex) = iterator.next()
-          remoteBlockBytes += size
-          if (size < 0) {
-            throw new BlockException(blockId, "Negative block size " + size)
-          } else if (size == 0) {
-            throw new BlockException(blockId, "Zero-sized blocks should be 
excluded.")
-          } else {
-            curBlocks += FetchBlockInfo(blockId, size, mapIndex)
-            curRequestSize += size
-          }
-          if (curRequestSize >= targetRequestSize ||
-              curBlocks.size >= maxBlocksInFlightPerAddress) {
-            // Add this FetchRequest
-            val mergedBlocks = 
mergeContinuousShuffleBlockIdsIfNeeded(curBlocks)
-            remoteBlocks ++= mergedBlocks.map(_.blockId)
-            remoteRequests += new FetchRequest(address, mergedBlocks)
-            logDebug(s"Creating fetch request of $curRequestSize at $address "
-              + s"with ${mergedBlocks.size} blocks")
-            curBlocks = new ArrayBuffer[FetchBlockInfo]
-            curRequestSize = 0
-          }
-        }
-        // Add in the final request
-        if (curBlocks.nonEmpty) {
-          val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks)
-          remoteBlocks ++= mergedBlocks.map(_.blockId)
-          remoteRequests += new FetchRequest(address, mergedBlocks)
-        }
+        numRemoteBlocks += blockInfos.size
+        remoteBlockBytes += blockInfos.map(_._2).sum
+        collectFetchRequests(address, blockInfos, collectedRemoteRequests)
       }
     }
     val totalBytes = localBlockBytes + remoteBlockBytes
     logInfo(s"Getting $numBlocksToFetch (${Utils.bytesToString(totalBytes)}) 
non-empty blocks " +
-      s"including ${localBlocks.size} 
(${Utils.bytesToString(localBlockBytes)}) local blocks and " +
-      s"${remoteBlocks.size} (${Utils.bytesToString(remoteBlockBytes)}) remote 
blocks")
-    remoteRequests
+      s"including ${localBlocks.size} 
(${Utils.bytesToString(localBlockBytes)}) local and " +
+      s"${hostLocalBlocks.size} (${Utils.bytesToString(hostLocalBlockBytes)}) 
" +
+      s"host-local and $numRemoteBlocks 
(${Utils.bytesToString(remoteBlockBytes)}) remote blocks")
+    collectedRemoteRequests
+  }
+
+  private def collectFetchRequests(
+      address: BlockManagerId,
+      blockInfos: Seq[(BlockId, Long, Int)],
+      collectedRemoteRequests: ArrayBuffer[FetchRequest]): Unit = {
+    val iterator = blockInfos.iterator
+    var curRequestSize = 0L
+    var curBlocks = new ArrayBuffer[FetchBlockInfo]
+    while (iterator.hasNext) {
+      val (blockId, size, mapIndex) = iterator.next()
+      assertPositiveBlockSize(blockId, size)
+      curBlocks += FetchBlockInfo(blockId, size, mapIndex)
+      curRequestSize += size
+      if (curRequestSize >= targetRemoteRequestSize ||
+          curBlocks.size >= maxBlocksInFlightPerAddress) {
+        // Add this FetchRequest
+        val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks)
+        collectedRemoteRequests += new FetchRequest(address, mergedBlocks)
+        logDebug(s"Creating fetch request of $curRequestSize at $address "
+          + s"with ${mergedBlocks.size} blocks")
+        curBlocks = new ArrayBuffer[FetchBlockInfo]
+        curRequestSize = 0
+      }
+    }
+    // Add in the final request
+    if (curBlocks.nonEmpty) {
+      val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks)
+      collectedRemoteRequests += new FetchRequest(address, mergedBlocks)
+    }
+  }
+
 
 Review comment:
   too many empty lines

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to