This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 55707dc0e8 Specialize ASCII case for substr() (#12444)
55707dc0e8 is described below
commit 55707dc0e83e00078423f35da4232111888319eb
Author: Yongting You <[email protected]>
AuthorDate: Wed Sep 18 00:40:34 2024 +0800
Specialize ASCII case for substr() (#12444)
* Specialize ASCII case for substr()
* cleanup + don't validate ASCII for short prefix
---
datafusion/functions/src/unicode/substr.rs | 146 ++++++++++++++++++++++++-----
1 file changed, 122 insertions(+), 24 deletions(-)
diff --git a/datafusion/functions/src/unicode/substr.rs
b/datafusion/functions/src/unicode/substr.rs
index 40d3a4d13e..5e311f1e18 100644
--- a/datafusion/functions/src/unicode/substr.rs
+++ b/datafusion/functions/src/unicode/substr.rs
@@ -16,18 +16,18 @@
// under the License.
use std::any::Any;
-use std::cmp::max;
use std::sync::Arc;
+use crate::string::common::StringArrayType;
use crate::utils::{make_scalar_function, utf8_to_str_type};
use arrow::array::{
- make_view, Array, ArrayAccessor, ArrayIter, ArrayRef, AsArray, ByteView,
- GenericStringArray, OffsetSizeTrait, StringViewArray,
+ make_view, Array, ArrayIter, ArrayRef, AsArray, ByteView,
GenericStringArray,
+ Int64Array, OffsetSizeTrait, StringViewArray,
};
use arrow::datatypes::DataType;
use arrow_buffer::{NullBufferBuilder, ScalarBuffer};
use datafusion_common::cast::as_int64_array;
-use datafusion_common::{exec_datafusion_err, exec_err, Result};
+use datafusion_common::{exec_err, Result};
use datafusion_expr::TypeSignature::Exact;
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
@@ -119,19 +119,27 @@ pub fn substr(args: &[ArrayRef]) -> Result<ArrayRef> {
}
// Convert the given `start` and `count` to valid byte indices within `input`
string
+//
// Input `start` and `count` are equivalent to PostgreSQL's `substr(s, start,
count)`
// `start` is 1-based, if `count` is not provided count to the end of the
string
// Input indices are character-based, and return values are byte indices
// The input bounds can be outside string bounds, this function will return
// the intersection between input bounds and valid string bounds
+// `input_ascii_only` is used to optimize this function if `input` is
ASCII-only
//
// * Example
// 'Hi🌏' in-mem (`[]` for one char, `x` for one byte): [x][x][xxxx]
// `get_true_start_end('Hi🌏', 1, None) -> (0, 6)`
// `get_true_start_end('Hi🌏', 1, 1) -> (0, 1)`
// `get_true_start_end('Hi🌏', -10, 2) -> (0, 0)`
-fn get_true_start_end(input: &str, start: i64, count: Option<u64>) -> (usize,
usize) {
- let start = start - 1;
+fn get_true_start_end(
+ input: &str,
+ start: i64,
+ count: Option<u64>,
+ is_input_ascii_only: bool,
+) -> (usize, usize) {
+ let start = start.checked_sub(1).unwrap_or(start);
+
let end = match count {
Some(count) => start + count as i64,
None => input.len() as i64,
@@ -142,6 +150,14 @@ fn get_true_start_end(input: &str, start: i64, count:
Option<u64>) -> (usize, us
let end = end.clamp(0, input.len() as i64) as usize;
let count = end - start;
+ // If input is ASCII-only, byte-based indices equals to char-based indices
+ if is_input_ascii_only {
+ return (start, end);
+ }
+
+ // Otherwise, calculate byte indices from char indices
+ // Note this decoding is relatively expensive for this simple `substr`
function,,
+ // so the implementation attempts to decode in one pass (and caused the
complexity)
let (mut st, mut ed) = (input.len(), input.len());
let mut start_counting = false;
let mut cnt = 0;
@@ -186,6 +202,53 @@ fn make_and_append_view(
null_builder.append_non_null();
}
+// String characters are variable length encoded in UTF-8, `substr()`
function's
+// arguments are character-based, converting them into byte-based indices
+// requires expensive decoding.
+// However, checking if a string is ASCII-only is relatively cheap.
+// If strings are ASCII only, use byte-based indices instead.
+//
+// A common pattern to call `substr()` is taking a small prefix of a long
+// string, such as `substr(long_str_with_1k_chars, 1, 32)`.
+// In such case the overhead of ASCII-validation may not be worth it, so
+// skip the validation for short prefix for now.
+fn enable_ascii_fast_path<'a, V: StringArrayType<'a>>(
+ string_array: &V,
+ start: &Int64Array,
+ count: Option<&Int64Array>,
+) -> bool {
+ let is_short_prefix = match count {
+ Some(count) => {
+ let short_prefix_threshold = 32.0;
+ let n_sample = 10;
+
+ // HACK: can be simplified if function has specialized
+ // implementation for `ScalarValue` (implement without
`make_scalar_function()`)
+ let avg_prefix_len = start
+ .iter()
+ .zip(count.iter())
+ .take(n_sample)
+ .map(|(start, count)| {
+ let start = start.unwrap_or(0);
+ let count = count.unwrap_or(0);
+ // To get substring, need to decode from 0 to start+count
instead of start to start+count
+ start + count
+ })
+ .sum::<i64>();
+
+ avg_prefix_len as f64 / n_sample as f64 <= short_prefix_threshold
+ }
+ None => false,
+ };
+
+ if is_short_prefix {
+ // Skip ASCII validation for short prefix
+ false
+ } else {
+ string_array.is_ascii()
+ }
+}
+
// The decoding process refs the trait at: arrow/arrow-data/src/byte_view.rs:44
// From<u128> for ByteView
fn string_view_substr(
@@ -196,6 +259,14 @@ fn string_view_substr(
let mut null_builder = NullBufferBuilder::new(string_view_array.len());
let start_array = as_int64_array(&args[0])?;
+ let count_array_opt = if args.len() == 2 {
+ Some(as_int64_array(&args[1])?)
+ } else {
+ None
+ };
+
+ let enable_ascii_fast_path =
+ enable_ascii_fast_path(&string_view_array, start_array,
count_array_opt);
// In either case of `substr(s, i)` or `substr(s, i, cnt)`
// If any of input argument is `NULL`, the result is `NULL`
@@ -207,7 +278,8 @@ fn string_view_substr(
.zip(start_array.iter())
{
if let (Some(str), Some(start)) = (str_opt, start_opt) {
- let (start, end) = get_true_start_end(str, start, None);
+ let (start, end) =
+ get_true_start_end(str, start, None,
enable_ascii_fast_path);
let substr = &str[start..end];
make_and_append_view(
@@ -224,7 +296,7 @@ fn string_view_substr(
}
}
2 => {
- let count_array = as_int64_array(&args[1])?;
+ let count_array = count_array_opt.unwrap();
for (((str_opt, raw_view), start_opt), count_opt) in
string_view_array
.iter()
.zip(string_view_array.views().iter())
@@ -239,8 +311,17 @@ fn string_view_substr(
"negative substring length not allowed:
substr(<str>, {start}, {count})"
);
} else {
- let (start, end) =
- get_true_start_end(str, start, Some(count as u64));
+ if start == i64::MIN {
+ return exec_err!(
+ "negative overflow when calculating skip value"
+ );
+ }
+ let (start, end) = get_true_start_end(
+ str,
+ start,
+ Some(count as u64),
+ enable_ascii_fast_path,
+ );
let substr = &str[start..end];
make_and_append_view(
@@ -283,23 +364,35 @@ fn string_view_substr(
fn string_substr<'a, V, T>(string_array: V, args: &[ArrayRef]) ->
Result<ArrayRef>
where
- V: ArrayAccessor<Item = &'a str>,
+ V: StringArrayType<'a>,
T: OffsetSizeTrait,
{
+ let start_array = as_int64_array(&args[0])?;
+ let count_array_opt = if args.len() == 2 {
+ Some(as_int64_array(&args[1])?)
+ } else {
+ None
+ };
+
+ let enable_ascii_fast_path =
+ enable_ascii_fast_path(&string_array, start_array, count_array_opt);
+
match args.len() {
1 => {
let iter = ArrayIter::new(string_array);
- let start_array = as_int64_array(&args[0])?;
let result = iter
.zip(start_array.iter())
.map(|(string, start)| match (string, start) {
(Some(string), Some(start)) => {
- if start <= 0 {
- Some(string.to_string())
- } else {
- Some(string.chars().skip(start as usize -
1).collect())
- }
+ let (start, end) = get_true_start_end(
+ string,
+ start,
+ None,
+ enable_ascii_fast_path,
+ ); // start, end is byte-based
+ let substr = &string[start..end];
+ Some(substr.to_string())
}
_ => None,
})
@@ -308,8 +401,7 @@ where
}
2 => {
let iter = ArrayIter::new(string_array);
- let start_array = as_int64_array(&args[0])?;
- let count_array = as_int64_array(&args[1])?;
+ let count_array = count_array_opt.unwrap();
let result = iter
.zip(start_array.iter())
@@ -322,11 +414,17 @@ where
"negative substring length not allowed:
substr(<str>, {start}, {count})"
)
} else {
- let skip = max(0,
start.checked_sub(1).ok_or_else(
- || exec_datafusion_err!("negative overflow
when calculating skip value")
- )?);
- let count = max(0, count + (if start < 1 {
start - 1 } else { 0 }));
- Ok(Some(string.chars().skip(skip as
usize).take(count as usize).collect::<String>()))
+ if start == i64::MIN {
+ return exec_err!("negative overflow when
calculating skip value")
+ }
+ let (start, end) = get_true_start_end(
+ string,
+ start,
+ Some(count as u64),
+ enable_ascii_fast_path,
+ ); // start, end is byte-based
+ let substr = &string[start..end];
+ Ok(Some(substr.to_string()))
}
}
_ => Ok(None),
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]