This is an automated email from the ASF dual-hosted git repository.

fchen pushed a commit to branch branch-0.4
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/branch-0.4 by this push:
     new 94bee045c [CELEBORN-1544][0.4] ShuffleWriter needs to call close 
finally to avoid memory leaks
94bee045c is described below

commit 94bee045c0b0085401dfb9034eacbbd2031b3425
Author: sychen <[email protected]>
AuthorDate: Mon Sep 2 13:50:25 2024 +0800

    [CELEBORN-1544][0.4] ShuffleWriter needs to call close finally to avoid 
memory leaks
    
    Backport CELEBORN-1544 (https://github.com/apache/celeborn/pull/2661 and 
https://github.com/apache/celeborn/pull/2663) to branch-0.4
    
    ### What changes were proposed in this pull request?
    This PR aims to fix a possible memory leak in ShuffleWriter.
    
    ### Why are the changes needed?
    When we turn on `spark.speculation=true` or we kill the executing SQL, the 
task may be interrupted. At this time, `ShuffleWriter` may not call close.
    At this time, `DataPusher#idleQueue` will occupy some memory capacity ( 
`celeborn.client.push.buffer.max.size` * `celeborn.client.push.queue.capacity` 
) and the instance will not be released.
    
    ```java
    Thread 537 (DataPusher-78931):
      State: TIMED_WAITING
      Blocked count: 0
      Waited count: 16337
      IsDaemon: true
      Stack:
        java.lang.Thread.sleep(Native Method)
        
org.apache.celeborn.client.write.DataPushQueue.takePushTasks(DataPushQueue.java:135)
        org.apache.celeborn.client.write.DataPusher$1.run(DataPusher.java:122)
    ```
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    Production testing
    
    #### Current
    <img width="547" alt="image" 
src="https://github.com/user-attachments/assets/d6f64257-144e-4139-96c6-518ca5f1bfd2";>
    
    #### PR
    <img width="479" alt="image" 
src="https://github.com/user-attachments/assets/e4ff62ec-5b9d-47a4-a36c-1d13bf378cbc";>
    
    Closes #2718 from pan3793/CELEBORN-1544-0.4.
    
    Authored-by: sychen <[email protected]>
    Signed-off-by: Fu Chen <[email protected]>
---
 .../spark/shuffle/celeborn/SortBasedPusher.java    |  6 ++--
 .../shuffle/celeborn/SortBasedPusherSuiteJ.java    |  2 +-
 .../shuffle/celeborn/HashBasedShuffleWriter.java   | 15 +++++++++
 .../shuffle/celeborn/SortBasedShuffleWriter.java   | 36 +++++++++++++++-------
 .../shuffle/celeborn/HashBasedShuffleWriter.java   | 15 +++++++++
 .../shuffle/celeborn/SortBasedShuffleWriter.java   | 26 +++++++++++++---
 6 files changed, 82 insertions(+), 18 deletions(-)

diff --git 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java
 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java
index 3b051c3e7..93a9095a9 100644
--- 
a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java
+++ 
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java
@@ -399,13 +399,15 @@ public class SortBasedPusher extends MemoryConsumer {
     taskContext.taskMetrics().incMemoryBytesSpilled(freedBytes);
   }
 
-  public void close() throws IOException {
+  public void close(boolean throwTaskKilledOnInterruption) throws IOException {
     cleanupResources();
     try {
       dataPusher.waitOnTermination();
       sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue());
     } catch (InterruptedException e) {
-      TaskInterruptedHelper.throwTaskKillException();
+      if (throwTaskKilledOnInterruption) {
+        TaskInterruptedHelper.throwTaskKillException();
+      }
     }
   }
 
diff --git 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedPusherSuiteJ.java
 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedPusherSuiteJ.java
index 0962c98c4..73c15bb70 100644
--- 
a/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedPusherSuiteJ.java
+++ 
b/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedPusherSuiteJ.java
@@ -127,7 +127,7 @@ public class SortBasedPusherSuiteJ {
         !pusher.insertRecord(
             row5k.getBaseObject(), row5k.getBaseOffset(), 
row5k.getSizeInBytes(), 0, true));
 
-    pusher.close();
+    pusher.close(true);
 
     assertEquals(taskContext.taskMetrics().memoryBytesSpilled(), 2097152);
   }
diff --git 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
index 06d7ccc72..6db620b41 100644
--- 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
+++ 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
@@ -164,6 +164,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
 
   @Override
   public void write(scala.collection.Iterator<Product2<K, V>> records) throws 
IOException {
+    boolean needCleanupPusher = true;
     try {
       if (canUseFastWrite()) {
         fastWrite0(records);
@@ -177,8 +178,13 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         write0(records);
       }
       close();
+      needCleanupPusher = false;
     } catch (InterruptedException e) {
       TaskInterruptedHelper.throwTaskKillException();
+    } finally {
+      if (needCleanupPusher) {
+        cleanupPusher();
+      }
     }
   }
 
@@ -316,6 +322,15 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     writeMetrics.incWriteTime(System.nanoTime() - start);
   }
 
