Github user henrify commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19943#discussion_r160078679
  
    --- Diff: 
sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/JavaOrcColumnarBatchReader.java
 ---
    @@ -0,0 +1,503 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one or more
    + * contributor license agreements.  See the NOTICE file distributed with
    + * this work for additional information regarding copyright ownership.
    + * The ASF licenses this file to You under the Apache License, Version 2.0
    + * (the "License"); you may not use this file except in compliance with
    + * the License.  You may obtain a copy of the License at
    + *
    + *    http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +package org.apache.spark.sql.execution.datasources.orc;
    +
    +import java.io.IOException;
    +
    +import org.apache.hadoop.conf.Configuration;
    +import org.apache.hadoop.mapreduce.InputSplit;
    +import org.apache.hadoop.mapreduce.RecordReader;
    +import org.apache.hadoop.mapreduce.TaskAttemptContext;
    +import org.apache.hadoop.mapreduce.lib.input.FileSplit;
    +import org.apache.orc.OrcConf;
    +import org.apache.orc.OrcFile;
    +import org.apache.orc.Reader;
    +import org.apache.orc.TypeDescription;
    +import org.apache.orc.mapred.OrcInputFormat;
    +import org.apache.orc.storage.common.type.HiveDecimal;
    +import org.apache.orc.storage.ql.exec.vector.*;
    +import org.apache.orc.storage.serde2.io.HiveDecimalWritable;
    +
    +import org.apache.spark.memory.MemoryMode;
    +import org.apache.spark.sql.catalyst.InternalRow;
    +import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils;
    +import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector;
    +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector;
    +import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
    +import org.apache.spark.sql.types.*;
    +import org.apache.spark.sql.vectorized.ColumnarBatch;
    +
    +
    +/**
    + * To support vectorization in WholeStageCodeGen, this reader returns 
ColumnarBatch.
    + * After creating, `initialize` and `setRequiredSchema` should be called 
sequentially.
    + */
    +public class JavaOrcColumnarBatchReader extends RecordReader<Void, 
