This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 76ce6b00e036 [SPARK-48019] Fix incorrect behavior in 
ColumnVector/ColumnarArray with dictionary and nulls
76ce6b00e036 is described below

commit 76ce6b00e036a699ad172ba4b7d3f2632ab75332
Author: Gene Pang <gene.p...@databricks.com>
AuthorDate: Sun Apr 28 11:07:12 2024 +0800

    [SPARK-48019] Fix incorrect behavior in ColumnVector/ColumnarArray with 
dictionary and nulls
    
    ### What changes were proposed in this pull request?
    
    This fixes how `ColumnVector` handles copying arrays when the vector has a 
dictionary and null values. The possible issues with the previous 
implementation:
    - An `ArrayIndexOutOfBoundsException` may be thrown when the `ColumnVector` 
has nulls and dictionaries. This is because the dictionary id for `null` 
entries might be invalid and should not be used for `null` entries.
    - Copying a `ColumnarArray` (which contains a `ColumnVector`) is incorrect, 
if it contains `null` entries. This is because copying a primitive array does 
not take into account the `null` entries, so all the null entries get lost.
    
    ### Why are the changes needed?
    
    These changes are needed to avoid `ArrayIndexOutOfBoundsException` and to 
produce correct results when copying `ColumnarArray`.
    
    ### Does this PR introduce _any_ user-facing change?
    
    The only user facing changes are to fix existing errors and incorrect 
results.
    
    ### How was this patch tested?
    
    Added new unit tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No.
    
    Closes #46254 from gene-db/dictionary-nulls.
    
    Authored-by: Gene Pang <gene.p...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../apache/spark/sql/vectorized/ColumnarArray.java |   5 +
 .../execution/vectorized/OffHeapColumnVector.java  |  24 ++-
 .../execution/vectorized/OnHeapColumnVector.java   |  24 ++-
 .../execution/vectorized/ColumnVectorSuite.scala   | 174 +++++++++++++++++++++
 4 files changed, 215 insertions(+), 12 deletions(-)

