This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 7c2c2f0297 Replace macro with function for `array_position` and
`array_positions` (#8170)
7c2c2f0297 is described below
commit 7c2c2f029730756d433602a3cc501f695792e58d
Author: Jay Zhan <[email protected]>
AuthorDate: Thu Nov 16 01:52:53 2023 +0800
Replace macro with function for `array_position` and `array_positions`
(#8170)
* basic one
Signed-off-by: jayzhan211 <[email protected]>
* complete n
Signed-off-by: jayzhan211 <[email protected]>
* positions done
Signed-off-by: jayzhan211 <[email protected]>
* compare_element_to_list
Signed-off-by: jayzhan211 <[email protected]>
* fmt
Signed-off-by: jayzhan211 <[email protected]>
* resolve rebase
Signed-off-by: jayzhan211 <[email protected]>
* fmt
Signed-off-by: jayzhan211 <[email protected]>
---------
Signed-off-by: jayzhan211 <[email protected]>
---
datafusion/physical-expr/src/array_expressions.rs | 309 ++++++++++++----------
datafusion/sqllogictest/test_files/array.slt | 12 +-
2 files changed, 168 insertions(+), 153 deletions(-)
diff --git a/datafusion/physical-expr/src/array_expressions.rs
b/datafusion/physical-expr/src/array_expressions.rs
index 01d495ee7f..515df2a970 100644
--- a/datafusion/physical-expr/src/array_expressions.rs
+++ b/datafusion/physical-expr/src/array_expressions.rs
@@ -131,6 +131,78 @@ macro_rules! array {
}};
}
+/// Computes a BooleanArray indicating equality or inequality between elements
in a list array and a specified element array.
+///
+/// # Arguments
+///
+/// * `list_array_row` - A reference to a trait object implementing the Arrow
`Array` trait. It represents the list array for which the equality or
inequality will be compared.
+///
+/// * `element_array` - A reference to a trait object implementing the Arrow
`Array` trait. It represents the array with which each element in the
`list_array_row` will be compared.
+///
+/// * `row_index` - The index of the row in the `element_array` and
`list_array` to use for the comparison.
+///
+/// * `eq` - A boolean flag. If `true`, the function computes equality; if
`false`, it computes inequality.
+///
+/// # Returns
+///
+/// Returns a `Result<BooleanArray>` representing the comparison results. The
result may contain an error if there are issues with the computation.
+///
+/// # Example
+///
+/// ```text
+/// compare_element_to_list(
+/// [1, 2, 3], [1, 2, 3], 0, true => [true, false, false]
+/// [1, 2, 3, 3, 2, 1], [1, 2, 3], 1, true => [false, true, false, false,
true, false]
+///
+/// [[1, 2, 3], [2, 3, 4], [3, 4, 5]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]],
0, true => [true, false, false]
+/// [[1, 2, 3], [2, 3, 4], [2, 3, 4]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]],
1, false => [true, false, false]
+/// )
+/// ```
+fn compare_element_to_list(
+ list_array_row: &dyn Array,
+ element_array: &dyn Array,
+ row_index: usize,
+ eq: bool,
+) -> Result<BooleanArray> {
+ let indices = UInt32Array::from(vec![row_index as u32]);
+ let element_array_row = arrow::compute::take(element_array, &indices,
None)?;
+ // Compute all positions in list_row_array (that is itself an
+ // array) that are equal to `from_array_row`
+ let res = match element_array_row.data_type() {
+ // arrow_ord::cmp::eq does not support ListArray, so we need to
compare it by loop
+ DataType::List(_) => {
+ // compare each element of the from array
+ let element_array_row_inner =
as_list_array(&element_array_row)?.value(0);
+ let list_array_row_inner = as_list_array(list_array_row)?;
+
+ list_array_row_inner
+ .iter()
+ // compare element by element the current row of list_array
+ .map(|row| {
+ row.map(|row| {
+ if eq {
+ row.eq(&element_array_row_inner)
+ } else {
+ row.ne(&element_array_row_inner)
+ }
+ })
+ })
+ .collect::<BooleanArray>()
+ }
+ _ => {
+ let element_arr = Scalar::new(element_array_row);
+ // use not_distinct so we can compare NULL
+ if eq {
+ arrow_ord::cmp::not_distinct(&list_array_row, &element_arr)?
+ } else {
+ arrow_ord::cmp::distinct(&list_array_row, &element_arr)?
+ }
+ }
+ };
+
+ Ok(res)
+}
+
/// Returns the length of a concrete array dimension
fn compute_array_length(
arr: Option<ArrayRef>,
@@ -1005,114 +1077,68 @@ fn general_list_repeat(
)?))
}
-macro_rules! position {
- ($ARRAY:expr, $ELEMENT:expr, $INDEX:expr, $ARRAY_TYPE:ident) => {{
- let element = downcast_arg!($ELEMENT, $ARRAY_TYPE);
- $ARRAY
- .iter()
- .zip(element.iter())
- .zip($INDEX.iter())
- .map(|((arr, el), i)| {
- let index = match i {
- Some(i) => {
- if i <= 0 {
- 0
- } else {
- i - 1
- }
- }
- None => return exec_err!("initial position must not be
null"),
- };
-
- match arr {
- Some(arr) => {
- let child_array = downcast_arg!(arr, $ARRAY_TYPE);
-
- match child_array
- .iter()
- .skip(index as usize)
- .position(|x| x == el)
- {
- Some(value) => Ok(Some(value as u64 + index as u64
+ 1u64)),
- None => Ok(None),
- }
- }
- None => Ok(None),
- }
- })
- .collect::<Result<UInt64Array>>()?
- }};
-}
-
/// Array_position SQL function
pub fn array_position(args: &[ArrayRef]) -> Result<ArrayRef> {
- let arr = as_list_array(&args[0])?;
- let element = &args[1];
+ let list_array = as_list_array(&args[0])?;
+ let element_array = &args[1];
- let index = if args.len() == 3 {
- as_int64_array(&args[2])?.clone()
+ check_datatypes("array_position", &[list_array.values(), element_array])?;
+
+ let arr_from = if args.len() == 3 {
+ as_int64_array(&args[2])?
+ .values()
+ .to_vec()
+ .iter()
+ .map(|&x| x - 1)
+ .collect::<Vec<_>>()
} else {
- Int64Array::from_value(0, arr.len())
+ vec![0; list_array.len()]
};
- check_datatypes("array_position", &[arr.values(), element])?;
- macro_rules! array_function {
- ($ARRAY_TYPE:ident) => {
- position!(arr, element, index, $ARRAY_TYPE)
- };
+ // if `start_from` index is out of bounds, return error
+ for (arr, &from) in list_array.iter().zip(arr_from.iter()) {
+ if let Some(arr) = arr {
+ if from < 0 || from as usize >= arr.len() {
+ return internal_err!("start_from index out of bounds");
+ }
+ } else {
+ // We will get null if we got null in the array, so we don't need
to check
+ }
}
- let res = call_array_function!(arr.value_type(), true);
- Ok(Arc::new(res))
+ general_position::<i32>(list_array, element_array, arr_from)
}
-macro_rules! positions {
- ($ARRAY:expr, $ELEMENT:expr, $ARRAY_TYPE:ident) => {{
- let element = downcast_arg!($ELEMENT, $ARRAY_TYPE);
- let mut offsets: Vec<i32> = vec![0];
- let mut values =
- downcast_arg!(new_empty_array(&DataType::UInt64),
UInt64Array).clone();
- for comp in $ARRAY
- .iter()
- .zip(element.iter())
- .map(|(arr, el)| match arr {
- Some(arr) => {
- let child_array = downcast_arg!(arr, $ARRAY_TYPE);
- let res = child_array
- .iter()
- .enumerate()
- .filter(|(_, x)| *x == el)
- .flat_map(|(i, _)| Some((i + 1) as u64))
- .collect::<UInt64Array>();
+fn general_position<OffsetSize: OffsetSizeTrait>(
+ list_array: &GenericListArray<OffsetSize>,
+ element_array: &ArrayRef,
+ arr_from: Vec<i64>, // 0-indexed
+) -> Result<ArrayRef> {
+ let mut data = Vec::with_capacity(list_array.len());
- Ok(res)
- }
- None => Ok(downcast_arg!(
- new_empty_array(&DataType::UInt64),
- UInt64Array
- )
- .clone()),
- })
- .collect::<Result<Vec<UInt64Array>>>()?
- {
- let last_offset: i32 = offsets.last().copied().ok_or_else(|| {
- DataFusionError::Internal(format!("offsets should not be
empty",))
- })?;
- values =
- downcast_arg!(compute::concat(&[&values, &comp,])?.clone(),
UInt64Array)
- .clone();
- offsets.push(last_offset + comp.len() as i32);
- }
+ for (row_index, (list_array_row, &from)) in
+ list_array.iter().zip(arr_from.iter()).enumerate()
+ {
+ let from = from as usize;
- let field = Arc::new(Field::new("item", DataType::UInt64, true));
+ if let Some(list_array_row) = list_array_row {
+ let eq_array =
+ compare_element_to_list(&list_array_row, element_array,
row_index, true)?;
- Arc::new(ListArray::try_new(
- field,
- OffsetBuffer::new(offsets.into()),
- Arc::new(values),
- None,
- )?)
- }};
+ // Collect `true`s in 1-indexed positions
+ let index = eq_array
+ .iter()
+ .skip(from)
+ .position(|e| e == Some(true))
+ .map(|index| (from + index + 1) as u64);
+
+ data.push(index);
+ } else {
+ data.push(None);
+ }
+ }
+
+ Ok(Arc::new(UInt64Array::from(data)))
}
/// Array_positions SQL function
@@ -1121,14 +1147,37 @@ pub fn array_positions(args: &[ArrayRef]) ->
Result<ArrayRef> {
let element = &args[1];
check_datatypes("array_positions", &[arr.values(), element])?;
- macro_rules! array_function {
- ($ARRAY_TYPE:ident) => {
- positions!(arr, element, $ARRAY_TYPE)
- };
+
+ general_positions::<i32>(arr, element)
+}
+
+fn general_positions<OffsetSize: OffsetSizeTrait>(
+ list_array: &GenericListArray<OffsetSize>,
+ element_array: &ArrayRef,
+) -> Result<ArrayRef> {
+ let mut data = Vec::with_capacity(list_array.len());
+
+ for (row_index, list_array_row) in list_array.iter().enumerate() {
+ if let Some(list_array_row) = list_array_row {
+ let eq_array =
+ compare_element_to_list(&list_array_row, element_array,
row_index, true)?;
+
+ // Collect `true`s in 1-indexed positions
+ let indexes = eq_array
+ .iter()
+ .positions(|e| e == Some(true))
+ .map(|index| Some(index as u64 + 1))
+ .collect::<Vec<_>>();
+
+ data.push(Some(indexes));
+ } else {
+ data.push(None);
+ }
}
- let res = call_array_function!(arr.value_type(), true);
- Ok(res)
+ Ok(Arc::new(
+ ListArray::from_iter_primitive::<UInt64Type, _, _>(data),
+ ))
}
/// For each element of `list_array[i]`, removed up to `arr_n[i]` occurences
@@ -1165,30 +1214,12 @@ fn general_remove<OffsetSize: OffsetSizeTrait>(
{
match list_array_row {
Some(list_array_row) => {
- let indices = UInt32Array::from(vec![row_index as u32]);
- let element_array_row =
- arrow::compute::take(element_array, &indices, None)?;
-
- let eq_array = match element_array_row.data_type() {
- // arrow_ord::cmp::distinct does not support ListArray, so
we need to compare it by loop
- DataType::List(_) => {
- // compare each element of the from array
- let element_array_row_inner =
- as_list_array(&element_array_row)?.value(0);
- let list_array_row_inner =
as_list_array(&list_array_row)?;
-
- list_array_row_inner
- .iter()
- // compare element by element the current row of
list_array
- .map(|row| row.map(|row|
row.ne(&element_array_row_inner)))
- .collect::<BooleanArray>()
- }
- _ => {
- let from_arr = Scalar::new(element_array_row);
- // use distinct so Null = Null is false
- arrow_ord::cmp::distinct(&list_array_row, &from_arr)?
- }
- };
+ let eq_array = compare_element_to_list(
+ &list_array_row,
+ element_array,
+ row_index,
+ false,
+ )?;
// We need to keep at most first n elements as `false`, which
represent the elements to remove.
let eq_array = if eq_array.false_count() < *n as usize {
@@ -1313,30 +1344,14 @@ fn general_replace(
match list_array_row {
Some(list_array_row) => {
- let indices = UInt32Array::from(vec![row_index as u32]);
- let from_array_row = arrow::compute::take(from_array,
&indices, None)?;
// Compute all positions in list_row_array (that is itself an
// array) that are equal to `from_array_row`
- let eq_array = match from_array_row.data_type() {
- // arrow_ord::cmp::eq does not support ListArray, so we
need to compare it by loop
- DataType::List(_) => {
- // compare each element of the from array
- let from_array_row_inner =
- as_list_array(&from_array_row)?.value(0);
- let list_array_row_inner =
as_list_array(&list_array_row)?;
-
- list_array_row_inner
- .iter()
- // compare element by element the current row of
list_array
- .map(|row| row.map(|row|
row.eq(&from_array_row_inner)))
- .collect::<BooleanArray>()
- }
- _ => {
- let from_arr = Scalar::new(from_array_row);
- // use not_distinct so NULL = NULL
- arrow_ord::cmp::not_distinct(&list_array_row,
&from_arr)?
- }
- };
+ let eq_array = compare_element_to_list(
+ &list_array_row,
+ &from_array,
+ row_index,
+ true,
+ )?;
// Use MutableArrayData to build the replaced array
let original_data = list_array_row.to_data();
diff --git a/datafusion/sqllogictest/test_files/array.slt
b/datafusion/sqllogictest/test_files/array.slt
index 92013f37d3..67cabb0988 100644
--- a/datafusion/sqllogictest/test_files/array.slt
+++ b/datafusion/sqllogictest/test_files/array.slt
@@ -702,7 +702,7 @@ select array_element(make_array(1, 2, 3, 4, 5), 0),
array_element(make_array('h'
NULL NULL
# array_element scalar function #4 (with NULL)
-query error
+query error
select array_element(make_array(1, 2, 3, 4, 5), NULL),
array_element(make_array('h', 'e', 'l', 'l', 'o'), NULL);
# array_element scalar function #5 (with negative index)
@@ -871,11 +871,11 @@ select array_slice(make_array(1, 2, 3, 4, 5), 0, 4),
array_slice(make_array('h',
[1, 2, 3, 4] [h, e, l]
# array_slice scalar function #8 (with NULL and positive number)
-query error
+query error
select array_slice(make_array(1, 2, 3, 4, 5), NULL, 4),
array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, 3);
# array_slice scalar function #9 (with positive number and NULL)
-query error
+query error
select array_slice(make_array(1, 2, 3, 4, 5), 2, NULL),
array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, NULL);
# array_slice scalar function #10 (with zero-zero)
@@ -885,7 +885,7 @@ select array_slice(make_array(1, 2, 3, 4, 5), 0, 0),
array_slice(make_array('h',
[] []
# array_slice scalar function #11 (with NULL-NULL)
-query error
+query error
select array_slice(make_array(1, 2, 3, 4, 5), NULL),
array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL);
# array_slice scalar function #12 (with zero and negative number)
@@ -895,11 +895,11 @@ select array_slice(make_array(1, 2, 3, 4, 5), 0, -4),
array_slice(make_array('h'
[1] [h, e]
# array_slice scalar function #13 (with negative number and NULL)
-query error
+query error
select array_slice(make_array(1, 2, 3, 4, 5), 2, NULL),
array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, NULL);
# array_slice scalar function #14 (with NULL and negative number)
-query error
+query error
select array_slice(make_array(1, 2, 3, 4, 5), NULL, -4),
array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, -3);
# array_slice scalar function #15 (with negative indexes)