This is an automated email from the ASF dual-hosted git repository.

rexxiong pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 30f7da556 [CELEBORN-1490][CIP-6] Enrich register shuffle method
30f7da556 is described below

commit 30f7da556fc5e08c927f746b8ae72b4ab44801ab
Author: Weijie Guo <[email protected]>
AuthorDate: Tue Sep 10 18:09:29 2024 +0800

    [CELEBORN-1490][CIP-6] Enrich register shuffle method
    
    ### What changes were proposed in this pull request?
    
    Enrich register shuffle related message to support segment-based shuffle.
    
    Note: This version of CIP-6 still only works in blocking mode, but we have 
extended related fields to give it the potential to reading while writing. Any 
subsequent changes needed to support reading while writing are recorded in TODO.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    No need.
    
    Closes #2719 from reswqa/cip6-3-pr.
    
    Lead-authored-by: Weijie Guo <[email protected]>
    Co-authored-by: codenohup <[email protected]>
    Signed-off-by: Shuang <[email protected]>
---
 .../celeborn/client/ChangePartitionManager.scala   | 17 +++++++---
 .../org/apache/celeborn/client/CommitManager.scala | 14 ++++++--
 .../apache/celeborn/client/LifecycleManager.scala  | 37 +++++++++++++++-------
 .../celeborn/client/commit/CommitHandler.scala     | 17 ++++++++--
 .../client/commit/MapPartitionCommitHandler.scala  | 23 +++++++++++++-
 .../commit/ReducePartitionCommitHandler.scala      |  7 ++--
 .../client/LifecycleManagerCommitFilesSuite.scala  |  4 +--
 7 files changed, 93 insertions(+), 26 deletions(-)

diff --git 
a/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
index c08c967c3..46628aaab 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
@@ -103,7 +103,8 @@ class ChangePartitionManager(
                       if (distinctPartitions.nonEmpty) {
                         handleRequestPartitions(
                           shuffleId,
-                          distinctPartitions)
+                          distinctPartitions,
+                          
lifecycleManager.commitManager.isSegmentGranularityVisible(shuffleId))
                       }
                     }
                   }
@@ -153,7 +154,8 @@ class ChangePartitionManager(
       partitionId: Int,
       oldEpoch: Int,
       oldPartition: PartitionLocation,
-      cause: Option[StatusCode] = None): Unit = {
+      cause: Option[StatusCode] = None,
+      isSegmentGranularityVisible: Boolean): Unit = {
 
     val changePartition = ChangePartitionRequest(
       context,
@@ -195,7 +197,7 @@ class ChangePartitionManager(
       }
     }
     if (!batchHandleChangePartitionEnabled) {
-      handleRequestPartitions(shuffleId, Array(changePartition))
+      handleRequestPartitions(shuffleId, Array(changePartition), 
isSegmentGranularityVisible)
     }
   }
 
