[SPARK-9480][SQL] add MapData and cleanup internal row stuff This PR adds a `MapData` as internal representation of map type in Spark SQL, and provides a default implementation with just 2 `ArrayData`.
After that, we have specialized getters for all internal type, so I removed generic getter in `ArrayData` and added specialized `toArray` for it. Also did some refactor and cleanup for `InternalRow` and its subclasses. Author: Wenchen Fan <cloud0...@outlook.com> Closes #7799 from cloud-fan/map-data and squashes the following commits: 77d482f [Wenchen Fan] fix python e8f6682 [Wenchen Fan] skip MapData equality check in HiveInspectorSuite 40cc9db [Wenchen Fan] add toString 6e06ec9 [Wenchen Fan] some more cleanup a90aca1 [Wenchen Fan] add MapData Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/1d59a416 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/1d59a416 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/1d59a416 Branch: refs/heads/master Commit: 1d59a4162bf5142af270ed7f4b3eab42870c87b7 Parents: d90f2cf Author: Wenchen Fan <cloud0...@outlook.com> Authored: Sat Aug 1 00:17:15 2015 -0700 Committer: Reynold Xin <r...@databricks.com> Committed: Sat Aug 1 00:17:15 2015 -0700 ---------------------------------------------------------------------- .../apache/spark/mllib/linalg/Matrices.scala | 6 +- .../org/apache/spark/mllib/linalg/Vectors.scala | 6 +- .../expressions/SpecializedGetters.java | 6 + .../sql/catalyst/expressions/UnsafeRow.java | 12 +- .../sql/catalyst/CatalystTypeConverters.scala | 79 ++++++---- .../apache/spark/sql/catalyst/InternalRow.scala | 117 +++++--------- .../catalyst/expressions/BoundAttribute.scala | 3 + .../spark/sql/catalyst/expressions/Cast.scala | 101 +++++------- .../expressions/GenericSpecializedGetters.scala | 61 ++++++++ .../sql/catalyst/expressions/Projection.scala | 69 +++++---- .../expressions/SpecificMutableRow.scala | 13 +- .../expressions/codegen/CodeGenerator.scala | 11 +- .../codegen/GenerateProjection.scala | 2 +- .../expressions/collectionOperations.scala | 8 +- .../expressions/complexTypeExtractors.scala | 81 +++++++--- .../sql/catalyst/expressions/generators.scala | 29 +++- .../spark/sql/catalyst/expressions/rows.scala | 44 ++---- .../catalyst/expressions/stringOperations.scala | 2 +- .../spark/sql/types/ArrayBasedMapData.scala | 51 ++++++ .../org/apache/spark/sql/types/ArrayData.scala | 155 ++++++++++--------- .../spark/sql/types/GenericArrayData.scala | 116 ++++++++++---- .../org/apache/spark/sql/types/MapData.scala | 38 +++++ .../catalyst/expressions/ComplexTypeSuite.scala | 4 - .../expressions/UnsafeRowConverterSuite.scala | 2 +- .../spark/sql/execution/debug/package.scala | 14 +- .../apache/spark/sql/execution/pythonUDFs.scala | 31 ++-- .../apache/spark/sql/json/JacksonParser.scala | 21 ++- .../sql/parquet/CatalystRowConverter.scala | 16 +- .../spark/sql/parquet/ParquetConverter.scala | 8 +- .../spark/sql/parquet/ParquetTableSupport.scala | 23 +-- .../org/apache/spark/sql/DataFrameSuite.scala | 1 + .../scala/org/apache/spark/sql/RowSuite.scala | 4 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 34 ++-- .../scala/org/apache/spark/sql/TestData.scala | 4 +- .../apache/spark/sql/UserDefinedTypeSuite.scala | 2 +- .../apache/spark/sql/hive/HiveInspectors.scala | 89 ++++++----- .../hive/execution/InsertIntoHiveTable.scala | 4 +- .../org/apache/spark/sql/hive/hiveUDFs.scala | 4 +- .../apache/spark/sql/hive/orc/OrcRelation.scala | 7 +- .../spark/sql/hive/HiveInspectorSuite.scala | 4 +- 40 files changed, 750 insertions(+), 532 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/1d59a416/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 88914fa..1c85834 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -179,12 +179,12 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { val tpe = row.getByte(0) val numRows = row.getInt(1) val numCols = row.getInt(2) - val values = row.getArray(5).toArray.map(_.asInstanceOf[Double]) + val values = row.getArray(5).toDoubleArray() val isTransposed = row.getBoolean(6) tpe match { case 0 => - val colPtrs = row.getArray(3).toArray.map(_.asInstanceOf[Int]) - val rowIndices = row.getArray(4).toArray.map(_.asInstanceOf[Int]) + val colPtrs = row.getArray(3).toIntArray() + val rowIndices = row.getArray(4).toIntArray() new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed) case 1 => new DenseMatrix(numRows, numCols, values, isTransposed) http://git-wip-us.apache.org/repos/asf/spark/blob/1d59a416/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 89a1818..96d1f48 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -209,11 +209,11 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { tpe match { case 0 => val size = row.getInt(1) - val indices = row.getArray(2).toArray().map(_.asInstanceOf[Int]) - val values = row.getArray(3).toArray().map(_.asInstanceOf[Double]) + val indices = row.getArray(2).toIntArray() + val values = row.getArray(3).toDoubleArray() new SparseVector(size, indices, values) case 1 => - val values = row.getArray(3).toArray().map(_.asInstanceOf[Double]) + val values = row.getArray(3).toDoubleArray() new DenseVector(values) } } http://git-wip-us.apache.org/repos/asf/spark/blob/1d59a416/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java index e3d3ba7..8f1027f 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/SpecializedGetters.java @@ -19,7 +19,9 @@ package org.apache.spark.sql.catalyst.expressions; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.types.ArrayData; +import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.MapData; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -52,4 +54,8 @@ public interface SpecializedGetters { InternalRow getStruct(int ordinal, int numFields); ArrayData getArray(int ordinal); + + MapData getMap(int ordinal); + + Object get(int ordinal, DataType dataType); } http://git-wip-us.apache.org/repos/asf/spark/blob/1d59a416/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 24dc80b..5a19aa8 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -118,6 +118,11 @@ public final class UnsafeRow extends MutableRow { return baseOffset + bitSetWidthInBytes + ordinal * 8L; } + private void assertIndexIsValid(int index) { + assert index >= 0 : "index (" + index + ") should >= 0"; + assert index < numFields : "index (" + index + ") should < " + numFields; + } + ////////////////////////////////////////////////////////////////////////////// // Public methods ////////////////////////////////////////////////////////////////////////////// @@ -163,11 +168,6 @@ public final class UnsafeRow extends MutableRow { pointTo(buf, PlatformDependent.BYTE_ARRAY_OFFSET, numFields, sizeInBytes); } - private void assertIndexIsValid(int index) { - assert index >= 0 : "index (" + index + ") should >= 0"; - assert index < numFields : "index (" + index + ") should < " + numFields; - } - @Override public void setNullAt(int i) { assertIndexIsValid(i); @@ -254,7 +254,7 @@ public final class UnsafeRow extends MutableRow { } @Override - public Object get(int ordinal) { + public Object genericGet(int ordinal) { throw new UnsupportedOperationException(); } http://git-wip-us.apache.org/repos/asf/spark/blob/1d59a416/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 7ca20fe..c666864 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -23,7 +23,6 @@ import java.sql.{Date, Timestamp} import java.util.{Map => JavaMap} import javax.annotation.Nullable -import scala.collection.mutable.HashMap import scala.language.existentials import org.apache.spark.sql.Row @@ -53,12 +52,6 @@ object CatalystTypeConverters { } } - private def isWholePrimitive(dt: DataType): Boolean = dt match { - case dt if isPrimitive(dt) => true - case MapType(keyType, valueType, _) => isWholePrimitive(keyType) && isWholePrimitive(valueType) - case _ => false - } - private def getConverterForType(dataType: DataType): CatalystTypeConverter[Any, Any, Any] = { val converter = dataType match { case udt: UserDefinedType[_] => UDTConverter(udt) @@ -157,8 +150,6 @@ object CatalystTypeConverters { private[this] val elementConverter = getConverterForType(elementType) - private[this] val isNoChange = isWholePrimitive(elementType) - override def toCatalystImpl(scalaValue: Any): ArrayData = { scalaValue match { case a: Array[_] => @@ -179,10 +170,14 @@ object CatalystTypeConverters { override def toScala(catalystValue: ArrayData): Seq[Any] = { if (catalystValue == null) { null - } else if (isNoChange) { - catalystValue.toArray() + } else if (isPrimitive(elementType)) { + catalystValue.toArray[Any](elementType) } else { - catalystValue.toArray().map(elementConverter.toScala) + val result = new Array[Any](catalystValue.numElements()) + catalystValue.foreach(elementType, (i, e) => { + result(i) = elementConverter.toScala(e) + }) + result } } @@ -193,44 +188,58 @@ object CatalystTypeConverters { private case class MapConverter( keyType: DataType, valueType: DataType) - extends CatalystTypeConverter[Any, Map[Any, Any], Map[Any, Any]] { + extends CatalystTypeConverter[Any, Map[Any, Any], MapData] { private[this] val keyConverter = getConverterForType(keyType) private[this] val valueConverter = getConverterForType(valueType) - private[this] val isNoChange = isWholePrimitive(keyType) && isWholePrimitive(valueType) - - override def toCatalystImpl(scalaValue: Any): Map[Any, Any] = scalaValue match { + override def toCatalystImpl(scalaValue: Any): MapData = scalaValue match { case m: Map[_, _] => - m.map { case (k, v) => - keyConverter.toCatalyst(k) -> valueConverter.toCatalyst(v) + val length = m.size + val convertedKeys = new Array[Any](length) + val convertedValues = new Array[Any](length) + + var i = 0 + for ((key, value) <- m) { + convertedKeys(i) = keyConverter.toCatalyst(key) + convertedValues(i) = valueConverter.toCatalyst(value) + i += 1 } + ArrayBasedMapData(convertedKeys, convertedValues) case jmap: JavaMap[_, _] => + val length = jmap.size() + val convertedKeys = new Array[Any](length) + val convertedValues = new Array[Any](length) + + var i = 0 val iter = jmap.entrySet.iterator - val convertedMap: HashMap[Any, Any] = HashMap() while (iter.hasNext) { val entry = iter.next() - val key = keyConverter.toCatalyst(entry.getKey) - convertedMap(key) = valueConverter.toCatalyst(entry.getValue) + convertedKeys(i) = keyConverter.toCatalyst(entry.getKey) + convertedValues(i) = valueConverter.toCatalyst(entry.getValue) + i += 1 } - convertedMap + ArrayBasedMapData(convertedKeys, convertedValues) } - override def toScala(catalystValue: Map[Any, Any]): Map[Any, Any] = { + override def toScala(catalystValue: MapData): Map[Any, Any] = { if (catalystValue == null) { null - } else if (isNoChange) { - catalystValue } else { - catalystValue.map { case (k, v) => - keyConverter.toScala(k) -> valueConverter.toScala(v) - } + val keys = catalystValue.keyArray().toArray[Any](keyType) + val values = catalystValue.valueArray().toArray[Any](valueType) + val convertedKeys = + if (isPrimitive(keyType)) keys else keys.map(keyConverter.toScala) + val convertedValues = + if (isPrimitive(valueType)) values else values.map(valueConverter.toScala) + + convertedKeys.zip(convertedValues).toMap } } override def toScalaImpl(row: InternalRow, column: Int): Map[Any, Any] = - toScala(row.get(column, MapType(keyType, valueType)).asInstanceOf[Map[Any, Any]]) + toScala(row.getMap(column)) } private case class StructConverter( @@ -410,7 +419,17 @@ object CatalystTypeConverters { case r: Row => InternalRow(r.toSeq.map(convertToCatalyst): _*) case arr: Array[Any] => new GenericArrayData(arr.map(convertToCatalyst)) case m: Map[_, _] => - m.map { case (k, v) => (convertToCatalyst(k), convertToCatalyst(v)) }.toMap + val length = m.size + val convertedKeys = new Array[Any](length) + val convertedValues = new Array[Any](length) + + var i = 0 + for ((key, value) <- m) { + convertedKeys(i) = convertToCatalyst(key) + convertedValues(i) = convertToCatalyst(value) + i += 1 + } + ArrayBasedMapData(convertedKeys, convertedValues) case other => other } http://git-wip-us.apache.org/repos/asf/spark/blob/1d59a416/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index b19bf43..7656d05 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -19,71 +19,25 @@ package org.apache.spark.sql.catalyst import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} /** * An abstract class for row used internal in Spark SQL, which only contain the columns as * internal types. */ -abstract class InternalRow extends Serializable with SpecializedGetters { +// todo: make InternalRow just extends SpecializedGetters, remove generic getter +abstract class InternalRow extends GenericSpecializedGetters with Serializable { def numFields: Int - def get(ordinal: Int): Any = get(ordinal, null) - - def genericGet(ordinal: Int): Any = get(ordinal, null) - - def get(ordinal: Int, dataType: DataType): Any - - def getAs[T](ordinal: Int, dataType: DataType): T = get(ordinal, dataType).asInstanceOf[T] - - override def isNullAt(ordinal: Int): Boolean = get(ordinal) == null - - override def getBoolean(ordinal: Int): Boolean = getAs[Boolean](ordinal, BooleanType) - - override def getByte(ordinal: Int): Byte = getAs[Byte](ordinal, ByteType) - - override def getShort(ordinal: Int): Short = getAs[Short](ordinal, ShortType) - - override def getInt(ordinal: Int): Int = getAs[Int](ordinal, IntegerType) - - override def getLong(ordinal: Int): Long = getAs[Long](ordinal, LongType) - - override def getFloat(ordinal: Int): Float = getAs[Float](ordinal, FloatType) - - override def getDouble(ordinal: Int): Double = getAs[Double](ordinal, DoubleType) - - override def getUTF8String(ordinal: Int): UTF8String = getAs[UTF8String](ordinal, StringType) - - override def getBinary(ordinal: Int): Array[Byte] = getAs[Array[Byte]](ordinal, BinaryType) - - override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = - getAs[Decimal](ordinal, DecimalType(precision, scale)) - - override def getInterval(ordinal: Int): CalendarInterval = - getAs[CalendarInterval](ordinal, CalendarIntervalType) - // This is only use for test and will throw a null pointer exception if the position is null. def getString(ordinal: Int): String = getUTF8String(ordinal).toString - /** - * Returns a struct from ordinal position. - * - * @param ordinal position to get the struct from. - * @param numFields number of fields the struct type has - */ - override def getStruct(ordinal: Int, numFields: Int): InternalRow = - getAs[InternalRow](ordinal, null) - - override def getArray(ordinal: Int): ArrayData = getAs(ordinal, null) - - override def toString: String = s"[${this.mkString(",")}]" + override def toString: String = mkString("[", ",", "]") /** * Make a copy of the current [[InternalRow]] object. */ - def copy(): InternalRow = this + def copy(): InternalRow /** Returns true if there are any NULL values in this row. */ def anyNull: Boolean = { @@ -117,8 +71,8 @@ abstract class InternalRow extends Serializable with SpecializedGetters { return false } if (!isNullAt(i)) { - val o1 = get(i) - val o2 = other.get(i) + val o1 = genericGet(i) + val o2 = other.genericGet(i) o1 match { case b1: Array[Byte] => if (!o2.isInstanceOf[Array[Byte]] || @@ -143,34 +97,6 @@ abstract class InternalRow extends Serializable with SpecializedGetters { true } - /* ---------------------- utility methods for Scala ---------------------- */ - - /** - * Return a Scala Seq representing the row. Elements are placed in the same order in the Seq. - */ - def toSeq: Seq[Any] = { - val n = numFields - val values = new Array[Any](n) - var i = 0 - while (i < n) { - values.update(i, get(i)) - i += 1 - } - values.toSeq - } - - /** Displays all elements of this sequence in a string (without a separator). */ - def mkString: String = toSeq.mkString - - /** Displays all elements of this sequence in a string using a separator string. */ - def mkString(sep: String): String = toSeq.mkString(sep) - - /** - * Displays all elements of this traversable or iterator in a string using - * start, end, and separator strings. - */ - def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end) - // Custom hashCode function that matches the efficient code generated version. override def hashCode: Int = { var result: Int = 37 @@ -181,7 +107,7 @@ abstract class InternalRow extends Serializable with SpecializedGetters { if (isNullAt(i)) { 0 } else { - get(i) match { + genericGet(i) match { case b: Boolean => if (b) 0 else 1 case b: Byte => b.toInt case s: Short => s.toInt @@ -200,6 +126,35 @@ abstract class InternalRow extends Serializable with SpecializedGetters { } result } + + /* ---------------------- utility methods for Scala ---------------------- */ + + /** + * Return a Scala Seq representing the row. Elements are placed in the same order in the Seq. + */ + // todo: remove this as it needs the generic getter + def toSeq: Seq[Any] = { + val n = numFields + val values = new Array[Any](n) + var i = 0 + while (i < n) { + values.update(i, genericGet(i)) + i += 1 + } + values + } + + /** Displays all elements of this sequence in a string (without a separator). */ + def mkString: String = toSeq.mkString + + /** Displays all elements of this sequence in a string using a separator string. */ + def mkString(sep: String): String = toSeq.mkString(sep) + + /** + * Displays all elements of this traversable or iterator in a string using + * start, end, and separator strings. + */ + def mkString(start: String, sep: String, end: String): String = toSeq.mkString(start, sep, end) } object InternalRow { http://git-wip-us.apache.org/repos/asf/spark/blob/1d59a416/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 45709c1..473b9b7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -49,7 +49,10 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) case StringType => input.getUTF8String(ordinal) case BinaryType => input.getBinary(ordinal) case CalendarIntervalType => input.getInterval(ordinal) + case t: DecimalType => input.getDecimal(ordinal, t.precision, t.scale) case t: StructType => input.getStruct(ordinal, t.size) + case _: ArrayType => input.getArray(ordinal) + case _: MapType => input.getMap(ordinal) case _ => input.get(ordinal, dataType) } } http://git-wip-us.apache.org/repos/asf/spark/blob/1d59a416/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 43be11c..88429bb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -361,30 +361,29 @@ case class Cast(child: Expression, dataType: DataType) b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b) } - private[this] def castArray(from: ArrayType, to: ArrayType): Any => Any = { - val elementCast = cast(from.elementType, to.elementType) + private[this] def castArray(fromType: DataType, toType: DataType): Any => Any = { + val elementCast = cast(fromType, toType) // TODO: Could be faster? buildCast[ArrayData](_, array => { - val length = array.numElements() - val values = new Array[Any](length) - var i = 0 - while (i < length) { - if (array.isNullAt(i)) { + val values = new Array[Any](array.numElements()) + array.foreach(fromType, (i, e) => { + if (e == null) { values(i) = null } else { - values(i) = elementCast(array.get(i)) + values(i) = elementCast(e) } - i += 1 - } + }) new GenericArrayData(values) }) } private[this] def castMap(from: MapType, to: MapType): Any => Any = { - val keyCast = cast(from.keyType, to.keyType) - val valueCast = cast(from.valueType, to.valueType) - buildCast[Map[Any, Any]](_, _.map { - case (key, value) => (keyCast(key), if (value == null) null else valueCast(value)) + val keyCast = castArray(from.keyType, to.keyType) + val valueCast = castArray(from.valueType, to.valueType) + buildCast[MapData](_, map => { + val keys = keyCast(map.keyArray()).asInstanceOf[ArrayData] + val values = valueCast(map.valueArray()).asInstanceOf[ArrayData] + new ArrayBasedMapData(keys, values) }) } @@ -420,7 +419,7 @@ case class Cast(child: Expression, dataType: DataType) case FloatType => castToFloat(from) case LongType => castToLong(from) case DoubleType => castToDouble(from) - case array: ArrayType => castArray(from.asInstanceOf[ArrayType], array) + case array: ArrayType => castArray(from.asInstanceOf[ArrayType].elementType, array.elementType) case map: MapType => castMap(from.asInstanceOf[MapType], map) case struct: StructType => castStruct(from.asInstanceOf[StructType], struct) } @@ -461,7 +460,8 @@ case class Cast(child: Expression, dataType: DataType) case LongType => castToLongCode(from) case DoubleType => castToDoubleCode(from) - case array: ArrayType => castArrayCode(from.asInstanceOf[ArrayType], array, ctx) + case array: ArrayType => + castArrayCode(from.asInstanceOf[ArrayType].elementType, array.elementType, ctx) case map: MapType => castMapCode(from.asInstanceOf[MapType], map, ctx) case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx) } @@ -801,8 +801,8 @@ case class Cast(child: Expression, dataType: DataType) } private[this] def castArrayCode( - from: ArrayType, to: ArrayType, ctx: CodeGenContext): CastFunction = { - val elementCast = nullSafeCastFunction(from.elementType, to.elementType, ctx) + fromType: DataType, toType: DataType, ctx: CodeGenContext): CastFunction = { + val elementCast = nullSafeCastFunction(fromType, toType, ctx) val arrayClass = classOf[GenericArrayData].getName val fromElementNull = ctx.freshName("feNull") val fromElementPrim = ctx.freshName("fePrim") @@ -821,10 +821,10 @@ case class Cast(child: Expression, dataType: DataType) $values[$j] = null; } else { boolean $fromElementNull = false; - ${ctx.javaType(from.elementType)} $fromElementPrim = - ${ctx.getValue(c, from.elementType, j)}; + ${ctx.javaType(fromType)} $fromElementPrim = + ${ctx.getValue(c, fromType, j)}; ${castCode(ctx, fromElementPrim, - fromElementNull, toElementPrim, toElementNull, to.elementType, elementCast)} + fromElementNull, toElementPrim, toElementNull, toType, elementCast)} if ($toElementNull) { $values[$j] = null; } else { @@ -837,48 +837,29 @@ case class Cast(child: Expression, dataType: DataType) } private[this] def castMapCode(from: MapType, to: MapType, ctx: CodeGenContext): CastFunction = { - val keyCast = nullSafeCastFunction(from.keyType, to.keyType, ctx) - val valueCast = nullSafeCastFunction(from.valueType, to.valueType, ctx) - - val hashMapClass = classOf[mutable.HashMap[Any, Any]].getName - val fromKeyPrim = ctx.freshName("fkp") - val fromKeyNull = ctx.freshName("fkn") - val fromValuePrim = ctx.freshName("fvp") - val fromValueNull = ctx.freshName("fvn") - val toKeyPrim = ctx.freshName("tkp") - val toKeyNull = ctx.freshName("tkn") - val toValuePrim = ctx.freshName("tvp") - val toValueNull = ctx.freshName("tvn") - val result = ctx.freshName("result") + val keysCast = castArrayCode(from.keyType, to.keyType, ctx) + val valuesCast = castArrayCode(from.valueType, to.valueType, ctx) + + val mapClass = classOf[ArrayBasedMapData].getName + + val keys = ctx.freshName("keys") + val convertedKeys = ctx.freshName("convertedKeys") + val convertedKeysNull = ctx.freshName("convertedKeysNull") + + val values = ctx.freshName("values") + val convertedValues = ctx.freshName("convertedValues") + val convertedValuesNull = ctx.freshName("convertedValuesNull") (c, evPrim, evNull) => s""" - final $hashMapClass $result = new $hashMapClass(); - scala.collection.Iterator iter = $c.iterator(); - while (iter.hasNext()) { - scala.Tuple2 kv = (scala.Tuple2) iter.next(); - boolean $fromKeyNull = false; - ${ctx.javaType(from.keyType)} $fromKeyPrim = - (${ctx.boxedType(from.keyType)}) kv._1(); - ${castCode(ctx, fromKeyPrim, - fromKeyNull, toKeyPrim, toKeyNull, to.keyType, keyCast)} - - boolean $fromValueNull = kv._2() == null; - if ($fromValueNull) { - $result.put($toKeyPrim, null); - } else { - ${ctx.javaType(from.valueType)} $fromValuePrim = - (${ctx.boxedType(from.valueType)}) kv._2(); - ${castCode(ctx, fromValuePrim, - fromValueNull, toValuePrim, toValueNull, to.valueType, valueCast)} - if ($toValueNull) { - $result.put($toKeyPrim, null); - } else { - $result.put($toKeyPrim, $toValuePrim); - } - } - } - $evPrim = $result; + final ArrayData $keys = $c.keyArray(); + final ArrayData $values = $c.valueArray(); + ${castCode(ctx, keys, "false", + convertedKeys, convertedKeysNull, ArrayType(to.keyType), keysCast)} + ${castCode(ctx, values, "false", + convertedValues, convertedValuesNull, ArrayType(to.valueType), valuesCast)} + + $evPrim = new $mapClass($convertedKeys, $convertedValues); """ } http://git-wip-us.apache.org/repos/asf/spark/blob/1d59a416/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GenericSpecializedGetters.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GenericSpecializedGetters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GenericSpecializedGetters.scala new file mode 100644 index 0000000..6e95792 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/GenericSpecializedGetters.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types.{DataType, MapData, ArrayData, Decimal} +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} + +trait GenericSpecializedGetters extends SpecializedGetters { + + def genericGet(ordinal: Int): Any + + private def getAs[T](ordinal: Int) = genericGet(ordinal).asInstanceOf[T] + + override def isNullAt(ordinal: Int): Boolean = getAs[AnyRef](ordinal) eq null + + override def get(ordinal: Int, elementType: DataType): AnyRef = getAs(ordinal) + + override def getBoolean(ordinal: Int): Boolean = getAs(ordinal) + + override def getByte(ordinal: Int): Byte = getAs(ordinal) + + override def getShort(ordinal: Int): Short = getAs(ordinal) + + override def getInt(ordinal: Int): Int = getAs(ordinal) + + override def getLong(ordinal: Int): Long = getAs(ordinal) + + override def getFloat(ordinal: Int): Float = getAs(ordinal) + + override def getDouble(ordinal: Int): Double = getAs(ordinal) + + override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal) + + override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) + + override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) + + override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) + + override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) + + override def getArray(ordinal: Int): ArrayData = getAs(ordinal) + + override def getMap(ordinal: Int): MapData = getAs(ordinal) +} http://git-wip-us.apache.org/repos/asf/spark/blob/1d59a416/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 7c7664e..d79325a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateUnsafeProjection, GenerateMutableProjection} -import org.apache.spark.sql.types.{Decimal, StructType, DataType} -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} /** * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. @@ -190,45 +190,55 @@ class JoinedRow extends InternalRow { override def numFields: Int = row1.numFields + row2.numFields - override def getUTF8String(i: Int): UTF8String = { - if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields) - } - - override def getBinary(i: Int): Array[Byte] = { - if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) - } - - override def get(i: Int, dataType: DataType): Any = - if (i < row1.numFields) row1.get(i) else row2.get(i - row1.numFields) + override def genericGet(i: Int): Any = + if (i < row1.numFields) row1.genericGet(i) else row2.genericGet(i - row1.numFields) override def isNullAt(i: Int): Boolean = if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields) + override def getBoolean(i: Int): Boolean = + if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields) + + override def getByte(i: Int): Byte = + if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields) + + override def getShort(i: Int): Short = + if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields) + override def getInt(i: Int): Int = if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields) override def getLong(i: Int): Long = if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields) + override def getFloat(i: Int): Float = + if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) + override def getDouble(i: Int): Double = if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields) - override def getBoolean(i: Int): Boolean = - if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields) + override def getDecimal(i: Int, precision: Int, scale: Int): Decimal = { + if (i < row1.numFields) { + row1.getDecimal(i, precision, scale) + } else { + row2.getDecimal(i - row1.numFields, precision, scale) + } + } - override def getShort(i: Int): Short = - if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields) + override def getUTF8String(i: Int): UTF8String = + if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields) - override def getByte(i: Int): Byte = - if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields) + override def getBinary(i: Int): Array[Byte] = + if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields) - override def getFloat(i: Int): Float = - if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields) + override def getArray(i: Int): ArrayData = + if (i < row1.numFields) row1.getArray(i) else row2.getArray(i - row1.numFields) - override def getDecimal(i: Int, precision: Int, scale: Int): Decimal = { - if (i < row1.numFields) row1.getDecimal(i, precision, scale) - else row2.getDecimal(i - row1.numFields, precision, scale) - } + override def getInterval(i: Int): CalendarInterval = + if (i < row1.numFields) row1.getInterval(i) else row2.getInterval(i - row1.numFields) + + override def getMap(i: Int): MapData = + if (i < row1.numFields) row1.getMap(i) else row2.getMap(i - row1.numFields) override def getStruct(i: Int, numFields: Int): InternalRow = { if (i < row1.numFields) { @@ -239,14 +249,9 @@ class JoinedRow extends InternalRow { } override def copy(): InternalRow = { - val totalSize = row1.numFields + row2.numFields - val copiedValues = new Array[Any](totalSize) - var i = 0 - while(i < totalSize) { - copiedValues(i) = get(i) - i += 1 - } - new GenericInternalRow(copiedValues) + val copy1 = row1.copy() + val copy2 = row2.copy() + new JoinedRow(copy1, copy2) } override def toString: String = { http://git-wip-us.apache.org/repos/asf/spark/blob/1d59a416/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index b877ce4..d149a5b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -213,18 +213,12 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def numFields: Int = values.length - override def toSeq: Seq[Any] = values.map(_.boxed).toSeq + override def toSeq: Seq[Any] = values.map(_.boxed) override def setNullAt(i: Int): Unit = { values(i).isNull = true } - override def get(i: Int, dataType: DataType): Any = values(i).boxed - - override def getStruct(ordinal: Int, numFields: Int): InternalRow = { - values(ordinal).boxed.asInstanceOf[InternalRow] - } - override def isNullAt(i: Int): Boolean = values(i).isNull override def copy(): InternalRow = { @@ -238,6 +232,8 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR new GenericInternalRow(newValues) } + override def genericGet(i: Int): Any = values(i).boxed + override def update(ordinal: Int, value: Any) { if (value == null) { setNullAt(ordinal) @@ -246,9 +242,6 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR } } - override def setString(ordinal: Int, value: String): Unit = - update(ordinal, UTF8String.fromString(value)) - override def setInt(ordinal: Int, value: Int): Unit = { val currentValue = values(ordinal).asInstanceOf[MutableInt] currentValue.isNull = false http://git-wip-us.apache.org/repos/asf/spark/blob/1d59a416/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 36f4e9c..fc7cfee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -112,8 +112,10 @@ class CodeGenContext { case BinaryType => s"$getter.getBinary($ordinal)" case CalendarIntervalType => s"$getter.getInterval($ordinal)" case t: StructType => s"$getter.getStruct($ordinal, ${t.size})" - case a: ArrayType => s"$getter.getArray($ordinal)" - case _ => s"($jt)$getter.get($ordinal)" // todo: remove generic getter. + case _: ArrayType => s"$getter.getArray($ordinal)" + case _: MapType => s"$getter.getMap($ordinal)" + case NullType => "null" + case _ => s"($jt)$getter.get($ordinal, null)" } } @@ -156,7 +158,7 @@ class CodeGenContext { case CalendarIntervalType => "CalendarInterval" case _: StructType => "InternalRow" case _: ArrayType => "ArrayData" - case _: MapType => "scala.collection.Map" + case _: MapType => "MapData" case dt: OpenHashSetUDT if dt.elementType == IntegerType => classOf[IntegerHashSet].getName case dt: OpenHashSetUDT if dt.elementType == LongType => classOf[LongHashSet].getName case _ => "Object" @@ -300,7 +302,8 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin classOf[UTF8String].getName, classOf[Decimal].getName, classOf[CalendarInterval].getName, - classOf[ArrayData].getName + classOf[ArrayData].getName, + classOf[MapData].getName )) evaluator.setExtendedClass(classOf[GeneratedClass]) try { http://git-wip-us.apache.org/repos/asf/spark/blob/1d59a416/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala index 3592014..6f9acda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateProjection.scala @@ -183,7 +183,7 @@ object GenerateProjection extends CodeGenerator[Seq[Expression], Projection] { public void setNullAt(int i) { nullBits[i] = true; } public boolean isNullAt(int i) { return nullBits[i]; } - public Object get(int i, ${classOf[DataType].getName} dataType) { + public Object genericGet(int i) { if (isNullAt(i)) return null; switch (i) { $getCases http://git-wip-us.apache.org/repos/asf/spark/blob/1d59a416/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 0a53059..1156797 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -31,15 +31,11 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType override def nullSafeEval(value: Any): Int = child.dataType match { case _: ArrayType => value.asInstanceOf[ArrayData].numElements() - case _: MapType => value.asInstanceOf[Map[Any, Any]].size + case _: MapType => value.asInstanceOf[MapData].numElements() } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val sizeCall = child.dataType match { - case _: ArrayType => "numElements()" - case _: MapType => "size()" - } - nullSafeCodeGen(ctx, ev, c => s"${ev.primitive} = ($c).$sizeCall;") + nullSafeCodeGen(ctx, ev, c => s"${ev.primitive} = ($c).numElements();") } } http://git-wip-us.apache.org/repos/asf/spark/blob/1d59a416/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index 99393c9..9927da2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import scala.collection.Map - import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis._ @@ -41,7 +39,7 @@ object ExtractValue { * Struct | Literal String | GetStructField * Array[Struct] | Literal String | GetArrayStructFields * Array | Integral type | GetArrayItem - * Map | Any type | GetMapValue + * Map | map key type | GetMapValue */ def apply( child: Expression, @@ -60,18 +58,14 @@ object ExtractValue { GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), ordinal, fields.length, containsNull) - case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] => - GetArrayItem(child, extraction) + case (_: ArrayType, _) => GetArrayItem(child, extraction) - case (_: MapType, _) => - GetMapValue(child, extraction) + case (MapType(kt, _, _), _) => GetMapValue(child, extraction) case (otherType, _) => val errorMsg = otherType match { - case StructType(_) | ArrayType(StructType(_), _) => + case StructType(_) => s"Field name should be String Literal, but it's $extraction" - case _: ArrayType => - s"Array index should be integral type, but it's ${extraction.dataType}" case other => s"Can't extract value from $child" } @@ -190,9 +184,13 @@ case class GetArrayStructFields( /** * Returns the field at `ordinal` in the Array `child`. * - * No need to do type checking since it is handled by [[ExtractValue]]. + * We need to do type checking here as `ordinal` expression maybe unresolved. */ -case class GetArrayItem(child: Expression, ordinal: Expression) extends BinaryExpression { +case class GetArrayItem(child: Expression, ordinal: Expression) + extends BinaryExpression with ExpectsInputTypes { + + // We have done type checking for child in `ExtractValue`, so only need to check the `ordinal`. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegralType) override def toString: String = s"$child[$ordinal]" @@ -205,14 +203,12 @@ case class GetArrayItem(child: Expression, ordinal: Expression) extends BinaryEx override def dataType: DataType = child.dataType.asInstanceOf[ArrayType].elementType protected override def nullSafeEval(value: Any, ordinal: Any): Any = { - // TODO: consider using Array[_] for ArrayType child to avoid - // boxing of primitives val baseValue = value.asInstanceOf[ArrayData] val index = ordinal.asInstanceOf[Number].intValue() if (index >= baseValue.numElements() || index < 0) { null } else { - baseValue.get(index) + baseValue.get(index, dataType) } } @@ -233,9 +229,15 @@ case class GetArrayItem(child: Expression, ordinal: Expression) extends BinaryEx /** * Returns the value of key `key` in Map `child`. * - * No need to do type checking since it is handled by [[ExtractValue]]. + * We need to do type checking here as `key` expression maybe unresolved. */ -case class GetMapValue(child: Expression, key: Expression) extends BinaryExpression { +case class GetMapValue(child: Expression, key: Expression) + extends BinaryExpression with ExpectsInputTypes { + + private def keyType = child.dataType.asInstanceOf[MapType].keyType + + // We have done type checking for child in `ExtractValue`, so only need to check the `key`. + override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, keyType) override def toString: String = s"$child[$key]" @@ -247,16 +249,53 @@ case class GetMapValue(child: Expression, key: Expression) extends BinaryExpress override def dataType: DataType = child.dataType.asInstanceOf[MapType].valueType + // todo: current search is O(n), improve it. protected override def nullSafeEval(value: Any, ordinal: Any): Any = { - val baseValue = value.asInstanceOf[Map[Any, _]] - baseValue.get(ordinal).orNull + val map = value.asInstanceOf[MapData] + val length = map.numElements() + val keys = map.keyArray() + + var i = 0 + var found = false + while (i < length && !found) { + if (keys.get(i, keyType) == ordinal) { + found = true + } else { + i += 1 + } + } + + if (!found) { + null + } else { + map.valueArray().get(i, dataType) + } } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val index = ctx.freshName("index") + val length = ctx.freshName("length") + val keys = ctx.freshName("keys") + val found = ctx.freshName("found") + val key = ctx.freshName("key") nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" - if ($eval1.contains($eval2)) { - ${ev.primitive} = (${ctx.boxedType(dataType)})$eval1.apply($eval2); + final int $length = $eval1.numElements(); + final ArrayData $keys = $eval1.keyArray(); + + int $index = 0; + boolean $found = false; + while ($index < $length && !$found) { + final ${ctx.javaType(keyType)} $key = ${ctx.getValue(keys, keyType, index)}; + if (${ctx.genEqual(keyType, key, eval2)}) { + $found = true; + } else { + $index++; + } + } + + if ($found) { + ${ev.primitive} = ${ctx.getValue(eval1 + ".valueArray()", dataType, index)}; } else { ${ev.isNull} = true; } http://git-wip-us.apache.org/repos/asf/spark/blob/1d59a416/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 8064235..d474853 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -120,13 +120,30 @@ case class Explode(child: Expression) extends UnaryExpression with Generator wit override def eval(input: InternalRow): TraversableOnce[InternalRow] = { child.dataType match { - case ArrayType(_, _) => + case ArrayType(et, _) => val inputArray = child.eval(input).asInstanceOf[ArrayData] - if (inputArray == null) Nil else inputArray.toArray().map(v => InternalRow(v)) - case MapType(_, _, _) => - val inputMap = child.eval(input).asInstanceOf[Map[Any, Any]] - if (inputMap == null) Nil - else inputMap.map { case (k, v) => InternalRow(k, v) } + if (inputArray == null) { + Nil + } else { + val rows = new Array[InternalRow](inputArray.numElements()) + inputArray.foreach(et, (i, e) => { + rows(i) = InternalRow(e) + }) + rows + } + case MapType(kt, vt, _) => + val inputMap = child.eval(input).asInstanceOf[MapData] + if (inputMap == null) { + Nil + } else { + val rows = new Array[InternalRow](inputMap.numElements()) + var i = 0 + inputMap.foreach(kt, vt, (k, v) => { + rows(i) = InternalRow(k, v) + i += 1 + }) + rows + } } } } http://git-wip-us.apache.org/repos/asf/spark/blob/1d59a416/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index df6ea58..73f6b7a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -32,28 +32,14 @@ abstract class MutableRow extends InternalRow { def update(i: Int, value: Any) // default implementation (slow) - def setInt(i: Int, value: Int): Unit = { update(i, value) } - def setLong(i: Int, value: Long): Unit = { update(i, value) } - def setDouble(i: Int, value: Double): Unit = { update(i, value) } def setBoolean(i: Int, value: Boolean): Unit = { update(i, value) } - def setShort(i: Int, value: Short): Unit = { update(i, value) } def setByte(i: Int, value: Byte): Unit = { update(i, value) } + def setShort(i: Int, value: Short): Unit = { update(i, value) } + def setInt(i: Int, value: Int): Unit = { update(i, value) } + def setLong(i: Int, value: Long): Unit = { update(i, value) } def setFloat(i: Int, value: Float): Unit = { update(i, value) } + def setDouble(i: Int, value: Double): Unit = { update(i, value) } def setDecimal(i: Int, value: Decimal, precision: Int) { update(i, value) } - def setString(i: Int, value: String): Unit = { - update(i, UTF8String.fromString(value)) - } - - override def copy(): InternalRow = { - val n = numFields - val arr = new Array[Any](n) - var i = 0 - while (i < n) { - arr(i) = get(i) - i += 1 - } - new GenericInternalRow(arr) - } } /** @@ -96,17 +82,13 @@ class GenericInternalRow(protected[sql] val values: Array[Any]) extends Internal def this(size: Int) = this(new Array[Any](size)) - override def toSeq: Seq[Any] = values.toSeq - - override def numFields: Int = values.length + override def genericGet(ordinal: Int): Any = values(ordinal) - override def get(i: Int, dataType: DataType): Any = values(i) + override def toSeq: Seq[Any] = values - override def getStruct(ordinal: Int, numFields: Int): InternalRow = { - values(ordinal).asInstanceOf[InternalRow] - } + override def numFields: Int = values.length - override def copy(): InternalRow = this + override def copy(): InternalRow = new GenericInternalRow(values.clone()) } /** @@ -127,15 +109,11 @@ class GenericMutableRow(val values: Array[Any]) extends MutableRow { def this(size: Int) = this(new Array[Any](size)) - override def toSeq: Seq[Any] = values.toSeq - - override def numFields: Int = values.length + override def genericGet(ordinal: Int): Any = values(ordinal) - override def get(i: Int, dataType: DataType): Any = values(i) + override def toSeq: Seq[Any] = values - override def getStruct(ordinal: Int, numFields: Int): InternalRow = { - values(ordinal).asInstanceOf[InternalRow] - } + override def numFields: Int = values.length override def setNullAt(i: Int): Unit = { values(i) = null} http://git-wip-us.apache.org/repos/asf/spark/blob/1d59a416/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 5dd387a..3ce5d6a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -95,7 +95,7 @@ case class ConcatWs(children: Seq[Expression]) val flatInputs = children.flatMap { child => child.eval(input) match { case s: UTF8String => Iterator(s) - case arr: ArrayData => arr.toArray().map(_.asInstanceOf[UTF8String]) + case arr: ArrayData => arr.toArray[UTF8String](StringType) case null => Iterator(null.asInstanceOf[UTF8String]) } } http://git-wip-us.apache.org/repos/asf/spark/blob/1d59a416/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala new file mode 100644 index 0000000..db48763 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayBasedMapData.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) extends MapData { + require(keyArray.numElements() == valueArray.numElements()) + + override def numElements(): Int = keyArray.numElements() + + override def equals(o: Any): Boolean = { + if (!o.isInstanceOf[ArrayBasedMapData]) { + return false + } + + val other = o.asInstanceOf[ArrayBasedMapData] + if (other eq null) { + return false + } + + this.keyArray == other.keyArray && this.valueArray == other.valueArray + } + + override def hashCode: Int = { + keyArray.hashCode() * 37 + valueArray.hashCode() + } + + override def toString(): String = { + s"keys: $keyArray\nvalues: $valueArray" + } +} + +object ArrayBasedMapData { + def apply(keys: Array[Any], values: Array[Any]): ArrayBasedMapData = { + new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values)) + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/1d59a416/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala index 14a7285..c99fc23 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/ArrayData.scala @@ -20,102 +20,111 @@ package org.apache.spark.sql.types import org.apache.spark.sql.catalyst.expressions.SpecializedGetters abstract class ArrayData extends SpecializedGetters with Serializable { - // todo: remove this after we handle all types.(map type need special getter) - def get(ordinal: Int): Any - def numElements(): Int - // todo: need a more efficient way to iterate array type. - def toArray(): Array[Any] = { - val n = numElements() - val values = new Array[Any](n) + def toBooleanArray(): Array[Boolean] = { + val size = numElements() + val values = new Array[Boolean](size) var i = 0 - while (i < n) { - if (isNullAt(i)) { - values(i) = null - } else { - values(i) = get(i) - } + while (i < size) { + values(i) = getBoolean(i) i += 1 } values } - override def toString(): String = toArray.mkString("[", ",", "]") + def toByteArray(): Array[Byte] = { + val size = numElements() + val values = new Array[Byte](size) + var i = 0 + while (i < size) { + values(i) = getByte(i) + i += 1 + } + values + } - override def equals(o: Any): Boolean = { - if (!o.isInstanceOf[ArrayData]) { - return false + def toShortArray(): Array[Short] = { + val size = numElements() + val values = new Array[Short](size) + var i = 0 + while (i < size) { + values(i) = getShort(i) + i += 1 } + values + } - val other = o.asInstanceOf[ArrayData] - if (other eq null) { - return false + def toIntArray(): Array[Int] = { + val size = numElements() + val values = new Array[Int](size) + var i = 0 + while (i < size) { + values(i) = getInt(i) + i += 1 } + values + } - val len = numElements() - if (len != other.numElements()) { - return false + def toLongArray(): Array[Long] = { + val size = numElements() + val values = new Array[Long](size) + var i = 0 + while (i < size) { + values(i) = getLong(i) + i += 1 } + values + } + def toFloatArray(): Array[Float] = { + val size = numElements() + val values = new Array[Float](size) var i = 0 - while (i < len) { - if (isNullAt(i) != other.isNullAt(i)) { - return false - } - if (!isNullAt(i)) { - val o1 = get(i) - val o2 = other.get(i) - o1 match { - case b1: Array[Byte] => - if (!o2.isInstanceOf[Array[Byte]] || - !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { - return false - } - case f1: Float if java.lang.Float.isNaN(f1) => - if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { - return false - } - case d1: Double if java.lang.Double.isNaN(d1) => - if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { - return false - } - case _ => if (o1 != o2) { - return false - } - } + while (i < size) { + values(i) = getFloat(i) + i += 1 + } + values + } + + def toDoubleArray(): Array[Double] = { + val size = numElements() + val values = new Array[Double](size) + var i = 0 + while (i < size) { + values(i) = getDouble(i) + i += 1 + } + values + } + + def toArray[T](elementType: DataType): Array[T] = { + val size = numElements() + val values = new Array[Any](size) + var i = 0 + while (i < size) { + if (isNullAt(i)) { + values(i) = null + } else { + values(i) = get(i, elementType) } i += 1 } - true + values.asInstanceOf[Array[T]] } - override def hashCode: Int = { - var result: Int = 37 + // todo: specialize this. + def foreach(elementType: DataType, f: (Int, Any) => Unit): Unit = { + val size = numElements() var i = 0 - val len = numElements() - while (i < len) { - val update: Int = - if (isNullAt(i)) { - 0 - } else { - get(i) match { - case b: Boolean => if (b) 0 else 1 - case b: Byte => b.toInt - case s: Short => s.toInt - case i: Int => i - case l: Long => (l ^ (l >>> 32)).toInt - case f: Float => java.lang.Float.floatToIntBits(f) - case d: Double => - val b = java.lang.Double.doubleToLongBits(d) - (b ^ (b >>> 32)).toInt - case a: Array[Byte] => java.util.Arrays.hashCode(a) - case other => other.hashCode() - } - } - result = 37 * result + update + while (i < size) { + if (isNullAt(i)) { + f(i, null) + } else { + f(i, get(i, elementType)) + } i += 1 } - result } } http://git-wip-us.apache.org/repos/asf/spark/blob/1d59a416/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala index 35ace67..b3e75f8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/GenericArrayData.scala @@ -17,43 +17,91 @@ package org.apache.spark.sql.types -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.unsafe.types.{UTF8String, CalendarInterval} +import org.apache.spark.sql.catalyst.expressions.GenericSpecializedGetters -class GenericArrayData(array: Array[Any]) extends ArrayData { - private def getAs[T](ordinal: Int) = get(ordinal).asInstanceOf[T] +class GenericArrayData(array: Array[Any]) extends ArrayData with GenericSpecializedGetters { - override def toArray(): Array[Any] = array + override def genericGet(ordinal: Int): Any = array(ordinal) - override def get(ordinal: Int): Any = array(ordinal) - - override def isNullAt(ordinal: Int): Boolean = get(ordinal) == null - - override def getBoolean(ordinal: Int): Boolean = getAs(ordinal) - - override def getByte(ordinal: Int): Byte = getAs(ordinal) - - override def getShort(ordinal: Int): Short = getAs(ordinal) - - override def getInt(ordinal: Int): Int = getAs(ordinal) - - override def getLong(ordinal: Int): Long = getAs(ordinal) - - override def getFloat(ordinal: Int): Float = getAs(ordinal) - - override def getDouble(ordinal: Int): Double = getAs(ordinal) - - override def getDecimal(ordinal: Int, precision: Int, scale: Int): Decimal = getAs(ordinal) - - override def getUTF8String(ordinal: Int): UTF8String = getAs(ordinal) - - override def getBinary(ordinal: Int): Array[Byte] = getAs(ordinal) - - override def getInterval(ordinal: Int): CalendarInterval = getAs(ordinal) - - override def getStruct(ordinal: Int, numFields: Int): InternalRow = getAs(ordinal) - - override def getArray(ordinal: Int): ArrayData = getAs(ordinal) + override def toArray[T](elementType: DataType): Array[T] = array.asInstanceOf[Array[T]] override def numElements(): Int = array.length + + override def toString(): String = array.mkString("[", ",", "]") + + override def equals(o: Any): Boolean = { + if (!o.isInstanceOf[GenericArrayData]) { + return false + } + + val other = o.asInstanceOf[GenericArrayData] + if (other eq null) { + return false + } + + val len = numElements() + if (len != other.numElements()) { + return false + } + + var i = 0 + while (i < len) { + if (isNullAt(i) != other.isNullAt(i)) { + return false + } + if (!isNullAt(i)) { + val o1 = genericGet(i) + val o2 = other.genericGet(i) + o1 match { + case b1: Array[Byte] => + if (!o2.isInstanceOf[Array[Byte]] || + !java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])) { + return false + } + case f1: Float if java.lang.Float.isNaN(f1) => + if (!o2.isInstanceOf[Float] || ! java.lang.Float.isNaN(o2.asInstanceOf[Float])) { + return false + } + case d1: Double if java.lang.Double.isNaN(d1) => + if (!o2.isInstanceOf[Double] || ! java.lang.Double.isNaN(o2.asInstanceOf[Double])) { + return false + } + case _ => if (o1 != o2) { + return false + } + } + } + i += 1 + } + true + } + + override def hashCode: Int = { + var result: Int = 37 + var i = 0 + val len = numElements() + while (i < len) { + val update: Int = + if (isNullAt(i)) { + 0 + } else { + genericGet(i) match { + case b: Boolean => if (b) 0 else 1 + case b: Byte => b.toInt + case s: Short => s.toInt + case i: Int => i + case l: Long => (l ^ (l >>> 32)).toInt + case f: Float => java.lang.Float.floatToIntBits(f) + case d: Double => + val b = java.lang.Double.doubleToLongBits(d) + (b ^ (b >>> 32)).toInt + case a: Array[Byte] => java.util.Arrays.hashCode(a) + case other => other.hashCode() + } + } + result = 37 * result + update + i += 1 + } + result + } } http://git-wip-us.apache.org/repos/asf/spark/blob/1d59a416/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapData.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapData.scala new file mode 100644 index 0000000..5514c3c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapData.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.types + +abstract class MapData extends Serializable { + + def numElements(): Int + + def keyArray(): ArrayData + + def valueArray(): ArrayData + + def foreach(keyType: DataType, valueType: DataType, f: (Any, Any) => Unit): Unit = { + val length = numElements() + val keys = keyArray() + val values = valueArray() + var i = 0 + while (i < length) { + f(keys.get(i, keyType), values.get(i, valueType)) + i += 1 + } + } +} http://git-wip-us.apache.org/repos/asf/spark/blob/1d59a416/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 3fa246b..e60990a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -171,8 +171,6 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { test("error message of ExtractValue") { val structType = StructType(StructField("a", StringType, true) :: Nil) - val arrayStructType = ArrayType(structType) - val arrayType = ArrayType(StringType) val otherType = StringType def checkErrorMessage( @@ -189,8 +187,6 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { } checkErrorMessage(structType, IntegerType, "Field name should be String Literal") - checkErrorMessage(arrayStructType, BooleanType, "Field name should be String Literal") - checkErrorMessage(arrayType, StringType, "Array index should be integral type") checkErrorMessage(otherType, StringType, "Can't extract value from") } } http://git-wip-us.apache.org/repos/asf/spark/blob/1d59a416/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index a0e1701..44f8456 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -87,7 +87,7 @@ class UnsafeRowConverterSuite extends SparkFunSuite with Matchers { val row = new SpecificMutableRow(fieldTypes) row.setLong(0, 0) - row.setString(1, "Hello") + row.update(1, UTF8String.fromString("Hello")) row.update(2, DateTimeUtils.fromJavaDate(Date.valueOf("1970-01-01"))) row.update(3, DateTimeUtils.fromJavaTimestamp(Timestamp.valueOf("2015-05-08 08:10:25"))) http://git-wip-us.apache.org/repos/asf/spark/blob/1d59a416/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index f26f41f..c37007f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -159,10 +159,16 @@ package object debug { case (row: InternalRow, StructType(fields)) => row.toSeq.zip(fields.map(_.dataType)).foreach { case(d, t) => typeCheck(d, t) } case (a: ArrayData, ArrayType(elemType, _)) => - a.toArray().foreach(typeCheck(_, elemType)) - case (m: Map[_, _], MapType(keyType, valueType, _)) => - m.keys.foreach(typeCheck(_, keyType)) - m.values.foreach(typeCheck(_, valueType)) + a.foreach(elemType, (_, e) => { + typeCheck(e, elemType) + }) + case (m: MapData, MapType(keyType, valueType, _)) => + m.keyArray().foreach(keyType, (_, e) => { + typeCheck(e, keyType) + }) + m.valueArray().foreach(valueType, (_, e) => { + typeCheck(e, valueType) + }) case (_: Long, LongType) => case (_: Int, IntegerType) => http://git-wip-us.apache.org/repos/asf/spark/blob/1d59a416/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala index ef1c6e5..aade2e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -135,22 +135,18 @@ object EvaluatePython { new GenericInternalRowWithSchema(values, struct) case (a: ArrayData, array: ArrayType) => - val length = a.numElements() - val values = new java.util.ArrayList[Any](length) - var i = 0 - while (i < length) { - if (a.isNullAt(i)) { - values.add(null) - } else { - values.add(toJava(a.get(i), array.elementType)) - } - i += 1 - } + val values = new java.util.ArrayList[Any](a.numElements()) + a.foreach(array.elementType, (_, e) => { + values.add(toJava(e, array.elementType)) + }) values - case (obj: Map[_, _], mt: MapType) => obj.map { - case (k, v) => (toJava(k, mt.keyType), toJava(v, mt.valueType)) - }.asJava + case (map: MapData, mt: MapType) => + val jmap = new java.util.HashMap[Any, Any](map.numElements()) + map.foreach(mt.keyType, mt.valueType, (k, v) => { + jmap.put(toJava(k, mt.keyType), toJava(v, mt.valueType)) + }) + jmap case (ud, udt: UserDefinedType[_]) => toJava(ud, udt.sqlType) @@ -206,9 +202,10 @@ object EvaluatePython { case (c, ArrayType(elementType, _)) if c.getClass.isArray => new GenericArrayData(c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType))) - case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => c.map { - case (key, value) => (fromJava(key, keyType), fromJava(value, valueType)) - }.toMap + case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => + val keys = c.keysIterator.map(fromJava(_, keyType)).toArray + val values = c.valuesIterator.map(fromJava(_, valueType)).toArray + ArrayBasedMapData(keys, values) case (c, StructType(fields)) if c.getClass.isArray => new GenericInternalRow(c.asInstanceOf[Array[_]].zip(fields).map { http://git-wip-us.apache.org/repos/asf/spark/blob/1d59a416/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala index 1c309f8..bf0448e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JacksonParser.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.json import java.io.ByteArrayOutputStream -import scala.collection.Map - import com.fasterxml.jackson.core._ +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -31,7 +31,6 @@ import org.apache.spark.sql.json.JacksonUtils.nextUntil import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String - private[sql] object JacksonParser { def apply( json: RDD[String], @@ -160,21 +159,21 @@ private[sql] object JacksonParser { private def convertMap( factory: JsonFactory, parser: JsonParser, - valueType: DataType): Map[UTF8String, Any] = { - val builder = Map.newBuilder[UTF8String, Any] + valueType: DataType): MapData = { + val keys = ArrayBuffer.empty[UTF8String] + val values = ArrayBuffer.empty[Any] while (nextUntil(parser, JsonToken.END_OBJECT)) { - builder += - UTF8String.fromString(parser.getCurrentName) -> convertField(factory, parser, valueType) + keys += UTF8String.fromString(parser.getCurrentName) + values += convertField(factory, parser, valueType) } - - builder.result() + ArrayBasedMapData(keys.toArray, values.toArray) } private def convertArray( factory: JsonFactory, parser: JsonParser, elementType: DataType): ArrayData = { - val values = scala.collection.mutable.ArrayBuffer.empty[Any] + val values = ArrayBuffer.empty[Any] while (nextUntil(parser, JsonToken.END_ARRAY)) { values += convertField(factory, parser, elementType) } @@ -213,7 +212,7 @@ private[sql] object JacksonParser { if (array.numElements() == 0) { Nil } else { - array.toArray().map(_.asInstanceOf[InternalRow]) + array.toArray[InternalRow](schema) } case _ => sys.error( http://git-wip-us.apache.org/repos/asf/spark/blob/1d59a416/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala index 172db83..6938b07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/CatalystRowConverter.scala @@ -385,7 +385,8 @@ private[parquet] class CatalystRowConverter( updater: ParentContainerUpdater) extends GroupConverter { - private var currentMap: mutable.Map[Any, Any] = _ + private var currentKeys: ArrayBuffer[Any] = _ + private var currentValues: ArrayBuffer[Any] = _ private val keyValueConverter = { val repeatedType = parquetType.getType(0).asGroupType() @@ -398,12 +399,16 @@ private[parquet] class CatalystRowConverter( override def getConverter(fieldIndex: Int): Converter = keyValueConverter - override def end(): Unit = updater.set(currentMap) + override def end(): Unit = + updater.set(ArrayBasedMapData(currentKeys.toArray, currentValues.toArray)) // NOTE: We can't reuse the mutable Map here and must instantiate a new `Map` for the next // value. `Row.copy()` only copies row cells, it doesn't do deep copy to objects stored in row // cells. - override def start(): Unit = currentMap = mutable.Map.empty[Any, Any] + override def start(): Unit = { + currentKeys = ArrayBuffer.empty[Any] + currentValues = ArrayBuffer.empty[Any] + } /** Parquet converter for key-value pairs within the map. */ private final class KeyValueConverter( @@ -430,7 +435,10 @@ private[parquet] class CatalystRowConverter( override def getConverter(fieldIndex: Int): Converter = converters(fieldIndex) - override def end(): Unit = currentMap(currentKey) = currentValue + override def end(): Unit = { + currentKeys += currentKey + currentValues += currentValue + } override def start(): Unit = { currentKey = null http://git-wip-us.apache.org/repos/asf/spark/blob/1d59a416/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala index 2332a36..6ed3580 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetConverter.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.parquet import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.types.ArrayData +import org.apache.spark.sql.types.{MapData, ArrayData} // TODO Removes this while fixing SPARK-8848 private[sql] object CatalystConverter { @@ -33,7 +33,7 @@ private[sql] object CatalystConverter { val MAP_SCHEMA_NAME = "map" // TODO: consider using Array[T] for arrays to avoid boxing of primitive types - type ArrayScalaType[T] = ArrayData - type StructScalaType[T] = InternalRow - type MapScalaType[K, V] = Map[K, V] + type ArrayScalaType = ArrayData + type StructScalaType = InternalRow + type MapScalaType = MapData } http://git-wip-us.apache.org/repos/asf/spark/blob/1d59a416/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala index ec8da38..9cd0250 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableSupport.scala @@ -88,13 +88,13 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo case t: UserDefinedType[_] => writeValue(t.sqlType, value) case t @ ArrayType(_, _) => writeArray( t, - value.asInstanceOf[CatalystConverter.ArrayScalaType[_]]) + value.asInstanceOf[CatalystConverter.ArrayScalaType]) case t @ MapType(_, _, _) => writeMap( t, - value.asInstanceOf[CatalystConverter.MapScalaType[_, _]]) + value.asInstanceOf[CatalystConverter.MapScalaType]) case t @ StructType(_) => writeStruct( t, - value.asInstanceOf[CatalystConverter.StructScalaType[_]]) + value.asInstanceOf[CatalystConverter.StructScalaType]) case _ => writePrimitive(schema.asInstanceOf[AtomicType], value) } } @@ -124,7 +124,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo private[parquet] def writeStruct( schema: StructType, - struct: CatalystConverter.StructScalaType[_]): Unit = { + struct: CatalystConverter.StructScalaType): Unit = { if (struct != null) { val fields = schema.fields.toArray writer.startGroup() @@ -143,7 +143,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo private[parquet] def writeArray( schema: ArrayType, - array: CatalystConverter.ArrayScalaType[_]): Unit = { + array: CatalystConverter.ArrayScalaType): Unit = { val elementType = schema.elementType writer.startGroup() if (array.numElements() > 0) { @@ -154,7 +154,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo writer.startGroup() if (!array.isNullAt(i)) { writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) - writeValue(elementType, array.get(i)) + writeValue(elementType, array.get(i, elementType)) writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) } writer.endGroup() @@ -165,7 +165,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo writer.startField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) var i = 0 while (i < array.numElements()) { - writeValue(elementType, array.get(i)) + writeValue(elementType, array.get(i, elementType)) i = i + 1 } writer.endField(CatalystConverter.ARRAY_ELEMENTS_SCHEMA_NAME, 0) @@ -176,11 +176,12 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo private[parquet] def writeMap( schema: MapType, - map: CatalystConverter.MapScalaType[_, _]): Unit = { + map: CatalystConverter.MapScalaType): Unit = { writer.startGroup() - if (map.size > 0) { + val length = map.numElements() + if (length > 0) { writer.startField(CatalystConverter.MAP_SCHEMA_NAME, 0) - for ((key, value) <- map) { + map.foreach(schema.keyType, schema.valueType, (key, value) => { writer.startGroup() writer.startField(CatalystConverter.MAP_KEY_SCHEMA_NAME, 0) writeValue(schema.keyType, key) @@ -191,7 +192,7 @@ private[parquet] class RowWriteSupport extends WriteSupport[InternalRow] with Lo writer.endField(CatalystConverter.MAP_VALUE_SCHEMA_NAME, 1) } writer.endGroup() - } + }) writer.endField(CatalystConverter.MAP_SCHEMA_NAME, 0) } writer.endGroup() --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org