This is an automated email from the ASF dual-hosted git repository. ruifengz pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new fe47edece059 [SPARK-47883][SQL] Make `CollectTailExec.doExecute` lazy with RowQueue fe47edece059 is described below commit fe47edece059e9189d8500b3c9b3881b44678785 Author: Ruifeng Zheng <ruife...@apache.org> AuthorDate: Fri Apr 19 12:16:58 2024 +0800 [SPARK-47883][SQL] Make `CollectTailExec.doExecute` lazy with RowQueue ### What changes were proposed in this pull request? Make `CollectTailExec.doExecute` execute lazily ### Why are the changes needed? 1, in Spark Connect, `dataframe.tail` is based on `Tail(...).collect()` 2, make `Tail` to be able to use alone; ### Does this PR introduce any user-facing change? no ### How was this patch tested? existing unit tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #46101 from zhengruifeng/sql_tail_row_queue. Lead-authored-by: Ruifeng Zheng <ruife...@apache.org> Co-authored-by: Ruifeng Zheng <ruife...@foxmail.com> Signed-off-by: Ruifeng Zheng <ruife...@apache.org> --- .../org/apache/spark/sql/execution/limit.scala | 62 ++++++++++++++++++---- .../spark/sql/execution/python/RowQueue.scala | 7 ++- 2 files changed, 57 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala index db5728d669ef..c0fb1c37b210 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution +import org.apache.spark.TaskContext import org.apache.spark.rdd.{ParallelCollectionRDD, RDD} import org.apache.spark.serializer.Serializer import org.apache.spark.sql.catalyst.InternalRow @@ -26,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.metric.{SQLShuffleReadMetricsReporter, SQLShuffleWriteMetricsReporter} -import org.apache.spark.util.ArrayImplicits._ +import org.apache.spark.sql.execution.python.HybridRowQueue import org.apache.spark.util.collection.Utils /** @@ -68,13 +69,13 @@ case class CollectLimitExec(limit: Int = -1, child: SparkPlan, offset: Int = 0) override lazy val metrics = readMetrics ++ writeMetrics protected override def doExecute(): RDD[InternalRow] = { val childRDD = child.execute() - if (childRDD.getNumPartitions == 0) { + if (childRDD.getNumPartitions == 0 || limit == 0) { new ParallelCollectionRDD(sparkContext, Seq.empty[InternalRow], 1, Map.empty) } else { val singlePartitionRDD = if (childRDD.getNumPartitions == 1) { childRDD } else { - val locallyLimited = if (limit >= 0) { + val locallyLimited = if (limit > 0) { childRDD.mapPartitionsInternal(_.take(limit)) } else { childRDD @@ -118,18 +119,57 @@ case class CollectLimitExec(limit: Int = -1, child: SparkPlan, offset: Int = 0) * logical plan, which happens when the user is collecting results back to the driver. */ case class CollectTailExec(limit: Int, child: SparkPlan) extends LimitExec { + assert(limit >= 0) + override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = SinglePartition override def executeCollect(): Array[InternalRow] = child.executeTail(limit) + private val serializer: Serializer = new UnsafeRowSerializer(child.output.size) + private lazy val writeMetrics = + SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext) + private lazy val readMetrics = + SQLShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext) + override lazy val metrics = readMetrics ++ writeMetrics protected override def doExecute(): RDD[InternalRow] = { - // This is a bit hacky way to avoid a shuffle and scanning all data when it performs - // at `Dataset.tail`. - // Since this execution plan and `execute` are currently called only when - // `Dataset.tail` is invoked, the jobs are always executed when they are supposed to be. - - // If we use this execution plan separately like `Dataset.limit` without an actual - // job launch, we might just have to mimic the implementation of `CollectLimitExec`. - sparkContext.parallelize(executeCollect().toImmutableArraySeq, numSlices = 1) + val childRDD = child.execute() + if (childRDD.getNumPartitions == 0 || limit == 0) { + new ParallelCollectionRDD(sparkContext, Seq.empty[InternalRow], 1, Map.empty) + } else { + val singlePartitionRDD = if (childRDD.getNumPartitions == 1) { + childRDD + } else { + val locallyLimited = childRDD.mapPartitionsInternal(takeRight) + new ShuffledRowRDD( + ShuffleExchangeExec.prepareShuffleDependency( + locallyLimited, + child.output, + SinglePartition, + serializer, + writeMetrics), + readMetrics) + } + singlePartitionRDD.mapPartitionsInternal(takeRight) + } + } + + private def takeRight(iter: Iterator[InternalRow]): Iterator[InternalRow] = { + if (iter.isEmpty) { + Iterator.empty[InternalRow] + } else { + val context = TaskContext.get() + val queue = HybridRowQueue.apply(context.taskMemoryManager(), output.size) + context.addTaskCompletionListener[Unit](_ => queue.close()) + var count = 0 + while (iter.hasNext) { + queue.add(iter.next().copy().asInstanceOf[UnsafeRow]) + if (count < limit) { + count += 1 + } else { + queue.remove() + } + } + Iterator.range(0, count).map(_ => queue.remove()) + } } override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala index 5e0c5ff92fda..ce30a54c8d4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/RowQueue.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.memory.MemoryBlock +import org.apache.spark.util.Utils /** * A RowQueue is an FIFO queue for UnsafeRow. @@ -288,8 +289,12 @@ private[python] case class HybridRowQueue( } } -private[python] object HybridRowQueue { +private[sql] object HybridRowQueue { def apply(taskMemoryMgr: TaskMemoryManager, file: File, fields: Int): HybridRowQueue = { HybridRowQueue(taskMemoryMgr, file, fields, SparkEnv.get.serializerManager) } + + def apply(taskMemoryMgr: TaskMemoryManager, fields: Int): HybridRowQueue = { + apply(taskMemoryMgr, new File(Utils.getLocalDir(SparkEnv.get.conf)), fields) + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org