This is an automated email from the ASF dual-hosted git repository.
fchen pushed a commit to branch branch-0.4
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/branch-0.4 by this push:
new 94bee045c [CELEBORN-1544][0.4] ShuffleWriter needs to call close
finally to avoid memory leaks
94bee045c is described below
commit 94bee045c0b0085401dfb9034eacbbd2031b3425
Author: sychen <[email protected]>
AuthorDate: Mon Sep 2 13:50:25 2024 +0800
[CELEBORN-1544][0.4] ShuffleWriter needs to call close finally to avoid
memory leaks
Backport CELEBORN-1544 (https://github.com/apache/celeborn/pull/2661 and
https://github.com/apache/celeborn/pull/2663) to branch-0.4
### What changes were proposed in this pull request?
This PR aims to fix a possible memory leak in ShuffleWriter.
### Why are the changes needed?
When we turn on `spark.speculation=true` or we kill the executing SQL, the
task may be interrupted. At this time, `ShuffleWriter` may not call close.
At this time, `DataPusher#idleQueue` will occupy some memory capacity (
`celeborn.client.push.buffer.max.size` * `celeborn.client.push.queue.capacity`
) and the instance will not be released.
```java
Thread 537 (DataPusher-78931):
State: TIMED_WAITING
Blocked count: 0
Waited count: 16337
IsDaemon: true
Stack:
java.lang.Thread.sleep(Native Method)
org.apache.celeborn.client.write.DataPushQueue.takePushTasks(DataPushQueue.java:135)
org.apache.celeborn.client.write.DataPusher$1.run(DataPusher.java:122)
```
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Production testing
#### Current
<img width="547" alt="image"
src="https://github.com/user-attachments/assets/d6f64257-144e-4139-96c6-518ca5f1bfd2">
#### PR
<img width="479" alt="image"
src="https://github.com/user-attachments/assets/e4ff62ec-5b9d-47a4-a36c-1d13bf378cbc">
Closes #2718 from pan3793/CELEBORN-1544-0.4.
Authored-by: sychen <[email protected]>
Signed-off-by: Fu Chen <[email protected]>
---
.../spark/shuffle/celeborn/SortBasedPusher.java | 6 ++--
.../shuffle/celeborn/SortBasedPusherSuiteJ.java | 2 +-
.../shuffle/celeborn/HashBasedShuffleWriter.java | 15 +++++++++
.../shuffle/celeborn/SortBasedShuffleWriter.java | 36 +++++++++++++++-------
.../shuffle/celeborn/HashBasedShuffleWriter.java | 15 +++++++++
.../shuffle/celeborn/SortBasedShuffleWriter.java | 26 +++++++++++++---
6 files changed, 82 insertions(+), 18 deletions(-)
diff --git
a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java
index 3b051c3e7..93a9095a9 100644
---
a/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java
+++
b/client-spark/common/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedPusher.java
@@ -399,13 +399,15 @@ public class SortBasedPusher extends MemoryConsumer {
taskContext.taskMetrics().incMemoryBytesSpilled(freedBytes);
}
- public void close() throws IOException {
+ public void close(boolean throwTaskKilledOnInterruption) throws IOException {
cleanupResources();
try {
dataPusher.waitOnTermination();
sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue());
} catch (InterruptedException e) {
- TaskInterruptedHelper.throwTaskKillException();
+ if (throwTaskKilledOnInterruption) {
+ TaskInterruptedHelper.throwTaskKillException();
+ }
}
}
diff --git
a/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedPusherSuiteJ.java
b/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedPusherSuiteJ.java
index 0962c98c4..73c15bb70 100644
---
a/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedPusherSuiteJ.java
+++
b/client-spark/common/src/test/java/org/apache/spark/shuffle/celeborn/SortBasedPusherSuiteJ.java
@@ -127,7 +127,7 @@ public class SortBasedPusherSuiteJ {
!pusher.insertRecord(
row5k.getBaseObject(), row5k.getBaseOffset(),
row5k.getSizeInBytes(), 0, true));
- pusher.close();
+ pusher.close(true);
assertEquals(taskContext.taskMetrics().memoryBytesSpilled(), 2097152);
}
diff --git
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
index 06d7ccc72..6db620b41 100644
---
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
+++
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
@@ -164,6 +164,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
@Override
public void write(scala.collection.Iterator<Product2<K, V>> records) throws
IOException {
+ boolean needCleanupPusher = true;
try {
if (canUseFastWrite()) {
fastWrite0(records);
@@ -177,8 +178,13 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
write0(records);
}
close();
+ needCleanupPusher = false;
} catch (InterruptedException e) {
TaskInterruptedHelper.throwTaskKillException();
+ } finally {
+ if (needCleanupPusher) {
+ cleanupPusher();
+ }
}
}
@@ -316,6 +322,15 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
writeMetrics.incWriteTime(System.nanoTime() - start);
}
+ private void cleanupPusher() throws IOException {
+ try {
+ dataPusher.waitOnTermination();
+ sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue());
+ } catch (InterruptedException e) {
+ TaskInterruptedHelper.throwTaskKillException();
+ }
+ }
+
private void close() throws IOException, InterruptedException {
// here we wait for all the in-flight batches to return which sent by
dataPusher thread
dataPusher.waitOnTermination();
diff --git
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
index a8bd23c21..58ee5dddd 100644
---
a/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
+++
b/client-spark/spark-2/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
@@ -145,18 +145,26 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
@Override
public void write(scala.collection.Iterator<Product2<K, V>> records) throws
IOException {
- if (canUseFastWrite()) {
- fastWrite0(records);
- } else if (dep.mapSideCombine()) {
- if (dep.aggregator().isEmpty()) {
- throw new UnsupportedOperationException(
- "When using map side combine, an aggregator must be specified.");
+ boolean needCleanupPusher = true;
+ try {
+ if (canUseFastWrite()) {
+ fastWrite0(records);
+ } else if (dep.mapSideCombine()) {
+ if (dep.aggregator().isEmpty()) {
+ throw new UnsupportedOperationException(
+ "When using map side combine, an aggregator must be specified.");
+ }
+ write0(dep.aggregator().get().combineValuesByKey(records,
taskContext));
+ } else {
+ write0(records);
+ }
+ close();
+ needCleanupPusher = false;
+ } finally {
+ if (needCleanupPusher) {
+ cleanupPusher();
}
- write0(dep.aggregator().get().combineValuesByKey(records, taskContext));
- } else {
- write0(records);
}
- close();
}
@VisibleForTesting
@@ -290,11 +298,17 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
writeMetrics.incBytesWritten(bytesWritten);
}
+ private void cleanupPusher() throws IOException {
+ if (pusher != null) {
+ pusher.close(false);
+ }
+ }
+
private void close() throws IOException {
logger.info("Memory used {}", Utils.bytesToString(pusher.getUsed()));
long pushStartTime = System.nanoTime();
pusher.pushData();
- pusher.close();
+ pusher.close(true);
writeMetrics.incWriteTime(System.nanoTime() - pushStartTime);
shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);
diff --git
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
index c3808b6f1..127ffb634 100644
---
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
+++
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/HashBasedShuffleWriter.java
@@ -161,6 +161,7 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
@Override
public void write(scala.collection.Iterator<Product2<K, V>> records) throws
IOException {
+ boolean needCleanupPusher = true;
try {
if (canUseFastWrite()) {
fastWrite0(records);
@@ -174,8 +175,13 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
write0(records);
}
close();
+ needCleanupPusher = false;
} catch (InterruptedException e) {
TaskInterruptedHelper.throwTaskKillException();
+ } finally {
+ if (needCleanupPusher) {
+ cleanupPusher();
+ }
}
}
@@ -355,6 +361,15 @@ public class HashBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
writeMetrics.incBytesWritten(bytesWritten);
}
+ private void cleanupPusher() throws IOException {
+ try {
+ dataPusher.waitOnTermination();
+ sendBufferPool.returnPushTaskQueue(dataPusher.getIdleQueue());
+ } catch (InterruptedException e) {
+ TaskInterruptedHelper.throwTaskKillException();
+ }
+ }
+
private void close() throws IOException, InterruptedException {
// here we wait for all the in-flight batches to return which sent by
dataPusher thread
long pushMergedDataTime = System.nanoTime();
diff --git
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
index 7984f9ec8..95664ce63 100644
---
a/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
+++
b/client-spark/spark-3/src/main/java/org/apache/spark/shuffle/celeborn/SortBasedShuffleWriter.java
@@ -174,8 +174,7 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
return peakMemoryUsedBytes;
}
- @Override
- public void write(scala.collection.Iterator<Product2<K, V>> records) throws
IOException {
+ void doWrite(scala.collection.Iterator<Product2<K, V>> records) throws
IOException {
if (canUseFastWrite()) {
fastWrite0(records);
} else if (dep.mapSideCombine()) {
@@ -187,7 +186,20 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
} else {
write0(records);
}
- close();
+ }
+
+ @Override
+ public void write(scala.collection.Iterator<Product2<K, V>> records) throws
IOException {
+ boolean needCleanupPusher = true;
+ try {
+ doWrite(records);
+ close();
+ needCleanupPusher = false;
+ } finally {
+ if (needCleanupPusher) {
+ cleanupPusher();
+ }
+ }
}
@VisibleForTesting
@@ -311,11 +323,17 @@ public class SortBasedShuffleWriter<K, V, C> extends
ShuffleWriter<K, V> {
writeMetrics.incBytesWritten(bytesWritten);
}
+ private void cleanupPusher() throws IOException {
+ if (pusher != null) {
+ pusher.close(false);
+ }
+ }
+
private void close() throws IOException {
logger.info("Memory used {}", Utils.bytesToString(pusher.getUsed()));
long pushStartTime = System.nanoTime();
pusher.pushData();
- pusher.close();
+ pusher.close(true);
shuffleClient.pushMergedData(shuffleId, mapId, encodedAttemptId);
writeMetrics.incWriteTime(System.nanoTime() - pushStartTime);