mridulm commented on a change in pull request #32811: URL: https://github.com/apache/spark/pull/32811#discussion_r651973359
########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java ########## @@ -294,18 +336,30 @@ public ShuffleMetrics() { private int index = 0; private final Function<Integer, ManagedBuffer> blockDataForIndexFn; private final int size; + private boolean requestForMergedBlockChunks; ManagedBufferIterator(OpenBlocks msg) { String appId = msg.appId; String execId = msg.execId; String[] blockIds = msg.blockIds; String[] blockId0Parts = blockIds[0].split("_"); - if (blockId0Parts.length == 4 && blockId0Parts[0].equals("shuffle")) { + if (blockId0Parts.length == 4 && (blockId0Parts[0].equals(SHUFFLE_BLOCK_ID) || + blockId0Parts[0].equals(SHUFFLE_CHUNK_ID))) { final int shuffleId = Integer.parseInt(blockId0Parts[1]); - final int[] mapIdAndReduceIds = shuffleMapIdAndReduceIds(blockIds, shuffleId); - size = mapIdAndReduceIds.length; - blockDataForIndexFn = index -> blockManager.getBlockData(appId, execId, shuffleId, - mapIdAndReduceIds[index], mapIdAndReduceIds[index + 1]); + requestForMergedBlockChunks = blockId0Parts[0].equals(SHUFFLE_CHUNK_ID); + // For regular shuffle blocks, primaryId is mapId and secondaryIds are reduceIds. + // For shuffle chunks, primaryIds is reduceId and secondaryIds are chunkIds. + final int[] primaryIdAndSecondaryIds = shuffleMapIdAndReduceIds(blockIds, shuffleId); + size = primaryIdAndSecondaryIds.length; + blockDataForIndexFn = index -> { + if (requestForMergedBlockChunks) { + return mergeManager.getMergedBlockData(msg.appId, shuffleId, + primaryIdAndSecondaryIds[index], primaryIdAndSecondaryIds[index + 1]); + } else { + return blockManager.getBlockData(msg.appId, msg.execId, shuffleId, + primaryIdAndSecondaryIds[index], primaryIdAndSecondaryIds[index + 1]); + } + }; Review comment: nit: Wondering if this is cleaner if we simply split this out into its own else block for block chunk ? ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java ########## @@ -88,82 +94,125 @@ public OneForOneBlockFetcher( if (blockIds.length == 0) { throw new IllegalArgumentException("Zero-sized blockIds array"); } - if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) { + if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) { this.blockIds = new String[blockIds.length]; - this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds); + this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds); } else { this.blockIds = blockIds; this.message = new OpenBlocks(appId, execId, blockIds); } } - private boolean isShuffleBlocks(String[] blockIds) { - for (String blockId : blockIds) { - if (!blockId.startsWith("shuffle_")) { - return false; - } + /** + * Check if the array of block IDs are all shuffle block IDs. With push based shuffle, + * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk + * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either + * all unmerged shuffle blocks or all merged shuffle chunks. + * @param blockIds block ID array + * @return whether the array contains only shuffle block IDs + */ + private boolean areShuffleBlocksOrChunks(String[] blockIds) { + if (Arrays.stream(blockIds).anyMatch(blockId -> !blockId.startsWith(SHUFFLE_BLOCK_PREFIX))) { + // It comes here because there is a blockId which doesn't have "shuffle_" prefix so we + // check if all the block ids are shuffle chunk Ids. + return Arrays.stream(blockIds).allMatch(blockId -> blockId.startsWith(SHUFFLE_CHUNK_PREFIX)); } return true; } + /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */ + private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg( + String appId, + String execId, + String[] blockIds) { + if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) { + return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true); + } else { + return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false); + } + } + /** - * Create FetchShuffleBlocks message and rebuild internal blockIds by + * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by * analyzing the pass in blockIds. */ - private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds( - String appId, String execId, String[] blockIds) { + private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds( + String appId, + String execId, + String[] blockIds, + boolean areMergedChunks) { String[] firstBlock = splitBlockId(blockIds[0]); int shuffleId = Integer.parseInt(firstBlock[1]); boolean batchFetchEnabled = firstBlock.length == 5; - LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>(); + // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId + // is reduceId. + LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>(); for (String blockId : blockIds) { String[] blockIdParts = splitBlockId(blockId); if (Integer.parseInt(blockIdParts[1]) != shuffleId) { throw new IllegalArgumentException("Expected shuffleId=" + shuffleId + ", got:" + blockId); } - long mapId = Long.parseLong(blockIdParts[2]); - if (!mapIdToBlocksInfo.containsKey(mapId)) { - mapIdToBlocksInfo.put(mapId, new BlocksInfo()); + Number primaryId; + if (!areMergedChunks) { + primaryId = Long.parseLong(blockIdParts[2]); + } else { + primaryId = Integer.parseInt(blockIdParts[2]); } - BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId); - blocksInfoByMapId.blockIds.add(blockId); - blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3])); + BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.computeIfAbsent(primaryId, + id -> new BlocksInfo()); + blocksInfoByPrimaryId.blockIds.add(blockId); + // If blockId is a regular shuffle block, then blockIdParts[3] = reduceId. If blockId is a + // shuffleChunk block, then blockIdParts[3] = chunkId + blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3])); if (batchFetchEnabled) { + // It comes here only if the blockId is a regular shuffle block not a shuffleChunk block. // When we read continuous shuffle blocks in batch, we will reuse reduceIds in // FetchShuffleBlocks to store the start and end reduce id for range // [startReduceId, endReduceId). assert(blockIdParts.length == 5); - blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4])); + // blockIdParts[4] is the end reduce id for the batch range + blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4])); } } - long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet()); - int[][] reduceIdArr = new int[mapIds.length][]; + // In case of FetchShuffleBlocks, secondaryIds are reduceIds. For FetchShuffleBlockChunks, + // secondaryIds are chunkIds. + int[][] secondaryIdsArray = new int[primaryIdToBlocksInfo.size()][]; int blockIdIndex = 0; - for (int i = 0; i < mapIds.length; i++) { - BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]); - reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds); + int secIndex = 0; + for (Map.Entry<Number, BlocksInfo> entry : primaryIdToBlocksInfo.entrySet()) { Review comment: `primaryIdToBlocksInfo.entrySet()` -> `primaryIdToBlocksInfo.values()` ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java ########## @@ -54,8 +59,12 @@ * Blocks are registered with the "one-for-one" strategy, meaning each Transport-layer Chunk * is equivalent to one block. */ -public class ExternalBlockHandler extends RpcHandler { +public class ExternalBlockHandler extends RpcHandler + implements RpcHandler.MergedBlockMetaReqHandler { private static final Logger logger = LoggerFactory.getLogger(ExternalBlockHandler.class); + private static final String SHUFFLE_MERGER_IDENTIFIER = "shuffle-push-merger"; Review comment: @Ngone51 This is another case of duplication between common modules and spark core (we had seen other cases where code is getting duplicated) ... we should find a way to unify them. ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java ########## @@ -88,82 +94,125 @@ public OneForOneBlockFetcher( if (blockIds.length == 0) { throw new IllegalArgumentException("Zero-sized blockIds array"); } - if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) { + if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) { this.blockIds = new String[blockIds.length]; - this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds); + this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds); } else { this.blockIds = blockIds; this.message = new OpenBlocks(appId, execId, blockIds); } } - private boolean isShuffleBlocks(String[] blockIds) { - for (String blockId : blockIds) { - if (!blockId.startsWith("shuffle_")) { - return false; - } + /** + * Check if the array of block IDs are all shuffle block IDs. With push based shuffle, + * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk + * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either + * all unmerged shuffle blocks or all merged shuffle chunks. + * @param blockIds block ID array + * @return whether the array contains only shuffle block IDs + */ + private boolean areShuffleBlocksOrChunks(String[] blockIds) { + if (Arrays.stream(blockIds).anyMatch(blockId -> !blockId.startsWith(SHUFFLE_BLOCK_PREFIX))) { + // It comes here because there is a blockId which doesn't have "shuffle_" prefix so we + // check if all the block ids are shuffle chunk Ids. + return Arrays.stream(blockIds).allMatch(blockId -> blockId.startsWith(SHUFFLE_CHUNK_PREFIX)); } return true; } + /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */ + private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg( + String appId, + String execId, + String[] blockIds) { + if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) { + return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true); + } else { + return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false); + } + } + /** - * Create FetchShuffleBlocks message and rebuild internal blockIds by + * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by * analyzing the pass in blockIds. Review comment: super nit: `pass` -> `passed` (not from your pr) ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FetchShuffleBlockChunks.java ########## @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle.protocol; + +import java.util.Arrays; + +import io.netty.buffer.ByteBuf; + +import org.apache.spark.network.protocol.Encoders; + +// Needed by ScalaDoc. See SPARK-7726 +import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; + + +/** + * Request to read a set of block chunks. Returns {@link StreamHandle}. + * + * @since 3.2.0 + */ +public class FetchShuffleBlockChunks extends AbstractFetchShuffleBlocks { + // The length of reduceIds must equal to chunkIds.size(). + public final int[] reduceIds; + // The i-th int[] in chunkIds contains all the chunks for the i-th reduceId in reduceIds. + public final int[][] chunkIds; + + public FetchShuffleBlockChunks( + String appId, + String execId, + int shuffleId, + int[] reduceIds, + int[][] chunkIds) { + super(appId, execId, shuffleId); + this.reduceIds = reduceIds; + this.chunkIds = chunkIds; + assert(reduceIds.length == chunkIds.length); + } + + @Override + protected Type type() { return Type.FETCH_SHUFFLE_BLOCK_CHUNKS; } + + @Override + public String toString() { + return toStringHelper() + .append("reduceIds", Arrays.toString(reduceIds)) + .append("chunkIds", Arrays.deepToString(chunkIds)) + .toString(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + FetchShuffleBlockChunks that = (FetchShuffleBlockChunks) o; + if (!super.equals(that)) return false; + if (!Arrays.equals(reduceIds, that.reduceIds)) return false; + return Arrays.deepEquals(chunkIds, that.chunkIds); + } + + @Override + public int hashCode() { + int result = super.hashCode(); + result = 31 * result + Arrays.hashCode(reduceIds); + result = 31 * result + Arrays.deepHashCode(chunkIds); + return result; + } + + @Override + public int encodedLength() { + int encodedLengthOfChunkIds = 0; + for (int[] ids: chunkIds) { + encodedLengthOfChunkIds += Encoders.IntArrays.encodedLength(ids); + } + return super.encodedLength() + + Encoders.IntArrays.encodedLength(reduceIds) + + 4 /* encoded length of chunkIds.size() */ + + encodedLengthOfChunkIds; + } + + @Override + public void encode(ByteBuf buf) { + super.encode(buf); + Encoders.IntArrays.encode(buf, reduceIds); + buf.writeInt(chunkIds.length); Review comment: Add a note that even though `reduceIds.length` == `chunkIds.length`, we are explicitly setting the length in interest of forward compatibility ? ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java ########## @@ -88,82 +94,125 @@ public OneForOneBlockFetcher( if (blockIds.length == 0) { throw new IllegalArgumentException("Zero-sized blockIds array"); } - if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) { + if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) { this.blockIds = new String[blockIds.length]; - this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds); + this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds); } else { this.blockIds = blockIds; this.message = new OpenBlocks(appId, execId, blockIds); Review comment: `OpenBlocks` does not support chunk right ? Do we want to error out in case there are chunk blocks ? ########## File path: common/network-common/src/main/java/org/apache/spark/network/client/MergedBlockMetaResponseCallback.java ########## @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.client; + +import org.apache.spark.network.buffer.ManagedBuffer; + +/** + * Callback for the result of a single + * {@link org.apache.spark.network.protocol.MergedBlockMetaRequest}. + * + * @since 3.2.0 + */ +public interface MergedBlockMetaResponseCallback extends BaseResponseCallback { + /** + * Called upon receipt of a particular merged block meta. + * + * The given buffer will initially have a refcount of 1, but will be release()'d as soon as this + * call returns. You must therefore either retain() the buffer or copy its contents before + * returning. + */ + void onSuccess(int numChunks, ManagedBuffer buffer); Review comment: Can you add doc on what `numChunks` and `buffer` refers to here ? ########## File path: common/network-common/src/main/java/org/apache/spark/network/protocol/MergedBlockMetaRequest.java ########## @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import com.google.common.base.Objects; +import io.netty.buffer.ByteBuf; +import org.apache.commons.lang3.builder.ToStringBuilder; +import org.apache.commons.lang3.builder.ToStringStyle; + +/** + * Request to find the meta information for the specified merged block. The meta information + * contains the number of chunks in the merged blocks and the maps ids in each chunk. + * + * @since 3.2.0 + */ +public class MergedBlockMetaRequest extends AbstractMessage implements RequestMessage { + public final long requestId; + public final String appId; + public final int shuffleId; + public final int reduceId; + + public MergedBlockMetaRequest(long requestId, String appId, int shuffleId, int reduceId) { + super(null, false); + this.requestId = requestId; + this.appId = appId; + this.shuffleId = shuffleId; + this.reduceId = reduceId; + } + + @Override + public Type type() { + return Type.MergedBlockMetaRequest; + } + + @Override + public int encodedLength() { + return 8 + Encoders.Strings.encodedLength(appId) + 8; Review comment: super nit: `+ 8` -> `+ 4 + 4` ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockStoreClient.java ########## @@ -187,6 +188,39 @@ public void onFailure(Throwable e) { } } + @Override + public void getMergedBlockMeta( + String host, + int port, + int shuffleId, + int reduceId, + MergedBlocksMetaListener listener) { + checkInit(); + logger.debug("Get merged blocks meta from {}:{} for shuffleId {} reduceId {}", host, port, + shuffleId, reduceId); + try { + TransportClient client = clientFactory.createClient(host, port); + client.sendMergedBlockMetaReq(appId, shuffleId, reduceId, + new MergedBlockMetaResponseCallback() { + @Override + public void onSuccess(int numChunks, ManagedBuffer buffer) { + logger.trace("Successfully got merged block meta for shuffleId {} reduceId {}", + shuffleId, reduceId); + listener.onSuccess(shuffleId, reduceId, new MergedBlockMeta(numChunks, buffer)); + } + + @Override + public void onFailure(Throwable e) { + logger.error("Failed while getting merged block meta", e); Review comment: Listener's `onFailure` will log this - do we need this ? (here and below) ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java ########## @@ -88,82 +94,125 @@ public OneForOneBlockFetcher( if (blockIds.length == 0) { throw new IllegalArgumentException("Zero-sized blockIds array"); } - if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) { + if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) { this.blockIds = new String[blockIds.length]; - this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds); + this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds); } else { this.blockIds = blockIds; this.message = new OpenBlocks(appId, execId, blockIds); } } - private boolean isShuffleBlocks(String[] blockIds) { - for (String blockId : blockIds) { - if (!blockId.startsWith("shuffle_")) { - return false; - } + /** + * Check if the array of block IDs are all shuffle block IDs. With push based shuffle, + * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk + * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either + * all unmerged shuffle blocks or all merged shuffle chunks. + * @param blockIds block ID array + * @return whether the array contains only shuffle block IDs + */ + private boolean areShuffleBlocksOrChunks(String[] blockIds) { + if (Arrays.stream(blockIds).anyMatch(blockId -> !blockId.startsWith(SHUFFLE_BLOCK_PREFIX))) { + // It comes here because there is a blockId which doesn't have "shuffle_" prefix so we + // check if all the block ids are shuffle chunk Ids. + return Arrays.stream(blockIds).allMatch(blockId -> blockId.startsWith(SHUFFLE_CHUNK_PREFIX)); } return true; } + /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */ + private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg( + String appId, + String execId, + String[] blockIds) { + if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) { + return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true); + } else { + return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false); + } + } + /** - * Create FetchShuffleBlocks message and rebuild internal blockIds by + * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by * analyzing the pass in blockIds. */ - private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds( - String appId, String execId, String[] blockIds) { + private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds( + String appId, + String execId, + String[] blockIds, + boolean areMergedChunks) { String[] firstBlock = splitBlockId(blockIds[0]); int shuffleId = Integer.parseInt(firstBlock[1]); boolean batchFetchEnabled = firstBlock.length == 5; - LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>(); + // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId + // is reduceId. + LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>(); for (String blockId : blockIds) { String[] blockIdParts = splitBlockId(blockId); if (Integer.parseInt(blockIdParts[1]) != shuffleId) { throw new IllegalArgumentException("Expected shuffleId=" + shuffleId + ", got:" + blockId); } - long mapId = Long.parseLong(blockIdParts[2]); - if (!mapIdToBlocksInfo.containsKey(mapId)) { - mapIdToBlocksInfo.put(mapId, new BlocksInfo()); + Number primaryId; + if (!areMergedChunks) { + primaryId = Long.parseLong(blockIdParts[2]); + } else { + primaryId = Integer.parseInt(blockIdParts[2]); } - BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId); - blocksInfoByMapId.blockIds.add(blockId); - blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3])); + BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.computeIfAbsent(primaryId, + id -> new BlocksInfo()); + blocksInfoByPrimaryId.blockIds.add(blockId); + // If blockId is a regular shuffle block, then blockIdParts[3] = reduceId. If blockId is a + // shuffleChunk block, then blockIdParts[3] = chunkId + blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3])); if (batchFetchEnabled) { + // It comes here only if the blockId is a regular shuffle block not a shuffleChunk block. // When we read continuous shuffle blocks in batch, we will reuse reduceIds in // FetchShuffleBlocks to store the start and end reduce id for range // [startReduceId, endReduceId). assert(blockIdParts.length == 5); - blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4])); + // blockIdParts[4] is the end reduce id for the batch range + blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4])); } } - long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet()); - int[][] reduceIdArr = new int[mapIds.length][]; + // In case of FetchShuffleBlocks, secondaryIds are reduceIds. For FetchShuffleBlockChunks, + // secondaryIds are chunkIds. + int[][] secondaryIdsArray = new int[primaryIdToBlocksInfo.size()][]; int blockIdIndex = 0; - for (int i = 0; i < mapIds.length; i++) { - BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]); - reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds); + int secIndex = 0; + for (Map.Entry<Number, BlocksInfo> entry : primaryIdToBlocksInfo.entrySet()) { + BlocksInfo blocksInfoByPrimaryId = entry.getValue(); + secondaryIdsArray[secIndex++] = Ints.toArray(blocksInfoByPrimaryId.ids); - // The `blockIds`'s order must be same with the read order specified in in FetchShuffleBlocks - // because the shuffle data's return order should match the `blockIds`'s order to ensure - // blockId and data match. - for (int j = 0; j < blocksInfoByMapId.blockIds.size(); j++) { - this.blockIds[blockIdIndex++] = blocksInfoByMapId.blockIds.get(j); + // The `blockIds`'s order must be same with the read order specified in FetchShuffleBlocks/ + // FetchShuffleBlockChunks because the shuffle data's return order should match the + // `blockIds`'s order to ensure blockId and data match. + for (String blockId : blocksInfoByPrimaryId.blockIds) { + this.blockIds[blockIdIndex++] = blockId; } } assert(blockIdIndex == this.blockIds.length); - - return new FetchShuffleBlocks( - appId, execId, shuffleId, mapIds, reduceIdArr, batchFetchEnabled); + Set<Number> primaryIds = primaryIdToBlocksInfo.keySet(); + if (!areMergedChunks) { + long[] mapIds = Longs.toArray(primaryIds); + return new FetchShuffleBlocks( + appId, execId, shuffleId, mapIds, secondaryIdsArray, batchFetchEnabled); + } else { + int[] reduceIds = Ints.toArray(primaryIds); + return new FetchShuffleBlockChunks(appId, execId, shuffleId, reduceIds, secondaryIdsArray); + } } /** Split the shuffleBlockId and return shuffleId, mapId and reduceIds. */ private String[] splitBlockId(String blockId) { String[] blockIdParts = blockId.split("_"); // For batch block id, the format contains shuffleId, mapId, begin reduceId, end reduceId. // For single block id, the format contains shuffleId, mapId, educeId. - if (blockIdParts.length < 4 || blockIdParts.length > 5 || !blockIdParts[0].equals("shuffle")) { + // For single block chunk id, the format contains shuffleId, reduceId, chunkId. + if (blockIdParts.length < 4 || blockIdParts.length > 5 || + !(blockIdParts[0].equals(SHUFFLE_BLOCK_PREFIX) || + !blockIdParts[0].equals(SHUFFLE_CHUNK_PREFIX))) { Review comment: `!blockIdParts[0].equals(SHUFFLE_CHUNK_PREFIX)` -> `blockIdParts[0].equals(SHUFFLE_CHUNK_PREFIX)` ? It is `not` on either of them matching, right ? If this is valid comment, please do add some test to catch this - we might have future changes here which might cause bugs as well. ########## File path: common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockFetcher.java ########## @@ -88,82 +94,125 @@ public OneForOneBlockFetcher( if (blockIds.length == 0) { throw new IllegalArgumentException("Zero-sized blockIds array"); } - if (!transportConf.useOldFetchProtocol() && isShuffleBlocks(blockIds)) { + if (!transportConf.useOldFetchProtocol() && areShuffleBlocksOrChunks(blockIds)) { this.blockIds = new String[blockIds.length]; - this.message = createFetchShuffleBlocksMsgAndBuildBlockIds(appId, execId, blockIds); + this.message = createFetchShuffleBlocksOrChunksMsg(appId, execId, blockIds); } else { this.blockIds = blockIds; this.message = new OpenBlocks(appId, execId, blockIds); } } - private boolean isShuffleBlocks(String[] blockIds) { - for (String blockId : blockIds) { - if (!blockId.startsWith("shuffle_")) { - return false; - } + /** + * Check if the array of block IDs are all shuffle block IDs. With push based shuffle, + * the shuffle block ID could be either unmerged shuffle block IDs or merged shuffle chunk + * IDs. For a given stream of shuffle blocks to be fetched in one request, they would be either + * all unmerged shuffle blocks or all merged shuffle chunks. + * @param blockIds block ID array + * @return whether the array contains only shuffle block IDs + */ + private boolean areShuffleBlocksOrChunks(String[] blockIds) { + if (Arrays.stream(blockIds).anyMatch(blockId -> !blockId.startsWith(SHUFFLE_BLOCK_PREFIX))) { + // It comes here because there is a blockId which doesn't have "shuffle_" prefix so we + // check if all the block ids are shuffle chunk Ids. + return Arrays.stream(blockIds).allMatch(blockId -> blockId.startsWith(SHUFFLE_CHUNK_PREFIX)); } return true; } + /** Creates either a {@link FetchShuffleBlocks} or {@link FetchShuffleBlockChunks} message. */ + private AbstractFetchShuffleBlocks createFetchShuffleBlocksOrChunksMsg( + String appId, + String execId, + String[] blockIds) { + if (blockIds[0].startsWith(SHUFFLE_CHUNK_PREFIX)) { + return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, true); + } else { + return createFetchShuffleMsgAndBuildBlockIds(appId, execId, blockIds, false); + } + } + /** - * Create FetchShuffleBlocks message and rebuild internal blockIds by + * Create FetchShuffleBlocks/FetchShuffleBlockChunks message and rebuild internal blockIds by * analyzing the pass in blockIds. */ - private FetchShuffleBlocks createFetchShuffleBlocksMsgAndBuildBlockIds( - String appId, String execId, String[] blockIds) { + private AbstractFetchShuffleBlocks createFetchShuffleMsgAndBuildBlockIds( + String appId, + String execId, + String[] blockIds, + boolean areMergedChunks) { String[] firstBlock = splitBlockId(blockIds[0]); int shuffleId = Integer.parseInt(firstBlock[1]); boolean batchFetchEnabled = firstBlock.length == 5; - LinkedHashMap<Long, BlocksInfo> mapIdToBlocksInfo = new LinkedHashMap<>(); + // In case of FetchShuffleBlocks, primaryId is mapId. For FetchShuffleBlockChunks, primaryId + // is reduceId. + LinkedHashMap<Number, BlocksInfo> primaryIdToBlocksInfo = new LinkedHashMap<>(); for (String blockId : blockIds) { String[] blockIdParts = splitBlockId(blockId); if (Integer.parseInt(blockIdParts[1]) != shuffleId) { throw new IllegalArgumentException("Expected shuffleId=" + shuffleId + ", got:" + blockId); } - long mapId = Long.parseLong(blockIdParts[2]); - if (!mapIdToBlocksInfo.containsKey(mapId)) { - mapIdToBlocksInfo.put(mapId, new BlocksInfo()); + Number primaryId; + if (!areMergedChunks) { + primaryId = Long.parseLong(blockIdParts[2]); + } else { + primaryId = Integer.parseInt(blockIdParts[2]); } - BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapId); - blocksInfoByMapId.blockIds.add(blockId); - blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[3])); + BlocksInfo blocksInfoByPrimaryId = primaryIdToBlocksInfo.computeIfAbsent(primaryId, + id -> new BlocksInfo()); + blocksInfoByPrimaryId.blockIds.add(blockId); + // If blockId is a regular shuffle block, then blockIdParts[3] = reduceId. If blockId is a + // shuffleChunk block, then blockIdParts[3] = chunkId + blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[3])); if (batchFetchEnabled) { + // It comes here only if the blockId is a regular shuffle block not a shuffleChunk block. // When we read continuous shuffle blocks in batch, we will reuse reduceIds in // FetchShuffleBlocks to store the start and end reduce id for range // [startReduceId, endReduceId). assert(blockIdParts.length == 5); - blocksInfoByMapId.reduceIds.add(Integer.parseInt(blockIdParts[4])); + // blockIdParts[4] is the end reduce id for the batch range + blocksInfoByPrimaryId.ids.add(Integer.parseInt(blockIdParts[4])); } } - long[] mapIds = Longs.toArray(mapIdToBlocksInfo.keySet()); - int[][] reduceIdArr = new int[mapIds.length][]; + // In case of FetchShuffleBlocks, secondaryIds are reduceIds. For FetchShuffleBlockChunks, + // secondaryIds are chunkIds. + int[][] secondaryIdsArray = new int[primaryIdToBlocksInfo.size()][]; int blockIdIndex = 0; - for (int i = 0; i < mapIds.length; i++) { - BlocksInfo blocksInfoByMapId = mapIdToBlocksInfo.get(mapIds[i]); - reduceIdArr[i] = Ints.toArray(blocksInfoByMapId.reduceIds); + int secIndex = 0; + for (Map.Entry<Number, BlocksInfo> entry : primaryIdToBlocksInfo.entrySet()) { + BlocksInfo blocksInfoByPrimaryId = entry.getValue(); + secondaryIdsArray[secIndex++] = Ints.toArray(blocksInfoByPrimaryId.ids); - // The `blockIds`'s order must be same with the read order specified in in FetchShuffleBlocks - // because the shuffle data's return order should match the `blockIds`'s order to ensure - // blockId and data match. - for (int j = 0; j < blocksInfoByMapId.blockIds.size(); j++) { - this.blockIds[blockIdIndex++] = blocksInfoByMapId.blockIds.get(j); + // The `blockIds`'s order must be same with the read order specified in FetchShuffleBlocks/ + // FetchShuffleBlockChunks because the shuffle data's return order should match the + // `blockIds`'s order to ensure blockId and data match. + for (String blockId : blocksInfoByPrimaryId.blockIds) { + this.blockIds[blockIdIndex++] = blockId; } } assert(blockIdIndex == this.blockIds.length); - - return new FetchShuffleBlocks( - appId, execId, shuffleId, mapIds, reduceIdArr, batchFetchEnabled); + Set<Number> primaryIds = primaryIdToBlocksInfo.keySet(); + if (!areMergedChunks) { + long[] mapIds = Longs.toArray(primaryIds); + return new FetchShuffleBlocks( + appId, execId, shuffleId, mapIds, secondaryIdsArray, batchFetchEnabled); + } else { + int[] reduceIds = Ints.toArray(primaryIds); + return new FetchShuffleBlockChunks(appId, execId, shuffleId, reduceIds, secondaryIdsArray); + } Review comment: Note: Given that `primaryIdToBlocksInfo` is a `LinkedHashMap`, we end up with `mapIds`/`reduceIds` being consistent w.r.t `secondaryIdsArray` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org