This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 1e96a0a76c Fix `concat` simplifier for Utf8View types (#13346)
1e96a0a76c is described below
commit 1e96a0a76ca60364aa74d9a8bd8a4c15efdfb9de
Author: Tim Saucer <[email protected]>
AuthorDate: Fri Nov 15 14:53:20 2024 -0500
Fix `concat` simplifier for Utf8View types (#13346)
* Add string view options to concat, fix simplifier for handling concat to
return the same schema as without
* Set coersion ordering
* Add to simplification unit test to catch changes in type for concat
* Update coersion ordering
* Simplify computing merged type for concat
---
datafusion/core/tests/expr_api/simplification.rs | 24 +++++--
datafusion/functions/src/string/concat.rs | 87 +++++++++++++++++++++---
2 files changed, 95 insertions(+), 16 deletions(-)
diff --git a/datafusion/core/tests/expr_api/simplification.rs
b/datafusion/core/tests/expr_api/simplification.rs
index 68785b7a5a..1e6ff8088d 100644
--- a/datafusion/core/tests/expr_api/simplification.rs
+++ b/datafusion/core/tests/expr_api/simplification.rs
@@ -483,10 +483,12 @@ fn expr_test_schema() -> DFSchemaRef {
Field::new("c2", DataType::Boolean, true),
Field::new("c3", DataType::Int64, true),
Field::new("c4", DataType::UInt32, true),
+ Field::new("c5", DataType::Utf8View, true),
Field::new("c1_non_null", DataType::Utf8, false),
Field::new("c2_non_null", DataType::Boolean, false),
Field::new("c3_non_null", DataType::Int64, false),
Field::new("c4_non_null", DataType::UInt32, false),
+ Field::new("c5_non_null", DataType::Utf8View, false),
])
.to_dfschema_ref()
.unwrap()
@@ -665,20 +667,32 @@ fn test_simplify_concat_ws_with_null() {
}
#[test]
-fn test_simplify_concat() {
+fn test_simplify_concat() -> Result<()> {
+ let schema = expr_test_schema();
let null = lit(ScalarValue::Utf8(None));
let expr = concat(vec![
null.clone(),
- col("c0"),
+ col("c1"),
lit("hello "),
null.clone(),
lit("rust"),
- col("c1"),
+ lit(ScalarValue::Utf8View(Some("!".to_string()))),
+ col("c2"),
lit(""),
null,
+ col("c5"),
]);
- let expected = concat(vec![col("c0"), lit("hello rust"), col("c1")]);
- test_simplify(expr, expected)
+ let expr_datatype = expr.get_type(schema.as_ref())?;
+ let expected = concat(vec![
+ col("c1"),
+ lit(ScalarValue::Utf8View(Some("hello rust!".to_string()))),
+ col("c2"),
+ col("c5"),
+ ]);
+ let expected_datatype = expected.get_type(schema.as_ref())?;
+ assert_eq!(expr_datatype, expected_datatype);
+ test_simplify(expr, expected);
+ Ok(())
}
#[test]
fn test_simplify_cycles() {
diff --git a/datafusion/functions/src/string/concat.rs
b/datafusion/functions/src/string/concat.rs
index f1e60004dd..d49a2777b4 100644
--- a/datafusion/functions/src/string/concat.rs
+++ b/datafusion/functions/src/string/concat.rs
@@ -48,7 +48,7 @@ impl ConcatFunc {
use DataType::*;
Self {
signature: Signature::variadic(
- vec![Utf8, Utf8View, LargeUtf8],
+ vec![Utf8View, Utf8, LargeUtf8],
Volatility::Immutable,
),
}
@@ -110,8 +110,19 @@ impl ScalarUDFImpl for ConcatFunc {
if array_len.is_none() {
let mut result = String::new();
for arg in args {
- if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = arg
{
- result.push_str(v);
+ match arg {
+ ColumnarValue::Scalar(ScalarValue::Utf8(Some(v)))
+ | ColumnarValue::Scalar(ScalarValue::Utf8View(Some(v)))
+ | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(v)))
=> {
+ result.push_str(v);
+ }
+ ColumnarValue::Scalar(ScalarValue::Utf8(None))
+ | ColumnarValue::Scalar(ScalarValue::Utf8View(None))
+ | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {}
+ other => plan_err!(
+ "Concat function does not support scalar type {:?}",
+ other
+ )?,
}
}
@@ -282,15 +293,37 @@ pub fn simplify_concat(args: Vec<Expr>) ->
Result<ExprSimplifyResult> {
let mut new_args = Vec::with_capacity(args.len());
let mut contiguous_scalar = "".to_string();
+ let return_type = {
+ let data_types: Vec<_> = args
+ .iter()
+ .filter_map(|expr| match expr {
+ Expr::Literal(l) => Some(l.data_type()),
+ _ => None,
+ })
+ .collect();
+ ConcatFunc::new().return_type(&data_types)
+ }?;
+
for arg in args.clone() {
match arg {
+ Expr::Literal(ScalarValue::Utf8(None)) => {}
+ Expr::Literal(ScalarValue::LargeUtf8(None)) => {
+ }
+ Expr::Literal(ScalarValue::Utf8View(None)) => { }
+
// filter out `null` args
- Expr::Literal(ScalarValue::Utf8(None) |
ScalarValue::LargeUtf8(None) | ScalarValue::Utf8View(None)) => {}
// All literals have been converted to Utf8 or LargeUtf8 in
type_coercion.
// Concatenate it with the `contiguous_scalar`.
- Expr::Literal(
- ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) |
ScalarValue::Utf8View(Some(v)),
- ) => contiguous_scalar += &v,
+ Expr::Literal(ScalarValue::Utf8(Some(v))) => {
+ contiguous_scalar += &v;
+ }
+ Expr::Literal(ScalarValue::LargeUtf8(Some(v))) => {
+ contiguous_scalar += &v;
+ }
+ Expr::Literal(ScalarValue::Utf8View(Some(v))) => {
+ contiguous_scalar += &v;
+ }
+
Expr::Literal(x) => {
return internal_err!(
"The scalar {x} should be casted to string type during the
type coercion."
@@ -301,7 +334,12 @@ pub fn simplify_concat(args: Vec<Expr>) ->
Result<ExprSimplifyResult> {
// Then pushing this arg to the `new_args`.
arg => {
if !contiguous_scalar.is_empty() {
- new_args.push(lit(contiguous_scalar));
+ match return_type {
+ DataType::Utf8 =>
new_args.push(lit(contiguous_scalar)),
+ DataType::LargeUtf8 =>
new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar)))),
+ DataType::Utf8View =>
new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))),
+ _ => unreachable!(),
+ }
contiguous_scalar = "".to_string();
}
new_args.push(arg);
@@ -310,7 +348,16 @@ pub fn simplify_concat(args: Vec<Expr>) ->
Result<ExprSimplifyResult> {
}
if !contiguous_scalar.is_empty() {
- new_args.push(lit(contiguous_scalar));
+ match return_type {
+ DataType::Utf8 => new_args.push(lit(contiguous_scalar)),
+ DataType::LargeUtf8 => {
+
new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar))))
+ }
+ DataType::Utf8View => {
+
new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar))))
+ }
+ _ => unreachable!(),
+ }
}
if !args.eq(&new_args) {
@@ -392,6 +439,17 @@ mod tests {
LargeUtf8,
LargeStringArray
);
+ test_function!(
+ ConcatFunc::new(),
+ &[
+
ColumnarValue::Scalar(ScalarValue::Utf8View(Some("aa".to_string()))),
+
ColumnarValue::Scalar(ScalarValue::Utf8(Some("cc".to_string()))),
+ ],
+ Ok(Some("aacc")),
+ &str,
+ Utf8View,
+ StringViewArray
+ );
Ok(())
}
@@ -406,11 +464,18 @@ mod tests {
None,
Some("z"),
])));
- let args = &[c0, c1, c2];
+ let c3 =
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(",".to_string())));
+ let c4 = ColumnarValue::Array(Arc::new(StringViewArray::from(vec![
+ Some("a"),
+ None,
+ Some("b"),
+ ])));
+ let args = &[c0, c1, c2, c3, c4];
let result = ConcatFunc::new().invoke_batch(args, 3)?;
let expected =
- Arc::new(StringArray::from(vec!["foo,x", "bar,", "baz,z"])) as
ArrayRef;
+ Arc::new(StringViewArray::from(vec!["foo,x,a", "bar,,",
"baz,z,b"]))
+ as ArrayRef;
match &result {
ColumnarValue::Array(array) => {
assert_eq!(&expected, array);
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]