This is an automated email from the ASF dual-hosted git repository.

chengpan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 0583cdb5a [CELEBORN-1048] Align fetchWaitTime metrics to spark 
implementation
0583cdb5a is described below

commit 0583cdb5a8ae90850103d6d0e1301ffd0a542077
Author: TongWei1105 <[email protected]>
AuthorDate: Thu Nov 2 15:27:30 2023 +0800

    [CELEBORN-1048] Align fetchWaitTime metrics to spark implementation
    
    ### What changes were proposed in this pull request?
    Align fetchWaitTime metrics to spark implementation
    
    ### Why are the changes needed?
    In our production environment, there are variations in the fetchWaitTime 
metric for the same stage of the same job.
    
    ON YARN ESS:
    
![image](https://github.com/apache/incubator-celeborn/assets/68682646/601a8315-1317-48dc-b9a6-7ea651d5122d)
    ON CELEBORN
    
![image](https://github.com/apache/incubator-celeborn/assets/68682646/e00ed60f-3789-4330-a7ed-fdd5754acf1d)
    Then, based on the implementation of Spark ShuffleBlockFetcherIterator, I 
made adjustments to the fetchWaitTime metrics code
    
    Now, looks like more reasonable, 
    
![image](https://github.com/apache/incubator-celeborn/assets/68682646/ce5e46e4-8ed2-422e-b54b-cd594aad73dd)
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    yes, tested in our production environment
    
    Closes #2000 from TongWei1105/CELEBORN-1048.
    
    Lead-authored-by: TongWei1105 <[email protected]>
    Co-authored-by: Keyong Zhou <[email protected]>
    Co-authored-by: zky.zhoukeyong <[email protected]>
    Signed-off-by: Cheng Pan <[email protected]>
---
 .../shuffle/celeborn/CelebornShuffleReader.scala   |  4 +--
 .../shuffle/celeborn/CelebornShuffleReader.scala   |  4 +--
 .../org/apache/celeborn/client/ShuffleClient.java  |  8 ++++-
 .../apache/celeborn/client/ShuffleClientImpl.java  | 11 +++++--
 .../celeborn/client/read/CelebornInputStream.java  | 36 ++++++++--------------
 .../celeborn/client/read/DfsPartitionReader.java   |  8 ++++-
 .../celeborn/client/read/LocalPartitionReader.java |  8 ++++-
 .../client/read/WorkerPartitionReader.java         |  8 ++++-
 .../apache/celeborn/client/DummyShuffleClient.java |  8 ++++-
 .../celeborn/client/WithShuffleClientSuite.scala   | 10 ++++--
 .../service/deploy/cluster/ReadWriteTestBase.scala |  7 ++++-
 11 files changed, 74 insertions(+), 38 deletions(-)

diff --git 
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
 
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index 7518b1e6b..1059f3604 100644
--- 
a/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++ 
b/client-spark/spark-2/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -91,7 +91,8 @@ class CelebornShuffleReader[K, C](
                 partitionId,
                 context.attemptNumber(),
                 startMapIndex,
-                endMapIndex)
+                endMapIndex,
+                metricsCallback)
               streams.put(partitionId, inputStream)
             } catch {
               case e: IOException =>
@@ -119,7 +120,6 @@ class CelebornShuffleReader[K, C](
         }
         metricsCallback.incReadTime(
           TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startFetchWait))
-        inputStream.setCallback(metricsCallback)
         // ensure inputStream is closed when task completes
         context.addTaskCompletionListener(_ => inputStream.close())
         inputStream
diff --git 
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
 
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
index 063ad0b6f..5ec0fed9b 100644
--- 
a/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
+++ 
b/client-spark/spark-3/src/main/scala/org/apache/spark/shuffle/celeborn/CelebornShuffleReader.scala
@@ -93,7 +93,8 @@ class CelebornShuffleReader[K, C](
                 partitionId,
                 context.attemptNumber(),
                 startMapIndex,
-                endMapIndex)
+                endMapIndex,
+                metricsCallback)
               streams.put(partitionId, inputStream)
             } catch {
               case e: IOException =>
@@ -121,7 +122,6 @@ class CelebornShuffleReader[K, C](
         }
         metricsCallback.incReadTime(
           TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startFetchWait))
