This is an automated email from the ASF dual-hosted git repository. mridulm80 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 8113c88 [SPARK-32916][SHUFFLE] Implementation of shuffle service that leverages push-based shuffle in YARN deployment mode 8113c88 is described below commit 8113c88542ee282b510c7e046d64df1761a85d14 Author: Chandni Singh <singh.chan...@gmail.com> AuthorDate: Mon Nov 9 11:00:52 2020 -0600 [SPARK-32916][SHUFFLE] Implementation of shuffle service that leverages push-based shuffle in YARN deployment mode ### What changes were proposed in this pull request? This is one of the patches for SPIP [SPARK-30602](https://issues.apache.org/jira/browse/SPARK-30602) which is needed for push-based shuffle. Summary of changes: - Adds an implementation of `MergedShuffleFileManager` which was introduced with [Spark 32915](https://issues.apache.org/jira/browse/SPARK-32915). - Integrated the push-based shuffle service with `YarnShuffleService`. ### Why are the changes needed? Refer to the SPIP in [SPARK-30602](https://issues.apache.org/jira/browse/SPARK-30602). ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added unit tests. The reference PR with the consolidated changes covering the complete implementation is also provided in [SPARK-30602](https://issues.apache.org/jira/browse/SPARK-30602). We have already verified the functionality and the improved performance as documented in the SPIP doc. Lead-authored-by: Min Shen mshenlinkedin.com Co-authored-by: Chandni Singh chsinghlinkedin.com Co-authored-by: Ye Zhou yezhoulinkedin.com Closes #30062 from otterc/SPARK-32916. Lead-authored-by: Chandni Singh <singh.chan...@gmail.com> Co-authored-by: Chandni Singh <chsi...@linkedin.com> Co-authored-by: Ye Zhou <yez...@linkedin.com> Co-authored-by: Min Shen <ms...@linkedin.com> Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com> --- .../apache/spark/network/protocol/Encoders.java | 26 +- .../apache/spark/network/util/TransportConf.java | 35 + .../spark/network/protocol/EncodersSuite.java | 68 ++ common/network-shuffle/pom.xml | 10 +- .../apache/spark/network/shuffle/ErrorHandler.java | 8 +- .../network/shuffle/ExternalBlockHandler.java | 25 +- .../spark/network/shuffle/MergedBlockMeta.java | 2 + .../network/shuffle/MergedShuffleFileManager.java | 28 +- .../network/shuffle/OneForOneBlockPusher.java | 11 +- .../network/shuffle/RemoteBlockPushResolver.java | 934 +++++++++++++++++++++ .../shuffle/protocol/FinalizeShuffleMerge.java | 2 + .../network/shuffle/protocol/MergeStatuses.java | 2 + .../network/shuffle/protocol/PushBlockStream.java | 37 +- .../network/shuffle/ExternalBlockHandlerSuite.java | 2 +- .../network/shuffle/OneForOneBlockPusherSuite.java | 66 +- .../shuffle/RemoteBlockPushResolverSuite.java | 496 +++++++++++ .../spark/network/yarn/YarnShuffleService.java | 23 +- .../network/yarn/YarnShuffleServiceSuite.java | 61 ++ 18 files changed, 1748 insertions(+), 88 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java index 4fa191b..8bab808 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java @@ -18,6 +18,7 @@ package org.apache.spark.network.protocol; import java.io.IOException; +import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import io.netty.buffer.ByteBuf; @@ -46,7 +47,11 @@ public class Encoders { } } - /** Bitmaps are encoded with their serialization length followed by the serialization bytes. */ + /** + * Bitmaps are encoded with their serialization length followed by the serialization bytes. + * + * @since 3.1.0 + */ public static class Bitmaps { public static int encodedLength(RoaringBitmap b) { // Compress the bitmap before serializing it. Note that since BlockTransferMessage @@ -57,13 +62,20 @@ public class Encoders { return b.serializedSizeInBytes(); } + /** + * The input ByteBuf for this encoder should have enough write capacity to fit the serialized + * bitmap. Other encoders which use {@link io.netty.buffer.AbstractByteBuf#writeBytes(byte[])} + * to write can expand the buf as writeBytes calls {@link ByteBuf#ensureWritable} internally. + * However, this encoder doesn't rely on netty's writeBytes and will fail if the input buf + * doesn't have enough write capacity. + */ public static void encode(ByteBuf buf, RoaringBitmap b) { - int encodedLength = b.serializedSizeInBytes(); // RoaringBitmap requires nio ByteBuffer for serde. We expose the netty ByteBuf as a nio // ByteBuffer. Here, we need to explicitly manage the index so we can write into the // ByteBuffer, and the write is reflected in the underneath ByteBuf. - b.serialize(buf.nioBuffer(buf.writerIndex(), encodedLength)); - buf.writerIndex(buf.writerIndex() + encodedLength); + ByteBuffer byteBuffer = buf.nioBuffer(buf.writerIndex(), buf.writableBytes()); + b.serialize(byteBuffer); + buf.writerIndex(buf.writerIndex() + byteBuffer.position()); } public static RoaringBitmap decode(ByteBuf buf) { @@ -172,7 +184,11 @@ public class Encoders { } } - /** Bitmap arrays are encoded with the number of bitmaps followed by per-Bitmap encoding. */ + /** + * Bitmap arrays are encoded with the number of bitmaps followed by per-Bitmap encoding. + * + * @since 3.1.0 + */ public static class BitmapArrays { public static int encodedLength(RoaringBitmap[] bitmaps) { int totalLength = 4; diff --git a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java index 646e427..fd287b0 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java +++ b/common/network-common/src/main/java/org/apache/spark/network/util/TransportConf.java @@ -363,4 +363,39 @@ public class TransportConf { return conf.getBoolean("spark.shuffle.useOldFetchProtocol", false); } + /** + * Class name of the implementation of MergedShuffleFileManager that merges the blocks + * pushed to it when push-based shuffle is enabled. By default, push-based shuffle is disabled at + * a cluster level because this configuration is set to + * 'org.apache.spark.network.shuffle.ExternalBlockHandler$NoOpMergedShuffleFileManager'. + * To turn on push-based shuffle at a cluster level, set the configuration to + * 'org.apache.spark.network.shuffle.RemoteBlockPushResolver'. + */ + public String mergedShuffleFileManagerImpl() { + return conf.get("spark.shuffle.server.mergedShuffleFileManagerImpl", + "org.apache.spark.network.shuffle.ExternalBlockHandler$NoOpMergedShuffleFileManager"); + } + + /** + * The minimum size of a chunk when dividing a merged shuffle file into multiple chunks during + * push-based shuffle. + * A merged shuffle file consists of multiple small shuffle blocks. Fetching the + * complete merged shuffle file in a single response increases the memory requirements for the + * clients. Instead of serving the entire merged file, the shuffle service serves the + * merged file in `chunks`. A `chunk` constitutes few shuffle blocks in entirety and this + * configuration controls how big a chunk can get. A corresponding index file for each merged + * shuffle file will be generated indicating chunk boundaries. + */ + public int minChunkSizeInMergedShuffleFile() { + return Ints.checkedCast(JavaUtils.byteStringAsBytes( + conf.get("spark.shuffle.server.minChunkSizeInMergedShuffleFile", "2m"))); + } + + /** + * The size of cache in memory which is used in push-based shuffle for storing merged index files. + */ + public long mergedIndexCacheSize() { + return JavaUtils.byteStringAsBytes( + conf.get("spark.shuffle.server.mergedIndexCacheSize", "100m")); + } } diff --git a/common/network-common/src/test/java/org/apache/spark/network/protocol/EncodersSuite.java b/common/network-common/src/test/java/org/apache/spark/network/protocol/EncodersSuite.java new file mode 100644 index 0000000..6e89702 --- /dev/null +++ b/common/network-common/src/test/java/org/apache/spark/network/protocol/EncodersSuite.java @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.protocol; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.junit.Test; +import org.roaringbitmap.RoaringBitmap; + +import static org.junit.Assert.*; + +/** + * Tests for {@link Encoders}. + */ +public class EncodersSuite { + + @Test + public void testRoaringBitmapEncodeDecode() { + RoaringBitmap bitmap = new RoaringBitmap(); + bitmap.add(1, 2, 3); + ByteBuf buf = Unpooled.buffer(Encoders.Bitmaps.encodedLength(bitmap)); + Encoders.Bitmaps.encode(buf, bitmap); + RoaringBitmap decodedBitmap = Encoders.Bitmaps.decode(buf); + assertEquals(bitmap, decodedBitmap); + } + + @Test (expected = java.nio.BufferOverflowException.class) + public void testRoaringBitmapEncodeShouldFailWhenBufferIsSmall() { + RoaringBitmap bitmap = new RoaringBitmap(); + bitmap.add(1, 2, 3); + ByteBuf buf = Unpooled.buffer(4); + Encoders.Bitmaps.encode(buf, bitmap); + } + + @Test + public void testBitmapArraysEncodeDecode() { + RoaringBitmap[] bitmaps = new RoaringBitmap[] { + new RoaringBitmap(), + new RoaringBitmap(), + new RoaringBitmap(), // empty + new RoaringBitmap(), + new RoaringBitmap() + }; + bitmaps[0].add(1, 2, 3); + bitmaps[1].add(1, 2, 4); + bitmaps[3].add(7L, 9L); + bitmaps[4].add(1L, 100L); + ByteBuf buf = Unpooled.buffer(Encoders.BitmapArrays.encodedLength(bitmaps)); + Encoders.BitmapArrays.encode(buf, bitmaps); + RoaringBitmap[] decodedBitmaps = Encoders.BitmapArrays.decode(buf); + assertArrayEquals(bitmaps, decodedBitmaps); + } +} diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index a4a1ff9..562a1d4 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -47,6 +47,11 @@ <artifactId>metrics-core</artifactId> </dependency> + <dependency> + <groupId>org.apache.spark</groupId> + <artifactId>spark-tags_${scala.binary.version}</artifactId> + </dependency> + <!-- Provided dependencies --> <dependency> <groupId>org.slf4j</groupId> @@ -70,11 +75,6 @@ <type>test-jar</type> <scope>test</scope> </dependency> - <dependency> - <groupId>org.apache.spark</groupId> - <artifactId>spark-tags_${scala.binary.version}</artifactId> - <scope>test</scope> - </dependency> <!-- This spark-tags test-dep is needed even though it isn't used in this module, otherwise testing-cmds that exclude diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ErrorHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ErrorHandler.java index 308b0b7..d13a027 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ErrorHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ErrorHandler.java @@ -21,14 +21,18 @@ import java.net.ConnectException; import com.google.common.base.Throwables; +import org.apache.spark.annotation.Evolving; + /** * Plugs into {@link RetryingBlockFetcher} to further control when an exception should be retried * and logged. * Note: {@link RetryingBlockFetcher} will delegate the exception to this handler only when * - remaining retries < max retries * - exception is an IOException + * + * @since 3.1.0 */ - +@Evolving public interface ErrorHandler { boolean shouldRetryError(Throwable t); @@ -44,6 +48,8 @@ public interface ErrorHandler { /** * The error handler for pushing shuffle blocks to remote shuffle services. + * + * @since 3.1.0 */ class BlockPushErrorHandler implements ErrorHandler { /** diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java index 321b253..688ee1c 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalBlockHandler.java @@ -68,7 +68,7 @@ public class ExternalBlockHandler extends RpcHandler { throws IOException { this(new OneForOneStreamManager(), new ExternalShuffleBlockResolver(conf, registeredExecutorFile), - new NoOpMergedShuffleFileManager()); + new NoOpMergedShuffleFileManager(conf)); } public ExternalBlockHandler( @@ -89,7 +89,7 @@ public class ExternalBlockHandler extends RpcHandler { public ExternalBlockHandler( OneForOneStreamManager streamManager, ExternalShuffleBlockResolver blockManager) { - this(streamManager, blockManager, new NoOpMergedShuffleFileManager()); + this(streamManager, blockManager, new NoOpMergedShuffleFileManager(null)); } /** Enables mocking out the StreamManager, BlockManager, and MergeManager. */ @@ -175,7 +175,7 @@ public class ExternalBlockHandler extends RpcHandler { RegisterExecutor msg = (RegisterExecutor) msgObj; checkAuth(client, msg.appId); blockManager.registerExecutor(msg.appId, msg.execId, msg.executorInfo); - mergeManager.registerExecutor(msg.appId, msg.executorInfo.localDirs); + mergeManager.registerExecutor(msg.appId, msg.executorInfo); callback.onSuccess(ByteBuffer.wrap(new byte[0])); } finally { responseDelayContext.stop(); @@ -232,6 +232,7 @@ public class ExternalBlockHandler extends RpcHandler { */ public void applicationRemoved(String appId, boolean cleanupLocalDirs) { blockManager.applicationRemoved(appId, cleanupLocalDirs); + mergeManager.applicationRemoved(appId, cleanupLocalDirs); } /** @@ -430,8 +431,15 @@ public class ExternalBlockHandler extends RpcHandler { /** * Dummy implementation of merged shuffle file manager. Suitable for when push-based shuffle * is not enabled. + * + * @since 3.1.0 */ - private static class NoOpMergedShuffleFileManager implements MergedShuffleFileManager { + public static class NoOpMergedShuffleFileManager implements MergedShuffleFileManager { + + // This constructor is needed because we use this constructor to instantiate an implementation + // of MergedShuffleFileManager using reflection. + // See YarnShuffleService#newMergedShuffleFileManagerInstance. + public NoOpMergedShuffleFileManager(TransportConf transportConf) {} @Override public StreamCallbackWithID receiveBlockDataAsStream(PushBlockStream msg) { @@ -444,18 +452,13 @@ public class ExternalBlockHandler extends RpcHandler { } @Override - public void registerApplication(String appId, String user) { - // No-op. Do nothing. - } - - @Override - public void registerExecutor(String appId, String[] localDirs) { + public void registerExecutor(String appId, ExecutorShuffleInfo executorInfo) { // No-Op. Do nothing. } @Override public void applicationRemoved(String appId, boolean cleanupLocalDirs) { - throw new UnsupportedOperationException("Cannot handle shuffle block merge"); + // No-Op. Do nothing. } @Override diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedBlockMeta.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedBlockMeta.java index e9d9e53..5541b74 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedBlockMeta.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedBlockMeta.java @@ -34,6 +34,8 @@ import org.apache.spark.network.protocol.Encoders; * 1. Number of chunks in a merged shuffle block. * 2. Bitmaps for each chunk in the merged block. A chunk bitmap contains all the mapIds that were * merged to that merged block chunk. + * + * @since 3.1.0 */ public class MergedBlockMeta { private final int numChunks; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedShuffleFileManager.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedShuffleFileManager.java index ef4dbb2..4ce6a47 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedShuffleFileManager.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/MergedShuffleFileManager.java @@ -19,13 +19,14 @@ package org.apache.spark.network.shuffle; import java.io.IOException; +import org.apache.spark.annotation.Evolving; import org.apache.spark.network.buffer.ManagedBuffer; import org.apache.spark.network.client.StreamCallbackWithID; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; import org.apache.spark.network.shuffle.protocol.FinalizeShuffleMerge; import org.apache.spark.network.shuffle.protocol.MergeStatuses; import org.apache.spark.network.shuffle.protocol.PushBlockStream; - /** * The MergedShuffleFileManager is used to process push based shuffle when enabled. It works * along side {@link ExternalBlockHandler} and serves as an RPCHandler for @@ -33,7 +34,10 @@ import org.apache.spark.network.shuffle.protocol.PushBlockStream; * remotely pushed streams of shuffle blocks to merge them into merged shuffle files. Right * now, support for push based shuffle is only implemented for external shuffle service in * YARN mode. + * + * @since 3.1.0 */ +@Evolving public interface MergedShuffleFileManager { /** * Provides the stream callback used to process a remotely pushed block. The callback is @@ -56,25 +60,15 @@ public interface MergedShuffleFileManager { MergeStatuses finalizeShuffleMerge(FinalizeShuffleMerge msg) throws IOException; /** - * Registers an application when it starts. It also stores the username which is necessary - * for generating the host local directories for merged shuffle files. - * Right now, this is invoked by YarnShuffleService. - * - * @param appId application ID - * @param user username - */ - void registerApplication(String appId, String user); - - /** - * Registers an executor with its local dir list when it starts. This provides the specific path - * so MergedShuffleFileManager knows where to store and look for shuffle data for a - * given application. It is invoked by the RPC call when executor tries to register with the - * local shuffle service. + * Registers an executor with MergedShuffleFileManager. This executor-info provides + * the directories and number of sub-dirs per dir so that MergedShuffleFileManager knows where to + * store and look for shuffle data for a given application. It is invoked by the RPC call when + * executor tries to register with the local shuffle service. * * @param appId application ID - * @param localDirs The list of local dirs that this executor gets granted from NodeManager + * @param executorInfo The list of local dirs that this executor gets granted from NodeManager */ - void registerExecutor(String appId, String[] localDirs); + void registerExecutor(String appId, ExecutorShuffleInfo executorInfo); /** * Invoked when an application finishes. This cleans up any remaining metadata associated with diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockPusher.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockPusher.java index 407b248..6ee95ef 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockPusher.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/OneForOneBlockPusher.java @@ -35,10 +35,13 @@ import org.apache.spark.network.shuffle.protocol.PushBlockStream; * be merged instead of for fetching them from remote shuffle services. This is used by * ShuffleWriter when the block push process is initiated. The supplied BlockFetchingListener * is used to handle the success or failure in pushing each blocks. + * + * @since 3.1.0 */ public class OneForOneBlockPusher { private static final Logger logger = LoggerFactory.getLogger(OneForOneBlockPusher.class); private static final ErrorHandler PUSH_ERROR_HANDLER = new ErrorHandler.BlockPushErrorHandler(); + public static final String SHUFFLE_PUSH_BLOCK_PREFIX = "shufflePush"; private final TransportClient client; private final String appId; @@ -115,7 +118,13 @@ public class OneForOneBlockPusher { for (int i = 0; i < blockIds.length; i++) { assert buffers.containsKey(blockIds[i]) : "Could not find the block buffer for block " + blockIds[i]; - ByteBuffer header = new PushBlockStream(appId, blockIds[i], i).toByteBuffer(); + String[] blockIdParts = blockIds[i].split("_"); + if (blockIdParts.length != 4 || !blockIdParts[0].equals(SHUFFLE_PUSH_BLOCK_PREFIX)) { + throw new IllegalArgumentException( + "Unexpected shuffle push block id format: " + blockIds[i]); + } + ByteBuffer header = new PushBlockStream(appId, Integer.parseInt(blockIdParts[1]), + Integer.parseInt(blockIdParts[2]), Integer.parseInt(blockIdParts[3]) , i).toByteBuffer(); client.uploadStream(new NioManagedBuffer(header), buffers.get(blockIds[i]), new BlockPushCallback(i, blockIds[i])); } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java new file mode 100644 index 0000000..76abb05 --- /dev/null +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/RemoteBlockPushResolver.java @@ -0,0 +1,934 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.RandomAccessFile; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Objects; +import com.google.common.base.Preconditions; +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; +import com.google.common.cache.Weigher; +import com.google.common.collect.Maps; +import org.roaringbitmap.RoaringBitmap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.network.buffer.FileSegmentManagedBuffer; +import org.apache.spark.network.buffer.ManagedBuffer; +import org.apache.spark.network.client.StreamCallbackWithID; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.shuffle.protocol.FinalizeShuffleMerge; +import org.apache.spark.network.shuffle.protocol.MergeStatuses; +import org.apache.spark.network.shuffle.protocol.PushBlockStream; +import org.apache.spark.network.util.JavaUtils; +import org.apache.spark.network.util.NettyUtils; +import org.apache.spark.network.util.TransportConf; + +/** + * An implementation of {@link MergedShuffleFileManager} that provides the most essential shuffle + * service processing logic to support push based shuffle. + * + * @since 3.1.0 + */ +public class RemoteBlockPushResolver implements MergedShuffleFileManager { + + private static final Logger logger = LoggerFactory.getLogger(RemoteBlockPushResolver.class); + @VisibleForTesting + static final String MERGE_MANAGER_DIR = "merge_manager"; + + private final ConcurrentMap<String, AppPathsInfo> appsPathInfo; + private final ConcurrentMap<AppShuffleId, Map<Integer, AppShufflePartitionInfo>> partitions; + + private final Executor directoryCleaner; + private final TransportConf conf; + private final int minChunkSize; + private final ErrorHandler.BlockPushErrorHandler errorHandler; + + @SuppressWarnings("UnstableApiUsage") + private final LoadingCache<File, ShuffleIndexInformation> indexCache; + + @SuppressWarnings("UnstableApiUsage") + public RemoteBlockPushResolver(TransportConf conf) { + this.conf = conf; + this.partitions = Maps.newConcurrentMap(); + this.appsPathInfo = Maps.newConcurrentMap(); + this.directoryCleaner = Executors.newSingleThreadExecutor( + // Add `spark` prefix because it will run in NM in Yarn mode. + NettyUtils.createThreadFactory("spark-shuffle-merged-shuffle-directory-cleaner")); + this.minChunkSize = conf.minChunkSizeInMergedShuffleFile(); + CacheLoader<File, ShuffleIndexInformation> indexCacheLoader = + new CacheLoader<File, ShuffleIndexInformation>() { + public ShuffleIndexInformation load(File file) throws IOException { + return new ShuffleIndexInformation(file); + } + }; + indexCache = CacheBuilder.newBuilder() + .maximumWeight(conf.mergedIndexCacheSize()) + .weigher((Weigher<File, ShuffleIndexInformation>) (file, indexInfo) -> indexInfo.getSize()) + .build(indexCacheLoader); + this.errorHandler = new ErrorHandler.BlockPushErrorHandler(); + } + + /** + * Given the appShuffleId and reduceId that uniquely identifies a given shuffle partition of an + * application, retrieves the associated metadata. If not present and the corresponding merged + * shuffle does not exist, initializes the metadata. + */ + private AppShufflePartitionInfo getOrCreateAppShufflePartitionInfo( + AppShuffleId appShuffleId, + int reduceId) { + File dataFile = getMergedShuffleDataFile(appShuffleId, reduceId); + if (!partitions.containsKey(appShuffleId) && dataFile.exists()) { + // If this partition is already finalized then the partitions map will not contain + // the appShuffleId but the data file would exist. In that case the block is considered late. + return null; + } + Map<Integer, AppShufflePartitionInfo> shufflePartitions = + partitions.computeIfAbsent(appShuffleId, id -> Maps.newConcurrentMap()); + return shufflePartitions.computeIfAbsent(reduceId, key -> { + // It only gets here when the key is not present in the map. This could either + // be the first time the merge manager receives a pushed block for a given application + // shuffle partition, or after the merged shuffle file is finalized. We handle these + // two cases accordingly by checking if the file already exists. + File indexFile = getMergedShuffleIndexFile(appShuffleId, reduceId); + File metaFile = getMergedShuffleMetaFile(appShuffleId, reduceId); + try { + if (dataFile.exists()) { + return null; + } else { + return new AppShufflePartitionInfo(appShuffleId, reduceId, dataFile, indexFile, metaFile); + } + } catch (IOException e) { + logger.error( + "Cannot create merged shuffle partition with data file {}, index file {}, and " + + "meta file {}", dataFile.getAbsolutePath(), + indexFile.getAbsolutePath(), metaFile.getAbsolutePath()); + throw new RuntimeException( + String.format("Cannot initialize merged shuffle partition for appId %s shuffleId %s " + + "reduceId %s", appShuffleId.appId, appShuffleId.shuffleId, reduceId), e); + } + }); + } + + @Override + public MergedBlockMeta getMergedBlockMeta(String appId, int shuffleId, int reduceId) { + AppShuffleId appShuffleId = new AppShuffleId(appId, shuffleId); + File indexFile = getMergedShuffleIndexFile(appShuffleId, reduceId); + if (!indexFile.exists()) { + throw new RuntimeException(String.format( + "Merged shuffle index file %s not found", indexFile.getPath())); + } + int size = (int) indexFile.length(); + // First entry is the zero offset + int numChunks = (size / Long.BYTES) - 1; + File metaFile = getMergedShuffleMetaFile(appShuffleId, reduceId); + if (!metaFile.exists()) { + throw new RuntimeException(String.format("Merged shuffle meta file %s not found", + metaFile.getPath())); + } + FileSegmentManagedBuffer chunkBitMaps = + new FileSegmentManagedBuffer(conf, metaFile, 0L, metaFile.length()); + logger.trace( + "{} shuffleId {} reduceId {} num chunks {}", appId, shuffleId, reduceId, numChunks); + return new MergedBlockMeta(numChunks, chunkBitMaps); + } + + @SuppressWarnings("UnstableApiUsage") + @Override + public ManagedBuffer getMergedBlockData(String appId, int shuffleId, int reduceId, int chunkId) { + AppShuffleId appShuffleId = new AppShuffleId(appId, shuffleId); + File dataFile = getMergedShuffleDataFile(appShuffleId, reduceId); + if (!dataFile.exists()) { + throw new RuntimeException(String.format("Merged shuffle data file %s not found", + dataFile.getPath())); + } + File indexFile = getMergedShuffleIndexFile(appShuffleId, reduceId); + try { + // If we get here, the merged shuffle file should have been properly finalized. Thus we can + // use the file length to determine the size of the merged shuffle block. + ShuffleIndexInformation shuffleIndexInformation = indexCache.get(indexFile); + ShuffleIndexRecord shuffleIndexRecord = shuffleIndexInformation.getIndex(chunkId); + return new FileSegmentManagedBuffer( + conf, dataFile, shuffleIndexRecord.getOffset(), shuffleIndexRecord.getLength()); + } catch (ExecutionException e) { + throw new RuntimeException(String.format( + "Failed to open merged shuffle index file %s", indexFile.getPath()), e); + } + } + + /** + * The logic here is consistent with + * org.apache.spark.storage.DiskBlockManager#getMergedShuffleFile + */ + private File getFile(String appId, String filename) { + // TODO: [SPARK-33236] Change the message when this service is able to handle NM restart + AppPathsInfo appPathsInfo = Preconditions.checkNotNull(appsPathInfo.get(appId), + "application " + appId + " is not registered or NM was restarted."); + File targetFile = ExecutorDiskUtils.getFile(appPathsInfo.activeLocalDirs, + appPathsInfo.subDirsPerLocalDir, filename); + logger.debug("Get merged file {}", targetFile.getAbsolutePath()); + return targetFile; + } + + private File getMergedShuffleDataFile(AppShuffleId appShuffleId, int reduceId) { + String fileName = String.format("%s.data", generateFileName(appShuffleId, reduceId)); + return getFile(appShuffleId.appId, fileName); + } + + private File getMergedShuffleIndexFile(AppShuffleId appShuffleId, int reduceId) { + String indexName = String.format("%s.index", generateFileName(appShuffleId, reduceId)); + return getFile(appShuffleId.appId, indexName); + } + + private File getMergedShuffleMetaFile(AppShuffleId appShuffleId, int reduceId) { + String metaName = String.format("%s.meta", generateFileName(appShuffleId, reduceId)); + return getFile(appShuffleId.appId, metaName); + } + + @Override + public String[] getMergedBlockDirs(String appId) { + AppPathsInfo appPathsInfo = Preconditions.checkNotNull(appsPathInfo.get(appId), + "application " + appId + " is not registered or NM was restarted."); + String[] activeLocalDirs = Preconditions.checkNotNull(appPathsInfo.activeLocalDirs, + "application " + appId + + " active local dirs list has not been updated by any executor registration"); + return activeLocalDirs; + } + + @Override + public void applicationRemoved(String appId, boolean cleanupLocalDirs) { + logger.info("Application {} removed, cleanupLocalDirs = {}", appId, cleanupLocalDirs); + // TODO: [SPARK-33236] Change the message when this service is able to handle NM restart + AppPathsInfo appPathsInfo = Preconditions.checkNotNull(appsPathInfo.remove(appId), + "application " + appId + " is not registered or NM was restarted."); + Iterator<Map.Entry<AppShuffleId, Map<Integer, AppShufflePartitionInfo>>> iterator = + partitions.entrySet().iterator(); + while (iterator.hasNext()) { + Map.Entry<AppShuffleId, Map<Integer, AppShufflePartitionInfo>> entry = iterator.next(); + AppShuffleId appShuffleId = entry.getKey(); + if (appId.equals(appShuffleId.appId)) { + iterator.remove(); + for (AppShufflePartitionInfo partitionInfo : entry.getValue().values()) { + partitionInfo.closeAllFiles(); + } + } + } + if (cleanupLocalDirs) { + Path[] dirs = Arrays.stream(appPathsInfo.activeLocalDirs) + .map(dir -> Paths.get(dir)).toArray(Path[]::new); + directoryCleaner.execute(() -> deleteExecutorDirs(dirs)); + } + } + + /** + * Serially delete local dirs, executed in a separate thread. + */ + @VisibleForTesting + void deleteExecutorDirs(Path[] dirs) { + for (Path localDir : dirs) { + try { + if (Files.exists(localDir)) { + JavaUtils.deleteRecursively(localDir.toFile()); + logger.debug("Successfully cleaned up directory: {}", localDir); + } + } catch (Exception e) { + logger.error("Failed to delete directory: {}", localDir, e); + } + } + } + + @Override + public StreamCallbackWithID receiveBlockDataAsStream(PushBlockStream msg) { + // Retrieve merged shuffle file metadata + AppShuffleId appShuffleId = new AppShuffleId(msg.appId, msg.shuffleId); + AppShufflePartitionInfo partitionInfoBeforeCheck = + getOrCreateAppShufflePartitionInfo(appShuffleId, msg.reduceId); + // Here partitionInfo will be null in 2 cases: + // 1) The request is received for a block that has already been merged, this is possible due + // to the retry logic. + // 2) The request is received after the merged shuffle is finalized, thus is too late. + // + // For case 1, we will drain the data in the channel and just respond success + // to the client. This is required because the response of the previously merged + // block will be ignored by the client, per the logic in RetryingBlockFetcher. + // Note that the netty server should receive data for a given block id only from 1 channel + // at any time. The block should be pushed only from successful maps, thus there should be + // only 1 source for a given block at any time. Although the netty client might retry sending + // this block to the server multiple times, the data of the same block always arrives from the + // same channel thus the server should have already processed the previous request of this + // block before seeing it again in the channel. This guarantees that we can simply just + // check the bitmap to determine if a block is a duplicate or not. + // + // For case 2, we will also drain the data in the channel, but throw an exception in + // {@link org.apache.spark.network.client.StreamCallback#onComplete(String)}. This way, + // the client will be notified of the failure but the channel will remain active. Keeping + // the channel alive is important because the same channel could be reused by multiple map + // tasks in the executor JVM, which belongs to different stages. While one of the shuffles + // in these stages is finalized, the others might still be active. Tearing down the channel + // on the server side will disrupt these other on-going shuffle merges. It's also important + // to notify the client of the failure, so that it can properly halt pushing the remaining + // blocks upon receiving such failures to preserve resources on the server/client side. + // + // Speculative execution would also raise a possible scenario with duplicate blocks. Although + // speculative execution would kill the slower task attempt, leading to only 1 task attempt + // succeeding in the end, there is no guarantee that only one copy of the block will be + // pushed. This is due to our handling of block push process outside of the map task, thus + // it is possible for the speculative task attempt to initiate the block push process before + // getting killed. When this happens, we need to distinguish the duplicate blocks as they + // arrive. More details on this is explained in later comments. + + // Track if the block is received after shuffle merge finalize + final boolean isTooLate = partitionInfoBeforeCheck == null; + // Check if the given block is already merged by checking the bitmap against the given map index + final AppShufflePartitionInfo partitionInfo = partitionInfoBeforeCheck != null + && partitionInfoBeforeCheck.mapTracker.contains(msg.mapIndex) ? null + : partitionInfoBeforeCheck; + final String streamId = String.format("%s_%d_%d_%d", + OneForOneBlockPusher.SHUFFLE_PUSH_BLOCK_PREFIX, appShuffleId.shuffleId, msg.mapIndex, + msg.reduceId); + if (partitionInfo != null) { + return new PushBlockStreamCallback(this, streamId, partitionInfo, msg.mapIndex); + } else { + // For a duplicate block or a block which is late, respond back with a callback that handles + // them differently. + return new StreamCallbackWithID() { + @Override + public String getID() { + return streamId; + } + + @Override + public void onData(String streamId, ByteBuffer buf) { + // Ignore the requests. It reaches here either when a request is received after the + // shuffle file is finalized or when a request is for a duplicate block. + } + + @Override + public void onComplete(String streamId) { + if (isTooLate) { + // Throw an exception here so the block data is drained from channel and server + // responds RpcFailure to the client. + throw new RuntimeException(String.format("Block %s %s", streamId, + ErrorHandler.BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX)); + } + // For duplicate block that is received before the shuffle merge finalizes, the + // server should respond success to the client. + } + + @Override + public void onFailure(String streamId, Throwable cause) { + } + }; + } + } + + @SuppressWarnings("SynchronizationOnLocalVariableOrMethodParameter") + @Override + public MergeStatuses finalizeShuffleMerge(FinalizeShuffleMerge msg) throws IOException { + logger.info("Finalizing shuffle {} from Application {}.", msg.shuffleId, msg.appId); + AppShuffleId appShuffleId = new AppShuffleId(msg.appId, msg.shuffleId); + Map<Integer, AppShufflePartitionInfo> shufflePartitions = partitions.get(appShuffleId); + MergeStatuses mergeStatuses; + if (shufflePartitions == null || shufflePartitions.isEmpty()) { + mergeStatuses = + new MergeStatuses(msg.shuffleId, new RoaringBitmap[0], new int[0], new long[0]); + } else { + Collection<AppShufflePartitionInfo> partitionsToFinalize = shufflePartitions.values(); + int totalPartitions = partitionsToFinalize.size(); + RoaringBitmap[] bitmaps = new RoaringBitmap[totalPartitions]; + int[] reduceIds = new int[totalPartitions]; + long[] sizes = new long[totalPartitions]; + Iterator<AppShufflePartitionInfo> partitionsIter = partitionsToFinalize.iterator(); + int idx = 0; + while (partitionsIter.hasNext()) { + AppShufflePartitionInfo partition = partitionsIter.next(); + synchronized (partition) { + // Get rid of any partial block data at the end of the file. This could either + // be due to failure or a request still being processed when the shuffle + // merge gets finalized. + try { + partition.dataChannel.truncate(partition.getPosition()); + if (partition.getPosition() != partition.getLastChunkOffset()) { + partition.updateChunkInfo(partition.getPosition(), partition.lastMergedMapIndex); + } + bitmaps[idx] = partition.mapTracker; + reduceIds[idx] = partition.reduceId; + sizes[idx++] = partition.getPosition(); + } catch (IOException ioe) { + logger.warn("Exception while finalizing shuffle partition {} {} {}", msg.appId, + msg.shuffleId, partition.reduceId, ioe); + } finally { + partition.closeAllFiles(); + // The partition should be removed after the files are written so that any new stream + // for the same reduce partition will see that the data file exists. + partitionsIter.remove(); + } + } + } + mergeStatuses = new MergeStatuses(msg.shuffleId, bitmaps, reduceIds, sizes); + } + partitions.remove(appShuffleId); + logger.info("Finalized shuffle {} from Application {}.", msg.shuffleId, msg.appId); + return mergeStatuses; + } + + @Override + public void registerExecutor(String appId, ExecutorShuffleInfo executorInfo) { + if (logger.isDebugEnabled()) { + logger.debug("register executor with RemoteBlockPushResolver {} local-dirs {} " + + "num sub-dirs {}", appId, Arrays.toString(executorInfo.localDirs), + executorInfo.subDirsPerLocalDir); + } + appsPathInfo.computeIfAbsent(appId, id -> new AppPathsInfo(appId, executorInfo.localDirs, + executorInfo.subDirsPerLocalDir)); + } + private static String generateFileName(AppShuffleId appShuffleId, int reduceId) { + return String.format("mergedShuffle_%s_%d_%d", appShuffleId.appId, appShuffleId.shuffleId, + reduceId); + } + + /** + * Callback for push stream that handles blocks which are not already merged. + */ + static class PushBlockStreamCallback implements StreamCallbackWithID { + + private final RemoteBlockPushResolver mergeManager; + private final String streamId; + private final int mapIndex; + private final AppShufflePartitionInfo partitionInfo; + private int length = 0; + // This indicates that this stream got the opportunity to write the blocks to the merged file. + // Once this is set to true and the stream encounters a failure then it will take necessary + // action to overwrite any partial written data. This is reset to false when the stream + // completes without any failures. + private boolean isWriting = false; + // Use on-heap instead of direct ByteBuffer since these buffers will be GC'ed very quickly + private List<ByteBuffer> deferredBufs; + + private PushBlockStreamCallback( + RemoteBlockPushResolver mergeManager, + String streamId, + AppShufflePartitionInfo partitionInfo, + int mapIndex) { + this.mergeManager = Preconditions.checkNotNull(mergeManager); + this.streamId = streamId; + this.partitionInfo = Preconditions.checkNotNull(partitionInfo); + this.mapIndex = mapIndex; + } + + @Override + public String getID() { + return streamId; + } + + /** + * Write a ByteBuffer to the merged shuffle file. Here we keep track of the length of the + * block data written to file. In case of failure during writing block to file, we use the + * information tracked in partitionInfo to overwrite the corrupt block when writing the new + * block. + */ + private void writeBuf(ByteBuffer buf) throws IOException { + while (buf.hasRemaining()) { + if (partitionInfo.isEncounteredFailure()) { + long updatedPos = partitionInfo.getPosition() + length; + logger.debug( + "{} shuffleId {} reduceId {} encountered failure current pos {} updated pos {}", + partitionInfo.appShuffleId.appId, partitionInfo.appShuffleId.shuffleId, + partitionInfo.reduceId, partitionInfo.getPosition(), updatedPos); + length += partitionInfo.dataChannel.write(buf, updatedPos); + } else { + length += partitionInfo.dataChannel.write(buf); + } + } + } + + /** + * There will be multiple streams of map blocks belonging to the same reduce partition. At any + * given point of time, only a single map stream can write its data to the merged file. Until + * this stream is completed, the other streams defer writing. This prevents corruption of + * merged data. This returns whether this stream is the active stream that can write to the + * merged file. + */ + private boolean allowedToWrite() { + return partitionInfo.getCurrentMapIndex() < 0 + || partitionInfo.getCurrentMapIndex() == mapIndex; + } + + /** + * Returns if this is a duplicate block generated by speculative tasks. With speculative + * tasks, we could receive the same block from 2 different sources at the same time. One of + * them is going to be the first to set the currentMapIndex. When that block does so, it's + * going to see the currentMapIndex initially as -1. After it sets the currentMapIndex, it's + * going to write some data to disk, thus increasing the length counter. The other duplicate + * block is going to see the currentMapIndex already set to its mapIndex. However, it hasn't + * written any data yet. If the first block gets written completely and resets the + * currentMapIndex to -1 before the processing for the second block finishes, we can just + * check the bitmap to identify the second as a duplicate. + */ + private boolean isDuplicateBlock() { + return (partitionInfo.getCurrentMapIndex() == mapIndex && length == 0) + || partitionInfo.mapTracker.contains(mapIndex); + } + + /** + * This is only invoked when the stream is able to write. The stream first writes any deferred + * block parts buffered in memory. + */ + private void writeAnyDeferredBufs() throws IOException { + if (deferredBufs != null && !deferredBufs.isEmpty()) { + for (ByteBuffer deferredBuf : deferredBufs) { + writeBuf(deferredBuf); + } + deferredBufs = null; + } + } + + @Override + public void onData(String streamId, ByteBuffer buf) throws IOException { + // When handling the block data using StreamInterceptor, it can help to reduce the amount + // of data that needs to be buffered in memory since it does not wait till the completion + // of the frame before handling the message, thus releasing the ByteBuf earlier. However, + // this also means it would chunk a block into multiple buffers. Here, we want to preserve + // the benefit of handling the block data using StreamInterceptor as much as possible while + // providing the guarantee that one block would be continuously written to the merged + // shuffle file before the next block starts. For each shuffle partition, we would track + // the current map index to make sure only block matching the map index can be written to + // disk. If one server thread sees the block being handled is the current block, it would + // directly write the block to disk. Otherwise, it would buffer the block chunks in memory. + // If the block becomes the current block before we see the end of it, we would then dump + // all buffered block data to disk and write the remaining portions of the block directly + // to disk as well. This way, we avoid having to buffer the entirety of every blocks in + // memory, while still providing the necessary guarantee. + synchronized (partitionInfo) { + Map<Integer, AppShufflePartitionInfo> shufflePartitions = + mergeManager.partitions.get(partitionInfo.appShuffleId); + // If the partitionInfo corresponding to (appId, shuffleId, reduceId) is no longer present + // then it means that the shuffle merge has already been finalized. We should thus ignore + // the data and just drain the remaining bytes of this message. This check should be + // placed inside the synchronized block to make sure that checking the key is still + // present and processing the data is atomic. + if (shufflePartitions == null || !shufflePartitions.containsKey(partitionInfo.reduceId)) { + deferredBufs = null; + return; + } + // Check whether we can write to disk + if (allowedToWrite()) { + isWriting = true; + // Identify duplicate block generated by speculative tasks. We respond success to + // the client in cases of duplicate even though no data is written. + if (isDuplicateBlock()) { + deferredBufs = null; + return; + } + logger.trace("{} shuffleId {} reduceId {} onData writable", + partitionInfo.appShuffleId.appId, partitionInfo.appShuffleId.shuffleId, + partitionInfo.reduceId); + if (partitionInfo.getCurrentMapIndex() < 0) { + partitionInfo.setCurrentMapIndex(mapIndex); + } + + // If we got here, it's safe to write the block data to the merged shuffle file. We + // first write any deferred block. + writeAnyDeferredBufs(); + writeBuf(buf); + // If we got here, it means we successfully write the current chunk of block to merged + // shuffle file. If we encountered failure while writing the previous block, we should + // reset the file channel position and the status of partitionInfo to indicate that we + // have recovered from previous disk write failure. However, we do not update the + // position tracked by partitionInfo here. That is only updated while the entire block + // is successfully written to merged shuffle file. + if (partitionInfo.isEncounteredFailure()) { + partitionInfo.dataChannel.position(partitionInfo.getPosition() + length); + partitionInfo.setEncounteredFailure(false); + } + } else { + logger.trace("{} shuffleId {} reduceId {} onData deferred", + partitionInfo.appShuffleId.appId, partitionInfo.appShuffleId.shuffleId, + partitionInfo.reduceId); + // If we cannot write to disk, we buffer the current block chunk in memory so it could + // potentially be written to disk later. We take our best effort without guarantee + // that the block will be written to disk. If the block data is divided into multiple + // chunks during TCP transportation, each #onData invocation is an attempt to write + // the block to disk. If the block is still not written to disk after all #onData + // invocations, the final #onComplete invocation is the last attempt to write the + // block to disk. If we still couldn't write this block to disk after this, we give up + // on this block push request and respond failure to client. We could potentially + // buffer the block longer or wait for a few iterations inside #onData or #onComplete + // to increase the chance of writing the block to disk, however this would incur more + // memory footprint or decrease the server processing throughput for the shuffle + // service. In addition, during test we observed that by randomizing the order in + // which clients sends block push requests batches, only ~0.5% blocks failed to be + // written to disk due to this reason. We thus decide to optimize for server + // throughput and memory usage. + if (deferredBufs == null) { + deferredBufs = new LinkedList<>(); + } + // Write the buffer to the in-memory deferred cache. Since buf is a slice of a larger + // byte buffer, we cache only the relevant bytes not the entire large buffer to save + // memory. + ByteBuffer deferredBuf = ByteBuffer.allocate(buf.remaining()); + deferredBuf.put(buf); + deferredBuf.flip(); + deferredBufs.add(deferredBuf); + } + } + } + + @Override + public void onComplete(String streamId) throws IOException { + synchronized (partitionInfo) { + logger.trace("{} shuffleId {} reduceId {} onComplete invoked", + partitionInfo.appShuffleId.appId, partitionInfo.appShuffleId.shuffleId, + partitionInfo.reduceId); + Map<Integer, AppShufflePartitionInfo> shufflePartitions = + mergeManager.partitions.get(partitionInfo.appShuffleId); + // When this request initially got to the server, the shuffle merge finalize request + // was not received yet. By the time we finish reading this message, the shuffle merge + // however is already finalized. We should thus respond RpcFailure to the client. + if (shufflePartitions == null || !shufflePartitions.containsKey(partitionInfo.reduceId)) { + deferredBufs = null; + throw new RuntimeException(String.format("Block %s %s", streamId, + ErrorHandler.BlockPushErrorHandler.TOO_LATE_MESSAGE_SUFFIX)); + } + // Check if we can commit this block + if (allowedToWrite()) { + isWriting = true; + // Identify duplicate block generated by speculative tasks. We respond success to + // the client in cases of duplicate even though no data is written. + if (isDuplicateBlock()) { + deferredBufs = null; + return; + } + if (partitionInfo.getCurrentMapIndex() < 0) { + writeAnyDeferredBufs(); + } + long updatedPos = partitionInfo.getPosition() + length; + boolean indexUpdated = false; + if (updatedPos - partitionInfo.getLastChunkOffset() >= mergeManager.minChunkSize) { + partitionInfo.updateChunkInfo(updatedPos, mapIndex); + indexUpdated = true; + } + partitionInfo.setPosition(updatedPos); + partitionInfo.setCurrentMapIndex(-1); + + // update merged results + partitionInfo.blockMerged(mapIndex); + if (indexUpdated) { + partitionInfo.resetChunkTracker(); + } + } else { + deferredBufs = null; + throw new RuntimeException(String.format("%s %s to merged shuffle", + ErrorHandler.BlockPushErrorHandler.BLOCK_APPEND_COLLISION_DETECTED_MSG_PREFIX, + streamId)); + } + } + isWriting = false; + } + + @Override + public void onFailure(String streamId, Throwable throwable) throws IOException { + if (mergeManager.errorHandler.shouldLogError(throwable)) { + logger.error("Encountered issue when merging {}", streamId, throwable); + } else { + logger.debug("Encountered issue when merging {}", streamId, throwable); + } + // Only update partitionInfo if the failure corresponds to a valid request. If the + // request is too late, i.e. received after shuffle merge finalize, #onFailure will + // also be triggered, and we can just ignore. Also, if we couldn't find an opportunity + // to write the block data to disk, we should also ignore here. + if (isWriting) { + synchronized (partitionInfo) { + Map<Integer, AppShufflePartitionInfo> shufflePartitions = + mergeManager.partitions.get(partitionInfo.appShuffleId); + if (shufflePartitions != null && shufflePartitions.containsKey(partitionInfo.reduceId)) { + logger.debug("{} shuffleId {} reduceId {} set encountered failure", + partitionInfo.appShuffleId.appId, partitionInfo.appShuffleId.shuffleId, + partitionInfo.reduceId); + partitionInfo.setCurrentMapIndex(-1); + partitionInfo.setEncounteredFailure(true); + } + } + } + } + } + + /** + * ID that uniquely identifies a shuffle for an application. This is used as a key in + * {@link #partitions}. + */ + public static class AppShuffleId { + public final String appId; + public final int shuffleId; + + AppShuffleId(String appId, int shuffleId) { + this.appId = appId; + this.shuffleId = shuffleId; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + AppShuffleId that = (AppShuffleId) o; + return shuffleId == that.shuffleId && Objects.equal(appId, that.appId); + } + + @Override + public int hashCode() { + return Objects.hashCode(appId, shuffleId); + } + + @Override + public String toString() { + return Objects.toStringHelper(this) + .add("appId", appId) + .add("shuffleId", shuffleId) + .toString(); + } + } + + /** Metadata tracked for an actively merged shuffle partition */ + public static class AppShufflePartitionInfo { + + private final AppShuffleId appShuffleId; + private final int reduceId; + // The merged shuffle data file channel + public FileChannel dataChannel; + // Location offset of the last successfully merged block for this shuffle partition + private long position; + // Indicating whether failure was encountered when merging the previous block + private boolean encounteredFailure; + // Track the map index whose block is being merged for this shuffle partition + private int currentMapIndex; + // Bitmap tracking which mapper's blocks have been merged for this shuffle partition + private RoaringBitmap mapTracker; + // The index file for a particular merged shuffle contains the chunk offsets. + private RandomAccessFile indexFile; + // The meta file for a particular merged shuffle contains all the map indices that belong to + // every chunk. The entry per chunk is a serialized bitmap. + private RandomAccessFile metaFile; + // The offset for the last chunk tracked in the index file for this shuffle partition + private long lastChunkOffset; + private int lastMergedMapIndex = -1; + // Bitmap tracking which mapper's blocks are in the current shuffle chunk + private RoaringBitmap chunkTracker; + + AppShufflePartitionInfo( + AppShuffleId appShuffleId, + int reduceId, + File dataFile, + File indexFile, + File metaFile) throws IOException { + this.appShuffleId = Preconditions.checkNotNull(appShuffleId, "app shuffle id"); + this.reduceId = reduceId; + this.dataChannel = new FileOutputStream(dataFile).getChannel(); + this.indexFile = new RandomAccessFile(indexFile, "rw"); + this.metaFile = new RandomAccessFile(metaFile, "rw"); + this.currentMapIndex = -1; + // Writing 0 offset so that we can reuse ShuffleIndexInformation.getIndex() + updateChunkInfo(0L, -1); + this.position = 0; + this.encounteredFailure = false; + this.mapTracker = new RoaringBitmap(); + this.chunkTracker = new RoaringBitmap(); + } + + public long getPosition() { + return position; + } + + public void setPosition(long position) { + logger.trace("{} shuffleId {} reduceId {} current pos {} update pos {}", appShuffleId.appId, + appShuffleId.shuffleId, reduceId, this.position, position); + this.position = position; + } + + boolean isEncounteredFailure() { + return encounteredFailure; + } + + void setEncounteredFailure(boolean encounteredFailure) { + this.encounteredFailure = encounteredFailure; + } + + int getCurrentMapIndex() { + return currentMapIndex; + } + + void setCurrentMapIndex(int mapIndex) { + logger.trace("{} shuffleId {} reduceId {} updated mapIndex {} current mapIndex {}", + appShuffleId.appId, appShuffleId.shuffleId, reduceId, currentMapIndex, mapIndex); + this.currentMapIndex = mapIndex; + } + + long getLastChunkOffset() { + return lastChunkOffset; + } + + void blockMerged(int mapIndex) { + logger.debug("{} shuffleId {} reduceId {} updated merging mapIndex {}", appShuffleId.appId, + appShuffleId.shuffleId, reduceId, mapIndex); + mapTracker.add(mapIndex); + chunkTracker.add(mapIndex); + lastMergedMapIndex = mapIndex; + } + + void resetChunkTracker() { + chunkTracker.clear(); + } + + /** + * Appends the chunk offset to the index file and adds the map index to the chunk tracker. + * + * @param chunkOffset the offset of the chunk in the data file. + * @param mapIndex the map index to be added to chunk tracker. + */ + void updateChunkInfo(long chunkOffset, int mapIndex) throws IOException { + long idxStartPos = -1; + try { + // update the chunk tracker to meta file before index file + writeChunkTracker(mapIndex); + idxStartPos = indexFile.getFilePointer(); + logger.trace("{} shuffleId {} reduceId {} updated index current {} updated {}", + appShuffleId.appId, appShuffleId.shuffleId, reduceId, this.lastChunkOffset, + chunkOffset); + indexFile.writeLong(chunkOffset); + } catch (IOException ioe) { + if (idxStartPos != -1) { + // reset the position to avoid corrupting index files during exception. + logger.warn("{} shuffleId {} reduceId {} reset index to position {}", + appShuffleId.appId, appShuffleId.shuffleId, reduceId, idxStartPos); + indexFile.seek(idxStartPos); + } + throw ioe; + } + this.lastChunkOffset = chunkOffset; + } + + private void writeChunkTracker(int mapIndex) throws IOException { + if (mapIndex == -1) { + return; + } + chunkTracker.add(mapIndex); + long metaStartPos = metaFile.getFilePointer(); + try { + logger.trace("{} shuffleId {} reduceId {} mapIndex {} write chunk to meta file", + appShuffleId.appId, appShuffleId.shuffleId, reduceId, mapIndex); + chunkTracker.serialize(metaFile); + } catch (IOException ioe) { + logger.warn("{} shuffleId {} reduceId {} mapIndex {} reset position of meta file to {}", + appShuffleId.appId, appShuffleId.shuffleId, reduceId, mapIndex, metaStartPos); + metaFile.seek(metaStartPos); + throw ioe; + } + } + + void closeAllFiles() { + if (dataChannel != null) { + try { + dataChannel.close(); + } catch (IOException ioe) { + logger.warn("Error closing data channel for {} shuffleId {} reduceId {}", + appShuffleId.appId, appShuffleId.shuffleId, reduceId); + } finally { + dataChannel = null; + } + } + if (metaFile != null) { + try { + // if the stream is closed, channel get's closed as well. + metaFile.close(); + } catch (IOException ioe) { + logger.warn("Error closing meta file for {} shuffleId {} reduceId {}", + appShuffleId.appId, appShuffleId.shuffleId, reduceId); + } finally { + metaFile = null; + } + } + if (indexFile != null) { + try { + indexFile.close(); + } catch (IOException ioe) { + logger.warn("Error closing index file for {} shuffleId {} reduceId {}", + appShuffleId.appId, appShuffleId.shuffleId, reduceId); + } finally { + indexFile = null; + } + } + } + + @Override + protected void finalize() throws Throwable { + closeAllFiles(); + } + } + + /** + * Wraps all the information related to the merge directory of an application. + */ + private static class AppPathsInfo { + + private final String[] activeLocalDirs; + private final int subDirsPerLocalDir; + + private AppPathsInfo( + String appId, + String[] localDirs, + int subDirsPerLocalDir) { + activeLocalDirs = Arrays.stream(localDirs) + .map(localDir -> + // Merge directory is created at the same level as block-manager directory. The list of + // local directories that we get from executorShuffleInfo are paths of each + // block-manager directory. To find out the merge directory location, we first find the + // parent dir and then append the "merger_manager" directory to it. + Paths.get(localDir).getParent().resolve(MERGE_MANAGER_DIR).toFile().getPath()) + .toArray(String[]::new); + this.subDirsPerLocalDir = subDirsPerLocalDir; + if (logger.isInfoEnabled()) { + logger.info("Updated active local dirs {} and sub dirs {} for application {}", + Arrays.toString(activeLocalDirs),subDirsPerLocalDir, appId); + } + } + } +} diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FinalizeShuffleMerge.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FinalizeShuffleMerge.java index 9058575..8427837 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FinalizeShuffleMerge.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/FinalizeShuffleMerge.java @@ -25,6 +25,8 @@ import org.apache.spark.network.protocol.Encoders; /** * Request to finalize merge for a given shuffle. * Returns {@link MergeStatuses} + * + * @since 3.1.0 */ public class FinalizeShuffleMerge extends BlockTransferMessage { public final String appId; diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/MergeStatuses.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/MergeStatuses.java index f57e8b3..d506d9e 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/MergeStatuses.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/MergeStatuses.java @@ -32,6 +32,8 @@ import org.apache.spark.network.protocol.Encoders; * the set of mapper partition blocks that are merged for a given reducer partition, an array * of reducer IDs, and an array of merged shuffle partition sizes. The 3 arrays list information * about all the reducer partitions merged by the ExternalShuffleService in the same order. + * + * @since 3.1.0 */ public class MergeStatuses extends BlockTransferMessage { /** Shuffle ID **/ diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/PushBlockStream.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/PushBlockStream.java index 7eab5a6..83fc7b2 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/PushBlockStream.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/protocol/PushBlockStream.java @@ -23,23 +23,27 @@ import io.netty.buffer.ByteBuf; import org.apache.spark.network.protocol.Encoders; // Needed by ScalaDoc. See SPARK-7726 -import static org.apache.spark.network.shuffle.protocol.BlockTransferMessage.Type; - /** * Request to push a block to a remote shuffle service to be merged in push based shuffle. * The remote shuffle service will also include this message when responding the push requests. + * + * @since 3.1.0 */ public class PushBlockStream extends BlockTransferMessage { public final String appId; - public final String blockId; + public final int shuffleId; + public final int mapIndex; + public final int reduceId; // Similar to the chunkIndex in StreamChunkId, indicating the index of a block in a batch of // blocks to be pushed. public final int index; - public PushBlockStream(String appId, String blockId, int index) { + public PushBlockStream(String appId, int shuffleId, int mapIndex, int reduceId, int index) { this.appId = appId; - this.blockId = blockId; + this.shuffleId = shuffleId; + this.mapIndex = mapIndex; + this.reduceId = reduceId; this.index = index; } @@ -50,14 +54,16 @@ public class PushBlockStream extends BlockTransferMessage { @Override public int hashCode() { - return Objects.hashCode(appId, blockId, index); + return Objects.hashCode(appId, shuffleId, mapIndex , reduceId, index); } @Override public String toString() { return Objects.toStringHelper(this) .add("appId", appId) - .add("blockId", blockId) + .add("shuffleId", shuffleId) + .add("mapIndex", mapIndex) + .add("reduceId", reduceId) .add("index", index) .toString(); } @@ -67,7 +73,9 @@ public class PushBlockStream extends BlockTransferMessage { if (other != null && other instanceof PushBlockStream) { PushBlockStream o = (PushBlockStream) other; return Objects.equal(appId, o.appId) - && Objects.equal(blockId, o.blockId) + && shuffleId == o.shuffleId + && mapIndex == o.mapIndex + && reduceId == o.reduceId && index == o.index; } return false; @@ -75,21 +83,24 @@ public class PushBlockStream extends BlockTransferMessage { @Override public int encodedLength() { - return Encoders.Strings.encodedLength(appId) - + Encoders.Strings.encodedLength(blockId) + 4; + return Encoders.Strings.encodedLength(appId) + 16; } @Override public void encode(ByteBuf buf) { Encoders.Strings.encode(buf, appId); - Encoders.Strings.encode(buf, blockId); + buf.writeInt(shuffleId); + buf.writeInt(mapIndex); + buf.writeInt(reduceId); buf.writeInt(index); } public static PushBlockStream decode(ByteBuf buf) { String appId = Encoders.Strings.decode(buf); - String blockId = Encoders.Strings.decode(buf); + int shuffleId = buf.readInt(); + int mapIdx = buf.readInt(); + int reduceId = buf.readInt(); int index = buf.readInt(); - return new PushBlockStream(appId, blockId, index); + return new PushBlockStream(appId, shuffleId, mapIdx, reduceId, index); } } diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java index 680b8d7..f06e7cb 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/ExternalBlockHandlerSuite.java @@ -77,7 +77,7 @@ public class ExternalBlockHandlerSuite { ByteBuffer registerMessage = new RegisterExecutor("app0", "exec1", config).toByteBuffer(); handler.receive(client, registerMessage, callback); verify(blockResolver, times(1)).registerExecutor("app0", "exec1", config); - verify(mergedShuffleManager, times(1)).registerExecutor("app0", localDirs); + verify(mergedShuffleManager, times(1)).registerExecutor("app0", config); verify(callback, times(1)).onSuccess(any(ByteBuffer.class)); verify(callback, never()).onFailure(any(Throwable.class)); diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockPusherSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockPusherSuite.java index ebcdba7..46a0f6c 100644 --- a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockPusherSuite.java +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/OneForOneBlockPusherSuite.java @@ -45,77 +45,77 @@ public class OneForOneBlockPusherSuite { @Test public void testPushOne() { LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap(); - blocks.put("shuffle_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[1]))); + blocks.put("shufflePush_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[1]))); String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); BlockFetchingListener listener = pushBlocks( blocks, blockIds, - Arrays.asList(new PushBlockStream("app-id", "shuffle_0_0_0", 0))); + Arrays.asList(new PushBlockStream("app-id", 0, 0, 0, 0))); - verify(listener).onBlockFetchSuccess(eq("shuffle_0_0_0"), any()); + verify(listener).onBlockFetchSuccess(eq("shufflePush_0_0_0"), any()); } @Test public void testPushThree() { LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap(); - blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); - blocks.put("b1", new NioManagedBuffer(ByteBuffer.wrap(new byte[23]))); - blocks.put("b2", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23]))); + blocks.put("shufflePush_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); + blocks.put("shufflePush_0_1_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[23]))); + blocks.put("shufflePush_0_2_0", new NettyManagedBuffer(Unpooled.wrappedBuffer(new byte[23]))); String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); BlockFetchingListener listener = pushBlocks( blocks, blockIds, - Arrays.asList(new PushBlockStream("app-id", "b0", 0), - new PushBlockStream("app-id", "b1", 1), - new PushBlockStream("app-id", "b2", 2))); + Arrays.asList(new PushBlockStream("app-id", 0, 0, 0, 0), + new PushBlockStream("app-id", 0, 1, 0, 1), + new PushBlockStream("app-id", 0, 2, 0, 2))); - for (int i = 0; i < 3; i ++) { - verify(listener, times(1)).onBlockFetchSuccess(eq("b" + i), any()); - } + verify(listener, times(1)).onBlockFetchSuccess(eq("shufflePush_0_0_0"), any()); + verify(listener, times(1)).onBlockFetchSuccess(eq("shufflePush_0_1_0"), any()); + verify(listener, times(1)).onBlockFetchSuccess(eq("shufflePush_0_2_0"), any()); } @Test public void testServerFailures() { LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap(); - blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); - blocks.put("b1", new NioManagedBuffer(ByteBuffer.wrap(new byte[0]))); - blocks.put("b2", new NioManagedBuffer(ByteBuffer.wrap(new byte[0]))); + blocks.put("shufflePush_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); + blocks.put("shufflePush_0_1_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0]))); + blocks.put("shufflePush_0_2_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0]))); String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); BlockFetchingListener listener = pushBlocks( blocks, blockIds, - Arrays.asList(new PushBlockStream("app-id", "b0", 0), - new PushBlockStream("app-id", "b1", 1), - new PushBlockStream("app-id", "b2", 2))); + Arrays.asList(new PushBlockStream("app-id", 0, 0, 0, 0), + new PushBlockStream("app-id", 0, 1, 0, 1), + new PushBlockStream("app-id", 0, 2, 0, 2))); - verify(listener, times(1)).onBlockFetchSuccess(eq("b0"), any()); - verify(listener, times(1)).onBlockFetchFailure(eq("b1"), any()); - verify(listener, times(1)).onBlockFetchFailure(eq("b2"), any()); + verify(listener, times(1)).onBlockFetchSuccess(eq("shufflePush_0_0_0"), any()); + verify(listener, times(1)).onBlockFetchFailure(eq("shufflePush_0_1_0"), any()); + verify(listener, times(1)).onBlockFetchFailure(eq("shufflePush_0_2_0"), any()); } @Test public void testHandlingRetriableFailures() { LinkedHashMap<String, ManagedBuffer> blocks = Maps.newLinkedHashMap(); - blocks.put("b0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); - blocks.put("b1", null); - blocks.put("b2", new NioManagedBuffer(ByteBuffer.wrap(new byte[0]))); + blocks.put("shufflePush_0_0_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[12]))); + blocks.put("shufflePush_0_1_0", null); + blocks.put("shufflePush_0_2_0", new NioManagedBuffer(ByteBuffer.wrap(new byte[0]))); String[] blockIds = blocks.keySet().toArray(new String[blocks.size()]); BlockFetchingListener listener = pushBlocks( blocks, blockIds, - Arrays.asList(new PushBlockStream("app-id", "b0", 0), - new PushBlockStream("app-id", "b1", 1), - new PushBlockStream("app-id", "b2", 2))); - - verify(listener, times(1)).onBlockFetchSuccess(eq("b0"), any()); - verify(listener, times(0)).onBlockFetchSuccess(not(eq("b0")), any()); - verify(listener, times(0)).onBlockFetchFailure(eq("b0"), any()); - verify(listener, times(1)).onBlockFetchFailure(eq("b1"), any()); - verify(listener, times(2)).onBlockFetchFailure(eq("b2"), any()); + Arrays.asList(new PushBlockStream("app-id", 0, 0, 0, 0), + new PushBlockStream("app-id", 0, 1, 0, 1), + new PushBlockStream("app-id", 0, 2, 0, 2))); + + verify(listener, times(1)).onBlockFetchSuccess(eq("shufflePush_0_0_0"), any()); + verify(listener, times(0)).onBlockFetchSuccess(not(eq("shufflePush_0_0_0")), any()); + verify(listener, times(0)).onBlockFetchFailure(eq("shufflePush_0_0_0"), any()); + verify(listener, times(1)).onBlockFetchFailure(eq("shufflePush_0_1_0"), any()); + verify(listener, times(2)).onBlockFetchFailure(eq("shufflePush_0_2_0"), any()); } /** diff --git a/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RemoteBlockPushResolverSuite.java b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RemoteBlockPushResolverSuite.java new file mode 100644 index 0000000..0f200dc --- /dev/null +++ b/common/network-shuffle/src/test/java/org/apache/spark/network/shuffle/RemoteBlockPushResolverSuite.java @@ -0,0 +1,496 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.shuffle; + +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.concurrent.Semaphore; + +import com.google.common.base.Throwables; +import com.google.common.collect.ImmutableMap; + +import org.apache.commons.io.FileUtils; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.roaringbitmap.RoaringBitmap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import static org.junit.Assert.*; + +import org.apache.spark.network.buffer.FileSegmentManagedBuffer; +import org.apache.spark.network.client.StreamCallbackWithID; +import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo; +import org.apache.spark.network.shuffle.protocol.FinalizeShuffleMerge; +import org.apache.spark.network.shuffle.protocol.MergeStatuses; +import org.apache.spark.network.shuffle.protocol.PushBlockStream; +import org.apache.spark.network.util.MapConfigProvider; +import org.apache.spark.network.util.TransportConf; + +/** + * Tests for {@link RemoteBlockPushResolver}. + */ +public class RemoteBlockPushResolverSuite { + + private static final Logger log = LoggerFactory.getLogger(RemoteBlockPushResolverSuite.class); + private final String TEST_APP = "testApp"; + private final String BLOCK_MANAGER_DIR = "blockmgr-193d8401"; + + private TransportConf conf; + private RemoteBlockPushResolver pushResolver; + private Path[] localDirs; + + @Before + public void before() throws IOException { + localDirs = createLocalDirs(2); + MapConfigProvider provider = new MapConfigProvider( + ImmutableMap.of("spark.shuffle.server.minChunkSizeInMergedShuffleFile", "4")); + conf = new TransportConf("shuffle", provider); + pushResolver = new RemoteBlockPushResolver(conf); + registerExecutor(TEST_APP, prepareLocalDirs(localDirs)); + } + + @After + public void after() { + try { + for (Path local : localDirs) { + FileUtils.deleteDirectory(local.toFile()); + } + removeApplication(TEST_APP); + } catch (Exception e) { + // don't fail if clean up doesn't succeed. + log.debug("Error while tearing down", e); + } + } + + @Test(expected = RuntimeException.class) + public void testNoIndexFile() { + try { + pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); + } catch (Throwable t) { + assertTrue(t.getMessage().startsWith("Merged shuffle index file")); + Throwables.propagate(t); + } + } + + @Test + public void testBasicBlockMerge() throws IOException { + PushBlock[] pushBlocks = new PushBlock[] { + new PushBlock(0, 0, 0, ByteBuffer.wrap(new byte[4])), + new PushBlock(0, 1, 0, ByteBuffer.wrap(new byte[5])) + }; + pushBlockHelper(TEST_APP, pushBlocks); + MergeStatuses statuses = pushResolver.finalizeShuffleMerge( + new FinalizeShuffleMerge(TEST_APP, 0)); + validateMergeStatuses(statuses, new int[] {0}, new long[] {9}); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); + validateChunks(TEST_APP, 0, 0, blockMeta, new int[]{4, 5}, new int[][]{{0}, {1}}); + } + + @Test + public void testDividingMergedBlocksIntoChunks() throws IOException { + PushBlock[] pushBlocks = new PushBlock[] { + new PushBlock(0, 0, 0, ByteBuffer.wrap(new byte[2])), + new PushBlock(0, 1, 0, ByteBuffer.wrap(new byte[3])), + new PushBlock(0, 2, 0, ByteBuffer.wrap(new byte[5])), + new PushBlock(0, 3, 0, ByteBuffer.wrap(new byte[3])) + }; + pushBlockHelper(TEST_APP, pushBlocks); + MergeStatuses statuses = pushResolver.finalizeShuffleMerge( + new FinalizeShuffleMerge(TEST_APP, 0)); + validateMergeStatuses(statuses, new int[] {0}, new long[] {13}); + MergedBlockMeta meta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); + validateChunks(TEST_APP, 0, 0, meta, new int[]{5, 5, 3}, new int[][]{{0, 1}, {2}, {3}}); + } + + @Test + public void testFinalizeWithMultipleReducePartitions() throws IOException { + PushBlock[] pushBlocks = new PushBlock[] { + new PushBlock(0, 0, 0, ByteBuffer.wrap(new byte[2])), + new PushBlock(0, 1, 0, ByteBuffer.wrap(new byte[3])), + new PushBlock(0, 0, 1, ByteBuffer.wrap(new byte[5])), + new PushBlock(0, 1, 1, ByteBuffer.wrap(new byte[3])) + }; + pushBlockHelper(TEST_APP, pushBlocks); + MergeStatuses statuses = pushResolver.finalizeShuffleMerge( + new FinalizeShuffleMerge(TEST_APP, 0)); + validateMergeStatuses(statuses, new int[] {0, 1}, new long[] {5, 8}); + MergedBlockMeta meta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); + validateChunks(TEST_APP, 0, 0, meta, new int[]{5}, new int[][]{{0, 1}}); + } + + @Test + public void testDeferredBufsAreWrittenDuringOnData() throws IOException { + StreamCallbackWithID stream1 = + pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 0, 0, 0)); + stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); + StreamCallbackWithID stream2 = + pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 1, 0, 0)); + // This should be deferred + stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[3])); + // stream 1 now completes + stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); + stream1.onComplete(stream1.getID()); + // stream 2 has more data and then completes + stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[3])); + stream2.onComplete(stream2.getID()); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, 0)); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); + validateChunks(TEST_APP, 0, 0, blockMeta, new int[]{4, 6}, new int[][]{{0}, {1}}); + } + + @Test + public void testDeferredBufsAreWrittenDuringOnComplete() throws IOException { + StreamCallbackWithID stream1 = + pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 0, 0, 0)); + stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); + StreamCallbackWithID stream2 = + pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 1, 0, 0)); + // This should be deferred + stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[3])); + stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[3])); + // stream 1 now completes + stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); + stream1.onComplete(stream1.getID()); + // stream 2 now completes completes + stream2.onComplete(stream2.getID()); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, 0)); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); + validateChunks(TEST_APP, 0, 0, blockMeta, new int[]{4, 6}, new int[][]{{0}, {1}}); + } + + @Test + public void testDuplicateBlocksAreIgnoredWhenPrevStreamHasCompleted() throws IOException { + StreamCallbackWithID stream1 = + pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 0, 0, 0)); + stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); + stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); + stream1.onComplete(stream1.getID()); + StreamCallbackWithID stream2 = + pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 0, 0, 0)); + // This should be ignored + stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2])); + stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2])); + stream2.onComplete(stream2.getID()); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, 0)); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); + validateChunks(TEST_APP, 0, 0, blockMeta, new int[]{4}, new int[][]{{0}}); + } + + @Test + public void testDuplicateBlocksAreIgnoredWhenPrevStreamIsInProgress() throws IOException { + StreamCallbackWithID stream1 = + pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 0, 0, 0)); + stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); + StreamCallbackWithID stream2 = + pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 0, 0, 0)); + // This should be ignored + stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2])); + stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[2])); + // stream 1 now completes + stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); + stream1.onComplete(stream1.getID()); + // stream 2 now completes completes + stream2.onComplete(stream2.getID()); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, 0)); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); + validateChunks(TEST_APP, 0, 0, blockMeta, new int[]{4}, new int[][]{{0}}); + } + + @Test + public void testFailureAfterData() throws IOException { + StreamCallbackWithID stream = + pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 0, 0, 0)); + stream.onData(stream.getID(), ByteBuffer.wrap(new byte[4])); + stream.onFailure(stream.getID(), new RuntimeException("Forced Failure")); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, 0)); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); + assertEquals("num-chunks", 0, blockMeta.getNumChunks()); + } + + @Test + public void testFailureAfterMultipleDataBlocks() throws IOException { + StreamCallbackWithID stream = + pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 0, 0, 0)); + stream.onData(stream.getID(), ByteBuffer.wrap(new byte[2])); + stream.onData(stream.getID(), ByteBuffer.wrap(new byte[3])); + stream.onData(stream.getID(), ByteBuffer.wrap(new byte[4])); + stream.onFailure(stream.getID(), new RuntimeException("Forced Failure")); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, 0)); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); + assertEquals("num-chunks", 0, blockMeta.getNumChunks()); + } + + @Test + public void testFailureAfterComplete() throws IOException { + StreamCallbackWithID stream = + pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 0, 0, 0)); + stream.onData(stream.getID(), ByteBuffer.wrap(new byte[2])); + stream.onData(stream.getID(), ByteBuffer.wrap(new byte[3])); + stream.onData(stream.getID(), ByteBuffer.wrap(new byte[4])); + stream.onComplete(stream.getID()); + stream.onFailure(stream.getID(), new RuntimeException("Forced Failure")); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, 0)); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); + validateChunks(TEST_APP, 0, 0, blockMeta, new int[]{9}, new int[][]{{0}}); + } + + @Test (expected = RuntimeException.class) + public void testTooLateArrival() throws IOException { + ByteBuffer[] blocks = new ByteBuffer[]{ + ByteBuffer.wrap(new byte[4]), + ByteBuffer.wrap(new byte[5]) + }; + StreamCallbackWithID stream = pushResolver.receiveBlockDataAsStream( + new PushBlockStream(TEST_APP, 0, 0, 0, 0)); + for (ByteBuffer block : blocks) { + stream.onData(stream.getID(), block); + } + stream.onComplete(stream.getID()); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, 0)); + StreamCallbackWithID stream1 = pushResolver.receiveBlockDataAsStream( + new PushBlockStream(TEST_APP, 0, 1, 0, 0)); + stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[4])); + try { + stream1.onComplete(stream1.getID()); + } catch (RuntimeException re) { + assertEquals( + "Block shufflePush_0_1_0 received after merged shuffle is finalized", + re.getMessage()); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); + validateChunks(TEST_APP, 0, 0, blockMeta, new int[]{9}, new int[][]{{0}}); + throw re; + } + } + + @Test + public void testIncompleteStreamsAreOverwritten() throws IOException { + registerExecutor(TEST_APP, prepareLocalDirs(localDirs)); + StreamCallbackWithID stream1 = + pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 0, 0, 0)); + stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[4])); + // There is a failure + stream1.onFailure(stream1.getID(), new RuntimeException("forced error")); + StreamCallbackWithID stream2 = + pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 1, 0, 0)); + stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[5])); + stream2.onComplete(stream2.getID()); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, 0)); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); + validateChunks(TEST_APP, 0, 0, blockMeta, new int[]{5}, new int[][]{{1}}); + } + + @Test (expected = RuntimeException.class) + public void testCollision() throws IOException { + StreamCallbackWithID stream1 = + pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 0, 0, 0)); + stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); + StreamCallbackWithID stream2 = + pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 1, 0, 0)); + // This should be deferred + stream2.onData(stream2.getID(), ByteBuffer.wrap(new byte[5])); + // Since stream2 didn't get any opportunity it will throw couldn't find opportunity error + try { + stream2.onComplete(stream2.getID()); + } catch (RuntimeException re) { + assertEquals( + "Couldn't find an opportunity to write block shufflePush_0_1_0 to merged shuffle", + re.getMessage()); + throw re; + } + } + + @Test (expected = RuntimeException.class) + public void testFailureInAStreamDoesNotInterfereWithStreamWhichIsWriting() throws IOException { + StreamCallbackWithID stream1 = + pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 0, 0, 0)); + stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); + StreamCallbackWithID stream2 = + pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 1, 0, 0)); + // There is a failure with stream2 + stream2.onFailure(stream2.getID(), new RuntimeException("forced error")); + StreamCallbackWithID stream3 = + pushResolver.receiveBlockDataAsStream(new PushBlockStream(TEST_APP, 0, 2, 0, 0)); + // This should be deferred + stream3.onData(stream3.getID(), ByteBuffer.wrap(new byte[5])); + // Since this stream didn't get any opportunity it will throw couldn't find opportunity error + RuntimeException failedEx = null; + try { + stream3.onComplete(stream3.getID()); + } catch (RuntimeException re) { + assertEquals( + "Couldn't find an opportunity to write block shufflePush_0_2_0 to merged shuffle", + re.getMessage()); + failedEx = re; + } + // stream 1 now completes + stream1.onData(stream1.getID(), ByteBuffer.wrap(new byte[2])); + stream1.onComplete(stream1.getID()); + + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(TEST_APP, 0)); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(TEST_APP, 0, 0); + validateChunks(TEST_APP, 0, 0, blockMeta, new int[] {4}, new int[][] {{0}}); + if (failedEx != null) { + throw failedEx; + } + } + + @Test(expected = NullPointerException.class) + public void testUpdateLocalDirsOnlyOnce() throws IOException { + String testApp = "updateLocalDirsOnlyOnceTest"; + Path[] activeLocalDirs = createLocalDirs(1); + registerExecutor(testApp, prepareLocalDirs(activeLocalDirs)); + assertEquals(pushResolver.getMergedBlockDirs(testApp).length, 1); + assertTrue(pushResolver.getMergedBlockDirs(testApp)[0].contains( + activeLocalDirs[0].toFile().getPath())); + // Any later executor register from the same application should not change the active local + // dirs list + Path[] updatedLocalDirs = localDirs; + registerExecutor(testApp, prepareLocalDirs(updatedLocalDirs)); + assertEquals(pushResolver.getMergedBlockDirs(testApp).length, 1); + assertTrue(pushResolver.getMergedBlockDirs(testApp)[0].contains( + activeLocalDirs[0].toFile().getPath())); + removeApplication(testApp); + try { + pushResolver.getMergedBlockDirs(testApp); + } catch (Throwable e) { + assertTrue(e.getMessage() + .startsWith("application " + testApp + " is not registered or NM was restarted.")); + Throwables.propagate(e); + } + } + + @Test + public void testCleanUpDirectory() throws IOException, InterruptedException { + String testApp = "cleanUpDirectory"; + Semaphore deleted = new Semaphore(0); + pushResolver = new RemoteBlockPushResolver(conf) { + @Override + void deleteExecutorDirs(Path[] dirs) { + super.deleteExecutorDirs(dirs); + deleted.release(); + } + }; + Path[] activeDirs = createLocalDirs(1); + registerExecutor(testApp, prepareLocalDirs(activeDirs)); + PushBlock[] pushBlocks = new PushBlock[] { + new PushBlock(0, 0, 0, ByteBuffer.wrap(new byte[4]))}; + pushBlockHelper(testApp, pushBlocks); + pushResolver.finalizeShuffleMerge(new FinalizeShuffleMerge(testApp, 0)); + MergedBlockMeta blockMeta = pushResolver.getMergedBlockMeta(testApp, 0, 0); + validateChunks(testApp, 0, 0, blockMeta, new int[]{4}, new int[][]{{0}}); + String[] mergeDirs = pushResolver.getMergedBlockDirs(testApp); + pushResolver.applicationRemoved(testApp, true); + // Since the cleanup happen in a different thread, check few times to see if the merge dirs gets + // deleted. + deleted.acquire(); + for (String mergeDir : mergeDirs) { + Assert.assertFalse(Files.exists(Paths.get(mergeDir))); + } + } + + private Path[] createLocalDirs(int numLocalDirs) throws IOException { + Path[] localDirs = new Path[numLocalDirs]; + for (int i = 0; i < localDirs.length; i++) { + localDirs[i] = Files.createTempDirectory("shuffleMerge"); + localDirs[i].toFile().deleteOnExit(); + } + return localDirs; + } + + private void registerExecutor(String appId, String[] localDirs) throws IOException { + ExecutorShuffleInfo shuffleInfo = new ExecutorShuffleInfo(localDirs, 1, "mergedShuffle"); + pushResolver.registerExecutor(appId, shuffleInfo); + } + + private String[] prepareLocalDirs(Path[] localDirs) throws IOException { + String[] blockMgrDirs = new String[localDirs.length]; + for (int i = 0; i< localDirs.length; i++) { + Files.createDirectories(localDirs[i].resolve( + RemoteBlockPushResolver.MERGE_MANAGER_DIR + File.separator + "00")); + blockMgrDirs[i] = localDirs[i].toFile().getPath() + File.separator + BLOCK_MANAGER_DIR; + } + return blockMgrDirs; + } + + private void removeApplication(String appId) { + // PushResolver cleans up the local dirs in a different thread which can conflict with the test + // data of other tests, since they are using the same Application Id. + pushResolver.applicationRemoved(appId, false); + } + + private void validateMergeStatuses( + MergeStatuses mergeStatuses, + int[] expectedReduceIds, + long[] expectedSizes) { + assertArrayEquals(expectedReduceIds, mergeStatuses.reduceIds); + assertArrayEquals(expectedSizes, mergeStatuses.sizes); + } + + private void validateChunks( + String appId, + int shuffleId, + int reduceId, + MergedBlockMeta meta, + int[] expectedSizes, + int[][] expectedMapsPerChunk) throws IOException { + assertEquals("num chunks", expectedSizes.length, meta.getNumChunks()); + RoaringBitmap[] bitmaps = meta.readChunkBitmaps(); + assertEquals("num of bitmaps", meta.getNumChunks(), bitmaps.length); + for (int i = 0; i < meta.getNumChunks(); i++) { + RoaringBitmap chunkBitmap = bitmaps[i]; + Arrays.stream(expectedMapsPerChunk[i]).forEach(x -> assertTrue(chunkBitmap.contains(x))); + } + for (int i = 0; i < meta.getNumChunks(); i++) { + FileSegmentManagedBuffer mb = + (FileSegmentManagedBuffer) pushResolver.getMergedBlockData(appId, shuffleId, reduceId, i); + assertEquals(expectedSizes[i], mb.getLength()); + } + } + + private void pushBlockHelper( + String appId, + PushBlock[] blocks) throws IOException { + for (int i = 0; i < blocks.length; i++) { + StreamCallbackWithID stream = pushResolver.receiveBlockDataAsStream( + new PushBlockStream(appId, blocks[i].shuffleId, blocks[i].mapIndex, blocks[i].reduceId, 0)); + stream.onData(stream.getID(), blocks[i].buffer); + stream.onComplete(stream.getID()); + } + } + + private static class PushBlock { + private final int shuffleId; + private final int mapIndex; + private final int reduceId; + private final ByteBuffer buffer; + PushBlock(int shuffleId, int mapIndex, int reduceId, ByteBuffer buffer) { + this.shuffleId = shuffleId; + this.mapIndex = mapIndex; + this.reduceId = reduceId; + this.buffer = buffer; + } + } +} diff --git a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index 3d14318..548a5cc 100644 --- a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -41,6 +41,7 @@ import org.apache.hadoop.metrics2.impl.MetricsSystemImpl; import org.apache.hadoop.metrics2.lib.DefaultMetricsSystem; import org.apache.hadoop.yarn.api.records.ContainerId; import org.apache.hadoop.yarn.server.api.*; +import org.apache.spark.network.shuffle.MergedShuffleFileManager; import org.apache.spark.network.util.LevelDBProvider; import org.iq80.leveldb.DB; import org.iq80.leveldb.DBIterator; @@ -172,7 +173,10 @@ public class YarnShuffleService extends AuxiliaryService { } TransportConf transportConf = new TransportConf("shuffle", new HadoopConfigProvider(conf)); - blockHandler = new ExternalBlockHandler(transportConf, registeredExecutorFile); + MergedShuffleFileManager shuffleMergeManager = newMergedShuffleFileManagerInstance( + transportConf); + blockHandler = new ExternalBlockHandler( + transportConf, registeredExecutorFile, shuffleMergeManager); // If authentication is enabled, set up the shuffle server to use a // special RPC handler that filters out unauthenticated fetch requests @@ -219,6 +223,23 @@ public class YarnShuffleService extends AuxiliaryService { } } + @VisibleForTesting + static MergedShuffleFileManager newMergedShuffleFileManagerInstance(TransportConf conf) { + String mergeManagerImplClassName = conf.mergedShuffleFileManagerImpl(); + try { + Class<?> mergeManagerImplClazz = Class.forName( + mergeManagerImplClassName, true, Thread.currentThread().getContextClassLoader()); + Class<? extends MergedShuffleFileManager> mergeManagerSubClazz = + mergeManagerImplClazz.asSubclass(MergedShuffleFileManager.class); + // The assumption is that all the custom implementations just like the RemoteBlockPushResolver + // will also need the transport configuration. + return mergeManagerSubClazz.getConstructor(TransportConf.class).newInstance(conf); + } catch (Exception e) { + logger.error("Unable to create an instance of {}", mergeManagerImplClassName); + return new ExternalBlockHandler.NoOpMergedShuffleFileManager(conf); + } + } + private void loadSecretsFromDb() throws IOException { secretsFile = initRecoveryDb(SECRETS_RECOVERY_FILE_NAME); diff --git a/common/network-yarn/src/test/java/org/apache/spark/network/yarn/YarnShuffleServiceSuite.java b/common/network-yarn/src/test/java/org/apache/spark/network/yarn/YarnShuffleServiceSuite.java new file mode 100644 index 0000000..09bc4d8 --- /dev/null +++ b/common/network-yarn/src/test/java/org/apache/spark/network/yarn/YarnShuffleServiceSuite.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.network.yarn; + +import org.junit.Test; + +import static org.junit.Assert.*; +import static org.mockito.Mockito.*; + +import org.apache.spark.network.shuffle.ExternalBlockHandler; +import org.apache.spark.network.shuffle.MergedShuffleFileManager; +import org.apache.spark.network.shuffle.RemoteBlockPushResolver; +import org.apache.spark.network.util.TransportConf; + +public class YarnShuffleServiceSuite { + + @Test + public void testCreateDefaultMergedShuffleFileManagerInstance() { + TransportConf mockConf = mock(TransportConf.class); + when(mockConf.mergedShuffleFileManagerImpl()).thenReturn( + "org.apache.spark.network.shuffle.ExternalBlockHandler$NoOpMergedShuffleFileManager"); + MergedShuffleFileManager mergeMgr = YarnShuffleService.newMergedShuffleFileManagerInstance( + mockConf); + assertTrue(mergeMgr instanceof ExternalBlockHandler.NoOpMergedShuffleFileManager); + } + + @Test + public void testCreateRemoteBlockPushResolverInstance() { + TransportConf mockConf = mock(TransportConf.class); + when(mockConf.mergedShuffleFileManagerImpl()).thenReturn( + "org.apache.spark.network.shuffle.RemoteBlockPushResolver"); + MergedShuffleFileManager mergeMgr = YarnShuffleService.newMergedShuffleFileManagerInstance( + mockConf); + assertTrue(mergeMgr instanceof RemoteBlockPushResolver); + } + + @Test + public void testInvalidClassNameOfMergeManagerWillUseNoOpInstance() { + TransportConf mockConf = mock(TransportConf.class); + when(mockConf.mergedShuffleFileManagerImpl()).thenReturn( + "org.apache.spark.network.shuffle.NotExistent"); + MergedShuffleFileManager mergeMgr = YarnShuffleService.newMergedShuffleFileManagerInstance( + mockConf); + assertTrue(mergeMgr instanceof ExternalBlockHandler.NoOpMergedShuffleFileManager); + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org