This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 7c2c2f0297 Replace macro with function for `array_position` and 
`array_positions` (#8170)
7c2c2f0297 is described below

commit 7c2c2f029730756d433602a3cc501f695792e58d
Author: Jay Zhan <[email protected]>
AuthorDate: Thu Nov 16 01:52:53 2023 +0800

    Replace macro with function for `array_position` and `array_positions` 
(#8170)
    
    * basic one
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * complete n
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * positions done
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * compare_element_to_list
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fmt
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * resolve rebase
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * fmt
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    ---------
    
    Signed-off-by: jayzhan211 <[email protected]>
---
 datafusion/physical-expr/src/array_expressions.rs | 309 ++++++++++++----------
 datafusion/sqllogictest/test_files/array.slt      |  12 +-
 2 files changed, 168 insertions(+), 153 deletions(-)

diff --git a/datafusion/physical-expr/src/array_expressions.rs 
b/datafusion/physical-expr/src/array_expressions.rs
index 01d495ee7f..515df2a970 100644
--- a/datafusion/physical-expr/src/array_expressions.rs
+++ b/datafusion/physical-expr/src/array_expressions.rs
@@ -131,6 +131,78 @@ macro_rules! array {
     }};
 }
 
+/// Computes a BooleanArray indicating equality or inequality between elements 
in a list array and a specified element array.
+///
+/// # Arguments
+///
+/// * `list_array_row` - A reference to a trait object implementing the Arrow 
`Array` trait. It represents the list array for which the equality or 
inequality will be compared.
+///
+/// * `element_array` - A reference to a trait object implementing the Arrow 
`Array` trait. It represents the array with which each element in the 
`list_array_row` will be compared.
+///
+/// * `row_index` - The index of the row in the `element_array` and 
`list_array` to use for the comparison.
+///
+/// * `eq` - A boolean flag. If `true`, the function computes equality; if 
`false`, it computes inequality.
+///
+/// # Returns
+///
+/// Returns a `Result<BooleanArray>` representing the comparison results. The 
result may contain an error if there are issues with the computation.
+///
+/// # Example
+///
+/// ```text
+/// compare_element_to_list(
+///     [1, 2, 3], [1, 2, 3], 0, true => [true, false, false]
+///     [1, 2, 3, 3, 2, 1], [1, 2, 3], 1, true => [false, true, false, false, 
true, false]
+///
+///     [[1, 2, 3], [2, 3, 4], [3, 4, 5]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 
0, true => [true, false, false]
+///     [[1, 2, 3], [2, 3, 4], [2, 3, 4]], [[1, 2, 3], [2, 3, 4], [3, 4, 5]], 
1, false => [true, false, false]
+/// )
+/// ```
+fn compare_element_to_list(
+    list_array_row: &dyn Array,
+    element_array: &dyn Array,
+    row_index: usize,
+    eq: bool,
+) -> Result<BooleanArray> {
+    let indices = UInt32Array::from(vec![row_index as u32]);
+    let element_array_row = arrow::compute::take(element_array, &indices, 
None)?;
+    // Compute all positions in list_row_array (that is itself an
+    // array) that are equal to `from_array_row`
+    let res = match element_array_row.data_type() {
+        // arrow_ord::cmp::eq does not support ListArray, so we need to 
compare it by loop
+        DataType::List(_) => {
+            // compare each element of the from array
+            let element_array_row_inner = 
as_list_array(&element_array_row)?.value(0);
+            let list_array_row_inner = as_list_array(list_array_row)?;
+
+            list_array_row_inner
+                .iter()
+                // compare element by element the current row of list_array
+                .map(|row| {
+                    row.map(|row| {
+                        if eq {
+                            row.eq(&element_array_row_inner)
+                        } else {
+                            row.ne(&element_array_row_inner)
+                        }
+                    })
+                })
+                .collect::<BooleanArray>()
+        }
+        _ => {
+            let element_arr = Scalar::new(element_array_row);
+            // use not_distinct so we can compare NULL
+            if eq {
+                arrow_ord::cmp::not_distinct(&list_array_row, &element_arr)?
+            } else {
+                arrow_ord::cmp::distinct(&list_array_row, &element_arr)?
+            }
+        }
+    };
+
+    Ok(res)
+}
+
 /// Returns the length of a concrete array dimension
 fn compute_array_length(
     arr: Option<ArrayRef>,
@@ -1005,114 +1077,68 @@ fn general_list_repeat(
     )?))
 }
 