-        inputStream.setCallback(metricsCallback)
         // ensure inputStream is closed when task completes
         context.addTaskCompletionListener[Unit](_ => inputStream.close())
         inputStream
diff --git a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java 
b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
index e9e4d78bd..22318e542 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClient.java
@@ -26,6 +26,7 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.celeborn.client.read.CelebornInputStream;
+import org.apache.celeborn.client.read.MetricsCallback;
 import org.apache.celeborn.common.CelebornConf;
 import org.apache.celeborn.common.identity.UserIdentifier;
 import org.apache.celeborn.common.protocol.PartitionLocation;
@@ -191,7 +192,12 @@ public abstract class ShuffleClient {
    * @throws IOException
    */
   public abstract CelebornInputStream readPartition(
-      int shuffleId, int partitionId, int attemptNumber, int startMapIndex, 
int endMapIndex)
+      int shuffleId,
+      int partitionId,
+      int attemptNumber,
+      int startMapIndex,
+      int endMapIndex,
+      MetricsCallback metricsCallback)
       throws IOException;
 
   public abstract boolean cleanupShuffle(int shuffleId);
diff --git 
a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java 
b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
index 339da8a7a..d913a9e06 100644
--- a/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
+++ b/client/src/main/java/org/apache/celeborn/client/ShuffleClientImpl.java
@@ -36,6 +36,7 @@ import org.slf4j.LoggerFactory;
 
 import org.apache.celeborn.client.compress.Compressor;
 import org.apache.celeborn.client.read.CelebornInputStream;
+import org.apache.celeborn.client.read.MetricsCallback;
 import org.apache.celeborn.common.CelebornConf;
 import org.apache.celeborn.common.exception.CelebornIOException;
 import org.apache.celeborn.common.identity.UserIdentifier;
@@ -1585,7 +1586,12 @@ public class ShuffleClientImpl extends ShuffleClient {
 
   @Override
   public CelebornInputStream readPartition(
-      int shuffleId, int partitionId, int attemptNumber, int startMapIndex, 
int endMapIndex)
+      int shuffleId,
+      int partitionId,
+      int attemptNumber,
+      int startMapIndex,
+      int endMapIndex,
+      MetricsCallback metricsCallback)
       throws IOException {
     ReduceFileGroups fileGroups = loadFileGroup(shuffleId, partitionId);
 
@@ -1604,7 +1610,8 @@ public class ShuffleClientImpl extends ShuffleClient {
           attemptNumber,
           startMapIndex,
           endMapIndex,
-          fetchExcludedWorkers);
+          fetchExcludedWorkers,
+          metricsCallback);
     }
   }
 
diff --git 
a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java 
b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java
index 996dd6e9e..22cc54a5f 100644
--- 
a/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java
+++ 
b/client/src/main/java/org/apache/celeborn/client/read/CelebornInputStream.java
@@ -56,7 +56,8 @@ public abstract class CelebornInputStream extends InputStream 
{
       int attemptNumber,
       int startMapIndex,
       int endMapIndex,
-      ConcurrentHashMap<String, Long> fetchExcludedWorkers)
+      ConcurrentHashMap<String, Long> fetchExcludedWorkers,
+      MetricsCallback metricsCallback)
       throws IOException {
     if (locations == null || locations.length == 0) {
       return emptyInputStream;
@@ -70,7 +71,8 @@ public abstract class CelebornInputStream extends InputStream 
{
           attemptNumber,
           startMapIndex,
           endMapIndex,
-          fetchExcludedWorkers);
+          fetchExcludedWorkers,
+          metricsCallback);
     }
   }
 
@@ -78,8 +80,6 @@ public abstract class CelebornInputStream extends InputStream 
{
     return emptyInputStream;
   }
 
