This is an automated email from the ASF dual-hosted git repository.

github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 321401ce62 fix: `array_remove`/`array_remove_n`/`array_remove_all` not 
using the same nullability as the input (#19259)
321401ce62 is described below

commit 321401ce627c586703f3b8902831bbe4578c5a9b
Author: Raz Luvaton <[email protected]>
AuthorDate: Wed Dec 10 21:13:34 2025 +0200

    fix: `array_remove`/`array_remove_n`/`array_remove_all` not using the same 
nullability as the input (#19259)
    
    ## Which issue does this PR close?
    
    - Closes #19260
    
    ## Rationale for this change
    
    removing items from list should not affect the nullability of the list
    
    ## What changes are included in this PR?
    
    reused the same input field in the output field
    
    ## Are these changes tested?
    
    yes, some of the tests fails on main and pass with this fix (as I added
    tests for nullable input for completness)
    
    ## Are there any user-facing changes?
    nullability change
---
 datafusion/functions-nested/src/remove.rs | 478 +++++++++++++++++++++++++++++-
 1 file changed, 467 insertions(+), 11 deletions(-)

diff --git a/datafusion/functions-nested/src/remove.rs 
b/datafusion/functions-nested/src/remove.rs
index 6cb4e28415..41c06cb9c4 100644
--- a/datafusion/functions-nested/src/remove.rs
+++ b/datafusion/functions-nested/src/remove.rs
@@ -24,10 +24,10 @@ use arrow::array::{
     new_empty_array,
 };
 use arrow::buffer::OffsetBuffer;
-use arrow::datatypes::{DataType, Field};
+use arrow::datatypes::{DataType, FieldRef};
 use datafusion_common::cast::as_int64_array;
 use datafusion_common::utils::ListCoercion;
-use datafusion_common::{Result, exec_err, utils::take_function_args};
+use datafusion_common::{Result, exec_err, internal_err, 
utils::take_function_args};
 use datafusion_expr::{
     ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue, 
Documentation,
     ScalarUDFImpl, Signature, TypeSignature, Volatility,
@@ -99,8 +99,15 @@ impl ScalarUDFImpl for ArrayRemove {
         &self.signature
     }
 
-    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
-        Ok(arg_types[0].clone())
+    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+        internal_err!("return_field_from_args should be used instead")
+    }
+
+    fn return_field_from_args(
+        &self,
+        args: datafusion_expr::ReturnFieldArgs,
+    ) -> Result<FieldRef> {
+        Ok(Arc::clone(&args.arg_fields[0]))
     }
 
     fn invoke_with_args(
@@ -187,8 +194,15 @@ impl ScalarUDFImpl for ArrayRemoveN {
         &self.signature
     }
 
-    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
-        Ok(arg_types[0].clone())
+    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+        internal_err!("return_field_from_args should be used instead")
+    }
+
+    fn return_field_from_args(
+        &self,
+        args: datafusion_expr::ReturnFieldArgs,
+    ) -> Result<FieldRef> {
+        Ok(Arc::clone(&args.arg_fields[0]))
     }
 
     fn invoke_with_args(
@@ -264,8 +278,15 @@ impl ScalarUDFImpl for ArrayRemoveAll {
         &self.signature
     }
 
-    fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
-        Ok(arg_types[0].clone())
+    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+        internal_err!("return_field_from_args should be used instead")
+    }
+
+    fn return_field_from_args(
+        &self,
+        args: datafusion_expr::ReturnFieldArgs,
+    ) -> Result<FieldRef> {
+        Ok(Arc::clone(&args.arg_fields[0]))
     }
 
     fn invoke_with_args(
@@ -347,7 +368,16 @@ fn general_remove<OffsetSize: OffsetSizeTrait>(
     element_array: &ArrayRef,
     arr_n: &[i64],
 ) -> Result<ArrayRef> {
-    let data_type = list_array.value_type();
+    let list_field = match list_array.data_type() {
+        DataType::List(field) | DataType::LargeList(field) => field,
+        _ => {
+            return exec_err!(
+                "Expected List or LargeList data type, got {:?}",
+                list_array.data_type()
+            );
+        }
+    };
+    let data_type = list_field.data_type();
     let mut new_values = vec![];
     // Build up the offsets for the final output array
     let mut offsets = Vec::<OffsetSize>::with_capacity(arr_n.len() + 1);
@@ -403,16 +433,442 @@ fn general_remove<OffsetSize: OffsetSizeTrait>(
     }
 
     let values = if new_values.is_empty() {
-        new_empty_array(&data_type)
+        new_empty_array(data_type)
     } else {
         let new_values = new_values.iter().map(|x| 
x.as_ref()).collect::<Vec<_>>();
         arrow::compute::concat(&new_values)?
     };
 
     Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
-        Arc::new(Field::new_list_field(data_type, true)),
+        Arc::clone(list_field),
         OffsetBuffer::new(offsets.into()),
         values,
         list_array.nulls().cloned(),
     )?))
 }
