Jefffrey commented on code in PR #22387:
URL: https://github.com/apache/datafusion/pull/22387#discussion_r3292149592


##########
datafusion/functions-nested/src/replace.rs:
##########
@@ -200,7 +219,41 @@ impl ScalarUDFImpl for ArrayReplaceN {
     }
 
     fn invoke_with_args(&self, args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
-        make_scalar_function(array_replace_n_inner)(&args.args)
+        let [list_arg, from_arg, to_arg, max_arg] =
+            take_function_args(self.name(), &args.args)?;
+        let num_rows = args.number_rows;
+        let list_array = list_arg.to_array(num_rows)?;
+        match (from_arg, to_arg, max_arg) {
+            (
+                ColumnarValue::Scalar(scalar_from),
+                ColumnarValue::Scalar(scalar_to),
+                ColumnarValue::Scalar(scalar_max),
+            ) => {
+                let a = scalar_max.to_array_of_size(1)?;
+                let n = as_int64_array(&a)?.value(0);

Review Comment:
   Could destructure like so:
   
   ```rust
   let ScalarValue::Int64(Some(n)) = scalar_max else {
       return exec_err!("");
   };
   ```
   
   - No need to roundtrip through an array
   
   But it does raise the question of how to handle nulls here; as it is, it 
ignores any null mask present



##########
datafusion/functions-nested/src/replace.rs:
##########
@@ -200,7 +219,41 @@ impl ScalarUDFImpl for ArrayReplaceN {
     }
 
     fn invoke_with_args(&self, args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
-        make_scalar_function(array_replace_n_inner)(&args.args)
+        let [list_arg, from_arg, to_arg, max_arg] =
+            take_function_args(self.name(), &args.args)?;
+        let num_rows = args.number_rows;
+        let list_array = list_arg.to_array(num_rows)?;
+        match (from_arg, to_arg, max_arg) {
+            (
+                ColumnarValue::Scalar(scalar_from),
+                ColumnarValue::Scalar(scalar_to),
+                ColumnarValue::Scalar(scalar_max),
+            ) => {
+                let a = scalar_max.to_array_of_size(1)?;
+                let n = as_int64_array(&a)?.value(0);
+                let result = array_replace_with_scalar_args(
+                    &list_array,
+                    scalar_from,
+                    scalar_to,
+                    n,
+                )?;
+                Ok(ColumnarValue::Array(result))
+            }
+            (from_arg, to_arg, max_arg) => {
+                let from_array = from_arg.to_array(num_rows)?;
+                let to_array = to_arg.to_array(num_rows)?;
+                let arr_n = match max_arg {

Review Comment:
   Here we could just use 
[`ColumnarValue::to_array`](https://docs.rs/datafusion/latest/datafusion/logical_expr/enum.ColumnarValue.html#method.to_array)
   
   ```rust
   let arr_n = max_arg.to_array(1)?;
   ```
   
   Though again, we may need to consider nulls here



##########
datafusion/functions-nested/src/replace.rs:
##########
@@ -412,63 +490,155 @@ fn general_replace<O: OffsetSizeTrait>(
     )?))
 }
 
-fn array_replace_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
-    let [array, from, to] = take_function_args("array_replace", args)?;
+/// Replaces up to `max_replacements` occurrences of `needle` with the single
+/// element in `to_array` for each row in `list_array`.
+///
+/// This is a specialized fast path for the all-scalar case that uses a single
+/// bulk `not_distinct` comparison over only the visible values range, then
+/// iterates match positions via `set_indices` instead of scanning every bit.
+fn general_replace_with_scalar<O: OffsetSizeTrait>(
+    list_array: &GenericListArray<O>,
+    needle: &Scalar<ArrayRef>,
+    to_array: &ArrayRef,
+    max_replacements: i64,
+) -> Result<ArrayRef> {
+    let first_offset = list_array.offsets()[0].to_usize().unwrap();
+    let last_offset = 
list_array.offsets()[list_array.len()].to_usize().unwrap();
+    let visible_values = list_array
+        .values()
+        .slice(first_offset, last_offset - first_offset);
 
-    // replace at most one occurrence for each element
-    let arr_n = vec![1; array.len()];
-    match array.data_type() {
-        DataType::List(_) => {
-            let list_array = array.as_list::<i32>();
-            general_replace::<i32>(list_array, from, to, &arr_n)
+    let original_data = visible_values.to_data();
+    let to_data = to_array.to_data();
+    let capacity = Capacities::Array(original_data.len());
+
+    let mut mutable = MutableArrayData::with_capacities(
+        vec![&original_data, &to_data],
+        false,
+        capacity,
+    );
+
+    let mut offsets = OffsetBufferBuilder::<O>::new(list_array.len());
+
+    // Single bulk comparison over the visible values only.
+    let match_bitmap = arrow_ord::cmp::not_distinct(&visible_values, needle)?;
+    let match_bits = match_bitmap.values();
+
+    for (row_index, offset_window) in 
list_array.offsets().windows(2).enumerate() {
+        // Offsets relative to visible_values (subtract first_offset).
+        let start = offset_window[0].to_usize().unwrap() - first_offset;
+        let end = offset_window[1].to_usize().unwrap() - first_offset;
+        let row_len = end - start;
+
+        if list_array.is_null(row_index) {
+            offsets.push_length(0);
+            continue;
         }
-        DataType::LargeList(_) => {
-            let list_array = array.as_list::<i64>();
-            general_replace::<i64>(list_array, from, to, &arr_n)
+
+        if max_replacements <= 0 {

Review Comment:
   If `max_replacements` is `<= 0` then I think we can fast path at the start 
by just returning the input array?



##########
datafusion/functions-nested/src/replace.rs:
##########
@@ -412,63 +490,155 @@ fn general_replace<O: OffsetSizeTrait>(
     )?))
 }
 
-fn array_replace_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
-    let [array, from, to] = take_function_args("array_replace", args)?;
+/// Replaces up to `max_replacements` occurrences of `needle` with the single
+/// element in `to_array` for each row in `list_array`.
+///
+/// This is a specialized fast path for the all-scalar case that uses a single
+/// bulk `not_distinct` comparison over only the visible values range, then
+/// iterates match positions via `set_indices` instead of scanning every bit.
+fn general_replace_with_scalar<O: OffsetSizeTrait>(
+    list_array: &GenericListArray<O>,
+    needle: &Scalar<ArrayRef>,
+    to_array: &ArrayRef,
+    max_replacements: i64,
+) -> Result<ArrayRef> {
+    let first_offset = list_array.offsets()[0].to_usize().unwrap();
+    let last_offset = 
list_array.offsets()[list_array.len()].to_usize().unwrap();
+    let visible_values = list_array
+        .values()
+        .slice(first_offset, last_offset - first_offset);
 
-    // replace at most one occurrence for each element
-    let arr_n = vec![1; array.len()];
-    match array.data_type() {
-        DataType::List(_) => {
-            let list_array = array.as_list::<i32>();
-            general_replace::<i32>(list_array, from, to, &arr_n)
+    let original_data = visible_values.to_data();
+    let to_data = to_array.to_data();
+    let capacity = Capacities::Array(original_data.len());
+
+    let mut mutable = MutableArrayData::with_capacities(
+        vec![&original_data, &to_data],
+        false,
+        capacity,
+    );
+
+    let mut offsets = OffsetBufferBuilder::<O>::new(list_array.len());
+
+    // Single bulk comparison over the visible values only.
+    let match_bitmap = arrow_ord::cmp::not_distinct(&visible_values, needle)?;
+    let match_bits = match_bitmap.values();
+
+    for (row_index, offset_window) in 
list_array.offsets().windows(2).enumerate() {
+        // Offsets relative to visible_values (subtract first_offset).
+        let start = offset_window[0].to_usize().unwrap() - first_offset;
+        let end = offset_window[1].to_usize().unwrap() - first_offset;
+        let row_len = end - start;
+
+        if list_array.is_null(row_index) {
+            offsets.push_length(0);
+            continue;
         }
-        DataType::LargeList(_) => {
-            let list_array = array.as_list::<i64>();
-            general_replace::<i64>(list_array, from, to, &arr_n)
+
+        if max_replacements <= 0 {
+            mutable.extend(0, start, end);
+            offsets.push_length(row_len);
+            continue;
         }
-        DataType::Null => Ok(new_null_array(array.data_type(), 1)),
-        array_type => exec_err!("array_replace does not support type 
'{array_type}'."),
+
+        // Slice the match bits to this row and iterate only over true 
positions.
+        let row_bits = match_bits.slice(start, row_len);
+        let mut match_positions = row_bits
+            .set_indices()
+            .take(max_replacements as usize)
+            .peekable();
+        if match_positions.peek().is_none() {
+            mutable.extend(0, start, end);
+            offsets.push_length(row_len);
+            continue;
+        }
+
+        // Iterate only over the positions that match using set_indices,
+        // which is more efficient than scanning every bit because the number
+        // of matches is typically much smaller than the total array size.
+        let mut prev_end = 0usize;
+        for match_pos in match_positions {
+            // Retain elements before this match.
+            if match_pos > prev_end {
+                mutable.extend(0, start + prev_end, start + match_pos);
+            }
+            // Emit the replacement element.
+            mutable.extend(1, 0, 1);
+            prev_end = match_pos + 1;
+        }
+
+        // Copy remaining elements after the last replacement.
+        if prev_end < row_len {
+            mutable.extend(0, start + prev_end, end);
+        }
+
+        offsets.push_length(row_len);
     }
+
+    let data = mutable.freeze();
+
+    Ok(Arc::new(GenericListArray::<O>::try_new(
+        Arc::new(Field::new_list_field(list_array.value_type(), true)),
+        offsets.finish(),
+        arrow::array::make_array(data),
+        list_array.nulls().cloned(),
+    )?))
 }
 
-fn array_replace_n_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
-    let [array, from, to, max] = take_function_args("array_replace_n", args)?;
+/// Fast path for `array_replace` when all arguments are scalars.
+///
+/// Uses a single bulk `not_distinct` comparison instead of per-row 
comparisons.
+fn array_replace_with_scalar_args(
+    list_array: &ArrayRef,
+    scalar_from: &ScalarValue,
+    scalar_to: &ScalarValue,
+    max_replacements: i64,
+) -> Result<ArrayRef> {
+    // `not_distinct` doesn't support nested types, fall back to the generic 
array path.
+    if scalar_from.data_type().is_nested() {
+        let num_rows = list_array.len();
+        let from_array = scalar_from.to_array_of_size(num_rows)?;
+        let to_array = scalar_to.to_array_of_size(num_rows)?;
+        return array_replace_internal(
+            list_array,
+            &from_array,
+            &to_array,
+            &vec![max_replacements; num_rows],
+        );
+    }
 
-    // replace the specified number of occurrences
-    let arr_n = as_int64_array(max)?.values().to_vec();
-    match array.data_type() {
+    let needle = Scalar::new(scalar_from.to_array_of_size(1)?);
+    let to_array = scalar_to.to_array_of_size(1)?;

Review Comment:
   Personally I'd do this conversion of `to_array` inside 
`general_replace_with_scalar` to make it clear the only reason we do this is to 
convert it to an `ArrayData`



##########
datafusion/functions-nested/src/replace.rs:
##########
@@ -343,7 +417,11 @@ fn general_replace<O: OffsetSizeTrait>(
 
         let original_idx = O::usize_as(0);
         let replace_idx = O::usize_as(1);
-        let n = arr_n[row_index];
+        let n = if arr_n.len() == 1 {

Review Comment:
   If we're going to keep this logic, consider for `array_replace` and 
`array_replace_all` just passing `&[1]` and `&[i64::MAX]` respectively, instead 
of creating `vec![1; num_rows]` and `vec![i64::MAX, num_rows]`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to