This is an automated email from the ASF dual-hosted git repository.

jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 31fcd72e96 Extend argument types for udf `return_type_from_exprs` 
(#9522)
31fcd72e96 is described below

commit 31fcd72e96fd6a8612d61ed3efc675f39e7198fa
Author: Jay Zhan <[email protected]>
AuthorDate: Sun Mar 10 20:14:37 2024 +0800

    Extend argument types for udf `return_type_from_exprs` (#9522)
    
    * Extend argument types for udf return type function
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * rm incorrect assumption
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    * possible empty types
    
    Signed-off-by: jayzhan211 <[email protected]>
    
    ---------
    
    Signed-off-by: jayzhan211 <[email protected]>
---
 .../user_defined/user_defined_scalar_functions.rs  |  1 +
 datafusion/expr/src/expr_schema.rs                 |  2 +-
 datafusion/expr/src/udf.rs                         | 15 +++++------
 datafusion/physical-expr/src/planner.rs            | 17 ++++++-------
 datafusion/physical-expr/src/udf.rs                | 29 +++++++++++++++++-----
 5 files changed, 39 insertions(+), 25 deletions(-)

diff --git 
a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs 
b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
index ca61c61db1..b525e4fc63 100644
--- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
+++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
@@ -652,6 +652,7 @@ impl ScalarUDFImpl for TakeUDF {
         &self,
         arg_exprs: &[Expr],
         schema: &dyn ExprSchema,
+        _arg_data_types: &[DataType],
     ) -> Result<DataType> {
         if arg_exprs.len() != 3 {
             return plan_err!("Expected 3 arguments, got {}.", arg_exprs.len());
diff --git a/datafusion/expr/src/expr_schema.rs 
b/datafusion/expr/src/expr_schema.rs
index 026627a05e..70ffa5064a 100644
--- a/datafusion/expr/src/expr_schema.rs
+++ b/datafusion/expr/src/expr_schema.rs
@@ -153,7 +153,7 @@ impl ExprSchemable for Expr {
 
                         // perform additional function arguments validation 
(due to limited
                         // expressiveness of `TypeSignature`), then infer 
return type
-                        Ok(fun.return_type_from_exprs(args, schema)?)
+                        Ok(fun.return_type_from_exprs(args, schema, 
&arg_data_types)?)
                     }
                     ScalarFunctionDefinition::Name(_) => {
                         internal_err!("Function `Expr` with name should be 
resolved.")
diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs
index 5ad420b2f3..3002a74505 100644
--- a/datafusion/expr/src/udf.rs
+++ b/datafusion/expr/src/udf.rs
@@ -18,7 +18,6 @@
 //! [`ScalarUDF`]: Scalar User Defined Functions
 
 use crate::simplify::{ExprSimplifyResult, SimplifyInfo};
-use crate::ExprSchemable;
 use crate::{
     ColumnarValue, Expr, FuncMonotonicity, ReturnTypeFunction,
     ScalarFunctionImplementation, Signature,
@@ -157,9 +156,10 @@ impl ScalarUDF {
         &self,
         args: &[Expr],
         schema: &dyn ExprSchema,
+        arg_types: &[DataType],
     ) -> Result<DataType> {
         // If the implementation provides a return_type_from_exprs, use it
-        self.inner.return_type_from_exprs(args, schema)
+        self.inner.return_type_from_exprs(args, schema, arg_types)
     }
 
     /// Do the function rewrite
@@ -305,14 +305,11 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
     /// value for `('foo' | 'bar')` as it does for ('foobar').
     fn return_type_from_exprs(
         &self,
-        args: &[Expr],
-        schema: &dyn ExprSchema,
+        _args: &[Expr],
+        _schema: &dyn ExprSchema,
+        arg_types: &[DataType],
     ) -> Result<DataType> {
-        let arg_types = args
-            .iter()
-            .map(|arg| arg.get_type(schema))
-            .collect::<Result<Vec<_>>>()?;
-        self.return_type(&arg_types)
+        self.return_type(arg_types)
     }
 
     /// Invoke the function on `args`, returning the appropriate result
diff --git a/datafusion/physical-expr/src/planner.rs 
b/datafusion/physical-expr/src/planner.rs
index 8001f989a2..e6022d383e 100644
--- a/datafusion/physical-expr/src/planner.rs
+++ b/datafusion/physical-expr/src/planner.rs
@@ -255,6 +255,7 @@ pub fn create_physical_expr(
                 .iter()
                 .map(|e| create_physical_expr(e, input_dfschema, 
execution_props))
                 .collect::<Result<Vec<_>>>()?;
+
             match func_def {
                 ScalarFunctionDefinition::BuiltIn(fun) => {
                     functions::create_physical_expr(
@@ -264,15 +265,13 @@ pub fn create_physical_expr(
                         execution_props,
                     )
                 }
-                ScalarFunctionDefinition::UDF(fun) => {
-                    let return_type = fun.return_type_from_exprs(args, 
input_dfschema)?;
-
-                    udf::create_physical_expr(
-                        fun.clone().as_ref(),
-                        &physical_args,
-                        return_type,
-                    )
-                }
+                ScalarFunctionDefinition::UDF(fun) => 
udf::create_physical_expr(
+                    fun.clone().as_ref(),
+                    &physical_args,
+                    input_schema,
+                    args,
+                    input_dfschema,
+                ),
                 ScalarFunctionDefinition::Name(_) => {
                     internal_err!("Function `Expr` with name should be 
resolved.")
                 }
diff --git a/datafusion/physical-expr/src/udf.rs 
b/datafusion/physical-expr/src/udf.rs
index d9c7c9e5c2..ede3e5badb 100644
--- a/datafusion/physical-expr/src/udf.rs
+++ b/datafusion/physical-expr/src/udf.rs
@@ -17,9 +17,10 @@
 
 //! UDF support
 use crate::{PhysicalExpr, ScalarFunctionExpr};
-use arrow_schema::DataType;
-use datafusion_common::Result;
+use arrow_schema::Schema;
+use datafusion_common::{DFSchema, Result};
 pub use datafusion_expr::ScalarUDF;
+use datafusion_expr::{type_coercion::functions::data_types, Expr};
 use std::sync::Arc;
 
 /// Create a physical expression of the UDF.
@@ -28,8 +29,22 @@ use std::sync::Arc;
 pub fn create_physical_expr(
     fun: &ScalarUDF,
     input_phy_exprs: &[Arc<dyn PhysicalExpr>],
-    return_type: DataType,
+    input_schema: &Schema,
+    args: &[Expr],
+    input_dfschema: &DFSchema,
 ) -> Result<Arc<dyn PhysicalExpr>> {
+    let input_expr_types = input_phy_exprs
+        .iter()
+        .map(|e| e.data_type(input_schema))
+        .collect::<Result<Vec<_>>>()?;
+
+    // verify that input data types is consistent with function's 
`TypeSignature`
+    data_types(&input_expr_types, fun.signature())?;
+
+    // Since we have arg_types, we dont need args and schema.
+    let return_type =
+        fun.return_type_from_exprs(args, input_dfschema, &input_expr_types)?;
+
     Ok(Arc::new(ScalarFunctionExpr::new(
         fun.name(),
         fun.fun(),
@@ -42,8 +57,8 @@ pub fn create_physical_expr(
 
 #[cfg(test)]
 mod tests {
-    use arrow_schema::DataType;
-    use datafusion_common::Result;
+    use arrow_schema::{DataType, Schema};
+    use datafusion_common::{DFSchema, Result};
     use datafusion_expr::{
         ColumnarValue, FuncMonotonicity, ScalarUDF, ScalarUDFImpl, Signature, 
Volatility,
     };
@@ -97,7 +112,9 @@ mod tests {
         // create and register the udf
         let udf = ScalarUDF::from(TestScalarUDF::new());
 
-        let p_expr = create_physical_expr(&udf, &[], DataType::Float64)?;
+        let e = crate::expressions::lit(1.1);
+        let p_expr =
+            create_physical_expr(&udf, &[e], &Schema::empty(), &[], 
&DFSchema::empty())?;
 
         assert_eq!(
             p_expr

Reply via email to