ArnavBalyan commented on code in PR #8914:
URL: https://github.com/apache/incubator-gluten/pull/8914#discussion_r2009916832


##########
backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarCollectLimitExec.scala:
##########
@@ -32,88 +32,94 @@ import org.apache.spark.sql.vectorized.ColumnarBatch
 
 case class ColumnarCollectLimitExec(
     limit: Int,
-    child: SparkPlan
-) extends ColumnarCollectLimitBaseExec(limit, child) {
+    child: SparkPlan,
+    offset: Int = 0
+) extends ColumnarCollectLimitBaseExec(limit, child, offset) {
+
+  assert(limit >= 0 || (limit == -1 && offset > 0))
 
   override def batchType(): Convention.BatchType =
     BackendsApiManager.getSettings.primaryBatchType
 
+  private lazy val writeMetrics =
+    SQLShuffleWriteMetricsReporter.createShuffleWriteMetrics(sparkContext)
+
+  private lazy val readMetrics =
+    
SQLColumnarShuffleReadMetricsReporter.createShuffleReadMetrics(sparkContext)
+
+  private lazy val useSortBasedShuffle: Boolean =
+    BackendsApiManager.getSparkPlanExecApiInstance
+      .useSortBasedShuffle(outputPartitioning, child.output)
+
+  @transient private lazy val serializer: Serializer =
+    BackendsApiManager.getSparkPlanExecApiInstance
+      .createColumnarBatchSerializer(child.schema, metrics, 
useSortBasedShuffle)
+
+  @transient override lazy val metrics: Map[String, SQLMetric] =
+    BackendsApiManager.getMetricsApiInstance
+      .genColumnarShuffleExchangeMetrics(sparkContext, useSortBasedShuffle) ++
+      readMetrics ++ writeMetrics
+
   /**
-   * Returns an iterator that yields up to `limit` rows in total from the 
input partitionIter.
+   * Returns an iterator that gives offset to limit rows in total from the 
input partitionIter.
    * Either retain the entire batch if it fits within the remaining limit, or 
prune it if it
-   * partially exceeds the remaining limit.
+   * partially exceeds the remaining limit/offset.
    */
-  private def collectLimitedRows(
-      partitionIter: Iterator[ColumnarBatch],
-      limit: Int
-  ): Iterator[ColumnarBatch] = {
-    if (partitionIter.isEmpty) {
-      return Iterator.empty
-    }
-    new Iterator[ColumnarBatch] {
+  private def collectWithOffsetAndLimit(
+                                   inputIter: Iterator[ColumnarBatch],
+                                   offset: Int,
+                                   limit: Int): Iterator[ColumnarBatch] = {
+
+    val unlimited = limit < 0
+    var rowsToSkip = math.max(offset, 0)
+    var rowsToCollect = if (unlimited) Int.MaxValue else limit
 
-      private var rowsCollected = 0
+    new Iterator[ColumnarBatch] {
       private var nextBatch: Option[ColumnarBatch] = None
 
       override def hasNext: Boolean = {
-        nextBatch.isDefined || fetchNext()
+        nextBatch.isDefined || fetchNextBatch()
       }
 
       override def next(): ColumnarBatch = {
-        if (!hasNext) {
-          throw new NoSuchElementException("No more batches available.")
-        }
+        if (!hasNext) throw new NoSuchElementException("No more batches 
available.")
         val batch = nextBatch.get
         nextBatch = None
         batch
       }
 
       /**
-       * Attempt to fetch the next batch from the underlying iterator if we 
haven't yet hit the
-       * limit. Returns true if we found a new batch, false otherwise.
+       * Advance the iterator until we find a batch (possibly sliced)
+       * that we can return, or exhaust the input.
        */
-      private def fetchNext(): Boolean = {
-        if (rowsCollected >= limit || !partitionIter.hasNext) {
-          return false
-        }
-
-        val currentBatch = partitionIter.next()
-        val currentBatchRowCount = currentBatch.numRows()
-        val remaining = limit - rowsCollected
-
-        if (currentBatchRowCount <= remaining) {
-          rowsCollected += currentBatchRowCount
-          ColumnarBatches.retain(currentBatch)
-          nextBatch = Some(currentBatch)
-        } else {
-          val prunedBatch = VeloxColumnarBatches.slice(currentBatch, 0, 
remaining)
-          rowsCollected += remaining
-          nextBatch = Some(prunedBatch)
+      private def fetchNextBatch(): Boolean = {
+        if (rowsToCollect <= 0) return false
+
+        while (inputIter.hasNext) {
+          val batch = inputIter.next()
+          val batchSize = batch.numRows()
+
+          if (rowsToSkip >= batchSize) {
+            rowsToSkip -= batchSize
+          } else {
+            val startIndex = rowsToSkip
+            val leftoverAfterSkip = batchSize - startIndex
+            rowsToSkip = 0
+
+            val needed = math.min(rowsToCollect, leftoverAfterSkip)

Review Comment:
   I see, you mean moving out this case to not slice, let me do the refactor 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to