This is an automated email from the ASF dual-hosted git repository. vanzin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new abef84a [SPARK-28209][CORE][SHUFFLE] Proposed new shuffle writer API abef84a is described below commit abef84a868e9e15f346eea315bbab0ec8ac8e389 Author: mcheah <mch...@palantir.com> AuthorDate: Tue Jul 30 14:17:30 2019 -0700 [SPARK-28209][CORE][SHUFFLE] Proposed new shuffle writer API ## What changes were proposed in this pull request? As part of the shuffle storage API proposed in SPARK-25299, this introduces an API for persisting shuffle data in arbitrary storage systems. This patch introduces several concepts: * `ShuffleDataIO`, which is the root of the entire plugin tree that will be proposed over the course of the shuffle API project. * `ShuffleExecutorComponents` - the subset of plugins for managing shuffle-related components for each executor. This will in turn instantiate shuffle readers and writers. * `ShuffleMapOutputWriter` interface - instantiated once per map task. This provides child `ShufflePartitionWriter` instances for persisting the bytes for each partition in the map task. The default implementation of these plugins exactly mirror what was done by the existing shuffle writing code - namely, writing the data to local disk and writing an index file. We leverage the APIs in the `BypassMergeSortShuffleWriter` only. Follow-up PRs will use the APIs in `SortShuffleWriter` and `UnsafeShuffleWriter`, but are left as future work to minimize the review surface area. ## How was this patch tested? New unit tests were added. Micro-benchmarks indicate there's no slowdown in the affected code paths. Closes #25007 from mccheah/spark-shuffle-writer-refactor. Lead-authored-by: mcheah <mch...@palantir.com> Co-authored-by: mccheah <mch...@palantir.com> Signed-off-by: Marcelo Vanzin <van...@cloudera.com> --- .../apache/spark/shuffle/api/ShuffleDataIO.java | 49 ++++ .../shuffle/api/ShuffleExecutorComponents.java | 55 +++++ .../spark/shuffle/api/ShuffleMapOutputWriter.java | 71 ++++++ .../spark/shuffle/api/ShufflePartitionWriter.java | 98 ++++++++ .../shuffle/api/WritableByteChannelWrapper.java | 42 ++++ .../shuffle/sort/BypassMergeSortShuffleWriter.java | 173 +++++++++----- .../shuffle/sort/io/LocalDiskShuffleDataIO.java | 40 ++++ .../io/LocalDiskShuffleExecutorComponents.java | 71 ++++++ .../sort/io/LocalDiskShuffleMapOutputWriter.java | 261 +++++++++++++++++++++ .../org/apache/spark/internal/config/package.scala | 7 + .../spark/shuffle/sort/SortShuffleManager.scala | 25 +- .../main/scala/org/apache/spark/util/Utils.scala | 30 ++- .../test/scala/org/apache/spark/ShuffleSuite.scala | 16 +- .../sort/BypassMergeSortShuffleWriterSuite.scala | 149 +++++++----- .../io/LocalDiskShuffleMapOutputWriterSuite.scala | 147 ++++++++++++ 15 files changed, 1087 insertions(+), 147 deletions(-) diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java new file mode 100644 index 0000000..e9e50ec --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleDataIO.java @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.api; + +import org.apache.spark.annotation.Private; + +/** + * :: Private :: + * An interface for plugging in modules for storing and reading temporary shuffle data. + * <p> + * This is the root of a plugin system for storing shuffle bytes to arbitrary storage + * backends in the sort-based shuffle algorithm implemented by the + * {@link org.apache.spark.shuffle.sort.SortShuffleManager}. If another shuffle algorithm is + * needed instead of sort-based shuffle, one should implement + * {@link org.apache.spark.shuffle.ShuffleManager} instead. + * <p> + * A single instance of this module is loaded per process in the Spark application. + * The default implementation reads and writes shuffle data from the local disks of + * the executor, and is the implementation of shuffle file storage that has remained + * consistent throughout most of Spark's history. + * <p> + * Alternative implementations of shuffle data storage can be loaded via setting + * <code>spark.shuffle.sort.io.plugin.class</code>. + * @since 3.0.0 + */ +@Private +public interface ShuffleDataIO { + + /** + * Called once on executor processes to bootstrap the shuffle data storage modules that + * are only invoked on the executors. + */ + ShuffleExecutorComponents executor(); +} diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java new file mode 100644 index 0000000..70c112b --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleExecutorComponents.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.api; + +import java.io.IOException; + +import org.apache.spark.annotation.Private; + +/** + * :: Private :: + * An interface for building shuffle support for Executors. + * + * @since 3.0.0 + */ +@Private +public interface ShuffleExecutorComponents { + + /** + * Called once per executor to bootstrap this module with state that is specific to + * that executor, specifically the application ID and executor ID. + */ + void initializeExecutor(String appId, String execId); + + /** + * Called once per map task to create a writer that will be responsible for persisting all the + * partitioned bytes written by that map task. + * @param shuffleId Unique identifier for the shuffle the map task is a part of + * @param mapId Within the shuffle, the identifier of the map task + * @param mapTaskAttemptId Identifier of the task attempt. Multiple attempts of the same map task + * with the same (shuffleId, mapId) pair can be distinguished by the + * different values of mapTaskAttemptId. + * @param numPartitions The number of partitions that will be written by the map task. Some of +* these partitions may be empty. + */ + ShuffleMapOutputWriter createMapOutputWriter( + int shuffleId, + int mapId, + long mapTaskAttemptId, + int numPartitions) throws IOException; +} diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java new file mode 100644 index 0000000..45a593c --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShuffleMapOutputWriter.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.api; + +import java.io.IOException; + +import org.apache.spark.annotation.Private; + +/** + * :: Private :: + * A top-level writer that returns child writers for persisting the output of a map task, + * and then commits all of the writes as one atomic operation. + * + * @since 3.0.0 + */ +@Private +public interface ShuffleMapOutputWriter { + + /** + * Creates a writer that can open an output stream to persist bytes targeted for a given reduce + * partition id. + * <p> + * The chunk corresponds to bytes in the given reduce partition. This will not be called twice + * for the same partition within any given map task. The partition identifier will be in the + * range of precisely 0 (inclusive) to numPartitions (exclusive), where numPartitions was + * provided upon the creation of this map output writer via + * {@link ShuffleExecutorComponents#createMapOutputWriter(int, int, long, int)}. + * <p> + * Calls to this method will be invoked with monotonically increasing reducePartitionIds; each + * call to this method will be called with a reducePartitionId that is strictly greater than + * the reducePartitionIds given to any previous call to this method. This method is not + * guaranteed to be called for every partition id in the above described range. In particular, + * no guarantees are made as to whether or not this method will be called for empty partitions. + */ + ShufflePartitionWriter getPartitionWriter(int reducePartitionId) throws IOException; + + /** + * Commits the writes done by all partition writers returned by all calls to this object's + * {@link #getPartitionWriter(int)}. + * <p> + * This should ensure that the writes conducted by this module's partition writers are + * available to downstream reduce tasks. If this method throws any exception, this module's + * {@link #abort(Throwable)} method will be invoked before propagating the exception. + * <p> + * This can also close any resources and clean up temporary state if necessary. + */ + void commitAllPartitions() throws IOException; + + /** + * Abort all of the writes done by any writers returned by {@link #getPartitionWriter(int)}. + * <p> + * This should invalidate the results of writing bytes. This can also close any resources and + * clean up temporary state if necessary. + */ + void abort(Throwable error) throws IOException; +} diff --git a/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java new file mode 100644 index 0000000..9288751 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/api/ShufflePartitionWriter.java @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.api; + +import java.io.IOException; +import java.util.Optional; +import java.io.OutputStream; + +import org.apache.spark.annotation.Private; + +/** + * :: Private :: + * An interface for opening streams to persist partition bytes to a backing data store. + * <p> + * This writer stores bytes for one (mapper, reducer) pair, corresponding to one shuffle + * block. + * + * @since 3.0.0 + */ +@Private +public interface ShufflePartitionWriter { + + /** + * Open and return an {@link OutputStream} that can write bytes to the underlying + * data store. + * <p> + * This method will only be called once on this partition writer in the map task, to write the + * bytes to the partition. The output stream will only be used to write the bytes for this + * partition. The map task closes this output stream upon writing all the bytes for this + * block, or if the write fails for any reason. + * <p> + * Implementations that intend on combining the bytes for all the partitions written by this + * map task should reuse the same OutputStream instance across all the partition writers provided + * by the parent {@link ShuffleMapOutputWriter}. If one does so, ensure that + * {@link OutputStream#close()} does not close the resource, since it will be reused across + * partition writes. The underlying resources should be cleaned up in + * {@link ShuffleMapOutputWriter#commitAllPartitions()} and + * {@link ShuffleMapOutputWriter#abort(Throwable)}. + */ + OutputStream openStream() throws IOException; + + /** + * Opens and returns a {@link WritableByteChannelWrapper} for transferring bytes from + * input byte channels to the underlying shuffle data store. + * <p> + * This method will only be called once on this partition writer in the map task, to write the + * bytes to the partition. The channel will only be used to write the bytes for this + * partition. The map task closes this channel upon writing all the bytes for this + * block, or if the write fails for any reason. + * <p> + * Implementations that intend on combining the bytes for all the partitions written by this + * map task should reuse the same channel instance across all the partition writers provided + * by the parent {@link ShuffleMapOutputWriter}. If one does so, ensure that + * {@link WritableByteChannelWrapper#close()} does not close the resource, since the channel + * will be reused across partition writes. The underlying resources should be cleaned up in + * {@link ShuffleMapOutputWriter#commitAllPartitions()} and + * {@link ShuffleMapOutputWriter#abort(Throwable)}. + * <p> + * This method is primarily for advanced optimizations where bytes can be copied from the input + * spill files to the output channel without copying data into memory. If such optimizations are + * not supported, the implementation should return {@link Optional#empty()}. By default, the + * implementation returns {@link Optional#empty()}. + * <p> + * Note that the returned {@link WritableByteChannelWrapper} itself is closed, but not the + * underlying channel that is returned by {@link WritableByteChannelWrapper#channel()}. Ensure + * that the underlying channel is cleaned up in {@link WritableByteChannelWrapper#close()}, + * {@link ShuffleMapOutputWriter#commitAllPartitions()}, or + * {@link ShuffleMapOutputWriter#abort(Throwable)}. + */ + default Optional<WritableByteChannelWrapper> openChannelWrapper() throws IOException { + return Optional.empty(); + } + + /** + * Returns the number of bytes written either by this writer's output stream opened by + * {@link #openStream()} or the byte channel opened by {@link #openChannelWrapper()}. + * <p> + * This can be different from the number of bytes given by the caller. For example, the + * stream might compress or encrypt the bytes before persisting the data to the backing + * data store. + */ + long getNumBytesWritten(); +} diff --git a/core/src/main/java/org/apache/spark/shuffle/api/WritableByteChannelWrapper.java b/core/src/main/java/org/apache/spark/shuffle/api/WritableByteChannelWrapper.java new file mode 100644 index 0000000..a204903 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/api/WritableByteChannelWrapper.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.api; + +import java.io.Closeable; +import java.nio.channels.WritableByteChannel; + +import org.apache.spark.annotation.Private; + +/** + * :: Private :: + * + * A thin wrapper around a {@link WritableByteChannel}. + * <p> + * This is primarily provided for the local disk shuffle implementation to provide a + * {@link java.nio.channels.FileChannel} that keeps the channel open across partition writes. + * + * @since 3.0.0 + */ +@Private +public interface WritableByteChannelWrapper extends Closeable { + + /** + * The underlying channel to write bytes into. + */ + WritableByteChannel channel(); +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 32b4467..3ccee70 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -19,8 +19,10 @@ package org.apache.spark.shuffle.sort; import java.io.File; import java.io.FileInputStream; -import java.io.FileOutputStream; import java.io.IOException; +import java.io.OutputStream; +import java.nio.channels.FileChannel; +import java.util.Optional; import javax.annotation.Nullable; import scala.None$; @@ -34,16 +36,19 @@ import com.google.common.io.Closeables; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import org.apache.spark.internal.config.package$; import org.apache.spark.Partitioner; import org.apache.spark.ShuffleDependency; import org.apache.spark.SparkConf; +import org.apache.spark.shuffle.api.ShuffleExecutorComponents; +import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; +import org.apache.spark.shuffle.api.ShufflePartitionWriter; +import org.apache.spark.shuffle.api.WritableByteChannelWrapper; +import org.apache.spark.internal.config.package$; import org.apache.spark.scheduler.MapStatus; import org.apache.spark.scheduler.MapStatus$; import org.apache.spark.serializer.Serializer; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.shuffle.ShuffleWriteMetricsReporter; -import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.shuffle.ShuffleWriter; import org.apache.spark.storage.*; import org.apache.spark.util.Utils; @@ -81,8 +86,9 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> { private final ShuffleWriteMetricsReporter writeMetrics; private final int shuffleId; private final int mapId; + private final long mapTaskAttemptId; private final Serializer serializer; - private final IndexShuffleBlockResolver shuffleBlockResolver; + private final ShuffleExecutorComponents shuffleExecutorComponents; /** Array of file writers, one for each partition */ private DiskBlockObjectWriter[] partitionWriters; @@ -99,74 +105,82 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> { BypassMergeSortShuffleWriter( BlockManager blockManager, - IndexShuffleBlockResolver shuffleBlockResolver, BypassMergeSortShuffleHandle<K, V> handle, int mapId, + long mapTaskAttemptId, SparkConf conf, - ShuffleWriteMetricsReporter writeMetrics) { + ShuffleWriteMetricsReporter writeMetrics, + ShuffleExecutorComponents shuffleExecutorComponents) { // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided this.fileBufferSize = (int) (long) conf.get(package$.MODULE$.SHUFFLE_FILE_BUFFER_SIZE()) * 1024; this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true); this.blockManager = blockManager; final ShuffleDependency<K, V, V> dep = handle.dependency(); this.mapId = mapId; + this.mapTaskAttemptId = mapTaskAttemptId; this.shuffleId = dep.shuffleId(); this.partitioner = dep.partitioner(); this.numPartitions = partitioner.numPartitions(); this.writeMetrics = writeMetrics; this.serializer = dep.serializer(); - this.shuffleBlockResolver = shuffleBlockResolver; + this.shuffleExecutorComponents = shuffleExecutorComponents; } @Override public void write(Iterator<Product2<K, V>> records) throws IOException { assert (partitionWriters == null); - if (!records.hasNext()) { - partitionLengths = new long[numPartitions]; - shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null); - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); - return; - } - final SerializerInstance serInstance = serializer.newInstance(); - final long openStartTime = System.nanoTime(); - partitionWriters = new DiskBlockObjectWriter[numPartitions]; - partitionWriterSegments = new FileSegment[numPartitions]; - for (int i = 0; i < numPartitions; i++) { - final Tuple2<TempShuffleBlockId, File> tempShuffleBlockIdPlusFile = - blockManager.diskBlockManager().createTempShuffleBlock(); - final File file = tempShuffleBlockIdPlusFile._2(); - final BlockId blockId = tempShuffleBlockIdPlusFile._1(); - partitionWriters[i] = - blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics); - } - // Creating the file to write to and creating a disk writer both involve interacting with - // the disk, and can take a long time in aggregate when we open many files, so should be - // included in the shuffle write time. - writeMetrics.incWriteTime(System.nanoTime() - openStartTime); - - while (records.hasNext()) { - final Product2<K, V> record = records.next(); - final K key = record._1(); - partitionWriters[partitioner.getPartition(key)].write(key, record._2()); - } + ShuffleMapOutputWriter mapOutputWriter = shuffleExecutorComponents + .createMapOutputWriter(shuffleId, mapId, mapTaskAttemptId, numPartitions); + try { + if (!records.hasNext()) { + partitionLengths = new long[numPartitions]; + mapOutputWriter.commitAllPartitions(); + mapStatus = MapStatus$.MODULE$.apply( + blockManager.shuffleServerId(), + partitionLengths); + return; + } + final SerializerInstance serInstance = serializer.newInstance(); + final long openStartTime = System.nanoTime(); + partitionWriters = new DiskBlockObjectWriter[numPartitions]; + partitionWriterSegments = new FileSegment[numPartitions]; + for (int i = 0; i < numPartitions; i++) { + final Tuple2<TempShuffleBlockId, File> tempShuffleBlockIdPlusFile = + blockManager.diskBlockManager().createTempShuffleBlock(); + final File file = tempShuffleBlockIdPlusFile._2(); + final BlockId blockId = tempShuffleBlockIdPlusFile._1(); + partitionWriters[i] = + blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, writeMetrics); + } + // Creating the file to write to and creating a disk writer both involve interacting with + // the disk, and can take a long time in aggregate when we open many files, so should be + // included in the shuffle write time. + writeMetrics.incWriteTime(System.nanoTime() - openStartTime); - for (int i = 0; i < numPartitions; i++) { - try (DiskBlockObjectWriter writer = partitionWriters[i]) { - partitionWriterSegments[i] = writer.commitAndGet(); + while (records.hasNext()) { + final Product2<K, V> record = records.next(); + final K key = record._1(); + partitionWriters[partitioner.getPartition(key)].write(key, record._2()); } - } - File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); - File tmp = Utils.tempFileWith(output); - try { - partitionLengths = writePartitionedFile(tmp); - shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp); - } finally { - if (tmp.exists() && !tmp.delete()) { - logger.error("Error while deleting temp file {}", tmp.getAbsolutePath()); + for (int i = 0; i < numPartitions; i++) { + try (DiskBlockObjectWriter writer = partitionWriters[i]) { + partitionWriterSegments[i] = writer.commitAndGet(); + } + } + + partitionLengths = writePartitionedData(mapOutputWriter); + mapOutputWriter.commitAllPartitions(); + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); + } catch (Exception e) { + try { + mapOutputWriter.abort(e); + } catch (Exception e2) { + logger.error("Failed to abort the writer after failing to write map output.", e2); + e.addSuppressed(e2); } + throw e; } - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); } @VisibleForTesting @@ -179,43 +193,80 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> { * * @return array of lengths, in bytes, of each partition of the file (used by map output tracker). */ - private long[] writePartitionedFile(File outputFile) throws IOException { + private long[] writePartitionedData(ShuffleMapOutputWriter mapOutputWriter) throws IOException { // Track location of the partition starts in the output file final long[] lengths = new long[numPartitions]; if (partitionWriters == null) { // We were passed an empty iterator return lengths; } - - final FileOutputStream out = new FileOutputStream(outputFile, true); final long writeStartTime = System.nanoTime(); - boolean threwException = true; try { for (int i = 0; i < numPartitions; i++) { final File file = partitionWriterSegments[i].file(); + ShufflePartitionWriter writer = mapOutputWriter.getPartitionWriter(i); if (file.exists()) { - final FileInputStream in = new FileInputStream(file); - boolean copyThrewException = true; - try { - lengths[i] = Utils.copyStream(in, out, false, transferToEnabled); - copyThrewException = false; - } finally { - Closeables.close(in, copyThrewException); + if (transferToEnabled) { + // Using WritableByteChannelWrapper to make resource closing consistent between + // this implementation and UnsafeShuffleWriter. + Optional<WritableByteChannelWrapper> maybeOutputChannel = writer.openChannelWrapper(); + if (maybeOutputChannel.isPresent()) { + writePartitionedDataWithChannel(file, maybeOutputChannel.get()); + } else { + writePartitionedDataWithStream(file, writer); + } + } else { + writePartitionedDataWithStream(file, writer); } if (!file.delete()) { logger.error("Unable to delete file for partition {}", i); } } + lengths[i] = writer.getNumBytesWritten(); } - threwException = false; } finally { - Closeables.close(out, threwException); writeMetrics.incWriteTime(System.nanoTime() - writeStartTime); } partitionWriters = null; return lengths; } + private void writePartitionedDataWithChannel( + File file, + WritableByteChannelWrapper outputChannel) throws IOException { + boolean copyThrewException = true; + try { + FileInputStream in = new FileInputStream(file); + try (FileChannel inputChannel = in.getChannel()) { + Utils.copyFileStreamNIO( + inputChannel, outputChannel.channel(), 0L, inputChannel.size()); + copyThrewException = false; + } finally { + Closeables.close(in, copyThrewException); + } + } finally { + Closeables.close(outputChannel, copyThrewException); + } + } + + private void writePartitionedDataWithStream(File file, ShufflePartitionWriter writer) + throws IOException { + boolean copyThrewException = true; + FileInputStream in = new FileInputStream(file); + OutputStream outputStream; + try { + outputStream = writer.openStream(); + try { + Utils.copyStream(in, outputStream, false, false); + copyThrewException = false; + } finally { + Closeables.close(outputStream, copyThrewException); + } + } finally { + Closeables.close(in, copyThrewException); + } + } + @Override public Option<MapStatus> stop(boolean success) { if (stopping) { diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java new file mode 100644 index 0000000..cabcb17 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleDataIO.java @@ -0,0 +1,40 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.io; + +import org.apache.spark.SparkConf; +import org.apache.spark.shuffle.api.ShuffleExecutorComponents; +import org.apache.spark.shuffle.api.ShuffleDataIO; + +/** + * Implementation of the {@link ShuffleDataIO} plugin system that replicates the local shuffle + * storage and index file functionality that has historically been used from Spark 2.4 and earlier. + */ +public class LocalDiskShuffleDataIO implements ShuffleDataIO { + + private final SparkConf sparkConf; + + public LocalDiskShuffleDataIO(SparkConf sparkConf) { + this.sparkConf = sparkConf; + } + + @Override + public ShuffleExecutorComponents executor() { + return new LocalDiskShuffleExecutorComponents(sparkConf); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java new file mode 100644 index 0000000..02eb710 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleExecutorComponents.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.io; + +import com.google.common.annotations.VisibleForTesting; + +import org.apache.spark.SparkConf; +import org.apache.spark.SparkEnv; +import org.apache.spark.shuffle.api.ShuffleExecutorComponents; +import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.storage.BlockManager; + +public class LocalDiskShuffleExecutorComponents implements ShuffleExecutorComponents { + + private final SparkConf sparkConf; + private BlockManager blockManager; + private IndexShuffleBlockResolver blockResolver; + + public LocalDiskShuffleExecutorComponents(SparkConf sparkConf) { + this.sparkConf = sparkConf; + } + + @VisibleForTesting + public LocalDiskShuffleExecutorComponents( + SparkConf sparkConf, + BlockManager blockManager, + IndexShuffleBlockResolver blockResolver) { + this.sparkConf = sparkConf; + this.blockManager = blockManager; + this.blockResolver = blockResolver; + } + + @Override + public void initializeExecutor(String appId, String execId) { + blockManager = SparkEnv.get().blockManager(); + if (blockManager == null) { + throw new IllegalStateException("No blockManager available from the SparkEnv."); + } + blockResolver = new IndexShuffleBlockResolver(sparkConf, blockManager); + } + + @Override + public ShuffleMapOutputWriter createMapOutputWriter( + int shuffleId, + int mapId, + long mapTaskAttemptId, + int numPartitions) { + if (blockResolver == null) { + throw new IllegalStateException( + "Executor components must be initialized before getting writers."); + } + return new LocalDiskShuffleMapOutputWriter( + shuffleId, mapId, numPartitions, blockResolver, sparkConf); + } +} diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java new file mode 100644 index 0000000..add4634 --- /dev/null +++ b/core/src/main/java/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriter.java @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.io; + +import java.io.BufferedOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.channels.FileChannel; +import java.nio.channels.WritableByteChannel; + +import java.util.Optional; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.spark.SparkConf; +import org.apache.spark.shuffle.api.ShuffleMapOutputWriter; +import org.apache.spark.shuffle.api.ShufflePartitionWriter; +import org.apache.spark.shuffle.api.WritableByteChannelWrapper; +import org.apache.spark.internal.config.package$; +import org.apache.spark.shuffle.IndexShuffleBlockResolver; +import org.apache.spark.util.Utils; + +/** + * Implementation of {@link ShuffleMapOutputWriter} that replicates the functionality of shuffle + * persisting shuffle data to local disk alongside index files, identical to Spark's historic + * canonical shuffle storage mechanism. + */ +public class LocalDiskShuffleMapOutputWriter implements ShuffleMapOutputWriter { + + private static final Logger log = + LoggerFactory.getLogger(LocalDiskShuffleMapOutputWriter.class); + + private final int shuffleId; + private final int mapId; + private final IndexShuffleBlockResolver blockResolver; + private final long[] partitionLengths; + private final int bufferSize; + private int lastPartitionId = -1; + private long currChannelPosition; + + private final File outputFile; + private File outputTempFile; + private FileOutputStream outputFileStream; + private FileChannel outputFileChannel; + private BufferedOutputStream outputBufferedFileStream; + + public LocalDiskShuffleMapOutputWriter( + int shuffleId, + int mapId, + int numPartitions, + IndexShuffleBlockResolver blockResolver, + SparkConf sparkConf) { + this.shuffleId = shuffleId; + this.mapId = mapId; + this.blockResolver = blockResolver; + this.bufferSize = + (int) (long) sparkConf.get( + package$.MODULE$.SHUFFLE_UNSAFE_FILE_OUTPUT_BUFFER_SIZE()) * 1024; + this.partitionLengths = new long[numPartitions]; + this.outputFile = blockResolver.getDataFile(shuffleId, mapId); + this.outputTempFile = null; + } + + @Override + public ShufflePartitionWriter getPartitionWriter(int reducePartitionId) throws IOException { + if (reducePartitionId <= lastPartitionId) { + throw new IllegalArgumentException("Partitions should be requested in increasing order."); + } + lastPartitionId = reducePartitionId; + if (outputTempFile == null) { + outputTempFile = Utils.tempFileWith(outputFile); + } + if (outputFileChannel != null) { + currChannelPosition = outputFileChannel.position(); + } else { + currChannelPosition = 0L; + } + return new LocalDiskShufflePartitionWriter(reducePartitionId); + } + + @Override + public void commitAllPartitions() throws IOException { + cleanUp(); + File resolvedTmp = outputTempFile != null && outputTempFile.isFile() ? outputTempFile : null; + blockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, resolvedTmp); + } + + @Override + public void abort(Throwable error) throws IOException { + cleanUp(); + if (outputTempFile != null && outputTempFile.exists() && !outputTempFile.delete()) { + log.warn("Failed to delete temporary shuffle file at {}", outputTempFile.getAbsolutePath()); + } + } + + private void cleanUp() throws IOException { + if (outputBufferedFileStream != null) { + outputBufferedFileStream.close(); + } + if (outputFileChannel != null) { + outputFileChannel.close(); + } + if (outputFileStream != null) { + outputFileStream.close(); + } + } + + private void initStream() throws IOException { + if (outputFileStream == null) { + outputFileStream = new FileOutputStream(outputTempFile, true); + } + if (outputBufferedFileStream == null) { + outputBufferedFileStream = new BufferedOutputStream(outputFileStream, bufferSize); + } + } + + private void initChannel() throws IOException { + if (outputFileStream == null) { + outputFileStream = new FileOutputStream(outputTempFile, true); + } + if (outputFileChannel == null) { + outputFileChannel = outputFileStream.getChannel(); + } + } + + private class LocalDiskShufflePartitionWriter implements ShufflePartitionWriter { + + private final int partitionId; + private PartitionWriterStream partStream = null; + private PartitionWriterChannel partChannel = null; + + private LocalDiskShufflePartitionWriter(int partitionId) { + this.partitionId = partitionId; + } + + @Override + public OutputStream openStream() throws IOException { + if (partStream == null) { + if (outputFileChannel != null) { + throw new IllegalStateException("Requested an output channel for a previous write but" + + " now an output stream has been requested. Should not be using both channels" + + " and streams to write."); + } + initStream(); + partStream = new PartitionWriterStream(partitionId); + } + return partStream; + } + + @Override + public Optional<WritableByteChannelWrapper> openChannelWrapper() throws IOException { + if (partChannel == null) { + if (partStream != null) { + throw new IllegalStateException("Requested an output stream for a previous write but" + + " now an output channel has been requested. Should not be using both channels" + + " and streams to write."); + } + initChannel(); + partChannel = new PartitionWriterChannel(partitionId); + } + return Optional.of(partChannel); + } + + @Override + public long getNumBytesWritten() { + if (partChannel != null) { + try { + return partChannel.getCount(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } else if (partStream != null) { + return partStream.getCount(); + } else { + // Assume an empty partition if stream and channel are never created + return 0; + } + } + } + + private class PartitionWriterStream extends OutputStream { + private final int partitionId; + private int count = 0; + private boolean isClosed = false; + + PartitionWriterStream(int partitionId) { + this.partitionId = partitionId; + } + + public int getCount() { + return count; + } + + @Override + public void write(int b) throws IOException { + verifyNotClosed(); + outputBufferedFileStream.write(b); + count++; + } + + @Override + public void write(byte[] buf, int pos, int length) throws IOException { + verifyNotClosed(); + outputBufferedFileStream.write(buf, pos, length); + count += length; + } + + @Override + public void close() { + isClosed = true; + partitionLengths[partitionId] = count; + } + + private void verifyNotClosed() { + if (isClosed) { + throw new IllegalStateException("Attempting to write to a closed block output stream."); + } + } + } + + private class PartitionWriterChannel implements WritableByteChannelWrapper { + + private final int partitionId; + + PartitionWriterChannel(int partitionId) { + this.partitionId = partitionId; + } + + public long getCount() throws IOException { + long writtenPosition = outputFileChannel.position(); + return writtenPosition - currChannelPosition; + } + + @Override + public WritableByteChannel channel() { + return outputFileChannel; + } + + @Override + public void close() throws IOException { + partitionLengths[partitionId] = getCount(); + } + } +} diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index f2b88fe..cda3b57 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -24,6 +24,7 @@ import org.apache.spark.metrics.GarbageCollectionMetrics import org.apache.spark.network.shuffle.Constants import org.apache.spark.network.util.ByteUnit import org.apache.spark.scheduler.{EventLoggingListener, SchedulingMode} +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleDataIO import org.apache.spark.storage.{DefaultTopologyMapper, RandomBlockReplicationPolicy} import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.util.Utils @@ -811,6 +812,12 @@ package object config { .booleanConf .createWithDefault(false) + private[spark] val SHUFFLE_IO_PLUGIN_CLASS = + ConfigBuilder("spark.shuffle.sort.io.plugin.class") + .doc("Name of the class to use for shuffle IO.") + .stringConf + .createWithDefault(classOf[LocalDiskShuffleDataIO].getName) + private[spark] val SHUFFLE_FILE_BUFFER_SIZE = ConfigBuilder("spark.shuffle.file.buffer") .doc("Size of the in-memory buffer for each shuffle file output stream, in KiB unless " + diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala index b59fa8e..17719f5 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala @@ -20,8 +20,10 @@ package org.apache.spark.shuffle.sort import java.util.concurrent.ConcurrentHashMap import org.apache.spark._ -import org.apache.spark.internal.Logging +import org.apache.spark.internal.{config, Logging} import org.apache.spark.shuffle._ +import org.apache.spark.shuffle.api.{ShuffleDataIO, ShuffleExecutorComponents} +import org.apache.spark.util.Utils /** * In sort-based shuffle, incoming records are sorted according to their target partition ids, then @@ -68,6 +70,8 @@ import org.apache.spark.shuffle._ */ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { + import SortShuffleManager._ + if (!conf.getBoolean("spark.shuffle.spill", true)) { logWarning( "spark.shuffle.spill was set to false, but this configuration is ignored as of Spark 1.6+." + @@ -79,6 +83,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager */ private[this] val numMapsForShuffle = new ConcurrentHashMap[Int, Int]() + private lazy val shuffleExecutorComponents = loadShuffleExecutorComponents(conf) + override val shuffleBlockResolver = new IndexShuffleBlockResolver(conf) /** @@ -134,7 +140,7 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager case unsafeShuffleHandle: SerializedShuffleHandle[K @unchecked, V @unchecked] => new UnsafeShuffleWriter( env.blockManager, - shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], + shuffleBlockResolver, context.taskMemoryManager(), unsafeShuffleHandle, mapId, @@ -144,11 +150,12 @@ private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V @unchecked] => new BypassMergeSortShuffleWriter( env.blockManager, - shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver], bypassMergeSortHandle, mapId, + context.taskAttemptId(), env.conf, - metrics) + metrics, + shuffleExecutorComponents) case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] => new SortShuffleWriter(shuffleBlockResolver, other, mapId, context) } @@ -205,6 +212,16 @@ private[spark] object SortShuffleManager extends Logging { true } } + + private def loadShuffleExecutorComponents(conf: SparkConf): ShuffleExecutorComponents = { + val configuredPluginClasses = conf.get(config.SHUFFLE_IO_PLUGIN_CLASS) + val maybeIO = Utils.loadExtensions( + classOf[ShuffleDataIO], Seq(configuredPluginClasses), conf) + require(maybeIO.size == 1, s"Failed to load plugins of type $configuredPluginClasses") + val executorComponents = maybeIO.head.executor() + executorComponents.initializeExecutor(conf.getAppId, SparkEnv.get.executorId) + executorComponents + } } /** diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 80d70a1..24042db 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -24,7 +24,7 @@ import java.lang.reflect.InvocationTargetException import java.math.{MathContext, RoundingMode} import java.net._ import java.nio.ByteBuffer -import java.nio.channels.{Channels, FileChannel} +import java.nio.channels.{Channels, FileChannel, WritableByteChannel} import java.nio.charset.StandardCharsets import java.nio.file.Files import java.security.SecureRandom @@ -394,10 +394,14 @@ private[spark] object Utils extends Logging { def copyFileStreamNIO( input: FileChannel, - output: FileChannel, + output: WritableByteChannel, startPosition: Long, bytesToCopy: Long): Unit = { - val initialPos = output.position() + val outputInitialState = output match { + case outputFileChannel: FileChannel => + Some((outputFileChannel.position(), outputFileChannel)) + case _ => None + } var count = 0L // In case transferTo method transferred less data than we have required. while (count < bytesToCopy) { @@ -412,15 +416,17 @@ private[spark] object Utils extends Logging { // kernel version 2.6.32, this issue can be seen in // https://bugs.openjdk.java.net/browse/JDK-7052359 // This will lead to stream corruption issue when using sort-based shuffle (SPARK-3948). - val finalPos = output.position() - val expectedPos = initialPos + bytesToCopy - assert(finalPos == expectedPos, - s""" - |Current position $finalPos do not equal to expected position $expectedPos - |after transferTo, please check your kernel version to see if it is 2.6.32, - |this is a kernel bug which will lead to unexpected behavior when using transferTo. - |You can set spark.file.transferTo = false to disable this NIO feature. - """.stripMargin) + outputInitialState.foreach { case (initialPos, outputFileChannel) => + val finalPos = outputFileChannel.position() + val expectedPos = initialPos + bytesToCopy + assert(finalPos == expectedPos, + s""" + |Current position $finalPos do not equal to expected position $expectedPos + |after transferTo, please check your kernel version to see if it is 2.6.32, + |this is a kernel bug which will lead to unexpected behavior when using transferTo. + |You can set spark.file.transferTo = false to disable this NIO feature. + """.stripMargin) + } } /** diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 8b1084a..923c9c9 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -383,13 +383,18 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // simultaneously, and everything is still OK def writeAndClose( - writer: ShuffleWriter[Int, Int])( - iter: Iterator[(Int, Int)]): Option[MapStatus] = { - val files = writer.write(iter) - writer.stop(true) + writer: ShuffleWriter[Int, Int], + taskContext: TaskContext)( + iter: Iterator[(Int, Int)]): Option[MapStatus] = { + try { + val files = writer.write(iter) + writer.stop(true) + } finally { + TaskContext.unset() + } } val interleaver = new InterleaveIterators( - data1, writeAndClose(writer1), data2, writeAndClose(writer2)) + data1, writeAndClose(writer1, context1), data2, writeAndClose(writer2, context2)) val (mapOutput1, mapOutput2) = interleaver.run() // check that we can read the map output and it has the right data @@ -407,6 +412,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC 1, 0, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem) val metrics = taskContext.taskMetrics.createTempShuffleReadMetrics() val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, taskContext, metrics) + TaskContext.unset() val readData = reader.read().toIndexedSeq assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq) diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index fc1422d..b9f81fa 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -27,13 +27,15 @@ import org.mockito.{Mock, MockitoAnnotations} import org.mockito.Answers.RETURNS_SMART_NULLS import org.mockito.ArgumentMatchers.{any, anyInt} import org.mockito.Mockito._ -import org.mockito.invocation.InvocationOnMock import org.scalatest.BeforeAndAfterEach import org.apache.spark._ import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics} +import org.apache.spark.memory.{TaskMemoryManager, TestMemoryManager} import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager} import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.shuffle.api.ShuffleExecutorComponents +import org.apache.spark.shuffle.sort.io.LocalDiskShuffleExecutorComponents import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -48,68 +50,82 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte private var taskMetrics: TaskMetrics = _ private var tempDir: File = _ private var outputFile: File = _ + private var shuffleExecutorComponents: ShuffleExecutorComponents = _ private val conf: SparkConf = new SparkConf(loadDefaults = false) + .set("spark.app.id", "sampleApp") private val temporaryFilesCreated: mutable.Buffer[File] = new ArrayBuffer[File]() private val blockIdToFileMap: mutable.Map[BlockId, File] = new mutable.HashMap[BlockId, File] private var shuffleHandle: BypassMergeSortShuffleHandle[Int, Int] = _ override def beforeEach(): Unit = { super.beforeEach() + MockitoAnnotations.initMocks(this) tempDir = Utils.createTempDir() outputFile = File.createTempFile("shuffle", null, tempDir) taskMetrics = new TaskMetrics - MockitoAnnotations.initMocks(this) shuffleHandle = new BypassMergeSortShuffleHandle[Int, Int]( shuffleId = 0, numMaps = 2, dependency = dependency ) + val memoryManager = new TestMemoryManager(conf) + val taskMemoryManager = new TaskMemoryManager(memoryManager, 0) when(dependency.partitioner).thenReturn(new HashPartitioner(7)) when(dependency.serializer).thenReturn(new JavaSerializer(conf)) when(taskContext.taskMetrics()).thenReturn(taskMetrics) when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile) - doAnswer { (invocationOnMock: InvocationOnMock) => - val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File] - if (tmp != null) { - outputFile.delete - tmp.renameTo(outputFile) - } - null - }.when(blockResolver) - .writeIndexFileAndCommit(anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File])) when(blockManager.diskBlockManager).thenReturn(diskBlockManager) + when(taskContext.taskMemoryManager()).thenReturn(taskMemoryManager) + + when(blockResolver.writeIndexFileAndCommit( + anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File]))) + .thenAnswer { invocationOnMock => + val tmp = invocationOnMock.getArguments()(3).asInstanceOf[File] + if (tmp != null) { + outputFile.delete + tmp.renameTo(outputFile) + } + null + } + when(blockManager.getDiskWriter( any[BlockId], any[File], any[SerializerInstance], anyInt(), - any[ShuffleWriteMetrics] - )).thenAnswer((invocation: InvocationOnMock) => { - val args = invocation.getArguments - val manager = new SerializerManager(new JavaSerializer(conf), conf) - new DiskBlockObjectWriter( - args(1).asInstanceOf[File], - manager, - args(2).asInstanceOf[SerializerInstance], - args(3).asInstanceOf[Int], - syncWrites = false, - args(4).asInstanceOf[ShuffleWriteMetrics], - blockId = args(0).asInstanceOf[BlockId] - ) - }) - when(diskBlockManager.createTempShuffleBlock()).thenAnswer((_: InvocationOnMock) => { - val blockId = new TempShuffleBlockId(UUID.randomUUID) - val file = new File(tempDir, blockId.name) - blockIdToFileMap.put(blockId, file) - temporaryFilesCreated += file - (blockId, file) - }) - when(diskBlockManager.getFile(any[BlockId])).thenAnswer { (invocation: InvocationOnMock) => + any[ShuffleWriteMetrics])) + .thenAnswer { invocation => + val args = invocation.getArguments + val manager = new SerializerManager(new JavaSerializer(conf), conf) + new DiskBlockObjectWriter( + args(1).asInstanceOf[File], + manager, + args(2).asInstanceOf[SerializerInstance], + args(3).asInstanceOf[Int], + syncWrites = false, + args(4).asInstanceOf[ShuffleWriteMetrics], + blockId = args(0).asInstanceOf[BlockId]) + } + + when(diskBlockManager.createTempShuffleBlock()) + .thenAnswer { _ => + val blockId = new TempShuffleBlockId(UUID.randomUUID) + val file = new File(tempDir, blockId.name) + blockIdToFileMap.put(blockId, file) + temporaryFilesCreated += file + (blockId, file) + } + + when(diskBlockManager.getFile(any[BlockId])).thenAnswer { invocation => blockIdToFileMap(invocation.getArguments.head.asInstanceOf[BlockId]) } + + shuffleExecutorComponents = new LocalDiskShuffleExecutorComponents( + conf, blockManager, blockResolver) } override def afterEach(): Unit = { + TaskContext.unset() try { Utils.deleteRecursively(tempDir) blockIdToFileMap.clear() @@ -122,12 +138,13 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte test("write empty iterator") { val writer = new BypassMergeSortShuffleWriter[Int, Int]( blockManager, - blockResolver, shuffleHandle, 0, // MapId + 0L, // MapTaskAttemptId conf, - taskContext.taskMetrics().shuffleWriteMetrics - ) + taskContext.taskMetrics().shuffleWriteMetrics, + shuffleExecutorComponents) + writer.write(Iterator.empty) writer.stop( /* success = */ true) assert(writer.getPartitionLengths.sum === 0) @@ -141,28 +158,31 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte assert(taskMetrics.memoryBytesSpilled === 0) } - test("write with some empty partitions") { - def records: Iterator[(Int, Int)] = - Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) - val writer = new BypassMergeSortShuffleWriter[Int, Int]( - blockManager, - blockResolver, - shuffleHandle, - 0, // MapId - conf, - taskContext.taskMetrics().shuffleWriteMetrics - ) - writer.write(records) - writer.stop( /* success = */ true) - assert(temporaryFilesCreated.nonEmpty) - assert(writer.getPartitionLengths.sum === outputFile.length()) - assert(writer.getPartitionLengths.count(_ == 0L) === 4) // should be 4 zero length files - assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temporary files were deleted - val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics - assert(shuffleWriteMetrics.bytesWritten === outputFile.length()) - assert(shuffleWriteMetrics.recordsWritten === records.length) - assert(taskMetrics.diskBytesSpilled === 0) - assert(taskMetrics.memoryBytesSpilled === 0) + Seq(true, false).foreach { transferTo => + test(s"write with some empty partitions - transferTo $transferTo") { + val transferConf = conf.clone.set("spark.file.transferTo", transferTo.toString) + def records: Iterator[(Int, Int)] = + Iterator((1, 1), (5, 5)) ++ (0 until 100000).iterator.map(x => (2, 2)) + val writer = new BypassMergeSortShuffleWriter[Int, Int]( + blockManager, + shuffleHandle, + 0, // MapId + 0L, + transferConf, + taskContext.taskMetrics().shuffleWriteMetrics, + shuffleExecutorComponents) + writer.write(records) + writer.stop( /* success = */ true) + assert(temporaryFilesCreated.nonEmpty) + assert(writer.getPartitionLengths.sum === outputFile.length()) + assert(writer.getPartitionLengths.count(_ == 0L) === 4) // should be 4 zero length files + assert(temporaryFilesCreated.count(_.exists()) === 0) // check that temp files were deleted + val shuffleWriteMetrics = taskContext.taskMetrics().shuffleWriteMetrics + assert(shuffleWriteMetrics.bytesWritten === outputFile.length()) + assert(shuffleWriteMetrics.recordsWritten === records.length) + assert(taskMetrics.diskBytesSpilled === 0) + assert(taskMetrics.memoryBytesSpilled === 0) + } } test("only generate temp shuffle file for non-empty partition") { @@ -181,12 +201,12 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte val writer = new BypassMergeSortShuffleWriter[Int, Int]( blockManager, - blockResolver, shuffleHandle, 0, // MapId + 0L, conf, - taskContext.taskMetrics().shuffleWriteMetrics - ) + taskContext.taskMetrics().shuffleWriteMetrics, + shuffleExecutorComponents) intercept[SparkException] { writer.write(records) @@ -203,12 +223,12 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte test("cleanup of intermediate files after errors") { val writer = new BypassMergeSortShuffleWriter[Int, Int]( blockManager, - blockResolver, shuffleHandle, 0, // MapId + 0L, conf, - taskContext.taskMetrics().shuffleWriteMetrics - ) + taskContext.taskMetrics().shuffleWriteMetrics, + shuffleExecutorComponents) intercept[SparkException] { writer.write((0 until 100000).iterator.map(i => { if (i == 99990) { @@ -221,5 +241,4 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte writer.stop( /* success = */ false) assert(temporaryFilesCreated.count(_.exists()) === 0) } - } diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala new file mode 100644 index 0000000..5693b98 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/io/LocalDiskShuffleMapOutputWriterSuite.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.shuffle.sort.io + +import java.io.{File, FileInputStream} +import java.nio.channels.FileChannel +import java.nio.file.Files +import java.util.Arrays + +import org.mockito.Answers.RETURNS_SMART_NULLS +import org.mockito.ArgumentMatchers.{any, anyInt} +import org.mockito.Mock +import org.mockito.Mockito.when +import org.mockito.MockitoAnnotations +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.shuffle.IndexShuffleBlockResolver +import org.apache.spark.util.Utils + +class LocalDiskShuffleMapOutputWriterSuite extends SparkFunSuite with BeforeAndAfterEach { + + @Mock(answer = RETURNS_SMART_NULLS) + private var blockResolver: IndexShuffleBlockResolver = _ + + private val NUM_PARTITIONS = 4 + private val data: Array[Array[Byte]] = (0 until NUM_PARTITIONS).map { p => + if (p == 3) { + Array.emptyByteArray + } else { + (0 to p * 10).map(_ + p).map(_.toByte).toArray + } + }.toArray + + private val partitionLengths = data.map(_.length) + + private var tempFile: File = _ + private var mergedOutputFile: File = _ + private var tempDir: File = _ + private var partitionSizesInMergedFile: Array[Long] = _ + private var conf: SparkConf = _ + private var mapOutputWriter: LocalDiskShuffleMapOutputWriter = _ + + override def afterEach(): Unit = { + try { + Utils.deleteRecursively(tempDir) + } finally { + super.afterEach() + } + } + + override def beforeEach(): Unit = { + MockitoAnnotations.initMocks(this) + tempDir = Utils.createTempDir() + mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir) + tempFile = File.createTempFile("tempfile", "", tempDir) + partitionSizesInMergedFile = null + conf = new SparkConf() + .set("spark.app.id", "example.spark.app") + .set("spark.shuffle.unsafe.file.output.buffer", "16k") + when(blockResolver.getDataFile(anyInt, anyInt)).thenReturn(mergedOutputFile) + when(blockResolver.writeIndexFileAndCommit( + anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File]))) + .thenAnswer { invocationOnMock => + partitionSizesInMergedFile = invocationOnMock.getArguments()(2).asInstanceOf[Array[Long]] + val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File] + if (tmp != null) { + mergedOutputFile.delete() + tmp.renameTo(mergedOutputFile) + } + null + } + mapOutputWriter = new LocalDiskShuffleMapOutputWriter( + 0, + 0, + NUM_PARTITIONS, + blockResolver, + conf) + } + + test("writing to an outputstream") { + (0 until NUM_PARTITIONS).foreach { p => + val writer = mapOutputWriter.getPartitionWriter(p) + val stream = writer.openStream() + data(p).foreach { i => stream.write(i) } + stream.close() + intercept[IllegalStateException] { + stream.write(p) + } + assert(writer.getNumBytesWritten === data(p).length) + } + verifyWrittenRecords() + } + + test("writing to a channel") { + (0 until NUM_PARTITIONS).foreach { p => + val writer = mapOutputWriter.getPartitionWriter(p) + val outputTempFile = File.createTempFile("channelTemp", "", tempDir) + Files.write(outputTempFile.toPath, data(p)) + val tempFileInput = new FileInputStream(outputTempFile) + val channel = writer.openChannelWrapper() + Utils.tryWithResource(new FileInputStream(outputTempFile)) { tempFileInput => + Utils.tryWithResource(writer.openChannelWrapper().get) { channelWrapper => + assert(channelWrapper.channel().isInstanceOf[FileChannel], + "Underlying channel should be a file channel") + Utils.copyFileStreamNIO( + tempFileInput.getChannel, channelWrapper.channel(), 0L, data(p).length) + } + } + assert(writer.getNumBytesWritten === data(p).length, + s"Partition $p does not have the correct number of bytes.") + } + verifyWrittenRecords() + } + + private def readRecordsFromFile() = { + val mergedOutputBytes = Files.readAllBytes(mergedOutputFile.toPath) + val result = (0 until NUM_PARTITIONS).map { part => + val startOffset = data.slice(0, part).map(_.length).sum + val partitionSize = data(part).length + Arrays.copyOfRange(mergedOutputBytes, startOffset, startOffset + partitionSize) + }.toArray + result + } + + private def verifyWrittenRecords(): Unit = { + mapOutputWriter.commitAllPartitions() + assert(partitionSizesInMergedFile === partitionLengths) + assert(mergedOutputFile.length() === partitionLengths.sum) + assert(data === readRecordsFromFile()) + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org