http://git-wip-us.apache.org/repos/asf/spark/blob/ea1a51fc/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala deleted file mode 100644 index 9322b77..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/CompressionScheme.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.columnar.compression - -import java.nio.{ByteBuffer, ByteOrder} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.MutableRow -import org.apache.spark.sql.columnar.{ColumnType, NativeColumnType} -import org.apache.spark.sql.types.AtomicType - -private[sql] trait Encoder[T <: AtomicType] { - def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = {} - - def compressedSize: Int - - def uncompressedSize: Int - - def compressionRatio: Double = { - if (uncompressedSize > 0) compressedSize.toDouble / uncompressedSize else 1.0 - } - - def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer -} - -private[sql] trait Decoder[T <: AtomicType] { - def next(row: MutableRow, ordinal: Int): Unit - - def hasNext: Boolean -} - -private[sql] trait CompressionScheme { - def typeId: Int - - def supports(columnType: ColumnType[_]): Boolean - - def encoder[T <: AtomicType](columnType: NativeColumnType[T]): Encoder[T] - - def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T] -} - -private[sql] trait WithCompressionSchemes { - def schemes: Seq[CompressionScheme] -} - -private[sql] trait AllCompressionSchemes extends WithCompressionSchemes { - override val schemes: Seq[CompressionScheme] = CompressionScheme.all -} - -private[sql] object CompressionScheme { - val all: Seq[CompressionScheme] = - Seq(PassThrough, RunLengthEncoding, DictionaryEncoding, BooleanBitSet, IntDelta, LongDelta) - - private val typeIdToScheme = all.map(scheme => scheme.typeId -> scheme).toMap - - def apply(typeId: Int): CompressionScheme = { - typeIdToScheme.getOrElse(typeId, throw new UnsupportedOperationException( - s"Unrecognized compression scheme type ID: $typeId")) - } - - def columnHeaderSize(columnBuffer: ByteBuffer): Int = { - val header = columnBuffer.duplicate().order(ByteOrder.nativeOrder) - val nullCount = header.getInt() - // null count + null positions - 4 + 4 * nullCount - } -}
http://git-wip-us.apache.org/repos/asf/spark/blob/ea1a51fc/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala deleted file mode 100644 index 41c9a28..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/compression/compressionSchemes.scala +++ /dev/null @@ -1,532 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.columnar.compression - -import java.nio.ByteBuffer - -import scala.collection.mutable - -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{MutableRow, SpecificMutableRow} -import org.apache.spark.sql.columnar._ -import org.apache.spark.sql.types._ - - -private[sql] case object PassThrough extends CompressionScheme { - override val typeId = 0 - - override def supports(columnType: ColumnType[_]): Boolean = true - - override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): Encoder[T] = { - new this.Encoder[T](columnType) - } - - override def decoder[T <: AtomicType]( - buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T] = { - new this.Decoder(buffer, columnType) - } - - class Encoder[T <: AtomicType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { - override def uncompressedSize: Int = 0 - - override def compressedSize: Int = 0 - - override def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer = { - // Writes compression type ID and copies raw contents - to.putInt(PassThrough.typeId).put(from).rewind() - to - } - } - - class Decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) - extends compression.Decoder[T] { - - override def next(row: MutableRow, ordinal: Int): Unit = { - columnType.extract(buffer, row, ordinal) - } - - override def hasNext: Boolean = buffer.hasRemaining - } -} - -private[sql] case object RunLengthEncoding extends CompressionScheme { - override val typeId = 1 - - override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): Encoder[T] = { - new this.Encoder[T](columnType) - } - - override def decoder[T <: AtomicType]( - buffer: ByteBuffer, columnType: NativeColumnType[T]): Decoder[T] = { - new this.Decoder(buffer, columnType) - } - - override def supports(columnType: ColumnType[_]): Boolean = columnType match { - case INT | LONG | SHORT | BYTE | STRING | BOOLEAN => true - case _ => false - } - - class Encoder[T <: AtomicType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { - private var _uncompressedSize = 0 - private var _compressedSize = 0 - - // Using `MutableRow` to store the last value to avoid boxing/unboxing cost. - private val lastValue = new SpecificMutableRow(Seq(columnType.dataType)) - private var lastRun = 0 - - override def uncompressedSize: Int = _uncompressedSize - - override def compressedSize: Int = _compressedSize - - override def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = { - val value = columnType.getField(row, ordinal) - val actualSize = columnType.actualSize(row, ordinal) - _uncompressedSize += actualSize - - if (lastValue.isNullAt(0)) { - columnType.copyField(row, ordinal, lastValue, 0) - lastRun = 1 - _compressedSize += actualSize + 4 - } else { - if (columnType.getField(lastValue, 0) == value) { - lastRun += 1 - } else { - _compressedSize += actualSize + 4 - columnType.copyField(row, ordinal, lastValue, 0) - lastRun = 1 - } - } - } - - override def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer = { - to.putInt(RunLengthEncoding.typeId) - - if (from.hasRemaining) { - val currentValue = new SpecificMutableRow(Seq(columnType.dataType)) - var currentRun = 1 - val value = new SpecificMutableRow(Seq(columnType.dataType)) - - columnType.extract(from, currentValue, 0) - - while (from.hasRemaining) { - columnType.extract(from, value, 0) - - if (value.get(0, columnType.dataType) == currentValue.get(0, columnType.dataType)) { - currentRun += 1 - } else { - // Writes current run - columnType.append(currentValue, 0, to) - to.putInt(currentRun) - - // Resets current run - columnType.copyField(value, 0, currentValue, 0) - currentRun = 1 - } - } - - // Writes the last run - columnType.append(currentValue, 0, to) - to.putInt(currentRun) - } - - to.rewind() - to - } - } - - class Decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) - extends compression.Decoder[T] { - - private var run = 0 - private var valueCount = 0 - private var currentValue: T#InternalType = _ - - override def next(row: MutableRow, ordinal: Int): Unit = { - if (valueCount == run) { - currentValue = columnType.extract(buffer) - run = ByteBufferHelper.getInt(buffer) - valueCount = 1 - } else { - valueCount += 1 - } - - columnType.setField(row, ordinal, currentValue) - } - - override def hasNext: Boolean = valueCount < run || buffer.hasRemaining - } -} - -private[sql] case object DictionaryEncoding extends CompressionScheme { - override val typeId = 2 - - // 32K unique values allowed - val MAX_DICT_SIZE = Short.MaxValue - - override def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) - : Decoder[T] = { - new this.Decoder(buffer, columnType) - } - - override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): Encoder[T] = { - new this.Encoder[T](columnType) - } - - override def supports(columnType: ColumnType[_]): Boolean = columnType match { - case INT | LONG | STRING => true - case _ => false - } - - class Encoder[T <: AtomicType](columnType: NativeColumnType[T]) extends compression.Encoder[T] { - // Size of the input, uncompressed, in bytes. Note that we only count until the dictionary - // overflows. - private var _uncompressedSize = 0 - - // If the number of distinct elements is too large, we discard the use of dictionary encoding - // and set the overflow flag to true. - private var overflow = false - - // Total number of elements. - private var count = 0 - - // The reverse mapping of _dictionary, i.e. mapping encoded integer to the value itself. - private var values = new mutable.ArrayBuffer[T#InternalType](1024) - - // The dictionary that maps a value to the encoded short integer. - private val dictionary = mutable.HashMap.empty[Any, Short] - - // Size of the serialized dictionary in bytes. Initialized to 4 since we need at least an `Int` - // to store dictionary element count. - private var dictionarySize = 4 - - override def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = { - val value = columnType.getField(row, ordinal) - - if (!overflow) { - val actualSize = columnType.actualSize(row, ordinal) - count += 1 - _uncompressedSize += actualSize - - if (!dictionary.contains(value)) { - if (dictionary.size < MAX_DICT_SIZE) { - val clone = columnType.clone(value) - values += clone - dictionarySize += actualSize - dictionary(clone) = dictionary.size.toShort - } else { - overflow = true - values.clear() - dictionary.clear() - } - } - } - } - - override def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer = { - if (overflow) { - throw new IllegalStateException( - "Dictionary encoding should not be used because of dictionary overflow.") - } - - to.putInt(DictionaryEncoding.typeId) - .putInt(dictionary.size) - - var i = 0 - while (i < values.length) { - columnType.append(values(i), to) - i += 1 - } - - while (from.hasRemaining) { - to.putShort(dictionary(columnType.extract(from))) - } - - to.rewind() - to - } - - override def uncompressedSize: Int = _uncompressedSize - - override def compressedSize: Int = if (overflow) Int.MaxValue else dictionarySize + count * 2 - } - - class Decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) - extends compression.Decoder[T] { - - private val dictionary: Array[Any] = { - val elementNum = ByteBufferHelper.getInt(buffer) - Array.fill[Any](elementNum)(columnType.extract(buffer).asInstanceOf[Any]) - } - - override def next(row: MutableRow, ordinal: Int): Unit = { - columnType.setField(row, ordinal, dictionary(buffer.getShort()).asInstanceOf[T#InternalType]) - } - - override def hasNext: Boolean = buffer.hasRemaining - } -} - -private[sql] case object BooleanBitSet extends CompressionScheme { - override val typeId = 3 - - val BITS_PER_LONG = 64 - - override def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) - : compression.Decoder[T] = { - new this.Decoder(buffer).asInstanceOf[compression.Decoder[T]] - } - - override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): compression.Encoder[T] = { - (new this.Encoder).asInstanceOf[compression.Encoder[T]] - } - - override def supports(columnType: ColumnType[_]): Boolean = columnType == BOOLEAN - - class Encoder extends compression.Encoder[BooleanType.type] { - private var _uncompressedSize = 0 - - override def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = { - _uncompressedSize += BOOLEAN.defaultSize - } - - override def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer = { - to.putInt(BooleanBitSet.typeId) - // Total element count (1 byte per Boolean value) - .putInt(from.remaining) - - while (from.remaining >= BITS_PER_LONG) { - var word = 0: Long - var i = 0 - - while (i < BITS_PER_LONG) { - if (BOOLEAN.extract(from)) { - word |= (1: Long) << i - } - i += 1 - } - - to.putLong(word) - } - - if (from.hasRemaining) { - var word = 0: Long - var i = 0 - - while (from.hasRemaining) { - if (BOOLEAN.extract(from)) { - word |= (1: Long) << i - } - i += 1 - } - - to.putLong(word) - } - - to.rewind() - to - } - - override def uncompressedSize: Int = _uncompressedSize - - override def compressedSize: Int = { - val extra = if (_uncompressedSize % BITS_PER_LONG == 0) 0 else 1 - (_uncompressedSize / BITS_PER_LONG + extra) * 8 + 4 - } - } - - class Decoder(buffer: ByteBuffer) extends compression.Decoder[BooleanType.type] { - private val count = ByteBufferHelper.getInt(buffer) - - private var currentWord = 0: Long - - private var visited: Int = 0 - - override def next(row: MutableRow, ordinal: Int): Unit = { - val bit = visited % BITS_PER_LONG - - visited += 1 - if (bit == 0) { - currentWord = ByteBufferHelper.getLong(buffer) - } - - row.setBoolean(ordinal, ((currentWord >> bit) & 1) != 0) - } - - override def hasNext: Boolean = visited < count - } -} - -private[sql] case object IntDelta extends CompressionScheme { - override def typeId: Int = 4 - - override def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) - : compression.Decoder[T] = { - new Decoder(buffer, INT).asInstanceOf[compression.Decoder[T]] - } - - override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): compression.Encoder[T] = { - (new Encoder).asInstanceOf[compression.Encoder[T]] - } - - override def supports(columnType: ColumnType[_]): Boolean = columnType == INT - - class Encoder extends compression.Encoder[IntegerType.type] { - protected var _compressedSize: Int = 0 - protected var _uncompressedSize: Int = 0 - - override def compressedSize: Int = _compressedSize - override def uncompressedSize: Int = _uncompressedSize - - private var prevValue: Int = _ - - override def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = { - val value = row.getInt(ordinal) - val delta = value - prevValue - - _compressedSize += 1 - - // If this is the first integer to be compressed, or the delta is out of byte range, then give - // up compressing this integer. - if (_uncompressedSize == 0 || delta <= Byte.MinValue || delta > Byte.MaxValue) { - _compressedSize += INT.defaultSize - } - - _uncompressedSize += INT.defaultSize - prevValue = value - } - - override def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer = { - to.putInt(typeId) - - if (from.hasRemaining) { - var prev = from.getInt() - to.put(Byte.MinValue) - to.putInt(prev) - - while (from.hasRemaining) { - val current = from.getInt() - val delta = current - prev - prev = current - - if (Byte.MinValue < delta && delta <= Byte.MaxValue) { - to.put(delta.toByte) - } else { - to.put(Byte.MinValue) - to.putInt(current) - } - } - } - - to.rewind().asInstanceOf[ByteBuffer] - } - } - - class Decoder(buffer: ByteBuffer, columnType: NativeColumnType[IntegerType.type]) - extends compression.Decoder[IntegerType.type] { - - private var prev: Int = _ - - override def hasNext: Boolean = buffer.hasRemaining - - override def next(row: MutableRow, ordinal: Int): Unit = { - val delta = buffer.get() - prev = if (delta > Byte.MinValue) prev + delta else ByteBufferHelper.getInt(buffer) - row.setInt(ordinal, prev) - } - } -} - -private[sql] case object LongDelta extends CompressionScheme { - override def typeId: Int = 5 - - override def decoder[T <: AtomicType](buffer: ByteBuffer, columnType: NativeColumnType[T]) - : compression.Decoder[T] = { - new Decoder(buffer, LONG).asInstanceOf[compression.Decoder[T]] - } - - override def encoder[T <: AtomicType](columnType: NativeColumnType[T]): compression.Encoder[T] = { - (new Encoder).asInstanceOf[compression.Encoder[T]] - } - - override def supports(columnType: ColumnType[_]): Boolean = columnType == LONG - - class Encoder extends compression.Encoder[LongType.type] { - protected var _compressedSize: Int = 0 - protected var _uncompressedSize: Int = 0 - - override def compressedSize: Int = _compressedSize - override def uncompressedSize: Int = _uncompressedSize - - private var prevValue: Long = _ - - override def gatherCompressibilityStats(row: InternalRow, ordinal: Int): Unit = { - val value = row.getLong(ordinal) - val delta = value - prevValue - - _compressedSize += 1 - - // If this is the first long integer to be compressed, or the delta is out of byte range, then - // give up compressing this long integer. - if (_uncompressedSize == 0 || delta <= Byte.MinValue || delta > Byte.MaxValue) { - _compressedSize += LONG.defaultSize - } - - _uncompressedSize += LONG.defaultSize - prevValue = value - } - - override def compress(from: ByteBuffer, to: ByteBuffer): ByteBuffer = { - to.putInt(typeId) - - if (from.hasRemaining) { - var prev = from.getLong() - to.put(Byte.MinValue) - to.putLong(prev) - - while (from.hasRemaining) { - val current = from.getLong() - val delta = current - prev - prev = current - - if (Byte.MinValue < delta && delta <= Byte.MaxValue) { - to.put(delta.toByte) - } else { - to.put(Byte.MinValue) - to.putLong(current) - } - } - } - - to.rewind().asInstanceOf[ByteBuffer] - } - } - - class Decoder(buffer: ByteBuffer, columnType: NativeColumnType[LongType.type]) - extends compression.Decoder[LongType.type] { - - private var prev: Long = _ - - override def hasNext: Boolean = buffer.hasRemaining - - override def next(row: MutableRow, ordinal: Int): Unit = { - val delta = buffer.get() - prev = if (delta > Byte.MinValue) prev + delta else ByteBufferHelper.getLong(buffer) - row.setLong(ordinal, prev) - } - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/ea1a51fc/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index f85aeb1..293fcfe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -22,7 +22,7 @@ import java.util.concurrent.locks.ReentrantReadWriteLock import org.apache.spark.Logging import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.columnar.InMemoryRelation +import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.storage.StorageLevel import org.apache.spark.storage.StorageLevel.MEMORY_AND_DISK http://git-wip-us.apache.org/repos/asf/spark/blob/ea1a51fc/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 3d4ce63..f67c951 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} +import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _} import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} import org.apache.spark.sql.{Strategy, execution} http://git-wip-us.apache.org/repos/asf/spark/blob/ea1a51fc/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala new file mode 100644 index 0000000..fee36f6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala @@ -0,0 +1,148 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.nio.{ByteBuffer, ByteOrder} + +import org.apache.spark.sql.catalyst.expressions.{MutableRow, UnsafeArrayData, UnsafeMapData, UnsafeRow} +import org.apache.spark.sql.execution.columnar.compression.CompressibleColumnAccessor +import org.apache.spark.sql.types._ + +/** + * An `Iterator` like trait used to extract values from columnar byte buffer. When a value is + * extracted from the buffer, instead of directly returning it, the value is set into some field of + * a [[MutableRow]]. In this way, boxing cost can be avoided by leveraging the setter methods + * for primitive values provided by [[MutableRow]]. + */ +private[columnar] trait ColumnAccessor { + initialize() + + protected def initialize() + + def hasNext: Boolean + + def extractTo(row: MutableRow, ordinal: Int) + + protected def underlyingBuffer: ByteBuffer +} + +private[columnar] abstract class BasicColumnAccessor[JvmType]( + protected val buffer: ByteBuffer, + protected val columnType: ColumnType[JvmType]) + extends ColumnAccessor { + + protected def initialize() {} + + override def hasNext: Boolean = buffer.hasRemaining + + override def extractTo(row: MutableRow, ordinal: Int): Unit = { + extractSingle(row, ordinal) + } + + def extractSingle(row: MutableRow, ordinal: Int): Unit = { + columnType.extract(buffer, row, ordinal) + } + + protected def underlyingBuffer = buffer +} + +private[columnar] class NullColumnAccessor(buffer: ByteBuffer) + extends BasicColumnAccessor[Any](buffer, NULL) + with NullableColumnAccessor + +private[columnar] abstract class NativeColumnAccessor[T <: AtomicType]( + override protected val buffer: ByteBuffer, + override protected val columnType: NativeColumnType[T]) + extends BasicColumnAccessor(buffer, columnType) + with NullableColumnAccessor + with CompressibleColumnAccessor[T] + +private[columnar] class BooleanColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, BOOLEAN) + +private[columnar] class ByteColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, BYTE) + +private[columnar] class ShortColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, SHORT) + +private[columnar] class IntColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, INT) + +private[columnar] class LongColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, LONG) + +private[columnar] class FloatColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, FLOAT) + +private[columnar] class DoubleColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, DOUBLE) + +private[columnar] class StringColumnAccessor(buffer: ByteBuffer) + extends NativeColumnAccessor(buffer, STRING) + +private[columnar] class BinaryColumnAccessor(buffer: ByteBuffer) + extends BasicColumnAccessor[Array[Byte]](buffer, BINARY) + with NullableColumnAccessor + +private[columnar] class CompactDecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType) + extends NativeColumnAccessor(buffer, COMPACT_DECIMAL(dataType)) + +private[columnar] class DecimalColumnAccessor(buffer: ByteBuffer, dataType: DecimalType) + extends BasicColumnAccessor[Decimal](buffer, LARGE_DECIMAL(dataType)) + with NullableColumnAccessor + +private[columnar] class StructColumnAccessor(buffer: ByteBuffer, dataType: StructType) + extends BasicColumnAccessor[UnsafeRow](buffer, STRUCT(dataType)) + with NullableColumnAccessor + +private[columnar] class ArrayColumnAccessor(buffer: ByteBuffer, dataType: ArrayType) + extends BasicColumnAccessor[UnsafeArrayData](buffer, ARRAY(dataType)) + with NullableColumnAccessor + +private[columnar] class MapColumnAccessor(buffer: ByteBuffer, dataType: MapType) + extends BasicColumnAccessor[UnsafeMapData](buffer, MAP(dataType)) + with NullableColumnAccessor + +private[columnar] object ColumnAccessor { + def apply(dataType: DataType, buffer: ByteBuffer): ColumnAccessor = { + val buf = buffer.order(ByteOrder.nativeOrder) + + dataType match { + case NullType => new NullColumnAccessor(buf) + case BooleanType => new BooleanColumnAccessor(buf) + case ByteType => new ByteColumnAccessor(buf) + case ShortType => new ShortColumnAccessor(buf) + case IntegerType | DateType => new IntColumnAccessor(buf) + case LongType | TimestampType => new LongColumnAccessor(buf) + case FloatType => new FloatColumnAccessor(buf) + case DoubleType => new DoubleColumnAccessor(buf) + case StringType => new StringColumnAccessor(buf) + case BinaryType => new BinaryColumnAccessor(buf) + case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => + new CompactDecimalColumnAccessor(buf, dt) + case dt: DecimalType => new DecimalColumnAccessor(buf, dt) + case struct: StructType => new StructColumnAccessor(buf, struct) + case array: ArrayType => new ArrayColumnAccessor(buf, array) + case map: MapType => new MapColumnAccessor(buf, map) + case udt: UserDefinedType[_] => ColumnAccessor(udt.sqlType, buffer) + case other => + throw new Exception(s"not support type: $other") + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/ea1a51fc/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala new file mode 100644 index 0000000..7e26f19 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.nio.{ByteBuffer, ByteOrder} + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.columnar.ColumnBuilder._ +import org.apache.spark.sql.execution.columnar.compression.{AllCompressionSchemes, CompressibleColumnBuilder} +import org.apache.spark.sql.types._ + +private[columnar] trait ColumnBuilder { + /** + * Initializes with an approximate lower bound on the expected number of elements in this column. + */ + def initialize(initialSize: Int, columnName: String = "", useCompression: Boolean = false) + + /** + * Appends `row(ordinal)` to the column builder. + */ + def appendFrom(row: InternalRow, ordinal: Int) + + /** + * Column statistics information + */ + def columnStats: ColumnStats + + /** + * Returns the final columnar byte buffer. + */ + def build(): ByteBuffer +} + +private[columnar] class BasicColumnBuilder[JvmType]( + val columnStats: ColumnStats, + val columnType: ColumnType[JvmType]) + extends ColumnBuilder { + + protected var columnName: String = _ + + protected var buffer: ByteBuffer = _ + + override def initialize( + initialSize: Int, + columnName: String = "", + useCompression: Boolean = false): Unit = { + + val size = if (initialSize == 0) DEFAULT_INITIAL_BUFFER_SIZE else initialSize + this.columnName = columnName + + buffer = ByteBuffer.allocate(size * columnType.defaultSize) + buffer.order(ByteOrder.nativeOrder()) + } + + override def appendFrom(row: InternalRow, ordinal: Int): Unit = { + buffer = ensureFreeSpace(buffer, columnType.actualSize(row, ordinal)) + columnType.append(row, ordinal, buffer) + } + + override def build(): ByteBuffer = { + if (buffer.capacity() > buffer.position() * 1.1) { + // trim the buffer + buffer = ByteBuffer + .allocate(buffer.position()) + .order(ByteOrder.nativeOrder()) + .put(buffer.array(), 0, buffer.position()) + } + buffer.flip().asInstanceOf[ByteBuffer] + } +} + +private[columnar] class NullColumnBuilder + extends BasicColumnBuilder[Any](new ObjectColumnStats(NullType), NULL) + with NullableColumnBuilder + +private[columnar] abstract class ComplexColumnBuilder[JvmType]( + columnStats: ColumnStats, + columnType: ColumnType[JvmType]) + extends BasicColumnBuilder[JvmType](columnStats, columnType) + with NullableColumnBuilder + +private[columnar] abstract class NativeColumnBuilder[T <: AtomicType]( + override val columnStats: ColumnStats, + override val columnType: NativeColumnType[T]) + extends BasicColumnBuilder[T#InternalType](columnStats, columnType) + with NullableColumnBuilder + with AllCompressionSchemes + with CompressibleColumnBuilder[T] + +private[columnar] +class BooleanColumnBuilder extends NativeColumnBuilder(new BooleanColumnStats, BOOLEAN) + +private[columnar] +class ByteColumnBuilder extends NativeColumnBuilder(new ByteColumnStats, BYTE) + +private[columnar] class ShortColumnBuilder extends NativeColumnBuilder(new ShortColumnStats, SHORT) + +private[columnar] class IntColumnBuilder extends NativeColumnBuilder(new IntColumnStats, INT) + +private[columnar] class LongColumnBuilder extends NativeColumnBuilder(new LongColumnStats, LONG) + +private[columnar] class FloatColumnBuilder extends NativeColumnBuilder(new FloatColumnStats, FLOAT) + +private[columnar] +class DoubleColumnBuilder extends NativeColumnBuilder(new DoubleColumnStats, DOUBLE) + +private[columnar] +class StringColumnBuilder extends NativeColumnBuilder(new StringColumnStats, STRING) + +private[columnar] +class BinaryColumnBuilder extends ComplexColumnBuilder(new BinaryColumnStats, BINARY) + +private[columnar] class CompactDecimalColumnBuilder(dataType: DecimalType) + extends NativeColumnBuilder(new DecimalColumnStats(dataType), COMPACT_DECIMAL(dataType)) + +private[columnar] class DecimalColumnBuilder(dataType: DecimalType) + extends ComplexColumnBuilder(new DecimalColumnStats(dataType), LARGE_DECIMAL(dataType)) + +private[columnar] class StructColumnBuilder(dataType: StructType) + extends ComplexColumnBuilder(new ObjectColumnStats(dataType), STRUCT(dataType)) + +private[columnar] class ArrayColumnBuilder(dataType: ArrayType) + extends ComplexColumnBuilder(new ObjectColumnStats(dataType), ARRAY(dataType)) + +private[columnar] class MapColumnBuilder(dataType: MapType) + extends ComplexColumnBuilder(new ObjectColumnStats(dataType), MAP(dataType)) + +private[columnar] object ColumnBuilder { + val DEFAULT_INITIAL_BUFFER_SIZE = 128 * 1024 + val MAX_BATCH_SIZE_IN_BYTE = 4 * 1024 * 1024L + + private[columnar] def ensureFreeSpace(orig: ByteBuffer, size: Int) = { + if (orig.remaining >= size) { + orig + } else { + // grow in steps of initial size + val capacity = orig.capacity() + val newSize = capacity + size.max(capacity) + val pos = orig.position() + + ByteBuffer + .allocate(newSize) + .order(ByteOrder.nativeOrder()) + .put(orig.array(), 0, pos) + } + } + + def apply( + dataType: DataType, + initialSize: Int = 0, + columnName: String = "", + useCompression: Boolean = false): ColumnBuilder = { + val builder: ColumnBuilder = dataType match { + case NullType => new NullColumnBuilder + case BooleanType => new BooleanColumnBuilder + case ByteType => new ByteColumnBuilder + case ShortType => new ShortColumnBuilder + case IntegerType | DateType => new IntColumnBuilder + case LongType | TimestampType => new LongColumnBuilder + case FloatType => new FloatColumnBuilder + case DoubleType => new DoubleColumnBuilder + case StringType => new StringColumnBuilder + case BinaryType => new BinaryColumnBuilder + case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => + new CompactDecimalColumnBuilder(dt) + case dt: DecimalType => new DecimalColumnBuilder(dt) + case struct: StructType => new StructColumnBuilder(struct) + case array: ArrayType => new ArrayColumnBuilder(array) + case map: MapType => new MapColumnBuilder(map) + case udt: UserDefinedType[_] => + return apply(udt.sqlType, initialSize, columnName, useCompression) + case other => + throw new Exception(s"not suppported type: $other") + } + + builder.initialize(initialSize, columnName, useCompression) + builder + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/ea1a51fc/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala new file mode 100644 index 0000000..c52ee9f --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnStats.scala @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, Attribute, AttributeMap, AttributeReference} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +private[columnar] class ColumnStatisticsSchema(a: Attribute) extends Serializable { + val upperBound = AttributeReference(a.name + ".upperBound", a.dataType, nullable = true)() + val lowerBound = AttributeReference(a.name + ".lowerBound", a.dataType, nullable = true)() + val nullCount = AttributeReference(a.name + ".nullCount", IntegerType, nullable = false)() + val count = AttributeReference(a.name + ".count", IntegerType, nullable = false)() + val sizeInBytes = AttributeReference(a.name + ".sizeInBytes", LongType, nullable = false)() + + val schema = Seq(lowerBound, upperBound, nullCount, count, sizeInBytes) +} + +private[columnar] class PartitionStatistics(tableSchema: Seq[Attribute]) extends Serializable { + val (forAttribute, schema) = { + val allStats = tableSchema.map(a => a -> new ColumnStatisticsSchema(a)) + (AttributeMap(allStats), allStats.map(_._2.schema).foldLeft(Seq.empty[Attribute])(_ ++ _)) + } +} + +/** + * Used to collect statistical information when building in-memory columns. + * + * NOTE: we intentionally avoid using `Ordering[T]` to compare values here because `Ordering[T]` + * brings significant performance penalty. + */ +private[columnar] sealed trait ColumnStats extends Serializable { + protected var count = 0 + protected var nullCount = 0 + private[columnar] var sizeInBytes = 0L + + /** + * Gathers statistics information from `row(ordinal)`. + */ + def gatherStats(row: InternalRow, ordinal: Int): Unit = { + if (row.isNullAt(ordinal)) { + nullCount += 1 + // 4 bytes for null position + sizeInBytes += 4 + } + count += 1 + } + + /** + * Column statistics represented as a single row, currently including closed lower bound, closed + * upper bound and null count. + */ + def collectedStatistics: GenericInternalRow +} + +/** + * A no-op ColumnStats only used for testing purposes. + */ +private[columnar] class NoopColumnStats extends ColumnStats { + override def gatherStats(row: InternalRow, ordinal: Int): Unit = super.gatherStats(row, ordinal) + + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](null, null, nullCount, count, 0L)) +} + +private[columnar] class BooleanColumnStats extends ColumnStats { + protected var upper = false + protected var lower = true + + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) + if (!row.isNullAt(ordinal)) { + val value = row.getBoolean(ordinal) + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += BOOLEAN.defaultSize + } + } + + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) +} + +private[columnar] class ByteColumnStats extends ColumnStats { + protected var upper = Byte.MinValue + protected var lower = Byte.MaxValue + + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) + if (!row.isNullAt(ordinal)) { + val value = row.getByte(ordinal) + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += BYTE.defaultSize + } + } + + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) +} + +private[columnar] class ShortColumnStats extends ColumnStats { + protected var upper = Short.MinValue + protected var lower = Short.MaxValue + + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) + if (!row.isNullAt(ordinal)) { + val value = row.getShort(ordinal) + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += SHORT.defaultSize + } + } + + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) +} + +private[columnar] class IntColumnStats extends ColumnStats { + protected var upper = Int.MinValue + protected var lower = Int.MaxValue + + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) + if (!row.isNullAt(ordinal)) { + val value = row.getInt(ordinal) + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += INT.defaultSize + } + } + + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) +} + +private[columnar] class LongColumnStats extends ColumnStats { + protected var upper = Long.MinValue + protected var lower = Long.MaxValue + + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) + if (!row.isNullAt(ordinal)) { + val value = row.getLong(ordinal) + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += LONG.defaultSize + } + } + + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) +} + +private[columnar] class FloatColumnStats extends ColumnStats { + protected var upper = Float.MinValue + protected var lower = Float.MaxValue + + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) + if (!row.isNullAt(ordinal)) { + val value = row.getFloat(ordinal) + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += FLOAT.defaultSize + } + } + + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) +} + +private[columnar] class DoubleColumnStats extends ColumnStats { + protected var upper = Double.MinValue + protected var lower = Double.MaxValue + + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) + if (!row.isNullAt(ordinal)) { + val value = row.getDouble(ordinal) + if (value > upper) upper = value + if (value < lower) lower = value + sizeInBytes += DOUBLE.defaultSize + } + } + + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) +} + +private[columnar] class StringColumnStats extends ColumnStats { + protected var upper: UTF8String = null + protected var lower: UTF8String = null + + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) + if (!row.isNullAt(ordinal)) { + val value = row.getUTF8String(ordinal) + if (upper == null || value.compareTo(upper) > 0) upper = value.clone() + if (lower == null || value.compareTo(lower) < 0) lower = value.clone() + sizeInBytes += STRING.actualSize(row, ordinal) + } + } + + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) +} + +private[columnar] class BinaryColumnStats extends ColumnStats { + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) + if (!row.isNullAt(ordinal)) { + sizeInBytes += BINARY.actualSize(row, ordinal) + } + } + + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes)) +} + +private[columnar] class DecimalColumnStats(precision: Int, scale: Int) extends ColumnStats { + def this(dt: DecimalType) = this(dt.precision, dt.scale) + + protected var upper: Decimal = null + protected var lower: Decimal = null + + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) + if (!row.isNullAt(ordinal)) { + val value = row.getDecimal(ordinal, precision, scale) + if (upper == null || value.compareTo(upper) > 0) upper = value + if (lower == null || value.compareTo(lower) < 0) lower = value + // TODO: this is not right for DecimalType with precision > 18 + sizeInBytes += 8 + } + } + + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](lower, upper, nullCount, count, sizeInBytes)) +} + +private[columnar] class ObjectColumnStats(dataType: DataType) extends ColumnStats { + val columnType = ColumnType(dataType) + + override def gatherStats(row: InternalRow, ordinal: Int): Unit = { + super.gatherStats(row, ordinal) + if (!row.isNullAt(ordinal)) { + sizeInBytes += columnType.actualSize(row, ordinal) + } + } + + override def collectedStatistics: GenericInternalRow = + new GenericInternalRow(Array[Any](null, null, nullCount, count, sizeInBytes)) +} http://git-wip-us.apache.org/repos/asf/spark/blob/ea1a51fc/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala new file mode 100644 index 0000000..c9f2329 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala @@ -0,0 +1,689 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.math.{BigDecimal, BigInteger} +import java.nio.ByteBuffer + +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.types.UTF8String + + +/** + * A help class for fast reading Int/Long/Float/Double from ByteBuffer in native order. + * + * Note: There is not much difference between ByteBuffer.getByte/getShort and + * Unsafe.getByte/getShort, so we do not have helper methods for them. + * + * The unrolling (building columnar cache) is already slow, putLong/putDouble will not help much, + * so we do not have helper methods for them. + * + * + * WARNNING: This only works with HeapByteBuffer + */ +private[columnar] object ByteBufferHelper { + def getInt(buffer: ByteBuffer): Int = { + val pos = buffer.position() + buffer.position(pos + 4) + Platform.getInt(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos) + } + + def getLong(buffer: ByteBuffer): Long = { + val pos = buffer.position() + buffer.position(pos + 8) + Platform.getLong(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos) + } + + def getFloat(buffer: ByteBuffer): Float = { + val pos = buffer.position() + buffer.position(pos + 4) + Platform.getFloat(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos) + } + + def getDouble(buffer: ByteBuffer): Double = { + val pos = buffer.position() + buffer.position(pos + 8) + Platform.getDouble(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos) + } +} + +/** + * An abstract class that represents type of a column. Used to append/extract Java objects into/from + * the underlying [[ByteBuffer]] of a column. + * + * @tparam JvmType Underlying Java type to represent the elements. + */ +private[columnar] sealed abstract class ColumnType[JvmType] { + + // The catalyst data type of this column. + def dataType: DataType + + // Default size in bytes for one element of type T (e.g. 4 for `Int`). + def defaultSize: Int + + /** + * Extracts a value out of the buffer at the buffer's current position. + */ + def extract(buffer: ByteBuffer): JvmType + + /** + * Extracts a value out of the buffer at the buffer's current position and stores in + * `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs whenever + * possible. + */ + def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + setField(row, ordinal, extract(buffer)) + } + + /** + * Appends the given value v of type T into the given ByteBuffer. + */ + def append(v: JvmType, buffer: ByteBuffer): Unit + + /** + * Appends `row(ordinal)` of type T into the given ByteBuffer. Subclasses should override this + * method to avoid boxing/unboxing costs whenever possible. + */ + def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + append(getField(row, ordinal), buffer) + } + + /** + * Returns the size of the value `row(ordinal)`. This is used to calculate the size of variable + * length types such as byte arrays and strings. + */ + def actualSize(row: InternalRow, ordinal: Int): Int = defaultSize + + /** + * Returns `row(ordinal)`. Subclasses should override this method to avoid boxing/unboxing costs + * whenever possible. + */ + def getField(row: InternalRow, ordinal: Int): JvmType + + /** + * Sets `row(ordinal)` to `field`. Subclasses should override this method to avoid boxing/unboxing + * costs whenever possible. + */ + def setField(row: MutableRow, ordinal: Int, value: JvmType): Unit + + /** + * Copies `from(fromOrdinal)` to `to(toOrdinal)`. Subclasses should override this method to avoid + * boxing/unboxing costs whenever possible. + */ + def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int): Unit = { + setField(to, toOrdinal, getField(from, fromOrdinal)) + } + + /** + * Creates a duplicated copy of the value. + */ + def clone(v: JvmType): JvmType = v + + override def toString: String = getClass.getSimpleName.stripSuffix("$") +} + +private[columnar] object NULL extends ColumnType[Any] { + + override def dataType: DataType = NullType + override def defaultSize: Int = 0 + override def append(v: Any, buffer: ByteBuffer): Unit = {} + override def extract(buffer: ByteBuffer): Any = null + override def setField(row: MutableRow, ordinal: Int, value: Any): Unit = row.setNullAt(ordinal) + override def getField(row: InternalRow, ordinal: Int): Any = null +} + +private[columnar] abstract class NativeColumnType[T <: AtomicType]( + val dataType: T, + val defaultSize: Int) + extends ColumnType[T#InternalType] { + + /** + * Scala TypeTag. Can be used to create primitive arrays and hash tables. + */ + def scalaTag: TypeTag[dataType.InternalType] = dataType.tag +} + +private[columnar] object INT extends NativeColumnType(IntegerType, 4) { + override def append(v: Int, buffer: ByteBuffer): Unit = { + buffer.putInt(v) + } + + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putInt(row.getInt(ordinal)) + } + + override def extract(buffer: ByteBuffer): Int = { + ByteBufferHelper.getInt(buffer) + } + + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setInt(ordinal, ByteBufferHelper.getInt(buffer)) + } + + override def setField(row: MutableRow, ordinal: Int, value: Int): Unit = { + row.setInt(ordinal, value) + } + + override def getField(row: InternalRow, ordinal: Int): Int = row.getInt(ordinal) + + + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + to.setInt(toOrdinal, from.getInt(fromOrdinal)) + } +} + +private[columnar] object LONG extends NativeColumnType(LongType, 8) { + override def append(v: Long, buffer: ByteBuffer): Unit = { + buffer.putLong(v) + } + + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putLong(row.getLong(ordinal)) + } + + override def extract(buffer: ByteBuffer): Long = { + ByteBufferHelper.getLong(buffer) + } + + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setLong(ordinal, ByteBufferHelper.getLong(buffer)) + } + + override def setField(row: MutableRow, ordinal: Int, value: Long): Unit = { + row.setLong(ordinal, value) + } + + override def getField(row: InternalRow, ordinal: Int): Long = row.getLong(ordinal) + + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + to.setLong(toOrdinal, from.getLong(fromOrdinal)) + } +} + +private[columnar] object FLOAT extends NativeColumnType(FloatType, 4) { + override def append(v: Float, buffer: ByteBuffer): Unit = { + buffer.putFloat(v) + } + + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putFloat(row.getFloat(ordinal)) + } + + override def extract(buffer: ByteBuffer): Float = { + ByteBufferHelper.getFloat(buffer) + } + + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setFloat(ordinal, ByteBufferHelper.getFloat(buffer)) + } + + override def setField(row: MutableRow, ordinal: Int, value: Float): Unit = { + row.setFloat(ordinal, value) + } + + override def getField(row: InternalRow, ordinal: Int): Float = row.getFloat(ordinal) + + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + to.setFloat(toOrdinal, from.getFloat(fromOrdinal)) + } +} + +private[columnar] object DOUBLE extends NativeColumnType(DoubleType, 8) { + override def append(v: Double, buffer: ByteBuffer): Unit = { + buffer.putDouble(v) + } + + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putDouble(row.getDouble(ordinal)) + } + + override def extract(buffer: ByteBuffer): Double = { + ByteBufferHelper.getDouble(buffer) + } + + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setDouble(ordinal, ByteBufferHelper.getDouble(buffer)) + } + + override def setField(row: MutableRow, ordinal: Int, value: Double): Unit = { + row.setDouble(ordinal, value) + } + + override def getField(row: InternalRow, ordinal: Int): Double = row.getDouble(ordinal) + + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + to.setDouble(toOrdinal, from.getDouble(fromOrdinal)) + } +} + +private[columnar] object BOOLEAN extends NativeColumnType(BooleanType, 1) { + override def append(v: Boolean, buffer: ByteBuffer): Unit = { + buffer.put(if (v) 1: Byte else 0: Byte) + } + + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.put(if (row.getBoolean(ordinal)) 1: Byte else 0: Byte) + } + + override def extract(buffer: ByteBuffer): Boolean = buffer.get() == 1 + + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setBoolean(ordinal, buffer.get() == 1) + } + + override def setField(row: MutableRow, ordinal: Int, value: Boolean): Unit = { + row.setBoolean(ordinal, value) + } + + override def getField(row: InternalRow, ordinal: Int): Boolean = row.getBoolean(ordinal) + + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + to.setBoolean(toOrdinal, from.getBoolean(fromOrdinal)) + } +} + +private[columnar] object BYTE extends NativeColumnType(ByteType, 1) { + override def append(v: Byte, buffer: ByteBuffer): Unit = { + buffer.put(v) + } + + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.put(row.getByte(ordinal)) + } + + override def extract(buffer: ByteBuffer): Byte = { + buffer.get() + } + + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setByte(ordinal, buffer.get()) + } + + override def setField(row: MutableRow, ordinal: Int, value: Byte): Unit = { + row.setByte(ordinal, value) + } + + override def getField(row: InternalRow, ordinal: Int): Byte = row.getByte(ordinal) + + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + to.setByte(toOrdinal, from.getByte(fromOrdinal)) + } +} + +private[columnar] object SHORT extends NativeColumnType(ShortType, 2) { + override def append(v: Short, buffer: ByteBuffer): Unit = { + buffer.putShort(v) + } + + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + buffer.putShort(row.getShort(ordinal)) + } + + override def extract(buffer: ByteBuffer): Short = { + buffer.getShort() + } + + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + row.setShort(ordinal, buffer.getShort()) + } + + override def setField(row: MutableRow, ordinal: Int, value: Short): Unit = { + row.setShort(ordinal, value) + } + + override def getField(row: InternalRow, ordinal: Int): Short = row.getShort(ordinal) + + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + to.setShort(toOrdinal, from.getShort(fromOrdinal)) + } +} + +/** + * A fast path to copy var-length bytes between ByteBuffer and UnsafeRow without creating wrapper + * objects. + */ +private[columnar] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType] { + + // copy the bytes from ByteBuffer to UnsafeRow + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + if (row.isInstanceOf[MutableUnsafeRow]) { + val numBytes = buffer.getInt + val cursor = buffer.position() + buffer.position(cursor + numBytes) + row.asInstanceOf[MutableUnsafeRow].writer.write(ordinal, buffer.array(), + buffer.arrayOffset() + cursor, numBytes) + } else { + setField(row, ordinal, extract(buffer)) + } + } + + // copy the bytes from UnsafeRow to ByteBuffer + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + if (row.isInstanceOf[UnsafeRow]) { + row.asInstanceOf[UnsafeRow].writeFieldTo(ordinal, buffer) + } else { + super.append(row, ordinal, buffer) + } + } +} + +private[columnar] object STRING + extends NativeColumnType(StringType, 8) with DirectCopyColumnType[UTF8String] { + + override def actualSize(row: InternalRow, ordinal: Int): Int = { + row.getUTF8String(ordinal).numBytes() + 4 + } + + override def append(v: UTF8String, buffer: ByteBuffer): Unit = { + buffer.putInt(v.numBytes()) + v.writeTo(buffer) + } + + override def extract(buffer: ByteBuffer): UTF8String = { + val length = buffer.getInt() + val cursor = buffer.position() + buffer.position(cursor + length) + UTF8String.fromBytes(buffer.array(), buffer.arrayOffset() + cursor, length) + } + + override def setField(row: MutableRow, ordinal: Int, value: UTF8String): Unit = { + if (row.isInstanceOf[MutableUnsafeRow]) { + row.asInstanceOf[MutableUnsafeRow].writer.write(ordinal, value) + } else { + row.update(ordinal, value.clone()) + } + } + + override def getField(row: InternalRow, ordinal: Int): UTF8String = { + row.getUTF8String(ordinal) + } + + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + setField(to, toOrdinal, getField(from, fromOrdinal)) + } + + override def clone(v: UTF8String): UTF8String = v.clone() +} + +private[columnar] case class COMPACT_DECIMAL(precision: Int, scale: Int) + extends NativeColumnType(DecimalType(precision, scale), 8) { + + override def extract(buffer: ByteBuffer): Decimal = { + Decimal(ByteBufferHelper.getLong(buffer), precision, scale) + } + + override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = { + if (row.isInstanceOf[MutableUnsafeRow]) { + // copy it as Long + row.setLong(ordinal, ByteBufferHelper.getLong(buffer)) + } else { + setField(row, ordinal, extract(buffer)) + } + } + + override def append(v: Decimal, buffer: ByteBuffer): Unit = { + buffer.putLong(v.toUnscaledLong) + } + + override def append(row: InternalRow, ordinal: Int, buffer: ByteBuffer): Unit = { + if (row.isInstanceOf[UnsafeRow]) { + // copy it as Long + buffer.putLong(row.getLong(ordinal)) + } else { + append(getField(row, ordinal), buffer) + } + } + + override def getField(row: InternalRow, ordinal: Int): Decimal = { + row.getDecimal(ordinal, precision, scale) + } + + override def setField(row: MutableRow, ordinal: Int, value: Decimal): Unit = { + row.setDecimal(ordinal, value, precision) + } + + override def copyField(from: InternalRow, fromOrdinal: Int, to: MutableRow, toOrdinal: Int) { + setField(to, toOrdinal, getField(from, fromOrdinal)) + } +} + +private[columnar] object COMPACT_DECIMAL { + def apply(dt: DecimalType): COMPACT_DECIMAL = { + COMPACT_DECIMAL(dt.precision, dt.scale) + } +} + +private[columnar] sealed abstract class ByteArrayColumnType[JvmType](val defaultSize: Int) + extends ColumnType[JvmType] with DirectCopyColumnType[JvmType] { + + def serialize(value: JvmType): Array[Byte] + def deserialize(bytes: Array[Byte]): JvmType + + override def append(v: JvmType, buffer: ByteBuffer): Unit = { + val bytes = serialize(v) + buffer.putInt(bytes.length).put(bytes, 0, bytes.length) + } + + override def extract(buffer: ByteBuffer): JvmType = { + val length = buffer.getInt() + val bytes = new Array[Byte](length) + buffer.get(bytes, 0, length) + deserialize(bytes) + } +} + +private[columnar] object BINARY extends ByteArrayColumnType[Array[Byte]](16) { + + def dataType: DataType = BinaryType + + override def setField(row: MutableRow, ordinal: Int, value: Array[Byte]): Unit = { + row.update(ordinal, value) + } + + override def getField(row: InternalRow, ordinal: Int): Array[Byte] = { + row.getBinary(ordinal) + } + + override def actualSize(row: InternalRow, ordinal: Int): Int = { + row.getBinary(ordinal).length + 4 + } + + def serialize(value: Array[Byte]): Array[Byte] = value + def deserialize(bytes: Array[Byte]): Array[Byte] = bytes +} + +private[columnar] case class LARGE_DECIMAL(precision: Int, scale: Int) + extends ByteArrayColumnType[Decimal](12) { + + override val dataType: DataType = DecimalType(precision, scale) + + override def getField(row: InternalRow, ordinal: Int): Decimal = { + row.getDecimal(ordinal, precision, scale) + } + + override def setField(row: MutableRow, ordinal: Int, value: Decimal): Unit = { + row.setDecimal(ordinal, value, precision) + } + + override def actualSize(row: InternalRow, ordinal: Int): Int = { + 4 + getField(row, ordinal).toJavaBigDecimal.unscaledValue().bitLength() / 8 + 1 + } + + override def serialize(value: Decimal): Array[Byte] = { + value.toJavaBigDecimal.unscaledValue().toByteArray + } + + override def deserialize(bytes: Array[Byte]): Decimal = { + val javaDecimal = new BigDecimal(new BigInteger(bytes), scale) + Decimal.apply(javaDecimal, precision, scale) + } +} + +private[columnar] object LARGE_DECIMAL { + def apply(dt: DecimalType): LARGE_DECIMAL = { + LARGE_DECIMAL(dt.precision, dt.scale) + } +} + +private[columnar] case class STRUCT(dataType: StructType) + extends ColumnType[UnsafeRow] with DirectCopyColumnType[UnsafeRow] { + + private val numOfFields: Int = dataType.fields.size + + override def defaultSize: Int = 20 + + override def setField(row: MutableRow, ordinal: Int, value: UnsafeRow): Unit = { + row.update(ordinal, value) + } + + override def getField(row: InternalRow, ordinal: Int): UnsafeRow = { + row.getStruct(ordinal, numOfFields).asInstanceOf[UnsafeRow] + } + + override def actualSize(row: InternalRow, ordinal: Int): Int = { + 4 + getField(row, ordinal).getSizeInBytes + } + + override def append(value: UnsafeRow, buffer: ByteBuffer): Unit = { + buffer.putInt(value.getSizeInBytes) + value.writeTo(buffer) + } + + override def extract(buffer: ByteBuffer): UnsafeRow = { + val sizeInBytes = ByteBufferHelper.getInt(buffer) + assert(buffer.hasArray) + val cursor = buffer.position() + buffer.position(cursor + sizeInBytes) + val unsafeRow = new UnsafeRow + unsafeRow.pointTo( + buffer.array(), + Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor, + numOfFields, + sizeInBytes) + unsafeRow + } + + override def clone(v: UnsafeRow): UnsafeRow = v.copy() +} + +private[columnar] case class ARRAY(dataType: ArrayType) + extends ColumnType[UnsafeArrayData] with DirectCopyColumnType[UnsafeArrayData] { + + override def defaultSize: Int = 16 + + override def setField(row: MutableRow, ordinal: Int, value: UnsafeArrayData): Unit = { + row.update(ordinal, value) + } + + override def getField(row: InternalRow, ordinal: Int): UnsafeArrayData = { + row.getArray(ordinal).asInstanceOf[UnsafeArrayData] + } + + override def actualSize(row: InternalRow, ordinal: Int): Int = { + val unsafeArray = getField(row, ordinal) + 4 + unsafeArray.getSizeInBytes + } + + override def append(value: UnsafeArrayData, buffer: ByteBuffer): Unit = { + buffer.putInt(value.getSizeInBytes) + value.writeTo(buffer) + } + + override def extract(buffer: ByteBuffer): UnsafeArrayData = { + val numBytes = buffer.getInt + assert(buffer.hasArray) + val cursor = buffer.position() + buffer.position(cursor + numBytes) + val array = new UnsafeArrayData + array.pointTo( + buffer.array(), + Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor, + numBytes) + array + } + + override def clone(v: UnsafeArrayData): UnsafeArrayData = v.copy() +} + +private[columnar] case class MAP(dataType: MapType) + extends ColumnType[UnsafeMapData] with DirectCopyColumnType[UnsafeMapData] { + + override def defaultSize: Int = 32 + + override def setField(row: MutableRow, ordinal: Int, value: UnsafeMapData): Unit = { + row.update(ordinal, value) + } + + override def getField(row: InternalRow, ordinal: Int): UnsafeMapData = { + row.getMap(ordinal).asInstanceOf[UnsafeMapData] + } + + override def actualSize(row: InternalRow, ordinal: Int): Int = { + val unsafeMap = getField(row, ordinal) + 4 + unsafeMap.getSizeInBytes + } + + override def append(value: UnsafeMapData, buffer: ByteBuffer): Unit = { + buffer.putInt(value.getSizeInBytes) + value.writeTo(buffer) + } + + override def extract(buffer: ByteBuffer): UnsafeMapData = { + val numBytes = buffer.getInt + val cursor = buffer.position() + buffer.position(cursor + numBytes) + val map = new UnsafeMapData + map.pointTo( + buffer.array(), + Platform.BYTE_ARRAY_OFFSET + buffer.arrayOffset() + cursor, + numBytes) + map + } + + override def clone(v: UnsafeMapData): UnsafeMapData = v.copy() +} + +private[columnar] object ColumnType { + def apply(dataType: DataType): ColumnType[_] = { + dataType match { + case NullType => NULL + case BooleanType => BOOLEAN + case ByteType => BYTE + case ShortType => SHORT + case IntegerType | DateType => INT + case LongType | TimestampType => LONG + case FloatType => FLOAT + case DoubleType => DOUBLE + case StringType => STRING + case BinaryType => BINARY + case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => COMPACT_DECIMAL(dt) + case dt: DecimalType => LARGE_DECIMAL(dt) + case arr: ArrayType => ARRAY(arr) + case map: MapType => MAP(map) + case struct: StructType => STRUCT(struct) + case udt: UserDefinedType[_] => apply(udt.sqlType) + case other => + throw new Exception(s"Unsupported type: $other") + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/ea1a51fc/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala new file mode 100644 index 0000000..eaafc96 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/GenerateColumnAccessor.scala @@ -0,0 +1,195 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import org.apache.spark.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{UnsafeRowWriter, CodeFormatter, CodeGenerator} +import org.apache.spark.sql.types._ + +/** + * An Iterator to walk through the InternalRows from a CachedBatch + */ +abstract class ColumnarIterator extends Iterator[InternalRow] { + def initialize(input: Iterator[CachedBatch], columnTypes: Array[DataType], + columnIndexes: Array[Int]): Unit +} + +/** + * An helper class to update the fields of UnsafeRow, used by ColumnAccessor + * + * WARNING: These setter MUST be called in increasing order of ordinals. + */ +class MutableUnsafeRow(val writer: UnsafeRowWriter) extends GenericMutableRow(null) { + + override def isNullAt(i: Int): Boolean = writer.isNullAt(i) + override def setNullAt(i: Int): Unit = writer.setNullAt(i) + + override def setBoolean(i: Int, v: Boolean): Unit = writer.write(i, v) + override def setByte(i: Int, v: Byte): Unit = writer.write(i, v) + override def setShort(i: Int, v: Short): Unit = writer.write(i, v) + override def setInt(i: Int, v: Int): Unit = writer.write(i, v) + override def setLong(i: Int, v: Long): Unit = writer.write(i, v) + override def setFloat(i: Int, v: Float): Unit = writer.write(i, v) + override def setDouble(i: Int, v: Double): Unit = writer.write(i, v) + + // the writer will be used directly to avoid creating wrapper objects + override def setDecimal(i: Int, v: Decimal, precision: Int): Unit = + throw new UnsupportedOperationException + override def update(i: Int, v: Any): Unit = throw new UnsupportedOperationException + + // all other methods inherited from GenericMutableRow are not need +} + +/** + * Generates bytecode for an [[ColumnarIterator]] for columnar cache. + */ +object GenerateColumnAccessor extends CodeGenerator[Seq[DataType], ColumnarIterator] with Logging { + + protected def canonicalize(in: Seq[DataType]): Seq[DataType] = in + protected def bind(in: Seq[DataType], inputSchema: Seq[Attribute]): Seq[DataType] = in + + protected def create(columnTypes: Seq[DataType]): ColumnarIterator = { + val ctx = newCodeGenContext() + val numFields = columnTypes.size + val (initializeAccessors, extractors) = columnTypes.zipWithIndex.map { case (dt, index) => + val accessorName = ctx.freshName("accessor") + val accessorCls = dt match { + case NullType => classOf[NullColumnAccessor].getName + case BooleanType => classOf[BooleanColumnAccessor].getName + case ByteType => classOf[ByteColumnAccessor].getName + case ShortType => classOf[ShortColumnAccessor].getName + case IntegerType | DateType => classOf[IntColumnAccessor].getName + case LongType | TimestampType => classOf[LongColumnAccessor].getName + case FloatType => classOf[FloatColumnAccessor].getName + case DoubleType => classOf[DoubleColumnAccessor].getName + case StringType => classOf[StringColumnAccessor].getName + case BinaryType => classOf[BinaryColumnAccessor].getName + case dt: DecimalType if dt.precision <= Decimal.MAX_LONG_DIGITS => + classOf[CompactDecimalColumnAccessor].getName + case dt: DecimalType => classOf[DecimalColumnAccessor].getName + case struct: StructType => classOf[StructColumnAccessor].getName + case array: ArrayType => classOf[ArrayColumnAccessor].getName + case t: MapType => classOf[MapColumnAccessor].getName + } + ctx.addMutableState(accessorCls, accessorName, s"$accessorName = null;") + + val createCode = dt match { + case t if ctx.isPrimitiveType(dt) => + s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" + case NullType | StringType | BinaryType => + s"$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder));" + case other => + s"""$accessorName = new $accessorCls(ByteBuffer.wrap(buffers[$index]).order(nativeOrder), + (${dt.getClass.getName}) columnTypes[$index]);""" + } + + val extract = s"$accessorName.extractTo(mutableRow, $index);" + val patch = dt match { + case DecimalType.Fixed(p, s) if p > Decimal.MAX_LONG_DIGITS => + // For large Decimal, it should have 16 bytes for future update even it's null now. + s""" + if (mutableRow.isNullAt($index)) { + rowWriter.write($index, (Decimal) null, $p, $s); + } + """ + case other => "" + } + (createCode, extract + patch) + }.unzip + + val code = s""" + import java.nio.ByteBuffer; + import java.nio.ByteOrder; + import scala.collection.Iterator; + import org.apache.spark.sql.types.DataType; + import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder; + import org.apache.spark.sql.catalyst.expressions.codegen.UnsafeRowWriter; + import org.apache.spark.sql.execution.columnar.MutableUnsafeRow; + + public SpecificColumnarIterator generate($exprType[] expr) { + return new SpecificColumnarIterator(); + } + + class SpecificColumnarIterator extends ${classOf[ColumnarIterator].getName} { + + private ByteOrder nativeOrder = null; + private byte[][] buffers = null; + private UnsafeRow unsafeRow = new UnsafeRow(); + private BufferHolder bufferHolder = new BufferHolder(); + private UnsafeRowWriter rowWriter = new UnsafeRowWriter(); + private MutableUnsafeRow mutableRow = null; + + private int currentRow = 0; + private int numRowsInBatch = 0; + + private scala.collection.Iterator input = null; + private DataType[] columnTypes = null; + private int[] columnIndexes = null; + + ${declareMutableStates(ctx)} + + public SpecificColumnarIterator() { + this.nativeOrder = ByteOrder.nativeOrder(); + this.buffers = new byte[${columnTypes.length}][]; + this.mutableRow = new MutableUnsafeRow(rowWriter); + + ${initMutableStates(ctx)} + } + + public void initialize(Iterator input, DataType[] columnTypes, int[] columnIndexes) { + this.input = input; + this.columnTypes = columnTypes; + this.columnIndexes = columnIndexes; + } + + public boolean hasNext() { + if (currentRow < numRowsInBatch) { + return true; + } + if (!input.hasNext()) { + return false; + } + + ${classOf[CachedBatch].getName} batch = (${classOf[CachedBatch].getName}) input.next(); + currentRow = 0; + numRowsInBatch = batch.numRows(); + for (int i = 0; i < columnIndexes.length; i ++) { + buffers[i] = batch.buffers()[columnIndexes[i]]; + } + ${initializeAccessors.mkString("\n")} + + return hasNext(); + } + + public InternalRow next() { + currentRow += 1; + bufferHolder.reset(); + rowWriter.initialize(bufferHolder, $numFields); + ${extractors.mkString("\n")} + unsafeRow.pointTo(bufferHolder.buffer, $numFields, bufferHolder.totalSize()); + return unsafeRow; + } + }""" + + logDebug(s"Generated ColumnarIterator: ${CodeFormatter.format(code)}") + + compile(code).generate(ctx.references.toArray).asInstanceOf[ColumnarIterator] + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org