This is an automated email from the ASF dual-hosted git repository.

mridulm80 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 186477c  [SPARK-35263][TEST] Refactor ShuffleBlockFetcherIteratorSuite 
to reduce duplicated code
186477c is described below

commit 186477c60e9cad71434b15fd9e08789740425d59
Author: Erik Krogen <xkro...@apache.org>
AuthorDate: Tue May 18 22:37:47 2021 -0500

    [SPARK-35263][TEST] Refactor ShuffleBlockFetcherIteratorSuite to reduce 
duplicated code
    
    ### What changes were proposed in this pull request?
    Introduce new shared methods to `ShuffleBlockFetcherIteratorSuite` to 
replace copy-pasted code. Use modern, Scala-like Mockito `Answer` syntax.
    
    ### Why are the changes needed?
    `ShuffleFetcherBlockIteratorSuite` has tons of duplicate code, like 
https://github.com/apache/spark/blob/0494dc90af48ce7da0625485a4dc6917a244d580/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala#L172-L185
 . It's challenging to tell what the interesting parts are vs. what is just 
being set to some default/unused value.
    
    Similarly but not as bad, there are many calls like the following
    ```
    verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), 
any())
    when(transfer.fetchBlocks(any(), any(), any(), any(), any(), 
any())).thenAnswer ...
    ```
    
    These changes result in about 10% reduction in both lines and characters in 
the file:
    ```bash
    # Before
    > wc 
core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
        1063    3950   43201 
core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
    
    # After
    > wc 
core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
         928    3609   39053 
core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
    ```
    
    It also helps readability, e.g.:
    ```
        val iterator = createShuffleBlockIteratorWithDefaults(
          transfer,
          blocksByAddress,
          maxBytesInFlight = 1000L
        )
    ```
    Now I can clearly tell that `maxBytesInFlight` is the main parameter we're 
interested in here.
    
    ### Does this PR introduce _any_ user-facing change?
    No, test only. There aren't even any behavior changes, just refactoring.
    
    ### How was this patch tested?
    Unit tests pass.
    
    Closes #32389 from 
xkrogen/xkrogen-spark-35263-refactor-shuffleblockfetcheriteratorsuite.
    
    Authored-by: Erik Krogen <xkro...@apache.org>
    Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
---
 .../storage/ShuffleBlockFetcherIteratorSuite.scala | 689 ++++++++-------------
 1 file changed, 245 insertions(+), 444 deletions(-)

diff --git 
a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
 
b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
index 99c43b1..4be5fae 100644
--- 
a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
@@ -27,7 +27,7 @@ import scala.concurrent.Future
 
 import org.mockito.ArgumentMatchers.{any, eq => meq}
 import org.mockito.Mockito.{mock, times, verify, when}
-import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
 import org.scalatest.PrivateMethodTester
 
 import org.apache.spark.{SparkFunSuite, TaskContext}
@@ -35,35 +35,44 @@ import org.apache.spark.network._
 import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, 
ManagedBuffer}
 import org.apache.spark.network.shuffle.{BlockFetchingListener, 
DownloadFileManager, ExternalBlockStoreClient}
 import org.apache.spark.network.util.LimitedInputStream
-import org.apache.spark.shuffle.FetchFailedException
+import org.apache.spark.shuffle.{FetchFailedException, 
ShuffleReadMetricsReporter}
 import org.apache.spark.storage.ShuffleBlockFetcherIterator.FetchBlockInfo
 import org.apache.spark.util.Utils
 
 
 class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with 
