This is an automated email from the ASF dual-hosted git repository.
jayzhan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 31fcd72e96 Extend argument types for udf `return_type_from_exprs`
(#9522)
31fcd72e96 is described below
commit 31fcd72e96fd6a8612d61ed3efc675f39e7198fa
Author: Jay Zhan <[email protected]>
AuthorDate: Sun Mar 10 20:14:37 2024 +0800
Extend argument types for udf `return_type_from_exprs` (#9522)
* Extend argument types for udf return type function
Signed-off-by: jayzhan211 <[email protected]>
* rm incorrect assumption
Signed-off-by: jayzhan211 <[email protected]>
* possible empty types
Signed-off-by: jayzhan211 <[email protected]>
---------
Signed-off-by: jayzhan211 <[email protected]>
---
.../user_defined/user_defined_scalar_functions.rs | 1 +
datafusion/expr/src/expr_schema.rs | 2 +-
datafusion/expr/src/udf.rs | 15 +++++------
datafusion/physical-expr/src/planner.rs | 17 ++++++-------
datafusion/physical-expr/src/udf.rs | 29 +++++++++++++++++-----
5 files changed, 39 insertions(+), 25 deletions(-)
diff --git
a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
index ca61c61db1..b525e4fc63 100644
--- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
+++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs
@@ -652,6 +652,7 @@ impl ScalarUDFImpl for TakeUDF {
&self,
arg_exprs: &[Expr],
schema: &dyn ExprSchema,
+ _arg_data_types: &[DataType],
) -> Result<DataType> {
if arg_exprs.len() != 3 {
return plan_err!("Expected 3 arguments, got {}.", arg_exprs.len());
diff --git a/datafusion/expr/src/expr_schema.rs
b/datafusion/expr/src/expr_schema.rs
index 026627a05e..70ffa5064a 100644
--- a/datafusion/expr/src/expr_schema.rs
+++ b/datafusion/expr/src/expr_schema.rs
@@ -153,7 +153,7 @@ impl ExprSchemable for Expr {
// perform additional function arguments validation
(due to limited
// expressiveness of `TypeSignature`), then infer
return type
- Ok(fun.return_type_from_exprs(args, schema)?)
+ Ok(fun.return_type_from_exprs(args, schema,
&arg_data_types)?)
}
ScalarFunctionDefinition::Name(_) => {
internal_err!("Function `Expr` with name should be
resolved.")
diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs
index 5ad420b2f3..3002a74505 100644
--- a/datafusion/expr/src/udf.rs
+++ b/datafusion/expr/src/udf.rs
@@ -18,7 +18,6 @@
//! [`ScalarUDF`]: Scalar User Defined Functions
use crate::simplify::{ExprSimplifyResult, SimplifyInfo};
-use crate::ExprSchemable;
use crate::{
ColumnarValue, Expr, FuncMonotonicity, ReturnTypeFunction,
ScalarFunctionImplementation, Signature,
@@ -157,9 +156,10 @@ impl ScalarUDF {
&self,
args: &[Expr],
schema: &dyn ExprSchema,
+ arg_types: &[DataType],
) -> Result<DataType> {
// If the implementation provides a return_type_from_exprs, use it
- self.inner.return_type_from_exprs(args, schema)
+ self.inner.return_type_from_exprs(args, schema, arg_types)
}
/// Do the function rewrite
@@ -305,14 +305,11 @@ pub trait ScalarUDFImpl: Debug + Send + Sync {
/// value for `('foo' | 'bar')` as it does for ('foobar').
fn return_type_from_exprs(
&self,
- args: &[Expr],
- schema: &dyn ExprSchema,
+ _args: &[Expr],
+ _schema: &dyn ExprSchema,
+ arg_types: &[DataType],
) -> Result<DataType> {
- let arg_types = args
- .iter()
- .map(|arg| arg.get_type(schema))
- .collect::<Result<Vec<_>>>()?;
- self.return_type(&arg_types)
+ self.return_type(arg_types)
}
/// Invoke the function on `args`, returning the appropriate result
diff --git a/datafusion/physical-expr/src/planner.rs
b/datafusion/physical-expr/src/planner.rs
index 8001f989a2..e6022d383e 100644
--- a/datafusion/physical-expr/src/planner.rs
+++ b/datafusion/physical-expr/src/planner.rs
@@ -255,6 +255,7 @@ pub fn create_physical_expr(
.iter()
.map(|e| create_physical_expr(e, input_dfschema,
execution_props))
.collect::<Result<Vec<_>>>()?;
+
match func_def {
ScalarFunctionDefinition::BuiltIn(fun) => {
functions::create_physical_expr(
@@ -264,15 +265,13 @@ pub fn create_physical_expr(
execution_props,
)
}
- ScalarFunctionDefinition::UDF(fun) => {
- let return_type = fun.return_type_from_exprs(args,
input_dfschema)?;
-
- udf::create_physical_expr(
- fun.clone().as_ref(),
- &physical_args,
- return_type,
- )
- }
+ ScalarFunctionDefinition::UDF(fun) =>
udf::create_physical_expr(
+ fun.clone().as_ref(),
+ &physical_args,
+ input_schema,
+ args,
+ input_dfschema,
+ ),
ScalarFunctionDefinition::Name(_) => {
internal_err!("Function `Expr` with name should be
resolved.")
}
diff --git a/datafusion/physical-expr/src/udf.rs
b/datafusion/physical-expr/src/udf.rs
index d9c7c9e5c2..ede3e5badb 100644
--- a/datafusion/physical-expr/src/udf.rs
+++ b/datafusion/physical-expr/src/udf.rs
@@ -17,9 +17,10 @@
//! UDF support
use crate::{PhysicalExpr, ScalarFunctionExpr};
-use arrow_schema::DataType;
-use datafusion_common::Result;
+use arrow_schema::Schema;
+use datafusion_common::{DFSchema, Result};
pub use datafusion_expr::ScalarUDF;
+use datafusion_expr::{type_coercion::functions::data_types, Expr};
use std::sync::Arc;
/// Create a physical expression of the UDF.
@@ -28,8 +29,22 @@ use std::sync::Arc;
pub fn create_physical_expr(
fun: &ScalarUDF,
input_phy_exprs: &[Arc<dyn PhysicalExpr>],
- return_type: DataType,
+ input_schema: &Schema,
+ args: &[Expr],
+ input_dfschema: &DFSchema,
) -> Result<Arc<dyn PhysicalExpr>> {
+ let input_expr_types = input_phy_exprs
+ .iter()
+ .map(|e| e.data_type(input_schema))
+ .collect::<Result<Vec<_>>>()?;
+
+ // verify that input data types is consistent with function's
`TypeSignature`
+ data_types(&input_expr_types, fun.signature())?;
+
+ // Since we have arg_types, we dont need args and schema.
+ let return_type =
+ fun.return_type_from_exprs(args, input_dfschema, &input_expr_types)?;
+
Ok(Arc::new(ScalarFunctionExpr::new(
fun.name(),
fun.fun(),
@@ -42,8 +57,8 @@ pub fn create_physical_expr(
#[cfg(test)]
mod tests {
- use arrow_schema::DataType;
- use datafusion_common::Result;
+ use arrow_schema::{DataType, Schema};
+ use datafusion_common::{DFSchema, Result};
use datafusion_expr::{
ColumnarValue, FuncMonotonicity, ScalarUDF, ScalarUDFImpl, Signature,
Volatility,
};
@@ -97,7 +112,9 @@ mod tests {
// create and register the udf
let udf = ScalarUDF::from(TestScalarUDF::new());
- let p_expr = create_physical_expr(&udf, &[], DataType::Float64)?;
+ let e = crate::expressions::lit(1.1);
+ let p_expr =
+ create_physical_expr(&udf, &[e], &Schema::empty(), &[],
&DFSchema::empty())?;
assert_eq!(
p_expr