mridulm commented on a change in pull request #32811:
URL: https://github.com/apache/spark/pull/32811#discussion_r651973359



##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
##########
@@ -294,18 +336,30 @@ public ShuffleMetrics() {
     private int index = 0;
     private final Function<Integer, ManagedBuffer> blockDataForIndexFn;
     private final int size;
+    private boolean requestForMergedBlockChunks;
 
     ManagedBufferIterator(OpenBlocks msg) {
       String appId = msg.appId;
       String execId = msg.execId;
       String[] blockIds = msg.blockIds;
       String[] blockId0Parts = blockIds[0].split("_");
-      if (blockId0Parts.length == 4 && blockId0Parts[0].equals("shuffle")) {
+      if (blockId0Parts.length == 4 && 
(blockId0Parts[0].equals(SHUFFLE_BLOCK_ID) ||
+        blockId0Parts[0].equals(SHUFFLE_CHUNK_ID))) {
         final int shuffleId = Integer.parseInt(blockId0Parts[1]);
-        final int[] mapIdAndReduceIds = shuffleMapIdAndReduceIds(blockIds, 
shuffleId);
-        size = mapIdAndReduceIds.length;
-        blockDataForIndexFn = index -> blockManager.getBlockData(appId, 
execId, shuffleId,
-          mapIdAndReduceIds[index], mapIdAndReduceIds[index + 1]);
+        requestForMergedBlockChunks = 
blockId0Parts[0].equals(SHUFFLE_CHUNK_ID);
+        // For regular shuffle blocks, primaryId is mapId and secondaryIds are 
reduceIds.
+        // For shuffle chunks, primaryIds is reduceId and secondaryIds are 
chunkIds.
+        final int[] primaryIdAndSecondaryIds = 
shuffleMapIdAndReduceIds(blockIds, shuffleId);
+        size = primaryIdAndSecondaryIds.length;
+        blockDataForIndexFn = index -> {
+          if (requestForMergedBlockChunks) {
+            return mergeManager.getMergedBlockData(msg.appId, shuffleId,
+              primaryIdAndSecondaryIds[index], primaryIdAndSecondaryIds[index 
+ 1]);
+          } else {
+            return blockManager.getBlockData(msg.appId, msg.execId, shuffleId,
+              primaryIdAndSecondaryIds[index], primaryIdAndSecondaryIds[index 
+ 1]);
+          }
+        };

Review comment:
       nit: Wondering if this is cleaner if we simply split this out into its 
own else block for block chunk ?

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +94,125 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && 
areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, 
execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, 
blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
-    for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
-        return false;
-      }
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push 
based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged 
shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, 
they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
+    if (Arrays.stream(blockIds).anyMatch(blockId -> 
!blockId.startsWith(SHUFFLE_BLOCK_PREFIX))) {
+      // It comes here because there is a blockId which doesn't have 
"shuffle_" prefix so we
+      // check if all the block ids are shuffle chunk Ids.
+      return Arrays.stream(blockIds).allMatch(blockId -> 
blockId.startsWith(SHUFFLE_CHUNK_PREFIX));
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link 
FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, 
true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, 
false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild 
internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For 
FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new 
LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
       }
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
-      blocksInfoByMapId.blockIds.add(blockId);
-      blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
+      BlocksInfo blocksInfoByPrimaryId = 
primaryIdToBlocksInfo.computeIfAbsent(primaryId,
+        id -> new BlocksInfo());
+      blocksInfoByPrimaryId.blockIds.add(blockId);
+      // If blockId is a regular shuffle block, then blockIdParts[3] = 
reduceId. If blockId is a
+      // shuffleChunk block, then blockIdParts[3] = chunkId
+      blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));
       if (batchFetchEnabled) {
+        // It comes here only if the blockId is a regular shuffle block not a 
shuffleChunk block.
         // When we read continuous shuffle blocks in batch, we will reuse 
reduceIds in
         // FetchShuffleBlocks to store the start and end reduce id for range
         // [startReduceId, endReduceId).
         assert(blockIdParts.length == 5);
-        blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4]));
+        // blockIdParts[4] is the end reduce id for the batch range
+        blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4]));
       }
     }
