rluvaton commented on code in PR #19259:
URL: https://github.com/apache/datafusion/pull/19259#discussion_r2606843228


##########
datafusion/functions-nested/src/remove.rs:
##########
@@ -403,16 +433,442 @@ fn general_remove<OffsetSize: OffsetSizeTrait>(
     }
 
     let values = if new_values.is_empty() {
-        new_empty_array(&data_type)
+        new_empty_array(data_type)
     } else {
         let new_values = new_values.iter().map(|x| 
x.as_ref()).collect::<Vec<_>>();
         arrow::compute::concat(&new_values)?
     };
 
     Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
-        Arc::new(Field::new_list_field(data_type, true)),
+        Arc::clone(list_field),
         OffsetBuffer::new(offsets.into()),
         values,
         list_array.nulls().cloned(),
     )?))
 }
+
+#[cfg(test)]
+mod tests {
+    use crate::remove::{ArrayRemove, ArrayRemoveAll, ArrayRemoveN};
+    use arrow::array::{
+        Array, ArrayRef, AsArray, GenericListArray, ListArray, OffsetSizeTrait,
+    };
+    use arrow::datatypes::{DataType, Field, Int32Type};
+    use datafusion_common::ScalarValue;
+    use datafusion_expr::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl};
+    use datafusion_expr_common::columnar_value::ColumnarValue;
+    use std::ops::Deref;
+    use std::sync::Arc;
+
+    #[test]
+    fn test_array_remove_nullability() {
+        for nullability in [true, false] {
+            for item_nullability in [true, false] {
+                let input_field = Arc::new(Field::new(
+                    "num",
+                    DataType::new_list(DataType::Int32, item_nullability),
+                    nullability,
+                ));
+                let args_fields = vec![
+                    Arc::clone(&input_field),
+                    Arc::new(Field::new("a", DataType::Int32, false)),
+                ];
+                let scalar_args = vec![None, 
Some(&ScalarValue::Int32(Some(1)))];
+
+                let result = ArrayRemove::new()
+                    .return_field_from_args(ReturnFieldArgs {
+                        arg_fields: &args_fields,
+                        scalar_arguments: &scalar_args,
+                    })
+                    .unwrap();
+
+                assert_eq!(result, input_field);
+            }
+        }
+    }
+
+    #[test]
+    fn test_array_remove_n_nullability() {
+        for nullability in [true, false] {
+            for item_nullability in [true, false] {
+                let input_field = Arc::new(Field::new(
+                    "num",
+                    DataType::new_list(DataType::Int32, item_nullability),
+                    nullability,
+                ));
+                let args_fields = vec![
+                    Arc::clone(&input_field),
+                    Arc::new(Field::new("a", DataType::Int32, false)),
+                    Arc::new(Field::new("b", DataType::Int64, false)),
+                ];
+                let scalar_args = vec![
+                    None,
+                    Some(&ScalarValue::Int32(Some(1))),
+                    Some(&ScalarValue::Int64(Some(1))),
+                ];
+
+                let result = ArrayRemoveN::new()
+                    .return_field_from_args(ReturnFieldArgs {
+                        arg_fields: &args_fields,
+                        scalar_arguments: &scalar_args,
+                    })
+                    .unwrap();
+
+                assert_eq!(result, input_field);
+            }
+        }
+    }
+
+    #[test]
+    fn test_array_remove_all_nullability() {
+        for nullability in [true, false] {
+            for item_nullability in [true, false] {
+                let input_field = Arc::new(Field::new(
+                    "num",
+                    DataType::new_list(DataType::Int32, item_nullability),
+                    nullability,
+                ));
+                let result = ArrayRemoveAll::new()
+                    .return_field_from_args(ReturnFieldArgs {
+                        arg_fields: &[Arc::clone(&input_field)],
+                        scalar_arguments: &[None],
+                    })
+                    .unwrap();
+
+                assert_eq!(result, input_field);
+            }
+        }
+    }
+
+    fn ensure_field_nullability<O: OffsetSizeTrait>(
+        field_nullable: bool,
+        list: GenericListArray<O>,
+    ) -> GenericListArray<O> {
+        let (field, offsets, values, nulls) = list.into_parts();
+
+        if field.is_nullable() == field_nullable {
+            return GenericListArray::new(field, offsets, values, nulls);
+        }
+        if !field_nullable {
+            assert_eq!(nulls, None);
+        }
+
+        let field = 
Arc::new(field.deref().clone().with_nullable(field_nullable));
+
+        GenericListArray::new(field, offsets, values, nulls)
+    }
+
+    #[test]
+    fn test_array_remove_non_nullable() {
+        let input_list = Arc::new(ensure_field_nullability(
+            false,
+            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+                Some(([1, 2, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
+                Some(([42, 2, 55, 63, 2]).iter().copied().map(Some)),
+            ]),
+        ));
+        let expected_list = ensure_field_nullability(
+            false,
+            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+                Some(([1, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
+                Some(([42, 55, 63, 2]).iter().copied().map(Some)),
+            ]),
+        );
+
+        let element_to_remove = ScalarValue::Int32(Some(2));
+
+        assert_array_remove(input_list, expected_list, element_to_remove);
+    }
+
+    #[test]
+    fn test_array_remove_nullable() {
+        let input_list = Arc::new(ensure_field_nullability(
+            true,
+            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+                Some(vec![
+                    Some(1),
+                    Some(2),
+                    Some(2),
+                    Some(3),
+                    None,
+                    Some(1),
+                    Some(4),
+                ]),
+                Some(vec![Some(42), Some(2), None, Some(63), Some(2)]),
+            ]),
+        ));
+        let expected_list = ensure_field_nullability(
+            true,
+            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+                Some(vec![Some(1), Some(2), Some(3), None, Some(1), Some(4)]),
+                Some(vec![Some(42), None, Some(63), Some(2)]),
+            ]),
+        );
+
+        let element_to_remove = ScalarValue::Int32(Some(2));
+
+        assert_array_remove(input_list, expected_list, element_to_remove);
+    }
+
+    fn assert_array_remove(
+        input_list: ArrayRef,
+        expected_list: GenericListArray<i32>,
+        element_to_remove: ScalarValue,
+    ) {
+        assert_eq!(input_list.data_type(), expected_list.data_type());
+        assert_eq!(expected_list.value_type(), element_to_remove.data_type());
+        let input_list_len = input_list.len();
+        let input_list_data_type = input_list.data_type().clone();
+
+        let udf = ArrayRemove::new();
+        let args_fields = vec![
+            Arc::new(Field::new("num", input_list.data_type().clone(), false)),
+            Arc::new(Field::new(
+                "el",
+                element_to_remove.data_type(),
+                element_to_remove.is_null(),
+            )),
+        ];
+        let scalar_args = vec![None, Some(&element_to_remove)];
+
+        let return_field = udf
+            .return_field_from_args(ReturnFieldArgs {
+                arg_fields: &args_fields,
+                scalar_arguments: &scalar_args,
+            })
+            .unwrap();
+
+        let result = udf
+            .invoke_with_args(ScalarFunctionArgs {
+                args: vec![
+                    ColumnarValue::Array(input_list),
+                    ColumnarValue::Scalar(element_to_remove),
+                ],
+                arg_fields: args_fields,
+                number_rows: input_list_len,
+                return_field,
+                config_options: Arc::new(Default::default()),
+            })
+            .unwrap();
+
+        assert_eq!(result.data_type(), input_list_data_type);
+        match result {
+            ColumnarValue::Array(array) => {
+                let result_list = array.as_list::<i32>();
+                assert_eq!(result_list, &expected_list);
+            }
+            _ => panic!("Expected ColumnarValue::Array"),
+        }
+    }
+
+    #[test]
+    fn test_array_remove_n_non_nullable() {
+        let input_list = Arc::new(ensure_field_nullability(
+            false,
+            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+                Some(([1, 2, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
+                Some(([42, 2, 55, 63, 2]).iter().copied().map(Some)),
+            ]),
+        ));
+        let expected_list = ensure_field_nullability(
+            false,
+            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+                Some(([1, 3, 2, 1, 4]).iter().copied().map(Some)),
+                Some(([42, 55, 63]).iter().copied().map(Some)),
+            ]),
+        );
+
+        let element_to_remove = ScalarValue::Int32(Some(2));
+
+        assert_array_remove_n(input_list, expected_list, element_to_remove, 2);
+    }
+
+    #[test]
+    fn test_array_remove_n_nullable() {
+        let input_list = Arc::new(ensure_field_nullability(
+            true,
+            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+                Some(vec![
+                    Some(1),
+                    Some(2),
+                    Some(2),
+                    Some(3),
+                    None,
+                    Some(1),
+                    Some(4),
+                ]),
+                Some(vec![Some(42), Some(2), None, Some(63), Some(2)]),
+            ]),
+        ));
+        let expected_list = ensure_field_nullability(
+            true,
+            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+                Some(vec![Some(1), Some(3), None, Some(1), Some(4)]),
+                Some(vec![Some(42), None, Some(63)]),
+            ]),
+        );
+
+        let element_to_remove = ScalarValue::Int32(Some(2));
+
+        assert_array_remove_n(input_list, expected_list, element_to_remove, 2);
+    }
+
+    fn assert_array_remove_n(
+        input_list: ArrayRef,
+        expected_list: GenericListArray<i32>,
+        element_to_remove: ScalarValue,
+        n: i64,
+    ) {
+        assert_eq!(input_list.data_type(), expected_list.data_type());
+        assert_eq!(expected_list.value_type(), element_to_remove.data_type());
+        let input_list_len = input_list.len();
+        let input_list_data_type = input_list.data_type().clone();
+
+        let count_scalar = ScalarValue::Int64(Some(n));
+
+        let udf = ArrayRemoveN::new();
+        let args_fields = vec![
+            Arc::new(Field::new("num", input_list.data_type().clone(), false)),
+            Arc::new(Field::new(
+                "el",
+                element_to_remove.data_type(),
+                element_to_remove.is_null(),
+            )),
+            Arc::new(Field::new("count", DataType::Int64, false)),
+        ];
+        let scalar_args = vec![None, Some(&element_to_remove), 
Some(&count_scalar)];
+
+        let return_field = udf
+            .return_field_from_args(ReturnFieldArgs {
+                arg_fields: &args_fields,
+                scalar_arguments: &scalar_args,
+            })
+            .unwrap();
+
+        let result = udf
+            .invoke_with_args(ScalarFunctionArgs {
+                args: vec![
+                    ColumnarValue::Array(input_list),
+                    ColumnarValue::Scalar(element_to_remove),
+                    ColumnarValue::Scalar(count_scalar),
+                ],
+                arg_fields: args_fields,
+                number_rows: input_list_len,
+                return_field,
+                config_options: Arc::new(Default::default()),
+            })
+            .unwrap();
+
+        assert_eq!(result.data_type(), input_list_data_type);
+        match result {
+            ColumnarValue::Array(array) => {
+                let result_list = array.as_list::<i32>();
+                assert_eq!(result_list, &expected_list);
+            }
+            _ => panic!("Expected ColumnarValue::Array"),
+        }
+    }
+
+    #[test]
+    fn test_array_remove_all_non_nullable() {

Review Comment:
   Fails on main



##########
datafusion/functions-nested/src/remove.rs:
##########
@@ -403,16 +433,442 @@ fn general_remove<OffsetSize: OffsetSizeTrait>(
     }
 
     let values = if new_values.is_empty() {
-        new_empty_array(&data_type)
+        new_empty_array(data_type)
     } else {
         let new_values = new_values.iter().map(|x| 
x.as_ref()).collect::<Vec<_>>();
         arrow::compute::concat(&new_values)?
     };
 
     Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
-        Arc::new(Field::new_list_field(data_type, true)),
+        Arc::clone(list_field),
         OffsetBuffer::new(offsets.into()),
         values,
         list_array.nulls().cloned(),
     )?))
 }
