This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch branch-0.3
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/branch-0.3 by this push:
     new 5e6ce1664 [CELEBORN-659][SPARK][TEST] Refine RssShuffleWriterSuiteJ
5e6ce1664 is described below

commit 5e6ce1664ee882537622729267b3d43648fd0751
Author: Fu Chen <[email protected]>
AuthorDate: Mon Jun 12 13:48:52 2023 +0800

    [CELEBORN-659][SPARK][TEST] Refine RssShuffleWriterSuiteJ
    
    ### What changes were proposed in this pull request?
    
    1. renamed `RssShuffleWriterSuiteJ` to `CelebornShuffleWriterSuiteBase`, 
which now serves as an abstract base class.
    2. two new classes, `HashBasedShuffleWriterSuiteJ` and 
`SortBasedShuffleWriterSuiteJ`, have been added. These classes extend 
`CelebornShuffleWriterSuiteBase` and provide suites for testing hash-based and 
sort-based shuffle writers.
    
    ### Why are the changes needed?
    
    ### Does this PR introduce _any_ user-facing change?
    
    ### How was this patch tested?
    
    Closes #1570 from cfmcgrady/sort-based-writer-suite.
    
    Authored-by: Fu Chen <[email protected]>
    Signed-off-by: zky.zhoukeyong <[email protected]>
    (cherry picked from commit cc716506f9ca882539b7e875b629735ee19327e1)
    Signed-off-by: zky.zhoukeyong <[email protected]>
---
 ...eJ.java => CelebornShuffleWriterSuiteBase.java} | 233 ++++++++-------------
 .../celeborn/HashBasedShuffleWriterSuiteJ.java     |  40 ++++
 .../celeborn/SortBasedShuffleWriterSuiteJ.java     |  39 ++++
 .../shuffle/celeborn/SortBasedShuffleWriter.java   |  19 ++
 ...eJ.java => CelebornShuffleWriterSuiteBase.java} |  39 ++--
 .../celeborn/HashBasedShuffleWriterSuiteJ.java     |  42 ++++
 .../celeborn/SortBasedShuffleWriterSuiteJ.java     |  41 ++++
 7 files changed, 291 insertions(+), 162 deletions(-)

diff --git 
a/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/RssShuffleWriterSuiteJ.java
 
b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java
similarity index 64%
rename from 
client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/RssShuffleWriterSuiteJ.java
rename to 
client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java
index 33f3ef170..103cdc129 100644
--- 
a/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/RssShuffleWriterSuiteJ.java
+++ 
b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java
@@ -50,6 +50,7 @@ import org.apache.spark.memory.UnifiedMemoryManager;
 import org.apache.spark.scheduler.MapStatus;
 import org.apache.spark.serializer.KryoSerializer;
 import org.apache.spark.serializer.Serializer;
+import org.apache.spark.shuffle.ShuffleWriter;
 import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
@@ -79,9 +80,9 @@ import org.apache.celeborn.common.identity.UserIdentifier;
 import org.apache.celeborn.common.util.JavaUtils;
 import org.apache.celeborn.common.util.Utils;
 