-    long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet());
-    int[][] reduceIdArr = new int[mapIds.length][];
+    // In case of FetchShuffleBlocks, secondaryIds are reduceIds. For 
FetchShuffleBlockChunks,
+    // secondaryIds are chunkIds.
+    int[][] secondaryIdsArray = new int[primaryIdToBlocksInfo.size()][];
     int blockIdIndex = 0;
-    for (int i = 0; i < mapIds.length; i++) {
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]);
-      reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds);
+    int secIndex = 0;
+    for (Map.Entry<Number, BlocksInfo> entry : 
primaryIdToBlocksInfo.entrySet()) {

Review comment:
       `primaryIdToBlocksInfo.entrySet()` -> `primaryIdToBlocksInfo.values()`

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java
##########
@@ -54,8 +59,12 @@
  * Blocks are registered with the "one-for-one" strategy, meaning each 
Transport-layer Chunk
  * is equivalent to one block.
  */
-public class ExternalBlockHandler extends RpcHandler {
+public class ExternalBlockHandler extends RpcHandler
+    implements RpcHandler.MergedBlockMetaReqHandler {
   private static final Logger logger = 
LoggerFactory.getLogger(ExternalBlockHandler.class);
+  private static final String SHUFFLE_MERGER_IDENTIFIER = 
"shuffle-push-merger";

Review comment:
       @Ngone51 This is another case of duplication between common modules and 
spark core (we had seen other cases where code is getting duplicated) ... we 
should find a way to unify them.

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +94,125 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && 
areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, 
execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, 
blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
-    for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
-        return false;
-      }
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push 
based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged 
shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, 
they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
+    if (Arrays.stream(blockIds).anyMatch(blockId -> 
!blockId.startsWith(SHUFFLE_BLOCK_PREFIX))) {
+      // It comes here because there is a blockId which doesn't have 
"shuffle_" prefix so we
+      // check if all the block ids are shuffle chunk Ids.
+      return Arrays.stream(blockIds).allMatch(blockId -> 
blockId.startsWith(SHUFFLE_CHUNK_PREFIX));
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link 
FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, 
true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, 
false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild 
internal blockIds by
    * analyzing the pass in blockIds.

Review comment:
       super nit: `pass` -> `passed` (not from your pr)

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunks.java
##########
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.shuffle.protocol;
+
+import java.util.Arrays;
+
+import io.netty.buffer.ByteBuf;
+
+import org.apache.spark.network.protocol.Encoders;
+
+// Needed by ScalaDoc. See SPARK-7726
+import static 
org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type;
+
+
+/**
+ * Request to read a set of block chunks. Returns {@link StreamHandle}.
+ *
+ * @since 3.2.0
+ */
+public class FetchShuffleBlockChunks extends AbstractFetchShuffleBlocks {
+  // The length of reduceIds must equal to chunkIds.size().
+  public final int[] reduceIds;
+  // The i-th int[] in chunkIds contains all the chunks for the i-th reduceId 
in reduceIds.
+  public final int[][] chunkIds;
+
+  public FetchShuffleBlockChunks(
+      String appId,
+      String execId,
+      int shuffleId,
+      int[] reduceIds,
+      int[][] chunkIds) {
+    super(appId, execId, shuffleId);
+    this.reduceIds = reduceIds;
+    this.chunkIds = chunkIds;
+    assert(reduceIds.length == chunkIds.length);
+  }
+
+  @Override
+  protected Type type() { return Type.FETCH_SHUFFLE_BLOCK_CHUNKS; }
+
+  @Override
+  public String toString() {
+    return toStringHelper()
+      .append("reduceIds", Arrays.toString(reduceIds))
+      .append("chunkIds", Arrays.deepToString(chunkIds))
+      .toString();
+  }
+
+  @Override
+  public boolean equals(Object o) {
+    if (this == o) return true;
+    if (o == null || getClass() != o.getClass()) return false;
+
+    FetchShuffleBlockChunks that = (FetchShuffleBlockChunks) o;
+    if (!super.equals(that)) return false;
+    if (!Arrays.equals(reduceIds, that.reduceIds)) return false;
+    return Arrays.deepEquals(chunkIds, that.chunkIds);
+  }
+
+  @Override
+  public int hashCode() {
+    int result = super.hashCode();
+    result = 31 * result + Arrays.hashCode(reduceIds);
+    result = 31 * result + Arrays.deepHashCode(chunkIds);
+    return result;
+  }
+
+  @Override
+  public int encodedLength() {
+    int encodedLengthOfChunkIds = 0;
+    for (int[] ids: chunkIds) {
+      encodedLengthOfChunkIds += Encoders.IntArrays.encodedLength(ids);
+    }
+    return super.encodedLength()
+      + Encoders.IntArrays.encodedLength(reduceIds)
+      + 4 /* encoded length of chunkIds.size() */
+      + encodedLengthOfChunkIds;
+  }
+
+  @Override
+  public void encode(ByteBuf buf) {
+    super.encode(buf);
+    Encoders.IntArrays.encode(buf, reduceIds);
+    buf.writeInt(chunkIds.length);

Review comment:
       Add a note that even though `reduceIds.length` == `chunkIds.length`, we 
are explicitly setting the length in interest of forward compatibility ?

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +94,125 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && 
areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, 
execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, 
blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);

Review comment:
       `OpenBlocks` does not support chunk right ?
   Do we want to error out in case there are chunk blocks ?

##########
File path: 
common/network-common/src/main/java/org/apache/spark/network/client/MergedBlockMetaResponseCallback.java
##########
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.client;
+
+import org.apache.spark.network.buffer.ManagedBuffer;
+
+/**
+ * Callback for the result of a single
+ * {@link org.apache.spark.network.protocol.MergedBlockMetaRequest}.
+ *
+ * @since 3.2.0
+ */
+public interface MergedBlockMetaResponseCallback extends BaseResponseCallback {
+  /**
+   * Called upon receipt of a particular merged block meta.
+   *
+   * The given buffer will initially have a refcount of 1, but will be 
release()'d as soon as this
+   * call returns. You must therefore either retain() the buffer or copy its 
contents before
+   * returning.
+   */
+  void onSuccess(int numChunks, ManagedBuffer buffer);

Review comment:
       Can you add doc on what `numChunks` and `buffer` refers to here ?

##########
File path: 
common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaRequest.java
##########
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.network.protocol;
+
+import com.google.common.base.Objects;
+import io.netty.buffer.ByteBuf;
+import org.apache.commons.lang3.builder.ToStringBuilder;
+import org.apache.commons.lang3.builder.ToStringStyle;
+
+/**
+ * Request to find the meta information for the specified merged block. The 
meta information
+ * contains the number of chunks in the merged blocks and the maps ids in each 
chunk.
+ *
+ * @since 3.2.0
+ */
+public class MergedBlockMetaRequest extends AbstractMessage implements 
RequestMessage {
+  public final long requestId;
+  public final String appId;
+  public final int shuffleId;
+  public final int reduceId;
+
+  public MergedBlockMetaRequest(long requestId, String appId, int shuffleId, 
int reduceId) {
+    super(null, false);
+    this.requestId = requestId;
+    this.appId = appId;
+    this.shuffleId = shuffleId;
+    this.reduceId = reduceId;
+  }
+
+  @Override
+  public Type type() {
+    return Type.MergedBlockMetaRequest;
+  }
+
+  @Override
+  public int encodedLength() {
+    return 8 + Encoders.Strings.encodedLength(appId) + 8;

Review comment:
       super nit: `+ 8` -> `+ 4 + 4`

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java
##########
@@ -187,6 +188,39 @@ public void onFailure(Throwable e) {
     }
   }
 
+  @Override
+  public void getMergedBlockMeta(
+      String host,
+      int port,
+      int shuffleId,
+      int reduceId,
+      MergedBlocksMetaListener listener) {
+    checkInit();
+    logger.debug("Get merged blocks meta from {}:{} for shuffleId {} reduceId 
{}", host, port,
+      shuffleId, reduceId);
+    try {
+      TransportClient client = clientFactory.createClient(host, port);
+      client.sendMergedBlockMetaReq(appId, shuffleId, reduceId,
+        new MergedBlockMetaResponseCallback() {
+          @Override
+          public void onSuccess(int numChunks, ManagedBuffer buffer) {
+            logger.trace("Successfully got merged block meta for shuffleId {} 
reduceId {}",
+              shuffleId, reduceId);
+            listener.onSuccess(shuffleId, reduceId, new 
MergedBlockMeta(numChunks, buffer));
+          }
+
+          @Override
+          public void onFailure(Throwable e) {
+            logger.error("Failed while getting merged block meta", e);

Review comment:
       Listener's `onFailure` will log this - do we need this ? (here and below)

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +94,125 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && 
areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, 
execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, 
blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
-    for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
-        return false;
-      }
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push 
based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged 
shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, 
they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
+    if (Arrays.stream(blockIds).anyMatch(blockId -> 
!blockId.startsWith(SHUFFLE_BLOCK_PREFIX))) {
+      // It comes here because there is a blockId which doesn't have 
"shuffle_" prefix so we
+      // check if all the block ids are shuffle chunk Ids.
+      return Arrays.stream(blockIds).allMatch(blockId -> 
blockId.startsWith(SHUFFLE_CHUNK_PREFIX));
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link 
FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, 
true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, 
false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild 
internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For 
FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new 
LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
       }
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
-      blocksInfoByMapId.blockIds.add(blockId);
-      blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
+      BlocksInfo blocksInfoByPrimaryId = 
primaryIdToBlocksInfo.computeIfAbsent(primaryId,
+        id -> new BlocksInfo());
+      blocksInfoByPrimaryId.blockIds.add(blockId);
+      // If blockId is a regular shuffle block, then blockIdParts[3] = 
reduceId. If blockId is a
+      // shuffleChunk block, then blockIdParts[3] = chunkId
+      blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));
       if (batchFetchEnabled) {
+        // It comes here only if the blockId is a regular shuffle block not a 
shuffleChunk block.
         // When we read continuous shuffle blocks in batch, we will reuse 
reduceIds in
         // FetchShuffleBlocks to store the start and end reduce id for range
         // [startReduceId, endReduceId).
         assert(blockIdParts.length == 5);
-        blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4]));
+        // blockIdParts[4] is the end reduce id for the batch range
+        blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4]));
       }
     }
-    long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet());
-    int[][] reduceIdArr = new int[mapIds.length][];
+    // In case of FetchShuffleBlocks, secondaryIds are reduceIds. For 
FetchShuffleBlockChunks,
+    // secondaryIds are chunkIds.
+    int[][] secondaryIdsArray = new int[primaryIdToBlocksInfo.size()][];
     int blockIdIndex = 0;
