This is an automated email from the ASF dual-hosted git repository.
github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 321401ce62 fix: `array_remove`/`array_remove_n`/`array_remove_all` not
using the same nullability as the input (#19259)
321401ce62 is described below
commit 321401ce627c586703f3b8902831bbe4578c5a9b
Author: Raz Luvaton <[email protected]>
AuthorDate: Wed Dec 10 21:13:34 2025 +0200
fix: `array_remove`/`array_remove_n`/`array_remove_all` not using the same
nullability as the input (#19259)
## Which issue does this PR close?
- Closes #19260
## Rationale for this change
removing items from list should not affect the nullability of the list
## What changes are included in this PR?
reused the same input field in the output field
## Are these changes tested?
yes, some of the tests fails on main and pass with this fix (as I added
tests for nullable input for completness)
## Are there any user-facing changes?
nullability change
---
datafusion/functions-nested/src/remove.rs | 478 +++++++++++++++++++++++++++++-
1 file changed, 467 insertions(+), 11 deletions(-)
diff --git a/datafusion/functions-nested/src/remove.rs
b/datafusion/functions-nested/src/remove.rs
index 6cb4e28415..41c06cb9c4 100644
--- a/datafusion/functions-nested/src/remove.rs
+++ b/datafusion/functions-nested/src/remove.rs
@@ -24,10 +24,10 @@ use arrow::array::{
new_empty_array,
};
use arrow::buffer::OffsetBuffer;
-use arrow::datatypes::{DataType, Field};
+use arrow::datatypes::{DataType, FieldRef};
use datafusion_common::cast::as_int64_array;
use datafusion_common::utils::ListCoercion;
-use datafusion_common::{Result, exec_err, utils::take_function_args};
+use datafusion_common::{Result, exec_err, internal_err,
utils::take_function_args};
use datafusion_expr::{
ArrayFunctionArgument, ArrayFunctionSignature, ColumnarValue,
Documentation,
ScalarUDFImpl, Signature, TypeSignature, Volatility,
@@ -99,8 +99,15 @@ impl ScalarUDFImpl for ArrayRemove {
&self.signature
}
- fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
- Ok(arg_types[0].clone())
+ fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+ internal_err!("return_field_from_args should be used instead")
+ }
+
+ fn return_field_from_args(
+ &self,
+ args: datafusion_expr::ReturnFieldArgs,
+ ) -> Result<FieldRef> {
+ Ok(Arc::clone(&args.arg_fields[0]))
}
fn invoke_with_args(
@@ -187,8 +194,15 @@ impl ScalarUDFImpl for ArrayRemoveN {
&self.signature
}
- fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
- Ok(arg_types[0].clone())
+ fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+ internal_err!("return_field_from_args should be used instead")
+ }
+
+ fn return_field_from_args(
+ &self,
+ args: datafusion_expr::ReturnFieldArgs,
+ ) -> Result<FieldRef> {
+ Ok(Arc::clone(&args.arg_fields[0]))
}
fn invoke_with_args(
@@ -264,8 +278,15 @@ impl ScalarUDFImpl for ArrayRemoveAll {
&self.signature
}
- fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
- Ok(arg_types[0].clone())
+ fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+ internal_err!("return_field_from_args should be used instead")
+ }
+
+ fn return_field_from_args(
+ &self,
+ args: datafusion_expr::ReturnFieldArgs,
+ ) -> Result<FieldRef> {
+ Ok(Arc::clone(&args.arg_fields[0]))
}
fn invoke_with_args(
@@ -347,7 +368,16 @@ fn general_remove<OffsetSize: OffsetSizeTrait>(
element_array: &ArrayRef,
arr_n: &[i64],
) -> Result<ArrayRef> {
- let data_type = list_array.value_type();
+ let list_field = match list_array.data_type() {
+ DataType::List(field) | DataType::LargeList(field) => field,
+ _ => {
+ return exec_err!(
+ "Expected List or LargeList data type, got {:?}",
+ list_array.data_type()
+ );
+ }
+ };
+ let data_type = list_field.data_type();
let mut new_values = vec![];
// Build up the offsets for the final output array
let mut offsets = Vec::<OffsetSize>::with_capacity(arr_n.len() + 1);
@@ -403,16 +433,442 @@ fn general_remove<OffsetSize: OffsetSizeTrait>(
}
let values = if new_values.is_empty() {
- new_empty_array(&data_type)
+ new_empty_array(data_type)
} else {
let new_values = new_values.iter().map(|x|
x.as_ref()).collect::<Vec<_>>();
arrow::compute::concat(&new_values)?
};
Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
- Arc::new(Field::new_list_field(data_type, true)),
+ Arc::clone(list_field),
OffsetBuffer::new(offsets.into()),
values,
list_array.nulls().cloned(),
)?))
}
+
+#[cfg(test)]
+mod tests {
+ use crate::remove::{ArrayRemove, ArrayRemoveAll, ArrayRemoveN};
+ use arrow::array::{
+ Array, ArrayRef, AsArray, GenericListArray, ListArray, OffsetSizeTrait,
+ };
+ use arrow::datatypes::{DataType, Field, Int32Type};
+ use datafusion_common::ScalarValue;
+ use datafusion_expr::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl};
+ use datafusion_expr_common::columnar_value::ColumnarValue;
+ use std::ops::Deref;
+ use std::sync::Arc;
+
+ #[test]
+ fn test_array_remove_nullability() {
+ for nullability in [true, false] {
+ for item_nullability in [true, false] {
+ let input_field = Arc::new(Field::new(
+ "num",
+ DataType::new_list(DataType::Int32, item_nullability),
+ nullability,
+ ));
+ let args_fields = vec![
+ Arc::clone(&input_field),
+ Arc::new(Field::new("a", DataType::Int32, false)),
+ ];
+ let scalar_args = vec![None,
Some(&ScalarValue::Int32(Some(1)))];
+
+ let result = ArrayRemove::new()
+ .return_field_from_args(ReturnFieldArgs {
+ arg_fields: &args_fields,
+ scalar_arguments: &scalar_args,
+ })
+ .unwrap();
+
+ assert_eq!(result, input_field);
+ }
+ }
+ }
+
+ #[test]
+ fn test_array_remove_n_nullability() {
+ for nullability in [true, false] {
+ for item_nullability in [true, false] {
+ let input_field = Arc::new(Field::new(
+ "num",
+ DataType::new_list(DataType::Int32, item_nullability),
+ nullability,
+ ));
+ let args_fields = vec![
+ Arc::clone(&input_field),
+ Arc::new(Field::new("a", DataType::Int32, false)),
+ Arc::new(Field::new("b", DataType::Int64, false)),
+ ];
+ let scalar_args = vec![
+ None,
+ Some(&ScalarValue::Int32(Some(1))),
+ Some(&ScalarValue::Int64(Some(1))),
+ ];
+
+ let result = ArrayRemoveN::new()
+ .return_field_from_args(ReturnFieldArgs {
+ arg_fields: &args_fields,
+ scalar_arguments: &scalar_args,
+ })
+ .unwrap();
+
+ assert_eq!(result, input_field);
+ }
+ }
+ }
+
+ #[test]
+ fn test_array_remove_all_nullability() {
+ for nullability in [true, false] {
+ for item_nullability in [true, false] {
+ let input_field = Arc::new(Field::new(
+ "num",
+ DataType::new_list(DataType::Int32, item_nullability),
+ nullability,
+ ));
+ let result = ArrayRemoveAll::new()
+ .return_field_from_args(ReturnFieldArgs {
+ arg_fields: &[Arc::clone(&input_field)],
+ scalar_arguments: &[None],
+ })
+ .unwrap();
+
+ assert_eq!(result, input_field);
+ }
+ }
+ }
+
+ fn ensure_field_nullability<O: OffsetSizeTrait>(
+ field_nullable: bool,
+ list: GenericListArray<O>,
+ ) -> GenericListArray<O> {
+ let (field, offsets, values, nulls) = list.into_parts();
+
+ if field.is_nullable() == field_nullable {
+ return GenericListArray::new(field, offsets, values, nulls);
+ }
+ if !field_nullable {
+ assert_eq!(nulls, None);
+ }
+
+ let field =
Arc::new(field.deref().clone().with_nullable(field_nullable));
+
+ GenericListArray::new(field, offsets, values, nulls)
+ }
+
+ #[test]
+ fn test_array_remove_non_nullable() {
+ let input_list = Arc::new(ensure_field_nullability(
+ false,
+ ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+ Some(([1, 2, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
+ Some(([42, 2, 55, 63, 2]).iter().copied().map(Some)),
+ ]),
+ ));
+ let expected_list = ensure_field_nullability(
+ false,
+ ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+ Some(([1, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
+ Some(([42, 55, 63, 2]).iter().copied().map(Some)),
+ ]),
+ );
+
+ let element_to_remove = ScalarValue::Int32(Some(2));
+
+ assert_array_remove(input_list, expected_list, element_to_remove);
+ }
+
+ #[test]
+ fn test_array_remove_nullable() {
+ let input_list = Arc::new(ensure_field_nullability(
+ true,
+ ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+ Some(vec![
+ Some(1),
+ Some(2),
+ Some(2),
+ Some(3),
+ None,
+ Some(1),
+ Some(4),
+ ]),
+ Some(vec![Some(42), Some(2), None, Some(63), Some(2)]),
+ ]),
+ ));
+ let expected_list = ensure_field_nullability(
+ true,
+ ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+ Some(vec![Some(1), Some(2), Some(3), None, Some(1), Some(4)]),
+ Some(vec![Some(42), None, Some(63), Some(2)]),
+ ]),
+ );
+
+ let element_to_remove = ScalarValue::Int32(Some(2));
+
+ assert_array_remove(input_list, expected_list, element_to_remove);
+ }
+
+ fn assert_array_remove(
+ input_list: ArrayRef,
+ expected_list: GenericListArray<i32>,
+ element_to_remove: ScalarValue,
+ ) {
+ assert_eq!(input_list.data_type(), expected_list.data_type());
+ assert_eq!(expected_list.value_type(), element_to_remove.data_type());
+ let input_list_len = input_list.len();
+ let input_list_data_type = input_list.data_type().clone();
+
+ let udf = ArrayRemove::new();
+ let args_fields = vec![
+ Arc::new(Field::new("num", input_list.data_type().clone(), false)),
+ Arc::new(Field::new(
+ "el",
+ element_to_remove.data_type(),
+ element_to_remove.is_null(),
+ )),
+ ];
+ let scalar_args = vec![None, Some(&element_to_remove)];
+
+ let return_field = udf
+ .return_field_from_args(ReturnFieldArgs {
+ arg_fields: &args_fields,
+ scalar_arguments: &scalar_args,
+ })
+ .unwrap();
+
+ let result = udf
+ .invoke_with_args(ScalarFunctionArgs {
+ args: vec![
+ ColumnarValue::Array(input_list),
+ ColumnarValue::Scalar(element_to_remove),
+ ],
+ arg_fields: args_fields,
+ number_rows: input_list_len,
+ return_field,
+ config_options: Arc::new(Default::default()),
+ })
+ .unwrap();
+
+ assert_eq!(result.data_type(), input_list_data_type);
+ match result {
+ ColumnarValue::Array(array) => {
+ let result_list = array.as_list::<i32>();
+ assert_eq!(result_list, &expected_list);
+ }
+ _ => panic!("Expected ColumnarValue::Array"),
+ }
+ }
+
+ #[test]
+ fn test_array_remove_n_non_nullable() {
+ let input_list = Arc::new(ensure_field_nullability(
+ false,
+ ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+ Some(([1, 2, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
+ Some(([42, 2, 55, 63, 2]).iter().copied().map(Some)),
+ ]),
+ ));
+ let expected_list = ensure_field_nullability(
+ false,
+ ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+ Some(([1, 3, 2, 1, 4]).iter().copied().map(Some)),
+ Some(([42, 55, 63]).iter().copied().map(Some)),
+ ]),
+ );
+
+ let element_to_remove = ScalarValue::Int32(Some(2));
+
+ assert_array_remove_n(input_list, expected_list, element_to_remove, 2);
+ }
+
+ #[test]
+ fn test_array_remove_n_nullable() {
+ let input_list = Arc::new(ensure_field_nullability(
+ true,
+ ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+ Some(vec![
+ Some(1),
+ Some(2),
+ Some(2),
+ Some(3),
+ None,
+ Some(1),
+ Some(4),
+ ]),
+ Some(vec![Some(42), Some(2), None, Some(63), Some(2)]),
+ ]),
+ ));
+ let expected_list = ensure_field_nullability(
+ true,
+ ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+ Some(vec![Some(1), Some(3), None, Some(1), Some(4)]),
+ Some(vec![Some(42), None, Some(63)]),
+ ]),
+ );
+
+ let element_to_remove = ScalarValue::Int32(Some(2));
+
+ assert_array_remove_n(input_list, expected_list, element_to_remove, 2);
+ }
+
+ fn assert_array_remove_n(
+ input_list: ArrayRef,
+ expected_list: GenericListArray<i32>,
+ element_to_remove: ScalarValue,
+ n: i64,
+ ) {
+ assert_eq!(input_list.data_type(), expected_list.data_type());
+ assert_eq!(expected_list.value_type(), element_to_remove.data_type());
+ let input_list_len = input_list.len();
+ let input_list_data_type = input_list.data_type().clone();
+
+ let count_scalar = ScalarValue::Int64(Some(n));
+
+ let udf = ArrayRemoveN::new();
+ let args_fields = vec![
+ Arc::new(Field::new("num", input_list.data_type().clone(), false)),
+ Arc::new(Field::new(
+ "el",
+ element_to_remove.data_type(),
+ element_to_remove.is_null(),
+ )),
+ Arc::new(Field::new("count", DataType::Int64, false)),
+ ];
+ let scalar_args = vec![None, Some(&element_to_remove),
Some(&count_scalar)];
+
+ let return_field = udf
+ .return_field_from_args(ReturnFieldArgs {
+ arg_fields: &args_fields,
+ scalar_arguments: &scalar_args,
+ })
+ .unwrap();
+
+ let result = udf
+ .invoke_with_args(ScalarFunctionArgs {
+ args: vec![
+ ColumnarValue::Array(input_list),
+ ColumnarValue::Scalar(element_to_remove),
+ ColumnarValue::Scalar(count_scalar),
+ ],
+ arg_fields: args_fields,
+ number_rows: input_list_len,
+ return_field,
+ config_options: Arc::new(Default::default()),
+ })
+ .unwrap();
+
+ assert_eq!(result.data_type(), input_list_data_type);
+ match result {
+ ColumnarValue::Array(array) => {
+ let result_list = array.as_list::<i32>();
+ assert_eq!(result_list, &expected_list);
+ }
+ _ => panic!("Expected ColumnarValue::Array"),
+ }
+ }
+
+ #[test]
+ fn test_array_remove_all_non_nullable() {
+ let input_list = Arc::new(ensure_field_nullability(
+ false,
+ ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+ Some(([1, 2, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
+ Some(([42, 2, 55, 63, 2]).iter().copied().map(Some)),
+ ]),
+ ));
+ let expected_list = ensure_field_nullability(
+ false,
+ ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+ Some(([1, 3, 1, 4]).iter().copied().map(Some)),
+ Some(([42, 55, 63]).iter().copied().map(Some)),
+ ]),
+ );
+
+ let element_to_remove = ScalarValue::Int32(Some(2));
+
+ assert_array_remove_all(input_list, expected_list, element_to_remove);
+ }
+
+ #[test]
+ fn test_array_remove_all_nullable() {
+ let input_list = Arc::new(ensure_field_nullability(
+ true,
+ ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+ Some(vec![
+ Some(1),
+ Some(2),
+ Some(2),
+ Some(3),
+ None,
+ Some(1),
+ Some(4),
+ ]),
+ Some(vec![Some(42), Some(2), None, Some(63), Some(2)]),
+ ]),
+ ));
+ let expected_list = ensure_field_nullability(
+ true,
+ ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+ Some(vec![Some(1), Some(3), None, Some(1), Some(4)]),
+ Some(vec![Some(42), None, Some(63)]),
+ ]),
+ );
+
+ let element_to_remove = ScalarValue::Int32(Some(2));
+
+ assert_array_remove_all(input_list, expected_list, element_to_remove);
+ }
+
+ fn assert_array_remove_all(
+ input_list: ArrayRef,
+ expected_list: GenericListArray<i32>,
+ element_to_remove: ScalarValue,
+ ) {
+ assert_eq!(input_list.data_type(), expected_list.data_type());
+ assert_eq!(expected_list.value_type(), element_to_remove.data_type());
+ let input_list_len = input_list.len();
+ let input_list_data_type = input_list.data_type().clone();
+
+ let udf = ArrayRemoveAll::new();
+ let args_fields = vec![
+ Arc::new(Field::new("num", input_list.data_type().clone(), false)),
+ Arc::new(Field::new(
+ "el",
+ element_to_remove.data_type(),
+ element_to_remove.is_null(),
+ )),
+ ];
+ let scalar_args = vec![None, Some(&element_to_remove)];
+
+ let return_field = udf
+ .return_field_from_args(ReturnFieldArgs {
+ arg_fields: &args_fields,
+ scalar_arguments: &scalar_args,
+ })
+ .unwrap();
+
+ let result = udf
+ .invoke_with_args(ScalarFunctionArgs {
+ args: vec![
+ ColumnarValue::Array(input_list),
+ ColumnarValue::Scalar(element_to_remove),
+ ],
+ arg_fields: args_fields,
+ number_rows: input_list_len,
+ return_field,
+ config_options: Arc::new(Default::default()),
+ })
+ .unwrap();
+
+ assert_eq!(result.data_type(), input_list_data_type);
+ match result {
+ ColumnarValue::Array(array) => {
+ let result_list = array.as_list::<i32>();
+ assert_eq!(result_list, &expected_list);
+ }
+ _ => panic!("Expected ColumnarValue::Array"),
+ }
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]