PrivateMethodTester {
 
+  private var transfer: BlockTransferService = _
+
+  override def beforeEach(): Unit = {
+    transfer = mock(classOf[BlockTransferService])
+  }
+
   private def doReturn(value: Any) = org.mockito.Mockito.doReturn(value, 
Seq.empty: _*)
 
+  private def answerFetchBlocks(answer: Answer[Unit]): Unit =
+    when(transfer.fetchBlocks(any(), any(), any(), any(), any(), 
any())).thenAnswer(answer)
+
+  private def verifyFetchBlocksInvocationCount(expectedCount: Int): Unit =
+    verify(transfer, times(expectedCount)).fetchBlocks(any(), any(), any(), 
any(), any(), any())
+
   // Some of the tests are quite tricky because we are testing the cleanup 
behavior
   // in the presence of faults.
 
-  /** Creates a mock [[BlockTransferService]] that returns data from the given 
map. */
-  private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): 
BlockTransferService = {
-    val transfer = mock(classOf[BlockTransferService])
-    when(transfer.fetchBlocks(any(), any(), any(), any(), any(), 
any())).thenAnswer(
-      (invocation: InvocationOnMock) => {
-        val blocks = invocation.getArguments()(3).asInstanceOf[Array[String]]
-        val listener = 
invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
-
-        for (blockId <- blocks) {
-          if (data.contains(BlockId(blockId))) {
-            listener.onBlockFetchSuccess(blockId, data(BlockId(blockId)))
-          } else {
-            listener.onBlockFetchFailure(blockId, new 
BlockNotFoundException(blockId))
-          }
+  /** Configures `transfer` (mock [[BlockTransferService]]) to return data 
from the given map. */
+  private def configureMockTransfer(data: Map[BlockId, ManagedBuffer]): Unit = 
{
+    answerFetchBlocks { invocation =>
+      val blocks = invocation.getArgument[Array[String]](3)
+      val listener = invocation.getArgument[BlockFetchingListener](4)
+
+      for (blockId <- blocks) {
+        if (data.contains(BlockId(blockId))) {
+          listener.onBlockFetchSuccess(blockId, data(BlockId(blockId)))
+        } else {
+          listener.onBlockFetchFailure(blockId, new 
BlockNotFoundException(blockId))
         }
-      })
-    transfer
+      }
+    }
   }
 
   private def createMockBlockManager(): BlockManager = {
@@ -88,10 +97,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite 
with PrivateMethodT
     
when(blockManager.hostLocalDirManager).thenReturn(Some(hostLocalDirManager))
     when(mockExternalBlockStoreClient.getHostLocalDirs(any(), any(), any(), 
any()))
       .thenAnswer { invocation =>
-        val completableFuture = invocation.getArguments()(3)
-          .asInstanceOf[CompletableFuture[java.util.Map[String, 
Array[String]]]]
         import scala.collection.JavaConverters._
-        completableFuture.complete(hostLocalDirs.asJava)
+        invocation.getArgument[CompletableFuture[java.util.Map[String, 
Array[String]]]](3)
+          .complete(hostLocalDirs.asJava)
       }
 
     blockManager.hostLocalDirManager = Some(hostLocalDirManager)
@@ -123,6 +131,49 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
     verify(wrappedInputStream.invokePrivate(delegateAccess()), 
times(1)).close()
   }
 
+  // scalastyle:off argcount
+  private def createShuffleBlockIteratorWithDefaults(
+      blocksByAddress: Map[BlockManagerId, Seq[(BlockId, Long, Int)]],
+      taskContext: Option[TaskContext] = None,
+      streamWrapperLimitSize: Option[Long] = None,
+      blockManager: Option[BlockManager] = None,
+      maxBytesInFlight: Long = Long.MaxValue,
+      maxReqsInFlight: Int = Int.MaxValue,
+      maxBlocksInFlightPerAddress: Int = Int.MaxValue,
+      maxReqSizeShuffleToMem: Int = Int.MaxValue,
+      detectCorrupt: Boolean = true,
+      detectCorruptUseExtraMemory: Boolean = true,
+      shuffleMetrics: Option[ShuffleReadMetricsReporter] = None,
+      doBatchFetch: Boolean = false): ShuffleBlockFetcherIterator = {
+    val tContext = taskContext.getOrElse(TaskContext.empty())
+    new ShuffleBlockFetcherIterator(
+      tContext,
+      transfer,
+      blockManager.getOrElse(createMockBlockManager()),
+      blocksByAddress.toIterator,
+      (_, in) => streamWrapperLimitSize.map(new LimitedInputStream(in, 
_)).getOrElse(in),
+      maxBytesInFlight,
+      maxReqsInFlight,
+      maxBlocksInFlightPerAddress,
+      maxReqSizeShuffleToMem,
+      detectCorrupt,
+      detectCorruptUseExtraMemory,
+      
shuffleMetrics.getOrElse(tContext.taskMetrics().createTempShuffleReadMetrics()),
+      doBatchFetch)
+  }
+  // scalastyle:on argcount
+
+  /**
+   * Convert a list of block IDs into a list of blocks with metadata, assuming 
all blocks have the
+   * same size and index.
+   */
+  private def toBlockList(
+      blockIds: Traversable[BlockId],
+      blockSize: Long,
+      blockMapIndex: Int): Seq[(BlockId, Long, Int)] = {
+    blockIds.map(blockId => (blockId, blockSize, blockMapIndex)).toSeq
+  }
+
   test("successful 3 local + 4 host local + 2 remote reads") {
     val blockManager = createMockBlockManager()
     val localBmId = blockManager.blockManagerId
@@ -142,15 +193,11 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
       ShuffleBlockId(0, 3, 0) -> createMockManagedBuffer(),
       ShuffleBlockId(0, 4, 0) -> createMockManagedBuffer())
 
-    val transfer = createMockTransfer(remoteBlocks)
+    configureMockTransfer(remoteBlocks)
 
     // Create a block manager running on the same host (host-local)
     val hostLocalBmId = BlockManagerId("test-host-local-client-1", 
"test-local-host", 3)
-    val hostLocalBlocks = Map[BlockId, ManagedBuffer](
-      ShuffleBlockId(0, 5, 0) -> createMockManagedBuffer(),
-      ShuffleBlockId(0, 6, 0) -> createMockManagedBuffer(),
-      ShuffleBlockId(0, 7, 0) -> createMockManagedBuffer(),
-      ShuffleBlockId(0, 8, 0) -> createMockManagedBuffer())
+    val hostLocalBlocks = 5.to(8).map(ShuffleBlockId(0, _, 0) -> 
createMockManagedBuffer()).toMap
 
     hostLocalBlocks.foreach { case (blockId, buf) =>
       doReturn(buf)
@@ -161,28 +208,14 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
     // returning local dir for hostLocalBmId
     initHostLocalDirManager(blockManager, hostLocalDirs)
 
-    val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
-      (localBmId, localBlocks.keys.map(blockId => (blockId, 1L, 0)).toSeq),
-      (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 1L, 1)).toSeq),
-      (hostLocalBmId, hostLocalBlocks.keys.map(blockId => (blockId, 1L, 
1)).toSeq)
-    ).toIterator
-
-    val taskContext = TaskContext.empty()
-    val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics()
-    val iterator = new ShuffleBlockFetcherIterator(
-      taskContext,
-      transfer,
-      blockManager,
-      blocksByAddress,
-      (_, in) => in,
-      48 * 1024 * 1024,
-      Int.MaxValue,
-      Int.MaxValue,
-      Int.MaxValue,
-      true,
-      false,
-      metrics,
-      false)
+    val iterator = createShuffleBlockIteratorWithDefaults(
+      Map(
+        localBmId -> toBlockList(localBlocks.keys, 1L, 0),
+        remoteBmId -> toBlockList(remoteBlocks.keys, 1L, 1),
+        hostLocalBmId -> toBlockList(hostLocalBlocks.keys, 1L, 1)
+      ),
+      blockManager = Some(blockManager)
+    )
 
     // 3 local blocks fetched in initialization
     verify(blockManager, times(3)).getLocalBlockData(any())
