This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch add-expre-serilize-test
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git

commit 983dd2ed7c96675b01377781b737521b1a3607f9
Author: Andrew Lamb <[email protected]>
AuthorDate: Mon Jan 15 06:21:40 2024 -0500

    Fix other functions, improve test logic
---
 datafusion/proto/src/logical_plan/from_proto.rs | 17 ++++++++++++----
 datafusion/proto/tests/cases/serialize.rs       | 26 +++++++++++++++++--------
 2 files changed, 31 insertions(+), 12 deletions(-)

diff --git a/datafusion/proto/src/logical_plan/from_proto.rs 
b/datafusion/proto/src/logical_plan/from_proto.rs
index d15cf1db92..b498889527 100644
--- a/datafusion/proto/src/logical_plan/from_proto.rs
+++ b/datafusion/proto/src/logical_plan/from_proto.rs
@@ -53,8 +53,8 @@ use datafusion_expr::{
     coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, current_date, 
current_time,
     date_bin, date_part, date_trunc, decode, degrees, digest, encode, exp,
     expr::{self, InList, Sort, WindowFunction},
-    factorial, find_in_set, flatten, floor, from_unixtime, gcd, gen_range, 
isnan, iszero,
-    lcm, left, levenshtein, ln, log, log10, log2,
+    factorial, find_in_set, flatten, floor, from_unixtime, gcd, gen_range, 
initcap,
+    isnan, iszero, lcm, left, levenshtein, ln, log, log10, log2,
     logical_plan::{PlanType, StringifiedPlan},
     lower, lpad, ltrim, md5, nanvl, now, nullif, octet_length, overlay, pi, 
power,
     radians, random, regexp_match, regexp_replace, repeat, replace, reverse, 
right,
@@ -1588,7 +1588,7 @@ pub fn parse_expr(
                     Ok(character_length(parse_expr(&args[0], registry)?))
                 }
                 ScalarFunction::Chr => Ok(chr(parse_expr(&args[0], 
registry)?)),
-                ScalarFunction::InitCap => Ok(ascii(parse_expr(&args[0], 
registry)?)),
+                ScalarFunction::InitCap => Ok(initcap(parse_expr(&args[0], 
registry)?)),
                 ScalarFunction::Gcd => Ok(gcd(
                     parse_expr(&args[0], registry)?,
                     parse_expr(&args[1], registry)?,
@@ -1745,7 +1745,16 @@ pub fn parse_expr(
                     Ok(arrow_typeof(parse_expr(&args[0], registry)?))
                 }
                 ScalarFunction::ToTimestamp => {
-                    Ok(to_timestamp_seconds(parse_expr(&args[0], registry)?))
+                    let args: Vec<_> = args
+                        .iter()
+                        .map(|expr| parse_expr(expr, registry))
+                        .collect::<Result<_, _>>()?;
+                    Ok(Expr::ScalarFunction(
+                        datafusion_expr::expr::ScalarFunction::new(
+                            BuiltinScalarFunction::ToTimestamp,
+                            args,
+                        ),
+                    ))
                 }
                 ScalarFunction::Flatten => Ok(flatten(parse_expr(&args[0], 
registry)?)),
                 ScalarFunction::StringToArray => Ok(string_to_array(
diff --git a/datafusion/proto/tests/cases/serialize.rs 
b/datafusion/proto/tests/cases/serialize.rs
index 58fa171838..aa6df66f24 100644
--- a/datafusion/proto/tests/cases/serialize.rs
+++ b/datafusion/proto/tests/cases/serialize.rs
@@ -256,18 +256,28 @@ fn test_expression_serialization_roundtrip() {
     let ctx = SessionContext::new();
     let lit = Expr::Literal(ScalarValue::Utf8(None));
     for builtin_fun in BuiltinScalarFunction::iter() {
-        let expr =
-            Expr::ScalarFunction(ScalarFunction::new(builtin_fun, 
vec![lit.clone(); 4]));
+        println!("Checking function: {}", builtin_fun.name());
+        // default to 4 args (though some exprs like substr have error 
checking)
+        let num_args = match builtin_fun {
+            BuiltinScalarFunction::Substr => 3,
+            _ => 4,
+        };
+        let args: Vec<_> = 
std::iter::repeat(&lit).take(num_args).cloned().collect();
+        let expr = Expr::ScalarFunction(ScalarFunction::new(builtin_fun, 
args));
 
         let proto = LogicalExprNode::try_from(&expr).unwrap();
+        let deserialize = parse_expr(&proto, &ctx).unwrap();
 
-        let desirilize = parse_expr(&proto, &ctx).unwrap();
+        let serialize_name = extract_function_name(&expr);
+        let deserialize_name = extract_function_name(&deserialize);
 
-        let serialize_name = expr.display_name().unwrap();
-        let serialize_name = 
serialize_name.split('(').collect::<Vec<&str>>()[0];
-
-        let deserialize_name = desirilize.display_name().unwrap();
-        let deserialize_name = 
deserialize_name.split('(').collect::<Vec<&str>>()[0];
         assert_eq!(serialize_name, deserialize_name);
     }
+
+    /// Extracts the first part of a function name
+    /// 'foo(bar)' -> 'foo'
+    fn extract_function_name(expr: &Expr) -> String {
+        let name = expr.display_name().unwrap();
+        name.split('(').next().unwrap().to_string()
+    }
 }

Reply via email to