Repository: spark
Updated Branches:
  refs/heads/master 05ae74778 -> 30ffb53ca


[SPARK-23875][SQL] Add IndexedSeq wrapper for ArrayData

## What changes were proposed in this pull request?

We don't have a good way to sequentially access `UnsafeArrayData` with a common 
interface such as `Seq`. An example is `MapObject` where we need to access 
several sequence collection types together. But `UnsafeArrayData` doesn't 
implement `ArrayData.array`. Calling `toArray` will copy the entire array. We 
can provide an `IndexedSeq` wrapper for `ArrayData`, so we can avoid copying 
the entire array.

## How was this patch tested?

Added test.

Author: Liang-Chi Hsieh <vii...@gmail.com>

Closes #20984 from viirya/SPARK-23875.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/30ffb53c
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/30ffb53c
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/30ffb53c

Branch: refs/heads/master
Commit: 30ffb53cad84283b4f7694bfd60bdd7e1101b04e
Parents: 05ae747
Author: Liang-Chi Hsieh <vii...@gmail.com>
Authored: Tue Apr 17 15:09:36 2018 +0200
Committer: Herman van Hovell <hvanhov...@databricks.com>
Committed: Tue Apr 17 15:09:36 2018 +0200

----------------------------------------------------------------------
 .../catalyst/expressions/objects/objects.scala  |   2 +-
 .../spark/sql/catalyst/util/ArrayData.scala     |  30 +++++-
 .../util/ArrayDataIndexedSeqSuite.scala         | 100 +++++++++++++++++++
 3 files changed, 130 insertions(+), 2 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/30ffb53c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 77802e8..72b202b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -708,7 +708,7 @@ case class MapObjects private(
         }
       }
     case ArrayType(et, _) =>
-      _.asInstanceOf[ArrayData].array
+      _.asInstanceOf[ArrayData].toSeq[Any](et)
   }
 
   private lazy val mapElements: Seq[_] => Any = customCollectionCls match {

http://git-wip-us.apache.org/repos/asf/spark/blob/30ffb53c/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
index 9beef41..2cf59d5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala
@@ -19,8 +19,9 @@ package org.apache.spark.sql.catalyst.util
 
 import scala.reflect.ClassTag
 
+import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, 
UnsafeArrayData}
-import org.apache.spark.sql.types.DataType
+import org.apache.spark.sql.types._
 
 object ArrayData {
   def toArrayData(input: Any): ArrayData = input match {
@@ -42,6 +43,9 @@ abstract class ArrayData extends SpecializedGetters with 
Serializable {
 
   def array: Array[Any]
 
+  def toSeq[T](dataType: DataType): IndexedSeq[T] =
+    new ArrayDataIndexedSeq[T](this, dataType)
+
   def setNullAt(i: Int): Unit
 
   def update(i: Int, value: Any): Unit
@@ -164,3 +168,27 @@ abstract class ArrayData extends SpecializedGetters with 
Serializable {
     }
   }
 }
+
+/**
+ * Implements an `IndexedSeq` interface for `ArrayData`. Notice that if the 
original `ArrayData`
+ * is a primitive array and contains null elements, it is better to ask for 
`IndexedSeq[Any]`,
+ * instead of `IndexedSeq[Int]`, in order to keep the null elements.
+ */
+class ArrayDataIndexedSeq[T](arrayData: ArrayData, dataType: DataType) extends 
IndexedSeq[T] {
+
+  private val accessor: (SpecializedGetters, Int) => Any = 
InternalRow.getAccessor(dataType)
+
+  override def apply(idx: Int): T =
+    if (0 <= idx && idx < arrayData.numElements()) {
+      if (arrayData.isNullAt(idx)) {
+        null.asInstanceOf[T]
+      } else {
+        accessor(arrayData, idx).asInstanceOf[T]
+      }
+    } else {
+      throw new IndexOutOfBoundsException(
+        s"Index $idx must be between 0 and the length of the ArrayData.")
+    }
+
+  override def length: Int = arrayData.numElements()
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/30ffb53c/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala
new file mode 100644
index 0000000..6400898
--- /dev/null
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayDataIndexedSeqSuite.scala
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.util
+
+import scala.util.Random
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.RandomDataGenerator
+import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder}
+import org.apache.spark.sql.catalyst.expressions.{FromUnsafeProjection, 
UnsafeArrayData, UnsafeProjection}
+import org.apache.spark.sql.types._
+
+class ArrayDataIndexedSeqSuite extends SparkFunSuite {
+  private def compArray(arrayData: ArrayData, elementDt: DataType, array: 
Array[Any]): Unit = {
+    assert(arrayData.numElements == array.length)
+    array.zipWithIndex.map { case (e, i) =>
+      if (e != null) {
+        elementDt match {
+          // For NaN, etc.
+          case FloatType | DoubleType => assert(arrayData.get(i, 
elementDt).equals(e))
+          case _ => assert(arrayData.get(i, elementDt) === e)
+        }
+      } else {
+        assert(arrayData.isNullAt(i))
+      }
+    }
+
+    val seq = arrayData.toSeq[Any](elementDt)
+    array.zipWithIndex.map { case (e, i) =>
+      if (e != null) {
+        elementDt match {
+          // For Nan, etc.
+          case FloatType | DoubleType => assert(seq(i).equals(e))
+          case _ => assert(seq(i) === e)
+        }
+      } else {
+        assert(seq(i) == null)
+      }
+    }
+
+    intercept[IndexOutOfBoundsException] {
+      seq(-1)
+    }.getMessage().contains("must be between 0 and the length of the 
ArrayData.")
+
+    intercept[IndexOutOfBoundsException] {
+      seq(seq.length)
+    }.getMessage().contains("must be between 0 and the length of the 
ArrayData.")
+  }
+
+  private def testArrayData(): Unit = {
+    val elementTypes = Seq(BooleanType, ByteType, ShortType, IntegerType, 
LongType, FloatType,
+      DoubleType, DecimalType.USER_DEFAULT, StringType, BinaryType, DateType, 
TimestampType,
+      CalendarIntervalType, new ExamplePointUDT())
+    val arrayTypes = elementTypes.flatMap { elementType =>
+      Seq(ArrayType(elementType, containsNull = false), ArrayType(elementType, 
containsNull = true))
+    }
+    val random = new Random(100)
+    arrayTypes.foreach { dt =>
+      val schema = StructType(StructField("col_1", dt, nullable = false) :: 
Nil)
+      val row = RandomDataGenerator.randomRow(random, schema)
+      val rowConverter = RowEncoder(schema)
+      val internalRow = rowConverter.toRow(row)
+
+      val unsafeRowConverter = UnsafeProjection.create(schema)
+      val safeRowConverter = FromUnsafeProjection(schema)
+
+      val unsafeRow = unsafeRowConverter(internalRow)
+      val safeRow = safeRowConverter(unsafeRow)
+
+      val genericArrayData = safeRow.getArray(0).asInstanceOf[GenericArrayData]
+      val unsafeArrayData = unsafeRow.getArray(0).asInstanceOf[UnsafeArrayData]
+
+      val elementType = dt.elementType
+      test("ArrayDataIndexedSeq - UnsafeArrayData - " + dt.toString) {
+        compArray(unsafeArrayData, elementType, 
unsafeArrayData.toArray[Any](elementType))
+      }
+
+      test("ArrayDataIndexedSeq - GenericArrayData - " + dt.toString) {
+        compArray(genericArrayData, elementType, 
genericArrayData.toArray[Any](elementType))
+      }
+    }
+  }
+
+  testArrayData()
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to