This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new 7f1586e9c [#2579] fix(spark): Correct partition length for overlapping 
compression (#2580)
7f1586e9c is described below

commit 7f1586e9ca1935cac54cc3f9efd2a2e98582a2ec
Author: Junfan Zhang <[email protected]>
AuthorDate: Mon Aug 18 14:41:19 2025 +0800

    [#2579] fix(spark): Correct partition length for overlapping compression 
(#2580)
    
    ### What changes were proposed in this pull request?
    
    This PR fixes the partition length calculation for overlapping compression. 
Previously, when overlapping compression was enabled, the partition length was 
recorded as the uncompressed length, which broke the default Spark semantics.
    This change aligns the behavior with ESS semantics and updates the 
partition length only after the event has been successfully processed.
    
    ### Why are the changes needed?
    
    To fix the incorrect semantic of partition length for overlapping 
compression, this will effect the AQE rules.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No.
    
    ### How was this patch tested?
    
    Existing tests.
---
 .../shuffle/writer/PartitionLengthStatistic.java   | 56 ++++++++++++++++++
 .../spark/shuffle/writer/RssShuffleWriter.java     | 39 +++++--------
 .../writer/PartitionLengthStatisticTest.java       | 67 ++++++++++++++++++++++
 3 files changed, 138 insertions(+), 24 deletions(-)

diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/PartitionLengthStatistic.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/PartitionLengthStatistic.java
new file mode 100644
index 000000000..3ead6052b
--- /dev/null
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/PartitionLengthStatistic.java
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.writer;
+
+import java.util.Arrays;
+import java.util.concurrent.atomic.AtomicLong;
+
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import org.apache.uniffle.common.ShuffleBlockInfo;
+import org.apache.uniffle.common.exception.RssException;
+
+public class PartitionLengthStatistic {
+  private static final Logger LOGGER = 
LoggerFactory.getLogger(PartitionLengthStatistic.class);
+  private final AtomicLong[] partitionLens;
+
+  public PartitionLengthStatistic(int numPartitions) {
+    this.partitionLens = new AtomicLong[numPartitions];
+    for (int i = 0; i < numPartitions; i++) {
+      partitionLens[i] = new AtomicLong(0);
+    }
+  }
+
+  public void inc(ShuffleBlockInfo block) {
+    int partitionId = block.getPartitionId();
+    if (partitionId >= partitionLens.length) {
+      throw new RssException(
+          "Partition ID "
+              + partitionId
+              + " is out of bounds (should be less than "
+              + partitionLens.length
+              + ")");
+    }
+    partitionLens[block.getPartitionId()].addAndGet(block.getLength());
+  }
+
+  public long[] toArray() {
+    return Arrays.stream(this.partitionLens).mapToLong(x -> x.get()).toArray();
+  }
+}
diff --git 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
index b6776fa2f..382643c52 100644
--- 
a/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
+++ 
b/client-spark/spark3/src/main/java/org/apache/spark/shuffle/writer/RssShuffleWriter.java
@@ -79,7 +79,6 @@ import 
org.apache.uniffle.client.request.RssReportShuffleWriteMetricRequest;
 import 
org.apache.uniffle.client.response.RssReassignOnBlockSendFailureResponse;
 import org.apache.uniffle.client.response.RssReportShuffleWriteFailureResponse;
 import org.apache.uniffle.client.response.RssReportShuffleWriteMetricResponse;
-import org.apache.uniffle.common.DeferredCompressedBlock;
 import org.apache.uniffle.common.ReceivingFailureServer;
 import org.apache.uniffle.common.ShuffleBlockInfo;
 import org.apache.uniffle.common.ShuffleServerInfo;
@@ -118,7 +117,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
   private Map<ShuffleServerInfo, Map<Integer, Set<Long>>> 
serverToPartitionToBlockIds;
   private final ShuffleWriteClient shuffleWriteClient;
   private final Set<ShuffleServerInfo> shuffleServersForData;
-  private final long[] partitionLengths;
+  private final PartitionLengthStatistic partitionLengthStatistic;
   // Gluten needs this variable
   protected final boolean isMemoryShuffleEnabled;
   private final Function<String, Boolean> taskFailureCallback;
@@ -220,8 +219,7 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     this.serverToPartitionToBlockIds = Maps.newHashMap();
     this.shuffleWriteClient = shuffleWriteClient;
     this.shuffleServersForData = shuffleHandleInfo.getServers();
-    this.partitionLengths = new long[partitioner.numPartitions()];
-    Arrays.fill(partitionLengths, 0);
+    this.partitionLengthStatistic = new 
PartitionLengthStatistic(partitioner.numPartitions());
     this.isMemoryShuffleEnabled =
         
isMemoryShuffleEnabled(sparkConf.get(RssSparkConfig.RSS_STORAGE_TYPE.key()));
     this.taskFailureCallback = taskFailureCallback;
@@ -478,7 +476,6 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
                               shuffleServerInfo, k -> Maps.newHashMap());
                       pToBlockIds.computeIfAbsent(partitionId, v -> 
Sets.newHashSet()).add(blockId);
                     });
-            partitionLengths[partitionId] += getBlockLength(sbi);
           });
       return postBlockEvent(shuffleBlockInfoList);
     }
