This is an automated email from the ASF dual-hosted git repository.

xianjin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 53212923 [Improvement] Avoid selecting storage which has reached the 
high watermark (#424)
53212923 is described below

commit 53212923328194fafa5ae94aa216ab04f06fbd13
Author: Junfan Zhang <[email protected]>
AuthorDate: Wed Dec 21 20:08:38 2022 +0800

    [Improvement] Avoid selecting storage which has reached the high watermark 
(#424)
    
    ### What changes were proposed in this pull request?
    1. Replace selecting storage every time with selection cache to avoid 
selection not being idempotent in some cases
    2. Avoid selecting storage which has reached the high watermark, which is 
based on above optimization
    
    ### Why are the changes needed?
    In current codebase, it's possible to select the local storage of reaching 
the high watermark in LocalStorageManager.
    
    This strategy is unreasonable. And it makes many apps fallback to HDFS, 
because they select one high watermark storage.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    1. UTs
---
 .../java/org/apache/uniffle/common/UnionKey.java   |  36 +++---
 .../org/apache/uniffle/common/UnionKeyTest.java    |  49 ++++++++
 .../uniffle/test/ShuffleServerWithLocalTest.java   |  24 ++--
 .../uniffle/server/ShuffleDataReadEvent.java       |  10 +-
 .../uniffle/server/ShuffleServerGrpcService.java   |  24 ++--
 .../apache/uniffle/server/ShuffleTaskManager.java  |  37 +++++-
 .../server/buffer/ShuffleBufferManager.java        |   2 +-
 .../server/storage/LocalStorageManager.java        | 140 ++++++++++++---------
 .../server/storage/LocalStorageManagerTest.java    |  41 +++++-
 .../uniffle/storage/common/LocalStorage.java       |   2 +-
 .../uniffle/storage/common/LocalStorageMeta.java   |   5 +
 11 files changed, 268 insertions(+), 102 deletions(-)

diff --git 
a/server/src/main/java/org/apache/uniffle/server/ShuffleDataReadEvent.java 
b/common/src/main/java/org/apache/uniffle/common/UnionKey.java
similarity index 56%
copy from 
server/src/main/java/org/apache/uniffle/server/ShuffleDataReadEvent.java
copy to common/src/main/java/org/apache/uniffle/common/UnionKey.java
index fb3932fa..18bf1f61 100644
--- a/server/src/main/java/org/apache/uniffle/server/ShuffleDataReadEvent.java
+++ b/common/src/main/java/org/apache/uniffle/common/UnionKey.java
@@ -15,29 +15,31 @@
  * limitations under the License.
  */
 
-package org.apache.uniffle.server;
+package org.apache.uniffle.common;
 
-public class ShuffleDataReadEvent {
+import org.apache.commons.lang3.StringUtils;
 
-  private String appId;
-  private int shuffleId;
-  private int startPartition;
-
-  public ShuffleDataReadEvent(String appId, int shuffleId, int startPartition) 
{
-    this.appId = appId;
-    this.shuffleId = shuffleId;
-    this.startPartition = startPartition;
-  }
+/**
+ * This class is to wrap multi elements to be as union key.
+ */
+public class UnionKey {
+  private static final String SPLIT_KEY = "_";
 
-  public String getAppId() {
-    return appId;
+  public static String buildKey(Object... factors) {
+    return StringUtils.join(factors, SPLIT_KEY);
   }
 
-  public int getShuffleId() {
-    return shuffleId;
+  public static boolean startsWith(String key, Object... factors) {
+    if (key == null) {
+      return false;
+    }
+    return key.startsWith(buildKey(factors));
   }
 
-  public int getStartPartition() {
-    return startPartition;
+  public static boolean sameWith(String key, Object... factors) {
+    if (key == null) {
+      return false;
+    }
+    return key.equals(buildKey(factors));
   }
 }
diff --git a/common/src/test/java/org/apache/uniffle/common/UnionKeyTest.java 
b/common/src/test/java/org/apache/uniffle/common/UnionKeyTest.java
new file mode 100644
index 00000000..3bf3f8d1
--- /dev/null
+++ b/common/src/test/java/org/apache/uniffle/common/UnionKeyTest.java
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common;
+
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class UnionKeyTest {
+
+  @Test
+  public void test() {
+    Object[] elements = new Object[]{
+        "appId",
+        1,
+        1
+    };
+
+    String key = UnionKey.buildKey(elements);
+    assertEquals(key, "appId_1_1");
+
+    assertTrue(UnionKey.sameWith(key, elements));
+    assertFalse(UnionKey.sameWith(null, elements));
+
+    assertFalse(UnionKey.startsWith(null, elements));
+    assertFalse(UnionKey.startsWith(key, new Object[]{"appId", "app"}));
+    assertTrue(UnionKey.startsWith(key, elements));
+    assertTrue(UnionKey.startsWith(key, new Object[]{"appId"}));
+    assertTrue(UnionKey.startsWith(key, new Object[]{"appId", 1}));
+    assertTrue(UnionKey.startsWith(key, new Object[]{"appId", 1, 1}));
+  }
+}
diff --git 
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithLocalTest.java
 
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithLocalTest.java
index 3f142272..2bb31b36 100644
--- 
a/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithLocalTest.java
+++ 
b/integration-test/common/src/test/java/org/apache/uniffle/test/ShuffleServerWithLocalTest.java
@@ -87,9 +87,13 @@ public class ShuffleServerWithLocalTest extends 
ShuffleReadWriteBase {
   public void localWriteReadTest() throws Exception {
     String testAppId = "localWriteReadTest";
     RssRegisterShuffleRequest rrsr = new RssRegisterShuffleRequest(testAppId, 
0,
-        Lists.newArrayList(new PartitionRange(0, 1)), "");
+        Lists.newArrayList(new PartitionRange(0, 0)), "");
     shuffleServerClient.registerShuffle(rrsr);
-    rrsr = new RssRegisterShuffleRequest(testAppId, 0, Lists.newArrayList(new 
PartitionRange(2, 3)), "");
+    rrsr = new RssRegisterShuffleRequest(testAppId, 0, Lists.newArrayList(new 
PartitionRange(1, 1)), "");
+    shuffleServerClient.registerShuffle(rrsr);
+    rrsr = new RssRegisterShuffleRequest(testAppId, 0, Lists.newArrayList(new 
PartitionRange(2, 2)), "");
+    shuffleServerClient.registerShuffle(rrsr);
+    rrsr = new RssRegisterShuffleRequest(testAppId, 0, Lists.newArrayList(new 
PartitionRange(3, 3)), "");
     shuffleServerClient.registerShuffle(rrsr);
 
     Map<Long, byte[]> expectedData = Maps.newHashMap();
@@ -113,20 +117,20 @@ public class ShuffleServerWithLocalTest extends 
ShuffleReadWriteBase {
     final Set<Long> expectedBlockIds3 = transBitmapToSet(bitmaps[2]);
     final Set<Long> expectedBlockIds4 = transBitmapToSet(bitmaps[3]);
     ShuffleDataResult sdr  = readShuffleData(
-        shuffleServerClient, testAppId, 0, 0, 2,
-        10, 1000, 0);
+        shuffleServerClient, testAppId, 0, 0, 1,
+        4, 1000, 0);
     validateResult(sdr, expectedBlockIds1, expectedData, 0);
     sdr  = readShuffleData(
-        shuffleServerClient, testAppId, 0, 1, 2,
-        10, 1000, 0);
+        shuffleServerClient, testAppId, 0, 1, 1,
+        4, 1000, 0);
     validateResult(sdr, expectedBlockIds2, expectedData, 1);
     sdr  = readShuffleData(
-        shuffleServerClient, testAppId, 0, 2, 2,
-        10, 1000, 0);
+        shuffleServerClient, testAppId, 0, 2, 1,
+        4, 1000, 0);
     validateResult(sdr, expectedBlockIds3, expectedData, 2);
     sdr  = readShuffleData(
-        shuffleServerClient, testAppId, 0, 3, 2,
-        10, 1000, 0);
+        shuffleServerClient, testAppId, 0, 3, 1,
+        4, 1000, 0);
     validateResult(sdr, expectedBlockIds4, expectedData, 3);
 
     assertNotNull(shuffleServers.get(0).getShuffleTaskManager()
diff --git 
a/server/src/main/java/org/apache/uniffle/server/ShuffleDataReadEvent.java 
b/server/src/main/java/org/apache/uniffle/server/ShuffleDataReadEvent.java
index fb3932fa..0f1804ef 100644
--- a/server/src/main/java/org/apache/uniffle/server/ShuffleDataReadEvent.java
+++ b/server/src/main/java/org/apache/uniffle/server/ShuffleDataReadEvent.java
@@ -21,12 +21,14 @@ public class ShuffleDataReadEvent {
 
   private String appId;
   private int shuffleId;
+  private int partitionId;
   private int startPartition;
 
-  public ShuffleDataReadEvent(String appId, int shuffleId, int startPartition) 
{
+  public ShuffleDataReadEvent(String appId, int shuffleId, int partitionId, 
int startPartitionOfRange) {
     this.appId = appId;
     this.shuffleId = shuffleId;
-    this.startPartition = startPartition;
+    this.partitionId = partitionId;
+    this.startPartition = startPartitionOfRange;
   }
 
   public String getAppId() {
@@ -37,6 +39,10 @@ public class ShuffleDataReadEvent {
     return shuffleId;
   }
 
+  public int getPartitionId() {
+    return partitionId;
+  }
+
   public int getStartPartition() {
     return startPartition;
   }
diff --git 
a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java 
b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
index 3de560ce..054ffe55 100644
--- 
a/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
+++ 
b/server/src/main/java/org/apache/uniffle/server/ShuffleServerGrpcService.java
@@ -77,7 +77,9 @@ import 
org.apache.uniffle.proto.RssProtos.ShufflePartitionRange;
 import org.apache.uniffle.proto.RssProtos.ShuffleRegisterRequest;
 import org.apache.uniffle.proto.RssProtos.ShuffleRegisterResponse;
 import org.apache.uniffle.proto.ShuffleServerGrpc.ShuffleServerImplBase;
+import org.apache.uniffle.storage.common.Storage;
 import org.apache.uniffle.storage.common.StorageReadMetrics;
+import org.apache.uniffle.storage.util.ShuffleStorageUtils;
 
 public class ShuffleServerGrpcService extends ShuffleServerImplBase {
 
@@ -511,9 +513,15 @@ public class ShuffleServerGrpcService extends 
ShuffleServerImplBase {
     String requestInfo = "appId[" + appId + "], shuffleId[" + shuffleId + "], 
partitionId["
         + partitionId + "]" + "offset[" + offset + "]" + "length[" + length + 
"]";
 
-    shuffleServer.getStorageManager()
-        .selectStorage(new ShuffleDataReadEvent(appId, shuffleId, partitionId))
-        .updateReadMetrics(new StorageReadMetrics(appId, shuffleId));
+    int[] range = ShuffleStorageUtils.getPartitionRange(partitionId, 
partitionNumPerRange, partitionNum);
+    Storage storage = shuffleServer
+        .getStorageManager()
+        .selectStorage(
+            new ShuffleDataReadEvent(appId, shuffleId, partitionId, range[0])
+        );
+    if (storage != null) {
+      storage.updateReadMetrics(new StorageReadMetrics(appId, shuffleId));
+    }
 
     if 
(shuffleServer.getShuffleBufferManager().requireReadMemoryWithRetry(length)) {
       try {
@@ -571,10 +579,12 @@ public class ShuffleServerGrpcService extends 
ShuffleServerImplBase {
     String requestInfo = "appId[" + appId + "], shuffleId[" + shuffleId + "], 
partitionId["
         + partitionId + "]";
 
-    shuffleServer.getStorageManager()
-        .selectStorage(new ShuffleDataReadEvent(appId, shuffleId, partitionId))
-        .updateReadMetrics(new StorageReadMetrics(appId, shuffleId));
-
+    int[] range = ShuffleStorageUtils.getPartitionRange(partitionId, 
partitionNumPerRange, partitionNum);
+    Storage storage = shuffleServer.getStorageManager()
+        .selectStorage(new ShuffleDataReadEvent(appId, shuffleId, partitionId, 
range[0]));
+    if (storage != null) {
+      storage.updateReadMetrics(new StorageReadMetrics(appId, shuffleId));
+    }
     // Index file is expected small size and won't cause oom problem with the 
assumed size. An index segment is 40B,
     // with the default size - 2MB, it can support 50k blocks for shuffle data.
     long assumedFileSize = shuffleServer
diff --git 
a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java 
b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
index 21d855f8..5e74eabf 100644
--- a/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
+++ b/server/src/main/java/org/apache/uniffle/server/ShuffleTaskManager.java
@@ -36,6 +36,7 @@ import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
 import com.google.common.collect.Queues;
+import com.google.common.collect.Range;
 import com.google.common.collect.Sets;
 import org.apache.commons.collections.CollectionUtils;
 import org.roaringbitmap.longlong.LongIterator;
@@ -51,10 +52,12 @@ import org.apache.uniffle.common.ShuffleIndexResult;
 import org.apache.uniffle.common.ShufflePartitionedBlock;
 import org.apache.uniffle.common.ShufflePartitionedData;
 import org.apache.uniffle.common.config.RssBaseConf;
+import org.apache.uniffle.common.exception.FileNotFoundException;
 import org.apache.uniffle.common.util.Constants;
 import org.apache.uniffle.common.util.RssUtils;
 import org.apache.uniffle.common.util.ThreadUtils;
 import org.apache.uniffle.server.buffer.PreAllocatedBufferInfo;
+import org.apache.uniffle.server.buffer.ShuffleBuffer;
 import org.apache.uniffle.server.buffer.ShuffleBufferManager;
 import org.apache.uniffle.server.event.AppPurgeEvent;
 import org.apache.uniffle.server.event.PurgeEvent;
@@ -63,6 +66,7 @@ import org.apache.uniffle.server.storage.StorageManager;
 import org.apache.uniffle.storage.common.Storage;
 import org.apache.uniffle.storage.common.StorageReadMetrics;
 import org.apache.uniffle.storage.request.CreateShuffleReadHandlerRequest;
+import org.apache.uniffle.storage.util.ShuffleStorageUtils;
 
 public class ShuffleTaskManager {
 
@@ -313,9 +317,25 @@ public class ShuffleTaskManager {
   public byte[] getFinishedBlockIds(String appId, Integer shuffleId, 
Set<Integer> partitions) throws IOException {
     refreshAppId(appId);
     for (int partitionId : partitions) {
-      Storage storage = storageManager.selectStorage(new 
ShuffleDataReadEvent(appId, shuffleId, partitionId));
+      Map.Entry<Range<Integer>, ShuffleBuffer> entry =
+          shuffleBufferManager.getShuffleBufferEntry(appId, shuffleId, 
partitionId);
+      if (entry == null) {
+        LOG.error("The empty shuffle buffer, this should not happen. appId: 
{}, shuffleId: {}, partition: {}",
+            appId, shuffleId, partitionId);
+        continue;
+      }
+      Storage storage = storageManager.selectStorage(
+          new ShuffleDataReadEvent(
+              appId,
+              shuffleId,
+              partitionId,
+              entry.getKey().lowerEndpoint()
+          )
+      );
       // update shuffle's timestamp that was recently read.
-      storage.updateReadMetrics(new StorageReadMetrics(appId, shuffleId));
+      if (storage != null) {
+        storage.updateReadMetrics(new StorageReadMetrics(appId, shuffleId));
+      }
     }
     Map<Integer, Roaring64NavigableMap[]> shuffleIdToPartitions = 
partitionsToBlockIds.get(appId);
     if (shuffleIdToPartitions == null) {
@@ -382,7 +402,11 @@ public class ShuffleTaskManager {
     request.setPartitionNum(partitionNum);
     request.setStorageType(storageType);
     request.setRssBaseConf(conf);
-    Storage storage = storageManager.selectStorage(new 
ShuffleDataReadEvent(appId, shuffleId, partitionId));
+    int[] range = ShuffleStorageUtils.getPartitionRange(partitionId, 
partitionNumPerRange, partitionNum);
+    Storage storage = storageManager.selectStorage(new 
ShuffleDataReadEvent(appId, shuffleId, partitionId, range[0]));
+    if (storage == null) {
+      throw new FileNotFoundException("No such data stored in current storage 
manager.");
+    }
 
     return storage.getOrCreateReadHandler(request).getShuffleData(offset, 
length);
   }
@@ -403,8 +427,11 @@ public class ShuffleTaskManager {
     request.setPartitionNum(partitionNum);
     request.setStorageType(storageType);
     request.setRssBaseConf(conf);
-
-    Storage storage = storageManager.selectStorage(new 
ShuffleDataReadEvent(appId, shuffleId, partitionId));
+    int[] range = ShuffleStorageUtils.getPartitionRange(partitionId, 
partitionNumPerRange, partitionNum);
+    Storage storage = storageManager.selectStorage(new 
ShuffleDataReadEvent(appId, shuffleId, partitionId, range[0]));
+    if (storage == null) {
+      throw new FileNotFoundException("No such data in current storage 
manager.");
+    }
     return storage.getOrCreateReadHandler(request).getShuffleIndex();
   }
 
diff --git 
a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java
 
b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java
index 366d10eb..500542a0 100644
--- 
a/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java
+++ 
b/server/src/main/java/org/apache/uniffle/server/buffer/ShuffleBufferManager.java
@@ -132,7 +132,7 @@ public class ShuffleBufferManager {
     shuffleIdToSize.get(shuffleId).addAndGet(size);
   }
 
-  protected Entry<Range<Integer>, ShuffleBuffer> getShuffleBufferEntry(
+  public Entry<Range<Integer>, ShuffleBuffer> getShuffleBufferEntry(
       String appId, int shuffleId, int partitionId) {
     Map<Integer, RangeMap<Integer, ShuffleBuffer>> shuffleIdToBuffers = 
bufferPool.get(appId);
     if (shuffleIdToBuffers == null) {
diff --git 
a/server/src/main/java/org/apache/uniffle/server/storage/LocalStorageManager.java
 
b/server/src/main/java/org/apache/uniffle/server/storage/LocalStorageManager.java
index ee099939..091a32d9 100644
--- 
a/server/src/main/java/org/apache/uniffle/server/storage/LocalStorageManager.java
+++ 
b/server/src/main/java/org/apache/uniffle/server/storage/LocalStorageManager.java
@@ -22,7 +22,9 @@ import java.util.Arrays;
 import java.util.Collection;
 import java.util.Collections;
 import java.util.HashSet;
+import java.util.Iterator;
 import java.util.List;
+import java.util.Map;
 import java.util.Objects;
 import java.util.Optional;
 import java.util.Set;
@@ -30,11 +32,11 @@ import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Executors;
 import java.util.concurrent.atomic.AtomicInteger;
+import java.util.function.Function;
 import java.util.stream.Collectors;
 
 import com.google.common.annotations.VisibleForTesting;
-import com.google.common.collect.Lists;
-import com.google.common.collect.Sets;
+import com.google.common.collect.Maps;
 import org.apache.commons.collections.CollectionUtils;
 import org.apache.commons.lang3.StringUtils;
 import org.apache.hadoop.conf.Configuration;
@@ -42,6 +44,7 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.uniffle.common.RemoteStorageInfo;
+import org.apache.uniffle.common.UnionKey;
 import org.apache.uniffle.common.util.RssUtils;
 import org.apache.uniffle.server.Checker;
 import org.apache.uniffle.server.LocalStorageChecker;
@@ -69,8 +72,8 @@ public class LocalStorageManager extends SingleStorageManager 
{
   private final List<LocalStorage> localStorages;
   private final List<String> storageBasePaths;
   private final LocalStorageChecker checker;
-  private List<LocalStorage> unCorruptedStorages = Lists.newArrayList();
-  private final Set<String> corruptedStorages = Sets.newConcurrentHashSet();
+
+  private final Map<String, LocalStorage> partitionsOfStorage;
 
   @VisibleForTesting
   LocalStorageManager(ShuffleServerConf conf) {
@@ -79,6 +82,7 @@ public class LocalStorageManager extends SingleStorageManager 
{
     if (CollectionUtils.isEmpty(storageBasePaths)) {
       throw new IllegalArgumentException("Base path dirs must not be empty");
     }
+    this.partitionsOfStorage = Maps.newConcurrentMap();
     long shuffleExpiredTimeoutMs = 
conf.get(ShuffleServerConf.SHUFFLE_EXPIRED_TIMEOUT_MS);
     long capacity = conf.getSizeAsBytes(ShuffleServerConf.DISK_CAPACITY);
     double highWaterMarkOfWrite = 
conf.get(ShuffleServerConf.HIGH_WATER_MARK_OF_WRITE);
@@ -139,33 +143,54 @@ public class LocalStorageManager extends 
SingleStorageManager {
 
   @Override
   public Storage selectStorage(ShuffleDataFlushEvent event) {
-    LocalStorage storage = 
localStorages.get(ShuffleStorageUtils.getStorageIndex(
-        localStorages.size(),
-        event.getAppId(),
-        event.getShuffleId(),
-        event.getStartPartition()));
-    if (storage.containsWriteHandler(event.getAppId(), event.getShuffleId(), 
event.getStartPartition())
-        && storage.isCorrupted()) {
-      LOG.error("storage " + storage.getBasePath() + " is corrupted");
-    }
-    if (storage.isCorrupted()) {
-      storage = getRepairedStorage(event.getAppId(), event.getShuffleId(), 
event.getStartPartition());
+    String appId = event.getAppId();
+    int shuffleId = event.getShuffleId();
+    int partitionId = event.getStartPartition();
+
+    LocalStorage storage = partitionsOfStorage.get(UnionKey.buildKey(appId, 
shuffleId, partitionId));
+    if (storage != null) {
+      if (storage.isCorrupted()) {
+        if (storage.containsWriteHandler(appId, shuffleId, partitionId)) {
+          LOG.error("LocalStorage: {} is corrupted. Switching another storage 
for event: {}, some data will be lost",
+              storage.getBasePath(), event);
+        }
+      } else {
+        return storage;
+      }
     }
-    event.setUnderStorage(storage);
-    return storage;
+
+    List<LocalStorage> candidates = localStorages
+        .stream()
+        .filter(x -> x.canWrite() && !x.isCorrupted())
+        .collect(Collectors.toList());
+    final LocalStorage selectedStorage = candidates.get(
+        ShuffleStorageUtils.getStorageIndex(
+            candidates.size(),
+            appId,
+            shuffleId,
+            partitionId
+        )
+    );
+    return partitionsOfStorage.compute(
+        UnionKey.buildKey(appId, shuffleId, partitionId),
+        (key, localStorage) -> {
+          // If this is the first time to select storage or existing storage 
is corrupted,
+          // we should refresh the cache.
+          if (localStorage == null || localStorage.isCorrupted()) {
+            event.setUnderStorage(selectedStorage);
+            return selectedStorage;
+          }
+          return localStorage;
+        });
   }
 
   @Override
   public Storage selectStorage(ShuffleDataReadEvent event) {
+    String appId = event.getAppId();
+    int shuffleId = event.getShuffleId();
+    int partitionId = event.getStartPartition();
 
-    LocalStorage storage = 
localStorages.get(ShuffleStorageUtils.getStorageIndex(
-        localStorages.size(),
-        event.getAppId(),
-        event.getShuffleId(),
-        event.getStartPartition()));
-    if (storage.isCorrupted()) {
-      storage = getRepairedStorage(event.getAppId(), event.getShuffleId(), 
event.getStartPartition());
-    }
+    LocalStorage storage = partitionsOfStorage.get(UnionKey.buildKey(appId, 
shuffleId, partitionId));
     return storage;
   }
 
@@ -186,6 +211,9 @@ public class LocalStorageManager extends 
SingleStorageManager {
     String user = event.getUser();
     List<Integer> shuffleSet = 
Optional.ofNullable(event.getShuffleIds()).orElse(Collections.emptyList());
 
+    // Remove partitions to storage mapping cache
+    cleanupStorageSelectionCache(event);
+
     for (LocalStorage storage : localStorages) {
       if (event instanceof AppPurgeEvent) {
         storage.removeHandlers(appId);
@@ -217,6 +245,37 @@ public class LocalStorageManager extends 
SingleStorageManager {
     deleteHandler.delete(deletePaths.toArray(new String[deletePaths.size()]), 
appId, user);
   }
 
+  private void cleanupStorageSelectionCache(PurgeEvent event) {
+    Function<String, Boolean> deleteConditionFunc = null;
+    if (event instanceof AppPurgeEvent) {
+      deleteConditionFunc = partitionUnionKey -> 
UnionKey.startsWith(partitionUnionKey, event.getAppId());
+    } else if (event instanceof ShufflePurgeEvent) {
+      deleteConditionFunc =
+          partitionUnionKey -> UnionKey.startsWith(
+              partitionUnionKey,
+              event.getAppId(),
+              event.getShuffleIds()
+          );
+    }
+    long startTime = System.currentTimeMillis();
+    deleteElement(
+        partitionsOfStorage,
+        deleteConditionFunc
+    );
+    LOG.info("Cleaning the storage selection cache costs: {}(ms) for event: 
{}",
+        System.currentTimeMillis() - startTime, event);
+  }
+
+  private <K, V> void deleteElement(Map<K, V> map, Function<K, Boolean> 
deleteConditionFunc) {
+    Iterator<Map.Entry<K, V>> iterator = map.entrySet().iterator();
+    while (iterator.hasNext()) {
+      Map.Entry<K, V> entry = iterator.next();
+      if (deleteConditionFunc.apply(entry.getKey())) {
+        iterator.remove();
+      }
+    }
+  }
+
   @Override
   public void registerRemoteStorage(String appId, RemoteStorageInfo 
remoteStorageInfo) {
     // ignore
@@ -246,37 +305,6 @@ public class LocalStorageManager extends 
SingleStorageManager {
     }
   }
 
-  void repair() {
-    boolean hasNewCorruptedStorage = false;
-    for (LocalStorage storage : localStorages) {
-      if (storage.isCorrupted() && 
!corruptedStorages.contains(storage.getBasePath())) {
-        hasNewCorruptedStorage = true;
-        corruptedStorages.add(storage.getBasePath());
-      }
-    }
-    if (hasNewCorruptedStorage) {
-      List<LocalStorage> healthyStorages = Lists.newArrayList();
-      for (LocalStorage storage : localStorages) {
-        if (!storage.isCorrupted()) {
-          healthyStorages.add(storage);
-        }
-      }
-      unCorruptedStorages = healthyStorages;
-    }
-  }
-
-  private synchronized LocalStorage getRepairedStorage(String appId, int 
shuffleId, int partitionId) {
-    repair();
-    if (unCorruptedStorages.isEmpty()) {
-      throw new RuntimeException("No enough storages");
-    }
-    return unCorruptedStorages.get(ShuffleStorageUtils.getStorageIndex(
-        unCorruptedStorages.size(),
-        appId,
-        shuffleId,
-        partitionId));
-  }
-
   public List<LocalStorage> getStorages() {
     return localStorages;
   }
diff --git 
a/server/src/test/java/org/apache/uniffle/server/storage/LocalStorageManagerTest.java
 
b/server/src/test/java/org/apache/uniffle/server/storage/LocalStorageManagerTest.java
index 719b7b2c..be42a215 100644
--- 
a/server/src/test/java/org/apache/uniffle/server/storage/LocalStorageManagerTest.java
+++ 
b/server/src/test/java/org/apache/uniffle/server/storage/LocalStorageManagerTest.java
@@ -71,9 +71,35 @@ public class LocalStorageManagerTest {
     );
   }
 
+  @Test
+  public void testStorageSelectionWhenReachingHighWatermark() {
+    String[] storagePaths = {
+        "/tmp/rss-data1",
+        "/tmp/rss-data2",
+        "/tmp/rss-data3"
+    };
+
+    ShuffleServerConf conf = new ShuffleServerConf();
+    conf.set(ShuffleServerConf.RSS_STORAGE_BASE_PATH, 
Arrays.asList(storagePaths));
+    conf.setLong(ShuffleServerConf.DISK_CAPACITY, 1024L);
+    conf.setString(ShuffleServerConf.RSS_STORAGE_TYPE, 
StorageType.LOCALFILE.name());
+
+    LocalStorageManager localStorageManager = new LocalStorageManager(conf);
+
+    String appId = "testStorageSelectionWhenReachingHighWatermark";
+    ShuffleDataFlushEvent dataFlushEvent = toDataFlushEvent(appId, 1, 1);
+    Storage storage1 = localStorageManager.selectStorage(dataFlushEvent);
+
+    ((LocalStorage) storage1).getMetaData().setSize(999);
+    localStorageManager = new LocalStorageManager(conf);
+    Storage storage2 = localStorageManager.selectStorage(dataFlushEvent);
+
+    assertNotEquals(storage1, storage2);
+  }
+
   @Test
   public void testStorageSelection() {
-    String[] storagePaths = {"/tmp/rss-data1", "/tmp/rss-data2"};
+    String[] storagePaths = {"/tmp/rss-data1", "/tmp/rss-data2", 
"/tmp/rss-data3"};
 
     ShuffleServerConf conf = new ShuffleServerConf();
     conf.set(ShuffleServerConf.RSS_STORAGE_BASE_PATH, 
Arrays.asList(storagePaths));
@@ -94,7 +120,7 @@ public class LocalStorageManagerTest {
     ShuffleDataFlushEvent dataFlushEvent2 = toDataFlushEvent(appId, 1, 1);
     Storage storage2 = localStorageManager.selectStorage(dataFlushEvent2);
 
-    ShuffleDataReadEvent dataReadEvent = new ShuffleDataReadEvent(appId, 1, 1);
+    ShuffleDataReadEvent dataReadEvent = new ShuffleDataReadEvent(appId, 1, 1, 
1);
     Storage storage3 = localStorageManager.selectStorage(dataReadEvent);
     assertEquals(storage1, storage2);
     assertEquals(storage1, storage3);
@@ -108,12 +134,21 @@ public class LocalStorageManagerTest {
 
     // case3: one storage is corrupted when it happened after the original 
event has been written,
     // so it will switch to another storage for write and read event.
-    LocalStorage mockedStorage = spy((LocalStorage)storage1);
+    LocalStorage mockedStorage = spy((LocalStorage)storage4);
     when(mockedStorage.containsWriteHandler(appId, 1, 1)).thenReturn(true);
     Storage storage5 = localStorageManager.selectStorage(dataFlushEvent1);
     Storage storage6 = localStorageManager.selectStorage(dataReadEvent);
     assertNotEquals(storage1, storage5);
+    assertEquals(storage4, storage5);
     assertEquals(storage5, storage6);
+
+    // case4: one storage is corrupted when it happened after the original 
event has been written,
+    // but before reading this partition, another storage corrupted, it still 
could read the original data.
+    Storage storage7 = localStorageManager.selectStorage(dataFlushEvent1);
+    Storage restStorage = storages.stream().filter(x -> !x.isCorrupted() && x 
!= storage7).findFirst().get();
+    ((LocalStorage)restStorage).markCorrupted();
+    Storage storage8 = localStorageManager.selectStorage(dataReadEvent);
+    assertEquals(storage7, storage8);
   }
 
   @Test
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/common/LocalStorage.java 
b/storage/src/main/java/org/apache/uniffle/storage/common/LocalStorage.java
index 2941a3ec..f9f63a3d 100644
--- a/storage/src/main/java/org/apache/uniffle/storage/common/LocalStorage.java
+++ b/storage/src/main/java/org/apache/uniffle/storage/common/LocalStorage.java
@@ -213,7 +213,7 @@ public class LocalStorage extends AbstractStorage {
   }
 
   @VisibleForTesting
-  LocalStorageMeta getMetaData() {
+  public LocalStorageMeta getMetaData() {
     return metaData;
   }
 
diff --git 
a/storage/src/main/java/org/apache/uniffle/storage/common/LocalStorageMeta.java 
b/storage/src/main/java/org/apache/uniffle/storage/common/LocalStorageMeta.java
index cdfea458..81fa6b02 100644
--- 
a/storage/src/main/java/org/apache/uniffle/storage/common/LocalStorageMeta.java
+++ 
b/storage/src/main/java/org/apache/uniffle/storage/common/LocalStorageMeta.java
@@ -27,6 +27,7 @@ import java.util.concurrent.locks.ReadWriteLock;
 import java.util.concurrent.locks.ReentrantReadWriteLock;
 import java.util.stream.Collectors;
 
+import com.google.common.annotations.VisibleForTesting;
 import com.google.common.collect.Maps;
 import org.roaringbitmap.RoaringBitmap;
 import org.slf4j.Logger;
@@ -160,6 +161,10 @@ public class LocalStorageMeta {
     return shuffleMetaMap.keySet();
   }
 
+  @VisibleForTesting
+  public void setSize(long diskSize) {
+    this.size.set(diskSize);
+  }
 
   /**
    *  If the method is implemented as below:

Reply via email to