This is an automated email from the ASF dual-hosted git repository.

lgbo pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git


The following commit(s) were added to refs/heads/main by this push:
     new eff142c5e8 [GLUTEN-9752][Flink] Support array/map in row/vector 
conversion (#9757)
eff142c5e8 is described below

commit eff142c5e8bacac476f43669dc27a4ea565467ba
Author: lgbo <[email protected]>
AuthorDate: Fri May 30 08:32:49 2025 +0800

    [GLUTEN-9752][Flink] Support array/map in row/vector conversion (#9757)
    
    Co-authored-by: PHILO-HE <[email protected]>
---
 .../apache/gluten/util/LogicalTypeConverter.java   |  99 +++--
 .../gluten/vectorized/ArrowVectorAccessor.java     |  96 ++++-
 .../gluten/vectorized/ArrowVectorWriter.java       | 440 ++++++++++++++++-----
 .../table/runtime/stream/custom/ScanTest.java      |  61 +++
 4 files changed, 547 insertions(+), 149 deletions(-)

diff --git 
a/gluten-flink/runtime/src/main/java/org/apache/gluten/util/LogicalTypeConverter.java
 
b/gluten-flink/runtime/src/main/java/org/apache/gluten/util/LogicalTypeConverter.java
index b39138ceab..30c3299990 100644
--- 
a/gluten-flink/runtime/src/main/java/org/apache/gluten/util/LogicalTypeConverter.java
+++ 
b/gluten-flink/runtime/src/main/java/org/apache/gluten/util/LogicalTypeConverter.java
@@ -16,9 +16,9 @@
  */
 package org.apache.gluten.util;
 
-import io.github.zhztheplayer.velox4j.type.IntegerType;
 import io.github.zhztheplayer.velox4j.type.Type;
 
+import org.apache.flink.table.types.logical.ArrayType;
 import org.apache.flink.table.types.logical.BigIntType;
 import org.apache.flink.table.types.logical.BooleanType;
 import org.apache.flink.table.types.logical.DayTimeIntervalType;
@@ -26,47 +26,84 @@ import org.apache.flink.table.types.logical.DecimalType;
 import org.apache.flink.table.types.logical.DoubleType;
 import org.apache.flink.table.types.logical.IntType;
 import org.apache.flink.table.types.logical.LogicalType;
+import org.apache.flink.table.types.logical.MapType;
 import org.apache.flink.table.types.logical.RowType;
 import org.apache.flink.table.types.logical.TimestampType;
 import org.apache.flink.table.types.logical.VarCharType;
 
 import java.util.List;
+import java.util.Map;
 import java.util.stream.Collectors;
 
-/** Convertor to convert Flink LogicalType to velox data Type */
+// Convertor to convert Flink LogicalType to velox data Type
 public class LogicalTypeConverter {
+  private interface VLTypeConverter {
+    Type build(LogicalType logicalType);
+  }
+
+  // Exact class matches
+  private static Map<Class<?>, VLTypeConverter> converters =
+      Map.ofEntries(
+          Map.entry(
+              BooleanType.class,
+              logicalType -> new 
io.github.zhztheplayer.velox4j.type.BooleanType()),
+          Map.entry(
+              IntType.class, logicalType -> new 
io.github.zhztheplayer.velox4j.type.IntegerType()),
+          Map.entry(
+              BigIntType.class,
+              logicalType -> new 
io.github.zhztheplayer.velox4j.type.BigIntType()),
+          Map.entry(
+              DoubleType.class,
+              logicalType -> new 
io.github.zhztheplayer.velox4j.type.DoubleType()),
+          Map.entry(
+              VarCharType.class,
+              logicalType -> new 
io.github.zhztheplayer.velox4j.type.VarCharType()),
+          // TODO: may need precision
+          Map.entry(
+              TimestampType.class,
+              logicalType -> new 
io.github.zhztheplayer.velox4j.type.TimestampType()),
+          Map.entry(
+              DecimalType.class,
+              logicalType -> {
+                DecimalType decimalType = (DecimalType) logicalType;
+                return new io.github.zhztheplayer.velox4j.type.DecimalType(
+                    decimalType.getPrecision(), decimalType.getScale());
+              }),
+          Map.entry(
+              DayTimeIntervalType.class,
+              logicalType -> new 
io.github.zhztheplayer.velox4j.type.BigIntType()),
+          Map.entry(
+              RowType.class,
+              logicalType -> {
+                RowType flinkRowType = (RowType) logicalType;
+                List<Type> fieldTypes =
+                    flinkRowType.getChildren().stream()
+                        .map(LogicalTypeConverter::toVLType)
+                        .collect(Collectors.toList());
+                return new io.github.zhztheplayer.velox4j.type.RowType(
+                    flinkRowType.getFieldNames(), fieldTypes);
+              }),
+          Map.entry(
+              ArrayType.class,
+              logicalType -> {
+                ArrayType arrayType = (ArrayType) logicalType;
+                Type elementType = toVLType(arrayType.getElementType());
+                return 
io.github.zhztheplayer.velox4j.type.ArrayType.create(elementType);
+              }),
+          Map.entry(
+              MapType.class,
+              logicalType -> {
+                MapType mapType = (MapType) logicalType;
+                Type keyType = toVLType(mapType.getKeyType());
+                Type valueType = toVLType(mapType.getValueType());
+                return 
io.github.zhztheplayer.velox4j.type.MapType.create(keyType, valueType);
+              }));
 
   public static Type toVLType(LogicalType logicalType) {
-    if (logicalType instanceof RowType) {
-      RowType flinkRowType = (RowType) logicalType;
-      List<Type> fieldTypes =
-          flinkRowType.getChildren().stream()
-              .map(LogicalTypeConverter::toVLType)
-              .collect(Collectors.toList());
-      return new io.github.zhztheplayer.velox4j.type.RowType(
-          flinkRowType.getFieldNames(), fieldTypes);
-    } else if (logicalType instanceof BooleanType) {
-      return new io.github.zhztheplayer.velox4j.type.BooleanType();
-    } else if (logicalType instanceof IntType) {
-      return new IntegerType();
-    } else if (logicalType instanceof BigIntType) {
-      return new io.github.zhztheplayer.velox4j.type.BigIntType();
-    } else if (logicalType instanceof DoubleType) {
-      return new io.github.zhztheplayer.velox4j.type.DoubleType();
-    } else if (logicalType instanceof VarCharType) {
-      return new io.github.zhztheplayer.velox4j.type.VarCharType();
-    } else if (logicalType instanceof TimestampType) {
-      // TODO: may need precision
-      return new io.github.zhztheplayer.velox4j.type.TimestampType();
-    } else if (logicalType instanceof DecimalType) {
-      DecimalType decimalType = (DecimalType) logicalType;
-      return new io.github.zhztheplayer.velox4j.type.DecimalType(
-          decimalType.getPrecision(), decimalType.getScale());
-    } else if (logicalType instanceof DayTimeIntervalType) {
-      // TODO: it seems interval now can be used as bigint for nexmark.
-      return new io.github.zhztheplayer.velox4j.type.BigIntType();
-    } else {
+    VLTypeConverter converter = converters.get(logicalType.getClass());
+    if (converter == null) {
       throw new RuntimeException("Unsupported logical type: " + logicalType);
     }
+    return converter.build(logicalType);
   }
 }
diff --git 
a/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorAccessor.java
 
b/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorAccessor.java
index 5138172a7d..5c4975b302 100644
--- 
a/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorAccessor.java
+++ 
b/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorAccessor.java
@@ -18,6 +18,8 @@ package org.apache.gluten.vectorized;
 
 import io.github.zhztheplayer.velox4j.type.*;
 
+import org.apache.flink.table.data.GenericArrayData;
+import org.apache.flink.table.data.GenericMapData;
 import org.apache.flink.table.data.GenericRowData;
 import org.apache.flink.table.data.binary.BinaryStringData;
 
@@ -27,33 +29,47 @@ import org.apache.arrow.vector.FieldVector;
 import org.apache.arrow.vector.Float8Vector;
 import org.apache.arrow.vector.IntVector;
 import org.apache.arrow.vector.VarCharVector;
+import org.apache.arrow.vector.complex.ListVector;
+import org.apache.arrow.vector.complex.MapVector;
 import org.apache.arrow.vector.complex.StructVector;
 
 import java.util.ArrayList;
+import java.util.LinkedHashMap;
 import java.util.List;
+import java.util.Map;
 
 /*
  * This module is used to convert column vector to flink generic rows.
  * BinaryRowData is not supported here.
  */
 public abstract class ArrowVectorAccessor {
+  private interface AccessorBuilder {
+    ArrowVectorAccessor build(FieldVector vector);
+  };
+
+  // Exact class matches
+  private static final Map<Class<? extends FieldVector>, AccessorBuilder> 
accessorBuilders =
+      Map.ofEntries(
+          Map.entry(BitVector.class, vector -> new 
BooleanVectorAccessor(vector)),
+          Map.entry(IntVector.class, vector -> new IntVectorAccessor(vector)),
+          Map.entry(BigIntVector.class, vector -> new 
BigIntVectorAccessor(vector)),
+          Map.entry(Float8Vector.class, vector -> new 
DoubleVectorAccessor(vector)),
+          Map.entry(VarCharVector.class, vector -> new 
VarCharVectorAccessor(vector)),
+          Map.entry(StructVector.class, vector -> new 
StructVectorAccessor(vector)),
+          Map.entry(ListVector.class, vector -> new 
ListVectorAccessor(vector)),
+          Map.entry(MapVector.class, vector -> new MapVectorAccessor(vector)));
+
   public static ArrowVectorAccessor create(FieldVector vector) {
-    if (vector instanceof BitVector) {
-      return new BooleanVectorAccessor(vector);
-    } else if (vector instanceof IntVector) {
-      return new IntVectorAccessor(vector);
-    } else if (vector instanceof BigIntVector) {
-      return new BigIntVectorAccessor(vector);
-    } else if (vector instanceof Float8Vector) {
-      return new DoubleVectorAccessor(vector);
-    } else if (vector instanceof VarCharVector) {
-      return new VarCharVectorAccessor(vector);
-    } else if (vector instanceof StructVector) {
-      return new StructVectorAccessor(vector);
-    } else {
+    if (vector == null) {
+      throw new IllegalArgumentException(
+          "ArrowVectorAccessor. Cannot create accessor for null vector.");
+    }
+    AccessorBuilder builder = accessorBuilders.get(vector.getClass());
+    if (builder == null) {
       throw new UnsupportedOperationException(
-          "ArrowVectorAccessor. Unsupported type: " + 
vector.getClass().getName());
+          "ArrowVectorAccessor. Unsupported vector type: " + 
vector.getClass().getName());
     }
+    return builder.build(vector);
   }
 
   // A general method to extract values from the vector.
@@ -153,3 +169,55 @@ class StructVectorAccessor extends ArrowVectorAccessor {
     return GenericRowData.of(fieldValues);
   }
 }
+
+class ListVectorAccessor extends ArrowVectorAccessor {
+  private ListVector vector;
+  private ArrowVectorAccessor elementAccessor;
+
+  public ListVectorAccessor(FieldVector vector) {
+    this.vector = (ListVector) vector;
+    FieldVector elementVector = this.vector.getDataVector();
+    this.elementAccessor = ArrowVectorAccessor.create(elementVector);
+  }
+
+  @Override
+  public Object get(int rowIndex) {
+    int startIndex = vector.getElementStartIndex(rowIndex);
+    int endIndex = vector.getElementEndIndex(rowIndex);
+    Object[] elements = new Object[endIndex - startIndex];
+    for (int i = startIndex; i < endIndex; i++) {
+      elements[i - startIndex] = elementAccessor.get(i);
+    }
+    return new GenericArrayData(elements);
+  }
+}
+
+// In Arrow, the internal implementation of a map vector is an array vector.
+class MapVectorAccessor extends ArrowVectorAccessor {
+  private final MapVector vector;
+  private StructVector entriesVector;
+  private ArrowVectorAccessor keyAccessor;
+  private ArrowVectorAccessor valueAccessor;
+
+  public MapVectorAccessor(FieldVector vector) {
+    this.vector = (MapVector) vector;
+    this.entriesVector = (StructVector) this.vector.getDataVector();
+    FieldVector keyVector = this.entriesVector.getChild(MapVector.KEY_NAME);
+    FieldVector valueVector = 
this.entriesVector.getChild(MapVector.VALUE_NAME);
+    this.keyAccessor = ArrowVectorAccessor.create(keyVector);
+    this.valueAccessor = ArrowVectorAccessor.create(valueVector);
+  }
+
+  @Override
+  public Object get(int rowIndex) {
+    int startIndex = vector.getElementStartIndex(rowIndex);
+    int endIndex = vector.getElementEndIndex(rowIndex);
+    Map<Object, Object> mapEntries = new LinkedHashMap<>();
+    for (int i = startIndex; i < endIndex; i++) {
+      Object key = keyAccessor.get(i);
+      Object value = valueAccessor.get(i);
+      mapEntries.put(key, value);
+    }
+    return new GenericMapData(mapEntries);
+  }
+}
diff --git 
a/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorWriter.java
 
b/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorWriter.java
index 9c6c845697..52c700c255 100644
--- 
a/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorWriter.java
+++ 
b/gluten-flink/runtime/src/main/java/org/apache/gluten/vectorized/ArrowVectorWriter.java
@@ -18,19 +18,67 @@ package org.apache.gluten.vectorized;
 
 import io.github.zhztheplayer.velox4j.type.*;
 
+import org.apache.flink.table.data.ArrayData;
+import org.apache.flink.table.data.MapData;
 import org.apache.flink.table.data.RowData;
 
 import org.apache.arrow.memory.BufferAllocator;
 import org.apache.arrow.vector.*;
+import org.apache.arrow.vector.complex.ListVector;
+import org.apache.arrow.vector.complex.MapVector;
 import org.apache.arrow.vector.complex.StructVector;
 import org.apache.arrow.vector.types.FloatingPointPrecision;
 import org.apache.arrow.vector.types.TimeUnit;
 import org.apache.arrow.vector.types.pojo.*;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.List;
+import java.util.Map;
 
 public abstract class ArrowVectorWriter {
+  private interface WriterBuilder {
+    ArrowVectorWriter build(Type fieldType, BufferAllocator allocator, 
FieldVector vector);
+  };
+
+  // Exact class matches
+  private static Map<Class<? extends Type>, WriterBuilder> writerBuilders =
+      Map.ofEntries(
+          Map.entry(
+              IntegerType.class,
+              (fieldType, allocator, vector) -> new IntVectorWriter(fieldType, 
allocator, vector)),
+          Map.entry(
+              BooleanType.class,
+              (fieldType, allocator, vector) ->
+                  new BooleanVectorWriter(fieldType, allocator, vector)),
+          Map.entry(
+              BigIntType.class,
+              (fieldType, allocator, vector) ->
+                  new BigIntVectorWriter(fieldType, allocator, vector)),
+          Map.entry(
+              DoubleType.class,
+              (fieldType, allocator, vector) ->
+                  new Float8VectorWriter(fieldType, allocator, vector)),
+          Map.entry(
+              VarCharType.class,
+              (fieldType, allocator, vector) ->
+                  new VarCharVectorWriter(fieldType, allocator, vector)),
+          Map.entry(
+              TimestampType.class,
+              (fieldType, allocator, vector) ->
+                  new TimestampVectorWriter(fieldType, allocator, vector)),
+          Map.entry(
+              RowType.class,
+              (fieldType, allocator, vector) ->
+                  new StructVectorWriter(fieldType, allocator, vector)),
+          Map.entry(
+              ArrayType.class,
+              (fieldType, allocator, vector) ->
+                  new ArrayVectorWriter(fieldType, allocator, vector)),
+          Map.entry(
+              MapType.class,
+              (fieldType, allocator, vector) -> new MapVectorWriter(fieldType, 
allocator, vector)));
+
   public static ArrowVectorWriter create(
       String fieldName, Type fieldType, BufferAllocator allocator) {
     return create(fieldName, fieldType, allocator, null);
@@ -42,40 +90,31 @@ public abstract class ArrowVectorWriter {
       // Build an empty vector
       vector = FieldVectorCreator.create(fieldName, fieldType, false, 
allocator, null);
     }
-    if (fieldType instanceof IntegerType) {
-      return new IntVectorWriter(fieldType, allocator, vector);
-    } else if (fieldType instanceof BooleanType) {
-      return new BooleanVectorWriter(fieldType, allocator, vector);
-    } else if (fieldType instanceof BigIntType) {
-      return new BigIntVectorWriter(fieldType, allocator, vector);
-    } else if (fieldType instanceof DoubleType) {
-      return new Float8VectorWriter(fieldType, allocator, vector);
-    } else if (fieldType instanceof VarCharType) {
-      return new VarCharVectorWriter(fieldType, allocator, vector);
-    } else if (fieldType instanceof TimestampType) {
-      return new TimestampVectorWriter(fieldType, allocator, vector);
-    } else if (fieldType instanceof RowType) {
-      return new StructVectorWriter(fieldType, allocator, vector);
-    } else {
-      throw new UnsupportedOperationException("ArrowVectorWriter. Unsupported 
type: " + fieldType);
+    WriterBuilder builder = writerBuilders.get(fieldType.getClass());
+    if (builder == null) {
+      throw new UnsupportedOperationException(
+          "ArrowVectorWriter. Unsupported type: " + 
fieldType.getClass().getName());
     }
+    return builder.build(fieldType, allocator, vector);
+  }
+
+  protected FieldVector vector = null;
+  protected int valueCount = 0;
+
+  ArrowVectorWriter(FieldVector vector) {
+    this.vector = vector;
   }
 
   public void write(int fieldIndex, RowData rowData) {
     throw new UnsupportedOperationException("assign is not supported");
   }
 
-  public void write(int fieldIndex, List<RowData> rowData) {
-    for (RowData row : rowData) {
-      write(fieldIndex, row);
-    }
+  public void writeArray(ArrayData arrayData) {
+    throw new UnsupportedOperationException("writeArray is not supported");
   }
 
-  protected FieldVector vector = null;
-  protected int valueCount = 0;
-
-  ArrowVectorWriter(FieldVector vector) {
-    this.vector = vector;
+  int getValueCount() {
+    return valueCount;
   }
 
   FieldVector getVector() {
@@ -86,6 +125,7 @@ public abstract class ArrowVectorWriter {
     vector.setValueCount(valueCount);
   }
 }
+
 // Build FieldVector from Type.
 class FieldVectorCreator {
   public static FieldVector create(
@@ -94,34 +134,60 @@ class FieldVectorCreator {
     return field.createVector(allocator);
   }
 
+  private interface ArrowTypeConverter {
+    ArrowType convert(Type dataType, String timeZoneId);
+  }
+
+  // Exact class matches
+  private static Map<Class<? extends Type>, ArrowTypeConverter> 
arrowTypeConverters =
+      Map.ofEntries(
+          Map.entry(BooleanType.class, (dataType, timeZoneId) -> 
ArrowType.Bool.INSTANCE),
+          Map.entry(IntegerType.class, (dataType, timeZoneId) -> new 
ArrowType.Int(8 * 4, true)),
+          Map.entry(BigIntType.class, (dataType, timeZoneId) -> new 
ArrowType.Int(8 * 8, true)),
+          Map.entry(
+              DoubleType.class,
+              (dataType, timeZoneId) -> new 
ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)),
+          Map.entry(VarCharType.class, (dataType, timeZoneId) -> 
ArrowType.Utf8.INSTANCE),
+          Map.entry(
+              TimestampType.class,
+              (dataType, timeZoneId) ->
+                  new ArrowType.Timestamp(
+                      TimeUnit.MILLISECOND, timeZoneId == null ? "UTC" : 
timeZoneId)));
+
   private static ArrowType toArrowType(Type dataType, String timeZoneId) {
-    if (dataType instanceof BooleanType) {
-      return ArrowType.Bool.INSTANCE;
-    } else if (dataType instanceof IntegerType) {
-      return new ArrowType.Int(8 * 4, true);
-    } else if (dataType instanceof BigIntType) {
-      return new ArrowType.Int(8 * 8, true);
-    } else if (dataType instanceof DoubleType) {
-      return new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE);
-    } else if (dataType instanceof VarCharType) {
-      return ArrowType.Utf8.INSTANCE;
-    } else if (dataType instanceof TimestampType) {
-      if (timeZoneId == null) {
-        return new ArrowType.Timestamp(TimeUnit.MILLISECOND, "UTC");
-      } else {
-        return new ArrowType.Timestamp(TimeUnit.MILLISECOND, timeZoneId);
-      }
-    } else {
-      throw new UnsupportedOperationException("Unsupported type: " + dataType);
+    ArrowTypeConverter converter = 
arrowTypeConverters.get(dataType.getClass());
+    if (converter == null) {
+      throw new UnsupportedOperationException("Unsupported type: " + 
dataType.getClass().getName());
     }
+    return converter.convert(dataType, timeZoneId);
   }
 
   private static Field toArrowField(
       String name, Type dataType, boolean nullable, String timeZoneId) {
     if (dataType instanceof ArrayType) {
-      throw new UnsupportedOperationException("ArrayType is not supported");
+      List<Type> elementTypes = ((ArrayType) dataType).getChildren();
+      if (elementTypes.size() != 1) {
+        throw new UnsupportedOperationException("ArrayType should have exactly 
one element type");
+      }
+
+      FieldType fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, 
null);
+      List<Field> elementFields = new ArrayList<>();
+      elementFields.add(toArrowField("element", elementTypes.get(0), nullable, 
timeZoneId));
+
+      return new Field(name, fieldType, elementFields);
+
     } else if (dataType instanceof MapType) {
-      throw new UnsupportedOperationException("MapType is not supported");
+      MapType mapType = (MapType) dataType;
+      FieldType mapFieldType = new FieldType(nullable, new 
ArrowType.Map(false), null);
+
+      List<String> fieldNames = Arrays.asList(MapVector.KEY_NAME, 
MapVector.VALUE_NAME);
+      List<Type> fieldTypes = mapType.getChildren();
+      RowType structType = new RowType(fieldNames, fieldTypes);
+      Field structField =
+          toArrowField(MapVector.DATA_VECTOR_NAME, structType, nullable, 
timeZoneId);
+
+      return new Field(name, mapFieldType, Arrays.asList(structField));
+
     } else if (dataType instanceof RowType) {
       RowType structType = (RowType) dataType;
       List<String> fieldNames = structType.getNames();
@@ -141,136 +207,302 @@ class FieldVectorCreator {
   }
 }
 
-class IntVectorWriter extends ArrowVectorWriter {
-  private final IntVector intVector;
+abstract class BaseVectorWriter<T extends FieldVector, V> extends 
ArrowVectorWriter {
+  protected final T typedVector;
 
-  public IntVectorWriter(Type fieldType, BufferAllocator allocator, 
FieldVector vector) {
+  protected BaseVectorWriter(FieldVector vector) {
     super(vector);
-    this.intVector = (IntVector) vector;
+    this.typedVector = (T) vector;
   }
 
+  protected abstract V getValue(RowData rowData, int fieldIndex);
+
+  protected abstract V getValue(ArrayData arrayData, int index);
+
+  protected abstract void setValue(int index, V value);
+
   @Override
   public void write(int fieldIndex, RowData rowData) {
-    intVector.setSafe(valueCount, rowData.getInt(fieldIndex));
+    setValue(valueCount, getValue(rowData, fieldIndex));
     valueCount++;
   }
+
+  @Override
+  public void writeArray(ArrayData arrayData) {
+    for (int i = 0; i < arrayData.size(); i++) {
+      setValue(valueCount, getValue(arrayData, i));
+      valueCount++;
+    }
+  }
 }
 
-class BooleanVectorWriter extends ArrowVectorWriter {
-  private final BitVector bitVector;
+class IntVectorWriter extends BaseVectorWriter<IntVector, Integer> {
+  public IntVectorWriter(Type fieldType, BufferAllocator allocator, 
FieldVector vector) {
+    super(vector);
+  }
+
+  @Override
+  protected Integer getValue(RowData rowData, int fieldIndex) {
+    return rowData.getInt(fieldIndex);
+  }
+
+  @Override
+  protected Integer getValue(ArrayData arrayData, int index) {
+    return arrayData.getInt(index);
+  }
+
+  @Override
+  protected void setValue(int index, Integer value) {
+    this.typedVector.setSafe(index, value);
+  }
+}
 
+class BooleanVectorWriter extends BaseVectorWriter<BitVector, Boolean> {
   public BooleanVectorWriter(Type fieldType, BufferAllocator allocator, 
FieldVector vector) {
     super(vector);
-    this.bitVector = (BitVector) vector;
   }
 
   @Override
-  public void write(int fieldIndex, RowData rowData) {
-    bitVector.setSafe(valueCount, rowData.getBoolean(fieldIndex) ? 1 : 0);
-    valueCount++;
+  protected Boolean getValue(RowData rowData, int fieldIndex) {
+    return rowData.getBoolean(fieldIndex);
+  }
+
+  @Override
+  protected Boolean getValue(ArrayData arrayData, int index) {
+    return arrayData.getBoolean(index);
+  }
+
+  @Override
+  protected void setValue(int index, Boolean value) {
+    this.typedVector.setSafe(index, value ? 1 : 0);
   }
 }
 
-class BigIntVectorWriter extends ArrowVectorWriter {
-  private final BigIntVector bigIntvector;
+class BigIntVectorWriter extends BaseVectorWriter<BigIntVector, Long> {
 
   public BigIntVectorWriter(Type fieldType, BufferAllocator allocator, 
FieldVector vector) {
     super(vector);
-    this.bigIntvector = (BigIntVector) vector;
   }
 
   @Override
-  public void write(int fieldIndex, RowData rowData) {
-    bigIntvector.setSafe(valueCount, rowData.getLong(fieldIndex));
-    valueCount++;
+  protected Long getValue(RowData rowData, int fieldIndex) {
+    return rowData.getLong(fieldIndex);
+  }
+
+  @Override
+  protected Long getValue(ArrayData arrayData, int index) {
+    return arrayData.getLong(index);
+  }
+
+  @Override
+  protected void setValue(int index, Long value) {
+    this.typedVector.setSafe(index, value);
   }
 }
 
-class Float8VectorWriter extends ArrowVectorWriter {
-  private final Float8Vector float8Vector;
+class Float8VectorWriter extends BaseVectorWriter<Float8Vector, Double> {
 
   public Float8VectorWriter(Type fieldType, BufferAllocator allocator, 
FieldVector vector) {
     super(vector);
-    this.float8Vector = (Float8Vector) vector;
   }
 
   @Override
-  public void write(int fieldIndex, RowData rowData) {
-    float8Vector.setSafe(valueCount, rowData.getDouble(fieldIndex));
-    valueCount++;
+  protected Double getValue(RowData rowData, int fieldIndex) {
+    return rowData.getDouble(fieldIndex);
+  }
+
+  @Override
+  protected Double getValue(ArrayData arrayData, int index) {
+    return arrayData.getDouble(index);
+  }
+
+  @Override
+  protected void setValue(int index, Double value) {
+    this.typedVector.setSafe(index, value);
   }
 }
 
-class VarCharVectorWriter extends ArrowVectorWriter {
-  private final VarCharVector varCharVector;
+class VarCharVectorWriter extends BaseVectorWriter<VarCharVector, byte[]> {
 
   public VarCharVectorWriter(Type fieldType, BufferAllocator allocator, 
FieldVector vector) {
     super(vector);
-    this.varCharVector = (VarCharVector) vector;
   }
 
   @Override
-  public void write(int fieldIndex, RowData rowData) {
-    varCharVector.setSafe(valueCount, rowData.getString(fieldIndex).toBytes());
-    valueCount++;
+  protected byte[] getValue(RowData rowData, int fieldIndex) {
+    return rowData.getString(fieldIndex).toBytes();
+  }
+
+  @Override
+  protected byte[] getValue(ArrayData arrayData, int index) {
+    return arrayData.getString(index).toBytes();
+  }
+
+  @Override
+  protected void setValue(int index, byte[] value) {
+    this.typedVector.setSafe(index, value);
   }
 }
 
-class TimestampVectorWriter extends ArrowVectorWriter {
-  private final TimeStampMilliVector tsVector;
+class TimestampVectorWriter extends BaseVectorWriter<TimeStampMilliVector, 
Long> {
+  private final int precision = 3; // Millisecond precision
 
   public TimestampVectorWriter(Type fieldType, BufferAllocator allocator, 
FieldVector vector) {
     super(vector);
-    this.tsVector = (TimeStampMilliVector) vector;
   }
 
   @Override
-  public void write(int fieldIndex, RowData rowData) {
-    // TODO: support precision
-    tsVector.setSafe(valueCount, rowData.getTimestamp(fieldIndex, 
3).getMillisecond());
-    valueCount++;
+  protected Long getValue(RowData rowData, int fieldIndex) {
+    return rowData.getTimestamp(fieldIndex, precision).getMillisecond();
+  }
+
+  @Override
+  protected Long getValue(ArrayData arrayData, int index) {
+    return arrayData.getTimestamp(index, precision).getMillisecond();
+  }
+
+  @Override
+  protected void setValue(int index, Long value) {
+    this.typedVector.setSafe(index, value);
   }
 }
 
-class StructVectorWriter extends ArrowVectorWriter {
-  private int fieldCounts = 0;
-  BufferAllocator allocator;
-  private List<ArrowVectorWriter> subFieldWriters;
-  private StructVector strctVector;
+class StructVectorWriter extends BaseVectorWriter<StructVector, RowData> {
+  private final int fieldCount;
+  private BufferAllocator allocator;
+  private final List<ArrowVectorWriter> fieldWriters;
 
   public StructVectorWriter(Type fieldType, BufferAllocator allocator, 
FieldVector vector) {
     super(vector);
-    this.strctVector = (StructVector) vector;
     RowType rowType = (RowType) fieldType;
-    List<String> subFieldNames = rowType.getNames();
-    subFieldWriters = new ArrayList<>();
-    for (int i = 0; i < subFieldNames.size(); ++i) {
-      subFieldWriters.add(
+    List<String> fieldNames = rowType.getNames();
+    fieldCount = fieldNames.size();
+    fieldWriters = new ArrayList<>();
+    for (int i = 0; i < fieldCount; ++i) {
+      fieldWriters.add(
           ArrowVectorWriter.create(
-              subFieldNames.get(i),
+              fieldNames.get(i),
               rowType.getChildren().get(i),
               allocator,
-              (FieldVector) (this.strctVector.getChildByOrdinal(i))));
+              (FieldVector) (this.typedVector.getChildByOrdinal(i))));
     }
-    fieldCounts = subFieldNames.size();
   }
 
   @Override
-  public void write(int fieldIndex, RowData rowData) {
-    // TODO: support nullable
-    RowData subRowData = rowData.getRow(fieldIndex, fieldCounts);
-    strctVector.setIndexDefined(valueCount);
-    for (int i = 0; i < fieldCounts; i++) {
-      subFieldWriters.get(i).write(i, subRowData);
+  protected RowData getValue(RowData rowData, int fieldIndex) {
+    return rowData.getRow(fieldIndex, fieldCount);
+  }
+
+  @Override
+  protected RowData getValue(ArrayData arrayData, int index) {
+    return arrayData.getRow(index, fieldCount);
+  }
+
+  @Override
+  protected void setValue(int index, RowData value) {
+    this.typedVector.setIndexDefined(index);
+    for (int i = 0; i < fieldCount; ++i) {
+      fieldWriters.get(i).write(i, value);
     }
-    valueCount++;
   }
 
   @Override
   public void finish() {
-    strctVector.setValueCount(valueCount);
-    for (int i = 0; i < fieldCounts; i++) {
-      subFieldWriters.get(i).finish();
+    this.typedVector.setValueCount(valueCount);
+    for (int i = 0; i < fieldCount; ++i) {
+      fieldWriters.get(i).finish();
     }
   }
 }
+
+class ArrayVectorWriter extends BaseVectorWriter<ListVector, ArrayData> {
+  private final ArrowVectorWriter elementWriter;
+
+  public ArrayVectorWriter(Type fieldType, BufferAllocator allocator, 
FieldVector vector) {
+    super(vector);
+
+    FieldVector elementVector = (FieldVector) this.typedVector.getDataVector();
+    List<Type> elementTypes = ((ArrayType) fieldType).getChildren();
+    if (elementTypes.size() != 1) {
+      throw new UnsupportedOperationException("ArrayType should have exactly 
one element type");
+    }
+    Type elementType = elementTypes.get(0);
+    this.elementWriter = ArrowVectorWriter.create("element", elementType, 
allocator, elementVector);
+  }
+
+  @Override
+  protected ArrayData getValue(RowData rowData, int fieldIndex) {
+    return rowData.getArray(fieldIndex);
+  }
+
+  @Override
+  protected ArrayData getValue(ArrayData arrayData, int index) {
+    return arrayData.getArray(index);
+  }
+
+  @Override
+  protected void setValue(int index, ArrayData value) {
+    this.typedVector.startNewValue(valueCount);
+    elementWriter.writeArray(value);
+    this.typedVector.endValue(valueCount, value.size());
+  }
+
+  @Override
+  public void finish() {
+    this.typedVector.setValueCount(valueCount);
+    elementWriter.finish();
+  }
+}
+
+class MapVectorWriter extends BaseVectorWriter<MapVector, MapData> {
+  private final ArrowVectorWriter keyWriter;
+  private final ArrowVectorWriter valueWriter;
+  private final StructVector entriesVector;
+
+  public MapVectorWriter(Type fieldType, BufferAllocator allocator, 
FieldVector vector) {
+    super(vector);
+
+    entriesVector = (StructVector) this.typedVector.getDataVector();
+
+    FieldVector keyVector = (FieldVector) 
entriesVector.getChild(MapVector.KEY_NAME);
+    FieldVector valueVector = (FieldVector) 
entriesVector.getChild(MapVector.VALUE_NAME);
+
+    MapType mapType = (MapType) fieldType;
+    this.keyWriter =
+        ArrowVectorWriter.create(
+            MapVector.KEY_NAME, mapType.getChildren().get(0), allocator, 
keyVector);
+    this.valueWriter =
+        ArrowVectorWriter.create(
+            MapVector.VALUE_NAME, mapType.getChildren().get(1), allocator, 
valueVector);
+  }
+
+  @Override
+  protected MapData getValue(RowData rowData, int fieldIndex) {
+    return rowData.getMap(fieldIndex);
+  }
+
+  @Override
+  protected MapData getValue(ArrayData arrayData, int index) {
+    return arrayData.getMap(index);
+  }
+
+  @Override
+  protected void setValue(int index, MapData value) {
+    this.typedVector.startNewValue(valueCount);
+    int arrayValueCount = keyWriter.getValueCount();
+    for (int i = 0; i < value.size(); i++) {
+      entriesVector.setIndexDefined(arrayValueCount + i);
+    }
+    keyWriter.writeArray(value.keyArray());
+    valueWriter.writeArray(value.valueArray());
+    this.typedVector.endValue(valueCount, value.size());
+  }
+
+  @Override
+  public void finish() {
+    this.typedVector.setValueCount(valueCount);
+    entriesVector.setValueCount(keyWriter.getValueCount());
+    keyWriter.finish();
+    valueWriter.finish();
+  }
+}
diff --git 
a/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/stream/custom/ScanTest.java
 
b/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/stream/custom/ScanTest.java
index 4a713ef1be..47eefd5fae 100644
--- 
a/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/stream/custom/ScanTest.java
+++ 
b/gluten-flink/ut/src/test/java/org/apache/gluten/table/runtime/stream/custom/ScanTest.java
@@ -27,6 +27,7 @@ import org.slf4j.LoggerFactory;
 
 import java.util.Arrays;
 import java.util.List;
+import java.util.Map;
 
 class ScanTest extends GlutenStreamingTestBase {
   private static final Logger LOG = LoggerFactory.getLogger(ScanTest.class);
@@ -69,4 +70,64 @@ class ScanTest extends GlutenStreamingTestBase {
     String query = "select a, b from floatTbl where a > 0";
     runAndCheck(query, Arrays.asList("+I[1, 1.0]", "+I[2, 2.0]"));
   }
+
+  @Test
+  void testArrayScan() {
+    List<Row> rows =
+        Arrays.asList(
+            Row.of(1, new Integer[] {1, 2, 3}),
+            Row.of(2, new Integer[] {4, 5, 6}),
+            Row.of(3, new Integer[] {7, 8, 9}));
+    createSimpleBoundedValuesTable("arrayTbl1", "a int, b array<int>", rows);
+    String query = "select a, b from arrayTbl1 where a > 0";
+    runAndCheck(query, Arrays.asList("+I[1, [1, 2, 3]]", "+I[2, [4, 5, 6]]", 
"+I[3, [7, 8, 9]]"));
+
+    rows =
+        Arrays.asList(
+            Row.of(1, new String[] {"a", "b", "c"}),
+            Row.of(2, new String[] {"d", "e", "f"}),
+            Row.of(3, new String[] {"g", "h", "i"}));
+    createSimpleBoundedValuesTable("arrayTbl2", "a int, b array<string>", 
rows);
+    query = "select a, b from arrayTbl2 where a > 0";
+    runAndCheck(query, Arrays.asList("+I[1, [a, b, c]]", "+I[2, [d, e, f]]", 
"+I[3, [g, h, i]]"));
+
+    rows =
+        Arrays.asList(
+            Row.of(1, new Row[] {Row.of(1, 2), Row.of(3, 4)}), Row.of(3, new 
Row[] {Row.of(5, 6)}));
+    createSimpleBoundedValuesTable("arrayTbl3", "a int, b array<ROW<x int, y 
int>>", rows);
+    query = "select a, b from arrayTbl3 where a > 0";
+    runAndCheck(query, Arrays.asList("+I[1, [+I[1, 2], +I[3, 4]]]", "+I[3, 
[+I[5, 6]]]"));
+
+    rows =
+        Arrays.asList(
+            Row.of(1, new Integer[][] {new Integer[] {1, 3}}),
+            Row.of(3, new Integer[][] {new Integer[] {4, 5}}));
+    createSimpleBoundedValuesTable("arrayTbl4", "a int, b array<array<int>>", 
rows);
+    query = "select a, b from arrayTbl4 where a > 0";
+    runAndCheck(query, Arrays.asList("+I[1, [[1, 3]]]", "+I[3, [[4, 5]]]"));
+  }
+
+  @Test
+  void testMapScan() {
+    List<Row> rows =
+        Arrays.asList(
+            Row.of(1, Map.of(1, "a")),
+            Row.of(2, Map.of(2, "b", 3, "c")),
+            Row.of(3, Map.of(4, "d", 5, "e", 6, "f")));
+    createSimpleBoundedValuesTable("mapTbl1", "a int, b map<int, string>", 
rows);
+    String query = "select a, b from mapTbl1 where a > 0";
+    runAndCheck(
+        query, Arrays.asList("+I[1, {1=a}]", "+I[2, {2=b, 3=c}]", "+I[3, {4=d, 
5=e, 6=f}]"));
+
+    rows =
+        Arrays.asList(
+            Row.of(1, new Map[] {Map.of("a", 1), Map.of("b", 2)}),
+            Row.of(2, new Map[] {Map.of("b", 2, "c", 3)}),
+            Row.of(3, new Map[] {Map.of("d", 4, "e", 5, "f", 6)}));
+    createSimpleBoundedValuesTable("mapTbl2", "a int, b array<map<string, 
int>>", rows);
+    query = "select a, b from mapTbl2 where a > 0";
+    runAndCheck(
+        query,
+        Arrays.asList("+I[1, [{a=1}, {b=2}]]", "+I[2, [{b=2, c=3}]]", "+I[3, 
[{d=4, e=5, f=6}]]"));
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to