Repository: spark
Updated Branches:
  refs/heads/master 1221ce040 -> daace6014


[SPARK-5581][CORE] When writing sorted map output file, avoid open / …

…close between each partition

## What changes were proposed in this pull request?

Replace commitAndClose with separate commit and close to avoid opening and 
closing
the file between partitions.

## How was this patch tested?

Run existing unit tests, add a few unit tests regarding reverts.

Observed a ~20% reduction in total time in tasks on stages with shuffle
writes to many partitions.

JoshRosen

Author: Brian Cho <b...@fb.com>

Closes #13382 from dafrista/separatecommit-master.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/daace601
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/daace601
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/daace601

Branch: refs/heads/master
Commit: daace6014216b996bcc8937f1fdcea732b6910ca
Parents: 1221ce0
Author: Brian Cho <b...@fb.com>
Authored: Sun Jul 24 19:36:58 2016 -0700
Committer: Josh Rosen <joshro...@databricks.com>
Committed: Sun Jul 24 19:36:58 2016 -0700

----------------------------------------------------------------------
 .../sort/BypassMergeSortShuffleWriter.java      |  10 +-
 .../shuffle/sort/ShuffleExternalSorter.java     |  31 ++--
 .../unsafe/sort/UnsafeSorterSpillWriter.java    |   3 +-
 .../spark/storage/DiskBlockObjectWriter.scala   | 157 ++++++++++++-------
 .../util/collection/ExternalAppendOnlyMap.scala |  28 ++--
 .../spark/util/collection/ExternalSorter.scala  |  52 +++---
 .../storage/DiskBlockObjectWriterSuite.scala    |  67 +++++---
 7 files changed, 192 insertions(+), 156 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/daace601/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
 
