This is an automated email from the ASF dual-hosted git repository.

ethanfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 765265a87 [CELEBORN-2031] Interruption Aware Slot Selection
765265a87 is described below

commit 765265a87d00cbc9e63e5763a158332c6d28bc62
Author: Aravind Patnam <[email protected]>
AuthorDate: Tue Jul 15 17:33:00 2025 +0800

    [CELEBORN-2031] Interruption Aware Slot Selection
    
    ### What changes were proposed in this pull request?
    This PR is part of [CIP17: Interruption Aware Slot 
Selection](https://cwiki.apache.org/confluence/display/CELEBORN/CIP-17%3A+Interruption+Aware+Slot+Selection).
    
    It makes the changes in the slot selection logic to prioritize workers that 
do not have interruption "soon". See more context about the slot selection 
logic 
[here](https://cwiki.apache.org/confluence/pages/viewpage.action?pageId=362056201#CIP17:InterruptionAwareSlotSelection-SlotsAllocator).
    
    ### Why are the changes needed?
    see [CIP17: Interruption Aware Slot 
Selection](https://cwiki.apache.org/confluence/display/CELEBORN/CIP-17%3A+Interruption+Aware+Slot+Selection).
    
    ### Does this PR introduce _any_ user-facing change?
    no
    
    ### How was this patch tested?
    unit tests. This is also already in production in our cluster for last 4-5 
months.
    
    Closes #3347 from akpatnam25/CELEBORN-2031-impl.
    
    Authored-by: Aravind Patnam <[email protected]>
    Signed-off-by: mingji <[email protected]>
---
 .../org/apache/celeborn/common/CelebornConf.scala  |  23 ++
 .../apache/celeborn/common/meta/WorkerInfo.scala   |   3 +
 .../celeborn/common/meta/WorkerInfoSuite.scala     |   5 +
 docs/configuration/master.md                       |   2 +
 .../service/deploy/master/SlotsAllocator.java      | 287 +++++++++++++++++----
 .../celeborn/service/deploy/master/Master.scala    |  11 +-
 .../deploy/master/SlotsAllocatorJmhBenchmark.java  |   9 +-
 .../master/SlotsAllocatorRackAwareSuiteJ.java      |   8 +-
 .../deploy/master/SlotsAllocatorSuiteJ.java        | 261 ++++++++++++++++++-
 9 files changed, 543 insertions(+), 66 deletions(-)

diff --git 
a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala 
b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
index c73832259..af728b830 100644
--- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala
@@ -655,6 +655,10 @@ class CelebornConf(loadDefaults: Boolean) extends 
Cloneable with Logging with Se
   // //////////////////////////////////////////////////////
   def masterSlotAssignPolicy: SlotsAssignPolicy =
     SlotsAssignPolicy.valueOf(get(MASTER_SLOT_ASSIGN_POLICY))
+
+  def masterSlotAssignInterruptionAware: Boolean = 
get(MASTER_SLOT_ASSIGN_INTERRUPTION_AWARE)
+  def masterSlotsAssignInterruptionAwareThreshold: Int =
+    get(MASTER_SLOT_INTERRUPTION_AWARE_THRESHOLD)
   def availableStorageTypes: Int = {
     val types = 
get(ACTIVE_STORAGE_TYPES).split(",").map(StorageInfo.Type.valueOf).toList
     StorageInfo.getAvailableTypes(types.asJava)
@@ -2936,6 +2940,25 @@ object CelebornConf extends Logging {
         SlotsAssignPolicy.LOADAWARE.name))
       .createWithDefault(SlotsAssignPolicy.ROUNDROBIN.name)
 
+  val MASTER_SLOT_ASSIGN_INTERRUPTION_AWARE: ConfigEntry[Boolean] =
+    buildConf("celeborn.master.slot.assign.interruptionAware")
+      .categories("master")
+      .version("0.7.0")
+      .doc("If this is set to true, Celeborn master will prioritize partition 
placement on workers that are not " +
+        "in scope for maintenance soon.")
+      .booleanConf
+      .createWithDefault(false)
+
+  val MASTER_SLOT_INTERRUPTION_AWARE_THRESHOLD: ConfigEntry[Int] =
+    buildConf("celeborn.master.slot.assign.interruptionAware.threshold")
+      .categories("master")
+      .doc("This controls what percentage of hosts would be selected for slot 
selection in the first iteration " +
+        "of creating partitions. Default is 50%.")
+      .version("0.7.0")
+      .intConf
+      .checkValue(v => v >= 0 && v <= 100, "This value must be a percentage.")
+      .createWithDefault(50)
+
   val MASTER_SLOT_ASSIGN_LOADAWARE_DISKGROUP_NUM: ConfigEntry[Int] =
     buildConf("celeborn.master.slot.assign.loadAware.numDiskGroups")
       .withAlternative("celeborn.slots.assign.loadAware.numDiskGroups")
diff --git 
a/common/src/main/scala/org/apache/celeborn/common/meta/WorkerInfo.scala 
b/common/src/main/scala/org/apache/celeborn/common/meta/WorkerInfo.scala
index d6078ba75..26e94e360 100644
--- a/common/src/main/scala/org/apache/celeborn/common/meta/WorkerInfo.scala
+++ b/common/src/main/scala/org/apache/celeborn/common/meta/WorkerInfo.scala
@@ -289,9 +289,12 @@ class WorkerInfo(
        |WorkerRef: $endpoint
        |WorkerStatus: $workerStatus
        |NetworkLocation: $networkLocation
+       |NextInterruptionNotice: ${if (hasInterruptionNotice) 
nextInterruptionNotice else "None"}
        |""".stripMargin
   }
 
+  def hasInterruptionNotice: Boolean = nextInterruptionNotice != Long.MaxValue
+
   override def equals(other: Any): Boolean = other match {
     case that: WorkerInfo =>
       host == that.host &&
diff --git 
a/common/src/test/scala/org/apache/celeborn/common/meta/WorkerInfoSuite.scala 
b/common/src/test/scala/org/apache/celeborn/common/meta/WorkerInfoSuite.scala
index 242856500..a00f4046c 100644
--- 
a/common/src/test/scala/org/apache/celeborn/common/meta/WorkerInfoSuite.scala
+++ 
b/common/src/test/scala/org/apache/celeborn/common/meta/WorkerInfoSuite.scala
@@ -236,6 +236,7 @@ class WorkerInfoSuite extends CelebornFunSuite {
     val worker2 =
       new WorkerInfo("h2", 20001, 20002, 20003, 2000, 20004, null, null)
     worker2.networkLocation = "/1"
+    worker2.nextInterruptionNotice = 1000L
 
     val worker3 = new WorkerInfo(
       "h3",
@@ -312,6 +313,7 @@ class WorkerInfoSuite extends CelebornFunSuite {
            |UserResourceConsumption: empty
            |WorkerRef: null
            |NetworkLocation: /default-rack
+           |NextInterruptionNotice: None
            |""".stripMargin
 
       val exp2 =
@@ -328,6 +330,7 @@ class WorkerInfoSuite extends CelebornFunSuite {
           |UserResourceConsumption: empty
           |WorkerRef: null
           |NetworkLocation: /1
+          |NextInterruptionNotice: 1000
           |""".stripMargin
       val exp3 =
         s"""
@@ -343,6 +346,7 @@ class WorkerInfoSuite extends CelebornFunSuite {
            |UserResourceConsumption: empty
            |WorkerRef: null
            |NetworkLocation: /default-rack
+           |NextInterruptionNotice: None
            |""".stripMargin
       val exp4 =
         s"""
@@ -362,6 +366,7 @@ class WorkerInfoSuite extends CelebornFunSuite {
            |  UserIdentifier: `tenant1`.`name1`, ResourceConsumption: 
ResourceConsumption(diskBytesWritten: 20.0 MiB, diskFileCount: 1, 
hdfsBytesWritten: 50.0 MiB, hdfsFileCount: 1, subResourceConsumptions: 
(application_1697697127390_2171854 -> ResourceConsumption(diskBytesWritten: 
20.0 MiB, diskFileCount: 1, hdfsBytesWritten: 50.0 MiB, hdfsFileCount: 1, 
subResourceConsumptions: empty)))
            |WorkerRef: null
            |NetworkLocation: /default-rack
+           |NextInterruptionNotice: None
            |""".stripMargin
 
       assertEquals(
diff --git a/docs/configuration/master.md b/docs/configuration/master.md
index f24d89e52..fd4af6a8e 100644
--- a/docs/configuration/master.md
+++ b/docs/configuration/master.md
@@ -72,6 +72,8 @@ license: |
 | celeborn.master.rackResolver.refresh.interval | 30s | false | Interval for 
refreshing the node rack information periodically. | 0.5.0 |  | 
 | celeborn.master.send.applicationMeta.threads | 8 | false | Number of threads 
used by the Master to send ApplicationMeta to Workers. | 0.5.0 |  | 
 | celeborn.master.slot.assign.extraSlots | 2 | false | Extra slots number when 
master assign slots. Provided enough workers are available. | 0.3.0 | 
celeborn.slots.assign.extraSlots | 
+| celeborn.master.slot.assign.interruptionAware | false | false | If this is 
set to true, Celeborn master will prioritize partition placement on workers 
that are not in scope for maintenance soon. | 0.7.0 |  | 
+| celeborn.master.slot.assign.interruptionAware.threshold | 50 | false | This 
controls what percentage of hosts would be selected for slot selection in the 
first iteration of creating partitions. Default is 50%. | 0.7.0 |  | 
 | celeborn.master.slot.assign.loadAware.diskGroupGradient | 0.1 | false | This 
value means how many more workload will be placed into a faster disk group than 
a slower group. | 0.3.0 | celeborn.slots.assign.loadAware.diskGroupGradient | 
 | celeborn.master.slot.assign.loadAware.fetchTimeWeight | 1.0 | false | Weight 
of average fetch time when calculating ordering in load-aware assignment 
strategy | 0.3.0 | celeborn.slots.assign.loadAware.fetchTimeWeight | 
 | celeborn.master.slot.assign.loadAware.flushTimeWeight | 0.0 | false | Weight 
of average flush time when calculating ordering in load-aware assignment 
strategy | 0.3.0 | celeborn.slots.assign.loadAware.flushTimeWeight | 
diff --git 
a/master/src/main/java/org/apache/celeborn/service/deploy/master/SlotsAllocator.java
 
b/master/src/main/java/org/apache/celeborn/service/deploy/master/SlotsAllocator.java
index 5e35d881d..450868895 100644
--- 
a/master/src/main/java/org/apache/celeborn/service/deploy/master/SlotsAllocator.java
+++ 
b/master/src/main/java/org/apache/celeborn/service/deploy/master/SlotsAllocator.java
@@ -22,6 +22,7 @@ import java.util.function.IntUnaryOperator;
 import java.util.stream.Collectors;
 
 import scala.Tuple2;
+import scala.Tuple3;
 
 import org.apache.commons.lang3.StringUtils;
 import org.roaringbitmap.RoaringBitmap;
@@ -56,7 +57,9 @@ public class SlotsAllocator {
           List<Integer> partitionIds,
           boolean shouldReplicate,
           boolean shouldRackAware,
-          int availableStorageTypes) {
+          int availableStorageTypes,
+          boolean interruptionAware,
+          int interruptionAwareThreshold) {
     if (partitionIds.isEmpty()) {
       return new HashMap<>();
     }
@@ -101,7 +104,9 @@ public class SlotsAllocator {
         slotsRestrictions,
         shouldReplicate,
         shouldRackAware,
-        availableStorageTypes);
+        availableStorageTypes,
+        interruptionAware,
+        interruptionAwareThreshold);
   }
 
   /**
@@ -119,7 +124,9 @@ public class SlotsAllocator {
           double diskGroupGradient,
           double flushTimeWeight,
           double fetchTimeWeight,
-          int availableStorageTypes) {
+          int availableStorageTypes,
+          boolean interruptionAware,
+          int interruptionAwareThreshold) {
     if (partitionIds.isEmpty()) {
       return new HashMap<>();
     }
@@ -130,7 +137,13 @@ public class SlotsAllocator {
         || StorageInfo.S3Only(availableStorageTypes)
         || StorageInfo.OSSOnly(availableStorageTypes)) {
       return offerSlotsRoundRobin(
-          workers, partitionIds, shouldReplicate, shouldRackAware, 
availableStorageTypes);
+          workers,
+          partitionIds,
+          shouldReplicate,
+          shouldRackAware,
+          availableStorageTypes,
+          interruptionAware,
+          interruptionAwareThreshold);
     }
 
     List<DiskInfo> usableDisks = new ArrayList<>();
@@ -165,7 +178,13 @@ public class SlotsAllocator {
           StringUtils.join(partitionIds, ','),
           noUsableDisks ? "usable disks" : "available slots");
       return offerSlotsRoundRobin(
-          workers, partitionIds, shouldReplicate, shouldRackAware, 
availableStorageTypes);
+          workers,
+          partitionIds,
+          shouldReplicate,
+          shouldRackAware,
+          availableStorageTypes,
+          interruptionAware,
+          interruptionAwareThreshold);
     }
 
     if (!initialized) {
@@ -183,7 +202,9 @@ public class SlotsAllocator {
         slotsRestrictions,
         shouldReplicate,
         shouldRackAware,
-        availableStorageTypes);
+        availableStorageTypes,
+        interruptionAware,
+        interruptionAwareThreshold);
   }
 
   private static StorageInfo getStorageInfo(
@@ -246,11 +267,60 @@ public class SlotsAllocator {
     return storageInfo;
   }
 
+  /**
+   * If interruptionAware = true, select workers based on 2 main criteria: <br>
+   * 1. Workers that have no nextInterruptionNotice are the first priority and 
are included in the
+   * 1st pass for slot selection. <br>
+   * 2. Workers that have a later interruption notice are a little less 
deprioritized, and are
+   * included in the 2nd pass for slot selection. This is determined by 
nextInterruptionNotice above
+   * a certain percentage threshold.<br>
+   * All other workers are considered least priority, and are only included 
for slot selection in
+   * the worst case. <br>
+   */
+  static Tuple3<List<WorkerInfo>, List<WorkerInfo>, List<WorkerInfo>>
+      prioritizeWorkersBasedOnInterruptionNotice(
+          List<WorkerInfo> workers,
+          boolean shouldReplicate,
+          boolean shouldRackAware,
+          double percentileThreshold) {
+    Map<Boolean, List<WorkerInfo>> partitioned =
+        
workers.stream().collect(Collectors.partitioningBy(WorkerInfo::hasInterruptionNotice));
+    List<WorkerInfo> workersWithInterruptions = partitioned.get(true);
+    List<WorkerInfo> workersWithoutInterruptions = partitioned.get(false);
+    // Timestamps towards the boundary of `percentileThreshold` might be the 
same. Given this
+    // is a stable sort, it makes sense to randomize these hosts so that the 
same hosts are not
+    // consistently selected.
+    Collections.shuffle(workersWithInterruptions);
+    workersWithInterruptions.sort(
+        
Comparator.comparingLong(WorkerInfo::nextInterruptionNotice).reversed());
+    int requiredNodes =
+        (int) Math.floor((percentileThreshold * 
workersWithInterruptions.size()) / 100.0);
+
+    List<WorkerInfo> workersWithLateInterruptions =
+        new ArrayList<>(workersWithInterruptions.subList(0, requiredNodes));
+    List<WorkerInfo> workersWithEarlyInterruptions =
+        new ArrayList<>(
+            workersWithInterruptions.subList(requiredNodes, 
workersWithInterruptions.size()));
+    if (shouldReplicate && shouldRackAware) {
+      return Tuple3.apply(
+          generateRackAwareWorkers(workersWithoutInterruptions),
+          Collections.unmodifiableList(workersWithLateInterruptions),
+          generateRackAwareWorkers(workersWithEarlyInterruptions));
+    }
+    return Tuple3.apply(
+        Collections.unmodifiableList(workersWithoutInterruptions),
+        Collections.unmodifiableList(workersWithLateInterruptions),
+        Collections.unmodifiableList(workersWithEarlyInterruptions));
+  }
+
   /**
    * Progressive locate slots for all partitions <br>
-   * 1. try to allocate for all partitions under restrictions <br>
-   * 2. allocate remain partitions to all workers <br>
-   * 3. allocate remain partitions to all workers again without considering 
rack aware <br>
+   * 1. try to allocate for all partitions under restrictions, on workers with 
no interruption
+   * notice if interruptionAware = true. <br>
+   * 2. try to allocate for all partitions, and attempt the replica selection 
to be
+   * interruptionAware if interruptionAware = true <br>
+   * 3. allocate remain partitions to all workers <br>
+   * 4. allocate remain partitions to all workers again without considering 
rack aware <br>
    */
   private static Map<WorkerInfo, Tuple2<List<PartitionLocation>, 
List<PartitionLocation>>>
       locateSlots(
@@ -259,7 +329,9 @@ public class SlotsAllocator {
           Map<WorkerInfo, List<UsableDiskInfo>> slotRestrictions,
           boolean shouldReplicate,
           boolean shouldRackAware,
-          int availableStorageTypes) {
+          int availableStorageTypes,
+          boolean interruptionAware,
+          int interruptionAwareThreshold) {
 
     List<WorkerInfo> workersFromSlotRestrictions = new 
ArrayList<>(slotRestrictions.keySet());
     List<WorkerInfo> workers = workersList;
@@ -270,29 +342,97 @@ public class SlotsAllocator {
 
     Map<WorkerInfo, Tuple2<List<PartitionLocation>, List<PartitionLocation>>> 
slots =
         new HashMap<>();
-
+    List<WorkerInfo> workersWithoutInterruptions;
+    List<WorkerInfo> workersWithLateInterruptions;
+    List<WorkerInfo> workersWithEarlyInterruptions;
+    if (interruptionAware) {
+      Tuple3<List<WorkerInfo>, List<WorkerInfo>, List<WorkerInfo>>
+          workersBasedOnInterruptionNotice =
+              prioritizeWorkersBasedOnInterruptionNotice(
+                  workersFromSlotRestrictions,
+                  shouldReplicate,
+                  shouldRackAware,
+                  interruptionAwareThreshold);
+      workersWithoutInterruptions = workersBasedOnInterruptionNotice._1();
+      workersWithLateInterruptions = workersBasedOnInterruptionNotice._2();
+      workersWithEarlyInterruptions = workersBasedOnInterruptionNotice._3();
+    } else {
+      workersWithoutInterruptions = workersFromSlotRestrictions;
+      workersWithLateInterruptions = null;
+      workersWithEarlyInterruptions = null;
+    }
+    // In the first pass, we try to place all partitions (primary and replica) 
from
+    // `workersWithoutInterruptions`.
     List<Integer> remain =
         roundRobin(
             slots,
             partitionIds,
-            workersFromSlotRestrictions,
+            workersWithoutInterruptions,
+            workersWithoutInterruptions,
             slotRestrictions,
             shouldReplicate,
             shouldRackAware,
-            availableStorageTypes);
+            availableStorageTypes,
+            true);
+    logger.debug(
+        "Remaining number of partitionIds after 1st pass slot selection: {}", 
remain.size());
+    // Do an extra pass for partition placement if interruptionAware = true, 
to see if we can
+    // assign the remaining partitions with slot restriction still set in 
place. The goal during
+    // this pass
+    // is to see if we can place primary from `workersWithoutInterruptions +
+    // workersWithLateInterruptions`, while replica can be in
+    // `workersWithEarlyInterruptions`.
+    // This is to avoid the degenerate case in which both primary and replica 
may end up in
+    // `workersWithEarlyInterruptions`.
+    if (interruptionAware && !remain.isEmpty()) {
+      List<WorkerInfo> primaryWorkerCandidates = new 
ArrayList<>(workersWithoutInterruptions);
+      primaryWorkerCandidates.addAll(workersWithLateInterruptions);
+      if (shouldReplicate && shouldRackAware) {
+        primaryWorkerCandidates = 
generateRackAwareWorkers(primaryWorkerCandidates);
+      }
+      remain =
+          roundRobin(
+              slots,
+              remain,
+              primaryWorkerCandidates,
+              workersWithEarlyInterruptions,
+              null,
+              shouldReplicate,
+              shouldRackAware,
+              availableStorageTypes,
+              false);
+      logger.debug(
+          "Remaining number of partitionIds after 2nd pass slot selection: 
{}", remain.size());
+    }
+    // If partitions are remaining from this point on, and interruptionAware = 
true, then
+    // this becomes the degenerate case where both primary and replica are 
likely chosen on
+    // workers with interruptions that are sooner.
     if (!remain.isEmpty()) {
       remain =
           roundRobin(
               slots,
               remain,
               workers,
+              workers,
               null,
               shouldReplicate,
               shouldRackAware,
-              availableStorageTypes);
+              availableStorageTypes,
+              true);
+      logger.debug(
+          "Remaining number of partitionIds after 3rd pass slot selection: 
{}", remain.size());
     }
     if (!remain.isEmpty()) {
-      roundRobin(slots, remain, workers, null, shouldReplicate, false, 
availableStorageTypes);
+      roundRobin(
+          slots,
+          remain,
+          workers,
+          workers,
+          null,
+          shouldReplicate,
+          false,
+          availableStorageTypes,
+          true);
     }
     return slots;
   }
@@ -341,22 +481,53 @@ public class SlotsAllocator {
     return Collections.unmodifiableList(result);
   }
 
+  /**
+   * Assigns slots in a roundrobin fashion given lists of primary and replica 
worker candidates and
+   * other restrictions.
+   *
+   * @param slots the slots that have been assigned for each partitionId
+   * @param partitionIds the partitionIds that require slot selection still
+   * @param primaryWorkers list of worker candidates that can be used for 
primary workers.
+   * @param replicaWorkers list of worker candidates that can be used for 
replica workers.
+   * @param slotsRestrictions restrictions for each available slot based on 
worker characteristics
+   * @param shouldReplicate if replication is enabled within the cluster
+   * @param shouldRackAware if rack-aware replication is enabled within the 
cluster.
+   * @param availableStorageTypes available storage types coming from the 
offer slots request.
+   * @param skipLocationsOnSameWorkerCheck if the worker candidates list for 
primaries and replicas
+   *     is the same. This is to prevent index mismatch while assigning slots 
across both lists.
+   * @return the partitionIds that were not able to be assigned slots in this 
iteration with the
+   *     current primary and replica worker candidates and slot restrictions.
+   */
   private static List<Integer> roundRobin(
       Map<WorkerInfo, Tuple2<List<PartitionLocation>, 
List<PartitionLocation>>> slots,
       List<Integer> partitionIds,
-      List<WorkerInfo> workers,
+      List<WorkerInfo> primaryWorkers,
+      List<WorkerInfo> replicaWorkers,
       Map<WorkerInfo, List<UsableDiskInfo>> slotsRestrictions,
       boolean shouldReplicate,
       boolean shouldRackAware,
-      int availableStorageTypes) {
+      int availableStorageTypes,
+      boolean skipLocationsOnSameWorkerCheck) {
+    if (primaryWorkers.isEmpty() || (shouldReplicate && 
replicaWorkers.isEmpty())) {
+      return partitionIds;
+    }
     // workerInfo -> (diskIndexForPrimaryAndReplica)
     Map<WorkerInfo, Integer> workerDiskIndex = new HashMap<>();
     List<Integer> partitionIdList = new LinkedList<>(partitionIds);
 
-    final int workerSize = workers.size();
-    final IntUnaryOperator incrementIndex = v -> (v + 1) % workerSize;
-    int primaryIndex = rand.nextInt(workerSize);
-    int replicaIndex = rand.nextInt(workerSize);
+    final int primaryWorkersSize = primaryWorkers.size();
+    final int replicaWorkersSize = replicaWorkers.size();
+    final IntUnaryOperator primaryWorkersIncrementIndex = v -> (v + 1) % 
primaryWorkersSize;
+    int primaryIndex = rand.nextInt(primaryWorkersSize);
+    final IntUnaryOperator replicaWorkersIncrementIndex;
+    int replicaIndex;
+    if (shouldReplicate) {
+      replicaWorkersIncrementIndex = v -> (v + 1) % replicaWorkersSize;
+      replicaIndex = rand.nextInt(replicaWorkersSize);
+    } else {
+      replicaWorkersIncrementIndex = null;
+      replicaIndex = -1;
+    }
 
     ListIterator<Integer> iter = 
partitionIdList.listIterator(partitionIdList.size());
     // Iterate from the end to preserve O(1) removal of processed partitions.
@@ -370,85 +541,103 @@ public class SlotsAllocator {
       StorageInfo storageInfo;
       if (slotsRestrictions != null && !slotsRestrictions.isEmpty()) {
         // this means that we'll select a mount point
-        while (!haveUsableSlots(slotsRestrictions, workers, nextPrimaryInd)) {
-          nextPrimaryInd = incrementIndex.applyAsInt(nextPrimaryInd);
+        while (!haveUsableSlots(slotsRestrictions, primaryWorkers, 
nextPrimaryInd)) {
+          nextPrimaryInd = 
primaryWorkersIncrementIndex.applyAsInt(nextPrimaryInd);
           if (nextPrimaryInd == primaryIndex) {
             break outer;
           }
         }
         storageInfo =
             getStorageInfo(
-                workers, nextPrimaryInd, slotsRestrictions, workerDiskIndex, 
availableStorageTypes);
+                primaryWorkers,
+                nextPrimaryInd,
+                slotsRestrictions,
+                workerDiskIndex,
+                availableStorageTypes);
       } else {
         if (StorageInfo.localDiskAvailable(availableStorageTypes)) {
-          while (!workers.get(nextPrimaryInd).haveDisk()) {
-            nextPrimaryInd = incrementIndex.applyAsInt(nextPrimaryInd);
+          while (!primaryWorkers.get(nextPrimaryInd).haveDisk()) {
+            nextPrimaryInd = 
primaryWorkersIncrementIndex.applyAsInt(nextPrimaryInd);
             if (nextPrimaryInd == primaryIndex) {
               break outer;
             }
           }
         }
         storageInfo =
-            getStorageInfo(workers, nextPrimaryInd, null, workerDiskIndex, 
availableStorageTypes);
+            getStorageInfo(
+                primaryWorkers, nextPrimaryInd, null, workerDiskIndex, 
availableStorageTypes);
       }
       PartitionLocation primaryPartition =
-          createLocation(partitionId, workers.get(nextPrimaryInd), null, 
storageInfo, true);
+          createLocation(partitionId, primaryWorkers.get(nextPrimaryInd), 
null, storageInfo, true);
 
       if (shouldReplicate) {
         int nextReplicaInd = replicaIndex;
         if (slotsRestrictions != null) {
-          while (nextReplicaInd == nextPrimaryInd
-              || !haveUsableSlots(slotsRestrictions, workers, nextReplicaInd)
-              || !satisfyRackAware(shouldRackAware, workers, nextPrimaryInd, 
nextReplicaInd)) {
-            nextReplicaInd = incrementIndex.applyAsInt(nextReplicaInd);
+          while ((nextReplicaInd == nextPrimaryInd && 
skipLocationsOnSameWorkerCheck)
+              || !haveUsableSlots(slotsRestrictions, replicaWorkers, 
nextReplicaInd)
+              || !satisfyRackAware(
+                  shouldRackAware,
+                  primaryWorkers,
+                  nextPrimaryInd,
+                  replicaWorkers,
+                  nextReplicaInd)) {
+            nextReplicaInd = 
replicaWorkersIncrementIndex.applyAsInt(nextReplicaInd);
             if (nextReplicaInd == replicaIndex) {
               break outer;
             }
           }
           storageInfo =
               getStorageInfo(
-                  workers,
+                  replicaWorkers,
                   nextReplicaInd,
                   slotsRestrictions,
                   workerDiskIndex,
                   availableStorageTypes);
         } else if (shouldRackAware) {
-          while (nextReplicaInd == nextPrimaryInd
-              || !satisfyRackAware(true, workers, nextPrimaryInd, 
nextReplicaInd)) {
-            nextReplicaInd = incrementIndex.applyAsInt(nextReplicaInd);
+          while ((nextReplicaInd == nextPrimaryInd && 
skipLocationsOnSameWorkerCheck)
+              || !satisfyRackAware(
+                  true, primaryWorkers, nextPrimaryInd, replicaWorkers, 
nextReplicaInd)) {
+            nextReplicaInd = 
replicaWorkersIncrementIndex.applyAsInt(nextReplicaInd);
             if (nextReplicaInd == replicaIndex) {
               break outer;
             }
           }
         } else {
           if (StorageInfo.localDiskAvailable(availableStorageTypes)) {
-            while (nextReplicaInd == nextPrimaryInd || 
!workers.get(nextReplicaInd).haveDisk()) {
-              nextReplicaInd = incrementIndex.applyAsInt(nextReplicaInd);
+            while ((nextReplicaInd == nextPrimaryInd && 
skipLocationsOnSameWorkerCheck)
+                || !replicaWorkers.get(nextReplicaInd).haveDisk()) {
+              nextReplicaInd = 
replicaWorkersIncrementIndex.applyAsInt(nextReplicaInd);
               if (nextReplicaInd == replicaIndex) {
                 break outer;
               }
             }
           }
           storageInfo =
-              getStorageInfo(workers, nextReplicaInd, null, workerDiskIndex, 
availableStorageTypes);
+              getStorageInfo(
+                  replicaWorkers, nextReplicaInd, null, workerDiskIndex, 
availableStorageTypes);
         }
         PartitionLocation replicaPartition =
             createLocation(
-                partitionId, workers.get(nextReplicaInd), primaryPartition, 
storageInfo, false);
+                partitionId,
+                replicaWorkers.get(nextReplicaInd),
+                primaryPartition,
+                storageInfo,
+                false);
         primaryPartition.setPeer(replicaPartition);
         Tuple2<List<PartitionLocation>, List<PartitionLocation>> locations =
             slots.computeIfAbsent(
-                workers.get(nextReplicaInd),
+                replicaWorkers.get(nextReplicaInd),
                 v -> new Tuple2<>(new ArrayList<>(), new ArrayList<>()));
         locations._2.add(replicaPartition);
-        replicaIndex = incrementIndex.applyAsInt(nextReplicaInd);
+        replicaIndex = replicaWorkersIncrementIndex.applyAsInt(nextReplicaInd);
       }
 
       Tuple2<List<PartitionLocation>, List<PartitionLocation>> locations =
           slots.computeIfAbsent(
-              workers.get(nextPrimaryInd), v -> new Tuple2<>(new 
ArrayList<>(), new ArrayList<>()));
+              primaryWorkers.get(nextPrimaryInd),
+              v -> new Tuple2<>(new ArrayList<>(), new ArrayList<>()));
       locations._1.add(primaryPartition);
-      primaryIndex = incrementIndex.applyAsInt(nextPrimaryInd);
+      primaryIndex = primaryWorkersIncrementIndex.applyAsInt(nextPrimaryInd);
       iter.remove();
     }
     return partitionIdList;
@@ -460,11 +649,15 @@ public class SlotsAllocator {
   }
 
   private static boolean satisfyRackAware(
-      boolean shouldRackAware, List<WorkerInfo> workers, int primaryIndex, int 
nextReplicaInd) {
+      boolean shouldRackAware,
+      List<WorkerInfo> primaryWorkers,
+      int primaryIndex,
+      List<WorkerInfo> replicaWorkers,
+      int nextReplicaInd) {
     return !shouldRackAware
         || !Objects.equals(
-            workers.get(primaryIndex).networkLocation(),
-            workers.get(nextReplicaInd).networkLocation());
+            primaryWorkers.get(primaryIndex).networkLocation(),
+            replicaWorkers.get(nextReplicaInd).networkLocation());
   }
 
   private static void initLoadAwareAlgorithm(int diskGroups, double 
diskGroupGradient) {
diff --git 
a/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala 
b/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala
index d66191dd6..00fbd220e 100644
--- 
a/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala
+++ 
b/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala
@@ -229,6 +229,9 @@ private[celeborn] class Master(
     estimatedPartitionSizeForEstimationUpdateInterval,
     TimeUnit.MILLISECONDS)
   private val slotsAssignPolicy = conf.masterSlotAssignPolicy
+  private val slotsAssignInterruptionAware = 
conf.masterSlotAssignInterruptionAware
+  private val slotsAssignInterruptionAwareThreshold =
+    conf.masterSlotsAssignInterruptionAwareThreshold
 
   private var hadoopFs: util.Map[StorageInfo.Type, FileSystem] = _
   masterSource.addGauge(MasterSource.REGISTERED_SHUFFLE_COUNT) { () =>
@@ -939,14 +942,18 @@ private[celeborn] class Master(
               slotsAssignLoadAwareDiskGroupGradient,
               loadAwareFlushTimeWeight,
               loadAwareFetchTimeWeight,
-              requestSlots.availableStorageTypes)
+              requestSlots.availableStorageTypes,
+              slotsAssignInterruptionAware,
+              slotsAssignInterruptionAwareThreshold)
           } else {
             SlotsAllocator.offerSlotsRoundRobin(
               selectedWorkers,
               requestSlots.partitionIdList,
               requestSlots.shouldReplicate,
               requestSlots.shouldRackAware,
-              requestSlots.availableStorageTypes)
+              requestSlots.availableStorageTypes,
+              slotsAssignInterruptionAware,
+              slotsAssignInterruptionAwareThreshold)
           }
         }
       }
diff --git 
a/master/src/test/java/org/apache/celeborn/service/deploy/master/SlotsAllocatorJmhBenchmark.java
 
b/master/src/test/java/org/apache/celeborn/service/deploy/master/SlotsAllocatorJmhBenchmark.java
index 7ce80fee9..e972a399c 100644
--- 
a/master/src/test/java/org/apache/celeborn/service/deploy/master/SlotsAllocatorJmhBenchmark.java
+++ 
b/master/src/test/java/org/apache/celeborn/service/deploy/master/SlotsAllocatorJmhBenchmark.java
@@ -63,6 +63,7 @@ public class SlotsAllocatorJmhBenchmark {
               diskPartitionToSize,
               PARTITION_SIZE,
               NUM_NETWORK_LOCATIONS,
+              false,
               new Random());
     }
   }
@@ -78,7 +79,13 @@ public class SlotsAllocatorJmhBenchmark {
 
     blackhole.consume(
         SlotsAllocator.offerSlotsRoundRobin(
-            state.workers, state.partitionIds, true, true, 
StorageInfo.ALL_TYPES_AVAILABLE_MASK));
+            state.workers,
+            state.partitionIds,
+            true,
+            true,
+            StorageInfo.ALL_TYPES_AVAILABLE_MASK,
+            false,
+            0));
   }
 
   public static void main(String[] args) throws Exception {
diff --git 
a/master/src/test/java/org/apache/celeborn/service/deploy/master/SlotsAllocatorRackAwareSuiteJ.java
 
b/master/src/test/java/org/apache/celeborn/service/deploy/master/SlotsAllocatorRackAwareSuiteJ.java
index d8d7bfdd1..0e7e34ced 100644
--- 
a/master/src/test/java/org/apache/celeborn/service/deploy/master/SlotsAllocatorRackAwareSuiteJ.java
+++ 
b/master/src/test/java/org/apache/celeborn/service/deploy/master/SlotsAllocatorRackAwareSuiteJ.java
@@ -81,7 +81,7 @@ public class SlotsAllocatorRackAwareSuiteJ {
 
     Map<WorkerInfo, Tuple2<List<PartitionLocation>, List<PartitionLocation>>> 
slots =
         SlotsAllocator.offerSlotsRoundRobin(
-            workers, partitionIds, true, true, 
StorageInfo.ALL_TYPES_AVAILABLE_MASK);
+            workers, partitionIds, true, true, 
StorageInfo.ALL_TYPES_AVAILABLE_MASK, false, 0);
 
     Consumer<PartitionLocation> assertCustomer =
         new Consumer<PartitionLocation>() {
@@ -121,7 +121,7 @@ public class SlotsAllocatorRackAwareSuiteJ {
 
     Map<WorkerInfo, Tuple2<List<PartitionLocation>, List<PartitionLocation>>> 
slots =
         SlotsAllocator.offerSlotsRoundRobin(
-            workers, partitionIds, true, true, 
StorageInfo.ALL_TYPES_AVAILABLE_MASK);
+            workers, partitionIds, true, true, 
StorageInfo.ALL_TYPES_AVAILABLE_MASK, false, 0);
 
     Consumer<PartitionLocation> assertConsumer =
         new Consumer<PartitionLocation>() {
@@ -208,7 +208,7 @@ public class SlotsAllocatorRackAwareSuiteJ {
 
       Map<WorkerInfo, Tuple2<List<PartitionLocation>, 
List<PartitionLocation>>> slots =
           SlotsAllocator.offerSlotsRoundRobin(
-              workers, partitionIds, true, true, 
StorageInfo.ALL_TYPES_AVAILABLE_MASK);
+              workers, partitionIds, true, true, 
StorageInfo.ALL_TYPES_AVAILABLE_MASK, false, 0);
 
       Map<String, Long> numReplicaPerHost =
           slots.entrySet().stream()
@@ -246,7 +246,7 @@ public class SlotsAllocatorRackAwareSuiteJ {
 
         Map<WorkerInfo, Tuple2<List<PartitionLocation>, 
List<PartitionLocation>>> slots =
             SlotsAllocator.offerSlotsRoundRobin(
-                workers, partitionIds, true, true, 
StorageInfo.ALL_TYPES_AVAILABLE_MASK);
+                workers, partitionIds, true, true, 
StorageInfo.ALL_TYPES_AVAILABLE_MASK, false, 0);
 
         Map<String, Long> numReplicaPerHost =
             slots.entrySet().stream()
diff --git 
a/master/src/test/java/org/apache/celeborn/service/deploy/master/SlotsAllocatorSuiteJ.java
 
b/master/src/test/java/org/apache/celeborn/service/deploy/master/SlotsAllocatorSuiteJ.java
index f43acb927..b3b53f221 100644
--- 
a/master/src/test/java/org/apache/celeborn/service/deploy/master/SlotsAllocatorSuiteJ.java
+++ 
b/master/src/test/java/org/apache/celeborn/service/deploy/master/SlotsAllocatorSuiteJ.java
@@ -17,18 +17,19 @@
 
 package org.apache.celeborn.service.deploy.master;
 
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertFalse;
-import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.*;
 
 import java.util.*;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 
 import scala.Tuple2;
+import scala.Tuple3;
 
 import com.google.common.collect.ImmutableMap;
 import org.junit.Test;
+import org.junit.jupiter.params.ParameterizedTest;
+import org.junit.jupiter.params.provider.CsvSource;
 
 import org.apache.celeborn.common.CelebornConf;
 import org.apache.celeborn.common.meta.DiskInfo;
@@ -49,6 +50,7 @@ public class SlotsAllocatorSuiteJ {
             "/mnt/disk3", random.nextInt() + 90 * 1024 * 1024 * 1024L),
         assumedPartitionSize,
         3,
+        false,
         random);
   }
 
@@ -139,7 +141,7 @@ public class SlotsAllocatorSuiteJ {
     }
     final boolean shouldReplicate = true;
 
-    check(workers, partitionIds, shouldReplicate, true, true);
+    check(workers, partitionIds, shouldReplicate, true, true, false, 0);
   }
 
   private void check(
@@ -147,15 +149,17 @@ public class SlotsAllocatorSuiteJ {
       List<Integer> partitionIds,
       boolean shouldReplicate,
       boolean expectSuccess) {
-    check(workers, partitionIds, shouldReplicate, expectSuccess, false);
+    check(workers, partitionIds, shouldReplicate, expectSuccess, false, false, 
0);
   }
 
-  private void check(
+  private Map<WorkerInfo, Tuple2<List<PartitionLocation>, 
List<PartitionLocation>>> check(
       List<WorkerInfo> workers,
       List<Integer> partitionIds,
       boolean shouldReplicate,
       boolean expectSuccess,
-      boolean roundrobin) {
+      boolean roundrobin,
+      boolean interruptionAware,
+      int interruptionAwareThreshold) {
     String shuffleKey = "appId-1";
     CelebornConf conf = new CelebornConf();
     conf.set(CelebornConf.MASTER_SLOT_ASSIGN_LOADAWARE_DISKGROUP_NUM().key(), 
"2");
@@ -164,7 +168,13 @@ public class SlotsAllocatorSuiteJ {
     if (roundrobin) {
       slots =
           SlotsAllocator.offerSlotsRoundRobin(
-              workers, partitionIds, shouldReplicate, false, 
StorageInfo.ALL_TYPES_AVAILABLE_MASK);
+              workers,
+              partitionIds,
+              shouldReplicate,
+              false,
+              StorageInfo.ALL_TYPES_AVAILABLE_MASK,
+              interruptionAware,
+              interruptionAwareThreshold);
     } else {
       slots =
           SlotsAllocator.offerSlotsLoadAware(
@@ -176,7 +186,9 @@ public class SlotsAllocatorSuiteJ {
               conf.masterSlotAssignLoadAwareDiskGroupGradient(),
               conf.masterSlotAssignLoadAwareFlushTimeWeight(),
               conf.masterSlotAssignLoadAwareFetchTimeWeight(),
-              StorageInfo.ALL_TYPES_AVAILABLE_MASK);
+              StorageInfo.ALL_TYPES_AVAILABLE_MASK,
+              interruptionAware,
+              interruptionAwareThreshold);
     }
     if (expectSuccess) {
       if (shouldReplicate) {
@@ -236,6 +248,7 @@ public class SlotsAllocatorSuiteJ {
       assertTrue(
           "Expect to fail to offer slots, but return " + slots.size() + " 
slots.", slots.isEmpty());
     }
+    return slots;
   }
 
   private void checkSlotsOnDFS(
@@ -273,7 +286,7 @@ public class SlotsAllocatorSuiteJ {
     if (roundRobin) {
       slots =
           SlotsAllocator.offerSlotsRoundRobin(
-              workers, partitionIds, shouldReplicate, false, 
availableStorageTypes);
+              workers, partitionIds, shouldReplicate, false, 
availableStorageTypes, false, 0);
     } else {
       slots =
           SlotsAllocator.offerSlotsLoadAware(
@@ -285,7 +298,9 @@ public class SlotsAllocatorSuiteJ {
               0.1,
               0,
               1,
-              StorageInfo.LOCAL_DISK_MASK | availableStorageTypes);
+              StorageInfo.LOCAL_DISK_MASK | availableStorageTypes,
+              false,
+              0);
     }
     int allocatedPartitionCount = 0;
     for (Map.Entry<WorkerInfo, Tuple2<List<PartitionLocation>, 
List<PartitionLocation>>>
@@ -466,12 +481,226 @@ public class SlotsAllocatorSuiteJ {
     checkSlotsOnDFS(workers, partitionIds, shouldReplicate, true, false, 
false, true);
   }
 
+  @ParameterizedTest
+  @CsvSource({"true, true", "true, false", "false, false", "false, true"})
+  public void testInterruptionAwareSlotSelection(boolean shouldReplicate, 
boolean shouldRackAware) {
+    long assumedPartitionSize = 64 * 1024 * 1024;
+    double interruptionAwarePercentileThreshold = 50;
+    Map<String, Long> diskPartitionToSize = new HashMap<>();
+    diskPartitionToSize.put("/mnt/disk", 512 * 1024 * 1024L); // 0.5gb disk 
space
+    // Cluster usable space is 50g, with 25g that will not be interrupted.
+    List<WorkerInfo> workers =
+        basePrepareWorkers(
+            100, true, diskPartitionToSize, assumedPartitionSize, 20, true, 
new Random());
+    Map<String, WorkerInfo> workersMap =
+        workers.stream().collect(Collectors.toMap(WorkerInfo::host, worker -> 
worker));
+    Tuple3<List<WorkerInfo>, List<WorkerInfo>, List<WorkerInfo>> 
prioritization =
+        SlotsAllocator.prioritizeWorkersBasedOnInterruptionNotice(
+            workers, shouldReplicate, shouldRackAware, 
interruptionAwarePercentileThreshold);
+    List<WorkerInfo> workersWithoutInterruptions = prioritization._1();
+    List<WorkerInfo> workersWithLateInterruptions = prioritization._2();
+    List<WorkerInfo> workersWithEarlyInterruptions = prioritization._3();
+    List<String> workersWithoutInterruptionsHosts = 
extractHosts(workersWithoutInterruptions);
+    List<String> workersWithLateInterruptionsHosts = 
extractHosts(workersWithLateInterruptions);
+    List<String> workersWithEarlyInterruptionsHosts = 
extractHosts(workersWithEarlyInterruptions);
+
+    assertEquals(50, workersWithoutInterruptionsHosts.size());
+    assertEquals(25, workersWithLateInterruptionsHosts.size());
+    assertEquals(25, workersWithEarlyInterruptionsHosts.size());
+    IntStream.range(0, 100)
+        .forEach(
+            i -> {
+              String host = "host" + i;
+              if (i % 2 == 0) {
+                assertTrue(workersWithoutInterruptionsHosts.contains(host));
+              } else if (i >= 51) {
+                assertTrue(workersWithLateInterruptionsHosts.contains(host));
+              } else {
+                assertTrue(workersWithEarlyInterruptionsHosts.contains(host));
+              }
+            });
+
+    // With replication enabled: 150 partitions * 128mb (64 primary, 64 
replica) is roughly 19gb.
+    // Both primaries and replicas should fit into 
workersWithoutInterruptions, since 19gb < 25gb
+    // uninterrupted capacity.
+    //
+    // With replication disabled: 150 partitions * 64mb is roughly 9gb.
+    // Similar to the above case, all primaries should fit into 
workersWithoutInterruptions since
+    // 9gb < 25gb uninterrupted capacity.
+    List<Integer> bestCasePartitionIds =
+        IntStream.range(0, 150).boxed().collect(Collectors.toList());
+    Map<WorkerInfo, Tuple2<List<PartitionLocation>, List<PartitionLocation>>>
+        slotsFromBestCasePartitionIds =
+            SlotsAllocator.offerSlotsRoundRobin(
+                workers,
+                bestCasePartitionIds,
+                shouldReplicate,
+                shouldRackAware,
+                StorageInfo.ALL_TYPES_AVAILABLE_MASK,
+                true,
+                (int) interruptionAwarePercentileThreshold);
+    slotsFromBestCasePartitionIds
+        .values()
+        .forEach(
+            primaryReplicaSlots -> {
+              List<PartitionLocation> primarySlots = primaryReplicaSlots._1;
+              List<PartitionLocation> replicaSlots = primaryReplicaSlots._2;
+              assertTrue(
+                  primarySlots.stream()
+                      .map(PartitionLocation::getHost)
+                      .allMatch(workersWithoutInterruptionsHosts::contains));
+              if (shouldReplicate) {
+                assertTrue(
+                    replicaSlots.stream()
+                        .map(PartitionLocation::getHost)
+                        .allMatch(workersWithoutInterruptionsHosts::contains));
+                if (shouldRackAware) {
+                  primarySlots.forEach(
+                      slot -> {
+                        WorkerInfo primary = workersMap.get(slot.getHost());
+                        WorkerInfo replica = 
workersMap.get(slot.getPeer().getHost());
+                        assertNotSame(primary.networkLocation(), 
replica.networkLocation());
+                      });
+                }
+              }
+            });
+
+    List<WorkerInfo> primaryWorkerCandidates =
+        combineWorkers(workersWithoutInterruptions, 
workersWithLateInterruptions);
+    List<String> primaryWorkerCandidatesHosts = 
extractHosts(primaryWorkerCandidates);
+
+    // With replication enabled: 300 partitions * 128mb (64 primary, 64 
replica) is roughly 38gb.
+    // In this case, primaries should be in workersWithoutInterruptions +
+    // workersWithLateInterruptions, while
+    // replicas can spill over into workersWithEarlyInterruptions.
+    //
+    // With replication disabled, we increase partitions to 600 to force this 
case:
+    // 600 partitions * 64mb is roughly 38gb.
+    // Similar to the above case, all primaries should be in 
workersWithoutInterruptions +
+    // workersWithLateInterruptions.
+    List<Integer> spillOverCasePartitionIds;
+    if (shouldReplicate) {
+      spillOverCasePartitionIds = IntStream.range(0, 
300).boxed().collect(Collectors.toList());
+    } else {
+      spillOverCasePartitionIds = IntStream.range(0, 
600).boxed().collect(Collectors.toList());
+    }
+    Map<WorkerInfo, Tuple2<List<PartitionLocation>, List<PartitionLocation>>>
+        slotsFromSpillOverCasePartitionIds =
+            SlotsAllocator.offerSlotsRoundRobin(
+                workers,
+                spillOverCasePartitionIds,
+                shouldReplicate,
+                shouldRackAware,
+                StorageInfo.ALL_TYPES_AVAILABLE_MASK,
+                true,
+                (int) interruptionAwarePercentileThreshold);
+    slotsFromSpillOverCasePartitionIds
+        .values()
+        .forEach(
+            primaryReplicaSlots -> {
+              List<PartitionLocation> primarySlots = primaryReplicaSlots._1;
+              List<PartitionLocation> replicaSlots = primaryReplicaSlots._2;
+              assertTrue(
+                  primarySlots.stream()
+                      .map(PartitionLocation::getHost)
+                      .allMatch(primaryWorkerCandidatesHosts::contains));
+              assertTrue(
+                  primarySlots.stream()
+                      .map(PartitionLocation::getHost)
+                      
.noneMatch(workersWithEarlyInterruptionsHosts::contains));
+              if (shouldReplicate) {
+                assertTrue(
+                    replicaSlots.stream()
+                        .map(PartitionLocation::getHost)
+                        .allMatch(
+                            host ->
+                                primaryWorkerCandidatesHosts.contains(host)
+                                    || 
workersWithEarlyInterruptionsHosts.contains(host)));
+                if (shouldRackAware) {
+                  primarySlots.forEach(
+                      slot -> {
+                        WorkerInfo primary = workersMap.get(slot.getHost());
+                        WorkerInfo replica = 
workersMap.get(slot.getPeer().getHost());
+                        assertNotSame(primary.networkLocation(), 
replica.networkLocation());
+                      });
+                }
+              }
+            });
+    // With the slot restrictions in place for LoadAware, we expect to spill 
replicas into
+    // workersWithEarlyInterruptionsHosts.
+    // But primaries should be in workersWithoutInterruptions + 
workersWithLateInterruptions.
+    Map<WorkerInfo, Tuple2<List<PartitionLocation>, List<PartitionLocation>>>
+        loadAwareBestCasePartitionIdsSlots =
+            check(
+                workers,
+                spillOverCasePartitionIds,
+                shouldReplicate,
+                true,
+                false,
+                true,
+                (int) interruptionAwarePercentileThreshold);
+    loadAwareBestCasePartitionIdsSlots
+        .values()
+        .forEach(
+            primaryReplicaSlots -> {
+              List<PartitionLocation> primarySlots = primaryReplicaSlots._1;
+              List<PartitionLocation> replicaSlots = primaryReplicaSlots._2;
+              assertTrue(
+                  primarySlots.stream()
+                      .map(PartitionLocation::getHost)
+                      .allMatch(primaryWorkerCandidatesHosts::contains));
+              assertTrue(
+                  primarySlots.stream()
+                      .map(PartitionLocation::getHost)
+                      
.noneMatch(workersWithEarlyInterruptionsHosts::contains));
+              if (shouldReplicate) {
+                assertTrue(
+                    replicaSlots.stream()
+                        .map(PartitionLocation::getHost)
+                        .allMatch(
+                            host ->
+                                primaryWorkerCandidatesHosts.contains(host)
+                                    || 
workersWithEarlyInterruptionsHosts.contains(host)));
+                if (shouldRackAware) {
+                  primarySlots.forEach(
+                      slot -> {
+                        WorkerInfo primary = workersMap.get(slot.getHost());
+                        WorkerInfo replica = 
workersMap.get(slot.getPeer().getHost());
+                        assertNotSame(primary.networkLocation(), 
replica.networkLocation());
+                      });
+                }
+              }
+            });
+  }
+
+  @Test
+  public void testInterruptionAwareSlotSelectionWithNoInterruptions() {
+    long assumedPartitionSize = 64 * 1024 * 1024;
+    Map<String, Long> diskPartitionToSize = new HashMap<>();
+    diskPartitionToSize.put("/mnt/disk", 512 * 1024 * 1024L); // 0.5gb disk 
space
+    // Cluster usable space is 50g, with 25g that will not be interrupted.
+    List<WorkerInfo> workers =
+        basePrepareWorkers(
+            100, true, diskPartitionToSize, assumedPartitionSize, 20, false, 
new Random());
+    List<Integer> partitionIds = IntStream.range(0, 
600).boxed().collect(Collectors.toList());
+    check(workers, partitionIds, true, true, false, true, 50);
+  }
+
+  private List<String> extractHosts(List<WorkerInfo> workers) {
+    return workers.stream().map(WorkerInfo::host).collect(Collectors.toList());
+  }
+
+  private List<WorkerInfo> combineWorkers(List<WorkerInfo>... workerLists) {
+    return 
Arrays.stream(workerLists).flatMap(List::stream).collect(Collectors.toList());
+  }
+
   static List<WorkerInfo> basePrepareWorkers(
       int numWorkers,
       boolean hasDisks,
       Map<String, Long> diskPartitionToSize,
       long assumedPartitionSize,
       int numNetworkLocations,
+      boolean hasInterruptions,
       Random random) {
     return IntStream.range(0, numWorkers)
         .mapToObj(
@@ -487,11 +716,19 @@ public class SlotsAllocatorSuiteJ {
                               random.nextInt(1000),
                               random.nextInt(1000),
                               0);
-                      diskInfo.maxSlots_$eq(diskInfo.actualUsableSpace() / 
assumedPartitionSize);
+                      diskInfo.availableSlots_$eq(
+                          diskInfo.actualUsableSpace() / assumedPartitionSize);
                       disks.put(diskMountPoint, diskInfo);
                     });
               }
               WorkerInfo worker = new WorkerInfo("host" + i, i, i, i, i, i, 
disks, null);
+              if (hasInterruptions) {
+                if (i % 2 == 0) {
+                  worker.nextInterruptionNotice_$eq(Long.MAX_VALUE);
+                } else {
+                  worker.nextInterruptionNotice_$eq(i);
+                }
+              }
               worker.networkLocation_$eq(String.valueOf(i % 
numNetworkLocations));
               return worker;
             })

Reply via email to