This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new f34898c  [SPARK-31034][CORE] ShuffleBlockFetcherIterator should always 
create request for last block group
f34898c is described below

commit f34898c5e19c9a35c091eded9652cd5e3d661d19
Author: yi.wu <yi...@databricks.com>
AuthorDate: Thu Mar 5 21:31:26 2020 +0800

    [SPARK-31034][CORE] ShuffleBlockFetcherIterator should always create 
request for last block group
    
    ### What changes were proposed in this pull request?
    
    This is a bug fix of #27280. This PR fix the bug where 
`ShuffleBlockFetcherIterator` may forget to create request for the last block 
group.
    
    ### Why are the changes needed?
    
    When (all blocks).sum < `targetRemoteRequestSize` and (all blocks).length > 
`maxBlocksInFlightPerAddress` and (last block group).size < 
`maxBlocksInFlightPerAddress`,
    `ShuffleBlockFetcherIterator` will not create a request for the last group. 
Thus, it will lost data for the reduce task.
    
    ### Does this PR introduce any user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Updated test.
    
    Closes #27786 from Ngone51/fix_no_request_bug.
    
    Authored-by: yi.wu <yi...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit 2257ce24437f05c417821c02e3e44c77c93f7211)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../storage/ShuffleBlockFetcherIterator.scala      |  8 +-
 .../storage/ShuffleBlockFetcherIteratorSuite.scala | 91 ++++++++++++++++++----
 2 files changed, 78 insertions(+), 21 deletions(-)

diff --git 
a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
 
b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index cd4c860..2a0447d 100644
--- 
a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ 
b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -339,14 +339,14 @@ final class ShuffleBlockFetcherIterator(
         + s"with ${blocks.size} blocks")
     }
 
