Github user zsxwing commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19763#discussion_r152084974
  
    --- Diff: core/src/main/scala/org/apache/spark/MapOutputTracker.scala ---
    @@ -472,16 +475,48 @@ private[spark] class MapOutputTrackerMaster(
         shuffleStatuses.get(shuffleId).map(_.findMissingPartitions())
       }
     
    +  /**
    +   * To equally divide n elements into m buckets, basically each bucket 
should have n/m elements,
    +   * for the remaining n%m elements, add one more element to the first n%m 
buckets each.
    +   */
    +  def equallyDivide(numElements: Int, numBuckets: Int): Iterator[Seq[Int]] 
= {
    +    val elementsPerBucket = numElements / numBuckets
    +    val remaining = numElements % numBuckets
    +    if (remaining == 0) {
    +      0.until(numElements).grouped(elementsPerBucket)
    +    } else {
    +      val splitPoint = (elementsPerBucket + 1) * remaining
    +      0.to(splitPoint).grouped(elementsPerBucket + 1) ++
    +        (splitPoint + 1).until(numElements).grouped(elementsPerBucket)
    +    }
    +  }
    +
       /**
        * Return statistics about all of the outputs for a given shuffle.
        */
       def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics 
= {
         shuffleStatuses(dep.shuffleId).withMapStatuses { statuses =>
           val totalSizes = new Array[Long](dep.partitioner.numPartitions)
    -      for (s <- statuses) {
    -        for (i <- 0 until totalSizes.length) {
    -          totalSizes(i) += s.getSizeForBlock(i)
    +      val parallelAggThreshold = conf.get(
    +        SHUFFLE_MAP_OUTPUT_STATISTICS_PARALLEL_AGGREGATION_THRESHOLD)
    +      if (statuses.length * totalSizes.length < parallelAggThreshold) {
    +        for (s <- statuses) {
    +          for (i <- 0 until totalSizes.length) {
    +            totalSizes(i) += s.getSizeForBlock(i)
    +          }
    +        }
    +      } else {
    +        val parallelism = 
conf.get(SHUFFLE_MAP_OUTPUT_STATISTICS_PARALLELISM)
    --- End diff --
    
    How about setting  `parallelism =  
math.min(Runtime.getRuntime.availableProcessors(), statuses.length.toLong * 
totalSizes.length / parallelAggThreshold)` rather than introducing a new 
config, such as:
    ```
         val parallelism = math.min(
             Runtime.getRuntime.availableProcessors(), 
             statuses.length.toLong * totalSizes.length / parallelAggThreshold 
+ 1)
          if (parallelism <= 1) {
           ...
          } else {
            ....
          }
    ```
    



---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to