Repository: spark
Updated Branches:
  refs/heads/master 85383d29e -> 6a064ba8f


[SPARK-26141] Enable custom metrics implementation in shuffle write

## What changes were proposed in this pull request?
This is the write side counterpart to https://github.com/apache/spark/pull/23105

## How was this patch tested?
No behavior change expected, as it is a straightforward refactoring. Updated 
all existing test cases.

Closes #23106 from rxin/SPARK-26141.

Authored-by: Reynold Xin <r...@databricks.com>
Signed-off-by: Reynold Xin <r...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6a064ba8
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6a064ba8
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6a064ba8

Branch: refs/heads/master
Commit: 6a064ba8f271d5f9d04acd41d0eea50a5b0f5018
Parents: 85383d2
Author: Reynold Xin <r...@databricks.com>
Authored: Mon Nov 26 22:35:52 2018 -0800
Committer: Reynold Xin <r...@databricks.com>
Committed: Mon Nov 26 22:35:52 2018 -0800

----------------------------------------------------------------------
 .../sort/BypassMergeSortShuffleWriter.java        | 11 +++++------
 .../spark/shuffle/sort/ShuffleExternalSorter.java | 18 ++++++++++++------
 .../spark/shuffle/sort/UnsafeShuffleWriter.java   |  9 +++++----
 .../spark/storage/TimeTrackingOutputStream.java   |  7 ++++---
 .../spark/executor/ShuffleWriteMetrics.scala      | 13 +++++++------
 .../apache/spark/scheduler/ShuffleMapTask.scala   |  3 ++-
 .../org/apache/spark/shuffle/ShuffleManager.scala |  6 +++++-
 .../spark/shuffle/sort/SortShuffleManager.scala   | 10 ++++++----
 .../org/apache/spark/storage/BlockManager.scala   |  7 +++----
 .../spark/storage/DiskBlockObjectWriter.scala     |  4 ++--
 .../spark/util/collection/ExternalSorter.scala    |  4 ++--
 .../shuffle/sort/UnsafeShuffleWriterSuite.java    |  6 ++++--
 .../scala/org/apache/spark/ShuffleSuite.scala     | 12 ++++++++----
 .../sort/BypassMergeSortShuffleWriterSuite.scala  | 16 ++++++++--------
 project/MimaExcludes.scala                        |  7 ++++++-
 15 files changed, 79 insertions(+), 54 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6a064ba8/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
 
b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index b020a6d..fda33cd 100644
--- 
a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ 
b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -37,12 +37,11 @@ import org.slf4j.LoggerFactory;
 import org.apache.spark.Partitioner;
 import org.apache.spark.ShuffleDependency;
 import org.apache.spark.SparkConf;
-import org.apache.spark.TaskContext;
-import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.scheduler.MapStatus;
 import org.apache.spark.scheduler.MapStatus$;
 import org.apache.spark.serializer.Serializer;
 import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
 import org.apache.spark.shuffle.IndexShuffleBlockResolver;
 import org.apache.spark.shuffle.ShuffleWriter;
 import org.apache.spark.storage.*;
