This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new fa31c781d2 Improve coerce API so it does not need DFSchema (#10331)
fa31c781d2 is described below

commit fa31c781d2cb0bdfba06dcc07bc75d9a5f9686b2
Author: Andrew Lamb <[email protected]>
AuthorDate: Thu May 2 07:32:27 2024 -0400

    Improve coerce API so it does not need DFSchema (#10331)
---
 datafusion-examples/examples/expr_api.rs           |  2 +-
 datafusion/core/src/test_util/parquet.rs           |  2 +-
 datafusion/optimizer/src/analyzer/type_coercion.rs | 99 +++++++++++-----------
 .../src/simplify_expressions/expr_simplifier.rs    | 20 +----
 4 files changed, 54 insertions(+), 69 deletions(-)

diff --git a/datafusion-examples/examples/expr_api.rs 
b/datafusion-examples/examples/expr_api.rs
index 6e9c42480c..2c1470a1d6 100644
--- a/datafusion-examples/examples/expr_api.rs
+++ b/datafusion-examples/examples/expr_api.rs
@@ -258,7 +258,7 @@ pub fn physical_expr(schema: &Schema, expr: Expr) -> 
Result<Arc<dyn PhysicalExpr
         
ExprSimplifier::new(SimplifyContext::new(&props).with_schema(df_schema.clone()));
 
     // apply type coercion here to ensure types match
-    let expr = simplifier.coerce(expr, df_schema.clone())?;
+    let expr = simplifier.coerce(expr, &df_schema)?;
 
     create_physical_expr(&expr, df_schema.as_ref(), &props)
 }
diff --git a/datafusion/core/src/test_util/parquet.rs 
b/datafusion/core/src/test_util/parquet.rs
index f949058769..1d5668c7ec 100644
--- a/datafusion/core/src/test_util/parquet.rs
+++ b/datafusion/core/src/test_util/parquet.rs
@@ -169,7 +169,7 @@ impl TestParquetFile {
         let parquet_options = ctx.copied_table_options().parquet;
         if let Some(filter) = maybe_filter {
             let simplifier = ExprSimplifier::new(context);
-            let filter = simplifier.coerce(filter, df_schema.clone()).unwrap();
+            let filter = simplifier.coerce(filter, &df_schema).unwrap();
             let physical_filter_expr =
                 create_physical_expr(&filter, &df_schema, 
&ExecutionProps::default())?;
             let parquet_exec = Arc::new(ParquetExec::new(
diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs 
b/datafusion/optimizer/src/analyzer/type_coercion.rs
index b7f95d83e8..9295b08f41 100644
--- a/datafusion/optimizer/src/analyzer/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -25,7 +25,7 @@ use datafusion_common::config::ConfigOptions;
 use datafusion_common::tree_node::{Transformed, TreeNodeRewriter};
 use datafusion_common::{
     exec_err, internal_err, not_impl_err, plan_datafusion_err, plan_err, 
DFSchema,
-    DFSchemaRef, DataFusionError, Result, ScalarValue,
+    DataFusionError, Result, ScalarValue,
 };
 use datafusion_expr::expr::{
     self, AggregateFunctionDefinition, Between, BinaryExpr, Case, Exists, 
InList,
@@ -99,9 +99,7 @@ fn analyze_internal(
     // select t2.c2 from t1 where t1.c1 in (select t2.c1 from t2 where 
t2.c2=t1.c3)
     schema.merge(external_schema);
 
-    let mut expr_rewrite = TypeCoercionRewriter {
-        schema: Arc::new(schema),
-    };
+    let mut expr_rewrite = TypeCoercionRewriter { schema: &schema };
 
     let new_expr = plan
         .expressions()
@@ -116,11 +114,11 @@ fn analyze_internal(
     plan.with_new_exprs(new_expr, new_inputs)
 }
 
-pub(crate) struct TypeCoercionRewriter {
-    pub(crate) schema: DFSchemaRef,
+pub(crate) struct TypeCoercionRewriter<'a> {
+    pub(crate) schema: &'a DFSchema,
 }
 
-impl TreeNodeRewriter for TypeCoercionRewriter {
+impl<'a> TreeNodeRewriter for TypeCoercionRewriter<'a> {
     type Node = Expr;
 
     fn f_up(&mut self, expr: Expr) -> Result<Transformed<Expr>> {
@@ -132,14 +130,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
                 subquery,
                 outer_ref_columns,
             }) => {
-                let new_plan = analyze_internal(&self.schema, &subquery)?;
+                let new_plan = analyze_internal(self.schema, &subquery)?;
                 Ok(Transformed::yes(Expr::ScalarSubquery(Subquery {
                     subquery: Arc::new(new_plan),
                     outer_ref_columns,
                 })))
             }
             Expr::Exists(Exists { subquery, negated }) => {
-                let new_plan = analyze_internal(&self.schema, 
&subquery.subquery)?;
+                let new_plan = analyze_internal(self.schema, 
&subquery.subquery)?;
                 Ok(Transformed::yes(Expr::Exists(Exists {
                     subquery: Subquery {
                         subquery: Arc::new(new_plan),
@@ -153,8 +151,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
                 subquery,
                 negated,
             }) => {
-                let new_plan = analyze_internal(&self.schema, 
&subquery.subquery)?;
-                let expr_type = expr.get_type(&self.schema)?;
+                let new_plan = analyze_internal(self.schema, 
&subquery.subquery)?;
+                let expr_type = expr.get_type(self.schema)?;
                 let subquery_type = new_plan.schema().field(0).data_type();
                 let common_type = comparison_coercion(&expr_type, 
subquery_type).ok_or(plan_datafusion_err!(
                         "expr type {expr_type:?} can't cast to 
{subquery_type:?} in InSubquery"
@@ -165,32 +163,32 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
                     outer_ref_columns: subquery.outer_ref_columns,
                 };
                 Ok(Transformed::yes(Expr::InSubquery(InSubquery::new(
-                    Box::new(expr.cast_to(&common_type, &self.schema)?),
+                    Box::new(expr.cast_to(&common_type, self.schema)?),
                     cast_subquery(new_subquery, &common_type)?,
                     negated,
                 ))))
             }
             Expr::Not(expr) => 
Ok(Transformed::yes(not(get_casted_expr_for_bool_op(
                 *expr,
-                &self.schema,
+                self.schema,
             )?))),
             Expr::IsTrue(expr) => Ok(Transformed::yes(is_true(
-                get_casted_expr_for_bool_op(*expr, &self.schema)?,
+                get_casted_expr_for_bool_op(*expr, self.schema)?,
             ))),
             Expr::IsNotTrue(expr) => Ok(Transformed::yes(is_not_true(
-                get_casted_expr_for_bool_op(*expr, &self.schema)?,
+                get_casted_expr_for_bool_op(*expr, self.schema)?,
             ))),
             Expr::IsFalse(expr) => Ok(Transformed::yes(is_false(
-                get_casted_expr_for_bool_op(*expr, &self.schema)?,
+                get_casted_expr_for_bool_op(*expr, self.schema)?,
             ))),
             Expr::IsNotFalse(expr) => Ok(Transformed::yes(is_not_false(
-                get_casted_expr_for_bool_op(*expr, &self.schema)?,
+                get_casted_expr_for_bool_op(*expr, self.schema)?,
             ))),
             Expr::IsUnknown(expr) => Ok(Transformed::yes(is_unknown(
-                get_casted_expr_for_bool_op(*expr, &self.schema)?,
+                get_casted_expr_for_bool_op(*expr, self.schema)?,
             ))),
             Expr::IsNotUnknown(expr) => Ok(Transformed::yes(is_not_unknown(
-                get_casted_expr_for_bool_op(*expr, &self.schema)?,
+                get_casted_expr_for_bool_op(*expr, self.schema)?,
             ))),
             Expr::Like(Like {
                 negated,
@@ -199,8 +197,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
                 escape_char,
                 case_insensitive,
             }) => {
-                let left_type = expr.get_type(&self.schema)?;
-                let right_type = pattern.get_type(&self.schema)?;
+                let left_type = expr.get_type(self.schema)?;
+                let right_type = pattern.get_type(self.schema)?;
                 let coerced_type = like_coercion(&left_type,  
&right_type).ok_or_else(|| {
                     let op_name = if case_insensitive {
                         "ILIKE"
@@ -211,8 +209,8 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
                         "There isn't a common type to coerce {left_type} and 
{right_type} in {op_name} expression"
                     )
                 })?;
-                let expr = Box::new(expr.cast_to(&coerced_type, 
&self.schema)?);
-                let pattern = Box::new(pattern.cast_to(&coerced_type, 
&self.schema)?);
+                let expr = Box::new(expr.cast_to(&coerced_type, self.schema)?);
+                let pattern = Box::new(pattern.cast_to(&coerced_type, 
self.schema)?);
                 Ok(Transformed::yes(Expr::Like(Like::new(
                     negated,
                     expr,
@@ -223,14 +221,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
             }
             Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
                 let (left_type, right_type) = get_input_types(
-                    &left.get_type(&self.schema)?,
+                    &left.get_type(self.schema)?,
                     &op,
-                    &right.get_type(&self.schema)?,
+                    &right.get_type(self.schema)?,
                 )?;
                 Ok(Transformed::yes(Expr::BinaryExpr(BinaryExpr::new(
-                    Box::new(left.cast_to(&left_type, &self.schema)?),
+                    Box::new(left.cast_to(&left_type, self.schema)?),
                     op,
-                    Box::new(right.cast_to(&right_type, &self.schema)?),
+                    Box::new(right.cast_to(&right_type, self.schema)?),
                 ))))
             }
             Expr::Between(Between {
@@ -239,15 +237,15 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
                 low,
                 high,
             }) => {
-                let expr_type = expr.get_type(&self.schema)?;
-                let low_type = low.get_type(&self.schema)?;
+                let expr_type = expr.get_type(self.schema)?;
+                let low_type = low.get_type(self.schema)?;
                 let low_coerced_type = comparison_coercion(&expr_type, 
&low_type)
                     .ok_or_else(|| {
                         DataFusionError::Internal(format!(
                             "Failed to coerce types {expr_type} and {low_type} 
in BETWEEN expression"
                         ))
                     })?;
-                let high_type = high.get_type(&self.schema)?;
+                let high_type = high.get_type(self.schema)?;
                 let high_coerced_type = comparison_coercion(&expr_type, 
&low_type)
                     .ok_or_else(|| {
                         DataFusionError::Internal(format!(
@@ -262,10 +260,10 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
                             ))
                         })?;
                 Ok(Transformed::yes(Expr::Between(Between::new(
-                    Box::new(expr.cast_to(&coercion_type, &self.schema)?),
+                    Box::new(expr.cast_to(&coercion_type, self.schema)?),
                     negated,
-                    Box::new(low.cast_to(&coercion_type, &self.schema)?),
-                    Box::new(high.cast_to(&coercion_type, &self.schema)?),
+                    Box::new(low.cast_to(&coercion_type, self.schema)?),
+                    Box::new(high.cast_to(&coercion_type, self.schema)?),
                 ))))
             }
             Expr::InList(InList {
@@ -273,10 +271,10 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
                 list,
                 negated,
             }) => {
-                let expr_data_type = expr.get_type(&self.schema)?;
+                let expr_data_type = expr.get_type(self.schema)?;
                 let list_data_types = list
                     .iter()
-                    .map(|list_expr| list_expr.get_type(&self.schema))
+                    .map(|list_expr| list_expr.get_type(self.schema))
                     .collect::<Result<Vec<_>>>()?;
                 let result_type =
                     get_coerce_type_for_list(&expr_data_type, 
&list_data_types);
@@ -286,11 +284,11 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
                     ),
                     Some(coerced_type) => {
                         // find the coerced type
-                        let cast_expr = expr.cast_to(&coerced_type, 
&self.schema)?;
+                        let cast_expr = expr.cast_to(&coerced_type, 
self.schema)?;
                         let cast_list_expr = list
                             .into_iter()
                             .map(|list_expr| {
-                                list_expr.cast_to(&coerced_type, &self.schema)
+                                list_expr.cast_to(&coerced_type, self.schema)
                             })
                             .collect::<Result<Vec<_>>>()?;
                         Ok(Transformed::yes(Expr::InList(InList ::new(
@@ -302,18 +300,17 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
                 }
             }
             Expr::Case(case) => {
-                let case = coerce_case_expression(case, &self.schema)?;
+                let case = coerce_case_expression(case, self.schema)?;
                 Ok(Transformed::yes(Expr::Case(case)))
             }
             Expr::ScalarFunction(ScalarFunction { func_def, args }) => match 
func_def {
                 ScalarFunctionDefinition::UDF(fun) => {
                     let new_expr = coerce_arguments_for_signature(
                         args,
-                        &self.schema,
+                        self.schema,
                         fun.signature(),
                     )?;
-                    let new_expr =
-                        coerce_arguments_for_fun(new_expr, &self.schema, 
&fun)?;
+                    let new_expr = coerce_arguments_for_fun(new_expr, 
self.schema, &fun)?;
                     Ok(Transformed::yes(Expr::ScalarFunction(
                         ScalarFunction::new_udf(fun, new_expr),
                     )))
@@ -331,7 +328,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
                     let new_expr = coerce_agg_exprs_for_signature(
                         &fun,
                         args,
-                        &self.schema,
+                        self.schema,
                         &fun.signature(),
                     )?;
                     Ok(Transformed::yes(Expr::AggregateFunction(
@@ -348,7 +345,7 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
                 AggregateFunctionDefinition::UDF(fun) => {
                     let new_expr = coerce_arguments_for_signature(
                         args,
-                        &self.schema,
+                        self.schema,
                         fun.signature(),
                     )?;
                     Ok(Transformed::yes(Expr::AggregateFunction(
@@ -375,14 +372,14 @@ impl TreeNodeRewriter for TypeCoercionRewriter {
                 null_treatment,
             }) => {
                 let window_frame =
-                    coerce_window_frame(window_frame, &self.schema, 
&order_by)?;
+                    coerce_window_frame(window_frame, self.schema, &order_by)?;
 
                 let args = match &fun {
                     expr::WindowFunctionDefinition::AggregateFunction(fun) => {
                         coerce_agg_exprs_for_signature(
                             fun,
                             args,
-                            &self.schema,
+                            self.schema,
                             &fun.signature(),
                         )?
                     }
@@ -495,7 +492,7 @@ fn coerce_frame_bound(
 // For example, ROWS and GROUPS frames use `UInt64` during calculations.
 fn coerce_window_frame(
     window_frame: WindowFrame,
-    schema: &DFSchemaRef,
+    schema: &DFSchema,
     expressions: &[Expr],
 ) -> Result<WindowFrame> {
     let mut window_frame = window_frame;
@@ -531,7 +528,7 @@ fn coerce_window_frame(
 
 // Support the `IsTrue` `IsNotTrue` `IsFalse` `IsNotFalse` type coercion.
 // The above op will be rewrite to the binary op when creating the physical op.
-fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchemaRef) -> 
Result<Expr> {
+fn get_casted_expr_for_bool_op(expr: Expr, schema: &DFSchema) -> Result<Expr> {
     let left_type = expr.get_type(schema)?;
     get_input_types(&left_type, &Operator::IsDistinctFrom, 
&DataType::Boolean)?;
     expr.cast_to(&DataType::Boolean, schema)
@@ -615,7 +612,7 @@ fn coerce_agg_exprs_for_signature(
         .collect()
 }
 
-fn coerce_case_expression(case: Case, schema: &DFSchemaRef) -> Result<Case> {
+fn coerce_case_expression(case: Case, schema: &DFSchema) -> Result<Case> {
     // Given expressions like:
     //
     // CASE a1
@@ -1238,7 +1235,7 @@ mod test {
             vec![Field::new("a", DataType::Int64, true)].into(),
             std::collections::HashMap::new(),
         )?);
-        let mut rewriter = TypeCoercionRewriter { schema };
+        let mut rewriter = TypeCoercionRewriter { schema: &schema };
         let expr = is_true(lit(12i32).gt(lit(13i64)));
         let expected = is_true(cast(lit(12i32), 
DataType::Int64).gt(lit(13i64)));
         let result = expr.rewrite(&mut rewriter).data()?;
@@ -1249,7 +1246,7 @@ mod test {
             vec![Field::new("a", DataType::Int64, true)].into(),
             std::collections::HashMap::new(),
         )?);
-        let mut rewriter = TypeCoercionRewriter { schema };
+        let mut rewriter = TypeCoercionRewriter { schema: &schema };
         let expr = is_true(lit(12i32).eq(lit(13i64)));
         let expected = is_true(cast(lit(12i32), 
DataType::Int64).eq(lit(13i64)));
         let result = expr.rewrite(&mut rewriter).data()?;
@@ -1260,7 +1257,7 @@ mod test {
             vec![Field::new("a", DataType::Int64, true)].into(),
             std::collections::HashMap::new(),
         )?);
-        let mut rewriter = TypeCoercionRewriter { schema };
+        let mut rewriter = TypeCoercionRewriter { schema: &schema };
         let expr = is_true(lit(12i32).lt(lit(13i64)));
         let expected = is_true(cast(lit(12i32), 
DataType::Int64).lt(lit(13i64)));
         let result = expr.rewrite(&mut rewriter).data()?;
diff --git a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs 
b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
index fb5125f097..4d7a207afb 100644
--- a/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
+++ b/datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
@@ -31,9 +31,7 @@ use datafusion_common::{
     cast::{as_large_list_array, as_list_array},
     tree_node::{Transformed, TransformedResult, TreeNode, TreeNodeRewriter},
 };
-use datafusion_common::{
-    internal_err, DFSchema, DFSchemaRef, DataFusionError, Result, ScalarValue,
-};
+use datafusion_common::{internal_err, DFSchema, DataFusionError, Result, 
ScalarValue};
 use datafusion_expr::expr::{InList, InSubquery};
 use datafusion_expr::simplify::ExprSimplifyResult;
 use datafusion_expr::{
@@ -208,14 +206,8 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
     ///
     /// See the [type coercion module](datafusion_expr::type_coercion)
     /// documentation for more details on type coercion
-    ///
-    // Would be nice if this API could use the SimplifyInfo
-    // rather than creating an DFSchemaRef coerces rather than doing
-    // it manually.
-    // https://github.com/apache/datafusion/issues/3793
-    pub fn coerce(&self, expr: Expr, schema: DFSchemaRef) -> Result<Expr> {
+    pub fn coerce(&self, expr: Expr, schema: &DFSchema) -> Result<Expr> {
         let mut expr_rewrite = TypeCoercionRewriter { schema };
-
         expr.rewrite(&mut expr_rewrite).data()
     }
 
@@ -1686,7 +1678,7 @@ mod tests {
         sync::Arc,
     };
 
-    use datafusion_common::{assert_contains, ToDFSchema};
+    use datafusion_common::{assert_contains, DFSchemaRef, ToDFSchema};
     use datafusion_expr::{interval_arithmetic::Interval, *};
 
     use crate::simplify_expressions::SimplifyContext;
@@ -1721,11 +1713,7 @@ mod tests {
         // should fully simplify to 3 < i (though i has been coerced to i64)
         let expected = lit(3i64).lt(col("i"));
 
-        // Would be nice if this API could use the SimplifyInfo
-        // rather than creating an DFSchemaRef coerces rather than doing
-        // it manually.
-        // https://github.com/apache/datafusion/issues/3793
-        let expr = simplifier.coerce(expr, schema).unwrap();
+        let expr = simplifier.coerce(expr, &schema).unwrap();
 
         assert_eq!(expected, simplifier.simplify(expr).unwrap());
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to