This is an automated email from the ASF dual-hosted git repository. nicholasjiang pushed a commit to branch branch-0.6 in repository https://gitbox.apache.org/repos/asf/celeborn.git
commit a80904e4acb12591e1323a54f12897dd01827cce Author: TheodoreLx <[email protected]> AuthorDate: Thu Nov 20 11:19:37 2025 +0800 [CELEBORN-2152] Support merge buffers on the worker side to improve memory utilization ### What changes were proposed in this pull request? Provides a configuration item that can copy the body buffer in pushdata to a newly requested buffer before writing on the worker, achieving 100% buffer internal space utilization, and ultimately significantly improving the overall utilization of NettyMemory. ### Why are the changes needed? In the worker, Netty uses AdaptiveRecvByteBufAllocator to determine the buffer size to allocate in advance when reading data from the socket. However, in certain network environments, there can be a significant discrepancy between the buffer size predicted and allocated by AdaptiveRecvByteBufAllocator and the actual data size read from the socket. This can result in a large buffer being allocated but only a small amount of data being read, ultimately leading to very low overall memory [...] ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? cluster test ### Performance Test <img width="1697" height="700" alt="image" src="https://github.com/user-attachments/assets/56495d08-6da7-4d43-8e8a-da87a33ccf90" /> Closes #3479 from TheodoreLx/merge-push-buffer. Authored-by: TheodoreLx <[email protected]> Signed-off-by: SteNicholas <[email protected]> (cherry picked from commit cc0d1ba70a4dd923f4cf69de5ec25bbd23a87c97) Signed-off-by: SteNicholas <[email protected]> --- .../org/apache/celeborn/common/CelebornConf.scala | 19 +++++++++++ docs/configuration/worker.md | 2 ++ .../service/deploy/worker/PushDataHandler.scala | 38 +++++++++++++++++++--- 3 files changed, 54 insertions(+), 5 deletions(-) diff --git a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala index 3345792a5..bfd590c08 100644 --- a/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala +++ b/common/src/main/scala/org/apache/celeborn/common/CelebornConf.scala @@ -1386,6 +1386,8 @@ class CelebornConf(loadDefaults: Boolean) extends Cloneable with Logging with Se get(WORKER_MEMORY_FILE_STORAGE_EVICT_AGGRESSIVE_MODE_ENABLED) def workerMemoryFileStorageEvictRatio: Double = get(WORKER_MEMORY_FILE_STORAGE_EVICT_RATIO) + def workerPushDataMergeBufferEnabled: Boolean = get(WORKER_PUSH_DATA_MERGE_BUFFER_ENABLED) + def workerDirectMemoryRatioToMergeBuffer: Double = get(WORKER_DIRECT_MEMORY_RATIO_TO_MERGE_BUFFER) // ////////////////////////////////////////////////////// // Rate Limit controller // @@ -4132,6 +4134,23 @@ object CelebornConf extends Logging { .doubleConf .createWithDefault(0.5) + val WORKER_PUSH_DATA_MERGE_BUFFER_ENABLED: ConfigEntry[Boolean] = + buildConf("celeborn.worker.pushdata.mergeBuffer.enabled") + .categories("worker") + .version("0.6.2") + .doc("enable merge low utilization push data's body buffer before write") + .booleanConf + .createWithDefault(false) + + val WORKER_DIRECT_MEMORY_RATIO_TO_MERGE_BUFFER: ConfigEntry[Double] = { + buildConf("celeborn.worker.directMemoryRatioToMergeBuffer") + .categories("worker") + .version("0.6.2") + .doc("If direct memory usage is above this limit, the worker will merge low utilization push data's body buffer") + .doubleConf + .createWithDefault(0.4) + } + val WORKER_CONGESTION_CONTROL_ENABLED: ConfigEntry[Boolean] = buildConf("celeborn.worker.congestionControl.enabled") .categories("worker") diff --git a/docs/configuration/worker.md b/docs/configuration/worker.md index be6afdff8..d9db38d0b 100644 --- a/docs/configuration/worker.md +++ b/docs/configuration/worker.md @@ -78,6 +78,7 @@ license: | | celeborn.worker.decommission.forceExitTimeout | 6h | false | The wait time of waiting for all the shuffle expire during worker decommission. | 0.4.0 | | | celeborn.worker.directMemoryRatioForMemoryFileStorage | 0.0 | false | Max ratio of direct memory to store shuffle data. This feature is experimental and disabled by default. | 0.5.0 | | | celeborn.worker.directMemoryRatioForReadBuffer | 0.1 | false | Max ratio of direct memory for read buffer | 0.2.0 | | +| celeborn.worker.directMemoryRatioToMergeBuffer | 0.4 | false | If direct memory usage is above this limit, the worker will merge low utilization push data's body buffer | 0.6.2 | | | celeborn.worker.directMemoryRatioToPauseReceive | 0.85 | false | If direct memory usage reaches this limit, the worker will stop to receive data from Celeborn shuffle clients. | 0.2.0 | | | celeborn.worker.directMemoryRatioToPauseReplicate | 0.95 | false | If direct memory usage reaches this limit, the worker will stop to receive replication data from other workers. This value should be higher than celeborn.worker.directMemoryRatioToPauseReceive. | 0.2.0 | | | celeborn.worker.directMemoryRatioToResume | 0.7 | false | If direct memory usage is less than this limit, worker will resume. | 0.2.0 | | @@ -161,6 +162,7 @@ license: | | celeborn.worker.push.heartbeat.enabled | false | false | enable the heartbeat from worker to client when pushing data | 0.3.0 | | | celeborn.worker.push.io.threads | <undefined> | false | Netty IO thread number of worker to handle client push data. The default threads number is the number of flush thread. | 0.2.0 | | | celeborn.worker.push.port | 0 | false | Server port for Worker to receive push data request from ShuffleClient. | 0.2.0 | | +| celeborn.worker.pushdata.mergeBuffer.enabled | false | false | enable merge low utilization push data's body buffer before write | 0.6.2 | | | celeborn.worker.readBuffer.allocationWait | 50ms | false | The time to wait when buffer dispatcher can not allocate a buffer. | 0.3.0 | | | celeborn.worker.readBuffer.processTimeout | 600s | false | Timeout for buffer dispatcher to process a read buffer request. | 0.6.2 | | | celeborn.worker.readBuffer.target.changeThreshold | 1mb | false | The target ratio for pre read memory usage. | 0.3.0 | | diff --git a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala index 892996780..436cef2d5 100644 --- a/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala +++ b/worker/src/main/scala/org/apache/celeborn/service/deploy/worker/PushDataHandler.scala @@ -45,6 +45,7 @@ import org.apache.celeborn.common.protocol.message.StatusCode import org.apache.celeborn.common.unsafe.Platform import org.apache.celeborn.common.util.{ExceptionUtils, Utils} import org.apache.celeborn.service.deploy.worker.congestcontrol.CongestionController +import org.apache.celeborn.service.deploy.worker.memory.MemoryManager import org.apache.celeborn.service.deploy.worker.storage.{LocalFlusher, PartitionDataWriter, StorageManager} class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler with Logging { @@ -64,6 +65,8 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler private var storageManager: StorageManager = _ private var workerPartitionSplitEnabled: Boolean = _ private var workerReplicateRandomConnectionEnabled: Boolean = _ + private var workerPushDataMergeBufferEnabled: Boolean = _ + private var workerDirectMemoryRatioToMergeBuffer: Double = _ private var testPushPrimaryDataTimeout: Boolean = _ private var testPushReplicaDataTimeout: Boolean = _ @@ -83,7 +86,8 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler shutdown = worker.shutdown workerPartitionSplitEnabled = worker.conf.workerPartitionSplitEnabled workerReplicateRandomConnectionEnabled = worker.conf.workerReplicateRandomConnectionEnabled - + workerPushDataMergeBufferEnabled = worker.conf.workerPushDataMergeBufferEnabled + workerDirectMemoryRatioToMergeBuffer = worker.conf.workerDirectMemoryRatioToMergeBuffer testPushPrimaryDataTimeout = worker.conf.testPushPrimaryDataTimeout testPushReplicaDataTimeout = worker.conf.testPushReplicaDataTimeout registered = Some(worker.registered) @@ -1492,6 +1496,26 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler hardSplitIndexes: Array[Int] = Array.empty[Int]): Unit = { val length = fileWriters.length val result = new Array[StatusCode](length) + + var finalBody: ByteBuf = body + var copyBody: ByteBuf = null + if (workerPushDataMergeBufferEnabled && + MemoryManager.instance().workerMemoryUsageRatio() > workerDirectMemoryRatioToMergeBuffer) { + val numBytes = body.readableBytes() + try { + copyBody = body.alloc.directBuffer(numBytes) + // this method do not increase the readerIndex of source buffer, when oom + // happens, we can fall back to the original buffer + copyBody.writeBytes(body, body.readerIndex, numBytes) + finalBody = copyBody + } catch { + case e: OutOfMemoryError => + logError(s"caught oom when consolidate data failed, size: $numBytes", e) + case e: Throwable => + logError(s"consolidate data failed, size: $numBytes", e) + } + } + def writeData( fileWriter: PartitionDataWriter, body: ByteBuf, @@ -1539,14 +1563,14 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler } else { fileWriter = fileWriters(index) if (!writePromise.isCompleted) { - val offset = body.readerIndex() + batchOffsets(index) + val offset = finalBody.readerIndex() + batchOffsets(index) val length = if (index == fileWriters.length - 1) { - body.readableBytes() - batchOffsets(index) + finalBody.readableBytes() - batchOffsets(index) } else { batchOffsets(index + 1) - batchOffsets(index) } - val batchBody = body.slice(offset, length) + val batchBody = finalBody.slice(offset, length) writeData(fileWriter, batchBody, shuffleKey, index) } else { fileWriter.decrementPendingWrites() @@ -1555,12 +1579,16 @@ class PushDataHandler(val workerSource: WorkerSource) extends BaseMessageHandler index += 1 } case _ => - writeData(fileWriters.head, body, shuffleKey, 0) + writeData(fileWriters.head, finalBody, shuffleKey, 0) } if (!writePromise.isCompleted) { workerSource.incCounter(WorkerSource.WRITE_DATA_SUCCESS_COUNT) writePromise.success(result) } + // manually release copyBody to avoid memory leak + if (copyBody != null) { + copyBody.release() + } } private def nextValueOrElse(iterator: Iterator[Int], defaultValue: Int): Int = {
