This is an automated email from the ASF dual-hosted git repository.

hvanhovell pushed a commit to branch branch-3.1
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.1 by this push:
     new fe412b6  [SPARK-35898][SQL] Fix arrays and maps in RowToColumnConverter
fe412b6 is described below

commit fe412b666a2d37ef34202eb9cf47aa469c336798
Author: Tom van Bussel <tom.vanbus...@databricks.com>
AuthorDate: Mon Jun 28 16:50:53 2021 +0200

    [SPARK-35898][SQL] Fix arrays and maps in RowToColumnConverter
    
    ### What changes were proposed in this pull request?
    
    This PR fixes support for arrays and maps in `RowToColumnConverter`. In 
particular this PR fixes two bugs:
    
    1. `appendArray` in `WritableColumnVector` does not reserve any elements in 
its child arrays, which causes the assertion in `OffHeapColumnVector.putArray` 
to fail.
    2. The nullability of the child columns is propagated incorrectly when 
creating the child converters of `ArrayConverter` and `MapConverter` in 
`RowToColumnConverter`.
    
    This PR fixes these issues.
    
    ### Why are the changes needed?
    
    Both bugs cause an exception to be thrown.
    
    ### Does this PR introduce _any_ user-facing change?
    
    No
    
    ### How was this patch tested?
    
    I added additional test cases to `ColumnVectorSuite` to catch the first 
bug, and I added `RowToColumnConverterSuite` to catch the both bugs (but 
specifically the second).
    
    Closes #33108 from tomvanbussel/SPARK-35898.
    
    Authored-by: Tom van Bussel <tom.vanbus...@databricks.com>
    Signed-off-by: herman <her...@databricks.com>
    (cherry picked from commit c6606502a2e338c0e973e5772a8cc44126ae2fde)
    Signed-off-by: herman <her...@databricks.com>
---
 .../execution/vectorized/WritableColumnVector.java |   3 +
 .../org/apache/spark/sql/execution/Columnar.scala  |   6 +-
 .../sql/execution/RowToColumnConverterSuite.scala  | 145 +++++++++++++++++++++
 .../execution/vectorized/ColumnVectorSuite.scala   |  87 +++++++++++++
 4 files changed, 238 insertions(+), 3 deletions(-)

diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
 
