Repository: spark
Updated Branches:
  refs/heads/master 894a7561d -> 3b049abf1


[SPARK-22003][SQL] support array column in vectorized reader with UDF

## What changes were proposed in this pull request?

The UDF needs to deserialize the `UnsafeRow`. When the column type is Array, 
the `get` method from the `ColumnVector`, which is used by the vectorized 
reader, is called, but this method is not implemented.

## How was this patch tested?

(Please explain how this patch was tested. E.g. unit tests, integration tests, 
manual tests)
(If this patch involves UI changes, please attach a screenshot; otherwise, 
remove this)

Please review http://spark.apache.org/contributing.html before opening a pull 
request.

Author: Feng Liu <feng...@databricks.com>

Closes #19230 from liufengdb/fix_array_open.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/3b049abf
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/3b049abf
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/3b049abf

Branch: refs/heads/master
Commit: 3b049abf102908ca72674139367e3b8d9ffcc283
Parents: 894a756
Author: Feng Liu <feng...@databricks.com>
Authored: Mon Sep 18 08:49:32 2017 -0700
Committer: gatorsmile <gatorsm...@gmail.com>
Committed: Mon Sep 18 08:49:32 2017 -0700

----------------------------------------------------------------------
 .../sql/execution/vectorized/ColumnVector.java  | 103 ++++------
 .../vectorized/ColumnVectorSuite.scala          | 201 +++++++++++++++++++
 2 files changed, 242 insertions(+), 62 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/3b049abf/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
 