-    for (int i = 0; i < mapIds.length; i++) {
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]);
-      reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds);
+    int secIndex = 0;
+    for (Map.Entry<Number, BlocksInfo> entry : 
primaryIdToBlocksInfo.entrySet()) {
+      BlocksInfo blocksInfoByPrimaryId = entry.getValue();
+      secondaryIdsArray[secIndex++] = Ints.toArray(blocksInfoByPrimaryId.ids);
 
-      // The `blockIds`'s order must be same with the read order specified in 
in FetchShuffleBlocks
-      // because the shuffle data's return order should match the `blockIds`'s 
order to ensure
-      // blockId and data match.
-      for (int j = 0; j < blocksInfoByMapId.blockIds.size(); j++) {
-        this.blockIds[blockIdIndex++] = blocksInfoByMapId.blockIds.get(j);
+      // The `blockIds`'s order must be same with the read order specified in 
FetchShuffleBlocks/
+      // FetchShuffleBlockChunks because the shuffle data's return order 
should match the
+      // `blockIds`'s order to ensure blockId and data match.
+      for (String blockId : blocksInfoByPrimaryId.blockIds) {
+        this.blockIds[blockIdIndex++] = blockId;
       }
     }
     assert(blockIdIndex == this.blockIds.length);
-
-    return new FetchShuffleBlocks(
-      appId, execId, shuffleId, mapIds, reduceIdArr, batchFetchEnabled);
+    Set<Number> primaryIds = primaryIdToBlocksInfo.keySet();
+    if (!areMergedChunks) {
+      long[] mapIds = Longs.toArray(primaryIds);
+      return new FetchShuffleBlocks(
+        appId, execId, shuffleId, mapIds, secondaryIdsArray, 
batchFetchEnabled);
+    } else {
+      int[] reduceIds = Ints.toArray(primaryIds);
+      return new FetchShuffleBlockChunks(appId, execId, shuffleId, reduceIds, 
secondaryIdsArray);
+    }
   }
 
   /** Split the shuffleBlockId and return shuffleId, mapId and reduceIds. */
   private String[] splitBlockId(String blockId) {
     String[] blockIdParts = blockId.split("_");
     // For batch block id, the format contains shuffleId, mapId, begin 
reduceId, end reduceId.
     // For single block id, the format contains shuffleId, mapId, educeId.
-    if (blockIdParts.length < 4 || blockIdParts.length > 5 || 
!blockIdParts[0].equals("shuffle")) {
+    // For single block chunk id, the format contains shuffleId, reduceId, 
chunkId.
+    if (blockIdParts.length < 4 || blockIdParts.length > 5 ||
+      !(blockIdParts[0].equals(SHUFFLE_BLOCK_PREFIX) ||
+        !blockIdParts[0].equals(SHUFFLE_CHUNK_PREFIX))) {

Review comment:
       `!blockIdParts[0].equals(SHUFFLE_CHUNK_PREFIX)` -> 
`blockIdParts[0].equals(SHUFFLE_CHUNK_PREFIX)` ?
   
   It is `not` on either of them matching, right ?
   
   If this is valid comment, please do add some test to catch this - we might 
have future changes here which might cause bugs as well.

##########
File path: 
common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java
##########
@@ -88,82 +94,125 @@ public OneForOneBlockFetcher(
     if (blockIds.length == 0) {
       throw new IllegalArgumentException("Zero-sized blockIds array");
     }
-    if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) {
+    if (!transportConf.useOldFetchProtocol() && 
areShuffleBlocksOrChunks(blockIds)) {
       this.blockIds = new String[blockIds.length];
-      this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, 
execId, blockIds);
+      this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, 
blockIds);
     } else {
       this.blockIds = blockIds;
       this.message = new OpenBlocks(appId, execId, blockIds);
     }
   }
 
-  private boolean isShuffleBlocks(String[] blockIds) {
-    for (String blockId : blockIds) {
-      if (!blockId.startsWith("shuffle_")) {
-        return false;
-      }
+  /**
+   * Check if the array of block IDs are all shuffle block IDs. With push 
based shuffle,
+   * the shuffle block ID could be either unmerged shuffle block IDs or merged 
shuffle chunk
+   * IDs. For a given stream of shuffle blocks to be fetched in one request, 
they would be either
+   * all unmerged shuffle blocks or all merged shuffle chunks.
+   * @param blockIds block ID array
+   * @return whether the array contains only shuffle block IDs
+   */
+  private boolean areShuffleBlocksOrChunks(String[] blockIds) {
+    if (Arrays.stream(blockIds).anyMatch(blockId -> 
!blockId.startsWith(SHUFFLE_BLOCK_PREFIX))) {
+      // It comes here because there is a blockId which doesn't have 
"shuffle_" prefix so we
+      // check if all the block ids are shuffle chunk Ids.
+      return Arrays.stream(blockIds).allMatch(blockId -> 
blockId.startsWith(SHUFFLE_CHUNK_PREFIX));
     }
     return true;
   }
 
+  /** Creates either a {@link FetchShuffleBlocks} or {@link 
FetchShuffleBlockChunks} message. */
+  private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg(
+      String appId,
+      String execId,
+      String[] blockIds) {
+    if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, 
true);
+    } else {
+      return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, 
false);
+    }
+  }
+
   /**
-   * Create FetchShuffleBlocks message and rebuild internal blockIds by
+   * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild 
internal blockIds by
    * analyzing the pass in blockIds.
    */