-  public abstract void setCallback(MetricsCallback callback);
-
   private static final CelebornInputStream emptyInputStream =
       new CelebornInputStream() {
         @Override
@@ -92,9 +92,6 @@ public abstract class CelebornInputStream extends InputStream 
{
           return -1;
         }
 
-        @Override
-        public void setCallback(MetricsCallback callback) {}
-
         @Override
         public int totalPartitionsToRead() {
           return 0;
@@ -164,7 +161,8 @@ public abstract class CelebornInputStream extends 
InputStream {
         int attemptNumber,
         int startMapIndex,
         int endMapIndex,
-        ConcurrentHashMap<String, Long> fetchExcludedWorkers)
+        ConcurrentHashMap<String, Long> fetchExcludedWorkers,
+        MetricsCallback metricsCallback)
         throws IOException {
       this.conf = conf;
       this.clientFactory = clientFactory;
@@ -202,6 +200,7 @@ public abstract class CelebornInputStream extends 
InputStream {
       TransportConf transportConf =
           Utils.fromCelebornConf(conf, TransportModuleConstants.DATA_MODULE, 
0);
       retryWaitMs = transportConf.ioRetryWaitTimeMs();
+      this.callback = metricsCallback;
       moveToNextReader();
     }
 
@@ -418,7 +417,7 @@ public abstract class CelebornInputStream extends 
InputStream {
             logger.debug("Read local shuffle file {}", localHostAddress);
             containLocalRead = true;
             return new LocalPartitionReader(
-                conf, shuffleKey, location, clientFactory, startMapIndex, 
endMapIndex);
+                conf, shuffleKey, location, clientFactory, startMapIndex, 
endMapIndex, callback);
           } else {
             return new WorkerPartitionReader(
                 conf,
@@ -428,22 +427,18 @@ public abstract class CelebornInputStream extends 
InputStream {
                 startMapIndex,
                 endMapIndex,
                 fetchChunkRetryCnt,
-                fetchChunkMaxRetry);
+                fetchChunkMaxRetry,
+                callback);
           }
         case HDFS:
           return new DfsPartitionReader(
-              conf, shuffleKey, location, clientFactory, startMapIndex, 
endMapIndex);
+              conf, shuffleKey, location, clientFactory, startMapIndex, 
endMapIndex, callback);
         default:
           throw new CelebornIOException(
               String.format("Unknown storage info %s to read location %s", 
storageInfo, location));
       }
     }
 
