This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 77e6ab1f5 [#2494] feat(spark): Overlapping compression to avoid block 
shuffle writing (#2511)
77e6ab1f5 is described below

commit 77e6ab1f540057b123ed0374ed598b035b3b6906
Author: Junfan Zhang <[email protected]>
AuthorDate: Thu Jul 3 11:02:21 2025 +0800

    [#2494] feat(spark): Overlapping compression to avoid block shuffle writing 
(#2511)
    
    ### What changes were proposed in this pull request?
    
    Support overlapping compression to avoid block shuffle writing
    
    ### Why are the changes needed?
    
    for #2494 .
    
    After applying the proposed improvements, the client observed that shuffle 
performance doubled in a 100GB Terasort benchmark. Specifically, the overall 
shuffle speed improved by over 100%, significantly reducing the job’s runtime 
and demonstrating the effectiveness of the optimizations in large-scale sorting 
workloads.
    
    
![image](https://github.com/user-attachments/assets/7e9ca515-516a-4f4c-8022-ca6e04de397a)
    
    
![image](https://github.com/user-attachments/assets/9cc2a4ea-a7c9-4a90-a54c-9909a7710791)
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. More configs and options will be added after this feature is stable
    
    ### How was this patch tested?
    
    Unit tests
---
 .../org/apache/spark/shuffle/RssSparkConfig.java   |   6 ++
 .../spark/shuffle/writer/WriteBufferManager.java   |  66 ++++++++++++-
 .../spark/shuffle/writer/RssShuffleWriter.java     |  12 ++-
 .../uniffle/common/DeferredCompressedBlock.java    | 106 +++++++++++++++++++++
 .../apache/uniffle/common/ShuffleBlockInfo.java    |  29 ++++--
 .../common/DeferredCompressedBlockTest.java        |  57 +++++++++++
 .../uniffle/test/CompressionOverlappingTest.java   |  94 ++++++++++++++++++
 7 files changed, 361 insertions(+), 9 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
index 79baab59c..66b964ac9 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/RssSparkConfig.java
@@ -39,6 +39,12 @@ import org.apache.uniffle.common.config.RssConf;
 
 public class RssSparkConfig {
 
+  public static final ConfigOption<Boolean> 
RSS_WRITE_OVERLAPPING_COMPRESSION_ENABLED =
+      ConfigOptions.key("rss.client.write.overlappingCompressionEnable")
+          .booleanType()
+          .defaultValue(false)
+          .withDescription("Whether to overlapping compress shuffle blocks.");
+
   public static final ConfigOption<Boolean> 
RSS_READ_REORDER_MULTI_SERVERS_ENABLED =
       ConfigOptions.key("rss.client.read.reorderMultiServersEnable")
           .booleanType()
diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
index 39b82e496..31d32e9b9 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/WriteBufferManager.java
@@ -48,6 +48,7 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
 import org.apache.uniffle.client.common.ShuffleServerPushCostTracker;
+import org.apache.uniffle.common.DeferredCompressedBlock;
 import org.apache.uniffle.common.ShuffleBlockInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
 import org.apache.uniffle.common.compression.Codec;
@@ -57,6 +58,8 @@ import org.apache.uniffle.common.exception.RssException;
 import org.apache.uniffle.common.util.BlockIdLayout;
 import org.apache.uniffle.common.util.ChecksumUtils;
 
+import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_WRITE_OVERLAPPING_COMPRESSION_ENABLED;
+
 public class WriteBufferManager extends MemoryConsumer {
 
   private static final Logger LOG = 
LoggerFactory.getLogger(WriteBufferManager.class);
@@ -108,6 +111,7 @@ public class WriteBufferManager extends MemoryConsumer {
   private Function<Integer, List<ShuffleServerInfo>> 
partitionAssignmentRetrieveFunc;
   private int stageAttemptNumber;
   private ShuffleServerPushCostTracker shuffleServerPushCostTracker;
+  private boolean overlappingCompressionEnabled;
 
   public WriteBufferManager(
       int shuffleId,
@@ -183,6 +187,8 @@ public class WriteBufferManager extends MemoryConsumer {
     this.requireMemoryInterval = 
bufferManagerOptions.getRequireMemoryInterval();
     this.requireMemoryRetryMax = 
bufferManagerOptions.getRequireMemoryRetryMax();
     this.arrayOutputStream = new 
WrappedByteArrayOutputStream(serializerBufferSize);
+    this.overlappingCompressionEnabled =
+        rssConf.getBoolean(RSS_WRITE_OVERLAPPING_COMPRESSION_ENABLED);
     // in columnar shuffle, the serializer here is never used
     this.isRowBased = rssConf.getBoolean(RssSparkConfig.RSS_ROW_BASED);
     if (isRowBased) {
@@ -420,8 +426,59 @@ public class WriteBufferManager extends MemoryConsumer {
     return result;
   }
 
+  protected ShuffleBlockInfo createDeferredCompressedBlock(
+      int partitionId, WriterBuffer writerBuffer) {
+    byte[] data = writerBuffer.getData();
+    final int uncompressLength = data.length;
+    final int memoryUsed = writerBuffer.getMemoryUsed();
+
+    this.blockCounter.incrementAndGet();
+    this.uncompressedDataLen += uncompressLength;
+    this.inSendListBytes.addAndGet(memoryUsed);
+
+    final long blockId =
+        blockIdLayout.getBlockId(getNextSeqNo(partitionId), partitionId, 
taskAttemptId);
+
+    Function<DeferredCompressedBlock, DeferredCompressedBlock> rebuildFunction 
=
+        block -> {
+          byte[] compressed = data;
+          if (codec.isPresent()) {
+            long start = System.currentTimeMillis();
+            compressed = codec.get().compress(data);
+            this.compressTime += System.currentTimeMillis() - start;
+          }
+          this.compressedDataLen += compressed.length;
+          this.shuffleWriteMetrics.incBytesWritten(compressed.length);
+          final long crc32 = ChecksumUtils.getCrc32(compressed);
+
+          block.reset(compressed, compressed.length, crc32);
+          return block;
+        };
+
+    int estimatedCompressedSize = data.length;
+    if (codec.isPresent()) {
+      estimatedCompressedSize = codec.get().maxCompressedLength(data.length);
+    }
+
+    return new DeferredCompressedBlock(
+        shuffleId,
+        partitionId,
+        blockId,
+        partitionAssignmentRetrieveFunc.apply(partitionId),
+        uncompressLength,
+        memoryUsed,
+        taskAttemptId,
+        partitionAssignmentRetrieveFunc,
+        rebuildFunction,
+        estimatedCompressedSize);
+  }
+
   // transform records to shuffleBlock
   protected ShuffleBlockInfo createShuffleBlock(int partitionId, WriterBuffer 
wb) {
+    if (overlappingCompressionEnabled) {
+      return createDeferredCompressedBlock(partitionId, wb);
+    }
+
     byte[] data = wb.getData();
     final int uncompressLength = data.length;
     byte[] compressed = data;
@@ -516,13 +573,20 @@ public class WriteBufferManager extends MemoryConsumer {
     block.getData().release();
   }
 
+  private int getBlockLayoutLength(ShuffleBlockInfo block) {
+    if (block instanceof DeferredCompressedBlock) {
+      return ((DeferredCompressedBlock) block).getEstimatedLayoutSize();
+    }
+    return block.getSize();
+  }
+
   public List<AddBlockEvent> buildBlockEvents(List<ShuffleBlockInfo> 
shuffleBlockInfoList) {
     long totalSize = 0;
     List<AddBlockEvent> events = new ArrayList<>();
     List<ShuffleBlockInfo> shuffleBlockInfosPerEvent = Lists.newArrayList();
     for (ShuffleBlockInfo sbi : shuffleBlockInfoList) {
       sbi.withCompletionCallback((block, isSuccessful) -> 
this.releaseBlockResource(block));
-      totalSize += sbi.getSize();
+      totalSize += getBlockLayoutLength(sbi);
       shuffleBlockInfosPerEvent.add(sbi);
       // split shuffle data according to the size
       if (totalSize > sendSizeLimit) {
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index a96a29b69..761ca368b 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -79,6 +79,7 @@ import 
org.apache.uniffle.client.request.RssReportShuffleWriteMetricRequest;
 import 
org.apache.uniffle.client.response.RssReassignOnBlockSendFailureResponse;
 import org.apache.uniffle.client.response.RssReportShuffleWriteFailureResponse;
 import org.apache.uniffle.client.response.RssReportShuffleWriteMetricResponse;
+import org.apache.uniffle.common.DeferredCompressedBlock;
 import org.apache.uniffle.common.ReceivingFailureServer;
 import org.apache.uniffle.common.ShuffleBlockInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
@@ -477,7 +478,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
                               shuffleServerInfo, k -> Maps.newHashMap());
                       pToBlockIds.computeIfAbsent(partitionId, v -> 
Sets.newHashSet()).add(blockId);
                     });
-            partitionLengths[partitionId] += sbi.getLength();
+            partitionLengths[partitionId] += getBlockLength(sbi);
           });
       return postBlockEvent(shuffleBlockInfoList);
     }
@@ -851,10 +852,17 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
                     .get(s)
                     .get(block.getPartitionId())
                     .remove(block.getBlockId()));
-    partitionLengths[block.getPartitionId()] -= block.getLength();
+    partitionLengths[block.getPartitionId()] -= getBlockLength(block);
     blockIds.remove(block.getBlockId());
   }
 
+  private long getBlockLength(ShuffleBlockInfo block) {
+    if (block instanceof DeferredCompressedBlock) {
+      return block.getUncompressLength();
+    }
+    return block.getLength();
+  }
+
   @VisibleForTesting
   protected void sendCommit() {
     ExecutorService executor = Executors.newSingleThreadExecutor();
diff --git 
a/common/src/main/java/org/apache/uniffle/common/DeferredCompressedBlock.java 
b/common/src/main/java/org/apache/uniffle/common/DeferredCompressedBlock.java
new file mode 100644
index 000000000..4fcbd11f9
--- /dev/null
+++ 
b/common/src/main/java/org/apache/uniffle/common/DeferredCompressedBlock.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common;
+
+import java.util.List;
+import java.util.function.Function;
+
+import io.netty.buffer.ByteBuf;
+import io.netty.buffer.Unpooled;
+
+/**
+ * This class is to deferred compress the block data to avoid blocking the 
main thread progress. And
+ * so, the below override methods should be invoked in the background thread. 
Because it will
+ * trigger the underlying the compression initialization for the data.
+ */
+public class DeferredCompressedBlock extends ShuffleBlockInfo {
+  private final Function<DeferredCompressedBlock, DeferredCompressedBlock> 
rebuildFunction;
+  private int estimatedCompressedSize;
+  private boolean isInitialized = false;
+
+  public DeferredCompressedBlock(
+      int shuffleId,
+      int partitionId,
+      long blockId,
+      List<ShuffleServerInfo> shuffleServerInfos,
+      int uncompressLength,
+      long freeMemory,
+      long taskAttemptId,
+      Function<Integer, List<ShuffleServerInfo>> 
partitionAssignmentRetrieveFunc,
+      Function<DeferredCompressedBlock, DeferredCompressedBlock> 
rebuildFunction,
+      int estimatedCompressedSize) {
+    super(
+        shuffleId,
+        partitionId,
+        blockId,
+        shuffleServerInfos,
+        uncompressLength,
+        freeMemory,
+        taskAttemptId,
+        partitionAssignmentRetrieveFunc);
+    this.rebuildFunction = rebuildFunction;
+    this.estimatedCompressedSize = estimatedCompressedSize;
+  }
+
+  public void reset(byte[] data, int length, long crc) {
+    super.length = length;
+    super.crc = crc;
+    super.data = Unpooled.wrappedBuffer(data);
+  }
+
+  private void initialize() {
+    if (!isInitialized) {
+      rebuildFunction.apply(this);
+      isInitialized = true;
+    }
+  }
+
+  public int getEstimatedLayoutSize() {
+    return estimatedCompressedSize + 3 * 8 + 2 * 4;
+  }
+
+  @Override
+  public int getLength() {
+    initialize();
+    return super.getLength();
+  }
+
+  @Override
+  public int getSize() {
+    initialize();
+    return super.getSize();
+  }
+
+  @Override
+  public long getCrc() {
+    initialize();
+    return super.getCrc();
+  }
+
+  @Override
+  public ByteBuf getData() {
+    initialize();
+    return super.getData();
+  }
+
+  @Override
+  public synchronized void copyDataTo(ByteBuf to) {
+    initialize();
+    super.copyDataTo(to);
+  }
+}
diff --git 
a/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java 
b/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java
index a38e9d206..c429ea7a6 100644
--- a/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java
+++ b/common/src/main/java/org/apache/uniffle/common/ShuffleBlockInfo.java
@@ -26,23 +26,21 @@ import io.netty.buffer.Unpooled;
 import org.apache.uniffle.common.util.ByteBufUtils;
 
 public class ShuffleBlockInfo {
-
   private int partitionId;
   private long blockId;
-  private int length;
   private int shuffleId;
-  private long crc;
   private long taskAttemptId;
-  private ByteBuf data;
   private List<ShuffleServerInfo> shuffleServerInfos;
   private int uncompressLength;
   private long freeMemory;
   private int retryCnt = 0;
-
   private transient BlockCompletionCallback completionCallback;
-
   private Function<Integer, List<ShuffleServerInfo>> 
partitionAssignmentRetrieveFunc;
 
+  protected int length;
+  protected long crc;
+  protected ByteBuf data;
+
   public ShuffleBlockInfo(
       int shuffleId,
       int partitionId,
@@ -69,6 +67,25 @@ public class ShuffleBlockInfo {
     this.partitionAssignmentRetrieveFunc = partitionAssignmentRetrieveFunc;
   }
 
+  protected ShuffleBlockInfo(
+      int shuffleId,
+      int partitionId,
+      long blockId,
+      List<ShuffleServerInfo> shuffleServerInfos,
+      int uncompressLength,
+      long freeMemory,
+      long taskAttemptId,
+      Function<Integer, List<ShuffleServerInfo>> 
partitionAssignmentRetrieveFunc) {
+    this.shuffleId = shuffleId;
+    this.partitionId = partitionId;
+    this.blockId = blockId;
+    this.shuffleServerInfos = shuffleServerInfos;
+    this.uncompressLength = uncompressLength;
+    this.freeMemory = freeMemory;
+    this.taskAttemptId = taskAttemptId;
+    this.partitionAssignmentRetrieveFunc = partitionAssignmentRetrieveFunc;
+  }
+
   public ShuffleBlockInfo(
       int shuffleId,
       int partitionId,
diff --git 
a/common/src/test/java/org/apache/uniffle/common/DeferredCompressedBlockTest.java
 
b/common/src/test/java/org/apache/uniffle/common/DeferredCompressedBlockTest.java
new file mode 100644
index 000000000..5ec2e851d
--- /dev/null
+++ 
b/common/src/test/java/org/apache/uniffle/common/DeferredCompressedBlockTest.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.common;
+
+import java.util.concurrent.atomic.AtomicBoolean;
+
+import org.junit.jupiter.api.Test;
+
+import static org.junit.jupiter.api.Assertions.assertFalse;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+public class DeferredCompressedBlockTest {
+
+  @Test
+  public void testDeferredCompressedBlock() {
+    AtomicBoolean isInitialized = new AtomicBoolean(false);
+    DeferredCompressedBlock block =
+        new DeferredCompressedBlock(
+            1,
+            1,
+            1,
+            null,
+            0,
+            1,
+            1,
+            null,
+            deferredCompressedBlock -> {
+              isInitialized.set(true);
+              deferredCompressedBlock.reset(new byte[10], 10, 10);
+              return deferredCompressedBlock;
+            },
+            10);
+
+    // case1: some params accessing won't trigger initialization
+    block.getBlockId();
+    assertFalse(isInitialized.get());
+
+    // case2
+    block.getLength();
+    assertTrue(isInitialized.get());
+  }
+}
diff --git 
a/integration-test/spark3/src/test/java/org/apache/uniffle/test/CompressionOverlappingTest.java
 
b/integration-test/spark3/src/test/java/org/apache/uniffle/test/CompressionOverlappingTest.java
new file mode 100644
index 000000000..c534969ed
--- /dev/null
+++ 
b/integration-test/spark3/src/test/java/org/apache/uniffle/test/CompressionOverlappingTest.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.uniffle.test;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.Map;
+
+import com.google.common.collect.Maps;
+import org.apache.spark.SparkConf;
+import org.apache.spark.shuffle.RssSparkConfig;
+import org.junit.jupiter.api.BeforeAll;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.rpc.ServerType;
+import org.apache.uniffle.coordinator.CoordinatorConf;
+import org.apache.uniffle.server.ShuffleServerConf;
+import org.apache.uniffle.storage.util.StorageType;
+
+import static 
org.apache.spark.shuffle.RssSparkConfig.RSS_WRITE_OVERLAPPING_COMPRESSION_ENABLED;
+
+public class CompressionOverlappingTest extends SparkSQLTest {
+  private static final Logger LOGGER = 
LoggerFactory.getLogger(CompressionOverlappingTest.class);
+
+  @BeforeAll
+  public static void setupServers() throws Exception {
+    LOGGER.info("Setup servers");
+
+    // for coordinator
+    CoordinatorConf coordinatorConf = coordinatorConfWithoutPort();
+    coordinatorConf.setLong("rss.coordinator.app.expired", 5000);
+    Map<String, String> dynamicConf = Maps.newHashMap();
+    dynamicConf.put(RssSparkConfig.RSS_STORAGE_TYPE.key(), 
StorageType.MEMORY_LOCALFILE.name());
+    addDynamicConf(coordinatorConf, dynamicConf);
+    storeCoordinatorConf(coordinatorConf);
+
+    // starting 3 nodes with grpc
+    for (int i = 0; i < 3; i++) {
+      storeShuffleServerConf(buildShuffleServerConf(ServerType.GRPC, i));
+    }
+    // starting 3 nodes with grpc-netty
+    for (int i = 0; i < 3; i++) {
+      storeShuffleServerConf(buildShuffleServerConf(ServerType.GRPC_NETTY, i));
+    }
+    startServersWithRandomPorts();
+  }
+
+  private static ShuffleServerConf buildShuffleServerConf(ServerType 
serverType, int index)
+      throws IOException {
+    Path tempDir = Files.createTempDirectory(serverType + "-" + index);
+    String dataPath = tempDir.toAbsolutePath().toString();
+
+    ShuffleServerConf shuffleServerConf = shuffleServerConfWithoutPort(0, 
null, serverType);
+    shuffleServerConf.setLong("rss.server.heartbeat.interval", 5000);
+    shuffleServerConf.setLong("rss.server.app.expired.withoutHeartbeat", 4000);
+    shuffleServerConf.setString("rss.storage.basePath", dataPath);
+    shuffleServerConf.setString("rss.storage.type", 
StorageType.MEMORY_LOCALFILE.name());
+    return shuffleServerConf;
+  }
+
+  @Override
+  public void updateSparkConfCustomer(SparkConf sparkConf) {
+    sparkConf.set("spark.sql.shuffle.partitions", "4");
+    String overlappingOptionKey = 
RSS_WRITE_OVERLAPPING_COMPRESSION_ENABLED.key();
+    sparkConf.set("spark." + overlappingOptionKey, "true");
+  }
+
+  @Override
+  public void updateRssStorage(SparkConf sparkConf) {
+    // ignore
+  }
+
+  @Override
+  public void checkShuffleData() throws Exception {
+    // ignore
+  }
+}

Reply via email to