This is an automated email from the ASF dual-hosted git repository.
zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/uniffle.git
The following commit(s) were added to refs/heads/master by this push:
new 77e6ab1f5 [#2494] feat(spark): Overlapping compression to avoid block
shuffle writing (#2511)
77e6ab1f5 is described below
commit 77e6ab1f540057b123ed0374ed598b035b3b6906
Author: Junfan Zhang <[email protected]>
AuthorDate: Thu Jul 3 11:02:21 2025 +0800
[#2494] feat(spark): Overlapping compression to avoid block shuffle writing
(#2511)
### What changes were proposed in this pull request?
Support overlapping compression to avoid block shuffle writing
### Why are the changes needed?
for #2494 .
After applying the proposed improvements, the client observed that shuffle
performance doubled in a 100GB Terasort benchmark. Specifically, the overall
shuffle speed improved by over 100%, significantly reducing the job’s runtime
and demonstrating the effectiveness of the optimizations in large-scale sorting
workloads.


### Does this PR introduce _any_ user-facing change?
Yes. More configs and options will be added after this feature is stable
### How was this patch tested?
Unit tests
---
.../org/apache/spark/shuffle/RssSparkConfig.java | 6 ++
.../spark/shuffle/writer/WriteBufferManager.java | 66 ++++++++++++-
.../spark/shuffle/writer/RssShuffleWriter.java | 12 ++-
.../uniffle/common/DeferredCompressedBlock.java | 106 +++++++++++++++++++++
.../apache/uniffle/common/ShuffleBlockInfo.java | 29 ++++--
.../common/DeferredCompressedBlockTest.java | 57 +++++++++++
.../uniffle/test/CompressionOverlappingTest.java | 94 ++++++++++++++++++
7 files changed, 361 insertions(+), 9 deletions(-)
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
index 79baab59c..66b964ac9 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
@@ -39,6 +39,12 @@ import org.apache.uniffle.common.config.RssConf;
public class RssSparkConfig {
+ public static final ConfigOption<Boolean>
RSS_WRITE_OVERLAPPING_COMPRESSION_ENABLED =
+ ConfigOptions.key("rss.client.write.overlappingCompressionEnable")
+ .booleanType()
+ .defaultValue(false)
+ .withDescription("Whether to overlapping compress shuffle blocks.");
+
public static final ConfigOption<Boolean>
RSS_READ_REORDER_MULTI_SERVERS_ENABLED =
ConfigOptions.key("rss.client.read.reorderMultiServersEnable")
.booleanType()
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
index 39b82e496..31d32e9b9 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
@@ -48,6 +48,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.uniffle.client.common.ShuffleServerPushCostTracker;
+import org.apache.uniffle.common.DeferredCompressedBlock;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
import org.apache.uniffle.common.compression.Codec;
@@ -57,6 +58,8 @@ import org.apache.uniffle.common.exception.RssException;
import org.apache.uniffle.common.util.BlockIdLayout;
import org.apache.uniffle.common.util.ChecksumUtils;
+import static
org.apache.spark.shuffle.RssSparkConfig.RSS_WRITE_OVERLAPPING_COMPRESSION_ENABLED;
+
public class WriteBufferManager extends MemoryConsumer {
private static final Logger LOG =
LoggerFactory.getLogger(WriteBufferManager.class);
@@ -108,6 +111,7 @@ public class WriteBufferManager extends MemoryConsumer {
private Function<Integer, List<ShuffleServerInfo>>
partitionAssignmentRetrieveFunc;
private int stageAttemptNumber;
private ShuffleServerPushCostTracker shuffleServerPushCostTracker;
+ private boolean overlappingCompressionEnabled;
public WriteBufferManager(
int shuffleId,
@@ -183,6 +187,8 @@ public class WriteBufferManager extends MemoryConsumer {
this.requireMemoryInterval =
bufferManagerOptions.getRequireMemoryInterval();
this.requireMemoryRetryMax =
bufferManagerOptions.getRequireMemoryRetryMax();
this.arrayOutputStream = new
WrappedByteArrayOutputStream(serializerBufferSize);
+ this.overlappingCompressionEnabled =
+ rssConf.getBoolean(RSS_WRITE_OVERLAPPING_COMPRESSION_ENABLED);
// in columnar shuffle, the serializer here is never used
this.isRowBased = rssConf.getBoolean(RssSparkConfig.RSS_ROW_BASED);
if (isRowBased) {
@@ -420,8 +426,59 @@ public class WriteBufferManager extends MemoryConsumer {
return result;
}
+ protected ShuffleBlockInfo createDeferredCompressedBlock(
+ int partitionId, WriterBuffer writerBuffer) {
+ byte[] data = writerBuffer.getData();
+ final int uncompressLength = data.length;
+ final int memoryUsed = writerBuffer.getMemoryUsed();
+
+ this.blockCounter.incrementAndGet();
+ this.uncompressedDataLen += uncompressLength;
+ this.inSendListBytes.addAndGet(memoryUsed);
+
+ final long blockId =
+ blockIdLayout.getBlockId(getNextSeqNo(partitionId), partitionId,
taskAttemptId);
+
+ Function<DeferredCompressedBlock, DeferredCompressedBlock> rebuildFunction
=
+ block -> {
+ byte[] compressed = data;
+ if (codec.isPresent()) {
+ long start = System.currentTimeMillis();
+ compressed = codec.get().compress(data);
+ this.compressTime += System.currentTimeMillis() - start;
+ }
+ this.compressedDataLen += compressed.length;
+ this.shuffleWriteMetrics.incBytesWritten(compressed.length);
+ final long crc32 = ChecksumUtils.getCrc32(compressed);
+
+ block.reset(compressed, compressed.length, crc32);
+ return block;
+ };
+
+ int estimatedCompressedSize = data.length;
+ if (codec.isPresent()) {
+ estimatedCompressedSize = codec.get().maxCompressedLength(data.length);
+ }
+
+ return new DeferredCompressedBlock(
+ shuffleId,
+ partitionId,
+ blockId,
+ partitionAssignmentRetrieveFunc.apply(partitionId),
+ uncompressLength,
+ memoryUsed,
+ taskAttemptId,
+ partitionAssignmentRetrieveFunc,
+ rebuildFunction,
+ estimatedCompressedSize);
+ }
+
// transform records to shuffleBlock
protected ShuffleBlockInfo createShuffleBlock(int partitionId, WriterBuffer
wb) {
+ if (overlappingCompressionEnabled) {
+ return createDeferredCompressedBlock(partitionId, wb);
+ }
+
byte[] data = wb.getData();
final int uncompressLength = data.length;
byte[] compressed = data;
@@ -516,13 +573,20 @@ public class WriteBufferManager extends MemoryConsumer {
block.getData().release();
}
+ private int getBlockLayoutLength(ShuffleBlockInfo block) {
+ if (block instanceof DeferredCompressedBlock) {
+ return ((DeferredCompressedBlock) block).getEstimatedLayoutSize();
+ }
+ return block.getSize();
+ }
+
public List<AddBlockEvent> buildBlockEvents(List<ShuffleBlockInfo>
shuffleBlockInfoList) {
long totalSize = 0;
List<AddBlockEvent> events = new ArrayList<>();
List<ShuffleBlockInfo> shuffleBlockInfosPerEvent = Lists.newArrayList();
for (ShuffleBlockInfo sbi : shuffleBlockInfoList) {
sbi.withCompletionCallback((block, isSuccessful) ->
this.releaseBlockResource(block));
- totalSize += sbi.getSize();
+ totalSize += getBlockLayoutLength(sbi);
shuffleBlockInfosPerEvent.add(sbi);
// split shuffle data according to the size
if (totalSize > sendSizeLimit) {
diff --git
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index a96a29b69..761ca368b 100644
---
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -79,6 +79,7 @@ import
org.apache.uniffle.client.request.RssReportShuffleWriteMetricRequest;
import
org.apache.uniffle.client.response.RssReassignOnBlockSendFailureResponse;
import org.apache.uniffle.client.response.RssReportShuffleWriteFailureResponse;
import org.apache.uniffle.client.response.RssReportShuffleWriteMetricResponse;
+import org.apache.uniffle.common.DeferredCompressedBlock;
import org.apache.uniffle.common.ReceivingFailureServer;
import org.apache.uniffle.common.ShuffleBlockInfo;
import org.apache.uniffle.common.ShuffleServerInfo;
@@ -477,7 +478,7 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
shuffleServerInfo, k -> Maps.newHashMap());
pToBlockIds.computeIfAbsent(partitionId, v ->
Sets.newHashSet()).add(blockId);
});
- partitionLengths[partitionId] += sbi.getLength();
+ partitionLengths[partitionId] += getBlockLength(sbi);
});
return postBlockEvent(shuffleBlockInfoList);
}
@@ -851,10 +852,17 @@ public class RssShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
.get(s)
.get(block.getPartitionId())
.remove(block.getBlockId()));
- partitionLengths[block.getPartitionId()] -= block.getLength();
+ partitionLengths[block.getPartitionId()] -= getBlockLength(block);
blockIds.remove(block.getBlockId());
}
+ private long getBlockLength(ShuffleBlockInfo block) {
+ if (block instanceof DeferredCompressedBlock) {
+ return block.getUncompressLength();
+ }
+ return block.getLength();
+ }
+
@VisibleForTesting
protected void sendCommit() {
ExecutorService executor = Executors.newSingleThreadExecutor();
diff --git
a/common/src/main/java/org/apache/uniffle/common/DeferredCompressedBlock.java
b/common/src/main/java/org/apache/uniffle/common/DeferredCompressedBlock.java
new file mode 100644
index 000000000..4fcbd11f9
--- /dev/null
+++
b/common/src/main/java/org/apache/uniffle/common/DeferredCompressedBlock.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common;
+
+import java.util.List;
+import java.util.function.Function;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+/**
+ * This class is to deferred compress the block data to avoid blocking the
main thread progress. And
+ * so, the below override methods should be invoked in the background thread.
Because it will
+ * trigger the underlying the compression initialization for the data.
+ */
+public class DeferredCompressedBlock extends ShuffleBlockInfo {
+ private final Function<DeferredCompressedBlock, DeferredCompressedBlock>
rebuildFunction;
+ private int estimatedCompressedSize;
+ private boolean isInitialized = false;
+
+ public DeferredCompressedBlock(
+ int shuffleId,
+ int partitionId,
+ long blockId,
+ List<ShuffleServerInfo> shuffleServerInfos,
+ int uncompressLength,
+ long freeMemory,
+ long taskAttemptId,
+ Function<Integer, List<ShuffleServerInfo>>
partitionAssignmentRetrieveFunc,
+ Function<DeferredCompressedBlock, DeferredCompressedBlock>
rebuildFunction,
+ int estimatedCompressedSize) {
+ super(
+ shuffleId,
+ partitionId,
+ blockId,
+ shuffleServerInfos,
+ uncompressLength,
+ freeMemory,
+ taskAttemptId,
+ partitionAssignmentRetrieveFunc);
+ this.rebuildFunction = rebuildFunction;
+ this.estimatedCompressedSize = estimatedCompressedSize;
+ }
+
+ public void reset(byte[] data, int length, long crc) {
+ super.length = length;
+ super.crc = crc;
+ super.data = Unpooled.wrappedBuffer(data);
+ }
+
+ private void initialize() {
+ if (!isInitialized) {
+ rebuildFunction.apply(this);
+ isInitialized = true;
+ }
+ }
+
+ public int getEstimatedLayoutSize() {
+ return estimatedCompressedSize + 3 * 8 + 2 * 4;
+ }
+
+ @Override
+ public int getLength() {
+ initialize();
+ return super.getLength();
+ }
+
+ @Override
+ public int getSize() {
+ initialize();
+ return super.getSize();
+ }
+
+ @Override
+ public long getCrc() {
+ initialize();
+ return super.getCrc();
+ }
+
+ @Override
+ public ByteBuf getData() {
+ initialize();
+ return super.getData();
+ }
+
+ @Override
+ public synchronized void copyDataTo(ByteBuf to) {
+ initialize();
+ super.copyDataTo(to);
+ }
+}
diff --git
a/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java
b/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java
index a38e9d206..c429ea7a6 100644
--- a/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java
+++ b/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java
@@ -26,23 +26,21 @@ import io.netty.buffer.Unpooled;
import org.apache.uniffle.common.util.ByteBufUtils;
public class ShuffleBlockInfo {
-
private int partitionId;
private long blockId;
- private int length;
private int shuffleId;
- private long crc;
private long taskAttemptId;
- private ByteBuf data;
private List<ShuffleServerInfo> shuffleServerInfos;
private int uncompressLength;
private long freeMemory;
private int retryCnt = 0;
-
private transient BlockCompletionCallback completionCallback;
-
private Function<Integer, List<ShuffleServerInfo>>
partitionAssignmentRetrieveFunc;
+ protected int length;
+ protected long crc;
+ protected ByteBuf data;
+
public ShuffleBlockInfo(
int shuffleId,
int partitionId,
@@ -69,6 +67,25 @@ public class ShuffleBlockInfo {
this.partitionAssignmentRetrieveFunc = partitionAssignmentRetrieveFunc;
}
+ protected ShuffleBlockInfo(
+ int shuffleId,
+ int partitionId,
+ long blockId,
+ List<ShuffleServerInfo> shuffleServerInfos,
+ int uncompressLength,
+ long freeMemory,
+ long taskAttemptId,
+ Function<Integer, List<ShuffleServerInfo>>
partitionAssignmentRetrieveFunc) {
+ this.shuffleId = shuffleId;
+ this.partitionId = partitionId;
+ this.blockId = blockId;
+ this.shuffleServerInfos = shuffleServerInfos;
+ this.uncompressLength = uncompressLength;
+ this.freeMemory = freeMemory;
+ this.taskAttemptId = taskAttemptId;
+ this.partitionAssignmentRetrieveFunc = partitionAssignmentRetrieveFunc;
+ }
+
public ShuffleBlockInfo(
int shuffleId,
int partitionId,
diff --git
a/common/src/test/java/org/apache/uniffle/common/DeferredCompressedBlockTest.java
b/common/src/test/java/org/apache/uniffle/common/DeferredCompressedBlockTest.java
new file mode 100644
index 000000000..5ec2e851d
--- /dev/null
+++
b/common/src/test/java/org/apache/uniffle/common/DeferredCompressedBlockTest.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common;
+
+import java.util.concurrent.atomic.AtomicBoolean;
+
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class DeferredCompressedBlockTest {
+
+ @Test
+ public void testDeferredCompressedBlock() {
+ AtomicBoolean isInitialized = new AtomicBoolean(false);
+ DeferredCompressedBlock block =
+ new DeferredCompressedBlock(
+ 1,
+ 1,
+ 1,
+ null,
+ 0,
+ 1,
+ 1,
+ null,
+ deferredCompressedBlock -> {
+ isInitialized.set(true);
+ deferredCompressedBlock.reset(new byte[10], 10, 10);
+ return deferredCompressedBlock;
+ },
+ 10);
+
+ // case1: some params accessing won't trigger initialization
+ block.getBlockId();
+ assertFalse(isInitialized.get());
+
+ // case2
+ block.getLength();
+ assertTrue(isInitialized.get());
+ }
+}
diff --git
a/integration-test/spark3/src/test/java/org/apache/uniffle/test/CompressionOverlappingTest.java
b/integration-test/spark3/src/test/java/org/apache/uniffle/test/CompressionOverlappingTest.java
new file mode 100644
index 000000000..c534969ed
--- /dev/null
+++
b/integration-test/spark3/src/test/java/org/apache/uniffle/test/CompressionOverlappingTest.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.test;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.Map;
+
+import com.google.common.collect.Maps;
+import org.apache.spark.SparkConf;
+import org.apache.spark.shuffle.RssSparkConfig;
+import org.junit.jupiter.api.BeforeAll;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.rpc.ServerType;
+import org.apache.uniffle.coordinator.CoordinatorConf;
+import org.apache.uniffle.server.ShuffleServerConf;
+import org.apache.uniffle.storage.util.StorageType;
+
+import static
org.apache.spark.shuffle.RssSparkConfig.RSS_WRITE_OVERLAPPING_COMPRESSION_ENABLED;
+
+public class CompressionOverlappingTest extends SparkSQLTest {
+ private static final Logger LOGGER =
LoggerFactory.getLogger(CompressionOverlappingTest.class);
+
+ @BeforeAll
+ public static void setupServers() throws Exception {
+ LOGGER.info("Setup servers");
+
+ // for coordinator
+ CoordinatorConf coordinatorConf = coordinatorConfWithoutPort();
+ coordinatorConf.setLong("rss.coordinator.app.expired", 5000);
+ Map<String, String> dynamicConf = Maps.newHashMap();
+ dynamicConf.put(RssSparkConfig.RSS_STORAGE_TYPE.key(),
StorageType.MEMORY_LOCALFILE.name());
+ addDynamicConf(coordinatorConf, dynamicConf);
+ storeCoordinatorConf(coordinatorConf);
+
+ // starting 3 nodes with grpc
+ for (int i = 0; i < 3; i++) {
+ storeShuffleServerConf(buildShuffleServerConf(ServerType.GRPC, i));
+ }
+ // starting 3 nodes with grpc-netty
+ for (int i = 0; i < 3; i++) {
+ storeShuffleServerConf(buildShuffleServerConf(ServerType.GRPC_NETTY, i));
+ }
+ startServersWithRandomPorts();
+ }
+
+ private static ShuffleServerConf buildShuffleServerConf(ServerType
serverType, int index)
+ throws IOException {
+ Path tempDir = Files.createTempDirectory(serverType + "-" + index);
+ String dataPath = tempDir.toAbsolutePath().toString();
+
+ ShuffleServerConf shuffleServerConf = shuffleServerConfWithoutPort(0,
null, serverType);
+ shuffleServerConf.setLong("rss.server.heartbeat.interval", 5000);
+ shuffleServerConf.setLong("rss.server.app.expired.withoutHeartbeat", 4000);
+ shuffleServerConf.setString("rss.storage.basePath", dataPath);
+ shuffleServerConf.setString("rss.storage.type",
StorageType.MEMORY_LOCALFILE.name());
+ return shuffleServerConf;
+ }
+
+ @Override
+ public void updateSparkConfCustomer(SparkConf sparkConf) {
+ sparkConf.set("spark.sql.shuffle.partitions", "4");
+ String overlappingOptionKey =
RSS_WRITE_OVERLAPPING_COMPRESSION_ENABLED.key();
+ sparkConf.set("spark." + overlappingOptionKey, "true");
+ }
+
+ @Override
+ public void updateRssStorage(SparkConf sparkConf) {
+ // ignore
+ }
+
+ @Override
+ public void checkShuffleData() throws Exception {
+ // ignore
+ }
+}