This is an automated email from the ASF dual-hosted git repository.
blue pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iceberg.git
The following commit(s) were added to refs/heads/master by this push:
new 6fab8f5 Spark: Support ORC vectorized reads (#1189)
6fab8f5 is described below
commit 6fab8f57bdb7e5fe7eadc3ff41558581338e1b69
Author: Shardul Mahadik <[email protected]>
AuthorDate: Mon Jul 13 14:27:36 2020 -0700
Spark: Support ORC vectorized reads (#1189)
---
.gitignore | 1 +
.../org/apache/iceberg/io/CloseableIterator.java | 23 ++
orc/src/main/java/org/apache/iceberg/orc/ORC.java | 19 +-
.../org/apache/iceberg/orc/OrcBatchReader.java | 19 +-
.../java/org/apache/iceberg/orc/OrcIterable.java | 37 +-
.../iceberg/orc/VectorizedRowBatchIterator.java | 9 +-
.../apache/iceberg/spark/data/SparkOrcReader.java | 7 +-
.../iceberg/spark/data/SparkOrcValueReaders.java | 19 +-
...ColumnVector.java => ConstantColumnVector.java} | 47 ++-
.../data/vectorized/IcebergArrowColumnVector.java | 3 +-
.../data/vectorized/VectorizedSparkOrcReaders.java | 416 +++++++++++++++++++++
.../iceberg/spark/source/BaseDataReader.java | 36 ++
.../iceberg/spark/source/BatchDataReader.java | 39 ++
.../apache/iceberg/spark/source/RowDataReader.java | 36 --
.../org/apache/iceberg/spark/data/TestHelpers.java | 4 +-
.../iceberg/spark/data/TestSparkOrcReader.java | 22 ++
.../spark/source/TestIdentityPartitionData.java | 20 +-
.../iceberg/spark/source/TestPartitionValues.java | 26 +-
.../spark/source/TestSparkReadProjection.java | 7 +-
.../orc/IcebergSourceFlatORCDataReadBenchmark.java | 31 +-
.../IcebergSourceNestedORCDataReadBenchmark.java | 30 +-
.../org/apache/iceberg/spark/source/Reader.java | 11 +-
.../iceberg/spark/source/TestFilteredScan.java | 44 ++-
.../spark/source/TestIdentityPartitionData24.java | 4 +-
.../spark/source/TestPartitionValues24.java | 4 +-
.../iceberg/spark/source/SparkBatchScan.java | 11 +-
.../iceberg/spark/source/TestFilteredScan.java | 111 +++---
.../spark/source/TestIdentityPartitionData3.java | 4 +-
.../iceberg/spark/source/TestPartitionValues3.java | 4 +-
29 files changed, 841 insertions(+), 203 deletions(-)
diff --git a/.gitignore b/.gitignore
index 7b14372..af76d7c 100644
--- a/.gitignore
+++ b/.gitignore
@@ -34,6 +34,7 @@ sdist/
coverage.xml
.pytest_cache/
spark/tmp/
+spark-warehouse/
spark/spark-warehouse/
spark2/spark-warehouse/
spark3/spark-warehouse/
diff --git a/api/src/main/java/org/apache/iceberg/io/CloseableIterator.java
b/api/src/main/java/org/apache/iceberg/io/CloseableIterator.java
index e2b2e4f..079190b 100644
--- a/api/src/main/java/org/apache/iceberg/io/CloseableIterator.java
+++ b/api/src/main/java/org/apache/iceberg/io/CloseableIterator.java
@@ -23,6 +23,8 @@ import java.io.Closeable;
import java.io.IOException;
import java.util.Collections;
import java.util.Iterator;
+import java.util.function.Function;
+import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
public interface CloseableIterator<T> extends Iterator<T>, Closeable {
@@ -54,4 +56,25 @@ public interface CloseableIterator<T> extends Iterator<T>,
Closeable {
}
};
}
+
+ static <I, O> CloseableIterator<O> transform(CloseableIterator<I> iterator,
Function<I, O> transform) {
+ Preconditions.checkNotNull(transform, "Cannot apply a null transform");
+
+ return new CloseableIterator<O>() {
+ @Override
+ public void close() throws IOException {
+ iterator.close();
+ }
+
+ @Override
+ public boolean hasNext() {
+ return iterator.hasNext();
+ }
+
+ @Override
+ public O next() {
+ return transform.apply(iterator.next());
+ }
+ };
+ }
}
diff --git a/orc/src/main/java/org/apache/iceberg/orc/ORC.java
b/orc/src/main/java/org/apache/iceberg/orc/ORC.java
index 9e03549..38dc522 100644
--- a/orc/src/main/java/org/apache/iceberg/orc/ORC.java
+++ b/orc/src/main/java/org/apache/iceberg/orc/ORC.java
@@ -127,6 +127,8 @@ public class ORC {
private boolean caseSensitive = true;
private Function<TypeDescription, OrcRowReader<?>> readerFunc;
+ private Function<TypeDescription, OrcBatchReader<?>> batchedReaderFunc;
+ private int recordsPerBatch = VectorizedRowBatch.DEFAULT_SIZE;
private ReadBuilder(InputFile file) {
Preconditions.checkNotNull(file, "Input file cannot be null");
@@ -168,6 +170,8 @@ public class ORC {
}
public ReadBuilder createReaderFunc(Function<TypeDescription,
OrcRowReader<?>> readerFunction) {
+ Preconditions.checkArgument(this.batchedReaderFunc == null,
+ "Reader function cannot be set since the batched version is already
set");
this.readerFunc = readerFunction;
return this;
}
@@ -177,9 +181,22 @@ public class ORC {
return this;
}
+ public ReadBuilder createBatchedReaderFunc(Function<TypeDescription,
OrcBatchReader<?>> batchReaderFunction) {
+ Preconditions.checkArgument(this.readerFunc == null,
+ "Batched reader function cannot be set since the non-batched version
is already set");
+ this.batchedReaderFunc = batchReaderFunction;
+ return this;
+ }
+
+ public ReadBuilder recordsPerBatch(int numRecordsPerBatch) {
+ this.recordsPerBatch = numRecordsPerBatch;
+ return this;
+ }
+
public <D> CloseableIterable<D> build() {
Preconditions.checkNotNull(schema, "Schema is required");
- return new OrcIterable<>(file, conf, schema, start, length, readerFunc,
caseSensitive, filter);
+ return new OrcIterable<>(file, conf, schema, start, length, readerFunc,
caseSensitive, filter, batchedReaderFunc,
+ recordsPerBatch);
}
}
diff --git
a/spark2/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues24.java
b/orc/src/main/java/org/apache/iceberg/orc/OrcBatchReader.java
similarity index 74%
copy from
spark2/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues24.java
copy to orc/src/main/java/org/apache/iceberg/orc/OrcBatchReader.java
index f9da71e..86dcc65 100644
---
a/spark2/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues24.java
+++ b/orc/src/main/java/org/apache/iceberg/orc/OrcBatchReader.java
@@ -17,10 +17,19 @@
* under the License.
*/
-package org.apache.iceberg.spark.source;
+package org.apache.iceberg.orc;
+
+import org.apache.orc.storage.ql.exec.vector.VectorizedRowBatch;
+
+/**
+ * Used for implementing ORC batch readers.
+ */
+@FunctionalInterface
+public interface OrcBatchReader<T> {
+
+ /**
+ * Reads a row batch.
+ */
+ T read(VectorizedRowBatch batch);
-public class TestPartitionValues24 extends TestPartitionValues {
- public TestPartitionValues24(String format) {
- super(format);
- }
}
diff --git a/orc/src/main/java/org/apache/iceberg/orc/OrcIterable.java
b/orc/src/main/java/org/apache/iceberg/orc/OrcIterable.java
index 88a421a..4dc29b1 100644
--- a/orc/src/main/java/org/apache/iceberg/orc/OrcIterable.java
+++ b/orc/src/main/java/org/apache/iceberg/orc/OrcIterable.java
@@ -20,7 +20,6 @@
package org.apache.iceberg.orc;
import java.io.IOException;
-import java.util.Iterator;
import java.util.function.Function;
import org.apache.hadoop.conf.Configuration;
import org.apache.iceberg.Schema;
@@ -49,10 +48,13 @@ class OrcIterable<T> extends CloseableGroup implements
CloseableIterable<T> {
private final Function<TypeDescription, OrcRowReader<?>> readerFunction;
private final Expression filter;
private final boolean caseSensitive;
+ private final Function<TypeDescription, OrcBatchReader<?>>
batchReaderFunction;
+ private final int recordsPerBatch;
OrcIterable(InputFile file, Configuration config, Schema schema,
Long start, Long length,
- Function<TypeDescription, OrcRowReader<?>> readerFunction,
boolean caseSensitive, Expression filter) {
+ Function<TypeDescription, OrcRowReader<?>> readerFunction,
boolean caseSensitive, Expression filter,
+ Function<TypeDescription, OrcBatchReader<?>>
batchReaderFunction, int recordsPerBatch) {
this.schema = schema;
this.readerFunction = readerFunction;
this.file = file;
@@ -61,6 +63,8 @@ class OrcIterable<T> extends CloseableGroup implements
CloseableIterable<T> {
this.config = config;
this.caseSensitive = caseSensitive;
this.filter = (filter == Expressions.alwaysTrue()) ? null : filter;
+ this.batchReaderFunction = batchReaderFunction;
+ this.recordsPerBatch = recordsPerBatch;
}
@SuppressWarnings("unchecked")
@@ -75,16 +79,22 @@ class OrcIterable<T> extends CloseableGroup implements
CloseableIterable<T> {
Expression boundFilter = Binder.bind(schema.asStruct(), filter,
caseSensitive);
sarg = ExpressionToSearchArgument.convert(boundFilter, readOrcSchema);
}
- Iterator<T> iterator = new OrcIterator(
- newOrcIterator(file, readOrcSchema, start, length, orcFileReader,
sarg),
- readerFunction.apply(readOrcSchema));
- return CloseableIterator.withClose(iterator);
+
+ VectorizedRowBatchIterator rowBatchIterator = newOrcIterator(file,
readOrcSchema, start, length, orcFileReader,
+ sarg, recordsPerBatch);
+ if (batchReaderFunction != null) {
+ OrcBatchReader<T> batchReader = (OrcBatchReader<T>)
batchReaderFunction.apply(readOrcSchema);
+ return CloseableIterator.transform(rowBatchIterator, batchReader::read);
+ } else {
+ return new OrcRowIterator<>(rowBatchIterator, (OrcRowReader<T>)
readerFunction.apply(readOrcSchema));
+ }
}
private static VectorizedRowBatchIterator newOrcIterator(InputFile file,
TypeDescription
readerSchema,
Long start, Long
length,
- Reader
orcFileReader, SearchArgument sarg) {
+ Reader
orcFileReader, SearchArgument sarg,
+ int
recordsPerBatch) {
final Reader.Options options = orcFileReader.options();
if (start != null) {
options.range(start, length);
@@ -93,13 +103,14 @@ class OrcIterable<T> extends CloseableGroup implements
CloseableIterable<T> {
options.searchArgument(sarg, new String[]{});
try {
- return new VectorizedRowBatchIterator(file.location(), readerSchema,
orcFileReader.rows(options));
+ return new VectorizedRowBatchIterator(file.location(), readerSchema,
orcFileReader.rows(options),
+ recordsPerBatch);
} catch (IOException ioe) {
throw new RuntimeIOException(ioe, "Failed to get ORC rows for file: %s",
file);
}
}
- private static class OrcIterator<T> implements Iterator<T> {
+ private static class OrcRowIterator<T> implements CloseableIterator<T> {
private int nextRow;
private VectorizedRowBatch current;
@@ -107,7 +118,7 @@ class OrcIterable<T> extends CloseableGroup implements
CloseableIterable<T> {
private final VectorizedRowBatchIterator batchIter;
private final OrcRowReader<T> reader;
- OrcIterator(VectorizedRowBatchIterator batchIter, OrcRowReader<T> reader) {
+ OrcRowIterator(VectorizedRowBatchIterator batchIter, OrcRowReader<T>
reader) {
this.batchIter = batchIter;
this.reader = reader;
current = null;
@@ -128,6 +139,10 @@ class OrcIterable<T> extends CloseableGroup implements
CloseableIterable<T> {
return this.reader.read(current, nextRow++);
}
- }
+ @Override
+ public void close() throws IOException {
+ batchIter.close();
+ }
+ }
}
diff --git
a/orc/src/main/java/org/apache/iceberg/orc/VectorizedRowBatchIterator.java
b/orc/src/main/java/org/apache/iceberg/orc/VectorizedRowBatchIterator.java
index 125a37c..7f3abbf 100644
--- a/orc/src/main/java/org/apache/iceberg/orc/VectorizedRowBatchIterator.java
+++ b/orc/src/main/java/org/apache/iceberg/orc/VectorizedRowBatchIterator.java
@@ -19,10 +19,9 @@
package org.apache.iceberg.orc;
-import java.io.Closeable;
import java.io.IOException;
-import java.util.Iterator;
import org.apache.iceberg.exceptions.RuntimeIOException;
+import org.apache.iceberg.io.CloseableIterator;
import org.apache.orc.RecordReader;
import org.apache.orc.TypeDescription;
import org.apache.orc.storage.ql.exec.vector.VectorizedRowBatch;
@@ -32,16 +31,16 @@ import
org.apache.orc.storage.ql.exec.vector.VectorizedRowBatch;
* Because the same VectorizedRowBatch is reused on each call to next,
* it gets changed when hasNext or next is called.
*/
-public class VectorizedRowBatchIterator implements
Iterator<VectorizedRowBatch>, Closeable {
+public class VectorizedRowBatchIterator implements
CloseableIterator<VectorizedRowBatch> {
private final String fileLocation;
private final RecordReader rows;
private final VectorizedRowBatch batch;
private boolean advanced = false;
- VectorizedRowBatchIterator(String fileLocation, TypeDescription schema,
RecordReader rows) {
+ VectorizedRowBatchIterator(String fileLocation, TypeDescription schema,
RecordReader rows, int recordsPerBatch) {
this.fileLocation = fileLocation;
this.rows = rows;
- this.batch = schema.createRowBatch();
+ this.batch = schema.createRowBatch(recordsPerBatch);
}
@Override
diff --git
a/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcReader.java
b/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcReader.java
index da1ee6e..6e6eb56 100644
--- a/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcReader.java
+++ b/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcReader.java
@@ -32,7 +32,6 @@ import org.apache.orc.TypeDescription;
import org.apache.orc.storage.ql.exec.vector.StructColumnVector;
import org.apache.orc.storage.ql.exec.vector.VectorizedRowBatch;
import org.apache.spark.sql.catalyst.InternalRow;
-import org.apache.spark.sql.types.Decimal;
/**
* Converts the OrcIterator, which returns ORC's VectorizedRowBatch to a
@@ -103,11 +102,7 @@ public class SparkOrcReader implements
OrcRowReader<InternalRow> {
case TIMESTAMP_INSTANT:
return SparkOrcValueReaders.timestampTzs();
case DECIMAL:
- if (primitive.getPrecision() <= Decimal.MAX_LONG_DIGITS()) {
- return new
SparkOrcValueReaders.Decimal18Reader(primitive.getPrecision(),
primitive.getScale());
- } else {
- return new
SparkOrcValueReaders.Decimal38Reader(primitive.getPrecision(),
primitive.getScale());
- }
+ return SparkOrcValueReaders.decimals(primitive.getPrecision(),
primitive.getScale());
case CHAR:
case VARCHAR:
case STRING:
diff --git
a/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueReaders.java
b/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueReaders.java
index 5add499..ab9ee43 100644
---
a/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueReaders.java
+++
b/spark/src/main/java/org/apache/iceberg/spark/data/SparkOrcValueReaders.java
@@ -42,19 +42,26 @@ import org.apache.spark.sql.catalyst.util.MapData;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.unsafe.types.UTF8String;
-
-class SparkOrcValueReaders {
+public class SparkOrcValueReaders {
private SparkOrcValueReaders() {
}
- static OrcValueReader<UTF8String> utf8String() {
+ public static OrcValueReader<UTF8String> utf8String() {
return StringReader.INSTANCE;
}
- static OrcValueReader<?> timestampTzs() {
+ public static OrcValueReader<Long> timestampTzs() {
return TimestampTzReader.INSTANCE;
}
+ public static OrcValueReader<Decimal> decimals(int precision, int scale) {
+ if (precision <= Decimal.MAX_LONG_DIGITS()) {
+ return new SparkOrcValueReaders.Decimal18Reader(precision, scale);
+ } else {
+ return new SparkOrcValueReaders.Decimal38Reader(precision, scale);
+ }
+ }
+
static OrcValueReader<?> struct(
List<OrcValueReader<?>> readers, Types.StructType struct, Map<Integer,
?> idToConstant) {
return new StructReader(readers, struct, idToConstant);
@@ -164,7 +171,7 @@ class SparkOrcValueReaders {
}
}
- static class Decimal18Reader implements OrcValueReader<Decimal> {
+ private static class Decimal18Reader implements OrcValueReader<Decimal> {
//TODO: these are being unused. check for bug
private final int precision;
private final int scale;
@@ -181,7 +188,7 @@ class SparkOrcValueReaders {
}
}
- static class Decimal38Reader implements OrcValueReader<Decimal> {
+ private static class Decimal38Reader implements OrcValueReader<Decimal> {
private final int precision;
private final int scale;
diff --git
a/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/NullValuesColumnVector.java
b/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ConstantColumnVector.java
similarity index 65%
rename from
spark/src/main/java/org/apache/iceberg/spark/data/vectorized/NullValuesColumnVector.java
rename to
spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ConstantColumnVector.java
index 8770d13..c3acbc4 100644
---
a/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/NullValuesColumnVector.java
+++
b/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ConstantColumnVector.java
@@ -21,105 +21,104 @@ package org.apache.iceberg.spark.data.vectorized;
import org.apache.iceberg.spark.SparkSchemaUtil;
import org.apache.iceberg.types.Type;
-import org.apache.iceberg.types.Types;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.vectorized.ColumnVector;
import org.apache.spark.sql.vectorized.ColumnarArray;
import org.apache.spark.sql.vectorized.ColumnarMap;
import org.apache.spark.unsafe.types.UTF8String;
-public class NullValuesColumnVector extends ColumnVector {
+class ConstantColumnVector extends ColumnVector {
- private final int numNulls;
- private static final Type NULL_TYPE = Types.IntegerType.get();
+ private final Object constant;
+ private final int batchSize;
- public NullValuesColumnVector(int nValues) {
- super(SparkSchemaUtil.convert(NULL_TYPE));
- this.numNulls = nValues;
+ ConstantColumnVector(Type type, int batchSize, Object constant) {
+ super(SparkSchemaUtil.convert(type));
+ this.constant = constant;
+ this.batchSize = batchSize;
}
@Override
public void close() {
-
}
@Override
public boolean hasNull() {
- return true;
+ return constant == null;
}
@Override
public int numNulls() {
- return numNulls;
+ return constant == null ? batchSize : 0;
}
@Override
public boolean isNullAt(int rowId) {
- return true;
+ return constant == null;
}
@Override
public boolean getBoolean(int rowId) {
- throw new UnsupportedOperationException();
+ return constant != null ? (boolean) constant : false;
}
@Override
public byte getByte(int rowId) {
- throw new UnsupportedOperationException();
+ return constant != null ? (byte) constant : 0;
}
@Override
public short getShort(int rowId) {
- throw new UnsupportedOperationException();
+ return constant != null ? (short) constant : 0;
}
@Override
public int getInt(int rowId) {
- throw new UnsupportedOperationException();
+ return constant != null ? (int) constant : 0;
}
@Override
public long getLong(int rowId) {
- throw new UnsupportedOperationException();
+ return constant != null ? (long) constant : 0L;
}
@Override
public float getFloat(int rowId) {
- throw new UnsupportedOperationException();
+ return constant != null ? (float) constant : 0.0F;
}
@Override
public double getDouble(int rowId) {
- throw new UnsupportedOperationException();
+ return constant != null ? (double) constant : 0.0;
}
@Override
public ColumnarArray getArray(int rowId) {
- throw new UnsupportedOperationException();
+ throw new UnsupportedOperationException("ConstantColumnVector only
supports primitives");
}
@Override
public ColumnarMap getMap(int ordinal) {
- throw new UnsupportedOperationException();
+ throw new UnsupportedOperationException("ConstantColumnVector only
supports primitives");
}
@Override
public Decimal getDecimal(int rowId, int precision, int scale) {
- throw new UnsupportedOperationException();
+ return (Decimal) constant;
}
@Override
public UTF8String getUTF8String(int rowId) {
- throw new UnsupportedOperationException();
+ return (UTF8String) constant;
}
@Override
public byte[] getBinary(int rowId) {
- throw new UnsupportedOperationException();
+ return (byte[]) constant;
}
@Override
protected ColumnVector getChild(int ordinal) {
- throw new UnsupportedOperationException();
+ throw new UnsupportedOperationException("ConstantColumnVector only
supports primitives");
}
}
diff --git
a/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/IcebergArrowColumnVector.java
b/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/IcebergArrowColumnVector.java
index 9d10cd9..60cd17e 100644
---
a/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/IcebergArrowColumnVector.java
+++
b/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/IcebergArrowColumnVector.java
@@ -22,6 +22,7 @@ package org.apache.iceberg.spark.data.vectorized;
import org.apache.iceberg.arrow.vectorized.NullabilityHolder;
import org.apache.iceberg.arrow.vectorized.VectorHolder;
import org.apache.iceberg.spark.SparkSchemaUtil;
+import org.apache.iceberg.types.Types;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.vectorized.ArrowColumnVector;
import org.apache.spark.sql.vectorized.ColumnVector;
@@ -143,7 +144,7 @@ public class IcebergArrowColumnVector extends ColumnVector {
}
static ColumnVector forHolder(VectorHolder holder, int numRows) {
- return holder.isDummy() ? new NullValuesColumnVector(numRows) :
+ return holder.isDummy() ? new
ConstantColumnVector(Types.IntegerType.get(), numRows, null) :
new IcebergArrowColumnVector(holder);
}
diff --git
a/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkOrcReaders.java
b/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkOrcReaders.java
new file mode 100644
index 0000000..564fcfa
--- /dev/null
+++
b/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkOrcReaders.java
@@ -0,0 +1,416 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iceberg.spark.data.vectorized;
+
+import java.util.List;
+import java.util.Map;
+import java.util.stream.IntStream;
+import org.apache.iceberg.Schema;
+import org.apache.iceberg.orc.OrcBatchReader;
+import org.apache.iceberg.orc.OrcSchemaWithTypeVisitor;
+import org.apache.iceberg.orc.OrcValueReader;
+import org.apache.iceberg.orc.OrcValueReaders;
+import org.apache.iceberg.relocated.com.google.common.collect.Lists;
+import org.apache.iceberg.spark.SparkSchemaUtil;
+import org.apache.iceberg.spark.data.SparkOrcValueReaders;
+import org.apache.iceberg.types.Type;
+import org.apache.iceberg.types.Types;
+import org.apache.orc.TypeDescription;
+import org.apache.orc.storage.ql.exec.vector.ListColumnVector;
+import org.apache.orc.storage.ql.exec.vector.MapColumnVector;
+import org.apache.orc.storage.ql.exec.vector.StructColumnVector;
+import org.apache.spark.sql.types.Decimal;
+import org.apache.spark.sql.vectorized.ColumnVector;
+import org.apache.spark.sql.vectorized.ColumnarArray;
+import org.apache.spark.sql.vectorized.ColumnarBatch;
+import org.apache.spark.sql.vectorized.ColumnarMap;
+import org.apache.spark.unsafe.types.UTF8String;
+
+public class VectorizedSparkOrcReaders {
+
+ private VectorizedSparkOrcReaders() {
+ }
+
+ public static OrcBatchReader<ColumnarBatch> buildReader(Schema
expectedSchema, TypeDescription fileSchema,
+ Map<Integer, ?>
idToConstant) {
+ Converter converter = OrcSchemaWithTypeVisitor.visit(expectedSchema,
fileSchema, new ReadBuilder(idToConstant));
+
+ return batch -> {
+ BaseOrcColumnVector cv = (BaseOrcColumnVector) converter.convert(new
StructColumnVector(batch.size, batch.cols),
+ batch.size);
+ ColumnarBatch columnarBatch = new ColumnarBatch(IntStream.range(0,
expectedSchema.columns().size())
+ .mapToObj(cv::getChild)
+ .toArray(ColumnVector[]::new));
+ columnarBatch.setNumRows(batch.size);
+ return columnarBatch;
+ };
+ }
+
+ private interface Converter {
+ ColumnVector convert(org.apache.orc.storage.ql.exec.vector.ColumnVector
columnVector, int batchSize);
+ }
+
+ private static class ReadBuilder extends OrcSchemaWithTypeVisitor<Converter>
{
+ private final Map<Integer, ?> idToConstant;
+
+ private ReadBuilder(Map<Integer, ?> idToConstant) {
+ this.idToConstant = idToConstant;
+ }
+
+ @Override
+ public Converter record(Types.StructType iStruct, TypeDescription record,
List<String> names,
+ List<Converter> fields) {
+ return new StructConverter(iStruct, fields, idToConstant);
+ }
+
+ @Override
+ public Converter list(Types.ListType iList, TypeDescription array,
Converter element) {
+ return new ArrayConverter(iList, element);
+ }
+
+ @Override
+ public Converter map(Types.MapType iMap, TypeDescription map, Converter
key, Converter value) {
+ return new MapConverter(iMap, key, value);
+ }
+
+ @Override
+ public Converter primitive(Type.PrimitiveType iPrimitive, TypeDescription
primitive) {
+ final OrcValueReader<?> primitiveValueReader;
+ switch (primitive.getCategory()) {
+ case BOOLEAN:
+ primitiveValueReader = OrcValueReaders.booleans();
+ break;
+ case BYTE:
+ // Iceberg does not have a byte type. Use int
+ case SHORT:
+ // Iceberg does not have a short type. Use int
+ case DATE:
+ case INT:
+ primitiveValueReader = OrcValueReaders.ints();
+ break;
+ case LONG:
+ primitiveValueReader = OrcValueReaders.longs();
+ break;
+ case FLOAT:
+ primitiveValueReader = OrcValueReaders.floats();
+ break;
+ case DOUBLE:
+ primitiveValueReader = OrcValueReaders.doubles();
+ break;
+ case TIMESTAMP_INSTANT:
+ primitiveValueReader = SparkOrcValueReaders.timestampTzs();
+ break;
+ case DECIMAL:
+ primitiveValueReader =
SparkOrcValueReaders.decimals(primitive.getPrecision(), primitive.getScale());
+ break;
+ case CHAR:
+ case VARCHAR:
+ case STRING:
+ primitiveValueReader = SparkOrcValueReaders.utf8String();
+ break;
+ case BINARY:
+ primitiveValueReader = OrcValueReaders.bytes();
+ break;
+ default:
+ throw new IllegalArgumentException("Unhandled type " + primitive);
+ }
+ return (columnVector, batchSize) ->
+ new PrimitiveOrcColumnVector(iPrimitive, batchSize, columnVector,
primitiveValueReader);
+ }
+ }
+
+ private abstract static class BaseOrcColumnVector extends ColumnVector {
+ private final org.apache.orc.storage.ql.exec.vector.ColumnVector vector;
+ private final int batchSize;
+ private Integer numNulls;
+
+ BaseOrcColumnVector(Type type, int batchSize,
org.apache.orc.storage.ql.exec.vector.ColumnVector vector) {
+ super(SparkSchemaUtil.convert(type));
+ this.vector = vector;
+ this.batchSize = batchSize;
+ }
+
+ @Override
+ public void close() {
+ }
+
+ @Override
+ public boolean hasNull() {
+ return !vector.noNulls;
+ }
+
+ @Override
+ public int numNulls() {
+ if (numNulls == null) {
+ numNulls = numNullsHelper();
+ }
+ return numNulls;
+ }
+
+ private int numNullsHelper() {
+ if (vector.isRepeating) {
+ if (vector.isNull[0]) {
+ return batchSize;
+ } else {
+ return 0;
+ }
+ } else if (vector.noNulls) {
+ return 0;
+ } else {
+ int count = 0;
+ for (int i = 0; i < batchSize; i++) {
+ if (vector.isNull[i]) {
+ count++;
+ }
+ }
+ return count;
+ }
+ }
+
+ protected int getRowIndex(int rowId) {
+ return vector.isRepeating ? 0 : rowId;
+ }
+
+ @Override
+ public boolean isNullAt(int rowId) {
+ return vector.isNull[getRowIndex(rowId)];
+ }
+
+ @Override
+ public boolean getBoolean(int rowId) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public byte getByte(int rowId) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public short getShort(int rowId) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int getInt(int rowId) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public long getLong(int rowId) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public float getFloat(int rowId) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public double getDouble(int rowId) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Decimal getDecimal(int rowId, int precision, int scale) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public UTF8String getUTF8String(int rowId) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public byte[] getBinary(int rowId) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public ColumnarArray getArray(int rowId) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public ColumnarMap getMap(int rowId) {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public ColumnVector getChild(int ordinal) {
+ throw new UnsupportedOperationException();
+ }
+ }
+
+ private static class PrimitiveOrcColumnVector extends BaseOrcColumnVector {
+ private final org.apache.orc.storage.ql.exec.vector.ColumnVector vector;
+ private final OrcValueReader<?> primitiveValueReader;
+
+ PrimitiveOrcColumnVector(Type type, int batchSize,
org.apache.orc.storage.ql.exec.vector.ColumnVector vector,
+ OrcValueReader<?> primitiveValueReader) {
+ super(type, batchSize, vector);
+ this.vector = vector;
+ this.primitiveValueReader = primitiveValueReader;
+ }
+
+ @Override
+ public boolean getBoolean(int rowId) {
+ Boolean value = (Boolean) primitiveValueReader.read(vector, rowId);
+ return value != null ? value : false;
+ }
+
+ @Override
+ public int getInt(int rowId) {
+ Integer value = (Integer) primitiveValueReader.read(vector, rowId);
+ return value != null ? value : 0;
+ }
+
+ @Override
+ public long getLong(int rowId) {
+ Long value = (Long) primitiveValueReader.read(vector, rowId);
+ return value != null ? value : 0L;
+ }
+
+ @Override
+ public float getFloat(int rowId) {
+ Float value = (Float) primitiveValueReader.read(vector, rowId);
+ return value != null ? value : 0.0F;
+ }
+
+ @Override
+ public double getDouble(int rowId) {
+ Double value = (Double) primitiveValueReader.read(vector, rowId);
+ return value != null ? value : 0.0;
+ }
+
+ @Override
+ public Decimal getDecimal(int rowId, int precision, int scale) {
+ // TODO: Is it okay to assume that (precision,scale) parameters ==
(precision,scale) of the decimal type
+ // and return a Decimal with (precision,scale) of the decimal type?
+ return (Decimal) primitiveValueReader.read(vector, rowId);
+ }
+
+ @Override
+ public UTF8String getUTF8String(int rowId) {
+ return (UTF8String) primitiveValueReader.read(vector, rowId);
+ }
+
+ @Override
+ public byte[] getBinary(int rowId) {
+ return (byte[]) primitiveValueReader.read(vector, rowId);
+ }
+ }
+
+ private static class ArrayConverter implements Converter {
+ private final Types.ListType listType;
+ private final Converter elementConverter;
+
+ private ArrayConverter(Types.ListType listType, Converter
elementConverter) {
+ this.listType = listType;
+ this.elementConverter = elementConverter;
+ }
+
+ @Override
+ public ColumnVector
convert(org.apache.orc.storage.ql.exec.vector.ColumnVector vector, int
batchSize) {
+ ListColumnVector listVector = (ListColumnVector) vector;
+ ColumnVector elementVector = elementConverter.convert(listVector.child,
batchSize);
+
+ return new BaseOrcColumnVector(listType, batchSize, vector) {
+ @Override
+ public ColumnarArray getArray(int rowId) {
+ if (isNullAt(rowId)) {
+ return null;
+ } else {
+ int index = getRowIndex(rowId);
+ return new ColumnarArray(elementVector, (int)
listVector.offsets[index], (int) listVector.lengths[index]);
+ }
+ }
+ };
+ }
+ }
+
+ private static class MapConverter implements Converter {
+ private final Types.MapType mapType;
+ private final Converter keyConverter;
+ private final Converter valueConverter;
+
+ private MapConverter(Types.MapType mapType, Converter keyConverter,
Converter valueConverter) {
+ this.mapType = mapType;
+ this.keyConverter = keyConverter;
+ this.valueConverter = valueConverter;
+ }
+
+ @Override
+ public ColumnVector
convert(org.apache.orc.storage.ql.exec.vector.ColumnVector vector, int
batchSize) {
+ MapColumnVector mapVector = (MapColumnVector) vector;
+ ColumnVector keyVector = keyConverter.convert(mapVector.keys, batchSize);
+ ColumnVector valueVector = valueConverter.convert(mapVector.values,
batchSize);
+
+ return new BaseOrcColumnVector(mapType, batchSize, vector) {
+ @Override
+ public ColumnarMap getMap(int rowId) {
+ if (isNullAt(rowId)) {
+ return null;
+ } else {
+ int index = getRowIndex(rowId);
+ return new ColumnarMap(keyVector, valueVector, (int)
mapVector.offsets[index],
+ (int) mapVector.lengths[index]);
+ }
+ }
+ };
+ }
+ }
+
+ private static class StructConverter implements Converter {
+ private final Types.StructType structType;
+ private final List<Converter> fieldConverters;
+ private final Map<Integer, ?> idToConstant;
+
+ private StructConverter(Types.StructType structType, List<Converter>
fieldConverters,
+ Map<Integer, ?> idToConstant) {
+ this.structType = structType;
+ this.fieldConverters = fieldConverters;
+ this.idToConstant = idToConstant;
+ }
+
+ @Override
+ public ColumnVector
convert(org.apache.orc.storage.ql.exec.vector.ColumnVector vector, int
batchSize) {
+ StructColumnVector structVector = (StructColumnVector) vector;
+ List<Types.NestedField> fields = structType.fields();
+ List<ColumnVector> fieldVectors =
Lists.newArrayListWithExpectedSize(fields.size());
+ for (int pos = 0, vectorIndex = 0; pos < fields.size(); pos += 1) {
+ Types.NestedField field = fields.get(pos);
+ if (idToConstant.containsKey(field.fieldId())) {
+ fieldVectors.add(new ConstantColumnVector(field.type(), batchSize,
idToConstant.get(field.fieldId())));
+ } else {
+
fieldVectors.add(fieldConverters.get(vectorIndex).convert(structVector.fields[vectorIndex],
batchSize));
+ vectorIndex++;
+ }
+ }
+
+ return new BaseOrcColumnVector(structType, batchSize, vector) {
+ @Override
+ public ColumnVector getChild(int ordinal) {
+ return fieldVectors.get(ordinal);
+ }
+ };
+ }
+ }
+}
diff --git
a/spark/src/main/java/org/apache/iceberg/spark/source/BaseDataReader.java
b/spark/src/main/java/org/apache/iceberg/spark/source/BaseDataReader.java
index 93e03aa..fc87f3f 100644
--- a/spark/src/main/java/org/apache/iceberg/spark/source/BaseDataReader.java
+++ b/spark/src/main/java/org/apache/iceberg/spark/source/BaseDataReader.java
@@ -21,8 +21,12 @@ package org.apache.iceberg.spark.source;
import java.io.Closeable;
import java.io.IOException;
+import java.math.BigDecimal;
+import java.nio.ByteBuffer;
import java.util.Iterator;
import java.util.Map;
+import org.apache.avro.generic.GenericData;
+import org.apache.avro.util.Utf8;
import org.apache.iceberg.CombinedScanTask;
import org.apache.iceberg.FileScanTask;
import org.apache.iceberg.encryption.EncryptedFiles;
@@ -33,7 +37,11 @@ import org.apache.iceberg.io.InputFile;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.collect.Iterables;
+import org.apache.iceberg.types.Type;
+import org.apache.iceberg.util.ByteBuffers;
import org.apache.spark.rdd.InputFileBlockHolder;
+import org.apache.spark.sql.types.Decimal;
+import org.apache.spark.unsafe.types.UTF8String;
/**
* Base class of Spark readers.
@@ -99,4 +107,32 @@ abstract class BaseDataReader<T> implements Closeable {
Preconditions.checkArgument(!task.isDataTask(), "Invalid task type");
return inputFiles.get(task.file().path().toString());
}
+
+ protected static Object convertConstant(Type type, Object value) {
+ if (value == null) {
+ return null;
+ }
+
+ switch (type.typeId()) {
+ case DECIMAL:
+ return Decimal.apply((BigDecimal) value);
+ case STRING:
+ if (value instanceof Utf8) {
+ Utf8 utf8 = (Utf8) value;
+ return UTF8String.fromBytes(utf8.getBytes(), 0,
utf8.getByteLength());
+ }
+ return UTF8String.fromString(value.toString());
+ case FIXED:
+ if (value instanceof byte[]) {
+ return value;
+ } else if (value instanceof GenericData.Fixed) {
+ return ((GenericData.Fixed) value).bytes();
+ }
+ return ByteBuffers.toByteArray((ByteBuffer) value);
+ case BINARY:
+ return ByteBuffers.toByteArray((ByteBuffer) value);
+ default:
+ }
+ return value;
+ }
}
diff --git
a/spark/src/main/java/org/apache/iceberg/spark/source/BatchDataReader.java
b/spark/src/main/java/org/apache/iceberg/spark/source/BatchDataReader.java
index f784b63..eff18ca 100644
--- a/spark/src/main/java/org/apache/iceberg/spark/source/BatchDataReader.java
+++ b/spark/src/main/java/org/apache/iceberg/spark/source/BatchDataReader.java
@@ -19,10 +19,14 @@
package org.apache.iceberg.spark.source;
+import java.util.Map;
+import java.util.Set;
import org.apache.arrow.vector.NullCheckingForGet;
import org.apache.iceberg.CombinedScanTask;
+import org.apache.iceberg.DataFile;
import org.apache.iceberg.FileFormat;
import org.apache.iceberg.FileScanTask;
+import org.apache.iceberg.PartitionSpec;
import org.apache.iceberg.Schema;
import org.apache.iceberg.encryption.EncryptionManager;
import org.apache.iceberg.io.CloseableIterable;
@@ -30,9 +34,15 @@ import org.apache.iceberg.io.CloseableIterator;
import org.apache.iceberg.io.FileIO;
import org.apache.iceberg.io.InputFile;
import org.apache.iceberg.mapping.NameMappingParser;
+import org.apache.iceberg.orc.ORC;
import org.apache.iceberg.parquet.Parquet;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
+import org.apache.iceberg.spark.data.vectorized.VectorizedSparkOrcReaders;
import org.apache.iceberg.spark.data.vectorized.VectorizedSparkParquetReaders;
+import org.apache.iceberg.types.TypeUtil;
+import org.apache.iceberg.util.PartitionUtil;
+import org.apache.spark.rdd.InputFileBlockHolder;
import org.apache.spark.sql.vectorized.ColumnarBatch;
class BatchDataReader extends BaseDataReader<ColumnarBatch> {
@@ -53,6 +63,24 @@ class BatchDataReader extends BaseDataReader<ColumnarBatch> {
@Override
CloseableIterator<ColumnarBatch> open(FileScanTask task) {
+ DataFile file = task.file();
+
+ // update the current file for Spark's filename() function
+ InputFileBlockHolder.set(file.path().toString(), task.start(),
task.length());
+
+ // schema or rows returned by readers
+ PartitionSpec spec = task.spec();
+ Set<Integer> idColumns = spec.identitySourceIds();
+ Schema partitionSchema = TypeUtil.select(expectedSchema, idColumns);
+ boolean projectsIdentityPartitionColumns =
!partitionSchema.columns().isEmpty();
+
+ Map<Integer, ?> idToConstant;
+ if (projectsIdentityPartitionColumns) {
+ idToConstant = PartitionUtil.constantsMap(task,
BatchDataReader::convertConstant);
+ } else {
+ idToConstant = ImmutableMap.of();
+ }
+
CloseableIterable<ColumnarBatch> iter;
InputFile location = getInputFile(task);
Preconditions.checkNotNull(location, "Could not find InputFile associated
with FileScanTask");
@@ -75,6 +103,17 @@ class BatchDataReader extends BaseDataReader<ColumnarBatch>
{
}
iter = builder.build();
+ } else if (task.file().format() == FileFormat.ORC) {
+ Schema schemaWithoutConstants = TypeUtil.selectNot(expectedSchema,
idToConstant.keySet());
+ iter = ORC.read(location)
+ .project(schemaWithoutConstants)
+ .split(task.start(), task.length())
+ .createBatchedReaderFunc(fileSchema ->
VectorizedSparkOrcReaders.buildReader(expectedSchema, fileSchema,
+ idToConstant))
+ .recordsPerBatch(batchSize)
+ .filter(task.residual())
+ .caseSensitive(caseSensitive)
+ .build();
} else {
throw new UnsupportedOperationException(
"Format: " + task.file().format() + " not supported for batched
reads");
diff --git
a/spark/src/main/java/org/apache/iceberg/spark/source/RowDataReader.java
b/spark/src/main/java/org/apache/iceberg/spark/source/RowDataReader.java
index 16d5cb9..ff133ed 100644
--- a/spark/src/main/java/org/apache/iceberg/spark/source/RowDataReader.java
+++ b/spark/src/main/java/org/apache/iceberg/spark/source/RowDataReader.java
@@ -19,13 +19,9 @@
package org.apache.iceberg.spark.source;
-import java.math.BigDecimal;
-import java.nio.ByteBuffer;
import java.util.List;
import java.util.Map;
import java.util.Set;
-import org.apache.avro.generic.GenericData;
-import org.apache.avro.util.Utf8;
import org.apache.iceberg.CombinedScanTask;
import org.apache.iceberg.DataFile;
import org.apache.iceberg.DataTask;
@@ -49,19 +45,15 @@ import org.apache.iceberg.spark.SparkSchemaUtil;
import org.apache.iceberg.spark.data.SparkAvroReader;
import org.apache.iceberg.spark.data.SparkOrcReader;
import org.apache.iceberg.spark.data.SparkParquetReaders;
-import org.apache.iceberg.types.Type;
import org.apache.iceberg.types.TypeUtil;
import org.apache.iceberg.types.Types;
-import org.apache.iceberg.util.ByteBuffers;
import org.apache.iceberg.util.PartitionUtil;
import org.apache.spark.rdd.InputFileBlockHolder;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.Attribute;
import org.apache.spark.sql.catalyst.expressions.AttributeReference;
import org.apache.spark.sql.catalyst.expressions.UnsafeProjection;
-import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.types.StructType;
-import org.apache.spark.unsafe.types.UTF8String;
import scala.collection.JavaConverters;
class RowDataReader extends BaseDataReader<InternalRow> {
@@ -217,32 +209,4 @@ class RowDataReader extends BaseDataReader<InternalRow> {
JavaConverters.asScalaBufferConverter(exprs).asScala().toSeq(),
JavaConverters.asScalaBufferConverter(attrs).asScala().toSeq());
}
-
- private static Object convertConstant(Type type, Object value) {
- if (value == null) {
- return null;
- }
-
- switch (type.typeId()) {
- case DECIMAL:
- return Decimal.apply((BigDecimal) value);
- case STRING:
- if (value instanceof Utf8) {
- Utf8 utf8 = (Utf8) value;
- return UTF8String.fromBytes(utf8.getBytes(), 0,
utf8.getByteLength());
- }
- return UTF8String.fromString(value.toString());
- case FIXED:
- if (value instanceof byte[]) {
- return value;
- } else if (value instanceof GenericData.Fixed) {
- return ((GenericData.Fixed) value).bytes();
- }
- return ByteBuffers.toByteArray((ByteBuffer) value);
- case BINARY:
- return ByteBuffers.toByteArray((ByteBuffer) value);
- default:
- }
- return value;
- }
}
diff --git a/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java
b/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java
index f603757..aa0b247 100644
--- a/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java
+++ b/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java
@@ -638,7 +638,9 @@ public class TestHelpers {
for (int i = 0; i < actual.numFields(); i += 1) {
StructField field = struct.fields()[i];
DataType type = field.dataType();
- assertEquals(context + "." + field.name(), type, expected.get(i, type),
actual.get(i, type));
+ assertEquals(context + "." + field.name(), type,
+ expected.isNullAt(i) ? null : expected.get(i, type),
+ actual.isNullAt(i) ? null : actual.get(i, type));
}
}
diff --git
a/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReader.java
b/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReader.java
index 6a58850..5042d1c 100644
--- a/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReader.java
+++ b/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkOrcReader.java
@@ -29,8 +29,12 @@ import org.apache.iceberg.Schema;
import org.apache.iceberg.io.CloseableIterable;
import org.apache.iceberg.io.FileAppender;
import org.apache.iceberg.orc.ORC;
+import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
+import org.apache.iceberg.relocated.com.google.common.collect.Iterators;
+import org.apache.iceberg.spark.data.vectorized.VectorizedSparkOrcReaders;
import org.apache.iceberg.types.Types;
import org.apache.spark.sql.catalyst.InternalRow;
+import org.apache.spark.sql.vectorized.ColumnarBatch;
import org.junit.Assert;
import org.junit.Test;
@@ -81,5 +85,23 @@ public class TestSparkOrcReader extends AvroDataTest {
}
Assert.assertFalse("Should not have extra rows", actualRows.hasNext());
}
+
+ try (CloseableIterable<ColumnarBatch> reader =
ORC.read(Files.localInput(testFile))
+ .project(schema)
+ .createBatchedReaderFunc(readOrcSchema ->
+ VectorizedSparkOrcReaders.buildReader(schema, readOrcSchema,
ImmutableMap.of()))
+ .build()) {
+ final Iterator<InternalRow> actualRows =
batchesToRows(reader.iterator());
+ final Iterator<InternalRow> expectedRows = expected.iterator();
+ while (expectedRows.hasNext()) {
+ Assert.assertTrue("Should have expected number of rows",
actualRows.hasNext());
+ assertEquals(schema, expectedRows.next(), actualRows.next());
+ }
+ Assert.assertFalse("Should not have extra rows", actualRows.hasNext());
+ }
+ }
+
+ private Iterator<InternalRow> batchesToRows(Iterator<ColumnarBatch> batches)
{
+ return Iterators.concat(Iterators.transform(batches,
ColumnarBatch::rowIterator));
}
}
diff --git
a/spark/src/test/java/org/apache/iceberg/spark/source/TestIdentityPartitionData.java
b/spark/src/test/java/org/apache/iceberg/spark/source/TestIdentityPartitionData.java
index 759fcf3..15ae2f5 100644
---
a/spark/src/test/java/org/apache/iceberg/spark/source/TestIdentityPartitionData.java
+++
b/spark/src/test/java/org/apache/iceberg/spark/source/TestIdentityPartitionData.java
@@ -53,16 +53,20 @@ public abstract class TestIdentityPartitionData {
@Parameterized.Parameters
public static Object[][] parameters() {
return new Object[][] {
- new Object[] { "parquet" },
- new Object[] { "avro" },
- new Object[] { "orc" }
+ new Object[] { "parquet", false },
+ new Object[] { "parquet", true },
+ new Object[] { "avro", false },
+ new Object[] { "orc", false },
+ new Object[] { "orc", true },
};
}
private final String format;
+ private final boolean vectorized;
- public TestIdentityPartitionData(String format) {
+ public TestIdentityPartitionData(String format, boolean vectorized) {
this.format = format;
+ this.vectorized = vectorized;
}
private static SparkSession spark = null;
@@ -121,7 +125,9 @@ public abstract class TestIdentityPartitionData {
@Test
public void testFullProjection() {
List<Row> expected = logs.orderBy("id").collectAsList();
- List<Row> actual =
spark.read().format("iceberg").load(table.location()).orderBy("id").collectAsList();
+ List<Row> actual = spark.read().format("iceberg")
+ .option("vectorization-enabled", String.valueOf(vectorized))
+ .load(table.location()).orderBy("id").collectAsList();
Assert.assertEquals("Rows should match", expected, actual);
}
@@ -152,7 +158,9 @@ public abstract class TestIdentityPartitionData {
for (String[] ordering : cases) {
List<Row> expected = logs.select("id",
ordering).orderBy("id").collectAsList();
List<Row> actual = spark.read()
- .format("iceberg").load(table.location())
+ .format("iceberg")
+ .option("vectorization-enabled", String.valueOf(vectorized))
+ .load(table.location())
.select("id", ordering).orderBy("id")
.collectAsList();
Assert.assertEquals("Rows should match for ordering: " +
Arrays.toString(ordering), expected, actual);
diff --git
a/spark/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues.java
b/spark/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues.java
index 07b2174..c46b191 100644
---
a/spark/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues.java
+++
b/spark/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues.java
@@ -62,9 +62,11 @@ public abstract class TestPartitionValues {
@Parameterized.Parameters
public static Object[][] parameters() {
return new Object[][] {
- new Object[] { "parquet" },
- new Object[] { "avro" },
- new Object[] { "orc" }
+ new Object[] { "parquet", false },
+ new Object[] { "parquet", true },
+ new Object[] { "avro", false },
+ new Object[] { "orc", false },
+ new Object[] { "orc", true }
};
}
@@ -111,9 +113,11 @@ public abstract class TestPartitionValues {
public TemporaryFolder temp = new TemporaryFolder();
private final String format;
+ private final boolean vectorized;
- public TestPartitionValues(String format) {
+ public TestPartitionValues(String format, boolean vectorized) {
this.format = format;
+ this.vectorized = vectorized;
}
@Test
@@ -144,6 +148,7 @@ public abstract class TestPartitionValues {
Dataset<Row> result = spark.read()
.format("iceberg")
+ .option("vectorization-enabled", String.valueOf(vectorized))
.load(location.toString());
List<SimpleRecord> actual = result
@@ -183,6 +188,7 @@ public abstract class TestPartitionValues {
Dataset<Row> result = spark.read()
.format("iceberg")
+ .option("vectorization-enabled", String.valueOf(vectorized))
.load(location.toString());
List<SimpleRecord> actual = result
@@ -223,6 +229,7 @@ public abstract class TestPartitionValues {
Dataset<Row> result = spark.read()
.format("iceberg")
+ .option("vectorization-enabled", String.valueOf(vectorized))
.load(location.toString());
List<SimpleRecord> actual = result
@@ -261,7 +268,9 @@ public abstract class TestPartitionValues {
.appendFile(DataFiles.fromInputFile(Files.localInput(avroData), 10))
.commit();
- Dataset<Row> sourceDF =
spark.read().format("iceberg").load(sourceLocation);
+ Dataset<Row> sourceDF = spark.read().format("iceberg")
+ .option("vectorization-enabled", String.valueOf(vectorized))
+ .load(sourceLocation);
for (String column : columnNames) {
String desc = "partition_by_" +
SUPPORTED_PRIMITIVES.findType(column).toString();
@@ -283,6 +292,7 @@ public abstract class TestPartitionValues {
List<Row> actual = spark.read()
.format("iceberg")
+ .option("vectorization-enabled", String.valueOf(vectorized))
.load(location.toString())
.collectAsList();
@@ -323,7 +333,9 @@ public abstract class TestPartitionValues {
.appendFile(DataFiles.fromInputFile(Files.localInput(avroData), 10))
.commit();
- Dataset<Row> sourceDF =
spark.read().format("iceberg").load(sourceLocation);
+ Dataset<Row> sourceDF = spark.read().format("iceberg")
+ .option("vectorization-enabled", String.valueOf(vectorized))
+ .load(sourceLocation);
for (String column : columnNames) {
String desc = "partition_by_" +
SUPPORTED_PRIMITIVES.findType(column).toString();
@@ -345,6 +357,7 @@ public abstract class TestPartitionValues {
List<Row> actual = spark.read()
.format("iceberg")
+ .option("vectorization-enabled", String.valueOf(vectorized))
.load(location.toString())
.collectAsList();
@@ -403,6 +416,7 @@ public abstract class TestPartitionValues {
// verify
List<Row> actual = spark.read()
.format("iceberg")
+ .option("vectorization-enabled", String.valueOf(vectorized))
.load(baseLocation)
.collectAsList();
diff --git
a/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReadProjection.java
b/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReadProjection.java
index 6c9b32b..ac64fa9 100644
---
a/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReadProjection.java
+++
b/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReadProjection.java
@@ -31,7 +31,6 @@ import org.apache.iceberg.FileFormat;
import org.apache.iceberg.PartitionSpec;
import org.apache.iceberg.Schema;
import org.apache.iceberg.Table;
-import org.apache.iceberg.TableProperties;
import org.apache.iceberg.avro.Avro;
import org.apache.iceberg.data.Record;
import org.apache.iceberg.data.avro.DataWriter;
@@ -70,7 +69,8 @@ public abstract class TestSparkReadProjection extends
TestReadProjection {
new Object[] { "parquet", false },
new Object[] { "parquet", true },
new Object[] { "avro", false },
- new Object[] { "orc", false }
+ new Object[] { "orc", false },
+ new Object[] { "orc", true }
};
}
@@ -148,8 +148,6 @@ public abstract class TestSparkReadProjection extends
TestReadProjection {
table.newAppend().appendFile(file).commit();
-
table.updateProperties().set(TableProperties.PARQUET_VECTORIZATION_ENABLED,
String.valueOf(vectorized)).commit();
-
// rewrite the read schema for the table's reassigned ids
Map<Integer, Integer> idMapping = Maps.newHashMap();
for (int id : allIds(writeSchema)) {
@@ -166,6 +164,7 @@ public abstract class TestSparkReadProjection extends
TestReadProjection {
Dataset<Row> df = spark.read()
.format("org.apache.iceberg.spark.source.TestIcebergSource")
.option("iceberg.table.name", desc)
+ .option("vectorization-enabled", String.valueOf(vectorized))
.load();
return SparkValueConverter.convert(readSchema,
df.collectAsList().get(0));
diff --git
a/spark2/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceFlatORCDataReadBenchmark.java
b/spark2/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceFlatORCDataReadBenchmark.java
index 5463c7f..811d15c 100644
---
a/spark2/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceFlatORCDataReadBenchmark.java
+++
b/spark2/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceFlatORCDataReadBenchmark.java
@@ -66,7 +66,7 @@ public class IcebergSourceFlatORCDataReadBenchmark extends
IcebergSourceFlatORCD
@Benchmark
@Threads(1)
- public void readIceberg() {
+ public void readIcebergNonVectorized() {
Map<String, String> tableProperties = Maps.newHashMap();
tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 *
1024));
withTableProperties(tableProperties, () -> {
@@ -78,6 +78,19 @@ public class IcebergSourceFlatORCDataReadBenchmark extends
IcebergSourceFlatORCD
@Benchmark
@Threads(1)
+ public void readIcebergVectorized() {
+ Map<String, String> tableProperties = Maps.newHashMap();
+ tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 *
1024));
+ withTableProperties(tableProperties, () -> {
+ String tableLocation = table().location();
+ Dataset<Row> df = spark().read().option("vectorization-enabled", "true")
+ .format("iceberg").load(tableLocation);
+ materialize(df);
+ });
+ }
+
+ @Benchmark
+ @Threads(1)
public void readFileSourceVectorized() {
Map<String, String> conf = Maps.newHashMap();
conf.put(SQLConf.ORC_VECTORIZED_READER_ENABLED().key(), "true");
@@ -102,7 +115,7 @@ public class IcebergSourceFlatORCDataReadBenchmark extends
IcebergSourceFlatORCD
@Benchmark
@Threads(1)
- public void readWithProjectionIceberg() {
+ public void readWithProjectionIcebergNonVectorized() {
Map<String, String> tableProperties = Maps.newHashMap();
tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 *
1024));
withTableProperties(tableProperties, () -> {
@@ -114,6 +127,20 @@ public class IcebergSourceFlatORCDataReadBenchmark extends
IcebergSourceFlatORCD
@Benchmark
@Threads(1)
+ public void readWithProjectionIcebergVectorized() {
+ Map<String, String> tableProperties = Maps.newHashMap();
+ tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 *
1024));
+ withTableProperties(tableProperties, () -> {
+ String tableLocation = table().location();
+ Dataset<Row> df = spark().read().option("vectorization-enabled", "true")
+ .format("iceberg").load(tableLocation).select("longCol");
+ materialize(df);
+ });
+ }
+
+
+ @Benchmark
+ @Threads(1)
public void readWithProjectionFileSourceVectorized() {
Map<String, String> conf = Maps.newHashMap();
conf.put(SQLConf.ORC_VECTORIZED_READER_ENABLED().key(), "true");
diff --git
a/spark2/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceNestedORCDataReadBenchmark.java
b/spark2/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceNestedORCDataReadBenchmark.java
index a4147e6..a63d4f9 100644
---
a/spark2/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceNestedORCDataReadBenchmark.java
+++
b/spark2/src/jmh/java/org/apache/iceberg/spark/source/orc/IcebergSourceNestedORCDataReadBenchmark.java
@@ -68,7 +68,7 @@ public class IcebergSourceNestedORCDataReadBenchmark extends
IcebergSourceNested
@Benchmark
@Threads(1)
- public void readIceberg() {
+ public void readIcebergNonVectorized() {
Map<String, String> tableProperties = Maps.newHashMap();
tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 *
1024));
withTableProperties(tableProperties, () -> {
@@ -80,12 +80,13 @@ public class IcebergSourceNestedORCDataReadBenchmark
extends IcebergSourceNested
@Benchmark
@Threads(1)
- public void readFileSourceVectorized() {
- Map<String, String> conf = Maps.newHashMap();
- conf.put(SQLConf.ORC_VECTORIZED_READER_ENABLED().key(), "true");
- conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 *
1024 * 1024));
- withSQLConf(conf, () -> {
- Dataset<Row> df = spark().read().orc(dataLocation());
+ public void readIcebergVectorized() {
+ Map<String, String> tableProperties = Maps.newHashMap();
+ tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 *
1024));
+ withTableProperties(tableProperties, () -> {
+ String tableLocation = table().location();
+ Dataset<Row> df = spark().read().option("vectorization-enabled", "true")
+ .format("iceberg").load(tableLocation);
materialize(df);
});
}
@@ -104,7 +105,7 @@ public class IcebergSourceNestedORCDataReadBenchmark
extends IcebergSourceNested
@Benchmark
@Threads(1)
- public void readWithProjectionIceberg() {
+ public void readWithProjectionIcebergNonVectorized() {
Map<String, String> tableProperties = Maps.newHashMap();
tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 *
1024));
withTableProperties(tableProperties, () -> {
@@ -116,12 +117,13 @@ public class IcebergSourceNestedORCDataReadBenchmark
extends IcebergSourceNested
@Benchmark
@Threads(1)
- public void readWithProjectionFileSourceVectorized() {
- Map<String, String> conf = Maps.newHashMap();
- conf.put(SQLConf.ORC_VECTORIZED_READER_ENABLED().key(), "true");
- conf.put(SQLConf.FILES_OPEN_COST_IN_BYTES().key(), Integer.toString(128 *
1024 * 1024));
- withSQLConf(conf, () -> {
- Dataset<Row> df =
spark().read().orc(dataLocation()).selectExpr("nested.col3");
+ public void readWithProjectionIcebergVectorized() {
+ Map<String, String> tableProperties = Maps.newHashMap();
+ tableProperties.put(SPLIT_OPEN_FILE_COST, Integer.toString(128 * 1024 *
1024));
+ withTableProperties(tableProperties, () -> {
+ String tableLocation = table().location();
+ Dataset<Row> df = spark().read().option("vectorization-enabled", "true")
+ .format("iceberg").load(tableLocation).selectExpr("nested.col3");
materialize(df);
});
}
diff --git a/spark2/src/main/java/org/apache/iceberg/spark/source/Reader.java
b/spark2/src/main/java/org/apache/iceberg/spark/source/Reader.java
index 51859de..9fb475b 100644
--- a/spark2/src/main/java/org/apache/iceberg/spark/source/Reader.java
+++ b/spark2/src/main/java/org/apache/iceberg/spark/source/Reader.java
@@ -299,6 +299,13 @@ class Reader implements DataSourceReader,
SupportsScanColumnarBatch, SupportsPus
.allMatch(fileScanTask ->
fileScanTask.file().format().equals(
FileFormat.PARQUET)));
+ boolean allOrcFileScanTasks =
+ tasks().stream()
+ .allMatch(combinedScanTask -> !combinedScanTask.isDataTask() &&
combinedScanTask.files()
+ .stream()
+ .allMatch(fileScanTask ->
fileScanTask.file().format().equals(
+ FileFormat.ORC)));
+
boolean atLeastOneColumn = lazySchema().columns().size() > 0;
boolean hasNoIdentityProjections = tasks().stream()
@@ -308,8 +315,8 @@ class Reader implements DataSourceReader,
SupportsScanColumnarBatch, SupportsPus
boolean onlyPrimitives = lazySchema().columns().stream().allMatch(c ->
c.type().isPrimitiveType());
- this.readUsingBatch = batchReadsEnabled && allParquetFileScanTasks &&
atLeastOneColumn &&
- hasNoIdentityProjections && onlyPrimitives;
+ this.readUsingBatch = batchReadsEnabled && (allOrcFileScanTasks ||
+ (allParquetFileScanTasks && atLeastOneColumn &&
hasNoIdentityProjections && onlyPrimitives));
}
return readUsingBatch;
}
diff --git
a/spark2/src/test/java/org/apache/iceberg/spark/source/TestFilteredScan.java
b/spark2/src/test/java/org/apache/iceberg/spark/source/TestFilteredScan.java
index c0d676e..0d45179 100644
--- a/spark2/src/test/java/org/apache/iceberg/spark/source/TestFilteredScan.java
+++ b/spark2/src/test/java/org/apache/iceberg/spark/source/TestFilteredScan.java
@@ -152,18 +152,22 @@ public class TestFilteredScan {
public TemporaryFolder temp = new TemporaryFolder();
private final String format;
+ private final boolean vectorized;
@Parameterized.Parameters
public static Object[][] parameters() {
return new Object[][] {
- new Object[] { "parquet" },
- new Object[] { "avro" },
- new Object[] { "orc" }
+ new Object[] { "parquet", false },
+ new Object[] { "parquet", true },
+ new Object[] { "avro", false },
+ new Object[] { "orc", false },
+ new Object[] { "orc", true }
};
}
- public TestFilteredScan(String format) {
+ public TestFilteredScan(String format, boolean vectorized) {
this.format = format;
+ this.vectorized = vectorized;
}
private File parent = null;
@@ -243,7 +247,7 @@ public class TestFilteredScan {
// validate row filtering
assertEqualsSafe(SCHEMA.asStruct(), expected(i),
- read(unpartitioned.toString(), "id = " + i));
+ read(unpartitioned.toString(), vectorized, "id = " + i));
}
}
@@ -270,7 +274,7 @@ public class TestFilteredScan {
// validate row filtering
assertEqualsSafe(SCHEMA.asStruct(), expected(i),
- read(unpartitioned.toString(), "id = " + i));
+ read(unpartitioned.toString(), vectorized, "id = " + i));
}
} finally {
// return global conf to previous state
@@ -294,7 +298,7 @@ public class TestFilteredScan {
Assert.assertEquals("Should only create one task for a small file", 1,
tasks.size());
assertEqualsSafe(SCHEMA.asStruct(), expected(5, 6, 7, 8, 9),
- read(unpartitioned.toString(), "ts < cast('2017-12-22 00:00:00+00:00'
as timestamp)"));
+ read(unpartitioned.toString(), vectorized, "ts < cast('2017-12-22
00:00:00+00:00' as timestamp)"));
}
@Test
@@ -321,7 +325,7 @@ public class TestFilteredScan {
Assert.assertEquals("Should create one task for a single bucket", 1,
tasks.size());
// validate row filtering
- assertEqualsSafe(SCHEMA.asStruct(), expected(i),
read(location.toString(), "id = " + i));
+ assertEqualsSafe(SCHEMA.asStruct(), expected(i),
read(location.toString(), vectorized, "id = " + i));
}
}
@@ -348,7 +352,7 @@ public class TestFilteredScan {
Assert.assertEquals("Should create one task for 2017-12-21", 1,
tasks.size());
assertEqualsSafe(SCHEMA.asStruct(), expected(5, 6, 7, 8, 9),
- read(location.toString(), "ts < cast('2017-12-22 00:00:00+00:00' as
timestamp)"));
+ read(location.toString(), vectorized, "ts < cast('2017-12-22
00:00:00+00:00' as timestamp)"));
}
{
@@ -361,7 +365,7 @@ public class TestFilteredScan {
List<InputPartition<InternalRow>> tasks = reader.planInputPartitions();
Assert.assertEquals("Should create one task for 2017-12-22", 1,
tasks.size());
- assertEqualsSafe(SCHEMA.asStruct(), expected(1, 2),
read(location.toString(),
+ assertEqualsSafe(SCHEMA.asStruct(), expected(1, 2),
read(location.toString(), vectorized,
"ts > cast('2017-12-22 06:00:00+00:00' as timestamp) and " +
"ts < cast('2017-12-22 08:00:00+00:00' as timestamp)"));
}
@@ -390,7 +394,7 @@ public class TestFilteredScan {
Assert.assertEquals("Should create 4 tasks for 2017-12-21: 15, 17, 21,
22", 4, tasks.size());
assertEqualsSafe(SCHEMA.asStruct(), expected(8, 9, 7, 6, 5),
- read(location.toString(), "ts < cast('2017-12-22 00:00:00+00:00' as
timestamp)"));
+ read(location.toString(), vectorized, "ts < cast('2017-12-22
00:00:00+00:00' as timestamp)"));
}
{
@@ -403,7 +407,7 @@ public class TestFilteredScan {
List<InputPartition<InternalRow>> tasks = reader.planInputPartitions();
Assert.assertEquals("Should create 2 tasks for 2017-12-22: 6, 7", 2,
tasks.size());
- assertEqualsSafe(SCHEMA.asStruct(), expected(2, 1),
read(location.toString(),
+ assertEqualsSafe(SCHEMA.asStruct(), expected(2, 1),
read(location.toString(), vectorized,
"ts > cast('2017-12-22 06:00:00+00:00' as timestamp) and " +
"ts < cast('2017-12-22 08:00:00+00:00' as timestamp)"));
}
@@ -420,7 +424,7 @@ public class TestFilteredScan {
}
assertEqualsSafe(actualProjection.asStruct(), expected, read(
- unpartitioned.toString(),
+ unpartitioned.toString(), vectorized,
"ts < cast('2017-12-22 00:00:00+00:00' as timestamp)",
"id", "data"));
}
@@ -435,7 +439,7 @@ public class TestFilteredScan {
}
assertEqualsSafe(actualProjection.asStruct(), expected, read(
- unpartitioned.toString(),
+ unpartitioned.toString(), vectorized,
"ts > cast('2017-12-22 06:00:00+00:00' as timestamp) and " +
"ts < cast('2017-12-22 08:00:00+00:00' as timestamp)",
"id"));
@@ -512,6 +516,7 @@ public class TestFilteredScan {
public void testUnpartitionedStartsWith() {
Dataset<Row> df = spark.read()
.format("iceberg")
+ .option("vectorization-enabled", String.valueOf(vectorized))
.load(unpartitioned.toString());
List<String> matchedData = df.select("data")
@@ -578,6 +583,7 @@ public class TestFilteredScan {
// copy the unpartitioned table into the partitioned table to produce the
partitioned data
Dataset<Row> allRows = spark.read()
.format("iceberg")
+ .option("vectorization-enabled", String.valueOf(vectorized))
.load(unpartitioned.toString());
allRows
@@ -608,12 +614,14 @@ public class TestFilteredScan {
);
}
- private static List<Row> read(String table, String expr) {
- return read(table, expr, "*");
+ private static List<Row> read(String table, boolean vectorized, String expr)
{
+ return read(table, vectorized, expr, "*");
}
- private static List<Row> read(String table, String expr, String select0,
String... selectN) {
- Dataset<Row> dataset =
spark.read().format("iceberg").load(table).filter(expr)
+ private static List<Row> read(String table, boolean vectorized, String expr,
String select0, String... selectN) {
+ Dataset<Row> dataset = spark.read().format("iceberg")
+ .option("vectorization-enabled", String.valueOf(vectorized))
+ .load(table).filter(expr)
.select(select0, selectN);
return dataset.collectAsList();
}
diff --git
a/spark2/src/test/java/org/apache/iceberg/spark/source/TestIdentityPartitionData24.java
b/spark2/src/test/java/org/apache/iceberg/spark/source/TestIdentityPartitionData24.java
index fd7db75..9e382bf 100644
---
a/spark2/src/test/java/org/apache/iceberg/spark/source/TestIdentityPartitionData24.java
+++
b/spark2/src/test/java/org/apache/iceberg/spark/source/TestIdentityPartitionData24.java
@@ -20,7 +20,7 @@
package org.apache.iceberg.spark.source;
public class TestIdentityPartitionData24 extends TestIdentityPartitionData {
- public TestIdentityPartitionData24(String format) {
- super(format);
+ public TestIdentityPartitionData24(String format, boolean vectorized) {
+ super(format, vectorized);
}
}
diff --git
a/spark2/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues24.java
b/spark2/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues24.java
index f9da71e..d5d891f 100644
---
a/spark2/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues24.java
+++
b/spark2/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues24.java
@@ -20,7 +20,7 @@
package org.apache.iceberg.spark.source;
public class TestPartitionValues24 extends TestPartitionValues {
- public TestPartitionValues24(String format) {
- super(format);
+ public TestPartitionValues24(String format, boolean vectorized) {
+ super(format, vectorized);
}
}
diff --git
a/spark3/src/main/java/org/apache/iceberg/spark/source/SparkBatchScan.java
b/spark3/src/main/java/org/apache/iceberg/spark/source/SparkBatchScan.java
index 026f6ba..912d90c 100644
--- a/spark3/src/main/java/org/apache/iceberg/spark/source/SparkBatchScan.java
+++ b/spark3/src/main/java/org/apache/iceberg/spark/source/SparkBatchScan.java
@@ -155,6 +155,13 @@ class SparkBatchScan implements Scan, Batch,
SupportsReportStatistics {
.allMatch(fileScanTask -> fileScanTask.file().format().equals(
FileFormat.PARQUET)));
+ boolean allOrcFileScanTasks =
+ tasks().stream()
+ .allMatch(combinedScanTask -> !combinedScanTask.isDataTask() &&
combinedScanTask.files()
+ .stream()
+ .allMatch(fileScanTask -> fileScanTask.file().format().equals(
+ FileFormat.ORC)));
+
boolean atLeastOneColumn = expectedSchema.columns().size() > 0;
boolean hasNoIdentityProjections = tasks().stream()
@@ -164,8 +171,8 @@ class SparkBatchScan implements Scan, Batch,
SupportsReportStatistics {
boolean onlyPrimitives = expectedSchema.columns().stream().allMatch(c ->
c.type().isPrimitiveType());
- boolean readUsingBatch = batchReadsEnabled && allParquetFileScanTasks &&
atLeastOneColumn &&
- hasNoIdentityProjections && onlyPrimitives;
+ boolean readUsingBatch = batchReadsEnabled && (allOrcFileScanTasks ||
+ (allParquetFileScanTasks && atLeastOneColumn &&
hasNoIdentityProjections && onlyPrimitives));
return new ReaderFactory(readUsingBatch ? batchSize : 0);
}
diff --git
a/spark3/src/test/java/org/apache/iceberg/spark/source/TestFilteredScan.java
b/spark3/src/test/java/org/apache/iceberg/spark/source/TestFilteredScan.java
index 7dd308d..9be9938 100644
--- a/spark3/src/test/java/org/apache/iceberg/spark/source/TestFilteredScan.java
+++ b/spark3/src/test/java/org/apache/iceberg/spark/source/TestFilteredScan.java
@@ -22,10 +22,10 @@ package org.apache.iceberg.spark.source;
import java.io.File;
import java.io.IOException;
import java.sql.Timestamp;
+import java.time.OffsetDateTime;
import java.util.List;
import java.util.Locale;
import java.util.UUID;
-import org.apache.avro.generic.GenericData.Record;
import org.apache.hadoop.conf.Configuration;
import org.apache.iceberg.DataFile;
import org.apache.iceberg.DataFiles;
@@ -34,14 +34,18 @@ import org.apache.iceberg.PartitionSpec;
import org.apache.iceberg.Schema;
import org.apache.iceberg.Table;
import org.apache.iceberg.avro.Avro;
-import org.apache.iceberg.avro.AvroSchemaUtil;
-import org.apache.iceberg.expressions.Literal;
+import org.apache.iceberg.data.GenericRecord;
+import org.apache.iceberg.data.Record;
+import org.apache.iceberg.data.avro.DataWriter;
+import org.apache.iceberg.data.orc.GenericOrcWriter;
+import org.apache.iceberg.data.parquet.GenericParquetWriter;
import org.apache.iceberg.hadoop.HadoopTables;
import org.apache.iceberg.io.FileAppender;
+import org.apache.iceberg.orc.ORC;
import org.apache.iceberg.parquet.Parquet;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
-import org.apache.iceberg.spark.data.TestHelpers;
+import org.apache.iceberg.spark.data.GenericsHelpers;
import org.apache.iceberg.transforms.Transform;
import org.apache.iceberg.transforms.Transforms;
import org.apache.iceberg.types.Types;
@@ -146,17 +150,22 @@ public class TestFilteredScan {
public TemporaryFolder temp = new TemporaryFolder();
private final String format;
+ private final boolean vectorized;
@Parameterized.Parameters
public static Object[][] parameters() {
return new Object[][] {
- new Object[] { "parquet" },
- new Object[] { "avro" }
+ new Object[] { "parquet", false },
+ new Object[] { "parquet", true },
+ new Object[] { "avro", false },
+ new Object[] { "orc", false },
+ new Object[] { "orc", true }
};
}
- public TestFilteredScan(String format) {
+ public TestFilteredScan(String format, boolean vectorized) {
this.format = format;
+ this.vectorized = vectorized;
}
private File parent = null;
@@ -177,13 +186,12 @@ public class TestFilteredScan {
File testFile = new File(dataFolder,
fileFormat.addExtension(UUID.randomUUID().toString()));
- // create records using the table's schema
- org.apache.avro.Schema avroSchema = AvroSchemaUtil.convert(tableSchema,
"test");
- this.records = testRecords(avroSchema);
+ this.records = testRecords(tableSchema);
switch (fileFormat) {
case AVRO:
try (FileAppender<Record> writer = Avro.write(localOutput(testFile))
+ .createWriterFunc(DataWriter::create)
.schema(tableSchema)
.build()) {
writer.addAll(records);
@@ -192,6 +200,16 @@ public class TestFilteredScan {
case PARQUET:
try (FileAppender<Record> writer = Parquet.write(localOutput(testFile))
+ .createWriterFunc(GenericParquetWriter::buildWriter)
+ .schema(tableSchema)
+ .build()) {
+ writer.addAll(records);
+ }
+ break;
+
+ case ORC:
+ try (FileAppender<Record> writer = ORC.write(localOutput(testFile))
+ .createWriterFunc(GenericOrcWriter::buildWriter)
.schema(tableSchema)
.build()) {
writer.addAll(records);
@@ -224,7 +242,7 @@ public class TestFilteredScan {
// validate row filtering
assertEqualsSafe(SCHEMA.asStruct(), expected(i),
- read(unpartitioned.toString(), "id = " + i));
+ read(unpartitioned.toString(), vectorized, "id = " + i));
}
}
@@ -252,7 +270,7 @@ public class TestFilteredScan {
// validate row filtering
assertEqualsSafe(SCHEMA.asStruct(), expected(i),
- read(unpartitioned.toString(), "id = " + i));
+ read(unpartitioned.toString(), vectorized, "id = " + i));
}
} finally {
// return global conf to previous state
@@ -275,7 +293,7 @@ public class TestFilteredScan {
Assert.assertEquals("Should only create one task for a small file", 1,
tasks.length);
assertEqualsSafe(SCHEMA.asStruct(), expected(5, 6, 7, 8, 9),
- read(unpartitioned.toString(), "ts < cast('2017-12-22 00:00:00+00:00'
as timestamp)"));
+ read(unpartitioned.toString(), vectorized, "ts < cast('2017-12-22
00:00:00+00:00' as timestamp)"));
}
@Test
@@ -299,7 +317,7 @@ public class TestFilteredScan {
Assert.assertEquals("Should create one task for a single bucket", 1,
tasks.length);
// validate row filtering
- assertEqualsSafe(SCHEMA.asStruct(), expected(i), read(table.location(),
"id = " + i));
+ assertEqualsSafe(SCHEMA.asStruct(), expected(i), read(table.location(),
vectorized, "id = " + i));
}
}
@@ -323,7 +341,7 @@ public class TestFilteredScan {
Assert.assertEquals("Should create one task for 2017-12-21", 1,
tasks.length);
assertEqualsSafe(SCHEMA.asStruct(), expected(5, 6, 7, 8, 9),
- read(table.location(), "ts < cast('2017-12-22 00:00:00+00:00' as
timestamp)"));
+ read(table.location(), vectorized, "ts < cast('2017-12-22
00:00:00+00:00' as timestamp)"));
}
{
@@ -337,7 +355,7 @@ public class TestFilteredScan {
InputPartition[] tasks = scan.planInputPartitions();
Assert.assertEquals("Should create one task for 2017-12-22", 1,
tasks.length);
- assertEqualsSafe(SCHEMA.asStruct(), expected(1, 2),
read(table.location(),
+ assertEqualsSafe(SCHEMA.asStruct(), expected(1, 2),
read(table.location(), vectorized,
"ts > cast('2017-12-22 06:00:00+00:00' as timestamp) and " +
"ts < cast('2017-12-22 08:00:00+00:00' as timestamp)"));
}
@@ -364,7 +382,7 @@ public class TestFilteredScan {
Assert.assertEquals("Should create 4 tasks for 2017-12-21: 15, 17, 21,
22", 4, tasks.length);
assertEqualsSafe(SCHEMA.asStruct(), expected(8, 9, 7, 6, 5),
- read(table.location(), "ts < cast('2017-12-22 00:00:00+00:00' as
timestamp)"));
+ read(table.location(), vectorized, "ts < cast('2017-12-22
00:00:00+00:00' as timestamp)"));
}
{
@@ -378,7 +396,7 @@ public class TestFilteredScan {
InputPartition[] tasks = scan.planInputPartitions();
Assert.assertEquals("Should create 2 tasks for 2017-12-22: 6, 7", 2,
tasks.length);
- assertEqualsSafe(SCHEMA.asStruct(), expected(2, 1),
read(table.location(),
+ assertEqualsSafe(SCHEMA.asStruct(), expected(2, 1),
read(table.location(), vectorized,
"ts > cast('2017-12-22 06:00:00+00:00' as timestamp) and " +
"ts < cast('2017-12-22 08:00:00+00:00' as timestamp)"));
}
@@ -395,7 +413,7 @@ public class TestFilteredScan {
}
assertEqualsSafe(actualProjection.asStruct(), expected, read(
- unpartitioned.toString(),
+ unpartitioned.toString(), vectorized,
"ts < cast('2017-12-22 00:00:00+00:00' as timestamp)",
"id", "data"));
}
@@ -410,7 +428,7 @@ public class TestFilteredScan {
}
assertEqualsSafe(actualProjection.asStruct(), expected, read(
- unpartitioned.toString(),
+ unpartitioned.toString(), vectorized,
"ts > cast('2017-12-22 06:00:00+00:00' as timestamp) and " +
"ts < cast('2017-12-22 08:00:00+00:00' as timestamp)",
"id"));
@@ -450,6 +468,7 @@ public class TestFilteredScan {
public void testUnpartitionedStartsWith() {
Dataset<Row> df = spark.read()
.format("iceberg")
+ .option("vectorization-enabled", String.valueOf(vectorized))
.load(unpartitioned.toString());
List<String> matchedData = df.select("data")
@@ -462,12 +481,11 @@ public class TestFilteredScan {
}
private static Record projectFlat(Schema projection, Record record) {
- org.apache.avro.Schema avroSchema = AvroSchemaUtil.convert(projection,
"test");
- Record result = new Record(avroSchema);
+ Record result = GenericRecord.create(projection);
List<Types.NestedField> fields = projection.asStruct().fields();
for (int i = 0; i < fields.size(); i += 1) {
Types.NestedField field = fields.get(i);
- result.put(i, record.get(field.name()));
+ result.set(i, record.getField(field.name()));
}
return result;
}
@@ -477,7 +495,7 @@ public class TestFilteredScan {
// TODO: match records by ID
int numRecords = Math.min(expected.size(), actual.size());
for (int i = 0; i < numRecords; i += 1) {
- TestHelpers.assertEqualsUnsafe(struct, expected.get(i), actual.get(i));
+ GenericsHelpers.assertEqualsUnsafe(struct, expected.get(i),
actual.get(i));
}
Assert.assertEquals("Number of results should match expected",
expected.size(), actual.size());
}
@@ -487,7 +505,7 @@ public class TestFilteredScan {
// TODO: match records by ID
int numRecords = Math.min(expected.size(), actual.size());
for (int i = 0; i < numRecords; i += 1) {
- TestHelpers.assertEqualsSafe(struct, expected.get(i), actual.get(i));
+ GenericsHelpers.assertEqualsSafe(struct, expected.get(i), actual.get(i));
}
Assert.assertEquals("Number of results should match expected",
expected.size(), actual.size());
}
@@ -517,6 +535,7 @@ public class TestFilteredScan {
// copy the unpartitioned table into the partitioned table to produce the
partitioned data
Dataset<Row> allRows = spark.read()
.format("iceberg")
+ .option("vectorization-enabled", String.valueOf(vectorized))
.load(unpartitioned.toString());
allRows
@@ -534,39 +553,41 @@ public class TestFilteredScan {
return table;
}
- private List<Record> testRecords(org.apache.avro.Schema avroSchema) {
+ private List<Record> testRecords(Schema schema) {
return Lists.newArrayList(
- record(avroSchema, 0L, timestamp("2017-12-22T09:20:44.294658+00:00"),
"junction"),
- record(avroSchema, 1L, timestamp("2017-12-22T07:15:34.582910+00:00"),
"alligator"),
- record(avroSchema, 2L, timestamp("2017-12-22T06:02:09.243857+00:00"),
"forrest"),
- record(avroSchema, 3L, timestamp("2017-12-22T03:10:11.134509+00:00"),
"clapping"),
- record(avroSchema, 4L, timestamp("2017-12-22T00:34:00.184671+00:00"),
"brush"),
- record(avroSchema, 5L, timestamp("2017-12-21T22:20:08.935889+00:00"),
"trap"),
- record(avroSchema, 6L, timestamp("2017-12-21T21:55:30.589712+00:00"),
"element"),
- record(avroSchema, 7L, timestamp("2017-12-21T17:31:14.532797+00:00"),
"limited"),
- record(avroSchema, 8L, timestamp("2017-12-21T15:21:51.237521+00:00"),
"global"),
- record(avroSchema, 9L, timestamp("2017-12-21T15:02:15.230570+00:00"),
"goldfish")
+ record(schema, 0L, parse("2017-12-22T09:20:44.294658+00:00"),
"junction"),
+ record(schema, 1L, parse("2017-12-22T07:15:34.582910+00:00"),
"alligator"),
+ record(schema, 2L, parse("2017-12-22T06:02:09.243857+00:00"),
"forrest"),
+ record(schema, 3L, parse("2017-12-22T03:10:11.134509+00:00"),
"clapping"),
+ record(schema, 4L, parse("2017-12-22T00:34:00.184671+00:00"), "brush"),
+ record(schema, 5L, parse("2017-12-21T22:20:08.935889+00:00"), "trap"),
+ record(schema, 6L, parse("2017-12-21T21:55:30.589712+00:00"),
"element"),
+ record(schema, 7L, parse("2017-12-21T17:31:14.532797+00:00"),
"limited"),
+ record(schema, 8L, parse("2017-12-21T15:21:51.237521+00:00"),
"global"),
+ record(schema, 9L, parse("2017-12-21T15:02:15.230570+00:00"),
"goldfish")
);
}
- private static List<Row> read(String table, String expr) {
- return read(table, expr, "*");
+ private static List<Row> read(String table, boolean vectorized, String expr)
{
+ return read(table, vectorized, expr, "*");
}
- private static List<Row> read(String table, String expr, String select0,
String... selectN) {
- Dataset<Row> dataset =
spark.read().format("iceberg").load(table).filter(expr)
+ private static List<Row> read(String table, boolean vectorized, String expr,
String select0, String... selectN) {
+ Dataset<Row> dataset = spark.read().format("iceberg")
+ .option("vectorization-enabled", String.valueOf(vectorized))
+ .load(table).filter(expr)
.select(select0, selectN);
return dataset.collectAsList();
}
- private static long timestamp(String timestamp) {
- return
Literal.of(timestamp).<Long>to(Types.TimestampType.withZone()).value();
+ private static OffsetDateTime parse(String timestamp) {
+ return OffsetDateTime.parse(timestamp);
}
- private static Record record(org.apache.avro.Schema schema, Object...
values) {
- Record rec = new Record(schema);
+ private static Record record(Schema schema, Object... values) {
+ Record rec = GenericRecord.create(schema);
for (int i = 0; i < values.length; i += 1) {
- rec.put(i, values[i]);
+ rec.set(i, values[i]);
}
return rec;
}
diff --git
a/spark3/src/test/java/org/apache/iceberg/spark/source/TestIdentityPartitionData3.java
b/spark3/src/test/java/org/apache/iceberg/spark/source/TestIdentityPartitionData3.java
index 83f3f32..3b90f61 100644
---
a/spark3/src/test/java/org/apache/iceberg/spark/source/TestIdentityPartitionData3.java
+++
b/spark3/src/test/java/org/apache/iceberg/spark/source/TestIdentityPartitionData3.java
@@ -20,7 +20,7 @@
package org.apache.iceberg.spark.source;
public class TestIdentityPartitionData3 extends TestIdentityPartitionData {
- public TestIdentityPartitionData3(String format) {
- super(format);
+ public TestIdentityPartitionData3(String format, boolean vectorized) {
+ super(format, vectorized);
}
}
diff --git
a/spark3/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues3.java
b/spark3/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues3.java
index 63db54e..9b42192 100644
---
a/spark3/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues3.java
+++
b/spark3/src/test/java/org/apache/iceberg/spark/source/TestPartitionValues3.java
@@ -20,7 +20,7 @@
package org.apache.iceberg.spark.source;
public class TestPartitionValues3 extends TestPartitionValues {
- public TestPartitionValues3(String format) {
- super(format);
+ public TestPartitionValues3(String format, boolean vectorized) {
+ super(format, vectorized);
}
}