b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
index 8c0f1e1..97a685a 100644
--- 
a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
@@ -611,6 +611,9 @@ public abstract class WritableColumnVector extends 
ColumnVector {
 
   public final int appendArray(int length) {
     reserve(elementsAppended + 1);
+    for (WritableColumnVector childColumn : childColumns) {
+      childColumn.reserve(childColumn.elementsAppended + length);
+    }
     putArray(elementsAppended, arrayData().elementsAppended, length);
     return elementsAppended++;
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala
index 8d54279..ccb525d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Columnar.scala
@@ -261,12 +261,12 @@ private object RowToColumnConverter {
       case DoubleType => DoubleConverter
       case StringType => StringConverter
       case CalendarIntervalType => CalendarConverter
-      case at: ArrayType => new 
ArrayConverter(getConverterForType(at.elementType, nullable))
+      case at: ArrayType => ArrayConverter(getConverterForType(at.elementType, 
at.containsNull))
       case st: StructType => new StructConverter(st.fields.map(
         (f) => getConverterForType(f.dataType, f.nullable)))
       case dt: DecimalType => new DecimalConverter(dt)
-      case mt: MapType => new MapConverter(getConverterForType(mt.keyType, 
nullable),
-        getConverterForType(mt.valueType, nullable))
+      case mt: MapType => MapConverter(getConverterForType(mt.keyType, 
nullable = false),
+        getConverterForType(mt.valueType, mt.valueContainsNull))
       case unknown => throw new UnsupportedOperationException(
         s"Type $unknown not supported")
     }
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowToColumnConverterSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowToColumnConverterSuite.scala
new file mode 100644
index 0000000..1afe742
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowToColumnConverterSuite.scala
@@ -0,0 +1,145 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, 
GenericArrayData}
+import org.apache.spark.sql.execution.vectorized.{OnHeapColumnVector, 
WritableColumnVector}
+import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
+
+class RowToColumnConverterSuite extends SparkFunSuite {
+  def convertRows(rows: Seq[InternalRow], schema: StructType): 
Seq[WritableColumnVector] = {
+    val converter = new RowToColumnConverter(schema)
+    val vectors =
+      schema.map(f => new OnHeapColumnVector(5, 
f.dataType)).toArray[WritableColumnVector]
+    for (row <- rows) {
+      converter.convert(row, vectors)
+    }
+    vectors
+  }
+
+  test("integer column") {
+    val schema = StructType(Seq(StructField("i", IntegerType)))
+    val rows = (0 until 100).map(i => InternalRow(i))
+    val vectors = convertRows(rows, schema)
+    rows.zipWithIndex.map { case (row, i) =>
+      assert(vectors.head.getInt(i) === row.getInt(0))
+    }
+  }
+
+  test("array column") {
+    val schema = StructType(Seq(StructField("a", ArrayType(IntegerType))))
+    val rows = (0 until 100).map { i =>
+      InternalRow(new GenericArrayData(0 until i))
+    }
+    val vectors = convertRows(rows, schema)
+    rows.zipWithIndex.map { case (row, i) =>
+      assert(vectors.head.getArray(i).array().array === row.getArray(0).array)
+    }
+  }
+
+  test("non-nullable array column with null elements") {
+    val arrayType = ArrayType(IntegerType, containsNull = true)
+    val schema = StructType(Seq(StructField("a", arrayType, nullable = false)))
+    val rows = (0 until 100).map { i =>
+      InternalRow(new GenericArrayData((0 until i).map { j =>
+        if (j % 3 == 0) {
+          null
+        } else {
+          j
+        }
+      }))
+    }
+    val vectors = convertRows(rows, schema)
+    rows.zipWithIndex.map { case (row, i) =>
+      assert(vectors.head.getArray(i).array().array === row.getArray(0).array)
+    }
+  }
+
+  test("nested array column") {
+    val arrayType = ArrayType(ArrayType(IntegerType))
+    val schema = StructType(Seq(StructField("a", arrayType)))
+    val rows = (0 until 100).map { i =>
+      InternalRow(new GenericArrayData((0 until i).map(j => new 
GenericArrayData(0 until j))))
+    }
+    val vectors = convertRows(rows, schema)
+    rows.zipWithIndex.map { case (row, i) =>
+      val result = vectors.head.getArray(i).array().array
+        .map(_.asInstanceOf[ArrayData].array)
+      val expected = row.getArray(0).array
+        .map(_.asInstanceOf[ArrayData].array)
+      assert(result === expected)
+    }
+  }
+
+  test("map column") {
+    val mapType = MapType(IntegerType, StringType)
+    val schema = StructType(Seq(StructField("m", mapType)))
+    val rows = (0 until 100).map { i =>
+      InternalRow(new ArrayBasedMapData(
+        new GenericArrayData(0 until i),
+        new GenericArrayData((0 until i).map(j => 
UTF8String.fromString(s"str$j")))))
+    }
+    val vectors = convertRows(rows, schema)
+    rows.zipWithIndex.map { case (row, i) =>
+      val result = vectors.head.getMap(i)
+      val expected = row.getMap(0)
+      assert(result.keyArray().array().array === expected.keyArray().array)
+      assert(result.valueArray().array().array === expected.valueArray().array)
+    }
+  }
+
+  test("non-nullable map column with null values") {
+    val mapType = MapType(IntegerType, StringType, valueContainsNull = true)
+    val schema = StructType(Seq(StructField("m", mapType, nullable = false)))
+    val rows = (0 until 100).map { i =>
+      InternalRow(new ArrayBasedMapData(
+        new GenericArrayData(0 until i),
+        new GenericArrayData((0 until i).map { j =>
+          if (j % 3 == 0) {
+            null
+          } else {
+            UTF8String.fromString(s"str$j")
+          }
+        })))
+    }
+    val vectors = convertRows(rows, schema)
+    rows.zipWithIndex.map { case (row, i) =>
+      val result = vectors.head.getMap(i)
+      val expected = row.getMap(0)
+      assert(result.keyArray().array().array === expected.keyArray().array)
+      assert(result.valueArray().array().array === expected.valueArray().array)
+    }
+  }
+
+  test("multiple columns") {
+    val schema = StructType(
+      Seq(StructField("s", ShortType), StructField("i", IntegerType), 
StructField("l", LongType)))
+    val rows = (0 until 100).map { i =>
+      InternalRow((3 * i).toShort, 3 * i + 1, (3 * i + 2).toLong)
+    }
+    val vectors = convertRows(rows, schema)
+    rows.zipWithIndex.map { case (row, i) =>
+      assert(vectors(0).getShort(i) === row.getShort(0))
+      assert(vectors(1).getInt(i) === row.getInt(1))
+      assert(vectors(2).getLong(i) === row.getLong(2))
+    }
+  }
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala
index 247efd5..43f48ab 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala
@@ -243,6 +243,93 @@ class ColumnVectorSuite extends SparkFunSuite with 
BeforeAndAfterEach {
     assert(testVector.getArray(3).toIntArray() === Array(3, 4, 5))
   }
 
+  testVectors("SPARK-35898: array append", 1, arrayType) { testVector =>
+    // Populate it with arrays [0], [1, 2], [], [3, 4, 5]
+    val data = testVector.arrayData()
+    testVector.appendArray(1)
+    data.appendInt(0)
+    testVector.appendArray(2)
+    data.appendInt(1)
+    data.appendInt(2)
+    testVector.appendArray(0)
+    testVector.appendArray(3)
+    data.appendInt(3)
+    data.appendInt(4)
+    data.appendInt(5)
+
+    assert(testVector.getArray(0).toIntArray === Array(0))
+    assert(testVector.getArray(1).toIntArray === Array(1, 2))
+    assert(testVector.getArray(2).toIntArray === Array.empty[Int])
+    assert(testVector.getArray(3).toIntArray === Array(3, 4, 5))
+  }
+
+  val mapType: MapType = MapType(IntegerType, StringType)
+  testVectors("SPARK-35898: map", 5, mapType) { testVector =>
+    val keys = testVector.getChild(0)
+    val values = testVector.getChild(1)
+    var i = 0
+    while (i < 6) {
+      keys.appendInt(i)
+      val utf8 = s"str$i".getBytes("utf8")
+      values.appendByteArray(utf8, 0, utf8.length)
+      i += 1
+    }
+
+    testVector.putArray(0, 0, 1)
+    testVector.putArray(1, 1, 2)
+    testVector.putArray(2, 3, 0)
+    testVector.putArray(3, 3, 3)
+
+    assert(testVector.getMap(0).keyArray().toIntArray === Array(0))
+    assert(testVector.getMap(0).valueArray().toArray[UTF8String](StringType) 
===
+      Array(UTF8String.fromString(s"str0")))
+    assert(testVector.getMap(1).keyArray().toIntArray === Array(1, 2))
+    assert(testVector.getMap(1).valueArray().toArray[UTF8String](StringType) 
===
+      (1 to 2).map(i => UTF8String.fromString(s"str$i")).toArray)
+    assert(testVector.getMap(2).keyArray().toIntArray === Array.empty[Int])
+    assert(testVector.getMap(2).valueArray().toArray[UTF8String](StringType) 
===
+      Array.empty[UTF8String])
+    assert(testVector.getMap(3).keyArray().toIntArray === Array(3, 4, 5))
+    assert(testVector.getMap(3).valueArray().toArray[UTF8String](StringType) 
===
+      (3 to 5).map(i => UTF8String.fromString(s"str$i")).toArray)
+  }
+
+  testVectors("SPARK-35898: map append", 1, mapType) { testVector =>
+    val keys = testVector.getChild(0)
+    val values = testVector.getChild(1)
+    def appendPair(i: Int): Unit = {
+      keys.appendInt(i)
+      val utf8 = s"str$i".getBytes("utf8")
+      values.appendByteArray(utf8, 0, utf8.length)
+    }
+
+    // Populate it with the maps [0 -> str0], [1 -> str1, 2 -> str2], [],
+    // [3 -> str3, 4 -> str4, 5 -> str5]
+    testVector.appendArray(1)
+    appendPair(0)
+    testVector.appendArray(2)
+    appendPair(1)
+    appendPair(2)
+    testVector.appendArray(0)
+    testVector.appendArray(3)
+    appendPair(3)
+    appendPair(4)
+    appendPair(5)
+
+    assert(testVector.getMap(0).keyArray().toIntArray === Array(0))
+    assert(testVector.getMap(0).valueArray().toArray[UTF8String](StringType) 
===
+      Array(UTF8String.fromString(s"str0")))
+    assert(testVector.getMap(1).keyArray().toIntArray === Array(1, 2))
+    assert(testVector.getMap(1).valueArray().toArray[UTF8String](StringType) 
===
+      (1 to 2).map(i => UTF8String.fromString(s"str$i")).toArray)
+    assert(testVector.getMap(2).keyArray().toIntArray === Array.empty[Int])
+    assert(testVector.getMap(2).valueArray().toArray[UTF8String](StringType) 
===
+      Array.empty[UTF8String])
+    assert(testVector.getMap(3).keyArray().toIntArray === Array(3, 4, 5))
+    assert(testVector.getMap(3).valueArray().toArray[UTF8String](StringType) 
===
+      (3 to 5).map(i => UTF8String.fromString(s"str$i")).toArray)
+  }
+
   val structType: StructType = new StructType().add("int", 
IntegerType).add("double", DoubleType)
   testVectors("struct", 10, structType) { testVector =>
     val c1 = testVector.getChild(0)

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to