This is an automated email from the ASF dual-hosted git repository. jiangxb1987 pushed a commit to branch branch-3.4 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.4 by this push: new f68ece9e607 [SPARK-43043][CORE] Improve the performance of MapOutputTracker.updateMapOutput f68ece9e607 is described below commit f68ece9e6074cecdaf74ad9b39eae3c7dc2cfaf1 Author: Xingbo Jiang <xingbo.ji...@databricks.com> AuthorDate: Tue May 16 11:34:30 2023 -0700 [SPARK-43043][CORE] Improve the performance of MapOutputTracker.updateMapOutput ### What changes were proposed in this pull request? The PR changes the implementation of MapOutputTracker.updateMapOutput() to search for the MapStatus under the help of a mapping from mapId to mapIndex, previously it was performing a linear search, which would become performance bottleneck if a large proportion of all blocks in the map are migrated. ### Why are the changes needed? To avoid performance bottleneck when block decommission is enabled and a lot of blocks are migrated within a short time window. ### Does this PR introduce _any_ user-facing change? No, it's pure performance improvement. ### How was this patch tested? Manually test. Closes #40690 from jiangxb1987/SPARK-43043. Lead-authored-by: Xingbo Jiang <xingbo.ji...@databricks.com> Co-authored-by: Jiang Xingbo <jiangxb1...@gmail.com> Signed-off-by: Xingbo Jiang <xingbo.ji...@databricks.com> (cherry picked from commit 66a2eb8f8957c22c69519b39be59beaaf931822b) Signed-off-by: Xingbo Jiang <xingbo.ji...@databricks.com> --- .../scala/org/apache/spark/MapOutputTracker.scala | 26 +++++++++++++++++----- .../apache/spark/util/collection/OpenHashMap.scala | 18 +++++++++++++++ .../spark/util/collection/OpenHashMapSuite.scala | 18 +++++++++++++++ 3 files changed, 56 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index fade0b86dd8..2dd3a903ee2 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -42,6 +42,7 @@ import org.apache.spark.scheduler.{MapStatus, MergeStatus, ShuffleOutputStatus} import org.apache.spark.shuffle.MetadataFetchFailedException import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockId, ShuffleMergedBlockId} import org.apache.spark.util._ +import org.apache.spark.util.collection.OpenHashMap import org.apache.spark.util.io.{ChunkedByteBuffer, ChunkedByteBufferOutputStream} /** @@ -147,6 +148,12 @@ private class ShuffleStatus( private[this] var shufflePushMergerLocations: Seq[BlockManagerId] = Seq.empty + /** + * Mapping from a mapId to the mapIndex, this is required to reduce the searching overhead within + * the function updateMapOutput(mapId, bmAddress). + */ + private[this] val mapIdToMapIndex = new OpenHashMap[Long, Int]() + /** * Register a map output. If there is already a registered location for the map output then it * will be replaced by the new location. @@ -157,6 +164,14 @@ private class ShuffleStatus( invalidateSerializedMapOutputStatusCache() } mapStatuses(mapIndex) = status + mapIdToMapIndex(status.mapId) = mapIndex + } + + /** + * Get the map output that corresponding to a given mapId. + */ + def getMapStatus(mapId: Long): Option[MapStatus] = withReadLock { + mapIdToMapIndex.get(mapId).map(mapStatuses(_)) } /** @@ -164,15 +179,16 @@ private class ShuffleStatus( */ def updateMapOutput(mapId: Long, bmAddress: BlockManagerId): Unit = withWriteLock { try { - val mapStatusOpt = mapStatuses.find(x => x != null && x.mapId == mapId) + val mapIndex = mapIdToMapIndex.get(mapId) + val mapStatusOpt = mapIndex.map(mapStatuses(_)).flatMap(Option(_)) mapStatusOpt match { case Some(mapStatus) => logInfo(s"Updating map output for ${mapId} to ${bmAddress}") mapStatus.updateLocation(bmAddress) invalidateSerializedMapOutputStatusCache() case None => - val index = mapStatusesDeleted.indexWhere(x => x != null && x.mapId == mapId) - if (index >= 0 && mapStatuses(index) == null) { + if (mapIndex.map(mapStatusesDeleted).exists(_.mapId == mapId)) { + val index = mapIndex.get val mapStatus = mapStatusesDeleted(index) mapStatus.updateLocation(bmAddress) mapStatuses(index) = mapStatus @@ -1133,9 +1149,7 @@ private[spark] class MapOutputTrackerMaster( */ def getMapOutputLocation(shuffleId: Int, mapId: Long): Option[BlockManagerId] = { shuffleStatuses.get(shuffleId).flatMap { shuffleStatus => - shuffleStatus.withMapStatuses { mapStatues => - mapStatues.filter(_ != null).find(_.mapId == mapId).map(_.location) - } + shuffleStatus.getMapStatus(mapId).map(_.location) } } diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala index 79e1a3562ae..e421a1f4746 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala @@ -75,6 +75,24 @@ class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag]( } } + /** Get the value for a given key, return None if the key doesn't exist */ + def get(k: K): Option[V] = { + if (k == null) { + if (haveNullValue) { + Some(nullValue) + } else { + None + } + } else { + val pos = _keySet.getPos(k) + if (pos < 0) { + None + } else { + Some(_values(pos)) + } + } + } + /** Set the value for a key */ def update(k: K, v: V): Unit = { if (k == null) { diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala index 08fed933640..1af99e9017c 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala @@ -231,4 +231,22 @@ class OpenHashMapSuite extends SparkFunSuite with Matchers { assert(map2("b") === 0.0) assert(map2("c") === null) } + + test("get") { + val map = new OpenHashMap[String, String]() + + // Get with normal/null keys. + map("1") = "1" + assert(map.get("1") === Some("1")) + assert(map.get("2") === None) + assert(map.get(null) === None) + map(null) = "hello" + assert(map.get(null) === Some("hello")) + + // Get with null values. + map("1") = null + assert(map.get("1") === Some(null)) + map(null) = null + assert(map.get(null) === Some(null)) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org