LuciferYang commented on code in PR #46101:
URL: https://github.com/apache/spark/pull/46101#discussion_r1568624495


##########
sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala:
##########
@@ -118,18 +118,52 @@ case class CollectLimitExec(limit: Int = -1, child: 
SparkPlan, offset: Int = 0)
  * logical plan, which happens when the user is collecting results back to the 
driver.
  */
 case class CollectTailExec(limit: Int, child: SparkPlan) extends LimitExec {
+  assert(limit >= 0)
+
   override def output: Seq[Attribute] = child.output
   override def outputPartitioning: Partitioning = SinglePartition
   override def executeCollect(): Array[InternalRow] = child.executeTail(limit)
+  private val serializer: Serializer = new 
UnsafeRowSerializer(child.output.size)
+  private lazy val writeMetrics =
+    SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
+  private lazy val readMetrics =
+    SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext)
+  override lazy val metrics = readMetrics ++ writeMetrics
   protected override def doExecute(): RDD[InternalRow] = {
-    // This is a bit hacky way to avoid a shuffle and scanning all data when 
it performs
-    // at `Dataset.tail`.
-    // Since this execution plan and `execute` are currently called only when
-    // `Dataset.tail` is invoked, the jobs are always executed when they are 
supposed to be.
-
-    // If we use this execution plan separately like `Dataset.limit` without 
an actual
-    // job launch, we might just have to mimic the implementation of 
`CollectLimitExec`.
-    sparkContext.parallelize(executeCollect().toImmutableArraySeq, numSlices = 
1)
+    val childRDD = child.execute().map(_.copy())
+    if (childRDD.getNumPartitions == 0 || limit == 0) {
+      new ParallelCollectionRDD(sparkContext, Seq.empty[InternalRow], 1, 
Map.empty)
+    } else {
+      val singlePartitionRDD = if (childRDD.getNumPartitions == 1) {
+        childRDD
+      } else {
+        val locallyLimited = childRDD.mapPartitionsInternal(takeRight)
+        new ShuffledRowRDD(
+          ShuffleExchangeExec.prepareShuffleDependency(
+            locallyLimited,
+            child.output,
+            SinglePartition,
+            serializer,
+            writeMetrics),
+          readMetrics)
+      }
+      singlePartitionRDD.mapPartitionsInternal(takeRight)
+    }
+  }
+
+  private def takeRight(iter: Iterator[InternalRow]): Iterator[InternalRow] = {
+    if (iter.isEmpty) {
+      Iterator.empty[InternalRow]
+    } else {
+      val queue = HybridRowQueue.apply(output.size)
+      while (iter.hasNext) {
+        queue.add(iter.next().asInstanceOf[UnsafeRow])
+        while (queue.size() > limit) {
+          queue.remove()
+        }
+      }
+      queue.destructiveIterator()

Review Comment:
   If `queue.destructiveIterator` has been opened, is it possible to be killed 
before completing the iteration? If it is killed, is there a possibility of 
leaking the disk handle of the `HybridRowQueue`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to