This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 7427a801a Fix some bugs in TypeCoercion rule (#3407)
7427a801a is described below
commit 7427a801a752223c453c919ef88399cb40820f0c
Author: Andy Grove <[email protected]>
AuthorDate: Mon Sep 12 11:03:21 2022 -0600
Fix some bugs in TypeCoercion rule (#3407)
* Fix schema bug in TypeCoercion rule
* Add type coercion for between
* add workaround for INTERVAL
* fix regression
* add support for coercion between Date32/Date64 and fix regressions caused
by recent merges to master
* fix error message
* update comments and link to follow-on issue
---
datafusion/expr/src/binary_rule.rs | 8 ++
datafusion/optimizer/src/type_coercion.rs | 117 ++++++++++++++++++++-----
datafusion/optimizer/tests/integration-test.rs | 50 +++++++++++
3 files changed, 152 insertions(+), 23 deletions(-)
diff --git a/datafusion/expr/src/binary_rule.rs
b/datafusion/expr/src/binary_rule.rs
index 24d1cb32d..8f4cf3356 100644
--- a/datafusion/expr/src/binary_rule.rs
+++ b/datafusion/expr/src/binary_rule.rs
@@ -73,6 +73,12 @@ pub fn binary_operator_data_type(
/// Coercion rules for all binary operators. Returns the output type
/// of applying `op` to an argument of `lhs_type` and `rhs_type`.
+///
+/// TODO this function is trying to serve two purposes at once; it determines
the result type
+/// of the binary operation and also determines how the inputs can be coerced
but this
+/// results in inconsistencies in some cases (particular around date +
interval)
+///
+/// Tracking issue is https://github.com/apache/arrow-datafusion/issues/3419
pub fn coerce_types(
lhs_type: &DataType,
op: &Operator,
@@ -518,6 +524,8 @@ fn temporal_coercion(lhs_type: &DataType, rhs_type:
&DataType) -> Option<DataTyp
use arrow::datatypes::DataType::*;
use arrow::datatypes::TimeUnit;
match (lhs_type, rhs_type) {
+ (Date64, Date32) => Some(Date64),
+ (Date32, Date64) => Some(Date64),
(Utf8, Date32) => Some(Date32),
(Date32, Utf8) => Some(Date32),
(Utf8, Date64) => Some(Date64),
diff --git a/datafusion/optimizer/src/type_coercion.rs
b/datafusion/optimizer/src/type_coercion.rs
index d267684b7..77580c063 100644
--- a/datafusion/optimizer/src/type_coercion.rs
+++ b/datafusion/optimizer/src/type_coercion.rs
@@ -18,15 +18,15 @@
//! Optimizer rule for type validation and coercion
use crate::{OptimizerConfig, OptimizerRule};
-use datafusion_common::{DFSchema, DFSchemaRef, Result};
-use datafusion_expr::binary_rule::coerce_types;
+use arrow::datatypes::DataType;
+use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result};
+use datafusion_expr::binary_rule::{coerce_types, comparison_coercion};
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter,
RewriteRecursion};
-use datafusion_expr::logical_plan::builder::build_join_schema;
-use datafusion_expr::logical_plan::JoinType;
use datafusion_expr::type_coercion::data_types;
use datafusion_expr::utils::from_plan;
use datafusion_expr::{Expr, LogicalPlan};
use datafusion_expr::{ExprSchemable, Signature};
+use std::sync::Arc;
#[derive(Default)]
pub struct TypeCoercion {}
@@ -54,17 +54,19 @@ impl OptimizerRule for TypeCoercion {
.map(|p| self.optimize(p, optimizer_config))
.collect::<Result<Vec<_>>>()?;
- let schema = match new_inputs.len() {
- 1 => new_inputs[0].schema().clone(),
- 2 => DFSchemaRef::new(build_join_schema(
- new_inputs[0].schema(),
- new_inputs[1].schema(),
- &JoinType::Inner,
- )?),
- _ => DFSchemaRef::new(DFSchema::empty()),
- };
+ // get schema representing all available input fields. This is used
for data type
+ // resolution only, so order does not matter here
+ let schema = new_inputs.iter().map(|input| input.schema()).fold(
+ DFSchema::empty(),
+ |mut lhs, rhs| {
+ lhs.merge(rhs);
+ lhs
+ },
+ );
- let mut expr_rewrite = TypeCoercionRewriter { schema };
+ let mut expr_rewrite = TypeCoercionRewriter {
+ schema: Arc::new(schema),
+ };
let new_expr = plan
.expressions()
@@ -87,14 +89,55 @@ impl ExprRewriter for TypeCoercionRewriter {
fn mutate(&mut self, expr: Expr) -> Result<Expr> {
match expr {
- Expr::BinaryExpr { left, op, right } => {
+ Expr::BinaryExpr {
+ ref left,
+ op,
+ ref right,
+ } => {
let left_type = left.get_type(&self.schema)?;
let right_type = right.get_type(&self.schema)?;
- let coerced_type = coerce_types(&left_type, &op, &right_type)?;
- Ok(Expr::BinaryExpr {
- left: Box::new(left.cast_to(&coerced_type, &self.schema)?),
- op,
- right: Box::new(right.cast_to(&coerced_type,
&self.schema)?),
+ match (&left_type, &right_type) {
+ (
+ DataType::Date32 | DataType::Date64 |
DataType::Timestamp(_, _),
+ &DataType::Interval(_),
+ ) => {
+ // this is a workaround for
https://github.com/apache/arrow-datafusion/issues/3419
+ Ok(expr.clone())
+ }
+ _ => {
+ let coerced_type = coerce_types(&left_type, &op,
&right_type)?;
+ Ok(Expr::BinaryExpr {
+ left: Box::new(
+ left.clone().cast_to(&coerced_type,
&self.schema)?,
+ ),
+ op,
+ right: Box::new(
+ right.clone().cast_to(&coerced_type,
&self.schema)?,
+ ),
+ })
+ }
+ }
+ }
+ Expr::Between {
+ expr,
+ negated,
+ low,
+ high,
+ } => {
+ let expr_type = expr.get_type(&self.schema)?;
+ let low_type = low.get_type(&self.schema)?;
+ let coerced_type = comparison_coercion(&expr_type, &low_type)
+ .ok_or_else(|| {
+ DataFusionError::Internal(format!(
+ "Failed to coerce types {} and {} in BETWEEN
expression",
+ expr_type, low_type
+ ))
+ })?;
+ Ok(Expr::Between {
+ expr: Box::new(expr.cast_to(&coerced_type, &self.schema)?),
+ negated,
+ low: Box::new(low.cast_to(&coerced_type, &self.schema)?),
+ high: Box::new(high.cast_to(&coerced_type, &self.schema)?),
})
}
Expr::ScalarUDF { fun, args } => {
@@ -145,12 +188,12 @@ mod test {
use crate::type_coercion::TypeCoercion;
use crate::{OptimizerConfig, OptimizerRule};
use arrow::datatypes::DataType;
- use datafusion_common::{DFSchema, Result};
+ use datafusion_common::{DFSchema, Result, ScalarValue};
use datafusion_expr::{
lit,
logical_plan::{EmptyRelation, Projection},
- Expr, LogicalPlan, ReturnTypeFunction, ScalarFunctionImplementation,
ScalarUDF,
- Signature, Volatility,
+ Expr, LogicalPlan, Operator, ReturnTypeFunction,
ScalarFunctionImplementation,
+ ScalarUDF, Signature, Volatility,
};
use std::sync::Arc;
@@ -244,6 +287,34 @@ mod test {
Ok(())
}
+ #[test]
+ fn binary_op_date32_add_interval() -> Result<()> {
+ //CAST(Utf8("1998-03-18") AS Date32) + IntervalDayTime("386547056640")
+ let expr = Expr::BinaryExpr {
+ left: Box::new(Expr::Cast {
+ expr: Box::new(lit("1998-03-18")),
+ data_type: DataType::Date32,
+ }),
+ op: Operator::Plus,
+ right: Box::new(Expr::Literal(ScalarValue::IntervalDayTime(Some(
+ 386547056640,
+ )))),
+ };
+ let empty = Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
+ produce_one_row: false,
+ schema: Arc::new(DFSchema::empty()),
+ }));
+ let plan = LogicalPlan::Projection(Projection::try_new(vec![expr],
empty, None)?);
+ let rule = TypeCoercion::new();
+ let mut config = OptimizerConfig::default();
+ let plan = rule.optimize(&plan, &mut config)?;
+ assert_eq!(
+ "Projection: CAST(Utf8(\"1998-03-18\") AS Date32) +
IntervalDayTime(\"386547056640\")\n EmptyRelation",
+ &format!("{:?}", plan)
+ );
+ Ok(())
+ }
+
fn empty() -> Arc<LogicalPlan> {
Arc::new(LogicalPlan::EmptyRelation(EmptyRelation {
produce_one_row: false,
diff --git a/datafusion/optimizer/tests/integration-test.rs
b/datafusion/optimizer/tests/integration-test.rs
index 55c38689b..87a0bab68 100644
--- a/datafusion/optimizer/tests/integration-test.rs
+++ b/datafusion/optimizer/tests/integration-test.rs
@@ -27,6 +27,7 @@ use
datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys;
use datafusion_optimizer::filter_push_down::FilterPushDown;
use datafusion_optimizer::limit_push_down::LimitPushDown;
use datafusion_optimizer::optimizer::Optimizer;
+use
datafusion_optimizer::pre_cast_lit_in_comparison::PreCastLitInComparisonExpressions;
use datafusion_optimizer::projection_push_down::ProjectionPushDown;
use datafusion_optimizer::reduce_outer_join::ReduceOuterJoin;
use
datafusion_optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
@@ -34,6 +35,7 @@ use
datafusion_optimizer::scalar_subquery_to_join::ScalarSubqueryToJoin;
use datafusion_optimizer::simplify_expressions::SimplifyExpressions;
use datafusion_optimizer::single_distinct_to_groupby::SingleDistinctToGroupBy;
use datafusion_optimizer::subquery_filter_to_join::SubqueryFilterToJoin;
+use datafusion_optimizer::type_coercion::TypeCoercion;
use datafusion_optimizer::{OptimizerConfig, OptimizerRule};
use datafusion_sql::planner::{ContextProvider, SqlToRel};
use datafusion_sql::sqlparser::ast::Statement;
@@ -56,11 +58,56 @@ fn distribute_by() -> Result<()> {
Ok(())
}
+#[test]
+fn intersect() -> Result<()> {
+ let sql = "SELECT col_int32, col_utf8 FROM test \
+ INTERSECT SELECT col_int32, col_utf8 FROM test \
+ INTERSECT SELECT col_int32, col_utf8 FROM test";
+ let plan = test_sql(sql)?;
+ let expected =
+ "Semi Join: #test.col_int32 = #test.col_int32, #test.col_utf8 =
#test.col_utf8\
+ \n Distinct:\
+ \n Semi Join: #test.col_int32 = #test.col_int32, #test.col_utf8 =
#test.col_utf8\
+ \n Distinct:\
+ \n TableScan: test projection=[col_int32, col_utf8]\
+ \n TableScan: test projection=[col_int32, col_utf8]\
+ \n TableScan: test projection=[col_int32, col_utf8]";
+ assert_eq!(expected, format!("{:?}", plan));
+ Ok(())
+}
+
+#[test]
+fn between_date32_plus_interval() -> Result<()> {
+ let sql = "SELECT count(1) FROM test \
+ WHERE col_date32 between '1998-03-18' AND cast('1998-03-18' as date) +
INTERVAL '90 days'";
+ let plan = test_sql(sql)?;
+ let expected =
+ "Projection: #COUNT(UInt8(1))\n Aggregate: groupBy=[[]],
aggr=[[COUNT(UInt8(1))]]\
+ \n Filter: #test.col_date32 >= CAST(Utf8(\"1998-03-18\") AS Date32)
AND #test.col_date32 <= Date32(\"10393\")\
+ \n TableScan: test projection=[col_date32]";
+ assert_eq!(expected, format!("{:?}", plan));
+ Ok(())
+}
+
+#[test]
+fn between_date64_plus_interval() -> Result<()> {
+ let sql = "SELECT count(1) FROM test \
+ WHERE col_date64 between '1998-03-18' AND cast('1998-03-18' as date) +
INTERVAL '90 days'";
+ let plan = test_sql(sql)?;
+ let expected =
+ "Projection: #COUNT(UInt8(1))\n Aggregate: groupBy=[[]],
aggr=[[COUNT(UInt8(1))]]\
+ \n Filter: #test.col_date64 >= CAST(Utf8(\"1998-03-18\") AS Date64)
AND #test.col_date64 <= CAST(Date32(\"10393\") AS Date64)\
+ \n TableScan: test projection=[col_date64]";
+ assert_eq!(expected, format!("{:?}", plan));
+ Ok(())
+}
+
fn test_sql(sql: &str) -> Result<LogicalPlan> {
let rules: Vec<Arc<dyn OptimizerRule + Sync + Send>> = vec![
// Simplify expressions first to maximize the chance
// of applying other optimizations
Arc::new(SimplifyExpressions::new()),
+ Arc::new(PreCastLitInComparisonExpressions::new()),
Arc::new(DecorrelateWhereExists::new()),
Arc::new(DecorrelateWhereIn::new()),
Arc::new(ScalarSubqueryToJoin::new()),
@@ -73,6 +120,7 @@ fn test_sql(sql: &str) -> Result<LogicalPlan> {
Arc::new(FilterNullJoinKeys::default()),
Arc::new(ReduceOuterJoin::new()),
Arc::new(FilterPushDown::new()),
+ Arc::new(TypeCoercion::new()),
Arc::new(LimitPushDown::new()),
Arc::new(SingleDistinctToGroupBy::new()),
];
@@ -107,6 +155,8 @@ impl ContextProvider for MySchemaProvider {
vec![
Field::new("col_int32", DataType::Int32, true),
Field::new("col_utf8", DataType::Utf8, true),
+ Field::new("col_date32", DataType::Date32, true),
+ Field::new("col_date64", DataType::Date64, true),
],
HashMap::new(),
);