Github user cloud-fan commented on a diff in the pull request: https://github.com/apache/spark/pull/16989#discussion_r116916099 --- Diff: core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala --- @@ -401,4 +429,146 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite with PrivateMethodT assert(id3 === ShuffleBlockId(0, 2, 0)) } + test("Blocks should be shuffled to disk when size of the request is above the" + + " threshold(maxReqSizeShuffleToMem.") { + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + + val diskBlockManager = mock(classOf[DiskBlockManager]) + doReturn(new File("shuffle-read-file")).when(diskBlockManager).getFile(any(classOf[String])) + doReturn(diskBlockManager).when(blockManager).diskBlockManager + + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val remoteBlocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer()) + val transfer = mock(classOf[BlockTransferService]) + var shuffleFiles: Array[File] = null + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { + override def answer(invocation: InvocationOnMock): Unit = { + val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] + shuffleFiles = invocation.getArguments()(5).asInstanceOf[Array[File]] + Future { + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0))) + } + } + }) + + val taskMemoryManager = createMockTaskMemoryManager() + val tc = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, new Properties, null) + + val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 100L)).toSeq)) + // Set maxReqSizeShuffleToMem to be 200. + val iterator1 = new ShuffleBlockFetcherIterator( + tc, + transfer, + blockManager, + blocksByAddress1, + (_, in) => in, + Int.MaxValue, + Int.MaxValue, + 200, + true, + taskMemoryManager) + assert(shuffleFiles === null) + + val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])]( + (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 300L)).toSeq)) + // Set maxReqSizeShuffleToMem to be 200. + val iterator2 = new ShuffleBlockFetcherIterator( + tc, + transfer, + blockManager, + blocksByAddress2, + (_, in) => in, + Int.MaxValue, + Int.MaxValue, + 200, + true, + taskMemoryManager) + assert(shuffleFiles != null) + } + + test("Blocks should be shuffled to disk when size is above memory threshold," + + " otherwise to memory.") { + val blockManager = mock(classOf[BlockManager]) + val localBmId = BlockManagerId("test-client", "test-client", 1) + doReturn(localBmId).when(blockManager).blockManagerId + + val diskBlockManager = mock(classOf[DiskBlockManager]) + doReturn(new File("shuffle-read-file")).when(diskBlockManager).getFile(any(classOf[String])) + doReturn(diskBlockManager).when(blockManager).diskBlockManager + + val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2) + val remoteBlocks = Map[BlockId, ManagedBuffer]( + ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(), + ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer()) + val transfer = mock(classOf[BlockTransferService]) + var shuffleFiles: Array[File] = null + when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any())) + .thenAnswer(new Answer[Unit] { + override def answer(invocation: InvocationOnMock): Unit = { + val listener = invocation.getArguments()(4).asInstanceOf[BlockFetchingListener] + shuffleFiles = invocation.getArguments()(5).asInstanceOf[Array[File]] + Future { + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 0, 0).toString, remoteBlocks(ShuffleBlockId(0, 0, 0))) + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 1, 0).toString, remoteBlocks(ShuffleBlockId(0, 1, 0))) + listener.onBlockFetchSuccess( + ShuffleBlockId(0, 2, 0).toString, remoteBlocks(ShuffleBlockId(0, 2, 0))) + } + } + }) + val taskMemoryManager = mock(classOf[TaskMemoryManager]) + when(taskMemoryManager.acquireExecutionMemory(any(), any())) + .thenAnswer(new Answer[Long] { + // 500 bytes at most can be offered from TaskMemoryManager. + override def answer(invocationOnMock: InvocationOnMock): Long = { + val required = invocationOnMock.getArguments()(0).asInstanceOf[Long] + if (required <= 500) { + return required + } else { + return 500 --- End diff -- ah i see, then let's revert the last change, always return 500 is fine
--- If your project is set up for it, you can reply to this email and have your reply appear on GitHub as well. If your project does not have this feature enabled and wishes so, or if the feature is enabled but not working, please contact infrastructure at infrastruct...@apache.org or file a JIRA ticket with INFRA. --- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org