This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new e89f5537c Fix: Fix null handling in CometVector implementations (#2643)
e89f5537c is described below

commit e89f5537c93508122217843d8eaa3b789435fe85
Author: Fu Chen <[email protected]>
AuthorDate: Fri Nov 21 03:34:15 2025 +0800

    Fix: Fix null handling in CometVector implementations (#2643)
---
 .../org/apache/comet/vector/CometListVector.java   |   1 +
 .../org/apache/comet/vector/CometMapVector.java    |   1 +
 .../org/apache/comet/vector/CometPlainVector.java  |   2 +
 .../java/org/apache/comet/vector/CometVector.java  |   1 +
 native/spark-expr/src/array_funcs/array_insert.rs  | 191 +++++++++++++--------
 .../apache/comet/CometArrayExpressionSuite.scala   |  27 +++
 6 files changed, 152 insertions(+), 71 deletions(-)

diff --git a/common/src/main/java/org/apache/comet/vector/CometListVector.java 
b/common/src/main/java/org/apache/comet/vector/CometListVector.java
index 752495c0d..93e8e8bf9 100644
--- a/common/src/main/java/org/apache/comet/vector/CometListVector.java
+++ b/common/src/main/java/org/apache/comet/vector/CometListVector.java
@@ -45,6 +45,7 @@ public class CometListVector extends CometDecodedVector {
 
   @Override
   public ColumnarArray getArray(int i) {
+    if (isNullAt(i)) return null;
     int start = listVector.getOffsetBuffer().getInt(i * 
ListVector.OFFSET_WIDTH);
     int end = listVector.getOffsetBuffer().getInt((i + 1) * 
ListVector.OFFSET_WIDTH);
 
diff --git a/common/src/main/java/org/apache/comet/vector/CometMapVector.java 
b/common/src/main/java/org/apache/comet/vector/CometMapVector.java
index 1d531ca90..c5984a4dc 100644
--- a/common/src/main/java/org/apache/comet/vector/CometMapVector.java
+++ b/common/src/main/java/org/apache/comet/vector/CometMapVector.java
@@ -65,6 +65,7 @@ public class CometMapVector extends CometDecodedVector {
 
   @Override
   public ColumnarMap getMap(int i) {
+    if (isNullAt(i)) return null;
     int start = mapVector.getOffsetBuffer().getInt(i * MapVector.OFFSET_WIDTH);
     int end = mapVector.getOffsetBuffer().getInt((i + 1) * 
MapVector.OFFSET_WIDTH);
 
diff --git a/common/src/main/java/org/apache/comet/vector/CometPlainVector.java 
b/common/src/main/java/org/apache/comet/vector/CometPlainVector.java
index f3803d53a..2a30be1b1 100644
--- a/common/src/main/java/org/apache/comet/vector/CometPlainVector.java
+++ b/common/src/main/java/org/apache/comet/vector/CometPlainVector.java
@@ -123,6 +123,7 @@ public class CometPlainVector extends CometDecodedVector {
 
   @Override
   public UTF8String getUTF8String(int rowId) {
+    if (isNullAt(rowId)) return null;
     if (!isBaseFixedWidthVector) {
       BaseVariableWidthVector varWidthVector = (BaseVariableWidthVector) 
valueVector;
       long offsetBufferAddress = 
varWidthVector.getOffsetBuffer().memoryAddress();
@@ -147,6 +148,7 @@ public class CometPlainVector extends CometDecodedVector {
 
   @Override
   public byte[] getBinary(int rowId) {
+    if (isNullAt(rowId)) return null;
     int offset;
     int length;
     if (valueVector instanceof BaseVariableWidthVector) {
diff --git a/common/src/main/java/org/apache/comet/vector/CometVector.java 
b/common/src/main/java/org/apache/comet/vector/CometVector.java
index 0c6fa8f12..a1f75696f 100644
--- a/common/src/main/java/org/apache/comet/vector/CometVector.java
+++ b/common/src/main/java/org/apache/comet/vector/CometVector.java
@@ -85,6 +85,7 @@ public abstract class CometVector extends ColumnVector {
 
   @Override
   public Decimal getDecimal(int i, int precision, int scale) {
+    if (isNullAt(i)) return null;
     if (!useDecimal128 && precision <= Decimal.MAX_INT_DIGITS() && type 
instanceof IntegerType) {
       return createDecimal(getInt(i), precision, scale);
     } else if (precision <= Decimal.MAX_LONG_DIGITS()) {
diff --git a/native/spark-expr/src/array_funcs/array_insert.rs 
b/native/spark-expr/src/array_funcs/array_insert.rs
index eb96fec12..dcee441ce 100644
--- a/native/spark-expr/src/array_funcs/array_insert.rs
+++ b/native/spark-expr/src/array_funcs/array_insert.rs
@@ -16,11 +16,10 @@
 // under the License.
 
 use arrow::array::{make_array, Array, ArrayRef, GenericListArray, Int32Array, 
OffsetSizeTrait};
-use arrow::datatypes::{DataType, Field, Schema};
+use arrow::datatypes::{DataType, Schema};
 use arrow::{
     array::{as_primitive_array, Capacities, MutableArrayData},
     buffer::{NullBuffer, OffsetBuffer},
-    datatypes::ArrowNativeType,
     record_batch::RecordBatch,
 };
 use datafusion::common::{
@@ -198,114 +197,131 @@ fn array_insert<O: OffsetSizeTrait>(
     pos_array: &ArrayRef,
     legacy_mode: bool,
 ) -> DataFusionResult<ColumnarValue> {
-    // The code is based on the implementation of the array_append from the 
Apache DataFusion
-    // 
https://github.com/apache/datafusion/blob/main/datafusion/functions-nested/src/concat.rs#L513
-    //
-    // This code is also based on the implementation of the array_insert from 
the Apache Spark
-    // 
https://github.com/apache/spark/blob/branch-3.5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala#L4713
+    // Implementation aligned with Arrow's half-open offset ranges and Spark 
semantics.
 
     let values = list_array.values();
     let offsets = list_array.offsets();
     let values_data = values.to_data();
     let item_data = items_array.to_data();
+
+    // Estimate capacity (original values + inserted items upper bound)
     let new_capacity = Capacities::Array(values_data.len() + item_data.len());
 
     let mut mutable_values =
         MutableArrayData::with_capacities(vec![&values_data, &item_data], 
true, new_capacity);
 
-    let mut new_offsets = vec![O::usize_as(0)];
-    let mut new_nulls = Vec::<bool>::with_capacity(list_array.len());
+    // New offsets and top-level list validity bitmap
+    let mut new_offsets = Vec::with_capacity(list_array.len() + 1);
+    new_offsets.push(O::usize_as(0));
+    let mut list_valid = Vec::<bool>::with_capacity(list_array.len());
 
-    let pos_data: &Int32Array = as_primitive_array(&pos_array); // Spark 
supports only i32 for positions
+    // Spark supports only Int32 position indices
+    let pos_data: &Int32Array = as_primitive_array(&pos_array);
 
-    for (row_index, offset_window) in offsets.windows(2).enumerate() {
-        let pos = pos_data.values()[row_index];
-        let start = offset_window[0].as_usize();
-        let end = offset_window[1].as_usize();
-        let is_item_null = items_array.is_null(row_index);
+    for (row_index, window) in offsets.windows(2).enumerate() {
+        let start = window[0].as_usize();
+        let end = window[1].as_usize();
+        let len = end - start;
+
+        // Return null for the entire row when pos is null (consistent with 
Spark's behavior)
+        if pos_data.is_null(row_index) {
+            new_offsets.push(new_offsets[row_index]);
+            list_valid.push(false);
+            continue;
+        }
+        let pos = pos_data.value(row_index);
 
         if list_array.is_null(row_index) {
-            // In Spark if value of the array is NULL than nothing happens
-            mutable_values.extend_nulls(1);
-            new_offsets.push(new_offsets[row_index] + O::one());
-            new_nulls.push(false);
+            // Top-level list row is NULL: do not write any child values and 
do not advance offset
+            new_offsets.push(new_offsets[row_index]);
+            list_valid.push(false);
             continue;
         }
 
         if pos == 0 {
             return Err(DataFusionError::Internal(
-                "Position for array_insert should be greter or less than 
zero".to_string(),
+                "Position for array_insert should be greater or less than 
zero".to_string(),
             ));
         }
 
-        if (pos > 0) || ((-pos).as_usize() < (end - start + 1)) {
-            let corrected_pos = if pos > 0 {
-                (pos - 1).as_usize()
-            } else {
-                end - start - (-pos).as_usize() + if legacy_mode { 0 } else { 
1 }
-            };
-            let new_array_len = std::cmp::max(end - start + 1, corrected_pos);
-            if new_array_len > MAX_ROUNDED_ARRAY_LENGTH {
-                return Err(DataFusionError::Internal(format!(
-                    "Max array length in Spark is 
{MAX_ROUNDED_ARRAY_LENGTH:?}, but got {new_array_len:?}"
-                )));
-            }
+        let final_len: usize;
 
-            if (start + corrected_pos) <= end {
-                mutable_values.extend(0, start, start + corrected_pos);
+        if pos > 0 {
+            // Positive index (1-based)
+            let pos1 = pos as usize;
+            if pos1 <= len + 1 {
+                // In-range insertion (including appending to end)
+                let corrected = pos1 - 1; // 0-based insertion point
+                mutable_values.extend(0, start, start + corrected);
                 mutable_values.extend(1, row_index, row_index + 1);
-                mutable_values.extend(0, start + corrected_pos, end);
-                new_offsets.push(new_offsets[row_index] + 
O::usize_as(new_array_len));
+                mutable_values.extend(0, start + corrected, end);
+                final_len = len + 1;
             } else {
+                // Beyond end: pad with nulls then insert
+                let corrected = pos1 - 1;
+                let padding = corrected - len;
                 mutable_values.extend(0, start, end);
-                mutable_values.extend_nulls(new_array_len - (end - start));
+                mutable_values.extend_nulls(padding);
                 mutable_values.extend(1, row_index, row_index + 1);
-                // In that case spark actualy makes array longer than expected;
-                // For example, if pos is equal to 5, len is eq to 3, than 
resulted len will be 5
-                new_offsets.push(new_offsets[row_index] + 
O::usize_as(new_array_len) + O::one());
+                final_len = corrected + 1; // equals pos1
             }
         } else {
-            // This comment is takes from the Apache Spark source code as is:
-            // special case- if the new position is negative but larger than 
the current array size
-            // place the new item at start of array, place the current array 
contents at the end
-            // and fill the newly created array elements inbetween with a null
-            let base_offset = if legacy_mode { 1 } else { 0 };
-            let new_array_len = (-pos + base_offset).as_usize();
-            if new_array_len > MAX_ROUNDED_ARRAY_LENGTH {
-                return Err(DataFusionError::Internal(format!(
-                    "Max array length in Spark is 
{MAX_ROUNDED_ARRAY_LENGTH:?}, but got {new_array_len:?}"
-                )));
-            }
-            mutable_values.extend(1, row_index, row_index + 1);
-            mutable_values.extend_nulls(new_array_len - (end - start + 1));
-            mutable_values.extend(0, start, end);
-            new_offsets.push(new_offsets[row_index] + 
O::usize_as(new_array_len));
-        }
-        if is_item_null {
-            if (start == end) || (values.is_null(row_index)) {
-                new_nulls.push(false)
+            // Negative index (1-based from the end)
+            let k = (-pos) as usize;
+
+            if k <= len {
+                // In-range negative insertion
+                // Non-legacy: -1 behaves like append to end (corrected = len 
- k + 1)
+                // Legacy:     -1 behaves like insert before the last element 
(corrected = len - k)
+                let base_offset = if legacy_mode { 0 } else { 1 };
+                let corrected = len - k + base_offset;
+                mutable_values.extend(0, start, start + corrected);
+                mutable_values.extend(1, row_index, row_index + 1);
+                mutable_values.extend(0, start + corrected, end);
+                final_len = len + 1;
             } else {
-                new_nulls.push(true)
+                // Negative index beyond the start (Spark-specific behavior):
+                // Place item first, then pad with nulls, then append the 
original array.
+                // Final length = k + base_offset, where base_offset = 1 in 
legacy mode, otherwise 0.
+                let base_offset = if legacy_mode { 1 } else { 0 };
+                let target_len = k + base_offset;
+                let padding = target_len.saturating_sub(len + 1);
+                mutable_values.extend(1, row_index, row_index + 1); // insert 
item first
+                mutable_values.extend_nulls(padding); // pad nulls
+                mutable_values.extend(0, start, end); // append original values
+                final_len = target_len;
             }
-        } else {
-            new_nulls.push(true)
         }
+
+        if final_len > MAX_ROUNDED_ARRAY_LENGTH {
+            return Err(DataFusionError::Internal(format!(
+                "Max array length in Spark is {MAX_ROUNDED_ARRAY_LENGTH}, but 
got {final_len}"
+            )));
+        }
+
+        let prev = new_offsets[row_index].as_usize();
+        new_offsets.push(O::usize_as(prev + final_len));
+        list_valid.push(true);
     }
 
-    let data = make_array(mutable_values.freeze());
-    let data_type = match list_array.data_type() {
-        DataType::List(field) => field.data_type(),
-        DataType::LargeList(field) => field.data_type(),
+    let child = make_array(mutable_values.freeze());
+
+    // Reuse the original list element field (name/type/nullability)
+    let elem_field = match list_array.data_type() {
+        DataType::List(field) => Arc::clone(field),
+        DataType::LargeList(field) => Arc::clone(field),
         _ => unreachable!(),
     };
-    let new_array = GenericListArray::<O>::try_new(
-        Arc::new(Field::new("item", data_type.clone(), true)),
+
+    // Build the resulting list array
+    let new_list = GenericListArray::<O>::try_new(
+        elem_field,
         OffsetBuffer::new(new_offsets.into()),
-        data,
-        Some(NullBuffer::new(new_nulls.into())),
+        child,
+        Some(NullBuffer::new(list_valid.into())),
     )?;
 
-    Ok(ColumnarValue::Array(Arc::new(new_array)))
+    Ok(ColumnarValue::Array(Arc::new(new_list)))
 }
 
 impl Display for ArrayInsert {
@@ -442,4 +458,37 @@ mod test {
 
         Ok(())
     }
+
+    #[test]
+    fn test_array_insert_bug_repro_null_item_pos1_fixed() -> Result<()> {
+        use arrow::array::{Array, ArrayRef, Int32Array, ListArray};
+        use arrow::datatypes::Int32Type;
+
+        // row0 = [0, null, 0]
+        // row1 = [1, null, 1]
+        let list = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+            Some(vec![Some(0), None, Some(0)]),
+            Some(vec![Some(1), None, Some(1)]),
+        ]);
+
+        let positions = Int32Array::from(vec![1, 1]);
+        let items = Int32Array::from(vec![None, None]);
+
+        let ColumnarValue::Array(result) = array_insert(
+            &list,
+            &(Arc::new(items) as ArrayRef),
+            &(Arc::new(positions) as ArrayRef),
+            false, // legacy_mode = false
+        )?
+        else {
+            unreachable!()
+        };
+
+        let expected = ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+            Some(vec![None, Some(0), None, Some(0)]),
+            Some(vec![None, Some(1), None, Some(1)]),
+        ]);
+        assert_eq!(&result.to_data(), &expected.to_data());
+        Ok(())
+    }
 }
diff --git 
a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala
index c5060382e..4d06baaa8 100644
--- a/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometArrayExpressionSuite.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.CometTestBase
 import org.apache.spark.sql.catalyst.expressions.{ArrayAppend, ArrayDistinct, 
ArrayExcept, ArrayInsert, ArrayIntersect, ArrayJoin, ArrayRepeat, 
ArraysOverlap, ArrayUnion}
 import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
 import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.ArrayType
 
 import org.apache.comet.CometSparkSessionExtensions.{isSpark35Plus, 
isSpark40Plus}
 import org.apache.comet.DataTypeSupport.isComplexType
@@ -210,11 +211,13 @@ class CometArrayExpressionSuite extends CometTestBase 
with AdaptiveSparkPlanHelp
             .withColumn("arrInsertResult", expr("array_insert(arr, 1, 1)"))
             .withColumn("arrInsertNegativeIndexResult", 
expr("array_insert(arr, -1, 1)"))
             .withColumn("arrPosGreaterThanSize", expr("array_insert(arr, 8, 
1)"))
+            .withColumn("arrPosIsNull", expr("array_insert(arr, cast(null as 
int), 1)"))
             .withColumn("arrNegPosGreaterThanSize", expr("array_insert(arr, 
-8, 1)"))
             .withColumn("arrInsertNone", expr("array_insert(arr, 1, null)"))
           checkSparkAnswerAndOperator(df.select("arrInsertResult"))
           
checkSparkAnswerAndOperator(df.select("arrInsertNegativeIndexResult"))
           checkSparkAnswerAndOperator(df.select("arrPosGreaterThanSize"))
+          checkSparkAnswerAndOperator(df.select("arrPosIsNull"))
           checkSparkAnswerAndOperator(df.select("arrNegPosGreaterThanSize"))
           checkSparkAnswerAndOperator(df.select("arrInsertNone"))
         })
@@ -802,4 +805,28 @@ class CometArrayExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelp
         fallbackReason)
     }
   }
+
+  test("array_reverse 2") {
+    // This test validates data correctness for array<binary> columns with 
nullable elements.
+    // See https://github.com/apache/datafusion-comet/issues/2612
+    withTempDir { dir =>
+      val path = new Path(dir.toURI.toString, "test.parquet")
+      val filename = path.toString
+      val random = new Random(42)
+      withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
+        val schemaOptions =
+          SchemaGenOptions(generateArray = true, generateStruct = false, 
generateMap = false)
+        val dataOptions = DataGenOptions(allowNull = true, 
generateNegativeZero = false)
+        ParquetGenerator.makeParquetFile(random, spark, filename, 100, 
schemaOptions, dataOptions)
+      }
+      withTempView("t1") {
+        val table = spark.read.parquet(filename)
+        table.createOrReplaceTempView("t1")
+        for (field <- 
table.schema.fields.filter(_.dataType.isInstanceOf[ArrayType])) {
+          val sql = s"SELECT ${field.name}, reverse(${field.name}) FROM t1 
ORDER BY ${field.name}"
+          checkSparkAnswer(sql)
+        }
+      }
+    }
+  }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to