ColumnarBatch> {
    +
    +  /**
    +   * ORC File Reader.
    +   */
    +  private Reader reader;
    +
    +  /**
    +   * Vectorized Row Batch.
    +   */
    +  private VectorizedRowBatch batch;
    +
    +  /**
    +   * Requested Column IDs.
    +   */
    +  private int[] requestedColIds;
    +
    +  /**
    +   * Record reader from row batch.
    +   */
    +  private org.apache.orc.RecordReader recordReader;
    +
    +  /**
    +   * Required Schema.
    +   */
    +  private StructType requiredSchema;
    +
    +  /**
    +   * ColumnarBatch for vectorized execution by whole-stage codegen.
    +   */
    +  private ColumnarBatch columnarBatch;
    +
    +  /**
    +   * Writable column vectors of ColumnarBatch.
    +   */
    +  private WritableColumnVector[] columnVectors;
    +
    +  /**
    +   * The number of rows read and considered to be returned.
    +   */
    +  private long rowsReturned = 0L;
    +
    +  /**
    +   * Total number of rows.
    +   */
    +  private long totalRowCount = 0L;
    +
    +  @Override
    +  public Void getCurrentKey() throws IOException, InterruptedException {
    +    return null;
    +  }
    +
    +  @Override
    +  public ColumnarBatch getCurrentValue() throws IOException, 
InterruptedException {
    +    return columnarBatch;
    +  }
    +
    +  @Override
    +  public float getProgress() throws IOException, InterruptedException {
    +    return (float) rowsReturned / totalRowCount;
    +  }
    +
    +  @Override
    +  public boolean nextKeyValue() throws IOException, InterruptedException {
    +    return nextBatch();
    +  }
    +
    +  @Override
    +  public void close() throws IOException {
    +    if (columnarBatch != null) {
    +      columnarBatch.close();
    +      columnarBatch = null;
    +    }
    +    if (recordReader != null) {
    +      recordReader.close();
    +      recordReader = null;
    +    }
    +  }
    +
    +  /**
    +   * Initialize ORC file reader and batch record reader.
    +   * Please note that `setRequiredSchema` is needed to be called after 
this.
    +   */
    +  @Override
    +  public void initialize(InputSplit inputSplit, TaskAttemptContext 
taskAttemptContext)
    +      throws IOException, InterruptedException {
    +    FileSplit fileSplit = (FileSplit)inputSplit;
    +    Configuration conf = taskAttemptContext.getConfiguration();
    +    reader = OrcFile.createReader(
    +      fileSplit.getPath(),
    +      OrcFile.readerOptions(conf)
    +        .maxLength(OrcConf.MAX_FILE_LENGTH.getLong(conf))
    +        .filesystem(fileSplit.getPath().getFileSystem(conf)));
    +
    +    Reader.Options options =
    +            OrcInputFormat.buildOptions(conf, reader, 
fileSplit.getStart(), fileSplit.getLength());
    +    recordReader = reader.rows(options);
    +    totalRowCount = reader.getNumberOfRows();
    +  }
    +
    +  /**
    +   * Set required schema and partition information.
    +   * With this information, this creates ColumnarBatch with the full 
schema.
    +   */
    +  public void setRequiredSchema(
    +    TypeDescription orcSchema,
    +    int[] requestedColIds,
    +    StructType requiredSchema,
    +    StructType partitionSchema,
    +    InternalRow partitionValues) {
    +    batch = orcSchema.createRowBatch(DEFAULT_SIZE);
    +    assert(!batch.selectedInUse); // `selectedInUse` should be initialized 
with `false`.
    +
    +    StructType resultSchema = new StructType(requiredSchema.fields());
    +    for (StructField f : partitionSchema.fields())
    +      resultSchema = resultSchema.add(f);
    +    this.requiredSchema = requiredSchema;
    +    this.requestedColIds = requestedColIds;
    +
    +    int capacity = DEFAULT_SIZE;
    +    if (DEFAULT_MEMORY_MODE == MemoryMode.OFF_HEAP) {
    +      columnVectors = OffHeapColumnVector.allocateColumns(capacity, 
resultSchema);
    +    } else {
    +      columnVectors = OnHeapColumnVector.allocateColumns(capacity, 
resultSchema);
    +    }
    +    columnarBatch = new ColumnarBatch(resultSchema, columnVectors, 
capacity);
    +
    +    if (partitionValues.numFields() > 0) {
    +      int partitionIdx = requiredSchema.fields().length;
    +      for (int i = 0; i < partitionValues.numFields(); i++) {
    +        ColumnVectorUtils.populate(columnVectors[i + partitionIdx], 
partitionValues, i);
    +        columnVectors[i + partitionIdx].setIsConstant();
    +      }
    +    }
    +
    +    // Initialize the missing columns once.
    +    for (int i = 0; i < requiredSchema.length(); i++) {
    +      if (requestedColIds[i] < 0) {
    +        columnVectors[i].putNulls(0, columnarBatch.capacity());
    +        columnVectors[i].setIsConstant();
    +      }
    +    }
    +  }
    +
    +  /**
    +   * Return true if there exists more data in the next batch. If exists, 
prepare the next batch
    +   * by copying from ORC VectorizedRowBatch columns to Spark ColumnarBatch 
columns.
    +   */
    +  private boolean nextBatch() throws IOException {
    +    if (rowsReturned >= totalRowCount) {
    +      return false;
    +    }
    +
    +    recordReader.nextBatch(batch);
    +    int batchSize = batch.size;
    +    if (batchSize == 0) {
    +      return false;
    +    }
    +    rowsReturned += batchSize;
    +    for (WritableColumnVector vector : columnVectors) {
    +      vector.reset();
    +    }
    +    columnarBatch.setNumRows(batchSize);
    +    int i = 0;
    +    while (i < requiredSchema.length()) {
    +      StructField field = requiredSchema.fields()[i];
    +      WritableColumnVector toColumn = columnVectors[i];
    +
    +      if (requestedColIds[i] < 0) {
    +        toColumn.appendNulls(batchSize);
    +      } else {
    +        ColumnVector fromColumn = batch.cols[requestedColIds[i]];
    +
    +        if (fromColumn.isRepeating) {
    +          if (fromColumn.isNull[0]) {
    +            toColumn.appendNulls(batchSize);
    +          } else {
    +            DataType type = field.dataType();
    +            if (type instanceof BooleanType) {
    +              toColumn.appendBooleans(batchSize, 
((LongColumnVector)fromColumn).vector[0] == 1);
    +            } else if (type instanceof ByteType) {
    +              toColumn.appendBytes(batchSize, 
(byte)((LongColumnVector)fromColumn).vector[0]);
    +            } else if (type instanceof ShortType) {
    +              toColumn.appendShorts(batchSize, 
(short)((LongColumnVector)fromColumn).vector[0]);
    +            } else if (type instanceof IntegerType || type instanceof 
DateType) {
    +              toColumn.appendInts(batchSize, 
(int)((LongColumnVector)fromColumn).vector[0]);
    +            } else if (type instanceof LongType) {
    +              toColumn.appendLongs(batchSize, 
((LongColumnVector)fromColumn).vector[0]);
    +            } else if (type instanceof TimestampType) {
    +              toColumn.appendLongs(batchSize, 
fromTimestampColumnVector((TimestampColumnVector)fromColumn, 0));
    +            } else if (type instanceof FloatType) {
    +              toColumn.appendFloats(batchSize, 
(float)((DoubleColumnVector)fromColumn).vector[0]);
    +            } else if (type instanceof DoubleType) {
    +              toColumn.appendDoubles(batchSize, 
((DoubleColumnVector)fromColumn).vector[0]);
    +            } else if (type instanceof StringType || type instanceof 
BinaryType) {
    +              BytesColumnVector data = (BytesColumnVector)fromColumn;
    +              int index = 0;
    +              while (index < batchSize) {
    +                toColumn.appendByteArray(data.vector[0], data.start[0], 
data.length[0]);
    +                index += 1;
    +              }
    +            } else if (type instanceof DecimalType) {
    +              DecimalType decimalType = (DecimalType)type;
    +              appendDecimalWritable(
    +                toColumn,
    +                decimalType.precision(),
    +                decimalType.scale(),
    +                ((DecimalColumnVector)fromColumn).vector[0]);
    +            } else {
    +              throw new UnsupportedOperationException("Unsupported Data 
Type: " + type);
    +            }
    +          }
    +        } else if (fromColumn.noNulls) {
    +          DataType type = field.dataType();
    +          if (type instanceof BooleanType) {
    +            long[] data = ((LongColumnVector)fromColumn).vector;
    +            int index = 0;
    +            while (index < batchSize) {
    +              toColumn.appendBoolean(data[index] == 1);
    --- End diff --
    
    Actually, you know the number of rows in advance. Wouldn't it possible to 
call reserve() once, and then use the putX() API instead of appendX() API 
inside the loops? That should be significantly faster.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to