This is an automated email from the ASF dual-hosted git repository.
wangguangxin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new cc55a28787 [VL] Make PartialProject support struct with null fields
(#10706)
cc55a28787 is described below
commit cc55a28787bfb033d925c4fe6d4447131e833967
Author: jiangjiangtian <[email protected]>
AuthorDate: Tue Oct 21 16:15:41 2025 +0800
[VL] Make PartialProject support struct with null fields (#10706)
* Make PartialProject support struct with null fields
* fix compilation error
* fix
---------
Co-authored-by: 蒋添 <[email protected]>
---
.../execution/ColumnarPartialGenerateExec.scala | 10 ++-
.../execution/ColumnarPartialProjectExec.scala | 6 +-
.../gluten/expression/UDFPartialProjectSuite.scala | 25 +++++++
.../gluten/columnarbatch/ColumnarBatches.java | 10 +++
.../vectorized/ArrowWritableColumnVector.java | 10 +++
.../gluten/vectorized/ArrowColumnarBatch.scala | 83 ++++++++++++++++++++++
.../gluten/vectorized/ArrowColumnarRow.scala | 48 +++++++------
7 files changed, 166 insertions(+), 26 deletions(-)
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialGenerateExec.scala
b/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialGenerateExec.scala
index ea1be35995..4e447df064 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialGenerateExec.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialGenerateExec.scala
@@ -23,7 +23,7 @@ import
org.apache.gluten.extension.columnar.transition.Convention
import org.apache.gluten.iterator.Iterators
import org.apache.gluten.memory.arrow.alloc.ArrowBufferAllocators
import org.apache.gluten.sql.shims.SparkShimLoader
-import org.apache.gluten.vectorized.{ArrowColumnarRow,
ArrowWritableColumnVector}
+import org.apache.gluten.vectorized.{ArrowColumnarBatch, ArrowColumnarRow,
ArrowWritableColumnVector}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
@@ -179,12 +179,16 @@ case class ColumnarPartialGenerateExec(generateExec:
GenerateExec, child: SparkP
}
}
- private def loadArrowBatch(inputData: ColumnarBatch): ColumnarBatch = {
- if (inputData.numCols() == 0) {
+ private def loadArrowBatch(inputData: ColumnarBatch): ArrowColumnarBatch = {
+ val sparkColumnarBatch = if (inputData.numCols() == 0) {
inputData
} else {
ColumnarBatches.load(ArrowBufferAllocators.contextInstance(), inputData)
}
+ // In spark with version belows 4.0, the `ColumnarRow`'s get method
doesn't check whether the
+ // column to get is null, so we change it to `ArrowColumnarBatch`
manually. `ArrowColumnarBatch`
+ // returns `ArrowColumnarRow`, which fixes the bug.
+ ColumnarBatches.convertToArrowColumnarBatch(sparkColumnarBatch)
}
private def isVariableWidthType(dt: DataType): Boolean = dt match {
diff --git
a/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala
b/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala
index 232c535fda..8fa33d97a1 100644
---
a/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala
+++
b/backends-velox/src/main/scala/org/apache/gluten/execution/ColumnarPartialProjectExec.scala
@@ -214,11 +214,15 @@ case class ColumnarPartialProjectExec(projectList:
Seq[NamedExpression], child:
val proj = ArrowProjection.create(replacedAlias, projectAttributes.toSeq)
val numRows = childData.numRows()
val start = System.currentTimeMillis()
- val arrowBatch = if (childData.numCols() == 0) {
+ val sparkColumnarBatch = if (childData.numCols() == 0) {
childData
} else {
ColumnarBatches.load(ArrowBufferAllocators.contextInstance(), childData)
}
+ // In spark with version belows 4.0, the `ColumnarRow`'s get method
doesn't check whether the
+ // column to get is null, so we change it to `ArrowColumnarBatch`
manually. `ArrowColumnarBatch`
+ // returns `ArrowColumnarRow`, which fixes the bug.
+ val arrowBatch =
ColumnarBatches.convertToArrowColumnarBatch(sparkColumnarBatch)
c2a += System.currentTimeMillis() - start
val schema =
diff --git
a/backends-velox/src/test/scala/org/apache/gluten/expression/UDFPartialProjectSuite.scala
b/backends-velox/src/test/scala/org/apache/gluten/expression/UDFPartialProjectSuite.scala
index 5152cbc457..ab6d111214 100644
---
a/backends-velox/src/test/scala/org/apache/gluten/expression/UDFPartialProjectSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/gluten/expression/UDFPartialProjectSuite.scala
@@ -28,6 +28,8 @@ import java.io.File
case class MyStruct(a: Long, b: Array[Long])
+case class MyStructWithNullValue(a: Option[Long], b: Array[Long])
+
class UDFPartialProjectSuiteRasOff extends UDFPartialProjectSuite {
override protected def sparkConf: SparkConf = {
super.sparkConf
@@ -247,4 +249,27 @@ abstract class UDFPartialProjectSuite extends
WholeStageTransformerSuite {
}
}
}
+
+ test("test struct data with null fields") {
+ spark.udf.register(
+ "struct_plus_one",
+ udf(
+ (m: MyStructWithNullValue) =>
+ MyStructWithNullValue(if (m.a.isEmpty) None else Some(m.a.get + 1),
m.b.map(_ + 1))))
+ runQueryAndCompare("""
+ |SELECT
+ | l_partkey,
+ | struct_plus_one(struct_data)
+ |FROM (
+ | SELECT l_partkey,
+ | struct(
+ | CASE WHEN l_orderkey % 2 == 0 THEN l_orderkey
ELSE null END as a,
+ | array(l_orderkey % 2, l_orderkey % 2 + 1,
l_orderkey % 2 + 2) as b
+ | ) as struct_data
+ | FROM lineitem
+ |)
+ |""".stripMargin) {
+ checkGlutenOperatorMatch[ColumnarPartialProjectExec]
+ }
+ }
}
diff --git
a/gluten-arrow/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatches.java
b/gluten-arrow/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatches.java
index 156de4e0d8..01ceb7d20c 100644
---
a/gluten-arrow/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatches.java
+++
b/gluten-arrow/src/main/java/org/apache/gluten/columnarbatch/ColumnarBatches.java
@@ -22,6 +22,7 @@ import org.apache.gluten.runtime.Runtimes;
import org.apache.gluten.utils.ArrowAbiUtil;
import org.apache.gluten.utils.ArrowUtil;
import org.apache.gluten.utils.InternalRowUtl;
+import org.apache.gluten.vectorized.ArrowColumnarBatch;
import org.apache.gluten.vectorized.ArrowWritableColumnVector;
import com.google.common.annotations.VisibleForTesting;
@@ -171,6 +172,15 @@ public final class ColumnarBatches {
}
}
+ public static ArrowColumnarBatch convertToArrowColumnarBatch(ColumnarBatch
sparkColumnarBatch) {
+ int numCols = sparkColumnarBatch.numCols();
+ ArrowWritableColumnVector[] writableColumns = new
ArrowWritableColumnVector[numCols];
+ for (int i = 0; i < numCols; i++) {
+ writableColumns[i] = (ArrowWritableColumnVector)
sparkColumnarBatch.column(i);
+ }
+ return new ArrowColumnarBatch(writableColumns,
sparkColumnarBatch.numRows());
+ }
+
public static ColumnarBatch load(BufferAllocator allocator, ColumnarBatch
input) {
if (isZeroColumnBatch(input)) {
return input;
diff --git
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowWritableColumnVector.java
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowWritableColumnVector.java
index 5491b19ca3..d00786f3f4 100644
---
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowWritableColumnVector.java
+++
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowWritableColumnVector.java
@@ -411,6 +411,16 @@ public final class ArrowWritableColumnVector extends
WritableColumnVectorShim {
return "vectorCounter is " + vectorCount.get();
}
+ public ArrowColumnarRow getStructInternal(int rowId) {
+ if (isNullAt(rowId)) return null;
+ ArrowWritableColumnVector[] writableColumns =
+ new ArrowWritableColumnVector[childColumns.length];
+ for (int i = 0; i < writableColumns.length; i++) {
+ writableColumns[i] = (ArrowWritableColumnVector) childColumns[i];
+ }
+ return new ArrowColumnarRow(writableColumns, rowId);
+ }
+
@Override
public boolean hasNull() {
return accessor.getNullCount() > 0;
diff --git
a/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ArrowColumnarBatch.scala
b/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ArrowColumnarBatch.scala
new file mode 100644
index 0000000000..decd87e78f
--- /dev/null
+++
b/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ArrowColumnarBatch.scala
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.vectorized
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.vectorized.ColumnVector
+
+/**
+ * Because Spark-3.2 declares ColumnarBatch as final, so `ArrowColumnarBatch`
can't extend
+ * `ColumnarBatch`. The code is mainly copied from Spark-3.2
+ *
+ * @param writableColumns
+ * the columns this class wraps
+ * @param rowNumbers
+ * the number of rows this batch contains
+ */
+class ArrowColumnarBatch(writableColumns: Array[ArrowWritableColumnVector],
var rowNumbers: Int) {
+ private val arrowColumnarRow = new ArrowColumnarRow(writableColumns)
+
+ /**
+ * Called to close all the columns in this batch. It is not valid to access
the data after calling
+ * this. This must be called at the end to clean up memory allocations.
+ */
+ def close(): Unit = {
+ for (c <- writableColumns) {
+ c.close()
+ }
+ }
+
+ /** Returns an iterator over the rows in this batch. */
+ def rowIterator: Iterator[InternalRow] = {
+ val maxRows = numRows
+ val row = new ArrowColumnarRow(writableColumns)
+ new Iterator[InternalRow]() {
+ var rowId = 0
+
+ override def hasNext: Boolean = rowId < maxRows
+
+ override def next: InternalRow = {
+ if (rowId >= maxRows) {
+ throw new NoSuchElementException()
+ }
+ row.rowId = rowId
+ rowId = rowId + 1
+ row
+ }
+ }
+ }
+
+ /** Sets the number of rows in this batch. */
+ def setNumRows(numRows: Int): Unit = {
+ this.rowNumbers = numRows
+ }
+
+ /** Returns the number of columns that make up this batch. */
+ def numCols: Int = writableColumns.length
+
+ /** Returns the number of rows for read, including filtered rows. */
+ def numRows: Int = this.rowNumbers
+
+ /** Returns the column at `ordinal`. */
+ def column(ordinal: Int): ColumnVector = writableColumns(ordinal)
+
+ def getRow(rowId: Int): InternalRow = {
+ assert(rowId >= 0 && rowId < this.numRows)
+ arrowColumnarRow.rowId = rowId
+ arrowColumnarRow
+ }
+}
diff --git
a/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ArrowColumnarRow.scala
b/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ArrowColumnarRow.scala
index e5452e4ae5..f0e2c4dabf 100644
---
a/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ArrowColumnarRow.scala
+++
b/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ArrowColumnarRow.scala
@@ -22,7 +22,7 @@ import
org.apache.gluten.execution.InternalRowGetVariantCompatible
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
import org.apache.spark.sql.types._
-import org.apache.spark.sql.vectorized.{ColumnarArray, ColumnarMap,
ColumnarRow}
+import org.apache.spark.sql.vectorized.{ColumnarArray, ColumnarMap}
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import java.math.BigDecimal
@@ -30,10 +30,9 @@ import java.math.BigDecimal
// Copy from Spark MutableColumnarRow mostly but class member columns' type is
// ArrowWritableColumnVector. And support string and binary type to write,
// Arrow writer does not need to setNotNull before writing a value.
-final class ArrowColumnarRow(writableColumns: Array[ArrowWritableColumnVector])
+final class ArrowColumnarRow(writableColumns:
Array[ArrowWritableColumnVector], var rowId: Int = 0)
extends InternalRowGetVariantCompatible {
- var rowId: Int = 0
private val columns: Array[ArrowWritableColumnVector] = writableColumns
override def numFields(): Int = columns.length
@@ -109,8 +108,8 @@ final class ArrowColumnarRow(writableColumns:
Array[ArrowWritableColumnVector])
override def getInterval(ordinal: Int): CalendarInterval =
columns(ordinal).getInterval(rowId)
- override def getStruct(ordinal: Int, numFields: Int): ColumnarRow =
- columns(ordinal).getStruct(rowId)
+ override def getStruct(ordinal: Int, numFields: Int): ArrowColumnarRow =
+ columns(ordinal).getStructInternal(rowId)
override def getArray(ordinal: Int): ColumnarArray =
columns(ordinal).getArray(rowId)
@@ -118,23 +117,28 @@ final class ArrowColumnarRow(writableColumns:
Array[ArrowWritableColumnVector])
override def getMap(ordinal: Int): ColumnarMap =
columns(ordinal).getMap(rowId)
- override def get(ordinal: Int, dataType: DataType): AnyRef = dataType match {
- case _: BooleanType => java.lang.Boolean.valueOf(getBoolean(ordinal))
- case _: ByteType => java.lang.Byte.valueOf(getByte(ordinal))
- case _: ShortType => java.lang.Short.valueOf(getShort(ordinal))
- case _: IntegerType => java.lang.Integer.valueOf(getInt(ordinal))
- case _: LongType => java.lang.Long.valueOf(getLong(ordinal))
- case _: FloatType => java.lang.Float.valueOf(getFloat(ordinal))
- case _: DoubleType => java.lang.Double.valueOf(getDouble(ordinal))
- case _: StringType => getUTF8String(ordinal)
- case _: BinaryType => getBinary(ordinal)
- case t: DecimalType => getDecimal(ordinal, t.precision, t.scale)
- case _: DateType => java.lang.Integer.valueOf(getInt(ordinal))
- case _: TimestampType => java.lang.Long.valueOf(getLong(ordinal))
- case _: ArrayType => getArray(ordinal)
- case s: StructType => getStruct(ordinal, s.fields.length)
- case _: MapType => getMap(ordinal)
- case _ => throw new UnsupportedOperationException(s"Datatype not supported
$dataType")
+ override def get(ordinal: Int, dataType: DataType): AnyRef = {
+ if (isNullAt(ordinal)) {
+ return null
+ }
+ dataType match {
+ case _: BooleanType => java.lang.Boolean.valueOf(getBoolean(ordinal))
+ case _: ByteType => java.lang.Byte.valueOf(getByte(ordinal))
+ case _: ShortType => java.lang.Short.valueOf(getShort(ordinal))
+ case _: IntegerType => java.lang.Integer.valueOf(getInt(ordinal))
+ case _: LongType => java.lang.Long.valueOf(getLong(ordinal))
+ case _: FloatType => java.lang.Float.valueOf(getFloat(ordinal))
+ case _: DoubleType => java.lang.Double.valueOf(getDouble(ordinal))
+ case _: StringType => getUTF8String(ordinal)
+ case _: BinaryType => getBinary(ordinal)
+ case t: DecimalType => getDecimal(ordinal, t.precision, t.scale)
+ case _: DateType => java.lang.Integer.valueOf(getInt(ordinal))
+ case _: TimestampType => java.lang.Long.valueOf(getLong(ordinal))
+ case _: ArrayType => getArray(ordinal)
+ case s: StructType => getStruct(ordinal, s.fields.length)
+ case _: MapType => getMap(ordinal)
+ case _ => throw new UnsupportedOperationException(s"Datatype not
supported $dataType")
+ }
}
override def update(ordinal: Int, value: Any): Unit = {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]