This is an automated email from the ASF dual-hosted git repository.

sunchao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 60fe431  feat: Add CometRowToColumnar operator (#206)
60fe431 is described below

commit 60fe4315e793d7efb6a8de9246257c1b3c623766
Author: advancedxy <[email protected]>
AuthorDate: Wed Apr 10 11:59:58 2024 +0800

    feat: Add CometRowToColumnar operator (#206)
---
 .../main/scala/org/apache/comet/CometConf.scala    |  20 +
 .../scala/org/apache/comet/vector/NativeUtil.scala |  17 +
 .../org/apache/comet/vector/StreamReader.scala     |  13 +-
 .../sql/comet/execution/arrow/ArrowWriters.scala   | 472 +++++++++++++++++++++
 .../execution/arrow/CometArrowConverters.scala     | 131 ++++++
 .../org/apache/spark/sql/comet/util/Utils.scala    |  15 +
 .../apache/comet/CometSparkSessionExtensions.scala |  56 ++-
 .../org/apache/comet/serde/QueryPlanSerde.scala    |   3 +-
 .../spark/sql/comet/CometRowToColumnarExec.scala   |  84 ++++
 .../org/apache/spark/sql/comet/operators.scala     |   5 +-
 .../org/apache/comet/exec/CometExecSuite.scala     |  54 ++-
 .../scala/org/apache/spark/sql/CometTestBase.scala |  15 +-
 12 files changed, 856 insertions(+), 29 deletions(-)

diff --git a/common/src/main/scala/org/apache/comet/CometConf.scala 
b/common/src/main/scala/org/apache/comet/CometConf.scala
index bd2e04d..341ec98 100644
--- a/common/src/main/scala/org/apache/comet/CometConf.scala
+++ b/common/src/main/scala/org/apache/comet/CometConf.scala
@@ -337,6 +337,26 @@ object CometConf {
         "enabled when reading from Iceberg tables.")
     .booleanConf
     .createWithDefault(false)
+
+  val COMET_ROW_TO_COLUMNAR_ENABLED: ConfigEntry[Boolean] =
+    conf("spark.comet.rowToColumnar.enabled")
+      .internal()
+      .doc("""
+         |Whether to enable row to columnar conversion in Comet. When this is 
turned on, Comet will
+         |convert row-based operators in 
`spark.comet.rowToColumnar.supportedOperatorList` into
+         |columnar based before processing.""".stripMargin)
+      .booleanConf
+      .createWithDefault(false)
+
+  val COMET_ROW_TO_COLUMNAR_SUPPORTED_OPERATOR_LIST: ConfigEntry[Seq[String]] =
+    conf("spark.comet.rowToColumnar.supportedOperatorList")
+      .doc(
+        "A comma-separated list of row-based operators that will be converted 
to columnar " +
+          "format when 'spark.comet.rowToColumnar.enabled' is true")
+      .stringConf
+      .toSequence
+      .createWithDefault(Seq("Range,InMemoryTableScan"))
+
 }
 
 object ConfigHelpers {
diff --git a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala 
b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
index 3756da9..763ccff 100644
--- a/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
+++ b/common/src/main/scala/org/apache/comet/vector/NativeUtil.scala
@@ -23,6 +23,8 @@ import scala.collection.mutable
 
 import org.apache.arrow.c.{ArrowArray, ArrowImporter, ArrowSchema, 
CDataDictionaryProvider, Data}
 import org.apache.arrow.memory.RootAllocator
+import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.vector.dictionary.DictionaryProvider
 import org.apache.spark.SparkException
 import org.apache.spark.sql.comet.util.Utils
 import org.apache.spark.sql.vectorized.ColumnarBatch
@@ -132,3 +134,18 @@ class NativeUtil {
     new ColumnarBatch(arrayVectors.toArray, maxNumRows)
   }
 }
+
+object NativeUtil {
+  def rootAsBatch(arrowRoot: VectorSchemaRoot): ColumnarBatch = {
+    rootAsBatch(arrowRoot, null)
+  }
+
+  def rootAsBatch(arrowRoot: VectorSchemaRoot, provider: DictionaryProvider): 
ColumnarBatch = {
+    val vectors = (0 until arrowRoot.getFieldVectors.size()).map { i =>
+      val vector = arrowRoot.getFieldVectors.get(i)
+      // Native shuffle always uses decimal128.
+      CometVector.getVector(vector, true, provider)
+    }
+    new ColumnarBatch(vectors.toArray, arrowRoot.getRowCount)
+  }
+}
diff --git a/common/src/main/scala/org/apache/comet/vector/StreamReader.scala 
b/common/src/main/scala/org/apache/comet/vector/StreamReader.scala
index da72383..61d800b 100644
--- a/common/src/main/scala/org/apache/comet/vector/StreamReader.scala
+++ b/common/src/main/scala/org/apache/comet/vector/StreamReader.scala
@@ -21,13 +21,11 @@ package org.apache.comet.vector
 
 import java.nio.channels.ReadableByteChannel
 
-import scala.collection.JavaConverters.collectionAsScalaIterableConverter
-
 import org.apache.arrow.memory.RootAllocator
 import org.apache.arrow.vector.VectorSchemaRoot
 import org.apache.arrow.vector.ipc.{ArrowStreamReader, ReadChannel}
 import org.apache.arrow.vector.ipc.message.MessageChannelReader
-import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector}
+import org.apache.spark.sql.vectorized.ColumnarBatch
 
 /**
  * A reader that consumes Arrow data from an input channel, and produces Comet 
batches.
@@ -47,14 +45,7 @@ case class StreamReader(channel: ReadableByteChannel) 
extends AutoCloseable {
   }
 
   private def rootAsBatch(root: VectorSchemaRoot): ColumnarBatch = {
-    val columns = root.getFieldVectors.asScala.map { vector =>
-      // Native shuffle always uses decimal128.
-      CometVector.getVector(vector, true, 
arrowReader).asInstanceOf[ColumnVector]
-    }.toArray
-
-    val batch = new ColumnarBatch(columns)
-    batch.setNumRows(root.getRowCount)
-    batch
+    NativeUtil.rootAsBatch(root, arrowReader)
   }
 
   override def close(): Unit = {
diff --git 
a/common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ArrowWriters.scala
 
b/common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ArrowWriters.scala
new file mode 100644
index 0000000..8d9f373
--- /dev/null
+++ 
b/common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/ArrowWriters.scala
@@ -0,0 +1,472 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.execution.arrow
+
+import scala.collection.JavaConverters._
+
+import org.apache.arrow.vector._
+import org.apache.arrow.vector.complex._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.SpecializedGetters
+import org.apache.spark.sql.comet.util.Utils
+import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.types._
+
+/**
+ * This file is mostly copied from Spark SQL's
+ * org.apache.spark.sql.execution.arrow.ArrowWriter.scala. Comet shadows Arrow 
classes to avoid
+ * potential conflicts with Spark's Arrow dependencies, hence we cannot reuse 
Spark's ArrowWriter
+ * directly.
+ */
+private[arrow] object ArrowWriter {
+  def create(root: VectorSchemaRoot): ArrowWriter = {
+    val children = root.getFieldVectors().asScala.map { vector =>
+      vector.allocateNew()
+      createFieldWriter(vector)
+    }
+    new ArrowWriter(root, children.toArray)
+  }
+
+  private[sql] def createFieldWriter(vector: ValueVector): ArrowFieldWriter = {
+    val field = vector.getField()
+    (Utils.fromArrowField(field), vector) match {
+      case (BooleanType, vector: BitVector) => new BooleanWriter(vector)
+      case (ByteType, vector: TinyIntVector) => new ByteWriter(vector)
+      case (ShortType, vector: SmallIntVector) => new ShortWriter(vector)
+      case (IntegerType, vector: IntVector) => new IntegerWriter(vector)
+      case (LongType, vector: BigIntVector) => new LongWriter(vector)
+      case (FloatType, vector: Float4Vector) => new FloatWriter(vector)
+      case (DoubleType, vector: Float8Vector) => new DoubleWriter(vector)
+      case (DecimalType.Fixed(precision, scale), vector: DecimalVector) =>
+        new DecimalWriter(vector, precision, scale)
+      case (StringType, vector: VarCharVector) => new StringWriter(vector)
+      case (StringType, vector: LargeVarCharVector) => new 
LargeStringWriter(vector)
+      case (BinaryType, vector: VarBinaryVector) => new BinaryWriter(vector)
+      case (BinaryType, vector: LargeVarBinaryVector) => new 
LargeBinaryWriter(vector)
+      case (DateType, vector: DateDayVector) => new DateWriter(vector)
+      case (TimestampType, vector: TimeStampMicroTZVector) => new 
TimestampWriter(vector)
+      case (TimestampNTZType, vector: TimeStampMicroVector) => new 
TimestampNTZWriter(vector)
+      case (ArrayType(_, _), vector: ListVector) =>
+        val elementVector = createFieldWriter(vector.getDataVector())
+        new ArrayWriter(vector, elementVector)
+      case (MapType(_, _, _), vector: MapVector) =>
+        val structVector = vector.getDataVector.asInstanceOf[StructVector]
+        val keyWriter = 
createFieldWriter(structVector.getChild(MapVector.KEY_NAME))
+        val valueWriter = 
createFieldWriter(structVector.getChild(MapVector.VALUE_NAME))
+        new MapWriter(vector, structVector, keyWriter, valueWriter)
+      case (StructType(_), vector: StructVector) =>
+        val children = (0 until vector.size()).map { ordinal =>
+          createFieldWriter(vector.getChildByOrdinal(ordinal))
+        }
+        new StructWriter(vector, children.toArray)
+      case (NullType, vector: NullVector) => new NullWriter(vector)
+      case (_: YearMonthIntervalType, vector: IntervalYearVector) =>
+        new IntervalYearWriter(vector)
+      case (_: DayTimeIntervalType, vector: DurationVector) => new 
DurationWriter(vector)
+//      case (CalendarIntervalType, vector: IntervalMonthDayNanoVector) =>
+//        new IntervalMonthDayNanoWriter(vector)
+      case (dt, _) =>
+        throw QueryExecutionErrors.notSupportTypeError(dt)
+    }
+  }
+}
+
+class ArrowWriter(val root: VectorSchemaRoot, fields: Array[ArrowFieldWriter]) 
{
+
+  def schema: StructType = Utils.fromArrowSchema(root.getSchema())
+
+  private var count: Int = 0
+
+  def write(row: InternalRow): Unit = {
+    var i = 0
+    while (i < fields.length) {
+      fields(i).write(row, i)
+      i += 1
+    }
+    count += 1
+  }
+
+  def finish(): Unit = {
+    root.setRowCount(count)
+    fields.foreach(_.finish())
+  }
+
+  def reset(): Unit = {
+    root.setRowCount(0)
+    count = 0
+    fields.foreach(_.reset())
+  }
+}
+
+private[arrow] abstract class ArrowFieldWriter {
+
+  def valueVector: ValueVector
+
+  def name: String = valueVector.getField().getName()
+  def dataType: DataType = Utils.fromArrowField(valueVector.getField())
+  def nullable: Boolean = valueVector.getField().isNullable()
+
+  def setNull(): Unit
+  def setValue(input: SpecializedGetters, ordinal: Int): Unit
+
+  private[arrow] var count: Int = 0
+
+  def write(input: SpecializedGetters, ordinal: Int): Unit = {
+    if (input.isNullAt(ordinal)) {
+      setNull()
+    } else {
+      setValue(input, ordinal)
+    }
+    count += 1
+  }
+
+  def finish(): Unit = {
+    valueVector.setValueCount(count)
+  }
+
+  def reset(): Unit = {
+    valueVector.reset()
+    count = 0
+  }
+}
+
+private[arrow] class BooleanWriter(val valueVector: BitVector) extends 
ArrowFieldWriter {
+
+  override def setNull(): Unit = {
+    valueVector.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    valueVector.setSafe(count, if (input.getBoolean(ordinal)) 1 else 0)
+  }
+}
+
+private[arrow] class ByteWriter(val valueVector: TinyIntVector) extends 
ArrowFieldWriter {
+
+  override def setNull(): Unit = {
+    valueVector.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    valueVector.setSafe(count, input.getByte(ordinal))
+  }
+}
+
+private[arrow] class ShortWriter(val valueVector: SmallIntVector) extends 
ArrowFieldWriter {
+
+  override def setNull(): Unit = {
+    valueVector.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    valueVector.setSafe(count, input.getShort(ordinal))
+  }
+}
+
+private[arrow] class IntegerWriter(val valueVector: IntVector) extends 
ArrowFieldWriter {
+
+  override def setNull(): Unit = {
+    valueVector.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    valueVector.setSafe(count, input.getInt(ordinal))
+  }
+}
+
+private[arrow] class LongWriter(val valueVector: BigIntVector) extends 
ArrowFieldWriter {
+
+  override def setNull(): Unit = {
+    valueVector.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    valueVector.setSafe(count, input.getLong(ordinal))
+  }
+}
+
+private[arrow] class FloatWriter(val valueVector: Float4Vector) extends 
ArrowFieldWriter {
+
+  override def setNull(): Unit = {
+    valueVector.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    valueVector.setSafe(count, input.getFloat(ordinal))
+  }
+}
+
+private[arrow] class DoubleWriter(val valueVector: Float8Vector) extends 
ArrowFieldWriter {
+
+  override def setNull(): Unit = {
+    valueVector.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    valueVector.setSafe(count, input.getDouble(ordinal))
+  }
+}
+
+private[arrow] class DecimalWriter(val valueVector: DecimalVector, precision: 
Int, scale: Int)
+    extends ArrowFieldWriter {
+
+  override def setNull(): Unit = {
+    valueVector.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    val decimal = input.getDecimal(ordinal, precision, scale)
+    if (decimal.changePrecision(precision, scale)) {
+      valueVector.setSafe(count, decimal.toJavaBigDecimal)
+    } else {
+      setNull()
+    }
+  }
+}
+
+private[arrow] class StringWriter(val valueVector: VarCharVector) extends 
ArrowFieldWriter {
+
+  override def setNull(): Unit = {
+    valueVector.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    val utf8 = input.getUTF8String(ordinal)
+    val utf8ByteBuffer = utf8.getByteBuffer
+    // todo: for off-heap UTF8String, how to pass in to arrow without copy?
+    valueVector.setSafe(count, utf8ByteBuffer, utf8ByteBuffer.position(), 
utf8.numBytes())
+  }
+}
+
+private[arrow] class LargeStringWriter(val valueVector: LargeVarCharVector)
+    extends ArrowFieldWriter {
+
+  override def setNull(): Unit = {
+    valueVector.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    val utf8 = input.getUTF8String(ordinal)
+    val utf8ByteBuffer = utf8.getByteBuffer
+    // todo: for off-heap UTF8String, how to pass in to arrow without copy?
+    valueVector.setSafe(count, utf8ByteBuffer, utf8ByteBuffer.position(), 
utf8.numBytes())
+  }
+}
+
+private[arrow] class BinaryWriter(val valueVector: VarBinaryVector) extends 
ArrowFieldWriter {
+
+  override def setNull(): Unit = {
+    valueVector.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    val bytes = input.getBinary(ordinal)
+    valueVector.setSafe(count, bytes, 0, bytes.length)
+  }
+}
+
+private[arrow] class LargeBinaryWriter(val valueVector: LargeVarBinaryVector)
+    extends ArrowFieldWriter {
+
+  override def setNull(): Unit = {
+    valueVector.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    val bytes = input.getBinary(ordinal)
+    valueVector.setSafe(count, bytes, 0, bytes.length)
+  }
+}
+
+private[arrow] class DateWriter(val valueVector: DateDayVector) extends 
ArrowFieldWriter {
+
+  override def setNull(): Unit = {
+    valueVector.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    valueVector.setSafe(count, input.getInt(ordinal))
+  }
+}
+
+private[arrow] class TimestampWriter(val valueVector: TimeStampMicroTZVector)
+    extends ArrowFieldWriter {
+
+  override def setNull(): Unit = {
+    valueVector.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    valueVector.setSafe(count, input.getLong(ordinal))
+  }
+}
+
+private[arrow] class TimestampNTZWriter(val valueVector: TimeStampMicroVector)
+    extends ArrowFieldWriter {
+
+  override def setNull(): Unit = {
+    valueVector.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    valueVector.setSafe(count, input.getLong(ordinal))
+  }
+}
+
+private[arrow] class ArrayWriter(val valueVector: ListVector, val 
elementWriter: ArrowFieldWriter)
+    extends ArrowFieldWriter {
+
+  override def setNull(): Unit = {}
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    val array = input.getArray(ordinal)
+    var i = 0
+    valueVector.startNewValue(count)
+    while (i < array.numElements()) {
+      elementWriter.write(array, i)
+      i += 1
+    }
+    valueVector.endValue(count, array.numElements())
+  }
+
+  override def finish(): Unit = {
+    super.finish()
+    elementWriter.finish()
+  }
+
+  override def reset(): Unit = {
+    super.reset()
+    elementWriter.reset()
+  }
+}
+
+private[arrow] class StructWriter(
+    val valueVector: StructVector,
+    children: Array[ArrowFieldWriter])
+    extends ArrowFieldWriter {
+
+  override def setNull(): Unit = {
+    var i = 0
+    while (i < children.length) {
+      children(i).setNull()
+      children(i).count += 1
+      i += 1
+    }
+    valueVector.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    val struct = input.getStruct(ordinal, children.length)
+    var i = 0
+    valueVector.setIndexDefined(count)
+    while (i < struct.numFields) {
+      children(i).write(struct, i)
+      i += 1
+    }
+  }
+
+  override def finish(): Unit = {
+    super.finish()
+    children.foreach(_.finish())
+  }
+
+  override def reset(): Unit = {
+    super.reset()
+    children.foreach(_.reset())
+  }
+}
+
+private[arrow] class MapWriter(
+    val valueVector: MapVector,
+    val structVector: StructVector,
+    val keyWriter: ArrowFieldWriter,
+    val valueWriter: ArrowFieldWriter)
+    extends ArrowFieldWriter {
+
+  override def setNull(): Unit = {}
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    val map = input.getMap(ordinal)
+    valueVector.startNewValue(count)
+    val keys = map.keyArray()
+    val values = map.valueArray()
+    var i = 0
+    while (i < map.numElements()) {
+      structVector.setIndexDefined(keyWriter.count)
+      keyWriter.write(keys, i)
+      valueWriter.write(values, i)
+      i += 1
+    }
+
+    valueVector.endValue(count, map.numElements())
+  }
+
+  override def finish(): Unit = {
+    super.finish()
+    keyWriter.finish()
+    valueWriter.finish()
+  }
+
+  override def reset(): Unit = {
+    super.reset()
+    keyWriter.reset()
+    valueWriter.reset()
+  }
+}
+
+private[arrow] class NullWriter(val valueVector: NullVector) extends 
ArrowFieldWriter {
+
+  override def setNull(): Unit = {}
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {}
+}
+
+private[arrow] class IntervalYearWriter(val valueVector: IntervalYearVector)
+    extends ArrowFieldWriter {
+  override def setNull(): Unit = {
+    valueVector.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    valueVector.setSafe(count, input.getInt(ordinal));
+  }
+}
+
+private[arrow] class DurationWriter(val valueVector: DurationVector) extends 
ArrowFieldWriter {
+  override def setNull(): Unit = {
+    valueVector.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    valueVector.setSafe(count, input.getLong(ordinal))
+  }
+}
+
+private[arrow] class IntervalMonthDayNanoWriter(val valueVector: 
IntervalMonthDayNanoVector)
+    extends ArrowFieldWriter {
+  override def setNull(): Unit = {
+    valueVector.setNull(count)
+  }
+
+  override def setValue(input: SpecializedGetters, ordinal: Int): Unit = {
+    val ci = input.getInterval(ordinal)
+    valueVector.setSafe(count, ci.months, ci.days, 
Math.multiplyExact(ci.microseconds, 1000L))
+  }
+}
diff --git 
a/common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala
 