-macro_rules! position {
-    ($ARRAY:expr, $ELEMENT:expr, $INDEX:expr, $ARRAY_TYPE:ident) => {{
-        let element = downcast_arg!($ELEMENT, $ARRAY_TYPE);
-        $ARRAY
-            .iter()
-            .zip(element.iter())
-            .zip($INDEX.iter())
-            .map(|((arr, el), i)| {
-                let index = match i {
-                    Some(i) => {
-                        if i <= 0 {
-                            0
-                        } else {
-                            i - 1
-                        }
-                    }
-                    None => return exec_err!("initial position must not be 
null"),
-                };
-
-                match arr {
-                    Some(arr) => {
-                        let child_array = downcast_arg!(arr, $ARRAY_TYPE);
-
-                        match child_array
-                            .iter()
-                            .skip(index as usize)
-                            .position(|x| x == el)
-                        {
-                            Some(value) => Ok(Some(value as u64 + index as u64 
+ 1u64)),
-                            None => Ok(None),
-                        }
-                    }
-                    None => Ok(None),
-                }
-            })
-            .collect::<Result<UInt64Array>>()?
-    }};
-}
-
 /// Array_position SQL function
 pub fn array_position(args: &[ArrayRef]) -> Result<ArrayRef> {
-    let arr = as_list_array(&args[0])?;
-    let element = &args[1];
+    let list_array = as_list_array(&args[0])?;
+    let element_array = &args[1];
 
-    let index = if args.len() == 3 {
-        as_int64_array(&args[2])?.clone()
+    check_datatypes("array_position", &[list_array.values(), element_array])?;
+
+    let arr_from = if args.len() == 3 {
+        as_int64_array(&args[2])?
+            .values()
+            .to_vec()
+            .iter()
+            .map(|&x| x - 1)
+            .collect::<Vec<_>>()
     } else {
-        Int64Array::from_value(0, arr.len())
+        vec![0; list_array.len()]
     };
 
-    check_datatypes("array_position", &[arr.values(), element])?;
-    macro_rules! array_function {
-        ($ARRAY_TYPE:ident) => {
-            position!(arr, element, index, $ARRAY_TYPE)
-        };
+    // if `start_from` index is out of bounds, return error
+    for (arr, &from) in list_array.iter().zip(arr_from.iter()) {
+        if let Some(arr) = arr {
+            if from < 0 || from as usize >= arr.len() {
+                return internal_err!("start_from index out of bounds");
+            }
+        } else {
+            // We will get null if we got null in the array, so we don't need 
to check
+        }
     }
-    let res = call_array_function!(arr.value_type(), true);
 
-    Ok(Arc::new(res))
+    general_position::<i32>(list_array, element_array, arr_from)
 }
 
-macro_rules! positions {
-    ($ARRAY:expr, $ELEMENT:expr, $ARRAY_TYPE:ident) => {{
-        let element = downcast_arg!($ELEMENT, $ARRAY_TYPE);
-        let mut offsets: Vec<i32> = vec![0];
-        let mut values =
-            downcast_arg!(new_empty_array(&DataType::UInt64), 
UInt64Array).clone();
-        for comp in $ARRAY
-            .iter()
-            .zip(element.iter())
-            .map(|(arr, el)| match arr {
-                Some(arr) => {
-                    let child_array = downcast_arg!(arr, $ARRAY_TYPE);
-                    let res = child_array
-                        .iter()
-                        .enumerate()
-                        .filter(|(_, x)| *x == el)
-                        .flat_map(|(i, _)| Some((i + 1) as u64))
-                        .collect::<UInt64Array>();
+fn general_position<OffsetSize: OffsetSizeTrait>(
+    list_array: &GenericListArray<OffsetSize>,
+    element_array: &ArrayRef,
+    arr_from: Vec<i64>, // 0-indexed
+) -> Result<ArrayRef> {
+    let mut data = Vec::with_capacity(list_array.len());
 
-                    Ok(res)
-                }
-                None => Ok(downcast_arg!(
-                    new_empty_array(&DataType::UInt64),
-                    UInt64Array
-                )
-                .clone()),
-            })
-            .collect::<Result<Vec<UInt64Array>>>()?
-        {
-            let last_offset: i32 = offsets.last().copied().ok_or_else(|| {
-                DataFusionError::Internal(format!("offsets should not be 
empty",))
-            })?;
-            values =
-                downcast_arg!(compute::concat(&[&values, &comp,])?.clone(), 
UInt64Array)
-                    .clone();
-            offsets.push(last_offset + comp.len() as i32);
-        }
+    for (row_index, (list_array_row, &from)) in
+        list_array.iter().zip(arr_from.iter()).enumerate()
+    {
+        let from = from as usize;
 
-        let field = Arc::new(Field::new("item", DataType::UInt64, true));
+        if let Some(list_array_row) = list_array_row {
+            let eq_array =
+                compare_element_to_list(&list_array_row, element_array, 
row_index, true)?;
 
-        Arc::new(ListArray::try_new(
-            field,
-            OffsetBuffer::new(offsets.into()),
-            Arc::new(values),
-            None,
-        )?)
-    }};
+            // Collect `true`s in 1-indexed positions
+            let index = eq_array
+                .iter()
+                .skip(from)
+                .position(|e| e == Some(true))
+                .map(|index| (from + index + 1) as u64);
+
+            data.push(index);
+        } else {
+            data.push(None);
+        }
+    }
+
+    Ok(Arc::new(UInt64Array::from(data)))
 }
 
 /// Array_positions SQL function
@@ -1121,14 +1147,37 @@ pub fn array_positions(args: &[ArrayRef]) -> 
Result<ArrayRef> {
     let element = &args[1];
 
     check_datatypes("array_positions", &[arr.values(), element])?;
-    macro_rules! array_function {
-        ($ARRAY_TYPE:ident) => {
-            positions!(arr, element, $ARRAY_TYPE)
-        };
+
+    general_positions::<i32>(arr, element)
+}
+
+fn general_positions<OffsetSize: OffsetSizeTrait>(
+    list_array: &GenericListArray<OffsetSize>,
+    element_array: &ArrayRef,
+) -> Result<ArrayRef> {
+    let mut data = Vec::with_capacity(list_array.len());
+
+    for (row_index, list_array_row) in list_array.iter().enumerate() {
+        if let Some(list_array_row) = list_array_row {
+            let eq_array =
+                compare_element_to_list(&list_array_row, element_array, 
row_index, true)?;
+
+            // Collect `true`s in 1-indexed positions
+            let indexes = eq_array
+                .iter()
+                .positions(|e| e == Some(true))
+                .map(|index| Some(index as u64 + 1))
+                .collect::<Vec<_>>();
+
+            data.push(Some(indexes));
+        } else {
+            data.push(None);
+        }
     }
-    let res = call_array_function!(arr.value_type(), true);
 
-    Ok(res)
+    Ok(Arc::new(
+        ListArray::from_iter_primitive::<UInt64Type, _, _>(data),
+    ))
 }
 
 /// For each element of `list_array[i]`, removed up to `arr_n[i]`  occurences
@@ -1165,30 +1214,12 @@ fn general_remove<OffsetSize: OffsetSizeTrait>(
     {
         match list_array_row {
             Some(list_array_row) => {
-                let indices = UInt32Array::from(vec![row_index as u32]);
-                let element_array_row =
-                    arrow::compute::take(element_array, &indices, None)?;
-
-                let eq_array = match element_array_row.data_type() {
-                    // arrow_ord::cmp::distinct does not support ListArray, so 
we need to compare it by loop
-                    DataType::List(_) => {
-                        // compare each element of the from array
-                        let element_array_row_inner =
-                            as_list_array(&element_array_row)?.value(0);
-                        let list_array_row_inner = 
as_list_array(&list_array_row)?;
-
-                        list_array_row_inner
-                            .iter()
-                            // compare element by element the current row of 
list_array
-                            .map(|row| row.map(|row| 
row.ne(&element_array_row_inner)))
-                            .collect::<BooleanArray>()
-                    }
-                    _ => {
-                        let from_arr = Scalar::new(element_array_row);
-                        // use distinct so Null = Null is false
-                        arrow_ord::cmp::distinct(&list_array_row, &from_arr)?
-                    }
-                };
+                let eq_array = compare_element_to_list(
+                    &list_array_row,
+                    element_array,
+                    row_index,
+                    false,
+                )?;
 
                 // We need to keep at most first n elements as `false`, which 
represent the elements to remove.
                 let eq_array = if eq_array.false_count() < *n as usize {
@@ -1313,30 +1344,14 @@ fn general_replace(
 
         match list_array_row {
             Some(list_array_row) => {
-                let indices = UInt32Array::from(vec![row_index as u32]);
-                let from_array_row = arrow::compute::take(from_array, 
&indices, None)?;
                 // Compute all positions in list_row_array (that is itself an
                 // array) that are equal to `from_array_row`
-                let eq_array = match from_array_row.data_type() {
-                    // arrow_ord::cmp::eq does not support ListArray, so we 
need to compare it by loop
-                    DataType::List(_) => {
-                        // compare each element of the from array
-                        let from_array_row_inner =
-                            as_list_array(&from_array_row)?.value(0);
-                        let list_array_row_inner = 
as_list_array(&list_array_row)?;
-
-                        list_array_row_inner
-                            .iter()
-                            // compare element by element the current row of 
list_array
-                            .map(|row| row.map(|row| 
row.eq(&from_array_row_inner)))
-                            .collect::<BooleanArray>()
-                    }
-                    _ => {
-                        let from_arr = Scalar::new(from_array_row);
-                        // use not_distinct so NULL = NULL
-                        arrow_ord::cmp::not_distinct(&list_array_row, 
&from_arr)?
-                    }
-                };
+                let eq_array = compare_element_to_list(
+                    &list_array_row,
+                    &from_array,
+                    row_index,
+                    true,
+                )?;
 
                 // Use MutableArrayData to build the replaced array
                 let original_data = list_array_row.to_data();
diff --git a/datafusion/sqllogictest/test_files/array.slt 
b/datafusion/sqllogictest/test_files/array.slt
index 92013f37d3..67cabb0988 100644
--- a/datafusion/sqllogictest/test_files/array.slt
+++ b/datafusion/sqllogictest/test_files/array.slt
@@ -702,7 +702,7 @@ select array_element(make_array(1, 2, 3, 4, 5), 0), 
array_element(make_array('h'
 NULL NULL
 
 # array_element scalar function #4 (with NULL)
-query error 
+query error
 select array_element(make_array(1, 2, 3, 4, 5), NULL), 
array_element(make_array('h', 'e', 'l', 'l', 'o'), NULL);
 
 # array_element scalar function #5 (with negative index)
@@ -871,11 +871,11 @@ select array_slice(make_array(1, 2, 3, 4, 5), 0, 4), 
array_slice(make_array('h',
 [1, 2, 3, 4] [h, e, l]
 
 # array_slice scalar function #8 (with NULL and positive number)
-query error 
+query error
 select array_slice(make_array(1, 2, 3, 4, 5), NULL, 4), 
array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, 3);
 
 # array_slice scalar function #9 (with positive number and NULL)
-query error 
+query error
 select array_slice(make_array(1, 2, 3, 4, 5), 2, NULL), 
array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, NULL);
 
 # array_slice scalar function #10 (with zero-zero)
@@ -885,7 +885,7 @@ select array_slice(make_array(1, 2, 3, 4, 5), 0, 0), 
array_slice(make_array('h',
 [] []
 
 # array_slice scalar function #11 (with NULL-NULL)
-query error 
+query error
 select array_slice(make_array(1, 2, 3, 4, 5), NULL), 
array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL);
 
 # array_slice scalar function #12 (with zero and negative number)
@@ -895,11 +895,11 @@ select array_slice(make_array(1, 2, 3, 4, 5), 0, -4), 
array_slice(make_array('h'
 [1] [h, e]
 
 # array_slice scalar function #13 (with negative number and NULL)
-query error 
+query error
 select array_slice(make_array(1, 2, 3, 4, 5), 2, NULL), 
array_slice(make_array('h', 'e', 'l', 'l', 'o'), 3, NULL);
 
 # array_slice scalar function #14 (with NULL and negative number)
-query error 
+query error
 select array_slice(make_array(1, 2, 3, 4, 5), NULL, -4), 
array_slice(make_array('h', 'e', 'l', 'l', 'o'), NULL, -3);
 
 # array_slice scalar function #15 (with negative indexes)

Reply via email to