Repository: spark
Updated Branches:
  refs/heads/branch-1.1 4c19614e9 -> a65c9ac11


SPARK-2566. Update ShuffleWriteMetrics incrementally

I haven't tested this out on a cluster yet, but wanted to make sure the 
approach (passing ShuffleWriteMetrics down to DiskBlockObjectWriter) was ok

Author: Sandy Ryza <sa...@cloudera.com>

Closes #1481 from sryza/sandy-spark-2566 and squashes the following commits:

8090d88 [Sandy Ryza] Fix ExternalSorter
b2a62ed [Sandy Ryza] Fix more test failures
8be6218 [Sandy Ryza] Fix test failures and mark a couple variables private
c5e68e5 [Sandy Ryza] SPARK-2566. Update ShuffleWriteMetrics incrementally
(cherry picked from commit 4e982364426c7d65032e8006c63ca4f9a0d40470)

Signed-off-by: Patrick Wendell <pwend...@gmail.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a65c9ac1
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a65c9ac1
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a65c9ac1

Branch: refs/heads/branch-1.1
Commit: a65c9ac11e7075c2d7a925772273b9b7cf9586d6
Parents: 4c19614
Author: Sandy Ryza <sa...@cloudera.com>
Authored: Wed Aug 6 13:10:33 2014 -0700
Committer: Patrick Wendell <pwend...@gmail.com>
Committed: Wed Aug 6 13:10:43 2014 -0700

----------------------------------------------------------------------
 .../org/apache/spark/executor/TaskMetrics.scala |  4 +-
 .../spark/shuffle/hash/HashShuffleWriter.scala  | 16 ++--
 .../spark/shuffle/sort/SortShuffleWriter.scala  | 16 ++--
 .../org/apache/spark/storage/BlockManager.scala | 12 +--
 .../spark/storage/BlockObjectWriter.scala       | 77 +++++++++++---------
 .../spark/storage/ShuffleBlockManager.scala     |  9 ++-
 .../util/collection/ExternalAppendOnlyMap.scala | 18 +++--
 .../spark/util/collection/ExternalSorter.scala  | 17 +++--
 .../spark/storage/BlockObjectWriterSuite.scala  | 65 +++++++++++++++++
 .../spark/storage/DiskBlockManagerSuite.scala   |  9 ++-
 .../apache/spark/tools/StoragePerfTester.scala  |  3 +-
 11 files changed, 164 insertions(+), 82 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/a65c9ac1/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala 
b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index 56cd872..11a6e10 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -190,10 +190,10 @@ class ShuffleWriteMetrics extends Serializable {
   /**
    * Number of bytes written for the shuffle by this task
    */
-  var shuffleBytesWritten: Long = _
+  @volatile var shuffleBytesWritten: Long = _
 
   /**
    * Time the task spent blocking on writes to disk or buffer cache, in 
nanoseconds
    */
-  var shuffleWriteTime: Long = _
+  @volatile var shuffleWriteTime: Long = _
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a65c9ac1/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala 
b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index 45d3b8b..51e454d 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -39,10 +39,14 @@ private[spark] class HashShuffleWriter[K, V](
   // we don't try deleting files, etc twice.
   private var stopping = false
 
+  private val writeMetrics = new ShuffleWriteMetrics()
+  metrics.shuffleWriteMetrics = Some(writeMetrics)
+
   private val blockManager = SparkEnv.get.blockManager
   private val shuffleBlockManager = blockManager.shuffleBlockManager
   private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null))
-  private val shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, mapId, 
numOutputSplits, ser)
+  private val shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, mapId, 
numOutputSplits, ser,
+    writeMetrics)
 
   /** Write a bunch of records to this task's output */
   override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
