This is an automated email from the ASF dual-hosted git repository.

xianjin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new e0996f2ba [#1824] feat(spark): Support map side combine of shuffle 
writer (#1825)
e0996f2ba is described below

commit e0996f2bad9f1f53c36a6863f23b7365969b2ae7
Author: Zhen Wang <643348...@qq.com>
AuthorDate: Wed Jun 26 21:22:07 2024 +0800

    [#1824] feat(spark): Support map side combine of shuffle writer (#1825)
    
    ### What changes were proposed in this pull request?
    Support map side combine of shuffle write
    
    ### Why are the changes needed?
    Fix: #1824
    
    ### Does this PR introduce _any_ user-facing change?
    Yes, support new shuffle writer behavior.
    
    ### How was this patch tested?
    Added integration test
---
 .../org/apache/spark/shuffle/RssSparkConfig.java   |   6 +
 .../spark/shuffle/writer/RssShuffleWriter.java     |  28 ++--
 docs/client_guide/spark_client_guide.md            |  14 +-
 .../uniffle/test/WriteAndReadMetricsTest.java      |  34 +----
 .../listener/WriteAndReadMetricsSparkListener.java |  52 +++++++
 .../apache/uniffle/test/MapSideCombineTest.java    | 150 +++++++++++++++++++++
 6 files changed, 240 insertions(+), 44 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
index b47707d16..ba5e414cc 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
@@ -113,6 +113,12 @@ public class RssSparkConfig {
           .defaultValue(1)
           .withDescription("The block retry max times when partition reassign 
is enabled.");
 
+  public static final ConfigOption<Boolean> 
RSS_CLIENT_MAP_SIDE_COMBINE_ENABLED =
+      ConfigOptions.key("rss.client.mapSideCombine.enabled")
+          .booleanType()
+          .defaultValue(false)
+          .withDescription("Whether to enable map side combine of shuffle 
writer.");
+
   public static final String SPARK_RSS_CONFIG_PREFIX = "spark.";
 
   public static final ConfigEntry<Integer> RSS_PARTITION_NUM_PER_RANGE =
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index 3dfc2fd62..50eb47001 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -41,6 +41,7 @@ import java.util.stream.Collectors;
 import scala.Function1;
 import scala.Option;
 import scala.Product2;
+import scala.Tuple2;
 import scala.collection.Iterator;
 
 import com.google.common.annotations.VisibleForTesting;
@@ -91,6 +92,7 @@ import 
org.apache.uniffle.common.exception.RssWaitFailedException;
 import org.apache.uniffle.common.rpc.StatusCode;
 import org.apache.uniffle.storage.util.StorageType;
 
+import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_CLIENT_MAP_SIDE_COMBINE_ENABLED;
 import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_PARTITION_REASSIGN_BLOCK_RETRY_MAX_TIMES;
 import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_RESUBMIT_STAGE_WITH_WRITE_FAILURE_ENABLED;
 
@@ -289,25 +291,27 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private void writeImpl(Iterator<Product2<K, V>> records) {
     List<ShuffleBlockInfo> shuffleBlockInfos;
     boolean isCombine = shuffleDependency.mapSideCombine();
-    Function1<V, C> createCombiner = null;
+
+    Iterator<? extends Product2<K, ?>> iterator = records;
     if (isCombine) {
-      createCombiner = shuffleDependency.aggregator().get().createCombiner();
+      if 
(RssSparkConfig.toRssConf(sparkConf).get(RSS_CLIENT_MAP_SIDE_COMBINE_ENABLED)) {
+        iterator = 
shuffleDependency.aggregator().get().combineValuesByKey(records, taskContext);
+      } else {
+        Function1<V, C> combiner = 
shuffleDependency.aggregator().get().createCombiner();
+        iterator =
+            records.map(
+                (Function1<Product2<K, V>, Product2<K, C>>)
+                    x -> new Tuple2<>(x._1(), combiner.apply(x._2())));
+      }
     }
     long recordCount = 0;
-    while (records.hasNext()) {
+    while (iterator.hasNext()) {
       recordCount++;
-
       checkDataIfAnyFailure();
-
-      Product2<K, V> record = records.next();
+      Product2<K, ?> record = iterator.next();
       K key = record._1();
       int partition = getPartition(key);
-      if (isCombine) {
-        Object c = createCombiner.apply(record._2());
-        shuffleBlockInfos = bufferManager.addRecord(partition, record._1(), c);
-      } else {
-        shuffleBlockInfos = bufferManager.addRecord(partition, record._1(), 
record._2());
-      }
+      shuffleBlockInfos = bufferManager.addRecord(partition, record._1(), 
record._2());
       if (shuffleBlockInfos != null && !shuffleBlockInfos.isEmpty()) {
         processShuffleBlockInfos(shuffleBlockInfos);
       }
diff --git a/docs/client_guide/spark_client_guide.md 
b/docs/client_guide/spark_client_guide.md
index dc6f9a9d9..e7d57141e 100644
--- a/docs/client_guide/spark_client_guide.md
+++ b/docs/client_guide/spark_client_guide.md
@@ -165,4 +165,16 @@ spark.rss.client.reassign.enabled                  true
 spark.rss.client.reassign.maxReassignServerNum     10
 # The block retry max times when partition reassign is enabled. 
 spark.rss.client.reassign.blockRetryMaxTimes       1
-```
\ No newline at end of file
+```
+
+### Map side combine
+
+Map side combine is a feature for rdd aggregation operators that combines the 
shuffle data on map side before sending it to the shuffle server, which can 
reduce the amount of data transmitted and the pressure on the shuffle server.
+
+We can enable this feature by using the following configuration:
+
+| Property Name                           | Default | Description              
                             |
+|-----------------------------------------|---------|-------------------------------------------------------|
+| spark.rss.client.mapSideCombine.enabled | false   | Whether to enable map 
side combine of shuffle writer. |
+
+**Note**: Map side combine will handle entire map side shuffle write data, 
which may cause data spills and delay shuffle writes.
diff --git 
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/WriteAndReadMetricsTest.java
 
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/WriteAndReadMetricsTest.java
index ec03ddbc9..b48de6435 100644
--- 
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/WriteAndReadMetricsTest.java
+++ 
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/WriteAndReadMetricsTest.java
@@ -21,15 +21,14 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 
-import org.apache.spark.executor.TaskMetrics;
-import org.apache.spark.scheduler.SparkListener;
-import org.apache.spark.scheduler.SparkListenerTaskEnd;
 import org.apache.spark.sql.Dataset;
 import org.apache.spark.sql.Row;
 import org.apache.spark.sql.SparkSession;
 import org.apache.spark.sql.functions;
 import org.junit.jupiter.api.Test;
 
+import org.apache.uniffle.test.listener.WriteAndReadMetricsSparkListener;
+
 public class WriteAndReadMetricsTest extends SimpleTestBase {
 
   @Test
@@ -63,6 +62,7 @@ public class WriteAndReadMetricsTest extends SimpleTestBase {
 
     // take a rest to make sure all task metrics are updated before read 
stageData
     Thread.sleep(100);
+
     for (int stageId : 
spark.sparkContext().statusTracker().getJobInfo(0).get().stageIds()) {
       long writeRecords = listener.getWriteRecords(stageId);
       long readRecords = listener.getReadRecords(stageId);
@@ -72,32 +72,4 @@ public class WriteAndReadMetricsTest extends SimpleTestBase {
 
     return result;
   }
-
-  private static class WriteAndReadMetricsSparkListener extends SparkListener {
-    private HashMap<Integer, Long> stageIdToWriteRecords = new HashMap<>();
-    private HashMap<Integer, Long> stageIdToReadRecords = new HashMap<>();
-
-    @Override
-    public void onTaskEnd(SparkListenerTaskEnd event) {
-      int stageId = event.stageId();
-      TaskMetrics taskMetrics = event.taskMetrics();
-      if (taskMetrics != null) {
-        long writeRecords = taskMetrics.shuffleWriteMetrics().recordsWritten();
-        long readRecords = taskMetrics.shuffleReadMetrics().recordsRead();
-        // Accumulate writeRecords and readRecords for the given stageId
-        stageIdToWriteRecords.put(
-            stageId, stageIdToWriteRecords.getOrDefault(stageId, 0L) + 
writeRecords);
-        stageIdToReadRecords.put(
-            stageId, stageIdToReadRecords.getOrDefault(stageId, 0L) + 
readRecords);
-      }
-    }
-
-    public long getWriteRecords(int stageId) {
-      return stageIdToWriteRecords.getOrDefault(stageId, 0L);
-    }
-
-    public long getReadRecords(int stageId) {
-      return stageIdToReadRecords.getOrDefault(stageId, 0L);
-    }
-  }
 }
diff --git 
a/integration-test/spark-common/src/test/java/org/apache/uniffle/test/listener/WriteAndReadMetricsSparkListener.java
 
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/listener/WriteAndReadMetricsSparkListener.java
new file mode 100644
index 000000000..0148ff6ac
--- /dev/null
+++ 
b/integration-test/spark-common/src/test/java/org/apache/uniffle/test/listener/WriteAndReadMetricsSparkListener.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.test.listener;
+
+import java.util.HashMap;
+
+import org.apache.spark.executor.TaskMetrics;
+import org.apache.spark.scheduler.SparkListener;
+import org.apache.spark.scheduler.SparkListenerTaskEnd;
+
+public class WriteAndReadMetricsSparkListener extends SparkListener {
+  private HashMap<Integer, Long> stageIdToWriteRecords = new HashMap<>();
+  private HashMap<Integer, Long> stageIdToReadRecords = new HashMap<>();
+
+  @Override
+  public void onTaskEnd(SparkListenerTaskEnd event) {
+    int stageId = event.stageId();
+    TaskMetrics taskMetrics = event.taskMetrics();
+    if (taskMetrics != null) {
+      long writeRecords = taskMetrics.shuffleWriteMetrics().recordsWritten();
+      long readRecords = taskMetrics.shuffleReadMetrics().recordsRead();
+      // Accumulate writeRecords and readRecords for the given stageId
+      stageIdToWriteRecords.put(
+          stageId, stageIdToWriteRecords.getOrDefault(stageId, 0L) + 
writeRecords);
+      stageIdToReadRecords.put(
+          stageId, stageIdToReadRecords.getOrDefault(stageId, 0L) + 
readRecords);
+    }
+  }
+
+  public long getWriteRecords(int stageId) {
+    return stageIdToWriteRecords.getOrDefault(stageId, 0L);
+  }
+
+  public long getReadRecords(int stageId) {
+    return stageIdToReadRecords.getOrDefault(stageId, 0L);
+  }
+}
diff --git 
a/integration-test/spark3/src/test/java/org/apache/uniffle/test/MapSideCombineTest.java
 
b/integration-test/spark3/src/test/java/org/apache/uniffle/test/MapSideCombineTest.java
new file mode 100644
index 000000000..e2aac274a
--- /dev/null
+++ 
b/integration-test/spark3/src/test/java/org/apache/uniffle/test/MapSideCombineTest.java
@@ -0,0 +1,150 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.test;
+
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.TimeoutException;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+import scala.Tuple2;
+
+import com.google.common.collect.Maps;
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaPairRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.shuffle.RssSparkConfig;
+import org.apache.spark.sql.SparkSession;
+import org.junit.jupiter.api.BeforeAll;
+import org.junit.jupiter.api.Test;
+
+import org.apache.uniffle.common.StorageType;
+import org.apache.uniffle.common.rpc.ServerType;
+import org.apache.uniffle.coordinator.CoordinatorConf;
+import org.apache.uniffle.server.ShuffleServerConf;
+import org.apache.uniffle.test.listener.WriteAndReadMetricsSparkListener;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+
+public class MapSideCombineTest extends SparkIntegrationTestBase {
+
+  @BeforeAll
+  public static void setupServers() throws Exception {
+    CoordinatorConf coordinatorConf = getCoordinatorConf();
+    createCoordinatorServer(coordinatorConf);
+    ShuffleServerConf grpcShuffleServerConf = 
getShuffleServerConf(ServerType.GRPC);
+    createShuffleServer(grpcShuffleServerConf);
+    ShuffleServerConf nettyShuffleServerConf = 
getShuffleServerConf(ServerType.GRPC_NETTY);
+    createShuffleServer(nettyShuffleServerConf);
+    startServers();
+  }
+
+  @Override
+  public void updateSparkConfCustomer(SparkConf sparkConf) {
+    sparkConf.set(RssSparkConfig.RSS_STORAGE_TYPE.key(), 
StorageType.MEMORY_LOCALFILE_HDFS.name());
+    sparkConf.set(RssSparkConfig.RSS_REMOTE_STORAGE_PATH.key(), HDFS_URI + 
"rss/test");
+    sparkConf.set("spark." + 
RssSparkConfig.RSS_CLIENT_MAP_SIDE_COMBINE_ENABLED.key(), "true");
+  }
+
+  @Test
+  public void resultCompareTest() throws Exception {
+    run();
+  }
+
+  @Override
+  Map runTest(SparkSession spark, String fileName) throws Exception {
+    Thread.sleep(4000);
+
+    WriteAndReadMetricsSparkListener listener = new 
WriteAndReadMetricsSparkListener();
+    spark.sparkContext().addSparkListener(listener);
+
+    Map<String, Object> result = Maps.newHashMap();
+    JavaSparkContext sc = new JavaSparkContext(spark.sparkContext());
+    List<Integer> data = Stream.iterate(1, n -> n + 
1).limit(10000).collect(Collectors.toList());
+    JavaPairRDD<Integer, Integer> dataSourceRdd =
+        sc.parallelize(data, 10).mapToPair(x -> new Tuple2<>(x % 10, 1));
+    int jobId = -1;
+
+    // reduceByKey
+    checkMapSideCombine(
+        spark,
+        listener,
+        "reduceByKey",
+        dataSourceRdd.reduceByKey(Integer::sum).collectAsMap(),
+        result,
+        ++jobId);
+
+    // combineByKey
+    checkMapSideCombine(
+        spark,
+        listener,
+        "combineByKey",
+        dataSourceRdd.combineByKey(x -> 1, Integer::sum, 
Integer::sum).collectAsMap(),
+        result,
+        ++jobId);
+
+    // aggregateByKey
+    checkMapSideCombine(
+        spark,
+        listener,
+        "aggregateByKey",
+        dataSourceRdd.aggregateByKey(10, Integer::sum, 
Integer::sum).collectAsMap(),
+        result,
+        ++jobId);
+
+    // foldByKey
+    checkMapSideCombine(
+        spark,
+        listener,
+        "foldByKey",
+        dataSourceRdd.foldByKey(10, Integer::sum).collectAsMap(),
+        result,
+        ++jobId);
+
+    // countByKey
+    checkMapSideCombine(spark, listener, "countByKey", 
dataSourceRdd.countByKey(), result, ++jobId);
+
+    return result;
+  }
+
+  private <K, V> void checkMapSideCombine(
+      SparkSession spark,
+      WriteAndReadMetricsSparkListener listener,
+      String method,
+      Map<K, V> rddResult,
+      Map<String, Object> result,
+      int jobId)
+      throws TimeoutException {
+    rddResult.forEach((key, value) -> result.put(method + "-result-value-" + 
key, value));
+
+    spark.sparkContext().listenerBus().waitUntilEmpty();
+
+    for (int stageId : 
spark.sparkContext().statusTracker().getJobInfo(jobId).get().stageIds()) {
+      long writeRecords = listener.getWriteRecords(stageId);
+      long readRecords = listener.getReadRecords(stageId);
+      result.put(stageId + "-write-records", writeRecords);
+      result.put(stageId + "-read-records", readRecords);
+    }
+
+    // each job has two stages, so each job start stageId = jobId * 2
+    int shuffleStageId = jobId * 2;
+    // check map side combine
+    assertEquals(100L, result.get(shuffleStageId + "-write-records"));
+  }
+}

Reply via email to