Github user Dooyoung-Hwang commented on a diff in the pull request:

    https://github.com/apache/spark/pull/22219#discussion_r212866466
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala ---
    @@ -329,49 +329,52 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] 
with Logging with Serializ
        *
        * This is modeled after `RDD.take` but never runs any job locally on 
the driver.
        */
    -  def executeTake(n: Int): Array[InternalRow] = {
    +  def executeTake(n: Int): Array[InternalRow] = 
executeTakeIterator(n)._2.toArray
    +
    +  private[spark] def executeTakeIterator(n: Int): (Long, 
Iterator[InternalRow]) = {
         if (n == 0) {
    -      return new Array[InternalRow](0)
    +      return (0, Iterator.empty)
         }
     
    -    val childRDD = getByteArrayRdd(n).map(_._2)
    -
    -    val buf = new ArrayBuffer[InternalRow]
    +    val childRDD = getByteArrayRdd(n)
    +    val encodedBuf = new ArrayBuffer[Array[Byte]]
         val totalParts = childRDD.partitions.length
    +    var scannedRowCount = 0L
         var partsScanned = 0
    -    while (buf.size < n && partsScanned < totalParts) {
    +    while (scannedRowCount < n && partsScanned < totalParts) {
           // The number of partitions to try in this iteration. It is ok for 
this number to be
           // greater than totalParts because we actually cap it at totalParts 
in runJob.
    -      var numPartsToTry = 1L
    -      if (partsScanned > 0) {
    +      val numPartsToTry = if (partsScanned > 0) {
             // If we didn't find any rows after the previous iteration, 
quadruple and retry.
             // Otherwise, interpolate the number of partitions we need to try, 
but overestimate
             // it by 50%. We also cap the estimation in the end.
             val limitScaleUpFactor = 
Math.max(sqlContext.conf.limitScaleUpFactor, 2)
    -        if (buf.isEmpty) {
    -          numPartsToTry = partsScanned * limitScaleUpFactor
    +        if (scannedRowCount == 0) {
    +          partsScanned * limitScaleUpFactor
             } else {
    -          val left = n - buf.size
    +          val left = n - scannedRowCount
               // As left > 0, numPartsToTry is always >= 1
    -          numPartsToTry = Math.ceil(1.5 * left * partsScanned / 
buf.size).toInt
    -          numPartsToTry = Math.min(numPartsToTry, partsScanned * 
limitScaleUpFactor)
    +          Math.min(Math.ceil(1.5 * left * partsScanned / 
scannedRowCount).toInt,
    +            partsScanned * limitScaleUpFactor)
             }
    +      } else {
    +        1L
    --- End diff --
    
    @kiszk 
    Yeah, I'll prepare test cases.
    
    @viirya 
    Above are changed to execute decodeUnsafeRows lazily for reduce peak 
memory. Changing type of numPartsToTry to val may be refactoring part that can 
be separated from this patch. If reviewers want to revert [this refactoring]( 
https://github.com/apache/spark/pull/22219/commits/91617caaab56760ea2f64f3da7486fbf445d7aa9),
 I can separate it and make another trivial pull request for it.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to