This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 1e96a0a76c Fix `concat` simplifier for Utf8View types (#13346)
1e96a0a76c is described below

commit 1e96a0a76ca60364aa74d9a8bd8a4c15efdfb9de
Author: Tim Saucer <[email protected]>
AuthorDate: Fri Nov 15 14:53:20 2024 -0500

    Fix `concat` simplifier for Utf8View types (#13346)
    
    * Add string view options to concat, fix simplifier for handling concat to 
return the same schema as without
    
    * Set coersion ordering
    
    * Add to simplification unit test to catch changes in type for concat
    
    * Update coersion ordering
    
    * Simplify computing merged type for concat
---
 datafusion/core/tests/expr_api/simplification.rs | 24 +++++--
 datafusion/functions/src/string/concat.rs        | 87 +++++++++++++++++++++---
 2 files changed, 95 insertions(+), 16 deletions(-)

diff --git a/datafusion/core/tests/expr_api/simplification.rs 
b/datafusion/core/tests/expr_api/simplification.rs
index 68785b7a5a..1e6ff8088d 100644
--- a/datafusion/core/tests/expr_api/simplification.rs
+++ b/datafusion/core/tests/expr_api/simplification.rs
@@ -483,10 +483,12 @@ fn expr_test_schema() -> DFSchemaRef {
         Field::new("c2", DataType::Boolean, true),
         Field::new("c3", DataType::Int64, true),
         Field::new("c4", DataType::UInt32, true),
+        Field::new("c5", DataType::Utf8View, true),
         Field::new("c1_non_null", DataType::Utf8, false),
         Field::new("c2_non_null", DataType::Boolean, false),
         Field::new("c3_non_null", DataType::Int64, false),
         Field::new("c4_non_null", DataType::UInt32, false),
+        Field::new("c5_non_null", DataType::Utf8View, false),
     ])
     .to_dfschema_ref()
     .unwrap()
@@ -665,20 +667,32 @@ fn test_simplify_concat_ws_with_null() {
 }
 
 #[test]
-fn test_simplify_concat() {
+fn test_simplify_concat() -> Result<()> {
+    let schema = expr_test_schema();
     let null = lit(ScalarValue::Utf8(None));
     let expr = concat(vec![
         null.clone(),
-        col("c0"),
+        col("c1"),
         lit("hello "),
         null.clone(),
         lit("rust"),
-        col("c1"),
+        lit(ScalarValue::Utf8View(Some("!".to_string()))),
+        col("c2"),
         lit(""),
         null,
+        col("c5"),
     ]);
-    let expected = concat(vec![col("c0"), lit("hello rust"), col("c1")]);
-    test_simplify(expr, expected)
+    let expr_datatype = expr.get_type(schema.as_ref())?;
+    let expected = concat(vec![
+        col("c1"),
+        lit(ScalarValue::Utf8View(Some("hello rust!".to_string()))),
+        col("c2"),
+        col("c5"),
+    ]);
+    let expected_datatype = expected.get_type(schema.as_ref())?;
+    assert_eq!(expr_datatype, expected_datatype);
+    test_simplify(expr, expected);
+    Ok(())
 }
 #[test]
 fn test_simplify_cycles() {
diff --git a/datafusion/functions/src/string/concat.rs 
b/datafusion/functions/src/string/concat.rs
index f1e60004dd..d49a2777b4 100644
--- a/datafusion/functions/src/string/concat.rs
+++ b/datafusion/functions/src/string/concat.rs
@@ -48,7 +48,7 @@ impl ConcatFunc {
         use DataType::*;
         Self {
             signature: Signature::variadic(
-                vec![Utf8, Utf8View, LargeUtf8],
+                vec![Utf8View, Utf8, LargeUtf8],
                 Volatility::Immutable,
             ),
         }
@@ -110,8 +110,19 @@ impl ScalarUDFImpl for ConcatFunc {
         if array_len.is_none() {
             let mut result = String::new();
             for arg in args {
-                if let ColumnarValue::Scalar(ScalarValue::Utf8(Some(v))) = arg 
{
-                    result.push_str(v);
+                match arg {
+                    ColumnarValue::Scalar(ScalarValue::Utf8(Some(v)))
+                    | ColumnarValue::Scalar(ScalarValue::Utf8View(Some(v)))
+                    | ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some(v))) 
=> {
+                        result.push_str(v);
+                    }
+                    ColumnarValue::Scalar(ScalarValue::Utf8(None))
+                    | ColumnarValue::Scalar(ScalarValue::Utf8View(None))
+                    | ColumnarValue::Scalar(ScalarValue::LargeUtf8(None)) => {}
+                    other => plan_err!(
+                        "Concat function does not support scalar type {:?}",
+                        other
+                    )?,
                 }
             }
 
@@ -282,15 +293,37 @@ pub fn simplify_concat(args: Vec<Expr>) -> 
Result<ExprSimplifyResult> {
     let mut new_args = Vec::with_capacity(args.len());
     let mut contiguous_scalar = "".to_string();
 
+    let return_type = {
+        let data_types: Vec<_> = args
+            .iter()
+            .filter_map(|expr| match expr {
+                Expr::Literal(l) => Some(l.data_type()),
+                _ => None,
+            })
+            .collect();
+        ConcatFunc::new().return_type(&data_types)
+    }?;
+
     for arg in args.clone() {
         match arg {
+            Expr::Literal(ScalarValue::Utf8(None)) => {}
+            Expr::Literal(ScalarValue::LargeUtf8(None)) => {
+            }
+            Expr::Literal(ScalarValue::Utf8View(None)) => { }
+
             // filter out `null` args
-            Expr::Literal(ScalarValue::Utf8(None) | 
ScalarValue::LargeUtf8(None) | ScalarValue::Utf8View(None)) => {}
             // All literals have been converted to Utf8 or LargeUtf8 in 
type_coercion.
             // Concatenate it with the `contiguous_scalar`.
-            Expr::Literal(
-                ScalarValue::Utf8(Some(v)) | ScalarValue::LargeUtf8(Some(v)) | 
ScalarValue::Utf8View(Some(v)),
-            ) => contiguous_scalar += &v,
+            Expr::Literal(ScalarValue::Utf8(Some(v))) => {
+                contiguous_scalar += &v;
+            }
+            Expr::Literal(ScalarValue::LargeUtf8(Some(v))) => {
+                contiguous_scalar += &v;
+            }
+            Expr::Literal(ScalarValue::Utf8View(Some(v))) => {
+                contiguous_scalar += &v;
+            }
+
             Expr::Literal(x) => {
                 return internal_err!(
                     "The scalar {x} should be casted to string type during the 
type coercion."
@@ -301,7 +334,12 @@ pub fn simplify_concat(args: Vec<Expr>) -> 
Result<ExprSimplifyResult> {
             // Then pushing this arg to the `new_args`.
             arg => {
                 if !contiguous_scalar.is_empty() {
-                    new_args.push(lit(contiguous_scalar));
+                    match return_type {
+                        DataType::Utf8 => 
new_args.push(lit(contiguous_scalar)),
+                        DataType::LargeUtf8 => 
new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar)))),
+                        DataType::Utf8View => 
new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar)))),
+                        _ => unreachable!(),
+                    }
                     contiguous_scalar = "".to_string();
                 }
                 new_args.push(arg);
@@ -310,7 +348,16 @@ pub fn simplify_concat(args: Vec<Expr>) -> 
Result<ExprSimplifyResult> {
     }
 
     if !contiguous_scalar.is_empty() {
-        new_args.push(lit(contiguous_scalar));
+        match return_type {
+            DataType::Utf8 => new_args.push(lit(contiguous_scalar)),
+            DataType::LargeUtf8 => {
+                
new_args.push(lit(ScalarValue::LargeUtf8(Some(contiguous_scalar))))
+            }
+            DataType::Utf8View => {
+                
new_args.push(lit(ScalarValue::Utf8View(Some(contiguous_scalar))))
+            }
+            _ => unreachable!(),
+        }
     }
 
     if !args.eq(&new_args) {
@@ -392,6 +439,17 @@ mod tests {
             LargeUtf8,
             LargeStringArray
         );
+        test_function!(
+            ConcatFunc::new(),
+            &[
+                
ColumnarValue::Scalar(ScalarValue::Utf8View(Some("aa".to_string()))),
+                
ColumnarValue::Scalar(ScalarValue::Utf8(Some("cc".to_string()))),
+            ],
+            Ok(Some("aacc")),
+            &str,
+            Utf8View,
+            StringViewArray
+        );
 
         Ok(())
     }
@@ -406,11 +464,18 @@ mod tests {
             None,
             Some("z"),
         ])));
-        let args = &[c0, c1, c2];
+        let c3 = 
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(",".to_string())));
+        let c4 = ColumnarValue::Array(Arc::new(StringViewArray::from(vec![
+            Some("a"),
+            None,
+            Some("b"),
+        ])));
+        let args = &[c0, c1, c2, c3, c4];
 
         let result = ConcatFunc::new().invoke_batch(args, 3)?;
         let expected =
-            Arc::new(StringArray::from(vec!["foo,x", "bar,", "baz,z"])) as 
ArrayRef;
+            Arc::new(StringViewArray::from(vec!["foo,x,a", "bar,,", 
"baz,z,b"]))
+                as ArrayRef;
         match &result {
             ColumnarValue::Array(array) => {
                 assert_eq!(&expected, array);


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to