+
+#[cfg(test)]
+mod tests {
+    use crate::remove::{ArrayRemove, ArrayRemoveAll, ArrayRemoveN};
+    use arrow::array::{
+        Array, ArrayRef, AsArray, GenericListArray, ListArray, OffsetSizeTrait,
+    };
+    use arrow::datatypes::{DataType, Field, Int32Type};
+    use datafusion_common::ScalarValue;
+    use datafusion_expr::{ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl};
+    use datafusion_expr_common::columnar_value::ColumnarValue;
+    use std::ops::Deref;
+    use std::sync::Arc;
+
+    #[test]
+    fn test_array_remove_nullability() {
+        for nullability in [true, false] {
+            for item_nullability in [true, false] {
+                let input_field = Arc::new(Field::new(
+                    "num",
+                    DataType::new_list(DataType::Int32, item_nullability),
+                    nullability,
+                ));
+                let args_fields = vec![
+                    Arc::clone(&input_field),
+                    Arc::new(Field::new("a", DataType::Int32, false)),
+                ];
+                let scalar_args = vec![None, 
Some(&ScalarValue::Int32(Some(1)))];
+
+                let result = ArrayRemove::new()
+                    .return_field_from_args(ReturnFieldArgs {
+                        arg_fields: &args_fields,
+                        scalar_arguments: &scalar_args,
+                    })
+                    .unwrap();
+
+                assert_eq!(result, input_field);
+            }
+        }
+    }
+
+    #[test]
+    fn test_array_remove_n_nullability() {
+        for nullability in [true, false] {
+            for item_nullability in [true, false] {
+                let input_field = Arc::new(Field::new(
+                    "num",
+                    DataType::new_list(DataType::Int32, item_nullability),
+                    nullability,
+                ));
+                let args_fields = vec![
+                    Arc::clone(&input_field),
+                    Arc::new(Field::new("a", DataType::Int32, false)),
+                    Arc::new(Field::new("b", DataType::Int64, false)),
+                ];
+                let scalar_args = vec![
+                    None,
+                    Some(&ScalarValue::Int32(Some(1))),
+                    Some(&ScalarValue::Int64(Some(1))),
+                ];
+
+                let result = ArrayRemoveN::new()
+                    .return_field_from_args(ReturnFieldArgs {
+                        arg_fields: &args_fields,
+                        scalar_arguments: &scalar_args,
+                    })
+                    .unwrap();
+
+                assert_eq!(result, input_field);
+            }
+        }
+    }
+
+    #[test]
+    fn test_array_remove_all_nullability() {
+        for nullability in [true, false] {
+            for item_nullability in [true, false] {
+                let input_field = Arc::new(Field::new(
+                    "num",
+                    DataType::new_list(DataType::Int32, item_nullability),
+                    nullability,
+                ));
+                let result = ArrayRemoveAll::new()
+                    .return_field_from_args(ReturnFieldArgs {
+                        arg_fields: &[Arc::clone(&input_field)],
+                        scalar_arguments: &[None],
+                    })
+                    .unwrap();
+
+                assert_eq!(result, input_field);
+            }
+        }
+    }
+
+    fn ensure_field_nullability<O: OffsetSizeTrait>(
+        field_nullable: bool,
+        list: GenericListArray<O>,
+    ) -> GenericListArray<O> {
+        let (field, offsets, values, nulls) = list.into_parts();
+
+        if field.is_nullable() == field_nullable {
+            return GenericListArray::new(field, offsets, values, nulls);
+        }
+        if !field_nullable {
+            assert_eq!(nulls, None);
+        }
+
+        let field = 
Arc::new(field.deref().clone().with_nullable(field_nullable));
+
+        GenericListArray::new(field, offsets, values, nulls)
+    }
+
+    #[test]
+    fn test_array_remove_non_nullable() {
+        let input_list = Arc::new(ensure_field_nullability(
+            false,
+            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+                Some(([1, 2, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
+                Some(([42, 2, 55, 63, 2]).iter().copied().map(Some)),
+            ]),
+        ));
+        let expected_list = ensure_field_nullability(
+            false,
+            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+                Some(([1, 2, 3, 2, 1, 4]).iter().copied().map(Some)),
+                Some(([42, 55, 63, 2]).iter().copied().map(Some)),
+            ]),
+        );
+
+        let element_to_remove = ScalarValue::Int32(Some(2));
+
+        assert_array_remove(input_list, expected_list, element_to_remove);
+    }
+
+    #[test]
+    fn test_array_remove_nullable() {
+        let input_list = Arc::new(ensure_field_nullability(
+            true,
+            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+                Some(vec![
+                    Some(1),
+                    Some(2),
+                    Some(2),
+                    Some(3),
+                    None,
+                    Some(1),
+                    Some(4),
+                ]),
+                Some(vec![Some(42), Some(2), None, Some(63), Some(2)]),
+            ]),
+        ));
+        let expected_list = ensure_field_nullability(
+            true,
+            ListArray::from_iter_primitive::<Int32Type, _, _>(vec![
+                Some(vec![Some(1), Some(2), Some(3), None, Some(1), Some(4)]),
+                Some(vec![Some(42), None, Some(63), Some(2)]),
+            ]),
+        );
+
+        let element_to_remove = ScalarValue::Int32(Some(2));
+
+        assert_array_remove(input_list, expected_list, element_to_remove);
+    }
+
+    fn assert_array_remove(
+        input_list: ArrayRef,
+        expected_list: GenericListArray<i32>,
+        element_to_remove: ScalarValue,
+    ) {
+        assert_eq!(input_list.data_type(), expected_list.data_type());
+        assert_eq!(expected_list.value_type(), element_to_remove.data_type());
+        let input_list_len = input_list.len();
+        let input_list_data_type = input_list.data_type().clone();
+
+        let udf = ArrayRemove::new();
+        let args_fields = vec![
+            Arc::new(Field::new("num", input_list.data_type().clone(), false)),
+            Arc::new(Field::new(
+                "el",
+                element_to_remove.data_type(),
+                element_to_remove.is_null(),
+            )),
+        ];
+        let scalar_args = vec![None, Some(&element_to_remove)];
+
+        let return_field = udf
+            .return_field_from_args(ReturnFieldArgs {
+                arg_fields: &args_fields,
+                scalar_arguments: &scalar_args,
+            })
+            .unwrap();
+
+        let result = udf
+            .invoke_with_args(ScalarFunctionArgs {
+                args: vec![
+                    ColumnarValue::Array(input_list),
+                    ColumnarValue::Scalar(element_to_remove),
+                ],
+                arg_fields: args_fields,
+                number_rows: input_list_len,
+                return_field,
+                config_options: Arc::new(Default::default()),
+            })
+            .unwrap();
+
+        assert_eq!(result.data_type(), input_list_data_type);
+        match result {
+            ColumnarValue::Array(array) => {
+                let result_list = array.as_list::<i32>();
+                assert_eq!(result_list, &expected_list);
+            }
+            _ => panic!("Expected ColumnarValue::Array"),
+        }
+    }
+
+    #[test]
+    fn test_array_remove_n_non_nullable() {

Review Comment:
   Fails on main



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to