xkrogen commented on a change in pull request #32389:
URL: https://github.com/apache/spark/pull/32389#discussion_r633725651



##########
File path: 
core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
##########
@@ -123,6 +131,42 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
     verify(wrappedInputStream.invokePrivate(delegateAccess()), 
times(1)).close()
   }
 
+  // scalastyle:off argcount
+  private def createShuffleBlockIteratorWithDefaults(
+      blocksByAddress: Map[BlockManagerId, (Traversable[BlockId], Long, Int)],
+      taskContext: Option[TaskContext] = None,
+      streamWrapperLimitSize: Option[Long] = None,
+      blockManager: Option[BlockManager] = None,
+      maxBytesInFlight: Long = Long.MaxValue,
+      maxReqsInFlight: Int = Int.MaxValue,
+      maxBlocksInFlightPerAddress: Int = Int.MaxValue,
+      maxReqSizeShuffleToMem: Int = Int.MaxValue,
+      detectCorrupt: Boolean = true,
+      detectCorruptUseExtraMemory: Boolean = true,
+      shuffleMetrics: Option[ShuffleReadMetricsReporter] = None,
+      doBatchFetch: Boolean = false): ShuffleBlockFetcherIterator = {
+    val tContext = taskContext.getOrElse(TaskContext.empty())
+    new ShuffleBlockFetcherIterator(
+      tContext,
+      transfer,
+      blockManager.getOrElse(createMockBlockManager()),
+      blocksByAddress.map { case (blockManagerId, (blocks, blockSize, 
blockMapIndex)) =>
+        (blockManagerId, blocks.map(blockId => (blockId, blockSize, 
blockMapIndex)).toSeq)
+      }.toIterator,
+      streamWrapperLimitSize
+        .map(limit => (_: BlockId, in: InputStream) => new 
LimitedInputStream(in, limit))
+        .getOrElse((_: BlockId, in: InputStream) => in),

Review comment:
       good suggestion! incorporated.

##########
File path: 
core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
##########
@@ -703,40 +600,24 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
     val sem = new Semaphore(0)
     val corruptLocalBuffer = new FileSegmentManagedBuffer(null, new File("a"), 
0, 100)
 
-    val transfer = mock(classOf[BlockTransferService])
-    when(transfer.fetchBlocks(any(), any(), any(), any(), any(), any()))
-      .thenAnswer((invocation: InvocationOnMock) => {
-        val listener = 
invocation.getArguments()(4).asInstanceOf[BlockFetchingListener]
-        Future {
-          // Return the first block, and then fail.
-          listener.onBlockFetchSuccess(
-            ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
-          listener.onBlockFetchSuccess(
-            ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer())
-          listener.onBlockFetchSuccess(
-            ShuffleBlockId(0, 2, 0).toString, corruptLocalBuffer)
-          sem.release()
-        }
-      })
-
-    val blocksByAddress = Seq[(BlockManagerId, Seq[(BlockId, Long, Int)])](
-      (remoteBmId, blocks.keys.map(blockId => (blockId, 1L, 
0)).toSeq)).toIterator
+    answerFetchBlocks { invocation =>
+      val listener = invocation.getArgument[BlockFetchingListener](4)
+      Future {
+        // Return the first block, and then fail.
+        listener.onBlockFetchSuccess(
+          ShuffleBlockId(0, 0, 0).toString, blocks(ShuffleBlockId(0, 0, 0)))
+        listener.onBlockFetchSuccess(
+          ShuffleBlockId(0, 1, 0).toString, mockCorruptBuffer())
+        listener.onBlockFetchSuccess(
+          ShuffleBlockId(0, 2, 0).toString, corruptLocalBuffer)
+        sem.release()
+      }
+    }
 
-    val taskContext = TaskContext.empty()
-    val iterator = new ShuffleBlockFetcherIterator(
-      taskContext,
-      transfer,
-      blockManager,
-      blocksByAddress,
-      (_, in) => new LimitedInputStream(in, 100),
-      48 * 1024 * 1024,
-      Int.MaxValue,
-      Int.MaxValue,
-      Int.MaxValue,
-      true,
-      true,
-      taskContext.taskMetrics.createTempShuffleReadMetrics(),
-      false)
+    val iterator = createShuffleBlockIteratorWithDefaults(
+      Map(remoteBmId ->(blocks.keys, 1L, 0)),

Review comment:
       thanks!

##########
File path: 
core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
##########
@@ -123,6 +131,42 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
     verify(wrappedInputStream.invokePrivate(delegateAccess()), 
times(1)).close()
   }
 
+  // scalastyle:off argcount
+  private def createShuffleBlockIteratorWithDefaults(
+      blocksByAddress: Map[BlockManagerId, (Traversable[BlockId], Long, Int)],
+      taskContext: Option[TaskContext] = None,
+      streamWrapperLimitSize: Option[Long] = None,
+      blockManager: Option[BlockManager] = None,
+      maxBytesInFlight: Long = Long.MaxValue,
+      maxReqsInFlight: Int = Int.MaxValue,
+      maxBlocksInFlightPerAddress: Int = Int.MaxValue,
+      maxReqSizeShuffleToMem: Int = Int.MaxValue,
+      detectCorrupt: Boolean = true,
+      detectCorruptUseExtraMemory: Boolean = true,
+      shuffleMetrics: Option[ShuffleReadMetricsReporter] = None,
+      doBatchFetch: Boolean = false): ShuffleBlockFetcherIterator = {
+    val tContext = taskContext.getOrElse(TaskContext.empty())
+    new ShuffleBlockFetcherIterator(
+      tContext,
+      transfer,
+      blockManager.getOrElse(createMockBlockManager()),
+      blocksByAddress.map { case (blockManagerId, (blocks, blockSize, 
blockMapIndex)) =>
+        (blockManagerId, blocks.map(blockId => (blockId, blockSize, 
blockMapIndex)).toSeq)
+      }.toIterator,

Review comment:
       I think this is closely related to your [other 
comment](https://github.com/apache/spark/pull/32389#discussion_r632823782), so 
I'll respond to both here.
   
   I agree that this is tailored to the current set of tests and thus current 
usage, as opposed to potential future usage. If we were designing a public API, 
or even a private API in the production code, I would agree with you. But in 
this case for a class-private method in a test file, I'm not convinced that 
designing for future possibilities is the right move. If someone later adds 
tests which do need blocks of different sizes, they can make the modification 
you've described, right? For now I would preference simplicity, and we can 
introduce the additional complexity if necessary.
   
   WDYT? This is not a strong conviction on my side so if you are not convinced 
I can make changes as you proposed.

##########
File path: 
core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
##########
@@ -123,6 +131,42 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
     verify(wrappedInputStream.invokePrivate(delegateAccess()), 
times(1)).close()
   }
 
+  // scalastyle:off argcount
+  private def createShuffleBlockIteratorWithDefaults(
+      blocksByAddress: Map[BlockManagerId, (Traversable[BlockId], Long, Int)],
+      taskContext: Option[TaskContext] = None,
+      streamWrapperLimitSize: Option[Long] = None,
+      blockManager: Option[BlockManager] = None,
+      maxBytesInFlight: Long = Long.MaxValue,
+      maxReqsInFlight: Int = Int.MaxValue,
+      maxBlocksInFlightPerAddress: Int = Int.MaxValue,
+      maxReqSizeShuffleToMem: Int = Int.MaxValue,
+      detectCorrupt: Boolean = true,
+      detectCorruptUseExtraMemory: Boolean = true,
+      shuffleMetrics: Option[ShuffleReadMetricsReporter] = None,
+      doBatchFetch: Boolean = false): ShuffleBlockFetcherIterator = {
+    val tContext = taskContext.getOrElse(TaskContext.empty())
+    new ShuffleBlockFetcherIterator(
+      tContext,
+      transfer,
+      blockManager.getOrElse(createMockBlockManager()),
+      blocksByAddress.map { case (blockManagerId, (blocks, blockSize, 
blockMapIndex)) =>
+        (blockManagerId, blocks.map(blockId => (blockId, blockSize, 
blockMapIndex)).toSeq)
+      }.toIterator,
+      streamWrapperLimitSize
+        .map(limit => (_: BlockId, in: InputStream) => new 
LimitedInputStream(in, limit))
+        .getOrElse((_: BlockId, in: InputStream) => in),
+      maxBytesInFlight,
+      maxReqsInFlight,
+      maxBlocksInFlightPerAddress,
+      maxReqSizeShuffleToMem,
+      detectCorrupt,
+      detectCorruptUseExtraMemory,
+      
shuffleMetrics.getOrElse(tContext.taskMetrics().createTempShuffleReadMetrics()),

Review comment:
       `taskMetrics()` is an empty-paren method so it should be called with 
parenthesis:
   ```
     def taskMetrics(): TaskMetrics
   ```

##########
File path: 
core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala
##########
@@ -123,6 +131,42 @@ class ShuffleBlockFetcherIteratorSuite extends 
SparkFunSuite with PrivateMethodT
     verify(wrappedInputStream.invokePrivate(delegateAccess()), 
times(1)).close()
   }
 
+  // scalastyle:off argcount
+  private def createShuffleBlockIteratorWithDefaults(
+      blocksByAddress: Map[BlockManagerId, (Traversable[BlockId], Long, Int)],
+      taskContext: Option[TaskContext] = None,
+      streamWrapperLimitSize: Option[Long] = None,
+      blockManager: Option[BlockManager] = None,
+      maxBytesInFlight: Long = Long.MaxValue,
+      maxReqsInFlight: Int = Int.MaxValue,
+      maxBlocksInFlightPerAddress: Int = Int.MaxValue,
+      maxReqSizeShuffleToMem: Int = Int.MaxValue,
+      detectCorrupt: Boolean = true,
+      detectCorruptUseExtraMemory: Boolean = true,
+      shuffleMetrics: Option[ShuffleReadMetricsReporter] = None,
+      doBatchFetch: Boolean = false): ShuffleBlockFetcherIterator = {
+    val tContext = taskContext.getOrElse(TaskContext.empty())
+    new ShuffleBlockFetcherIterator(
+      tContext,
+      transfer,
+      blockManager.getOrElse(createMockBlockManager()),
+      blocksByAddress.map { case (blockManagerId, (blocks, blockSize, 
blockMapIndex)) =>
+        (blockManagerId, blocks.map(blockId => (blockId, blockSize, 
blockMapIndex)).toSeq)
+      }.toIterator,
+      streamWrapperLimitSize
+        .map(limit => (_: BlockId, in: InputStream) => new 
LimitedInputStream(in, limit))
+        .getOrElse((_: BlockId, in: InputStream) => in),
+      maxBytesInFlight,
+      maxReqsInFlight,
+      maxBlocksInFlightPerAddress,
+      maxReqSizeShuffleToMem,
+      detectCorrupt,
+      detectCorruptUseExtraMemory,
+      
shuffleMetrics.getOrElse(tContext.taskMetrics().createTempShuffleReadMetrics()),

Review comment:
       `taskMetrics()` is defined as an empty-paren method so it should be 
called with parenthesis:
   ```
     def taskMetrics(): TaskMetrics
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to