This is an automated email from the ASF dual-hosted git repository.
mridulm80 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 11d1c08e2 [CELEBORN-2160] Speedup CommitHandler.finishMapperAttempt
11d1c08e2 is described below
commit 11d1c08e2e2b43b5e4dacd9f597ab1a75fe7319c
Author: Mridul Muralidharan <mridulatgmail.com>
AuthorDate: Wed Oct 1 23:30:33 2025 -0500
[CELEBORN-2160] Speedup CommitHandler.finishMapperAttempt
### What changes were proposed in this pull request?
Speedup `finishMapperAttempt` by making its complexity constant instead of
linear to number of mappers
### Why are the changes needed?
As detailed in
[CELEBORN-2160](https://issues.apache.org/jira/browse/CELEBORN-2160), when
there are a large number of concurrent 'mapper' tasks for a stage with a large
mapper partitions - the linear (to number of mappers) complexity of
`finishMapperAttempt` results in rpc timing out and failing.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Existing tests.
Closes #3487 from mridulm/speedup-finish-mapper-attempts.
Authored-by: Mridul Muralidharan <mridulatgmail.com>
Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
---
.../org/apache/celeborn/client/ClientUtils.scala | 19 ----------------
.../org/apache/celeborn/client/CommitManager.scala | 4 ++++
.../apache/celeborn/client/LifecycleManager.scala | 2 +-
.../celeborn/client/commit/CommitHandler.scala | 2 ++
.../client/commit/MapPartitionCommitHandler.scala | 5 +++++
.../commit/ReducePartitionCommitHandler.scala | 26 ++++++++++++++++++++--
.../tests/spark/memory/MemorySkewJoinSuite.scala | 5 +++--
7 files changed, 39 insertions(+), 24 deletions(-)
diff --git a/client/src/main/scala/org/apache/celeborn/client/ClientUtils.scala
b/client/src/main/scala/org/apache/celeborn/client/ClientUtils.scala
index b071eff3b..bf214ba15 100644
--- a/client/src/main/scala/org/apache/celeborn/client/ClientUtils.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/ClientUtils.scala
@@ -21,25 +21,6 @@ import org.apache.celeborn.common.CelebornConf
object ClientUtils {
- /**
- * Check if all the mapper attempts are finished. If any of the attempts is
not finished, return false.
- * This method checks the attempts array in reverse order, which can be
faster if the unfinished attempts
- * are more likely to be towards the end of the array.
- *
- * @param attempts The mapper finished attemptId array. An attempt ID of -1
indicates that the mapper is not finished.
- * @return True if all mapper attempts are finished, false otherwise.
- */
- def areAllMapperAttemptsFinished(attempts: Array[Int]): Boolean = {
- var i = attempts.length - 1
- while (i >= 0) {
- if (attempts(i) < 0) {
- return false
- }
- i -= 1
- }
- true
- }
-
/**
* If startMapIndex > endMapIndex, means partition is skew partition.
* locations will split to sub-partitions with startMapIndex size.
diff --git
a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
index 1a71d3ccd..a4df22fb5 100644
--- a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
@@ -215,6 +215,10 @@ class CommitManager(appUniqueId: String, val conf:
CelebornConf, lifecycleManage
getCommitHandler(shuffleId).getMapperAttempts(shuffleId)
}
+ def areAllMapperAttemptsFinished(shuffleId: Int): Boolean = {
+ getCommitHandler(shuffleId).areAllMapperAttemptsFinished(shuffleId)
+ }
+
def finishMapperAttempt(
shuffleId: Int,
mapId: Int,
diff --git
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index a4f5c5b4c..bb008edab 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -997,7 +997,7 @@ class LifecycleManager(val appUniqueId: String, val conf:
CelebornConf) extends
}
def areAllMapTasksEnd(shuffleId: Int): Boolean = {
-
ClientUtils.areAllMapperAttemptsFinished(commitManager.getMapperAttempts(shuffleId))
+ commitManager.areAllMapperAttemptsFinished(shuffleId)
}
shuffleIds.synchronized {
diff --git
a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
index 735ebb99c..b4c2e87a1 100644
---
a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
+++
b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
@@ -206,6 +206,8 @@ abstract class CommitHandler(
def getMapperAttempts(shuffleId: Int): Array[Int]
+ def areAllMapperAttemptsFinished(shuffleId: Int): Boolean
+
/**
* return (thisMapperAttemptedFinishedSuccessOrNot, allMapperFinishedOrNot)
*/
diff --git
a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
index 490b9450a..dd05e7aaf 100644
---
a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
+++
b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
@@ -117,6 +117,11 @@ class MapPartitionCommitHandler(
Array.empty
}
+ override def areAllMapperAttemptsFinished(shuffleId: Int): Boolean = {
+ // see getMapperAttempts. !getMapperAttempts.exists(_ < -1) is always true
+ true
+ }
+
override def removeExpiredShuffle(shuffleId: Int): Unit = {
inProcessMapPartitionEndIds.remove(shuffleId)
shuffleSucceedPartitionIds.remove(shuffleId)
diff --git
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
index 6d8ee3c2a..3c332c9a8 100644
---
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
+++
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
@@ -75,6 +75,8 @@ class ReducePartitionCommitHandler(
private val stageEndShuffleSet = ConcurrentHashMap.newKeySet[Int]()
private val inProcessStageEndShuffleSet = ConcurrentHashMap.newKeySet[Int]()
private val shuffleMapperAttempts = JavaUtils.newConcurrentHashMap[Int,
Array[Int]]()
+ // TODO: Move this to native Int -> Int Map
+ private val shuffleToCompletedMappers = JavaUtils.newConcurrentHashMap[Int,
Int]()
private val stageEndTimeout = conf.clientPushStageEndTimeout
private val mockShuffleLost = conf.testMockShuffleLost
private val mockShuffleLostShuffle = conf.testMockShuffleLostShuffle
@@ -161,6 +163,7 @@ class ReducePartitionCommitHandler(
stageEndShuffleSet.remove(shuffleId)
inProcessStageEndShuffleSet.remove(shuffleId)
shuffleMapperAttempts.remove(shuffleId)
+ shuffleToCompletedMappers.remove(shuffleId)
commitMetadataForReducer.remove(shuffleId)
skewPartitionCompletenessValidator.remove(shuffleId)
super.removeExpiredShuffle(shuffleId)
@@ -271,6 +274,21 @@ class ReducePartitionCommitHandler(
shuffleMapperAttempts.get(shuffleId)
}
+ override def areAllMapperAttemptsFinished(shuffleId: Int): Boolean = {
+ val attempts = shuffleMapperAttempts.get(shuffleId)
+ if (null != attempts) {
+ attempts.length == shuffleToCompletedMappers.get(shuffleId)
+ } else {
+ false
+ }
+ }
+
+ private val valueIncrementFunction = new function.BiFunction[Int, Int,
Int]() {
+ override def apply(key: Int, value: Int): Int = {
+ value + 1
+ }
+ }
+
override def finishMapperAttempt(
shuffleId: Int,
mapId: Int,
@@ -291,6 +309,9 @@ class ReducePartitionCommitHandler(
val attempts = shuffleMapperAttempts.get(shuffleId)
if (attempts(mapId) < 0) {
attempts(mapId) = attemptId
+ // increment completed mappers
+ val completedMappers =
+ shuffleToCompletedMappers.compute(shuffleId, valueIncrementFunction)
if (null != pushFailedBatches && !pushFailedBatches.isEmpty) {
val pushFailedBatchesMap = shufflePushFailedBatches.computeIfAbsent(
@@ -304,7 +325,7 @@ class ReducePartitionCommitHandler(
}
}
// Mapper with this attemptId finished, also check all other mapper
finished or not.
- (true, ClientUtils.areAllMapperAttemptsFinished(attempts))
+ (true, completedMappers == attempts.length)
} else {
// Mapper with another attemptId finished, skip this request
(false, false)
@@ -391,8 +412,9 @@ class ReducePartitionCommitHandler(
shuffleMapperAttempts.synchronized {
if (!shuffleMapperAttempts.containsKey(shuffleId)) {
val attempts = new Array[Int](numMappers)
- 0 until numMappers foreach (idx => attempts(idx) = -1)
+ util.Arrays.fill(attempts, -1)
shuffleMapperAttempts.put(shuffleId, attempts)
+ shuffleToCompletedMappers.put(shuffleId, 0)
}
if (shuffleIntegrityCheckEnabled) {
commitMetadataForReducer.put(shuffleId, Array.fill(numPartitions)(new
CommitMetadata()))
diff --git
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/memory/MemorySkewJoinSuite.scala
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/memory/MemorySkewJoinSuite.scala
index 9c777c1e3..03f6c7ff7 100644
---
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/memory/MemorySkewJoinSuite.scala
+++
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/memory/MemorySkewJoinSuite.scala
@@ -22,6 +22,7 @@ import java.io.File
import scala.util.Random
import org.apache.spark.SparkConf
+import org.apache.spark.network.util.JavaUtils
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.internal.SQLConf
import org.scalatest.BeforeAndAfterEach
@@ -100,8 +101,8 @@ class MemorySkewJoinSuite extends AnyFunSuite
})
.toDF("key", "fa", "fb", "fc", "fd")
df2.createOrReplaceTempView("view2")
- new File("./df1").delete()
- new File("./df2").delete()
+ JavaUtils.deleteRecursively(new File("./df1"))
+ JavaUtils.deleteRecursively(new File("./df2"))
df.write.parquet("./df1")
df2.write.parquet("./df2")
sparkSession.close()