Repository: spark Updated Branches: refs/heads/branch-2.1 9e96ac5a9 -> c2c2fdcb7
[SPARK-18546][CORE] Fix merging shuffle spills when using encryption. The problem exists because it's not possible to just concatenate encrypted partition data from different spill files; currently each partition would have its own initial vector to set up encryption, and the final merged file should contain a single initial vector for each merged partiton, otherwise iterating over each record becomes really hard. To fix that, UnsafeShuffleWriter now decrypts the partitions when merging, so that the merged file contains a single initial vector at the start of the partition data. Because it's not possible to do that using the fast transferTo path, when encryption is enabled UnsafeShuffleWriter will revert back to using file streams when merging. It may be possible to use a hybrid approach when using encryption, using an intermediate direct buffer when reading from files and encrypting the data, but that's better left for a separate patch. As part of the change I made DiskBlockObjectWriter take a SerializerManager instead of a "wrap stream" closure, since that makes it easier to test the code without having to mock SerializerManager functionality. Tested with newly added unit tests (UnsafeShuffleWriterSuite for the write side and ExternalAppendOnlyMapSuite for integration), and by running some apps that failed without the fix. Author: Marcelo Vanzin <van...@cloudera.com> Closes #15982 from vanzin/SPARK-18546. (cherry picked from commit 93e9d880bf8a144112d74a6897af4e36fcfa5807) Signed-off-by: Marcelo Vanzin <van...@cloudera.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c2c2fdcb Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c2c2fdcb Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c2c2fdcb Branch: refs/heads/branch-2.1 Commit: c2c2fdcb71e9bc82f0e88567148d1bae283f256a Parents: 9e96ac5 Author: Marcelo Vanzin <van...@cloudera.com> Authored: Wed Nov 30 14:10:32 2016 -0800 Committer: Marcelo Vanzin <van...@cloudera.com> Committed: Wed Nov 30 14:10:44 2016 -0800 ---------------------------------------------------------------------- .../spark/shuffle/sort/UnsafeShuffleWriter.java | 48 +++++---- .../spark/serializer/SerializerManager.scala | 6 +- .../org/apache/spark/storage/BlockManager.scala | 5 +- .../spark/storage/DiskBlockObjectWriter.scala | 6 +- .../shuffle/sort/UnsafeShuffleWriterSuite.java | 100 +++++++++++++------ .../map/AbstractBytesToBytesMapSuite.java | 11 +- .../unsafe/sort/UnsafeExternalSorterSuite.java | 21 ++-- .../BypassMergeSortShuffleWriterSuite.scala | 5 +- .../storage/DiskBlockObjectWriterSuite.scala | 54 ++++------ .../collection/ExternalAppendOnlyMapSuite.scala | 8 +- 10 files changed, 145 insertions(+), 119 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/c2c2fdcb/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java index f235c43..8a17718 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java @@ -40,6 +40,8 @@ import org.apache.spark.annotation.Private; import org.apache.spark.executor.ShuffleWriteMetrics; import org.apache.spark.io.CompressionCodec; import org.apache.spark.io.CompressionCodec$; +import org.apache.commons.io.output.CloseShieldOutputStream; +import org.apache.commons.io.output.CountingOutputStream; import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.scheduler.MapStatus; @@ -264,6 +266,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> { sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true); final boolean fastMergeIsSupported = !compressionEnabled || CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec); + final boolean encryptionEnabled = blockManager.serializerManager().encryptionEnabled(); try { if (spills.length == 0) { new FileOutputStream(outputFile).close(); // Create an empty file @@ -289,7 +292,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> { // Compression is disabled or we are using an IO compression codec that supports // decompression of concatenated compressed streams, so we can perform a fast spill merge // that doesn't need to interpret the spilled bytes. - if (transferToEnabled) { + if (transferToEnabled && !encryptionEnabled) { logger.debug("Using transferTo-based fast merge"); partitionLengths = mergeSpillsWithTransferTo(spills, outputFile); } else { @@ -320,9 +323,9 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> { /** * Merges spill files using Java FileStreams. This code path is slower than the NIO-based merge, * {@link UnsafeShuffleWriter#mergeSpillsWithTransferTo(SpillInfo[], File)}, so it's only used in - * cases where the IO compression codec does not support concatenation of compressed data, or in - * cases where users have explicitly disabled use of {@code transferTo} in order to work around - * kernel bugs. + * cases where the IO compression codec does not support concatenation of compressed data, when + * encryption is enabled, or when users have explicitly disabled use of {@code transferTo} in + * order to work around kernel bugs. * * @param spills the spills to merge. * @param outputFile the file to write the merged data to. @@ -337,7 +340,11 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> { final int numPartitions = partitioner.numPartitions(); final long[] partitionLengths = new long[numPartitions]; final InputStream[] spillInputStreams = new FileInputStream[spills.length]; - OutputStream mergedFileOutputStream = null; + + // Use a counting output stream to avoid having to close the underlying file and ask + // the file system for its size after each partition is written. + final CountingOutputStream mergedFileOutputStream = new CountingOutputStream( + new FileOutputStream(outputFile)); boolean threwException = true; try { @@ -345,34 +352,35 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> { spillInputStreams[i] = new FileInputStream(spills[i].file); } for (int partition = 0; partition < numPartitions; partition++) { - final long initialFileLength = outputFile.length(); - mergedFileOutputStream = - new TimeTrackingOutputStream(writeMetrics, new FileOutputStream(outputFile, true)); + final long initialFileLength = mergedFileOutputStream.getByteCount(); + // Shield the underlying output stream from close() calls, so that we can close the higher + // level streams to make sure all data is really flushed and internal state is cleaned. + OutputStream partitionOutput = new CloseShieldOutputStream( + new TimeTrackingOutputStream(writeMetrics, mergedFileOutputStream)); + partitionOutput = blockManager.serializerManager().wrapForEncryption(partitionOutput); if (compressionCodec != null) { - mergedFileOutputStream = compressionCodec.compressedOutputStream(mergedFileOutputStream); + partitionOutput = compressionCodec.compressedOutputStream(partitionOutput); } - for (int i = 0; i < spills.length; i++) { final long partitionLengthInSpill = spills[i].partitionLengths[partition]; if (partitionLengthInSpill > 0) { - InputStream partitionInputStream = null; - boolean innerThrewException = true; + InputStream partitionInputStream = new LimitedInputStream(spillInputStreams[i], + partitionLengthInSpill, false); try { - partitionInputStream = - new LimitedInputStream(spillInputStreams[i], partitionLengthInSpill, false); + partitionInputStream = blockManager.serializerManager().wrapForEncryption( + partitionInputStream); if (compressionCodec != null) { partitionInputStream = compressionCodec.compressedInputStream(partitionInputStream); } - ByteStreams.copy(partitionInputStream, mergedFileOutputStream); - innerThrewException = false; + ByteStreams.copy(partitionInputStream, partitionOutput); } finally { - Closeables.close(partitionInputStream, innerThrewException); + partitionInputStream.close(); } } } - mergedFileOutputStream.flush(); - mergedFileOutputStream.close(); - partitionLengths[partition] = (outputFile.length() - initialFileLength); + partitionOutput.flush(); + partitionOutput.close(); + partitionLengths[partition] = (mergedFileOutputStream.getByteCount() - initialFileLength); } threwException = false; } finally { http://git-wip-us.apache.org/repos/asf/spark/blob/c2c2fdcb/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index 7371f88..686305e 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -75,6 +75,8 @@ private[spark] class SerializerManager( * loaded yet. */ private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf) + def encryptionEnabled: Boolean = encryptionKey.isDefined + def canUseKryo(ct: ClassTag[_]): Boolean = { primitiveAndPrimitiveArrayClassTags.contains(ct) || ct == stringClassTag } @@ -129,7 +131,7 @@ private[spark] class SerializerManager( /** * Wrap an input stream for encryption if shuffle encryption is enabled */ - private[this] def wrapForEncryption(s: InputStream): InputStream = { + def wrapForEncryption(s: InputStream): InputStream = { encryptionKey .map { key => CryptoStreamUtils.createCryptoInputStream(s, conf, key) } .getOrElse(s) @@ -138,7 +140,7 @@ private[spark] class SerializerManager( /** * Wrap an output stream for encryption if shuffle encryption is enabled */ - private[this] def wrapForEncryption(s: OutputStream): OutputStream = { + def wrapForEncryption(s: OutputStream): OutputStream = { encryptionKey .map { key => CryptoStreamUtils.createCryptoOutputStream(s, conf, key) } .getOrElse(s) http://git-wip-us.apache.org/repos/asf/spark/blob/c2c2fdcb/core/src/main/scala/org/apache/spark/storage/BlockManager.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 982b833..04521c9 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -62,7 +62,7 @@ private[spark] class BlockManager( executorId: String, rpcEnv: RpcEnv, val master: BlockManagerMaster, - serializerManager: SerializerManager, + val serializerManager: SerializerManager, val conf: SparkConf, memoryManager: MemoryManager, mapOutputTracker: MapOutputTracker, @@ -745,9 +745,8 @@ private[spark] class BlockManager( serializerInstance: SerializerInstance, bufferSize: Int, writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = { - val wrapStream: OutputStream => OutputStream = serializerManager.wrapStream(blockId, _) val syncWrites = conf.getBoolean("spark.shuffle.sync", false) - new DiskBlockObjectWriter(file, serializerInstance, bufferSize, wrapStream, + new DiskBlockObjectWriter(file, serializerManager, serializerInstance, bufferSize, syncWrites, writeMetrics, blockId) } http://git-wip-us.apache.org/repos/asf/spark/blob/c2c2fdcb/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index a499827..3cb12fc 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -22,7 +22,7 @@ import java.nio.channels.FileChannel import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.internal.Logging -import org.apache.spark.serializer.{SerializationStream, SerializerInstance} +import org.apache.spark.serializer.{SerializationStream, SerializerInstance, SerializerManager} import org.apache.spark.util.Utils /** @@ -37,9 +37,9 @@ import org.apache.spark.util.Utils */ private[spark] class DiskBlockObjectWriter( val file: File, + serializerManager: SerializerManager, serializerInstance: SerializerInstance, bufferSize: Int, - wrapStream: OutputStream => OutputStream, syncWrites: Boolean, // These write metrics concurrently shared with other active DiskBlockObjectWriters who // are themselves performing writes. All updates must be relative. @@ -116,7 +116,7 @@ private[spark] class DiskBlockObjectWriter( initialized = true } - bs = wrapStream(mcs) + bs = serializerManager.wrapStream(blockId, mcs) objOut = serializerInstance.serializeStream(bs) streamOpen = true this http://git-wip-us.apache.org/repos/asf/spark/blob/c2c2fdcb/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index a96cd82..088b681 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -26,11 +26,9 @@ import scala.Product2; import scala.Tuple2; import scala.Tuple2$; import scala.collection.Iterator; -import scala.runtime.AbstractFunction1; import com.google.common.collect.HashMultiset; import com.google.common.collect.Iterators; -import com.google.common.io.ByteStreams; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -53,6 +51,7 @@ import org.apache.spark.memory.TaskMemoryManager; import org.apache.spark.memory.TestMemoryManager; import org.apache.spark.network.util.LimitedInputStream; import org.apache.spark.scheduler.MapStatus; +import org.apache.spark.security.CryptoStreamUtils; import org.apache.spark.serializer.*; import org.apache.spark.shuffle.IndexShuffleBlockResolver; import org.apache.spark.storage.*; @@ -77,7 +76,6 @@ public class UnsafeShuffleWriterSuite { final LinkedList<File> spillFilesCreated = new LinkedList<>(); SparkConf conf; final Serializer serializer = new KryoSerializer(new SparkConf()); - final SerializerManager serializerManager = new SerializerManager(serializer, new SparkConf()); TaskMetrics taskMetrics; @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; @@ -86,17 +84,6 @@ public class UnsafeShuffleWriterSuite { @Mock(answer = RETURNS_SMART_NULLS) TaskContext taskContext; @Mock(answer = RETURNS_SMART_NULLS) ShuffleDependency<Object, Object, Object> shuffleDep; - private final class WrapStream extends AbstractFunction1<OutputStream, OutputStream> { - @Override - public OutputStream apply(OutputStream stream) { - if (conf.getBoolean("spark.shuffle.compress", true)) { - return CompressionCodec$.MODULE$.createCodec(conf).compressedOutputStream(stream); - } else { - return stream; - } - } - } - @After public void tearDown() { Utils.deleteRecursively(tempDir); @@ -121,6 +108,11 @@ public class UnsafeShuffleWriterSuite { memoryManager = new TestMemoryManager(conf); taskMemoryManager = new TaskMemoryManager(memoryManager, 0); + // Some tests will override this manager because they change the configuration. This is a + // default for tests that don't need a specific one. + SerializerManager manager = new SerializerManager(serializer, conf); + when(blockManager.serializerManager()).thenReturn(manager); + when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); when(blockManager.getDiskWriter( any(BlockId.class), @@ -131,12 +123,11 @@ public class UnsafeShuffleWriterSuite { @Override public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable { Object[] args = invocationOnMock.getArguments(); - return new DiskBlockObjectWriter( (File) args[1], + blockManager.serializerManager(), (SerializerInstance) args[2], (Integer) args[3], - new WrapStream(), false, (ShuffleWriteMetrics) args[4], (BlockId) args[0] @@ -201,9 +192,10 @@ public class UnsafeShuffleWriterSuite { for (int i = 0; i < NUM_PARTITITONS; i++) { final long partitionSize = partitionSizesInMergedFile[i]; if (partitionSize > 0) { - InputStream in = new FileInputStream(mergedOutputFile); - ByteStreams.skipFully(in, startOffset); - in = new LimitedInputStream(in, partitionSize); + FileInputStream fin = new FileInputStream(mergedOutputFile); + fin.getChannel().position(startOffset); + InputStream in = new LimitedInputStream(fin, partitionSize); + in = blockManager.serializerManager().wrapForEncryption(in); if (conf.getBoolean("spark.shuffle.compress", true)) { in = CompressionCodec$.MODULE$.createCodec(conf).compressedInputStream(in); } @@ -294,14 +286,32 @@ public class UnsafeShuffleWriterSuite { } private void testMergingSpills( - boolean transferToEnabled, - String compressionCodecName) throws IOException { + final boolean transferToEnabled, + String compressionCodecName, + boolean encrypt) throws Exception { if (compressionCodecName != null) { conf.set("spark.shuffle.compress", "true"); conf.set("spark.io.compression.codec", compressionCodecName); } else { conf.set("spark.shuffle.compress", "false"); } + conf.set(org.apache.spark.internal.config.package$.MODULE$.IO_ENCRYPTION_ENABLED(), encrypt); + + SerializerManager manager; + if (encrypt) { + manager = new SerializerManager(serializer, conf, + Option.apply(CryptoStreamUtils.createKey(conf))); + } else { + manager = new SerializerManager(serializer, conf); + } + + when(blockManager.serializerManager()).thenReturn(manager); + testMergingSpills(transferToEnabled, encrypt); + } + + private void testMergingSpills( + boolean transferToEnabled, + boolean encrypted) throws IOException { final UnsafeShuffleWriter<Object, Object> writer = createWriter(transferToEnabled); final ArrayList<Product2<Object, Object>> dataToWrite = new ArrayList<>(); for (int i : new int[] { 1, 2, 3, 4, 4, 2 }) { @@ -324,6 +334,7 @@ public class UnsafeShuffleWriterSuite { for (long size: partitionSizesInMergedFile) { sumOfPartitionSizes += size; } + assertEquals(sumOfPartitionSizes, mergedOutputFile.length()); assertEquals(HashMultiset.create(dataToWrite), HashMultiset.create(readRecordsFromFile())); @@ -338,42 +349,72 @@ public class UnsafeShuffleWriterSuite { @Test public void mergeSpillsWithTransferToAndLZF() throws Exception { - testMergingSpills(true, LZFCompressionCodec.class.getName()); + testMergingSpills(true, LZFCompressionCodec.class.getName(), false); } @Test public void mergeSpillsWithFileStreamAndLZF() throws Exception { - testMergingSpills(false, LZFCompressionCodec.class.getName()); + testMergingSpills(false, LZFCompressionCodec.class.getName(), false); } @Test public void mergeSpillsWithTransferToAndLZ4() throws Exception { - testMergingSpills(true, LZ4CompressionCodec.class.getName()); + testMergingSpills(true, LZ4CompressionCodec.class.getName(), false); } @Test public void mergeSpillsWithFileStreamAndLZ4() throws Exception { - testMergingSpills(false, LZ4CompressionCodec.class.getName()); + testMergingSpills(false, LZ4CompressionCodec.class.getName(), false); } @Test public void mergeSpillsWithTransferToAndSnappy() throws Exception { - testMergingSpills(true, SnappyCompressionCodec.class.getName()); + testMergingSpills(true, SnappyCompressionCodec.class.getName(), false); } @Test public void mergeSpillsWithFileStreamAndSnappy() throws Exception { - testMergingSpills(false, SnappyCompressionCodec.class.getName()); + testMergingSpills(false, SnappyCompressionCodec.class.getName(), false); } @Test public void mergeSpillsWithTransferToAndNoCompression() throws Exception { - testMergingSpills(true, null); + testMergingSpills(true, null, false); } @Test public void mergeSpillsWithFileStreamAndNoCompression() throws Exception { - testMergingSpills(false, null); + testMergingSpills(false, null, false); + } + + @Test + public void mergeSpillsWithCompressionAndEncryption() throws Exception { + // This should actually be translated to a "file stream merge" internally, just have the + // test to make sure that it's the case. + testMergingSpills(true, LZ4CompressionCodec.class.getName(), true); + } + + @Test + public void mergeSpillsWithFileStreamAndCompressionAndEncryption() throws Exception { + testMergingSpills(false, LZ4CompressionCodec.class.getName(), true); + } + + @Test + public void mergeSpillsWithCompressionAndEncryptionSlowPath() throws Exception { + conf.set("spark.shuffle.unsafe.fastMergeEnabled", "false"); + testMergingSpills(false, LZ4CompressionCodec.class.getName(), true); + } + + @Test + public void mergeSpillsWithEncryptionAndNoCompression() throws Exception { + // This should actually be translated to a "file stream merge" internally, just have the + // test to make sure that it's the case. + testMergingSpills(true, null, true); + } + + @Test + public void mergeSpillsWithFileStreamAndEncryptionAndNoCompression() throws Exception { + testMergingSpills(false, null, true); } @Test @@ -531,4 +572,5 @@ public class UnsafeShuffleWriterSuite { writer.stop(false); } } + } http://git-wip-us.apache.org/repos/asf/spark/blob/c2c2fdcb/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java index 33709b4..2656814 100644 --- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java +++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java @@ -19,13 +19,11 @@ package org.apache.spark.unsafe.map; import java.io.File; import java.io.IOException; -import java.io.OutputStream; import java.nio.ByteBuffer; import java.util.*; import scala.Tuple2; import scala.Tuple2$; -import scala.runtime.AbstractFunction1; import org.junit.After; import org.junit.Assert; @@ -75,13 +73,6 @@ public abstract class AbstractBytesToBytesMapSuite { @Mock(answer = RETURNS_SMART_NULLS) BlockManager blockManager; @Mock(answer = RETURNS_SMART_NULLS) DiskBlockManager diskBlockManager; - private static final class WrapStream extends AbstractFunction1<OutputStream, OutputStream> { - @Override - public OutputStream apply(OutputStream stream) { - return stream; - } - } - @Before public void setup() { memoryManager = @@ -120,9 +111,9 @@ public abstract class AbstractBytesToBytesMapSuite { return new DiskBlockObjectWriter( (File) args[1], + serializerManager, (SerializerInstance) args[2], (Integer) args[3], - new WrapStream(), false, (ShuffleWriteMetrics) args[4], (BlockId) args[0] http://git-wip-us.apache.org/repos/asf/spark/blob/c2c2fdcb/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java ---------------------------------------------------------------------- diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java index a9cf8ff..fbbe530 100644 --- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java +++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java @@ -19,14 +19,12 @@ package org.apache.spark.util.collection.unsafe.sort; import java.io.File; import java.io.IOException; -import java.io.OutputStream; import java.util.Arrays; import java.util.LinkedList; import java.util.UUID; import scala.Tuple2; import scala.Tuple2$; -import scala.runtime.AbstractFunction1; import org.junit.After; import org.junit.Before; @@ -57,13 +55,15 @@ import static org.mockito.Mockito.*; public class UnsafeExternalSorterSuite { + private final SparkConf conf = new SparkConf(); + final LinkedList<File> spillFilesCreated = new LinkedList<>(); final TestMemoryManager memoryManager = - new TestMemoryManager(new SparkConf().set("spark.memory.offHeap.enabled", "false")); + new TestMemoryManager(conf.clone().set("spark.memory.offHeap.enabled", "false")); final TaskMemoryManager taskMemoryManager = new TaskMemoryManager(memoryManager, 0); final SerializerManager serializerManager = new SerializerManager( - new JavaSerializer(new SparkConf()), - new SparkConf().set("spark.shuffle.spill.compress", "false")); + new JavaSerializer(conf), + conf.clone().set("spark.shuffle.spill.compress", "false")); // Use integer comparison for comparing prefixes (which are partition ids, in this case) final PrefixComparator prefixComparator = PrefixComparators.LONG; // Since the key fits within the 8-byte prefix, we don't need to do any record comparison, so @@ -86,14 +86,7 @@ public class UnsafeExternalSorterSuite { protected boolean shouldUseRadixSort() { return false; } - private final long pageSizeBytes = new SparkConf().getSizeAsBytes("spark.buffer.pageSize", "4m"); - - private static final class WrapStream extends AbstractFunction1<OutputStream, OutputStream> { - @Override - public OutputStream apply(OutputStream stream) { - return stream; - } - } + private final long pageSizeBytes = conf.getSizeAsBytes("spark.buffer.pageSize", "4m"); @Before public void setUp() { @@ -126,9 +119,9 @@ public class UnsafeExternalSorterSuite { return new DiskBlockObjectWriter( (File) args[1], + serializerManager, (SerializerInstance) args[2], (Integer) args[3], - new WrapStream(), false, (ShuffleWriteMetrics) args[4], (BlockId) args[0] http://git-wip-us.apache.org/repos/asf/spark/blob/c2c2fdcb/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala index 4429416..85ccb33 100644 --- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala @@ -33,7 +33,7 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark._ import org.apache.spark.executor.{ShuffleWriteMetrics, TaskMetrics} -import org.apache.spark.serializer.{JavaSerializer, SerializerInstance} +import org.apache.spark.serializer.{JavaSerializer, SerializerInstance, SerializerManager} import org.apache.spark.shuffle.IndexShuffleBlockResolver import org.apache.spark.storage._ import org.apache.spark.util.Utils @@ -90,11 +90,12 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte )).thenAnswer(new Answer[DiskBlockObjectWriter] { override def answer(invocation: InvocationOnMock): DiskBlockObjectWriter = { val args = invocation.getArguments + val manager = new SerializerManager(new JavaSerializer(conf), conf) new DiskBlockObjectWriter( args(1).asInstanceOf[File], + manager, args(2).asInstanceOf[SerializerInstance], args(3).asInstanceOf[Int], - wrapStream = identity, syncWrites = false, args(4).asInstanceOf[ShuffleWriteMetrics], blockId = args(0).asInstanceOf[BlockId] http://git-wip-us.apache.org/repos/asf/spark/blob/c2c2fdcb/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala index 684e978..bfb3ac4 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala @@ -22,7 +22,7 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.executor.ShuffleWriteMetrics -import org.apache.spark.serializer.JavaSerializer +import org.apache.spark.serializer.{JavaSerializer, SerializerManager} import org.apache.spark.util.Utils class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { @@ -42,11 +42,19 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } } - test("verify write metrics") { + private def createWriter(): (DiskBlockObjectWriter, File, ShuffleWriteMetrics) = { val file = new File(tempDir, "somefile") + val conf = new SparkConf() + val serializerManager = new SerializerManager(new JavaSerializer(conf), conf) val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + file, serializerManager, new JavaSerializer(new SparkConf()).newInstance(), 1024, true, + writeMetrics) + (writer, file, writeMetrics) + } + + test("verify write metrics") { + val (writer, file, writeMetrics) = createWriter() writer.write(Long.box(20), Long.box(30)) // Record metrics update on every write @@ -66,10 +74,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } test("verify write metrics on revert") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val (writer, _, writeMetrics) = createWriter() writer.write(Long.box(20), Long.box(30)) // Record metrics update on every write @@ -89,10 +94,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } test("Reopening a closed block writer") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val (writer, _, _) = createWriter() writer.open() writer.close() @@ -102,10 +104,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } test("calling revertPartialWritesAndClose() on a partial write should truncate up to commit") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val (writer, file, writeMetrics) = createWriter() writer.write(Long.box(20), Long.box(30)) val firstSegment = writer.commitAndGet() @@ -120,10 +119,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } test("calling revertPartialWritesAndClose() after commit() should have no effect") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val (writer, file, writeMetrics) = createWriter() writer.write(Long.box(20), Long.box(30)) val firstSegment = writer.commitAndGet() @@ -136,10 +132,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } test("calling revertPartialWritesAndClose() on a closed block writer should have no effect") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val (writer, file, writeMetrics) = createWriter() for (i <- 1 to 1000) { writer.write(i, i) } @@ -153,10 +146,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } test("commit() and close() should be idempotent") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val (writer, file, writeMetrics) = createWriter() for (i <- 1 to 1000) { writer.write(i, i) } @@ -173,10 +163,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } test("revertPartialWritesAndClose() should be idempotent") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val (writer, file, writeMetrics) = createWriter() for (i <- 1 to 1000) { writer.write(i, i) } @@ -191,10 +178,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } test("commit() and close() without ever opening or writing") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + val (writer, _, _) = createWriter() val segment = writer.commitAndGet() writer.close() assert(segment.length === 0) http://git-wip-us.apache.org/repos/asf/spark/blob/c2c2fdcb/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala index 5141e36..7f08382 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalAppendOnlyMapSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.util.collection import scala.collection.mutable.ArrayBuffer import org.apache.spark._ +import org.apache.spark.internal.config._ import org.apache.spark.io.CompressionCodec import org.apache.spark.memory.MemoryTestingUtils @@ -230,14 +231,19 @@ class ExternalAppendOnlyMapSuite extends SparkFunSuite with LocalSparkContext { } } + test("spilling with compression and encryption") { + testSimpleSpilling(Some(CompressionCodec.DEFAULT_COMPRESSION_CODEC), encrypt = true) + } + /** * Test spilling through simple aggregations and cogroups. * If a compression codec is provided, use it. Otherwise, do not compress spills. */ - private def testSimpleSpilling(codec: Option[String] = None): Unit = { + private def testSimpleSpilling(codec: Option[String] = None, encrypt: Boolean = false): Unit = { val size = 1000 val conf = createSparkConf(loadDefaults = true, codec) // Load defaults for Spark home conf.set("spark.shuffle.spill.numElementsForceSpillThreshold", (size / 4).toString) + conf.set(IO_ENCRYPTION_ENABLED, encrypt) sc = new SparkContext("local-cluster[1,1,1024]", "test", conf) assertSpilled(sc, "reduceByKey") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org