This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 55707dc0e8 Specialize ASCII case for substr() (#12444)
55707dc0e8 is described below

commit 55707dc0e83e00078423f35da4232111888319eb
Author: Yongting You <[email protected]>
AuthorDate: Wed Sep 18 00:40:34 2024 +0800

    Specialize ASCII case for substr() (#12444)
    
    * Specialize ASCII case for substr()
    
    * cleanup + don't validate ASCII for short prefix
---
 datafusion/functions/src/unicode/substr.rs | 146 ++++++++++++++++++++++++-----
 1 file changed, 122 insertions(+), 24 deletions(-)

diff --git a/datafusion/functions/src/unicode/substr.rs 
b/datafusion/functions/src/unicode/substr.rs
index 40d3a4d13e..5e311f1e18 100644
--- a/datafusion/functions/src/unicode/substr.rs
+++ b/datafusion/functions/src/unicode/substr.rs
@@ -16,18 +16,18 @@
 // under the License.
 
 use std::any::Any;
-use std::cmp::max;
 use std::sync::Arc;
 
+use crate::string::common::StringArrayType;
 use crate::utils::{make_scalar_function, utf8_to_str_type};
 use arrow::array::{
-    make_view, Array, ArrayAccessor, ArrayIter, ArrayRef, AsArray, ByteView,
-    GenericStringArray, OffsetSizeTrait, StringViewArray,
+    make_view, Array, ArrayIter, ArrayRef, AsArray, ByteView, 
GenericStringArray,
+    Int64Array, OffsetSizeTrait, StringViewArray,
 };
 use arrow::datatypes::DataType;
 use arrow_buffer::{NullBufferBuilder, ScalarBuffer};
 use datafusion_common::cast::as_int64_array;
-use datafusion_common::{exec_datafusion_err, exec_err, Result};
+use datafusion_common::{exec_err, Result};
 use datafusion_expr::TypeSignature::Exact;
 use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
 
@@ -119,19 +119,27 @@ pub fn substr(args: &[ArrayRef]) -> Result<ArrayRef> {
 }
 
 // Convert the given `start` and `count` to valid byte indices within `input` 
string
+//
 // Input `start` and `count` are equivalent to PostgreSQL's `substr(s, start, 
count)`
 // `start` is 1-based, if `count` is not provided count to the end of the 
string
 // Input indices are character-based, and return values are byte indices
 // The input bounds can be outside string bounds, this function will return
 // the intersection between input bounds and valid string bounds
+// `input_ascii_only` is used to optimize this function if `input` is 
ASCII-only
 //
 // * Example
 // 'Hi🌏' in-mem (`[]` for one char, `x` for one byte): [x][x][xxxx]
 // `get_true_start_end('Hi🌏', 1, None) -> (0, 6)`
 // `get_true_start_end('Hi🌏', 1, 1) -> (0, 1)`
 // `get_true_start_end('Hi🌏', -10, 2) -> (0, 0)`
-fn get_true_start_end(input: &str, start: i64, count: Option<u64>) -> (usize, 
usize) {
-    let start = start - 1;
+fn get_true_start_end(
+    input: &str,
+    start: i64,
+    count: Option<u64>,
+    is_input_ascii_only: bool,
+) -> (usize, usize) {
+    let start = start.checked_sub(1).unwrap_or(start);
+
     let end = match count {
         Some(count) => start + count as i64,
         None => input.len() as i64,
@@ -142,6 +150,14 @@ fn get_true_start_end(input: &str, start: i64, count: 
Option<u64>) -> (usize, us
     let end = end.clamp(0, input.len() as i64) as usize;
     let count = end - start;
 
+    // If input is ASCII-only, byte-based indices equals to char-based indices
+    if is_input_ascii_only {
+        return (start, end);
+    }
+
+    // Otherwise, calculate byte indices from char indices
+    // Note this decoding is relatively expensive for this simple `substr` 
function,,
+    // so the implementation attempts to decode in one pass (and caused the 
complexity)
     let (mut st, mut ed) = (input.len(), input.len());
     let mut start_counting = false;
     let mut cnt = 0;
@@ -186,6 +202,53 @@ fn make_and_append_view(
     null_builder.append_non_null();
 }
 
+// String characters are variable length encoded in UTF-8, `substr()` 
function's
+// arguments are character-based, converting them into byte-based indices
+// requires expensive decoding.
+// However, checking if a string is ASCII-only is relatively cheap.
+// If strings are ASCII only, use byte-based indices instead.
+//
+// A common pattern to call `substr()` is taking a small prefix of a long
+// string, such as `substr(long_str_with_1k_chars, 1, 32)`.
+// In such case the overhead of ASCII-validation may not be worth it, so
+// skip the validation for short prefix for now.
+fn enable_ascii_fast_path<'a, V: StringArrayType<'a>>(
+    string_array: &V,
+    start: &Int64Array,
+    count: Option<&Int64Array>,
+) -> bool {
+    let is_short_prefix = match count {
+        Some(count) => {
+            let short_prefix_threshold = 32.0;
+            let n_sample = 10;
+
+            // HACK: can be simplified if function has specialized
+            // implementation for `ScalarValue` (implement without 
`make_scalar_function()`)
+            let avg_prefix_len = start
+                .iter()
+                .zip(count.iter())
+                .take(n_sample)
+                .map(|(start, count)| {
+                    let start = start.unwrap_or(0);
+                    let count = count.unwrap_or(0);
+                    // To get substring, need to decode from 0 to start+count 
instead of start to start+count
+                    start + count
+                })
+                .sum::<i64>();
+
+            avg_prefix_len as f64 / n_sample as f64 <= short_prefix_threshold
+        }
+        None => false,
+    };
+
+    if is_short_prefix {
+        // Skip ASCII validation for short prefix
+        false
+    } else {
+        string_array.is_ascii()
+    }
+}
+
 // The decoding process refs the trait at: arrow/arrow-data/src/byte_view.rs:44
 // From<u128> for ByteView
 fn string_view_substr(
@@ -196,6 +259,14 @@ fn string_view_substr(
     let mut null_builder = NullBufferBuilder::new(string_view_array.len());
 
     let start_array = as_int64_array(&args[0])?;
+    let count_array_opt = if args.len() == 2 {
+        Some(as_int64_array(&args[1])?)
+    } else {
+        None
+    };
+
+    let enable_ascii_fast_path =
+        enable_ascii_fast_path(&string_view_array, start_array, 
count_array_opt);
 
     // In either case of `substr(s, i)` or `substr(s, i, cnt)`
     // If any of input argument is `NULL`, the result is `NULL`
@@ -207,7 +278,8 @@ fn string_view_substr(
                 .zip(start_array.iter())
             {
                 if let (Some(str), Some(start)) = (str_opt, start_opt) {
-                    let (start, end) = get_true_start_end(str, start, None);
+                    let (start, end) =
+                        get_true_start_end(str, start, None, 
enable_ascii_fast_path);
                     let substr = &str[start..end];
 
                     make_and_append_view(
@@ -224,7 +296,7 @@ fn string_view_substr(
             }
         }
         2 => {
-            let count_array = as_int64_array(&args[1])?;
+            let count_array = count_array_opt.unwrap();
             for (((str_opt, raw_view), start_opt), count_opt) in 
string_view_array
                 .iter()
                 .zip(string_view_array.views().iter())
@@ -239,8 +311,17 @@ fn string_view_substr(
                             "negative substring length not allowed: 
substr(<str>, {start}, {count})"
                         );
                     } else {
-                        let (start, end) =
-                            get_true_start_end(str, start, Some(count as u64));
+                        if start == i64::MIN {
+                            return exec_err!(
+                                "negative overflow when calculating skip value"
+                            );
+                        }
+                        let (start, end) = get_true_start_end(
+                            str,
+                            start,
+                            Some(count as u64),
+                            enable_ascii_fast_path,
+                        );
                         let substr = &str[start..end];
 
                         make_and_append_view(
@@ -283,23 +364,35 @@ fn string_view_substr(
 
 fn string_substr<'a, V, T>(string_array: V, args: &[ArrayRef]) -> 
Result<ArrayRef>
 where
-    V: ArrayAccessor<Item = &'a str>,
+    V: StringArrayType<'a>,
     T: OffsetSizeTrait,
 {
+    let start_array = as_int64_array(&args[0])?;
+    let count_array_opt = if args.len() == 2 {
+        Some(as_int64_array(&args[1])?)
+    } else {
+        None
+    };
+
+    let enable_ascii_fast_path =
+        enable_ascii_fast_path(&string_array, start_array, count_array_opt);
+
     match args.len() {
         1 => {
             let iter = ArrayIter::new(string_array);
-            let start_array = as_int64_array(&args[0])?;
 
             let result = iter
                 .zip(start_array.iter())
                 .map(|(string, start)| match (string, start) {
                     (Some(string), Some(start)) => {
-                        if start <= 0 {
-                            Some(string.to_string())
-                        } else {
-                            Some(string.chars().skip(start as usize - 
1).collect())
-                        }
+                        let (start, end) = get_true_start_end(
+                            string,
+                            start,
+                            None,
+                            enable_ascii_fast_path,
+                        ); // start, end is byte-based
+                        let substr = &string[start..end];
+                        Some(substr.to_string())
                     }
                     _ => None,
                 })
@@ -308,8 +401,7 @@ where
         }
         2 => {
             let iter = ArrayIter::new(string_array);
-            let start_array = as_int64_array(&args[0])?;
-            let count_array = as_int64_array(&args[1])?;
+            let count_array = count_array_opt.unwrap();
 
             let result = iter
                 .zip(start_array.iter())
@@ -322,11 +414,17 @@ where
                                 "negative substring length not allowed: 
substr(<str>, {start}, {count})"
                             )
                             } else {
-                                let skip = max(0, 
start.checked_sub(1).ok_or_else(
-                                    || exec_datafusion_err!("negative overflow 
when calculating skip value")
-                                )?);
-                                let count = max(0, count + (if start < 1 { 
start - 1 } else { 0 }));
-                                Ok(Some(string.chars().skip(skip as 
usize).take(count as usize).collect::<String>()))
+                                if start == i64::MIN {
+                                    return exec_err!("negative overflow when 
calculating skip value")
+                                }
+                                let (start, end) = get_true_start_end(
+                                    string,
+                                    start,
+                                    Some(count as u64),
+                                    enable_ascii_fast_path,
+                                ); // start, end is byte-based
+                                let substr = &string[start..end];
+                                Ok(Some(substr.to_string()))
                             }
                         }
                         _ => Ok(None),


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to