@@ -215,7 +217,8 @@ class ChangePartitionManager(
 
   def handleRequestPartitions(
       shuffleId: Int,
-      changePartitions: Array[ChangePartitionRequest]): Unit = {
+      changePartitions: Array[ChangePartitionRequest],
+      isSegmentGranularityVisible: Boolean): Unit = {
     val requestsMap = changePartitionRequests.get(shuffleId)
 
     val changes = changePartitions.map { change =>
@@ -296,7 +299,8 @@ class ChangePartitionManager(
     if (!lifecycleManager.reserveSlotsWithRetry(
         shuffleId,
         new util.HashSet(candidates.toSet.asJava),
-        newlyAllocatedLocations)) {
+        newlyAllocatedLocations,
+        isSegmentGranularityVisible = isSegmentGranularityVisible)) {
       logError(s"[Update partition] failed for $shuffleId.")
       replyFailure(StatusCode.RESERVE_SLOTS_FAILED)
       return
@@ -324,6 +328,9 @@ class ChangePartitionManager(
               s"shuffle $shuffleId, succeed partitions: " +
               s"$changes.")
           }
+
+          // TODO: should record the new partition locations and acknowledge 
the new partitionLocations to downstream task,
+          //  in scenario the downstream task start early before the upstream 
task.
           locations
       }
     replySuccess(newPrimaryLocations.toArray)
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
index 93178c047..601ebf4de 100644
--- a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
@@ -174,7 +174,10 @@ class CommitManager(appUniqueId: String, val conf: 
CelebornConf, lifecycleManage
     batchHandleCommitPartitionSchedulerThread.foreach(ThreadUtils.shutdown(_))
   }
 
-  def registerShuffle(shuffleId: Int, numMappers: Int): Unit = {
+  def registerShuffle(
+      shuffleId: Int,
+      numMappers: Int,
+      isSegmentGranularityVisible: Boolean): Unit = {
     committedPartitionInfo.put(
       shuffleId,
       ShuffleCommittedInfo(
@@ -191,7 +194,14 @@ class CommitManager(appUniqueId: String, val conf: 
CelebornConf, lifecycleManage
         new AtomicInteger(),
         JavaUtils.newConcurrentHashMap[Int, AtomicInteger]()))
 
-    getCommitHandler(shuffleId).registerShuffle(shuffleId, numMappers)
+    getCommitHandler(shuffleId).registerShuffle(
+      shuffleId,
+      numMappers,
+      isSegmentGranularityVisible);
+  }
+
+  def isSegmentGranularityVisible(shuffleId: Int): Boolean = {
+    getCommitHandler(shuffleId).isSegmentGranularityVisible(shuffleId);
   }
 
   def isMapperEnded(shuffleId: Int, mapId: Int): Boolean = {
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
index 7c22b0234..f1324f9cb 100644
--- a/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/LifecycleManager.scala
@@ -325,15 +325,17 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
       val mapId = pb.getMapId
       val attemptId = pb.getAttemptId
       val partitionId = pb.getPartitionId
+      val isSegmentGranularityVisible = pb.getIsSegmentGranularityVisible
       logDebug(s"Received Register map partition task request, " +
-        s"$shuffleId, $numMappers, $mapId, $attemptId, $partitionId.")
+        s"$shuffleId, $numMappers, $mapId, $attemptId, $partitionId, 
$isSegmentGranularityVisible.")
       shufflePartitionType.putIfAbsent(shuffleId, PartitionType.MAP)
       offerAndReserveSlots(
         RegisterCallContext(context, partitionId),
         shuffleId,
         numMappers,
         numMappers,
-        partitionId)
+        partitionId,
+        isSegmentGranularityVisible)
 
     case pb: PbRevive =>
       val shuffleId = pb.getShuffleId
@@ -377,7 +379,8 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
         shuffleId,
         partitionId,
         epoch,
-        oldPartition)
+        oldPartition,
+        isSegmentGranularityVisible = 
commitManager.isSegmentGranularityVisible(shuffleId))
 
     case MapperEnd(shuffleId, mapId, attemptId, numMappers, partitionId) =>
       logTrace(s"Received MapperEnd TaskEnd request, " +
@@ -496,7 +499,8 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
       shuffleId: Int,
       numMappers: Int,
       numPartitions: Int,
-      partitionId: Int = -1): Unit = {
+      partitionId: Int = -1,
+      isSegmentGranularityVisible: Boolean = false): Unit = {
     val partitionType = getPartitionType(shuffleId)
     registeringShuffleRequest.synchronized {
       if (registeringShuffleRequest.containsKey(shuffleId)) {
@@ -575,7 +579,8 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
           shuffleId,
           partitionId,
           -1,
-          null)
+          null,
+          isSegmentGranularityVisible = 
commitManager.isSegmentGranularityVisible(shuffleId))
       }
     }
 
@@ -681,7 +686,8 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
         shuffleId,
         candidatesWorkers,
         slots,
-        updateEpoch = false)
+        updateEpoch = false,
+        isSegmentGranularityVisible)
 
     // If reserve slots failed, clear allocated resources, reply 
ReserveSlotFailed and return.
     if (!reserveSlotsSuccess) {
@@ -703,7 +709,10 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
       }
       shuffleAllocatedWorkers.put(shuffleId, allocatedWorkers)
       registeredShuffle.add(shuffleId)
-      commitManager.registerShuffle(shuffleId, numMappers)
+      commitManager.registerShuffle(
+        shuffleId,
+        numMappers,
+        isSegmentGranularityVisible)
 
       // Fifth, reply the allocated partition location to ShuffleClient.
       logInfo(s"Handle RegisterShuffle Success for $shuffleId.")
@@ -761,7 +770,8 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
         partitionIds.get(idx),
         oldEpochs.get(idx),
         oldPartitions.get(idx),
-        Some(causes.get(idx)))
+        Some(causes.get(idx)),
+        commitManager.isSegmentGranularityVisible(shuffleId))
     }
   }
 
