[SPARK-20960][SQL] make ColumnVector public ## What changes were proposed in this pull request?
move `ColumnVector` and related classes to `org.apache.spark.sql.vectorized`, and improve the document. ## How was this patch tested? existing tests. Author: Wenchen Fan <wenc...@databricks.com> Closes #20116 from cloud-fan/column-vector. (cherry picked from commit b297029130735316e1ac1144dee44761a12bfba7) Signed-off-by: gatorsmile <gatorsm...@gmail.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a51212b6 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a51212b6 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a51212b6 Branch: refs/heads/branch-2.3 Commit: a51212b642f05f28447b80aa29f5482de2c27f58 Parents: 79f7263 Author: Wenchen Fan <wenc...@databricks.com> Authored: Thu Jan 4 07:28:53 2018 +0800 Committer: gatorsmile <gatorsm...@gmail.com> Committed: Thu Jan 4 07:29:33 2018 +0800 ---------------------------------------------------------------------- .../parquet/VectorizedParquetRecordReader.java | 7 +- .../execution/vectorized/ArrowColumnVector.java | 620 ------------------- .../sql/execution/vectorized/ColumnVector.java | 208 ------- .../execution/vectorized/ColumnVectorUtils.java | 2 + .../sql/execution/vectorized/ColumnarArray.java | 202 ------ .../sql/execution/vectorized/ColumnarBatch.java | 149 ----- .../sql/execution/vectorized/ColumnarRow.java | 206 ------ .../vectorized/MutableColumnarRow.java | 4 + .../vectorized/WritableColumnVector.java | 7 +- .../spark/sql/vectorized/ArrowColumnVector.java | 562 +++++++++++++++++ .../spark/sql/vectorized/ColumnVector.java | 215 +++++++ .../spark/sql/vectorized/ColumnarArray.java | 201 ++++++ .../spark/sql/vectorized/ColumnarBatch.java | 129 ++++ .../spark/sql/vectorized/ColumnarRow.java | 205 ++++++ .../spark/sql/execution/ColumnarBatchScan.scala | 4 +- .../execution/aggregate/HashAggregateExec.scala | 2 +- .../aggregate/VectorizedHashMapGenerator.scala | 3 +- .../sql/execution/arrow/ArrowConverters.scala | 2 +- .../columnar/InMemoryTableScanExec.scala | 1 + .../sql/execution/datasources/FileScanRDD.scala | 2 +- .../execution/python/ArrowPythonRunner.scala | 2 +- .../sql/execution/arrow/ArrowWriterSuite.scala | 2 +- .../vectorized/ArrowColumnVectorSuite.scala | 1 + .../vectorized/ColumnVectorSuite.scala | 2 +- .../vectorized/ColumnarBatchSuite.scala | 6 +- 25 files changed, 1341 insertions(+), 1403 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/a51212b6/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java ---------------------------------------------------------------------- diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 6c157e8..cd745b1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -31,10 +31,10 @@ import org.apache.parquet.schema.Type; import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils; -import org.apache.spark.sql.execution.vectorized.ColumnarBatch; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector; import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -248,7 +248,10 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa * Advances to the next batch of rows. Returns false if there are no more. */ public boolean nextBatch() throws IOException { - columnarBatch.reset(); + for (WritableColumnVector vector : columnVectors) { + vector.reset(); + } + columnarBatch.setNumRows(0); if (rowsReturned >= totalRowCount) return false; checkEndOfRowGroup(); http://git-wip-us.apache.org/repos/asf/spark/blob/a51212b6/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java ---------------------------------------------------------------------- diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java deleted file mode 100644 index af5673e..0000000 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ /dev/null @@ -1,620 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.vectorized; - -import org.apache.arrow.vector.*; -import org.apache.arrow.vector.complex.*; -import org.apache.arrow.vector.holders.NullableVarCharHolder; - -import org.apache.spark.sql.execution.arrow.ArrowUtils; -import org.apache.spark.sql.types.*; -import org.apache.spark.unsafe.types.UTF8String; - -/** - * A column vector backed by Apache Arrow. - */ -public final class ArrowColumnVector extends ColumnVector { - - private final ArrowVectorAccessor accessor; - private ArrowColumnVector[] childColumns; - - private void ensureAccessible(int index) { - int valueCount = accessor.getValueCount(); - if (index < 0 || index >= valueCount) { - throw new IndexOutOfBoundsException( - String.format("index: %d, valueCount: %d", index, valueCount)); - } - } - - private void ensureAccessible(int index, int count) { - int valueCount = accessor.getValueCount(); - if (index < 0 || index + count > valueCount) { - throw new IndexOutOfBoundsException( - String.format("index range: [%d, %d), valueCount: %d", index, index + count, valueCount)); - } - } - - @Override - public int numNulls() { - return accessor.getNullCount(); - } - - @Override - public void close() { - if (childColumns != null) { - for (int i = 0; i < childColumns.length; i++) { - childColumns[i].close(); - } - } - accessor.close(); - } - - // - // APIs dealing with nulls - // - - @Override - public boolean isNullAt(int rowId) { - ensureAccessible(rowId); - return accessor.isNullAt(rowId); - } - - // - // APIs dealing with Booleans - // - - @Override - public boolean getBoolean(int rowId) { - ensureAccessible(rowId); - return accessor.getBoolean(rowId); - } - - @Override - public boolean[] getBooleans(int rowId, int count) { - ensureAccessible(rowId, count); - boolean[] array = new boolean[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getBoolean(rowId + i); - } - return array; - } - - // - // APIs dealing with Bytes - // - - @Override - public byte getByte(int rowId) { - ensureAccessible(rowId); - return accessor.getByte(rowId); - } - - @Override - public byte[] getBytes(int rowId, int count) { - ensureAccessible(rowId, count); - byte[] array = new byte[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getByte(rowId + i); - } - return array; - } - - // - // APIs dealing with Shorts - // - - @Override - public short getShort(int rowId) { - ensureAccessible(rowId); - return accessor.getShort(rowId); - } - - @Override - public short[] getShorts(int rowId, int count) { - ensureAccessible(rowId, count); - short[] array = new short[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getShort(rowId + i); - } - return array; - } - - // - // APIs dealing with Ints - // - - @Override - public int getInt(int rowId) { - ensureAccessible(rowId); - return accessor.getInt(rowId); - } - - @Override - public int[] getInts(int rowId, int count) { - ensureAccessible(rowId, count); - int[] array = new int[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getInt(rowId + i); - } - return array; - } - - // - // APIs dealing with Longs - // - - @Override - public long getLong(int rowId) { - ensureAccessible(rowId); - return accessor.getLong(rowId); - } - - @Override - public long[] getLongs(int rowId, int count) { - ensureAccessible(rowId, count); - long[] array = new long[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getLong(rowId + i); - } - return array; - } - - // - // APIs dealing with floats - // - - @Override - public float getFloat(int rowId) { - ensureAccessible(rowId); - return accessor.getFloat(rowId); - } - - @Override - public float[] getFloats(int rowId, int count) { - ensureAccessible(rowId, count); - float[] array = new float[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getFloat(rowId + i); - } - return array; - } - - // - // APIs dealing with doubles - // - - @Override - public double getDouble(int rowId) { - ensureAccessible(rowId); - return accessor.getDouble(rowId); - } - - @Override - public double[] getDoubles(int rowId, int count) { - ensureAccessible(rowId, count); - double[] array = new double[count]; - for (int i = 0; i < count; ++i) { - array[i] = accessor.getDouble(rowId + i); - } - return array; - } - - // - // APIs dealing with Arrays - // - - @Override - public int getArrayLength(int rowId) { - ensureAccessible(rowId); - return accessor.getArrayLength(rowId); - } - - @Override - public int getArrayOffset(int rowId) { - ensureAccessible(rowId); - return accessor.getArrayOffset(rowId); - } - - // - // APIs dealing with Decimals - // - - @Override - public Decimal getDecimal(int rowId, int precision, int scale) { - ensureAccessible(rowId); - return accessor.getDecimal(rowId, precision, scale); - } - - // - // APIs dealing with UTF8Strings - // - - @Override - public UTF8String getUTF8String(int rowId) { - ensureAccessible(rowId); - return accessor.getUTF8String(rowId); - } - - // - // APIs dealing with Binaries - // - - @Override - public byte[] getBinary(int rowId) { - ensureAccessible(rowId); - return accessor.getBinary(rowId); - } - - /** - * Returns the data for the underlying array. - */ - @Override - public ArrowColumnVector arrayData() { return childColumns[0]; } - - /** - * Returns the ordinal's child data column. - */ - @Override - public ArrowColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; } - - public ArrowColumnVector(ValueVector vector) { - super(ArrowUtils.fromArrowField(vector.getField())); - - if (vector instanceof BitVector) { - accessor = new BooleanAccessor((BitVector) vector); - } else if (vector instanceof TinyIntVector) { - accessor = new ByteAccessor((TinyIntVector) vector); - } else if (vector instanceof SmallIntVector) { - accessor = new ShortAccessor((SmallIntVector) vector); - } else if (vector instanceof IntVector) { - accessor = new IntAccessor((IntVector) vector); - } else if (vector instanceof BigIntVector) { - accessor = new LongAccessor((BigIntVector) vector); - } else if (vector instanceof Float4Vector) { - accessor = new FloatAccessor((Float4Vector) vector); - } else if (vector instanceof Float8Vector) { - accessor = new DoubleAccessor((Float8Vector) vector); - } else if (vector instanceof DecimalVector) { - accessor = new DecimalAccessor((DecimalVector) vector); - } else if (vector instanceof VarCharVector) { - accessor = new StringAccessor((VarCharVector) vector); - } else if (vector instanceof VarBinaryVector) { - accessor = new BinaryAccessor((VarBinaryVector) vector); - } else if (vector instanceof DateDayVector) { - accessor = new DateAccessor((DateDayVector) vector); - } else if (vector instanceof TimeStampMicroTZVector) { - accessor = new TimestampAccessor((TimeStampMicroTZVector) vector); - } else if (vector instanceof ListVector) { - ListVector listVector = (ListVector) vector; - accessor = new ArrayAccessor(listVector); - - childColumns = new ArrowColumnVector[1]; - childColumns[0] = new ArrowColumnVector(listVector.getDataVector()); - } else if (vector instanceof MapVector) { - MapVector mapVector = (MapVector) vector; - accessor = new StructAccessor(mapVector); - - childColumns = new ArrowColumnVector[mapVector.size()]; - for (int i = 0; i < childColumns.length; ++i) { - childColumns[i] = new ArrowColumnVector(mapVector.getVectorById(i)); - } - } else { - throw new UnsupportedOperationException(); - } - } - - private abstract static class ArrowVectorAccessor { - - private final ValueVector vector; - - ArrowVectorAccessor(ValueVector vector) { - this.vector = vector; - } - - // TODO: should be final after removing ArrayAccessor workaround - boolean isNullAt(int rowId) { - return vector.isNull(rowId); - } - - final int getValueCount() { - return vector.getValueCount(); - } - - final int getNullCount() { - return vector.getNullCount(); - } - - final void close() { - vector.close(); - } - - boolean getBoolean(int rowId) { - throw new UnsupportedOperationException(); - } - - byte getByte(int rowId) { - throw new UnsupportedOperationException(); - } - - short getShort(int rowId) { - throw new UnsupportedOperationException(); - } - - int getInt(int rowId) { - throw new UnsupportedOperationException(); - } - - long getLong(int rowId) { - throw new UnsupportedOperationException(); - } - - float getFloat(int rowId) { - throw new UnsupportedOperationException(); - } - - double getDouble(int rowId) { - throw new UnsupportedOperationException(); - } - - Decimal getDecimal(int rowId, int precision, int scale) { - throw new UnsupportedOperationException(); - } - - UTF8String getUTF8String(int rowId) { - throw new UnsupportedOperationException(); - } - - byte[] getBinary(int rowId) { - throw new UnsupportedOperationException(); - } - - int getArrayLength(int rowId) { - throw new UnsupportedOperationException(); - } - - int getArrayOffset(int rowId) { - throw new UnsupportedOperationException(); - } - } - - private static class BooleanAccessor extends ArrowVectorAccessor { - - private final BitVector accessor; - - BooleanAccessor(BitVector vector) { - super(vector); - this.accessor = vector; - } - - @Override - final boolean getBoolean(int rowId) { - return accessor.get(rowId) == 1; - } - } - - private static class ByteAccessor extends ArrowVectorAccessor { - - private final TinyIntVector accessor; - - ByteAccessor(TinyIntVector vector) { - super(vector); - this.accessor = vector; - } - - @Override - final byte getByte(int rowId) { - return accessor.get(rowId); - } - } - - private static class ShortAccessor extends ArrowVectorAccessor { - - private final SmallIntVector accessor; - - ShortAccessor(SmallIntVector vector) { - super(vector); - this.accessor = vector; - } - - @Override - final short getShort(int rowId) { - return accessor.get(rowId); - } - } - - private static class IntAccessor extends ArrowVectorAccessor { - - private final IntVector accessor; - - IntAccessor(IntVector vector) { - super(vector); - this.accessor = vector; - } - - @Override - final int getInt(int rowId) { - return accessor.get(rowId); - } - } - - private static class LongAccessor extends ArrowVectorAccessor { - - private final BigIntVector accessor; - - LongAccessor(BigIntVector vector) { - super(vector); - this.accessor = vector; - } - - @Override - final long getLong(int rowId) { - return accessor.get(rowId); - } - } - - private static class FloatAccessor extends ArrowVectorAccessor { - - private final Float4Vector accessor; - - FloatAccessor(Float4Vector vector) { - super(vector); - this.accessor = vector; - } - - @Override - final float getFloat(int rowId) { - return accessor.get(rowId); - } - } - - private static class DoubleAccessor extends ArrowVectorAccessor { - - private final Float8Vector accessor; - - DoubleAccessor(Float8Vector vector) { - super(vector); - this.accessor = vector; - } - - @Override - final double getDouble(int rowId) { - return accessor.get(rowId); - } - } - - private static class DecimalAccessor extends ArrowVectorAccessor { - - private final DecimalVector accessor; - - DecimalAccessor(DecimalVector vector) { - super(vector); - this.accessor = vector; - } - - @Override - final Decimal getDecimal(int rowId, int precision, int scale) { - if (isNullAt(rowId)) return null; - return Decimal.apply(accessor.getObject(rowId), precision, scale); - } - } - - private static class StringAccessor extends ArrowVectorAccessor { - - private final VarCharVector accessor; - private final NullableVarCharHolder stringResult = new NullableVarCharHolder(); - - StringAccessor(VarCharVector vector) { - super(vector); - this.accessor = vector; - } - - @Override - final UTF8String getUTF8String(int rowId) { - accessor.get(rowId, stringResult); - if (stringResult.isSet == 0) { - return null; - } else { - return UTF8String.fromAddress(null, - stringResult.buffer.memoryAddress() + stringResult.start, - stringResult.end - stringResult.start); - } - } - } - - private static class BinaryAccessor extends ArrowVectorAccessor { - - private final VarBinaryVector accessor; - - BinaryAccessor(VarBinaryVector vector) { - super(vector); - this.accessor = vector; - } - - @Override - final byte[] getBinary(int rowId) { - return accessor.getObject(rowId); - } - } - - private static class DateAccessor extends ArrowVectorAccessor { - - private final DateDayVector accessor; - - DateAccessor(DateDayVector vector) { - super(vector); - this.accessor = vector; - } - - @Override - final int getInt(int rowId) { - return accessor.get(rowId); - } - } - - private static class TimestampAccessor extends ArrowVectorAccessor { - - private final TimeStampMicroTZVector accessor; - - TimestampAccessor(TimeStampMicroTZVector vector) { - super(vector); - this.accessor = vector; - } - - @Override - final long getLong(int rowId) { - return accessor.get(rowId); - } - } - - private static class ArrayAccessor extends ArrowVectorAccessor { - - private final ListVector accessor; - - ArrayAccessor(ListVector vector) { - super(vector); - this.accessor = vector; - } - - @Override - final boolean isNullAt(int rowId) { - // TODO: Workaround if vector has all non-null values, see ARROW-1948 - if (accessor.getValueCount() > 0 && accessor.getValidityBuffer().capacity() == 0) { - return false; - } else { - return super.isNullAt(rowId); - } - } - - @Override - final int getArrayLength(int rowId) { - return accessor.getInnerValueCountAt(rowId); - } - - @Override - final int getArrayOffset(int rowId) { - return accessor.getOffsetBuffer().getInt(rowId * accessor.OFFSET_WIDTH); - } - } - - private static class StructAccessor extends ArrowVectorAccessor { - - StructAccessor(MapVector vector) { - super(vector); - } - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/a51212b6/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java ---------------------------------------------------------------------- diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java deleted file mode 100644 index dc7c126..0000000 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ /dev/null @@ -1,208 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.vectorized; - -import org.apache.spark.sql.catalyst.util.MapData; -import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.Decimal; -import org.apache.spark.unsafe.types.UTF8String; - -/** - * This class represents in-memory values of a column and provides the main APIs to access the data. - * It supports all the types and contains get APIs as well as their batched versions. The batched - * versions are considered to be faster and preferable whenever possible. - * - * To handle nested schemas, ColumnVector has two types: Arrays and Structs. In both cases these - * columns have child columns. All of the data are stored in the child columns and the parent column - * only contains nullability. In the case of Arrays, the lengths and offsets are saved in the child - * column and are encoded identically to INTs. - * - * Maps are just a special case of a two field struct. - * - * Most of the APIs take the rowId as a parameter. This is the batch local 0-based row id for values - * in the current batch. - */ -public abstract class ColumnVector implements AutoCloseable { - - /** - * Returns the data type of this column. - */ - public final DataType dataType() { return type; } - - /** - * Cleans up memory for this column. The column is not usable after this. - */ - public abstract void close(); - - /** - * Returns the number of nulls in this column. - */ - public abstract int numNulls(); - - /** - * Returns whether the value at rowId is NULL. - */ - public abstract boolean isNullAt(int rowId); - - /** - * Returns the value for rowId. - */ - public abstract boolean getBoolean(int rowId); - - /** - * Gets values from [rowId, rowId + count) - */ - public abstract boolean[] getBooleans(int rowId, int count); - - /** - * Returns the value for rowId. - */ - public abstract byte getByte(int rowId); - - /** - * Gets values from [rowId, rowId + count) - */ - public abstract byte[] getBytes(int rowId, int count); - - /** - * Returns the value for rowId. - */ - public abstract short getShort(int rowId); - - /** - * Gets values from [rowId, rowId + count) - */ - public abstract short[] getShorts(int rowId, int count); - - /** - * Returns the value for rowId. - */ - public abstract int getInt(int rowId); - - /** - * Gets values from [rowId, rowId + count) - */ - public abstract int[] getInts(int rowId, int count); - - /** - * Returns the value for rowId. - */ - public abstract long getLong(int rowId); - - /** - * Gets values from [rowId, rowId + count) - */ - public abstract long[] getLongs(int rowId, int count); - - /** - * Returns the value for rowId. - */ - public abstract float getFloat(int rowId); - - /** - * Gets values from [rowId, rowId + count) - */ - public abstract float[] getFloats(int rowId, int count); - - /** - * Returns the value for rowId. - */ - public abstract double getDouble(int rowId); - - /** - * Gets values from [rowId, rowId + count) - */ - public abstract double[] getDoubles(int rowId, int count); - - /** - * Returns the length of the array for rowId. - */ - public abstract int getArrayLength(int rowId); - - /** - * Returns the offset of the array for rowId. - */ - public abstract int getArrayOffset(int rowId); - - /** - * Returns the struct for rowId. - */ - public final ColumnarRow getStruct(int rowId) { - return new ColumnarRow(this, rowId); - } - - /** - * A special version of {@link #getStruct(int)}, which is only used as an adapter for Spark - * codegen framework, the second parameter is totally ignored. - */ - public final ColumnarRow getStruct(int rowId, int size) { - return getStruct(rowId); - } - - /** - * Returns the array for rowId. - */ - public final ColumnarArray getArray(int rowId) { - return new ColumnarArray(arrayData(), getArrayOffset(rowId), getArrayLength(rowId)); - } - - /** - * Returns the map for rowId. - */ - public MapData getMap(int ordinal) { - throw new UnsupportedOperationException(); - } - - /** - * Returns the decimal for rowId. - */ - public abstract Decimal getDecimal(int rowId, int precision, int scale); - - /** - * Returns the UTF8String for rowId. Note that the returned UTF8String may point to the data of - * this column vector, please copy it if you want to keep it after this column vector is freed. - */ - public abstract UTF8String getUTF8String(int rowId); - - /** - * Returns the byte array for rowId. - */ - public abstract byte[] getBinary(int rowId); - - /** - * Returns the data for the underlying array. - */ - public abstract ColumnVector arrayData(); - - /** - * Returns the ordinal's child data column. - */ - public abstract ColumnVector getChildColumn(int ordinal); - - /** - * Data type for this column. - */ - protected DataType type; - - /** - * Sets up the common state and also handles creating the child columns if this is a nested - * type. - */ - protected ColumnVector(DataType type) { - this.type = type; - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/a51212b6/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java ---------------------------------------------------------------------- diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index bc62bc4..b5cbe8e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -28,6 +28,8 @@ import org.apache.spark.sql.Row; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.DateTimeUtils; import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarBatch; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; http://git-wip-us.apache.org/repos/asf/spark/blob/a51212b6/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java ---------------------------------------------------------------------- diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java deleted file mode 100644 index cbc39d1..0000000 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java +++ /dev/null @@ -1,202 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.vectorized; - -import org.apache.spark.sql.catalyst.util.ArrayData; -import org.apache.spark.sql.catalyst.util.MapData; -import org.apache.spark.sql.types.*; -import org.apache.spark.unsafe.types.CalendarInterval; -import org.apache.spark.unsafe.types.UTF8String; - -/** - * Array abstraction in {@link ColumnVector}. The instance of this class is intended - * to be reused, callers should copy the data out if it needs to be stored. - */ -public final class ColumnarArray extends ArrayData { - // The data for this array. This array contains elements from - // data[offset] to data[offset + length). - private final ColumnVector data; - private final int offset; - private final int length; - - ColumnarArray(ColumnVector data, int offset, int length) { - this.data = data; - this.offset = offset; - this.length = length; - } - - @Override - public int numElements() { - return length; - } - - @Override - public ArrayData copy() { - throw new UnsupportedOperationException(); - } - - @Override - public boolean[] toBooleanArray() { return data.getBooleans(offset, length); } - - @Override - public byte[] toByteArray() { return data.getBytes(offset, length); } - - @Override - public short[] toShortArray() { return data.getShorts(offset, length); } - - @Override - public int[] toIntArray() { return data.getInts(offset, length); } - - @Override - public long[] toLongArray() { return data.getLongs(offset, length); } - - @Override - public float[] toFloatArray() { return data.getFloats(offset, length); } - - @Override - public double[] toDoubleArray() { return data.getDoubles(offset, length); } - - // TODO: this is extremely expensive. - @Override - public Object[] array() { - DataType dt = data.dataType(); - Object[] list = new Object[length]; - try { - for (int i = 0; i < length; i++) { - if (!data.isNullAt(offset + i)) { - list[i] = get(i, dt); - } - } - return list; - } catch(Exception e) { - throw new RuntimeException("Could not get the array", e); - } - } - - @Override - public boolean isNullAt(int ordinal) { return data.isNullAt(offset + ordinal); } - - @Override - public boolean getBoolean(int ordinal) { - return data.getBoolean(offset + ordinal); - } - - @Override - public byte getByte(int ordinal) { return data.getByte(offset + ordinal); } - - @Override - public short getShort(int ordinal) { - return data.getShort(offset + ordinal); - } - - @Override - public int getInt(int ordinal) { return data.getInt(offset + ordinal); } - - @Override - public long getLong(int ordinal) { return data.getLong(offset + ordinal); } - - @Override - public float getFloat(int ordinal) { - return data.getFloat(offset + ordinal); - } - - @Override - public double getDouble(int ordinal) { return data.getDouble(offset + ordinal); } - - @Override - public Decimal getDecimal(int ordinal, int precision, int scale) { - return data.getDecimal(offset + ordinal, precision, scale); - } - - @Override - public UTF8String getUTF8String(int ordinal) { - return data.getUTF8String(offset + ordinal); - } - - @Override - public byte[] getBinary(int ordinal) { - return data.getBinary(offset + ordinal); - } - - @Override - public CalendarInterval getInterval(int ordinal) { - int month = data.getChildColumn(0).getInt(offset + ordinal); - long microseconds = data.getChildColumn(1).getLong(offset + ordinal); - return new CalendarInterval(month, microseconds); - } - - @Override - public ColumnarRow getStruct(int ordinal, int numFields) { - return data.getStruct(offset + ordinal); - } - - @Override - public ColumnarArray getArray(int ordinal) { - return data.getArray(offset + ordinal); - } - - @Override - public MapData getMap(int ordinal) { - throw new UnsupportedOperationException(); - } - - @Override - public Object get(int ordinal, DataType dataType) { - if (dataType instanceof BooleanType) { - return getBoolean(ordinal); - } else if (dataType instanceof ByteType) { - return getByte(ordinal); - } else if (dataType instanceof ShortType) { - return getShort(ordinal); - } else if (dataType instanceof IntegerType) { - return getInt(ordinal); - } else if (dataType instanceof LongType) { - return getLong(ordinal); - } else if (dataType instanceof FloatType) { - return getFloat(ordinal); - } else if (dataType instanceof DoubleType) { - return getDouble(ordinal); - } else if (dataType instanceof StringType) { - return getUTF8String(ordinal); - } else if (dataType instanceof BinaryType) { - return getBinary(ordinal); - } else if (dataType instanceof DecimalType) { - DecimalType t = (DecimalType) dataType; - return getDecimal(ordinal, t.precision(), t.scale()); - } else if (dataType instanceof DateType) { - return getInt(ordinal); - } else if (dataType instanceof TimestampType) { - return getLong(ordinal); - } else if (dataType instanceof ArrayType) { - return getArray(ordinal); - } else if (dataType instanceof StructType) { - return getStruct(ordinal, ((StructType)dataType).fields().length); - } else if (dataType instanceof MapType) { - return getMap(ordinal); - } else if (dataType instanceof CalendarIntervalType) { - return getInterval(ordinal); - } else { - throw new UnsupportedOperationException("Datatype not supported " + dataType); - } - } - - @Override - public void update(int ordinal, Object value) { throw new UnsupportedOperationException(); } - - @Override - public void setNullAt(int ordinal) { throw new UnsupportedOperationException(); } -} http://git-wip-us.apache.org/repos/asf/spark/blob/a51212b6/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java ---------------------------------------------------------------------- diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java deleted file mode 100644 index a9d09aa..0000000 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ /dev/null @@ -1,149 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.vectorized; - -import java.util.*; - -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.types.StructType; - -/** - * This class is the in memory representation of rows as they are streamed through operators. It - * is designed to maximize CPU efficiency and not storage footprint. Since it is expected that - * each operator allocates one of these objects, the storage footprint on the task is negligible. - * - * The layout is a columnar with values encoded in their native format. Each RowBatch contains - * a horizontal partitioning of the data, split into columns. - * - * The ColumnarBatch supports either on heap or offheap modes with (mostly) the identical API. - * - * TODO: - * - There are many TODOs for the existing APIs. They should throw a not implemented exception. - * - Compaction: The batch and columns should be able to compact based on a selection vector. - */ -public final class ColumnarBatch { - public static final int DEFAULT_BATCH_SIZE = 4 * 1024; - - private final StructType schema; - private final int capacity; - private int numRows; - private final ColumnVector[] columns; - - // Staging row returned from `getRow`. - private final MutableColumnarRow row; - - /** - * Called to close all the columns in this batch. It is not valid to access the data after - * calling this. This must be called at the end to clean up memory allocations. - */ - public void close() { - for (ColumnVector c: columns) { - c.close(); - } - } - - /** - * Returns an iterator over the rows in this batch. This skips rows that are filtered out. - */ - public Iterator<InternalRow> rowIterator() { - final int maxRows = numRows; - final MutableColumnarRow row = new MutableColumnarRow(columns); - return new Iterator<InternalRow>() { - int rowId = 0; - - @Override - public boolean hasNext() { - return rowId < maxRows; - } - - @Override - public InternalRow next() { - if (rowId >= maxRows) { - throw new NoSuchElementException(); - } - row.rowId = rowId++; - return row; - } - - @Override - public void remove() { - throw new UnsupportedOperationException(); - } - }; - } - - /** - * Resets the batch for writing. - */ - public void reset() { - for (int i = 0; i < numCols(); ++i) { - if (columns[i] instanceof WritableColumnVector) { - ((WritableColumnVector) columns[i]).reset(); - } - } - this.numRows = 0; - } - - /** - * Sets the number of rows that are valid. - */ - public void setNumRows(int numRows) { - assert(numRows <= this.capacity); - this.numRows = numRows; - } - - /** - * Returns the number of columns that make up this batch. - */ - public int numCols() { return columns.length; } - - /** - * Returns the number of rows for read, including filtered rows. - */ - public int numRows() { return numRows; } - - /** - * Returns the schema that makes up this batch. - */ - public StructType schema() { return schema; } - - /** - * Returns the max capacity (in number of rows) for this batch. - */ - public int capacity() { return capacity; } - - /** - * Returns the column at `ordinal`. - */ - public ColumnVector column(int ordinal) { return columns[ordinal]; } - - /** - * Returns the row in this batch at `rowId`. Returned row is reused across calls. - */ - public InternalRow getRow(int rowId) { - assert(rowId >= 0 && rowId < numRows); - row.rowId = rowId; - return row; - } - - public ColumnarBatch(StructType schema, ColumnVector[] columns, int capacity) { - this.schema = schema; - this.columns = columns; - this.capacity = capacity; - this.row = new MutableColumnarRow(columns); - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/a51212b6/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java ---------------------------------------------------------------------- diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java deleted file mode 100644 index 8bb33ed..0000000 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java +++ /dev/null @@ -1,206 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.vectorized; - -import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; -import org.apache.spark.sql.catalyst.util.MapData; -import org.apache.spark.sql.types.*; -import org.apache.spark.unsafe.types.CalendarInterval; -import org.apache.spark.unsafe.types.UTF8String; - -/** - * Row abstraction in {@link ColumnVector}. The instance of this class is intended - * to be reused, callers should copy the data out if it needs to be stored. - */ -public final class ColumnarRow extends InternalRow { - // The data for this row. - // E.g. the value of 3rd int field is `data.getChildColumn(3).getInt(rowId)`. - private final ColumnVector data; - private final int rowId; - private final int numFields; - - ColumnarRow(ColumnVector data, int rowId) { - assert (data.dataType() instanceof StructType); - this.data = data; - this.rowId = rowId; - this.numFields = ((StructType) data.dataType()).size(); - } - - @Override - public int numFields() { return numFields; } - - /** - * Revisit this. This is expensive. This is currently only used in test paths. - */ - @Override - public InternalRow copy() { - GenericInternalRow row = new GenericInternalRow(numFields); - for (int i = 0; i < numFields(); i++) { - if (isNullAt(i)) { - row.setNullAt(i); - } else { - DataType dt = data.getChildColumn(i).dataType(); - if (dt instanceof BooleanType) { - row.setBoolean(i, getBoolean(i)); - } else if (dt instanceof ByteType) { - row.setByte(i, getByte(i)); - } else if (dt instanceof ShortType) { - row.setShort(i, getShort(i)); - } else if (dt instanceof IntegerType) { - row.setInt(i, getInt(i)); - } else if (dt instanceof LongType) { - row.setLong(i, getLong(i)); - } else if (dt instanceof FloatType) { - row.setFloat(i, getFloat(i)); - } else if (dt instanceof DoubleType) { - row.setDouble(i, getDouble(i)); - } else if (dt instanceof StringType) { - row.update(i, getUTF8String(i).copy()); - } else if (dt instanceof BinaryType) { - row.update(i, getBinary(i)); - } else if (dt instanceof DecimalType) { - DecimalType t = (DecimalType)dt; - row.setDecimal(i, getDecimal(i, t.precision(), t.scale()), t.precision()); - } else if (dt instanceof DateType) { - row.setInt(i, getInt(i)); - } else if (dt instanceof TimestampType) { - row.setLong(i, getLong(i)); - } else { - throw new RuntimeException("Not implemented. " + dt); - } - } - } - return row; - } - - @Override - public boolean anyNull() { - throw new UnsupportedOperationException(); - } - - @Override - public boolean isNullAt(int ordinal) { return data.getChildColumn(ordinal).isNullAt(rowId); } - - @Override - public boolean getBoolean(int ordinal) { return data.getChildColumn(ordinal).getBoolean(rowId); } - - @Override - public byte getByte(int ordinal) { return data.getChildColumn(ordinal).getByte(rowId); } - - @Override - public short getShort(int ordinal) { return data.getChildColumn(ordinal).getShort(rowId); } - - @Override - public int getInt(int ordinal) { return data.getChildColumn(ordinal).getInt(rowId); } - - @Override - public long getLong(int ordinal) { return data.getChildColumn(ordinal).getLong(rowId); } - - @Override - public float getFloat(int ordinal) { return data.getChildColumn(ordinal).getFloat(rowId); } - - @Override - public double getDouble(int ordinal) { return data.getChildColumn(ordinal).getDouble(rowId); } - - @Override - public Decimal getDecimal(int ordinal, int precision, int scale) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - return data.getChildColumn(ordinal).getDecimal(rowId, precision, scale); - } - - @Override - public UTF8String getUTF8String(int ordinal) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - return data.getChildColumn(ordinal).getUTF8String(rowId); - } - - @Override - public byte[] getBinary(int ordinal) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - return data.getChildColumn(ordinal).getBinary(rowId); - } - - @Override - public CalendarInterval getInterval(int ordinal) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - final int months = data.getChildColumn(ordinal).getChildColumn(0).getInt(rowId); - final long microseconds = data.getChildColumn(ordinal).getChildColumn(1).getLong(rowId); - return new CalendarInterval(months, microseconds); - } - - @Override - public ColumnarRow getStruct(int ordinal, int numFields) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - return data.getChildColumn(ordinal).getStruct(rowId); - } - - @Override - public ColumnarArray getArray(int ordinal) { - if (data.getChildColumn(ordinal).isNullAt(rowId)) return null; - return data.getChildColumn(ordinal).getArray(rowId); - } - - @Override - public MapData getMap(int ordinal) { - throw new UnsupportedOperationException(); - } - - @Override - public Object get(int ordinal, DataType dataType) { - if (dataType instanceof BooleanType) { - return getBoolean(ordinal); - } else if (dataType instanceof ByteType) { - return getByte(ordinal); - } else if (dataType instanceof ShortType) { - return getShort(ordinal); - } else if (dataType instanceof IntegerType) { - return getInt(ordinal); - } else if (dataType instanceof LongType) { - return getLong(ordinal); - } else if (dataType instanceof FloatType) { - return getFloat(ordinal); - } else if (dataType instanceof DoubleType) { - return getDouble(ordinal); - } else if (dataType instanceof StringType) { - return getUTF8String(ordinal); - } else if (dataType instanceof BinaryType) { - return getBinary(ordinal); - } else if (dataType instanceof DecimalType) { - DecimalType t = (DecimalType) dataType; - return getDecimal(ordinal, t.precision(), t.scale()); - } else if (dataType instanceof DateType) { - return getInt(ordinal); - } else if (dataType instanceof TimestampType) { - return getLong(ordinal); - } else if (dataType instanceof ArrayType) { - return getArray(ordinal); - } else if (dataType instanceof StructType) { - return getStruct(ordinal, ((StructType)dataType).fields().length); - } else if (dataType instanceof MapType) { - return getMap(ordinal); - } else { - throw new UnsupportedOperationException("Datatype not supported " + dataType); - } - } - - @Override - public void update(int ordinal, Object value) { throw new UnsupportedOperationException(); } - - @Override - public void setNullAt(int ordinal) { throw new UnsupportedOperationException(); } -} http://git-wip-us.apache.org/repos/asf/spark/blob/a51212b6/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java ---------------------------------------------------------------------- diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index 06602c1..70057a9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -23,6 +23,10 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.sql.vectorized.ColumnarRow; +import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; http://git-wip-us.apache.org/repos/asf/spark/blob/a51212b6/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java ---------------------------------------------------------------------- diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 5f6f125..d2ae32b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -23,6 +23,7 @@ import com.google.common.annotations.VisibleForTesting; import org.apache.spark.sql.internal.SQLConf; import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.UTF8String; @@ -585,11 +586,11 @@ public abstract class WritableColumnVector extends ColumnVector { public final int appendStruct(boolean isNull) { if (isNull) { appendNull(); - for (ColumnVector c: childColumns) { + for (WritableColumnVector c: childColumns) { if (c.type instanceof StructType) { - ((WritableColumnVector) c).appendStruct(true); + c.appendStruct(true); } else { - ((WritableColumnVector) c).appendNull(); + c.appendNull(); } } } else { http://git-wip-us.apache.org/repos/asf/spark/blob/a51212b6/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java ---------------------------------------------------------------------- diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java new file mode 100644 index 0000000..7083332 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -0,0 +1,562 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.vectorized; + +import org.apache.arrow.vector.*; +import org.apache.arrow.vector.complex.*; +import org.apache.arrow.vector.holders.NullableVarCharHolder; + +import org.apache.spark.sql.execution.arrow.ArrowUtils; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A column vector backed by Apache Arrow. + */ +public final class ArrowColumnVector extends ColumnVector { + + private final ArrowVectorAccessor accessor; + private ArrowColumnVector[] childColumns; + + private void ensureAccessible(int index) { + ensureAccessible(index, 1); + } + + private void ensureAccessible(int index, int count) { + int valueCount = accessor.getValueCount(); + if (index < 0 || index + count > valueCount) { + throw new IndexOutOfBoundsException( + String.format("index range: [%d, %d), valueCount: %d", index, index + count, valueCount)); + } + } + + @Override + public int numNulls() { + return accessor.getNullCount(); + } + + @Override + public void close() { + if (childColumns != null) { + for (int i = 0; i < childColumns.length; i++) { + childColumns[i].close(); + } + } + accessor.close(); + } + + @Override + public boolean isNullAt(int rowId) { + ensureAccessible(rowId); + return accessor.isNullAt(rowId); + } + + @Override + public boolean getBoolean(int rowId) { + ensureAccessible(rowId); + return accessor.getBoolean(rowId); + } + + @Override + public boolean[] getBooleans(int rowId, int count) { + ensureAccessible(rowId, count); + boolean[] array = new boolean[count]; + for (int i = 0; i < count; ++i) { + array[i] = accessor.getBoolean(rowId + i); + } + return array; + } + + @Override + public byte getByte(int rowId) { + ensureAccessible(rowId); + return accessor.getByte(rowId); + } + + @Override + public byte[] getBytes(int rowId, int count) { + ensureAccessible(rowId, count); + byte[] array = new byte[count]; + for (int i = 0; i < count; ++i) { + array[i] = accessor.getByte(rowId + i); + } + return array; + } + + @Override + public short getShort(int rowId) { + ensureAccessible(rowId); + return accessor.getShort(rowId); + } + + @Override + public short[] getShorts(int rowId, int count) { + ensureAccessible(rowId, count); + short[] array = new short[count]; + for (int i = 0; i < count; ++i) { + array[i] = accessor.getShort(rowId + i); + } + return array; + } + + @Override + public int getInt(int rowId) { + ensureAccessible(rowId); + return accessor.getInt(rowId); + } + + @Override + public int[] getInts(int rowId, int count) { + ensureAccessible(rowId, count); + int[] array = new int[count]; + for (int i = 0; i < count; ++i) { + array[i] = accessor.getInt(rowId + i); + } + return array; + } + + @Override + public long getLong(int rowId) { + ensureAccessible(rowId); + return accessor.getLong(rowId); + } + + @Override + public long[] getLongs(int rowId, int count) { + ensureAccessible(rowId, count); + long[] array = new long[count]; + for (int i = 0; i < count; ++i) { + array[i] = accessor.getLong(rowId + i); + } + return array; + } + + @Override + public float getFloat(int rowId) { + ensureAccessible(rowId); + return accessor.getFloat(rowId); + } + + @Override + public float[] getFloats(int rowId, int count) { + ensureAccessible(rowId, count); + float[] array = new float[count]; + for (int i = 0; i < count; ++i) { + array[i] = accessor.getFloat(rowId + i); + } + return array; + } + + @Override + public double getDouble(int rowId) { + ensureAccessible(rowId); + return accessor.getDouble(rowId); + } + + @Override + public double[] getDoubles(int rowId, int count) { + ensureAccessible(rowId, count); + double[] array = new double[count]; + for (int i = 0; i < count; ++i) { + array[i] = accessor.getDouble(rowId + i); + } + return array; + } + + @Override + public int getArrayLength(int rowId) { + ensureAccessible(rowId); + return accessor.getArrayLength(rowId); + } + + @Override + public int getArrayOffset(int rowId) { + ensureAccessible(rowId); + return accessor.getArrayOffset(rowId); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + ensureAccessible(rowId); + return accessor.getDecimal(rowId, precision, scale); + } + + @Override + public UTF8String getUTF8String(int rowId) { + ensureAccessible(rowId); + return accessor.getUTF8String(rowId); + } + + @Override + public byte[] getBinary(int rowId) { + ensureAccessible(rowId); + return accessor.getBinary(rowId); + } + + @Override + public ArrowColumnVector arrayData() { return childColumns[0]; } + + @Override + public ArrowColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; } + + public ArrowColumnVector(ValueVector vector) { + super(ArrowUtils.fromArrowField(vector.getField())); + + if (vector instanceof BitVector) { + accessor = new BooleanAccessor((BitVector) vector); + } else if (vector instanceof TinyIntVector) { + accessor = new ByteAccessor((TinyIntVector) vector); + } else if (vector instanceof SmallIntVector) { + accessor = new ShortAccessor((SmallIntVector) vector); + } else if (vector instanceof IntVector) { + accessor = new IntAccessor((IntVector) vector); + } else if (vector instanceof BigIntVector) { + accessor = new LongAccessor((BigIntVector) vector); + } else if (vector instanceof Float4Vector) { + accessor = new FloatAccessor((Float4Vector) vector); + } else if (vector instanceof Float8Vector) { + accessor = new DoubleAccessor((Float8Vector) vector); + } else if (vector instanceof DecimalVector) { + accessor = new DecimalAccessor((DecimalVector) vector); + } else if (vector instanceof VarCharVector) { + accessor = new StringAccessor((VarCharVector) vector); + } else if (vector instanceof VarBinaryVector) { + accessor = new BinaryAccessor((VarBinaryVector) vector); + } else if (vector instanceof DateDayVector) { + accessor = new DateAccessor((DateDayVector) vector); + } else if (vector instanceof TimeStampMicroTZVector) { + accessor = new TimestampAccessor((TimeStampMicroTZVector) vector); + } else if (vector instanceof ListVector) { + ListVector listVector = (ListVector) vector; + accessor = new ArrayAccessor(listVector); + + childColumns = new ArrowColumnVector[1]; + childColumns[0] = new ArrowColumnVector(listVector.getDataVector()); + } else if (vector instanceof MapVector) { + MapVector mapVector = (MapVector) vector; + accessor = new StructAccessor(mapVector); + + childColumns = new ArrowColumnVector[mapVector.size()]; + for (int i = 0; i < childColumns.length; ++i) { + childColumns[i] = new ArrowColumnVector(mapVector.getVectorById(i)); + } + } else { + throw new UnsupportedOperationException(); + } + } + + private abstract static class ArrowVectorAccessor { + + private final ValueVector vector; + + ArrowVectorAccessor(ValueVector vector) { + this.vector = vector; + } + + // TODO: should be final after removing ArrayAccessor workaround + boolean isNullAt(int rowId) { + return vector.isNull(rowId); + } + + final int getValueCount() { + return vector.getValueCount(); + } + + final int getNullCount() { + return vector.getNullCount(); + } + + final void close() { + vector.close(); + } + + boolean getBoolean(int rowId) { + throw new UnsupportedOperationException(); + } + + byte getByte(int rowId) { + throw new UnsupportedOperationException(); + } + + short getShort(int rowId) { + throw new UnsupportedOperationException(); + } + + int getInt(int rowId) { + throw new UnsupportedOperationException(); + } + + long getLong(int rowId) { + throw new UnsupportedOperationException(); + } + + float getFloat(int rowId) { + throw new UnsupportedOperationException(); + } + + double getDouble(int rowId) { + throw new UnsupportedOperationException(); + } + + Decimal getDecimal(int rowId, int precision, int scale) { + throw new UnsupportedOperationException(); + } + + UTF8String getUTF8String(int rowId) { + throw new UnsupportedOperationException(); + } + + byte[] getBinary(int rowId) { + throw new UnsupportedOperationException(); + } + + int getArrayLength(int rowId) { + throw new UnsupportedOperationException(); + } + + int getArrayOffset(int rowId) { + throw new UnsupportedOperationException(); + } + } + + private static class BooleanAccessor extends ArrowVectorAccessor { + + private final BitVector accessor; + + BooleanAccessor(BitVector vector) { + super(vector); + this.accessor = vector; + } + + @Override + final boolean getBoolean(int rowId) { + return accessor.get(rowId) == 1; + } + } + + private static class ByteAccessor extends ArrowVectorAccessor { + + private final TinyIntVector accessor; + + ByteAccessor(TinyIntVector vector) { + super(vector); + this.accessor = vector; + } + + @Override + final byte getByte(int rowId) { + return accessor.get(rowId); + } + } + + private static class ShortAccessor extends ArrowVectorAccessor { + + private final SmallIntVector accessor; + + ShortAccessor(SmallIntVector vector) { + super(vector); + this.accessor = vector; + } + + @Override + final short getShort(int rowId) { + return accessor.get(rowId); + } + } + + private static class IntAccessor extends ArrowVectorAccessor { + + private final IntVector accessor; + + IntAccessor(IntVector vector) { + super(vector); + this.accessor = vector; + } + + @Override + final int getInt(int rowId) { + return accessor.get(rowId); + } + } + + private static class LongAccessor extends ArrowVectorAccessor { + + private final BigIntVector accessor; + + LongAccessor(BigIntVector vector) { + super(vector); + this.accessor = vector; + } + + @Override + final long getLong(int rowId) { + return accessor.get(rowId); + } + } + + private static class FloatAccessor extends ArrowVectorAccessor { + + private final Float4Vector accessor; + + FloatAccessor(Float4Vector vector) { + super(vector); + this.accessor = vector; + } + + @Override + final float getFloat(int rowId) { + return accessor.get(rowId); + } + } + + private static class DoubleAccessor extends ArrowVectorAccessor { + + private final Float8Vector accessor; + + DoubleAccessor(Float8Vector vector) { + super(vector); + this.accessor = vector; + } + + @Override + final double getDouble(int rowId) { + return accessor.get(rowId); + } + } + + private static class DecimalAccessor extends ArrowVectorAccessor { + + private final DecimalVector accessor; + + DecimalAccessor(DecimalVector vector) { + super(vector); + this.accessor = vector; + } + + @Override + final Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) return null; + return Decimal.apply(accessor.getObject(rowId), precision, scale); + } + } + + private static class StringAccessor extends ArrowVectorAccessor { + + private final VarCharVector accessor; + private final NullableVarCharHolder stringResult = new NullableVarCharHolder(); + + StringAccessor(VarCharVector vector) { + super(vector); + this.accessor = vector; + } + + @Override + final UTF8String getUTF8String(int rowId) { + accessor.get(rowId, stringResult); + if (stringResult.isSet == 0) { + return null; + } else { + return UTF8String.fromAddress(null, + stringResult.buffer.memoryAddress() + stringResult.start, + stringResult.end - stringResult.start); + } + } + } + + private static class BinaryAccessor extends ArrowVectorAccessor { + + private final VarBinaryVector accessor; + + BinaryAccessor(VarBinaryVector vector) { + super(vector); + this.accessor = vector; + } + + @Override + final byte[] getBinary(int rowId) { + return accessor.getObject(rowId); + } + } + + private static class DateAccessor extends ArrowVectorAccessor { + + private final DateDayVector accessor; + + DateAccessor(DateDayVector vector) { + super(vector); + this.accessor = vector; + } + + @Override + final int getInt(int rowId) { + return accessor.get(rowId); + } + } + + private static class TimestampAccessor extends ArrowVectorAccessor { + + private final TimeStampMicroTZVector accessor; + + TimestampAccessor(TimeStampMicroTZVector vector) { + super(vector); + this.accessor = vector; + } + + @Override + final long getLong(int rowId) { + return accessor.get(rowId); + } + } + + private static class ArrayAccessor extends ArrowVectorAccessor { + + private final ListVector accessor; + + ArrayAccessor(ListVector vector) { + super(vector); + this.accessor = vector; + } + + @Override + final boolean isNullAt(int rowId) { + // TODO: Workaround if vector has all non-null values, see ARROW-1948 + if (accessor.getValueCount() > 0 && accessor.getValidityBuffer().capacity() == 0) { + return false; + } else { + return super.isNullAt(rowId); + } + } + + @Override + final int getArrayLength(int rowId) { + return accessor.getInnerValueCountAt(rowId); + } + + @Override + final int getArrayOffset(int rowId) { + return accessor.getOffsetBuffer().getInt(rowId * accessor.OFFSET_WIDTH); + } + } + + private static class StructAccessor extends ArrowVectorAccessor { + + StructAccessor(MapVector vector) { + super(vector); + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/a51212b6/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java ---------------------------------------------------------------------- diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java new file mode 100644 index 0000000..d1196e1 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.vectorized; + +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * An interface representing in-memory columnar data in Spark. This interface defines the main APIs + * to access the data, as well as their batched versions. The batched versions are considered to be + * faster and preferable whenever possible. + * + * Most of the APIs take the rowId as a parameter. This is the batch local 0-based row id for values + * in this ColumnVector. + * + * ColumnVector supports all the data types including nested types. To handle nested types, + * ColumnVector can have children and is a tree structure. For struct type, it stores the actual + * data of each field in the corresponding child ColumnVector, and only stores null information in + * the parent ColumnVector. For array type, it stores the actual array elements in the child + * ColumnVector, and stores null information, array offsets and lengths in the parent ColumnVector. + * + * ColumnVector is expected to be reused during the entire data loading process, to avoid allocating + * memory again and again. + * + * ColumnVector is meant to maximize CPU efficiency but not to minimize storage footprint. + * Implementations should prefer computing efficiency over storage efficiency when design the + * format. Since it is expected to reuse the ColumnVector instance while loading data, the storage + * footprint is negligible. + */ +public abstract class ColumnVector implements AutoCloseable { + + /** + * Returns the data type of this column vector. + */ + public final DataType dataType() { return type; } + + /** + * Cleans up memory for this column. The column is not usable after this. + */ + public abstract void close(); + + /** + * Returns the number of nulls in this column. + */ + public abstract int numNulls(); + + /** + * Returns whether the value at rowId is NULL. + */ + public abstract boolean isNullAt(int rowId); + + /** + * Returns the value for rowId. + */ + public abstract boolean getBoolean(int rowId); + + /** + * Gets values from [rowId, rowId + count) + */ + public abstract boolean[] getBooleans(int rowId, int count); + + /** + * Returns the value for rowId. + */ + public abstract byte getByte(int rowId); + + /** + * Gets values from [rowId, rowId + count) + */ + public abstract byte[] getBytes(int rowId, int count); + + /** + * Returns the value for rowId. + */ + public abstract short getShort(int rowId); + + /** + * Gets values from [rowId, rowId + count) + */ + public abstract short[] getShorts(int rowId, int count); + + /** + * Returns the value for rowId. + */ + public abstract int getInt(int rowId); + + /** + * Gets values from [rowId, rowId + count) + */ + public abstract int[] getInts(int rowId, int count); + + /** + * Returns the value for rowId. + */ + public abstract long getLong(int rowId); + + /** + * Gets values from [rowId, rowId + count) + */ + public abstract long[] getLongs(int rowId, int count); + + /** + * Returns the value for rowId. + */ + public abstract float getFloat(int rowId); + + /** + * Gets values from [rowId, rowId + count) + */ + public abstract float[] getFloats(int rowId, int count); + + /** + * Returns the value for rowId. + */ + public abstract double getDouble(int rowId); + + /** + * Gets values from [rowId, rowId + count) + */ + public abstract double[] getDoubles(int rowId, int count); + + /** + * Returns the length of the array for rowId. + */ + public abstract int getArrayLength(int rowId); + + /** + * Returns the offset of the array for rowId. + */ + public abstract int getArrayOffset(int rowId); + + /** + * Returns the struct for rowId. + */ + public final ColumnarRow getStruct(int rowId) { + return new ColumnarRow(this, rowId); + } + + /** + * A special version of {@link #getStruct(int)}, which is only used as an adapter for Spark + * codegen framework, the second parameter is totally ignored. + */ + public final ColumnarRow getStruct(int rowId, int size) { + return getStruct(rowId); + } + + /** + * Returns the array for rowId. + */ + public final ColumnarArray getArray(int rowId) { + return new ColumnarArray(arrayData(), getArrayOffset(rowId), getArrayLength(rowId)); + } + + /** + * Returns the map for rowId. + */ + public MapData getMap(int ordinal) { + throw new UnsupportedOperationException(); + } + + /** + * Returns the decimal for rowId. + */ + public abstract Decimal getDecimal(int rowId, int precision, int scale); + + /** + * Returns the UTF8String for rowId. Note that the returned UTF8String may point to the data of + * this column vector, please copy it if you want to keep it after this column vector is freed. + */ + public abstract UTF8String getUTF8String(int rowId); + + /** + * Returns the byte array for rowId. + */ + public abstract byte[] getBinary(int rowId); + + /** + * Returns the data for the underlying array. + */ + public abstract ColumnVector arrayData(); + + /** + * Returns the ordinal's child data column. + */ + public abstract ColumnVector getChildColumn(int ordinal); + + /** + * Data type for this column. + */ + protected DataType type; + + /** + * Sets up the common state and also handles creating the child columns if this is a nested + * type. + */ + protected ColumnVector(DataType type) { + this.type = type; + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/a51212b6/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java ---------------------------------------------------------------------- diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java new file mode 100644 index 0000000..0d89a52 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.vectorized; + +import org.apache.spark.sql.catalyst.util.ArrayData; +import org.apache.spark.sql.catalyst.util.MapData; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * Array abstraction in {@link ColumnVector}. + */ +public final class ColumnarArray extends ArrayData { + // The data for this array. This array contains elements from + // data[offset] to data[offset + length). + private final ColumnVector data; + private final int offset; + private final int length; + + public ColumnarArray(ColumnVector data, int offset, int length) { + this.data = data; + this.offset = offset; + this.length = length; + } + + @Override + public int numElements() { + return length; + } + + @Override + public ArrayData copy() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean[] toBooleanArray() { return data.getBooleans(offset, length); } + + @Override + public byte[] toByteArray() { return data.getBytes(offset, length); } + + @Override + public short[] toShortArray() { return data.getShorts(offset, length); } + + @Override + public int[] toIntArray() { return data.getInts(offset, length); } + + @Override + public long[] toLongArray() { return data.getLongs(offset, length); } + + @Override + public float[] toFloatArray() { return data.getFloats(offset, length); } + + @Override + public double[] toDoubleArray() { return data.getDoubles(offset, length); } + + // TODO: this is extremely expensive. + @Override + public Object[] array() { + DataType dt = data.dataType(); + Object[] list = new Object[length]; + try { + for (int i = 0; i < length; i++) { + if (!data.isNullAt(offset + i)) { + list[i] = get(i, dt); + } + } + return list; + } catch(Exception e) { + throw new RuntimeException("Could not get the array", e); + } + } + + @Override + public boolean isNullAt(int ordinal) { return data.isNullAt(offset + ordinal); } + + @Override + public boolean getBoolean(int ordinal) { + return data.getBoolean(offset + ordinal); + } + + @Override + public byte getByte(int ordinal) { return data.getByte(offset + ordinal); } + + @Override + public short getShort(int ordinal) { + return data.getShort(offset + ordinal); + } + + @Override + public int getInt(int ordinal) { return data.getInt(offset + ordinal); } + + @Override + public long getLong(int ordinal) { return data.getLong(offset + ordinal); } + + @Override + public float getFloat(int ordinal) { + return data.getFloat(offset + ordinal); + } + + @Override + public double getDouble(int ordinal) { return data.getDouble(offset + ordinal); } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + return data.getDecimal(offset + ordinal, precision, scale); + } + + @Override + public UTF8String getUTF8String(int ordinal) { + return data.getUTF8String(offset + ordinal); + } + + @Override + public byte[] getBinary(int ordinal) { + return data.getBinary(offset + ordinal); + } + + @Override + public CalendarInterval getInterval(int ordinal) { + int month = data.getChildColumn(0).getInt(offset + ordinal); + long microseconds = data.getChildColumn(1).getLong(offset + ordinal); + return new CalendarInterval(month, microseconds); + } + + @Override + public ColumnarRow getStruct(int ordinal, int numFields) { + return data.getStruct(offset + ordinal); + } + + @Override + public ColumnarArray getArray(int ordinal) { + return data.getArray(offset + ordinal); + } + + @Override + public MapData getMap(int ordinal) { + throw new UnsupportedOperationException(); + } + + @Override + public Object get(int ordinal, DataType dataType) { + if (dataType instanceof BooleanType) { + return getBoolean(ordinal); + } else if (dataType instanceof ByteType) { + return getByte(ordinal); + } else if (dataType instanceof ShortType) { + return getShort(ordinal); + } else if (dataType instanceof IntegerType) { + return getInt(ordinal); + } else if (dataType instanceof LongType) { + return getLong(ordinal); + } else if (dataType instanceof FloatType) { + return getFloat(ordinal); + } else if (dataType instanceof DoubleType) { + return getDouble(ordinal); + } else if (dataType instanceof StringType) { + return getUTF8String(ordinal); + } else if (dataType instanceof BinaryType) { + return getBinary(ordinal); + } else if (dataType instanceof DecimalType) { + DecimalType t = (DecimalType) dataType; + return getDecimal(ordinal, t.precision(), t.scale()); + } else if (dataType instanceof DateType) { + return getInt(ordinal); + } else if (dataType instanceof TimestampType) { + return getLong(ordinal); + } else if (dataType instanceof ArrayType) { + return getArray(ordinal); + } else if (dataType instanceof StructType) { + return getStruct(ordinal, ((StructType)dataType).fields().length); + } else if (dataType instanceof MapType) { + return getMap(ordinal); + } else if (dataType instanceof CalendarIntervalType) { + return getInterval(ordinal); + } else { + throw new UnsupportedOperationException("Datatype not supported " + dataType); + } + } + + @Override + public void update(int ordinal, Object value) { throw new UnsupportedOperationException(); } + + @Override + public void setNullAt(int ordinal) { throw new UnsupportedOperationException(); } +} http://git-wip-us.apache.org/repos/asf/spark/blob/a51212b6/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java ---------------------------------------------------------------------- diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java new file mode 100644 index 0000000..9ae1c6d --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java @@ -0,0 +1,129 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.vectorized; + +import java.util.*; + +import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.execution.vectorized.MutableColumnarRow; +import org.apache.spark.sql.types.StructType; + +/** + * This class wraps multiple ColumnVectors as a row-wise table. It provides a row view of this + * batch so that Spark can access the data row by row. Instance of it is meant to be reused during + * the entire data loading process. + */ +public final class ColumnarBatch { + public static final int DEFAULT_BATCH_SIZE = 4 * 1024; + + private final StructType schema; + private final int capacity; + private int numRows; + private final ColumnVector[] columns; + + // Staging row returned from `getRow`. + private final MutableColumnarRow row; + + /** + * Called to close all the columns in this batch. It is not valid to access the data after + * calling this. This must be called at the end to clean up memory allocations. + */ + public void close() { + for (ColumnVector c: columns) { + c.close(); + } + } + + /** + * Returns an iterator over the rows in this batch. + */ + public Iterator<InternalRow> rowIterator() { + final int maxRows = numRows; + final MutableColumnarRow row = new MutableColumnarRow(columns); + return new Iterator<InternalRow>() { + int rowId = 0; + + @Override + public boolean hasNext() { + return rowId < maxRows; + } + + @Override + public InternalRow next() { + if (rowId >= maxRows) { + throw new NoSuchElementException(); + } + row.rowId = rowId++; + return row; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + + /** + * Sets the number of rows in this batch. + */ + public void setNumRows(int numRows) { + assert(numRows <= this.capacity); + this.numRows = numRows; + } + + /** + * Returns the number of columns that make up this batch. + */ + public int numCols() { return columns.length; } + + /** + * Returns the number of rows for read, including filtered rows. + */ + public int numRows() { return numRows; } + + /** + * Returns the schema that makes up this batch. + */ + public StructType schema() { return schema; } + + /** + * Returns the max capacity (in number of rows) for this batch. + */ + public int capacity() { return capacity; } + + /** + * Returns the column at `ordinal`. + */ + public ColumnVector column(int ordinal) { return columns[ordinal]; } + + /** + * Returns the row in this batch at `rowId`. Returned row is reused across calls. + */ + public InternalRow getRow(int rowId) { + assert(rowId >= 0 && rowId < numRows); + row.rowId = rowId; + return row; + } + + public ColumnarBatch(StructType schema, ColumnVector[] columns, int capacity) { + this.schema = schema; + this.columns = columns; + this.capacity = capacity; + this.row = new MutableColumnarRow(columns); + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org