Repository: spark Updated Branches: refs/heads/master 1221ce040 -> daace6014
[SPARK-5581][CORE] When writing sorted map output file, avoid open / ⦠â¦close between each partition ## What changes were proposed in this pull request? Replace commitAndClose with separate commit and close to avoid opening and closing the file between partitions. ## How was this patch tested? Run existing unit tests, add a few unit tests regarding reverts. Observed a ~20% reduction in total time in tasks on stages with shuffle writes to many partitions. JoshRosen Author: Brian Cho <b...@fb.com> Closes #13382 from dafrista/separatecommit-master. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/daace601 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/daace601 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/daace601 Branch: refs/heads/master Commit: daace6014216b996bcc8937f1fdcea732b6910ca Parents: 1221ce0 Author: Brian Cho <b...@fb.com> Authored: Sun Jul 24 19:36:58 2016 -0700 Committer: Josh Rosen <joshro...@databricks.com> Committed: Sun Jul 24 19:36:58 2016 -0700 ---------------------------------------------------------------------- .../sort/BypassMergeSortShuffleWriter.java | 10 +- .../shuffle/sort/ShuffleExternalSorter.java | 31 ++-- .../unsafe/sort/UnsafeSorterSpillWriter.java | 3 +- .../spark/storage/DiskBlockObjectWriter.scala | 157 ++++++++++++------- .../util/collection/ExternalAppendOnlyMap.scala | 28 ++-- .../spark/util/collection/ExternalSorter.scala | 52 +++--- .../storage/DiskBlockObjectWriterSuite.scala | 67 +++++--- 7 files changed, 192 insertions(+), 156 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/daace601/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java index 0e9defe..83dc61c 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java @@ -88,6 +88,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> { /** Array of file writers, one for each partition */ private DiskBlockObjectWriter[] partitionWriters; + private FileSegment[] partitionWriterSegments; @Nullable private MapStatus mapStatus; private long[] partitionLengths; @@ -131,6 +132,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> { final SerializerInstance serInstance = serializer.newInstance(); final long openStartTime = System.nanoTime(); partitionWriters = new DiskBlockObjectWriter[numPartitions]; + partitionWriterSegments = new FileSegment[numPartitions]; for (int i = 0; i < numPartitions; i++) { final Tuple2<TempShuffleBlockId, File> tempShuffleBlockIdPlusFile = blockManager.diskBlockManager().createTempShuffleBlock(); @@ -150,8 +152,10 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> { partitionWriters[partitioner.getPartition(key)].write(key, record._2()); } - for (DiskBlockObjectWriter writer : partitionWriters) { - writer.commitAndClose(); + for (int i = 0; i < numPartitions; i++) { + final DiskBlockObjectWriter writer = partitionWriters[i]; + partitionWriterSegments[i] = writer.commitAndGet(); + writer.close(); } File output = shuffleBlockResolver.getDataFile(shuffleId, mapId); @@ -184,7 +188,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> { boolean threwException = true; try { for (int i = 0; i < numPartitions; i++) { - final File file = partitionWriters[i].fileSegment().file(); + final File file = partitionWriterSegments[i].file(); if (file.exists()) { final FileInputStream in = new FileInputStream(file); boolean copyThrewException = true; http://git-wip-us.apache.org/repos/asf/spark/blob/daace601/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java index cf38a04..cfec724 100644 --- a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java +++ b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java @@ -37,6 +37,7 @@ import org.apache.spark.serializer.DummySerializerInstance; import org.apache.spark.serializer.SerializerInstance; import org.apache.spark.storage.BlockManager; import org.apache.spark.storage.DiskBlockObjectWriter; +import org.apache.spark.storage.FileSegment; import org.apache.spark.storage.TempShuffleBlockId; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.array.LongArray; @@ -150,10 +151,6 @@ final class ShuffleExternalSorter extends MemoryConsumer { final ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords = inMemSorter.getSortedIterator(); - // Currently, we need to open a new DiskBlockObjectWriter for each partition; we can avoid this - // after SPARK-5581 is fixed. - DiskBlockObjectWriter writer; - // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since there doesn't seem to // be an API to directly transfer bytes from managed memory to the disk writer, we buffer // data through a byte array. This array does not need to be large enough to hold a single @@ -175,7 +172,8 @@ final class ShuffleExternalSorter extends MemoryConsumer { // around this, we pass a dummy no-op serializer. final SerializerInstance ser = DummySerializerInstance.INSTANCE; - writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse); + final DiskBlockObjectWriter writer = + blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse); int currentPartition = -1; while (sortedRecords.hasNext()) { @@ -185,12 +183,10 @@ final class ShuffleExternalSorter extends MemoryConsumer { if (partition != currentPartition) { // Switch to the new partition if (currentPartition != -1) { - writer.commitAndClose(); - spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length(); + final FileSegment fileSegment = writer.commitAndGet(); + spillInfo.partitionLengths[currentPartition] = fileSegment.length(); } currentPartition = partition; - writer = - blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, writeMetricsToUse); } final long recordPointer = sortedRecords.packedRecordPointer.getRecordPointer(); @@ -209,15 +205,14 @@ final class ShuffleExternalSorter extends MemoryConsumer { writer.recordWritten(); } - if (writer != null) { - writer.commitAndClose(); - // If `writeSortedFile()` was called from `closeAndGetSpills()` and no records were inserted, - // then the file might be empty. Note that it might be better to avoid calling - // writeSortedFile() in that case. - if (currentPartition != -1) { - spillInfo.partitionLengths[currentPartition] = writer.fileSegment().length(); - spills.add(spillInfo); - } + final FileSegment committedSegment = writer.commitAndGet(); + writer.close(); + // If `writeSortedFile()` was called from `closeAndGetSpills()` and no records were inserted, + // then the file might be empty. Note that it might be better to avoid calling + // writeSortedFile() in that case. + if (currentPartition != -1) { + spillInfo.partitionLengths[currentPartition] = committedSegment.length(); + spills.add(spillInfo); } if (!isLastFile) { // i.e. this is a spill file http://git-wip-us.apache.org/repos/asf/spark/blob/daace601/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java ---------------------------------------------------------------------- diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java index 9ba760e..164b9d7 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java @@ -136,7 +136,8 @@ public final class UnsafeSorterSpillWriter { } public void close() throws IOException { - writer.commitAndClose(); + writer.commitAndGet(); + writer.close(); writer = null; writeBuffer = null; } http://git-wip-us.apache.org/repos/asf/spark/blob/daace601/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala index 5b493f4..e5b1bf2 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala @@ -27,8 +27,10 @@ import org.apache.spark.util.Utils /** * A class for writing JVM objects directly to a file on disk. This class allows data to be appended - * to an existing block and can guarantee atomicity in the case of faults as it allows the caller to - * revert partial writes. + * to an existing block. For efficiency, it retains the underlying file channel across + * multiple commits. This channel is kept open until close() is called. In case of faults, + * callers should instead close with revertPartialWritesAndClose() to atomically revert the + * uncommitted partial writes. * * This class does not support concurrent writes. Also, once the writer has been opened it cannot be * reopened again. @@ -46,34 +48,49 @@ private[spark] class DiskBlockObjectWriter( extends OutputStream with Logging { + /** + * Guards against close calls, e.g. from a wrapping stream. + * Call manualClose to close the stream that was extended by this trait. + * Commit uses this trait to close object streams without paying the + * cost of closing and opening the underlying file. + */ + private trait ManualCloseOutputStream extends OutputStream { + abstract override def close(): Unit = { + flush() + } + + def manualClose(): Unit = { + super.close() + } + } + /** The file channel, used for repositioning / truncating the file. */ private var channel: FileChannel = null + private var mcs: ManualCloseOutputStream = null private var bs: OutputStream = null private var fos: FileOutputStream = null private var ts: TimeTrackingOutputStream = null private var objOut: SerializationStream = null private var initialized = false + private var streamOpen = false private var hasBeenClosed = false - private var commitAndCloseHasBeenCalled = false /** * Cursors used to represent positions in the file. * - * xxxxxxxx|--------|--- | - * ^ ^ ^ - * | | finalPosition - * | reportedPosition - * initialPosition + * xxxxxxxxxx|----------|-----| + * ^ ^ ^ + * | | channel.position() + * | reportedPosition + * committedPosition * - * initialPosition: Offset in the file where we start writing. Immutable. * reportedPosition: Position at the time of the last update to the write metrics. - * finalPosition: Offset where we stopped writing. Set on closeAndCommit() then never changed. + * committedPosition: Offset after last committed write. * -----: Current writes to the underlying file. - * xxxxx: Existing contents of the file. + * xxxxx: Committed contents of the file. */ - private val initialPosition = file.length() - private var finalPosition: Long = -1 - private var reportedPosition = initialPosition + private var committedPosition = file.length() + private var reportedPosition = committedPosition /** * Keep track of number of records written and also use this to periodically @@ -81,67 +98,98 @@ private[spark] class DiskBlockObjectWriter( */ private var numRecordsWritten = 0 + private def initialize(): Unit = { + fos = new FileOutputStream(file, true) + channel = fos.getChannel() + ts = new TimeTrackingOutputStream(writeMetrics, fos) + class ManualCloseBufferedOutputStream + extends BufferedOutputStream(ts, bufferSize) with ManualCloseOutputStream + mcs = new ManualCloseBufferedOutputStream + } + def open(): DiskBlockObjectWriter = { if (hasBeenClosed) { throw new IllegalStateException("Writer already closed. Cannot be reopened.") } - fos = new FileOutputStream(file, true) - ts = new TimeTrackingOutputStream(writeMetrics, fos) - channel = fos.getChannel() - bs = compressStream(new BufferedOutputStream(ts, bufferSize)) + if (!initialized) { + initialize() + initialized = true + } + bs = compressStream(mcs) objOut = serializerInstance.serializeStream(bs) - initialized = true + streamOpen = true this } - override def close() { + /** + * Close and cleanup all resources. + * Should call after committing or reverting partial writes. + */ + private def closeResources(): Unit = { if (initialized) { - Utils.tryWithSafeFinally { - if (syncWrites) { - // Force outstanding writes to disk and track how long it takes - objOut.flush() - val start = System.nanoTime() - fos.getFD.sync() - writeMetrics.incWriteTime(System.nanoTime() - start) - } - } { - objOut.close() - } - + mcs.manualClose() channel = null + mcs = null bs = null fos = null ts = null objOut = null initialized = false + streamOpen = false hasBeenClosed = true } } - def isOpen: Boolean = objOut != null + /** + * Commits any remaining partial writes and closes resources. + */ + override def close() { + if (initialized) { + Utils.tryWithSafeFinally { + commitAndGet() + } { + closeResources() + } + } + } /** * Flush the partial writes and commit them as a single atomic block. + * A commit may write additional bytes to frame the atomic block. + * + * @return file segment with previous offset and length committed on this call. */ - def commitAndClose(): Unit = { - if (initialized) { + def commitAndGet(): FileSegment = { + if (streamOpen) { // NOTE: Because Kryo doesn't flush the underlying stream we explicitly flush both the // serializer stream and the lower level stream. objOut.flush() bs.flush() - close() - finalPosition = file.length() - // In certain compression codecs, more bytes are written after close() is called - writeMetrics.incBytesWritten(finalPosition - reportedPosition) + objOut.close() + streamOpen = false + + if (syncWrites) { + // Force outstanding writes to disk and track how long it takes + val start = System.nanoTime() + fos.getFD.sync() + writeMetrics.incWriteTime(System.nanoTime() - start) + } + + val pos = channel.position() + val fileSegment = new FileSegment(file, committedPosition, pos - committedPosition) + committedPosition = pos + // In certain compression codecs, more bytes are written after streams are closed + writeMetrics.incBytesWritten(committedPosition - reportedPosition) + reportedPosition = committedPosition + fileSegment } else { - finalPosition = file.length() + new FileSegment(file, committedPosition, 0) } - commitAndCloseHasBeenCalled = true } /** - * Reverts writes that haven't been flushed yet. Callers should invoke this function + * Reverts writes that haven't been committed yet. Callers should invoke this function * when there are runtime exceptions. This method will not throw, though it may be * unsuccessful in truncating written data. * @@ -152,16 +200,15 @@ private[spark] class DiskBlockObjectWriter( // truncating the file to its initial position. try { if (initialized) { - writeMetrics.decBytesWritten(reportedPosition - initialPosition) + writeMetrics.decBytesWritten(reportedPosition - committedPosition) writeMetrics.decRecordsWritten(numRecordsWritten) - objOut.flush() - bs.flush() - close() + streamOpen = false + closeResources() } val truncateStream = new FileOutputStream(file, true) try { - truncateStream.getChannel.truncate(initialPosition) + truncateStream.getChannel.truncate(committedPosition) file } finally { truncateStream.close() @@ -177,7 +224,7 @@ private[spark] class DiskBlockObjectWriter( * Writes a key-value pair. */ def write(key: Any, value: Any) { - if (!initialized) { + if (!streamOpen) { open() } @@ -189,7 +236,7 @@ private[spark] class DiskBlockObjectWriter( override def write(b: Int): Unit = throw new UnsupportedOperationException() override def write(kvBytes: Array[Byte], offs: Int, len: Int): Unit = { - if (!initialized) { + if (!streamOpen) { open() } @@ -209,18 +256,6 @@ private[spark] class DiskBlockObjectWriter( } /** - * Returns the file segment of committed data that this Writer has written. - * This is only valid after commitAndClose() has been called. - */ - def fileSegment(): FileSegment = { - if (!commitAndCloseHasBeenCalled) { - throw new IllegalStateException( - "fileSegment() is only valid after commitAndClose() has been called") - } - new FileSegment(file, initialPosition, finalPosition - initialPosition) - } - - /** * Report the number of bytes written in this writer's shuffle write metrics. * Note that this is only valid before the underlying streams are closed. */ http://git-wip-us.apache.org/repos/asf/spark/blob/daace601/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala index 6ddc72a..8c8860b 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala @@ -105,8 +105,8 @@ class ExternalAppendOnlyMap[K, V, C]( private val fileBufferSize = sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024 - // Write metrics for current spill - private var curWriteMetrics: ShuffleWriteMetrics = _ + // Write metrics + private val writeMetrics: ShuffleWriteMetrics = new ShuffleWriteMetrics() // Peak size of the in-memory map observed so far, in bytes private var _peakMemoryUsedBytes: Long = 0L @@ -206,8 +206,7 @@ class ExternalAppendOnlyMap[K, V, C]( private[this] def spillMemoryIteratorToDisk(inMemoryIterator: Iterator[(K, C)]) : DiskMapIterator = { val (blockId, file) = diskBlockManager.createTempLocalBlock() - curWriteMetrics = new ShuffleWriteMetrics() - var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) + val writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, writeMetrics) var objectsWritten = 0 // List of batch sizes (bytes) in the order they are written to disk @@ -215,11 +214,9 @@ class ExternalAppendOnlyMap[K, V, C]( // Flush the disk writer's contents to disk, and update relevant variables def flush(): Unit = { - val w = writer - writer = null - w.commitAndClose() - _diskBytesSpilled += curWriteMetrics.bytesWritten - batchSizes.append(curWriteMetrics.bytesWritten) + val segment = writer.commitAndGet() + batchSizes.append(segment.length) + _diskBytesSpilled += segment.length objectsWritten = 0 } @@ -232,25 +229,20 @@ class ExternalAppendOnlyMap[K, V, C]( if (objectsWritten == serializerBatchSize) { flush() - curWriteMetrics = new ShuffleWriteMetrics() - writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize, curWriteMetrics) } } if (objectsWritten > 0) { flush() - } else if (writer != null) { - val w = writer - writer = null - w.revertPartialWritesAndClose() + writer.close() + } else { + writer.revertPartialWritesAndClose() } success = true } finally { if (!success) { // This code path only happens if an exception was thrown above before we set success; // close our stuff and let the exception be thrown further - if (writer != null) { - writer.revertPartialWritesAndClose() - } + writer.revertPartialWritesAndClose() if (file.exists()) { if (!file.delete()) { logWarning(s"Error deleting ${file}") http://git-wip-us.apache.org/repos/asf/spark/blob/daace601/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala index 4067ace..708a007 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala @@ -272,14 +272,9 @@ private[spark] class ExternalSorter[K, V, C]( // These variables are reset after each flush var objectsWritten: Long = 0 - var spillMetrics: ShuffleWriteMetrics = null - var writer: DiskBlockObjectWriter = null - def openWriter(): Unit = { - assert (writer == null && spillMetrics == null) - spillMetrics = new ShuffleWriteMetrics - writer = blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, spillMetrics) - } - openWriter() + val spillMetrics: ShuffleWriteMetrics = new ShuffleWriteMetrics + val writer: DiskBlockObjectWriter = + blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, spillMetrics) // List of batch sizes (bytes) in the order they are written to disk val batchSizes = new ArrayBuffer[Long] @@ -288,14 +283,11 @@ private[spark] class ExternalSorter[K, V, C]( val elementsPerPartition = new Array[Long](numPartitions) // Flush the disk writer's contents to disk, and update relevant variables. - // The writer is closed at the end of this process, and cannot be reused. + // The writer is committed at the end of this process. def flush(): Unit = { - val w = writer - writer = null - w.commitAndClose() - _diskBytesSpilled += spillMetrics.bytesWritten - batchSizes.append(spillMetrics.bytesWritten) - spillMetrics = null + val segment = writer.commitAndGet() + batchSizes.append(segment.length) + _diskBytesSpilled += segment.length objectsWritten = 0 } @@ -311,24 +303,21 @@ private[spark] class ExternalSorter[K, V, C]( if (objectsWritten == serializerBatchSize) { flush() - openWriter() } } if (objectsWritten > 0) { flush() - } else if (writer != null) { - val w = writer - writer = null - w.revertPartialWritesAndClose() + } else { + writer.revertPartialWritesAndClose() } success = true } finally { - if (!success) { + if (success) { + writer.close() + } else { // This code path only happens if an exception was thrown above before we set success; // close our stuff and let the exception be thrown further - if (writer != null) { - writer.revertPartialWritesAndClose() - } + writer.revertPartialWritesAndClose() if (file.exists()) { if (!file.delete()) { logWarning(s"Error deleting ${file}") @@ -693,42 +682,37 @@ private[spark] class ExternalSorter[K, V, C]( blockId: BlockId, outputFile: File): Array[Long] = { - val writeMetrics = context.taskMetrics().shuffleWriteMetrics - // Track location of each range in the output file val lengths = new Array[Long](numPartitions) + val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, fileBufferSize, + context.taskMetrics().shuffleWriteMetrics) if (spills.isEmpty) { // Case where we only have in-memory data val collection = if (aggregator.isDefined) map else buffer val it = collection.destructiveSortedWritablePartitionedIterator(comparator) while (it.hasNext) { - val writer = blockManager.getDiskWriter( - blockId, outputFile, serInstance, fileBufferSize, writeMetrics) val partitionId = it.nextPartition() while (it.hasNext && it.nextPartition() == partitionId) { it.writeNext(writer) } - writer.commitAndClose() - val segment = writer.fileSegment() + val segment = writer.commitAndGet() lengths(partitionId) = segment.length } } else { // We must perform merge-sort; get an iterator by partition and write everything directly. for ((id, elements) <- this.partitionedIterator) { if (elements.hasNext) { - val writer = blockManager.getDiskWriter( - blockId, outputFile, serInstance, fileBufferSize, writeMetrics) for (elem <- elements) { writer.write(elem._1, elem._2) } - writer.commitAndClose() - val segment = writer.fileSegment() + val segment = writer.commitAndGet() lengths(id) = segment.length } } } + writer.close() context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled) context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled) context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes) http://git-wip-us.apache.org/repos/asf/spark/blob/daace601/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala ---------------------------------------------------------------------- diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala index ec4ef4b..059c2c2 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala @@ -60,7 +60,8 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } assert(writeMetrics.bytesWritten > 0) assert(writeMetrics.recordsWritten === 16385) - writer.commitAndClose() + writer.commitAndGet() + writer.close() assert(file.length() == writeMetrics.bytesWritten) } @@ -100,6 +101,40 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { } } + test("calling revertPartialWritesAndClose() on a partial write should truncate up to commit") { + val file = new File(tempDir, "somefile") + val writeMetrics = new ShuffleWriteMetrics() + val writer = new DiskBlockObjectWriter( + file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + + writer.write(Long.box(20), Long.box(30)) + val firstSegment = writer.commitAndGet() + assert(firstSegment.length === file.length()) + assert(writeMetrics.shuffleBytesWritten === file.length()) + + writer.write(Long.box(40), Long.box(50)) + + writer.revertPartialWritesAndClose() + assert(firstSegment.length === file.length()) + assert(writeMetrics.shuffleBytesWritten === file.length()) + } + + test("calling revertPartialWritesAndClose() after commit() should have no effect") { + val file = new File(tempDir, "somefile") + val writeMetrics = new ShuffleWriteMetrics() + val writer = new DiskBlockObjectWriter( + file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) + + writer.write(Long.box(20), Long.box(30)) + val firstSegment = writer.commitAndGet() + assert(firstSegment.length === file.length()) + assert(writeMetrics.shuffleBytesWritten === file.length()) + + writer.revertPartialWritesAndClose() + assert(firstSegment.length === file.length()) + assert(writeMetrics.shuffleBytesWritten === file.length()) + } + test("calling revertPartialWritesAndClose() on a closed block writer should have no effect") { val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() @@ -108,7 +143,8 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { for (i <- 1 to 1000) { writer.write(i, i) } - writer.commitAndClose() + writer.commitAndGet() + writer.close() val bytesWritten = writeMetrics.bytesWritten assert(writeMetrics.recordsWritten === 1000) writer.revertPartialWritesAndClose() @@ -116,7 +152,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { assert(writeMetrics.bytesWritten === bytesWritten) } - test("commitAndClose() should be idempotent") { + test("commit() and close() should be idempotent") { val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter( @@ -124,11 +160,13 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { for (i <- 1 to 1000) { writer.write(i, i) } - writer.commitAndClose() + writer.commitAndGet() + writer.close() val bytesWritten = writeMetrics.bytesWritten val writeTime = writeMetrics.writeTime assert(writeMetrics.recordsWritten === 1000) - writer.commitAndClose() + writer.commitAndGet() + writer.close() assert(writeMetrics.recordsWritten === 1000) assert(writeMetrics.bytesWritten === bytesWritten) assert(writeMetrics.writeTime === writeTime) @@ -152,26 +190,13 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with BeforeAndAfterEach { assert(writeMetrics.writeTime === writeTime) } - test("fileSegment() can only be called after commitAndClose() has been called") { + test("commit() and close() without ever opening or writing") { val file = new File(tempDir, "somefile") val writeMetrics = new ShuffleWriteMetrics() val writer = new DiskBlockObjectWriter( file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) - for (i <- 1 to 1000) { - writer.write(i, i) - } - intercept[IllegalStateException] { - writer.fileSegment() - } + val segment = writer.commitAndGet() writer.close() - } - - test("commitAndClose() without ever opening or writing") { - val file = new File(tempDir, "somefile") - val writeMetrics = new ShuffleWriteMetrics() - val writer = new DiskBlockObjectWriter( - file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, true, writeMetrics) - writer.commitAndClose() - assert(writer.fileSegment().length === 0) + assert(segment.length === 0) } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org