-    public void setCallback(MetricsCallback callback) {
-      // callback must set before read()
-      this.callback = callback;
-    }
-
     @Override
     public int read() throws IOException {
       if (position < limit) {
@@ -539,8 +534,6 @@ public abstract class CelebornInputStream extends 
InputStream {
         return false;
       }
 
-      long startTime = System.nanoTime();
-
       boolean hasData = false;
       while (currentChunk.isReadable() || moveToNextChunk()) {
         currentChunk.readBytes(sizeBuf);
@@ -572,9 +565,7 @@ public abstract class CelebornInputStream extends 
InputStream {
           Set<Integer> batchSet = batchesRead.get(mapId);
           if (!batchSet.contains(batchId)) {
             batchSet.add(batchId);
-            if (callback != null) {
-              callback.incBytesRead(BATCH_HEADER_SIZE + size);
-            }
+            callback.incBytesRead(BATCH_HEADER_SIZE + size);
             if (shuffleCompressionEnabled) {
               // decompress data
               int originalLength = decompressor.getOriginalLen(compressedBuf);
@@ -598,9 +589,6 @@ public abstract class CelebornInputStream extends 
InputStream {
         }
       }
 
-      if (callback != null) {
-        callback.incReadTime(TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - 
startTime));
-      }
       return hasData;
     }
 
diff --git 
a/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java 
b/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java
index e7ffbf182..7acc4fc8c 100644
--- 
a/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java
+++ 
b/client/src/main/java/org/apache/celeborn/client/read/DfsPartitionReader.java
@@ -64,6 +64,7 @@ public class DfsPartitionReader implements PartitionReader {
   private int currentChunkIndex = 0;
   private TransportClient client;
   private PbStreamHandler streamHandler;
+  private MetricsCallback metricsCallback;
 
   public DfsPartitionReader(
       CelebornConf conf,
@@ -71,12 +72,14 @@ public class DfsPartitionReader implements PartitionReader {
       PartitionLocation location,
       TransportClientFactory clientFactory,
       int startMapIndex,
-      int endMapIndex)
+      int endMapIndex,
+      MetricsCallback metricsCallback)
       throws IOException {
     shuffleChunkSize = conf.dfsReadChunkSize();
     fetchMaxReqsInFlight = conf.clientFetchMaxReqsInFlight();
     results = new LinkedBlockingQueue<>();
 
+    this.metricsCallback = metricsCallback;
     this.location = location;
 
     final List<Long> chunkOffsets = new ArrayList<>();
@@ -224,7 +227,10 @@ public class DfsPartitionReader implements PartitionReader 
{
     try {
       while (chunk == null) {
         checkException();
+        Long startFetchWait = System.nanoTime();
         chunk = results.poll(500, TimeUnit.MILLISECONDS);
+        metricsCallback.incReadTime(
+            TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startFetchWait));
         logger.debug("poll result with result size: {}", results.size());
       }
     } catch (InterruptedException e) {
diff --git 
a/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java
 
b/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java
index e8437d02a..97a3acb34 100644
--- 
a/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java
+++ 
b/client/src/main/java/org/apache/celeborn/client/read/LocalPartitionReader.java
@@ -67,6 +67,7 @@ public class LocalPartitionReader implements PartitionReader {
   private AtomicBoolean pendingFetchTask = new AtomicBoolean(false);
   private PbStreamHandler streamHandler;
   private TransportClient client;
+  private MetricsCallback metricsCallback;
 
   public LocalPartitionReader(
       CelebornConf conf,
@@ -74,7 +75,8 @@ public class LocalPartitionReader implements PartitionReader {
       PartitionLocation location,
       TransportClientFactory clientFactory,
       int startMapIndex,
-      int endMapIndex)
+      int endMapIndex,
+      MetricsCallback metricsCallback)
       throws IOException {
     if (readLocalShufflePool == null) {
       synchronized (LocalPartitionReader.class) {
@@ -88,6 +90,7 @@ public class LocalPartitionReader implements PartitionReader {
     fetchMaxReqsInFlight = conf.clientFetchMaxReqsInFlight();
     results = new LinkedBlockingQueue<>();
     this.location = location;
+    this.metricsCallback = metricsCallback;
     long fetchTimeoutMs = conf.clientFetchTimeoutMs();
     try {
       client = clientFactory.createClient(location.getHost(), 
location.getFetchPort(), 0);
@@ -199,7 +202,10 @@ public class LocalPartitionReader implements 
PartitionReader {
     try {
       while (chunk == null) {
         checkException();
+        Long startFetchWait = System.nanoTime();
         chunk = results.poll(100, TimeUnit.MILLISECONDS);
+        metricsCallback.incReadTime(
+            TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startFetchWait));
         logger.debug("Poll result with result size: {}", results.size());
       }
     } catch (InterruptedException e) {
diff --git 
a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java
 
b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java
index b97da1009..080643814 100644
--- 
a/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java
+++ 
b/client/src/main/java/org/apache/celeborn/client/read/WorkerPartitionReader.java
@@ -51,6 +51,7 @@ public class WorkerPartitionReader implements PartitionReader 
{
   private final TransportClientFactory clientFactory;
   private PbStreamHandler streamHandler;
   private TransportClient client;
+  private MetricsCallback metricsCallback;
 
   private int returnedChunks;
   private int chunkIndex;
@@ -76,11 +77,13 @@ public class WorkerPartitionReader implements 
PartitionReader {
       int startMapIndex,
       int endMapIndex,
       int fetchChunkRetryCnt,
-      int fetchChunkMaxRetry)
+      int fetchChunkMaxRetry,
+      MetricsCallback metricsCallback)
       throws IOException, InterruptedException {
     fetchMaxReqsInFlight = conf.clientFetchMaxReqsInFlight();
     results = new LinkedBlockingQueue<>();
     fetchTimeoutMs = conf.clientFetchTimeoutMs();
+    this.metricsCallback = metricsCallback;
     // only add the buffer to results queue if this reader is not closed.
     callback =
         new ChunkReceivedCallback() {
@@ -144,7 +147,10 @@ public class WorkerPartitionReader implements 
PartitionReader {
     try {
       while (chunk == null) {
         checkException();
+        Long startFetchWait = System.nanoTime();
         chunk = results.poll(500, TimeUnit.MILLISECONDS);
+        metricsCallback.incReadTime(
+            TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - startFetchWait));
       }
     } catch (InterruptedException e) {
       logger.error("PartitionReader thread interrupted while polling data.");
diff --git 
a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java 
b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
index 252452687..339b8b859 100644
--- a/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
+++ b/client/src/test/java/org/apache/celeborn/client/DummyShuffleClient.java
@@ -34,6 +34,7 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.celeborn.client.read.CelebornInputStream;
+import org.apache.celeborn.client.read.MetricsCallback;
 import org.apache.celeborn.common.CelebornConf;
 import org.apache.celeborn.common.protocol.PartitionLocation;
 import org.apache.celeborn.common.rpc.RpcEndpointRef;
@@ -112,7 +113,12 @@ public class DummyShuffleClient extends ShuffleClient {
 
   @Override
   public CelebornInputStream readPartition(
-      int shuffleId, int partitionId, int attemptNumber, int startMapIndex, 
int endMapIndex) {
+      int shuffleId,
+      int partitionId,
+      int attemptNumber,
+      int startMapIndex,
+      int endMapIndex,
+      MetricsCallback metricsCallback) {
     return null;
   }
 
diff --git 
a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala 
b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala
index 33bb24d62..f7678452f 100644
--- 
a/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala
+++ 
b/client/src/test/scala/org/apache/celeborn/client/WithShuffleClientSuite.scala
@@ -25,6 +25,7 @@ import scala.collection.JavaConverters._
 import org.junit.Assert
 
 import org.apache.celeborn.CelebornFunSuite
+import org.apache.celeborn.client.read.MetricsCallback
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.identity.UserIdentifier
 import org.apache.celeborn.common.util.JavaUtils.timeOutOrMeetCondition
@@ -140,12 +141,17 @@ trait WithShuffleClientSuite extends CelebornFunSuite {
     // reduce file group size (for empty partitions)
     Assert.assertEquals(shuffleClient.getReduceFileGroupsMap.size(), 0)
 
+    val metricsCallback = new MetricsCallback {
+      override def incBytesRead(bytesWritten: Long): Unit = {}
+      override def incReadTime(time: Long): Unit = {}
+    }
+
     // reduce normal empty CelebornInputStream
-    var stream = shuffleClient.readPartition(shuffleId, 1, 1, 0, 
Integer.MAX_VALUE)
+    var stream = shuffleClient.readPartition(shuffleId, 1, 1, 0, 
Integer.MAX_VALUE, metricsCallback)
     Assert.assertEquals(stream.read(), -1)
 
     // reduce normal null partition for CelebornInputStream
-    stream = shuffleClient.readPartition(shuffleId, 3, 1, 0, Integer.MAX_VALUE)
+    stream = shuffleClient.readPartition(shuffleId, 3, 1, 0, 
Integer.MAX_VALUE, metricsCallback)
     Assert.assertEquals(stream.read(), -1)
   }
 
diff --git 
a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala
 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala
index 85fc3201c..f8cbec595 100644
--- 
a/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala
+++ 
b/worker/src/test/scala/org/apache/celeborn/service/deploy/cluster/ReadWriteTestBase.scala
@@ -29,6 +29,7 @@ import org.scalatest.BeforeAndAfterAll
 import org.scalatest.funsuite.AnyFunSuite
 
 import org.apache.celeborn.client.{LifecycleManager, ShuffleClientImpl}
+import org.apache.celeborn.client.read.MetricsCallback
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.identity.UserIdentifier
 import org.apache.celeborn.common.internal.Logging
@@ -102,7 +103,11 @@ trait ReadWriteTestBase extends AnyFunSuite
 
     shuffleClient.mapperEnd(1, 0, 0, 1)
 
-    val inputStream = shuffleClient.readPartition(1, 0, 0, 0, 
Integer.MAX_VALUE)
+    val metricsCallback = new MetricsCallback {
+      override def incBytesRead(bytesWritten: Long): Unit = {}
+      override def incReadTime(time: Long): Unit = {}
+    }
+    val inputStream = shuffleClient.readPartition(1, 0, 0, 0, 
Integer.MAX_VALUE, metricsCallback)
     val outputStream = new ByteArrayOutputStream()
 
     var b = inputStream.read()

Reply via email to