LuciferYang commented on code in PR #46101: URL: https://github.com/apache/spark/pull/46101#discussion_r1568624495
########## sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala: ########## @@ -118,18 +118,52 @@ case class CollectLimitExec(limit: Int = -1, child: SparkPlan, offset: Int = 0) * logical plan, which happens when the user is collecting results back to the driver. */ case class CollectTailExec(limit: Int, child: SparkPlan) extends LimitExec { + assert(limit >= 0) + override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = SinglePartition override def executeCollect(): Array[InternalRow] = child.executeTail(limit) + private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) + private lazy val writeMetrics = + SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) + private lazy val readMetrics = + SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext) + override lazy val metrics = readMetrics ++ writeMetrics protected override def doExecute(): RDD[InternalRow] = { - // This is a bit hacky way to avoid a shuffle and scanning all data when it performs - // at `Dataset.tail`. - // Since this execution plan and `execute` are currently called only when - // `Dataset.tail` is invoked, the jobs are always executed when they are supposed to be. - - // If we use this execution plan separately like `Dataset.limit` without an actual - // job launch, we might just have to mimic the implementation of `CollectLimitExec`. - sparkContext.parallelize(executeCollect().toImmutableArraySeq, numSlices = 1) + val childRDD = child.execute().map(_.copy()) + if (childRDD.getNumPartitions == 0 || limit == 0) { + new ParallelCollectionRDD(sparkContext, Seq.empty[InternalRow], 1, Map.empty) + } else { + val singlePartitionRDD = if (childRDD.getNumPartitions == 1) { + childRDD + } else { + val locallyLimited = childRDD.mapPartitionsInternal(takeRight) + new ShuffledRowRDD( + ShuffleExchangeExec.prepareShuffleDependency( + locallyLimited, + child.output, + SinglePartition, + serializer, + writeMetrics), + readMetrics) + } + singlePartitionRDD.mapPartitionsInternal(takeRight) + } + } + + private def takeRight(iter: Iterator[InternalRow]): Iterator[InternalRow] = { + if (iter.isEmpty) { + Iterator.empty[InternalRow] + } else { + val queue = HybridRowQueue.apply(output.size) + while (iter.hasNext) { + queue.add(iter.next().asInstanceOf[UnsafeRow]) + while (queue.size() > limit) { + queue.remove() + } + } + queue.destructiveIterator() Review Comment: If `queue.destructiveIterator` has been opened, is it possible to be killed before completing the iteration? If it is killed, is there a possibility of leaking the disk handle of the `HybridRowQueue`? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org