This is an automated email from the ASF dual-hosted git repository.
viirya pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new 4fec40e refactor: Skipping slicing on shuffle arrays in shuffle
reader (#189)
4fec40e is described below
commit 4fec40e5b81a6ef04e35be1ae8332bcfdf8597fe
Author: Liang-Chi Hsieh <[email protected]>
AuthorDate: Mon Mar 11 14:19:51 2024 -0700
refactor: Skipping slicing on shuffle arrays in shuffle reader (#189)
* refactor: Skipping slicing on shuffle arrays
* Add note for columnar shuffle batch size.
---
.../main/scala/org/apache/comet/CometConf.scala | 4 ++-
.../execution/shuffle/ArrowReaderIterator.scala | 37 +++-------------------
2 files changed, 8 insertions(+), 33 deletions(-)
diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala
b/common/src/main/scala/org/apache/comet/CometConf.scala
index 1153b55..de49fdf 100644
--- a/common/src/main/scala/org/apache/comet/CometConf.scala
+++ b/common/src/main/scala/org/apache/comet/CometConf.scala
@@ -227,7 +227,9 @@ object CometConf {
val COMET_COLUMNAR_SHUFFLE_BATCH_SIZE: ConfigEntry[Int] =
conf("spark.comet.columnar.shuffle.batch.size")
.internal()
- .doc("Batch size when writing out sorted spill files on the native
side.")
+ .doc("Batch size when writing out sorted spill files on the native side.
Note that " +
+ "this should not be larger than batch size (i.e.,
`spark.comet.batchSize`). Otherwise " +
+ "it will produce larger batches than expected in the native operator
after shuffle.")
.intConf
.createWithDefault(8192)
diff --git
a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala
b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala
index c17c5bc..e8dba93 100644
---
a/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala
+++
b/common/src/main/scala/org/apache/spark/sql/comet/execution/shuffle/ArrowReaderIterator.scala
@@ -21,22 +21,14 @@ package org.apache.spark.sql.comet.execution.shuffle
import java.nio.channels.ReadableByteChannel
-import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.vectorized.ColumnarBatch
-import org.apache.comet.CometConf
-import org.apache.comet.vector.{NativeUtil, StreamReader}
+import org.apache.comet.vector.StreamReader
class ArrowReaderIterator(channel: ReadableByteChannel) extends
Iterator[ColumnarBatch] {
- private val nativeUtil = new NativeUtil
-
- private val maxBatchSize = CometConf.COMET_BATCH_SIZE.get(SQLConf.get)
-
private val reader = StreamReader(channel)
- private var currentIdx = -1
private var batch = nextBatch()
- private var previousBatch: ColumnarBatch = null
private var currentBatch: ColumnarBatch = null
override def hasNext: Boolean = {
@@ -57,40 +49,20 @@ class ArrowReaderIterator(channel: ReadableByteChannel)
extends Iterator[Columna
}
val nextBatch = batch.get
- val batchRows = nextBatch.numRows()
- val numRows = Math.min(batchRows - currentIdx, maxBatchSize)
- // Release the previous sliced batch.
+ // Release the previous batch.
// If it is not released, when closing the reader, arrow library will
complain about
// memory leak.
if (currentBatch != null) {
- // Close plain arrays in the previous sliced batch.
- // The dictionary arrays will be closed when closing the entire batch.
currentBatch.close()
}
- currentBatch = nativeUtil.takeRows(nextBatch, currentIdx, numRows)
- currentIdx += numRows
-
- if (currentIdx == batchRows) {
- // We cannot close the batch here, because if there is dictionary array
in the batch,
- // the dictionary array will be closed immediately, and the returned
sliced batch will
- // be invalid.
- previousBatch = batch.get
-
- batch = None
- currentIdx = -1
- }
-
+ currentBatch = nextBatch
+ batch = None
currentBatch
}
private def nextBatch(): Option[ColumnarBatch] = {
- if (previousBatch != null) {
- previousBatch.close()
- previousBatch = null
- }
- currentIdx = 0
reader.nextBatch()
}
@@ -98,6 +70,7 @@ class ArrowReaderIterator(channel: ReadableByteChannel)
extends Iterator[Columna
synchronized {
if (currentBatch != null) {
currentBatch.close()
+ currentBatch = null
}
reader.close()
}