b/common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala
new file mode 100644
index 0000000..9dbd8dc
--- /dev/null
+++ 
b/common/src/main/scala/org/apache/spark/sql/comet/execution/arrow/CometArrowConverters.scala
@@ -0,0 +1,131 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet.execution.arrow
+
+import org.apache.arrow.memory.{BufferAllocator, RootAllocator}
+import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.spark.TaskContext
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.comet.util.Utils
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.vectorized.ColumnarBatch
+
+import org.apache.comet.vector.NativeUtil
+
+object CometArrowConverters extends Logging {
+  // TODO: we should reuse the same root allocator in the comet code base?
+  val rootAllocator: BufferAllocator = new RootAllocator(Long.MaxValue)
+
+  // This is similar how Spark converts internal row to Arrow format except 
that it is transforming
+  // the result batch to Comet's ColumnarBatch instead of serialized bytes.
+  // There's another big difference that Comet may consume the ColumnarBatch 
by exporting it to
+  // the native side. Hence, we need to:
+  // 1. reset the Arrow writer after the ColumnarBatch is consumed
+  // 2. close the allocator when the task is finished but not when the 
iterator is all consumed
+  // The reason for the second point is that when ColumnarBatch is exported to 
the native side, the
+  // exported process increases the reference count of the Arrow vectors. The 
reference count is
+  // only decreased when the native plan is done with the vectors, which is 
usually longer than
+  // all the ColumnarBatches are consumed.
+  private[sql] class ArrowBatchIterator(
+      rowIter: Iterator[InternalRow],
+      schema: StructType,
+      maxRecordsPerBatch: Long,
+      timeZoneId: String,
+      context: TaskContext)
+      extends Iterator[ColumnarBatch]
+      with AutoCloseable {
+
+    private val arrowSchema = Utils.toArrowSchema(schema, timeZoneId)
+    // Reuse the same root allocator here.
+    private val allocator =
+      rootAllocator.newChildAllocator(s"to${this.getClass.getSimpleName}", 0, 
Long.MaxValue)
+    private val root = VectorSchemaRoot.create(arrowSchema, allocator)
+    private val arrowWriter = ArrowWriter.create(root)
+
+    private var currentBatch: ColumnarBatch = null
+    private var closed: Boolean = false
+
+    Option(context).foreach {
+      _.addTaskCompletionListener[Unit] { _ =>
+        close(true)
+      }
+    }
+
+    override def hasNext: Boolean = rowIter.hasNext || {
+      close(false)
+      false
+    }
+
+    override def next(): ColumnarBatch = {
+      currentBatch = nextBatch()
+      currentBatch
+    }
+
+    override def close(): Unit = {
+      close(false)
+    }
+
+    private def nextBatch(): ColumnarBatch = {
+      if (rowIter.hasNext) {
+        // the arrow writer shall be reset before writing the next batch
+        arrowWriter.reset()
+        var rowCount = 0L
+        while (rowIter.hasNext && (maxRecordsPerBatch <= 0 || rowCount < 
maxRecordsPerBatch)) {
+          val row = rowIter.next()
+          arrowWriter.write(row)
+          rowCount += 1
+        }
+        arrowWriter.finish()
+        NativeUtil.rootAsBatch(root)
+      } else {
+        null
+      }
+    }
+
+    private def close(closeAllocator: Boolean): Unit = {
+      try {
+        if (!closed) {
+          if (currentBatch != null) {
+            arrowWriter.reset()
+            currentBatch.close()
+            currentBatch = null
+          }
+          root.close()
+          closed = true
+        }
+      } finally {
+        // the allocator shall be closed when the task is finished
+        if (closeAllocator) {
+          allocator.close()
+        }
+      }
+    }
+  }
+
+  def toArrowBatchIterator(
+      rowIter: Iterator[InternalRow],
+      schema: StructType,
+      maxRecordsPerBatch: Long,
+      timeZoneId: String,
+      context: TaskContext): Iterator[ColumnarBatch] = {
+    new ArrowBatchIterator(rowIter, schema, maxRecordsPerBatch, timeZoneId, 
context)
+  }
+}
diff --git a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala 
b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala
index 684d778..7d920e1 100644
--- a/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala
+++ b/common/src/main/scala/org/apache/spark/sql/comet/util/Utils.scala
@@ -54,6 +54,11 @@ object Utils {
     str.split(",").map(_.trim()).filter(_.nonEmpty)
   }
 
