viirya commented on code in PR #892:
URL: https://github.com/apache/datafusion-comet/pull/892#discussion_r1742294372
##########
common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala:
##########
@@ -44,46 +45,77 @@ object CometArrowConverters extends Logging {
// exported process increases the reference count of the Arrow vectors. The
reference count is
// only decreased when the native plan is done with the vectors, which is
usually longer than
// all the ColumnarBatches are consumed.
- private[sql] class ArrowBatchIterator(
- rowIter: Iterator[InternalRow],
+
+ abstract private[sql] class SparkToArrowConverter(
Review Comment:
```suggestion
abstract private[sql] class ArrowBatchIterBase(
```
##########
common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala:
##########
@@ -99,33 +131,78 @@ object CometArrowConverters extends Logging {
null
}
}
+ }
- private def close(closeAllocator: Boolean): Unit = {
- try {
- if (!closed) {
- if (currentBatch != null) {
- arrowWriter.reset()
- currentBatch.close()
- currentBatch = null
+ def toArrowBatchIteratorFromInternalRow(
+ rowIter: Iterator[InternalRow],
+ schema: StructType,
+ maxRecordsPerBatch: Long,
+ timeZoneId: String,
+ context: TaskContext): Iterator[ColumnarBatch] = {
+ new ArrowBatchIteratorFromInternalRow(
+ rowIter,
+ schema,
+ maxRecordsPerBatch,
+ timeZoneId,
+ context)
+ }
+
+ private[sql] class ArrowBatchIteratorFromColumnBatch(
Review Comment:
```suggestion
private[sql] class ColumnBatchToArrowBatchIter(
```
##########
common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala:
##########
@@ -44,46 +45,77 @@ object CometArrowConverters extends Logging {
// exported process increases the reference count of the Arrow vectors. The
reference count is
// only decreased when the native plan is done with the vectors, which is
usually longer than
// all the ColumnarBatches are consumed.
- private[sql] class ArrowBatchIterator(
- rowIter: Iterator[InternalRow],
+
+ abstract private[sql] class SparkToArrowConverter(
schema: StructType,
- maxRecordsPerBatch: Long,
timeZoneId: String,
context: TaskContext)
extends Iterator[ColumnarBatch]
with AutoCloseable {
- private val arrowSchema = Utils.toArrowSchema(schema, timeZoneId)
+ protected val arrowSchema: Schema = Utils.toArrowSchema(schema, timeZoneId)
// Reuse the same root allocator here.
- private val allocator =
+ protected val allocator: BufferAllocator =
rootAllocator.newChildAllocator(s"to${this.getClass.getSimpleName}", 0,
Long.MaxValue)
- private val root = VectorSchemaRoot.create(arrowSchema, allocator)
- private val arrowWriter = ArrowWriter.create(root)
+ protected val root: VectorSchemaRoot =
VectorSchemaRoot.create(arrowSchema, allocator)
+ protected val arrowWriter: ArrowWriter = ArrowWriter.create(root)
- private var currentBatch: ColumnarBatch = null
- private var closed: Boolean = false
+ protected var currentBatch: ColumnarBatch = null
+ protected var closed: Boolean = false
Option(context).foreach {
_.addTaskCompletionListener[Unit] { _ =>
close(true)
}
}
- override def hasNext: Boolean = rowIter.hasNext || {
+ override def close(): Unit = {
close(false)
- false
+ }
+
+ protected def close(closeAllocator: Boolean): Unit = {
+ try {
+ if (!closed) {
+ if (currentBatch != null) {
+ arrowWriter.reset()
+ currentBatch.close()
+ currentBatch = null
+ }
+ root.close()
+ closed = true
+ }
+ } finally {
+ // the allocator shall be closed when the task is finished
+ if (closeAllocator) {
+ allocator.close()
+ }
+ }
}
override def next(): ColumnarBatch = {
currentBatch = nextBatch()
currentBatch
}
- override def close(): Unit = {
+ protected def nextBatch(): ColumnarBatch
+
+ }
+
+ private[sql] class ArrowBatchIteratorFromInternalRow(
Review Comment:
```suggestion
private[sql] class RowToArrowBatchIter(
```
##########
common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala:
##########
@@ -99,33 +131,78 @@ object CometArrowConverters extends Logging {
null
}
}
+ }
- private def close(closeAllocator: Boolean): Unit = {
- try {
- if (!closed) {
- if (currentBatch != null) {
- arrowWriter.reset()
- currentBatch.close()
- currentBatch = null
+ def toArrowBatchIteratorFromInternalRow(
+ rowIter: Iterator[InternalRow],
+ schema: StructType,
+ maxRecordsPerBatch: Long,
+ timeZoneId: String,
+ context: TaskContext): Iterator[ColumnarBatch] = {
+ new ArrowBatchIteratorFromInternalRow(
+ rowIter,
+ schema,
+ maxRecordsPerBatch,
+ timeZoneId,
+ context)
+ }
+
+ private[sql] class ArrowBatchIteratorFromColumnBatch(
+ colBatch: ColumnarBatch,
+ schema: StructType,
+ maxRecordsPerBatch: Int,
+ timeZoneId: String,
+ context: TaskContext)
+ extends SparkToArrowConverter(schema, timeZoneId, context)
+ with AutoCloseable {
+
+ private var rowsProduced: Int = 0
+
+ override def hasNext: Boolean = rowsProduced < colBatch.numRows() || {
+ close(false)
+ false
+ }
+
+ override protected def nextBatch(): ColumnarBatch = {
+ val rowsInBatch = colBatch.numRows()
+ if (rowsProduced < rowsInBatch) {
+ // the arrow writer shall be reset before writing the next batch
+ arrowWriter.reset()
+ val rowsToProduce =
+ if (maxRecordsPerBatch <= 0) rowsInBatch - rowsProduced
+ else Math.min(maxRecordsPerBatch, rowsInBatch - rowsProduced)
+
+ for (columnIndex <- 0 until colBatch.numCols()) {
+ val column = colBatch.column(columnIndex)
+ val columnArray = new ColumnarArray(column, rowsProduced,
rowsToProduce)
+ if (column.hasNull) {
+ arrowWriter.writeCol(columnArray, columnIndex)
+ } else {
+ arrowWriter.writeColNoNull(columnArray, columnIndex)
}
- root.close()
- closed = true
- }
- } finally {
- // the allocator shall be closed when the task is finished
- if (closeAllocator) {
- allocator.close()
}
+
+ rowsProduced += rowsToProduce
+
+ arrowWriter.finish()
+ NativeUtil.rootAsBatch(root)
+ } else {
+ null
}
}
}
- def toArrowBatchIterator(
- rowIter: Iterator[InternalRow],
+ def toArrowBatchIteratorFromColumnBatch(
Review Comment:
```suggestion
def columnarBatchToArrowBatchIter(
```
##########
common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala:
##########
@@ -99,33 +131,78 @@ object CometArrowConverters extends Logging {
null
}
}
+ }
- private def close(closeAllocator: Boolean): Unit = {
- try {
- if (!closed) {
- if (currentBatch != null) {
- arrowWriter.reset()
- currentBatch.close()
- currentBatch = null
+ def toArrowBatchIteratorFromInternalRow(
Review Comment:
```suggestion
def rowToArrowBatchIter(
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]