This is an automated email from the ASF dual-hosted git repository.

chengpan pushed a commit to branch branch-0.3
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/branch-0.3 by this push:
     new c9a990985 [CELEBORN-693][SPARK] Align the `incWriterTime` in the 
hash-based shuffle writer with the sort-based shuffle
c9a990985 is described below

commit c9a99098591495f0662d103453c01c9205bd01ed
Author: Fu Chen <[email protected]>
AuthorDate: Mon Jun 19 15:42:01 2023 +0800

    [CELEBORN-693][SPARK] Align the `incWriterTime` in the hash-based shuffle 
writer with the sort-based shuffle
    
    ### What changes were proposed in this pull request?
    
    As title.
    
    ### Why are the changes needed?
    
    
https://github.com/apache/incubator-celeborn/pull/1585#issuecomment-1589164128
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    tested locally.
    
    Closes #1604 from cfmcgrady/hash-based-writer-metrics.
    
    Authored-by: Fu Chen <[email protected]>
    Signed-off-by: Cheng Pan <[email protected]>
    (cherry picked from commit 18f2be0fbec7258462541457b2bda4279fbc00af)
    Signed-off-by: Cheng Pan <[email protected]>
---
 .../spark/shuffle/celeborn/HashBasedShuffleWriter.java  | 13 ++++++++-----
 .../spark/shuffle/celeborn/HashBasedShuffleWriter.java  | 17 +++++++++++++----
 2 files changed, 21 insertions(+), 9 deletions(-)

diff --git 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
index f93eee8d5..e2f12f969 100644
--- 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
+++ 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
@@ -191,7 +191,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
 
     SQLMetric dataSize =
         SparkUtils.getUnsafeRowSerializerDataSizeMetric((UnsafeRowSerializer) 
dep.serializer());
-
+    long shuffleWriteTimeSum = 0L;
     while (records.hasNext()) {
       final Product2<Integer, UnsafeRow> record = records.next();
       final int partitionId = record._1();
@@ -203,6 +203,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         dataSize.add(serializedRecordSize);
       }
 
+      long insertAndPushStartTime = System.nanoTime();
       if (serializedRecordSize > PUSH_BUFFER_MAX_SIZE) {
         byte[] giantBuffer = new byte[serializedRecordSize];
         Platform.putInt(giantBuffer, Platform.BYTE_ARRAY_OFFSET, 
Integer.reverseBytes(rowSize));
@@ -225,13 +226,16 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
             rowSize);
         sendOffsets[partitionId] = offset + serializedRecordSize;
       }
+      shuffleWriteTimeSum += System.nanoTime() - insertAndPushStartTime;
       tmpRecords[partitionId] += 1;
     }
+    writeMetrics.incWriteTime(shuffleWriteTimeSum);
   }
 
   private void write0(scala.collection.Iterator iterator) throws IOException, 
InterruptedException {
     final scala.collection.Iterator<Product2<K, ?>> records = iterator;
 
+    long shuffleWriteTimeSum = 0L;
     while (records.hasNext()) {
       final Product2<K, ?> record = records.next();
       final K key = record._1();
@@ -244,6 +248,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       final int serializedRecordSize = serBuffer.size();
       assert (serializedRecordSize > 0);
 
+      long insertAndPushStartTime = System.nanoTime();
       if (serializedRecordSize > PUSH_BUFFER_MAX_SIZE) {
         pushGiantRecord(partitionId, serBuffer.getBuf(), serializedRecordSize);
       } else {
@@ -252,8 +257,10 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         System.arraycopy(serBuffer.getBuf(), 0, buffer, offset, 
serializedRecordSize);
         sendOffsets[partitionId] = offset + serializedRecordSize;
       }
+      shuffleWriteTimeSum += System.nanoTime() - insertAndPushStartTime;
       tmpRecords[partitionId] += 1;
     }
+    writeMetrics.incWriteTime(shuffleWriteTimeSum);
   }
 
   private byte[] getOrCreateBuffer(int partitionId) {
@@ -268,7 +275,6 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
 
   private void pushGiantRecord(int partitionId, byte[] buffer, int numBytes) 
throws IOException {
     logger.debug("Push giant record for partition {}, size {}.", partitionId, 
numBytes);
-    long pushStartTime = System.nanoTime();
     int bytesWritten =
         rssShuffleClient.pushData(
             appId,
@@ -283,7 +289,6 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
             numPartitions);
     mapStatusLengths[partitionId].add(bytesWritten);
     writeMetrics.incBytesWritten(bytesWritten);
-    writeMetrics.incWriteTime(System.nanoTime() - pushStartTime);
   }
 
   private int getOrUpdateOffset(int partitionId, int serializedRecordSize)
@@ -310,10 +315,8 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
 
   private void flushSendBuffer(int partitionId, byte[] buffer, int size)
       throws IOException, InterruptedException {
-    long pushStartTime = System.nanoTime();
     logger.debug("Flush buffer for partition {}, size {}.", partitionId, size);
     dataPusher.addTask(partitionId, buffer, size);
-    writeMetrics.incWriteTime(System.nanoTime() - pushStartTime);
   }
 
   private void close() throws IOException, InterruptedException {
diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
index a3a6c02aa..34c2d9937 100644
--- 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
@@ -226,6 +226,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     final scala.collection.Iterator<Product2<Integer, UnsafeRow>> records = 
iterator;
 
     SQLMetric dataSize = SparkUtils.getDataSize((UnsafeRowSerializer) 
dep.serializer());
+    long shuffleWriteTimeSum = 0L;
     while (records.hasNext()) {
       final Product2<Integer, UnsafeRow> record = records.next();
       final int partitionId = record._1();
@@ -247,6 +248,8 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         columnBuilders.newBuilders();
         rssBatchBuilders[partitionId] = columnBuilders;
       }
+
+      long insertAndPushStartTime = System.nanoTime();
       rssBatchBuilders[partitionId].writeRow(row);
       if (rssBatchBuilders[partitionId].getRowCnt() >= 
columnarShuffleBatchSize) {
         byte[] arr = rssBatchBuilders[partitionId].buildColumnBytes();
@@ -256,8 +259,10 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         }
         rssBatchBuilders[partitionId].newBuilders();
       }
+      shuffleWriteTimeSum += System.nanoTime() - insertAndPushStartTime;
       tmpRecords[partitionId] += 1;
     }
