Github user cloud-fan commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19763#discussion_r152193763
  
    --- Diff: core/src/main/scala/org/apache/spark/MapOutputTracker.scala ---
    @@ -472,15 +475,66 @@ private[spark] class MapOutputTrackerMaster(
         shuffleStatuses.get(shuffleId).map(_.findMissingPartitions())
       }
     
    +  /**
    +   * Grouped function of Range, this is to avoid traverse of all elements 
of Range using
    +   * IterableLike's grouped function.
    +   */
    +  def rangeGrouped(range: Range, size: Int): Seq[Range] = {
    +    val start = range.start
    +    val step = range.step
    +    val end = range.end
    +    for (i <- start.until(end, size * step)) yield {
    +      i.until(i + size * step, step)
    +    }
    +  }
    +
    +  /**
    +   * To equally divide n elements into m buckets, basically each bucket 
should have n/m elements,
    +   * for the remaining n%m elements, add one more element to the first n%m 
buckets each.
    +   */
    +  def equallyDivide(numElements: Int, numBuckets: Int): Seq[Seq[Int]] = {
    +    val elementsPerBucket = numElements / numBuckets
    +    val remaining = numElements % numBuckets
    +    val splitPoint = (elementsPerBucket + 1) * remaining
    +    if (elementsPerBucket == 0) {
    +      rangeGrouped(0.until(splitPoint), elementsPerBucket + 1)
    +    } else {
    +      rangeGrouped(0.until(splitPoint), elementsPerBucket + 1) ++
    +        rangeGrouped(splitPoint.until(numElements), elementsPerBucket)
    +    }
    +  }
    +
       /**
        * Return statistics about all of the outputs for a given shuffle.
        */
       def getStatistics(dep: ShuffleDependency[_, _, _]): MapOutputStatistics 
= {
         shuffleStatuses(dep.shuffleId).withMapStatuses { statuses =>
           val totalSizes = new Array[Long](dep.partitioner.numPartitions)
    -      for (s <- statuses) {
    -        for (i <- 0 until totalSizes.length) {
    -          totalSizes(i) += s.getSizeForBlock(i)
    +      val parallelAggThreshold = conf.get(
    +        SHUFFLE_MAP_OUTPUT_PARALLEL_AGGREGATION_THRESHOLD)
    +      val parallelism = math.min(
    +        Runtime.getRuntime.availableProcessors(),
    +        statuses.length * totalSizes.length / parallelAggThreshold + 1)
    +      if (parallelism <= 1) {
    +        for (s <- statuses) {
    +          for (i <- 0 until totalSizes.length) {
    +            totalSizes(i) += s.getSizeForBlock(i)
    +          }
    +        }
    +      } else {
    +        try {
    +          val threadPool = 
ThreadUtils.newDaemonFixedThreadPool(parallelism, "map-output-aggregate")
    +          implicit val executionContext = 
ExecutionContext.fromExecutor(threadPool)
    +          val mapStatusSubmitTasks = equallyDivide(totalSizes.length, 
parallelism).map {
    +            reduceIds => Future {
    +              for (s <- statuses; i <- reduceIds) {
    +                totalSizes(i) += s.getSizeForBlock(i)
    +              }
    +            }
    +          }
    +          ThreadUtils.awaitResult(Future.sequence(mapStatusSubmitTasks), 
Duration.Inf)
    +        } finally {
    +          threadpool.shutdown()
    --- End diff --
    
    cc @zsxwing do we really need to shut down the thread pool every time? This 
method may be called many times and is it better to cache this thread pool? 
like the dispatcher thread pool.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to