@@ -203,7 +236,7 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
       .getHostLocalShuffleData(any(), meq(Array("local-dir")))
 
     // 2 remote blocks are read from the same block manager
-    verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), 
any())
+    verifyFetchBlocksInvocationCount(1)
     assert(blockManager.hostLocalDirManager.get.getCachedHostLocalDirs.size 
=== 1)
   }
 
@@ -228,117 +261,64 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
     
when(blockManager.hostLocalDirManager).thenReturn(Some(hostLocalDirManager))
     when(mockExternalBlockStoreClient.getHostLocalDirs(any(), any(), any(), 
any()))
       .thenAnswer { invocation =>
-        val completableFuture = invocation.getArguments()(3)
-          .asInstanceOf[CompletableFuture[java.util.Map[String, 
Array[String]]]]
-        completableFuture.completeExceptionally(new Throwable("failed fetch"))
+        invocation.getArgument[CompletableFuture[java.util.Map[String, 
Array[String]]]](3)
+          .completeExceptionally(new Throwable("failed fetch"))
       }
 
     blockManager.hostLocalDirManager = Some(hostLocalDirManager)
-    val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
-      (hostLocalBmId, hostLocalBlocks.keys.map(blockId => (blockId, 1L, 
1)).toSeq)
-    ).toIterator
 
-    val transfer = createMockTransfer(Map())
-    val taskContext = TaskContext.empty()
-    val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics()
-    val iterator = new ShuffleBlockFetcherIterator(
-      taskContext,
-      transfer,
-      blockManager,
-      blocksByAddress,
-      (_, in) => in,
-      48 * 1024 * 1024,
-      Int.MaxValue,
-      Int.MaxValue,
-      Int.MaxValue,
-      true,
-      false,
-      metrics,
-      false)
+    configureMockTransfer(Map())
+    val iterator = createShuffleBlockIteratorWithDefaults(
+      Map(hostLocalBmId -> toBlockList(hostLocalBlocks.keys, 1L, 1))
+    )
     intercept[FetchFailedException] { iterator.next() }
   }
 
   test("Hit maxBytesInFlight limitation before maxBlocksInFlightPerAddress") {
-    val blockManager = createMockBlockManager()
     val remoteBmId1 = BlockManagerId("test-remote-client-1", 
"test-remote-host1", 1)
     val remoteBmId2 = BlockManagerId("test-remote-client-2", 
"test-remote-host2", 2)
     val blockId1 = ShuffleBlockId(0, 1, 0)
     val blockId2 = ShuffleBlockId(1, 1, 0)
-    val blocksByAddress = Seq(
-      (remoteBmId1, Seq((blockId1, 1000L, 0))),
-      (remoteBmId2, Seq((blockId2, 1000L, 0)))).toIterator
-    val transfer = createMockTransfer(Map(
+    configureMockTransfer(Map(
       blockId1 -> createMockManagedBuffer(1000),
       blockId2 -> createMockManagedBuffer(1000)))
-    val taskContext = TaskContext.empty()
-    val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics()
-    val iterator = new ShuffleBlockFetcherIterator(
-      taskContext,
-      transfer,
-      blockManager,
-      blocksByAddress,
-      (_, in) => in,
-      1000L, // allow 1 FetchRequests at most at the same time
-      Int.MaxValue,
-      Int.MaxValue, // set maxBlocksInFlightPerAddress to Int.MaxValue
-      Int.MaxValue,
-      true,
-      false,
-      metrics,
-      false)
+    val iterator = createShuffleBlockIteratorWithDefaults(Map(
+      remoteBmId1 -> toBlockList(Seq(blockId1), 1000L, 0),
+      remoteBmId2 -> toBlockList(Seq(blockId2), 1000L, 0)
+    ), maxBytesInFlight = 1000L)
     // After initialize() we'll have 2 FetchRequests and each is 1000 bytes. 
So only the
     // first FetchRequests can be sent, and the second one will hit 
maxBytesInFlight so
     // it won't be sent.
-    verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), 
any())
+    verifyFetchBlocksInvocationCount(1)
     assert(iterator.hasNext)
     // next() will trigger off sending deferred request
     iterator.next()
     // the second FetchRequest should be sent at this time
-    verify(transfer, times(2)).fetchBlocks(any(), any(), any(), any(), any(), 
any())
+    verifyFetchBlocksInvocationCount(2)
     assert(iterator.hasNext)
     iterator.next()
     assert(!iterator.hasNext)
   }
 
   test("Hit maxBlocksInFlightPerAddress limitation before maxBytesInFlight") {
-    val blockManager = createMockBlockManager()
     val remoteBmId = BlockManagerId("test-remote-client-1", 
"test-remote-host", 2)
-    val blockId1 = ShuffleBlockId(0, 1, 0)
-    val blockId2 = ShuffleBlockId(0, 2, 0)
-    val blockId3 = ShuffleBlockId(0, 3, 0)
-    val blocksByAddress = Seq((remoteBmId,
-      Seq((blockId1, 1000L, 0), (blockId2, 1000L, 0), (blockId3, 1000L, 
0)))).toIterator
-    val transfer = createMockTransfer(Map(
-      blockId1 -> createMockManagedBuffer(),
-      blockId2 -> createMockManagedBuffer(),
-      blockId3 -> createMockManagedBuffer()))
-    val taskContext = TaskContext.empty()
-    val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics()
-    val iterator = new ShuffleBlockFetcherIterator(
-      taskContext,
-      transfer,
-      blockManager,
-      blocksByAddress,
-      (_, in) => in,
-      Int.MaxValue, // set maxBytesInFlight to Int.MaxValue
-      Int.MaxValue,
-      2, // set maxBlocksInFlightPerAddress to 2
-      Int.MaxValue,
-      true,
-      false,
-      metrics,
-      false)
+    val blocks = 1.to(3).map(ShuffleBlockId(0, _, 0))
+    configureMockTransfer(blocks.map(_ -> createMockManagedBuffer()).toMap)
+    val iterator = createShuffleBlockIteratorWithDefaults(
+      Map(remoteBmId -> toBlockList(blocks, 1000L, 0)),
+      maxBlocksInFlightPerAddress = 2
+    )
     // After initialize(), we'll have 2 FetchRequests that one has 2 blocks 
inside and another one
     // has only one block. So only the first FetchRequest can be sent. The 
second FetchRequest will
     // hit maxBlocksInFlightPerAddress so it won't be sent.
-    verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), 
any())
+    verifyFetchBlocksInvocationCount(1)
     // the first request packaged 2 blocks, so we also need to
     // call next() for 2 times to exhaust the iterator.
     assert(iterator.hasNext)
     iterator.next()
     assert(iterator.hasNext)
     iterator.next()
