This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new fe47edece059 [SPARK-47883][SQL] Make `CollectTailExec.doExecute` lazy 
with RowQueue
fe47edece059 is described below

commit fe47edece059e9189d8500b3c9b3881b44678785
Author: Ruifeng Zheng <ruife...@apache.org>
AuthorDate: Fri Apr 19 12:16:58 2024 +0800

    [SPARK-47883][SQL] Make `CollectTailExec.doExecute` lazy with RowQueue
    
    ### What changes were proposed in this pull request?
    Make `CollectTailExec.doExecute` execute lazily
    
    ### Why are the changes needed?
    1, in Spark Connect, `dataframe.tail` is based on `Tail(...).collect()`
    2, make `Tail` to be able to use alone;
    
    ### Does this PR introduce any user-facing change?
    no
    
    ### How was this patch tested?
    existing unit tests
    
    ### Was this patch authored or co-authored using generative AI tooling?
    no
    
    Closes #46101 from zhengruifeng/sql_tail_row_queue.
    
    Lead-authored-by: Ruifeng Zheng <ruife...@apache.org>
    Co-authored-by: Ruifeng Zheng <ruife...@foxmail.com>
    Signed-off-by: Ruifeng Zheng <ruife...@apache.org>
---
 .../org/apache/spark/sql/execution/limit.scala     | 62 ++++++++++++++++++----
 .../spark/sql/execution/python/RowQueue.scala      |  7 ++-
 2 files changed, 57 insertions(+), 12 deletions(-)

diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index db5728d669ef..c0fb1c37b210 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -17,6 +17,7 @@
 
 package org.apache.spark.sql.execution
 
+import org.apache.spark.TaskContext
 import org.apache.spark.rdd.{ParallelCollectionRDD, RDD}
 import org.apache.spark.serializer.Serializer
 import org.apache.spark.sql.catalyst.InternalRow
@@ -26,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.physical._
 import org.apache.spark.sql.catalyst.util.truncatedString
 import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
 import org.apache.spark.sql.execution.metric.{SQLShuffleReadMetricsReporter, 
SQLShuffleWriteMetricsReporter}
-import org.apache.spark.util.ArrayImplicits._
+import org.apache.spark.sql.execution.python.HybridRowQueue
 import org.apache.spark.util.collection.Utils
 
 /**
@@ -68,13 +69,13 @@ case class CollectLimitExec(limit: Int = -1, child: 
SparkPlan, offset: Int = 0)
   override lazy val metrics = readMetrics ++ writeMetrics
   protected override def doExecute(): RDD[InternalRow] = {
     val childRDD = child.execute()
-    if (childRDD.getNumPartitions == 0) {
+    if (childRDD.getNumPartitions == 0 || limit == 0) {
       new ParallelCollectionRDD(sparkContext, Seq.empty[InternalRow], 1, 
Map.empty)
     } else {
       val singlePartitionRDD = if (childRDD.getNumPartitions == 1) {
         childRDD
       } else {
-        val locallyLimited = if (limit >= 0) {
+        val locallyLimited = if (limit > 0) {
           childRDD.mapPartitionsInternal(_.take(limit))
         } else {
           childRDD
@@ -118,18 +119,57 @@ case class CollectLimitExec(limit: Int = -1, child: 
SparkPlan, offset: Int = 0)
  * logical plan, which happens when the user is collecting results back to the 
driver.
  */
 case class CollectTailExec(limit: Int, child: SparkPlan) extends LimitExec {
+  assert(limit >= 0)
+
   override def output: Seq[Attribute] = child.output
   override def outputPartitioning: Partitioning = SinglePartition
   override def executeCollect(): Array[InternalRow] = child.executeTail(limit)
+  private val serializer: Serializer = new 
UnsafeRowSerializer(child.output.size)
+  private lazy val writeMetrics =
+    SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
+  private lazy val readMetrics =
+    SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext)
+  override lazy val metrics = readMetrics ++ writeMetrics
   protected override def doExecute(): RDD[InternalRow] = {
-    // This is a bit hacky way to avoid a shuffle and scanning all data when 
it performs
-    // at `Dataset.tail`.
-    // Since this execution plan and `execute` are currently called only when
-    // `Dataset.tail` is invoked, the jobs are always executed when they are 
supposed to be.
-
-    // If we use this execution plan separately like `Dataset.limit` without 
an actual
-    // job launch, we might just have to mimic the implementation of 
`CollectLimitExec`.
-    sparkContext.parallelize(executeCollect().toImmutableArraySeq, numSlices = 
1)
+    val childRDD = child.execute()
+    if (childRDD.getNumPartitions == 0 || limit == 0) {
+      new ParallelCollectionRDD(sparkContext, Seq.empty[InternalRow], 1, 
Map.empty)
+    } else {
+      val singlePartitionRDD = if (childRDD.getNumPartitions == 1) {
+        childRDD
+      } else {
+        val locallyLimited = childRDD.mapPartitionsInternal(takeRight)
+        new ShuffledRowRDD(
+          ShuffleExchangeExec.prepareShuffleDependency(
+            locallyLimited,
+            child.output,
+            SinglePartition,
+            serializer,
+            writeMetrics),
+          readMetrics)
+      }
+      singlePartitionRDD.mapPartitionsInternal(takeRight)
+    }
+  }
+
+  private def takeRight(iter: Iterator[InternalRow]): Iterator[InternalRow] = {
+    if (iter.isEmpty) {
+      Iterator.empty[InternalRow]
+    } else {
+      val context = TaskContext.get()
+      val queue = HybridRowQueue.apply(context.taskMemoryManager(), 
output.size)
+      context.addTaskCompletionListener[Unit](_ => queue.close())
+      var count = 0
+      while (iter.hasNext) {
+        queue.add(iter.next().copy().asInstanceOf[UnsafeRow])
+        if (count < limit) {
+          count += 1
+        } else {
+          queue.remove()
+        }
+      }
+      Iterator.range(0, count).map(_ => queue.remove())
+    }
   }
 
   override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala
index 5e0c5ff92fda..ce30a54c8d4e 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala
@@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.errors.QueryExecutionErrors
 import org.apache.spark.unsafe.Platform
 import org.apache.spark.unsafe.memory.MemoryBlock
+import org.apache.spark.util.Utils
 
 /**
  * A RowQueue is an FIFO queue for UnsafeRow.
@@ -288,8 +289,12 @@ private[python] case class HybridRowQueue(
   }
 }
 
-private[python] object HybridRowQueue {
+private[sql] object HybridRowQueue {
   def apply(taskMemoryMgr: TaskMemoryManager, file: File, fields: Int): 
HybridRowQueue = {
     HybridRowQueue(taskMemoryMgr, file, fields, SparkEnv.get.serializerManager)
   }
+
+  def apply(taskMemoryMgr: TaskMemoryManager, fields: Int): HybridRowQueue = {
+    apply(taskMemoryMgr, new File(Utils.getLocalDir(SparkEnv.get.conf)), 
fields)
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to