+    writeMetrics.incWriteTime(shuffleWriteTimeSum);
   }
 
   private void fastWrite0(scala.collection.Iterator iterator)
@@ -265,6 +270,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     final scala.collection.Iterator<Product2<Integer, UnsafeRow>> records = 
iterator;
 
     SQLMetric dataSize = SparkUtils.getDataSize((UnsafeRowSerializer) 
dep.serializer());
+    long shuffleWriteTimeSum = 0L;
     while (records.hasNext()) {
       final Product2<Integer, UnsafeRow> record = records.next();
       final int partitionId = record._1();
@@ -277,6 +283,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         dataSize.add(rowSize);
       }
 
+      long insertAndPushStartTime = System.nanoTime();
       if (serializedRecordSize > PUSH_BUFFER_MAX_SIZE) {
         byte[] giantBuffer = new byte[serializedRecordSize];
         Platform.putInt(giantBuffer, Platform.BYTE_ARRAY_OFFSET, 
Integer.reverseBytes(rowSize));
@@ -299,13 +306,16 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
             rowSize);
         sendOffsets[partitionId] = offset + serializedRecordSize;
       }
+      shuffleWriteTimeSum += System.nanoTime() - insertAndPushStartTime;
       tmpRecords[partitionId] += 1;
     }
+    writeMetrics.incWriteTime(shuffleWriteTimeSum);
   }
 
   private void write0(scala.collection.Iterator iterator) throws IOException, 
InterruptedException {
     final scala.collection.Iterator<Product2<K, ?>> records = iterator;
 
+    long shuffleWriteTimeSum = 0L;
     while (records.hasNext()) {
       final Product2<K, ?> record = records.next();
       final K key = record._1();
@@ -318,6 +328,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       final int serializedRecordSize = serBuffer.size();
       assert (serializedRecordSize > 0);
 
+      long insertAndPushStartTime = System.nanoTime();
       if (serializedRecordSize > PUSH_BUFFER_MAX_SIZE) {
         pushGiantRecord(partitionId, serBuffer.getBuf(), serializedRecordSize);
       } else {
@@ -326,8 +337,10 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         System.arraycopy(serBuffer.getBuf(), 0, buffer, offset, 
serializedRecordSize);
         sendOffsets[partitionId] = offset + serializedRecordSize;
       }
+      shuffleWriteTimeSum += System.nanoTime() - insertAndPushStartTime;
       tmpRecords[partitionId] += 1;
     }
+    writeMetrics.incWriteTime(shuffleWriteTimeSum);
   }
 
   private byte[] getOrCreateBuffer(int partitionId) {
@@ -342,7 +355,6 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
 
   private void pushGiantRecord(int partitionId, byte[] buffer, int numBytes) 
throws IOException {
     logger.debug("Push giant record, size {}.", numBytes);
-    long pushStartTime = System.nanoTime();
     int bytesWritten =
         rssShuffleClient.pushData(
             appId,
@@ -357,7 +369,6 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
             numPartitions);
     mapStatusLengths[partitionId].add(bytesWritten);
     writeMetrics.incBytesWritten(bytesWritten);
-    writeMetrics.incWriteTime(System.nanoTime() - pushStartTime);
   }
 
   private int getOrUpdateOffset(int partitionId, int serializedRecordSize)
@@ -384,10 +395,8 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
 
   private void flushSendBuffer(int partitionId, byte[] buffer, int size)
       throws IOException, InterruptedException {
-    long pushStartTime = System.nanoTime();
     logger.debug("Flush buffer, size {}.", size);
     dataPusher.addTask(partitionId, buffer, size);
-    writeMetrics.incWriteTime(System.nanoTime() - pushStartTime);
   }
 
   private void closeColumnarWrite() throws IOException {

Reply via email to