-    verify(transfer, times(2)).fetchBlocks(any(), any(), any(), any(), any(), 
any())
+    verifyFetchBlocksInvocationCount(2)
     assert(iterator.hasNext)
     iterator.next()
     assert(!iterator.hasNext)
@@ -365,7 +345,7 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
       ShuffleBlockId(0, 3, 1))
     val mergedRemoteBlocks = Map[BlockId, ManagedBuffer](
       ShuffleBlockBatchId(0, 3, 0, 2) -> createMockManagedBuffer())
-    val transfer = createMockTransfer(mergedRemoteBlocks)
+    configureMockTransfer(mergedRemoteBlocks)
 
      // Create a block manager running on the same host (host-local)
     val hostLocalBmId = BlockManagerId("test-host-local-client-1", 
"test-local-host", 3)
@@ -386,28 +366,15 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
     // returning local dir for hostLocalBmId
     initHostLocalDirManager(blockManager, hostLocalDirs)
 
-    val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
-      (localBmId, localBlocks.map(blockId => (blockId, 1L, 0))),
-      (remoteBmId, remoteBlocks.map(blockId => (blockId, 1L, 1))),
-      (hostLocalBmId, hostLocalBlocks.keys.map(blockId => (blockId, 1L, 
1)).toSeq)
-    ).toIterator
-
-    val taskContext = TaskContext.empty()
-    val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics()
-    val iterator = new ShuffleBlockFetcherIterator(
-      taskContext,
-      transfer,
-      blockManager,
-      blocksByAddress,
-      (_, in) => in,
-      48 * 1024 * 1024,
-      Int.MaxValue,
-      Int.MaxValue,
-      Int.MaxValue,
-      true,
-      false,
-      metrics,
-      true)
+    val iterator = createShuffleBlockIteratorWithDefaults(
+      Map(
+        localBmId -> toBlockList(localBlocks, 1L, 0),
+        remoteBmId -> toBlockList(remoteBlocks, 1L, 1),
+        hostLocalBmId -> toBlockList(hostLocalBlocks.keys, 1L, 1)
+      ),
+      blockManager = Some(blockManager),
+      doBatchFetch = true
+    )
 
     // 3 local blocks batch fetched in initialization
     verify(blockManager, times(1)).getLocalBlockData(any())
@@ -416,7 +383,7 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
     for (i <- 0 until 3) {
       assert(iterator.hasNext, s"iterator should have 3 elements but actually 
has $i elements")
       val (blockId, inputStream) = iterator.next()
-      verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), 
any(), any())
+      verifyFetchBlocksInvocationCount(1)
       // Make sure we release buffers when a wrapped input stream is closed.
       val mockBuf = allBlocks(blockId)
       verifyBufferRelease(mockBuf, inputStream)
@@ -430,7 +397,6 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
   }
 
   test("fetch continuous blocks in batch should respect maxBytesInFlight") {
-    val blockManager = createMockBlockManager()
     // Make sure remote blocks would return the merged block
     val remoteBmId1 = BlockManagerId("test-client-1", "test-client-1", 1)
     val remoteBmId2 = BlockManagerId("test-client-2", "test-client-2", 2)
@@ -443,28 +409,16 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
       ShuffleBlockBatchId(0, 3, 9, 12) -> createMockManagedBuffer(),
       ShuffleBlockBatchId(0, 3, 12, 15) -> createMockManagedBuffer(),
       ShuffleBlockBatchId(0, 4, 0, 2) -> createMockManagedBuffer())
-    val transfer = createMockTransfer(mergedRemoteBlocks)
-
-    val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
-      (remoteBmId1, remoteBlocks1.map(blockId => (blockId, 100L, 1))),
-      (remoteBmId2, remoteBlocks2.map(blockId => (blockId, 100L, 
1)))).toIterator
-
-    val taskContext = TaskContext.empty()
-    val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics()
-    val iterator = new ShuffleBlockFetcherIterator(
-      taskContext,
-      transfer,
-      blockManager,
-      blocksByAddress,
-      (_, in) => in,
-      1500,
-      Int.MaxValue,
-      Int.MaxValue,
-      Int.MaxValue,
-      true,
-      false,
-      metrics,
-      true)
+    configureMockTransfer(mergedRemoteBlocks)
+
+    val iterator = createShuffleBlockIteratorWithDefaults(
+      Map(
+        remoteBmId1 -> toBlockList(remoteBlocks1, 100L, 1),
+        remoteBmId2 -> toBlockList(remoteBlocks2, 100L, 1)
+      ),
+      maxBytesInFlight = 1500,
+      doBatchFetch = true
+    )
 
     var numResults = 0
     // After initialize(), there will be 6 FetchRequests. And each of the 
