This is an automated email from the ASF dual-hosted git repository.

mridulm80 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 11d1c08e2 [CELEBORN-2160] Speedup CommitHandler.finishMapperAttempt
11d1c08e2 is described below

commit 11d1c08e2e2b43b5e4dacd9f597ab1a75fe7319c
Author: Mridul Muralidharan <mridulatgmail.com>
AuthorDate: Wed Oct 1 23:30:33 2025 -0500

    [CELEBORN-2160] Speedup CommitHandler.finishMapperAttempt
    
    ### What changes were proposed in this pull request?
    Speedup `finishMapperAttempt` by making its complexity constant instead of 
linear to number of mappers
    
    ### Why are the changes needed?
    As detailed in 
[CELEBORN-2160](https://issues.apache.org/jira/browse/CELEBORN-2160), when 
there are a large number of concurrent 'mapper' tasks for a stage with a large 
mapper partitions - the linear (to number of mappers) complexity of 
`finishMapperAttempt` results in rpc timing out and failing.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    Existing tests.
    
    Closes #3487 from mridulm/speedup-finish-mapper-attempts.
    
    Authored-by: Mridul Muralidharan <mridulatgmail.com>
    Signed-off-by: Mridul Muralidharan <mridul<at>gmail.com>
---
 .../org/apache/celeborn/client/ClientUtils.scala   | 19 ----------------
 .../org/apache/celeborn/client/CommitManager.scala |  4 ++++
 .../apache/celeborn/client/LifecycleManager.scala  |  2 +-
 .../celeborn/client/commit/CommitHandler.scala     |  2 ++
 .../client/commit/MapPartitionCommitHandler.scala  |  5 +++++
 .../commit/ReducePartitionCommitHandler.scala      | 26 ++++++++++++++++++++--
 .../tests/spark/memory/MemorySkewJoinSuite.scala   |  5 +++--
 7 files changed, 39 insertions(+), 24 deletions(-)

diff --git a/client/src/main/scala/org/apache/celeborn/client/ClientUtils.scala 
b/client/src/main/scala/org/apache/celeborn/client/ClientUtils.scala
index b071eff3b..bf214ba15 100644
--- a/client/src/main/scala/org/apache/celeborn/client/ClientUtils.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/ClientUtils.scala
@@ -21,25 +21,6 @@ import org.apache.celeborn.common.CelebornConf
 
 object ClientUtils {
 
-  /**
-   * Check if all the mapper attempts are finished. If any of the attempts is 
not finished, return false.
-   * This method checks the attempts array in reverse order, which can be 
faster if the unfinished attempts
-   * are more likely to be towards the end of the array.
-   *
-   * @param attempts The mapper finished attemptId array. An attempt ID of -1 
indicates that the mapper is not finished.
-   * @return True if all mapper attempts are finished, false otherwise.
-   */
-  def areAllMapperAttemptsFinished(attempts: Array[Int]): Boolean = {
-    var i = attempts.length - 1
-    while (i >= 0) {
-      if (attempts(i) < 0) {
-        return false
-      }
-      i -= 1
-    }
-    true
-  }
-
   /**
    * If startMapIndex > endMapIndex, means partition is skew partition.
    * locations will split to sub-partitions with startMapIndex size.
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
index 1a71d3ccd..a4df22fb5 100644
--- a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
@@ -215,6 +215,10 @@ class CommitManager(appUniqueId: String, val conf: 
CelebornConf, lifecycleManage
     getCommitHandler(shuffleId).getMapperAttempts(shuffleId)
   }
 
+  def areAllMapperAttemptsFinished(shuffleId: Int): Boolean = {
+    getCommitHandler(shuffleId).areAllMapperAttemptsFinished(shuffleId)
+  }
+
   def finishMapperAttempt(
       shuffleId: Int,
       mapId: Int,
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index a4f5c5b4c..bb008edab 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -997,7 +997,7 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
     }
 
     def areAllMapTasksEnd(shuffleId: Int): Boolean = {
-      
ClientUtils.areAllMapperAttemptsFinished(commitManager.getMapperAttempts(shuffleId))
+      commitManager.areAllMapperAttemptsFinished(shuffleId)
     }
 
     shuffleIds.synchronized {
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala 
b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
index 735ebb99c..b4c2e87a1 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
@@ -206,6 +206,8 @@ abstract class CommitHandler(
 
   def getMapperAttempts(shuffleId: Int): Array[Int]
 
+  def areAllMapperAttemptsFinished(shuffleId: Int): Boolean
+
   /**
    * return (thisMapperAttemptedFinishedSuccessOrNot, allMapperFinishedOrNot)
    */
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
 
b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
index 490b9450a..dd05e7aaf 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
@@ -117,6 +117,11 @@ class MapPartitionCommitHandler(
     Array.empty
   }
 
+  override def areAllMapperAttemptsFinished(shuffleId: Int): Boolean = {
+    // see getMapperAttempts. !getMapperAttempts.exists(_ < -1) is always true
+    true
+  }
+
   override def removeExpiredShuffle(shuffleId: Int): Unit = {
     inProcessMapPartitionEndIds.remove(shuffleId)
     shuffleSucceedPartitionIds.remove(shuffleId)
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
 
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
index 6d8ee3c2a..3c332c9a8 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
@@ -75,6 +75,8 @@ class ReducePartitionCommitHandler(
   private val stageEndShuffleSet = ConcurrentHashMap.newKeySet[Int]()
   private val inProcessStageEndShuffleSet = ConcurrentHashMap.newKeySet[Int]()
   private val shuffleMapperAttempts = JavaUtils.newConcurrentHashMap[Int, 
Array[Int]]()
+  // TODO: Move this to native Int -> Int Map
+  private val shuffleToCompletedMappers = JavaUtils.newConcurrentHashMap[Int, 
Int]()
   private val stageEndTimeout = conf.clientPushStageEndTimeout
   private val mockShuffleLost = conf.testMockShuffleLost
   private val mockShuffleLostShuffle = conf.testMockShuffleLostShuffle
@@ -161,6 +163,7 @@ class ReducePartitionCommitHandler(
     stageEndShuffleSet.remove(shuffleId)
     inProcessStageEndShuffleSet.remove(shuffleId)
     shuffleMapperAttempts.remove(shuffleId)
+    shuffleToCompletedMappers.remove(shuffleId)
     commitMetadataForReducer.remove(shuffleId)
     skewPartitionCompletenessValidator.remove(shuffleId)
     super.removeExpiredShuffle(shuffleId)
@@ -271,6 +274,21 @@ class ReducePartitionCommitHandler(
     shuffleMapperAttempts.get(shuffleId)
   }
 
+  override def areAllMapperAttemptsFinished(shuffleId: Int): Boolean = {
+    val attempts = shuffleMapperAttempts.get(shuffleId)
+    if (null != attempts) {
+      attempts.length == shuffleToCompletedMappers.get(shuffleId)
+    } else {
+      false
+    }
+  }
+
+  private val valueIncrementFunction = new function.BiFunction[Int, Int, 
Int]() {
+    override def apply(key: Int, value: Int): Int = {
+      value + 1
+    }
+  }
+
   override def finishMapperAttempt(
       shuffleId: Int,
       mapId: Int,
@@ -291,6 +309,9 @@ class ReducePartitionCommitHandler(
       val attempts = shuffleMapperAttempts.get(shuffleId)
       if (attempts(mapId) < 0) {
         attempts(mapId) = attemptId
+        // increment completed mappers
+        val completedMappers =
+          shuffleToCompletedMappers.compute(shuffleId, valueIncrementFunction)
 
         if (null != pushFailedBatches && !pushFailedBatches.isEmpty) {
           val pushFailedBatchesMap = shufflePushFailedBatches.computeIfAbsent(
@@ -304,7 +325,7 @@ class ReducePartitionCommitHandler(
           }
         }
         // Mapper with this attemptId finished, also check all other mapper 
finished or not.
-        (true, ClientUtils.areAllMapperAttemptsFinished(attempts))
+        (true, completedMappers == attempts.length)
       } else {
         // Mapper with another attemptId finished, skip this request
         (false, false)
@@ -391,8 +412,9 @@ class ReducePartitionCommitHandler(
     shuffleMapperAttempts.synchronized {
       if (!shuffleMapperAttempts.containsKey(shuffleId)) {
         val attempts = new Array[Int](numMappers)
-        0 until numMappers foreach (idx => attempts(idx) = -1)
+        util.Arrays.fill(attempts, -1)
         shuffleMapperAttempts.put(shuffleId, attempts)
+        shuffleToCompletedMappers.put(shuffleId, 0)
       }
       if (shuffleIntegrityCheckEnabled) {
         commitMetadataForReducer.put(shuffleId, Array.fill(numPartitions)(new 
CommitMetadata()))
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/memory/MemorySkewJoinSuite.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/memory/MemorySkewJoinSuite.scala
index 9c777c1e3..03f6c7ff7 100644
--- 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/memory/MemorySkewJoinSuite.scala
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/spark/memory/MemorySkewJoinSuite.scala
@@ -22,6 +22,7 @@ import java.io.File
 import scala.util.Random
 
 import org.apache.spark.SparkConf
+import org.apache.spark.network.util.JavaUtils
 import org.apache.spark.sql.{Row, SparkSession}
 import org.apache.spark.sql.internal.SQLConf
 import org.scalatest.BeforeAndAfterEach
@@ -100,8 +101,8 @@ class MemorySkewJoinSuite extends AnyFunSuite
       })
       .toDF("key", "fa", "fb", "fc", "fd")
     df2.createOrReplaceTempView("view2")
-    new File("./df1").delete()
-    new File("./df2").delete()
+    JavaUtils.deleteRecursively(new File("./df1"))
+    JavaUtils.deleteRecursively(new File("./df2"))
     df.write.parquet("./df1")
     df2.write.parquet("./df2")
     sparkSession.close()

Reply via email to