otterc commented on a change in pull request #32140:
URL: https://github.com/apache/spark/pull/32140#discussion_r646807905



##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +93,124 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && 
areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, 
execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, 
blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push 
based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged 
shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, 
they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
     for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
+      if (!blockId.startsWith(SHUFFLE_BLOCK_PREFIX) &&
+          !blockId.startsWith(SHUFFLE_CHUNK_PREFIX)) {
         return false;
       }
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link 
FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, 
true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, 
false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild 
internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For 
FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new 
LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
+      }
+      if (!primaryIdToBlocksInfo.containsKey(primaryId)) {
+        primaryIdToBlocksInfo.put(primaryId, new BlocksInfo());
       }
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
-      blocksInfoByMapId.blockIds.add(blockId);
-      blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      blocksInfoByPrimaryId.blockIds.add(blockId);
+      blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));
       if (batchFetchEnabled) {
         // When we read continuous shuffle blocks in batch, we will reuse 
reduceIds in
         // FetchShuffleBlocks to store the start and end reduce id for range
         // [startReduceId, endReduceId).
         assert(blockIdParts.length == 5);
-        blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4]));
+        blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4]));
       }
     }
-    long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet());
-    int[][] reduceIdArr = new int[mapIds.length][];
+    Set<Number> primaryIds = primaryIdToBlocksInfo.keySet();
+    // In case of FetchShuffleBlocks, secondaryIds are reduceIds. For 
FetchShuffleBlockChunks,
+    // secondaryIds are chunkIds.
+    int[][] secondaryIdsArray = new int[primaryIds.size()][];
     int blockIdIndex = 0;
-    for (int i = 0; i < mapIds.length; i++) {
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]);
-      reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds);
+    int secIndex = 0;
+    for (Number primaryId : primaryIds) {
+      BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.get(primaryId);
+      secondaryIdsArray[secIndex++] = Ints.toArray(blocksInfoByPrimaryId.ids);
 
-      // The `blockIds`'s order must be same with the read order specified in 
in FetchShuffleBlocks
-      // because the shuffle data's return order should match the `blockIds`'s 
order to ensure
-      // blockId and data match.
-      for (int j = 0; j < blocksInfoByMapId.blockIds.size(); j++) {
-        this.blockIds[blockIdIndex++] = blocksInfoByMapId.blockIds.get(j);
+      // The `blockIds`'s order must be same with the read order specified in 
FetchShuffleBlocks/
+      // FetchShuffleBlockChunks because the shuffle data's return order 
should match the
+      // `blockIds`'s order to ensure blockId and data match.
+      for (int j = 0; j < blocksInfoByPrimaryId.blockIds.size(); j++) {
+        this.blockIds[blockIdIndex++] = blocksInfoByPrimaryId.blockIds.get(j);
       }
     }
     assert(blockIdIndex == this.blockIds.length);
-
-    return new FetchShuffleBlocks(
-      appId, execId, shuffleId, mapIds, reduceIdArr, batchFetchEnabled);
+    if (!areMergedChunks) {
+      long[] mapIds = Longs.toArray(primaryIds);

Review comment:
       This is invoked for each fetch request from the client. A fetch request 
is for multiple blocks. Since it is per fetch request, it is still frequent so 
I will create utils for it.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to