+  /** bridges the function call to Spark's Util */
+  def getSimpleName(cls: Class[_]): String = {
+    org.apache.spark.util.Utils.getSimpleName(cls)
+  }
+
   def fromArrowField(field: Field): DataType = {
     field.getType match {
       case _: ArrowType.Map =>
@@ -90,6 +95,9 @@ object Utils {
     case _: ArrowType.FixedSizeBinary => BinaryType
     case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale)
     case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType
+    case ts: ArrowType.Timestamp
+        if ts.getUnit == TimeUnit.MICROSECOND && ts.getTimezone == null =>
+      TimestampNTZType
     case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => 
TimestampType
     case ArrowType.Null.INSTANCE => NullType
     case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH =>
@@ -98,6 +106,13 @@ object Utils {
     case _ => throw new UnsupportedOperationException(s"Unsupported data type: 
${dt.toString}")
   }
 
+  def fromArrowSchema(schema: Schema): StructType = {
+    StructType(schema.getFields.asScala.map { field =>
+      val dt = fromArrowField(field)
+      StructField(field.getName, dt, field.isNullable)
+    }.toArray)
+  }
+
   /** Maps data type from Spark to Arrow. NOTE: timeZoneId required for 
TimestampTypes */
   def toArrowType(dt: DataType, timeZoneId: String): ArrowType =
     dt match {
diff --git 
a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala 
b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
index 7795194..a10ac57 100644
--- a/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
+++ b/spark/src/main/scala/org/apache/comet/CometSparkSessionExtensions.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.comet._
 import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, 
CometNativeShuffle}
 import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
+import org.apache.spark.sql.comet.util.Utils
 import org.apache.spark.sql.execution._
 import org.apache.spark.sql.execution.adaptive.{BroadcastQueryStageExec, 
ShuffleQueryStageExec}
 import org.apache.spark.sql.execution.aggregate.HashAggregateExec
@@ -43,7 +44,7 @@ import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types._
 
 import org.apache.comet.CometConf._
-import org.apache.comet.CometSparkSessionExtensions.{isANSIEnabled, 
isCometBroadCastForceEnabled, isCometColumnarShuffleEnabled, isCometEnabled, 
isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled, 
isCometShuffleEnabled, isSchemaSupported}
+import org.apache.comet.CometSparkSessionExtensions.{isANSIEnabled, 
isCometBroadCastForceEnabled, isCometColumnarShuffleEnabled, isCometEnabled, 
isCometExecEnabled, isCometOperatorEnabled, isCometScan, isCometScanEnabled, 
isCometShuffleEnabled, isSchemaSupported, shouldApplyRowToColumnar}
 import org.apache.comet.parquet.{CometParquetScan, SupportsComet}
 import org.apache.comet.serde.OperatorOuterClass.Operator
 import org.apache.comet.serde.QueryPlanSerde
@@ -68,7 +69,7 @@ class CometSparkSessionExtensions
     override def preColumnarTransitions: Rule[SparkPlan] = 
CometExecRule(session)
 
     override def postColumnarTransitions: Rule[SparkPlan] =
-      EliminateRedundantColumnarToRow(session)
+      EliminateRedundantTransitions(session)
   }
 
   case class CometScanRule(session: SparkSession) extends Rule[SparkPlan] {
@@ -238,6 +239,11 @@ class CometSparkSessionExtensions
           val nativeOp = QueryPlanSerde.operator2Proto(op).get
           CometScanWrapper(nativeOp, op)
 
+        case op if shouldApplyRowToColumnar(conf, op) =>
+          val cometOp = CometRowToColumnarExec(op)
+          val nativeOp = QueryPlanSerde.operator2Proto(cometOp).get
+          CometScanWrapper(nativeOp, cometOp)
+
         case op: ProjectExec =>
           val newOp = transform1(op)
           newOp match {
@@ -592,18 +598,32 @@ class CometSparkSessionExtensions
     }
   }
 
-  // CometExec already wraps a `ColumnarToRowExec` for row-based operators. 
Therefore,
-  // `ColumnarToRowExec` is redundant and can be eliminated.
+  // This rule is responsible for eliminating redundant transitions between 
row-based and
+  // columnar-based operators for Comet. Currently, two potential redundant 
transitions are:
+  // 1. `ColumnarToRowExec` on top of an ending `CometCollectLimitExec` 
operator, which is
+  //    redundant as `CometCollectLimitExec` already wraps a 
`ColumnarToRowExec` for row-based
+  //    output.
+  // 2. Consecutive operators of `CometRowToColumnarExec` and 
`ColumnarToRowExec`.
+  //
+  // Note about the first case: The `ColumnarToRowExec` was added during
+  // ApplyColumnarRulesAndInsertTransitions' insertTransitions phase when 
Spark requests row-based
+  // output such as a `collect` call. It's correct to add a redundant 
`ColumnarToRowExec` for
+  // `CometExec`. However, for certain operators such as 
`CometCollectLimitExec` which overrides
+  // `executeCollect`, the redundant `ColumnarToRowExec` makes the override 
ineffective.
   //
-  // It was added during ApplyColumnarRulesAndInsertTransitions' 
insertTransitions phase when Spark
-  // requests row-based output such as `collect` call. It's correct to add a 
redundant
-  // `ColumnarToRowExec` for `CometExec`. However, for certain operators such 
as
-  // `CometCollectLimitExec` which overrides `executeCollect`, the redundant 
`ColumnarToRowExec`
-  // makes the override ineffective. The purpose of this rule is to eliminate 
the redundant
-  // `ColumnarToRowExec` for such operators.
-  case class EliminateRedundantColumnarToRow(session: SparkSession) extends 
Rule[SparkPlan] {
+  // Note about the second case: When `spark.comet.rowToColumnar.enabled` is 
set, Comet will add
+  // `CometRowToColumnarExec` on top of row-based operators first, but the 
downstream operator
+  // only takes row-based input as it's a vanilla Spark operator(as Comet 
cannot convert it for
+  // various reasons) or Spark requests row-based output such as a `collect` 
call. Spark will adds
+  // another `ColumnarToRowExec` on top of `CometRowToColumnarExec`. In this 
case, the pair could
+  // be removed.
+  case class EliminateRedundantTransitions(session: SparkSession) extends 
Rule[SparkPlan] {
     override def apply(plan: SparkPlan): SparkPlan = {
-      plan match {
+      val eliminatedPlan = plan transformUp {
+        case ColumnarToRowExec(rowToColumnar: CometRowToColumnarExec) => 
rowToColumnar.child
+      }
+
+      eliminatedPlan match {
         case ColumnarToRowExec(child: CometCollectLimitExec) =>
           child
         case other =>
@@ -716,6 +736,18 @@ object CometSparkSessionExtensions extends Logging {
     op.isInstanceOf[CometBatchScanExec] || op.isInstanceOf[CometScanExec]
   }
 
+  private def shouldApplyRowToColumnar(conf: SQLConf, op: SparkPlan): Boolean 
= {
+    // Only consider converting leaf nodes to columnar currently, so that all 
the following
+    // operators can have a chance to be converted to columnar.
+    // TODO: consider converting other intermediate operators to columnar.
+    op.isInstanceOf[LeafExecNode] && !op.supportsColumnar && 
isSchemaSupported(op.schema) &&
+    COMET_ROW_TO_COLUMNAR_ENABLED.get(conf) && {
+      val simpleClassName = Utils.getSimpleName(op.getClass)
+      val nodeName = simpleClassName.replaceAll("Exec$", "")
+      
COMET_ROW_TO_COLUMNAR_SUPPORTED_OPERATOR_LIST.get(conf).contains(nodeName)
+    }
+  }
+
   /** Used for operations that weren't available in Spark 3.2 */
   def isSpark32: Boolean = {
     org.apache.spark.SPARK_VERSION.matches("3\\.2\\..*")
diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala 
b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
index b98c438..26fc708 100644
--- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
+++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.optimizer.{BuildRight, 
NormalizeNaNAndZero}
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, 
Partitioning, SinglePartition}
 import org.apache.spark.sql.catalyst.util.CharVarcharCodegenUtils
-import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, 
CometSinkPlaceHolder, DecimalPrecision}
+import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, 
CometRowToColumnarExec, CometSinkPlaceHolder, DecimalPrecision}
 import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec
 import org.apache.spark.sql.execution
 import org.apache.spark.sql.execution._
@@ -2064,6 +2064,7 @@ object QueryPlanSerde extends Logging with 
ShimQueryPlanSerde {
   private def isCometSink(op: SparkPlan): Boolean = {
     op match {
       case s if isCometScan(s) => true
+      case _: CometRowToColumnarExec => true
       case _: CometSinkPlaceHolder => true
       case _: CoalesceExec => true
       case _: CollectLimitExec => true
diff --git 
a/spark/src/main/scala/org/apache/spark/sql/comet/CometRowToColumnarExec.scala 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometRowToColumnarExec.scala
new file mode 100644
index 0000000..5679e86
--- /dev/null
+++ 
b/spark/src/main/scala/org/apache/spark/sql/comet/CometRowToColumnarExec.scala
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.spark.sql.comet
+
+import org.apache.spark.TaskContext
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, SortOrder}
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.comet.execution.arrow.CometArrowConverters
+import org.apache.spark.sql.execution.{RowToColumnarTransition, SparkPlan}
+import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
+import org.apache.spark.sql.vectorized.ColumnarBatch
+
+case class CometRowToColumnarExec(child: SparkPlan)
+    extends RowToColumnarTransition
+    with CometPlan {
+  override def output: Seq[Attribute] = child.output
+
+  override def outputPartitioning: Partitioning = child.outputPartitioning
+
+  override def outputOrdering: Seq[SortOrder] = child.outputOrdering
+
+  override protected def doExecute(): RDD[InternalRow] = {
+    child.execute()
+  }
+
+  override def doExecuteBroadcast[T](): Broadcast[T] = {
+    child.executeBroadcast()
+  }
+
+  override def supportsColumnar: Boolean = true
+
+  override lazy val metrics: Map[String, SQLMetric] = Map(
+    "numInputRows" -> SQLMetrics.createMetric(sparkContext, "number of input 
rows"),
+    "numOutputBatches" -> SQLMetrics.createMetric(sparkContext, "number of 
output batches"))
+
+  override def doExecuteColumnar(): RDD[ColumnarBatch] = {
+    val numInputRows = longMetric("numInputRows")
+    val numOutputBatches = longMetric("numOutputBatches")
+    val maxRecordsPerBatch = conf.arrowMaxRecordsPerBatch
+    val timeZoneId = conf.sessionLocalTimeZone
+    val schema = child.schema
+
+    child
+      .execute()
+      .mapPartitionsInternal { iter =>
+        val context = TaskContext.get()
+        CometArrowConverters.toArrowBatchIterator(
+          iter,
+          schema,
+          maxRecordsPerBatch,
+          timeZoneId,
+          context)
+      }
+      .map { batch =>
+        numInputRows += batch.numRows()
+        numOutputBatches += 1
+        batch
+      }
+  }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): 
CometRowToColumnarExec =
+    copy(child = newChild)
+
+}
diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala 
b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
index 520f239..8545eee 100644
--- a/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
+++ b/spark/src/main/scala/org/apache/spark/sql/comet/operators.scala
@@ -270,7 +270,7 @@ abstract class CometNativeExec extends CometExec {
         }
 
         if (inputs.isEmpty) {
-          throw new CometRuntimeException(s"No input for CometNativeExec: 
$this")
+          throw new CometRuntimeException(s"No input for CometNativeExec:\n 
$this")
         }
 
         ZippedPartitionsRDD(sparkContext, inputs.toSeq)(createCometExecIter(_))
@@ -300,7 +300,8 @@ abstract class CometNativeExec extends CometExec {
       case _: CometScanExec | _: CometBatchScanExec | _: ShuffleQueryStageExec 
|
           _: AQEShuffleReadExec | _: CometShuffleExchangeExec | _: 
CometUnionExec |
           _: CometTakeOrderedAndProjectExec | _: CometCoalesceExec | _: 
ReusedExchangeExec |
-          _: CometBroadcastExchangeExec | _: BroadcastQueryStageExec =>
+          _: CometBroadcastExchangeExec | _: BroadcastQueryStageExec |
+          _: CometRowToColumnarExec =>
         func(plan)
       case _: CometPlan =>
         // Other Comet operators, continue to traverse the tree.
diff --git a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala 
b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
index b2c4fd6..0bb21ab 100644
--- a/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/exec/CometExecSuite.scala
@@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
 import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStatistics, 
CatalogTable}
 import org.apache.spark.sql.catalyst.expressions.Hex
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateMode
-import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, 
CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, 
CometProjectExec, CometScanExec, CometTakeOrderedAndProjectExec}
+import org.apache.spark.sql.comet.{CometBroadcastExchangeExec, 
CometCollectLimitExec, CometFilterExec, CometHashAggregateExec, 
CometProjectExec, CometRowToColumnarExec, CometScanExec, 
CometTakeOrderedAndProjectExec}
 import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, 
CometShuffleExchangeExec}
 import org.apache.spark.sql.execution.{CollectLimitExec, ProjectExec, 
SQLExecution, UnionExec}
 import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec
@@ -1118,6 +1118,58 @@ class CometExecSuite extends CometTestBase {
       }
     })
   }
+
+  test("RowToColumnar over RangeExec") {
+    Seq("true", "false").foreach(aqe => {
+      Seq(500, 900).foreach { batchSize =>
+        withSQLConf(
+          SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqe,
+          SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> 
batchSize.toString) {
+          val df = spark.range(1000).selectExpr("id", "id % 8 as 
k").groupBy("k").sum("id")
+          checkSparkAnswerAndOperator(df)
+          // empty record batch should also be handled
+          val df2 = spark.range(0).selectExpr("id", "id % 8 as 
k").groupBy("k").sum("id")
+          checkSparkAnswerAndOperator(df2, includeClasses = 
Seq(classOf[CometRowToColumnarExec]))
+        }
+      }
+    })
+  }
+
+  test("RowToColumnar over RangeExec directly is eliminated for row output") {
+    Seq("true", "false").foreach(aqe => {
+      Seq(500, 900).foreach { batchSize =>
+        withSQLConf(
+          SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqe,
+          SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> 
batchSize.toString) {
+          val df = spark.range(1000)
+          val qe = df.queryExecution
+          qe.executedPlan.collectFirst({ case r: CometRowToColumnarExec => r 
}) match {
+            case Some(_) => fail("CometRowToColumnarExec should be eliminated")
+            case _ =>
+          }
+        }
+      }
+    })
+  }
+
+  test("RowToColumnar over InMemoryTableScanExec") {
+    Seq("true", "false").foreach(aqe => {
+      withSQLConf(
+        SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> aqe,
+        CometConf.COMET_COLUMNAR_SHUFFLE_ENABLED.key -> "true",
+        SQLConf.CACHE_VECTORIZED_READER_ENABLED.key -> "false") {
+        spark
+          .range(1000)
+          .selectExpr("id as key", "id % 8 as value")
+          .toDF("key", "value")
+          .selectExpr("key", "value", "key+1")
+          .createOrReplaceTempView("abc")
+        spark.catalog.cacheTable("abc")
+        val df = spark.sql("SELECT * FROM abc").groupBy("key").count()
+        checkSparkAnswerAndOperator(df, includeClasses = 
Seq(classOf[CometRowToColumnarExec]))
+      }
+    })
+  }
 }
 
 case class BucketedTableTestSpec(
diff --git a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala 
b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
index 6fb81bc..de58665 100644
--- a/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
+++ b/spark/src/test/scala/org/apache/spark/sql/CometTestBase.scala
@@ -34,7 +34,7 @@ import org.apache.parquet.hadoop.example.ExampleParquetWriter
 import org.apache.parquet.schema.{MessageType, MessageTypeParser}
 import org.apache.spark._
 import org.apache.spark.internal.config.{MEMORY_OFFHEAP_ENABLED, 
MEMORY_OFFHEAP_SIZE, SHUFFLE_MANAGER}
-import org.apache.spark.sql.comet.{CometBatchScanExec, 
CometBroadcastExchangeExec, CometExec, CometScanExec, CometScanWrapper, 
CometSinkPlaceHolder}
+import org.apache.spark.sql.comet.{CometBatchScanExec, 
CometBroadcastExchangeExec, CometExec, CometRowToColumnarExec, CometScanExec, 
CometScanWrapper, CometSinkPlaceHolder}
 import org.apache.spark.sql.comet.execution.shuffle.{CometColumnarShuffle, 
CometNativeShuffle, CometShuffleExchangeExec}
 import org.apache.spark.sql.execution.{ColumnarToRowExec, InputAdapter, 
SparkPlan, WholeStageCodegenExec}
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
@@ -75,6 +75,7 @@ abstract class CometTestBase
     conf.set(CometConf.COMET_EXEC_ENABLED.key, "true")
     conf.set(CometConf.COMET_EXEC_ALL_OPERATOR_ENABLED.key, "true")
     conf.set(CometConf.COMET_EXEC_ALL_EXPR_ENABLED.key, "true")
+    conf.set(CometConf.COMET_ROW_TO_COLUMNAR_ENABLED.key, "true")
     conf.set(CometConf.COMET_MEMORY_OVERHEAD.key, "2g")
     conf
   }
@@ -155,9 +156,11 @@ abstract class CometTestBase
   }
 
   protected def checkCometOperators(plan: SparkPlan, excludedClasses: 
Class[_]*): Unit = {
-    plan.foreach {
+    val wrapped = wrapCometRowToColumnar(plan)
+    wrapped.foreach {
       case _: CometScanExec | _: CometBatchScanExec => true
       case _: CometSinkPlaceHolder | _: CometScanWrapper => false
+      case _: CometRowToColumnarExec => false
       case _: CometExec | _: CometShuffleExchangeExec => true
       case _: CometBroadcastExchangeExec => true
       case _: WholeStageCodegenExec | _: ColumnarToRowExec | _: InputAdapter 
=> true
@@ -184,6 +187,14 @@ abstract class CometTestBase
     }
   }
 
+  /** Wraps the CometRowToColumn as ScanWrapper, so the child operators will 
not be checked */
+  private def wrapCometRowToColumnar(plan: SparkPlan): SparkPlan = {
+    plan.transformDown {
+      // don't care the native operators
+      case p: CometRowToColumnarExec => CometScanWrapper(null, p)
+    }
+  }
+
   /**
    * Check the answer of a Comet SQL query with Spark result using absolute 
tolerance.
    */


Reply via email to