@@ -79,7 +78,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends 
ShuffleWriter<K, V> {
   private final int numPartitions;
   private final BlockManager blockManager;
   private final Partitioner partitioner;
-  private final ShuffleWriteMetrics writeMetrics;
+  private final ShuffleWriteMetricsReporter writeMetrics;
   private final int shuffleId;
   private final int mapId;
   private final Serializer serializer;
@@ -103,8 +102,8 @@ final class BypassMergeSortShuffleWriter<K, V> extends 
ShuffleWriter<K, V> {
       IndexShuffleBlockResolver shuffleBlockResolver,
       BypassMergeSortShuffleHandle<K, V> handle,
       int mapId,
-      TaskContext taskContext,
-      SparkConf conf) {
+      SparkConf conf,
+      ShuffleWriteMetricsReporter writeMetrics) {
     // Use getSizeAsKb (not bytes) to maintain backwards compatibility if no 
units are provided
     this.fileBufferSize = (int) conf.getSizeAsKb("spark.shuffle.file.buffer", 
"32k") * 1024;
     this.transferToEnabled = conf.getBoolean("spark.file.transferTo", true);
@@ -114,7 +113,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends 
ShuffleWriter<K, V> {
     this.shuffleId = dep.shuffleId();
     this.partitioner = dep.partitioner();
     this.numPartitions = partitioner.numPartitions();
-    this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics();
+    this.writeMetrics = writeMetrics;
     this.serializer = dep.serializer();
     this.shuffleBlockResolver = shuffleBlockResolver;
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/6a064ba8/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java 
b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
index 1c0d664..6ee9d5f 100644
--- 
a/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
+++ 
b/core/src/main/java/org/apache/spark/shuffle/sort/ShuffleExternalSorter.java
@@ -38,6 +38,7 @@ import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.memory.TooLargePageException;
 import org.apache.spark.serializer.DummySerializerInstance;
 import org.apache.spark.serializer.SerializerInstance;
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
 import org.apache.spark.storage.BlockManager;
 import org.apache.spark.storage.DiskBlockObjectWriter;
 import org.apache.spark.storage.FileSegment;
@@ -75,7 +76,7 @@ final class ShuffleExternalSorter extends MemoryConsumer {
   private final TaskMemoryManager taskMemoryManager;
   private final BlockManager blockManager;
   private final TaskContext taskContext;
-  private final ShuffleWriteMetrics writeMetrics;
+  private final ShuffleWriteMetricsReporter writeMetrics;
 
   /**
    * Force this sorter to spill when there are this many elements in memory.
@@ -113,7 +114,7 @@ final class ShuffleExternalSorter extends MemoryConsumer {
       int initialSize,
       int numPartitions,
       SparkConf conf,
-      ShuffleWriteMetrics writeMetrics) {
+      ShuffleWriteMetricsReporter writeMetrics) {
     super(memoryManager,
       (int) Math.min(PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES, 
memoryManager.pageSizeBytes()),
       memoryManager.getTungstenMemoryMode());
@@ -144,7 +145,7 @@ final class ShuffleExternalSorter extends MemoryConsumer {
    */
   private void writeSortedFile(boolean isLastFile) {
 
-    final ShuffleWriteMetrics writeMetricsToUse;
+    final ShuffleWriteMetricsReporter writeMetricsToUse;
 
     if (isLastFile) {
       // We're writing the final non-spill file, so we _do_ want to count this 
as shuffle bytes.
@@ -241,9 +242,14 @@ final class ShuffleExternalSorter extends MemoryConsumer {
       //
       // Note that we intentionally ignore the value of 
`writeMetricsToUse.shuffleWriteTime()`.
       // Consistent with ExternalSorter, we do not count this IO towards 
shuffle write time.
-      // This means that this IO time is not accounted for anywhere; 
SPARK-3577 will fix this.
-      writeMetrics.incRecordsWritten(writeMetricsToUse.recordsWritten());
-      
taskContext.taskMetrics().incDiskBytesSpilled(writeMetricsToUse.bytesWritten());
+      // SPARK-3577 tracks the spill time separately.
+
+      // This is guaranteed to be a ShuffleWriteMetrics based on the if check 
in the beginning
+      // of this method.
+      writeMetrics.incRecordsWritten(
+        ((ShuffleWriteMetrics)writeMetricsToUse).recordsWritten());
+      taskContext.taskMetrics().incDiskBytesSpilled(
+        ((ShuffleWriteMetrics)writeMetricsToUse).bytesWritten());
     }
   }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/6a064ba8/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java 
b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index 4839d04..4b0c743 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -37,7 +37,6 @@ import org.slf4j.LoggerFactory;
 
 import org.apache.spark.*;
 import org.apache.spark.annotation.Private;
-import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.io.CompressionCodec;
 import org.apache.spark.io.CompressionCodec$;
 import org.apache.spark.io.NioBufferedFileInputStream;
@@ -47,6 +46,7 @@ import org.apache.spark.memory.TaskMemoryManager;
 import org.apache.spark.network.util.LimitedInputStream;
 import org.apache.spark.scheduler.MapStatus;
 import org.apache.spark.scheduler.MapStatus$;
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
 import org.apache.spark.serializer.SerializationStream;
 import org.apache.spark.serializer.SerializerInstance;
 import org.apache.spark.shuffle.IndexShuffleBlockResolver;
@@ -73,7 +73,7 @@ public class UnsafeShuffleWriter<K, V> extends 
ShuffleWriter<K, V> {
   private final TaskMemoryManager memoryManager;
   private final SerializerInstance serializer;
   private final Partitioner partitioner;
-  private final ShuffleWriteMetrics writeMetrics;
+  private final ShuffleWriteMetricsReporter writeMetrics;
   private final int shuffleId;
   private final int mapId;
   private final TaskContext taskContext;
@@ -122,7 +122,8 @@ public class UnsafeShuffleWriter<K, V> extends 
ShuffleWriter<K, V> {
       SerializedShuffleHandle<K, V> handle,
       int mapId,
       TaskContext taskContext,
-      SparkConf sparkConf) throws IOException {
+      SparkConf sparkConf,
+      ShuffleWriteMetricsReporter writeMetrics) throws IOException {
     final int numPartitions = 
handle.dependency().partitioner().numPartitions();
     if (numPartitions > 
SortShuffleManager.MAX_SHUFFLE_OUTPUT_PARTITIONS_FOR_SERIALIZED_MODE()) {
       throw new IllegalArgumentException(
@@ -138,7 +139,7 @@ public class UnsafeShuffleWriter<K, V> extends 
ShuffleWriter<K, V> {
     this.shuffleId = dep.shuffleId();
     this.serializer = dep.serializer().newInstance();
     this.partitioner = dep.partitioner();
-    this.writeMetrics = taskContext.taskMetrics().shuffleWriteMetrics();
+    this.writeMetrics = writeMetrics;
     this.taskContext = taskContext;
     this.sparkConf = sparkConf;
     this.transferToEnabled = sparkConf.getBoolean("spark.file.transferTo", 
true);

http://git-wip-us.apache.org/repos/asf/spark/blob/6a064ba8/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java
----------------------------------------------------------------------
diff --git 
a/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java 
b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java
index 5d0555a..fcba3b7 100644
--- a/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java
+++ b/core/src/main/java/org/apache/spark/storage/TimeTrackingOutputStream.java
@@ -21,7 +21,7 @@ import java.io.IOException;
 import java.io.OutputStream;
 
 import org.apache.spark.annotation.Private;
-import org.apache.spark.executor.ShuffleWriteMetrics;
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
 
 /**
  * Intercepts write calls and tracks total time spent writing in order to 
update shuffle write
@@ -30,10 +30,11 @@ import org.apache.spark.executor.ShuffleWriteMetrics;
 @Private
 public final class TimeTrackingOutputStream extends OutputStream {
 
-  private final ShuffleWriteMetrics writeMetrics;
+  private final ShuffleWriteMetricsReporter writeMetrics;
   private final OutputStream outputStream;
 
-  public TimeTrackingOutputStream(ShuffleWriteMetrics writeMetrics, 
OutputStream outputStream) {
+  public TimeTrackingOutputStream(
+      ShuffleWriteMetricsReporter writeMetrics, OutputStream outputStream) {
     this.writeMetrics = writeMetrics;
     this.outputStream = outputStream;
   }

http://git-wip-us.apache.org/repos/asf/spark/blob/6a064ba8/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala 
b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala
index 0c9da65..d0b0e7d 100644
--- a/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ShuffleWriteMetrics.scala
@@ -18,6 +18,7 @@
 package org.apache.spark.executor
 
 import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
 import org.apache.spark.util.LongAccumulator
 
 
@@ -27,7 +28,7 @@ import org.apache.spark.util.LongAccumulator
  * Operations are not thread-safe.
  */
 @DeveloperApi
-class ShuffleWriteMetrics private[spark] () extends Serializable {
+class ShuffleWriteMetrics private[spark] () extends 
ShuffleWriteMetricsReporter with Serializable {
   private[executor] val _bytesWritten = new LongAccumulator
   private[executor] val _recordsWritten = new LongAccumulator
   private[executor] val _writeTime = new LongAccumulator
@@ -47,13 +48,13 @@ class ShuffleWriteMetrics private[spark] () extends 
Serializable {
    */
   def writeTime: Long = _writeTime.sum
 
-  private[spark] def incBytesWritten(v: Long): Unit = _bytesWritten.add(v)
-  private[spark] def incRecordsWritten(v: Long): Unit = _recordsWritten.add(v)
-  private[spark] def incWriteTime(v: Long): Unit = _writeTime.add(v)
-  private[spark] def decBytesWritten(v: Long): Unit = {
+  private[spark] override def incBytesWritten(v: Long): Unit = 
_bytesWritten.add(v)
+  private[spark] override def incRecordsWritten(v: Long): Unit = 
_recordsWritten.add(v)
+  private[spark] override def incWriteTime(v: Long): Unit = _writeTime.add(v)
+  private[spark] override def decBytesWritten(v: Long): Unit = {
     _bytesWritten.setValue(bytesWritten - v)
   }
-  private[spark] def decRecordsWritten(v: Long): Unit = {
+  private[spark] override def decRecordsWritten(v: Long): Unit = {
     _recordsWritten.setValue(recordsWritten - v)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/6a064ba8/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala 
b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index f2cd65f..5412717 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -95,7 +95,8 @@ private[spark] class ShuffleMapTask(
     var writer: ShuffleWriter[Any, Any] = null
     try {
       val manager = SparkEnv.get.shuffleManager
-      writer = manager.getWriter[Any, Any](dep.shuffleHandle, partitionId, 
context)
+      writer = manager.getWriter[Any, Any](
+        dep.shuffleHandle, partitionId, context, 
context.taskMetrics().shuffleWriteMetrics)
       writer.write(rdd.iterator(partition, context).asInstanceOf[Iterator[_ <: 
Product2[Any, Any]]])
       writer.stop(success = true).get
     } catch {

http://git-wip-us.apache.org/repos/asf/spark/blob/6a064ba8/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala 
b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
index df601cb..18a743f 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
@@ -38,7 +38,11 @@ private[spark] trait ShuffleManager {
       dependency: ShuffleDependency[K, V, C]): ShuffleHandle
 
   /** Get a writer for a given partition. Called on executors by map tasks. */
-  def getWriter[K, V](handle: ShuffleHandle, mapId: Int, context: 
TaskContext): ShuffleWriter[K, V]
+  def getWriter[K, V](
+      handle: ShuffleHandle,
+      mapId: Int,
+      context: TaskContext,
+      metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V]
 
   /**
    * Get a reader for a range of reduce partitions (startPartition to 
endPartition-1, inclusive).

http://git-wip-us.apache.org/repos/asf/spark/blob/6a064ba8/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala 
b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
index 4f8be19..b51a843 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -125,7 +125,8 @@ private[spark] class SortShuffleManager(conf: SparkConf) 
extends ShuffleManager
   override def getWriter[K, V](
       handle: ShuffleHandle,
       mapId: Int,
-      context: TaskContext): ShuffleWriter[K, V] = {
+      context: TaskContext,
+      metrics: ShuffleWriteMetricsReporter): ShuffleWriter[K, V] = {
     numMapsForShuffle.putIfAbsent(
       handle.shuffleId, handle.asInstanceOf[BaseShuffleHandle[_, _, 
_]].numMaps)
     val env = SparkEnv.get
@@ -138,15 +139,16 @@ private[spark] class SortShuffleManager(conf: SparkConf) 
extends ShuffleManager
           unsafeShuffleHandle,
           mapId,
           context,
-          env.conf)
+          env.conf,
+          metrics)
       case bypassMergeSortHandle: BypassMergeSortShuffleHandle[K @unchecked, V 
@unchecked] =>
         new BypassMergeSortShuffleWriter(
           env.blockManager,
           shuffleBlockResolver.asInstanceOf[IndexShuffleBlockResolver],
           bypassMergeSortHandle,
           mapId,
-          context,
-          env.conf)
+          env.conf,
+          metrics)
       case other: BaseShuffleHandle[K @unchecked, V @unchecked, _] =>
         new SortShuffleWriter(shuffleBlockResolver, other, mapId, context)
     }

http://git-wip-us.apache.org/repos/asf/spark/blob/6a064ba8/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
----------------------------------------------------------------------
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala 
b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index edae2f9..1b61729 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -33,10 +33,9 @@ import scala.util.Random
 import scala.util.control.NonFatal
 
 import com.codahale.metrics.{MetricRegistry, MetricSet}
-import com.google.common.io.CountingOutputStream
 
 import org.apache.spark._
-import org.apache.spark.executor.{DataReadMethod, ShuffleWriteMetrics}
+import org.apache.spark.executor.DataReadMethod
 import org.apache.spark.internal.{config, Logging}
 import org.apache.spark.memory.{MemoryManager, MemoryMode}
 import org.apache.spark.metrics.source.Source
@@ -50,7 +49,7 @@ import org.apache.spark.network.util.TransportConf
 import org.apache.spark.rpc.RpcEnv
 import org.apache.spark.scheduler.ExecutorCacheTaskLocation
 import org.apache.spark.serializer.{SerializerInstance, SerializerManager}
-import org.apache.spark.shuffle.ShuffleManager
+import org.apache.spark.shuffle.{ShuffleManager, ShuffleWriteMetricsReporter}
 import org.apache.spark.storage.memory._
 import org.apache.spark.unsafe.Platform
 import org.apache.spark.util._
@@ -932,7 +931,7 @@ private[spark] class BlockManager(
       file: File,
       serializerInstance: SerializerInstance,
       bufferSize: Int,
-      writeMetrics: ShuffleWriteMetrics): DiskBlockObjectWriter = {
+      writeMetrics: ShuffleWriteMetricsReporter): DiskBlockObjectWriter = {
     val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
     new DiskBlockObjectWriter(file, serializerManager, serializerInstance, 
bufferSize,
       syncWrites, writeMetrics, blockId)

http://git-wip-us.apache.org/repos/asf/spark/blob/6a064ba8/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala 
b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
index a024c83..17390f9 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
@@ -20,9 +20,9 @@ package org.apache.spark.storage
 import java.io.{BufferedOutputStream, File, FileOutputStream, OutputStream}
 import java.nio.channels.FileChannel
 
-import org.apache.spark.executor.ShuffleWriteMetrics
 import org.apache.spark.internal.Logging
 import org.apache.spark.serializer.{SerializationStream, SerializerInstance, 
SerializerManager}
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter
 import org.apache.spark.util.Utils
 
 /**
@@ -43,7 +43,7 @@ private[spark] class DiskBlockObjectWriter(
     syncWrites: Boolean,
     // These write metrics concurrently shared with other active 
DiskBlockObjectWriters who
     // are themselves performing writes. All updates must be relative.
-    writeMetrics: ShuffleWriteMetrics,
+    writeMetrics: ShuffleWriteMetricsReporter,
     val blockId: BlockId = null)
   extends OutputStream
   with Logging {

http://git-wip-us.apache.org/repos/asf/spark/blob/6a064ba8/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
----------------------------------------------------------------------
diff --git 
a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala 
b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index b159200..eac3db0 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -793,8 +793,8 @@ private[spark] class ExternalSorter[K, V, C](
 
           def nextPartition(): Int = cur._1._1
         }
-        logInfo(s"Task ${context.taskAttemptId} force spilling in-memory map 
to disk and " +
-          s" it will release 
${org.apache.spark.util.Utils.bytesToString(getUsed())} memory")
+        logInfo(s"Task ${TaskContext.get().taskAttemptId} force spilling 
in-memory map to disk " +
+          s"and it will release 
${org.apache.spark.util.Utils.bytesToString(getUsed())} memory")
         val spillFile = spillMemoryIteratorToDisk(inMemoryIterator)
         forceSpillFiles += spillFile
         val spillReader = new SpillReader(spillFile)

http://git-wip-us.apache.org/repos/asf/spark/blob/6a064ba8/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
----------------------------------------------------------------------
diff --git 
a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
 
b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index a07d0e8..30ad3f5 100644
--- 
a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ 
b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -162,7 +162,8 @@ public class UnsafeShuffleWriterSuite {
       new SerializedShuffleHandle<>(0, 1, shuffleDep),
       0, // map id
       taskContext,
-      conf
+      conf,
+      taskContext.taskMetrics().shuffleWriteMetrics()
     );
   }
 
@@ -521,7 +522,8 @@ public class UnsafeShuffleWriterSuite {
         new SerializedShuffleHandle<>(0, 1, shuffleDep),
         0, // map id
         taskContext,
-        conf);
+        conf,
+        taskContext.taskMetrics().shuffleWriteMetrics());
 
     // Peak memory should be monotonically increasing. More specifically, 
every time
     // we allocate a new page it should increase by exactly the size of the 
page.

http://git-wip-us.apache.org/repos/asf/spark/blob/6a064ba8/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
----------------------------------------------------------------------
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala 
b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index 419a26b..35f728c 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -362,15 +362,19 @@ abstract class ShuffleSuite extends SparkFunSuite with 
Matchers with LocalSparkC
     mapTrackerMaster.registerShuffle(0, 1)
 
     // first attempt -- its successful
-    val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0,
-      new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, 
metricsSystem))
+    val context1 =
+      new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, 
metricsSystem)
+    val writer1 = manager.getWriter[Int, Int](
+      shuffleHandle, 0, context1, context1.taskMetrics.shuffleWriteMetrics)
     val data1 = (1 to 10).map { x => x -> x}
 
     // second attempt -- also successful.  We'll write out different data,
     // just to simulate the fact that the records may get written differently
     // depending on what gets spilled, what gets combined, etc.
-    val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0,
-      new TaskContextImpl(0, 0, 0, 1L, 0, taskMemoryManager, new Properties, 
metricsSystem))
+    val context2 =
+      new TaskContextImpl(0, 0, 0, 1L, 0, taskMemoryManager, new Properties, 
metricsSystem)
+    val writer2 = manager.getWriter[Int, Int](
+      shuffleHandle, 0, context2, context2.taskMetrics.shuffleWriteMetrics)
     val data2 = (11 to 20).map { x => x -> x}
 
     // interleave writes of both attempts -- we want to test that both 
attempts can occur

http://git-wip-us.apache.org/repos/asf/spark/blob/6a064ba8/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
----------------------------------------------------------------------
diff --git 
a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
 
b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
index 85ccb33..4467c32 100644
--- 
a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
+++ 
b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
@@ -136,8 +136,8 @@ class BypassMergeSortShuffleWriterSuite extends 
SparkFunSuite with BeforeAndAfte
       blockResolver,
       shuffleHandle,
       0, // MapId
-      taskContext,
-      conf
+      conf,
+      taskContext.taskMetrics().shuffleWriteMetrics
     )
     writer.write(Iterator.empty)
     writer.stop( /* success = */ true)
@@ -160,8 +160,8 @@ class BypassMergeSortShuffleWriterSuite extends 
SparkFunSuite with BeforeAndAfte
       blockResolver,
       shuffleHandle,
       0, // MapId
-      taskContext,
-      conf
+      conf,
+      taskContext.taskMetrics().shuffleWriteMetrics
     )
     writer.write(records)
     writer.stop( /* success = */ true)
@@ -195,8 +195,8 @@ class BypassMergeSortShuffleWriterSuite extends 
SparkFunSuite with BeforeAndAfte
       blockResolver,
       shuffleHandle,
       0, // MapId
-      taskContext,
-      conf
+      conf,
+      taskContext.taskMetrics().shuffleWriteMetrics
     )
 
     intercept[SparkException] {
@@ -217,8 +217,8 @@ class BypassMergeSortShuffleWriterSuite extends 
SparkFunSuite with BeforeAndAfte
       blockResolver,
       shuffleHandle,
       0, // MapId
-      taskContext,
-      conf
+      conf,
+      taskContext.taskMetrics().shuffleWriteMetrics
     )
     intercept[SparkException] {
       writer.write((0 until 100000).iterator.map(i => {

http://git-wip-us.apache.org/repos/asf/spark/blob/6a064ba8/project/MimaExcludes.scala
----------------------------------------------------------------------
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 333adb0..3fabec0 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -226,7 +226,12 @@ object MimaExcludes {
     
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchReader"),
     
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.writer.DataSourceWriter"),
     
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.v2.writer.DataWriterFactory.createWriter"),
-    
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter")
+    
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.v2.writer.streaming.StreamWriter"),
+
+    // [SPARK-26141] Enable custom metrics implementation in shuffle write
+    // Following are Java private classes
+    
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.shuffle.sort.UnsafeShuffleWriter.this"),
+    
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.storage.TimeTrackingOutputStream.this")
   )
 
   // Exclude rules for 2.4.x


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to