-public class RssShuffleWriterSuiteJ {
+public abstract class CelebornShuffleWriterSuiteBase {
 
-  private static final Logger LOG = 
LoggerFactory.getLogger(RssShuffleWriterSuiteJ.class);
+  private static final Logger LOG = 
LoggerFactory.getLogger(CelebornShuffleWriterSuiteBase.class);
   private static final String NORMAL_RECORD = "hello, world";
   private static final String GIANT_RECORD = getGiantRecord();
 
@@ -90,19 +91,17 @@ public class RssShuffleWriterSuiteJ {
   private final String appId = "appId";
   private final String host = "host";
   private final int port = 0;
+  private final int shuffleId = 0;
 
   private final UserIdentifier userIdentifier = new UserIdentifier("mock", 
"mock");
 
-  private final int shuffleId = 0;
   private final int numMaps = 10;
-  private final int numPartitions = 10;
-  private final int mapId = 0;
+  protected final int numPartitions = 10;
   private final SparkConf sparkConf = new SparkConf(false);
   private final BlockManagerId bmId = BlockManagerId.apply("execId", "host", 
1, None$.empty());
 
-  private final UnifiedMemoryManager memoryManager = 
UnifiedMemoryManager.apply(sparkConf, 1);
-  private final TaskMemoryManager taskMemoryManager = new 
TaskMemoryManager(memoryManager, 0);
-
+  private final TaskMemoryManager taskMemoryManager =
+      new TaskMemoryManager(UnifiedMemoryManager.apply(sparkConf, 1), 0);
   private final MapStatus mapStatus =
       SparkUtils.createMapStatus(bmId, new long[numPartitions], new 
long[numPartitions]);
 
@@ -122,11 +121,11 @@ public class RssShuffleWriterSuiteJ {
 
   private static File tempDir = null;
 
-  public RssShuffleWriterSuiteJ() throws IOException {}
+  public CelebornShuffleWriterSuiteBase() throws IOException {}
 
   @BeforeClass
   public static void beforeAll() {
-    tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), 
"rss_test");
+    tempDir = Utils.createTempDir(System.getProperty("java.io.tmpdir"), 
"celeborn_test");
   }
 
   @AfterClass
@@ -145,11 +144,11 @@ public class RssShuffleWriterSuiteJ {
     Mockito.doReturn(shuffleId).when(dependency).shuffleId();
 
     Mockito.doReturn(metrics).when(taskContext).taskMetrics();
+    Mockito.doReturn(taskMemoryManager).when(taskContext).taskMemoryManager();
 
     Mockito.doReturn(bmId).when(blockManager).shuffleServerId();
     Mockito.doReturn(blockManager).when(env).blockManager();
     Mockito.doReturn(sparkConf).when(env).conf();
-    Mockito.doReturn(taskMemoryManager).when(taskContext).taskMemoryManager();
     SparkEnv.set(env);
   }
 
@@ -157,16 +156,14 @@ public class RssShuffleWriterSuiteJ {
   public void testEmptyBlock() throws Exception {
     final KryoSerializer serializer = new KryoSerializer(sparkConf);
     final CelebornConf conf = new CelebornConf();
-    check(0, conf, serializer, true);
-    check(0, conf, serializer, false);
+    check(0, conf, serializer);
   }
 
   @Test
   public void testEmptyBlockWithFastWrite() throws Exception {
     final UnsafeRowSerializer serializer = new UnsafeRowSerializer(2, null);
     final CelebornConf conf = new CelebornConf();
-    check(0, conf, serializer, true);
-    check(0, conf, serializer, false);
+    check(0, conf, serializer);
   }
 
   @Test
@@ -174,8 +171,7 @@ public class RssShuffleWriterSuiteJ {
     final KryoSerializer serializer = new KryoSerializer(sparkConf);
     final CelebornConf conf =
         new 
CelebornConf().set(CelebornConf.CLIENT_PUSH_BUFFER_MAX_SIZE().key(), "1024");
-    check(10000, conf, serializer, true);
-    check(10000, conf, serializer, false);
+    check(10000, conf, serializer);
   }
 
   @Test
@@ -183,8 +179,7 @@ public class RssShuffleWriterSuiteJ {
     final UnsafeRowSerializer serializer = new UnsafeRowSerializer(2, null);
     final CelebornConf conf =
         new 
CelebornConf().set(CelebornConf.CLIENT_PUSH_BUFFER_MAX_SIZE().key(), "1024");
-    check(10000, conf, serializer, true);
-    check(10000, conf, serializer, false);
+    check(10000, conf, serializer);
   }
 
   @Test
@@ -192,8 +187,7 @@ public class RssShuffleWriterSuiteJ {
     final KryoSerializer serializer = new KryoSerializer(sparkConf);
     final CelebornConf conf =
         new 
CelebornConf().set(CelebornConf.CLIENT_PUSH_BUFFER_MAX_SIZE().key(), "5");
-    check(10000, conf, serializer, true);
-    check(10000, conf, serializer, false);
+    check(10000, conf, serializer);
   }
 
   @Test
@@ -201,8 +195,7 @@ public class RssShuffleWriterSuiteJ {
     final UnsafeRowSerializer serializer = new UnsafeRowSerializer(2, null);
     final CelebornConf conf =
         new 
CelebornConf().set(CelebornConf.CLIENT_PUSH_BUFFER_MAX_SIZE().key(), "5");
-    check(10000, conf, serializer, true);
-    check(10000, conf, serializer, false);
+    check(10000, conf, serializer);
   }
 
   @Test
@@ -210,8 +203,7 @@ public class RssShuffleWriterSuiteJ {
     final KryoSerializer serializer = new KryoSerializer(sparkConf);
     final CelebornConf conf =
         new 
CelebornConf().set(CelebornConf.CLIENT_PUSH_BUFFER_MAX_SIZE().key(), "128");
-    check(2 << 30, conf, serializer, true);
-    check(2 << 30, conf, serializer, false);
+    check(2 << 30, conf, serializer);
   }
 
   @Test
@@ -219,15 +211,11 @@ public class RssShuffleWriterSuiteJ {
     final UnsafeRowSerializer serializer = new UnsafeRowSerializer(2, null);
     final CelebornConf conf =
         new 
CelebornConf().set(CelebornConf.CLIENT_PUSH_BUFFER_MAX_SIZE().key(), "128");
-    check(2 << 30, conf, serializer, true);
-    check(2 << 30, conf, serializer, false);
+    check(2 << 30, conf, serializer);
   }
 
   private void check(
-      final int approximateSize,
-      final CelebornConf conf,
-      final Serializer serializer,
-      final boolean hashWriter)
+      final int approximateSize, final CelebornConf conf, final Serializer 
serializer)
       throws Exception {
     final boolean useUnsafe = serializer instanceof UnsafeRowSerializer;
 
@@ -242,134 +230,75 @@ public class RssShuffleWriterSuiteJ {
     final ShuffleClient client = new DummyShuffleClient(conf, tempFile);
     ((DummyShuffleClient) client).initReducePartitionMap(shuffleId, 
numPartitions, 1);
 
-    if (hashWriter) {
-      final HashBasedShuffleWriter<Integer, String, String> writer =
-          new HashBasedShuffleWriter<>(
-              handle, mapId, taskContext, conf, client, SendBufferPool.get(1));
+    final ShuffleWriter<Integer, String> writer =
+        createShuffleWriter(handle, taskContext, conf, client);
+
+    if (writer instanceof SortBasedShuffleWriter) {
+      assertEquals(useUnsafe, ((SortBasedShuffleWriter) 
writer).canUseFastWrite());
+    } else if (writer instanceof HashBasedShuffleWriter) {
+      assertEquals(useUnsafe, ((HashBasedShuffleWriter) 
writer).canUseFastWrite());
+    }
 
-      AtomicInteger total = new AtomicInteger(0);
-      Iterator iterator = getIterator(approximateSize, total, useUnsafe, 
false);
+    AtomicInteger total = new AtomicInteger(0);
+    Iterator iterator = getIterator(approximateSize, total, useUnsafe, false);
 
-      int expectChecksum = 0;
-      for (int i = 0; i < total.intValue(); ++i) {
-        expectChecksum ^= i;
-      }
+    int expectChecksum = 0;
+    for (int i = 0; i < total.intValue(); ++i) {
+      expectChecksum ^= i;
+    }
 
-      writer.write(iterator);
-      Option<MapStatus> status = writer.stop(true);
-      client.shutdown();
-
-      assertNotNull(status);
-      assertTrue(status.isDefined());
-      assertEquals(bmId, status.get().location());
-
-      ShuffleWriteMetrics metrics = 
taskContext.taskMetrics().shuffleWriteMetrics();
-      assertEquals(metrics.recordsWritten(), total.intValue());
-      assertEquals(metrics.bytesWritten(), tempFile.length());
-
-      try (FileInputStream fis = new FileInputStream(tempFile)) {
-        Iterator it = 
serializer.newInstance().deserializeStream(fis).asKeyValueIterator();
-        int checksum = 0;
-        while (it.hasNext()) {
-          Product2<Integer, ?> record;
-          if (useUnsafe) {
-            record = (Product2<Integer, UnsafeRow>) it.next();
-          } else {
-            record = (Product2<Integer, String>) it.next();
-          }
-
-          assertNotNull(record);
-          assertNotNull(record._1());
-          assertNotNull(record._2());
-
-          int key;
-          String value;
-
-          if (useUnsafe) {
-            UnsafeRow row = (UnsafeRow) record._2();
-
-            key = row.getInt(0);
-            value = row.getString(1);
-          } else {
-            key = record._1();
-            value = (String) record._2();
-          }
-
-          checksum ^= key;
-          total.decrementAndGet();
-
-          assertTrue(
-              "value should equals to normal record or giant record with key.",
-              value.equals(key + ": " + NORMAL_RECORD) || value.equals(key + 
": " + GIANT_RECORD));
+    writer.write(iterator);
+    Option<MapStatus> status = writer.stop(true);
+    client.shutdown();
+
+    assertNotNull(status);
+    assertTrue(status.isDefined());
+    assertEquals(bmId, status.get().location());
+
+    ShuffleWriteMetrics metrics = 
taskContext.taskMetrics().shuffleWriteMetrics();
+    assertEquals(metrics.recordsWritten(), total.intValue());
+    assertEquals(metrics.bytesWritten(), tempFile.length());
+
+    try (FileInputStream fis = new FileInputStream(tempFile)) {
+      Iterator it = 
serializer.newInstance().deserializeStream(fis).asKeyValueIterator();
+      int checksum = 0;
+      while (it.hasNext()) {
+        Product2<Integer, ?> record;
+        if (useUnsafe) {
+          record = (Product2<Integer, UnsafeRow>) it.next();
+        } else {
+          record = (Product2<Integer, String>) it.next();
         }
-        assertEquals(0, total.intValue());
-        assertEquals(expectChecksum, checksum);
-      } catch (Exception e) {
-        e.printStackTrace();
-        fail("Should read with no exception.");
-      }
-    } else {
-      final SortBasedShuffleWriter<Integer, String, String> writer =
-          new SortBasedShuffleWriter<>(
-              dependency, appId, numPartitions, taskContext, conf, client, 
null);
 
-      AtomicInteger total = new AtomicInteger(0);
-      Iterator iterator = getIterator(approximateSize, total, useUnsafe, 
false);
+        assertNotNull(record);
+        assertNotNull(record._1());
+        assertNotNull(record._2());
 
-      int expectChecksum = 0;
-      for (int i = 0; i < total.intValue(); ++i) {
-        expectChecksum ^= i;
-      }
+        int key;
+        String value;
 
-      writer.write(iterator);
-      Option<MapStatus> status = writer.stop(true);
-      client.shutdown();
-
-      assertNotNull(status);
-      assertTrue(status.isDefined());
-      assertEquals(bmId, status.get().location());
-
-      try (FileInputStream fis = new FileInputStream(tempFile)) {
-        Iterator it = 
serializer.newInstance().deserializeStream(fis).asKeyValueIterator();
-        int checksum = 0;
-        while (it.hasNext()) {
-          Product2<Integer, ?> record;
-          if (useUnsafe) {
-            record = (Product2<Integer, UnsafeRow>) it.next();
-          } else {
-            record = (Product2<Integer, String>) it.next();
-          }
-
-          assertNotNull(record);
-          assertNotNull(record._1());
-          assertNotNull(record._2());
-
-          int key;
-          String value;
-
-          if (useUnsafe) {
-            UnsafeRow row = (UnsafeRow) record._2();
-
-            key = row.getInt(0);
-            value = row.getString(1);
-          } else {
-            key = record._1();
-            value = (String) record._2();
-          }
-
-          checksum ^= key;
-          total.decrementAndGet();
-
-          assertTrue(
-              "value should equals to normal record or giant record with key.",
-              value.equals(key + ": " + NORMAL_RECORD) || value.equals(key + 
": " + GIANT_RECORD));
+        if (useUnsafe) {
+          UnsafeRow row = (UnsafeRow) record._2();
+
+          key = row.getInt(0);
+          value = row.getString(1);
+        } else {
+          key = record._1();
+          value = (String) record._2();
         }
-        assertEquals(0, total.intValue());
-        assertEquals(expectChecksum, checksum);
-      } catch (Exception e) {
-        e.printStackTrace();
-        fail("Should read with no exception.");
+
+        checksum ^= key;
+        total.decrementAndGet();
+
+        assertTrue(
+            "value should equals to normal record or giant record with key.",
+            value.equals(key + ": " + NORMAL_RECORD) || value.equals(key + ": 
" + GIANT_RECORD));
       }
+      assertEquals(0, total.intValue());
+      assertEquals(expectChecksum, checksum);
+    } catch (Exception e) {
+      e.printStackTrace();
+      fail("Should read with no exception.");
     }
   }
 
@@ -423,4 +352,8 @@ public class RssShuffleWriterSuiteJ {
     int numCopies = (128 + NORMAL_RECORD.length() - 1) / 
NORMAL_RECORD.length();
     return String.join("/", Collections.nCopies(numCopies, NORMAL_RECORD));
   }
+
+  protected abstract ShuffleWriter<Integer, String> createShuffleWriter(
+      RssShuffleHandle handle, TaskContext context, CelebornConf conf, 
ShuffleClient client)
+      throws IOException;
 }
diff --git 
a/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java
 
b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java
new file mode 100644
index 000000000..0a9a86362
--- /dev/null
+++ 
b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.celeborn;
+
+import java.io.IOException;
+
+import org.apache.spark.TaskContext;
+import org.apache.spark.shuffle.ShuffleWriter;
+
+import org.apache.celeborn.client.ShuffleClient;
+import org.apache.celeborn.common.CelebornConf;
+
+public class HashBasedShuffleWriterSuiteJ extends 
CelebornShuffleWriterSuiteBase {
+
+  public HashBasedShuffleWriterSuiteJ() throws IOException {}
+
+  @Override
+  protected ShuffleWriter<Integer, String> createShuffleWriter(
+      RssShuffleHandle handle, TaskContext context, CelebornConf conf, 
ShuffleClient client)
+      throws IOException {
+    // this test case is independent of the `mapId` value
+    return new HashBasedShuffleWriter<Integer, String, String>(
+        handle, /*mapId=*/ 0, context, conf, client, SendBufferPool.get(1));
+  }
+}
diff --git 
a/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java
 
b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java
new file mode 100644
index 000000000..1d7274444
--- /dev/null
+++ 
b/client-spark/spark-2/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.celeborn;
+
+import java.io.IOException;
+
+import org.apache.spark.TaskContext;
+import org.apache.spark.shuffle.ShuffleWriter;
+
+import org.apache.celeborn.client.ShuffleClient;
+import org.apache.celeborn.common.CelebornConf;
+
+public class SortBasedShuffleWriterSuiteJ extends 
CelebornShuffleWriterSuiteBase {
+
+  public SortBasedShuffleWriterSuiteJ() throws IOException {}
+
+  @Override
+  protected ShuffleWriter<Integer, String> createShuffleWriter(
+      RssShuffleHandle handle, TaskContext context, CelebornConf conf, 
ShuffleClient client)
+      throws IOException {
+    return new SortBasedShuffleWriter<Integer, String, String>(
+        handle.dependency(), handle.appUniqueId(), numPartitions, context, 
conf, client, null);
+  }
+}
diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
index 0cc7114cf..47d8eb41e 100644
--- 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
@@ -172,6 +172,25 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     }
   }
 
+  public SortBasedShuffleWriter(
+      RssShuffleHandle<K, V, C> handle,
+      TaskContext taskContext,
+      CelebornConf conf,
+      ShuffleClient client,
+      ShuffleWriteMetricsReporter metrics,
+      ExecutorService executorService)
+      throws IOException {
+    this(
+        handle.dependency(),
+        handle.appUniqueId(),
+        handle.numMappers(),
+        taskContext,
+        conf,
+        client,
+        metrics,
+        executorService);
+  }
+
   @Override
   public void write(scala.collection.Iterator<Product2<K, V>> records) throws 
IOException {
     if (canUseFastWrite()) {
diff --git 
a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/RssShuffleWriterSuiteJ.java
 
b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java
similarity index 90%
rename from 
client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/RssShuffleWriterSuiteJ.java
rename to 
client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java
index 91d2e043c..6995ae663 100644
--- 
a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/RssShuffleWriterSuiteJ.java
+++ 
b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/CelebornShuffleWriterSuiteBase.java
@@ -45,9 +45,13 @@ import org.apache.spark.SparkEnv;
 import org.apache.spark.TaskContext;
 import org.apache.spark.executor.ShuffleWriteMetrics;
 import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.memory.UnifiedMemoryManager;
 import org.apache.spark.scheduler.MapStatus;
 import org.apache.spark.serializer.KryoSerializer;
 import org.apache.spark.serializer.Serializer;
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
+import org.apache.spark.shuffle.ShuffleWriter;
 import org.apache.spark.sql.catalyst.InternalRow;
 import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow;
@@ -77,9 +81,9 @@ import org.apache.celeborn.common.util.JavaUtils;
 import org.apache.celeborn.common.util.Utils;
 import org.apache.celeborn.reflect.DynConstructors;
 
-public class RssShuffleWriterSuiteJ {
+public abstract class CelebornShuffleWriterSuiteBase {
 
-  private static final Logger LOG = 
LoggerFactory.getLogger(RssShuffleWriterSuiteJ.class);
+  private static final Logger LOG = 
LoggerFactory.getLogger(CelebornShuffleWriterSuiteBase.class);
   private static final String NORMAL_RECORD = "hello, world";
   private static final String GIANT_RECORD = getGiantRecord();
 
@@ -93,10 +97,13 @@ public class RssShuffleWriterSuiteJ {
   private final UserIdentifier userIdentifier = new UserIdentifier("mock", 
"mock");
 
   private final int numMaps = 10;
-  private final Integer numPartitions = 10;
+  private final int numPartitions = 10;
   private final SparkConf sparkConf = new SparkConf(false);
   private final BlockManagerId bmId = BlockManagerId.apply("execId", "host", 
1, None$.empty());
 
+  private final TaskMemoryManager taskMemoryManager =
+      new TaskMemoryManager(UnifiedMemoryManager.apply(sparkConf, 1), 0);
+
   @Mock(answer = Answers.RETURNS_SMART_NULLS)
   private TaskContext taskContext = null;
 
@@ -134,6 +141,7 @@ public class RssShuffleWriterSuiteJ {
     Mockito.doReturn(shuffleId).when(dependency).shuffleId();
 
     Mockito.doReturn(metrics).when(taskContext).taskMetrics();
+    Mockito.doReturn(taskMemoryManager).when(taskContext).taskMemoryManager();
 
     Mockito.doReturn(bmId).when(blockManager).shuffleServerId();
     Mockito.doReturn(blockManager).when(env).blockManager();
@@ -228,15 +236,14 @@ public class RssShuffleWriterSuiteJ {
     final ShuffleClient client = new DummyShuffleClient(conf, tempFile);
     ((DummyShuffleClient) client).initReducePartitionMap(shuffleId, 
numPartitions, 1);
 
-    final HashBasedShuffleWriter<Integer, String, String> writer =
-        new HashBasedShuffleWriter<>(
-            handle,
-            taskContext,
-            conf,
-            client,
-            metrics.shuffleWriteMetrics(),
-            SendBufferPool.get(1));
-    assertEquals(useUnsafe, writer.canUseFastWrite());
+    final ShuffleWriter<Integer, String> writer =
+        createShuffleWriter(handle, taskContext, conf, client, 
metrics.shuffleWriteMetrics());
+
+    if (writer instanceof SortBasedShuffleWriter) {
+      assertEquals(useUnsafe, ((SortBasedShuffleWriter) 
writer).canUseFastWrite());
+    } else if (writer instanceof HashBasedShuffleWriter) {
+      assertEquals(useUnsafe, ((HashBasedShuffleWriter) 
writer).canUseFastWrite());
+    }
 
     AtomicInteger total = new AtomicInteger(0);
     Iterator iterator = getIterator(approximateSize, total, useUnsafe, false);
@@ -351,4 +358,12 @@ public class RssShuffleWriterSuiteJ {
     int numCopies = (128 + NORMAL_RECORD.length() - 1) / 
NORMAL_RECORD.length();
     return String.join("/", Collections.nCopies(numCopies, NORMAL_RECORD));
   }
+
+  protected abstract ShuffleWriter<Integer, String> createShuffleWriter(
+      RssShuffleHandle handle,
+      TaskContext context,
+      CelebornConf conf,
+      ShuffleClient client,
+      ShuffleWriteMetricsReporter metrics)
+      throws IOException;
 }
diff --git 
a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java
 
b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java
new file mode 100644
index 000000000..d44960dd5
--- /dev/null
+++ 
b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriterSuiteJ.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.celeborn;
+
+import java.io.IOException;
+
+import org.apache.spark.TaskContext;
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
+import org.apache.spark.shuffle.ShuffleWriter;
+
+import org.apache.celeborn.client.ShuffleClient;
+import org.apache.celeborn.common.CelebornConf;
+
+public class HashBasedShuffleWriterSuiteJ extends 
CelebornShuffleWriterSuiteBase {
+
+  @Override
+  protected ShuffleWriter<Integer, String> createShuffleWriter(
+      RssShuffleHandle handle,
+      TaskContext context,
+      CelebornConf conf,
+      ShuffleClient client,
+      ShuffleWriteMetricsReporter metrics)
+      throws IOException {
+    return new HashBasedShuffleWriter<Integer, String, String>(
+        handle, context, conf, client, metrics, SendBufferPool.get(1));
+  }
+}
diff --git 
a/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java
 
b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java
new file mode 100644
index 000000000..96d5ed1c6
--- /dev/null
+++ 
b/client-spark/spark-3/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriterSuiteJ.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.celeborn;
+
+import java.io.IOException;
+
+import org.apache.spark.TaskContext;
+import org.apache.spark.shuffle.ShuffleWriteMetricsReporter;
+import org.apache.spark.shuffle.ShuffleWriter;
+
+import org.apache.celeborn.client.ShuffleClient;
+import org.apache.celeborn.common.CelebornConf;
+
+public class SortBasedShuffleWriterSuiteJ extends 
CelebornShuffleWriterSuiteBase {
+  @Override
+  protected ShuffleWriter<Integer, String> createShuffleWriter(
+      RssShuffleHandle handle,
+      TaskContext context,
+      CelebornConf conf,
+      ShuffleClient client,
+      ShuffleWriteMetricsReporter metrics)
+      throws IOException {
+    return new SortBasedShuffleWriter<Integer, String, String>(
+        handle, context, conf, client, metrics, null);
+  }
+}

Reply via email to