@@ -1083,7 +1093,8 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
    */
   private def reserveSlots(
       shuffleId: Int,
-      slots: WorkerResource): util.List[WorkerInfo] = {
+      slots: WorkerResource,
+      isSegmentGranularityVisible: Boolean = false): util.List[WorkerInfo] = {
     val reserveSlotFailedWorkers = new ShuffleFailedWorkers()
     val failureInfos = new util.concurrent.CopyOnWriteArrayList[String]()
     val workerPartitionLocations = slots.asScala.filter(p => !p._2._1.isEmpty 
|| !p._2._2.isEmpty)
@@ -1106,7 +1117,8 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
             conf.pushDataTimeoutMs,
             if (getPartitionType(shuffleId) == PartitionType.MAP)
               conf.clientShuffleMapPartitionSplitEnabled
-            else true))
+            else true,
+            isSegmentGranularityVisible))
         futures.add((future, workerInfo))
       }(ec)
     }
@@ -1303,7 +1315,8 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
       shuffleId: Int,
       candidates: util.HashSet[WorkerInfo],
       slots: WorkerResource,
-      updateEpoch: Boolean = true): Boolean = {
+      updateEpoch: Boolean = true,
+      isSegmentGranularityVisible: Boolean = false): Boolean = {
     var requestSlots = slots
     val reserveSlotsMaxRetries = conf.clientReserveSlotsMaxRetries
     val reserveSlotsRetryWait = conf.clientReserveSlotsRetryWait
@@ -1316,7 +1329,7 @@ class LifecycleManager(val appUniqueId: String, val conf: 
CelebornConf) extends
       }
       // reserve buffers
       logInfo(s"Try reserve slots for $shuffleId for $retryTimes times.")
