This is an automated email from the ASF dual-hosted git repository.

zuston pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/uniffle.git


The following commit(s) were added to refs/heads/master by this push:
     new dac1065fe [#2411] fix(spark): Spill memory corresponding to 
successfully sent blocks (#2415)
dac1065fe is described below

commit dac1065fe465faeae59c70a119b52a7f203a4810
Author: summaryzb <[email protected]>
AuthorDate: Fri Mar 21 13:55:48 2025 +0800

    [#2411] fix(spark): Spill memory corresponding to successfully sent blocks 
(#2415)
    
    ### What changes were proposed in this pull request?
    As title
    
    ### Why are the changes needed?
    Before this pr, spark client spill more memory than actually did
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    UT
---
 .../java/org/apache/spark/shuffle/writer/DataPusher.java   | 14 ++++++++++----
 .../spark/shuffle/writer/WriteBufferManagerTest.java       | 10 ++++++++--
 2 files changed, 18 insertions(+), 6 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
index bdf0cf849..c55216d26 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/writer/DataPusher.java
@@ -100,10 +100,7 @@ public class DataPusher implements Closeable {
                 putFailedBlockSendTracker(
                     taskToFailedBlockSendTracker, taskId, 
result.getFailedBlockSendTracker());
               } finally {
-                Set<Long> succeedBlockIds =
-                    result.getSuccessBlockIds() == null
-                        ? Collections.emptySet()
-                        : result.getSuccessBlockIds();
+                Set<Long> succeedBlockIds = getSucceedBlockIds(result);
                 for (ShuffleBlockInfo block : shuffleBlockInfoList) {
                   
block.executeCompletionCallback(succeedBlockIds.contains(block.getBlockId()));
                 }
@@ -114,7 +111,9 @@ public class DataPusher implements Closeable {
                   runnable.run();
                 }
               }
+              Set<Long> succeedBlockIds = getSucceedBlockIds(result);
               return shuffleBlockInfoList.stream()
+                  .filter(x -> succeedBlockIds.contains(x.getBlockId()))
                   .map(x -> x.getFreeMemory())
                   .reduce((a, b) -> a + b)
                   .get();
@@ -127,6 +126,13 @@ public class DataPusher implements Closeable {
             });
   }
 
+  private Set<Long> getSucceedBlockIds(SendShuffleDataResult result) {
+    if (result == null || result.getSuccessBlockIds() == null) {
+      return Collections.emptySet();
+    }
+    return result.getSuccessBlockIds();
+  }
+
   private synchronized void putBlockId(
       Map<String, Set<Long>> taskToBlockIds, String taskAttemptId, Set<Long> 
blockIds) {
     if (blockIds == null || blockIds.isEmpty()) {
diff --git 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
index 97639d3c4..501b57e44 100644
--- 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
+++ 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/writer/WriteBufferManagerTest.java
@@ -487,7 +487,12 @@ public class WriteBufferManagerTest {
           List<AddBlockEvent> events = wbm.buildBlockEvents(blocks);
           for (AddBlockEvent event : events) {
             event.getProcessedCallbackChain().stream().forEach(x -> x.run());
-            sum += event.getShuffleDataInfoList().stream().mapToLong(x -> 
x.getFreeMemory()).sum();
+            // simulate: the block for partition 2 send failed
+            sum +=
+                event.getShuffleDataInfoList().stream()
+                    .filter(x -> x.getPartitionId() <= 1)
+                    .mapToLong(x -> x.getFreeMemory())
+                    .sum();
           }
           return Arrays.asList(CompletableFuture.completedFuture(sum));
         };
@@ -502,10 +507,11 @@ public class WriteBufferManagerTest {
     wbm.addRecord(1, testKey, testValue);
     wbm.addRecord(1, testKey, testValue);
     wbm.addRecord(1, testKey, testValue);
+    wbm.addRecord(2, testKey, testValue);
 
     long releasedSize = wbm.spill(1000, wbm);
     assertEquals(64, releasedSize);
-    assertEquals(96, wbm.getUsedBytes());
+    assertEquals(128, wbm.getUsedBytes());
     assertEquals(0, wbm.getBuffers().keySet().toArray()[0]);
   }
 

Reply via email to