Github user cloud-fan commented on a diff in the pull request:

    https://github.com/apache/spark/pull/16989#discussion_r116916099
  
    --- Diff: 
core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
 ---
    @@ -401,4 +429,146 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
         assert(id3 === ShuffleBlockId(0, 2, 0))
       }
     
    +  test("Blocks should be shuffled to disk when size of the request is 
above the" +
    +    " threshold(maxReqSizeShuffleToMem.") {
    +    val blockManager = mock(classOf[BlockManager])
    +    val localBmId = BlockManagerId("test-client", "test-client", 1)
    +    doReturn(localBmId).when(blockManager).blockManagerId
    +
    +    val diskBlockManager = mock(classOf[DiskBlockManager])
    +    doReturn(new 
File("shuffle-read-file")).when(diskBlockManager).getFile(any(classOf[String]))
    +    doReturn(diskBlockManager).when(blockManager).diskBlockManager
    +
    +    val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
    +    val remoteBlocks = Map[BlockId, ManagedBuffer](
    +      ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer())
    +    val transfer = mock(classOf[BlockTransferService])
    +    var shuffleFiles: Array[File] = null
    +    when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
    +      .thenAnswer(new Answer[Unit] {
    +        override def answer(invocation: InvocationOnMock): Unit = {
    +          val listener = 
invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
    +          shuffleFiles = 
invocation.getArguments()(5).asInstanceOf[Array[File]]
    +          Future {
    +            listener.onBlockFetchSuccess(
    +              ShuffleBlockId(0, 0, 0).toString, 
remoteBlocks(ShuffleBlockId(0, 0, 0)))
    +          }
    +        }
    +      })
    +
    +    val taskMemoryManager = createMockTaskMemoryManager()
    +    val tc = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, new 
Properties, null)
    +
    +    val blocksByAddress1 = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
    +      (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 
100L)).toSeq))
    +    // Set maxReqSizeShuffleToMem to be 200.
    +    val iterator1 = new ShuffleBlockFetcherIterator(
    +      tc,
    +      transfer,
    +      blockManager,
    +      blocksByAddress1,
    +      (_, in) => in,
    +      Int.MaxValue,
    +      Int.MaxValue,
    +      200,
    +      true,
    +      taskMemoryManager)
    +    assert(shuffleFiles === null)
    +
    +    val blocksByAddress2 = Seq[(BlockManagerId, Seq[(BlockId, Long)])](
    +      (remoteBmId, remoteBlocks.keys.map(blockId => (blockId, 
300L)).toSeq))
    +    // Set maxReqSizeShuffleToMem to be 200.
    +    val iterator2 = new ShuffleBlockFetcherIterator(
    +      tc,
    +      transfer,
    +      blockManager,
    +      blocksByAddress2,
    +      (_, in) => in,
    +      Int.MaxValue,
    +      Int.MaxValue,
    +      200,
    +      true,
    +      taskMemoryManager)
    +    assert(shuffleFiles != null)
    +  }
    +
    +  test("Blocks should be shuffled to disk when size is above memory 
threshold," +
    +    " otherwise to memory.") {
    +    val blockManager = mock(classOf[BlockManager])
    +    val localBmId = BlockManagerId("test-client", "test-client", 1)
    +    doReturn(localBmId).when(blockManager).blockManagerId
    +
    +    val diskBlockManager = mock(classOf[DiskBlockManager])
    +    doReturn(new 
File("shuffle-read-file")).when(diskBlockManager).getFile(any(classOf[String]))
    +    doReturn(diskBlockManager).when(blockManager).diskBlockManager
    +
    +    val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
    +    val remoteBlocks = Map[BlockId, ManagedBuffer](
    +      ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(),
    +      ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(),
    +      ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer())
    +    val transfer = mock(classOf[BlockTransferService])
    +    var shuffleFiles: Array[File] = null
    +    when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
    +      .thenAnswer(new Answer[Unit] {
    +        override def answer(invocation: InvocationOnMock): Unit = {
    +          val listener = 
invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
    +          shuffleFiles = 
invocation.getArguments()(5).asInstanceOf[Array[File]]
    +          Future {
    +            listener.onBlockFetchSuccess(
    +              ShuffleBlockId(0, 0, 0).toString, 
remoteBlocks(ShuffleBlockId(0, 0, 0)))
    +            listener.onBlockFetchSuccess(
    +              ShuffleBlockId(0, 1, 0).toString, 
remoteBlocks(ShuffleBlockId(0, 1, 0)))
    +            listener.onBlockFetchSuccess(
    +              ShuffleBlockId(0, 2, 0).toString, 
remoteBlocks(ShuffleBlockId(0, 2, 0)))
    +          }
    +        }
    +      })
    +    val taskMemoryManager = mock(classOf[TaskMemoryManager])
    +    when(taskMemoryManager.acquireExecutionMemory(any(), any()))
    +      .thenAnswer(new Answer[Long] {
    +        // 500 bytes at most can be offered from TaskMemoryManager.
    +        override def answer(invocationOnMock: InvocationOnMock): Long = {
    +          val required = 
invocationOnMock.getArguments()(0).asInstanceOf[Long]
    +          if (required <= 500) {
    +            return required
    +          } else {
    +            return 500
    --- End diff --
    
    ah i see, then let's revert the last change, always return 500 is fine


---
If your project is set up for it, you can reply to this email and have your
reply appear on GitHub as well. If your project does not have this feature
enabled and wishes so, or if the feature is enabled but not working, please
contact infrastructure at infrastruct...@apache.org or file a JIRA ticket
with INFRA.
---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to