This is an automated email from the ASF dual-hosted git repository.
nicholasjiang pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new e119902b6 [CELEBORN-2268] Improve test coverage for MEMORY and S3
storage
e119902b6 is described below
commit e119902b6c538789679a277056ea58f91e19f455
Author: Enrico Olivelli <[email protected]>
AuthorDate: Fri Feb 27 15:19:22 2026 +0800
[CELEBORN-2268] Improve test coverage for MEMORY and S3 storage
### What changes were proposed in this pull request?
This commit adds only tests and some useful debug information about using
MEMORY and S3 storage.
### Why are the changes needed?
Because there is not enough code coverage on some configurations that may
happen in production, in particular about:
- using MEMORY storage
- using only S3 storage
- using MEMORY with eviction to S3
There is an interesting case to test: when you configure MEMORY to S3
eviction and the dataset is small.
It is important to ensure that no file is created in S3
### Does this PR resolve a correctness bug?
No.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
It adds new integration tests.
Closes #3608 from eolivelli/fix-eviction-apache.
Authored-by: Enrico Olivelli <[email protected]>
Signed-off-by: SteNicholas <[email protected]>
---
.../spark/shuffle/celeborn/ShuffleManagerSpy.java | 81 ++++
.../spark/shuffle/celeborn/ShuffleManagerSpy.java | 89 ++++
.../celeborn/common/protocol/StorageInfo.java | 14 +-
.../service/deploy/master/SlotsAllocator.java | 16 +-
.../celeborn/service/deploy/master/Master.scala | 7 +-
.../deploy/master/BuildStorageInfoSuiteJ.java | 500 +++++++++++++++++++++
.../spark/EvictMemoryToTieredStorageTest.scala | 279 ++++++++++++
...StorageTest.scala => S3TieredStorageTest.scala} | 30 +-
.../celeborn/tests/spark/SparkTestBase.scala | 35 +-
.../service/deploy/worker/storage/TierWriter.scala | 5 +-
10 files changed, 1016 insertions(+), 40 deletions(-)
diff --git
a/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/ShuffleManagerSpy.java
b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/ShuffleManagerSpy.java
new file mode 100644
index 000000000..82909c4e5
--- /dev/null
+++
b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/ShuffleManagerSpy.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.celeborn;
+
+import java.util.concurrent.atomic.AtomicReference;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.shuffle.ShuffleHandle;
+import org.apache.spark.shuffle.ShuffleReader;
+
+import org.apache.celeborn.client.ShuffleClient;
+import org.apache.celeborn.common.CelebornConf;
+
+public class ShuffleManagerSpy extends SparkShuffleManager {
+
+ private static final AtomicReference<OpenShuffleReaderCallback>
getShuffleReaderHook =
+ new AtomicReference<>();
+ private static final AtomicReference<CelebornConf> configurationHolder = new
AtomicReference<>();
+
+ public ShuffleManagerSpy(SparkConf conf, boolean isDriver) {
+ super(conf, isDriver);
+ }
+
+ @Override
+ public <K, C> ShuffleReader<K, C> getReader(
+ ShuffleHandle handle, int startPartition, int endPartition, TaskContext
context) {
+ OpenShuffleReaderCallback consumer = getShuffleReaderHook.get();
+ if (consumer != null) {
+ CelebornShuffleHandle celebornShuffleHandle = (CelebornShuffleHandle)
handle;
+ ShuffleClient client =
+ ShuffleClient.get(
+ celebornShuffleHandle.appUniqueId(),
+ celebornShuffleHandle.lifecycleManagerHost(),
+ celebornShuffleHandle.lifecycleManagerPort(),
+ configurationHolder.get(),
+ celebornShuffleHandle.userIdentifier());
+ consumer.accept(
+ celebornShuffleHandle.appUniqueId(),
+ celebornShuffleHandle.shuffleId(),
+ client,
+ startPartition,
+ endPartition);
+ }
+ return super.getReader(handle, startPartition, endPartition, context);
+ }
+
+ public interface OpenShuffleReaderCallback {
+ void accept(
+ String appId,
+ Integer shuffleId,
+ ShuffleClient client,
+ Integer startPartition,
+ Integer endPartition);
+ }
+
+ public static void interceptOpenShuffleReader(OpenShuffleReaderCallback
hook, CelebornConf conf) {
+ getShuffleReaderHook.set(hook);
+ configurationHolder.set(conf);
+ }
+
+ public static void resetHook() {
+ getShuffleReaderHook.set(null);
+ configurationHolder.set(null);
+ }
+}
diff --git
a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/ShuffleManagerSpy.java
b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/ShuffleManagerSpy.java
new file mode 100644
index 000000000..1bb197da7
--- /dev/null
+++
b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/ShuffleManagerSpy.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.celeborn;
+
+import java.util.concurrent.atomic.AtomicReference;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.TaskContext;
+import org.apache.spark.shuffle.ShuffleHandle;
+import org.apache.spark.shuffle.ShuffleReadMetricsReporter;
+import org.apache.spark.shuffle.ShuffleReader;
+
+import org.apache.celeborn.client.ShuffleClient;
+import org.apache.celeborn.common.CelebornConf;
+
+public class ShuffleManagerSpy extends SparkShuffleManager {
+
+ private static final AtomicReference<OpenShuffleReaderCallback>
getShuffleReaderHook =
+ new AtomicReference<>();
+ private static final AtomicReference<CelebornConf> configurationHolder = new
AtomicReference<>();
+
+ public ShuffleManagerSpy(SparkConf conf, boolean isDriver) {
+ super(conf, isDriver);
+ }
+
+ @Override
+ public <K, C> ShuffleReader<K, C> getCelebornShuffleReader(
+ ShuffleHandle handle,
+ int startPartition,
+ int endPartition,
+ int startMapIndex,
+ int endMapIndex,
+ TaskContext context,
+ ShuffleReadMetricsReporter metrics) {
+ OpenShuffleReaderCallback consumer = getShuffleReaderHook.get();
+ if (consumer != null) {
+ CelebornShuffleHandle celebornShuffleHandle = (CelebornShuffleHandle)
handle;
+ ShuffleClient client =
+ ShuffleClient.get(
+ celebornShuffleHandle.appUniqueId(),
+ celebornShuffleHandle.lifecycleManagerHost(),
+ celebornShuffleHandle.lifecycleManagerPort(),
+ configurationHolder.get(),
+ celebornShuffleHandle.userIdentifier());
+ consumer.accept(
+ celebornShuffleHandle.appUniqueId(),
+ celebornShuffleHandle.shuffleId(),
+ client,
+ startPartition,
+ endPartition);
+ }
+ return super.getCelebornShuffleReader(
+ handle, startPartition, endPartition, startMapIndex, endMapIndex,
context, metrics);
+ }
+
+ public interface OpenShuffleReaderCallback {
+ void accept(
+ String appId,
+ Integer shuffleId,
+ ShuffleClient client,
+ Integer startPartition,
+ Integer endPartition);
+ }
+
+ public static void interceptOpenShuffleReader(OpenShuffleReaderCallback
hook, CelebornConf conf) {
+ getShuffleReaderHook.set(hook);
+ configurationHolder.set(conf);
+ }
+
+ public static void resetHook() {
+ getShuffleReaderHook.set(null);
+ configurationHolder.set(null);
+ }
+}
diff --git
a/common/src/main/java/org/apache/celeborn/common/protocol/StorageInfo.java
b/common/src/main/java/org/apache/celeborn/common/protocol/StorageInfo.java
index b8d9428c3..1ab97309e 100644
--- a/common/src/main/java/org/apache/celeborn/common/protocol/StorageInfo.java
+++ b/common/src/main/java/org/apache/celeborn/common/protocol/StorageInfo.java
@@ -175,18 +175,22 @@ public class StorageInfo implements Serializable {
+ '}';
}
- public boolean memoryAvailable() {
+ public static boolean memoryAvailable(int availableStorageTypes) {
return availableStorageTypes == ALL_TYPES_AVAILABLE_MASK
|| (availableStorageTypes & MEMORY_MASK) > 0;
}
+ public boolean memoryAvailable() {
+ return memoryAvailable(availableStorageTypes);
+ }
+
public static boolean localDiskAvailable(int availableStorageTypes) {
return availableStorageTypes == ALL_TYPES_AVAILABLE_MASK
|| (availableStorageTypes & LOCAL_DISK_MASK) > 0;
}
public boolean localDiskAvailable() {
- return StorageInfo.localDiskAvailable(availableStorageTypes);
+ return localDiskAvailable(availableStorageTypes);
}
public static boolean HDFSAvailable(int availableStorageTypes) {
@@ -195,7 +199,7 @@ public class StorageInfo implements Serializable {
}
public boolean HDFSAvailable() {
- return StorageInfo.HDFSAvailable(availableStorageTypes);
+ return HDFSAvailable(availableStorageTypes);
}
public static boolean HDFSOnly(int availableStorageTypes) {
@@ -221,11 +225,11 @@ public class StorageInfo implements Serializable {
}
public boolean OSSAvailable() {
- return StorageInfo.OSSAvailable(availableStorageTypes);
+ return OSSAvailable(availableStorageTypes);
}
public boolean S3Available() {
- return StorageInfo.S3Available(availableStorageTypes);
+ return S3Available(availableStorageTypes);
}
@Override
diff --git
a/master/src/main/java/org/apache/celeborn/service/deploy/master/SlotsAllocator.java
b/master/src/main/java/org/apache/celeborn/service/deploy/master/SlotsAllocator.java
index 450868895..5580cd341 100644
---
a/master/src/main/java/org/apache/celeborn/service/deploy/master/SlotsAllocator.java
+++
b/master/src/main/java/org/apache/celeborn/service/deploy/master/SlotsAllocator.java
@@ -207,7 +207,7 @@ public class SlotsAllocator {
interruptionAwareThreshold);
}
- private static StorageInfo getStorageInfo(
+ static StorageInfo buildStorageInfo(
List<WorkerInfo> workers,
int workerIndex,
Map<WorkerInfo, List<UsableDiskInfo>> restrictions,
@@ -260,8 +260,12 @@ public class SlotsAllocator {
storageInfo = new StorageInfo("", StorageInfo.Type.S3,
availableStorageTypes);
} else if (StorageInfo.OSSAvailable(availableStorageTypes)) {
storageInfo = new StorageInfo("", StorageInfo.Type.OSS,
availableStorageTypes);
- } else {
+ } else if (StorageInfo.HDFSAvailable(availableStorageTypes)) {
storageInfo = new StorageInfo("", StorageInfo.Type.HDFS,
availableStorageTypes);
+ } else if (StorageInfo.memoryAvailable(availableStorageTypes)) {
+ storageInfo = new StorageInfo("", StorageInfo.Type.MEMORY,
availableStorageTypes);
+ } else {
+ throw new IllegalStateException("no storage type available");
}
}
return storageInfo;
@@ -548,7 +552,7 @@ public class SlotsAllocator {
}
}
storageInfo =
- getStorageInfo(
+ buildStorageInfo(
primaryWorkers,
nextPrimaryInd,
slotsRestrictions,
@@ -564,7 +568,7 @@ public class SlotsAllocator {
}
}
storageInfo =
- getStorageInfo(
+ buildStorageInfo(
primaryWorkers, nextPrimaryInd, null, workerDiskIndex,
availableStorageTypes);
}
PartitionLocation primaryPartition =
@@ -587,7 +591,7 @@ public class SlotsAllocator {
}
}
storageInfo =
- getStorageInfo(
+ buildStorageInfo(
replicaWorkers,
nextReplicaInd,
slotsRestrictions,
@@ -613,7 +617,7 @@ public class SlotsAllocator {
}
}
storageInfo =
- getStorageInfo(
+ buildStorageInfo(
replicaWorkers, nextReplicaInd, null, workerDiskIndex,
availableStorageTypes);
}
PartitionLocation replicaPartition =
diff --git
a/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala
b/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala
index 0799af45c..346fab74a 100644
---
a/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala
+++
b/master/src/main/scala/org/apache/celeborn/service/deploy/master/Master.scala
@@ -1019,9 +1019,12 @@ private[celeborn] class Master(
Utils.getSlotsPerDisk(slots.asInstanceOf[WorkerResource])
.asScala.map { case (worker, slots) => worker.toUniqueId -> slots
}.asJava,
requestSlots.requestId)
-
+ val primaryLocationsByType = slots.values.asScala
+ .flatMap(entry => entry._1.asScala) // ._1 extracts the primary location
+ .groupBy(l => l.getStorageInfo.getType)
+ .mapValues(locations => locations.size)
var offerSlotsMsg = s"Successfully offered slots for $numReducers reducers
of $shuffleKey" +
- s" on ${slots.size()} workers"
+ s" on ${slots.size()} workers, primary types: $primaryLocationsByType"
val workersNotSelected =
availableWorkers.asScala.filter(!slots.containsKey(_))
val offerSlotsExtraSize = Math.min(
Math.max(
diff --git
a/master/src/test/java/org/apache/celeborn/service/deploy/master/BuildStorageInfoSuiteJ.java
b/master/src/test/java/org/apache/celeborn/service/deploy/master/BuildStorageInfoSuiteJ.java
new file mode 100644
index 000000000..cefd81e1b
--- /dev/null
+++
b/master/src/test/java/org/apache/celeborn/service/deploy/master/BuildStorageInfoSuiteJ.java
@@ -0,0 +1,500 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.service.deploy.master;
+
+import static org.junit.Assert.*;
+
+import java.util.*;
+
+import org.junit.Test;
+
+import org.apache.celeborn.common.meta.DiskInfo;
+import org.apache.celeborn.common.meta.WorkerInfo;
+import org.apache.celeborn.common.protocol.StorageInfo;
+
+/**
+ * Unit tests for {@link SlotsAllocator#buildStorageInfo}.
+ *
+ * <p>The method has two main paths:
+ *
+ * <ol>
+ * <li><b>With restrictions</b> – a restrictions map is provided, and a disk
is selected from the
+ * worker's {@link SlotsAllocator.UsableDiskInfo} list.
+ * <li><b>Without restrictions</b> – the restrictions map is {@code null},
and storage type is
+ * derived from the {@code availableStorageTypes} bitmask.
+ * </ol>
+ */
+public class BuildStorageInfoSuiteJ {
+
+ //
---------------------------------------------------------------------------
+ // Helpers
+ //
---------------------------------------------------------------------------
+
+ private DiskInfo makeDiskInfo(String mountPoint, StorageInfo.Type
storageType) {
+ return new DiskInfo(mountPoint, 10L * 1024 * 1024 * 1024, 100, 100, 0,
storageType);
+ }
+
+ private WorkerInfo makeWorker(String host, Map<String, DiskInfo> disks) {
+ return new WorkerInfo(host, 9001, 9002, 9003, 9004, 9005, disks, null);
+ }
+
+ private Map<WorkerInfo, List<SlotsAllocator.UsableDiskInfo>> restrictionsFor(
+ WorkerInfo worker, List<SlotsAllocator.UsableDiskInfo> diskList) {
+ Map<WorkerInfo, List<SlotsAllocator.UsableDiskInfo>> restrictions = new
HashMap<>();
+ restrictions.put(worker, diskList);
+ return restrictions;
+ }
+
+ //
---------------------------------------------------------------------------
+ // Tests: restrictions != null
+ //
---------------------------------------------------------------------------
+
+ /**
+ * An HDD disk in the restrictions list should produce a StorageInfo with
the disk mount point.
+ */
+ @Test
+ public void testWithRestrictions_HDDDisk() {
+ DiskInfo disk = makeDiskInfo("/mnt/hdd1", StorageInfo.Type.HDD);
+ WorkerInfo worker = makeWorker("host1",
Collections.singletonMap("/mnt/hdd1", disk));
+
+ SlotsAllocator.UsableDiskInfo usable = new
SlotsAllocator.UsableDiskInfo(disk, 10);
+ Map<WorkerInfo, List<SlotsAllocator.UsableDiskInfo>> restrictions =
+ restrictionsFor(worker, Collections.singletonList(usable));
+ Map<WorkerInfo, Integer> workerDiskIndex = new HashMap<>();
+ workerDiskIndex.put(worker, 0);
+
+ StorageInfo result =
+ SlotsAllocator.buildStorageInfo(
+ Collections.singletonList(worker),
+ 0,
+ restrictions,
+ workerDiskIndex,
+ StorageInfo.ALL_TYPES_AVAILABLE_MASK);
+
+ assertEquals(StorageInfo.Type.HDD, result.getType());
+ assertEquals("/mnt/hdd1", result.getMountPoint());
+ assertEquals(StorageInfo.ALL_TYPES_AVAILABLE_MASK,
result.availableStorageTypes);
+ assertEquals(9, usable.usableSlots); // consumed one slot
+ }
+
+ /**
+ * An SSD disk in the restrictions list should produce a StorageInfo with
the disk mount point.
+ */
+ @Test
+ public void testWithRestrictions_SSDDisk() {
+ DiskInfo disk = makeDiskInfo("/mnt/ssd1", StorageInfo.Type.SSD);
+ WorkerInfo worker = makeWorker("host1",
Collections.singletonMap("/mnt/ssd1", disk));
+
+ SlotsAllocator.UsableDiskInfo usable = new
SlotsAllocator.UsableDiskInfo(disk, 5);
+ Map<WorkerInfo, List<SlotsAllocator.UsableDiskInfo>> restrictions =
+ restrictionsFor(worker, Collections.singletonList(usable));
+ Map<WorkerInfo, Integer> workerDiskIndex = new HashMap<>();
+ workerDiskIndex.put(worker, 0);
+
+ StorageInfo result =
+ SlotsAllocator.buildStorageInfo(
+ Collections.singletonList(worker),
+ 0,
+ restrictions,
+ workerDiskIndex,
+ StorageInfo.LOCAL_DISK_MASK);
+
+ assertEquals(StorageInfo.Type.SSD, result.getType());
+ assertEquals("/mnt/ssd1", result.getMountPoint());
+ assertEquals(StorageInfo.LOCAL_DISK_MASK, result.availableStorageTypes);
+ assertEquals(4, usable.usableSlots);
+ }
+
+ /**
+ * An HDFS disk in the restrictions list should produce a StorageInfo with
an empty mount point
+ * and HDFS type, regardless of the actual mount-point string stored in
DiskInfo.
+ */
+ @Test
+ public void testWithRestrictions_HDFSDisk_emptyMountPoint() {
+ DiskInfo disk = makeDiskInfo("HDFS", StorageInfo.Type.HDFS);
+ WorkerInfo worker = makeWorker("host1", Collections.singletonMap("HDFS",
disk));
+
+ SlotsAllocator.UsableDiskInfo usable = new
SlotsAllocator.UsableDiskInfo(disk, 50);
+ Map<WorkerInfo, List<SlotsAllocator.UsableDiskInfo>> restrictions =
+ restrictionsFor(worker, Collections.singletonList(usable));
+ Map<WorkerInfo, Integer> workerDiskIndex = new HashMap<>();
+ workerDiskIndex.put(worker, 0);
+
+ StorageInfo result =
+ SlotsAllocator.buildStorageInfo(
+ Collections.singletonList(worker),
+ 0,
+ restrictions,
+ workerDiskIndex,
+ StorageInfo.HDFS_MASK);
+
+ assertEquals(StorageInfo.Type.HDFS, result.getType());
+ assertEquals("", result.getMountPoint());
+ assertEquals(StorageInfo.HDFS_MASK, result.availableStorageTypes);
+ assertEquals(49, usable.usableSlots);
+ }
+
+ /**
+ * An S3 disk in the restrictions list should produce a StorageInfo with an
empty mount point and
+ * S3 type.
+ */
+ @Test
+ public void testWithRestrictions_S3Disk_emptyMountPoint() {
+ DiskInfo disk = makeDiskInfo("S3", StorageInfo.Type.S3);
+ WorkerInfo worker = makeWorker("host1", Collections.singletonMap("S3",
disk));
+
+ SlotsAllocator.UsableDiskInfo usable = new
SlotsAllocator.UsableDiskInfo(disk, 20);
+ Map<WorkerInfo, List<SlotsAllocator.UsableDiskInfo>> restrictions =
+ restrictionsFor(worker, Collections.singletonList(usable));
+ Map<WorkerInfo, Integer> workerDiskIndex = new HashMap<>();
+ workerDiskIndex.put(worker, 0);
+
+ StorageInfo result =
+ SlotsAllocator.buildStorageInfo(
+ Collections.singletonList(worker),
+ 0,
+ restrictions,
+ workerDiskIndex,
+ StorageInfo.S3_MASK);
+
+ assertEquals(StorageInfo.Type.S3, result.getType());
+ assertEquals("", result.getMountPoint());
+ assertEquals(StorageInfo.S3_MASK, result.availableStorageTypes);
+ assertEquals(19, usable.usableSlots);
+ }
+
+ /**
+ * An OSS disk in the restrictions list should produce a StorageInfo with an
empty mount point and
+ * OSS type.
+ */
+ @Test
+ public void testWithRestrictions_OSSDisk_emptyMountPoint() {
+ DiskInfo disk = makeDiskInfo("OSS", StorageInfo.Type.OSS);
+ WorkerInfo worker = makeWorker("host1", Collections.singletonMap("OSS",
disk));
+
+ SlotsAllocator.UsableDiskInfo usable = new
SlotsAllocator.UsableDiskInfo(disk, 30);
+ Map<WorkerInfo, List<SlotsAllocator.UsableDiskInfo>> restrictions =
+ restrictionsFor(worker, Collections.singletonList(usable));
+ Map<WorkerInfo, Integer> workerDiskIndex = new HashMap<>();
+ workerDiskIndex.put(worker, 0);
+
+ StorageInfo result =
+ SlotsAllocator.buildStorageInfo(
+ Collections.singletonList(worker),
+ 0,
+ restrictions,
+ workerDiskIndex,
+ StorageInfo.OSS_MASK);
+
+ assertEquals(StorageInfo.Type.OSS, result.getType());
+ assertEquals("", result.getMountPoint());
+ assertEquals(StorageInfo.OSS_MASK, result.availableStorageTypes);
+ assertEquals(29, usable.usableSlots);
+ }
+
+ /**
+ * When the first disk in the restrictions list has zero usable slots, the
method must skip it and
+ * select the next disk with available capacity.
+ */
+ @Test
+ public void testWithRestrictions_skipExhaustedDisk() {
+ DiskInfo disk1 = makeDiskInfo("/mnt/disk1", StorageInfo.Type.HDD);
+ DiskInfo disk2 = makeDiskInfo("/mnt/disk2", StorageInfo.Type.HDD);
+ WorkerInfo worker = makeWorker("host1", new HashMap<>());
+
+ SlotsAllocator.UsableDiskInfo exhausted = new
SlotsAllocator.UsableDiskInfo(disk1, 0);
+ SlotsAllocator.UsableDiskInfo active = new
SlotsAllocator.UsableDiskInfo(disk2, 8);
+ List<SlotsAllocator.UsableDiskInfo> diskList = new ArrayList<>();
+ diskList.add(exhausted);
+ diskList.add(active);
+
+ Map<WorkerInfo, List<SlotsAllocator.UsableDiskInfo>> restrictions =
+ restrictionsFor(worker, diskList);
+ Map<WorkerInfo, Integer> workerDiskIndex = new HashMap<>();
+ workerDiskIndex.put(worker, 0); // start at the exhausted disk
+
+ StorageInfo result =
+ SlotsAllocator.buildStorageInfo(
+ Collections.singletonList(worker),
+ 0,
+ restrictions,
+ workerDiskIndex,
+ StorageInfo.ALL_TYPES_AVAILABLE_MASK);
+
+ assertEquals(StorageInfo.Type.HDD, result.getType());
+ assertEquals("/mnt/disk2", result.getMountPoint()); // disk2 was chosen
+ assertEquals(0, exhausted.usableSlots); // disk1 untouched
+ assertEquals(7, active.usableSlots); // disk2 consumed one slot
+ }
+
+ /**
+ * After assigning a slot on a local disk, the disk index stored in {@code
workerDiskIndex} must
+ * advance to the next disk (round-robin).
+ */
+ @Test
+ public void testWithRestrictions_localDiskAdvancesDiskIndex() {
+ DiskInfo disk1 = makeDiskInfo("/mnt/disk1", StorageInfo.Type.HDD);
+ DiskInfo disk2 = makeDiskInfo("/mnt/disk2", StorageInfo.Type.HDD);
+ WorkerInfo worker = makeWorker("host1", new HashMap<>());
+
+ List<SlotsAllocator.UsableDiskInfo> diskList = new ArrayList<>();
+ diskList.add(new SlotsAllocator.UsableDiskInfo(disk1, 10));
+ diskList.add(new SlotsAllocator.UsableDiskInfo(disk2, 10));
+
+ Map<WorkerInfo, List<SlotsAllocator.UsableDiskInfo>> restrictions =
+ restrictionsFor(worker, diskList);
+ Map<WorkerInfo, Integer> workerDiskIndex = new HashMap<>();
+ workerDiskIndex.put(worker, 0);
+
+ SlotsAllocator.buildStorageInfo(
+ Collections.singletonList(worker),
+ 0,
+ restrictions,
+ workerDiskIndex,
+ StorageInfo.ALL_TYPES_AVAILABLE_MASK);
+
+ assertEquals(Integer.valueOf(1), workerDiskIndex.get(worker));
+ }
+
+ /**
+ * For DFS storage types (HDFS, S3, OSS), the disk index in {@code
workerDiskIndex} must NOT be
+ * advanced after slot assignment, because there is only a single logical
endpoint.
+ */
+ @Test
+ public void testWithRestrictions_HDFSDiskDoesNotAdvanceDiskIndex() {
+ DiskInfo disk = makeDiskInfo("HDFS", StorageInfo.Type.HDFS);
+ WorkerInfo worker = makeWorker("host1", new HashMap<>());
+
+ Map<WorkerInfo, List<SlotsAllocator.UsableDiskInfo>> restrictions =
+ restrictionsFor(
+ worker, Collections.singletonList(new
SlotsAllocator.UsableDiskInfo(disk, 10)));
+ Map<WorkerInfo, Integer> workerDiskIndex = new HashMap<>();
+ workerDiskIndex.put(worker, 0);
+
+ SlotsAllocator.buildStorageInfo(
+ Collections.singletonList(worker), 0, restrictions, workerDiskIndex,
StorageInfo.HDFS_MASK);
+
+ assertEquals(Integer.valueOf(0), workerDiskIndex.get(worker));
+ }
+
+ /**
+ * Consecutive calls with two local disks in the restrictions list must
cycle through the disks in
+ * order, demonstrating round-robin selection.
+ */
+ @Test
+ public void testWithRestrictions_roundRobinAcrossLocalDisks() {
+ DiskInfo disk1 = makeDiskInfo("/mnt/disk1", StorageInfo.Type.HDD);
+ DiskInfo disk2 = makeDiskInfo("/mnt/disk2", StorageInfo.Type.HDD);
+ WorkerInfo worker = makeWorker("host1", new HashMap<>());
+
+ SlotsAllocator.UsableDiskInfo usable1 = new
SlotsAllocator.UsableDiskInfo(disk1, 10);
+ SlotsAllocator.UsableDiskInfo usable2 = new
SlotsAllocator.UsableDiskInfo(disk2, 10);
+ List<SlotsAllocator.UsableDiskInfo> diskList = new ArrayList<>();
+ diskList.add(usable1);
+ diskList.add(usable2);
+
+ Map<WorkerInfo, List<SlotsAllocator.UsableDiskInfo>> restrictions =
+ restrictionsFor(worker, diskList);
+ Map<WorkerInfo, Integer> workerDiskIndex = new HashMap<>();
+ workerDiskIndex.put(worker, 0);
+ List<WorkerInfo> workers = Collections.singletonList(worker);
+
+ // First call: disk at index 0
+ StorageInfo result1 =
+ SlotsAllocator.buildStorageInfo(
+ workers, 0, restrictions, workerDiskIndex,
StorageInfo.ALL_TYPES_AVAILABLE_MASK);
+ assertEquals("/mnt/disk1", result1.getMountPoint());
+ assertEquals(9, usable1.usableSlots);
+ assertEquals(Integer.valueOf(1), workerDiskIndex.get(worker));
+
+ // Second call: disk at index 1
+ StorageInfo result2 =
+ SlotsAllocator.buildStorageInfo(
+ workers, 0, restrictions, workerDiskIndex,
StorageInfo.ALL_TYPES_AVAILABLE_MASK);
+ assertEquals("/mnt/disk2", result2.getMountPoint());
+ assertEquals(9, usable2.usableSlots);
+ assertEquals(Integer.valueOf(0), workerDiskIndex.get(worker)); // wrapped
around
+ }
+
+ //
---------------------------------------------------------------------------
+ // Tests: restrictions == null
+ //
---------------------------------------------------------------------------
+
+ /**
+ * When restrictions are {@code null} and all storage types are available
(mask = 0), the method
+ * must pick a local disk from the worker's disk map.
+ */
+ @Test
+ public void testWithoutRestrictions_allTypesAvailable_picksLocalDisk() {
+ DiskInfo disk = makeDiskInfo("/mnt/hdd1", StorageInfo.Type.HDD);
+ WorkerInfo worker = makeWorker("host1",
Collections.singletonMap("/mnt/hdd1", disk));
+ Map<WorkerInfo, Integer> workerDiskIndex = new HashMap<>();
+ workerDiskIndex.put(worker, 0);
+
+ StorageInfo result =
+ SlotsAllocator.buildStorageInfo(
+ Collections.singletonList(worker),
+ 0,
+ null,
+ workerDiskIndex,
+ StorageInfo.ALL_TYPES_AVAILABLE_MASK);
+
+ assertEquals(StorageInfo.Type.HDD, result.getType());
+ assertEquals("/mnt/hdd1", result.getMountPoint());
+ assertEquals(StorageInfo.ALL_TYPES_AVAILABLE_MASK,
result.availableStorageTypes);
+ }
+
+ /**
+ * When restrictions are {@code null} and only LOCAL_DISK_MASK is set, the
method must pick an SSD
+ * disk from the worker and record its mount point.
+ */
+ @Test
+ public void testWithoutRestrictions_localDiskMask() {
+ DiskInfo disk = makeDiskInfo("/mnt/ssd1", StorageInfo.Type.SSD);
+ WorkerInfo worker = makeWorker("host1",
Collections.singletonMap("/mnt/ssd1", disk));
+ Map<WorkerInfo, Integer> workerDiskIndex = new HashMap<>();
+ workerDiskIndex.put(worker, 0);
+
+ StorageInfo result =
+ SlotsAllocator.buildStorageInfo(
+ Collections.singletonList(worker),
+ 0,
+ null,
+ workerDiskIndex,
+ StorageInfo.LOCAL_DISK_MASK);
+
+ assertEquals(StorageInfo.Type.SSD, result.getType());
+ assertEquals("/mnt/ssd1", result.getMountPoint());
+ assertEquals(StorageInfo.LOCAL_DISK_MASK, result.availableStorageTypes);
+ }
+
+ /**
+ * When restrictions are {@code null} and only S3_MASK is set, the method
must return a
+ * StorageInfo with empty mount point and S3 type without touching any
worker disks.
+ */
+ @Test
+ public void testWithoutRestrictions_S3Only() {
+ WorkerInfo worker = makeWorker("host1", new HashMap<>());
+
+ StorageInfo result =
+ SlotsAllocator.buildStorageInfo(
+ Collections.singletonList(worker), 0, null, new HashMap<>(),
StorageInfo.S3_MASK);
+
+ assertEquals(StorageInfo.Type.S3, result.getType());
+ assertEquals("", result.getMountPoint());
+ assertEquals(StorageInfo.S3_MASK, result.availableStorageTypes);
+ }
+
+ /**
+ * When restrictions are {@code null} and only OSS_MASK is set, the method
must return a
+ * StorageInfo with empty mount point and OSS type.
+ */
+ @Test
+ public void testWithoutRestrictions_OSSOnly() {
+ WorkerInfo worker = makeWorker("host1", new HashMap<>());
+
+ StorageInfo result =
+ SlotsAllocator.buildStorageInfo(
+ Collections.singletonList(worker), 0, null, new HashMap<>(),
StorageInfo.OSS_MASK);
+
+ assertEquals(StorageInfo.Type.OSS, result.getType());
+ assertEquals("", result.getMountPoint());
+ assertEquals(StorageInfo.OSS_MASK, result.availableStorageTypes);
+ }
+
+ /**
+ * When restrictions are {@code null} and only HDFS_MASK is set, the method
must return a
+ * StorageInfo with empty mount point and HDFS type.
+ */
+ @Test
+ public void testWithoutRestrictions_HDFSOnly() {
+ WorkerInfo worker = makeWorker("host1", new HashMap<>());
+
+ StorageInfo result =
+ SlotsAllocator.buildStorageInfo(
+ Collections.singletonList(worker), 0, null, new HashMap<>(),
StorageInfo.HDFS_MASK);
+
+ assertEquals(StorageInfo.Type.HDFS, result.getType());
+ assertEquals("", result.getMountPoint());
+ assertEquals(StorageInfo.HDFS_MASK, result.availableStorageTypes);
+ }
+
+ /**
+ * When restrictions are {@code null} and only MEMORY_MASK is set, the
method must return a
+ * StorageInfo with empty mount point and MEMORY type.
+ */
+ @Test
+ public void testWithoutRestrictions_memoryOnly() {
+ WorkerInfo worker = makeWorker("host1", new HashMap<>());
+
+ StorageInfo result =
+ SlotsAllocator.buildStorageInfo(
+ Collections.singletonList(worker), 0, null, new HashMap<>(),
StorageInfo.MEMORY_MASK);
+
+ assertEquals(StorageInfo.Type.MEMORY, result.getType());
+ assertEquals("", result.getMountPoint());
+ assertEquals(StorageInfo.MEMORY_MASK, result.availableStorageTypes);
+ }
+
+ /**
+ * When restrictions are {@code null} and the bitmask does not correspond to
any known storage
+ * type, the method must throw {@link IllegalStateException}.
+ */
+ @Test(expected = IllegalStateException.class)
+ public void testWithoutRestrictions_noValidStorageType_throwsIllegalState() {
+ WorkerInfo worker = makeWorker("host1", new HashMap<>());
+ // 0b100000 = 32 has none of the bits used by known storage types
+ int unknownMask = 0b100000;
+ SlotsAllocator.buildStorageInfo(
+ Collections.singletonList(worker), 0, null, new HashMap<>(),
unknownMask);
+ }
+
+ /**
+ * When restrictions are {@code null} and a local disk is selected, the disk
index in {@code
+ * workerDiskIndex} must advance so that the next call picks the following
disk (round-robin).
+ */
+ @Test
+ public void testWithoutRestrictions_localDiskAdvancesDiskIndex() {
+ DiskInfo disk1 = makeDiskInfo("/mnt/disk1", StorageInfo.Type.HDD);
+ DiskInfo disk2 = makeDiskInfo("/mnt/disk2", StorageInfo.Type.HDD);
+
+ // Use a LinkedHashMap so that disk1 comes before disk2 during stream
iteration.
+ Map<String, DiskInfo> disks = new LinkedHashMap<>();
+ disks.put("/mnt/disk1", disk1);
+ disks.put("/mnt/disk2", disk2);
+ WorkerInfo worker = makeWorker("host1", disks);
+
+ Map<WorkerInfo, Integer> workerDiskIndex = new HashMap<>();
+ workerDiskIndex.put(worker, 0); // force first disk
+ List<WorkerInfo> workers = Collections.singletonList(worker);
+
+ // First call selects disk1 and advances the index to 1
+ StorageInfo result1 =
+ SlotsAllocator.buildStorageInfo(
+ workers, 0, null, workerDiskIndex,
StorageInfo.ALL_TYPES_AVAILABLE_MASK);
+ assertEquals("/mnt/disk1", result1.getMountPoint());
+ assertEquals(Integer.valueOf(1), workerDiskIndex.get(worker));
+
+ // Second call selects disk2 and wraps the index back to 0
+ StorageInfo result2 =
+ SlotsAllocator.buildStorageInfo(
+ workers, 0, null, workerDiskIndex,
StorageInfo.ALL_TYPES_AVAILABLE_MASK);
+ assertEquals("/mnt/disk2", result2.getMountPoint());
+ assertEquals(Integer.valueOf(0), workerDiskIndex.get(worker));
+ }
+}
diff --git
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/EvictMemoryToTieredStorageTest.scala
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/EvictMemoryToTieredStorageTest.scala
new file mode 100644
index 000000000..567a5b283
--- /dev/null
+++
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/EvictMemoryToTieredStorageTest.scala
@@ -0,0 +1,279 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.celeborn.tests.spark
+
+import java.util.concurrent.CopyOnWriteArrayList
+
+import scala.collection.JavaConverters._
+import scala.collection.immutable
+import scala.util.Random
+
+import org.apache.commons.lang3.StringUtils
+import org.apache.spark.{SparkConf, SparkException}
+import org.apache.spark.shuffle.celeborn.{CelebornShuffleHandle,
ShuffleManagerSpy}
+import
org.apache.spark.shuffle.celeborn.ShuffleManagerSpy.OpenShuffleReaderCallback
+import org.apache.spark.sql.SparkSession
+import org.scalatest.BeforeAndAfterEach
+import org.scalatest.funsuite.AnyFunSuite
+import org.testcontainers.containers.MinIOContainer
+
+import org.apache.celeborn.client.ShuffleClient
+import org.apache.celeborn.common.CelebornConf
+import org.apache.celeborn.common.protocol.{PartitionLocation, ShuffleMode}
+import org.apache.celeborn.common.protocol.StorageInfo.Type
+
+class EvictMemoryToTieredStorageTest extends AnyFunSuite
+ with SparkTestBase
+ with BeforeAndAfterEach {
+
+ private var container: MinIOContainer = _;
+ private val seenPartitionLocationsOpenReader:
CopyOnWriteArrayList[PartitionLocation] =
+ new CopyOnWriteArrayList[PartitionLocation]()
+ private val seenPartitionLocationsUpdateFileGroups:
CopyOnWriteArrayList[PartitionLocation] =
+ new CopyOnWriteArrayList[PartitionLocation]()
+
+ override def beforeAll(): Unit = {
+
+ if (!isS3LibraryAvailable)
+ return
+
+ container = new MinIOContainer("minio/minio:RELEASE.2023-09-04T19-57-37Z");
+ container.start()
+
+ // create bucket using Minio command line tool
+ container.execInContainer(
+ "mc",
+ "alias",
+ "set",
+ "dockerminio",
+ "http://minio:9000",
+ container.getUserName,
+ container.getPassword)
+ container.execInContainer("mc", "mb", "dockerminio/sample-bucket")
+
+ System.setProperty("aws.accessKeyId", container.getUserName)
+ System.setProperty("aws.secretKey", container.getPassword)
+
+ val s3url = container.getS3URL
+ val augmentedConfiguration = Map(
+ CelebornConf.ACTIVE_STORAGE_TYPES.key -> "MEMORY,HDD,S3",
+ CelebornConf.WORKER_STORAGE_CREATE_FILE_POLICY.key -> "MEMORY,HDD,S3",
+ // CelebornConf.WORKER_STORAGE_EVICT_POLICY.key -> "MEMORY,S3",
+ // note that in S3 (and Minio) you cannot upload parts smaller than 5MB,
so we trigger eviction only when there
+ // is enough data
+ CelebornConf.WORKER_MEMORY_FILE_STORAGE_MAX_FILE_SIZE.key -> "5MB",
+ "celeborn.worker.directMemoryRatioForMemoryFileStorage" -> "0.2", //
this is needed to use MEMORY storage
+ "celeborn.hadoop.fs.s3a.endpoint" -> s"$s3url",
+ "celeborn.hadoop.fs.s3a.aws.credentials.provider" ->
"org.apache.hadoop.fs.s3a.SimpleAWSCredentialsProvider",
+ "celeborn.hadoop.fs.s3a.access.key" -> container.getUserName,
+ "celeborn.hadoop.fs.s3a.secret.key" -> container.getPassword,
+ "celeborn.hadoop.fs.s3a.path.style.access" -> "true",
+ CelebornConf.S3_DIR.key -> "s3://sample-bucket/test/celeborn",
+ CelebornConf.S3_ENDPOINT_REGION.key -> "dummy-region")
+
+ setupMiniClusterWithRandomPorts(
+ masterConf = augmentedConfiguration,
+ workerConf = augmentedConfiguration,
+ workerNum = 1)
+
+ interceptLocationsSeenByClient()
+ }
+
+ override def beforeEach(): Unit = {
+ ShuffleClient.reset()
+ seenPartitionLocationsOpenReader.clear()
+ seenPartitionLocationsUpdateFileGroups.clear()
+ }
+
+ override def afterAll(): Unit = {
+ System.clearProperty("aws.accessKeyId")
+ System.clearProperty("aws.secretKey")
+ if (container != null) {
+ container.close()
+ super.afterAll()
+ }
+ ShuffleManagerSpy.resetHook()
+ }
+
+ def updateSparkConfWithStorageTypes(
+ sparkConf: SparkConf,
+ mode: ShuffleMode,
+ storageTypes: String): SparkConf = {
+ val s3url = container.getS3URL
+ val newConf = sparkConf
+ .set("spark." + CelebornConf.ACTIVE_STORAGE_TYPES.key, storageTypes)
+ .set("spark." + CelebornConf.S3_DIR.key,
"s3://sample-bucket/test/celeborn")
+ .set("spark." + CelebornConf.S3_ENDPOINT_REGION.key, "dummy-region")
+ .set(
+ "spark." + CelebornConf.SHUFFLE_COMPRESSION_CODEC.key,
+ "none"
+ ) // we want predictable shuffle data size
+ .set("spark.celeborn.hadoop.fs.s3a.endpoint", s"$s3url")
+ .set(
+ "spark.celeborn.hadoop.fs.s3a.aws.credentials.provider",
+ "org.apache.hadoop.fs.s3a.SimpleAWSCredentialsProvider")
+ .set("spark.celeborn.hadoop.fs.s3a.access.key", container.getUserName)
+ .set("spark.celeborn.hadoop.fs.s3a.secret.key", container.getPassword)
+ .set("spark.celeborn.hadoop.fs.s3a.path.style.access", "true")
+
+ super.updateSparkConf(newConf, mode)
+
+ sparkConf.set("spark.shuffle.manager",
"org.apache.spark.shuffle.celeborn.ShuffleManagerSpy")
+ }
+
+ def assumeS3LibraryIsLoaded(): Unit = {
+ assume(
+ isS3LibraryAvailable,
+ "Skipping test because AWS Hadoop client is not in the classpath(enable
with -Paws)")
+ }
+
+ test("celeborn spark integration test - only memory") {
+ assumeS3LibraryIsLoaded()
+
+ val sparkConf = new
SparkConf().setAppName("celeborn-demo").setMaster("local[2]")
+ val celebornSparkSession = SparkSession.builder()
+ .config(updateSparkConfWithStorageTypes(sparkConf, ShuffleMode.HASH,
"MEMORY"))
+ .getOrCreate()
+ repartition(celebornSparkSession, partitions = 1)
+ // MEMORY partitions are not seen when opening the reader, but they are
seen when discovering the actual locations
+ validateLocationTypesSeenByClient(Type.MEMORY, 0, 2)
+ celebornSparkSession.stop()
+ }
+
+ test("celeborn spark integration test - only s3") {
+ assumeS3LibraryIsLoaded()
+
+ val sparkConf = new
SparkConf().setAppName("celeborn-demo").setMaster("local[2]")
+ val celebornSparkSession = SparkSession.builder()
+ .config(updateSparkConfWithStorageTypes(sparkConf, ShuffleMode.HASH,
"S3"))
+ .getOrCreate()
+
+ repartition(celebornSparkSession, partitions = 1)
+ validateLocationTypesSeenByClient(Type.S3, 2, 2)
+ celebornSparkSession.stop()
+ }
+
+ test("celeborn spark integration test - memory does not evict to s3") {
+ assumeS3LibraryIsLoaded()
+
+ val sparkConf = new
SparkConf().setAppName("celeborn-demo").setMaster("local[2]")
+ val celebornSparkSession = SparkSession.builder()
+ .config(updateSparkConfWithStorageTypes(sparkConf, ShuffleMode.HASH,
"MEMORY,S3"))
+ .getOrCreate()
+
+ // little data, no eviction to s3 happens
+ repartition(celebornSparkSession, partitions = 1)
+ validateLocationTypesSeenByClient(Type.MEMORY, 0, 2)
+ celebornSparkSession.stop()
+ }
+
+ test("celeborn spark integration test - memory evict to s3") {
+ assumeS3LibraryIsLoaded()
+
+ val sparkConf = new
SparkConf().setAppName("celeborn-demo").setMaster("local[2]")
+ val celebornSparkSession = SparkSession.builder()
+ .config(updateSparkConfWithStorageTypes(sparkConf, ShuffleMode.HASH,
"MEMORY,S3"))
+ .getOrCreate()
+
+ // we need to write enough to trigger eviction from MEMORY to S3
+ // we want the partition to not fit the memory storage
+ val sampleSeq: immutable.Seq[(String, Int)] = buildBigDataSet
+
+ repartition(celebornSparkSession, sequence = sampleSeq, partitions = 1)
+ validateLocationTypesSeenByClient(Type.S3, 2, 2)
+ celebornSparkSession.stop()
+ }
+
+ test("celeborn spark integration test - push fails no way of evicting") {
+ assumeS3LibraryIsLoaded()
+
+ val sparkConf = new
SparkConf().setAppName("celeborn-demo").setMaster("local[2]")
+ val celebornSparkSession = SparkSession.builder()
+ .config(updateSparkConfWithStorageTypes(sparkConf, ShuffleMode.HASH,
"MEMORY"))
+ .getOrCreate()
+
+ val sampleSeq: immutable.Seq[(String, Int)] = buildBigDataSet
+
+ // we want the partition to not fit the memory storage, the job fails
+ assertThrows[SparkException](
+ repartition(celebornSparkSession, sequence = sampleSeq, partitions = 1))
+
+ celebornSparkSession.stop()
+ }
+
+ private def buildBigDataSet = {
+ val big1KBString: String = StringUtils.repeat(' ', 1024)
+ val partitionSize = 10 * 1024 * 1024
+ val numValues = partitionSize / big1KBString.length
+ // we need to write enough to trigger eviction from MEMORY to S3
+ val sampleSeq: immutable.Seq[(String, Int)] = (1 to numValues)
+ .map(i => big1KBString + i) // all different keys
+ .toList
+ .map(v => (v.toUpperCase, Random.nextInt(12) + 1))
+ sampleSeq
+ }
+
+ def interceptLocationsSeenByClient(): Unit = {
+ val worker = getOneWorker()
+ ShuffleManagerSpy.interceptOpenShuffleReader(
+ new OpenShuffleReaderCallback {
+ override def accept(
+ appId: String,
+ shuffleId: java.lang.Integer,
+ client: ShuffleClient,
+ startPartition: java.lang.Integer,
+ endPartition: java.lang.Integer): Unit = {
+ logInfo(
+ s"Open Shuffle Reader for App $appId shuffleId $shuffleId
locations ${worker.controller.partitionLocationInfo.primaryPartitionLocations}")
+ val locations =
worker.controller.partitionLocationInfo.primaryPartitionLocations.get(
+ appId + "-" + shuffleId)
+ logInfo(s"Locations on openReader $locations")
+ seenPartitionLocationsOpenReader.addAll(locations.values());
+
+ val partitionIdList = List.range(startPartition.intValue(),
endPartition.intValue())
+ partitionIdList.foreach(partitionId => {
+ val fileGroups = client.updateFileGroup(shuffleId, partitionId)
+ val locationsForPartition =
fileGroups.partitionGroups.get(partitionId)
+ logInfo(s"locationsForPartition $partitionId
$locationsForPartition")
+
seenPartitionLocationsUpdateFileGroups.addAll(locationsForPartition)
+ })
+ }
+ },
+ worker.conf)
+ }
+
+ def validateLocationTypesSeenByClient(
+ storageType: Type,
+ numberAtOpenReader: Int,
+ numberAfterUpdateFileGroups: Int): Unit = {
+ seenPartitionLocationsOpenReader.asScala.foreach(location => {
+ assert(location.getStorageInfo.getType == storageType)
+ // filePath is empty string for MEMORY and S3 at this stage
+ assert(location.getStorageInfo.getFilePath == "")
+ })
+ seenPartitionLocationsUpdateFileGroups.asScala.foreach(location => {
+ assert(location.getStorageInfo.getType == storageType)
+ // at this stage for S3 the reader must know the URI
+ if (storageType == Type.S3)
+ assert(location.getStorageInfo.getFilePath startsWith "s3://")
+ })
+ assert(seenPartitionLocationsOpenReader.size == numberAtOpenReader)
+ assert(seenPartitionLocationsUpdateFileGroups.size ==
numberAfterUpdateFileGroups)
+ }
+
+}
diff --git
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/s3/BasicEndToEndTieredStorageTest.scala
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/S3TieredStorageTest.scala
similarity index 84%
rename from
tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/s3/BasicEndToEndTieredStorageTest.scala
rename to
tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/S3TieredStorageTest.scala
index 0b2ea3646..9d1ebbacd 100644
---
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/s3/BasicEndToEndTieredStorageTest.scala
+++
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/S3TieredStorageTest.scala
@@ -27,25 +27,15 @@ import org.apache.celeborn.client.ShuffleClient
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.protocol.ShuffleMode
-class BasicEndToEndTieredStorageTest extends AnyFunSuite
+class S3TieredStorageTest extends AnyFunSuite
with SparkTestBase
with BeforeAndAfterEach {
- var container: MinIOContainer = null;
- val skipAWSTest = !isClassPresent("org.apache.hadoop.fs.s3a.S3AFileSystem")
-
- def isClassPresent(className: String): Boolean = {
- try {
- Class.forName(className)
- true
- } catch {
- case _: ClassNotFoundException => false
- }
- }
+ var container: MinIOContainer = _
override def beforeAll(): Unit = {
- if (skipAWSTest)
+ if (!isS3LibraryAvailable)
return
container = new MinIOContainer("minio/minio:RELEASE.2023-09-04T19-57-37Z");
@@ -67,9 +57,9 @@ class BasicEndToEndTieredStorageTest extends AnyFunSuite
val s3url = container.getS3URL
val augmentedConfiguration = Map(
- CelebornConf.ACTIVE_STORAGE_TYPES.key -> "MEMORY,S3",
- CelebornConf.WORKER_STORAGE_CREATE_FILE_POLICY.key -> "MEMORY,S3",
- CelebornConf.WORKER_STORAGE_EVICT_POLICY.key -> "MEMORY|S3",
+ CelebornConf.ACTIVE_STORAGE_TYPES.key -> "S3",
+ CelebornConf.WORKER_STORAGE_CREATE_FILE_POLICY.key -> "S3",
+ CelebornConf.WORKER_STORAGE_EVICT_POLICY.key -> "S3",
"celeborn.hadoop.fs.s3a.endpoint" -> s"$s3url",
"celeborn.hadoop.fs.s3a.aws.credentials.provider" ->
"org.apache.hadoop.fs.s3a.SimpleAWSCredentialsProvider",
"celeborn.hadoop.fs.s3a.access.key" -> container.getUserName,
@@ -100,7 +90,7 @@ class BasicEndToEndTieredStorageTest extends AnyFunSuite
override def updateSparkConf(sparkConf: SparkConf, mode: ShuffleMode):
SparkConf = {
val s3url = container.getS3URL
val newConf = sparkConf
- .set("spark." + CelebornConf.ACTIVE_STORAGE_TYPES.key, "MEMORY,S3")
+ .set("spark." + CelebornConf.ACTIVE_STORAGE_TYPES.key, "S3")
.set("spark." + CelebornConf.S3_DIR.key,
"s3://sample-bucket/test/celeborn")
.set("spark." + CelebornConf.S3_ENDPOINT_REGION.key, "dummy-region")
.set("spark.celeborn.hadoop.fs.s3a.endpoint", s"$s3url")
@@ -116,11 +106,9 @@ class BasicEndToEndTieredStorageTest extends AnyFunSuite
test("celeborn spark integration test - s3") {
assume(
- !skipAWSTest,
- "Skipping test because AWS Hadoop client is not in the classpath (enable
with -Paws")
+ isS3LibraryAvailable,
+ "Skipping test because AWS Hadoop client is not in the classpath (enable
with -Paws)")
- val s3url = container.getS3URL
- log.info(s"s3url $s3url");
val sparkConf = new
SparkConf().setAppName("celeborn-demo").setMaster("local[2]")
val celebornSparkSession = SparkSession.builder()
.config(updateSparkConf(sparkConf, ShuffleMode.HASH))
diff --git
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala
index dd58934ac..c857bd67b 100644
---
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala
+++
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/SparkTestBase.scala
@@ -17,6 +17,7 @@
package org.apache.celeborn.tests.spark
+import scala.collection.immutable
import scala.util.Random
import org.apache.spark.{SPARK_VERSION, SparkConf}
@@ -25,6 +26,7 @@ import org.apache.spark.sql.internal.SQLConf
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
import org.scalatest.funsuite.AnyFunSuite
+import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.CelebornConf._
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.protocol.ShuffleMode
@@ -37,7 +39,9 @@ trait SparkTestBase extends AnyFunSuite
val Spark3OrNewer = SPARK_VERSION >= "3.0"
println(s"Spark version is $SPARK_VERSION, Spark3OrNewer: $Spark3OrNewer")
- private val sampleSeq = (1 to 78)
+ val isS3LibraryAvailable =
isClassPresent("org.apache.hadoop.fs.s3a.S3AFileSystem")
+
+ private val sampleSeq: immutable.Seq[(Char, Int)] = (1 to 78)
.map(Random.alphanumeric)
.toList
.map(v => (v.toUpper, Random.nextInt(12) + 1))
@@ -54,6 +58,10 @@ trait SparkTestBase extends AnyFunSuite
var workerDirs: Seq[String] = Seq.empty
+ def getOneWorker(): Worker = {
+ workerInfos.head._1
+ }
+
override def createWorker(map: Map[String, String]): Worker = {
val storageDir = createTmpDir()
this.synchronized {
@@ -89,9 +97,18 @@ trait SparkTestBase extends AnyFunSuite
resultWithOutCeleborn
}
- def repartition(sparkSession: SparkSession): collection.Map[Char, Int] = {
- val inputRdd = sparkSession.sparkContext.parallelize(sampleSeq, 2)
- val result = inputRdd.repartition(8).reduceByKey((acc, v) => acc +
v).collectAsMap()
+ def repartition(
+ sparkSession: SparkSession,
+ sequence: immutable.Seq[(Any, Int)] = sampleSeq,
+ parallelism: Integer = 2,
+ partitions: Integer = 8,
+ additionalFilter: (Any) => Boolean = (a) => true): collection.Map[Any,
Int] = {
+ val inputRdd = sparkSession.sparkContext.parallelize(sequence, parallelism)
+ val result = inputRdd
+ .repartition(partitions)
+ .reduceByKey((acc, v) => acc + v)
+ .filter(additionalFilter)
+ .collectAsMap()
result
}
@@ -109,4 +126,14 @@ trait SparkTestBase extends AnyFunSuite
val outMap = result.collect().map(row => row.getString(0) ->
row.getLong(1)).toMap
outMap
}
+
+ def isClassPresent(className: String): Boolean = {
+ try {
+ Class.forName(className)
+ true
+ } catch {
+ case _: ClassNotFoundException => false
+ }
+ }
+
}
diff --git
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/TierWriter.scala
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/TierWriter.scala
index 16fe3921f..10bf94455 100644
---
a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/TierWriter.scala
+++
b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/storage/TierWriter.scala
@@ -298,13 +298,14 @@ class MemoryTierWriter(
override def evict(file: TierWriterBase): Unit = {
flushLock.synchronized {
+ val numBytes = flushBuffer.readableBytes()
+ logDebug(s"Evict ${Utils.bytesToString(
+ numBytes)} from memory to other tier ${file.filename} on
${file.storageType} for ${file.shuffleKey}")
// swap tier writer's flush buffer to memory tier writer's
// and handle its release
file.swapFlushBuffer(flushBuffer)
// close memory file writer after evict happened
file.flush(false, true)
- val numBytes = flushBuffer.readableBytes()
- logDebug(s"Evict $numBytes from memory to other tier")
MemoryManager.instance.releaseMemoryFileStorage(numBytes)
MemoryManager.instance.incrementDiskBuffer(numBytes)
storageManager.unregisterMemoryPartitionWriterAndFileInfo(fileInfo,
shuffleKey, filename)