-  private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds(
-      String appId, String execId, String[] blockIds) {
+  private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds(
+      String appId,
+      String execId,
+      String[] blockIds,
+      boolean areMergedChunks) {
     String[] firstBlock = splitBlockId(blockIds[0]);
     int shuffleId = Integer.parseInt(firstBlock[1]);
     boolean batchFetchEnabled = firstBlock.length == 5;
 
-    LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>();
+    // In case of FetchShuffleBlocks, primaryId is mapId. For 
FetchShuffleBlockChunks, primaryId
+    // is reduceId.
+    LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new 
LinkedHashMap<>();
     for (String blockId : blockIds) {
       String[] blockIdParts = splitBlockId(blockId);
       if (Integer.parseInt(blockIdParts[1]) != shuffleId) {
         throw new IllegalArgumentException("Expected shuffleId=" + shuffleId +
           ", got:" + blockId);
       }
-      long mapId = Long.parseLong(blockIdParts[2]);
-      if (!mapIdToBlocksInfo.containsKey(mapId)) {
-        mapIdToBlocksInfo.put(mapId, new BlocksInfo());
+      Number primaryId;
+      if (!areMergedChunks) {
+        primaryId = Long.parseLong(blockIdParts[2]);
+      } else {
+        primaryId = Integer.parseInt(blockIdParts[2]);
       }
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId);
-      blocksInfoByMapId.blockIds.add(blockId);
-      blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3]));
+      BlocksInfo blocksInfoByPrimaryId = 
primaryIdToBlocksInfo.computeIfAbsent(primaryId,
+        id -> new BlocksInfo());
+      blocksInfoByPrimaryId.blockIds.add(blockId);
+      // If blockId is a regular shuffle block, then blockIdParts[3] = 
reduceId. If blockId is a
+      // shuffleChunk block, then blockIdParts[3] = chunkId
+      blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3]));
       if (batchFetchEnabled) {
+        // It comes here only if the blockId is a regular shuffle block not a 
shuffleChunk block.
         // When we read continuous shuffle blocks in batch, we will reuse 
reduceIds in
         // FetchShuffleBlocks to store the start and end reduce id for range
         // [startReduceId, endReduceId).
         assert(blockIdParts.length == 5);
-        blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4]));
+        // blockIdParts[4] is the end reduce id for the batch range
+        blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4]));
       }
     }
