This is an automated email from the ASF dual-hosted git repository.
github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new bb4e0eca22 perf: Optimize `starts_with` and `ends_with` for scalar
arguments (#19516)
bb4e0eca22 is described below
commit bb4e0eca22da3da209ebf0410214a1ed1e48b115
Author: Andy Grove <[email protected]>
AuthorDate: Sun Dec 28 12:04:04 2025 -0700
perf: Optimize `starts_with` and `ends_with` for scalar arguments (#19516)
## Which issue does this PR close?
<!--
We generally require a GitHub issue to be filed for all bug fixes and
enhancements and this helps us generate change logs for our releases.
You can link an issue to this PR using the GitHub syntax. For example
`Closes #123` indicates that this PR will close issue #123.
-->
- Closes #.
## Rationale for this change
<!--
Why are you proposing this change? If this is already explained clearly
in the issue then this section is not needed.
Explaining clearly why changes are proposed helps reviewers understand
your changes and offer better suggestions for fixes.
-->
- Scalar argument optimization delivers 3.6x-8x speedup for the common
case of starts_with(column, 'literal') or ends_with(column, 'literal')
- StringViewArray benefits even more (~6-8x) than StringArray
(~3.6-3.8x)
- The optimization uses Arrow's Scalar wrapper to avoid broadcasting
scalar values to full arrays
### starts_with
| Benchmark | Before | After | Speedup |
|--------------------------|----------|----------|---------|
| StringArray + scalar | 32.38 µs | 8.49 µs | 3.8x |
| StringViewArray + scalar | 78.15 µs | 9.82 µs | 8.0x |
### ends_with
| Benchmark | Before | After | Speedup |
|--------------------------|----------|----------|---------|
| StringArray + scalar | 32.76 µs | 9.06 µs | 3.6x |
| StringViewArray + scalar | 76.44 µs | 12.04 µs | 6.4x |
## What changes are included in this PR?
<!--
There is no need to duplicate the description in the issue here but it
is sometimes worth providing a summary of the individual changes in this
PR.
-->
Handle all combinations of array and scalar arguments without converting
scalars to arrays
## Are these changes tested?
<!--
We typically require tests for all PRs in order to:
1. Prevent the code from being accidentally broken by subsequent changes
2. Serve as another way to document the expected behavior of the code
If tests are not included in your PR, please explain why (for example,
are they covered by existing tests)?
-->
Yes, new unit tests added in this PR.
## Are there any user-facing changes?
<!--
If there are user-facing changes then we may require documentation to be
updated before approving the PR.
-->
<!--
If there are any breaking changes to public APIs, please add the `api
change` label.
-->
No, just faster performance.
---
datafusion/functions/Cargo.toml | 10 +
datafusion/functions/benches/ends_with.rs | 185 +++++++++++++++
datafusion/functions/benches/starts_with.rs | 185 +++++++++++++++
datafusion/functions/src/string/ends_with.rs | 298 +++++++++++++++++++++----
datafusion/functions/src/string/starts_with.rs | 269 ++++++++++++++++++----
5 files changed, 868 insertions(+), 79 deletions(-)
diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml
index 3e832691f9..765f5d865a 100644
--- a/datafusion/functions/Cargo.toml
+++ b/datafusion/functions/Cargo.toml
@@ -254,3 +254,13 @@ required-features = ["unicode_expressions"]
harness = false
name = "find_in_set"
required-features = ["unicode_expressions"]
+
+[[bench]]
+harness = false
+name = "starts_with"
+required-features = ["string_expressions"]
+
+[[bench]]
+harness = false
+name = "ends_with"
+required-features = ["string_expressions"]
diff --git a/datafusion/functions/benches/ends_with.rs
b/datafusion/functions/benches/ends_with.rs
new file mode 100644
index 0000000000..926fd9ff72
--- /dev/null
+++ b/datafusion/functions/benches/ends_with.rs
@@ -0,0 +1,185 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+extern crate criterion;
+
+use arrow::array::{StringArray, StringViewArray};
+use arrow::datatypes::{DataType, Field};
+use criterion::{Criterion, criterion_group, criterion_main};
+use datafusion_common::ScalarValue;
+use datafusion_common::config::ConfigOptions;
+use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
+use rand::distr::Alphanumeric;
+use rand::prelude::StdRng;
+use rand::{Rng, SeedableRng};
+use std::hint::black_box;
+use std::sync::Arc;
+
+/// Generate a StringArray/StringViewArray with random ASCII strings
+fn gen_string_array(
+ n_rows: usize,
+ str_len: usize,
+ is_string_view: bool,
+) -> ColumnarValue {
+ let mut rng = StdRng::seed_from_u64(42);
+ let strings: Vec<Option<String>> = (0..n_rows)
+ .map(|_| {
+ let s: String = (&mut rng)
+ .sample_iter(&Alphanumeric)
+ .take(str_len)
+ .map(char::from)
+ .collect();
+ Some(s)
+ })
+ .collect();
+
+ if is_string_view {
+ ColumnarValue::Array(Arc::new(StringViewArray::from(strings)))
+ } else {
+ ColumnarValue::Array(Arc::new(StringArray::from(strings)))
+ }
+}
+
+/// Generate a scalar suffix string
+fn gen_scalar_suffix(suffix_str: &str, is_string_view: bool) -> ColumnarValue {
+ if is_string_view {
+
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(suffix_str.to_string())))
+ } else {
+ ColumnarValue::Scalar(ScalarValue::Utf8(Some(suffix_str.to_string())))
+ }
+}
+
+/// Generate an array of suffix strings (same string repeated)
+fn gen_array_suffix(
+ suffix_str: &str,
+ n_rows: usize,
+ is_string_view: bool,
+) -> ColumnarValue {
+ let strings: Vec<Option<String>> =
+ (0..n_rows).map(|_| Some(suffix_str.to_string())).collect();
+
+ if is_string_view {
+ ColumnarValue::Array(Arc::new(StringViewArray::from(strings)))
+ } else {
+ ColumnarValue::Array(Arc::new(StringArray::from(strings)))
+ }
+}
+
+fn criterion_benchmark(c: &mut Criterion) {
+ let ends_with = datafusion_functions::string::ends_with();
+ let n_rows = 8192;
+ let str_len = 128;
+ let suffix_str = "xyz"; // A pattern that likely won't match
+
+ // Benchmark: StringArray with scalar suffix (the optimized path)
+ let str_array = gen_string_array(n_rows, str_len, false);
+ let scalar_suffix = gen_scalar_suffix(suffix_str, false);
+ let arg_fields = vec![
+ Field::new("a", DataType::Utf8, true).into(),
+ Field::new("b", DataType::Utf8, true).into(),
+ ];
+ let return_field = Field::new("f", DataType::Boolean, true).into();
+ let config_options = Arc::new(ConfigOptions::default());
+
+ c.bench_function("ends_with_StringArray_scalar_suffix", |b| {
+ b.iter(|| {
+ black_box(ends_with.invoke_with_args(ScalarFunctionArgs {
+ args: vec![str_array.clone(), scalar_suffix.clone()],
+ arg_fields: arg_fields.clone(),
+ number_rows: n_rows,
+ return_field: Arc::clone(&return_field),
+ config_options: Arc::clone(&config_options),
+ }))
+ })
+ });
+
+ // Benchmark: StringArray with array suffix (for comparison)
+ let array_suffix = gen_array_suffix(suffix_str, n_rows, false);
+ c.bench_function("ends_with_StringArray_array_suffix", |b| {
+ b.iter(|| {
+ black_box(ends_with.invoke_with_args(ScalarFunctionArgs {
+ args: vec![str_array.clone(), array_suffix.clone()],
+ arg_fields: arg_fields.clone(),
+ number_rows: n_rows,
+ return_field: Arc::clone(&return_field),
+ config_options: Arc::clone(&config_options),
+ }))
+ })
+ });
+
+ // Benchmark: StringViewArray with scalar suffix (the optimized path)
+ let str_view_array = gen_string_array(n_rows, str_len, true);
+ let scalar_suffix_view = gen_scalar_suffix(suffix_str, true);
+ let arg_fields_view = vec![
+ Field::new("a", DataType::Utf8View, true).into(),
+ Field::new("b", DataType::Utf8View, true).into(),
+ ];
+
+ c.bench_function("ends_with_StringViewArray_scalar_suffix", |b| {
+ b.iter(|| {
+ black_box(ends_with.invoke_with_args(ScalarFunctionArgs {
+ args: vec![str_view_array.clone(), scalar_suffix_view.clone()],
+ arg_fields: arg_fields_view.clone(),
+ number_rows: n_rows,
+ return_field: Arc::clone(&return_field),
+ config_options: Arc::clone(&config_options),
+ }))
+ })
+ });
+
+ // Benchmark: StringViewArray with array suffix (for comparison)
+ let array_suffix_view = gen_array_suffix(suffix_str, n_rows, true);
+ c.bench_function("ends_with_StringViewArray_array_suffix", |b| {
+ b.iter(|| {
+ black_box(ends_with.invoke_with_args(ScalarFunctionArgs {
+ args: vec![str_view_array.clone(), array_suffix_view.clone()],
+ arg_fields: arg_fields_view.clone(),
+ number_rows: n_rows,
+ return_field: Arc::clone(&return_field),
+ config_options: Arc::clone(&config_options),
+ }))
+ })
+ });
+
+ // Benchmark different string lengths with scalar suffix
+ for str_len in [8, 32, 128, 512] {
+ let str_array = gen_string_array(n_rows, str_len, true);
+ let scalar_suffix = gen_scalar_suffix(suffix_str, true);
+ let arg_fields = vec![
+ Field::new("a", DataType::Utf8View, true).into(),
+ Field::new("b", DataType::Utf8View, true).into(),
+ ];
+
+ c.bench_function(
+ &format!("ends_with_StringViewArray_scalar_strlen_{str_len}"),
+ |b| {
+ b.iter(|| {
+ black_box(ends_with.invoke_with_args(ScalarFunctionArgs {
+ args: vec![str_array.clone(), scalar_suffix.clone()],
+ arg_fields: arg_fields.clone(),
+ number_rows: n_rows,
+ return_field: Arc::clone(&return_field),
+ config_options: Arc::clone(&config_options),
+ }))
+ })
+ },
+ );
+ }
+}
+
+criterion_group!(benches, criterion_benchmark);
+criterion_main!(benches);
diff --git a/datafusion/functions/benches/starts_with.rs
b/datafusion/functions/benches/starts_with.rs
new file mode 100644
index 0000000000..9ee39b6945
--- /dev/null
+++ b/datafusion/functions/benches/starts_with.rs
@@ -0,0 +1,185 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+extern crate criterion;
+
+use arrow::array::{StringArray, StringViewArray};
+use arrow::datatypes::{DataType, Field};
+use criterion::{Criterion, criterion_group, criterion_main};
+use datafusion_common::ScalarValue;
+use datafusion_common::config::ConfigOptions;
+use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
+use rand::distr::Alphanumeric;
+use rand::prelude::StdRng;
+use rand::{Rng, SeedableRng};
+use std::hint::black_box;
+use std::sync::Arc;
+
+/// Generate a StringArray/StringViewArray with random ASCII strings
+fn gen_string_array(
+ n_rows: usize,
+ str_len: usize,
+ is_string_view: bool,
+) -> ColumnarValue {
+ let mut rng = StdRng::seed_from_u64(42);
+ let strings: Vec<Option<String>> = (0..n_rows)
+ .map(|_| {
+ let s: String = (&mut rng)
+ .sample_iter(&Alphanumeric)
+ .take(str_len)
+ .map(char::from)
+ .collect();
+ Some(s)
+ })
+ .collect();
+
+ if is_string_view {
+ ColumnarValue::Array(Arc::new(StringViewArray::from(strings)))
+ } else {
+ ColumnarValue::Array(Arc::new(StringArray::from(strings)))
+ }
+}
+
+/// Generate a scalar prefix string
+fn gen_scalar_prefix(prefix_str: &str, is_string_view: bool) -> ColumnarValue {
+ if is_string_view {
+
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(prefix_str.to_string())))
+ } else {
+ ColumnarValue::Scalar(ScalarValue::Utf8(Some(prefix_str.to_string())))
+ }
+}
+
+/// Generate an array of prefix strings (same string repeated)
+fn gen_array_prefix(
+ prefix_str: &str,
+ n_rows: usize,
+ is_string_view: bool,
+) -> ColumnarValue {
+ let strings: Vec<Option<String>> =
+ (0..n_rows).map(|_| Some(prefix_str.to_string())).collect();
+
+ if is_string_view {
+ ColumnarValue::Array(Arc::new(StringViewArray::from(strings)))
+ } else {
+ ColumnarValue::Array(Arc::new(StringArray::from(strings)))
+ }
+}
+
+fn criterion_benchmark(c: &mut Criterion) {
+ let starts_with = datafusion_functions::string::starts_with();
+ let n_rows = 8192;
+ let str_len = 128;
+ let prefix_str = "xyz"; // A pattern that likely won't match
+
+ // Benchmark: StringArray with scalar prefix (the optimized path)
+ let str_array = gen_string_array(n_rows, str_len, false);
+ let scalar_prefix = gen_scalar_prefix(prefix_str, false);
+ let arg_fields = vec![
+ Field::new("a", DataType::Utf8, true).into(),
+ Field::new("b", DataType::Utf8, true).into(),
+ ];
+ let return_field = Field::new("f", DataType::Boolean, true).into();
+ let config_options = Arc::new(ConfigOptions::default());
+
+ c.bench_function("starts_with_StringArray_scalar_prefix", |b| {
+ b.iter(|| {
+ black_box(starts_with.invoke_with_args(ScalarFunctionArgs {
+ args: vec![str_array.clone(), scalar_prefix.clone()],
+ arg_fields: arg_fields.clone(),
+ number_rows: n_rows,
+ return_field: Arc::clone(&return_field),
+ config_options: Arc::clone(&config_options),
+ }))
+ })
+ });
+
+ // Benchmark: StringArray with array prefix (for comparison)
+ let array_prefix = gen_array_prefix(prefix_str, n_rows, false);
+ c.bench_function("starts_with_StringArray_array_prefix", |b| {
+ b.iter(|| {
+ black_box(starts_with.invoke_with_args(ScalarFunctionArgs {
+ args: vec![str_array.clone(), array_prefix.clone()],
+ arg_fields: arg_fields.clone(),
+ number_rows: n_rows,
+ return_field: Arc::clone(&return_field),
+ config_options: Arc::clone(&config_options),
+ }))
+ })
+ });
+
+ // Benchmark: StringViewArray with scalar prefix (the optimized path)
+ let str_view_array = gen_string_array(n_rows, str_len, true);
+ let scalar_prefix_view = gen_scalar_prefix(prefix_str, true);
+ let arg_fields_view = vec![
+ Field::new("a", DataType::Utf8View, true).into(),
+ Field::new("b", DataType::Utf8View, true).into(),
+ ];
+
+ c.bench_function("starts_with_StringViewArray_scalar_prefix", |b| {
+ b.iter(|| {
+ black_box(starts_with.invoke_with_args(ScalarFunctionArgs {
+ args: vec![str_view_array.clone(), scalar_prefix_view.clone()],
+ arg_fields: arg_fields_view.clone(),
+ number_rows: n_rows,
+ return_field: Arc::clone(&return_field),
+ config_options: Arc::clone(&config_options),
+ }))
+ })
+ });
+
+ // Benchmark: StringViewArray with array prefix (for comparison)
+ let array_prefix_view = gen_array_prefix(prefix_str, n_rows, true);
+ c.bench_function("starts_with_StringViewArray_array_prefix", |b| {
+ b.iter(|| {
+ black_box(starts_with.invoke_with_args(ScalarFunctionArgs {
+ args: vec![str_view_array.clone(), array_prefix_view.clone()],
+ arg_fields: arg_fields_view.clone(),
+ number_rows: n_rows,
+ return_field: Arc::clone(&return_field),
+ config_options: Arc::clone(&config_options),
+ }))
+ })
+ });
+
+ // Benchmark different string lengths with scalar prefix
+ for str_len in [8, 32, 128, 512] {
+ let str_array = gen_string_array(n_rows, str_len, true);
+ let scalar_prefix = gen_scalar_prefix(prefix_str, true);
+ let arg_fields = vec![
+ Field::new("a", DataType::Utf8View, true).into(),
+ Field::new("b", DataType::Utf8View, true).into(),
+ ];
+
+ c.bench_function(
+ &format!("starts_with_StringViewArray_scalar_strlen_{str_len}"),
+ |b| {
+ b.iter(|| {
+ black_box(starts_with.invoke_with_args(ScalarFunctionArgs {
+ args: vec![str_array.clone(), scalar_prefix.clone()],
+ arg_fields: arg_fields.clone(),
+ number_rows: n_rows,
+ return_field: Arc::clone(&return_field),
+ config_options: Arc::clone(&config_options),
+ }))
+ })
+ },
+ );
+ }
+}
+
+criterion_group!(benches, criterion_benchmark);
+criterion_main!(benches);
diff --git a/datafusion/functions/src/string/ends_with.rs
b/datafusion/functions/src/string/ends_with.rs
index e3fa7c92ca..a1fa124548 100644
--- a/datafusion/functions/src/string/ends_with.rs
+++ b/datafusion/functions/src/string/ends_with.rs
@@ -18,12 +18,12 @@
use std::any::Any;
use std::sync::Arc;
-use arrow::array::ArrayRef;
+use arrow::array::{ArrayRef, Scalar};
+use arrow::compute::kernels::comparison::ends_with as arrow_ends_with;
use arrow::datatypes::DataType;
-use crate::utils::make_scalar_function;
use datafusion_common::types::logical_string;
-use datafusion_common::{Result, internal_err};
+use datafusion_common::{Result, ScalarValue, exec_err};
use datafusion_expr::binary::{binary_to_string_coercion, string_coercion};
use datafusion_expr::{
Coercion, ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl,
Signature,
@@ -95,13 +95,76 @@ impl ScalarUDFImpl for EndsWithFunc {
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) ->
Result<ColumnarValue> {
- match args.args[0].data_type() {
- DataType::Utf8View | DataType::Utf8 | DataType::LargeUtf8 => {
- make_scalar_function(ends_with, vec![])(&args.args)
+ let [str_arg, suffix_arg] = args.args.as_slice() else {
+ return exec_err!(
+ "ends_with was called with {} arguments, expected 2",
+ args.args.len()
+ );
+ };
+
+ // Determine the common type for coercion
+ let coercion_type = string_coercion(
+ &str_arg.data_type(),
+ &suffix_arg.data_type(),
+ )
+ .or_else(|| {
+ binary_to_string_coercion(&str_arg.data_type(),
&suffix_arg.data_type())
+ });
+
+ let Some(coercion_type) = coercion_type else {
+ return exec_err!(
+ "Unsupported data types {:?}, {:?} for function `ends_with`.",
+ str_arg.data_type(),
+ suffix_arg.data_type()
+ );
+ };
+
+ // Helper to cast an array if needed
+ let maybe_cast = |arr: &ArrayRef, target: &DataType| ->
Result<ArrayRef> {
+ if arr.data_type() == target {
+ Ok(Arc::clone(arr))
+ } else {
+ Ok(arrow::compute::kernels::cast::cast(arr, target)?)
+ }
+ };
+
+ match (str_arg, suffix_arg) {
+ // Both scalars - just compute directly
+ (ColumnarValue::Scalar(str_scalar),
ColumnarValue::Scalar(suffix_scalar)) => {
+ let str_arr = str_scalar.to_array_of_size(1)?;
+ let suffix_arr = suffix_scalar.to_array_of_size(1)?;
+ let str_arr = maybe_cast(&str_arr, &coercion_type)?;
+ let suffix_arr = maybe_cast(&suffix_arr, &coercion_type)?;
+ let result = arrow_ends_with(&str_arr, &suffix_arr)?;
+ Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(
+ &result, 0,
+ )?))
+ }
+ // String is array, suffix is scalar - use Scalar wrapper for
optimization
+ (ColumnarValue::Array(str_arr),
ColumnarValue::Scalar(suffix_scalar)) => {
+ let str_arr = maybe_cast(str_arr, &coercion_type)?;
+ let suffix_arr = suffix_scalar.to_array_of_size(1)?;
+ let suffix_arr = maybe_cast(&suffix_arr, &coercion_type)?;
+ let suffix_scalar = Scalar::new(suffix_arr);
+ let result = arrow_ends_with(&str_arr, &suffix_scalar)?;
+ Ok(ColumnarValue::Array(Arc::new(result)))
+ }
+ // String is scalar, suffix is array - use Scalar wrapper for
string
+ (ColumnarValue::Scalar(str_scalar),
ColumnarValue::Array(suffix_arr)) => {
+ let str_arr = str_scalar.to_array_of_size(1)?;
+ let str_arr = maybe_cast(&str_arr, &coercion_type)?;
+ let str_scalar = Scalar::new(str_arr);
+ let suffix_arr = maybe_cast(suffix_arr, &coercion_type)?;
+ let result = arrow_ends_with(&str_scalar, &suffix_arr)?;
+ Ok(ColumnarValue::Array(Arc::new(result)))
+ }
+ // Both arrays - pass directly
+ (ColumnarValue::Array(str_arr), ColumnarValue::Array(suffix_arr))
=> {
+ let str_arr = maybe_cast(str_arr, &coercion_type)?;
+ let suffix_arr = maybe_cast(suffix_arr, &coercion_type)?;
+ let result = arrow_ends_with(&str_arr, &suffix_arr)?;
+ Ok(ColumnarValue::Array(Arc::new(result)))
}
- other => internal_err!(
- "Unsupported data type {other:?} for function ends_with.
Expected Utf8, LargeUtf8 or Utf8View"
- )?,
}
}
@@ -110,47 +173,24 @@ impl ScalarUDFImpl for EndsWithFunc {
}
}
-/// Returns true if string ends with suffix.
-/// ends_with('alphabet', 'abet') = 't'
-fn ends_with(args: &[ArrayRef]) -> Result<ArrayRef> {
- if let Some(coercion_data_type) =
- string_coercion(args[0].data_type(), args[1].data_type()).or_else(|| {
- binary_to_string_coercion(args[0].data_type(), args[1].data_type())
- })
- {
- let arg0 = if args[0].data_type() == &coercion_data_type {
- Arc::clone(&args[0])
- } else {
- arrow::compute::kernels::cast::cast(&args[0], &coercion_data_type)?
- };
- let arg1 = if args[1].data_type() == &coercion_data_type {
- Arc::clone(&args[1])
- } else {
- arrow::compute::kernels::cast::cast(&args[1], &coercion_data_type)?
- };
- let result = arrow::compute::kernels::comparison::ends_with(&arg0,
&arg1)?;
- Ok(Arc::new(result) as ArrayRef)
- } else {
- internal_err!(
- "Unsupported data types for ends_with. Expected Utf8, LargeUtf8 or
Utf8View"
- )
- }
-}
-
#[cfg(test)]
mod tests {
- use arrow::array::{Array, BooleanArray};
+ use arrow::array::{Array, BooleanArray, StringArray};
use arrow::datatypes::DataType::Boolean;
+ use arrow::datatypes::{DataType, Field};
+ use std::sync::Arc;
use datafusion_common::Result;
use datafusion_common::ScalarValue;
- use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
+ use datafusion_common::config::ConfigOptions;
+ use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
use crate::string::ends_with::EndsWithFunc;
use crate::utils::test::test_function;
#[test]
- fn test_functions() -> Result<()> {
+ fn test_scalar_scalar() -> Result<()> {
+ // Test Scalar + Scalar combinations
test_function!(
EndsWithFunc::new(),
vec![
@@ -196,6 +236,186 @@ mod tests {
BooleanArray
);
+ // Test with LargeUtf8
+ test_function!(
+ EndsWithFunc::new(),
+ vec![
+ ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(
+ "alphabet".to_string()
+ ))),
+
ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("bet".to_string()))),
+ ],
+ Ok(Some(true)),
+ bool,
+ Boolean,
+ BooleanArray
+ );
+
+ // Test with Utf8View
+ test_function!(
+ EndsWithFunc::new(),
+ vec![
+ ColumnarValue::Scalar(ScalarValue::Utf8View(Some(
+ "alphabet".to_string()
+ ))),
+
ColumnarValue::Scalar(ScalarValue::Utf8View(Some("bet".to_string()))),
+ ],
+ Ok(Some(true)),
+ bool,
+ Boolean,
+ BooleanArray
+ );
+
+ Ok(())
+ }
+
+ #[test]
+ fn test_array_scalar() -> Result<()> {
+ // Test Array + Scalar (the optimized path)
+ let array = ColumnarValue::Array(Arc::new(StringArray::from(vec![
+ Some("alphabet"),
+ Some("alphabet"),
+ Some("beta"),
+ None,
+ ])));
+ let scalar =
ColumnarValue::Scalar(ScalarValue::Utf8(Some("bet".to_string())));
+
+ let args = vec![array, scalar];
+ test_function!(
+ EndsWithFunc::new(),
+ args,
+ Ok(Some(true)), // First element result: "alphabet" ends with "bet"
+ bool,
+ Boolean,
+ BooleanArray
+ );
+
Ok(())
}
+
+ #[test]
+ fn test_array_scalar_full_result() {
+ // Test Array + Scalar and verify all results
+ let func = EndsWithFunc::new();
+ let array = Arc::new(StringArray::from(vec![
+ Some("alphabet"),
+ Some("alphabet"),
+ Some("beta"),
+ None,
+ ]));
+ let args = vec![
+ ColumnarValue::Array(array),
+ ColumnarValue::Scalar(ScalarValue::Utf8(Some("bet".to_string()))),
+ ];
+
+ let result = func
+ .invoke_with_args(ScalarFunctionArgs {
+ args,
+ arg_fields: vec![
+ Field::new("a", DataType::Utf8, true).into(),
+ Field::new("b", DataType::Utf8, true).into(),
+ ],
+ number_rows: 4,
+ return_field: Field::new("f", Boolean, true).into(),
+ config_options: Arc::new(ConfigOptions::default()),
+ })
+ .unwrap();
+
+ let result_array = result.into_array(4).unwrap();
+ let bool_array = result_array
+ .as_any()
+ .downcast_ref::<BooleanArray>()
+ .unwrap();
+
+ assert!(bool_array.value(0)); // "alphabet" ends with "bet"
+ assert!(bool_array.value(1)); // "alphabet" ends with "bet"
+ assert!(!bool_array.value(2)); // "beta" does not end with "bet"
+ assert!(bool_array.is_null(3)); // null input -> null output
+ }
+
+ #[test]
+ fn test_scalar_array() {
+ // Test Scalar + Array
+ let func = EndsWithFunc::new();
+ let suffixes = Arc::new(StringArray::from(vec![
+ Some("bet"),
+ Some("alph"),
+ Some("phabet"),
+ None,
+ ]));
+ let args = vec![
+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("alphabet".to_string()))),
+ ColumnarValue::Array(suffixes),
+ ];
+
+ let result = func
+ .invoke_with_args(ScalarFunctionArgs {
+ args,
+ arg_fields: vec![
+ Field::new("a", DataType::Utf8, true).into(),
+ Field::new("b", DataType::Utf8, true).into(),
+ ],
+ number_rows: 4,
+ return_field: Field::new("f", Boolean, true).into(),
+ config_options: Arc::new(ConfigOptions::default()),
+ })
+ .unwrap();
+
+ let result_array = result.into_array(4).unwrap();
+ let bool_array = result_array
+ .as_any()
+ .downcast_ref::<BooleanArray>()
+ .unwrap();
+
+ assert!(bool_array.value(0)); // "alphabet" ends with "bet"
+ assert!(!bool_array.value(1)); // "alphabet" does not end with "alph"
+ assert!(bool_array.value(2)); // "alphabet" ends with "phabet"
+ assert!(bool_array.is_null(3)); // null suffix -> null output
+ }
+
+ #[test]
+ fn test_array_array() {
+ // Test Array + Array
+ let func = EndsWithFunc::new();
+ let strings = Arc::new(StringArray::from(vec![
+ Some("alphabet"),
+ Some("rust"),
+ Some("datafusion"),
+ None,
+ ]));
+ let suffixes = Arc::new(StringArray::from(vec![
+ Some("bet"),
+ Some("st"),
+ Some("hello"),
+ Some("test"),
+ ]));
+ let args = vec![
+ ColumnarValue::Array(strings),
+ ColumnarValue::Array(suffixes),
+ ];
+
+ let result = func
+ .invoke_with_args(ScalarFunctionArgs {
+ args,
+ arg_fields: vec![
+ Field::new("a", DataType::Utf8, true).into(),
+ Field::new("b", DataType::Utf8, true).into(),
+ ],
+ number_rows: 4,
+ return_field: Field::new("f", Boolean, true).into(),
+ config_options: Arc::new(ConfigOptions::default()),
+ })
+ .unwrap();
+
+ let result_array = result.into_array(4).unwrap();
+ let bool_array = result_array
+ .as_any()
+ .downcast_ref::<BooleanArray>()
+ .unwrap();
+
+ assert!(bool_array.value(0)); // "alphabet" ends with "bet"
+ assert!(bool_array.value(1)); // "rust" ends with "st"
+ assert!(!bool_array.value(2)); // "datafusion" does not end with
"hello"
+ assert!(bool_array.is_null(3)); // null string -> null output
+ }
}
diff --git a/datafusion/functions/src/string/starts_with.rs
b/datafusion/functions/src/string/starts_with.rs
index 1a60eb91aa..259612c429 100644
--- a/datafusion/functions/src/string/starts_with.rs
+++ b/datafusion/functions/src/string/starts_with.rs
@@ -18,49 +18,22 @@
use std::any::Any;
use std::sync::Arc;
-use arrow::array::ArrayRef;
+use arrow::array::{ArrayRef, Scalar};
+use arrow::compute::kernels::comparison::starts_with as arrow_starts_with;
use arrow::datatypes::DataType;
use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo};
use datafusion_expr::type_coercion::binary::{
binary_to_string_coercion, string_coercion,
};
-use crate::utils::make_scalar_function;
use datafusion_common::types::logical_string;
-use datafusion_common::{Result, ScalarValue, internal_err};
+use datafusion_common::{Result, ScalarValue, exec_err};
use datafusion_expr::{
Coercion, ColumnarValue, Documentation, Expr, Like, ScalarFunctionArgs,
ScalarUDFImpl, Signature, TypeSignatureClass, Volatility, cast,
};
use datafusion_macros::user_doc;
-/// Returns true if string starts with prefix.
-/// starts_with('alphabet', 'alph') = 't'
-fn starts_with(args: &[ArrayRef]) -> Result<ArrayRef> {
- if let Some(coercion_data_type) =
- string_coercion(args[0].data_type(), args[1].data_type()).or_else(|| {
- binary_to_string_coercion(args[0].data_type(), args[1].data_type())
- })
- {
- let arg0 = if args[0].data_type() == &coercion_data_type {
- Arc::clone(&args[0])
- } else {
- arrow::compute::kernels::cast::cast(&args[0], &coercion_data_type)?
- };
- let arg1 = if args[1].data_type() == &coercion_data_type {
- Arc::clone(&args[1])
- } else {
- arrow::compute::kernels::cast::cast(&args[1], &coercion_data_type)?
- };
- let result = arrow::compute::kernels::comparison::starts_with(&arg0,
&arg1)?;
- Ok(Arc::new(result) as ArrayRef)
- } else {
- internal_err!(
- "Unsupported data types for starts_with. Expected Utf8, LargeUtf8
or Utf8View"
- )
- }
-}
-
#[user_doc(
doc_section(label = "String Functions"),
description = "Tests if a string starts with a substring.",
@@ -119,13 +92,76 @@ impl ScalarUDFImpl for StartsWithFunc {
}
fn invoke_with_args(&self, args: ScalarFunctionArgs) ->
Result<ColumnarValue> {
- match args.args[0].data_type() {
- DataType::Utf8View | DataType::Utf8 | DataType::LargeUtf8 => {
- make_scalar_function(starts_with, vec![])(&args.args)
+ let [str_arg, prefix_arg] = args.args.as_slice() else {
+ return exec_err!(
+ "starts_with was called with {} arguments, expected 2",
+ args.args.len()
+ );
+ };
+
+ // Determine the common type for coercion
+ let coercion_type = string_coercion(
+ &str_arg.data_type(),
+ &prefix_arg.data_type(),
+ )
+ .or_else(|| {
+ binary_to_string_coercion(&str_arg.data_type(),
&prefix_arg.data_type())
+ });
+
+ let Some(coercion_type) = coercion_type else {
+ return exec_err!(
+ "Unsupported data types {:?}, {:?} for function
`starts_with`.",
+ str_arg.data_type(),
+ prefix_arg.data_type()
+ );
+ };
+
+ // Helper to cast an array if needed
+ let maybe_cast = |arr: &ArrayRef, target: &DataType| ->
Result<ArrayRef> {
+ if arr.data_type() == target {
+ Ok(Arc::clone(arr))
+ } else {
+ Ok(arrow::compute::kernels::cast::cast(arr, target)?)
+ }
+ };
+
+ match (str_arg, prefix_arg) {
+ // Both scalars - just compute directly
+ (ColumnarValue::Scalar(str_scalar),
ColumnarValue::Scalar(prefix_scalar)) => {
+ let str_arr = str_scalar.to_array_of_size(1)?;
+ let prefix_arr = prefix_scalar.to_array_of_size(1)?;
+ let str_arr = maybe_cast(&str_arr, &coercion_type)?;
+ let prefix_arr = maybe_cast(&prefix_arr, &coercion_type)?;
+ let result = arrow_starts_with(&str_arr, &prefix_arr)?;
+ Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(
+ &result, 0,
+ )?))
+ }
+ // String is array, prefix is scalar - use Scalar wrapper for
optimization
+ (ColumnarValue::Array(str_arr),
ColumnarValue::Scalar(prefix_scalar)) => {
+ let str_arr = maybe_cast(str_arr, &coercion_type)?;
+ let prefix_arr = prefix_scalar.to_array_of_size(1)?;
+ let prefix_arr = maybe_cast(&prefix_arr, &coercion_type)?;
+ let prefix_scalar = Scalar::new(prefix_arr);
+ let result = arrow_starts_with(&str_arr, &prefix_scalar)?;
+ Ok(ColumnarValue::Array(Arc::new(result)))
+ }
+ // String is scalar, prefix is array - use Scalar wrapper for
string
+ (ColumnarValue::Scalar(str_scalar),
ColumnarValue::Array(prefix_arr)) => {
+ let str_arr = str_scalar.to_array_of_size(1)?;
+ let str_arr = maybe_cast(&str_arr, &coercion_type)?;
+ let str_scalar = Scalar::new(str_arr);
+ let prefix_arr = maybe_cast(prefix_arr, &coercion_type)?;
+ let result = arrow_starts_with(&str_scalar, &prefix_arr)?;
+ Ok(ColumnarValue::Array(Arc::new(result)))
+ }
+ // Both arrays - pass directly
+ (ColumnarValue::Array(str_arr), ColumnarValue::Array(prefix_arr))
=> {
+ let str_arr = maybe_cast(str_arr, &coercion_type)?;
+ let prefix_arr = maybe_cast(prefix_arr, &coercion_type)?;
+ let result = arrow_starts_with(&str_arr, &prefix_arr)?;
+ Ok(ColumnarValue::Array(Arc::new(result)))
}
- _ => internal_err!(
- "Unsupported data types for starts_with. Expected Utf8,
LargeUtf8 or Utf8View"
- )?,
}
}
@@ -195,16 +231,19 @@ impl ScalarUDFImpl for StartsWithFunc {
#[cfg(test)]
mod tests {
use crate::utils::test::test_function;
- use arrow::array::{Array, BooleanArray};
+ use arrow::array::{Array, BooleanArray, StringArray};
use arrow::datatypes::DataType::Boolean;
+ use arrow::datatypes::{DataType, Field};
+ use datafusion_common::config::ConfigOptions;
use datafusion_common::{Result, ScalarValue};
- use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
+ use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
+ use std::sync::Arc;
use super::*;
#[test]
- fn test_functions() -> Result<()> {
- // Generate test cases for starts_with
+ fn test_scalar_scalar() -> Result<()> {
+ // Test Scalar + Scalar combinations
let test_cases = vec![
(Some("alphabet"), Some("alph"), Some(true)),
(Some("alphabet"), Some("bet"), Some(false)),
@@ -248,4 +287,154 @@ mod tests {
Ok(())
}
+
+ #[test]
+ fn test_array_scalar() -> Result<()> {
+ // Test Array + Scalar (the optimized path)
+ let array = ColumnarValue::Array(Arc::new(StringArray::from(vec![
+ Some("alphabet"),
+ Some("alphabet"),
+ Some("beta"),
+ None,
+ ])));
+ let scalar =
ColumnarValue::Scalar(ScalarValue::Utf8(Some("alph".to_string())));
+
+ let args = vec![array, scalar];
+ test_function!(
+ StartsWithFunc::new(),
+ args,
+ Ok(Some(true)), // First element result
+ bool,
+ Boolean,
+ BooleanArray
+ );
+
+ Ok(())
+ }
+
+ #[test]
+ fn test_array_scalar_full_result() {
+ // Test Array + Scalar and verify all results
+ let func = StartsWithFunc::new();
+ let array = Arc::new(StringArray::from(vec![
+ Some("alphabet"),
+ Some("alphabet"),
+ Some("beta"),
+ None,
+ ]));
+ let args = vec![
+ ColumnarValue::Array(array),
+ ColumnarValue::Scalar(ScalarValue::Utf8(Some("alph".to_string()))),
+ ];
+
+ let result = func
+ .invoke_with_args(ScalarFunctionArgs {
+ args,
+ arg_fields: vec![
+ Field::new("a", DataType::Utf8, true).into(),
+ Field::new("b", DataType::Utf8, true).into(),
+ ],
+ number_rows: 4,
+ return_field: Field::new("f", Boolean, true).into(),
+ config_options: Arc::new(ConfigOptions::default()),
+ })
+ .unwrap();
+
+ let result_array = result.into_array(4).unwrap();
+ let bool_array = result_array
+ .as_any()
+ .downcast_ref::<BooleanArray>()
+ .unwrap();
+
+ assert!(bool_array.value(0)); // "alphabet" starts with "alph"
+ assert!(bool_array.value(1)); // "alphabet" starts with "alph"
+ assert!(!bool_array.value(2)); // "beta" does not start with "alph"
+ assert!(bool_array.is_null(3)); // null input -> null output
+ }
+
+ #[test]
+ fn test_scalar_array() {
+ // Test Scalar + Array
+ let func = StartsWithFunc::new();
+ let prefixes = Arc::new(StringArray::from(vec![
+ Some("alph"),
+ Some("bet"),
+ Some("alpha"),
+ None,
+ ]));
+ let args = vec![
+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("alphabet".to_string()))),
+ ColumnarValue::Array(prefixes),
+ ];
+
+ let result = func
+ .invoke_with_args(ScalarFunctionArgs {
+ args,
+ arg_fields: vec![
+ Field::new("a", DataType::Utf8, true).into(),
+ Field::new("b", DataType::Utf8, true).into(),
+ ],
+ number_rows: 4,
+ return_field: Field::new("f", Boolean, true).into(),
+ config_options: Arc::new(ConfigOptions::default()),
+ })
+ .unwrap();
+
+ let result_array = result.into_array(4).unwrap();
+ let bool_array = result_array
+ .as_any()
+ .downcast_ref::<BooleanArray>()
+ .unwrap();
+
+ assert!(bool_array.value(0)); // "alphabet" starts with "alph"
+ assert!(!bool_array.value(1)); // "alphabet" does not start with "bet"
+ assert!(bool_array.value(2)); // "alphabet" starts with "alpha"
+ assert!(bool_array.is_null(3)); // null prefix -> null output
+ }
+
+ #[test]
+ fn test_array_array() {
+ // Test Array + Array
+ let func = StartsWithFunc::new();
+ let strings = Arc::new(StringArray::from(vec![
+ Some("alphabet"),
+ Some("rust"),
+ Some("datafusion"),
+ None,
+ ]));
+ let prefixes = Arc::new(StringArray::from(vec![
+ Some("alph"),
+ Some("ru"),
+ Some("hello"),
+ Some("test"),
+ ]));
+ let args = vec![
+ ColumnarValue::Array(strings),
+ ColumnarValue::Array(prefixes),
+ ];
+
+ let result = func
+ .invoke_with_args(ScalarFunctionArgs {
+ args,
+ arg_fields: vec![
+ Field::new("a", DataType::Utf8, true).into(),
+ Field::new("b", DataType::Utf8, true).into(),
+ ],
+ number_rows: 4,
+ return_field: Field::new("f", Boolean, true).into(),
+ config_options: Arc::new(ConfigOptions::default()),
+ })
+ .unwrap();
+
+ let result_array = result.into_array(4).unwrap();
+ let bool_array = result_array
+ .as_any()
+ .downcast_ref::<BooleanArray>()
+ .unwrap();
+
+ assert!(bool_array.value(0)); // "alphabet" starts with "alph"
+ assert!(bool_array.value(1)); // "rust" starts with "ru"
+ assert!(!bool_array.value(2)); // "datafusion" does not start with
"hello"
+ assert!(bool_array.is_null(3)); // null string -> null output
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]