b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index 0e9defe..83dc61c 100644
--- 
a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ 
b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -88,6 +88,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends 
ShuffleWriter<K, V> {
 
   /** Array of file writers, one for each partition */
   private DiskBlockObjectWriter[] partitionWriters;
+  private FileSegment[] partitionWriterSegments;
   @Nullable private MapStatus mapStatus;
   private long[] partitionLengths;
 
@@ -131,6 +132,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends 
ShuffleWriter<K, V> {
     final SerializerInstance serInstance = serializer.newInstance();
     final long openStartTime = System.nanoTime();
     partitionWriters = new DiskBlockObjectWriter[numPartitions];
+    partitionWriterSegments = new FileSegment[numPartitions];
     for (int i = 0; i < numPartitions; i++) {
       final Tuple2<TempShuffleBlockId, File> tempShuffleBlockIdPlusFile =
         blockManager.diskBlockManager().createTempShuffleBlock();
@@ -150,8 +152,10 @@ final class BypassMergeSortShuffleWriter<K, V> extends 
ShuffleWriter<K, V> {
       partitionWriters[partitioner.getPartition(key)].write(key, record._2());
     }
 
-    for (DiskBlockObjectWriter writer : partitionWriters) {
-      writer.commitAndClose();
+    for (int i = 0; i < numPartitions; i++) {
+      final DiskBlockObjectWriter writer = partitionWriters[i];
+      partitionWriterSegments[i] = writer.commitAndGet();
+      writer.close();
     }
 
     File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
@@ -184,7 +188,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends 
ShuffleWriter<K, V> {
     boolean threwException = true;
     try {
       for (int i = 0; i < numPartitions; i++) {
-        final File file = partitionWriters[i].fileSegment().file();
+        final File file = partitionWriterSegments[i].file();
         if (file.exists()) {
           final FileInputStream in = new FileInputStream(file);
           boolean copyThrewException = true;

http://git-wip-us.apache.org/repos/asf/spark/blob/daace601/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java 
b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
index cf38a04..cfec724 100644
--- 
a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
+++ 
b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
@@ -37,6 +37,7 @@ import org.apache.spark.serializer.DummySerializerInstance;
 import org.apache.spark.serializer.SerializerInstance;
 import org.apache.spark.storage.BlockManager;
 import org.apache.spark.storage.DiskBlockObjectWriter;
+import org.apache.spark.storage.FileSegment;
 import org.apache.spark.storage.TempShuffleBlockId;
 import org.apache.spark.unsafe.Platform;
 import org.apache.spark.unsafe.array.LongArray;
@@ -150,10 +151,6 @@ final class ShuffleExternalSorter extends MemoryConsumer {
     final ShuffleInMemorySorter.ShuffleSorterIterator sortedRecords =
       inMemSorter.getSortedIterator();
 
-    // Currently, we need to open a new DiskBlockObjectWriter for each 
partition; we can avoid this
-    // after SPARK-5581 is fixed.
-    DiskBlockObjectWriter writer;
-
     // Small writes to DiskBlockObjectWriter will be fairly inefficient. Since 
there doesn't seem to
     // be an API to directly transfer bytes from managed memory to the disk 
writer, we buffer
     // data through a byte array. This array does not need to be large enough 
to hold a single
@@ -175,7 +172,8 @@ final class ShuffleExternalSorter extends MemoryConsumer {
     // around this, we pass a dummy no-op serializer.
     final SerializerInstance ser = DummySerializerInstance.INSTANCE;
 
-    writer = blockManager.getDiskWriter(blockId, file, ser, 
fileBufferSizeBytes, writeMetricsToUse);
+    final DiskBlockObjectWriter writer =
+      blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, 
writeMetricsToUse);
 
     int currentPartition = -1;
     while (sortedRecords.hasNext()) {
@@ -185,12 +183,10 @@ final class ShuffleExternalSorter extends MemoryConsumer {
       if (partition != currentPartition) {
         // Switch to the new partition
         if (currentPartition != -1) {
-          writer.commitAndClose();
-          spillInfo.partitionLengths[currentPartition] = 
writer.fileSegment().length();
+          final FileSegment fileSegment = writer.commitAndGet();
+          spillInfo.partitionLengths[currentPartition] = fileSegment.length();
         }
         currentPartition = partition;
-        writer =
-          blockManager.getDiskWriter(blockId, file, ser, fileBufferSizeBytes, 
writeMetricsToUse);
       }
 
       final long recordPointer = 
sortedRecords.packedRecordPointer.getRecordPointer();
@@ -209,15 +205,14 @@ final class ShuffleExternalSorter extends MemoryConsumer {
       writer.recordWritten();
     }
 
-    if (writer != null) {
-      writer.commitAndClose();
-      // If `writeSortedFile()` was called from `closeAndGetSpills()` and no 
records were inserted,
-      // then the file might be empty. Note that it might be better to avoid 
calling
-      // writeSortedFile() in that case.
-      if (currentPartition != -1) {
-        spillInfo.partitionLengths[currentPartition] = 
writer.fileSegment().length();
-        spills.add(spillInfo);
-      }
+    final FileSegment committedSegment = writer.commitAndGet();
+    writer.close();
+    // If `writeSortedFile()` was called from `closeAndGetSpills()` and no 
records were inserted,
+    // then the file might be empty. Note that it might be better to avoid 
calling
+    // writeSortedFile() in that case.
+    if (currentPartition != -1) {
+      spillInfo.partitionLengths[currentPartition] = committedSegment.length();
+      spills.add(spillInfo);
     }
 
     if (!isLastFile) {  // i.e. this is a spill file

http://git-wip-us.apache.org/repos/asf/spark/blob/daace601/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
index 9ba760e..164b9d7 100644
--- 
a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
+++ 
b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillWriter.java
@@ -136,7 +136,8 @@ public final class UnsafeSorterSpillWriter {
   }
 
   public void close() throws IOException {
-    writer.commitAndClose();
+    writer.commitAndGet();
+    writer.close();
     writer = null;
     writeBuffer = null;
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/daace601/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala 
b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
index 5b493f4..e5b1bf2 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
@@ -27,8 +27,10 @@ import org.apache.spark.util.Utils
 
 /**
  * A class for writing JVM objects directly to a file on disk. This class 
allows data to be appended
- * to an existing block and can guarantee atomicity in the case of faults as 
it allows the caller to
- * revert partial writes.
+ * to an existing block. For efficiency, it retains the underlying file 
channel across
+ * multiple commits. This channel is kept open until close() is called. In 
case of faults,
+ * callers should instead close with revertPartialWritesAndClose() to 
atomically revert the
+ * uncommitted partial writes.
  *
  * This class does not support concurrent writes. Also, once the writer has 
been opened it cannot be
  * reopened again.
@@ -46,34 +48,49 @@ private[spark] class DiskBlockObjectWriter(
   extends OutputStream
   with Logging {
 
+  /**
+   * Guards against close calls, e.g. from a wrapping stream.
+   * Call manualClose to close the stream that was extended by this trait.
+   * Commit uses this trait to close object streams without paying the
+   * cost of closing and opening the underlying file.
+   */
+  private trait ManualCloseOutputStream extends OutputStream {
+    abstract override def close(): Unit = {
+      flush()
+    }
+
+    def manualClose(): Unit = {
+      super.close()
+    }
+  }
+
   /** The file channel, used for repositioning / truncating the file. */
   private var channel: FileChannel = null
+  private var mcs: ManualCloseOutputStream = null
   private var bs: OutputStream = null
   private var fos: FileOutputStream = null
   private var ts: TimeTrackingOutputStream = null
   private var objOut: SerializationStream = null
   private var initialized = false
+  private var streamOpen = false
   private var hasBeenClosed = false
-  private var commitAndCloseHasBeenCalled = false
 
   /**
    * Cursors used to represent positions in the file.
    *
-   * xxxxxxxx|--------|---       |
-   *         ^        ^          ^
-   *         |        |        finalPosition
-   *         |      reportedPosition
-   *       initialPosition
+   * xxxxxxxxxx|----------|-----|
+   *           ^          ^     ^
+   *           |          |    channel.position()
+   *           |        reportedPosition
+   *         committedPosition
    *
-   * initialPosition: Offset in the file where we start writing. Immutable.
    * reportedPosition: Position at the time of the last update to the write 
metrics.
-   * finalPosition: Offset where we stopped writing. Set on closeAndCommit() 
then never changed.
+   * committedPosition: Offset after last committed write.
    * -----: Current writes to the underlying file.
-   * xxxxx: Existing contents of the file.
+   * xxxxx: Committed contents of the file.
    */
-  private val initialPosition = file.length()
-  private var finalPosition: Long = -1
-  private var reportedPosition = initialPosition
+  private var committedPosition = file.length()
+  private var reportedPosition = committedPosition
 
   /**
    * Keep track of number of records written and also use this to periodically
@@ -81,67 +98,98 @@ private[spark] class DiskBlockObjectWriter(
    */
   private var numRecordsWritten = 0
 
+  private def initialize(): Unit = {
+    fos = new FileOutputStream(file, true)
+    channel = fos.getChannel()
+    ts = new TimeTrackingOutputStream(writeMetrics, fos)
+    class ManualCloseBufferedOutputStream
+      extends BufferedOutputStream(ts, bufferSize) with ManualCloseOutputStream
+    mcs = new ManualCloseBufferedOutputStream
+  }
+
   def open(): DiskBlockObjectWriter = {
     if (hasBeenClosed) {
       throw new IllegalStateException("Writer already closed. Cannot be 
reopened.")
     }
-    fos = new FileOutputStream(file, true)
-    ts = new TimeTrackingOutputStream(writeMetrics, fos)
-    channel = fos.getChannel()
-    bs = compressStream(new BufferedOutputStream(ts, bufferSize))
+    if (!initialized) {
+      initialize()
+      initialized = true
+    }
+    bs = compressStream(mcs)
     objOut = serializerInstance.serializeStream(bs)
-    initialized = true
+    streamOpen = true
     this
   }
 
-  override def close() {
+  /**
+   * Close and cleanup all resources.
+   * Should call after committing or reverting partial writes.
+   */
+  private def closeResources(): Unit = {
     if (initialized) {
-      Utils.tryWithSafeFinally {
-        if (syncWrites) {
-          // Force outstanding writes to disk and track how long it takes
-          objOut.flush()
-          val start = System.nanoTime()
-          fos.getFD.sync()
-          writeMetrics.incWriteTime(System.nanoTime() - start)
-        }
-      } {
-        objOut.close()
-      }
-
+      mcs.manualClose()
       channel = null
+      mcs = null
       bs = null
       fos = null
       ts = null
       objOut = null
       initialized = false
+      streamOpen = false
       hasBeenClosed = true
     }
   }
 
-  def isOpen: Boolean = objOut != null
+  /**
+   * Commits any remaining partial writes and closes resources.
+   */
+  override def close() {
+    if (initialized) {
+      Utils.tryWithSafeFinally {
+        commitAndGet()
+      } {
+        closeResources()
+      }
+    }
+  }
 
   /**
    * Flush the partial writes and commit them as a single atomic block.
+   * A commit may write additional bytes to frame the atomic block.
+   *
+   * @return file segment with previous offset and length committed on this 
call.
    */
-  def commitAndClose(): Unit = {
-    if (initialized) {
+  def commitAndGet(): FileSegment = {
+    if (streamOpen) {
       // NOTE: Because Kryo doesn't flush the underlying stream we explicitly 
flush both the
       //       serializer stream and the lower level stream.
       objOut.flush()
       bs.flush()
-      close()
-      finalPosition = file.length()
-      // In certain compression codecs, more bytes are written after close() 
is called
-      writeMetrics.incBytesWritten(finalPosition - reportedPosition)
+      objOut.close()
+      streamOpen = false
+
+      if (syncWrites) {
+        // Force outstanding writes to disk and track how long it takes
+        val start = System.nanoTime()
+        fos.getFD.sync()
+        writeMetrics.incWriteTime(System.nanoTime() - start)
+      }
+
+      val pos = channel.position()
+      val fileSegment = new FileSegment(file, committedPosition, pos - 
committedPosition)
+      committedPosition = pos
+      // In certain compression codecs, more bytes are written after streams 
are closed
+      writeMetrics.incBytesWritten(committedPosition - reportedPosition)
+      reportedPosition = committedPosition
+      fileSegment
     } else {
-      finalPosition = file.length()
+      new FileSegment(file, committedPosition, 0)
     }
-    commitAndCloseHasBeenCalled = true
   }
 
 
   /**
-   * Reverts writes that haven't been flushed yet. Callers should invoke this 
function
+   * Reverts writes that haven't been committed yet. Callers should invoke 
this function
    * when there are runtime exceptions. This method will not throw, though it 
may be
    * unsuccessful in truncating written data.
    *
@@ -152,16 +200,15 @@ private[spark] class DiskBlockObjectWriter(
     // truncating the file to its initial position.
     try {
       if (initialized) {
-        writeMetrics.decBytesWritten(reportedPosition - initialPosition)
+        writeMetrics.decBytesWritten(reportedPosition - committedPosition)
         writeMetrics.decRecordsWritten(numRecordsWritten)
-        objOut.flush()
-        bs.flush()
-        close()
+        streamOpen = false
+        closeResources()
       }
 
       val truncateStream = new FileOutputStream(file, true)
       try {
-        truncateStream.getChannel.truncate(initialPosition)
+        truncateStream.getChannel.truncate(committedPosition)
         file
       } finally {
         truncateStream.close()
@@ -177,7 +224,7 @@ private[spark] class DiskBlockObjectWriter(
    * Writes a key-value pair.
    */
   def write(key: Any, value: Any) {
-    if (!initialized) {
+    if (!streamOpen) {
       open()
     }
 
@@ -189,7 +236,7 @@ private[spark] class DiskBlockObjectWriter(
   override def write(b: Int): Unit = throw new UnsupportedOperationException()
 
   override def write(kvBytes: Array[Byte], offs: Int, len: Int): Unit = {
-    if (!initialized) {
+    if (!streamOpen) {
       open()
     }
 
@@ -209,18 +256,6 @@ private[spark] class DiskBlockObjectWriter(
   }
 
   /**
-   * Returns the file segment of committed data that this Writer has written.
-   * This is only valid after commitAndClose() has been called.
-   */
-  def fileSegment(): FileSegment = {
-    if (!commitAndCloseHasBeenCalled) {
-      throw new IllegalStateException(
-        "fileSegment() is only valid after commitAndClose() has been called")
-    }
-    new FileSegment(file, initialPosition, finalPosition - initialPosition)
-  }
-
-  /**
    * Report the number of bytes written in this writer's shuffle write metrics.
    * Note that this is only valid before the underlying streams are closed.
    */

http://git-wip-us.apache.org/repos/asf/spark/blob/daace601/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
 
b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index 6ddc72a..8c8860b 100644
--- 
a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ 
b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -105,8 +105,8 @@ class ExternalAppendOnlyMap[K, V, C](
   private val fileBufferSize =
     sparkConf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024
 
-  // Write metrics for current spill
-  private var curWriteMetrics: ShuffleWriteMetrics = _
+  // Write metrics
+  private val writeMetrics: ShuffleWriteMetrics = new ShuffleWriteMetrics()
 
   // Peak size of the in-memory map observed so far, in bytes
   private var _peakMemoryUsedBytes: Long = 0L
@@ -206,8 +206,7 @@ class ExternalAppendOnlyMap[K, V, C](
   private[this] def spillMemoryIteratorToDisk(inMemoryIterator: Iterator[(K, 
C)])
       : DiskMapIterator = {
     val (blockId, file) = diskBlockManager.createTempLocalBlock()
-    curWriteMetrics = new ShuffleWriteMetrics()
-    var writer = blockManager.getDiskWriter(blockId, file, ser, 
fileBufferSize, curWriteMetrics)
+    val writer = blockManager.getDiskWriter(blockId, file, ser, 
fileBufferSize, writeMetrics)
     var objectsWritten = 0
 
     // List of batch sizes (bytes) in the order they are written to disk
@@ -215,11 +214,9 @@ class ExternalAppendOnlyMap[K, V, C](
 
     // Flush the disk writer's contents to disk, and update relevant variables
     def flush(): Unit = {
-      val w = writer
-      writer = null
-      w.commitAndClose()
-      _diskBytesSpilled += curWriteMetrics.bytesWritten
-      batchSizes.append(curWriteMetrics.bytesWritten)
+      val segment = writer.commitAndGet()
+      batchSizes.append(segment.length)
+      _diskBytesSpilled += segment.length
       objectsWritten = 0
     }
 
@@ -232,25 +229,20 @@ class ExternalAppendOnlyMap[K, V, C](
 
         if (objectsWritten == serializerBatchSize) {
           flush()
-          curWriteMetrics = new ShuffleWriteMetrics()
-          writer = blockManager.getDiskWriter(blockId, file, ser, 
fileBufferSize, curWriteMetrics)
         }
       }
       if (objectsWritten > 0) {
         flush()
-      } else if (writer != null) {
-        val w = writer
-        writer = null
-        w.revertPartialWritesAndClose()
+        writer.close()
+      } else {
+        writer.revertPartialWritesAndClose()
       }
       success = true
     } finally {
       if (!success) {
         // This code path only happens if an exception was thrown above before 
we set success;
         // close our stuff and let the exception be thrown further
-        if (writer != null) {
-          writer.revertPartialWritesAndClose()
-        }
+        writer.revertPartialWritesAndClose()
         if (file.exists()) {
           if (!file.delete()) {
             logWarning(s"Error deleting ${file}")

http://git-wip-us.apache.org/repos/asf/spark/blob/daace601/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala 
b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 4067ace..708a007 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -272,14 +272,9 @@ private[spark] class ExternalSorter[K, V, C](
 
     // These variables are reset after each flush
     var objectsWritten: Long = 0
-    var spillMetrics: ShuffleWriteMetrics = null
-    var writer: DiskBlockObjectWriter = null
-    def openWriter(): Unit = {
-      assert (writer == null && spillMetrics == null)
-      spillMetrics = new ShuffleWriteMetrics
-      writer = blockManager.getDiskWriter(blockId, file, serInstance, 
fileBufferSize, spillMetrics)
-    }
-    openWriter()
+    val spillMetrics: ShuffleWriteMetrics = new ShuffleWriteMetrics
+    val writer: DiskBlockObjectWriter =
+      blockManager.getDiskWriter(blockId, file, serInstance, fileBufferSize, 
spillMetrics)
 
     // List of batch sizes (bytes) in the order they are written to disk
     val batchSizes = new ArrayBuffer[Long]
@@ -288,14 +283,11 @@ private[spark] class ExternalSorter[K, V, C](
     val elementsPerPartition = new Array[Long](numPartitions)
 
     // Flush the disk writer's contents to disk, and update relevant variables.
-    // The writer is closed at the end of this process, and cannot be reused.
+    // The writer is committed at the end of this process.
     def flush(): Unit = {
-      val w = writer
-      writer = null
-      w.commitAndClose()
-      _diskBytesSpilled += spillMetrics.bytesWritten
-      batchSizes.append(spillMetrics.bytesWritten)
-      spillMetrics = null
+      val segment = writer.commitAndGet()
+      batchSizes.append(segment.length)
+      _diskBytesSpilled += segment.length
       objectsWritten = 0
     }
 
@@ -311,24 +303,21 @@ private[spark] class ExternalSorter[K, V, C](
 
         if (objectsWritten == serializerBatchSize) {
           flush()
-          openWriter()
         }
       }
       if (objectsWritten > 0) {
         flush()
-      } else if (writer != null) {
-        val w = writer
-        writer = null
-        w.revertPartialWritesAndClose()
+      } else {
+        writer.revertPartialWritesAndClose()
       }
       success = true
     } finally {
-      if (!success) {
+      if (success) {
+        writer.close()
+      } else {
         // This code path only happens if an exception was thrown above before 
we set success;
         // close our stuff and let the exception be thrown further
-        if (writer != null) {
-          writer.revertPartialWritesAndClose()
-        }
+        writer.revertPartialWritesAndClose()
         if (file.exists()) {
           if (!file.delete()) {
             logWarning(s"Error deleting ${file}")
@@ -693,42 +682,37 @@ private[spark] class ExternalSorter[K, V, C](
       blockId: BlockId,
       outputFile: File): Array[Long] = {
 
-    val writeMetrics = context.taskMetrics().shuffleWriteMetrics
-
     // Track location of each range in the output file
     val lengths = new Array[Long](numPartitions)
+    val writer = blockManager.getDiskWriter(blockId, outputFile, serInstance, 
fileBufferSize,
+      context.taskMetrics().shuffleWriteMetrics)
 
     if (spills.isEmpty) {
       // Case where we only have in-memory data
       val collection = if (aggregator.isDefined) map else buffer
       val it = 
collection.destructiveSortedWritablePartitionedIterator(comparator)
       while (it.hasNext) {
-        val writer = blockManager.getDiskWriter(
-          blockId, outputFile, serInstance, fileBufferSize, writeMetrics)
         val partitionId = it.nextPartition()
         while (it.hasNext && it.nextPartition() == partitionId) {
           it.writeNext(writer)
         }
-        writer.commitAndClose()
-        val segment = writer.fileSegment()
+        val segment = writer.commitAndGet()
         lengths(partitionId) = segment.length
       }
     } else {
       // We must perform merge-sort; get an iterator by partition and write 
everything directly.
       for ((id, elements) <- this.partitionedIterator) {
         if (elements.hasNext) {
-          val writer = blockManager.getDiskWriter(
-            blockId, outputFile, serInstance, fileBufferSize, writeMetrics)
           for (elem <- elements) {
             writer.write(elem._1, elem._2)
           }
-          writer.commitAndClose()
-          val segment = writer.fileSegment()
+          val segment = writer.commitAndGet()
           lengths(id) = segment.length
         }
       }
     }
 
+    writer.close()
     context.taskMetrics().incMemoryBytesSpilled(memoryBytesSpilled)
     context.taskMetrics().incDiskBytesSpilled(diskBytesSpilled)
     context.taskMetrics().incPeakExecutionMemory(peakMemoryUsedBytes)

http://git-wip-us.apache.org/repos/asf/spark/blob/daace601/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala 
b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala
index ec4ef4b..059c2c2 100644
--- 
a/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/storage/DiskBlockObjectWriterSuite.scala
@@ -60,7 +60,8 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with 
BeforeAndAfterEach {
     }
     assert(writeMetrics.bytesWritten > 0)
     assert(writeMetrics.recordsWritten === 16385)
-    writer.commitAndClose()
+    writer.commitAndGet()
+    writer.close()
     assert(file.length() == writeMetrics.bytesWritten)
   }
 
@@ -100,6 +101,40 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite 
with BeforeAndAfterEach {
     }
   }
 
+  test("calling revertPartialWritesAndClose() on a partial write should 
truncate up to commit") {
+    val file = new File(tempDir, "somefile")
+    val writeMetrics = new ShuffleWriteMetrics()
+    val writer = new DiskBlockObjectWriter(
+      file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, 
true, writeMetrics)
+
+    writer.write(Long.box(20), Long.box(30))
+    val firstSegment = writer.commitAndGet()
+    assert(firstSegment.length === file.length())
+    assert(writeMetrics.shuffleBytesWritten === file.length())
+
+    writer.write(Long.box(40), Long.box(50))
+
+    writer.revertPartialWritesAndClose()
+    assert(firstSegment.length === file.length())
+    assert(writeMetrics.shuffleBytesWritten === file.length())
+  }
+
+  test("calling revertPartialWritesAndClose() after commit() should have no 
effect") {
+    val file = new File(tempDir, "somefile")
+    val writeMetrics = new ShuffleWriteMetrics()
+    val writer = new DiskBlockObjectWriter(
+      file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, 
true, writeMetrics)
+
+    writer.write(Long.box(20), Long.box(30))
+    val firstSegment = writer.commitAndGet()
+    assert(firstSegment.length === file.length())
+    assert(writeMetrics.shuffleBytesWritten === file.length())
+
+    writer.revertPartialWritesAndClose()
+    assert(firstSegment.length === file.length())
+    assert(writeMetrics.shuffleBytesWritten === file.length())
+  }
+
   test("calling revertPartialWritesAndClose() on a closed block writer should 
have no effect") {
     val file = new File(tempDir, "somefile")
     val writeMetrics = new ShuffleWriteMetrics()
@@ -108,7 +143,8 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with 
BeforeAndAfterEach {
     for (i <- 1 to 1000) {
       writer.write(i, i)
     }
-    writer.commitAndClose()
+    writer.commitAndGet()
+    writer.close()
     val bytesWritten = writeMetrics.bytesWritten
     assert(writeMetrics.recordsWritten === 1000)
     writer.revertPartialWritesAndClose()
@@ -116,7 +152,7 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite with 
BeforeAndAfterEach {
     assert(writeMetrics.bytesWritten === bytesWritten)
   }
 
-  test("commitAndClose() should be idempotent") {
+  test("commit() and close() should be idempotent") {
     val file = new File(tempDir, "somefile")
     val writeMetrics = new ShuffleWriteMetrics()
     val writer = new DiskBlockObjectWriter(
@@ -124,11 +160,13 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite 
with BeforeAndAfterEach {
     for (i <- 1 to 1000) {
       writer.write(i, i)
     }
-    writer.commitAndClose()
+    writer.commitAndGet()
+    writer.close()
     val bytesWritten = writeMetrics.bytesWritten
     val writeTime = writeMetrics.writeTime
     assert(writeMetrics.recordsWritten === 1000)
-    writer.commitAndClose()
+    writer.commitAndGet()
+    writer.close()
     assert(writeMetrics.recordsWritten === 1000)
     assert(writeMetrics.bytesWritten === bytesWritten)
     assert(writeMetrics.writeTime === writeTime)
@@ -152,26 +190,13 @@ class DiskBlockObjectWriterSuite extends SparkFunSuite 
with BeforeAndAfterEach {
     assert(writeMetrics.writeTime === writeTime)
   }
 
-  test("fileSegment() can only be called after commitAndClose() has been 
called") {
+  test("commit() and close() without ever opening or writing") {
     val file = new File(tempDir, "somefile")
     val writeMetrics = new ShuffleWriteMetrics()
     val writer = new DiskBlockObjectWriter(
       file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, 
true, writeMetrics)
-    for (i <- 1 to 1000) {
-      writer.write(i, i)
-    }
-    intercept[IllegalStateException] {
-      writer.fileSegment()
-    }
+    val segment = writer.commitAndGet()
     writer.close()
-  }
-
-  test("commitAndClose() without ever opening or writing") {
-    val file = new File(tempDir, "somefile")
-    val writeMetrics = new ShuffleWriteMetrics()
-    val writer = new DiskBlockObjectWriter(
-      file, new JavaSerializer(new SparkConf()).newInstance(), 1024, os => os, 
true, writeMetrics)
-    writer.commitAndClose()
-    assert(writer.fileSegment().length === 0)
+    assert(segment.length === 0)
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to