This is an automated email from the ASF dual-hosted git repository.

zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git


The following commit(s) were added to refs/heads/main by this push:
     new 13769f0f [CELEBORN-121] Refactor batchHandleCommitPartition (#1089)
13769f0f is described below

commit 13769f0f0a2401aacc66162a2f6006816a175ca6
Author: Shuang <[email protected]>
AuthorDate: Mon Dec 19 12:39:39 2022 +0800

    [CELEBORN-121] Refactor batchHandleCommitPartition (#1089)
---
 .../celeborn/client/ChangePartitionManager.scala   |   1 -
 .../org/apache/celeborn/client/CommitManager.scala | 122 +++------------------
 .../celeborn/client/commit/CommitHandler.scala     |  60 +++++++++-
 .../client/commit/MapPartitionCommitHandler.scala  |  40 ++++++-
 .../commit/ReducePartitionCommitHandler.scala      |  11 +-
 5 files changed, 121 insertions(+), 113 deletions(-)

diff --git 
a/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
index 960ac1e6..0547692c 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
@@ -150,7 +150,6 @@ class ChangePartitionManager(
     inBatchPartitions.computeIfAbsent(shuffleId, inBatchShuffleIdRegisterFunc)
 
     lifecycleManager.commitManager.registerCommitPartitionRequest(
-      applicationId,
       shuffleId,
       oldPartition,
       cause)
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala 
b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
index 21551871..d88389d1 100644
--- a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
@@ -38,11 +38,6 @@ import org.apache.celeborn.common.protocol.message.StatusCode
 import org.apache.celeborn.common.util.FunctionConverter._
 import org.apache.celeborn.common.util.ThreadUtils
 
-case class CommitPartitionRequest(
-    applicationId: String,
-    shuffleId: Int,
-    partition: PartitionLocation)
-
 case class ShuffleCommittedInfo(
     committedMasterIds: ConcurrentHashMap[Int, util.List[String]],
     committedSlaveIds: ConcurrentHashMap[Int, util.List[String]],
@@ -52,8 +47,8 @@ case class ShuffleCommittedInfo(
     committedSlaveStorageInfos: ConcurrentHashMap[String, StorageInfo],
     committedMapIdBitmap: ConcurrentHashMap[String, RoaringBitmap],
     currentShuffleFileCount: LongAdder,
-    commitPartitionRequests: util.Set[CommitPartitionRequest],
-    handledCommitPartitionRequests: util.Set[PartitionLocation],
+    unHandledPartitionLocations: util.Set[PartitionLocation],
+    handledPartitionLocations: util.Set[PartitionLocation],
     allInFlightCommitRequestNum: AtomicInteger,
     partitionInFlightCommitRequestNum: ConcurrentHashMap[Int, AtomicInteger])
 
@@ -91,104 +86,20 @@ class CommitManager(appId: String, val conf: CelebornConf, 
lifecycleManager: Lif
             committedPartitionInfo.asScala.foreach { case (shuffleId, 
shuffleCommittedInfo) =>
               batchHandleCommitPartitionExecutors.submit {
                 new Runnable {
-                  val partitionType = 
lifecycleManager.getPartitionType(shuffleId)
                   val commitHandler = getCommitHandler(shuffleId)
-                  def incrementInflightNum(workerToRequests: Map[
-                    WorkerInfo,
-                    collection.Set[PartitionLocation]]): Unit = {
-                    if (partitionType == PartitionType.MAP) {
-                      workerToRequests.foreach {
-                        case (_, partitions) =>
-                          partitions.groupBy(_.getId).foreach { case (id, _) =>
-                            val atomicInteger = shuffleCommittedInfo
-                              .partitionInFlightCommitRequestNum
-                              .computeIfAbsent(id, (k: Int) => new 
AtomicInteger(0))
-                            atomicInteger.incrementAndGet()
-                          }
-                      }
-                    }
-                    shuffleCommittedInfo.allInFlightCommitRequestNum.addAndGet(
-                      workerToRequests.size)
-                  }
-
-                  def decrementInflightNum(
-                      workerToRequests: Map[WorkerInfo, 
collection.Set[PartitionLocation]])
-                      : Unit = {
-                    if (partitionType == PartitionType.MAP) {
-                      workerToRequests.foreach {
-                        case (_, partitions) =>
-                          partitions.groupBy(_.getId).foreach { case (id, _) =>
-                            
shuffleCommittedInfo.partitionInFlightCommitRequestNum.get(
-                              id).decrementAndGet()
-                          }
-                      }
-                    }
-                    shuffleCommittedInfo.allInFlightCommitRequestNum.addAndGet(
-                      -workerToRequests.size)
-                  }
-
-                  def getUnCommitPartitionRequests(
-                      commitPartitionRequests: 
util.Set[CommitPartitionRequest])
-                      : scala.collection.mutable.Set[CommitPartitionRequest] = 
{
-                    if (partitionType == PartitionType.MAP) {
-                      commitPartitionRequests.asScala.filterNot { request =>
-                        shuffleCommittedInfo.handledCommitPartitionRequests
-                          .contains(request.partition) && 
commitHandler.isPartitionInProcess(
-                          shuffleId,
-                          request.partition.getId)
-                      }
-                    } else {
-                      commitPartitionRequests.asScala.filterNot { request =>
-                        shuffleCommittedInfo.handledCommitPartitionRequests
-                          .contains(request.partition)
-                      }
-                    }
-                  }
-
                   override def run(): Unit = {
-                    val workerToRequests = shuffleCommittedInfo.synchronized {
-                      // When running to here, if handleStageEnd got lock 
first and commitFiles,
-                      // then this batch get this lock, 
commitPartitionRequests may contains
-                      // partitions which are already committed by stageEnd 
process.
-                      // But inProcessStageEndShuffleSet should have contain 
this shuffle id,
-                      // can directly return.
-                      if (commitHandler.isStageEndOrInProcess(shuffleId)) {
-                        logWarning(s"Shuffle $shuffleId ended or during 
processing stage end.")
-                        shuffleCommittedInfo.commitPartitionRequests.clear()
-                        Map.empty[WorkerInfo, Set[PartitionLocation]]
-                      } else {
-                        val currentBatch =
-                          
getUnCommitPartitionRequests(shuffleCommittedInfo.commitPartitionRequests)
-                        shuffleCommittedInfo.commitPartitionRequests.clear()
-                        currentBatch.foreach { commitPartitionRequest =>
-                          shuffleCommittedInfo.handledCommitPartitionRequests
-                            .add(commitPartitionRequest.partition)
-                          if (commitPartitionRequest.partition.getPeer != 
null) {
-                            shuffleCommittedInfo.handledCommitPartitionRequests
-                              .add(commitPartitionRequest.partition.getPeer)
-                          }
-                        }
-
-                        if (currentBatch.nonEmpty) {
-                          logWarning(s"Commit current batch HARD_SPLIT 
partitions for $shuffleId: " +
-                            
s"${currentBatch.map(_.partition.getUniqueId).mkString("[", ",", "]")}")
-                          val workerToRequests = currentBatch.flatMap { 
request =>
-                            if (request.partition.getPeer != null) {
-                              Seq(request.partition, request.partition.getPeer)
-                            } else {
-                              Seq(request.partition)
-                            }
-                          }.groupBy(_.getWorker)
-                          incrementInflightNum(workerToRequests)
-                          workerToRequests
-                        } else {
-                          Map.empty[WorkerInfo, Set[PartitionLocation]]
-                        }
-                      }
+                    var workerToRequests: Map[WorkerInfo, 
collection.Set[PartitionLocation]] = null
+                    shuffleCommittedInfo.synchronized {
+                      workerToRequests =
+                        commitHandler.batchUnHandledRequests(shuffleId, 
shuffleCommittedInfo)
+                      // when batch commit thread starts to commit these 
requests, we should increment inFlightNum,
+                      // then stage/partition end would be able to recognize 
all requests are over.
+                      commitHandler.incrementInFlightNum(shuffleCommittedInfo, 
workerToRequests)
                     }
+
                     if (workerToRequests.nonEmpty) {
                       val commitFilesFailedWorkers = new ShuffleFailedWorkers()
-                      val parallelism = workerToRequests.size
+                      val parallelism = Math.min(workerToRequests.size, 
conf.rpcMaxParallelism)
                       try {
                         ThreadUtils.parmap(
                           workerToRequests.to,
@@ -226,7 +137,8 @@ class CommitManager(appId: String, val conf: CelebornConf, 
lifecycleManager: Lif
                         }
                         
lifecycleManager.recordWorkerFailure(commitFilesFailedWorkers)
                       } finally {
-                        decrementInflightNum(workerToRequests)
+                        // when batch commit thread ends, we need 
decrementInFlightNum
+                        
commitHandler.decrementInFlightNum(shuffleCommittedInfo, workerToRequests)
                       }
                     }
                   }
@@ -258,7 +170,7 @@ class CommitManager(appId: String, val conf: CelebornConf, 
lifecycleManager: Lif
         new ConcurrentHashMap[String, StorageInfo](),
         new ConcurrentHashMap[String, RoaringBitmap](),
         new LongAdder,
-        new util.HashSet[CommitPartitionRequest](),
+        new util.HashSet[PartitionLocation](),
         new util.HashSet[PartitionLocation](),
         new AtomicInteger(),
         new ConcurrentHashMap[Int, AtomicInteger]()))
@@ -270,15 +182,13 @@ class CommitManager(appId: String, val conf: 
CelebornConf, lifecycleManager: Lif
   }
 
   def registerCommitPartitionRequest(
-      applicationId: String,
       shuffleId: Int,
-      partition: PartitionLocation,
+      partitionLocation: PartitionLocation,
       cause: Option[StatusCode]): Unit = {
     if (batchHandleCommitPartitionEnabled && cause.isDefined && cause.get == 
StatusCode.HARD_SPLIT) {
       val shuffleCommittedInfo = committedPartitionInfo.get(shuffleId)
       shuffleCommittedInfo.synchronized {
-        shuffleCommittedInfo.commitPartitionRequests
-          .add(CommitPartitionRequest(applicationId, shuffleId, partition))
+        shuffleCommittedInfo.unHandledPartitionLocations.add(partitionLocation)
       }
     }
   }
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala 
b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
index dd38c965..81c33426 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
@@ -22,6 +22,7 @@ import java.util.concurrent.ConcurrentHashMap
 import java.util.concurrent.atomic.{AtomicLong, LongAdder}
 
 import scala.collection.JavaConverters._
+import scala.collection.mutable
 
 import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
 import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, 
ShuffleFailedWorkers, ShuffleFileGroups, ShuffleMapperAttempts}
@@ -67,9 +68,64 @@ abstract class CommitHandler(
 
   def isPartitionInProcess(shuffleId: Int, partitionId: Int): Boolean = false
 
+  def batchUnHandledRequests(shuffleId: Int, shuffleCommittedInfo: 
ShuffleCommittedInfo)
+      : Map[WorkerInfo, collection.Set[PartitionLocation]] = {
+    // When running to here, if handleStageEnd got lock first and commitFiles,
+    // then this batch get this lock, commitPartitionRequests may contains
+    // partitions which are already committed by stageEnd process.
+    // But inProcessStageEndShuffleSet should have contain this shuffle id,
+    // can directly return empty.
+    if (this.isStageEndOrInProcess(shuffleId)) {
+      logWarning(s"Shuffle $shuffleId ended or during processing stage end.")
+      shuffleCommittedInfo.unHandledPartitionLocations.clear()
+      Map.empty[WorkerInfo, Set[PartitionLocation]]
+    } else {
+      val currentBatch = this.getUnHandledPartitionLocations(shuffleId, 
shuffleCommittedInfo)
+      shuffleCommittedInfo.unHandledPartitionLocations.clear()
+      currentBatch.foreach { partitionLocation =>
+        shuffleCommittedInfo.handledPartitionLocations.add(partitionLocation)
+        if (partitionLocation.getPeer != null) {
+          
shuffleCommittedInfo.handledPartitionLocations.add(partitionLocation.getPeer)
+        }
+      }
+
+      if (currentBatch.nonEmpty) {
+        logWarning(s"Commit current batch HARD_SPLIT partitions for 
$shuffleId: " +
+          s"${currentBatch.map(_.getUniqueId).mkString("[", ",", "]")}")
+        val workerToRequests = currentBatch.flatMap { partitionLocation =>
+          if (partitionLocation.getPeer != null) {
+            Seq(partitionLocation, partitionLocation.getPeer)
+          } else {
+            Seq(partitionLocation)
+          }
+        }.groupBy(_.getWorker)
+        workerToRequests
+      } else {
+        Map.empty[WorkerInfo, Set[PartitionLocation]]
+      }
+    }
+  }
+
+  protected def getUnHandledPartitionLocations(
+      shuffleId: Int,
+      shuffleCommittedInfo: ShuffleCommittedInfo): 
mutable.Set[PartitionLocation]
+
+  def incrementInFlightNum(
+      shuffleCommittedInfo: ShuffleCommittedInfo,
+      workerToRequests: Map[WorkerInfo, collection.Set[PartitionLocation]]): 
Unit = {
+    
shuffleCommittedInfo.allInFlightCommitRequestNum.addAndGet(workerToRequests.size)
+  }
+
+  def decrementInFlightNum(
+      shuffleCommittedInfo: ShuffleCommittedInfo,
+      workerToRequests: Map[WorkerInfo, collection.Set[PartitionLocation]]): 
Unit = {
+    
shuffleCommittedInfo.allInFlightCommitRequestNum.addAndGet(-workerToRequests.size)
+  }
+
   /**
    * when someone calls tryFinalCommit, the function will return true if there 
is no one ever do final commit before,
    * otherwise it will return false.
+   *
    * @return
    */
   def tryFinalCommit(
@@ -119,10 +175,10 @@ abstract class CommitHandler(
         val (masterIds, slaveIds) = shuffleCommittedInfo.synchronized {
           (
             masterParts.asScala
-              
.filterNot(shuffleCommittedInfo.handledCommitPartitionRequests.contains)
+              
.filterNot(shuffleCommittedInfo.handledPartitionLocations.contains)
               .map(_.getUniqueId).asJava,
             slaveParts.asScala
-              
.filterNot(shuffleCommittedInfo.handledCommitPartitionRequests.contains)
+              
.filterNot(shuffleCommittedInfo.handledPartitionLocations.contains)
               .map(_.getUniqueId).asJava)
         }
 
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
 
b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
index 2b25080b..a1685276 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
@@ -20,8 +20,10 @@ package org.apache.celeborn.client.commit
 import java.util
 import java.util.Collections
 import java.util.concurrent.ConcurrentHashMap
+import java.util.concurrent.atomic.AtomicInteger
 
 import scala.collection.JavaConverters._
+import scala.collection.mutable
 
 import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
 import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, 
ShuffleFailedWorkers, ShuffleFileGroups}
@@ -29,7 +31,7 @@ import org.apache.celeborn.client.ShuffleCommittedInfo
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.internal.Logging
 import org.apache.celeborn.common.meta.{PartitionLocationInfo, WorkerInfo}
-import org.apache.celeborn.common.protocol.PartitionType
+import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionType}
 // Can Remove this if celeborn don't support scala211 in future
 import org.apache.celeborn.common.util.FunctionConverter._
 import org.apache.celeborn.common.util.Utils
@@ -107,6 +109,39 @@ class MapPartitionCommitHandler(
     dataCommitSuccess
   }
 
+  override def getUnHandledPartitionLocations(
+      shuffleId: Int,
+      shuffleCommittedInfo: ShuffleCommittedInfo): 
mutable.Set[PartitionLocation] = {
+    shuffleCommittedInfo.unHandledPartitionLocations.asScala.filterNot { 
partitionLocation =>
+      
shuffleCommittedInfo.handledPartitionLocations.contains(partitionLocation) &&
+      this.isPartitionInProcess(shuffleId, partitionLocation.getId)
+    }
+  }
+
+  override def incrementInFlightNum(
+      shuffleCommittedInfo: ShuffleCommittedInfo,
+      workerToRequests: Map[WorkerInfo, collection.Set[PartitionLocation]]): 
Unit = {
+    workerToRequests.foreach {
+      case (_, partitions) =>
+        partitions.groupBy(_.getId).foreach { case (id, _) =>
+          val atomicInteger = 
shuffleCommittedInfo.partitionInFlightCommitRequestNum
+            .computeIfAbsent(id, (k: Int) => new AtomicInteger(0))
+          atomicInteger.incrementAndGet()
+        }
+    }
+  }
+
+  override def decrementInFlightNum(
+      shuffleCommittedInfo: ShuffleCommittedInfo,
+      workerToRequests: Map[WorkerInfo, collection.Set[PartitionLocation]]): 
Unit = {
+    workerToRequests.foreach {
+      case (_, partitions) =>
+        partitions.groupBy(_.getId).foreach { case (id, _) =>
+          
shuffleCommittedInfo.partitionInFlightCommitRequestNum.get(id).decrementAndGet()
+        }
+    }
+  }
+
   override def getShuffleMapperAttempts(shuffleId: Int): Array[Int] = {
     // map partition now return empty mapper attempts array as map partition 
don't prevent other mapper commit file
     // even the same mapper id with another attemptId success in lifecycle 
manager.
@@ -165,8 +200,7 @@ class MapPartitionCommitHandler(
       partitionIds: ConcurrentHashMap[String, WorkerInfo],
       partitionId: Int): util.Map[String, WorkerInfo] = {
     partitionIds.asScala.filter(p =>
-      Utils.splitPartitionLocationUniqueId(p._1)._1 ==
-        partitionId).asJava
+      Utils.splitPartitionLocationUniqueId(p._1)._1 == partitionId).asJava
   }
 
   private def getPartitionUniqueIds(
diff --git 
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
 
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
index 34e6758e..11566063 100644
--- 
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
+++ 
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
@@ -21,6 +21,7 @@ import java.util
 import java.util.concurrent.ConcurrentHashMap
 
 import scala.collection.JavaConverters._
+import scala.collection.mutable
 
 import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
 import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers, 
ShuffleFailedWorkers, ShuffleFileGroups, ShuffleMapperAttempts}
@@ -28,7 +29,7 @@ import org.apache.celeborn.client.ShuffleCommittedInfo
 import org.apache.celeborn.common.CelebornConf
 import org.apache.celeborn.common.internal.Logging
 import org.apache.celeborn.common.meta.{PartitionLocationInfo, WorkerInfo}
-import org.apache.celeborn.common.protocol.PartitionType
+import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionType}
 
 /**
  * This commit handler is for ReducePartition ShuffleType, which means that a 
Reduce Partition contains all data
@@ -150,6 +151,14 @@ class ReducePartitionCommitHandler(
     (dataLost, parallelCommitResult.commitFilesFailedWorkers)
   }
 
+  override def getUnHandledPartitionLocations(
+      shuffleId: Int,
+      shuffleCommittedInfo: ShuffleCommittedInfo): 
mutable.Set[PartitionLocation] = {
+    shuffleCommittedInfo.unHandledPartitionLocations.asScala.filterNot { 
partitionLocation =>
+      
shuffleCommittedInfo.handledPartitionLocations.contains(partitionLocation)
+    }
+  }
+
   override def finalPartitionCommit(
       shuffleId: Int,
       partitionId: Int,

Reply via email to