Github user Dooyoung-Hwang commented on a diff in the pull request: https://github.com/apache/spark/pull/22219#discussion_r212866466 --- Diff: sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala --- @@ -329,49 +329,52 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ * * This is modeled after `RDD.take` but never runs any job locally on the driver. */ - def executeTake(n: Int): Array[InternalRow] = { + def executeTake(n: Int): Array[InternalRow] = executeTakeIterator(n)._2.toArray + + private[spark] def executeTakeIterator(n: Int): (Long, Iterator[InternalRow]) = { if (n == 0) { - return new Array[InternalRow](0) + return (0, Iterator.empty) } - val childRDD = getByteArrayRdd(n).map(_._2) - - val buf = new ArrayBuffer[InternalRow] + val childRDD = getByteArrayRdd(n) + val encodedBuf = new ArrayBuffer[Array[Byte]] val totalParts = childRDD.partitions.length + var scannedRowCount = 0L var partsScanned = 0 - while (buf.size < n && partsScanned < totalParts) { + while (scannedRowCount < n && partsScanned < totalParts) { // The number of partitions to try in this iteration. It is ok for this number to be // greater than totalParts because we actually cap it at totalParts in runJob. - var numPartsToTry = 1L - if (partsScanned > 0) { + val numPartsToTry = if (partsScanned > 0) { // If we didn't find any rows after the previous iteration, quadruple and retry. // Otherwise, interpolate the number of partitions we need to try, but overestimate // it by 50%. We also cap the estimation in the end. val limitScaleUpFactor = Math.max(sqlContext.conf.limitScaleUpFactor, 2) - if (buf.isEmpty) { - numPartsToTry = partsScanned * limitScaleUpFactor + if (scannedRowCount == 0) { + partsScanned * limitScaleUpFactor } else { - val left = n - buf.size + val left = n - scannedRowCount // As left > 0, numPartsToTry is always >= 1 - numPartsToTry = Math.ceil(1.5 * left * partsScanned / buf.size).toInt - numPartsToTry = Math.min(numPartsToTry, partsScanned * limitScaleUpFactor) + Math.min(Math.ceil(1.5 * left * partsScanned / scannedRowCount).toInt, + partsScanned * limitScaleUpFactor) } + } else { + 1L --- End diff -- @kiszk Yeah, I'll prepare test cases. @viirya Above are changed to execute decodeUnsafeRows lazily for reduce peak memory. Changing type of numPartsToTry to val may be refactoring part that can be separated from this patch. If reviewers want to revert [this refactoring]( https://github.com/apache/spark/pull/22219/commits/91617caaab56760ea2f64f3da7486fbf445d7aa9), I can separate it and make another trivial pull request for it.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org