This is an automated email from the ASF dual-hosted git repository.
liukun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new a92119a05 support type coercion for ScalarFunction (#3749)
a92119a05 is described below
commit a92119a0584322234a9005dc0897e94026464cad
Author: Kun Liu <[email protected]>
AuthorDate: Mon Oct 10 20:48:21 2022 +0800
support type coercion for ScalarFunction (#3749)
---
datafusion/optimizer/src/type_coercion.rs | 82 ++++++++++++++++++++-----------
datafusion/physical-expr/src/functions.rs | 69 ++++++++++++++++----------
2 files changed, 95 insertions(+), 56 deletions(-)
diff --git a/datafusion/optimizer/src/type_coercion.rs
b/datafusion/optimizer/src/type_coercion.rs
index fcfe6eaaa..5438632ba 100644
--- a/datafusion/optimizer/src/type_coercion.rs
+++ b/datafusion/optimizer/src/type_coercion.rs
@@ -30,8 +30,8 @@ use datafusion_expr::type_coercion::other::{
};
use datafusion_expr::utils::from_plan;
use datafusion_expr::{
- is_false, is_not_false, is_not_true, is_not_unknown, is_true, is_unknown,
- BuiltinScalarFunction, Expr, LogicalPlan, Operator,
+ function, is_false, is_not_false, is_not_true, is_not_unknown, is_true,
is_unknown,
+ Expr, LogicalPlan, Operator,
};
use datafusion_expr::{ExprSchemable, Signature};
use std::sync::Arc;
@@ -311,18 +311,6 @@ impl ExprRewriter for TypeCoercionRewriter {
};
Ok(expr)
}
- Expr::ScalarUDF { fun, args } => {
- let new_expr = coerce_arguments_for_signature(
- args.as_slice(),
- &self.schema,
- &fun.signature,
- )?;
- let expr = Expr::ScalarUDF {
- fun,
- args: new_expr,
- };
- Ok(expr)
- }
Expr::InList {
expr,
list,
@@ -395,20 +383,30 @@ impl ExprRewriter for TypeCoercionRewriter {
}
}
}
- Expr::ScalarFunction { fun, args } => match fun {
- BuiltinScalarFunction::Concat
- | BuiltinScalarFunction::ConcatWithSeparator => {
- let new_args = args
- .iter()
- .map(|e| e.clone().cast_to(&DataType::Utf8,
&self.schema))
- .collect::<Result<Vec<_>>>()?;
- Ok(Expr::ScalarFunction {
- fun,
- args: new_args,
- })
- }
- fun => Ok(Expr::ScalarFunction { fun, args }),
- },
+ Expr::ScalarUDF { fun, args } => {
+ let new_expr = coerce_arguments_for_signature(
+ args.as_slice(),
+ &self.schema,
+ &fun.signature,
+ )?;
+ let expr = Expr::ScalarUDF {
+ fun,
+ args: new_expr,
+ };
+ Ok(expr)
+ }
+ Expr::ScalarFunction { fun, args } => {
+ let nex_expr = coerce_arguments_for_signature(
+ args.as_slice(),
+ &self.schema,
+ &function::signature(&fun),
+ )?;
+ let expr = Expr::ScalarFunction {
+ fun,
+ args: nex_expr,
+ };
+ Ok(expr)
+ }
expr => Ok(expr),
}
}
@@ -457,7 +455,9 @@ mod test {
use arrow::datatypes::DataType;
use datafusion_common::{DFField, DFSchema, Result, ScalarValue};
use datafusion_expr::expr_rewriter::ExprRewritable;
- use datafusion_expr::{cast, col, concat, concat_ws, is_true,
ColumnarValue};
+ use datafusion_expr::{
+ cast, col, concat, concat_ws, is_true, BuiltinScalarFunction,
ColumnarValue,
+ };
use datafusion_expr::{
lit,
logical_plan::{EmptyRelation, Projection},
@@ -572,6 +572,30 @@ mod test {
Ok(())
}
+ #[test]
+ fn scalar_function() -> Result<()> {
+ let empty = empty();
+ let lit_expr = lit(10i64);
+ let fun: BuiltinScalarFunction = BuiltinScalarFunction::Abs;
+ let scalar_function_expr = Expr::ScalarFunction {
+ fun,
+ args: vec![lit_expr],
+ };
+ let plan = LogicalPlan::Projection(Projection::try_new(
+ vec![scalar_function_expr],
+ empty,
+ None,
+ )?);
+ let rule = TypeCoercion::new();
+ let mut config = OptimizerConfig::default();
+ let plan = rule.optimize(&plan, &mut config)?;
+ assert_eq!(
+ "Projection: abs(CAST(Int64(10) AS Float64))\n EmptyRelation",
+ &format!("{:?}", plan)
+ );
+ Ok(())
+ }
+
#[test]
fn binary_op_date32_add_interval() -> Result<()> {
//CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("386547056640")
diff --git a/datafusion/physical-expr/src/functions.rs
b/datafusion/physical-expr/src/functions.rs
index 5796f8f7d..7d9e89b52 100644
--- a/datafusion/physical-expr/src/functions.rs
+++ b/datafusion/physical-expr/src/functions.rs
@@ -34,9 +34,8 @@ use crate::execution_props::ExecutionProps;
use crate::{
array_expressions, conditional_expressions, datetime_expressions,
expressions::{cast_column, nullif_func, DEFAULT_DATAFUSION_CAST_OPTIONS},
- math_expressions, string_expressions, struct_expressions,
- type_coercion::coerce,
- PhysicalExpr, ScalarFunctionExpr,
+ math_expressions, string_expressions, struct_expressions, PhysicalExpr,
+ ScalarFunctionExpr,
};
use arrow::{
array::ArrayRef,
@@ -58,15 +57,12 @@ pub fn create_physical_expr(
input_schema: &Schema,
execution_props: &ExecutionProps,
) -> Result<Arc<dyn PhysicalExpr>> {
- let coerced_phy_exprs =
- coerce(input_phy_exprs, input_schema, &function::signature(fun))?;
-
- let coerced_expr_types = coerced_phy_exprs
+ let input_expr_types = input_phy_exprs
.iter()
.map(|e| e.data_type(input_schema))
.collect::<Result<Vec<_>>>()?;
- let data_type = function::return_type(fun, &coerced_expr_types)?;
+ let data_type = function::return_type(fun, &input_expr_types)?;
let fun_expr: ScalarFunctionImplementation = match fun {
// These functions need args and input schema to pick an implementation
@@ -74,7 +70,7 @@ pub fn create_physical_expr(
// here we return either a cast fn or string timestamp translation
based on the expression data type
// so we don't have to pay a per-array/batch cost.
BuiltinScalarFunction::ToTimestamp => {
- Arc::new(match coerced_phy_exprs[0].data_type(input_schema) {
+ Arc::new(match input_phy_exprs[0].data_type(input_schema) {
Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => {
|col_values: &[ColumnarValue]| {
cast_column(
@@ -89,12 +85,12 @@ pub fn create_physical_expr(
return Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function to_timestamp",
other,
- )))
+ )));
}
})
}
BuiltinScalarFunction::ToTimestampMillis => {
- Arc::new(match coerced_phy_exprs[0].data_type(input_schema) {
+ Arc::new(match input_phy_exprs[0].data_type(input_schema) {
Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => {
|col_values: &[ColumnarValue]| {
cast_column(
@@ -109,12 +105,12 @@ pub fn create_physical_expr(
return Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function
to_timestamp_millis",
other,
- )))
+ )));
}
})
}
BuiltinScalarFunction::ToTimestampMicros => {
- Arc::new(match coerced_phy_exprs[0].data_type(input_schema) {
+ Arc::new(match input_phy_exprs[0].data_type(input_schema) {
Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => {
|col_values: &[ColumnarValue]| {
cast_column(
@@ -129,12 +125,12 @@ pub fn create_physical_expr(
return Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function
to_timestamp_micros",
other,
- )))
+ )));
}
})
}
BuiltinScalarFunction::ToTimestampSeconds => Arc::new({
- match coerced_phy_exprs[0].data_type(input_schema) {
+ match input_phy_exprs[0].data_type(input_schema) {
Ok(DataType::Int64) | Ok(DataType::Timestamp(_, None)) => {
|col_values: &[ColumnarValue]| {
cast_column(
@@ -149,12 +145,12 @@ pub fn create_physical_expr(
return Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function
to_timestamp_seconds",
other,
- )))
+ )));
}
}
}),
BuiltinScalarFunction::FromUnixtime => Arc::new({
- match coerced_phy_exprs[0].data_type(input_schema) {
+ match input_phy_exprs[0].data_type(input_schema) {
Ok(DataType::Int64) => |col_values: &[ColumnarValue]| {
cast_column(
&col_values[0],
@@ -166,12 +162,12 @@ pub fn create_physical_expr(
return Err(DataFusionError::Internal(format!(
"Unsupported data type {:?} for function
from_unixtime",
other,
- )))
+ )));
}
}
}),
BuiltinScalarFunction::ArrowTypeof => {
- let input_data_type =
coerced_phy_exprs[0].data_type(input_schema)?;
+ let input_data_type = input_phy_exprs[0].data_type(input_schema)?;
Arc::new(move |_| {
Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(format!(
"{}",
@@ -186,7 +182,7 @@ pub fn create_physical_expr(
Ok(Arc::new(ScalarFunctionExpr::new(
&format!("{}", fun),
fun_expr,
- coerced_phy_exprs,
+ input_phy_exprs.to_vec(),
&data_type,
)))
}
@@ -727,7 +723,7 @@ pub fn create_physical_fun(
return Err(DataFusionError::Internal(format!(
"create_physical_fun: Unsupported scalar function {:?}",
fun
- )))
+ )));
}
})
}
@@ -737,6 +733,7 @@ mod tests {
use super::*;
use crate::expressions::{col, lit};
use crate::from_slice::FromSlice;
+ use crate::type_coercion::coerce;
use arrow::{
array::{
Array, ArrayRef, BinaryArray, BooleanArray, FixedSizeListArray,
Float32Array,
@@ -764,7 +761,7 @@ mod tests {
let columns: Vec<ArrayRef> =
vec![Arc::new(Int32Array::from_slice(&[1]))];
let expr =
- create_physical_expr(&BuiltinScalarFunction::$FUNC, $ARGS,
&schema, &execution_props)?;
+
create_physical_expr_with_type_coercion(&BuiltinScalarFunction::$FUNC, $ARGS,
&schema, &execution_props)?;
// type is correct
assert_eq!(expr.data_type(&schema)?, DataType::$DATA_TYPE);
@@ -2683,7 +2680,12 @@ mod tests {
];
for fun in funs.iter() {
- let expr = create_physical_expr(fun, &[], &schema,
&execution_props);
+ let expr = create_physical_expr_with_type_coercion(
+ fun,
+ &[],
+ &schema,
+ &execution_props,
+ );
match expr {
Ok(..) => {
@@ -2720,7 +2722,7 @@ mod tests {
let funs = [BuiltinScalarFunction::Now, BuiltinScalarFunction::Random];
for fun in funs.iter() {
- create_physical_expr(fun, &[], &schema, &execution_props)?;
+ create_physical_expr_with_type_coercion(fun, &[], &schema,
&execution_props)?;
}
Ok(())
}
@@ -2739,7 +2741,7 @@ mod tests {
let columns: Vec<ArrayRef> = vec![value1, value2];
let execution_props = ExecutionProps::new();
- let expr = create_physical_expr(
+ let expr = create_physical_expr_with_type_coercion(
&BuiltinScalarFunction::MakeArray,
&[col("a", &schema)?, col("b", &schema)?],
&schema,
@@ -2805,7 +2807,7 @@ mod tests {
let col_value: ArrayRef =
Arc::new(StringArray::from_slice(&["aaa-555"]));
let pattern = lit(r".*-(\d*)");
let columns: Vec<ArrayRef> = vec![col_value];
- let expr = create_physical_expr(
+ let expr = create_physical_expr_with_type_coercion(
&BuiltinScalarFunction::RegexpMatch,
&[col("a", &schema)?, pattern],
&schema,
@@ -2844,7 +2846,7 @@ mod tests {
let col_value = lit("aaa-555");
let pattern = lit(r".*-(\d*)");
let columns: Vec<ArrayRef> =
vec![Arc::new(Int32Array::from_slice(&[1]))];
- let expr = create_physical_expr(
+ let expr = create_physical_expr_with_type_coercion(
&BuiltinScalarFunction::RegexpMatch,
&[col_value, pattern],
&schema,
@@ -2872,4 +2874,17 @@ mod tests {
Ok(())
}
+
+ // Helper function
+ // The type coercion will be done in the logical phase, should do the type
coercion for the test
+ fn create_physical_expr_with_type_coercion(
+ fun: &BuiltinScalarFunction,
+ input_phy_exprs: &[Arc<dyn PhysicalExpr>],
+ input_schema: &Schema,
+ execution_props: &ExecutionProps,
+ ) -> Result<Arc<dyn PhysicalExpr>> {
+ let type_coerced_phy_exprs =
+ coerce(input_phy_exprs, input_schema,
&function::signature(fun)).unwrap();
+ create_physical_expr(fun, &type_coerced_phy_exprs, input_schema,
execution_props)
+ }
}