first 5 requests
@@ -472,7 +426,7 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
     // block which merged from 2 shuffle blocks. So, only the first 5 
requests(5 * 3 * 100 >= 1500)
     // can be sent. The 6th FetchRequest will hit maxBlocksInFlightPerAddress 
so it won't
     // be sent.
-    verify(transfer, times(5)).fetchBlocks(any(), any(), any(), any(), any(), 
any())
+    verifyFetchBlocksInvocationCount(5)
     while (iterator.hasNext) {
       val (blockId, inputStream) = iterator.next()
       // Make sure we release buffers when a wrapped input stream is closed.
@@ -481,12 +435,11 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
       numResults += 1
     }
     // The 6th request will be sent after next() is called.
-    verify(transfer, times(6)).fetchBlocks(any(), any(), any(), any(), any(), 
any())
+    verifyFetchBlocksInvocationCount(6)
     assert(numResults == 6)
   }
 
   test("fetch continuous blocks in batch should respect 
maxBlocksInFlightPerAddress") {
-    val blockManager = createMockBlockManager()
     // Make sure remote blocks would return the merged block
     val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 1)
     val remoteBlocks = Seq(
@@ -500,31 +453,18 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
       ShuffleBlockBatchId(0, 4, 0, 2) -> createMockManagedBuffer(),
       ShuffleBlockBatchId(0, 5, 0, 1) -> createMockManagedBuffer())
 
-    val transfer = createMockTransfer(mergedRemoteBlocks)
-    val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
-      (remoteBmId, remoteBlocks.map(blockId => (blockId, 100L, 1)))).toIterator
-    val taskContext = TaskContext.empty()
-    val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics()
-    val iterator = new ShuffleBlockFetcherIterator(
-      taskContext,
-      transfer,
-      blockManager,
-      blocksByAddress,
-      (_, in) => in,
-      Int.MaxValue,
-      Int.MaxValue,
-      2,
-      Int.MaxValue,
-      true,
-      false,
-      metrics,
-      true)
+    configureMockTransfer(mergedRemoteBlocks)
+    val iterator = createShuffleBlockIteratorWithDefaults(
+      Map(remoteBmId -> toBlockList(remoteBlocks, 100L, 1)),
+      maxBlocksInFlightPerAddress = 2,
+      doBatchFetch = true
+    )
     var numResults = 0
     // After initialize(), there will be 2 FetchRequests. First one has 2 
merged blocks and each
     // of them is merged from 2 shuffle blocks, second one has 1 merged block 
which is merged from
     // 1 shuffle block. So only the first FetchRequest can be sent. The second 
FetchRequest will
     // hit maxBlocksInFlightPerAddress so it won't be sent.
-    verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), 
any())
+    verifyFetchBlocksInvocationCount(1)
     while (iterator.hasNext) {
       val (blockId, inputStream) = iterator.next()
       // Make sure we release buffers when a wrapped input stream is closed.
@@ -533,12 +473,11 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
       numResults += 1
     }
     // The second request will be sent after next() is called.
