This is an automated email from the ASF dual-hosted git repository.
zhouky pushed a commit to branch branch-0.3
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
The following commit(s) were added to refs/heads/branch-0.3 by this push:
new 5e6ce1664 [CELEBORN-659][SPARK][TEST] Refine RssShuffleWriterSuiteJ
5e6ce1664 is described below
commit 5e6ce1664ee882537622729267b3d43648fd0751
Author: Fu Chen <[email protected]>
AuthorDate: Mon Jun 12 13:48:52 2023 +0800
[CELEBORN-659][SPARK][TEST] Refine RssShuffleWriterSuiteJ
### What changes were proposed in this pull request?
1. renamed `RssShuffleWriterSuiteJ` to `CelebornShuffleWriterSuiteBase`,
which now serves as an abstract base class.
2. two new classes, `HashBasedShuffleWriterSuiteJ` and
`SortBasedShuffleWriterSuiteJ`, have been added. These classes extend
`CelebornShuffleWriterSuiteBase` and provide suites for testing hash-based and
sort-based shuffle writers.
### Why are the changes needed?
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
Closes #1570 from cfmcgrady/sort-based-writer-suite.
Authored-by: Fu Chen <[email protected]>
Signed-off-by: zky.zhoukeyong <[email protected]>
(cherry picked from commit cc716506f9ca882539b7e875b629735ee19327e1)
Signed-off-by: zky.zhoukeyong <[email protected]>
---
...eJ.java => CelebornShuffleWriterSuiteBase.java} | 233 ++++++++-------------
.../celeborn/HashBasedShuffleWriterSuiteJ.java | 40 ++++
.../celeborn/SortBasedShuffleWriterSuiteJ.java | 39 ++++
.../shuffle/celeborn/SortBasedShuffleWriter.java | 19 ++
...eJ.java => CelebornShuffleWriterSuiteBase.java} | 39 ++--
.../celeborn/HashBasedShuffleWriterSuiteJ.java | 42 ++++
.../celeborn/SortBasedShuffleWriterSuiteJ.java | 41 ++++
7 files changed, 291 insertions(+), 162 deletions(-)
diff --git
a/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/RssShuffleWriterSuiteJ.java
b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java
similarity index 64%
rename from
client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/RssShuffleWriterSuiteJ.java
rename to
client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java
index 33f3ef170..103cdc129 100644
---
a/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/RssShuffleWriterSuiteJ.java
+++
b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java
@@ -50,6 +50,7 @@ import org.apache.spark.memory.UnifiedMemoryManager;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.serializer.KryoSerializer;
import org.apache.spark.serializer.Serializer;
+import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
@@ -79,9 +80,9 @@ import org.apache.celeborn.common.identity.UserIdentifier;
import org.apache.celeborn.common.util.JavaUtils;
import org.apache.celeborn.common.util.Utils;
-public class RssShuffleWriterSuiteJ {
+public abstract class CelebornShuffleWriterSuiteBase {
- private static final Logger LOG =
LoggerFactory.getLogger(RssShuffleWriterSuiteJ.class);
+ private static final Logger LOG =
LoggerFactory.getLogger(CelebornShuffleWriterSuiteBase.class);
private static final String NORMAL_RECORD = "hello, world";
private static final String GIANT_RECORD = getGiantRecord();
@@ -90,19 +91,17 @@ public class RssShuffleWriterSuiteJ {
private final String appId = "appId";
private final String host = "host";
private final int port = 0;
+ private final int shuffleId = 0;
private final UserIdentifier userIdentifier = new UserIdentifier("mock",
"mock");
- private final int shuffleId = 0;
private final int numMaps = 10;
- private final int numPartitions = 10;
- private final int mapId = 0;
+ protected final int numPartitions = 10;
private final SparkConf sparkConf = new SparkConf(false);
private final BlockManagerId bmId = BlockManagerId.apply("execId", "host",
1, None$.empty());
- private final UnifiedMemoryManager memoryManager =
UnifiedMemoryManager.apply(sparkConf, 1);
- private final TaskMemoryManager taskMemoryManager = new
TaskMemoryManager(memoryManager, 0);
-
+ private final TaskMemoryManager taskMemoryManager =
+ new TaskMemoryManager(UnifiedMemoryManager.apply(sparkConf, 1), 0);
private final MapStatus mapStatus =
SparkUtils.createMapStatus(bmId, new long[numPartitions], new
long[numPartitions]);
@@ -122,11 +121,11 @@ public class RssShuffleWriterSuiteJ {
private static File tempDir = null;
- public RssShuffleWriterSuiteJ() throws IOException {}
+ public CelebornShuffleWriterSuiteBase() throws IOException {}
@BeforeClass
public static void beforeAll() {
- tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"),
"rss_test");
+ tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"),
"celeborn_test");
}
@AfterClass
@@ -145,11 +144,11 @@ public class RssShuffleWriterSuiteJ {
Mockito.doReturn(shuffleId).when(dependency).shuffleId();
Mockito.doReturn(metrics).when(taskContext).taskMetrics();
+ Mockito.doReturn(taskMemoryManager).when(taskContext).taskMemoryManager();
Mockito.doReturn(bmId).when(blockManager).shuffleServerId();
Mockito.doReturn(blockManager).when(env).blockManager();
Mockito.doReturn(sparkConf).when(env).conf();
- Mockito.doReturn(taskMemoryManager).when(taskContext).taskMemoryManager();
SparkEnv.set(env);
}
@@ -157,16 +156,14 @@ public class RssShuffleWriterSuiteJ {
public void testEmptyBlock() throws Exception {
final KryoSerializer serializer = new KryoSerializer(sparkConf);
final CelebornConf conf = new CelebornConf();
- check(0, conf, serializer, true);
- check(0, conf, serializer, false);
+ check(0, conf, serializer);
}
@Test
public void testEmptyBlockWithFastWrite() throws Exception {
final UnsafeRowSerializer serializer = new UnsafeRowSerializer(2, null);
final CelebornConf conf = new CelebornConf();
- check(0, conf, serializer, true);
- check(0, conf, serializer, false);
+ check(0, conf, serializer);
}
@Test
@@ -174,8 +171,7 @@ public class RssShuffleWriterSuiteJ {
final KryoSerializer serializer = new KryoSerializer(sparkConf);
final CelebornConf conf =
new
CelebornConf().set(CelebornConf.CLIENT_PUSH_BUFFER_MAX_SIZE().key(), "1024");
- check(10000, conf, serializer, true);
- check(10000, conf, serializer, false);
+ check(10000, conf, serializer);
}
@Test
@@ -183,8 +179,7 @@ public class RssShuffleWriterSuiteJ {
final UnsafeRowSerializer serializer = new UnsafeRowSerializer(2, null);
final CelebornConf conf =
new
CelebornConf().set(CelebornConf.CLIENT_PUSH_BUFFER_MAX_SIZE().key(), "1024");
- check(10000, conf, serializer, true);
- check(10000, conf, serializer, false);
+ check(10000, conf, serializer);
}
@Test
@@ -192,8 +187,7 @@ public class RssShuffleWriterSuiteJ {
final KryoSerializer serializer = new KryoSerializer(sparkConf);
final CelebornConf conf =
new
CelebornConf().set(CelebornConf.CLIENT_PUSH_BUFFER_MAX_SIZE().key(), "5");
- check(10000, conf, serializer, true);
- check(10000, conf, serializer, false);
+ check(10000, conf, serializer);
}
@Test
@@ -201,8 +195,7 @@ public class RssShuffleWriterSuiteJ {
final UnsafeRowSerializer serializer = new UnsafeRowSerializer(2, null);
final CelebornConf conf =
new
CelebornConf().set(CelebornConf.CLIENT_PUSH_BUFFER_MAX_SIZE().key(), "5");
- check(10000, conf, serializer, true);
- check(10000, conf, serializer, false);
+ check(10000, conf, serializer);
}
@Test
@@ -210,8 +203,7 @@ public class RssShuffleWriterSuiteJ {
final KryoSerializer serializer = new KryoSerializer(sparkConf);
final CelebornConf conf =
new
CelebornConf().set(CelebornConf.CLIENT_PUSH_BUFFER_MAX_SIZE().key(), "128");
- check(2 << 30, conf, serializer, true);
- check(2 << 30, conf, serializer, false);
+ check(2 << 30, conf, serializer);
}
@Test
@@ -219,15 +211,11 @@ public class RssShuffleWriterSuiteJ {
final UnsafeRowSerializer serializer = new UnsafeRowSerializer(2, null);
final CelebornConf conf =
new
CelebornConf().set(CelebornConf.CLIENT_PUSH_BUFFER_MAX_SIZE().key(), "128");
- check(2 << 30, conf, serializer, true);
- check(2 << 30, conf, serializer, false);
+ check(2 << 30, conf, serializer);
}
private void check(
- final int approximateSize,
- final CelebornConf conf,
- final Serializer serializer,
- final boolean hashWriter)
+ final int approximateSize, final CelebornConf conf, final Serializer
serializer)
throws Exception {
final boolean useUnsafe = serializer instanceof UnsafeRowSerializer;
@@ -242,134 +230,75 @@ public class RssShuffleWriterSuiteJ {
final ShuffleClient client = new DummyShuffleClient(conf, tempFile);
((DummyShuffleClient) client).initReducePartitionMap(shuffleId,
numPartitions, 1);
- if (hashWriter) {
- final HashBasedShuffleWriter<Integer, String, String> writer =
- new HashBasedShuffleWriter<>(
- handle, mapId, taskContext, conf, client, SendBufferPool.get(1));
+ final ShuffleWriter<Integer, String> writer =
+ createShuffleWriter(handle, taskContext, conf, client);
+
+ if (writer instanceof SortBasedShuffleWriter) {
+ assertEquals(useUnsafe, ((SortBasedShuffleWriter)
writer).canUseFastWrite());
+ } else if (writer instanceof HashBasedShuffleWriter) {
+ assertEquals(useUnsafe, ((HashBasedShuffleWriter)
writer).canUseFastWrite());
+ }
- AtomicInteger total = new AtomicInteger(0);
- Iterator iterator = getIterator(approximateSize, total, useUnsafe,
false);
+ AtomicInteger total = new AtomicInteger(0);
+ Iterator iterator = getIterator(approximateSize, total, useUnsafe, false);
- int expectChecksum = 0;
- for (int i = 0; i < total.intValue(); ++i) {
- expectChecksum ^= i;
- }
+ int expectChecksum = 0;
+ for (int i = 0; i < total.intValue(); ++i) {
+ expectChecksum ^= i;
+ }
- writer.write(iterator);
- Option<MapStatus> status = writer.stop(true);
- client.shutdown();
-
- assertNotNull(status);
- assertTrue(status.isDefined());
- assertEquals(bmId, status.get().location());
-
- ShuffleWriteMetrics metrics =
taskContext.taskMetrics().shuffleWriteMetrics();
- assertEquals(metrics.recordsWritten(), total.intValue());
- assertEquals(metrics.bytesWritten(), tempFile.length());
-
- try (FileInputStream fis = new FileInputStream(tempFile)) {
- Iterator it =
serializer.newInstance().deserializeStream(fis).asKeyValueIterator();
- int checksum = 0;
- while (it.hasNext()) {
- Product2<Integer, ?> record;
- if (useUnsafe) {
- record = (Product2<Integer, UnsafeRow>) it.next();
- } else {
- record = (Product2<Integer, String>) it.next();
- }
-
- assertNotNull(record);
- assertNotNull(record._1());
- assertNotNull(record._2());
-
- int key;
- String value;
-
- if (useUnsafe) {
- UnsafeRow row = (UnsafeRow) record._2();
-
- key = row.getInt(0);
- value = row.getString(1);
- } else {
- key = record._1();
- value = (String) record._2();
- }
-
- checksum ^= key;
- total.decrementAndGet();
-
- assertTrue(
- "value should equals to normal record or giant record with key.",
- value.equals(key + ": " + NORMAL_RECORD) || value.equals(key +
": " + GIANT_RECORD));
+ writer.write(iterator);
+ Option<MapStatus> status = writer.stop(true);
+ client.shutdown();
+
+ assertNotNull(status);
+ assertTrue(status.isDefined());
+ assertEquals(bmId, status.get().location());
+
+ ShuffleWriteMetrics metrics =
taskContext.taskMetrics().shuffleWriteMetrics();
+ assertEquals(metrics.recordsWritten(), total.intValue());
+ assertEquals(metrics.bytesWritten(), tempFile.length());
+
+ try (FileInputStream fis = new FileInputStream(tempFile)) {
+ Iterator it =
serializer.newInstance().deserializeStream(fis).asKeyValueIterator();
+ int checksum = 0;
+ while (it.hasNext()) {
+ Product2<Integer, ?> record;
+ if (useUnsafe) {
+ record = (Product2<Integer, UnsafeRow>) it.next();
+ } else {
+ record = (Product2<Integer, String>) it.next();
}
- assertEquals(0, total.intValue());
- assertEquals(expectChecksum, checksum);
- } catch (Exception e) {
- e.printStackTrace();
- fail("Should read with no exception.");
- }
- } else {
- final SortBasedShuffleWriter<Integer, String, String> writer =
- new SortBasedShuffleWriter<>(
- dependency, appId, numPartitions, taskContext, conf, client,
null);
- AtomicInteger total = new AtomicInteger(0);
- Iterator iterator = getIterator(approximateSize, total, useUnsafe,
false);
+ assertNotNull(record);
+ assertNotNull(record._1());
+ assertNotNull(record._2());
- int expectChecksum = 0;
- for (int i = 0; i < total.intValue(); ++i) {
- expectChecksum ^= i;
- }
+ int key;
+ String value;
- writer.write(iterator);
- Option<MapStatus> status = writer.stop(true);
- client.shutdown();
-
- assertNotNull(status);
- assertTrue(status.isDefined());
- assertEquals(bmId, status.get().location());
-
- try (FileInputStream fis = new FileInputStream(tempFile)) {
- Iterator it =
serializer.newInstance().deserializeStream(fis).asKeyValueIterator();
- int checksum = 0;
- while (it.hasNext()) {
- Product2<Integer, ?> record;
- if (useUnsafe) {
- record = (Product2<Integer, UnsafeRow>) it.next();
- } else {
- record = (Product2<Integer, String>) it.next();
- }
-
- assertNotNull(record);
- assertNotNull(record._1());
- assertNotNull(record._2());
-
- int key;
- String value;
-
- if (useUnsafe) {
- UnsafeRow row = (UnsafeRow) record._2();
-
- key = row.getInt(0);
- value = row.getString(1);
- } else {
- key = record._1();
- value = (String) record._2();
- }
-
- checksum ^= key;
- total.decrementAndGet();
-
- assertTrue(
- "value should equals to normal record or giant record with key.",
- value.equals(key + ": " + NORMAL_RECORD) || value.equals(key +
": " + GIANT_RECORD));
+ if (useUnsafe) {
+ UnsafeRow row = (UnsafeRow) record._2();
+
+ key = row.getInt(0);
+ value = row.getString(1);
+ } else {
+ key = record._1();
+ value = (String) record._2();
}
- assertEquals(0, total.intValue());
- assertEquals(expectChecksum, checksum);
- } catch (Exception e) {
- e.printStackTrace();
- fail("Should read with no exception.");
+
+ checksum ^= key;
+ total.decrementAndGet();
+
+ assertTrue(
+ "value should equals to normal record or giant record with key.",
+ value.equals(key + ": " + NORMAL_RECORD) || value.equals(key + ":
" + GIANT_RECORD));
}
+ assertEquals(0, total.intValue());
+ assertEquals(expectChecksum, checksum);
+ } catch (Exception e) {
+ e.printStackTrace();
+ fail("Should read with no exception.");
}
}
@@ -423,4 +352,8 @@ public class RssShuffleWriterSuiteJ {
int numCopies = (128 + NORMAL_RECORD.length() - 1) /
NORMAL_RECORD.length();
return String.join("/", Collections.nCopies(numCopies, NORMAL_RECORD));
}
+
+ protected abstract ShuffleWriter<Integer, String> createShuffleWriter(
+ RssShuffleHandle handle, TaskContext context, CelebornConf conf,
ShuffleClient client)
+ throws IOException;
}
diff --git
a/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java
b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java
new file mode 100644
index 000000000..0a9a86362
--- /dev/null
+++
b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.celeborn;
+
+import java.io.IOException;
+
+import org.apache.spark.TaskContext;
+import org.apache.spark.shuffle.ShuffleWriter;
+
+import org.apache.celeborn.client.ShuffleClient;
+import org.apache.celeborn.common.CelebornConf;
+
+public class HashBasedShuffleWriterSuiteJ extends
CelebornShuffleWriterSuiteBase {
+
+ public HashBasedShuffleWriterSuiteJ() throws IOException {}
+
+ @Override
+ protected ShuffleWriter<Integer, String> createShuffleWriter(
+ RssShuffleHandle handle, TaskContext context, CelebornConf conf,
ShuffleClient client)
+ throws IOException {
+ // this test case is independent of the `mapId` value
+ return new HashBasedShuffleWriter<Integer, String, String>(
+ handle, /*mapId=*/ 0, context, conf, client, SendBufferPool.get(1));
+ }
+}
diff --git
a/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java
b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java
new file mode 100644
index 000000000..1d7274444
--- /dev/null
+++
b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.celeborn;
+
+import java.io.IOException;
+
+import org.apache.spark.TaskContext;
+import org.apache.spark.shuffle.ShuffleWriter;
+
+import org.apache.celeborn.client.ShuffleClient;
+import org.apache.celeborn.common.CelebornConf;
+
+public class SortBasedShuffleWriterSuiteJ extends
CelebornShuffleWriterSuiteBase {
+
+ public SortBasedShuffleWriterSuiteJ() throws IOException {}
+
+ @Override
+ protected ShuffleWriter<Integer, String> createShuffleWriter(
+ RssShuffleHandle handle, TaskContext context, CelebornConf conf,
ShuffleClient client)
+ throws IOException {
+ return new SortBasedShuffleWriter<Integer, String, String>(
+ handle.dependency(), handle.appUniqueId(), numPartitions, context,
conf, client, null);
+ }
+}
diff --git
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
index 0cc7114cf..47d8eb41e 100644
---
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
+++
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
@@ -172,6 +172,25 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
}
}
+ public SortBasedShuffleWriter(
+ RssShuffleHandle<K, V, C> handle,
+ TaskContext taskContext,
+ CelebornConf conf,
+ ShuffleClient client,
+ ShuffleWriteMetricsReporter metrics,
+ ExecutorService executorService)
+ throws IOException {
+ this(
+ handle.dependency(),
+ handle.appUniqueId(),
+ handle.numMappers(),
+ taskContext,
+ conf,
+ client,
+ metrics,
+ executorService);
+ }
+
@Override
public void write(scala.collection.Iterator<Product2<K, V>> records) throws
IOException {
if (canUseFastWrite()) {
diff --git
a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/RssShuffleWriterSuiteJ.java
b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java
similarity index 90%
rename from
client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/RssShuffleWriterSuiteJ.java
rename to
client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java
index 91d2e043c..6995ae663 100644
---
a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/RssShuffleWriterSuiteJ.java
+++
b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java
@@ -45,9 +45,13 @@ import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.memory.UnifiedMemoryManager;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.serializer.KryoSerializer;
import org.apache.spark.serializer.Serializer;
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
+import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
@@ -77,9 +81,9 @@ import org.apache.celeborn.common.util.JavaUtils;
import org.apache.celeborn.common.util.Utils;
import org.apache.celeborn.reflect.DynConstructors;
-public class RssShuffleWriterSuiteJ {
+public abstract class CelebornShuffleWriterSuiteBase {
- private static final Logger LOG =
LoggerFactory.getLogger(RssShuffleWriterSuiteJ.class);
+ private static final Logger LOG =
LoggerFactory.getLogger(CelebornShuffleWriterSuiteBase.class);
private static final String NORMAL_RECORD = "hello, world";
private static final String GIANT_RECORD = getGiantRecord();
@@ -93,10 +97,13 @@ public class RssShuffleWriterSuiteJ {
private final UserIdentifier userIdentifier = new UserIdentifier("mock",
"mock");
private final int numMaps = 10;
- private final Integer numPartitions = 10;
+ private final int numPartitions = 10;
private final SparkConf sparkConf = new SparkConf(false);
private final BlockManagerId bmId = BlockManagerId.apply("execId", "host",
1, None$.empty());
+ private final TaskMemoryManager taskMemoryManager =
+ new TaskMemoryManager(UnifiedMemoryManager.apply(sparkConf, 1), 0);
+
@Mock(answer = Answers.RETURNS_SMART_NULLS)
private TaskContext taskContext = null;
@@ -134,6 +141,7 @@ public class RssShuffleWriterSuiteJ {
Mockito.doReturn(shuffleId).when(dependency).shuffleId();
Mockito.doReturn(metrics).when(taskContext).taskMetrics();
+ Mockito.doReturn(taskMemoryManager).when(taskContext).taskMemoryManager();
Mockito.doReturn(bmId).when(blockManager).shuffleServerId();
Mockito.doReturn(blockManager).when(env).blockManager();
@@ -228,15 +236,14 @@ public class RssShuffleWriterSuiteJ {
final ShuffleClient client = new DummyShuffleClient(conf, tempFile);
((DummyShuffleClient) client).initReducePartitionMap(shuffleId,
numPartitions, 1);
- final HashBasedShuffleWriter<Integer, String, String> writer =
- new HashBasedShuffleWriter<>(
- handle,
- taskContext,
- conf,
- client,
- metrics.shuffleWriteMetrics(),
- SendBufferPool.get(1));
- assertEquals(useUnsafe, writer.canUseFastWrite());
+ final ShuffleWriter<Integer, String> writer =
+ createShuffleWriter(handle, taskContext, conf, client,
metrics.shuffleWriteMetrics());
+
+ if (writer instanceof SortBasedShuffleWriter) {
+ assertEquals(useUnsafe, ((SortBasedShuffleWriter)
writer).canUseFastWrite());
+ } else if (writer instanceof HashBasedShuffleWriter) {
+ assertEquals(useUnsafe, ((HashBasedShuffleWriter)
writer).canUseFastWrite());
+ }
AtomicInteger total = new AtomicInteger(0);
Iterator iterator = getIterator(approximateSize, total, useUnsafe, false);
@@ -351,4 +358,12 @@ public class RssShuffleWriterSuiteJ {
int numCopies = (128 + NORMAL_RECORD.length() - 1) /
NORMAL_RECORD.length();
return String.join("/", Collections.nCopies(numCopies, NORMAL_RECORD));
}
+
+ protected abstract ShuffleWriter<Integer, String> createShuffleWriter(
+ RssShuffleHandle handle,
+ TaskContext context,
+ CelebornConf conf,
+ ShuffleClient client,
+ ShuffleWriteMetricsReporter metrics)
+ throws IOException;
}
diff --git
a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java
b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java
new file mode 100644
index 000000000..d44960dd5
--- /dev/null
+++
b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.celeborn;
+
+import java.io.IOException;
+
+import org.apache.spark.TaskContext;
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
+import org.apache.spark.shuffle.ShuffleWriter;
+
+import org.apache.celeborn.client.ShuffleClient;
+import org.apache.celeborn.common.CelebornConf;
+
+public class HashBasedShuffleWriterSuiteJ extends
CelebornShuffleWriterSuiteBase {
+
+ @Override
+ protected ShuffleWriter<Integer, String> createShuffleWriter(
+ RssShuffleHandle handle,
+ TaskContext context,
+ CelebornConf conf,
+ ShuffleClient client,
+ ShuffleWriteMetricsReporter metrics)
+ throws IOException {
+ return new HashBasedShuffleWriter<Integer, String, String>(
+ handle, context, conf, client, metrics, SendBufferPool.get(1));
+ }
+}
diff --git
a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java
b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java
new file mode 100644
index 000000000..96d5ed1c6
--- /dev/null
+++
b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.celeborn;
+
+import java.io.IOException;
+
+import org.apache.spark.TaskContext;
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
+import org.apache.spark.shuffle.ShuffleWriter;
+
+import org.apache.celeborn.client.ShuffleClient;
+import org.apache.celeborn.common.CelebornConf;
+
+public class SortBasedShuffleWriterSuiteJ extends
CelebornShuffleWriterSuiteBase {
+ @Override
+ protected ShuffleWriter<Integer, String> createShuffleWriter(
+ RssShuffleHandle handle,
+ TaskContext context,
+ CelebornConf conf,
+ ShuffleClient client,
+ ShuffleWriteMetricsReporter metrics)
+ throws IOException {
+ return new SortBasedShuffleWriter<Integer, String, String>(
+ handle, context, conf, client, metrics, null);
+ }
+}