Github user dongjoon-hyun commented on a diff in the pull request:

    https://github.com/apache/spark/pull/22347#discussion_r222158041
  
    --- Diff: 
sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala ---
    @@ -348,30 +349,30 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] 
with Logging with Serializ
             // Otherwise, interpolate the number of partitions we need to try, 
but overestimate
             // it by 50%. We also cap the estimation in the end.
             val limitScaleUpFactor = 
Math.max(sqlContext.conf.limitScaleUpFactor, 2)
    -        if (buf.isEmpty) {
    +        if (scannedRowCount == 0) {
               numPartsToTry = partsScanned * limitScaleUpFactor
             } else {
    -          val left = n - buf.size
    +          val left = n - scannedRowCount
               // As left > 0, numPartsToTry is always >= 1
    -          numPartsToTry = Math.ceil(1.5 * left * partsScanned / 
buf.size).toInt
    +          numPartsToTry = Math.ceil(1.5 * left * partsScanned / 
scannedRowCount).toInt
               numPartsToTry = Math.min(numPartsToTry, partsScanned * 
limitScaleUpFactor)
             }
           }
     
           val p = partsScanned.until(math.min(partsScanned + numPartsToTry, 
totalParts).toInt)
           val sc = sqlContext.sparkContext
    -      val res = sc.runJob(childRDD,
    -        (it: Iterator[Array[Byte]]) => if (it.hasNext) it.next() else 
Array.empty[Byte], p)
    -
    -      buf ++= res.flatMap(decodeUnsafeRows)
    +      val res = sc.runJob(childRDD, (it: Iterator[(Long, Array[Byte])]) =>
    +        if (it.hasNext) it.next() else (0L, Array.empty[Byte]), p)
     
    +      buf ++= res.map(_._2)
    +      scannedRowCount += res.map(_._1).sum
           partsScanned += p.size
         }
     
    -    if (buf.size > n) {
    -      buf.take(n).toArray
    +    if (scannedRowCount > n) {
    +      buf.toArray.view.flatMap(decodeUnsafeRows).take(n).force
    --- End diff --
    
    Can we simplify like the following?
    ```scala
          buf.flatMap(decodeUnsafeRows).take(n).toArray
    ```


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to