+
+#[cfg(test)]
+mod tests {
+    use crate::remove::{ArrayRemove, ArrayRemoveAll, ArrayRemoveN};
+    use arrow::array::{
+        Array, ArrayRef, AsArray, GenericListArray, ListArray, OffsetSizeTrait,
+    };
+    use arrow::datatypes::{DataType, Field, Int32Type};
+    use datafusion_common::ScalarValue;
+    use datafusion_expr::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl};
+    use datafusion_expr_common::columnar_value::ColumnarValue;
+    use std::ops::Deref;
+    use std::sync::Arc;
+
+    #[test]
+    fn test_array_remove_nullability() {
+        for nullability in [true, false] {
+            for item_nullability in [true, false] {
+                let input_field = Arc::new(Field::new(
+                    "num",
+                    DataType::new_list(DataType::Int32, item_nullability),
+                    nullability,
+                ));
+                let args_fields = vec![
+                    Arc::clone(&input_field),
+                    Arc::new(Field::new("a", DataType::Int32, false)),
+                ];
+                let scalar_args = vec![None, 
Some(&ScalarValue::Int32(Some(1)))];
+
+                let result = ArrayRemove::new()
+                    .return_field_from_args(ReturnFieldArgs {
+                        arg_fields: &args_fields,
+                        scalar_arguments: &scalar_args,
+                    })
+                    .unwrap();
+
+                assert_eq!(result, input_field);
+            }
+        }
+    }
+
+    #[test]
+    fn test_array_remove_n_nullability() {
+        for nullability in [true, false] {
+            for item_nullability in [true, false] {
+                let input_field = Arc::new(Field::new(
+                    "num",
+                    DataType::new_list(DataType::Int32, item_nullability),
+                    nullability,
+                ));
+                let args_fields = vec![
+                    Arc::clone(&input_field),
+                    Arc::new(Field::new("a", DataType::Int32, false)),
+                    Arc::new(Field::new("b", DataType::Int64, false)),
+                ];
+                let scalar_args = vec![
+                    None,
+                    Some(&ScalarValue::Int32(Some(1))),
+                    Some(&ScalarValue::Int64(Some(1))),
+                ];
+
+                let result = ArrayRemoveN::new()
+                    .return_field_from_args(ReturnFieldArgs {
+                        arg_fields: &args_fields,
+                        scalar_arguments: &scalar_args,
+                    })
+                    .unwrap();
+
+                assert_eq!(result, input_field);
+            }
+        }
+    }
+
+    #[test]
+    fn test_array_remove_all_nullability() {
+        for nullability in [true, false] {
+            for item_nullability in [true, false] {
+                let input_field = Arc::new(Field::new(
+                    "num",
+                    DataType::new_list(DataType::Int32, item_nullability),
+                    nullability,
+                ));
+                let result = ArrayRemoveAll::new()
+                    .return_field_from_args(ReturnFieldArgs {
+                        arg_fields: &[Arc::clone(&input_field)],
+                        scalar_arguments: &[None],
+                    })
+                    .unwrap();
+
+                assert_eq!(result, input_field);
+            }
+        }
+    }
+
+    fn ensure_field_nullability<O: OffsetSizeTrait>(
+        field_nullable: bool,
+        list: GenericListArray<O>,
+    ) -> GenericListArray<O> {
+        let (field, offsets, values, nulls) = list.into_parts();
+
+        if field.is_nullable() == field_nullable {
+            return GenericListArray::new(field, offsets, values, nulls);
+        }
+        if !field_nullable {
+            assert_eq!(nulls, None);
+        }
+
+        let field = 
Arc::new(field.deref().clone().with_nullable(field_nullable));
+
+        GenericListArray::new(field, offsets, values, nulls)
+    }
+
+    #[test]
+    fn test_array_remove_non_nullable() {
+        let input_list = Arc::new(ensure_field_nullability(
+            false,
+            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+                Some(([1, 2, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
+                Some(([42, 2, 55, 63, 2]).iter().copied().map(Some)),
+            ]),
+        ));
+        let expected_list = ensure_field_nullability(
+            false,
+            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+                Some(([1, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
+                Some(([42, 55, 63, 2]).iter().copied().map(Some)),
+            ]),
+        );
+
+        let element_to_remove = ScalarValue::Int32(Some(2));
+
+        assert_array_remove(input_list, expected_list, element_to_remove);
+    }
+
+    #[test]
+    fn test_array_remove_nullable() {
+        let input_list = Arc::new(ensure_field_nullability(
+            true,
+            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+                Some(vec![
+                    Some(1),
+                    Some(2),
+                    Some(2),
+                    Some(3),
+                    None,
+                    Some(1),
+                    Some(4),
+                ]),
+                Some(vec![Some(42), Some(2), None, Some(63), Some(2)]),
+            ]),
+        ));
+        let expected_list = ensure_field_nullability(
+            true,
+            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+                Some(vec![Some(1), Some(2), Some(3), None, Some(1), Some(4)]),
+                Some(vec![Some(42), None, Some(63), Some(2)]),
+            ]),
+        );
+
+        let element_to_remove = ScalarValue::Int32(Some(2));
+
+        assert_array_remove(input_list, expected_list, element_to_remove);
+    }
+
+    fn assert_array_remove(
+        input_list: ArrayRef,
+        expected_list: GenericListArray<i32>,
+        element_to_remove: ScalarValue,
+    ) {
+        assert_eq!(input_list.data_type(), expected_list.data_type());
+        assert_eq!(expected_list.value_type(), element_to_remove.data_type());
+        let input_list_len = input_list.len();
+        let input_list_data_type = input_list.data_type().clone();
+
+        let udf = ArrayRemove::new();
+        let args_fields = vec![
+            Arc::new(Field::new("num", input_list.data_type().clone(), false)),
+            Arc::new(Field::new(
+                "el",
+                element_to_remove.data_type(),
+                element_to_remove.is_null(),
+            )),
+        ];
+        let scalar_args = vec![None, Some(&element_to_remove)];
+
+        let return_field = udf
+            .return_field_from_args(ReturnFieldArgs {
+                arg_fields: &args_fields,
+                scalar_arguments: &scalar_args,
+            })
+            .unwrap();
+
+        let result = udf
+            .invoke_with_args(ScalarFunctionArgs {
+                args: vec![
+                    ColumnarValue::Array(input_list),
+                    ColumnarValue::Scalar(element_to_remove),
+                ],
+                arg_fields: args_fields,
+                number_rows: input_list_len,
+                return_field,
+                config_options: Arc::new(Default::default()),
+            })
+            .unwrap();
+
+        assert_eq!(result.data_type(), input_list_data_type);
+        match result {
+            ColumnarValue::Array(array) => {
+                let result_list = array.as_list::<i32>();
+                assert_eq!(result_list, &expected_list);
+            }
+            _ => panic!("Expected ColumnarValue::Array"),
+        }
+    }
+
+    #[test]
+    fn test_array_remove_n_non_nullable() {
+        let input_list = Arc::new(ensure_field_nullability(
+            false,
+            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+                Some(([1, 2, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
+                Some(([42, 2, 55, 63, 2]).iter().copied().map(Some)),
+            ]),
+        ));
+        let expected_list = ensure_field_nullability(
+            false,
+            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+                Some(([1, 3, 2, 1, 4]).iter().copied().map(Some)),
+                Some(([42, 55, 63]).iter().copied().map(Some)),
+            ]),
+        );
+
+        let element_to_remove = ScalarValue::Int32(Some(2));
+
+        assert_array_remove_n(input_list, expected_list, element_to_remove, 2);
+    }
+
+    #[test]
+    fn test_array_remove_n_nullable() {
+        let input_list = Arc::new(ensure_field_nullability(
+            true,
+            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+                Some(vec![
+                    Some(1),
+                    Some(2),
+                    Some(2),
+                    Some(3),
+                    None,
+                    Some(1),
+                    Some(4),
+                ]),
+                Some(vec![Some(42), Some(2), None, Some(63), Some(2)]),
+            ]),
+        ));
+        let expected_list = ensure_field_nullability(
+            true,
+            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+                Some(vec![Some(1), Some(3), None, Some(1), Some(4)]),
+                Some(vec![Some(42), None, Some(63)]),
+            ]),
+        );
+
+        let element_to_remove = ScalarValue::Int32(Some(2));
+
+        assert_array_remove_n(input_list, expected_list, element_to_remove, 2);
+    }
+
+    fn assert_array_remove_n(
+        input_list: ArrayRef,
+        expected_list: GenericListArray<i32>,
+        element_to_remove: ScalarValue,
+        n: i64,
+    ) {
+        assert_eq!(input_list.data_type(), expected_list.data_type());
+        assert_eq!(expected_list.value_type(), element_to_remove.data_type());
+        let input_list_len = input_list.len();
+        let input_list_data_type = input_list.data_type().clone();
+
+        let count_scalar = ScalarValue::Int64(Some(n));
+
+        let udf = ArrayRemoveN::new();
+        let args_fields = vec![
+            Arc::new(Field::new("num", input_list.data_type().clone(), false)),
+            Arc::new(Field::new(
+                "el",
+                element_to_remove.data_type(),
+                element_to_remove.is_null(),
+            )),
+            Arc::new(Field::new("count", DataType::Int64, false)),
+        ];
+        let scalar_args = vec![None, Some(&element_to_remove), 
Some(&count_scalar)];
+
+        let return_field = udf
+            .return_field_from_args(ReturnFieldArgs {
+                arg_fields: &args_fields,
+                scalar_arguments: &scalar_args,
+            })
+            .unwrap();
+
+        let result = udf
+            .invoke_with_args(ScalarFunctionArgs {
+                args: vec![
+                    ColumnarValue::Array(input_list),
+                    ColumnarValue::Scalar(element_to_remove),
+                    ColumnarValue::Scalar(count_scalar),
+                ],
+                arg_fields: args_fields,
+                number_rows: input_list_len,
+                return_field,
+                config_options: Arc::new(Default::default()),
+            })
+            .unwrap();
+
+        assert_eq!(result.data_type(), input_list_data_type);
+        match result {
+            ColumnarValue::Array(array) => {
+                let result_list = array.as_list::<i32>();
+                assert_eq!(result_list, &expected_list);
+            }
+            _ => panic!("Expected ColumnarValue::Array"),
+        }
+    }
+
+    #[test]
+    fn test_array_remove_all_non_nullable() {
+        let input_list = Arc::new(ensure_field_nullability(
+            false,
+            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+                Some(([1, 2, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
+                Some(([42, 2, 55, 63, 2]).iter().copied().map(Some)),
+            ]),
+        ));
+        let expected_list = ensure_field_nullability(
+            false,
+            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+                Some(([1, 3, 1, 4]).iter().copied().map(Some)),
+                Some(([42, 55, 63]).iter().copied().map(Some)),
+            ]),
+        );
+
+        let element_to_remove = ScalarValue::Int32(Some(2));
+
+        assert_array_remove_all(input_list, expected_list, element_to_remove);
+    }
+
+    #[test]
+    fn test_array_remove_all_nullable() {
+        let input_list = Arc::new(ensure_field_nullability(
+            true,
+            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+                Some(vec![
+                    Some(1),
+                    Some(2),
+                    Some(2),
+                    Some(3),
+                    None,
+                    Some(1),
+                    Some(4),
+                ]),
+                Some(vec![Some(42), Some(2), None, Some(63), Some(2)]),
+            ]),
+        ));
+        let expected_list = ensure_field_nullability(
+            true,
+            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+                Some(vec![Some(1), Some(3), None, Some(1), Some(4)]),
+                Some(vec![Some(42), None, Some(63)]),
+            ]),
+        );
+
+        let element_to_remove = ScalarValue::Int32(Some(2));
+
+        assert_array_remove_all(input_list, expected_list, element_to_remove);
+    }
+
+    fn assert_array_remove_all(
+        input_list: ArrayRef,
+        expected_list: GenericListArray<i32>,
+        element_to_remove: ScalarValue,
+    ) {
+        assert_eq!(input_list.data_type(), expected_list.data_type());
+        assert_eq!(expected_list.value_type(), element_to_remove.data_type());
+        let input_list_len = input_list.len();
+        let input_list_data_type = input_list.data_type().clone();
+
+        let udf = ArrayRemoveAll::new();
+        let args_fields = vec![
+            Arc::new(Field::new("num", input_list.data_type().clone(), false)),
+            Arc::new(Field::new(
+                "el",
+                element_to_remove.data_type(),
+                element_to_remove.is_null(),
+            )),
+        ];
+        let scalar_args = vec![None, Some(&element_to_remove)];
+
+        let return_field = udf
+            .return_field_from_args(ReturnFieldArgs {
+                arg_fields: &args_fields,
+                scalar_arguments: &scalar_args,
+            })
+            .unwrap();
+
+        let result = udf
+            .invoke_with_args(ScalarFunctionArgs {
+                args: vec![
+                    ColumnarValue::Array(input_list),
+                    ColumnarValue::Scalar(element_to_remove),
+                ],
+                arg_fields: args_fields,
+                number_rows: input_list_len,
+                return_field,
+                config_options: Arc::new(Default::default()),
+            })
+            .unwrap();
+
+        assert_eq!(result.data_type(), input_list_data_type);
+        match result {
+            ColumnarValue::Array(array) => {
+                let result_list = array.as_list::<i32>();
+                assert_eq!(result_list, &expected_list);
+            }
+            _ => panic!("Expected ColumnarValue::Array"),
+        }
+    }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to