@@ -489,16 +486,17 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
       List<ShuffleBlockInfo> shuffleBlockInfoList) {
     List<CompletableFuture<Long>> futures = new ArrayList<>();
     for (AddBlockEvent event : 
bufferManager.buildBlockEvents(shuffleBlockInfoList)) {
-      if (blockFailSentRetryEnabled) {
-        // do nothing if failed.
-        for (ShuffleBlockInfo block : event.getShuffleDataInfoList()) {
-          block.withCompletionCallback(
-              (completionBlock, isSuccessful) -> {
-                if (isSuccessful) {
-                  bufferManager.releaseBlockResource(completionBlock);
-                }
-              });
-        }
+      for (ShuffleBlockInfo block : event.getShuffleDataInfoList()) {
+        block.withCompletionCallback(
+            (b, isSuccessful) -> {
+              // If partition reassignment is enabled, the block is only 
released upon successful
+              // completion.
+              // Otherwise, the block is released immediately once completed.
+              if (!blockFailSentRetryEnabled || isSuccessful) {
+                bufferManager.releaseBlockResource(b);
+                partitionLengthStatistic.inc(b);
+              }
+            });
       }
       event.addCallback(
           () -> {
@@ -865,17 +863,9 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
                     .get(s)
                     .get(block.getPartitionId())
                     .remove(block.getBlockId()));
-    partitionLengths[block.getPartitionId()] -= getBlockLength(block);
     blockIds.remove(block.getBlockId());
   }
 
-  private long getBlockLength(ShuffleBlockInfo block) {
-    if (block instanceof DeferredCompressedBlock) {
-      return block.getUncompressLength();
-    }
-    return block.getLength();
-  }
-
   @VisibleForTesting
   protected void sendCommit() {
     ExecutorService executor = Executors.newSingleThreadExecutor();
@@ -946,7 +936,8 @@ public class RssShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
                 DUMMY_HOST,
                 DUMMY_PORT,
                 Option.apply(Long.toString(taskAttemptId)));
-        MapStatus mapStatus = MapStatus.apply(blockManagerId, 
partitionLengths, taskAttemptId);
+        MapStatus mapStatus =
+            MapStatus.apply(blockManagerId, 
partitionLengthStatistic.toArray(), taskAttemptId);
         return Option.apply(mapStatus);
       } else {
         return Option.empty();
diff --git 
a/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/PartitionLengthStatisticTest.java
 
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/PartitionLengthStatisticTest.java
new file mode 100644
index 000000000..66f3d58aa
--- /dev/null
+++ 
b/client-spark/spark3/src/test/java/org/apache/spark/shuffle/writer/PartitionLengthStatisticTest.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.writer;
+
+import io.netty.buffer.Unpooled;
+import org.junit.jupiter.api.Test;
+
+import org.apache.uniffle.common.ShuffleBlockInfo;
+
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.fail;
+
+public class PartitionLengthStatisticTest {
+
+  @Test
+  public void test() {
+    int numPartitions = 10;
+    PartitionLengthStatistic statistic = new 
PartitionLengthStatistic(numPartitions);
+    // case1
+    assertEquals(numPartitions, statistic.toArray().length);
+    assertEquals(0, statistic.toArray()[0]);
+
+    // case2
+    ShuffleBlockInfo blockInfo =
+        new ShuffleBlockInfo(
+            1, 1, 1, 100, 123, Unpooled.wrappedBuffer(new byte[100]).retain(), 
null, 5, 0, 1);
+    statistic.inc(blockInfo);
+    assertEquals(numPartitions, statistic.toArray().length);
+    assertEquals(0, statistic.toArray()[0]);
+    assertEquals(100, statistic.toArray()[1]);
+
+    // case3
+    blockInfo =
+        new ShuffleBlockInfo(
+            1,
+            numPartitions,
+            1,
+            100,
+            123,
+            Unpooled.wrappedBuffer(new byte[100]).retain(),
+            null,
+            5,
+            0,
+            1);
+    try {
+      statistic.inc(blockInfo);
+      fail();
+    } catch (Exception e) {
+      // ignore
+    }
+  }
+}

Reply via email to