This is an automated email from the ASF dual-hosted git repository.
xianjingfeng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new 0a2cec93 [#615] improvement: Reduce task binary by removing
'partitionToServers' from RssShuffleHandle (#637)
0a2cec93 is described below
commit 0a2cec9386728700ca73df0974a6fd6674ba00a8
Author: jiafu zhang <[email protected]>
AuthorDate: Tue Mar 14 10:02:02 2023 +0800
[#615] improvement: Reduce task binary by removing 'partitionToServers'
from RssShuffleHandle (#637)
### What changes were proposed in this pull request?
move partition -> shuffle servers mapping from direct field of
RssShuffleHandle to a broadcast variable to reduce task binary size.
### Why are the changes needed?
to reduce task delay and task serialize/deserialize time by reduce task
binary size
### Does this PR introduce any user-facing change?
No.
### How was this patch tested?
tested with 10000 partitions shuffle. Task binary size reduced from more
than 670KB to less than 6KB.
tested with multiple shuffle stages in same job to verify ShuffleHandleInfo
cache logic
---
.../org/apache/spark/shuffle/RssShuffleHandle.java | 35 +++++--------
.../apache/spark/shuffle/RssSparkShuffleUtils.java | 36 +++++++++++++
...ssShuffleHandle.java => ShuffleHandleInfo.java} | 60 +++++++++-------------
.../apache/spark/shuffle/RssShuffleManager.java | 11 ++--
.../apache/spark/shuffle/RssShuffleManager.java | 13 +++--
5 files changed, 90 insertions(+), 65 deletions(-)
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssShuffleHandle.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssShuffleHandle.java
index 99f214ce..83b7db3e 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssShuffleHandle.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssShuffleHandle.java
@@ -21,8 +21,8 @@ import java.util.List;
import java.util.Map;
import java.util.Set;
-import com.google.common.collect.Sets;
import org.apache.spark.ShuffleDependency;
+import org.apache.spark.broadcast.Broadcast;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
@@ -32,29 +32,19 @@ public class RssShuffleHandle<K, V, C> extends
ShuffleHandle {
private String appId;
private int numMaps;
private ShuffleDependency<K, V, C> dependency;
- private Map<Integer, List<ShuffleServerInfo>> partitionToServers;
- // shuffle servers which is for store shuffle data
- private Set<ShuffleServerInfo> shuffleServersForData;
- // remoteStorage used for this job
- private RemoteStorageInfo remoteStorage;
+ private Broadcast<ShuffleHandleInfo> handlerInfoBd;
public RssShuffleHandle(
int shuffleId,
String appId,
int numMaps,
ShuffleDependency<K, V, C> dependency,
- Map<Integer, List<ShuffleServerInfo>> partitionToServers,
- RemoteStorageInfo remoteStorage) {
+ Broadcast<ShuffleHandleInfo> handlerInfoBd) {
super(shuffleId);
this.appId = appId;
this.numMaps = numMaps;
this.dependency = dependency;
- this.partitionToServers = partitionToServers;
- this.remoteStorage = remoteStorage;
- shuffleServersForData = Sets.newHashSet();
- for (List<ShuffleServerInfo> ssis : partitionToServers.values()) {
- shuffleServersForData.addAll(ssis);
- }
+ this.handlerInfoBd = handlerInfoBd;
}
public String getAppId() {
@@ -69,19 +59,20 @@ public class RssShuffleHandle<K, V, C> extends
ShuffleHandle {
return dependency;
}
- public Map<Integer, List<ShuffleServerInfo>> getPartitionToServers() {
- return partitionToServers;
- }
-
public int getShuffleId() {
return shuffleId();
}
- public Set<ShuffleServerInfo> getShuffleServersForData() {
- return shuffleServersForData;
+ public RemoteStorageInfo getRemoteStorage() {
+ return handlerInfoBd.value().getRemoteStorage();
}
- public RemoteStorageInfo getRemoteStorage() {
- return remoteStorage;
+ public Map<Integer, List<ShuffleServerInfo>> getPartitionToServers() {
+ return handlerInfoBd.value().getPartitionToServers();
}
+
+ public Set<ShuffleServerInfo> getShuffleServersForData() {
+ return handlerInfoBd.value().getShuffleServersForData();
+ }
+
}
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
index d7d68330..dabe0e4b 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkShuffleUtils.java
@@ -27,22 +27,31 @@ import java.util.Set;
import org.apache.commons.lang3.StringUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.spark.SparkConf;
+import org.apache.spark.SparkContext;
+import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.deploy.SparkHadoopUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import scala.reflect.ClassTag;
import org.apache.uniffle.client.api.CoordinatorClient;
import org.apache.uniffle.client.factory.CoordinatorClientFactory;
import org.apache.uniffle.client.util.ClientUtils;
import org.apache.uniffle.common.ClientType;
import org.apache.uniffle.common.RemoteStorageInfo;
+import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.Constants;
+
public class RssSparkShuffleUtils {
private static final Logger LOG =
LoggerFactory.getLogger(RssSparkShuffleUtils.class);
+ public static final ClassTag<ShuffleHandleInfo>
SHUFFLE_HANDLER_INFO_CLASS_TAG = scala.reflect.ClassTag$.MODULE$
+ .apply(ShuffleHandleInfo.class);
+ public static final ClassTag<byte[]> BYTE_ARRAY_CLASS_TAG =
scala.reflect.ClassTag$.MODULE$.apply(byte[].class);
+
public static Configuration newHadoopConfiguration(SparkConf sparkConf) {
SparkHadoopUtil util = new SparkHadoopUtil();
Configuration conf = util.newConfiguration(sparkConf);
@@ -190,4 +199,31 @@ public class RssSparkShuffleUtils {
int taskConcurrencyPerServer =
sparkConf.get(RssSparkConfig.RSS_ESTIMATE_TASK_CONCURRENCY_PER_SERVER);
return (int) Math.ceil(estimateTaskConcurrency * 1.0 /
taskConcurrencyPerServer);
}
+
+ /**
+ * Get current active {@link SparkContext}. It should be called inside
Driver since we don't mean to create any
+ * new {@link SparkContext} here.
+ *
+ * Note: We could use "SparkContext.getActive()" instead of
"SparkContext.getOrCreate()" if the "getActive" method
+ * is not declared as package private in Scala.
+ * @return Active SparkContext created by Driver.
+ */
+ public static SparkContext getActiveSparkContext() {
+ return SparkContext.getOrCreate();
+ }
+
+ /**
+ * create broadcast variable of {@link ShuffleHandleInfo}
+ *
+ * @param sc expose for easy unit-test
+ * @param shuffleId
+ * @param partitionToServers
+ * @param storageInfo
+ * @return Broadcast variable registered for auto cleanup
+ */
+ public static Broadcast<ShuffleHandleInfo>
broadcastShuffleHdlInfo(SparkContext sc, int shuffleId,
+ Map<Integer, List<ShuffleServerInfo>> partitionToServers,
RemoteStorageInfo storageInfo) {
+ ShuffleHandleInfo handleInfo = new ShuffleHandleInfo(shuffleId,
partitionToServers, storageInfo);
+ return sc.broadcast(handleInfo, SHUFFLE_HANDLER_INFO_CLASS_TAG);
+ }
}
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssShuffleHandle.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/ShuffleHandleInfo.java
similarity index 66%
copy from
client-spark/common/src/main/java/org/apache/spark/shuffle/RssShuffleHandle.java
copy to
client-spark/common/src/main/java/org/apache/spark/shuffle/ShuffleHandleInfo.java
index 99f214ce..f045db46 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssShuffleHandle.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/ShuffleHandleInfo.java
@@ -17,66 +17,52 @@
package org.apache.spark.shuffle;
+import java.io.Serializable;
+import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import com.google.common.collect.Sets;
-import org.apache.spark.ShuffleDependency;
import org.apache.uniffle.common.RemoteStorageInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
-public class RssShuffleHandle<K, V, C> extends ShuffleHandle {
+/**
+ * Class for holding,
+ * 1. partition ID -> shuffle servers mapping.
+ * 2. remote storage info
+ *
+ * It's to be broadcast to executors and referenced by shuffle tasks.
+ */
+public class ShuffleHandleInfo implements Serializable {
+
+ private int shuffleId;
- private String appId;
- private int numMaps;
- private ShuffleDependency<K, V, C> dependency;
private Map<Integer, List<ShuffleServerInfo>> partitionToServers;
// shuffle servers which is for store shuffle data
private Set<ShuffleServerInfo> shuffleServersForData;
// remoteStorage used for this job
private RemoteStorageInfo remoteStorage;
- public RssShuffleHandle(
- int shuffleId,
- String appId,
- int numMaps,
- ShuffleDependency<K, V, C> dependency,
- Map<Integer, List<ShuffleServerInfo>> partitionToServers,
- RemoteStorageInfo remoteStorage) {
- super(shuffleId);
- this.appId = appId;
- this.numMaps = numMaps;
- this.dependency = dependency;
+ public static final ShuffleHandleInfo EMPTY_HANDLE_INFO = new
ShuffleHandleInfo(-1, Collections.EMPTY_MAP,
+ RemoteStorageInfo.EMPTY_REMOTE_STORAGE);
+
+ public ShuffleHandleInfo(int shuffleId, Map<Integer,
List<ShuffleServerInfo>> partitionToServers,
+ RemoteStorageInfo storageInfo) {
+ this.shuffleId = shuffleId;
this.partitionToServers = partitionToServers;
- this.remoteStorage = remoteStorage;
- shuffleServersForData = Sets.newHashSet();
+ this.shuffleServersForData = Sets.newHashSet();
for (List<ShuffleServerInfo> ssis : partitionToServers.values()) {
- shuffleServersForData.addAll(ssis);
+ this.shuffleServersForData.addAll(ssis);
}
- }
-
- public String getAppId() {
- return appId;
- }
-
- public int getNumMaps() {
- return numMaps;
- }
-
- public ShuffleDependency<K, V, C> getDependency() {
- return dependency;
+ this.remoteStorage = storageInfo;
}
public Map<Integer, List<ShuffleServerInfo>> getPartitionToServers() {
return partitionToServers;
}
- public int getShuffleId() {
- return shuffleId();
- }
-
public Set<ShuffleServerInfo> getShuffleServersForData() {
return shuffleServersForData;
}
@@ -84,4 +70,8 @@ public class RssShuffleHandle<K, V, C> extends ShuffleHandle {
public RemoteStorageInfo getRemoteStorage() {
return remoteStorage;
}
+
+ public int getShuffleId() {
+ return shuffleId;
+ }
}
diff --git
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 9b14f52c..2ecb6f8f 100644
---
a/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark2/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -36,6 +36,7 @@ import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
+import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.shuffle.reader.RssShuffleReader;
import org.apache.spark.shuffle.writer.AddBlockEvent;
@@ -242,12 +243,14 @@ public class RssShuffleManager implements ShuffleManager {
if (dependency.partitioner().numPartitions() == 0) {
LOG.info("RegisterShuffle with ShuffleId[" + shuffleId + "],
partitionNum is 0, "
+ "return the empty RssShuffleHandle directly");
+ Broadcast<ShuffleHandleInfo> hdlInfoBd =
RssSparkShuffleUtils.broadcastShuffleHdlInfo(
+ RssSparkShuffleUtils.getActiveSparkContext(), shuffleId,
Collections.emptyMap(),
+ RemoteStorageInfo.EMPTY_REMOTE_STORAGE);
return new RssShuffleHandle<>(shuffleId,
appId,
dependency.rdd().getNumPartitions(),
dependency,
- Collections.emptyMap(),
- RemoteStorageInfo.EMPTY_REMOTE_STORAGE);
+ hdlInfoBd);
}
String storageType = sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key());
@@ -281,8 +284,10 @@ public class RssShuffleManager implements ShuffleManager {
startHeartbeat();
+ Broadcast<ShuffleHandleInfo> hdlInfoBd =
RssSparkShuffleUtils.broadcastShuffleHdlInfo(
+ RssSparkShuffleUtils.getActiveSparkContext(), shuffleId,
partitionToServers, remoteStorage);
LOG.info("RegisterShuffle with ShuffleId[" + shuffleId + "],
partitionNum[" + partitionToServers.size() + "]");
- return new RssShuffleHandle(shuffleId, appId, numMaps, dependency,
partitionToServers, remoteStorage);
+ return new RssShuffleHandle(shuffleId, appId, numMaps, dependency,
hdlInfoBd);
}
private void startHeartbeat() {
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
index 5b21ffda..e70026ac 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/RssShuffleManager.java
@@ -42,6 +42,7 @@ import org.apache.spark.ShuffleDependency;
import org.apache.spark.SparkConf;
import org.apache.spark.SparkEnv;
import org.apache.spark.TaskContext;
+import org.apache.spark.broadcast.Broadcast;
import org.apache.spark.executor.ShuffleReadMetrics;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.shuffle.reader.RssShuffleReader;
@@ -293,7 +294,6 @@ public class RssShuffleManager implements ShuffleManager {
heartBeatScheduledExecutorService = null;
}
-
// This method is called in Spark driver side,
// and Spark driver will make some decision according to coordinator,
// e.g. determining what RSS servers to use.
@@ -322,12 +322,14 @@ public class RssShuffleManager implements ShuffleManager {
if (dependency.partitioner().numPartitions() == 0) {
LOG.info("RegisterShuffle with ShuffleId[" + shuffleId + "],
partitionNum is 0, "
+ "return the empty RssShuffleHandle directly");
+ Broadcast<ShuffleHandleInfo> hdlInfoBd =
RssSparkShuffleUtils.broadcastShuffleHdlInfo(
+ RssSparkShuffleUtils.getActiveSparkContext(), shuffleId,
Collections.emptyMap(),
+ RemoteStorageInfo.EMPTY_REMOTE_STORAGE);
return new RssShuffleHandle<>(shuffleId,
id.get(),
dependency.rdd().getNumPartitions(),
dependency,
- Collections.emptyMap(),
- RemoteStorageInfo.EMPTY_REMOTE_STORAGE);
+ hdlInfoBd);
}
String storageType = sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key());
@@ -363,14 +365,15 @@ public class RssShuffleManager implements ShuffleManager {
}
startHeartbeat();
+ Broadcast<ShuffleHandleInfo> hdlInfoBd =
RssSparkShuffleUtils.broadcastShuffleHdlInfo(
+ RssSparkShuffleUtils.getActiveSparkContext(), shuffleId,
partitionToServers, remoteStorage);
LOG.info("RegisterShuffle with ShuffleId[" + shuffleId + "],
partitionNum[" + partitionToServers.size()
+ "], shuffleServerForResult: " + partitionToServers);
return new RssShuffleHandle<>(shuffleId,
id.get(),
dependency.rdd().getNumPartitions(),
dependency,
- partitionToServers,
- remoteStorage);
+ hdlInfoBd);
}
@Override