This is an automated email from the ASF dual-hosted git repository.

liukun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new a92119a05 support type coercion for ScalarFunction (#3749)
a92119a05 is described below

commit a92119a0584322234a9005dc0897e94026464cad
Author: Kun Liu <[email protected]>
AuthorDate: Mon Oct 10 20:48:21 2022 +0800

    support type coercion for ScalarFunction (#3749)
---
 datafusion/optimizer/src/type_coercion.rs | 82 ++++++++++++++++++++-----------
 datafusion/physical-expr/src/functions.rs | 69 ++++++++++++++++----------
 2 files changed, 95 insertions(+), 56 deletions(-)

diff --git a/datafusion/optimizer/src/type_coercion.rs 
b/datafusion/optimizer/src/type_coercion.rs
index fcfe6eaaa..5438632ba 100644
--- a/datafusion/optimizer/src/type_coercion.rs
+++ b/datafusion/optimizer/src/type_coercion.rs
@@ -30,8 +30,8 @@ use datafusion_expr::type_coercion::other::{
 };
 use datafusion_expr::utils::from_plan;
 use datafusion_expr::{
-    is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown,
-    BuiltinScalarFunction, Expr, LogicalPlan, Operator,
+    function, is_false, is_not_false, is_not_true, is_not_unknown, is_true, 
is_unknown,
+    Expr, LogicalPlan, Operator,
 };
 use datafusion_expr::{ExprSchemable, Signature};
 use std::sync::Arc;
@@ -311,18 +311,6 @@ impl ExprRewriter for TypeCoercionRewriter {
                 };
                 Ok(expr)
             }
-            Expr::ScalarUDF { fun, args } => {
-                let new_expr = coerce_arguments_for_signature(
-                    args.as_slice(),
-                    &self.schema,
-                    &fun.signature,
-                )?;
-                let expr = Expr::ScalarUDF {
-                    fun,
-                    args: new_expr,
-                };
-                Ok(expr)
-            }
             Expr::InList {
                 expr,
                 list,
@@ -395,20 +383,30 @@ impl ExprRewriter for TypeCoercionRewriter {
                     }
                 }
             }
-            Expr::ScalarFunction { fun, args } => match fun {
-                BuiltinScalarFunction::Concat
-                | BuiltinScalarFunction::ConcatWithSeparator => {
-                    let new_args = args
-                        .iter()
-                        .map(|e| e.clone().cast_to(&DataType::Utf8, 
&self.schema))
-                        .collect::<Result<Vec<_>>>()?;
-                    Ok(Expr::ScalarFunction {
-                        fun,
-                        args: new_args,
-                    })
-                }
-                fun => Ok(Expr::ScalarFunction { fun, args }),
-            },
+            Expr::ScalarUDF { fun, args } => {
+                let new_expr = coerce_arguments_for_signature(
+                    args.as_slice(),
+                    &self.schema,
+                    &fun.signature,
+                )?;
+                let expr = Expr::ScalarUDF {
+                    fun,
+                    args: new_expr,
+                };
+                Ok(expr)
+            }
+            Expr::ScalarFunction { fun, args } => {
+                let nex_expr = coerce_arguments_for_signature(
+                    args.as_slice(),
+                    &self.schema,
+                    &function::signature(&fun),
+                )?;
+                let expr = Expr::ScalarFunction {
+                    fun,
+                    args: nex_expr,
+                };
+                Ok(expr)
+            }
             expr => Ok(expr),
         }
     }