diff --git 
a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java 
b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java
index 4163af9bfda5..d92293b91870 100644
--- 
a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java
+++ 
b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java
@@ -53,6 +53,11 @@ public final class ColumnarArray extends ArrayData {
   public ArrayData copy() {
     DataType dt = data.dataType();
 
+    if (data.hasNull()) {
+      // UnsafeArrayData cannot be used if there are any nulls.
+      return new GenericArrayData(toObjectArray(dt)).copy();
+    }
+
     if (dt instanceof BooleanType) {
       return UnsafeArrayData.fromPrimitiveArray(toBooleanArray());
     } else if (dt instanceof ByteType) {
diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
 
b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
index 2bb0b02d4c9c..1882d990bef5 100644
--- 
a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
@@ -218,7 +218,9 @@ public final class OffHeapColumnVector extends 
WritableColumnVector {
       Platform.copyMemory(null, data + rowId, array, 
Platform.BYTE_ARRAY_OFFSET, count);
     } else {
       for (int i = 0; i < count; i++) {
-        array[i] = (byte) dictionary.decodeToInt(dictionaryIds.getDictId(rowId 
+ i));
+        if (!isNullAt(rowId + i)) {
+          array[i] = (byte) 
dictionary.decodeToInt(dictionaryIds.getDictId(rowId + i));
+        }
       }
     }
     return array;
@@ -279,7 +281,9 @@ public final class OffHeapColumnVector extends 
WritableColumnVector {
       Platform.copyMemory(null, data + rowId * 2L, array, 
Platform.SHORT_ARRAY_OFFSET, count * 2L);
     } else {
       for (int i = 0; i < count; i++) {
-        array[i] = (short) 
dictionary.decodeToInt(dictionaryIds.getDictId(rowId + i));
+        if (!isNullAt(rowId + i)) {
+          array[i] = (short) 
dictionary.decodeToInt(dictionaryIds.getDictId(rowId + i));
+        }
       }
     }
     return array;
@@ -345,7 +349,9 @@ public final class OffHeapColumnVector extends 
WritableColumnVector {
       Platform.copyMemory(null, data + rowId * 4L, array, 
Platform.INT_ARRAY_OFFSET, count * 4L);
     } else {
       for (int i = 0; i < count; i++) {
-        array[i] = dictionary.decodeToInt(dictionaryIds.getDictId(rowId + i));
+        if (!isNullAt(rowId + i)) {
+          array[i] = dictionary.decodeToInt(dictionaryIds.getDictId(rowId + 
i));
+        }
       }
     }
     return array;
@@ -423,7 +429,9 @@ public final class OffHeapColumnVector extends 
WritableColumnVector {
       Platform.copyMemory(null, data + rowId * 8L, array, 
Platform.LONG_ARRAY_OFFSET, count * 8L);
     } else {
       for (int i = 0; i < count; i++) {
-        array[i] = dictionary.decodeToLong(dictionaryIds.getDictId(rowId + i));
+        if (!isNullAt(rowId + i)) {
+          array[i] = dictionary.decodeToLong(dictionaryIds.getDictId(rowId + 
i));
+        }
       }
     }
     return array;
@@ -487,7 +495,9 @@ public final class OffHeapColumnVector extends 
WritableColumnVector {
       Platform.copyMemory(null, data + rowId * 4L, array, 
Platform.FLOAT_ARRAY_OFFSET, count * 4L);
     } else {
       for (int i = 0; i < count; i++) {
-        array[i] = dictionary.decodeToFloat(dictionaryIds.getDictId(rowId + 
i));
+        if (!isNullAt(rowId + i)) {
+          array[i] = dictionary.decodeToFloat(dictionaryIds.getDictId(rowId + 
i));
+        }
       }
     }
     return array;
@@ -553,7 +563,9 @@ public final class OffHeapColumnVector extends 
WritableColumnVector {
         count * 8L);
     } else {
       for (int i = 0; i < count; i++) {
-        array[i] = dictionary.decodeToDouble(dictionaryIds.getDictId(rowId + 
i));
+        if (!isNullAt(rowId + i)) {
+          array[i] = dictionary.decodeToDouble(dictionaryIds.getDictId(rowId + 
i));
+        }
       }
     }
     return array;
diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
 
b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
index 2bf2b8d08fce..1908b511269a 100644
--- 
a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
@@ -216,7 +216,9 @@ public final class OnHeapColumnVector extends 
WritableColumnVector {
       System.arraycopy(byteData, rowId, array, 0, count);
     } else {
       for (int i = 0; i < count; i++) {
-        array[i] = (byte) dictionary.decodeToInt(dictionaryIds.getDictId(rowId 
+ i));
+        if (!isNullAt(rowId + i)) {
+          array[i] = (byte) 
dictionary.decodeToInt(dictionaryIds.getDictId(rowId + i));
+        }
       }
     }
     return array;
@@ -276,7 +278,9 @@ public final class OnHeapColumnVector extends 
WritableColumnVector {
       System.arraycopy(shortData, rowId, array, 0, count);
     } else {
       for (int i = 0; i < count; i++) {
-        array[i] = (short) 
dictionary.decodeToInt(dictionaryIds.getDictId(rowId + i));
+        if (!isNullAt(rowId + i)) {
+          array[i] = (short) 
dictionary.decodeToInt(dictionaryIds.getDictId(rowId + i));
+        }
       }
     }
     return array;
@@ -337,7 +341,9 @@ public final class OnHeapColumnVector extends 
WritableColumnVector {
       System.arraycopy(intData, rowId, array, 0, count);
     } else {
       for (int i = 0; i < count; i++) {
-        array[i] = dictionary.decodeToInt(dictionaryIds.getDictId(rowId + i));
+        if (!isNullAt(rowId + i)) {
+          array[i] = dictionary.decodeToInt(dictionaryIds.getDictId(rowId + 
i));
+        }
       }
     }
     return array;
@@ -409,7 +415,9 @@ public final class OnHeapColumnVector extends 
WritableColumnVector {
       System.arraycopy(longData, rowId, array, 0, count);
     } else {
       for (int i = 0; i < count; i++) {
-        array[i] = dictionary.decodeToLong(dictionaryIds.getDictId(rowId + i));
+        if (!isNullAt(rowId + i)) {
+          array[i] = dictionary.decodeToLong(dictionaryIds.getDictId(rowId + 
i));
+        }
       }
     }
     return array;
@@ -466,7 +474,9 @@ public final class OnHeapColumnVector extends 
WritableColumnVector {
       System.arraycopy(floatData, rowId, array, 0, count);
     } else {
       for (int i = 0; i < count; i++) {
-        array[i] = dictionary.decodeToFloat(dictionaryIds.getDictId(rowId + 
i));
+        if (!isNullAt(rowId + i)) {
+          array[i] = dictionary.decodeToFloat(dictionaryIds.getDictId(rowId + 
i));
+        }
       }
     }
     return array;
@@ -525,7 +535,9 @@ public final class OnHeapColumnVector extends 
WritableColumnVector {
       System.arraycopy(doubleData, rowId, array, 0, count);
     } else {
       for (int i = 0; i < count; i++) {
-        array[i] = dictionary.decodeToDouble(dictionaryIds.getDictId(rowId + 
i));
+        if (!isNullAt(rowId + i)) {
+          array[i] = dictionary.decodeToDouble(dictionaryIds.getDictId(rowId + 
i));
+        }
       }
     }
     return array;
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala
index 7cce6086c6fd..aca968745d19 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala
@@ -476,6 +476,180 @@ class ColumnVectorSuite extends SparkFunSuite with 
SQLHelper {
     assert(testVector.getDoubles(0, 3)(2) == 1342.17729d)
   }
 
+  def check(expected: Seq[Any], testVector: WritableColumnVector): Unit = {
+    expected.zipWithIndex.foreach {
+      case (v: Integer, idx) =>
+        assert(testVector.getInt(idx) == v)
+        assert(testVector.getInts(0, testVector.capacity)(idx) == v)
+      case (v: Short, idx) =>
+        assert(testVector.getShort(idx) == v)
+        assert(testVector.getShorts(0, testVector.capacity)(idx) == v)
+      case (v: Byte, idx) =>
+        assert(testVector.getByte(idx) == v)
+        assert(testVector.getBytes(0, testVector.capacity)(idx) == v)
+      case (v: Long, idx) =>
+        assert(testVector.getLong(idx) == v)
+        assert(testVector.getLongs(0, testVector.capacity)(idx) == v)
+      case (v: Float, idx) =>
+        assert(testVector.getFloat(idx) == v)
+        assert(testVector.getFloats(0, testVector.capacity)(idx) == v)
+      case (v: Double, idx) =>
+        assert(testVector.getDouble(idx) == v)
+        assert(testVector.getDoubles(0, testVector.capacity)(idx) == v)
+      case (null, idx) => testVector.isNullAt(idx)
+      case (_, idx) => assert(false, s"Unexpected value at $idx")
+    }
+
+    // Verify ColumnarArray.copy() works as expected
+    val arr = new ColumnarArray(testVector, 0, testVector.capacity)
+    assert(arr.toSeq(testVector.dataType) == expected)
+    assert(arr.copy().toSeq(testVector.dataType) == expected)
+  }
+
+  testVectors("getInts with dictionary and nulls", 3, IntegerType) { 
testVector =>
+    // Validate without dictionary
+    val expected = Seq(1, null, 3)
+    expected.foreach {
+      case i: Integer => testVector.appendInt(i)
+      case _ => testVector.appendNull()
+    }
+    check(expected, testVector)
+
+    // Validate with dictionary
+    val expectedDictionary = Seq(7, null, 9)
+    val dictArray = (Seq(-1, -1) ++ expectedDictionary.map {
+      case i: Integer => i.toInt
+      case _ => -1
+    }).toArray
+    val dict = new ColumnDictionary(dictArray)
+    testVector.setDictionary(dict)
+    testVector.reserveDictionaryIds(3)
+    testVector.getDictionaryIds.putInt(0, 2)
+    testVector.getDictionaryIds.putInt(1, -1) // This is a null, so the entry 
should be ignored
+    testVector.getDictionaryIds.putInt(2, 4)
+    check(expectedDictionary, testVector)
+  }
+
+  testVectors("getShorts with dictionary and nulls", 3, ShortType) { 
testVector =>
+    // Validate without dictionary
+    val expected = Seq(1.toShort, null, 3.toShort)
+    expected.foreach {
+      case i: Short => testVector.appendShort(i)
+      case _ => testVector.appendNull()
+    }
+    check(expected, testVector)
+
+    // Validate with dictionary
+    val expectedDictionary = Seq(7.toShort, null, 9.toShort)
+    val dictArray = (Seq(-1, -1) ++ expectedDictionary.map {
+      case i: Short => i.toInt
+      case _ => -1
+    }).toArray
+    val dict = new ColumnDictionary(dictArray)
+    testVector.setDictionary(dict)
+    testVector.reserveDictionaryIds(3)
+    testVector.getDictionaryIds.putInt(0, 2)
+    testVector.getDictionaryIds.putInt(1, -1) // This is a null, so the entry 
should be ignored
+    testVector.getDictionaryIds.putInt(2, 4)
+    check(expectedDictionary, testVector)
+  }
+
+  testVectors("getBytes with dictionary and nulls", 3, ByteType) { testVector 
=>
+    // Validate without dictionary
+    val expected = Seq(1.toByte, null, 3.toByte)
+    expected.foreach {
+      case i: Byte => testVector.appendByte(i)
+      case _ => testVector.appendNull()
+    }
+    check(expected, testVector)
+
+    // Validate with dictionary
+    val expectedDictionary = Seq(7.toByte, null, 9.toByte)
+    val dictArray = (Seq(-1, -1) ++ expectedDictionary.map {
+      case i: Byte => i.toInt
+      case _ => -1
+    }).toArray
+    val dict = new ColumnDictionary(dictArray)
+    testVector.setDictionary(dict)
+    testVector.reserveDictionaryIds(3)
+    testVector.getDictionaryIds.putInt(0, 2)
+    testVector.getDictionaryIds.putInt(1, -1) // This is a null, so the entry 
should be ignored
+    testVector.getDictionaryIds.putInt(2, 4)
+    check(expectedDictionary, testVector)
+  }
+
+  testVectors("getLongs with dictionary and nulls", 3, LongType) { testVector 
=>
+    // Validate without dictionary
+    val expected = Seq(2147483L, null, 2147485L)
+    expected.foreach {
+      case i: Long => testVector.appendLong(i)
+      case _ => testVector.appendNull()
+    }
+    check(expected, testVector)
+
+    // Validate with dictionary
+    val expectedDictionary = Seq(2147483648L, null, 2147483650L)
+    val dictArray = (Seq(-1L, -1L) ++ expectedDictionary.map {
+      case i: Long => i
+      case _ => -1L
+    }).toArray
+    val dict = new ColumnDictionary(dictArray)
+    testVector.setDictionary(dict)
+    testVector.reserveDictionaryIds(3)
+    testVector.getDictionaryIds.putInt(0, 2)
+    testVector.getDictionaryIds.putInt(1, -1) // This is a null, so the entry 
should be ignored
+    testVector.getDictionaryIds.putInt(2, 4)
+    check(expectedDictionary, testVector)
+  }
+
+  testVectors("getFloats with dictionary and nulls", 3, FloatType) { 
testVector =>
+    // Validate without dictionary
+    val expected = Seq(1.1f, null, 3.3f)
+    expected.foreach {
+      case i: Float => testVector.appendFloat(i)
+      case _ => testVector.appendNull()
+    }
+    check(expected, testVector)
+
+    // Validate with dictionary
+    val expectedDictionary = Seq(0.1f, null, 0.3f)
+    val dictArray = (Seq(-1f, -1f) ++ expectedDictionary.map {
+      case i: Float => i
+      case _ => -1f
+    }).toArray
+    val dict = new ColumnDictionary(dictArray)
+    testVector.setDictionary(dict)
+    testVector.reserveDictionaryIds(3)
+    testVector.getDictionaryIds.putInt(0, 2)
+    testVector.getDictionaryIds.putInt(1, -1) // This is a null, so the entry 
should be ignored
+    testVector.getDictionaryIds.putInt(2, 4)
+    check(expectedDictionary, testVector)
+  }
+
+  testVectors("getDoubles with dictionary and nulls", 3, DoubleType) { 
testVector =>
+    // Validate without dictionary
+    val expected = Seq(1.1d, null, 3.3d)
+    expected.foreach {
+      case i: Double => testVector.appendDouble(i)
+      case _ => testVector.appendNull()
+    }
+    check(expected, testVector)
+
+    // Validate with dictionary
+    val expectedDictionary = Seq(1342.17727d, null, 1342.17729d)
+    val dictArray = (Seq(-1d, -1d) ++ expectedDictionary.map {
+      case i: Double => i
+      case _ => -1d
+    }).toArray
+    val dict = new ColumnDictionary(dictArray)
+    testVector.setDictionary(dict)
+    testVector.reserveDictionaryIds(3)
+    testVector.getDictionaryIds.putInt(0, 2)
+    testVector.getDictionaryIds.putInt(1, -1) // This is a null, so the entry 
should be ignored
+    testVector.getDictionaryIds.putInt(2, 4)
+    check(expectedDictionary, testVector)
+  }
+
   test("[SPARK-22092] off-heap column vector reallocation corrupts array 
data") {
     withVector(new OffHeapColumnVector(8, arrayType)) { testVector =>
       val data = testVector.arrayData()


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to