Repository: spark Updated Branches: refs/heads/master 05ae74778 -> 30ffb53ca
[SPARK-23875][SQL] Add IndexedSeq wrapper for ArrayData ## What changes were proposed in this pull request? We don't have a good way to sequentially access `UnsafeArrayData` with a common interface such as `Seq`. An example is `MapObject` where we need to access several sequence collection types together. But `UnsafeArrayData` doesn't implement `ArrayData.array`. Calling `toArray` will copy the entire array. We can provide an `IndexedSeq` wrapper for `ArrayData`, so we can avoid copying the entire array. ## How was this patch tested? Added test. Author: Liang-Chi Hsieh <vii...@gmail.com> Closes #20984 from viirya/SPARK-23875. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/30ffb53c Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/30ffb53c Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/30ffb53c Branch: refs/heads/master Commit: 30ffb53cad84283b4f7694bfd60bdd7e1101b04e Parents: 05ae747 Author: Liang-Chi Hsieh <vii...@gmail.com> Authored: Tue Apr 17 15:09:36 2018 +0200 Committer: Herman van Hovell <hvanhov...@databricks.com> Committed: Tue Apr 17 15:09:36 2018 +0200 ---------------------------------------------------------------------- .../catalyst/expressions/objects/objects.scala | 2 +- .../spark/sql/catalyst/util/ArrayData.scala | 30 +++++- .../util/ArrayDataIndexedSeqSuite.scala | 100 +++++++++++++++++++ 3 files changed, 130 insertions(+), 2 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/30ffb53c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 77802e8..72b202b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -708,7 +708,7 @@ case class MapObjects private( } } case ArrayType(et, _) => - _.asInstanceOf[ArrayData].array + _.asInstanceOf[ArrayData].toSeq[Any](et) } private lazy val mapElements: Seq[_] => Any = customCollectionCls match { http://git-wip-us.apache.org/repos/asf/spark/blob/30ffb53c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala index 9beef41..2cf59d5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.catalyst.util import scala.reflect.ClassTag +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, UnsafeArrayData} -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types._ object ArrayData { def toArrayData(input: Any): ArrayData = input match { @@ -42,6 +43,9 @@ abstract class ArrayData extends SpecializedGetters with Serializable { def array: Array[Any] + def toSeq[T](dataType: DataType): IndexedSeq[T] = + new ArrayDataIndexedSeq[T](this, dataType) + def setNullAt(i: Int): Unit def update(i: Int, value: Any): Unit @@ -164,3 +168,27 @@ abstract class ArrayData extends SpecializedGetters with Serializable { } } } + +/** + * Implements an `IndexedSeq` interface for `ArrayData`. Notice that if the original `ArrayData` + * is a primitive array and contains null elements, it is better to ask for `IndexedSeq[Any]`, + * instead of `IndexedSeq[Int]`, in order to keep the null elements. + */ +class ArrayDataIndexedSeq[T](arrayData: ArrayData, dataType: DataType) extends IndexedSeq[T] { + + private val accessor: (SpecializedGetters, Int) => Any = InternalRow.getAccessor(dataType) + + override def apply(idx: Int): T = + if (0 <= idx && idx < arrayData.numElements()) { + if (arrayData.isNullAt(idx)) { + null.asInstanceOf[T] + } else { + accessor(arrayData, idx).asInstanceOf[T] + } + } else { + throw new IndexOutOfBoundsException( + s"Index $idx must be between 0 and the length of the ArrayData.") + } + + override def length: Int = arrayData.numElements() +} http://git-wip-us.apache.org/repos/asf/spark/blob/30ffb53c/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala new file mode 100644 index 0000000..6400898 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala @@ -0,0 +1,100 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import scala.util.Random + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.RandomDataGenerator +import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder} +import org.apache.spark.sql.catalyst.expressions.{FromUnsafeProjection, UnsafeArrayData, UnsafeProjection} +import org.apache.spark.sql.types._ + +class ArrayDataIndexedSeqSuite extends SparkFunSuite { + private def compArray(arrayData: ArrayData, elementDt: DataType, array: Array[Any]): Unit = { + assert(arrayData.numElements == array.length) + array.zipWithIndex.map { case (e, i) => + if (e != null) { + elementDt match { + // For NaN, etc. + case FloatType | DoubleType => assert(arrayData.get(i, elementDt).equals(e)) + case _ => assert(arrayData.get(i, elementDt) === e) + } + } else { + assert(arrayData.isNullAt(i)) + } + } + + val seq = arrayData.toSeq[Any](elementDt) + array.zipWithIndex.map { case (e, i) => + if (e != null) { + elementDt match { + // For Nan, etc. + case FloatType | DoubleType => assert(seq(i).equals(e)) + case _ => assert(seq(i) === e) + } + } else { + assert(seq(i) == null) + } + } + + intercept[IndexOutOfBoundsException] { + seq(-1) + }.getMessage().contains("must be between 0 and the length of the ArrayData.") + + intercept[IndexOutOfBoundsException] { + seq(seq.length) + }.getMessage().contains("must be between 0 and the length of the ArrayData.") + } + + private def testArrayData(): Unit = { + val elementTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, LongType, FloatType, + DoubleType, DecimalType.USER_DEFAULT, StringType, BinaryType, DateType, TimestampType, + CalendarIntervalType, new ExamplePointUDT()) + val arrayTypes = elementTypes.flatMap { elementType => + Seq(ArrayType(elementType, containsNull = false), ArrayType(elementType, containsNull = true)) + } + val random = new Random(100) + arrayTypes.foreach { dt => + val schema = StructType(StructField("col_1", dt, nullable = false) :: Nil) + val row = RandomDataGenerator.randomRow(random, schema) + val rowConverter = RowEncoder(schema) + val internalRow = rowConverter.toRow(row) + + val unsafeRowConverter = UnsafeProjection.create(schema) + val safeRowConverter = FromUnsafeProjection(schema) + + val unsafeRow = unsafeRowConverter(internalRow) + val safeRow = safeRowConverter(unsafeRow) + + val genericArrayData = safeRow.getArray(0).asInstanceOf[GenericArrayData] + val unsafeArrayData = unsafeRow.getArray(0).asInstanceOf[UnsafeArrayData] + + val elementType = dt.elementType + test("ArrayDataIndexedSeq - UnsafeArrayData - " + dt.toString) { + compArray(unsafeArrayData, elementType, unsafeArrayData.toArray[Any](elementType)) + } + + test("ArrayDataIndexedSeq - GenericArrayData - " + dt.toString) { + compArray(genericArrayData, elementType, genericArrayData.toArray[Any](elementType)) + } + } + } + + testArrayData() +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org