-    verify(transfer, times(2)).fetchBlocks(any(), any(), any(), any(), any(), 
any())
+    verifyFetchBlocksInvocationCount(2)
     assert(numResults == 3)
   }
 
   test("release current unexhausted buffer in case the task completes early") {
-    val blockManager = createMockBlockManager()
     // Make sure remote blocks would return
     val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
     val blocks = Map[BlockId, ManagedBuffer](
@@ -549,40 +488,25 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
     // Semaphore to coordinate event sequence in two different threads.
     val sem = new Semaphore(0)
 
-    val transfer = mock(classOf[BlockTransferService])
-    when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
-      .thenAnswer((invocation: InvocationOnMock) => {
-        val listener = 
invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
-        Future {
-          // Return the first two blocks, and wait till task completion before 
returning the 3rd one
-          listener.onBlockFetchSuccess(
-            ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
-          listener.onBlockFetchSuccess(
-            ShuffleBlockId(0, 1, 0).toString, blocks(ShuffleBlockId(0, 1, 0)))
-          sem.acquire()
-          listener.onBlockFetchSuccess(
-            ShuffleBlockId(0, 2, 0).toString, blocks(ShuffleBlockId(0, 2, 0)))
-        }
-      })
-
-    val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
-      (remoteBmId, blocks.keys.map(blockId => (blockId, 1L, 
0)).toSeq)).toIterator
+    answerFetchBlocks { invocation =>
+      val listener = invocation.getArgument[BlockFetchingListener](4)
+      Future {
+        // Return the first two blocks, and wait till task completion before 
returning the 3rd one
+        listener.onBlockFetchSuccess(
+          ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
+        listener.onBlockFetchSuccess(
+          ShuffleBlockId(0, 1, 0).toString, blocks(ShuffleBlockId(0, 1, 0)))
+        sem.acquire()
+        listener.onBlockFetchSuccess(
+          ShuffleBlockId(0, 2, 0).toString, blocks(ShuffleBlockId(0, 2, 0)))
+      }
+    }
 
     val taskContext = TaskContext.empty()
-    val iterator = new ShuffleBlockFetcherIterator(
-      taskContext,
-      transfer,
-      blockManager,
-      blocksByAddress,
-      (_, in) => in,
-      48 * 1024 * 1024,
-      Int.MaxValue,
-      Int.MaxValue,
-      Int.MaxValue,
-      true,
-      false,
-      taskContext.taskMetrics.createTempShuffleReadMetrics(),
-      false)
+    val iterator = createShuffleBlockIteratorWithDefaults(
+      Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0)),
+      taskContext = Some(taskContext)
+    )
 
     verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release()
     iterator.next()._2.close() // close() first block's input stream
@@ -603,7 +527,6 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
   }
 
   test("fail all blocks if any of the remote request fails") {
-    val blockManager = createMockBlockManager()
     // Make sure remote blocks would return
     val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
     val blocks = Map[BlockId, ManagedBuffer](
@@ -615,41 +538,23 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
     // Semaphore to coordinate event sequence in two different threads.
     val sem = new Semaphore(0)
 
-    val transfer = mock(classOf[BlockTransferService])
-    when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
-      .thenAnswer((invocation: InvocationOnMock) => {
-        val listener = 
invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
-        Future {
-          // Return the first block, and then fail.
-          listener.onBlockFetchSuccess(
-            ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
-          listener.onBlockFetchFailure(
-            ShuffleBlockId(0, 1, 0).toString, new 
BlockNotFoundException("blah"))
+    answerFetchBlocks { invocation =>
+      val listener = invocation.getArgument[BlockFetchingListener](4)
+      Future {
+        // Return the first block, and then fail.
+        listener.onBlockFetchSuccess(
+          ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
+        listener.onBlockFetchFailure(
+          ShuffleBlockId(0, 1, 0).toString, new BlockNotFoundException("blah"))
           listener.onBlockFetchFailure(
             ShuffleBlockId(0, 2, 0).toString, new 
BlockNotFoundException("blah"))
-          sem.release()
-        }
-      })
-
-    val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
-      (remoteBmId, blocks.keys.map(blockId => (blockId, 1L, 0)).toSeq))
-      .toIterator
+        sem.release()
+      }
+    }
 
-    val taskContext = TaskContext.empty()
-    val iterator = new ShuffleBlockFetcherIterator(
-      taskContext,
-      transfer,
-      blockManager,
-      blocksByAddress,
-      (_, in) => in,
-      48 * 1024 * 1024,
-      Int.MaxValue,
-      Int.MaxValue,
-      Int.MaxValue,
-      true,
-      false,
-      taskContext.taskMetrics.createTempShuffleReadMetrics(),
-      false)
+    val iterator = createShuffleBlockIteratorWithDefaults(
+      Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0))
+    )
 
     // Continue only after the mock calls onBlockFetchFailure
     sem.acquire()
@@ -690,7 +595,6 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
   }
 
   test("retry corrupt blocks") {
-    val blockManager = createMockBlockManager()
     // Make sure remote blocks would return
     val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
     val blocks = Map[BlockId, ManagedBuffer](
@@ -703,40 +607,24 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
     val sem = new Semaphore(0)
     val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 
0, 100)
 
-    val transfer = mock(classOf[BlockTransferService])
-    when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
-      .thenAnswer((invocation: InvocationOnMock) => {
-        val listener = 
invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
-        Future {
-          // Return the first block, and then fail.
-          listener.onBlockFetchSuccess(
-            ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
-          listener.onBlockFetchSuccess(
-            ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer())
-          listener.onBlockFetchSuccess(
-            ShuffleBlockId(0, 2, 0).toString, corruptLocalBuffer)
-          sem.release()
-        }
-      })
-
-    val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
-      (remoteBmId, blocks.keys.map(blockId => (blockId, 1L, 
0)).toSeq)).toIterator
+    answerFetchBlocks { invocation =>
+      val listener = invocation.getArgument[BlockFetchingListener](4)
+      Future {
+        // Return the first block, and then fail.
+        listener.onBlockFetchSuccess(
+          ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
+        listener.onBlockFetchSuccess(
+          ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer())
+        listener.onBlockFetchSuccess(
+          ShuffleBlockId(0, 2, 0).toString, corruptLocalBuffer)
+        sem.release()
+      }
+    }
 
-    val taskContext = TaskContext.empty()
-    val iterator = new ShuffleBlockFetcherIterator(
-      taskContext,
-      transfer,
-      blockManager,
-      blocksByAddress,
-      (_, in) => new LimitedInputStream(in, 100),
-      48 * 1024 * 1024,
-      Int.MaxValue,
-      Int.MaxValue,
-      Int.MaxValue,
-      true,
-      true,
-      taskContext.taskMetrics.createTempShuffleReadMetrics(),
-      false)
+    val iterator = createShuffleBlockIteratorWithDefaults(
+      Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0)),
+      streamWrapperLimitSize = Some(100)
+    )
 
     // Continue only after the mock calls onBlockFetchFailure
     sem.acquire()
@@ -745,16 +633,14 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
     val (id1, _) = iterator.next()
     assert(id1 === ShuffleBlockId(0, 0, 0))
 
-    when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
-      .thenAnswer((invocation: InvocationOnMock) => {
-        val listener = 
invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
-        Future {
-          // Return the first block, and then fail.
-          listener.onBlockFetchSuccess(
-            ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer())
-          sem.release()
-        }
-      })
+    answerFetchBlocks { invocation =>
+      val listener = invocation.getArgument[BlockFetchingListener](4)
+      Future {
+        // Return the first block, and then fail.
+        listener.onBlockFetchSuccess(ShuffleBlockId(0, 1, 0).toString, 
mockCorruptBuffer())
+        sem.release()
+      }
+    }
 
     // The next block is corrupt local block (the second one is corrupt and 
