This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 11881ab92 [#2640] feat(spark): Involve background prefetch time in 
spark UI (#2641)
11881ab92 is described below

commit 11881ab92706bb8cf3a2bad556ed5e545ac890cb
Author: Junfan Zhang <[email protected]>
AuthorDate: Wed Oct 15 09:37:57 2025 +0800

    [#2640] feat(spark): Involve background prefetch time in spark UI (#2641)
    
    ### What changes were proposed in this pull request?
    
    This PR is to Involve background prefetch time in spark UI
    
    ### Why are the changes needed?
    
    for #2640
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Internal job test
---
 .../shuffle/reader/RssShuffleDataIterator.java     |  1 +
 .../scala/org/apache/spark/ui/ShufflePage.scala    |  5 +-
 .../uniffle/client/impl/ShuffleReadClientImpl.java | 16 ++++++-
 .../apache/uniffle/common/ShuffleReadTimes.java    | 14 +++++-
 proto/src/main/proto/Rss.proto                     |  1 +
 .../storage/handler/ClientReadHandlerMetric.java   | 20 ++++++++
 .../handler/impl/ComposedClientReadHandler.java    | 56 +++++++++++++++-------
 .../impl/PrefetchableClientReadHandler.java        | 17 ++++++-
 8 files changed, 110 insertions(+), 20 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
index 3fdb3bb47..c7e4a7559 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/reader/RssShuffleDataIterator.java
@@ -166,6 +166,7 @@ public class RssShuffleDataIterator<K, C> extends 
AbstractIterator<Product2<K, C
         // finish reading records, check data consistent
         shuffleReadClient.checkProcessedBlockIds();
         shuffleReadClient.logStatics();
+        shuffleReadClient.getShuffleReadTimes();
         String decInfo =
             !codec.isPresent()
                 ? "."
diff --git 
a/client-spark/extension/src/main/scala/org/apache/spark/ui/ShufflePage.scala 
b/client-spark/extension/src/main/scala/org/apache/spark/ui/ShufflePage.scala
index 24b8f5be2..46359e1e6 100644
--- 
a/client-spark/extension/src/main/scala/org/apache/spark/ui/ShufflePage.scala
+++ 
b/client-spark/extension/src/main/scala/org/apache/spark/ui/ShufflePage.scala
@@ -48,6 +48,7 @@ class ShufflePage(parent: ShuffleTab) extends WebUIPage("") 
with Logging {
     <td>{kv(4)}</td>
     <td>{kv(5)}</td>
     <td>{kv(6)}</td>
+    <td>{kv(7)}</td>
   </tr>
 
   private def shuffleWriteTimesRow(kv: Seq[String]) = <tr>
@@ -160,7 +161,7 @@ class ShufflePage(parent: ShuffleTab) extends WebUIPage("") 
with Logging {
     val readTimes = runtimeStatusStore.shuffleReadTimes().times
     val readTotal = if (readTimes.getTotal <= 0) -1 else readTimes.getTotal
     val readTimesUI = UIUtils.listingTable(
-      Seq("Total", "Fetch", "Copy", "CRC", "Deserialize", "Decompress", 
"Background Decompress"),
+      Seq("Total", "Fetch", "Copy", "CRC", "Deserialize", "Decompress", 
"Background Decompress", "Background Fetch"),
       shuffleReadTimesRow,
       Seq(
         Seq(
@@ -171,6 +172,7 @@ class ShufflePage(parent: ShuffleTab) extends WebUIPage("") 
with Logging {
           UIUtils.formatDuration(readTimes.getDeserialize),
           UIUtils.formatDuration(readTimes.getDecompress),
           UIUtils.formatDuration(readTimes.getBackgroundDecompress),
+          UIUtils.formatDuration(readTimes.getBackgroundFetch),
         ),
         Seq(
           1,
@@ -180,6 +182,7 @@ class ShufflePage(parent: ShuffleTab) extends WebUIPage("") 
with Logging {
           readTimes.getDeserialize.toDouble / readTotal,
           readTimes.getDecompress.toDouble / readTotal,
           readTimes.getBackgroundDecompress.toDouble / readTotal,
+          readTimes.getBackgroundFetch.toDouble / readTotal,
         ).map(x => roundToTwoDecimals(x).toString)
       ),
       fixedWidth = true
diff --git 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
index d61283742..a2bb18fa1 100644
--- 
a/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
+++ 
b/client/src/main/java/org/apache/uniffle/client/impl/ShuffleReadClientImpl.java
@@ -53,7 +53,9 @@ import org.apache.uniffle.common.util.ChecksumUtils;
 import org.apache.uniffle.common.util.IdHelper;
 import org.apache.uniffle.common.util.RssUtils;
 import org.apache.uniffle.storage.factory.ShuffleHandlerFactory;
+import org.apache.uniffle.storage.handler.ClientReadHandlerMetric;
 import org.apache.uniffle.storage.handler.api.ClientReadHandler;
+import org.apache.uniffle.storage.handler.impl.AbstractClientReadHandler;
 import org.apache.uniffle.storage.handler.impl.ShuffleServerReadCostTracker;
 import org.apache.uniffle.storage.request.CreateShuffleReadHandlerRequest;
 
@@ -400,7 +402,19 @@ public class ShuffleReadClientImpl implements 
ShuffleReadClient {
     if (decompressionWorker != null) {
       backgroundDecompressionTime = decompressionWorker.decompressionMillis();
     }
+
+    long backgroundFetchTime = 0;
+    if (clientReadHandler instanceof AbstractClientReadHandler) {
+      ClientReadHandlerMetric metric =
+          ((AbstractClientReadHandler) 
clientReadHandler).getReadHandlerMetric();
+      backgroundFetchTime += metric.getPrefetchTime();
+    }
+
     return new ShuffleReadTimes(
-        readDataTime.get(), copyTime.get(), crcCheckTime.get(), 
backgroundDecompressionTime);
+        readDataTime.get(),
+        copyTime.get(),
+        crcCheckTime.get(),
+        backgroundDecompressionTime,
+        backgroundFetchTime);
   }
 }
diff --git 
a/common/src/main/java/org/apache/uniffle/common/ShuffleReadTimes.java 
b/common/src/main/java/org/apache/uniffle/common/ShuffleReadTimes.java
index 1c9523355..92f488372 100644
--- a/common/src/main/java/org/apache/uniffle/common/ShuffleReadTimes.java
+++ b/common/src/main/java/org/apache/uniffle/common/ShuffleReadTimes.java
@@ -22,19 +22,24 @@ import org.apache.uniffle.proto.RssProtos;
 /** The unit is millis */
 public class ShuffleReadTimes {
   private long fetch;
+  private long backgroundFetch;
+
   private long crc;
   private long copy;
   private long deserialize;
+
   private long decompress;
   private long backgroundDecompress;
 
   public ShuffleReadTimes() {}
 
-  public ShuffleReadTimes(long fetch, long crc, long copy, long 
backgroundDecompress) {
+  public ShuffleReadTimes(
+      long fetch, long crc, long copy, long backgroundDecompress, long 
backgroundFetch) {
     this.fetch = fetch;
     this.crc = crc;
     this.copy = copy;
     this.backgroundDecompress = backgroundDecompress;
+    this.backgroundFetch = backgroundFetch;
   }
 
   public long getFetch() {
@@ -69,6 +74,10 @@ public class ShuffleReadTimes {
     return backgroundDecompress;
   }
 
+  public long getBackgroundFetch() {
+    return backgroundFetch;
+  }
+
   public void merge(ShuffleReadTimes other) {
     if (other == null) {
       return;
@@ -79,6 +88,7 @@ public class ShuffleReadTimes {
     this.deserialize += other.deserialize;
     this.decompress += other.decompress;
     this.backgroundDecompress += other.backgroundDecompress;
+    this.backgroundFetch += other.backgroundFetch;
   }
 
   public long getTotal() {
@@ -93,6 +103,7 @@ public class ShuffleReadTimes {
         .setDecompress(decompress)
         .setDeserialize(deserialize)
         .setBackgroundDecompress(backgroundDecompress)
+        .setBackgroundFetch(backgroundFetch)
         .build();
   }
 
@@ -104,6 +115,7 @@ public class ShuffleReadTimes {
     time.decompress = proto.getDecompress();
     time.deserialize = proto.getDeserialize();
     time.backgroundDecompress = proto.getBackgroundDecompress();
+    time.backgroundFetch = proto.getBackgroundFetch();
     return time;
   }
 }
diff --git a/proto/src/main/proto/Rss.proto b/proto/src/main/proto/Rss.proto
index 0950c5e4c..0da250345 100644
--- a/proto/src/main/proto/Rss.proto
+++ b/proto/src/main/proto/Rss.proto
@@ -657,6 +657,7 @@ message ShuffleReadTimes {
   int64 deserialize = 4;
   int64 decompress = 5;
   int64 backgroundDecompress = 6;
+  int64 backgroundFetch = 7;
 }
 
 message ReportShuffleReadMetricResponse {
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/handler/ClientReadHandlerMetric.java
 
b/storage/src/main/java/org/apache/uniffle/storage/handler/ClientReadHandlerMetric.java
index b59b28f73..45a7e5a58 100644
--- 
a/storage/src/main/java/org/apache/uniffle/storage/handler/ClientReadHandlerMetric.java
+++ 
b/storage/src/main/java/org/apache/uniffle/storage/handler/ClientReadHandlerMetric.java
@@ -28,6 +28,16 @@ public class ClientReadHandlerMetric {
   private long skippedReadLength = 0L;
   private long skippedReadUncompressLength = 0L;
 
+  private long prefetchTime = 0L;
+
+  public void setPrefetchTime(long prefetchTime) {
+    this.prefetchTime = prefetchTime;
+  }
+
+  public long getPrefetchTime() {
+    return prefetchTime;
+  }
+
   public long getReadBlockNum() {
     return readBlockNum;
   }
@@ -103,4 +113,14 @@ public class ClientReadHandlerMetric {
         skippedReadLength,
         skippedReadUncompressLength);
   }
+
+  public void merge(ClientReadHandlerMetric other) {
+    this.readBlockNum += other.readBlockNum;
+    this.readLength += other.readLength;
+    this.readUncompressLength += other.readUncompressLength;
+    this.skippedReadBlockNum += other.skippedReadBlockNum;
+    this.skippedReadLength += other.skippedReadLength;
+    this.skippedReadUncompressLength += other.skippedReadUncompressLength;
+    this.prefetchTime += other.prefetchTime;
+  }
 }
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/ComposedClientReadHandler.java
 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/ComposedClientReadHandler.java
index 7c76f55cf..8ca5229b9 100644
--- 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/ComposedClientReadHandler.java
+++ 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/ComposedClientReadHandler.java
@@ -63,16 +63,9 @@ public class ComposedClientReadHandler extends 
AbstractClientReadHandler {
   private final ShuffleServerInfo serverInfo;
   private final Map<Tier, Supplier<ClientReadHandler>> supplierMap = new 
EnumMap<>(Tier.class);
   private final Map<Tier, ClientReadHandler> handlerMap = new 
EnumMap<>(Tier.class);
-  private final Map<Tier, ClientReadHandlerMetric> metricsMap = new 
EnumMap<>(Tier.class);
   private Tier currentTier = Tier.VALUES[0]; // == Tier.HOT
   private final int numTiers;
 
-  {
-    for (Tier tier : Tier.VALUES) {
-      metricsMap.put(tier, new ClientReadHandlerMetric());
-    }
-  }
-
   public ComposedClientReadHandler(ShuffleServerInfo serverInfo, 
ClientReadHandler... handlers) {
     Preconditions.checkArgument(
         handlers.length <= Tier.VALUES.length,
@@ -100,14 +93,30 @@ public class ComposedClientReadHandler extends 
AbstractClientReadHandler {
     }
   }
 
-  @Override
-  public ShuffleDataResult readShuffleData() {
+  private ClientReadHandlerMetric getMetric(Tier tier) {
+    ClientReadHandler handler = getHandler(tier);
+    if (handler != null && handler instanceof AbstractClientReadHandler) {
+      return ((AbstractClientReadHandler) handler).getReadHandlerMetric();
+    }
+    return new ClientReadHandlerMetric();
+  }
+
+  private ClientReadHandler getOrCreateHandler(Tier tier) {
     ClientReadHandler handler =
-        handlerMap.computeIfAbsent(
-            currentTier, key -> supplierMap.getOrDefault(key, () -> 
null).get());
+        handlerMap.computeIfAbsent(tier, key -> supplierMap.getOrDefault(key, 
() -> null).get());
     if (handler == null) {
       throw new RssException("Unexpected null when getting " + 
currentTier.name() + " handler");
     }
+    return handler;
+  }
+
+  private ClientReadHandler getHandler(Tier tier) {
+    return handlerMap.get(tier);
+  }
+
+  @Override
+  public ShuffleDataResult readShuffleData() {
+    ClientReadHandler handler = getOrCreateHandler(currentTier);
     ShuffleDataResult shuffleDataResult;
     try {
       shuffleDataResult = handler.readShuffleData();
@@ -147,8 +156,11 @@ public class ComposedClientReadHandler extends 
AbstractClientReadHandler {
     if (bs == null) {
       return;
     }
-    super.updateConsumedBlockInfo(bs, isSkippedMetrics);
-    updateBlockMetric(metricsMap.get(currentTier), bs, isSkippedMetrics);
+    ClientReadHandler handler = getHandler(currentTier);
+    if (handler == null) {
+      throw new RssException("Unexpected null when getting " + 
currentTier.name() + " handler");
+    }
+    handler.updateConsumedBlockInfo(bs, isSkippedMetrics);
   }
 
   @Override
@@ -188,7 +200,7 @@ public class ComposedClientReadHandler extends 
AbstractClientReadHandler {
       Function<ClientReadHandlerMetric, Long> skipped) {
     StringBuilder sb =
         new StringBuilder("Client read ")
-            .append(consumed.apply(readHandlerMetric))
+            .append(consumed.apply(getReadHandlerMetric()))
             .append(" ")
             .append(name)
             .append(" from [")
@@ -198,16 +210,28 @@ public class ComposedClientReadHandler extends 
AbstractClientReadHandler {
       sb.append(" ")
           .append(tier.name().toLowerCase())
           .append(":")
-          .append(consumed.apply(metricsMap.get(tier)));
+          .append(consumed.apply(getMetric(tier)));
     }
     sb.append(" ], Skipped[");
     for (Tier tier : Tier.VALUES) {
       sb.append(" ")
           .append(tier.name().toLowerCase())
           .append(":")
-          .append(skipped.apply(metricsMap.get(tier)));
+          .append(skipped.apply(getMetric(tier)));
     }
     sb.append(" ]");
     return sb.toString();
   }
+
+  @Override
+  public ClientReadHandlerMetric getReadHandlerMetric() {
+    ClientReadHandlerMetric metric = new ClientReadHandlerMetric();
+    for (Tier tier : Tier.VALUES) {
+      ClientReadHandlerMetric tierMetric = getMetric(tier);
+      if (tierMetric != null) {
+        metric.merge(tierMetric);
+      }
+    }
+    return metric;
+  }
 }
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/PrefetchableClientReadHandler.java
 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/PrefetchableClientReadHandler.java
index a8386ea4f..bc64ab04d 100644
--- 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/PrefetchableClientReadHandler.java
+++ 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/PrefetchableClientReadHandler.java
@@ -31,6 +31,7 @@ import org.slf4j.LoggerFactory;
 
 import org.apache.uniffle.common.ShuffleDataResult;
 import org.apache.uniffle.common.exception.RssException;
+import org.apache.uniffle.storage.handler.ClientReadHandlerMetric;
 
 public abstract class PrefetchableClientReadHandler extends 
AbstractClientReadHandler {
   private static final Logger LOG = 
LoggerFactory.getLogger(PrefetchableClientReadHandler.class);
@@ -144,13 +145,27 @@ public abstract class PrefetchableClientReadHandler 
extends AbstractClientReadHa
     }
   }
 
+  private long getBackgroundFetchTime() {
+    long fetch = 0;
+    if (fetchTime != null) {
+      fetch = fetchTime.get();
+    }
+    return fetch;
+  }
+
   @Override
   public void logConsumedBlockInfo() {
     LOG.info(
         "Metrics for shuffleId[{}], partitionId[{}], background fetch cost {} 
ms",
         shuffleId,
         partitionId,
-        fetchTime);
+        getBackgroundFetchTime());
     super.logConsumedBlockInfo();
   }
+
+  @Override
+  public ClientReadHandlerMetric getReadHandlerMetric() {
+    readHandlerMetric.setPrefetchTime(getBackgroundFetchTime());
+    return readHandlerMetric;
+  }
 }

Reply via email to