@@ -99,22 +103,12 @@ private[spark] class HashShuffleWriter[K, V](
 
   private def commitWritesAndBuildStatus(): MapStatus = {
     // Commit the writes. Get the size of each bucket block (total block size).
-    var totalBytes = 0L
-    var totalTime = 0L
     val compressedSizes = shuffle.writers.map { writer: BlockObjectWriter =>
       writer.commitAndClose()
       val size = writer.fileSegment().length
-      totalBytes += size
-      totalTime += writer.timeWriting()
       MapOutputTracker.compressSize(size)
     }
 
-    // Update shuffle metrics.
-    val shuffleMetrics = new ShuffleWriteMetrics
-    shuffleMetrics.shuffleBytesWritten = totalBytes
-    shuffleMetrics.shuffleWriteTime = totalTime
-    metrics.shuffleWriteMetrics = Some(shuffleMetrics)
-
     new MapStatus(blockManager.blockManagerId, compressedSizes)
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a65c9ac1/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala 
b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index 24db2f2..e54e638 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -52,6 +52,9 @@ private[spark] class SortShuffleWriter[K, V, C](
 
   private var mapStatus: MapStatus = null
 
+  private val writeMetrics = new ShuffleWriteMetrics()
+  context.taskMetrics.shuffleWriteMetrics = Some(writeMetrics)
+
   /** Write a bunch of records to this task's output */
   override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
     // Get an iterator with the elements for each partition ID
@@ -84,13 +87,10 @@ private[spark] class SortShuffleWriter[K, V, C](
     val offsets = new Array[Long](numPartitions + 1)
     val lengths = new Array[Long](numPartitions)
 
-    // Statistics
-    var totalBytes = 0L
-    var totalTime = 0L
-
     for ((id, elements) <- partitions) {
       if (elements.hasNext) {
-        val writer = blockManager.getDiskWriter(blockId, outputFile, ser, 
fileBufferSize)
+        val writer = blockManager.getDiskWriter(blockId, outputFile, ser, 
fileBufferSize,
+          writeMetrics)
         for (elem <- elements) {
           writer.write(elem)
         }
@@ -98,18 +98,12 @@ private[spark] class SortShuffleWriter[K, V, C](
         val segment = writer.fileSegment()
         offsets(id + 1) = segment.offset + segment.length
         lengths(id) = segment.length
-        totalTime += writer.timeWriting()
-        totalBytes += segment.length
       } else {
         // The partition is empty; don't create a new writer to avoid writing 
headers, etc
         offsets(id + 1) = offsets(id)
       }
     }
 
-    val shuffleMetrics = new ShuffleWriteMetrics
-    shuffleMetrics.shuffleBytesWritten = totalBytes
-    shuffleMetrics.shuffleWriteTime = totalTime
-    context.taskMetrics.shuffleWriteMetrics = Some(shuffleMetrics)
     context.taskMetrics.memoryBytesSpilled += sorter.memoryBytesSpilled
     context.taskMetrics.diskBytesSpilled += sorter.diskBytesSpilled
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a65c9ac1/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala 
b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 3876cf4..8d21b02 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -29,7 +29,7 @@ import akka.actor.{ActorSystem, Cancellable, Props}
 import sun.nio.ch.DirectBuffer
 
 import org.apache.spark._
-import org.apache.spark.executor.{DataReadMethod, InputMetrics}
+import org.apache.spark.executor.{DataReadMethod, InputMetrics, 
ShuffleWriteMetrics}
 import org.apache.spark.io.CompressionCodec
 import org.apache.spark.network._
 import org.apache.spark.serializer.Serializer
@@ -562,17 +562,19 @@ private[spark] class BlockManager(
 
   /**
    * A short circuited method to get a block writer that can write data 
directly to disk.
-   * The Block will be appended to the File specified by filename. This is 
currently used for
-   * writing shuffle files out. Callers should handle error cases.
+   * The Block will be appended to the File specified by filename. Callers 
should handle error
+   * cases.
    */
   def getDiskWriter(
       blockId: BlockId,
       file: File,
       serializer: Serializer,
-      bufferSize: Int): BlockObjectWriter = {
+      bufferSize: Int,
+      writeMetrics: ShuffleWriteMetrics): BlockObjectWriter = {
     val compressStream: OutputStream => OutputStream = 
wrapForCompression(blockId, _)
     val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
-    new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, 
compressStream, syncWrites)
+    new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, 
compressStream, syncWrites,
+      writeMetrics)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/a65c9ac1/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala 
b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index 01d46e1..adda971 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -22,6 +22,7 @@ import java.nio.channels.FileChannel
 
 import org.apache.spark.Logging
 import org.apache.spark.serializer.{SerializationStream, Serializer}
+import org.apache.spark.executor.ShuffleWriteMetrics
 
 /**
  * An interface for writing JVM objects to some underlying storage. This 
interface allows
@@ -60,41 +61,26 @@ private[spark] abstract class BlockObjectWriter(val 
blockId: BlockId) {
    * This is only valid after commitAndClose() has been called.
    */
   def fileSegment(): FileSegment
-
-  /**
-   * Cumulative time spent performing blocking writes, in ns.
-   */
-  def timeWriting(): Long
-
-  /**
-   * Number of bytes written so far
-   */
-  def bytesWritten: Long
 }
 
-/** BlockObjectWriter which writes directly to a file on disk. Appends to the 
given file. */
+/**
+ * BlockObjectWriter which writes directly to a file on disk. Appends to the 
given file.
+ * The given write metrics will be updated incrementally, but will not 
necessarily be current until
+ * commitAndClose is called.
+ */
 private[spark] class DiskBlockObjectWriter(
     blockId: BlockId,
     file: File,
     serializer: Serializer,
     bufferSize: Int,
     compressStream: OutputStream => OutputStream,
-    syncWrites: Boolean)
+    syncWrites: Boolean,
+    writeMetrics: ShuffleWriteMetrics)
   extends BlockObjectWriter(blockId)
   with Logging
 {
-
   /** Intercepts write calls and tracks total time spent writing. Not thread 
safe. */
   private class TimeTrackingOutputStream(out: OutputStream) extends 
OutputStream {
-    def timeWriting = _timeWriting
-    private var _timeWriting = 0L
-
-    private def callWithTiming(f: => Unit) = {
-      val start = System.nanoTime()
-      f
-      _timeWriting += (System.nanoTime() - start)
-    }
-
     def write(i: Int): Unit = callWithTiming(out.write(i))
     override def write(b: Array[Byte]) = callWithTiming(out.write(b))
     override def write(b: Array[Byte], off: Int, len: Int) = 
callWithTiming(out.write(b, off, len))
@@ -111,7 +97,11 @@ private[spark] class DiskBlockObjectWriter(
   private val initialPosition = file.length()
   private var finalPosition: Long = -1
   private var initialized = false
-  private var _timeWriting = 0L
+
+  /** Calling channel.position() to update the write metrics can be a little 
bit expensive, so we
+    * only call it every N writes */
+  private var writesSinceMetricsUpdate = 0
+  private var lastPosition = initialPosition
 
   override def open(): BlockObjectWriter = {
     fos = new FileOutputStream(file, true)
@@ -128,14 +118,11 @@ private[spark] class DiskBlockObjectWriter(
       if (syncWrites) {
         // Force outstanding writes to disk and track how long it takes
         objOut.flush()
-        val start = System.nanoTime()
-        fos.getFD.sync()
-        _timeWriting += System.nanoTime() - start
+        def sync = fos.getFD.sync()
+        callWithTiming(sync)
       }
       objOut.close()
 
-      _timeWriting += ts.timeWriting
-
       channel = null
       bs = null
       fos = null
@@ -153,6 +140,7 @@ private[spark] class DiskBlockObjectWriter(
       //       serializer stream and the lower level stream.
       objOut.flush()
       bs.flush()
+      updateBytesWritten()
       close()
     }
     finalPosition = file.length()
@@ -162,6 +150,8 @@ private[spark] class DiskBlockObjectWriter(
   // truncating the file to its initial position.
   override def revertPartialWritesAndClose() {
     try {
+      writeMetrics.shuffleBytesWritten -= (lastPosition - initialPosition)
+
       if (initialized) {
         objOut.flush()
         bs.flush()
@@ -184,19 +174,36 @@ private[spark] class DiskBlockObjectWriter(
     if (!initialized) {
       open()
     }
+
     objOut.writeObject(value)
+
+    if (writesSinceMetricsUpdate == 32) {
+      writesSinceMetricsUpdate = 0
+      updateBytesWritten()
+    } else {
+      writesSinceMetricsUpdate += 1
+    }
   }
 
   override def fileSegment(): FileSegment = {
-    new FileSegment(file, initialPosition, bytesWritten)
+    new FileSegment(file, initialPosition, finalPosition - initialPosition)
   }
 
-  // Only valid if called after close()
-  override def timeWriting() = _timeWriting
+  private def updateBytesWritten() {
+    val pos = channel.position()
+    writeMetrics.shuffleBytesWritten += (pos - lastPosition)
+    lastPosition = pos
+  }
+
+  private def callWithTiming(f: => Unit) = {
+    val start = System.nanoTime()
+    f
+    writeMetrics.shuffleWriteTime += (System.nanoTime() - start)
+  }
 
-  // Only valid if called after commit()
-  override def bytesWritten: Long = {
-    assert(finalPosition != -1, "bytesWritten is only valid after successful 
commit()")
-    finalPosition - initialPosition
+  // For testing
+  private[spark] def flush() {
+    objOut.flush()
+    bs.flush()
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/a65c9ac1/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala 
b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
index f9fdffa..3565719 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
@@ -29,6 +29,7 @@ import 
org.apache.spark.storage.ShuffleBlockManager.ShuffleFileGroup
 import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, 
TimeStampedHashMap}
 import org.apache.spark.util.collection.{PrimitiveKeyOpenHashMap, 
PrimitiveVector}
 import org.apache.spark.shuffle.sort.SortShuffleManager
+import org.apache.spark.executor.ShuffleWriteMetrics
 
 /** A group of writers for a ShuffleMapTask, one writer per reducer. */
 private[spark] trait ShuffleWriterGroup {
@@ -111,7 +112,8 @@ class ShuffleBlockManager(blockManager: BlockManager) 
extends Logging {
    * Get a ShuffleWriterGroup for the given map task, which will register it 
as complete
    * when the writers are closed successfully
    */
-  def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: 
Serializer) = {
+  def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: 
Serializer,
+      writeMetrics: ShuffleWriteMetrics) = {
     new ShuffleWriterGroup {
       shuffleStates.putIfAbsent(shuffleId, new ShuffleState(numBuckets))
       private val shuffleState = shuffleStates(shuffleId)
@@ -121,7 +123,8 @@ class ShuffleBlockManager(blockManager: BlockManager) 
extends Logging {
         fileGroup = getUnusedFileGroup()
         Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
           val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
-          blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, 
bufferSize)
+          blockManager.getDiskWriter(blockId, fileGroup(bucketId), serializer, 
bufferSize,
+            writeMetrics)
         }
       } else {
         Array.tabulate[BlockObjectWriter](numBuckets) { bucketId =>
@@ -136,7 +139,7 @@ class ShuffleBlockManager(blockManager: BlockManager) 
extends Logging {
               logWarning(s"Failed to remove existing shuffle file $blockFile")
             }
           }
-          blockManager.getDiskWriter(blockId, blockFile, serializer, 
bufferSize)
+          blockManager.getDiskWriter(blockId, blockFile, serializer, 
bufferSize, writeMetrics)
         }
       }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/a65c9ac1/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
 
b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
index 260a5c3..9f85b94 100644
--- 
a/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
+++ 
b/core/src/main/scala/org/apache/spark/util/collection/ExternalAppendOnlyMap.scala
@@ -31,6 +31,7 @@ import org.apache.spark.annotation.DeveloperApi
 import org.apache.spark.serializer.{DeserializationStream, Serializer}
 import org.apache.spark.storage.{BlockId, BlockManager}
 import org.apache.spark.util.collection.ExternalAppendOnlyMap.HashComparator
+import org.apache.spark.executor.ShuffleWriteMetrics
 
 /**
  * :: DeveloperApi ::
@@ -102,6 +103,10 @@ class ExternalAppendOnlyMap[K, V, C](
   private var _diskBytesSpilled = 0L
 
   private val fileBufferSize = 
sparkConf.getInt("spark.shuffle.file.buffer.kb", 32) * 1024
+
+  // Write metrics for current spill
+  private var curWriteMetrics: ShuffleWriteMetrics = _
+
   private val keyComparator = new HashComparator[K]
   private val ser = serializer.newInstance()
 
@@ -172,7 +177,9 @@ class ExternalAppendOnlyMap[K, V, C](
     logInfo("Thread %d spilling in-memory map of %d MB to disk (%d time%s so 
far)"
       .format(threadId, mapSize / (1024 * 1024), spillCount, if (spillCount > 
1) "s" else ""))
     val (blockId, file) = diskBlockManager.createTempBlock()
-    var writer = blockManager.getDiskWriter(blockId, file, serializer, 
fileBufferSize)
+    curWriteMetrics = new ShuffleWriteMetrics()
+    var writer = blockManager.getDiskWriter(blockId, file, serializer, 
fileBufferSize,
+      curWriteMetrics)
     var objectsWritten = 0
 
     // List of batch sizes (bytes) in the order they are written to disk
@@ -183,9 +190,8 @@ class ExternalAppendOnlyMap[K, V, C](
       val w = writer
       writer = null
       w.commitAndClose()
-      val bytesWritten = w.bytesWritten
-      batchSizes.append(bytesWritten)
-      _diskBytesSpilled += bytesWritten
+      _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten
+      batchSizes.append(curWriteMetrics.shuffleBytesWritten)
       objectsWritten = 0
     }
 
@@ -199,7 +205,9 @@ class ExternalAppendOnlyMap[K, V, C](
 
         if (objectsWritten == serializerBatchSize) {
           flush()
-          writer = blockManager.getDiskWriter(blockId, file, serializer, 
fileBufferSize)
+          curWriteMetrics = new ShuffleWriteMetrics()
+          writer = blockManager.getDiskWriter(blockId, file, serializer, 
fileBufferSize,
+            curWriteMetrics)
         }
       }
       if (objectsWritten > 0) {

http://git-wip-us.apache.org/repos/asf/spark/blob/a65c9ac1/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala 
b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 3f93afd..eb4849e 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -28,6 +28,7 @@ import com.google.common.io.ByteStreams
 import org.apache.spark.{Aggregator, SparkEnv, Logging, Partitioner}
 import org.apache.spark.serializer.{DeserializationStream, Serializer}
 import org.apache.spark.storage.BlockId
+import org.apache.spark.executor.ShuffleWriteMetrics
 
 /**
  * Sorts and potentially merges a number of key-value pairs of type (K, V) to 
produce key-combiner
@@ -112,11 +113,14 @@ private[spark] class ExternalSorter[K, V, C](
   // What threshold of elementsRead we start estimating map size at.
   private val trackMemoryThreshold = 1000
 
-  // Spilling statistics
+  // Total spilling statistics
   private var spillCount = 0
   private var _memoryBytesSpilled = 0L
   private var _diskBytesSpilled = 0L
 
+  // Write metrics for current spill
+  private var curWriteMetrics: ShuffleWriteMetrics = _
+
   // How much of the shared memory pool this collection has claimed
   private var myMemoryThreshold = 0L
 
@@ -239,7 +243,8 @@ private[spark] class ExternalSorter[K, V, C](
     logInfo("Thread %d spilling in-memory batch of %d MB to disk (%d spill%s 
so far)"
       .format(threadId, memorySize / (1024 * 1024), spillCount, if (spillCount 
> 1) "s" else ""))
     val (blockId, file) = diskBlockManager.createTempBlock()
-    var writer = blockManager.getDiskWriter(blockId, file, ser, fileBufferSize)
+    curWriteMetrics = new ShuffleWriteMetrics()
+    var writer = blockManager.getDiskWriter(blockId, file, ser, 
fileBufferSize, curWriteMetrics)
     var objectsWritten = 0   // Objects written since the last flush
 
     // List of batch sizes (bytes) in the order they are written to disk
@@ -254,9 +259,8 @@ private[spark] class ExternalSorter[K, V, C](
       val w = writer
       writer = null
       w.commitAndClose()
-      val bytesWritten = w.bytesWritten
-      batchSizes.append(bytesWritten)
-      _diskBytesSpilled += bytesWritten
+      _diskBytesSpilled += curWriteMetrics.shuffleBytesWritten
+      batchSizes.append(curWriteMetrics.shuffleBytesWritten)
       objectsWritten = 0
     }
 
@@ -275,7 +279,8 @@ private[spark] class ExternalSorter[K, V, C](
 
         if (objectsWritten == serializerBatchSize) {
           flush()
-          writer = blockManager.getDiskWriter(blockId, file, ser, 
fileBufferSize)
+          curWriteMetrics = new ShuffleWriteMetrics()
+          writer = blockManager.getDiskWriter(blockId, file, ser, 
fileBufferSize, curWriteMetrics)
         }
       }
       if (objectsWritten > 0) {

http://git-wip-us.apache.org/repos/asf/spark/blob/a65c9ac1/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala 
b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
new file mode 100644
index 0000000..bbc7e13
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/BlockObjectWriterSuite.scala
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.storage
+
+import org.scalatest.FunSuite
+import java.io.File
+import org.apache.spark.executor.ShuffleWriteMetrics
+import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.SparkConf
+
+class BlockObjectWriterSuite extends FunSuite {
+  test("verify write metrics") {
+    val file = new File("somefile")
+    file.deleteOnExit()
+    val writeMetrics = new ShuffleWriteMetrics()
+    val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
+      new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics)
+
+    writer.write(Long.box(20))
+    // Metrics don't update on every write
+    assert(writeMetrics.shuffleBytesWritten == 0)
+    // After 32 writes, metrics should update
+    for (i <- 0 until 32) {
+      writer.flush()
+      writer.write(Long.box(i))
+    }
+    assert(writeMetrics.shuffleBytesWritten > 0)
+    writer.commitAndClose()
+    assert(file.length() == writeMetrics.shuffleBytesWritten)
+  }
+
+  test("verify write metrics on revert") {
+    val file = new File("somefile")
+    file.deleteOnExit()
+    val writeMetrics = new ShuffleWriteMetrics()
+    val writer = new DiskBlockObjectWriter(new TestBlockId("0"), file,
+      new JavaSerializer(new SparkConf()), 1024, os => os, true, writeMetrics)
+
+    writer.write(Long.box(20))
+    // Metrics don't update on every write
+    assert(writeMetrics.shuffleBytesWritten == 0)
+    // After 32 writes, metrics should update
+    for (i <- 0 until 32) {
+      writer.flush()
+      writer.write(Long.box(i))
+    }
+    assert(writeMetrics.shuffleBytesWritten > 0)
+    writer.revertPartialWritesAndClose()
+    assert(writeMetrics.shuffleBytesWritten == 0)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/a65c9ac1/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala 
b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
index 985ac93..b8299e2 100644
--- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
@@ -30,6 +30,7 @@ import org.apache.spark.SparkConf
 import org.apache.spark.scheduler.LiveListenerBus
 import org.apache.spark.serializer.JavaSerializer
 import org.apache.spark.util.{AkkaUtils, Utils}
+import org.apache.spark.executor.ShuffleWriteMetrics
 
 class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with 
BeforeAndAfterAll {
   private val testConf = new SparkConf(false)
@@ -153,7 +154,7 @@ class DiskBlockManagerSuite extends FunSuite with 
BeforeAndAfterEach with Before
 
       val shuffleManager = store.shuffleBlockManager
 
-      val shuffle1 = shuffleManager.forMapTask(1, 1, 1, serializer)
+      val shuffle1 = shuffleManager.forMapTask(1, 1, 1, serializer, new 
ShuffleWriteMetrics)
       for (writer <- shuffle1.writers) {
         writer.write("test1")
         writer.write("test2")
@@ -165,7 +166,8 @@ class DiskBlockManagerSuite extends FunSuite with 
BeforeAndAfterEach with Before
       val shuffle1Segment = shuffle1.writers(0).fileSegment()
       shuffle1.releaseWriters(success = true)
 
-      val shuffle2 = shuffleManager.forMapTask(1, 2, 1, new 
JavaSerializer(testConf))
+      val shuffle2 = shuffleManager.forMapTask(1, 2, 1, new 
JavaSerializer(testConf),
+        new ShuffleWriteMetrics)
 
       for (writer <- shuffle2.writers) {
         writer.write("test3")
@@ -183,7 +185,8 @@ class DiskBlockManagerSuite extends FunSuite with 
BeforeAndAfterEach with Before
       // of block based on remaining data in file : which could mess things up 
when there is concurrent read
       // and writes happening to the same shuffle group.
 
-      val shuffle3 = shuffleManager.forMapTask(1, 3, 1, new 
JavaSerializer(testConf))
+      val shuffle3 = shuffleManager.forMapTask(1, 3, 1, new 
JavaSerializer(testConf),
+        new ShuffleWriteMetrics)
       for (writer <- shuffle3.writers) {
         writer.write("test3")
         writer.write("test4")

http://git-wip-us.apache.org/repos/asf/spark/blob/a65c9ac1/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala
----------------------------------------------------------------------
diff --git 
a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala 
b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala
index 8a05fcb..17bf7c2 100644
--- a/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala
+++ b/tools/src/main/scala/org/apache/spark/tools/StoragePerfTester.scala
@@ -23,6 +23,7 @@ import java.util.concurrent.atomic.AtomicLong
 import org.apache.spark.SparkContext
 import org.apache.spark.serializer.KryoSerializer
 import org.apache.spark.util.Utils
+import org.apache.spark.executor.ShuffleWriteMetrics
 
 /**
  * Internal utility for micro-benchmarking shuffle write performance.
@@ -56,7 +57,7 @@ object StoragePerfTester {
 
     def writeOutputBytes(mapId: Int, total: AtomicLong) = {
       val shuffle = blockManager.shuffleBlockManager.forMapTask(1, mapId, 
numOutputSplits,
-        new KryoSerializer(sc.conf))
+        new KryoSerializer(sc.conf), new ShuffleWriteMetrics())
       val writers = shuffle.writers
       for (i <- 1 to recordsPerMap) {
         writers(i % numOutputSplits).write(writeData)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to