retried)
     intercept[FetchFailedException] { iterator.next() }
@@ -765,47 +651,28 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
 
   test("big blocks are also checked for corruption") {
     val streamLength = 10000L
-    val blockManager = createMockBlockManager()
 
     // This stream will throw IOException when the first byte is read
     val corruptBuffer1 = mockCorruptBuffer(streamLength, 0)
     val blockManagerId1 = BlockManagerId("remote-client-1", "remote-client-1", 
1)
     val shuffleBlockId1 = ShuffleBlockId(0, 1, 0)
-    val blockLengths1 = Seq[Tuple3[BlockId, Long, Int]](
-      (shuffleBlockId1, corruptBuffer1.size(), 1)
-    )
 
     val streamNotCorruptTill = 8 * 1024
     // This stream will throw exception after streamNotCorruptTill bytes are 
read
     val corruptBuffer2 = mockCorruptBuffer(streamLength, streamNotCorruptTill)
     val blockManagerId2 = BlockManagerId("remote-client-2", "remote-client-2", 
2)
     val shuffleBlockId2 = ShuffleBlockId(0, 2, 0)
-    val blockLengths2 = Seq[Tuple3[BlockId, Long, Int]](
-      (shuffleBlockId2, corruptBuffer2.size(), 2)
-    )
 
-    val transfer = createMockTransfer(
+    configureMockTransfer(
       Map(shuffleBlockId1 -> corruptBuffer1, shuffleBlockId2 -> 
corruptBuffer2))
-    val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
-      (blockManagerId1, blockLengths1),
-      (blockManagerId2, blockLengths2)
-    ).toIterator
-    val taskContext = TaskContext.empty()
-    val maxBytesInFlight = 3 * 1024
-    val iterator = new ShuffleBlockFetcherIterator(
-      taskContext,
-      transfer,
-      blockManager,
-      blocksByAddress,
-      (_, in) => new LimitedInputStream(in, streamLength),
-      maxBytesInFlight,
-      Int.MaxValue,
-      Int.MaxValue,
-      Int.MaxValue,
-      true,
-      true,
-      taskContext.taskMetrics.createTempShuffleReadMetrics(),
-      false)
+    val iterator = createShuffleBlockIteratorWithDefaults(
+      Map(
+        blockManagerId1 -> toBlockList(Seq(shuffleBlockId1), 
corruptBuffer1.size(), 1),
+        blockManagerId2 -> toBlockList(Seq(shuffleBlockId2), 
corruptBuffer2.size(), 2)
+      ),
+      streamWrapperLimitSize = Some(streamLength),
+      maxBytesInFlight = 3 * 1024
+    )
 
     // We'll get back the block which has corruption after maxBytesInFlight/3 
because the other
     // block will detect corruption on first fetch, and then get added to the 
queue again for
@@ -848,30 +715,15 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
     val localBmId = BlockManagerId("test-client", "test-client", 1)
     doReturn(localBmId).when(blockManager).blockManagerId
     
doReturn(managedBuffer).when(blockManager).getLocalBlockData(meq(ShuffleBlockId(0,
 0, 0)))
-    val localBlockLengths = Seq[Tuple3[BlockId, Long, Int]](
-      (ShuffleBlockId(0, 0, 0), 10000, 0)
-    )
-    val transfer = createMockTransfer(Map(ShuffleBlockId(0, 0, 0) -> 
managedBuffer))
-    val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
-      (localBmId, localBlockLengths)
-    ).toIterator
+    configureMockTransfer(Map(ShuffleBlockId(0, 0, 0) -> managedBuffer))
 
-    val taskContext = TaskContext.empty()
-    val iterator = new ShuffleBlockFetcherIterator(
-      taskContext,
-      transfer,
-      blockManager,
-      blocksByAddress,
-      (_, in) => new LimitedInputStream(in, 10000),
-      2048,
-      Int.MaxValue,
-      Int.MaxValue,
-      Int.MaxValue,
-      true,
-      true,
-      taskContext.taskMetrics.createTempShuffleReadMetrics(),
-      false)
-    val (id, st) = iterator.next()
+    val iterator = createShuffleBlockIteratorWithDefaults(
+      Map(localBmId -> toBlockList(Seq(ShuffleBlockId(0, 0, 0)), 10000L, 0)),
+      blockManager = Some(blockManager),
+      streamWrapperLimitSize = Some(10000),
+      maxBytesInFlight = 2048 // force concatenation of stream by limiting 
bytes in flight
+    )
+    val (_, st) = iterator.next()
     // Check that the test setup is correct -- make sure we have a 
concatenated stream.
     assert 
(st.asInstanceOf[BufferReleasingInputStream].delegate.isInstanceOf[SequenceInputStream])
 
