This is an automated email from the ASF dual-hosted git repository. mridulm80 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 186477c [SPARK-35263][TEST] Refactor ShuffleBlockFetcherIteratorSuite to reduce duplicated code 186477c is described below commit 186477c60e9cad71434b15fd9e08789740425d59 Author: Erik Krogen <xkro...@apache.org> AuthorDate: Tue May 18 22:37:47 2021 -0500 [SPARK-35263][TEST] Refactor ShuffleBlockFetcherIteratorSuite to reduce duplicated code ### What changes were proposed in this pull request? Introduce new shared methods to `ShuffleBlockFetcherIteratorSuite` to replace copy-pasted code. Use modern, Scala-like Mockito `Answer` syntax. ### Why are the changes needed? `ShuffleFetcherBlockIteratorSuite` has tons of duplicate code, like https://github.com/apache/spark/blob/0494dc90af48ce7da0625485a4dc6917a244d580/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala#L172-L185 . It's challenging to tell what the interesting parts are vs. what is just being set to some default/unused value. Similarly but not as bad, there are many calls like the following ``` verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any()) when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())).thenAnswer ... ``` These changes result in about 10% reduction in both lines and characters in the file: ```bash # Before > wc core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala 1063 3950 43201 core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala # After > wc core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala 928 3609 39053 core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala ``` It also helps readability, e.g.: ``` val iterator = createShuffleBlockIteratorWithDefaults( transfer, blocksByAddress, maxBytesInFlight = 1000L ) ``` Now I can clearly tell that `maxBytesInFlight` is the main parameter we're interested in here. ### Does this PR introduce _any_ user-facing change? No, test only. There aren't even any behavior changes, just refactoring. ### How was this patch tested? Unit tests pass. Closes #32389 from xkrogen/xkrogen-spark-35263-refactor-shuffleblockfetcheriteratorsuite. Authored-by: Erik Krogen <xkro...@apache.org> Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com> --- .../storage/ShuffleBlockFetcherIteratorSuite.scala | 689 ++++++++------------- 1 file changed, 245 insertions(+), 444 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala index 99c43b1..4be5fae 100644 --- a/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala @@ -27,7 +27,7 @@ import scala.concurrent.Future import org.mockito.ArgumentMatchers.{any, eq => meq} import org.mockito.Mockito.{mock, times, verify, when} -import org.mockito.invocation.InvocationOnMock +import org.mockito.stubbing.Answer import org.scalatest.PrivateMethodTester import org.apache.spark.{SparkFunSuite, TaskContext} @@ -35,35 +35,44 @@ import org.apache.spark.network._ import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer} import org.apache.spark.network.shuffle.{BlockFetchingListener, DownloadFileManager, ExternalBlockStoreClient} import org.apache.spark.network.util.LimitedInputStream -import org.apache.spark.shuffle.FetchFailedException +import org.apache.spark.shuffle.{FetchFailedException, ShuffleReadMetricsReporter} import org.apache.spark.storage.ShuffleBlockFetcherIterator.FetchBlockInfo import org.apache.spark.util.Utils class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodTester { + private var transfer: BlockTransferService = _ + + override def beforeEach(): Unit = { + transfer = mock(classOf[BlockTransferService]) + } + private def doReturn(value: Any) = org.mockito.Mockito.doReturn(value, Seq.empty: _*) + private def answerFetchBlocks(answer: Answer[Unit]): Unit = + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())).thenAnswer(answer) + + private def verifyFetchBlocksInvocationCount(expectedCount: Int): Unit = + verify(transfer, times(expectedCount)).fetchBlocks(any(), any(), any(), any(), any(), any()) + // Some of the tests are quite tricky because we are testing the cleanup behavior // in the presence of faults. - /** Creates a mock [[BlockTransferService]] that returns data from the given map. */ - private def createMockTransfer(data: Map[BlockId, ManagedBuffer]): BlockTransferService = { - val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())).thenAnswer( - (invocation: InvocationOnMock) => { - val blocks = invocation.getArguments()(3).asInstanceOf[Array[String]] - val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - - for (blockId <- blocks) { - if (data.contains(BlockId(blockId))) { - listener.onBlockFetchSuccess(blockId, data(BlockId(blockId))) - } else { - listener.onBlockFetchFailure(blockId, new BlockNotFoundException(blockId)) - } + /** Configures `transfer` (mock [[BlockTransferService]]) to return data from the given map. */ + private def configureMockTransfer(data: Map[BlockId, ManagedBuffer]): Unit = { + answerFetchBlocks { invocation => + val blocks = invocation.getArgument[Array[String]](3) + val listener = invocation.getArgument[BlockFetchingListener](4) + + for (blockId <- blocks) { + if (data.contains(BlockId(blockId))) { + listener.onBlockFetchSuccess(blockId, data(BlockId(blockId))) + } else { + listener.onBlockFetchFailure(blockId, new BlockNotFoundException(blockId)) } - }) - transfer + } + } } private def createMockBlockManager(): BlockManager = { @@ -88,10 +97,9 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT when(blockManager.hostLocalDirManager).thenReturn(Some(hostLocalDirManager)) when(mockExternalBlockStoreClient.getHostLocalDirs(any(), any(), any(), any())) .thenAnswer { invocation => - val completableFuture = invocation.getArguments()(3) - .asInstanceOf[CompletableFuture[java.util.Map[String, Array[String]]]] import scala.collection.JavaConverters._ - completableFuture.complete(hostLocalDirs.asJava) + invocation.getArgument[CompletableFuture[java.util.Map[String, Array[String]]]](3) + .complete(hostLocalDirs.asJava) } blockManager.hostLocalDirManager = Some(hostLocalDirManager) @@ -123,6 +131,49 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT verify(wrappedInputStream.invokePrivate(delegateAccess()), times(1)).close() } + // scalastyle:off argcount + private def createShuffleBlockIteratorWithDefaults( + blocksByAddress: Map[BlockManagerId, Seq[(BlockId, Long, Int)]], + taskContext: Option[TaskContext] = None, + streamWrapperLimitSize: Option[Long] = None, + blockManager: Option[BlockManager] = None, + maxBytesInFlight: Long = Long.MaxValue, + maxReqsInFlight: Int = Int.MaxValue, + maxBlocksInFlightPerAddress: Int = Int.MaxValue, + maxReqSizeShuffleToMem: Int = Int.MaxValue, + detectCorrupt: Boolean = true, + detectCorruptUseExtraMemory: Boolean = true, + shuffleMetrics: Option[ShuffleReadMetricsReporter] = None, + doBatchFetch: Boolean = false): ShuffleBlockFetcherIterator = { + val tContext = taskContext.getOrElse(TaskContext.empty()) + new ShuffleBlockFetcherIterator( + tContext, + transfer, + blockManager.getOrElse(createMockBlockManager()), + blocksByAddress.toIterator, + (_, in) => streamWrapperLimitSize.map(new LimitedInputStream(in, _)).getOrElse(in), + maxBytesInFlight, + maxReqsInFlight, + maxBlocksInFlightPerAddress, + maxReqSizeShuffleToMem, + detectCorrupt, + detectCorruptUseExtraMemory, + shuffleMetrics.getOrElse(tContext.taskMetrics().createTempShuffleReadMetrics()), + doBatchFetch) + } + // scalastyle:on argcount + + /** + * Convert a list of block IDs into a list of blocks with metadata, assuming all blocks have the + * same size and index. + */ + private def toBlockList( + blockIds: Traversable[BlockId], + blockSize: Long, + blockMapIndex: Int): Seq[(BlockId, Long, Int)] = { + blockIds.map(blockId => (blockId, blockSize, blockMapIndex)).toSeq + } + test("successful 3 local + 4 host local + 2 remote reads") { val blockManager = createMockBlockManager() val localBmId = blockManager.blockManagerId @@ -142,15 +193,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ShuffleBlockId(0, 3, 0) -> createMockManagedBuffer(), ShuffleBlockId(0, 4, 0) -> createMockManagedBuffer()) - val transfer = createMockTransfer(remoteBlocks) + configureMockTransfer(remoteBlocks) // Create a block manager running on the same host (host-local) val hostLocalBmId = BlockManagerId("test-host-local-client-1", "test-local-host", 3) - val hostLocalBlocks = Map[BlockId, ManagedBuffer]( - ShuffleBlockId(0, 5, 0) -> createMockManagedBuffer(), - ShuffleBlockId(0, 6, 0) -> createMockManagedBuffer(), - ShuffleBlockId(0, 7, 0) -> createMockManagedBuffer(), - ShuffleBlockId(0, 8, 0) -> createMockManagedBuffer()) + val hostLocalBlocks = 5.to(8).map(ShuffleBlockId(0, _, 0) -> createMockManagedBuffer()).toMap hostLocalBlocks.foreach { case (blockId, buf) => doReturn(buf) @@ -161,28 +208,14 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // returning local dir for hostLocalBmId initHostLocalDirManager(blockManager, hostLocalDirs) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( - (localBmId, localBlocks.keys.map(blockId => (blockId, 1L, 0)).toSeq), - (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 1L, 1)).toSeq), - (hostLocalBmId, hostLocalBlocks.keys.map(blockId => (blockId, 1L, 1)).toSeq) - ).toIterator - - val taskContext = TaskContext.empty() - val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() - val iterator = new ShuffleBlockFetcherIterator( - taskContext, - transfer, - blockManager, - blocksByAddress, - (_, in) => in, - 48 * 1024 * 1024, - Int.MaxValue, - Int.MaxValue, - Int.MaxValue, - true, - false, - metrics, - false) + val iterator = createShuffleBlockIteratorWithDefaults( + Map( + localBmId -> toBlockList(localBlocks.keys, 1L, 0), + remoteBmId -> toBlockList(remoteBlocks.keys, 1L, 1), + hostLocalBmId -> toBlockList(hostLocalBlocks.keys, 1L, 1) + ), + blockManager = Some(blockManager) + ) // 3 local blocks fetched in initialization verify(blockManager, times(3)).getLocalBlockData(any()) @@ -203,7 +236,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT .getHostLocalShuffleData(any(), meq(Array("local-dir"))) // 2 remote blocks are read from the same block manager - verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any()) + verifyFetchBlocksInvocationCount(1) assert(blockManager.hostLocalDirManager.get.getCachedHostLocalDirs.size === 1) } @@ -228,117 +261,64 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT when(blockManager.hostLocalDirManager).thenReturn(Some(hostLocalDirManager)) when(mockExternalBlockStoreClient.getHostLocalDirs(any(), any(), any(), any())) .thenAnswer { invocation => - val completableFuture = invocation.getArguments()(3) - .asInstanceOf[CompletableFuture[java.util.Map[String, Array[String]]]] - completableFuture.completeExceptionally(new Throwable("failed fetch")) + invocation.getArgument[CompletableFuture[java.util.Map[String, Array[String]]]](3) + .completeExceptionally(new Throwable("failed fetch")) } blockManager.hostLocalDirManager = Some(hostLocalDirManager) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( - (hostLocalBmId, hostLocalBlocks.keys.map(blockId => (blockId, 1L, 1)).toSeq) - ).toIterator - val transfer = createMockTransfer(Map()) - val taskContext = TaskContext.empty() - val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() - val iterator = new ShuffleBlockFetcherIterator( - taskContext, - transfer, - blockManager, - blocksByAddress, - (_, in) => in, - 48 * 1024 * 1024, - Int.MaxValue, - Int.MaxValue, - Int.MaxValue, - true, - false, - metrics, - false) + configureMockTransfer(Map()) + val iterator = createShuffleBlockIteratorWithDefaults( + Map(hostLocalBmId -> toBlockList(hostLocalBlocks.keys, 1L, 1)) + ) intercept[FetchFailedException] { iterator.next() } } test("Hit maxBytesInFlight limitation before maxBlocksInFlightPerAddress") { - val blockManager = createMockBlockManager() val remoteBmId1 = BlockManagerId("test-remote-client-1", "test-remote-host1", 1) val remoteBmId2 = BlockManagerId("test-remote-client-2", "test-remote-host2", 2) val blockId1 = ShuffleBlockId(0, 1, 0) val blockId2 = ShuffleBlockId(1, 1, 0) - val blocksByAddress = Seq( - (remoteBmId1, Seq((blockId1, 1000L, 0))), - (remoteBmId2, Seq((blockId2, 1000L, 0)))).toIterator - val transfer = createMockTransfer(Map( + configureMockTransfer(Map( blockId1 -> createMockManagedBuffer(1000), blockId2 -> createMockManagedBuffer(1000))) - val taskContext = TaskContext.empty() - val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() - val iterator = new ShuffleBlockFetcherIterator( - taskContext, - transfer, - blockManager, - blocksByAddress, - (_, in) => in, - 1000L, // allow 1 FetchRequests at most at the same time - Int.MaxValue, - Int.MaxValue, // set maxBlocksInFlightPerAddress to Int.MaxValue - Int.MaxValue, - true, - false, - metrics, - false) + val iterator = createShuffleBlockIteratorWithDefaults(Map( + remoteBmId1 -> toBlockList(Seq(blockId1), 1000L, 0), + remoteBmId2 -> toBlockList(Seq(blockId2), 1000L, 0) + ), maxBytesInFlight = 1000L) // After initialize() we'll have 2 FetchRequests and each is 1000 bytes. So only the // first FetchRequests can be sent, and the second one will hit maxBytesInFlight so // it won't be sent. - verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any()) + verifyFetchBlocksInvocationCount(1) assert(iterator.hasNext) // next() will trigger off sending deferred request iterator.next() // the second FetchRequest should be sent at this time - verify(transfer, times(2)).fetchBlocks(any(), any(), any(), any(), any(), any()) + verifyFetchBlocksInvocationCount(2) assert(iterator.hasNext) iterator.next() assert(!iterator.hasNext) } test("Hit maxBlocksInFlightPerAddress limitation before maxBytesInFlight") { - val blockManager = createMockBlockManager() val remoteBmId = BlockManagerId("test-remote-client-1", "test-remote-host", 2) - val blockId1 = ShuffleBlockId(0, 1, 0) - val blockId2 = ShuffleBlockId(0, 2, 0) - val blockId3 = ShuffleBlockId(0, 3, 0) - val blocksByAddress = Seq((remoteBmId, - Seq((blockId1, 1000L, 0), (blockId2, 1000L, 0), (blockId3, 1000L, 0)))).toIterator - val transfer = createMockTransfer(Map( - blockId1 -> createMockManagedBuffer(), - blockId2 -> createMockManagedBuffer(), - blockId3 -> createMockManagedBuffer())) - val taskContext = TaskContext.empty() - val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() - val iterator = new ShuffleBlockFetcherIterator( - taskContext, - transfer, - blockManager, - blocksByAddress, - (_, in) => in, - Int.MaxValue, // set maxBytesInFlight to Int.MaxValue - Int.MaxValue, - 2, // set maxBlocksInFlightPerAddress to 2 - Int.MaxValue, - true, - false, - metrics, - false) + val blocks = 1.to(3).map(ShuffleBlockId(0, _, 0)) + configureMockTransfer(blocks.map(_ -> createMockManagedBuffer()).toMap) + val iterator = createShuffleBlockIteratorWithDefaults( + Map(remoteBmId -> toBlockList(blocks, 1000L, 0)), + maxBlocksInFlightPerAddress = 2 + ) // After initialize(), we'll have 2 FetchRequests that one has 2 blocks inside and another one // has only one block. So only the first FetchRequest can be sent. The second FetchRequest will // hit maxBlocksInFlightPerAddress so it won't be sent. - verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any()) + verifyFetchBlocksInvocationCount(1) // the first request packaged 2 blocks, so we also need to // call next() for 2 times to exhaust the iterator. assert(iterator.hasNext) iterator.next() assert(iterator.hasNext) iterator.next() - verify(transfer, times(2)).fetchBlocks(any(), any(), any(), any(), any(), any()) + verifyFetchBlocksInvocationCount(2) assert(iterator.hasNext) iterator.next() assert(!iterator.hasNext) @@ -365,7 +345,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ShuffleBlockId(0, 3, 1)) val mergedRemoteBlocks = Map[BlockId, ManagedBuffer]( ShuffleBlockBatchId(0, 3, 0, 2) -> createMockManagedBuffer()) - val transfer = createMockTransfer(mergedRemoteBlocks) + configureMockTransfer(mergedRemoteBlocks) // Create a block manager running on the same host (host-local) val hostLocalBmId = BlockManagerId("test-host-local-client-1", "test-local-host", 3) @@ -386,28 +366,15 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // returning local dir for hostLocalBmId initHostLocalDirManager(blockManager, hostLocalDirs) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( - (localBmId, localBlocks.map(blockId => (blockId, 1L, 0))), - (remoteBmId, remoteBlocks.map(blockId => (blockId, 1L, 1))), - (hostLocalBmId, hostLocalBlocks.keys.map(blockId => (blockId, 1L, 1)).toSeq) - ).toIterator - - val taskContext = TaskContext.empty() - val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() - val iterator = new ShuffleBlockFetcherIterator( - taskContext, - transfer, - blockManager, - blocksByAddress, - (_, in) => in, - 48 * 1024 * 1024, - Int.MaxValue, - Int.MaxValue, - Int.MaxValue, - true, - false, - metrics, - true) + val iterator = createShuffleBlockIteratorWithDefaults( + Map( + localBmId -> toBlockList(localBlocks, 1L, 0), + remoteBmId -> toBlockList(remoteBlocks, 1L, 1), + hostLocalBmId -> toBlockList(hostLocalBlocks.keys, 1L, 1) + ), + blockManager = Some(blockManager), + doBatchFetch = true + ) // 3 local blocks batch fetched in initialization verify(blockManager, times(1)).getLocalBlockData(any()) @@ -416,7 +383,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT for (i <- 0 until 3) { assert(iterator.hasNext, s"iterator should have 3 elements but actually has $i elements") val (blockId, inputStream) = iterator.next() - verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any()) + verifyFetchBlocksInvocationCount(1) // Make sure we release buffers when a wrapped input stream is closed. val mockBuf = allBlocks(blockId) verifyBufferRelease(mockBuf, inputStream) @@ -430,7 +397,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } test("fetch continuous blocks in batch should respect maxBytesInFlight") { - val blockManager = createMockBlockManager() // Make sure remote blocks would return the merged block val remoteBmId1 = BlockManagerId("test-client-1", "test-client-1", 1) val remoteBmId2 = BlockManagerId("test-client-2", "test-client-2", 2) @@ -443,28 +409,16 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ShuffleBlockBatchId(0, 3, 9, 12) -> createMockManagedBuffer(), ShuffleBlockBatchId(0, 3, 12, 15) -> createMockManagedBuffer(), ShuffleBlockBatchId(0, 4, 0, 2) -> createMockManagedBuffer()) - val transfer = createMockTransfer(mergedRemoteBlocks) - - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( - (remoteBmId1, remoteBlocks1.map(blockId => (blockId, 100L, 1))), - (remoteBmId2, remoteBlocks2.map(blockId => (blockId, 100L, 1)))).toIterator - - val taskContext = TaskContext.empty() - val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() - val iterator = new ShuffleBlockFetcherIterator( - taskContext, - transfer, - blockManager, - blocksByAddress, - (_, in) => in, - 1500, - Int.MaxValue, - Int.MaxValue, - Int.MaxValue, - true, - false, - metrics, - true) + configureMockTransfer(mergedRemoteBlocks) + + val iterator = createShuffleBlockIteratorWithDefaults( + Map( + remoteBmId1 -> toBlockList(remoteBlocks1, 100L, 1), + remoteBmId2 -> toBlockList(remoteBlocks2, 100L, 1) + ), + maxBytesInFlight = 1500, + doBatchFetch = true + ) var numResults = 0 // After initialize(), there will be 6 FetchRequests. And each of the first 5 requests @@ -472,7 +426,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // block which merged from 2 shuffle blocks. So, only the first 5 requests(5 * 3 * 100 >= 1500) // can be sent. The 6th FetchRequest will hit maxBlocksInFlightPerAddress so it won't // be sent. - verify(transfer, times(5)).fetchBlocks(any(), any(), any(), any(), any(), any()) + verifyFetchBlocksInvocationCount(5) while (iterator.hasNext) { val (blockId, inputStream) = iterator.next() // Make sure we release buffers when a wrapped input stream is closed. @@ -481,12 +435,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT numResults += 1 } // The 6th request will be sent after next() is called. - verify(transfer, times(6)).fetchBlocks(any(), any(), any(), any(), any(), any()) + verifyFetchBlocksInvocationCount(6) assert(numResults == 6) } test("fetch continuous blocks in batch should respect maxBlocksInFlightPerAddress") { - val blockManager = createMockBlockManager() // Make sure remote blocks would return the merged block val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 1) val remoteBlocks = Seq( @@ -500,31 +453,18 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ShuffleBlockBatchId(0, 4, 0, 2) -> createMockManagedBuffer(), ShuffleBlockBatchId(0, 5, 0, 1) -> createMockManagedBuffer()) - val transfer = createMockTransfer(mergedRemoteBlocks) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( - (remoteBmId, remoteBlocks.map(blockId => (blockId, 100L, 1)))).toIterator - val taskContext = TaskContext.empty() - val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() - val iterator = new ShuffleBlockFetcherIterator( - taskContext, - transfer, - blockManager, - blocksByAddress, - (_, in) => in, - Int.MaxValue, - Int.MaxValue, - 2, - Int.MaxValue, - true, - false, - metrics, - true) + configureMockTransfer(mergedRemoteBlocks) + val iterator = createShuffleBlockIteratorWithDefaults( + Map(remoteBmId -> toBlockList(remoteBlocks, 100L, 1)), + maxBlocksInFlightPerAddress = 2, + doBatchFetch = true + ) var numResults = 0 // After initialize(), there will be 2 FetchRequests. First one has 2 merged blocks and each // of them is merged from 2 shuffle blocks, second one has 1 merged block which is merged from // 1 shuffle block. So only the first FetchRequest can be sent. The second FetchRequest will // hit maxBlocksInFlightPerAddress so it won't be sent. - verify(transfer, times(1)).fetchBlocks(any(), any(), any(), any(), any(), any()) + verifyFetchBlocksInvocationCount(1) while (iterator.hasNext) { val (blockId, inputStream) = iterator.next() // Make sure we release buffers when a wrapped input stream is closed. @@ -533,12 +473,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT numResults += 1 } // The second request will be sent after next() is called. - verify(transfer, times(2)).fetchBlocks(any(), any(), any(), any(), any(), any()) + verifyFetchBlocksInvocationCount(2) assert(numResults == 3) } test("release current unexhausted buffer in case the task completes early") { - val blockManager = createMockBlockManager() // Make sure remote blocks would return val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val blocks = Map[BlockId, ManagedBuffer]( @@ -549,40 +488,25 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Semaphore to coordinate event sequence in two different threads. val sem = new Semaphore(0) - val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) - .thenAnswer((invocation: InvocationOnMock) => { - val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - Future { - // Return the first two blocks, and wait till task completion before returning the 3rd one - listener.onBlockFetchSuccess( - ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) - listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, blocks(ShuffleBlockId(0, 1, 0))) - sem.acquire() - listener.onBlockFetchSuccess( - ShuffleBlockId(0, 2, 0).toString, blocks(ShuffleBlockId(0, 2, 0))) - } - }) - - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1L, 0)).toSeq)).toIterator + answerFetchBlocks { invocation => + val listener = invocation.getArgument[BlockFetchingListener](4) + Future { + // Return the first two blocks, and wait till task completion before returning the 3rd one + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 1, 0).toString, blocks(ShuffleBlockId(0, 1, 0))) + sem.acquire() + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 2, 0).toString, blocks(ShuffleBlockId(0, 2, 0))) + } + } val taskContext = TaskContext.empty() - val iterator = new ShuffleBlockFetcherIterator( - taskContext, - transfer, - blockManager, - blocksByAddress, - (_, in) => in, - 48 * 1024 * 1024, - Int.MaxValue, - Int.MaxValue, - Int.MaxValue, - true, - false, - taskContext.taskMetrics.createTempShuffleReadMetrics(), - false) + val iterator = createShuffleBlockIteratorWithDefaults( + Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0)), + taskContext = Some(taskContext) + ) verify(blocks(ShuffleBlockId(0, 0, 0)), times(0)).release() iterator.next()._2.close() // close() first block's input stream @@ -603,7 +527,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } test("fail all blocks if any of the remote request fails") { - val blockManager = createMockBlockManager() // Make sure remote blocks would return val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val blocks = Map[BlockId, ManagedBuffer]( @@ -615,41 +538,23 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Semaphore to coordinate event sequence in two different threads. val sem = new Semaphore(0) - val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) - .thenAnswer((invocation: InvocationOnMock) => { - val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - Future { - // Return the first block, and then fail. - listener.onBlockFetchSuccess( - ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) - listener.onBlockFetchFailure( - ShuffleBlockId(0, 1, 0).toString, new BlockNotFoundException("blah")) + answerFetchBlocks { invocation => + val listener = invocation.getArgument[BlockFetchingListener](4) + Future { + // Return the first block, and then fail. + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) + listener.onBlockFetchFailure( + ShuffleBlockId(0, 1, 0).toString, new BlockNotFoundException("blah")) listener.onBlockFetchFailure( ShuffleBlockId(0, 2, 0).toString, new BlockNotFoundException("blah")) - sem.release() - } - }) - - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1L, 0)).toSeq)) - .toIterator + sem.release() + } + } - val taskContext = TaskContext.empty() - val iterator = new ShuffleBlockFetcherIterator( - taskContext, - transfer, - blockManager, - blocksByAddress, - (_, in) => in, - 48 * 1024 * 1024, - Int.MaxValue, - Int.MaxValue, - Int.MaxValue, - true, - false, - taskContext.taskMetrics.createTempShuffleReadMetrics(), - false) + val iterator = createShuffleBlockIteratorWithDefaults( + Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0)) + ) // Continue only after the mock calls onBlockFetchFailure sem.acquire() @@ -690,7 +595,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } test("retry corrupt blocks") { - val blockManager = createMockBlockManager() // Make sure remote blocks would return val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val blocks = Map[BlockId, ManagedBuffer]( @@ -703,40 +607,24 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val sem = new Semaphore(0) val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 0, 100) - val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) - .thenAnswer((invocation: InvocationOnMock) => { - val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - Future { - // Return the first block, and then fail. - listener.onBlockFetchSuccess( - ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) - listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) - listener.onBlockFetchSuccess( - ShuffleBlockId(0, 2, 0).toString, corruptLocalBuffer) - sem.release() - } - }) - - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1L, 0)).toSeq)).toIterator + answerFetchBlocks { invocation => + val listener = invocation.getArgument[BlockFetchingListener](4) + Future { + // Return the first block, and then fail. + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 2, 0).toString, corruptLocalBuffer) + sem.release() + } + } - val taskContext = TaskContext.empty() - val iterator = new ShuffleBlockFetcherIterator( - taskContext, - transfer, - blockManager, - blocksByAddress, - (_, in) => new LimitedInputStream(in, 100), - 48 * 1024 * 1024, - Int.MaxValue, - Int.MaxValue, - Int.MaxValue, - true, - true, - taskContext.taskMetrics.createTempShuffleReadMetrics(), - false) + val iterator = createShuffleBlockIteratorWithDefaults( + Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0)), + streamWrapperLimitSize = Some(100) + ) // Continue only after the mock calls onBlockFetchFailure sem.acquire() @@ -745,16 +633,14 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val (id1, _) = iterator.next() assert(id1 === ShuffleBlockId(0, 0, 0)) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) - .thenAnswer((invocation: InvocationOnMock) => { - val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - Future { - // Return the first block, and then fail. - listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) - sem.release() - } - }) + answerFetchBlocks { invocation => + val listener = invocation.getArgument[BlockFetchingListener](4) + Future { + // Return the first block, and then fail. + listener.onBlockFetchSuccess(ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) + sem.release() + } + } // The next block is corrupt local block (the second one is corrupt and retried) intercept[FetchFailedException] { iterator.next() } @@ -765,47 +651,28 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT test("big blocks are also checked for corruption") { val streamLength = 10000L - val blockManager = createMockBlockManager() // This stream will throw IOException when the first byte is read val corruptBuffer1 = mockCorruptBuffer(streamLength, 0) val blockManagerId1 = BlockManagerId("remote-client-1", "remote-client-1", 1) val shuffleBlockId1 = ShuffleBlockId(0, 1, 0) - val blockLengths1 = Seq[Tuple3[BlockId, Long, Int]]( - (shuffleBlockId1, corruptBuffer1.size(), 1) - ) val streamNotCorruptTill = 8 * 1024 // This stream will throw exception after streamNotCorruptTill bytes are read val corruptBuffer2 = mockCorruptBuffer(streamLength, streamNotCorruptTill) val blockManagerId2 = BlockManagerId("remote-client-2", "remote-client-2", 2) val shuffleBlockId2 = ShuffleBlockId(0, 2, 0) - val blockLengths2 = Seq[Tuple3[BlockId, Long, Int]]( - (shuffleBlockId2, corruptBuffer2.size(), 2) - ) - val transfer = createMockTransfer( + configureMockTransfer( Map(shuffleBlockId1 -> corruptBuffer1, shuffleBlockId2 -> corruptBuffer2)) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( - (blockManagerId1, blockLengths1), - (blockManagerId2, blockLengths2) - ).toIterator - val taskContext = TaskContext.empty() - val maxBytesInFlight = 3 * 1024 - val iterator = new ShuffleBlockFetcherIterator( - taskContext, - transfer, - blockManager, - blocksByAddress, - (_, in) => new LimitedInputStream(in, streamLength), - maxBytesInFlight, - Int.MaxValue, - Int.MaxValue, - Int.MaxValue, - true, - true, - taskContext.taskMetrics.createTempShuffleReadMetrics(), - false) + val iterator = createShuffleBlockIteratorWithDefaults( + Map( + blockManagerId1 -> toBlockList(Seq(shuffleBlockId1), corruptBuffer1.size(), 1), + blockManagerId2 -> toBlockList(Seq(shuffleBlockId2), corruptBuffer2.size(), 2) + ), + streamWrapperLimitSize = Some(streamLength), + maxBytesInFlight = 3 * 1024 + ) // We'll get back the block which has corruption after maxBytesInFlight/3 because the other // block will detect corruption on first fetch, and then get added to the queue again for @@ -848,30 +715,15 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val localBmId = BlockManagerId("test-client", "test-client", 1) doReturn(localBmId).when(blockManager).blockManagerId doReturn(managedBuffer).when(blockManager).getLocalBlockData(meq(ShuffleBlockId(0, 0, 0))) - val localBlockLengths = Seq[Tuple3[BlockId, Long, Int]]( - (ShuffleBlockId(0, 0, 0), 10000, 0) - ) - val transfer = createMockTransfer(Map(ShuffleBlockId(0, 0, 0) -> managedBuffer)) - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( - (localBmId, localBlockLengths) - ).toIterator + configureMockTransfer(Map(ShuffleBlockId(0, 0, 0) -> managedBuffer)) - val taskContext = TaskContext.empty() - val iterator = new ShuffleBlockFetcherIterator( - taskContext, - transfer, - blockManager, - blocksByAddress, - (_, in) => new LimitedInputStream(in, 10000), - 2048, - Int.MaxValue, - Int.MaxValue, - Int.MaxValue, - true, - true, - taskContext.taskMetrics.createTempShuffleReadMetrics(), - false) - val (id, st) = iterator.next() + val iterator = createShuffleBlockIteratorWithDefaults( + Map(localBmId -> toBlockList(Seq(ShuffleBlockId(0, 0, 0)), 10000L, 0)), + blockManager = Some(blockManager), + streamWrapperLimitSize = Some(10000), + maxBytesInFlight = 2048 // force concatenation of stream by limiting bytes in flight + ) + val (_, st) = iterator.next() // Check that the test setup is correct -- make sure we have a concatenated stream. assert (st.asInstanceOf[BufferReleasingInputStream].delegate.isInstanceOf[SequenceInputStream]) @@ -884,7 +736,6 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT } test("retry corrupt blocks (disabled)") { - val blockManager = createMockBlockManager() // Make sure remote blocks would return val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val blocks = Map[BlockId, ManagedBuffer]( @@ -896,41 +747,25 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT // Semaphore to coordinate event sequence in two different threads. val sem = new Semaphore(0) - val transfer = mock(classOf[BlockTransferService]) - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) - .thenAnswer((invocation: InvocationOnMock) => { - val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - Future { - // Return the first block, and then fail. - listener.onBlockFetchSuccess( + answerFetchBlocks { invocation => + val listener = invocation.getArgument[BlockFetchingListener](4) + Future { + // Return the first block, and then fail. + listener.onBlockFetchSuccess( ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0))) - listener.onBlockFetchSuccess( - ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) - listener.onBlockFetchSuccess( - ShuffleBlockId(0, 2, 0).toString, mockCorruptBuffer()) - sem.release() - } - }) - - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1L, 0)).toSeq)) - .toIterator + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer()) + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 2, 0).toString, mockCorruptBuffer()) + sem.release() + } + } - val taskContext = TaskContext.empty() - val iterator = new ShuffleBlockFetcherIterator( - taskContext, - transfer, - blockManager, - blocksByAddress, - (_, in) => new LimitedInputStream(in, 100), - 48 * 1024 * 1024, - Int.MaxValue, - Int.MaxValue, - Int.MaxValue, - true, - false, - taskContext.taskMetrics.createTempShuffleReadMetrics(), - false) + val iterator = createShuffleBlockIteratorWithDefaults( + Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0)), + streamWrapperLimitSize = Some(100), + detectCorruptUseExtraMemory = false + ) // Continue only after the mock calls onBlockFetchFailure sem.acquire() @@ -958,57 +793,38 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val remoteBlocks = Map[BlockId, ManagedBuffer]( ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()) - val transfer = mock(classOf[BlockTransferService]) var tempFileManager: DownloadFileManager = null - when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) - .thenAnswer((invocation: InvocationOnMock) => { - val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] - tempFileManager = invocation.getArguments()(5).asInstanceOf[DownloadFileManager] - Future { - listener.onBlockFetchSuccess( - ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0))) - } - }) + answerFetchBlocks { invocation => + val listener = invocation.getArgument[BlockFetchingListener](4) + tempFileManager = invocation.getArgument[DownloadFileManager](5) + Future { + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0))) + } + } - def fetchShuffleBlock( - blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])]): Unit = { - // Set `maxBytesInFlight` and `maxReqsInFlight` to `Int.MaxValue`, so that during the + def fetchShuffleBlock(blockSize: Long): Unit = { + // Use default `maxBytesInFlight` and `maxReqsInFlight` (`Int.MaxValue`) so that during the // construction of `ShuffleBlockFetcherIterator`, all requests to fetch remote shuffle blocks // are issued. The `maxReqSizeShuffleToMem` is hard-coded as 200 here. - val taskContext = TaskContext.empty() - new ShuffleBlockFetcherIterator( - taskContext, - transfer, - blockManager, - blocksByAddress, - (_, in) => in, - maxBytesInFlight = Int.MaxValue, - maxReqsInFlight = Int.MaxValue, - maxBlocksInFlightPerAddress = Int.MaxValue, - maxReqSizeShuffleToMem = 200, - detectCorrupt = true, - false, - taskContext.taskMetrics.createTempShuffleReadMetrics(), - false) + createShuffleBlockIteratorWithDefaults( + Map(remoteBmId -> toBlockList(remoteBlocks.keys, blockSize, 0)), + blockManager = Some(blockManager), + maxReqSizeShuffleToMem = 200) } - val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( - (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L, 0)).toSeq)).toIterator - fetchShuffleBlock(blocksByAddress1) + fetchShuffleBlock(100L) // `maxReqSizeShuffleToMem` is 200, which is greater than the block size 100, so don't fetch // shuffle block to disk. assert(tempFileManager == null) - val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( - (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L, 0)).toSeq)).toIterator - fetchShuffleBlock(blocksByAddress2) + fetchShuffleBlock(300L) // `maxReqSizeShuffleToMem` is 200, which is smaller than the block size 300, so fetch // shuffle block to disk. assert(tempFileManager != null) } test("fail zero-size blocks") { - val blockManager = createMockBlockManager() // Make sure remote blocks would return val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) val blocks = Map[BlockId, ManagedBuffer]( @@ -1016,26 +832,11 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer() ) - val transfer = createMockTransfer(blocks.mapValues(_ => createMockManagedBuffer(0)).toMap) - - val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])]( - (remoteBmId, blocks.keys.map(blockId => (blockId, 1L, 0)).toSeq)) + configureMockTransfer(blocks.mapValues(_ => createMockManagedBuffer(0)).toMap) - val taskContext = TaskContext.empty() - val iterator = new ShuffleBlockFetcherIterator( - taskContext, - transfer, - blockManager, - blocksByAddress.toIterator, - (_, in) => in, - 48 * 1024 * 1024, - Int.MaxValue, - Int.MaxValue, - Int.MaxValue, - true, - false, - taskContext.taskMetrics.createTempShuffleReadMetrics(), - false) + val iterator = createShuffleBlockIteratorWithDefaults( + Map(remoteBmId -> toBlockList(blocks.keys, 1L, 0)) + ) // All blocks fetched return zero length and should trigger a receive-side error: val e = intercept[FetchFailedException] { iterator.next() } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org