This is an automated email from the ASF dual-hosted git repository.
zhouky pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-celeborn.git
The following commit(s) were added to refs/heads/main by this push:
new 13769f0f [CELEBORN-121] Refactor batchHandleCommitPartition (#1089)
13769f0f is described below
commit 13769f0f0a2401aacc66162a2f6006816a175ca6
Author: Shuang <[email protected]>
AuthorDate: Mon Dec 19 12:39:39 2022 +0800
[CELEBORN-121] Refactor batchHandleCommitPartition (#1089)
---
.../celeborn/client/ChangePartitionManager.scala | 1 -
.../org/apache/celeborn/client/CommitManager.scala | 122 +++------------------
.../celeborn/client/commit/CommitHandler.scala | 60 +++++++++-
.../client/commit/MapPartitionCommitHandler.scala | 40 ++++++-
.../commit/ReducePartitionCommitHandler.scala | 11 +-
5 files changed, 121 insertions(+), 113 deletions(-)
diff --git
a/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
b/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
index 960ac1e6..0547692c 100644
---
a/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
+++
b/client/src/main/scala/org/apache/celeborn/client/ChangePartitionManager.scala
@@ -150,7 +150,6 @@ class ChangePartitionManager(
inBatchPartitions.computeIfAbsent(shuffleId, inBatchShuffleIdRegisterFunc)
lifecycleManager.commitManager.registerCommitPartitionRequest(
- applicationId,
shuffleId,
oldPartition,
cause)
diff --git
a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
index 21551871..d88389d1 100644
--- a/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
+++ b/client/src/main/scala/org/apache/celeborn/client/CommitManager.scala
@@ -38,11 +38,6 @@ import org.apache.celeborn.common.protocol.message.StatusCode
import org.apache.celeborn.common.util.FunctionConverter._
import org.apache.celeborn.common.util.ThreadUtils
-case class CommitPartitionRequest(
- applicationId: String,
- shuffleId: Int,
- partition: PartitionLocation)
-
case class ShuffleCommittedInfo(
committedMasterIds: ConcurrentHashMap[Int, util.List[String]],
committedSlaveIds: ConcurrentHashMap[Int, util.List[String]],
@@ -52,8 +47,8 @@ case class ShuffleCommittedInfo(
committedSlaveStorageInfos: ConcurrentHashMap[String, StorageInfo],
committedMapIdBitmap: ConcurrentHashMap[String, RoaringBitmap],
currentShuffleFileCount: LongAdder,
- commitPartitionRequests: util.Set[CommitPartitionRequest],
- handledCommitPartitionRequests: util.Set[PartitionLocation],
+ unHandledPartitionLocations: util.Set[PartitionLocation],
+ handledPartitionLocations: util.Set[PartitionLocation],
allInFlightCommitRequestNum: AtomicInteger,
partitionInFlightCommitRequestNum: ConcurrentHashMap[Int, AtomicInteger])
@@ -91,104 +86,20 @@ class CommitManager(appId: String, val conf: CelebornConf,
lifecycleManager: Lif
committedPartitionInfo.asScala.foreach { case (shuffleId,
shuffleCommittedInfo) =>
batchHandleCommitPartitionExecutors.submit {
new Runnable {
- val partitionType =
lifecycleManager.getPartitionType(shuffleId)
val commitHandler = getCommitHandler(shuffleId)
- def incrementInflightNum(workerToRequests: Map[
- WorkerInfo,
- collection.Set[PartitionLocation]]): Unit = {
- if (partitionType == PartitionType.MAP) {
- workerToRequests.foreach {
- case (_, partitions) =>
- partitions.groupBy(_.getId).foreach { case (id, _) =>
- val atomicInteger = shuffleCommittedInfo
- .partitionInFlightCommitRequestNum
- .computeIfAbsent(id, (k: Int) => new
AtomicInteger(0))
- atomicInteger.incrementAndGet()
- }
- }
- }
- shuffleCommittedInfo.allInFlightCommitRequestNum.addAndGet(
- workerToRequests.size)
- }
-
- def decrementInflightNum(
- workerToRequests: Map[WorkerInfo,
collection.Set[PartitionLocation]])
- : Unit = {
- if (partitionType == PartitionType.MAP) {
- workerToRequests.foreach {
- case (_, partitions) =>
- partitions.groupBy(_.getId).foreach { case (id, _) =>
-
shuffleCommittedInfo.partitionInFlightCommitRequestNum.get(
- id).decrementAndGet()
- }
- }
- }
- shuffleCommittedInfo.allInFlightCommitRequestNum.addAndGet(
- -workerToRequests.size)
- }
-
- def getUnCommitPartitionRequests(
- commitPartitionRequests:
util.Set[CommitPartitionRequest])
- : scala.collection.mutable.Set[CommitPartitionRequest] =
{
- if (partitionType == PartitionType.MAP) {
- commitPartitionRequests.asScala.filterNot { request =>
- shuffleCommittedInfo.handledCommitPartitionRequests
- .contains(request.partition) &&
commitHandler.isPartitionInProcess(
- shuffleId,
- request.partition.getId)
- }
- } else {
- commitPartitionRequests.asScala.filterNot { request =>
- shuffleCommittedInfo.handledCommitPartitionRequests
- .contains(request.partition)
- }
- }
- }
-
override def run(): Unit = {
- val workerToRequests = shuffleCommittedInfo.synchronized {
- // When running to here, if handleStageEnd got lock
first and commitFiles,
- // then this batch get this lock,
commitPartitionRequests may contains
- // partitions which are already committed by stageEnd
process.
- // But inProcessStageEndShuffleSet should have contain
this shuffle id,
- // can directly return.
- if (commitHandler.isStageEndOrInProcess(shuffleId)) {
- logWarning(s"Shuffle $shuffleId ended or during
processing stage end.")
- shuffleCommittedInfo.commitPartitionRequests.clear()
- Map.empty[WorkerInfo, Set[PartitionLocation]]
- } else {
- val currentBatch =
-
getUnCommitPartitionRequests(shuffleCommittedInfo.commitPartitionRequests)
- shuffleCommittedInfo.commitPartitionRequests.clear()
- currentBatch.foreach { commitPartitionRequest =>
- shuffleCommittedInfo.handledCommitPartitionRequests
- .add(commitPartitionRequest.partition)
- if (commitPartitionRequest.partition.getPeer !=
null) {
- shuffleCommittedInfo.handledCommitPartitionRequests
- .add(commitPartitionRequest.partition.getPeer)
- }
- }
-
- if (currentBatch.nonEmpty) {
- logWarning(s"Commit current batch HARD_SPLIT
partitions for $shuffleId: " +
-
s"${currentBatch.map(_.partition.getUniqueId).mkString("[", ",", "]")}")
- val workerToRequests = currentBatch.flatMap {
request =>
- if (request.partition.getPeer != null) {
- Seq(request.partition, request.partition.getPeer)
- } else {
- Seq(request.partition)
- }
- }.groupBy(_.getWorker)
- incrementInflightNum(workerToRequests)
- workerToRequests
- } else {
- Map.empty[WorkerInfo, Set[PartitionLocation]]
- }
- }
+ var workerToRequests: Map[WorkerInfo,
collection.Set[PartitionLocation]] = null
+ shuffleCommittedInfo.synchronized {
+ workerToRequests =
+ commitHandler.batchUnHandledRequests(shuffleId,
shuffleCommittedInfo)
+ // when batch commit thread starts to commit these
requests, we should increment inFlightNum,
+ // then stage/partition end would be able to recognize
all requests are over.
+ commitHandler.incrementInFlightNum(shuffleCommittedInfo,
workerToRequests)
}
+
if (workerToRequests.nonEmpty) {
val commitFilesFailedWorkers = new ShuffleFailedWorkers()
- val parallelism = workerToRequests.size
+ val parallelism = Math.min(workerToRequests.size,
conf.rpcMaxParallelism)
try {
ThreadUtils.parmap(
workerToRequests.to,
@@ -226,7 +137,8 @@ class CommitManager(appId: String, val conf: CelebornConf,
lifecycleManager: Lif
}
lifecycleManager.recordWorkerFailure(commitFilesFailedWorkers)
} finally {
- decrementInflightNum(workerToRequests)
+ // when batch commit thread ends, we need
decrementInFlightNum
+
commitHandler.decrementInFlightNum(shuffleCommittedInfo, workerToRequests)
}
}
}
@@ -258,7 +170,7 @@ class CommitManager(appId: String, val conf: CelebornConf,
lifecycleManager: Lif
new ConcurrentHashMap[String, StorageInfo](),
new ConcurrentHashMap[String, RoaringBitmap](),
new LongAdder,
- new util.HashSet[CommitPartitionRequest](),
+ new util.HashSet[PartitionLocation](),
new util.HashSet[PartitionLocation](),
new AtomicInteger(),
new ConcurrentHashMap[Int, AtomicInteger]()))
@@ -270,15 +182,13 @@ class CommitManager(appId: String, val conf:
CelebornConf, lifecycleManager: Lif
}
def registerCommitPartitionRequest(
- applicationId: String,
shuffleId: Int,
- partition: PartitionLocation,
+ partitionLocation: PartitionLocation,
cause: Option[StatusCode]): Unit = {
if (batchHandleCommitPartitionEnabled && cause.isDefined && cause.get ==
StatusCode.HARD_SPLIT) {
val shuffleCommittedInfo = committedPartitionInfo.get(shuffleId)
shuffleCommittedInfo.synchronized {
- shuffleCommittedInfo.commitPartitionRequests
- .add(CommitPartitionRequest(applicationId, shuffleId, partition))
+ shuffleCommittedInfo.unHandledPartitionLocations.add(partitionLocation)
}
}
}
diff --git
a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
index dd38c965..81c33426 100644
---
a/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
+++
b/client/src/main/scala/org/apache/celeborn/client/commit/CommitHandler.scala
@@ -22,6 +22,7 @@ import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.atomic.{AtomicLong, LongAdder}
import scala.collection.JavaConverters._
+import scala.collection.mutable
import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers,
ShuffleFailedWorkers, ShuffleFileGroups, ShuffleMapperAttempts}
@@ -67,9 +68,64 @@ abstract class CommitHandler(
def isPartitionInProcess(shuffleId: Int, partitionId: Int): Boolean = false
+ def batchUnHandledRequests(shuffleId: Int, shuffleCommittedInfo:
ShuffleCommittedInfo)
+ : Map[WorkerInfo, collection.Set[PartitionLocation]] = {
+ // When running to here, if handleStageEnd got lock first and commitFiles,
+ // then this batch get this lock, commitPartitionRequests may contains
+ // partitions which are already committed by stageEnd process.
+ // But inProcessStageEndShuffleSet should have contain this shuffle id,
+ // can directly return empty.
+ if (this.isStageEndOrInProcess(shuffleId)) {
+ logWarning(s"Shuffle $shuffleId ended or during processing stage end.")
+ shuffleCommittedInfo.unHandledPartitionLocations.clear()
+ Map.empty[WorkerInfo, Set[PartitionLocation]]
+ } else {
+ val currentBatch = this.getUnHandledPartitionLocations(shuffleId,
shuffleCommittedInfo)
+ shuffleCommittedInfo.unHandledPartitionLocations.clear()
+ currentBatch.foreach { partitionLocation =>
+ shuffleCommittedInfo.handledPartitionLocations.add(partitionLocation)
+ if (partitionLocation.getPeer != null) {
+
shuffleCommittedInfo.handledPartitionLocations.add(partitionLocation.getPeer)
+ }
+ }
+
+ if (currentBatch.nonEmpty) {
+ logWarning(s"Commit current batch HARD_SPLIT partitions for
$shuffleId: " +
+ s"${currentBatch.map(_.getUniqueId).mkString("[", ",", "]")}")
+ val workerToRequests = currentBatch.flatMap { partitionLocation =>
+ if (partitionLocation.getPeer != null) {
+ Seq(partitionLocation, partitionLocation.getPeer)
+ } else {
+ Seq(partitionLocation)
+ }
+ }.groupBy(_.getWorker)
+ workerToRequests
+ } else {
+ Map.empty[WorkerInfo, Set[PartitionLocation]]
+ }
+ }
+ }
+
+ protected def getUnHandledPartitionLocations(
+ shuffleId: Int,
+ shuffleCommittedInfo: ShuffleCommittedInfo):
mutable.Set[PartitionLocation]
+
+ def incrementInFlightNum(
+ shuffleCommittedInfo: ShuffleCommittedInfo,
+ workerToRequests: Map[WorkerInfo, collection.Set[PartitionLocation]]):
Unit = {
+
shuffleCommittedInfo.allInFlightCommitRequestNum.addAndGet(workerToRequests.size)
+ }
+
+ def decrementInFlightNum(
+ shuffleCommittedInfo: ShuffleCommittedInfo,
+ workerToRequests: Map[WorkerInfo, collection.Set[PartitionLocation]]):
Unit = {
+
shuffleCommittedInfo.allInFlightCommitRequestNum.addAndGet(-workerToRequests.size)
+ }
+
/**
* when someone calls tryFinalCommit, the function will return true if there
is no one ever do final commit before,
* otherwise it will return false.
+ *
* @return
*/
def tryFinalCommit(
@@ -119,10 +175,10 @@ abstract class CommitHandler(
val (masterIds, slaveIds) = shuffleCommittedInfo.synchronized {
(
masterParts.asScala
-
.filterNot(shuffleCommittedInfo.handledCommitPartitionRequests.contains)
+
.filterNot(shuffleCommittedInfo.handledPartitionLocations.contains)
.map(_.getUniqueId).asJava,
slaveParts.asScala
-
.filterNot(shuffleCommittedInfo.handledCommitPartitionRequests.contains)
+
.filterNot(shuffleCommittedInfo.handledPartitionLocations.contains)
.map(_.getUniqueId).asJava)
}
diff --git
a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
index 2b25080b..a1685276 100644
---
a/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
+++
b/client/src/main/scala/org/apache/celeborn/client/commit/MapPartitionCommitHandler.scala
@@ -20,8 +20,10 @@ package org.apache.celeborn.client.commit
import java.util
import java.util.Collections
import java.util.concurrent.ConcurrentHashMap
+import java.util.concurrent.atomic.AtomicInteger
import scala.collection.JavaConverters._
+import scala.collection.mutable
import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers,
ShuffleFailedWorkers, ShuffleFileGroups}
@@ -29,7 +31,7 @@ import org.apache.celeborn.client.ShuffleCommittedInfo
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.meta.{PartitionLocationInfo, WorkerInfo}
-import org.apache.celeborn.common.protocol.PartitionType
+import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionType}
// Can Remove this if celeborn don't support scala211 in future
import org.apache.celeborn.common.util.FunctionConverter._
import org.apache.celeborn.common.util.Utils
@@ -107,6 +109,39 @@ class MapPartitionCommitHandler(
dataCommitSuccess
}
+ override def getUnHandledPartitionLocations(
+ shuffleId: Int,
+ shuffleCommittedInfo: ShuffleCommittedInfo):
mutable.Set[PartitionLocation] = {
+ shuffleCommittedInfo.unHandledPartitionLocations.asScala.filterNot {
partitionLocation =>
+
shuffleCommittedInfo.handledPartitionLocations.contains(partitionLocation) &&
+ this.isPartitionInProcess(shuffleId, partitionLocation.getId)
+ }
+ }
+
+ override def incrementInFlightNum(
+ shuffleCommittedInfo: ShuffleCommittedInfo,
+ workerToRequests: Map[WorkerInfo, collection.Set[PartitionLocation]]):
Unit = {
+ workerToRequests.foreach {
+ case (_, partitions) =>
+ partitions.groupBy(_.getId).foreach { case (id, _) =>
+ val atomicInteger =
shuffleCommittedInfo.partitionInFlightCommitRequestNum
+ .computeIfAbsent(id, (k: Int) => new AtomicInteger(0))
+ atomicInteger.incrementAndGet()
+ }
+ }
+ }
+
+ override def decrementInFlightNum(
+ shuffleCommittedInfo: ShuffleCommittedInfo,
+ workerToRequests: Map[WorkerInfo, collection.Set[PartitionLocation]]):
Unit = {
+ workerToRequests.foreach {
+ case (_, partitions) =>
+ partitions.groupBy(_.getId).foreach { case (id, _) =>
+
shuffleCommittedInfo.partitionInFlightCommitRequestNum.get(id).decrementAndGet()
+ }
+ }
+ }
+
override def getShuffleMapperAttempts(shuffleId: Int): Array[Int] = {
// map partition now return empty mapper attempts array as map partition
don't prevent other mapper commit file
// even the same mapper id with another attemptId success in lifecycle
manager.
@@ -165,8 +200,7 @@ class MapPartitionCommitHandler(
partitionIds: ConcurrentHashMap[String, WorkerInfo],
partitionId: Int): util.Map[String, WorkerInfo] = {
partitionIds.asScala.filter(p =>
- Utils.splitPartitionLocationUniqueId(p._1)._1 ==
- partitionId).asJava
+ Utils.splitPartitionLocationUniqueId(p._1)._1 == partitionId).asJava
}
private def getPartitionUniqueIds(
diff --git
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
index 34e6758e..11566063 100644
---
a/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
+++
b/client/src/main/scala/org/apache/celeborn/client/commit/ReducePartitionCommitHandler.scala
@@ -21,6 +21,7 @@ import java.util
import java.util.concurrent.ConcurrentHashMap
import scala.collection.JavaConverters._
+import scala.collection.mutable
import org.apache.celeborn.client.CommitManager.CommittedPartitionInfo
import org.apache.celeborn.client.LifecycleManager.{ShuffleAllocatedWorkers,
ShuffleFailedWorkers, ShuffleFileGroups, ShuffleMapperAttempts}
@@ -28,7 +29,7 @@ import org.apache.celeborn.client.ShuffleCommittedInfo
import org.apache.celeborn.common.CelebornConf
import org.apache.celeborn.common.internal.Logging
import org.apache.celeborn.common.meta.{PartitionLocationInfo, WorkerInfo}
-import org.apache.celeborn.common.protocol.PartitionType
+import org.apache.celeborn.common.protocol.{PartitionLocation, PartitionType}
/**
* This commit handler is for ReducePartition ShuffleType, which means that a
Reduce Partition contains all data
@@ -150,6 +151,14 @@ class ReducePartitionCommitHandler(
(dataLost, parallelCommitResult.commitFilesFailedWorkers)
}
+ override def getUnHandledPartitionLocations(
+ shuffleId: Int,
+ shuffleCommittedInfo: ShuffleCommittedInfo):
mutable.Set[PartitionLocation] = {
+ shuffleCommittedInfo.unHandledPartitionLocations.asScala.filterNot {
partitionLocation =>
+
shuffleCommittedInfo.handledPartitionLocations.contains(partitionLocation)
+ }
+ }
+
override def finalPartitionCommit(
shuffleId: Int,
partitionId: Int,