@@ -884,7 +736,6 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
   }
 
   test("retry corrupt blocks (disabled)") {
-    val blockManager = createMockBlockManager()
     // Make sure remote blocks would return
     val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
     val blocks = Map[BlockId, ManagedBuffer](
@@ -896,41 +747,25 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
     // Semaphore to coordinate event sequence in two different threads.
     val sem = new Semaphore(0)
 
-    val transfer = mock(classOf[BlockTransferService])
-    when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
-      .thenAnswer((invocation: InvocationOnMock) => {
-        val listener = 
invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
-        Future {
-          // Return the first block, and then fail.
-          listener.onBlockFetchSuccess(
+    answerFetchBlocks { invocation =>
+      val listener = invocation.getArgument[BlockFetchingListener](4)
+      Future {
+        // Return the first block, and then fail.
+        listener.onBlockFetchSuccess(
             ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
-          listener.onBlockFetchSuccess(
-            ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer())
-          listener.onBlockFetchSuccess(
-            ShuffleBlockId(0, 2, 0).toString, mockCorruptBuffer())
-          sem.release()
-        }
-      })
-
-    val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
-      (remoteBmId, blocks.keys.map(blockId => (blockId, 1L, 0)).toSeq))
-      .toIterator
+        listener.onBlockFetchSuccess(
+          ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer())
+        listener.onBlockFetchSuccess(
+          ShuffleBlockId(0, 2, 0).toString, mockCorruptBuffer())
+        sem.release()
+      }
+    }
 
-    val taskContext = TaskContext.empty()
-    val iterator = new ShuffleBlockFetcherIterator(
-      taskContext,
-      transfer,
-      blockManager,
-      blocksByAddress,
-      (_, in) => new LimitedInputStream(in, 100),
-      48 * 1024 * 1024,
-      Int.MaxValue,
-      Int.MaxValue,
-      Int.MaxValue,
-      true,
-      false,
-      taskContext.taskMetrics.createTempShuffleReadMetrics(),
-      false)
+    val iterator = createShuffleBlockIteratorWithDefaults(
+      Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0)),
+      streamWrapperLimitSize = Some(100),
+      detectCorruptUseExtraMemory = false
+    )
 
     // Continue only after the mock calls onBlockFetchFailure
     sem.acquire()
@@ -958,57 +793,38 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
     val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
     val remoteBlocks = Map[BlockId, ManagedBuffer](
       ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer())
-    val transfer = mock(classOf[BlockTransferService])
     var tempFileManager: DownloadFileManager = null
-    when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
-      .thenAnswer((invocation: InvocationOnMock) => {
-        val listener = 
invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
-        tempFileManager = 
invocation.getArguments()(5).asInstanceOf[DownloadFileManager]
-        Future {
-          listener.onBlockFetchSuccess(
-            ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 
0, 0)))
-        }
-      })
+    answerFetchBlocks { invocation =>
+      val listener = invocation.getArgument[BlockFetchingListener](4)
+      tempFileManager = invocation.getArgument[DownloadFileManager](5)
+      Future {
+        listener.onBlockFetchSuccess(
+          ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 
0)))
+      }
+    }
 
-    def fetchShuffleBlock(
-        blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, 
Int)])]): Unit = {
-      // Set `maxBytesInFlight` and `maxReqsInFlight` to `Int.MaxValue`, so 
that during the
+    def fetchShuffleBlock(blockSize: Long): Unit = {
+      // Use default `maxBytesInFlight` and `maxReqsInFlight` (`Int.MaxValue`) 
so that during the
       // construction of `ShuffleBlockFetcherIterator`, all requests to fetch 
remote shuffle blocks
       // are issued. The `maxReqSizeShuffleToMem` is hard-coded as 200 here.
-      val taskContext = TaskContext.empty()
-      new ShuffleBlockFetcherIterator(
-        taskContext,
-        transfer,
-        blockManager,
-        blocksByAddress,
-        (_, in) => in,
-        maxBytesInFlight = Int.MaxValue,
-        maxReqsInFlight = Int.MaxValue,
-        maxBlocksInFlightPerAddress = Int.MaxValue,
-        maxReqSizeShuffleToMem = 200,
-        detectCorrupt = true,
-        false,
-        taskContext.taskMetrics.createTempShuffleReadMetrics(),
-        false)
+      createShuffleBlockIteratorWithDefaults(
+        Map(remoteBmId -> toBlockList(remoteBlocks.keys, blockSize, 0)),
+        blockManager = Some(blockManager),
+        maxReqSizeShuffleToMem = 200)
     }
 
-    val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
-      (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L, 
0)).toSeq)).toIterator
-    fetchShuffleBlock(blocksByAddress1)
+    fetchShuffleBlock(100L)
     // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 
100, so don't fetch
     // shuffle block to disk.
     assert(tempFileManager == null)
 
-    val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
-      (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L, 
0)).toSeq)).toIterator
-    fetchShuffleBlock(blocksByAddress2)
+    fetchShuffleBlock(300L)
     // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 
300, so fetch
     // shuffle block to disk.
     assert(tempFileManager != null)
   }
 
   test("fail zero-size blocks") {
-    val blockManager = createMockBlockManager()
     // Make sure remote blocks would return
     val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
     val blocks = Map[BlockId, ManagedBuffer](
@@ -1016,26 +832,11 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
       ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer()
     )
 
-    val transfer = createMockTransfer(blocks.mapValues(_ => 
createMockManagedBuffer(0)).toMap)
-
-    val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
-      (remoteBmId, blocks.keys.map(blockId => (blockId, 1L, 0)).toSeq))
+    configureMockTransfer(blocks.mapValues(_ => 
createMockManagedBuffer(0)).toMap)
 
-    val taskContext = TaskContext.empty()
-    val iterator = new ShuffleBlockFetcherIterator(
-      taskContext,
-      transfer,
-      blockManager,
-      blocksByAddress.toIterator,
-      (_, in) => in,
-      48 * 1024 * 1024,
-      Int.MaxValue,
-      Int.MaxValue,
-      Int.MaxValue,
-      true,
-      false,
-      taskContext.taskMetrics.createTempShuffleReadMetrics(),
-      false)
+    val iterator = createShuffleBlockIteratorWithDefaults(
+      Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0))
+    )
 
     // All blocks fetched return zero length and should trigger a receive-side 
error:
     val e = intercept[FetchFailedException] { iterator.next() }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to