This is an automated email from the ASF dual-hosted git repository.
chengchengjin pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 41073d5b17 [GLUTEN-11330][VL] Make PartialProject support array and
map with null values (#11331)
41073d5b17 is described below
commit 41073d5b1773dfe389defebaeede9f45a9692269
Author: jiangjiangtian <[email protected]>
AuthorDate: Fri Jan 2 13:05:33 2026 +0800
[GLUTEN-11330][VL] Make PartialProject support array and map with null
values (#11331)
---------
Co-authored-by: jiangtian <[email protected]>
---
.../java/org/apache/gluten/udf/DuplicateArray.java | 85 ++++++++
.../spark/sql/execution/GlutenHiveUDFSuite.scala | 55 ++++-
.../gluten/vectorized/ArrowColumnarArray.java | 36 +++
.../apache/gluten/vectorized/ArrowColumnarMap.java | 55 +++++
.../vectorized/ArrowWritableColumnVector.java | 30 ++-
.../gluten/vectorized/ArrowColumnarRow.scala | 10 +-
.../execution/vectorized/ColumnarArrayShim.java | 234 ++++++++++++++++++++
.../execution/vectorized/ColumnarArrayShim.java | 234 ++++++++++++++++++++
.../execution/vectorized/ColumnarArrayShim.java | 234 ++++++++++++++++++++
.../execution/vectorized/ColumnarArrayShim.java | 234 ++++++++++++++++++++
.../execution/vectorized/ColumnarArrayShim.java | 241 +++++++++++++++++++++
11 files changed, 1434 insertions(+), 14 deletions(-)
diff --git
a/backends-velox/src/test/java/org/apache/gluten/udf/DuplicateArray.java
b/backends-velox/src/test/java/org/apache/gluten/udf/DuplicateArray.java
new file mode 100644
index 0000000000..dc7905859a
--- /dev/null
+++ b/backends-velox/src/test/java/org/apache/gluten/udf/DuplicateArray.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.udf;
+
+import org.apache.hadoop.hive.ql.exec.Description;
+import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
+import org.apache.hadoop.hive.ql.metadata.HiveException;
+import org.apache.hadoop.hive.ql.udf.generic.GenericUDF;
+import org.apache.hadoop.hive.serde2.objectinspector.ListObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
+import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+/** UDF for duplicating array. */
+@Description(
+ name = "array_duplicate",
+ value =
+ "_FUNC_(array(obj1, obj2,...)) - "
+ + "The function returns an array of the same type as every element"
+ + "in array is duplicated.",
+ extended =
+ "Example:\n"
+ + " > SELECT _FUNC_(array('b', 'd')) FROM src LIMIT 1;\n"
+ + " ['b', 'b', 'd', 'd']")
+public class DuplicateArray extends GenericUDF {
+
+ ListObjectInspector arrayOI;
+
+ public DuplicateArray() {}
+
+ @Override
+ public ObjectInspector initialize(ObjectInspector[] arguments) throws
UDFArgumentException {
+ if (arguments.length != 1) {
+ throw new UDFArgumentException("Argument size of array_duplicate must be
1.");
+ }
+
+ arrayOI = (ListObjectInspector) arguments[0];
+ return ObjectInspectorFactory.getStandardListObjectInspector(
+ arrayOI.getListElementObjectInspector());
+ }
+
+ @Override
+ public Object evaluate(DeferredObject[] arguments) throws HiveException {
+
+ Object array = arguments[0].get();
+
+ // If the array is empty, return back the empty array
+ if (arrayOI.getListLength(array) == 0) {
+ return Collections.emptyList();
+ } else if (arrayOI.getListLength(array) < 0) {
+ return null;
+ }
+
+ List<?> retArray = arrayOI.getList(array);
+ List<Object> result = new ArrayList<>();
+ retArray.forEach(
+ element -> {
+ result.add(element);
+ result.add(element);
+ });
+ return result;
+ }
+
+ @Override
+ public String getDisplayString(String[] children) {
+ return "array_duplicate";
+ }
+}
diff --git
a/backends-velox/src/test/scala/org/apache/spark/sql/execution/GlutenHiveUDFSuite.scala
b/backends-velox/src/test/scala/org/apache/spark/sql/execution/GlutenHiveUDFSuite.scala
index c5c6981a08..6919a30390 100644
---
a/backends-velox/src/test/scala/org/apache/spark/sql/execution/GlutenHiveUDFSuite.scala
+++
b/backends-velox/src/test/scala/org/apache/spark/sql/execution/GlutenHiveUDFSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution
import org.apache.gluten.config.GlutenConfig
import org.apache.gluten.execution.{ColumnarPartialGenerateExec,
ColumnarPartialProjectExec, GlutenQueryComparisonTest}
import org.apache.gluten.expression.UDFMappings
-import org.apache.gluten.udf.CustomerUDF
+import org.apache.gluten.udf.{CustomerUDF, DuplicateArray}
import org.apache.gluten.udtf.{ConditionalOutputUDTF, CustomerUDTF,
NoInputUDTF, SimpleUDTF}
import org.apache.spark.SparkConf
@@ -326,4 +326,57 @@ class GlutenHiveUDFSuite extends GlutenQueryComparisonTest
with SQLTestUtils {
}
}
}
+
+ test("udf with map with null values") {
+ withTempFunction("udf_map_values") {
+ sql("""
+ |CREATE TEMPORARY FUNCTION udf_map_values AS
+ |'org.apache.hadoop.hive.ql.udf.generic.GenericUDFMapValues';
+ |""".stripMargin)
+
+ runQueryAndCompare("""
+ |SELECT
+ | l_partkey,
+ | udf_map_values(map_data)
+ |FROM (
+ | SELECT l_partkey,
+ | map(
+ | concat('hello', l_orderkey % 2),
+ | CASE WHEN l_orderkey % 2 == 0 THEN l_orderkey
ELSE null END,
+ | concat('world', l_orderkey % 2),
+ | CASE WHEN l_orderkey % 2 == 0 THEN l_orderkey
ELSE null END
+ | ) as map_data
+ | FROM lineitem
+ |)
+ |""".stripMargin) {
+ checkOperatorMatch[ColumnarPartialProjectExec]
+ }
+ }
+ }
+
+ test("udf with array with null values") {
+ withTempFunction("udf_array_distinct") {
+ sql(s"""
+ |CREATE TEMPORARY FUNCTION udf_array_distinct AS
'${classOf[DuplicateArray].getName}'
+ |""".stripMargin)
+
+ runQueryAndCompare("""
+ |SELECT
+ | l_partkey,
+ | udf_array_distinct(map_data)
+ |FROM (
+ | SELECT l_partkey,
+ | array(
+ | l_orderkey % 2,
+ | CASE WHEN l_orderkey % 2 == 0 THEN l_orderkey
ELSE null END,
+ | l_orderkey % 2,
+ | CASE WHEN l_orderkey % 2 == 0 THEN l_orderkey
ELSE null END
+ | ) as map_data
+ | FROM lineitem
+ |)
+ |""".stripMargin) {
+ checkOperatorMatch[ColumnarPartialProjectExec]
+ }
+ }
+ }
}
diff --git
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowColumnarArray.java
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowColumnarArray.java
new file mode 100644
index 0000000000..3ea0444ee0
--- /dev/null
+++
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowColumnarArray.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.vectorized;
+
+import org.apache.spark.sql.execution.vectorized.ColumnarArrayShim;
+import org.apache.spark.sql.vectorized.ColumnVector;
+
+/**
+ * Because `get` method in `ColumnarArray` don't check whether the data to get
is null and arrow
+ * vectors will throw exception when we try to access null value, so we define
the following class
+ * as a workaround. Its implementation is copied from Spark-4.0, except that
the `handleNull`
+ * parameter is set to true when we call `SpecializedGettersReader.read` in
`get`, which means that
+ * when trying to access a value of the array, we will check whether the value
to get is null first.
+ *
+ * <p>The actual implementation is put in [[ColumnarArrayShim]] because
Variant data type is
+ * introduced in Spark-4.0.
+ */
+public class ArrowColumnarArray extends ColumnarArrayShim {
+ public ArrowColumnarArray(ColumnVector data, int offset, int length) {
+ super(data, offset, length);
+ }
+}
diff --git
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowColumnarMap.java
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowColumnarMap.java
new file mode 100644
index 0000000000..b6bfacb835
--- /dev/null
+++
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowColumnarMap.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.gluten.vectorized;
+
+import org.apache.spark.sql.catalyst.util.ArrayBasedMapData;
+import org.apache.spark.sql.catalyst.util.ArrayData;
+import org.apache.spark.sql.catalyst.util.MapData;
+import org.apache.spark.sql.vectorized.ColumnVector;
+
+/** See [[ArrowColumnarArray]]. */
+public class ArrowColumnarMap extends MapData {
+ private final ArrowColumnarArray keys;
+ private final ArrowColumnarArray values;
+ private final int length;
+
+ public ArrowColumnarMap(ColumnVector keys, ColumnVector values, int offset,
int length) {
+ this.length = length;
+ this.keys = new ArrowColumnarArray(keys, offset, length);
+ this.values = new ArrowColumnarArray(values, offset, length);
+ }
+
+ @Override
+ public int numElements() {
+ return length;
+ }
+
+ @Override
+ public ArrayData keyArray() {
+ return keys;
+ }
+
+ @Override
+ public ArrayData valueArray() {
+ return values;
+ }
+
+ @Override
+ public MapData copy() {
+ return new ArrayBasedMapData(keys.copy(), values.copy());
+ }
+}
diff --git
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowWritableColumnVector.java
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowWritableColumnVector.java
index 7a4585dce7..1c28e6d577 100644
---
a/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowWritableColumnVector.java
+++
b/gluten-arrow/src/main/java/org/apache/gluten/vectorized/ArrowWritableColumnVector.java
@@ -38,8 +38,6 @@ import
org.apache.spark.sql.execution.vectorized.WritableColumnVectorShim;
import org.apache.spark.sql.types.*;
import org.apache.spark.sql.utils.SparkArrowUtil;
import org.apache.spark.sql.utils.SparkSchemaUtil;
-import org.apache.spark.sql.vectorized.ColumnarArray;
-import org.apache.spark.sql.vectorized.ColumnarMap;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.types.UTF8String;
import org.slf4j.Logger;
@@ -411,6 +409,22 @@ public final class ArrowWritableColumnVector extends
WritableColumnVectorShim {
return "vectorCounter is " + vectorCount.get();
}
+ // `get` method in Spark `ColumnarArray` doesn't check null values and arrow
vectors will throw
+ // exception when we try to access a null value, so here we return
`ArrowColumnarMap`
+ // as a workaround.
+ public MapData getMapInternal(int rowId) {
+ return accessor.getMap(rowId);
+ }
+
+ // `get` method in Spark `ColumnarArray` doesn't check null values and arrow
vectors will throw
+ // exception when we try to access a null value, so here we return
`ArrowColumnarMap`
+ // as a workaround.
+ public ArrayData getArrayInternal(int rowId) {
+ return accessor.getArray(rowId);
+ }
+
+ // `get` method in Spark `ColumnarRow` doesn't check whether the data to get
is a null value,
+ // so we return `ArrowColumnarRow` as a workaround.
public ArrowColumnarRow getStructInternal(int rowId) {
if (isNullAt(rowId)) return null;
ArrowWritableColumnVector[] writableColumns =
@@ -893,7 +907,7 @@ public final class ArrowWritableColumnVector extends
WritableColumnVectorShim {
throw new UnsupportedOperationException();
}
- ColumnarArray getArray(int rowId) {
+ ArrayData getArray(int rowId) {
throw new UnsupportedOperationException();
}
@@ -905,7 +919,7 @@ public final class ArrowWritableColumnVector extends
WritableColumnVectorShim {
throw new UnsupportedOperationException();
}
- ColumnarMap getMap(int rowId) {
+ MapData getMap(int rowId) {
throw new UnsupportedOperationException();
}
}
@@ -1239,8 +1253,8 @@ public final class ArrowWritableColumnVector extends
WritableColumnVectorShim {
}
@Override
- final ColumnarArray getArray(int rowId) {
- return new ColumnarArray(elements, getArrayOffset(rowId),
getArrayLength(rowId));
+ final ArrayData getArray(int rowId) {
+ return new ArrowColumnarArray(elements, getArrayOffset(rowId),
getArrayLength(rowId));
}
@Override
@@ -1264,11 +1278,11 @@ public final class ArrowWritableColumnVector extends
WritableColumnVectorShim {
}
@Override
- final ColumnarMap getMap(int rowId) {
+ final MapData getMap(int rowId) {
int index = rowId * MapVector.OFFSET_WIDTH;
int offset = accessor.getOffsetBuffer().getInt(index);
int length = accessor.getInnerValueCountAt(rowId);
- return new ColumnarMap(keys, values, offset, length);
+ return new ArrowColumnarMap(keys, values, offset, length);
}
@Override
diff --git
a/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ArrowColumnarRow.scala
b/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ArrowColumnarRow.scala
index 2ee3d081c3..121ca01639 100644
---
a/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ArrowColumnarRow.scala
+++
b/gluten-arrow/src/main/scala/org/apache/gluten/vectorized/ArrowColumnarRow.scala
@@ -21,8 +21,8 @@ import org.apache.gluten.execution.InternalRowSparkCompatible
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
+import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._
-import org.apache.spark.sql.vectorized.{ColumnarArray, ColumnarMap}
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import java.math.BigDecimal
@@ -111,11 +111,11 @@ final class ArrowColumnarRow(writableColumns:
Array[ArrowWritableColumnVector],
override def getStruct(ordinal: Int, numFields: Int): ArrowColumnarRow =
columns(ordinal).getStructInternal(rowId)
- override def getArray(ordinal: Int): ColumnarArray =
- columns(ordinal).getArray(rowId)
+ override def getArray(ordinal: Int): ArrayData =
+ columns(ordinal).getArrayInternal(rowId)
- override def getMap(ordinal: Int): ColumnarMap =
- columns(ordinal).getMap(rowId)
+ override def getMap(ordinal: Int): MapData =
+ columns(ordinal).getMapInternal(rowId)
override def get(ordinal: Int, dataType: DataType): AnyRef = {
if (isNullAt(ordinal)) {
diff --git
a/shims/spark32/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java
b/shims/spark32/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java
new file mode 100644
index 0000000000..21594a155a
--- /dev/null
+++
b/shims/spark32/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.vectorized;
+
+import org.apache.spark.sql.catalyst.expressions.SpecializedGettersReader;
+import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData;
+import org.apache.spark.sql.catalyst.util.ArrayData;
+import org.apache.spark.sql.catalyst.util.GenericArrayData;
+import org.apache.spark.sql.types.*;
+import org.apache.spark.sql.vectorized.ColumnVector;
+import org.apache.spark.sql.vectorized.ColumnarArray;
+import org.apache.spark.sql.vectorized.ColumnarMap;
+import org.apache.spark.sql.vectorized.ColumnarRow;
+import org.apache.spark.unsafe.types.CalendarInterval;
+import org.apache.spark.unsafe.types.UTF8String;
+
+public class ColumnarArrayShim extends ArrayData {
+ // The data for this array. This array contains elements from
+ // data[offset] to data[offset + length).
+ private final ColumnVector data;
+ private final int offset;
+ private final int length;
+
+ public ColumnarArrayShim(ColumnVector data, int offset, int length) {
+ this.data = data;
+ this.offset = offset;
+ this.length = length;
+ }
+
+ @Override
+ public int numElements() {
+ return length;
+ }
+
+ /**
+ * Sets all the appropriate null bits in the input UnsafeArrayData.
+ *
+ * @param arrayData The UnsafeArrayData to set the null bits for
+ * @return The UnsafeArrayData with the null bits set
+ */
+ private UnsafeArrayData setNullBits(UnsafeArrayData arrayData) {
+ if (data.hasNull()) {
+ for (int i = 0; i < length; i++) {
+ if (data.isNullAt(offset + i)) {
+ arrayData.setNullAt(i);
+ }
+ }
+ }
+ return arrayData;
+ }
+
+ @Override
+ public ArrayData copy() {
+ DataType dt = data.dataType();
+
+ if (dt instanceof BooleanType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toBooleanArray()));
+ } else if (dt instanceof ByteType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toByteArray()));
+ } else if (dt instanceof ShortType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toShortArray()));
+ } else if (dt instanceof IntegerType
+ || dt instanceof DateType
+ || dt instanceof YearMonthIntervalType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toIntArray()));
+ } else if (dt instanceof LongType
+ || dt instanceof TimestampType
+ || dt instanceof DayTimeIntervalType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toLongArray()));
+ } else if (dt instanceof FloatType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toFloatArray()));
+ } else if (dt instanceof DoubleType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toDoubleArray()));
+ } else {
+ return new GenericArrayData(toObjectArray(dt)).copy(); // ensure the
elements are copied.
+ }
+ }
+
+ @Override
+ public boolean[] toBooleanArray() {
+ return data.getBooleans(offset, length);
+ }
+
+ @Override
+ public byte[] toByteArray() {
+ return data.getBytes(offset, length);
+ }
+
+ @Override
+ public short[] toShortArray() {
+ return data.getShorts(offset, length);
+ }
+
+ @Override
+ public int[] toIntArray() {
+ return data.getInts(offset, length);
+ }
+
+ @Override
+ public long[] toLongArray() {
+ return data.getLongs(offset, length);
+ }
+
+ @Override
+ public float[] toFloatArray() {
+ return data.getFloats(offset, length);
+ }
+
+ @Override
+ public double[] toDoubleArray() {
+ return data.getDoubles(offset, length);
+ }
+
+ // TODO: this is extremely expensive.
+ @Override
+ public Object[] array() {
+ DataType dt = data.dataType();
+ Object[] list = new Object[length];
+ try {
+ for (int i = 0; i < length; i++) {
+ if (!data.isNullAt(offset + i)) {
+ list[i] = get(i, dt);
+ }
+ }
+ return list;
+ } catch (Exception e) {
+ throw new RuntimeException("Could not get the array", e);
+ }
+ }
+
+ @Override
+ public boolean isNullAt(int ordinal) {
+ return data.isNullAt(offset + ordinal);
+ }
+
+ @Override
+ public boolean getBoolean(int ordinal) {
+ return data.getBoolean(offset + ordinal);
+ }
+
+ @Override
+ public byte getByte(int ordinal) {
+ return data.getByte(offset + ordinal);
+ }
+
+ @Override
+ public short getShort(int ordinal) {
+ return data.getShort(offset + ordinal);
+ }
+
+ @Override
+ public int getInt(int ordinal) {
+ return data.getInt(offset + ordinal);
+ }
+
+ @Override
+ public long getLong(int ordinal) {
+ return data.getLong(offset + ordinal);
+ }
+
+ @Override
+ public float getFloat(int ordinal) {
+ return data.getFloat(offset + ordinal);
+ }
+
+ @Override
+ public double getDouble(int ordinal) {
+ return data.getDouble(offset + ordinal);
+ }
+
+ @Override
+ public Decimal getDecimal(int ordinal, int precision, int scale) {
+ return data.getDecimal(offset + ordinal, precision, scale);
+ }
+
+ @Override
+ public UTF8String getUTF8String(int ordinal) {
+ return data.getUTF8String(offset + ordinal);
+ }
+
+ @Override
+ public byte[] getBinary(int ordinal) {
+ return data.getBinary(offset + ordinal);
+ }
+
+ @Override
+ public CalendarInterval getInterval(int ordinal) {
+ return data.getInterval(offset + ordinal);
+ }
+
+ @Override
+ public ColumnarRow getStruct(int ordinal, int numFields) {
+ return data.getStruct(offset + ordinal);
+ }
+
+ @Override
+ public ColumnarArray getArray(int ordinal) {
+ return data.getArray(offset + ordinal);
+ }
+
+ @Override
+ public ColumnarMap getMap(int ordinal) {
+ return data.getMap(offset + ordinal);
+ }
+
+ @Override
+ public Object get(int ordinal, DataType dataType) {
+ return SpecializedGettersReader.read(this, ordinal, dataType, true, false);
+ }
+
+ @Override
+ public void update(int ordinal, Object value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setNullAt(int ordinal) {
+ throw new UnsupportedOperationException();
+ }
+}
diff --git
a/shims/spark33/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java
b/shims/spark33/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java
new file mode 100644
index 0000000000..21594a155a
--- /dev/null
+++
b/shims/spark33/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.vectorized;
+
+import org.apache.spark.sql.catalyst.expressions.SpecializedGettersReader;
+import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData;
+import org.apache.spark.sql.catalyst.util.ArrayData;
+import org.apache.spark.sql.catalyst.util.GenericArrayData;
+import org.apache.spark.sql.types.*;
+import org.apache.spark.sql.vectorized.ColumnVector;
+import org.apache.spark.sql.vectorized.ColumnarArray;
+import org.apache.spark.sql.vectorized.ColumnarMap;
+import org.apache.spark.sql.vectorized.ColumnarRow;
+import org.apache.spark.unsafe.types.CalendarInterval;
+import org.apache.spark.unsafe.types.UTF8String;
+
+public class ColumnarArrayShim extends ArrayData {
+ // The data for this array. This array contains elements from
+ // data[offset] to data[offset + length).
+ private final ColumnVector data;
+ private final int offset;
+ private final int length;
+
+ public ColumnarArrayShim(ColumnVector data, int offset, int length) {
+ this.data = data;
+ this.offset = offset;
+ this.length = length;
+ }
+
+ @Override
+ public int numElements() {
+ return length;
+ }
+
+ /**
+ * Sets all the appropriate null bits in the input UnsafeArrayData.
+ *
+ * @param arrayData The UnsafeArrayData to set the null bits for
+ * @return The UnsafeArrayData with the null bits set
+ */
+ private UnsafeArrayData setNullBits(UnsafeArrayData arrayData) {
+ if (data.hasNull()) {
+ for (int i = 0; i < length; i++) {
+ if (data.isNullAt(offset + i)) {
+ arrayData.setNullAt(i);
+ }
+ }
+ }
+ return arrayData;
+ }
+
+ @Override
+ public ArrayData copy() {
+ DataType dt = data.dataType();
+
+ if (dt instanceof BooleanType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toBooleanArray()));
+ } else if (dt instanceof ByteType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toByteArray()));
+ } else if (dt instanceof ShortType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toShortArray()));
+ } else if (dt instanceof IntegerType
+ || dt instanceof DateType
+ || dt instanceof YearMonthIntervalType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toIntArray()));
+ } else if (dt instanceof LongType
+ || dt instanceof TimestampType
+ || dt instanceof DayTimeIntervalType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toLongArray()));
+ } else if (dt instanceof FloatType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toFloatArray()));
+ } else if (dt instanceof DoubleType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toDoubleArray()));
+ } else {
+ return new GenericArrayData(toObjectArray(dt)).copy(); // ensure the
elements are copied.
+ }
+ }
+
+ @Override
+ public boolean[] toBooleanArray() {
+ return data.getBooleans(offset, length);
+ }
+
+ @Override
+ public byte[] toByteArray() {
+ return data.getBytes(offset, length);
+ }
+
+ @Override
+ public short[] toShortArray() {
+ return data.getShorts(offset, length);
+ }
+
+ @Override
+ public int[] toIntArray() {
+ return data.getInts(offset, length);
+ }
+
+ @Override
+ public long[] toLongArray() {
+ return data.getLongs(offset, length);
+ }
+
+ @Override
+ public float[] toFloatArray() {
+ return data.getFloats(offset, length);
+ }
+
+ @Override
+ public double[] toDoubleArray() {
+ return data.getDoubles(offset, length);
+ }
+
+ // TODO: this is extremely expensive.
+ @Override
+ public Object[] array() {
+ DataType dt = data.dataType();
+ Object[] list = new Object[length];
+ try {
+ for (int i = 0; i < length; i++) {
+ if (!data.isNullAt(offset + i)) {
+ list[i] = get(i, dt);
+ }
+ }
+ return list;
+ } catch (Exception e) {
+ throw new RuntimeException("Could not get the array", e);
+ }
+ }
+
+ @Override
+ public boolean isNullAt(int ordinal) {
+ return data.isNullAt(offset + ordinal);
+ }
+
+ @Override
+ public boolean getBoolean(int ordinal) {
+ return data.getBoolean(offset + ordinal);
+ }
+
+ @Override
+ public byte getByte(int ordinal) {
+ return data.getByte(offset + ordinal);
+ }
+
+ @Override
+ public short getShort(int ordinal) {
+ return data.getShort(offset + ordinal);
+ }
+
+ @Override
+ public int getInt(int ordinal) {
+ return data.getInt(offset + ordinal);
+ }
+
+ @Override
+ public long getLong(int ordinal) {
+ return data.getLong(offset + ordinal);
+ }
+
+ @Override
+ public float getFloat(int ordinal) {
+ return data.getFloat(offset + ordinal);
+ }
+
+ @Override
+ public double getDouble(int ordinal) {
+ return data.getDouble(offset + ordinal);
+ }
+
+ @Override
+ public Decimal getDecimal(int ordinal, int precision, int scale) {
+ return data.getDecimal(offset + ordinal, precision, scale);
+ }
+
+ @Override
+ public UTF8String getUTF8String(int ordinal) {
+ return data.getUTF8String(offset + ordinal);
+ }
+
+ @Override
+ public byte[] getBinary(int ordinal) {
+ return data.getBinary(offset + ordinal);
+ }
+
+ @Override
+ public CalendarInterval getInterval(int ordinal) {
+ return data.getInterval(offset + ordinal);
+ }
+
+ @Override
+ public ColumnarRow getStruct(int ordinal, int numFields) {
+ return data.getStruct(offset + ordinal);
+ }
+
+ @Override
+ public ColumnarArray getArray(int ordinal) {
+ return data.getArray(offset + ordinal);
+ }
+
+ @Override
+ public ColumnarMap getMap(int ordinal) {
+ return data.getMap(offset + ordinal);
+ }
+
+ @Override
+ public Object get(int ordinal, DataType dataType) {
+ return SpecializedGettersReader.read(this, ordinal, dataType, true, false);
+ }
+
+ @Override
+ public void update(int ordinal, Object value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setNullAt(int ordinal) {
+ throw new UnsupportedOperationException();
+ }
+}
diff --git
a/shims/spark34/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java
b/shims/spark34/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java
new file mode 100644
index 0000000000..21594a155a
--- /dev/null
+++
b/shims/spark34/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.vectorized;
+
+import org.apache.spark.sql.catalyst.expressions.SpecializedGettersReader;
+import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData;
+import org.apache.spark.sql.catalyst.util.ArrayData;
+import org.apache.spark.sql.catalyst.util.GenericArrayData;
+import org.apache.spark.sql.types.*;
+import org.apache.spark.sql.vectorized.ColumnVector;
+import org.apache.spark.sql.vectorized.ColumnarArray;
+import org.apache.spark.sql.vectorized.ColumnarMap;
+import org.apache.spark.sql.vectorized.ColumnarRow;
+import org.apache.spark.unsafe.types.CalendarInterval;
+import org.apache.spark.unsafe.types.UTF8String;
+
+public class ColumnarArrayShim extends ArrayData {
+ // The data for this array. This array contains elements from
+ // data[offset] to data[offset + length).
+ private final ColumnVector data;
+ private final int offset;
+ private final int length;
+
+ public ColumnarArrayShim(ColumnVector data, int offset, int length) {
+ this.data = data;
+ this.offset = offset;
+ this.length = length;
+ }
+
+ @Override
+ public int numElements() {
+ return length;
+ }
+
+ /**
+ * Sets all the appropriate null bits in the input UnsafeArrayData.
+ *
+ * @param arrayData The UnsafeArrayData to set the null bits for
+ * @return The UnsafeArrayData with the null bits set
+ */
+ private UnsafeArrayData setNullBits(UnsafeArrayData arrayData) {
+ if (data.hasNull()) {
+ for (int i = 0; i < length; i++) {
+ if (data.isNullAt(offset + i)) {
+ arrayData.setNullAt(i);
+ }
+ }
+ }
+ return arrayData;
+ }
+
+ @Override
+ public ArrayData copy() {
+ DataType dt = data.dataType();
+
+ if (dt instanceof BooleanType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toBooleanArray()));
+ } else if (dt instanceof ByteType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toByteArray()));
+ } else if (dt instanceof ShortType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toShortArray()));
+ } else if (dt instanceof IntegerType
+ || dt instanceof DateType
+ || dt instanceof YearMonthIntervalType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toIntArray()));
+ } else if (dt instanceof LongType
+ || dt instanceof TimestampType
+ || dt instanceof DayTimeIntervalType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toLongArray()));
+ } else if (dt instanceof FloatType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toFloatArray()));
+ } else if (dt instanceof DoubleType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toDoubleArray()));
+ } else {
+ return new GenericArrayData(toObjectArray(dt)).copy(); // ensure the
elements are copied.
+ }
+ }
+
+ @Override
+ public boolean[] toBooleanArray() {
+ return data.getBooleans(offset, length);
+ }
+
+ @Override
+ public byte[] toByteArray() {
+ return data.getBytes(offset, length);
+ }
+
+ @Override
+ public short[] toShortArray() {
+ return data.getShorts(offset, length);
+ }
+
+ @Override
+ public int[] toIntArray() {
+ return data.getInts(offset, length);
+ }
+
+ @Override
+ public long[] toLongArray() {
+ return data.getLongs(offset, length);
+ }
+
+ @Override
+ public float[] toFloatArray() {
+ return data.getFloats(offset, length);
+ }
+
+ @Override
+ public double[] toDoubleArray() {
+ return data.getDoubles(offset, length);
+ }
+
+ // TODO: this is extremely expensive.
+ @Override
+ public Object[] array() {
+ DataType dt = data.dataType();
+ Object[] list = new Object[length];
+ try {
+ for (int i = 0; i < length; i++) {
+ if (!data.isNullAt(offset + i)) {
+ list[i] = get(i, dt);
+ }
+ }
+ return list;
+ } catch (Exception e) {
+ throw new RuntimeException("Could not get the array", e);
+ }
+ }
+
+ @Override
+ public boolean isNullAt(int ordinal) {
+ return data.isNullAt(offset + ordinal);
+ }
+
+ @Override
+ public boolean getBoolean(int ordinal) {
+ return data.getBoolean(offset + ordinal);
+ }
+
+ @Override
+ public byte getByte(int ordinal) {
+ return data.getByte(offset + ordinal);
+ }
+
+ @Override
+ public short getShort(int ordinal) {
+ return data.getShort(offset + ordinal);
+ }
+
+ @Override
+ public int getInt(int ordinal) {
+ return data.getInt(offset + ordinal);
+ }
+
+ @Override
+ public long getLong(int ordinal) {
+ return data.getLong(offset + ordinal);
+ }
+
+ @Override
+ public float getFloat(int ordinal) {
+ return data.getFloat(offset + ordinal);
+ }
+
+ @Override
+ public double getDouble(int ordinal) {
+ return data.getDouble(offset + ordinal);
+ }
+
+ @Override
+ public Decimal getDecimal(int ordinal, int precision, int scale) {
+ return data.getDecimal(offset + ordinal, precision, scale);
+ }
+
+ @Override
+ public UTF8String getUTF8String(int ordinal) {
+ return data.getUTF8String(offset + ordinal);
+ }
+
+ @Override
+ public byte[] getBinary(int ordinal) {
+ return data.getBinary(offset + ordinal);
+ }
+
+ @Override
+ public CalendarInterval getInterval(int ordinal) {
+ return data.getInterval(offset + ordinal);
+ }
+
+ @Override
+ public ColumnarRow getStruct(int ordinal, int numFields) {
+ return data.getStruct(offset + ordinal);
+ }
+
+ @Override
+ public ColumnarArray getArray(int ordinal) {
+ return data.getArray(offset + ordinal);
+ }
+
+ @Override
+ public ColumnarMap getMap(int ordinal) {
+ return data.getMap(offset + ordinal);
+ }
+
+ @Override
+ public Object get(int ordinal, DataType dataType) {
+ return SpecializedGettersReader.read(this, ordinal, dataType, true, false);
+ }
+
+ @Override
+ public void update(int ordinal, Object value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setNullAt(int ordinal) {
+ throw new UnsupportedOperationException();
+ }
+}
diff --git
a/shims/spark35/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java
b/shims/spark35/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java
new file mode 100644
index 0000000000..21594a155a
--- /dev/null
+++
b/shims/spark35/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.vectorized;
+
+import org.apache.spark.sql.catalyst.expressions.SpecializedGettersReader;
+import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData;
+import org.apache.spark.sql.catalyst.util.ArrayData;
+import org.apache.spark.sql.catalyst.util.GenericArrayData;
+import org.apache.spark.sql.types.*;
+import org.apache.spark.sql.vectorized.ColumnVector;
+import org.apache.spark.sql.vectorized.ColumnarArray;
+import org.apache.spark.sql.vectorized.ColumnarMap;
+import org.apache.spark.sql.vectorized.ColumnarRow;
+import org.apache.spark.unsafe.types.CalendarInterval;
+import org.apache.spark.unsafe.types.UTF8String;
+
+public class ColumnarArrayShim extends ArrayData {
+ // The data for this array. This array contains elements from
+ // data[offset] to data[offset + length).
+ private final ColumnVector data;
+ private final int offset;
+ private final int length;
+
+ public ColumnarArrayShim(ColumnVector data, int offset, int length) {
+ this.data = data;
+ this.offset = offset;
+ this.length = length;
+ }
+
+ @Override
+ public int numElements() {
+ return length;
+ }
+
+ /**
+ * Sets all the appropriate null bits in the input UnsafeArrayData.
+ *
+ * @param arrayData The UnsafeArrayData to set the null bits for
+ * @return The UnsafeArrayData with the null bits set
+ */
+ private UnsafeArrayData setNullBits(UnsafeArrayData arrayData) {
+ if (data.hasNull()) {
+ for (int i = 0; i < length; i++) {
+ if (data.isNullAt(offset + i)) {
+ arrayData.setNullAt(i);
+ }
+ }
+ }
+ return arrayData;
+ }
+
+ @Override
+ public ArrayData copy() {
+ DataType dt = data.dataType();
+
+ if (dt instanceof BooleanType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toBooleanArray()));
+ } else if (dt instanceof ByteType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toByteArray()));
+ } else if (dt instanceof ShortType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toShortArray()));
+ } else if (dt instanceof IntegerType
+ || dt instanceof DateType
+ || dt instanceof YearMonthIntervalType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toIntArray()));
+ } else if (dt instanceof LongType
+ || dt instanceof TimestampType
+ || dt instanceof DayTimeIntervalType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toLongArray()));
+ } else if (dt instanceof FloatType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toFloatArray()));
+ } else if (dt instanceof DoubleType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toDoubleArray()));
+ } else {
+ return new GenericArrayData(toObjectArray(dt)).copy(); // ensure the
elements are copied.
+ }
+ }
+
+ @Override
+ public boolean[] toBooleanArray() {
+ return data.getBooleans(offset, length);
+ }
+
+ @Override
+ public byte[] toByteArray() {
+ return data.getBytes(offset, length);
+ }
+
+ @Override
+ public short[] toShortArray() {
+ return data.getShorts(offset, length);
+ }
+
+ @Override
+ public int[] toIntArray() {
+ return data.getInts(offset, length);
+ }
+
+ @Override
+ public long[] toLongArray() {
+ return data.getLongs(offset, length);
+ }
+
+ @Override
+ public float[] toFloatArray() {
+ return data.getFloats(offset, length);
+ }
+
+ @Override
+ public double[] toDoubleArray() {
+ return data.getDoubles(offset, length);
+ }
+
+ // TODO: this is extremely expensive.
+ @Override
+ public Object[] array() {
+ DataType dt = data.dataType();
+ Object[] list = new Object[length];
+ try {
+ for (int i = 0; i < length; i++) {
+ if (!data.isNullAt(offset + i)) {
+ list[i] = get(i, dt);
+ }
+ }
+ return list;
+ } catch (Exception e) {
+ throw new RuntimeException("Could not get the array", e);
+ }
+ }
+
+ @Override
+ public boolean isNullAt(int ordinal) {
+ return data.isNullAt(offset + ordinal);
+ }
+
+ @Override
+ public boolean getBoolean(int ordinal) {
+ return data.getBoolean(offset + ordinal);
+ }
+
+ @Override
+ public byte getByte(int ordinal) {
+ return data.getByte(offset + ordinal);
+ }
+
+ @Override
+ public short getShort(int ordinal) {
+ return data.getShort(offset + ordinal);
+ }
+
+ @Override
+ public int getInt(int ordinal) {
+ return data.getInt(offset + ordinal);
+ }
+
+ @Override
+ public long getLong(int ordinal) {
+ return data.getLong(offset + ordinal);
+ }
+
+ @Override
+ public float getFloat(int ordinal) {
+ return data.getFloat(offset + ordinal);
+ }
+
+ @Override
+ public double getDouble(int ordinal) {
+ return data.getDouble(offset + ordinal);
+ }
+
+ @Override
+ public Decimal getDecimal(int ordinal, int precision, int scale) {
+ return data.getDecimal(offset + ordinal, precision, scale);
+ }
+
+ @Override
+ public UTF8String getUTF8String(int ordinal) {
+ return data.getUTF8String(offset + ordinal);
+ }
+
+ @Override
+ public byte[] getBinary(int ordinal) {
+ return data.getBinary(offset + ordinal);
+ }
+
+ @Override
+ public CalendarInterval getInterval(int ordinal) {
+ return data.getInterval(offset + ordinal);
+ }
+
+ @Override
+ public ColumnarRow getStruct(int ordinal, int numFields) {
+ return data.getStruct(offset + ordinal);
+ }
+
+ @Override
+ public ColumnarArray getArray(int ordinal) {
+ return data.getArray(offset + ordinal);
+ }
+
+ @Override
+ public ColumnarMap getMap(int ordinal) {
+ return data.getMap(offset + ordinal);
+ }
+
+ @Override
+ public Object get(int ordinal, DataType dataType) {
+ return SpecializedGettersReader.read(this, ordinal, dataType, true, false);
+ }
+
+ @Override
+ public void update(int ordinal, Object value) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public void setNullAt(int ordinal) {
+ throw new UnsupportedOperationException();
+ }
+}
diff --git
a/shims/spark40/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java
b/shims/spark40/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java
new file mode 100644
index 0000000000..25adf5d233
--- /dev/null
+++
b/shims/spark40/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArrayShim.java
@@ -0,0 +1,241 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.execution.vectorized;
+
+import org.apache.spark.SparkUnsupportedOperationException;
+import org.apache.spark.sql.catalyst.expressions.SpecializedGettersReader;
+import org.apache.spark.sql.catalyst.expressions.UnsafeArrayData;
+import org.apache.spark.sql.catalyst.util.ArrayData;
+import org.apache.spark.sql.catalyst.util.GenericArrayData;
+import org.apache.spark.sql.types.*;
+import org.apache.spark.sql.vectorized.ColumnVector;
+import org.apache.spark.sql.vectorized.ColumnarArray;
+import org.apache.spark.sql.vectorized.ColumnarMap;
+import org.apache.spark.sql.vectorized.ColumnarRow;
+import org.apache.spark.unsafe.types.CalendarInterval;
+import org.apache.spark.unsafe.types.UTF8String;
+import org.apache.spark.unsafe.types.VariantVal;
+
+public class ColumnarArrayShim extends ArrayData {
+ // The data for this array. This array contains elements from
+ // data[offset] to data[offset + length).
+ private final ColumnVector data;
+ private final int offset;
+ private final int length;
+
+ public ColumnarArrayShim(ColumnVector data, int offset, int length) {
+ this.data = data;
+ this.offset = offset;
+ this.length = length;
+ }
+
+ @Override
+ public int numElements() {
+ return length;
+ }
+
+ /**
+ * Sets all the appropriate null bits in the input UnsafeArrayData.
+ *
+ * @param arrayData The UnsafeArrayData to set the null bits for
+ * @return The UnsafeArrayData with the null bits set
+ */
+ private UnsafeArrayData setNullBits(UnsafeArrayData arrayData) {
+ if (data.hasNull()) {
+ for (int i = 0; i < length; i++) {
+ if (data.isNullAt(offset + i)) {
+ arrayData.setNullAt(i);
+ }
+ }
+ }
+ return arrayData;
+ }
+
+ @Override
+ public ArrayData copy() {
+ DataType dt = data.dataType();
+
+ if (dt instanceof BooleanType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toBooleanArray()));
+ } else if (dt instanceof ByteType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toByteArray()));
+ } else if (dt instanceof ShortType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toShortArray()));
+ } else if (dt instanceof IntegerType
+ || dt instanceof DateType
+ || dt instanceof YearMonthIntervalType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toIntArray()));
+ } else if (dt instanceof LongType
+ || dt instanceof TimestampType
+ || dt instanceof DayTimeIntervalType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toLongArray()));
+ } else if (dt instanceof FloatType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toFloatArray()));
+ } else if (dt instanceof DoubleType) {
+ return setNullBits(UnsafeArrayData.fromPrimitiveArray(toDoubleArray()));
+ } else {
+ return new GenericArrayData(toObjectArray(dt)).copy(); // ensure the
elements are copied.
+ }
+ }
+
+ @Override
+ public boolean[] toBooleanArray() {
+ return data.getBooleans(offset, length);
+ }
+
+ @Override
+ public byte[] toByteArray() {
+ return data.getBytes(offset, length);
+ }
+
+ @Override
+ public short[] toShortArray() {
+ return data.getShorts(offset, length);
+ }
+
+ @Override
+ public int[] toIntArray() {
+ return data.getInts(offset, length);
+ }
+
+ @Override
+ public long[] toLongArray() {
+ return data.getLongs(offset, length);
+ }
+
+ @Override
+ public float[] toFloatArray() {
+ return data.getFloats(offset, length);
+ }
+
+ @Override
+ public double[] toDoubleArray() {
+ return data.getDoubles(offset, length);
+ }
+
+ // TODO: this is extremely expensive.
+ @Override
+ public Object[] array() {
+ DataType dt = data.dataType();
+ Object[] list = new Object[length];
+ try {
+ for (int i = 0; i < length; i++) {
+ if (!data.isNullAt(offset + i)) {
+ list[i] = get(i, dt);
+ }
+ }
+ return list;
+ } catch (Exception e) {
+ throw new RuntimeException("Could not get the array", e);
+ }
+ }
+
+ @Override
+ public boolean isNullAt(int ordinal) {
+ return data.isNullAt(offset + ordinal);
+ }
+
+ @Override
+ public boolean getBoolean(int ordinal) {
+ return data.getBoolean(offset + ordinal);
+ }
+
+ @Override
+ public byte getByte(int ordinal) {
+ return data.getByte(offset + ordinal);
+ }
+
+ @Override
+ public short getShort(int ordinal) {
+ return data.getShort(offset + ordinal);
+ }
+
+ @Override
+ public int getInt(int ordinal) {
+ return data.getInt(offset + ordinal);
+ }
+
+ @Override
+ public long getLong(int ordinal) {
+ return data.getLong(offset + ordinal);
+ }
+
+ @Override
+ public float getFloat(int ordinal) {
+ return data.getFloat(offset + ordinal);
+ }
+
+ @Override
+ public double getDouble(int ordinal) {
+ return data.getDouble(offset + ordinal);
+ }
+
+ @Override
+ public Decimal getDecimal(int ordinal, int precision, int scale) {
+ return data.getDecimal(offset + ordinal, precision, scale);
+ }
+
+ @Override
+ public UTF8String getUTF8String(int ordinal) {
+ return data.getUTF8String(offset + ordinal);
+ }
+
+ @Override
+ public byte[] getBinary(int ordinal) {
+ return data.getBinary(offset + ordinal);
+ }
+
+ @Override
+ public CalendarInterval getInterval(int ordinal) {
+ return data.getInterval(offset + ordinal);
+ }
+
+ @Override
+ public VariantVal getVariant(int ordinal) {
+ return data.getVariant(offset + ordinal);
+ }
+
+ @Override
+ public ColumnarRow getStruct(int ordinal, int numFields) {
+ return data.getStruct(offset + ordinal);
+ }
+
+ @Override
+ public ColumnarArray getArray(int ordinal) {
+ return data.getArray(offset + ordinal);
+ }
+
+ @Override
+ public ColumnarMap getMap(int ordinal) {
+ return data.getMap(offset + ordinal);
+ }
+
+ @Override
+ public Object get(int ordinal, DataType dataType) {
+ return SpecializedGettersReader.read(this, ordinal, dataType, true, false);
+ }
+
+ @Override
+ public void update(int ordinal, Object value) {
+ throw SparkUnsupportedOperationException.apply();
+ }
+
+ @Override
+ public void setNullAt(int ordinal) {
+ throw SparkUnsupportedOperationException.apply();
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]