-    def createFetchRequests(): Unit = {
+    def createFetchRequests(isLast: Boolean): Unit = {
       val mergedBlocks = mergeContinuousShuffleBlockIdsIfNeeded(curBlocks)
       curBlocks = new ArrayBuffer[FetchBlockInfo]
       if (mergedBlocks.length <= maxBlocksInFlightPerAddress) {
         createFetchRequest(mergedBlocks)
       } else {
         mergedBlocks.grouped(maxBlocksInFlightPerAddress).foreach { blocks =>
-          if (blocks.length == maxBlocksInFlightPerAddress) {
+          if (blocks.length == maxBlocksInFlightPerAddress || isLast) {
             createFetchRequest(blocks)
           } else {
             // The last group does not exceed `maxBlocksInFlightPerAddress`. 
Put it back
@@ -367,12 +367,12 @@ final class ShuffleBlockFetcherIterator(
       // For batch fetch, the actual block in flight should count for merged 
block.
       val mayExceedsMaxBlocks = !doBatchFetch && curBlocks.size >= 
maxBlocksInFlightPerAddress
       if (curRequestSize >= targetRemoteRequestSize || mayExceedsMaxBlocks) {
-        createFetchRequests()
+        createFetchRequests(isLast = false)
       }
     }
     // Add in the final request
     if (curBlocks.nonEmpty) {
-      createFetchRequests()
+      createFetchRequests(isLast = true)
     }
   }
 
diff --git 
a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
 
b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index 2090a51..773629c 100644
--- 
a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -433,32 +433,86 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
     assert(blockManager.hostLocalDirManager.get.getCachedHostLocalDirs().size 
=== 1)
   }
 
-  test("fetch continuous blocks in batch respects maxSize and maxBlocks") {
+  test("fetch continuous blocks in batch should respect maxBytesInFlight") {
     val blockManager = mock(classOf[BlockManager])
     val localBmId = BlockManagerId("test-client", "test-local-host", 1)
     doReturn(localBmId).when(blockManager).blockManagerId
 
     // Make sure remote blocks would return the merged block
-    val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
-    val remoteBlocks = Seq[BlockId](
+    val remoteBmId1 = BlockManagerId("test-client-1", "test-client-1", 1)
+    val remoteBmId2 = BlockManagerId("test-client-2", "test-client-2", 2)
+    val remoteBlocks1 = (0 until 15).map(ShuffleBlockId(0, 3, _))
+    val remoteBlocks2 = Seq[BlockId](ShuffleBlockId(0, 4, 0), 
ShuffleBlockId(0, 4, 1))
+    val mergedRemoteBlocks = Map[BlockId, ManagedBuffer](
+      ShuffleBlockBatchId(0, 3, 0, 3) -> createMockManagedBuffer(),
+      ShuffleBlockBatchId(0, 3, 3, 6) -> createMockManagedBuffer(),
+      ShuffleBlockBatchId(0, 3, 6, 9) -> createMockManagedBuffer(),
+      ShuffleBlockBatchId(0, 3, 9, 12) -> createMockManagedBuffer(),
+      ShuffleBlockBatchId(0, 3, 12, 15) -> createMockManagedBuffer(),
+      ShuffleBlockBatchId(0, 4, 0, 2) -> createMockManagedBuffer())
+    val transfer = createMockTransfer(mergedRemoteBlocks)
+
+    val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
+      (remoteBmId1, remoteBlocks1.map(blockId => (blockId, 100L, 1))),
+      (remoteBmId2, remoteBlocks2.map(blockId => (blockId, 100L, 
1)))).toIterator
+
+    val taskContext = TaskContext.empty()
+    val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics()
+    val iterator = new ShuffleBlockFetcherIterator(
+      taskContext,
+      transfer,
+      blockManager,
+      blocksByAddress,
+      (_, in) => in,
+      1500,
+      Int.MaxValue,
+      Int.MaxValue,
+      Int.MaxValue,
+      true,
+      false,
+      metrics,
+      true)
+
+    var numResults = 0
+    // After initialize(), there will be 6 FetchRequests. And each of the 
first 5 requests
+    // includes 1 merged block which is merged from 3 shuffle blocks. The last 
request has 1 merged
+    // block which merged from 2 shuffle blocks. So, only the first 5 
requests(5 * 3 * 100 >= 1500)
+    // can be sent. The 6th FetchRequest will hit maxBlocksInFlightPerAddress 
so it won't
+    // be sent.
+    verify(transfer, times(5)).fetchBlocks(any(), any(), any(), any(), any(), 
any())
+    while (iterator.hasNext) {
+      val (blockId, inputStream) = iterator.next()
+      // Make sure we release buffers when a wrapped input stream is closed.
+      val mockBuf = mergedRemoteBlocks(blockId)
+      verifyBufferRelease(mockBuf, inputStream)
+      numResults += 1
+    }
+    // The 6th request will be sent after next() is called.
+    verify(transfer, times(6)).fetchBlocks(any(), any(), any(), any(), any(), 
any())
+    assert(numResults == 6)
+  }
+
+  test("fetch continuous blocks in batch should respect 
maxBlocksInFlightPerAddress") {
+    val blockManager = mock(classOf[BlockManager])
+    val localBmId = BlockManagerId("test-client", "test-local-host", 1)
+    doReturn(localBmId).when(blockManager).blockManagerId
+
+    // Make sure remote blocks would return the merged block
+    val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 1)
+    val remoteBlocks = Seq(
       ShuffleBlockId(0, 3, 0),
       ShuffleBlockId(0, 3, 1),
-      ShuffleBlockId(0, 3, 2),
       ShuffleBlockId(0, 4, 0),
       ShuffleBlockId(0, 4, 1),
-      ShuffleBlockId(0, 5, 0),
-      ShuffleBlockId(0, 5, 1),
-      ShuffleBlockId(0, 5, 2))
+      ShuffleBlockId(0, 5, 0))
     val mergedRemoteBlocks = Map[BlockId, ManagedBuffer](
-      ShuffleBlockBatchId(0, 3, 0, 3) -> createMockManagedBuffer(),
+      ShuffleBlockBatchId(0, 3, 0, 2) -> createMockManagedBuffer(),
       ShuffleBlockBatchId(0, 4, 0, 2) -> createMockManagedBuffer(),
-      ShuffleBlockBatchId(0, 5, 0, 3) -> createMockManagedBuffer())
-    val transfer = createMockTransfer(mergedRemoteBlocks)
+      ShuffleBlockBatchId(0, 5, 0, 1) -> createMockManagedBuffer())
 
+    val transfer = createMockTransfer(mergedRemoteBlocks)
     val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
-      (remoteBmId, remoteBlocks.map(blockId => (blockId, 1L, 1)))
-    ).toIterator
-
+      (remoteBmId, remoteBlocks.map(blockId => (blockId, 100L, 1)))).toIterator
     val taskContext = TaskContext.empty()
     val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics()
     val iterator = new ShuffleBlockFetcherIterator(
@@ -467,7 +521,7 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
       blockManager,
       blocksByAddress,
       (_, in) => in,
-      35,
+      Int.MaxValue,
       Int.MaxValue,
       2,
       Int.MaxValue,
@@ -475,8 +529,12 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
       false,
       metrics,
       true)
-
     var numResults = 0
+    // After initialize(), there will be 2 FetchRequests. First one has 2 
merged blocks and each
+    // of them is merged from 2 shuffle blocks, second one has 1 merged block 
which is merged from
+    // 1 shuffle block. So only the first FetchRequest can be sent. The second 
FetchRequest will
+    // hit maxBlocksInFlightPerAddress so it won't be sent.
+    verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), 
any())
     while (iterator.hasNext) {
       val (blockId, inputStream) = iterator.next()
       // Make sure we release buffers when a wrapped input stream is closed.
@@ -484,8 +542,7 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
       verifyBufferRelease(mockBuf, inputStream)
       numResults += 1
     }
-    // The first 2 batch block ids are in the same fetch request as they don't 
exceed the max size
-    // and max blocks, so 2 requests in total.
+    // The second request will be sent after next() is called.
     verify(transfer, times(2)).fetchBlocks(any(), any(), any(), any(), any(), 
any())
     assert(numResults == 3)
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to