-    long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet());
-    int[][] reduceIdArr = new int[mapIds.length][];
+    // In case of FetchShuffleBlocks, secondaryIds are reduceIds. For 
FetchShuffleBlockChunks,
+    // secondaryIds are chunkIds.
+    int[][] secondaryIdsArray = new int[primaryIdToBlocksInfo.size()][];
     int blockIdIndex = 0;
-    for (int i = 0; i < mapIds.length; i++) {
-      BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]);
-      reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds);
+    int secIndex = 0;
+    for (Map.Entry<Number, BlocksInfo> entry : 
primaryIdToBlocksInfo.entrySet()) {
+      BlocksInfo blocksInfoByPrimaryId = entry.getValue();
+      secondaryIdsArray[secIndex++] = Ints.toArray(blocksInfoByPrimaryId.ids);
 
-      // The `blockIds`'s order must be same with the read order specified in 
in FetchShuffleBlocks
-      // because the shuffle data's return order should match the `blockIds`'s 
order to ensure
-      // blockId and data match.
-      for (int j = 0; j < blocksInfoByMapId.blockIds.size(); j++) {
-        this.blockIds[blockIdIndex++] = blocksInfoByMapId.blockIds.get(j);
+      // The `blockIds`'s order must be same with the read order specified in 
FetchShuffleBlocks/
+      // FetchShuffleBlockChunks because the shuffle data's return order 
should match the
+      // `blockIds`'s order to ensure blockId and data match.
+      for (String blockId : blocksInfoByPrimaryId.blockIds) {
+        this.blockIds[blockIdIndex++] = blockId;
       }
     }
     assert(blockIdIndex == this.blockIds.length);
-
-    return new FetchShuffleBlocks(
-      appId, execId, shuffleId, mapIds, reduceIdArr, batchFetchEnabled);
+    Set<Number> primaryIds = primaryIdToBlocksInfo.keySet();
+    if (!areMergedChunks) {
+      long[] mapIds = Longs.toArray(primaryIds);
+      return new FetchShuffleBlocks(
+        appId, execId, shuffleId, mapIds, secondaryIdsArray, 
batchFetchEnabled);
+    } else {
+      int[] reduceIds = Ints.toArray(primaryIds);
+      return new FetchShuffleBlockChunks(appId, execId, shuffleId, reduceIds, 
secondaryIdsArray);
+    }

Review comment:
       Note: Given that `primaryIdToBlocksInfo` is a `LinkedHashMap`, we end up 
with `mapIds`/`reduceIds` being consistent w.r.t `secondaryIdsArray`




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org



---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to