-      val reserveFailedWorkers = reserveSlots(shuffleId, requestSlots)
+      val reserveFailedWorkers = reserveSlots(shuffleId, requestSlots, 
isSegmentGranularityVisible)
       if (reserveFailedWorkers.isEmpty) {
         success = true
       } else {
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala 
b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
index 8cc81972c..f03b82222 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
@@ -199,10 +199,20 @@ abstract class CommitHandler(
       partitionId: Int,
       recordWorkerFailure: ShuffleFailedWorkers => Unit): (Boolean, Boolean)
 
-  def registerShuffle(shuffleId: Int, numMappers: Int): Unit = {
+  def registerShuffle(
+      shuffleId: Int,
+      numMappers: Int,
+      isSegmentGranularityVisible: Boolean): Unit = {
+    // TODO: if isSegmentGranularityVisible is set to true, it is necessary to 
handle the pending
+    //  get partition request of downstream reduce task here, in scenarios 
which support
+    //  downstream task start early before the upstream task, e.g. flink 
hybrid shuffle.
     reducerFileGroupsMap.put(shuffleId, JavaUtils.newConcurrentHashMap())
   }
 
+  def isSegmentGranularityVisible(shuffleId: Int): Boolean = {
+    false
+  }
+
   def doParallelCommitFiles(
       shuffleId: Int,
       shuffleCommittedInfo: ShuffleCommittedInfo,
@@ -463,7 +473,8 @@ abstract class CommitHandler(
       primaryPartitionUniqueIds: util.Iterator[String],
       replicaPartitionUniqueIds: util.Iterator[String],
       primaryPartMap: ConcurrentHashMap[String, PartitionLocation],
-      replicaPartMap: ConcurrentHashMap[String, PartitionLocation]): Unit = {
+      replicaPartMap: ConcurrentHashMap[String, PartitionLocation],
+      isSegmentGranularityVisible: Boolean = false): Unit = {
     val committedPartitions = new util.HashMap[String, PartitionLocation]
     primaryPartitionUniqueIds.asScala.foreach { id =>
       val partitionLocation = primaryPartMap.get(id)
@@ -488,6 +499,8 @@ abstract class CommitHandler(
       }
     }
 
+    // TODO: if support upstream task write and downstream task read 
simultaneously,
+    //  should record the partition locations information in upstream task 
start time, rather than end time.
     committedPartitions.values().asScala.foreach { partition =>
       val partitionLocations = 
reducerFileGroupsMap.get(shuffleId).computeIfAbsent(
         partition.getId,
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
 
b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
index b799d0870..deb402d38 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
@@ -62,6 +62,9 @@ class MapPartitionCommitHandler(
   // shuffleId -> in processing partitionId set
   private val inProcessMapPartitionEndIds = 
JavaUtils.newConcurrentHashMap[Int, util.Set[Integer]]()
 
+  // shuffleId -> boolean, records whether the shuffle is visible at the 
segment level, facilitating future optimization of worker read and write 
processes
+  private val shuffleIsSegmentGranularityVisible = 
JavaUtils.newConcurrentHashMap[Int, Boolean]
+
   override def getPartitionType(): PartitionType = {
     PartitionType.MAP
   }
@@ -113,6 +116,7 @@ class MapPartitionCommitHandler(
   override def removeExpiredShuffle(shuffleId: Int): Unit = {
     inProcessMapPartitionEndIds.remove(shuffleId)
     shuffleSucceedPartitionIds.remove(shuffleId)
+    shuffleIsSegmentGranularityVisible.remove(shuffleId)
     super.removeExpiredShuffle(shuffleId)
   }
 
@@ -143,7 +147,8 @@ class MapPartitionCommitHandler(
         getPartitionUniqueIds(shuffleCommittedInfo.committedPrimaryIds, 
partitionId),
         getPartitionUniqueIds(shuffleCommittedInfo.committedReplicaIds, 
partitionId),
         parallelCommitResult.primaryPartitionLocationMap,
-        parallelCommitResult.replicaPartitionLocationMap)
+        parallelCommitResult.replicaPartitionLocationMap,
+        shuffleIsSegmentGranularityVisible.get(shuffleId))
     }
 
     (dataLost, parallelCommitResult.commitFilesFailedWorkers)
@@ -211,7 +216,23 @@ class MapPartitionCommitHandler(
     (dataCommitSuccess, false)
   }
 
+  override def registerShuffle(
+      shuffleId: Int,
+      numMappers: Int,
+      isSegmentGranularityVisible: Boolean): Unit = {
+    super.registerShuffle(shuffleId, numMappers, isSegmentGranularityVisible)
+    shuffleIsSegmentGranularityVisible.put(shuffleId, 
isSegmentGranularityVisible)
+  }
+
+  override def isSegmentGranularityVisible(shuffleId: Int): Boolean = {
+    shuffleIsSegmentGranularityVisible.get(shuffleId)
+  }
+
   override def handleGetReducerFileGroup(context: RpcCallContext, shuffleId: 
Int): Unit = {
+    // TODO: if support the downstream map task start early before the 
upstream reduce task, it should
+    //  waiting the upstream task register shuffle, then reply these 
GetReducerFileGroup.
+    //  Note that flink hybrid shuffle should support it in the future.
+
     // we need obtain the last succeed partitionIds
     val lastSucceedPartitionIds =
       shuffleSucceedPartitionIds.getOrDefault(shuffleId, new 
util.HashSet[Integer]())
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
 
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
index 23d6a7b8d..69b84058b 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
@@ -253,8 +253,11 @@ class ReducePartitionCommitHandler(
     }
   }
 
-  override def registerShuffle(shuffleId: Int, numMappers: Int): Unit = {
-    super.registerShuffle(shuffleId, numMappers)
+  override def registerShuffle(
+      shuffleId: Int,
+      numMappers: Int,
+      isSegmentGranularityVisible: Boolean): Unit = {
+    super.registerShuffle(shuffleId, numMappers, isSegmentGranularityVisible)
     getReducerFileGroupRequest.put(shuffleId, new 
util.HashSet[RpcCallContext]())
     initMapperAttempts(shuffleId, numMappers)
   }
diff --git 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala
 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala
index 8bdc4beeb..90192e283 100644
--- 
a/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala
+++ 
b/tests/spark-it/src/test/scala/org/apache/celeborn/tests/client/LifecycleManagerCommitFilesSuite.scala
@@ -66,7 +66,7 @@ class LifecycleManagerCommitFilesSuite extends 
WithShuffleClientSuite with MiniC
       res.workerResource,
       updateEpoch = false)
 
-    lifecycleManager.commitManager.registerShuffle(shuffleId, 1)
+    lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false)
     0 until 10 foreach { partitionId =>
       lifecycleManager.commitManager.finishMapperAttempt(shuffleId, 0, 0, 1, 
partitionId)
     }
@@ -116,7 +116,7 @@ class LifecycleManagerCommitFilesSuite extends 
WithShuffleClientSuite with MiniC
       res.workerResource,
       updateEpoch = false)
 
-    lifecycleManager.commitManager.registerShuffle(shuffleId, 1)
+    lifecycleManager.commitManager.registerShuffle(shuffleId, 1, false)
     0 until 10 foreach { partitionId =>
       lifecycleManager.commitManager.finishMapperAttempt(shuffleId, 0, 0, 1, 
partitionId)
     }

Reply via email to