This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 7427a801a Fix some bugs in TypeCoercion rule (#3407)
7427a801a is described below

commit 7427a801a752223c453c919ef88399cb40820f0c
Author: Andy Grove <[email protected]>
AuthorDate: Mon Sep 12 11:03:21 2022 -0600

    Fix some bugs in TypeCoercion rule (#3407)
    
    * Fix schema bug in TypeCoercion rule
    
    * Add type coercion for between
    
    * add workaround for INTERVAL
    
    * fix regression
    
    * add support for coercion between Date32/Date64 and fix regressions caused 
by recent merges to master
    
    * fix error message
    
    * update comments and link to follow-on issue
---
 datafusion/expr/src/binary_rule.rs             |   8 ++
 datafusion/optimizer/src/type_coercion.rs      | 117 ++++++++++++++++++++-----
 datafusion/optimizer/tests/integration-test.rs |  50 +++++++++++
 3 files changed, 152 insertions(+), 23 deletions(-)

diff --git a/datafusion/expr/src/binary_rule.rs 
b/datafusion/expr/src/binary_rule.rs
index 24d1cb32d..8f4cf3356 100644
--- a/datafusion/expr/src/binary_rule.rs
+++ b/datafusion/expr/src/binary_rule.rs
@@ -73,6 +73,12 @@ pub fn binary_operator_data_type(
 
 /// Coercion rules for all binary operators. Returns the output type
 /// of applying `op` to an argument of `lhs_type` and `rhs_type`.
+///
+/// TODO this function is trying to serve two purposes at once; it determines 
the result type
+/// of the binary operation and also determines how the inputs can be coerced 
but this
+/// results in inconsistencies in some cases (particular around date + 
interval)
+///
+/// Tracking issue is https://github.com/apache/arrow-datafusion/issues/3419
 pub fn coerce_types(
     lhs_type: &DataType,
     op: &Operator,
@@ -518,6 +524,8 @@ fn temporal_coercion(lhs_type: &DataType, rhs_type: 
&DataType) -> Option<DataTyp
     use arrow::datatypes::DataType::*;
     use arrow::datatypes::TimeUnit;
     match (lhs_type, rhs_type) {
+        (Date64, Date32) => Some(Date64),
+        (Date32, Date64) => Some(Date64),
         (Utf8, Date32) => Some(Date32),
         (Date32, Utf8) => Some(Date32),
         (Utf8, Date64) => Some(Date64),
diff --git a/datafusion/optimizer/src/type_coercion.rs 
b/datafusion/optimizer/src/type_coercion.rs
index d267684b7..77580c063 100644
--- a/datafusion/optimizer/src/type_coercion.rs
+++ b/datafusion/optimizer/src/type_coercion.rs
@@ -18,15 +18,15 @@
 //! Optimizer rule for type validation and coercion
 
 use crate::{OptimizerConfig, OptimizerRule};
-use datafusion_common::{DFSchema, DFSchemaRef, Result};
-use datafusion_expr::binary_rule::coerce_types;
+use arrow::datatypes::DataType;
+use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result};
+use datafusion_expr::binary_rule::{coerce_types, comparison_coercion};
 use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, 
RewriteRecursion};
-use datafusion_expr::logical_plan::builder::build_join_schema;
-use datafusion_expr::logical_plan::JoinType;
 use datafusion_expr::type_coercion::data_types;
 use datafusion_expr::utils::from_plan;
 use datafusion_expr::{Expr, LogicalPlan};
 use datafusion_expr::{ExprSchemable, Signature};
+use std::sync::Arc;
 
 #[derive(Default)]
 pub struct TypeCoercion {}
@@ -54,17 +54,19 @@ impl OptimizerRule for TypeCoercion {
             .map(|p| self.optimize(p, optimizer_config))
             .collect::<Result<Vec<_>>>()?;
 
-        let schema = match new_inputs.len() {
-            1 => new_inputs[0].schema().clone(),
-            2 => DFSchemaRef::new(build_join_schema(
-                new_inputs[0].schema(),
-                new_inputs[1].schema(),
-                &JoinType::Inner,
-            )?),
-            _ => DFSchemaRef::new(DFSchema::empty()),
-        };
+        // get schema representing all available input fields. This is used 
for data type
+        // resolution only, so order does not matter here
+        let schema = new_inputs.iter().map(|input| input.schema()).fold(
+            DFSchema::empty(),
+            |mut lhs, rhs| {
+                lhs.merge(rhs);
+                lhs
+            },
+        );
 
-        let mut expr_rewrite = TypeCoercionRewriter { schema };
+        let mut expr_rewrite = TypeCoercionRewriter {
+            schema: Arc::new(schema),
+        };
 
         let new_expr = plan
             .expressions()
@@ -87,14 +89,55 @@ impl ExprRewriter for TypeCoercionRewriter {
 
     fn mutate(&mut self, expr: Expr) -> Result<Expr> {
         match expr {
-            Expr::BinaryExpr { left, op, right } => {
+            Expr::BinaryExpr {
+                ref left,
+                op,
+                ref right,
+            } => {
                 let left_type = left.get_type(&self.schema)?;
                 let right_type = right.get_type(&self.schema)?;
-                let coerced_type = coerce_types(&left_type, &op, &right_type)?;
-                Ok(Expr::BinaryExpr {
-                    left: Box::new(left.cast_to(&coerced_type, &self.schema)?),
-                    op,
-                    right: Box::new(right.cast_to(&coerced_type, 
&self.schema)?),
+                match (&left_type, &right_type) {
+                    (
+                        DataType::Date32 | DataType::Date64 | 
DataType::Timestamp(_, _),
+                        &DataType::Interval(_),
+                    ) => {
+                        // this is a workaround for 
https://github.com/apache/arrow-datafusion/issues/3419
+                        Ok(expr.clone())
+                    }
+                    _ => {
+                        let coerced_type = coerce_types(&left_type, &op, 
&right_type)?;
+                        Ok(Expr::BinaryExpr {
+                            left: Box::new(
+                                left.clone().cast_to(&coerced_type, 
&self.schema)?,
+                            ),
+                            op,
+                            right: Box::new(
+                                right.clone().cast_to(&coerced_type, 
&self.schema)?,
+                            ),
+                        })
+                    }
+                }
+            }
+            Expr::Between {
+                expr,
+                negated,
+                low,
+                high,
+            } => {
+                let expr_type = expr.get_type(&self.schema)?;
+                let low_type = low.get_type(&self.schema)?;
+                let coerced_type = comparison_coercion(&expr_type, &low_type)
+                    .ok_or_else(|| {
+                        DataFusionError::Internal(format!(
+                            "Failed to coerce types {} and {} in BETWEEN 
expression",
+                            expr_type, low_type
+                        ))
+                    })?;
+                Ok(Expr::Between {
+                    expr: Box::new(expr.cast_to(&coerced_type, &self.schema)?),
+                    negated,
+                    low: Box::new(low.cast_to(&coerced_type, &self.schema)?),
+                    high: Box::new(high.cast_to(&coerced_type, &self.schema)?),
                 })
             }
             Expr::ScalarUDF { fun, args } => {
@@ -145,12 +188,12 @@ mod test {
     use crate::type_coercion::TypeCoercion;
     use crate::{OptimizerConfig, OptimizerRule};
     use arrow::datatypes::DataType;
-    use datafusion_common::{DFSchema, Result};
+    use datafusion_common::{DFSchema, Result, ScalarValue};
     use datafusion_expr::{
         lit,
         logical_plan::{EmptyRelation, Projection},
-        Expr, LogicalPlan, ReturnTypeFunction, ScalarFunctionImplementation, 
ScalarUDF,
-        Signature, Volatility,
+        Expr, LogicalPlan, Operator, ReturnTypeFunction, 
ScalarFunctionImplementation,
+        ScalarUDF, Signature, Volatility,
     };
     use std::sync::Arc;
 
@@ -244,6 +287,34 @@ mod test {
         Ok(())
     }
 
+    #[test]
+    fn binary_op_date32_add_interval() -> Result<()> {
+        //CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("386547056640")
+        let expr = Expr::BinaryExpr {
+            left: Box::new(Expr::Cast {
+                expr: Box::new(lit("1998-03-18")),
+                data_type: DataType::Date32,
+            }),
+            op: Operator::Plus,
+            right: Box::new(Expr::Literal(ScalarValue::IntervalDayTime(Some(
+                386547056640,
+            )))),
+        };
+        let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
+            produce_one_row: false,
+            schema: Arc::new(DFSchema::empty()),
+        }));
+        let plan = LogicalPlan::Projection(Projection::try_new(vec![expr], 
empty, None)?);
+        let rule = TypeCoercion::new();
+        let mut config = OptimizerConfig::default();
+        let plan = rule.optimize(&plan, &mut config)?;
+        assert_eq!(
+            "Projection: CAST(Utf8(\"1998-03-18\") AS Date32) + 
IntervalDayTime(\"386547056640\")\n  EmptyRelation",
+            &format!("{:?}", plan)
+        );
+        Ok(())
+    }
+
     fn empty() -> Arc<LogicalPlan> {
         Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
             produce_one_row: false,
diff --git a/datafusion/optimizer/tests/integration-test.rs 
b/datafusion/optimizer/tests/integration-test.rs
index 55c38689b..87a0bab68 100644
--- a/datafusion/optimizer/tests/integration-test.rs
+++ b/datafusion/optimizer/tests/integration-test.rs
@@ -27,6 +27,7 @@ use 
datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys;
 use datafusion_optimizer::filter_push_down::FilterPushDown;
 use datafusion_optimizer::limit_push_down::LimitPushDown;
 use datafusion_optimizer::optimizer::Optimizer;
+use 
datafusion_optimizer::pre_cast_lit_in_comparison::PreCastLitInComparisonExpressions;
 use datafusion_optimizer::projection_push_down::ProjectionPushDown;
 use datafusion_optimizer::reduce_outer_join::ReduceOuterJoin;
 use 
datafusion_optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
@@ -34,6 +35,7 @@ use 
datafusion_optimizer::scalar_subquery_to_join::ScalarSubqueryToJoin;
 use datafusion_optimizer::simplify_expressions::SimplifyExpressions;
 use datafusion_optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy;
 use datafusion_optimizer::subquery_filter_to_join::SubqueryFilterToJoin;
+use datafusion_optimizer::type_coercion::TypeCoercion;
 use datafusion_optimizer::{OptimizerConfig, OptimizerRule};
 use datafusion_sql::planner::{ContextProvider, SqlToRel};
 use datafusion_sql::sqlparser::ast::Statement;
@@ -56,11 +58,56 @@ fn distribute_by() -> Result<()> {
     Ok(())
 }
 
+#[test]
+fn intersect() -> Result<()> {
+    let sql = "SELECT col_int32, col_utf8 FROM test \
+    INTERSECT SELECT col_int32, col_utf8 FROM test \
+    INTERSECT SELECT col_int32, col_utf8 FROM test";
+    let plan = test_sql(sql)?;
+    let expected =
+        "Semi Join: #test.col_int32 = #test.col_int32, #test.col_utf8 = 
#test.col_utf8\
+    \n  Distinct:\
+    \n    Semi Join: #test.col_int32 = #test.col_int32, #test.col_utf8 = 
#test.col_utf8\
+    \n      Distinct:\
+    \n        TableScan: test projection=[col_int32, col_utf8]\
+    \n      TableScan: test projection=[col_int32, col_utf8]\
+    \n  TableScan: test projection=[col_int32, col_utf8]";
+    assert_eq!(expected, format!("{:?}", plan));
+    Ok(())
+}
+
+#[test]
+fn between_date32_plus_interval() -> Result<()> {
+    let sql = "SELECT count(1) FROM test \
+    WHERE col_date32 between '1998-03-18' AND cast('1998-03-18' as date) + 
INTERVAL '90 days'";
+    let plan = test_sql(sql)?;
+    let expected =
+        "Projection: #COUNT(UInt8(1))\n  Aggregate: groupBy=[[]], 
aggr=[[COUNT(UInt8(1))]]\
+        \n    Filter: #test.col_date32 >= CAST(Utf8(\"1998-03-18\") AS Date32) 
AND #test.col_date32 <= Date32(\"10393\")\
+        \n      TableScan: test projection=[col_date32]";
+    assert_eq!(expected, format!("{:?}", plan));
+    Ok(())
+}
+
+#[test]
+fn between_date64_plus_interval() -> Result<()> {
+    let sql = "SELECT count(1) FROM test \
+    WHERE col_date64 between '1998-03-18' AND cast('1998-03-18' as date) + 
INTERVAL '90 days'";
+    let plan = test_sql(sql)?;
+    let expected =
+        "Projection: #COUNT(UInt8(1))\n  Aggregate: groupBy=[[]], 
aggr=[[COUNT(UInt8(1))]]\
+        \n    Filter: #test.col_date64 >= CAST(Utf8(\"1998-03-18\") AS Date64) 
AND #test.col_date64 <= CAST(Date32(\"10393\") AS Date64)\
+        \n      TableScan: test projection=[col_date64]";
+    assert_eq!(expected, format!("{:?}", plan));
+    Ok(())
+}
+
 fn test_sql(sql: &str) -> Result<LogicalPlan> {
     let rules: Vec<Arc<dyn OptimizerRule + Sync + Send>> = vec![
         // Simplify expressions first to maximize the chance
         // of applying other optimizations
         Arc::new(SimplifyExpressions::new()),
+        Arc::new(PreCastLitInComparisonExpressions::new()),
         Arc::new(DecorrelateWhereExists::new()),
         Arc::new(DecorrelateWhereIn::new()),
         Arc::new(ScalarSubqueryToJoin::new()),
@@ -73,6 +120,7 @@ fn test_sql(sql: &str) -> Result<LogicalPlan> {
         Arc::new(FilterNullJoinKeys::default()),
         Arc::new(ReduceOuterJoin::new()),
         Arc::new(FilterPushDown::new()),
+        Arc::new(TypeCoercion::new()),
         Arc::new(LimitPushDown::new()),
         Arc::new(SingleDistinctToGroupBy::new()),
     ];
@@ -107,6 +155,8 @@ impl ContextProvider for MySchemaProvider {
                 vec![
                     Field::new("col_int32", DataType::Int32, true),
                     Field::new("col_utf8", DataType::Utf8, true),
+                    Field::new("col_date32", DataType::Date32, true),
+                    Field::new("col_date64", DataType::Date64, true),
                 ],
                 HashMap::new(),
             );

Reply via email to