b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
index a69dd97..c4b519f 100644
--- 
a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java
@@ -100,72 +100,16 @@ public abstract class ColumnVector implements 
AutoCloseable {
     public Object[] array() {
       DataType dt = data.dataType();
       Object[] list = new Object[length];
-
-      if (dt instanceof BooleanType) {
-        for (int i = 0; i < length; i++) {
-          if (!data.isNullAt(offset + i)) {
-            list[i] = data.getBoolean(offset + i);
-          }
-        }
-      } else if (dt instanceof ByteType) {
-        for (int i = 0; i < length; i++) {
-          if (!data.isNullAt(offset + i)) {
-            list[i] = data.getByte(offset + i);
-          }
-        }
-      } else if (dt instanceof ShortType) {
-        for (int i = 0; i < length; i++) {
-          if (!data.isNullAt(offset + i)) {
-            list[i] = data.getShort(offset + i);
-          }
-        }
-      } else if (dt instanceof IntegerType) {
-        for (int i = 0; i < length; i++) {
-          if (!data.isNullAt(offset + i)) {
-            list[i] = data.getInt(offset + i);
-          }
-        }
-      } else if (dt instanceof FloatType) {
-        for (int i = 0; i < length; i++) {
-          if (!data.isNullAt(offset + i)) {
-            list[i] = data.getFloat(offset + i);
-          }
-        }
-      } else if (dt instanceof DoubleType) {
+      try {
         for (int i = 0; i < length; i++) {
           if (!data.isNullAt(offset + i)) {
-            list[i] = data.getDouble(offset + i);
+            list[i] = get(i, dt);
           }
         }
-      } else if (dt instanceof LongType) {
-        for (int i = 0; i < length; i++) {
-          if (!data.isNullAt(offset + i)) {
-            list[i] = data.getLong(offset + i);
-          }
-        }
-      } else if (dt instanceof DecimalType) {
-        DecimalType decType = (DecimalType)dt;
-        for (int i = 0; i < length; i++) {
-          if (!data.isNullAt(offset + i)) {
-            list[i] = getDecimal(i, decType.precision(), decType.scale());
-          }
-        }
-      } else if (dt instanceof StringType) {
-        for (int i = 0; i < length; i++) {
-          if (!data.isNullAt(offset + i)) {
-            list[i] = getUTF8String(i).toString();
-          }
-        }
-      } else if (dt instanceof CalendarIntervalType) {
-        for (int i = 0; i < length; i++) {
-          if (!data.isNullAt(offset + i)) {
-            list[i] = getInterval(i);
-          }
-        }
-      } else {
-        throw new UnsupportedOperationException("Type " + dt);
+        return list;
+      } catch(Exception e) {
+        throw new RuntimeException("Could not get the array", e);
       }
-      return list;
     }
 
     @Override
@@ -237,7 +181,42 @@ public abstract class ColumnVector implements 
AutoCloseable {
 
     @Override
     public Object get(int ordinal, DataType dataType) {
-      throw new UnsupportedOperationException();
+      if (dataType instanceof BooleanType) {
+        return getBoolean(ordinal);
+      } else if (dataType instanceof ByteType) {
+        return getByte(ordinal);
+      } else if (dataType instanceof ShortType) {
+        return getShort(ordinal);
+      } else if (dataType instanceof IntegerType) {
+        return getInt(ordinal);
+      } else if (dataType instanceof LongType) {
+        return getLong(ordinal);
+      } else if (dataType instanceof FloatType) {
+        return getFloat(ordinal);
+      } else if (dataType instanceof DoubleType) {
+        return getDouble(ordinal);
+      } else if (dataType instanceof StringType) {
+        return getUTF8String(ordinal);
+      } else if (dataType instanceof BinaryType) {
+        return getBinary(ordinal);
+      } else if (dataType instanceof DecimalType) {
+        DecimalType t = (DecimalType) dataType;
+        return getDecimal(ordinal, t.precision(), t.scale());
+      } else if (dataType instanceof DateType) {
+        return getInt(ordinal);
+      } else if (dataType instanceof TimestampType) {
+        return getLong(ordinal);
+      } else if (dataType instanceof ArrayType) {
+        return getArray(ordinal);
+      } else if (dataType instanceof StructType) {
+        return getStruct(ordinal, ((StructType)dataType).fields().length);
+      } else if (dataType instanceof MapType) {
+        return getMap(ordinal);
+      } else if (dataType instanceof CalendarIntervalType) {
+        return getInterval(ordinal);
+      } else {
+        throw new UnsupportedOperationException("Datatype not supported " + 
dataType);
+      }
     }
 
     @Override

http://git-wip-us.apache.org/repos/asf/spark/blob/3b049abf/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala
new file mode 100644
index 0000000..998067a
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.vectorized
+
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.util.ArrayData
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach {
+
+  var testVector: WritableColumnVector = _
+
+  private def allocate(capacity: Int, dt: DataType): WritableColumnVector = {
+    new OnHeapColumnVector(capacity, dt)
+  }
+
+  override def afterEach(): Unit = {
+    testVector.close()
+  }
+
+  test("boolean") {
+    testVector = allocate(10, BooleanType)
+    (0 until 10).foreach { i =>
+      testVector.appendBoolean(i % 2 == 0)
+    }
+
+    val array = new ColumnVector.Array(testVector)
+
+    (0 until 10).foreach { i =>
+      assert(array.get(i, BooleanType) === (i % 2 == 0))
+    }
+  }
+
+  test("byte") {
+    testVector = allocate(10, ByteType)
+    (0 until 10).foreach { i =>
+      testVector.appendByte(i.toByte)
+    }
+
+    val array = new ColumnVector.Array(testVector)
+
+    (0 until 10).foreach { i =>
+      assert(array.get(i, ByteType) === (i.toByte))
+    }
+  }
+
+  test("short") {
+    testVector = allocate(10, ShortType)
+    (0 until 10).foreach { i =>
+      testVector.appendShort(i.toShort)
+    }
+
+    val array = new ColumnVector.Array(testVector)
+
+    (0 until 10).foreach { i =>
+      assert(array.get(i, ShortType) === (i.toShort))
+    }
+  }
+
+  test("int") {
+    testVector = allocate(10, IntegerType)
+    (0 until 10).foreach { i =>
+      testVector.appendInt(i)
+    }
+
+    val array = new ColumnVector.Array(testVector)
+
+    (0 until 10).foreach { i =>
+      assert(array.get(i, IntegerType) === i)
+    }
+  }
+
+  test("long") {
+    testVector = allocate(10, LongType)
+    (0 until 10).foreach { i =>
+      testVector.appendLong(i)
+    }
+
+    val array = new ColumnVector.Array(testVector)
+
+    (0 until 10).foreach { i =>
+      assert(array.get(i, LongType) === i)
+    }
+  }
+
+  test("float") {
+    testVector = allocate(10, FloatType)
+    (0 until 10).foreach { i =>
+      testVector.appendFloat(i.toFloat)
+    }
+
+    val array = new ColumnVector.Array(testVector)
+
+    (0 until 10).foreach { i =>
+      assert(array.get(i, FloatType) === i.toFloat)
+    }
+  }
+
+  test("double") {
+    testVector = allocate(10, DoubleType)
+    (0 until 10).foreach { i =>
+      testVector.appendDouble(i.toDouble)
+    }
+
+    val array = new ColumnVector.Array(testVector)
+
+    (0 until 10).foreach { i =>
+      assert(array.get(i, DoubleType) === i.toDouble)
+    }
+  }
+
+  test("string") {
+    testVector = allocate(10, StringType)
+    (0 until 10).map { i =>
+      val utf8 = s"str$i".getBytes("utf8")
+      testVector.appendByteArray(utf8, 0, utf8.length)
+    }
+
+    val array = new ColumnVector.Array(testVector)
+
+    (0 until 10).foreach { i =>
+      assert(array.get(i, StringType) === UTF8String.fromString(s"str$i"))
+    }
+  }
+
+  test("binary") {
+    testVector = allocate(10, BinaryType)
+    (0 until 10).map { i =>
+      val utf8 = s"str$i".getBytes("utf8")
+      testVector.appendByteArray(utf8, 0, utf8.length)
+    }
+
+    val array = new ColumnVector.Array(testVector)
+
+    (0 until 10).foreach { i =>
+      val utf8 = s"str$i".getBytes("utf8")
+      assert(array.get(i, BinaryType) === utf8)
+    }
+  }
+
+  test("array") {
+    val arrayType = ArrayType(IntegerType, true)
+    testVector = allocate(10, arrayType)
+
+    val data = testVector.arrayData()
+    var i = 0
+    while (i < 6) {
+      data.putInt(i, i)
+      i += 1
+    }
+
+    // Populate it with arrays [0], [1, 2], [], [3, 4, 5]
+    testVector.putArray(0, 0, 1)
+    testVector.putArray(1, 1, 2)
+    testVector.putArray(2, 3, 0)
+    testVector.putArray(3, 3, 3)
+
+    val array = new ColumnVector.Array(testVector)
+
+    assert(array.get(0, arrayType).asInstanceOf[ArrayData].toIntArray() === 
Array(0))
+    assert(array.get(1, arrayType).asInstanceOf[ArrayData].toIntArray() === 
Array(1, 2))
+    assert(array.get(2, arrayType).asInstanceOf[ArrayData].toIntArray() === 
Array.empty[Int])
+    assert(array.get(3, arrayType).asInstanceOf[ArrayData].toIntArray() === 
Array(3, 4, 5))
+  }
+
+  test("struct") {
+    val schema = new StructType().add("int", IntegerType).add("double", 
DoubleType)
+    testVector = allocate(10, schema)
+    val c1 = testVector.getChildColumn(0)
+    val c2 = testVector.getChildColumn(1)
+    c1.putInt(0, 123)
+    c2.putDouble(0, 3.45)
+    c1.putInt(1, 456)
+    c2.putDouble(1, 5.67)
+
+    val array = new ColumnVector.Array(testVector)
+
+    assert(array.get(0, schema).asInstanceOf[ColumnarBatch.Row].get(0, 
IntegerType) === 123)
+    assert(array.get(0, schema).asInstanceOf[ColumnarBatch.Row].get(1, 
DoubleType) === 3.45)
+    assert(array.get(1, schema).asInstanceOf[ColumnarBatch.Row].get(0, 
IntegerType) === 456)
+    assert(array.get(1, schema).asInstanceOf[ColumnarBatch.Row].get(1, 
DoubleType) === 5.67)
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to