This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new fa31c781d2 Improve coerce API so it does not need DFSchema (#10331)
fa31c781d2 is described below
commit fa31c781d2cb0bdfba06dcc07bc75d9a5f9686b2
Author: Andrew Lamb <[email protected]>
AuthorDate: Thu May 2 07:32:27 2024 -0400
Improve coerce API so it does not need DFSchema (#10331)
---
datafusion-examples/examples/expr_api.rs | 2 +-
datafusion/core/src/test_util/parquet.rs | 2 +-
datafusion/optimizer/src/analyzer/type_coercion.rs | 99 +++++++++++-----------
.../src/simplify_expressions/expr_simplifier.rs | 20 +----
4 files changed, 54 insertions(+), 69 deletions(-)
diff --git a/datafusion-examples/examples/expr_api.rs
b/datafusion-examples/examples/expr_api.rs
index 6e9c42480c..2c1470a1d6 100644
--- a/datafusion-examples/examples/expr_api.rs
+++ b/datafusion-examples/examples/expr_api.rs
@@ -258,7 +258,7 @@ pub fn physical_expr(schema: &Schema, expr: Expr) ->
Result<Arc<dyn PhysicalExpr
ExprSimplifier::new(SimplifyContext::new(&props).with_schema(df_schema.clone()));
// apply type coercion here to ensure types match
- let expr = simplifier.coerce(expr, df_schema.clone())?;
+ let expr = simplifier.coerce(expr, &df_schema)?;
create_physical_expr(&expr, df_schema.as_ref(), &props)
}
diff --git a/datafusion/core/src/test_util/parquet.rs
b/datafusion/core/src/test_util/parquet.rs
index f949058769..1d5668c7ec 100644
--- a/datafusion/core/src/test_util/parquet.rs
+++ b/datafusion/core/src/test_util/parquet.rs
@@ -169,7 +169,7 @@ impl TestParquetFile {
let parquet_options = ctx.copied_table_options().parquet;
if let Some(filter) = maybe_filter {
let simplifier = ExprSimplifier::new(context);
- let filter = simplifier.coerce(filter, df_schema.clone()).unwrap();
+ let filter = simplifier.coerce(filter, &df_schema).unwrap();
let physical_filter_expr =
create_physical_expr(&filter, &df_schema,
&ExecutionProps::default())?;
let parquet_exec = Arc::new(ParquetExec::new(
diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs
b/datafusion/optimizer/src/analyzer/type_coercion.rs
index b7f95d83e8..9295b08f41 100644
--- a/datafusion/optimizer/src/analyzer/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -25,7 +25,7 @@ use datafusion_common::config::ConfigOptions;
use datafusion_common::tree_node::{Transformed, TreeNodeRewriter};
use datafusion_common::{
exec_err, internal_err, not_impl_err, plan_datafusion_err, plan_err,
DFSchema,
- DFSchemaRef, DataFusionError, Result, ScalarValue,
+ DataFusionError, Result, ScalarValue,
};
use datafusion_expr::expr::{
self, AggregateFunctionDefinition, Between, BinaryExpr, Case, Exists,
InList,
@@ -99,9 +99,7 @@ fn analyze_internal(
// select t2.c2 from t1 where t1.c1 in (select t2.c1 from t2 where
t2.c2=t1.c3)
schema.merge(external_schema);
- let mut expr_rewrite = TypeCoercionRewriter {
- schema: Arc::new(schema),
- };
+ let mut expr_rewrite = TypeCoercionRewriter { schema: &schema };
let new_expr = plan
.expressions()
@@ -116,11 +114,11 @@ fn analyze_internal(
plan.with_new_exprs(new_expr, new_inputs)
}
-pub(crate) struct TypeCoercionRewriter {
- pub(crate) schema: DFSchemaRef,
+pub(crate) struct TypeCoercionRewriter<'a> {
+ pub(crate) schema: &'a DFSchema,
}
-impl TreeNodeRewriter for TypeCoercionRewriter {
+impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> {
type Node = Expr;
fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
@@ -132,14 +130,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
subquery,
outer_ref_columns,
}) => {
- let new_plan = analyze_internal(&self.schema, &subquery)?;
+ let new_plan = analyze_internal(self.schema, &subquery)?;
Ok(Transformed::yes(Expr::ScalarSubquery(Subquery {
subquery: Arc::new(new_plan),
outer_ref_columns,
})))
}
Expr::Exists(Exists { subquery, negated }) => {
- let new_plan = analyze_internal(&self.schema,
&subquery.subquery)?;
+ let new_plan = analyze_internal(self.schema,
&subquery.subquery)?;
Ok(Transformed::yes(Expr::Exists(Exists {
subquery: Subquery {
subquery: Arc::new(new_plan),
@@ -153,8 +151,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
subquery,
negated,
}) => {
- let new_plan = analyze_internal(&self.schema,
&subquery.subquery)?;
- let expr_type = expr.get_type(&self.schema)?;
+ let new_plan = analyze_internal(self.schema,
&subquery.subquery)?;
+ let expr_type = expr.get_type(self.schema)?;
let subquery_type = new_plan.schema().field(0).data_type();
let common_type = comparison_coercion(&expr_type,
subquery_type).ok_or(plan_datafusion_err!(
"expr type {expr_type:?} can't cast to
{subquery_type:?} in InSubquery"
@@ -165,32 +163,32 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
outer_ref_columns: subquery.outer_ref_columns,
};
Ok(Transformed::yes(Expr::InSubquery(InSubquery::new(
- Box::new(expr.cast_to(&common_type, &self.schema)?),
+ Box::new(expr.cast_to(&common_type, self.schema)?),
cast_subquery(new_subquery, &common_type)?,
negated,
))))
}
Expr::Not(expr) =>
Ok(Transformed::yes(not(get_casted_expr_for_bool_op(
*expr,
- &self.schema,
+ self.schema,
)?))),
Expr::IsTrue(expr) => Ok(Transformed::yes(is_true(
- get_casted_expr_for_bool_op(*expr, &self.schema)?,
+ get_casted_expr_for_bool_op(*expr, self.schema)?,
))),
Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true(
- get_casted_expr_for_bool_op(*expr, &self.schema)?,
+ get_casted_expr_for_bool_op(*expr, self.schema)?,
))),
Expr::IsFalse(expr) => Ok(Transformed::yes(is_false(
- get_casted_expr_for_bool_op(*expr, &self.schema)?,
+ get_casted_expr_for_bool_op(*expr, self.schema)?,
))),
Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false(
- get_casted_expr_for_bool_op(*expr, &self.schema)?,
+ get_casted_expr_for_bool_op(*expr, self.schema)?,
))),
Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown(
- get_casted_expr_for_bool_op(*expr, &self.schema)?,
+ get_casted_expr_for_bool_op(*expr, self.schema)?,
))),
Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown(
- get_casted_expr_for_bool_op(*expr, &self.schema)?,
+ get_casted_expr_for_bool_op(*expr, self.schema)?,
))),
Expr::Like(Like {
negated,
@@ -199,8 +197,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
escape_char,
case_insensitive,
}) => {
- let left_type = expr.get_type(&self.schema)?;
- let right_type = pattern.get_type(&self.schema)?;
+ let left_type = expr.get_type(self.schema)?;
+ let right_type = pattern.get_type(self.schema)?;
let coerced_type = like_coercion(&left_type,
&right_type).ok_or_else(|| {
let op_name = if case_insensitive {
"ILIKE"
@@ -211,8 +209,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
"There isn't a common type to coerce {left_type} and
{right_type} in {op_name} expression"
)
})?;
- let expr = Box::new(expr.cast_to(&coerced_type,
&self.schema)?);
- let pattern = Box::new(pattern.cast_to(&coerced_type,
&self.schema)?);
+ let expr = Box::new(expr.cast_to(&coerced_type, self.schema)?);
+ let pattern = Box::new(pattern.cast_to(&coerced_type,
self.schema)?);
Ok(Transformed::yes(Expr::Like(Like::new(
negated,
expr,
@@ -223,14 +221,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
}
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
let (left_type, right_type) = get_input_types(
- &left.get_type(&self.schema)?,
+ &left.get_type(self.schema)?,
&op,
- &right.get_type(&self.schema)?,
+ &right.get_type(self.schema)?,
)?;
Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new(
- Box::new(left.cast_to(&left_type, &self.schema)?),
+ Box::new(left.cast_to(&left_type, self.schema)?),
op,
- Box::new(right.cast_to(&right_type, &self.schema)?),
+ Box::new(right.cast_to(&right_type, self.schema)?),
))))
}
Expr::Between(Between {
@@ -239,15 +237,15 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
low,
high,
}) => {
- let expr_type = expr.get_type(&self.schema)?;
- let low_type = low.get_type(&self.schema)?;
+ let expr_type = expr.get_type(self.schema)?;
+ let low_type = low.get_type(self.schema)?;
let low_coerced_type = comparison_coercion(&expr_type,
&low_type)
.ok_or_else(|| {
DataFusionError::Internal(format!(
"Failed to coerce types {expr_type} and {low_type}
in BETWEEN expression"
))
})?;
- let high_type = high.get_type(&self.schema)?;
+ let high_type = high.get_type(self.schema)?;
let high_coerced_type = comparison_coercion(&expr_type,
&low_type)
.ok_or_else(|| {
DataFusionError::Internal(format!(
@@ -262,10 +260,10 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
))
})?;
Ok(Transformed::yes(Expr::Between(Between::new(
- Box::new(expr.cast_to(&coercion_type, &self.schema)?),
+ Box::new(expr.cast_to(&coercion_type, self.schema)?),
negated,
- Box::new(low.cast_to(&coercion_type, &self.schema)?),
- Box::new(high.cast_to(&coercion_type, &self.schema)?),
+ Box::new(low.cast_to(&coercion_type, self.schema)?),
+ Box::new(high.cast_to(&coercion_type, self.schema)?),
))))
}
Expr::InList(InList {
@@ -273,10 +271,10 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
list,
negated,
}) => {
- let expr_data_type = expr.get_type(&self.schema)?;
+ let expr_data_type = expr.get_type(self.schema)?;
let list_data_types = list
.iter()
- .map(|list_expr| list_expr.get_type(&self.schema))
+ .map(|list_expr| list_expr.get_type(self.schema))
.collect::<Result<Vec<_>>>()?;
let result_type =
get_coerce_type_for_list(&expr_data_type,
&list_data_types);
@@ -286,11 +284,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
),
Some(coerced_type) => {
// find the coerced type
- let cast_expr = expr.cast_to(&coerced_type,
&self.schema)?;
+ let cast_expr = expr.cast_to(&coerced_type,
self.schema)?;
let cast_list_expr = list
.into_iter()
.map(|list_expr| {
- list_expr.cast_to(&coerced_type, &self.schema)
+ list_expr.cast_to(&coerced_type, self.schema)
})
.collect::<Result<Vec<_>>>()?;
Ok(Transformed::yes(Expr::InList(InList ::new(
@@ -302,18 +300,17 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
}
}
Expr::Case(case) => {
- let case = coerce_case_expression(case, &self.schema)?;
+ let case = coerce_case_expression(case, self.schema)?;
Ok(Transformed::yes(Expr::Case(case)))
}
Expr::ScalarFunction(ScalarFunction { func_def, args }) => match
func_def {
ScalarFunctionDefinition::UDF(fun) => {
let new_expr = coerce_arguments_for_signature(
args,
- &self.schema,
+ self.schema,
fun.signature(),
)?;
- let new_expr =
- coerce_arguments_for_fun(new_expr, &self.schema,
&fun)?;
+ let new_expr = coerce_arguments_for_fun(new_expr,
self.schema, &fun)?;
Ok(Transformed::yes(Expr::ScalarFunction(
ScalarFunction::new_udf(fun, new_expr),
)))
@@ -331,7 +328,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
let new_expr = coerce_agg_exprs_for_signature(
&fun,
args,
- &self.schema,
+ self.schema,
&fun.signature(),
)?;
Ok(Transformed::yes(Expr::AggregateFunction(
@@ -348,7 +345,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
AggregateFunctionDefinition::UDF(fun) => {
let new_expr = coerce_arguments_for_signature(
args,
- &self.schema,
+ self.schema,
fun.signature(),
)?;
Ok(Transformed::yes(Expr::AggregateFunction(
@@ -375,14 +372,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
null_treatment,
}) => {
let window_frame =
- coerce_window_frame(window_frame, &self.schema,
&order_by)?;
+ coerce_window_frame(window_frame, self.schema, &order_by)?;
let args = match &fun {
expr::WindowFunctionDefinition::AggregateFunction(fun) => {
coerce_agg_exprs_for_signature(
fun,
args,
- &self.schema,
+ self.schema,
&fun.signature(),
)?
}
@@ -495,7 +492,7 @@ fn coerce_frame_bound(
// For example, ROWS and GROUPS frames use `UInt64` during calculations.
fn coerce_window_frame(
window_frame: WindowFrame,
- schema: &DFSchemaRef,
+ schema: &DFSchema,
expressions: &[Expr],
) -> Result<WindowFrame> {
let mut window_frame = window_frame;
@@ -531,7 +528,7 @@ fn coerce_window_frame(
// Support the `IsTrue` `IsNotTrue` `IsFalse` `IsNotFalse` type coercion.
// The above op will be rewrite to the binary op when creating the physical op.
-fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchemaRef) ->
Result<Expr> {
+fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchema) -> Result<Expr> {
let left_type = expr.get_type(schema)?;
get_input_types(&left_type, &Operator::IsDistinctFrom,
&DataType::Boolean)?;
expr.cast_to(&DataType::Boolean, schema)
@@ -615,7 +612,7 @@ fn coerce_agg_exprs_for_signature(
.collect()
}
-fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result<Case> {
+fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result<Case> {
// Given expressions like:
//
// CASE a1
@@ -1238,7 +1235,7 @@ mod test {
vec![Field::new("a", DataType::Int64, true)].into(),
std::collections::HashMap::new(),
)?);
- let mut rewriter = TypeCoercionRewriter { schema };
+ let mut rewriter = TypeCoercionRewriter { schema: &schema };
let expr = is_true(lit(12i32).gt(lit(13i64)));
let expected = is_true(cast(lit(12i32),
DataType::Int64).gt(lit(13i64)));
let result = expr.rewrite(&mut rewriter).data()?;
@@ -1249,7 +1246,7 @@ mod test {
vec![Field::new("a", DataType::Int64, true)].into(),
std::collections::HashMap::new(),
)?);
- let mut rewriter = TypeCoercionRewriter { schema };
+ let mut rewriter = TypeCoercionRewriter { schema: &schema };
let expr = is_true(lit(12i32).eq(lit(13i64)));
let expected = is_true(cast(lit(12i32),
DataType::Int64).eq(lit(13i64)));
let result = expr.rewrite(&mut rewriter).data()?;
@@ -1260,7 +1257,7 @@ mod test {
vec![Field::new("a", DataType::Int64, true)].into(),
std::collections::HashMap::new(),
)?);
- let mut rewriter = TypeCoercionRewriter { schema };
+ let mut rewriter = TypeCoercionRewriter { schema: &schema };
let expr = is_true(lit(12i32).lt(lit(13i64)));
let expected = is_true(cast(lit(12i32),
DataType::Int64).lt(lit(13i64)));
let result = expr.rewrite(&mut rewriter).data()?;
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index fb5125f097..4d7a207afb 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -31,9 +31,7 @@ use datafusion_common::{
cast::{as_large_list_array, as_list_array},
tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter},
};
-use datafusion_common::{
- internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue,
-};
+use datafusion_common::{internal_err, DFSchema, DataFusionError, Result,
ScalarValue};
use datafusion_expr::expr::{InList, InSubquery};
use datafusion_expr::simplify::ExprSimplifyResult;
use datafusion_expr::{
@@ -208,14 +206,8 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
///
/// See the [type coercion module](datafusion_expr::type_coercion)
/// documentation for more details on type coercion
- ///
- // Would be nice if this API could use the SimplifyInfo
- // rather than creating an DFSchemaRef coerces rather than doing
- // it manually.
- // https://github.com/apache/datafusion/issues/3793
- pub fn coerce(&self, expr: Expr, schema: DFSchemaRef) -> Result<Expr> {
+ pub fn coerce(&self, expr: Expr, schema: &DFSchema) -> Result<Expr> {
let mut expr_rewrite = TypeCoercionRewriter { schema };
-
expr.rewrite(&mut expr_rewrite).data()
}
@@ -1686,7 +1678,7 @@ mod tests {
sync::Arc,
};
- use datafusion_common::{assert_contains, ToDFSchema};
+ use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
use datafusion_expr::{interval_arithmetic::Interval, *};
use crate::simplify_expressions::SimplifyContext;
@@ -1721,11 +1713,7 @@ mod tests {
// should fully simplify to 3 < i (though i has been coerced to i64)
let expected = lit(3i64).lt(col("i"));
- // Would be nice if this API could use the SimplifyInfo
- // rather than creating an DFSchemaRef coerces rather than doing
- // it manually.
- // https://github.com/apache/datafusion/issues/3793
- let expr = simplifier.coerce(expr, schema).unwrap();
+ let expr = simplifier.coerce(expr, &schema).unwrap();
assert_eq!(expected, simplifier.simplify(expr).unwrap());
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]