@@ -457,7 +455,9 @@ mod test {
     use arrow::datatypes::DataType;
     use datafusion_common::{DFField, DFSchema, Result, ScalarValue};
     use datafusion_expr::expr_rewriter::ExprRewritable;
-    use datafusion_expr::{cast, col, concat, concat_ws, is_true, 
ColumnarValue};
+    use datafusion_expr::{
+        cast, col, concat, concat_ws, is_true, BuiltinScalarFunction, 
ColumnarValue,
+    };
     use datafusion_expr::{
         lit,
         logical_plan::{EmptyRelation, Projection},
@@ -572,6 +572,30 @@ mod test {
         Ok(())
     }
 
+    #[test]
+    fn scalar_function() -> Result<()> {
+        let empty = empty();
+        let lit_expr = lit(10i64);
+        let fun: BuiltinScalarFunction = BuiltinScalarFunction::Abs;
+        let scalar_function_expr = Expr::ScalarFunction {
+            fun,
+            args: vec![lit_expr],
+        };
+        let plan = LogicalPlan::Projection(Projection::try_new(
+            vec![scalar_function_expr],
+            empty,
+            None,
+        )?);
+        let rule = TypeCoercion::new();
+        let mut config = OptimizerConfig::default();
+        let plan = rule.optimize(&plan, &mut config)?;
+        assert_eq!(
+            "Projection: abs(CAST(Int64(10) AS Float64))\n  EmptyRelation",
+            &format!("{:?}", plan)
+        );
+        Ok(())
+    }
+
     #[test]
     fn binary_op_date32_add_interval() -> Result<()> {
         //CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("386547056640")
diff --git a/datafusion/physical-expr/src/functions.rs 
b/datafusion/physical-expr/src/functions.rs
index 5796f8f7d..7d9e89b52 100644
--- a/datafusion/physical-expr/src/functions.rs
+++ b/datafusion/physical-expr/src/functions.rs
@@ -34,9 +34,8 @@ use crate::execution_props::ExecutionProps;
 use crate::{
     array_expressions, conditional_expressions, datetime_expressions,
     expressions::{cast_column, nullif_func, DEFAULT_DATAFUSION_CAST_OPTIONS},
-    math_expressions, string_expressions, struct_expressions,
-    type_coercion::coerce,
-    PhysicalExpr, ScalarFunctionExpr,
+    math_expressions, string_expressions, struct_expressions, PhysicalExpr,
+    ScalarFunctionExpr,
 };
 use arrow::{
     array::ArrayRef,
@@ -58,15 +57,12 @@ pub fn create_physical_expr(
     input_schema: &Schema,
     execution_props: &ExecutionProps,
 ) -> Result<Arc<dyn PhysicalExpr>> {
-    let coerced_phy_exprs =
-        coerce(input_phy_exprs, input_schema, &function::signature(fun))?;
-
-    let coerced_expr_types = coerced_phy_exprs
+    let input_expr_types = input_phy_exprs
         .iter()
         .map(|e| e.data_type(input_schema))
         .collect::<Result<Vec<_>>>()?;
 
-    let data_type = function::return_type(fun, &coerced_expr_types)?;
+    let data_type = function::return_type(fun, &input_expr_types)?;
 
     let fun_expr: ScalarFunctionImplementation = match fun {
         // These functions need args and input schema to pick an implementation
@@ -74,7 +70,7 @@ pub fn create_physical_expr(
         // here we return either a cast fn or string timestamp translation 
based on the expression data type
         // so we don't have to pay a per-array/batch cost.
         BuiltinScalarFunction::ToTimestamp => {
-            Arc::new(match coerced_phy_exprs[0].data_type(input_schema) {
+            Arc::new(match input_phy_exprs[0].data_type(input_schema) {
                 Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => {
                     |col_values: &[ColumnarValue]| {
                         cast_column(
@@ -89,12 +85,12 @@ pub fn create_physical_expr(
                     return Err(DataFusionError::Internal(format!(
                         "Unsupported data type {:?} for function to_timestamp",
                         other,
-                    )))
+                    )));
                 }
             })
         }
         BuiltinScalarFunction::ToTimestampMillis => {
-            Arc::new(match coerced_phy_exprs[0].data_type(input_schema) {
+            Arc::new(match input_phy_exprs[0].data_type(input_schema) {
                 Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => {
                     |col_values: &[ColumnarValue]| {
                         cast_column(
@@ -109,12 +105,12 @@ pub fn create_physical_expr(
                     return Err(DataFusionError::Internal(format!(
                         "Unsupported data type {:?} for function 
to_timestamp_millis",
                         other,
-                    )))
+                    )));
                 }
             })
         }
         BuiltinScalarFunction::ToTimestampMicros => {
-            Arc::new(match coerced_phy_exprs[0].data_type(input_schema) {
+            Arc::new(match input_phy_exprs[0].data_type(input_schema) {
                 Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => {
                     |col_values: &[ColumnarValue]| {
                         cast_column(
@@ -129,12 +125,12 @@ pub fn create_physical_expr(
                     return Err(DataFusionError::Internal(format!(
                         "Unsupported data type {:?} for function 
to_timestamp_micros",
                         other,
-                    )))
+                    )));
                 }
             })
         }
         BuiltinScalarFunction::ToTimestampSeconds => Arc::new({
-            match coerced_phy_exprs[0].data_type(input_schema) {
+            match input_phy_exprs[0].data_type(input_schema) {
                 Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => {
                     |col_values: &[ColumnarValue]| {
                         cast_column(
@@ -149,12 +145,12 @@ pub fn create_physical_expr(
                     return Err(DataFusionError::Internal(format!(
                         "Unsupported data type {:?} for function 
to_timestamp_seconds",
                         other,
-                    )))
+                    )));
                 }
             }
         }),
         BuiltinScalarFunction::FromUnixtime => Arc::new({
-            match coerced_phy_exprs[0].data_type(input_schema) {
+            match input_phy_exprs[0].data_type(input_schema) {
                 Ok(DataType::Int64) => |col_values: &[ColumnarValue]| {
                     cast_column(
                         &col_values[0],
@@ -166,12 +162,12 @@ pub fn create_physical_expr(
                     return Err(DataFusionError::Internal(format!(
                         "Unsupported data type {:?} for function 
from_unixtime",
                         other,
-                    )))
+                    )));
                 }
             }
         }),
         BuiltinScalarFunction::ArrowTypeof => {
-            let input_data_type = 
coerced_phy_exprs[0].data_type(input_schema)?;
+            let input_data_type = input_phy_exprs[0].data_type(input_schema)?;
             Arc::new(move |_| {
                 Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(format!(
                     "{}",
@@ -186,7 +182,7 @@ pub fn create_physical_expr(
     Ok(Arc::new(ScalarFunctionExpr::new(
         &format!("{}", fun),
         fun_expr,
-        coerced_phy_exprs,
+        input_phy_exprs.to_vec(),
         &data_type,
     )))
 }
@@ -727,7 +723,7 @@ pub fn create_physical_fun(
             return Err(DataFusionError::Internal(format!(
                 "create_physical_fun: Unsupported scalar function {:?}",
                 fun
-            )))
+            )));
         }
     })
 }
@@ -737,6 +733,7 @@ mod tests {
     use super::*;
     use crate::expressions::{col, lit};
     use crate::from_slice::FromSlice;
+    use crate::type_coercion::coerce;
     use arrow::{
         array::{
             Array, ArrayRef, BinaryArray, BooleanArray, FixedSizeListArray, 
Float32Array,
@@ -764,7 +761,7 @@ mod tests {
             let columns: Vec<ArrayRef> = 
vec![Arc::new(Int32Array::from_slice(&[1]))];
 
             let expr =
-                create_physical_expr(&BuiltinScalarFunction::$FUNC, $ARGS, 
&schema, &execution_props)?;
+                
create_physical_expr_with_type_coercion(&BuiltinScalarFunction::$FUNC, $ARGS, 
&schema, &execution_props)?;
 
             // type is correct
             assert_eq!(expr.data_type(&schema)?, DataType::$DATA_TYPE);
@@ -2683,7 +2680,12 @@ mod tests {
         ];
 
         for fun in funs.iter() {
-            let expr = create_physical_expr(fun, &[], &schema, 
&execution_props);
+            let expr = create_physical_expr_with_type_coercion(
+                fun,
+                &[],
+                &schema,
+                &execution_props,
+            );
 
             match expr {
                 Ok(..) => {
@@ -2720,7 +2722,7 @@ mod tests {
         let funs = [BuiltinScalarFunction::Now, BuiltinScalarFunction::Random];
 
         for fun in funs.iter() {
-            create_physical_expr(fun, &[], &schema, &execution_props)?;
+            create_physical_expr_with_type_coercion(fun, &[], &schema, 
&execution_props)?;
         }
         Ok(())
     }
@@ -2739,7 +2741,7 @@ mod tests {
         let columns: Vec<ArrayRef> = vec![value1, value2];
         let execution_props = ExecutionProps::new();
 
-        let expr = create_physical_expr(
+        let expr = create_physical_expr_with_type_coercion(
             &BuiltinScalarFunction::MakeArray,
             &[col("a", &schema)?, col("b", &schema)?],
             &schema,
@@ -2805,7 +2807,7 @@ mod tests {
         let col_value: ArrayRef = 
Arc::new(StringArray::from_slice(&["aaa-555"]));
         let pattern = lit(r".*-(\d*)");
         let columns: Vec<ArrayRef> = vec![col_value];
-        let expr = create_physical_expr(
+        let expr = create_physical_expr_with_type_coercion(
             &BuiltinScalarFunction::RegexpMatch,
             &[col("a", &schema)?, pattern],
             &schema,
@@ -2844,7 +2846,7 @@ mod tests {
         let col_value = lit("aaa-555");
         let pattern = lit(r".*-(\d*)");
         let columns: Vec<ArrayRef> = 
vec![Arc::new(Int32Array::from_slice(&[1]))];
-        let expr = create_physical_expr(
+        let expr = create_physical_expr_with_type_coercion(
             &BuiltinScalarFunction::RegexpMatch,
             &[col_value, pattern],
             &schema,
@@ -2872,4 +2874,17 @@ mod tests {
 
         Ok(())
     }
+
+    // Helper function
+    // The type coercion will be done in the logical phase, should do the type 
coercion for the test
+    fn create_physical_expr_with_type_coercion(
+        fun: &BuiltinScalarFunction,
+        input_phy_exprs: &[Arc<dyn PhysicalExpr>],
+        input_schema: &Schema,
+        execution_props: &ExecutionProps,
+    ) -> Result<Arc<dyn PhysicalExpr>> {
+        let type_coerced_phy_exprs =
+            coerce(input_phy_exprs, input_schema, 
&function::signature(fun)).unwrap();
+        create_physical_expr(fun, &type_coerced_phy_exprs, input_schema, 
execution_props)
+    }
 }

Reply via email to