+  private void cleanupPusher() throws IOException {
+    try {
+      dataPusher.waitOnTermination();
+      sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue());
+    } catch (InterruptedException e) {
+      TaskInterruptedHelper.throwTaskKillException();
+    }
+  }
+
   private void close() throws IOException, InterruptedException {
     // here we wait for all the in-flight batches to return which sent by 
dataPusher thread
     dataPusher.waitOnTermination();
diff --git 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
index a8bd23c21..58ee5dddd 100644
--- 
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
+++ 
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
@@ -145,18 +145,26 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
 
   @Override
   public void write(scala.collection.Iterator<Product2<K, V>> records) throws 
IOException {
-    if (canUseFastWrite()) {
-      fastWrite0(records);
-    } else if (dep.mapSideCombine()) {
-      if (dep.aggregator().isEmpty()) {
-        throw new UnsupportedOperationException(
-            "When using map side combine, an aggregator must be specified.");
+    boolean needCleanupPusher = true;
+    try {
+      if (canUseFastWrite()) {
+        fastWrite0(records);
+      } else if (dep.mapSideCombine()) {
+        if (dep.aggregator().isEmpty()) {
+          throw new UnsupportedOperationException(
+              "When using map side combine, an aggregator must be specified.");
+        }
+        write0(dep.aggregator().get().combineValuesByKey(records, 
taskContext));
+      } else {
+        write0(records);
+      }
+      close();
+      needCleanupPusher = false;
+    } finally {
+      if (needCleanupPusher) {
+        cleanupPusher();
       }
-      write0(dep.aggregator().get().combineValuesByKey(records, taskContext));
-    } else {
-      write0(records);
     }
-    close();
   }
 
   @VisibleForTesting
@@ -290,11 +298,17 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     writeMetrics.incBytesWritten(bytesWritten);
   }
 
+  private void cleanupPusher() throws IOException {
+    if (pusher != null) {
+      pusher.close(false);
+    }
+  }
+
   private void close() throws IOException {
     logger.info("Memory used {}", Utils.bytesToString(pusher.getUsed()));
     long pushStartTime = System.nanoTime();
     pusher.pushData();
-    pusher.close();
+    pusher.close(true);
     writeMetrics.incWriteTime(System.nanoTime() - pushStartTime);
 
     shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);
diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
index c3808b6f1..127ffb634 100644
--- 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
@@ -161,6 +161,7 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
 
   @Override
   public void write(scala.collection.Iterator<Product2<K, V>> records) throws 
IOException {
+    boolean needCleanupPusher = true;
     try {
       if (canUseFastWrite()) {
         fastWrite0(records);
@@ -174,8 +175,13 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
         write0(records);
       }
       close();
+      needCleanupPusher = false;
     } catch (InterruptedException e) {
       TaskInterruptedHelper.throwTaskKillException();
+    } finally {
+      if (needCleanupPusher) {
+        cleanupPusher();
+      }
     }
   }
 
@@ -355,6 +361,15 @@ public class HashBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     writeMetrics.incBytesWritten(bytesWritten);
   }
 
+  private void cleanupPusher() throws IOException {
+    try {
+      dataPusher.waitOnTermination();
+      sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue());
+    } catch (InterruptedException e) {
+      TaskInterruptedHelper.throwTaskKillException();
+    }
+  }
+
   private void close() throws IOException, InterruptedException {
     // here we wait for all the in-flight batches to return which sent by 
dataPusher thread
     long pushMergedDataTime = System.nanoTime();
diff --git 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
index 7984f9ec8..95664ce63 100644
--- 
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
+++ 
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
@@ -174,8 +174,7 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     return peakMemoryUsedBytes;
   }
 
-  @Override
-  public void write(scala.collection.Iterator<Product2<K, V>> records) throws 
IOException {
+  void doWrite(scala.collection.Iterator<Product2<K, V>> records) throws 
IOException {
     if (canUseFastWrite()) {
       fastWrite0(records);
     } else if (dep.mapSideCombine()) {
@@ -187,7 +186,20 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     } else {
       write0(records);
     }
-    close();
+  }
+
+  @Override
+  public void write(scala.collection.Iterator<Product2<K, V>> records) throws 
IOException {
+    boolean needCleanupPusher = true;
+    try {
+      doWrite(records);
+      close();
+      needCleanupPusher = false;
+    } finally {
+      if (needCleanupPusher) {
+        cleanupPusher();
+      }
+    }
   }
 
   @VisibleForTesting
@@ -311,11 +323,17 @@ public class SortBasedShuffleWriter<K, V, C> extends 
ShuffleWriter<K, V> {
     writeMetrics.incBytesWritten(bytesWritten);
   }
 
+  private void cleanupPusher() throws IOException {
+    if (pusher != null) {
+      pusher.close(false);
+    }
+  }
+
   private void close() throws IOException {
     logger.info("Memory used {}", Utils.bytesToString(pusher.getUsed()));
     long pushStartTime = System.nanoTime();
     pusher.pushData();
-    pusher.close();
+    pusher.close(true);
 
     shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);
     writeMetrics.incWriteTime(System.nanoTime() - pushStartTime);

Reply via email to