This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 81e00bec9 [#291] feat(client): Introduce PrefetchableClientReadHandler 
to support async read (#2365)
81e00bec9 is described below

commit 81e00bec9ef23dc53716e3193cebc8c75477cd52
Author: Junfan Zhang <[email protected]>
AuthorDate: Sat Feb 8 17:22:51 2025 +0800

    [#291] feat(client): Introduce PrefetchableClientReadHandler to support 
async read (#2365)
    
    ### What changes were proposed in this pull request?
    
    1. Introduce PrefetchableClientReadHandler to support async read. And this 
will be disabled by default.
    2. Apply for the memory/localfile/hdfs read handler
    
    ### Why are the changes needed?
    
    Recently I found some important spark jobs are slow due to the lots of 
shuffle read operations. If we could support async read, the job's performance 
will be improved.
    
    So this PR is the callback for #291. almost 3 years ago!
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. Some configs are introduced
    1. `rss.client.read.prefetch.enabled`
    2. `rss.client.read.prefetch.capacity`
    3. `rss.client.read.prefetch.timeoutSec`
    
    ### How was this patch tested?
    
    1. Unit tests
    
    ---------
    
    Co-authored-by: Junfan Zhang <[email protected]>
---
 .../uniffle/common/config/RssClientConf.java       |  18 +++
 docs/client_guide/client_guide.md                  |   4 +
 .../storage/factory/ShuffleHandlerFactory.java     |   9 +-
 .../handler/impl/DataSkippableReadHandler.java     |   9 +-
 .../handler/impl/HadoopClientReadHandler.java      |  12 +-
 .../handler/impl/HadoopShuffleReadHandler.java     |  10 +-
 .../handler/impl/LocalFileClientReadHandler.java   |  11 +-
 .../handler/impl/MemoryClientReadHandler.java      |  20 ++-
 .../impl/PrefetchableClientReadHandler.java        | 141 +++++++++++++++++++++
 .../request/CreateShuffleReadHandlerRequest.java   |  14 ++
 .../impl/PrefetchableClientReadHandlerTest.java    | 117 +++++++++++++++++
 11 files changed, 346 insertions(+), 19 deletions(-)

diff --git 
a/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java 
b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
index c366f1203..1f803f167 100644
--- a/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
+++ b/common/src/main/java/org/apache/uniffle/common/config/RssClientConf.java
@@ -330,4 +330,22 @@ public class RssClientConf {
           .intType()
           .defaultValue(-1)
           .withDescription("the event loop threads of netty impl for grpc");
+
+  public static final ConfigOption<Boolean> RSS_CLIENT_PREFETCH_ENABLED =
+      ConfigOptions.key("rss.client.read.prefetch.enabled")
+          .booleanType()
+          .defaultValue(false)
+          .withDescription("Read prefetch switch that will be disabled by 
default");
+
+  public static final ConfigOption<Integer> RSS_CLIENT_PREFETCH_CAPACITY =
+      ConfigOptions.key("rss.client.read.prefetch.capacity")
+          .intType()
+          .defaultValue(4)
+          .withDescription("Read prefetch capacity");
+
+  public static final ConfigOption<Integer> READ_CLIENT_PREFETCH_TIMEOUT_SEC =
+      ConfigOptions.key("rss.client.read.prefetch.timeoutSec")
+          .intType()
+          .defaultValue(120)
+          .withDescription("Read prefetch timeout seconds");
 }
diff --git a/docs/client_guide/client_guide.md 
b/docs/client_guide/client_guide.md
index c926deed1..18c11e46e 100644
--- a/docs/client_guide/client_guide.md
+++ b/docs/client_guide/client_guide.md
@@ -63,6 +63,10 @@ The important configuration of client is listed as 
following. These configuratio
 | <client_type>.rss.client.blockIdManagerClass                    | -          
                            | The block id manager class of server for this 
application, the implementation of this interface to manage the shuffle block 
ids                                                                             
                                                                                
                                                                                
                    [...]
 | <client_type>.rss.client.reportExcludeProperties                | -          
                            | The value of exclude properties specify a list of 
client configuration properties that should not be reported to the coordinator 
by the DelegationRssShuffleManager.                                             
                                                                                
                                                                                
               [...]
 | <client_type>.rss.client.reportIncludeProperties                | -          
                            | The value of include properties specify a list of 
client configuration properties that should be exclusively reported to the 
coordinator by the DelegationRssShuffleManager.                                 
                                                                                
                                                                                
                   [...]
+| <client_type>.rss.client.read.prefetch.enabled                  | false      
                            | Read prefetch switch that will be disabled by 
default                                                                         
                                                                                
                                                                                
                                                                                
                  [...]
+| <client_type>.rss.client.read.prefetch.capacity                 | 4          
                            | Read prefetch capacity                            
                                                                                
                                                                                
                                                                                
                                                                                
              [...]
+| <client_type>.rss.client.read.prefetch.timeoutSec               | 120        
                            | Read prefetch timeout seconds                     
                                                                                
                                                                                
                                                                                
                                                                                
              [...]
+
 
 Notice:
 
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
 
b/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
index 4edca3b34..b68d77a19 100644
--- 
a/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
+++ 
b/storage/src/main/java/org/apache/uniffle/storage/factory/ShuffleHandlerFactory.java
@@ -136,7 +136,8 @@ public class ShuffleHandlerFactory {
             shuffleServerClient,
             expectTaskIds,
             request.getRetryMax(),
-            request.getRetryIntervalMax());
+            request.getRetryIntervalMax(),
+            request.getPrefetchOption());
     return memoryClientReadHandler;
   }
 
@@ -159,7 +160,8 @@ public class ShuffleHandlerFactory {
         request.getDistributionType(),
         request.getExpectTaskIds(),
         request.getRetryMax(),
-        request.getRetryIntervalMax());
+        request.getRetryIntervalMax(),
+        request.getPrefetchOption());
   }
 
   private ClientReadHandler getHadoopClientReadHandler(
@@ -179,7 +181,8 @@ public class ShuffleHandlerFactory {
         request.getDistributionType(),
         request.getExpectTaskIds(),
         ssi.getId(),
-        request.isOffHeapEnabled());
+        request.isOffHeapEnabled(),
+        request.getPrefetchOption());
   }
 
   public ShuffleDeleteHandler createShuffleDeleteHandler(
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java
 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java
index 220e02997..ae45a7505 100644
--- 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java
+++ 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/DataSkippableReadHandler.java
@@ -18,6 +18,7 @@
 package org.apache.uniffle.storage.handler.impl;
 
 import java.util.List;
+import java.util.Optional;
 
 import com.google.common.collect.Lists;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
@@ -30,7 +31,7 @@ import org.apache.uniffle.common.ShuffleDataSegment;
 import org.apache.uniffle.common.ShuffleIndexResult;
 import org.apache.uniffle.common.segment.SegmentSplitterFactory;
 
-public abstract class DataSkippableReadHandler extends 
AbstractClientReadHandler {
+public abstract class DataSkippableReadHandler extends 
PrefetchableClientReadHandler {
   private static final Logger LOG = 
LoggerFactory.getLogger(DataSkippableReadHandler.class);
 
   protected List<ShuffleDataSegment> shuffleDataSegments = 
Lists.newArrayList();
@@ -50,7 +51,9 @@ public abstract class DataSkippableReadHandler extends 
AbstractClientReadHandler
       Roaring64NavigableMap expectBlockIds,
       Roaring64NavigableMap processBlockIds,
       ShuffleDataDistributionType distributionType,
-      Roaring64NavigableMap expectTaskIds) {
+      Roaring64NavigableMap expectTaskIds,
+      Optional<PrefetchOption> prefetchOption) {
+    super(prefetchOption);
     this.appId = appId;
     this.shuffleId = shuffleId;
     this.partitionId = partitionId;
@@ -66,7 +69,7 @@ public abstract class DataSkippableReadHandler extends 
AbstractClientReadHandler
   protected abstract ShuffleDataResult readShuffleData(ShuffleDataSegment 
segment);
 
   @Override
-  public ShuffleDataResult readShuffleData() {
+  public ShuffleDataResult doReadShuffleData() {
     if (shuffleDataSegments.isEmpty()) {
       ShuffleIndexResult shuffleIndexResult = readShuffleIndex();
       if (shuffleIndexResult == null || shuffleIndexResult.isEmpty()) {
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HadoopClientReadHandler.java
 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HadoopClientReadHandler.java
index 1cf636e86..ec444b2b4 100644
--- 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HadoopClientReadHandler.java
+++ 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HadoopClientReadHandler.java
@@ -20,6 +20,7 @@ package org.apache.uniffle.storage.handler.impl;
 import java.io.FileNotFoundException;
 import java.util.Collections;
 import java.util.List;
+import java.util.Optional;
 import java.util.stream.Collectors;
 
 import com.google.common.collect.Lists;
@@ -55,6 +56,7 @@ public class HadoopClientReadHandler extends 
AbstractClientReadHandler {
   private ShuffleDataDistributionType distributionType;
   private Roaring64NavigableMap expectTaskIds;
   private boolean offHeapEnable = false;
+  private Optional<PrefetchableClientReadHandler.PrefetchOption> 
prefetchOption;
 
   public HadoopClientReadHandler(
       String appId,
@@ -71,7 +73,8 @@ public class HadoopClientReadHandler extends 
AbstractClientReadHandler {
       ShuffleDataDistributionType distributionType,
       Roaring64NavigableMap expectTaskIds,
       String shuffleServerId,
-      boolean offHeapEnable) {
+      boolean offHeapEnable,
+      Optional<PrefetchableClientReadHandler.PrefetchOption> prefetchOption) {
     this.appId = appId;
     this.shuffleId = shuffleId;
     this.partitionId = partitionId;
@@ -87,6 +90,7 @@ public class HadoopClientReadHandler extends 
AbstractClientReadHandler {
     this.expectTaskIds = expectTaskIds;
     this.shuffleServerId = shuffleServerId;
     this.offHeapEnable = offHeapEnable;
+    this.prefetchOption = prefetchOption;
   }
 
   // Only for test
@@ -117,7 +121,8 @@ public class HadoopClientReadHandler extends 
AbstractClientReadHandler {
         ShuffleDataDistributionType.NORMAL,
         Roaring64NavigableMap.bitmapOf(),
         null,
-        false);
+        false,
+        Optional.empty());
   }
 
   protected void init(String fullShufflePath) {
@@ -174,7 +179,8 @@ public class HadoopClientReadHandler extends 
AbstractClientReadHandler {
                   hadoopConf,
                   distributionType,
                   expectTaskIds,
-                  offHeapEnable);
+                  offHeapEnable,
+                  prefetchOption);
           readHandlers.add(handler);
         } catch (Exception e) {
           LOG.warn("Can't create ShuffleReaderHandler for " + filePrefix, e);
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandler.java
 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandler.java
index c7af921b4..f3ecc16cf 100644
--- 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandler.java
+++ 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/HadoopShuffleReadHandler.java
@@ -20,6 +20,7 @@ package org.apache.uniffle.storage.handler.impl;
 import java.io.IOException;
 import java.nio.ByteBuffer;
 import java.util.List;
+import java.util.Optional;
 
 import org.apache.hadoop.conf.Configuration;
 import org.apache.hadoop.fs.Path;
@@ -57,7 +58,8 @@ public class HadoopShuffleReadHandler extends 
DataSkippableReadHandler {
       Configuration conf,
       ShuffleDataDistributionType distributionType,
       Roaring64NavigableMap expectTaskIds,
-      boolean offHeapEnabled)
+      boolean offHeapEnabled,
+      Optional<PrefetchOption> prefetchOption)
       throws Exception {
     super(
         appId,
@@ -67,7 +69,8 @@ public class HadoopShuffleReadHandler extends 
DataSkippableReadHandler {
         expectBlockIds,
         processBlockIds,
         distributionType,
-        expectTaskIds);
+        expectTaskIds,
+        prefetchOption);
     this.filePrefix = filePrefix;
     this.indexReader =
         
createHadoopReader(ShuffleStorageUtils.generateIndexFileName(filePrefix), conf);
@@ -98,7 +101,8 @@ public class HadoopShuffleReadHandler extends 
DataSkippableReadHandler {
         conf,
         ShuffleDataDistributionType.NORMAL,
         Roaring64NavigableMap.bitmapOf(),
-        false);
+        false,
+        Optional.empty());
   }
 
   @Override
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileClientReadHandler.java
 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileClientReadHandler.java
index 4772dcf21..a5464678b 100644
--- 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileClientReadHandler.java
+++ 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/LocalFileClientReadHandler.java
@@ -17,6 +17,8 @@
 
 package org.apache.uniffle.storage.handler.impl;
 
+import java.util.Optional;
+
 import com.google.common.annotations.VisibleForTesting;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
 import org.slf4j.Logger;
@@ -55,7 +57,8 @@ public class LocalFileClientReadHandler extends 
DataSkippableReadHandler {
       ShuffleDataDistributionType distributionType,
       Roaring64NavigableMap expectTaskIds,
       int retryMax,
-      long retryIntervalMax) {
+      long retryIntervalMax,
+      Optional<PrefetchOption> prefetchOption) {
     super(
         appId,
         shuffleId,
@@ -64,7 +67,8 @@ public class LocalFileClientReadHandler extends 
DataSkippableReadHandler {
         expectBlockIds,
         processBlockIds,
         distributionType,
-        expectTaskIds);
+        expectTaskIds,
+        prefetchOption);
     this.shuffleServerClient = shuffleServerClient;
     this.partitionNumPerRange = partitionNumPerRange;
     this.partitionNum = partitionNum;
@@ -98,7 +102,8 @@ public class LocalFileClientReadHandler extends 
DataSkippableReadHandler {
         ShuffleDataDistributionType.NORMAL,
         Roaring64NavigableMap.bitmapOf(),
         1,
-        0);
+        0,
+        Optional.empty());
   }
 
   @Override
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/MemoryClientReadHandler.java
 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/MemoryClientReadHandler.java
index f1fbe2361..b5ec28970 100644
--- 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/MemoryClientReadHandler.java
+++ 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/MemoryClientReadHandler.java
@@ -18,6 +18,7 @@
 package org.apache.uniffle.storage.handler.impl;
 
 import java.util.List;
+import java.util.Optional;
 
 import com.google.common.annotations.VisibleForTesting;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
@@ -32,7 +33,7 @@ import org.apache.uniffle.common.ShuffleDataResult;
 import org.apache.uniffle.common.exception.RssFetchFailedException;
 import org.apache.uniffle.common.util.Constants;
 
-public class MemoryClientReadHandler extends AbstractClientReadHandler {
+public class MemoryClientReadHandler extends PrefetchableClientReadHandler {
 
   private static final Logger LOG = 
LoggerFactory.getLogger(MemoryClientReadHandler.class);
   private long lastBlockId = Constants.INVALID_BLOCK_ID;
@@ -49,7 +50,9 @@ public class MemoryClientReadHandler extends 
AbstractClientReadHandler {
       ShuffleServerClient shuffleServerClient,
       Roaring64NavigableMap expectTaskIds,
       int retryMax,
-      long retryIntervalMax) {
+      long retryIntervalMax,
+      Optional<PrefetchableClientReadHandler.PrefetchOption> prefetchOption) {
+    super(prefetchOption);
     this.appId = appId;
     this.shuffleId = shuffleId;
     this.partitionId = partitionId;
@@ -68,11 +71,20 @@ public class MemoryClientReadHandler extends 
AbstractClientReadHandler {
       int readBufferSize,
       ShuffleServerClient shuffleServerClient,
       Roaring64NavigableMap expectTaskIds) {
-    this(appId, shuffleId, partitionId, readBufferSize, shuffleServerClient, 
expectTaskIds, 1, 0);
+    this(
+        appId,
+        shuffleId,
+        partitionId,
+        readBufferSize,
+        shuffleServerClient,
+        expectTaskIds,
+        1,
+        0,
+        Optional.empty());
   }
 
   @Override
-  public ShuffleDataResult readShuffleData() {
+  public ShuffleDataResult doReadShuffleData() {
     ShuffleDataResult result = null;
 
     RssGetInMemoryShuffleDataRequest request =
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/handler/impl/PrefetchableClientReadHandler.java
 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/PrefetchableClientReadHandler.java
new file mode 100644
index 000000000..0bfafde74
--- /dev/null
+++ 
b/storage/src/main/java/org/apache/uniffle/storage/handler/impl/PrefetchableClientReadHandler.java
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.storage.handler.impl;
+
+import java.util.Optional;
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.Executors;
+import java.util.concurrent.LinkedBlockingQueue;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.ShuffleDataResult;
+import org.apache.uniffle.common.exception.RssException;
+
+public abstract class PrefetchableClientReadHandler extends 
AbstractClientReadHandler {
+  private static final Logger LOG = 
LoggerFactory.getLogger(PrefetchableClientReadHandler.class);
+
+  private boolean prefetchEnabled;
+  private int prefetchQueueCapacity;
+  private int prefetchTimeoutSec;
+  private LinkedBlockingQueue<Optional<ShuffleDataResult>> prefetchResultQueue;
+  private ExecutorService prefetchExecutors;
+  private AtomicBoolean abnormalFetchTag;
+  private AtomicBoolean finishedTag;
+  private AtomicInteger queueingNumber;
+
+  public PrefetchableClientReadHandler(Optional<PrefetchOption> 
prefetchOptional) {
+    if (prefetchOptional.isPresent()) {
+      PrefetchOption option = prefetchOptional.get();
+      if (option.capacity <= 0) {
+        throw new RssException("Illegal prefetch capacity: " + 
option.capacity);
+      }
+      LOG.info("Prefetch is enabled, capacity: {}", option.capacity);
+      this.prefetchEnabled = true;
+      this.prefetchQueueCapacity = option.capacity;
+      this.prefetchTimeoutSec = option.timeoutSec;
+      this.prefetchResultQueue = new LinkedBlockingQueue<>(option.capacity);
+      // todo: support multi threads to prefetch
+      this.prefetchExecutors = Executors.newFixedThreadPool(1);
+      this.abnormalFetchTag = new AtomicBoolean(false);
+      this.finishedTag = new AtomicBoolean(false);
+      this.queueingNumber = new AtomicInteger(0);
+    } else {
+      this.prefetchEnabled = false;
+    }
+  }
+
+  public static class PrefetchOption {
+    private int capacity;
+    private int timeoutSec;
+
+    public PrefetchOption(int capacity, int timeoutSec) {
+      this.capacity = capacity;
+      this.timeoutSec = timeoutSec;
+    }
+  }
+
+  protected abstract ShuffleDataResult doReadShuffleData();
+
+  @Override
+  public ShuffleDataResult readShuffleData() {
+    if (!prefetchEnabled) {
+      return doReadShuffleData();
+    }
+
+    int free = prefetchQueueCapacity - prefetchResultQueue.size() - 
queueingNumber.get();
+    for (int i = 0; i < free; i++) {
+      queueingNumber.incrementAndGet();
+      prefetchExecutors.submit(
+          () -> {
+            try {
+              if (abnormalFetchTag.get() || finishedTag.get()) {
+                return;
+              }
+              ShuffleDataResult result = doReadShuffleData();
+              if (result == null) {
+                this.finishedTag.set(true);
+              }
+              prefetchResultQueue.offer(Optional.ofNullable(result));
+            } catch (Exception e) {
+              abnormalFetchTag.set(true);
+              LOG.error("Errors on doing readShuffleData", e);
+            } finally {
+              queueingNumber.decrementAndGet();
+            }
+          });
+    }
+
+    long start = System.currentTimeMillis();
+    while (true) {
+      if (abnormalFetchTag.get()) {
+        throw new RssException("Fast fail due to the fetch failure");
+      }
+
+      try {
+        Optional<ShuffleDataResult> optionalShuffleDataResult =
+            prefetchResultQueue.poll(10, TimeUnit.MILLISECONDS);
+        if (optionalShuffleDataResult != null) {
+          if (optionalShuffleDataResult.isPresent()) {
+            return optionalShuffleDataResult.get();
+          } else {
+            return null;
+          }
+        }
+      } catch (InterruptedException e) {
+        return null;
+      }
+
+      if (System.currentTimeMillis() - start > prefetchTimeoutSec * 1000) {
+        throw new RssException("Unexpected duration of reading shuffle data. 
Fast fail!");
+      }
+    }
+  }
+
+  @Override
+  public void close() {
+    super.close();
+    if (prefetchExecutors != null) {
+      prefetchExecutors.shutdown();
+    }
+  }
+}
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
 
b/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
index 9b73dc85a..9a36e4ad0 100644
--- 
a/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
+++ 
b/storage/src/main/java/org/apache/uniffle/storage/request/CreateShuffleReadHandlerRequest.java
@@ -18,6 +18,7 @@
 package org.apache.uniffle.storage.request;
 
 import java.util.List;
+import java.util.Optional;
 
 import org.apache.hadoop.conf.Configuration;
 import org.roaringbitmap.longlong.Roaring64NavigableMap;
@@ -26,8 +27,10 @@ import org.apache.uniffle.common.ClientType;
 import org.apache.uniffle.common.ShuffleDataDistributionType;
 import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.config.RssBaseConf;
+import org.apache.uniffle.common.config.RssClientConf;
 import org.apache.uniffle.common.config.RssConf;
 import org.apache.uniffle.common.util.IdHelper;
+import org.apache.uniffle.storage.handler.impl.PrefetchableClientReadHandler;
 
 public class CreateShuffleReadHandlerRequest {
 
@@ -242,4 +245,15 @@ public class CreateShuffleReadHandlerRequest {
   public void setClientType(ClientType clientType) {
     this.clientType = clientType;
   }
+
+  public Optional<PrefetchableClientReadHandler.PrefetchOption> 
getPrefetchOption() {
+    if (clientConf.get(RssClientConf.RSS_CLIENT_PREFETCH_ENABLED)) {
+      return Optional.of(
+          new PrefetchableClientReadHandler.PrefetchOption(
+              clientConf.get(RssClientConf.RSS_CLIENT_PREFETCH_CAPACITY),
+              clientConf.get(RssClientConf.READ_CLIENT_PREFETCH_TIMEOUT_SEC)));
+    } else {
+      return Optional.empty();
+    }
+  }
 }
diff --git 
a/storage/src/test/java/org/apache/uniffle/storage/handler/impl/PrefetchableClientReadHandlerTest.java
 
b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/PrefetchableClientReadHandlerTest.java
new file mode 100644
index 000000000..9adac5a3d
--- /dev/null
+++ 
b/storage/src/test/java/org/apache/uniffle/storage/handler/impl/PrefetchableClientReadHandlerTest.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.storage.handler.impl;
+
+import java.util.Optional;
+import java.util.concurrent.atomic.AtomicInteger;
+
+import org.junit.jupiter.api.Test;
+
+import org.apache.uniffle.common.ShuffleDataResult;
+import org.apache.uniffle.common.exception.RssException;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.fail;
+
+public class PrefetchableClientReadHandlerTest {
+
+  class MockedHandler extends PrefetchableClientReadHandler {
+    private AtomicInteger readNum;
+    private boolean markTimeout;
+    private boolean markFetchFailure;
+
+    MockedHandler(
+        Optional<PrefetchOption> option,
+        int readNum,
+        boolean markTimeout,
+        boolean markFetchFailure) {
+      super(option);
+      this.readNum = new AtomicInteger(readNum);
+      this.markTimeout = markTimeout;
+      this.markFetchFailure = markFetchFailure;
+    }
+
+    @Override
+    protected ShuffleDataResult doReadShuffleData() {
+      if (markFetchFailure) {
+        throw new RssException("");
+      }
+
+      if (markTimeout) {
+        try {
+          Thread.sleep(2 * 1000L);
+        } catch (Exception e) {
+          // ignore
+        }
+      }
+      if (readNum.get() > 0) {
+        readNum.decrementAndGet();
+        return new ShuffleDataResult();
+      }
+      return null;
+    }
+  }
+
+  @Test
+  public void test_with_prefetch() {
+    PrefetchableClientReadHandler handler =
+        new MockedHandler(
+            Optional.of(new PrefetchableClientReadHandler.PrefetchOption(4, 
10)), 10, false, false);
+    int counter = 0;
+    while (true) {
+      if (handler.readShuffleData() != null) {
+        counter += 1;
+      } else {
+        break;
+      }
+    }
+    assertEquals(10, counter);
+  }
+
+  @Test
+  public void test_with_timeout() {
+    try {
+      PrefetchableClientReadHandler handler =
+          new MockedHandler(
+              Optional.of(new PrefetchableClientReadHandler.PrefetchOption(4, 
1)), 10, true, false);
+      handler.readShuffleData();
+      fail();
+    } catch (Exception e) {
+      // ignore
+    }
+  }
+
+  @Test
+  public void test_with_fetch_failure() {
+    try {
+      PrefetchableClientReadHandler handler =
+          new MockedHandler(
+              Optional.of(new PrefetchableClientReadHandler.PrefetchOption(4, 
1)), 10, false, true);
+      handler.readShuffleData();
+      fail();
+    } catch (Exception e) {
+      // ignore
+    }
+  }
+
+  @Test
+  public void test_without_prefetch() {
+    PrefetchableClientReadHandler handler = new 
MockedHandler(Optional.empty(), 10, true, false);
+    handler.readShuffleData();
+  }
+}

Reply via email to