This is an automated email from the ASF dual-hosted git repository.

liukun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 29b8bbdf8 move the `type coercion` to the beginning of the optimizer 
rule and support type coercion for subquery (#3636)
29b8bbdf8 is described below

commit 29b8bbdf85f550cd9ffafff5ebeaac9a5701fffb
Author: Kun Liu <[email protected]>
AuthorDate: Thu Sep 29 22:57:02 2022 +0800

    move the `type coercion` to the beginning of the optimizer rule and support 
type coercion for subquery (#3636)
    
    * support subquery for type coercion
    
    * support subquery
    
    * move the type coercion to the begine of the rules
    
    * fix all test case
    
    * fix test
    
    * remove useless code
    
    * add subquery in type coercion
    
    * address comments
    
    * fix test
    
    * support case #3565
---
 datafusion/core/src/execution/context.rs       |  10 +-
 datafusion/core/tests/sql/explain_analyze.rs   |  21 +++-
 datafusion/core/tests/sql/predicates.rs        |   2 +
 datafusion/core/tests/sql/subqueries.rs        |  12 +-
 datafusion/expr/src/logical_plan/plan.rs       |   5 +
 datafusion/optimizer/src/type_coercion.rs      | 164 ++++++++++++++++++-------
 datafusion/optimizer/tests/integration-test.rs |   8 +-
 7 files changed, 154 insertions(+), 68 deletions(-)

diff --git a/datafusion/core/src/execution/context.rs 
b/datafusion/core/src/execution/context.rs
index f65c849e4..ff0ccf835 100644
--- a/datafusion/core/src/execution/context.rs
+++ b/datafusion/core/src/execution/context.rs
@@ -1466,10 +1466,9 @@ impl SessionState {
         }
 
         let mut rules: Vec<Arc<dyn OptimizerRule + Sync + Send>> = vec![
-            // Simplify expressions first to maximize the chance
-            // of applying other optimizations
-            Arc::new(SimplifyExpressions::new()),
             Arc::new(PreCastLitInComparisonExpressions::new()),
+            Arc::new(TypeCoercion::new()),
+            Arc::new(SimplifyExpressions::new()),
             Arc::new(DecorrelateWhereExists::new()),
             Arc::new(DecorrelateWhereIn::new()),
             Arc::new(ScalarSubqueryToJoin::new()),
@@ -1490,11 +1489,6 @@ impl SessionState {
             rules.push(Arc::new(FilterNullJoinKeys::default()));
         }
         rules.push(Arc::new(ReduceOuterJoin::new()));
-        // TODO: https://github.com/apache/arrow-datafusion/issues/3557
-        // remove this, after the issue fixed.
-        rules.push(Arc::new(TypeCoercion::new()));
-        // after the type coercion, can do simplify expression again
-        rules.push(Arc::new(SimplifyExpressions::new()));
         rules.push(Arc::new(FilterPushDown::new()));
         rules.push(Arc::new(LimitPushDown::new()));
         rules.push(Arc::new(SingleDistinctToGroupBy::new()));
diff --git a/datafusion/core/tests/sql/explain_analyze.rs 
b/datafusion/core/tests/sql/explain_analyze.rs
index f2069126c..fe51aedc8 100644
--- a/datafusion/core/tests/sql/explain_analyze.rs
+++ b/datafusion/core/tests/sql/explain_analyze.rs
@@ -767,6 +767,8 @@ async fn test_physical_plan_display_indent_multi_children() 
{
 #[tokio::test]
 #[cfg_attr(tarpaulin, ignore)]
 async fn csv_explain() {
+    // TODO: https://github.com/apache/arrow-datafusion/issues/3622 refactor 
the `PreCastLitInComparisonExpressions`
+
     // This test uses the execute function that create full plan cycle: 
logical, optimized logical, and physical,
     // then execute the physical plan and return the final explain results
     let ctx = SessionContext::new();
@@ -777,6 +779,23 @@ async fn csv_explain() {
 
     // Note can't use `assert_batches_eq` as the plan needs to be
     // normalized for filenames and number of cores
+    let expected = vec![
+        vec![
+            "logical_plan",
+            "Projection: #aggregate_test_100.c1\
+             \n  Filter: CAST(#aggregate_test_100.c2 AS Int32) > Int32(10)\
+             \n    TableScan: aggregate_test_100 projection=[c1, c2], 
partial_filters=[CAST(#aggregate_test_100.c2 AS Int32) > Int32(10)]"
+        ],
+        vec!["physical_plan",
+             "ProjectionExec: expr=[c1@0 as c1]\
+              \n  CoalesceBatchesExec: target_batch_size=4096\
+              \n    FilterExec: CAST(c2@1 AS Int32) > 10\
+              \n      RepartitionExec: partitioning=RoundRobinBatch(NUM_CORES)\
+              \n        CsvExec: 
files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, 
limit=None, projection=[c1, c2]\
+              \n"
+        ]];
+    assert_eq!(expected, actual);
+
     let expected = vec![
         vec![
             "logical_plan",
@@ -792,9 +811,7 @@ async fn csv_explain() {
               \n        CsvExec: 
files=[ARROW_TEST_DATA/csv/aggregate_test_100.csv], has_header=true, 
limit=None, projection=[c1, c2]\
               \n"
         ]];
-    assert_eq!(expected, actual);
 
-    // Also, expect same result with lowercase explain
     let sql = "explain SELECT c1 FROM aggregate_test_100 where c2 > 10";
     let actual = execute(&ctx, sql).await;
     let actual = normalize_vec_for_explain(actual);
diff --git a/datafusion/core/tests/sql/predicates.rs 
b/datafusion/core/tests/sql/predicates.rs
index 895af7081..15e89f7b3 100644
--- a/datafusion/core/tests/sql/predicates.rs
+++ b/datafusion/core/tests/sql/predicates.rs
@@ -385,6 +385,8 @@ async fn csv_in_set_test() -> Result<()> {
 }
 
 #[tokio::test]
+#[ignore]
+// https://github.com/apache/arrow-datafusion/issues/3635
 async fn multiple_or_predicates() -> Result<()> {
     // TODO https://github.com/apache/arrow-datafusion/issues/3587
     let ctx = SessionContext::new();
diff --git a/datafusion/core/tests/sql/subqueries.rs 
b/datafusion/core/tests/sql/subqueries.rs
index 0ac286d76..4b4f23e13 100644
--- a/datafusion/core/tests/sql/subqueries.rs
+++ b/datafusion/core/tests/sql/subqueries.rs
@@ -336,10 +336,10 @@ order by s_name;
               Projection: #part.p_partkey AS p_partkey, alias=__sq_1
                 Filter: #part.p_name LIKE Utf8("forest%")
                   TableScan: part projection=[p_partkey, p_name], 
partial_filters=[#part.p_name LIKE Utf8("forest%")]
-            Projection: #lineitem.l_partkey, #lineitem.l_suppkey, 
Decimal128(Some(50000000000000000),38,17) * CAST(#SUM(lineitem.l_quantity) AS 
Decimal128(38, 17)) AS __value, alias=__sq_3
+            Projection: #lineitem.l_partkey, #lineitem.l_suppkey, 
CAST(Float64(0.5) AS Decimal128(38, 17)) * CAST(#SUM(lineitem.l_quantity) AS 
Decimal128(38, 17)) AS __value, alias=__sq_3
               Aggregate: groupBy=[[#lineitem.l_partkey, #lineitem.l_suppkey]], 
aggr=[[SUM(#lineitem.l_quantity)]]
-                Filter: #lineitem.l_shipdate >= Date32("8766")
-                  TableScan: lineitem projection=[l_partkey, l_suppkey, 
l_quantity, l_shipdate], partial_filters=[#lineitem.l_shipdate >= 
Date32("8766")]"#
+                Filter: #lineitem.l_shipdate >= CAST(Utf8("1994-01-01") AS 
Date32)
+                  TableScan: lineitem projection=[l_partkey, l_suppkey, 
l_quantity, l_shipdate], partial_filters=[#lineitem.l_shipdate >= 
CAST(Utf8("1994-01-01") AS Date32)]"#
         .to_string();
     assert_eq!(actual, expected);
 
@@ -393,8 +393,8 @@ order by cntrycode;"#;
                 TableScan: orders projection=[o_custkey]
               Projection: #AVG(customer.c_acctbal) AS __value, alias=__sq_1
                 Aggregate: groupBy=[[]], aggr=[[AVG(#customer.c_acctbal)]]
-                  Filter: CAST(#customer.c_acctbal AS Decimal128(30, 15)) > 
Decimal128(Some(0),30,15) AND substr(#customer.c_phone, Int64(1), Int64(2)) IN 
([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), 
Utf8("17")])
-                    TableScan: customer projection=[c_phone, c_acctbal], 
partial_filters=[CAST(#customer.c_acctbal AS Decimal128(30, 15)) > 
Decimal128(Some(0),30,15), substr(#customer.c_phone, Int64(1), Int64(2)) IN 
([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), 
Utf8("17")])]"#
+                  Filter: CAST(#customer.c_acctbal AS Decimal128(30, 15)) > 
CAST(Float64(0) AS Decimal128(30, 15)) AND substr(#customer.c_phone, Int64(1), 
Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), 
Utf8("18"), Utf8("17")])
+                    TableScan: customer projection=[c_phone, c_acctbal], 
partial_filters=[CAST(#customer.c_acctbal AS Decimal128(30, 15)) > 
CAST(Float64(0) AS Decimal128(30, 15)), substr(#customer.c_phone, Int64(1), 
Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), 
Utf8("18"), Utf8("17")])]"#
         .to_string();
     assert_eq!(actual, expected);
 
@@ -453,7 +453,7 @@ order by value desc;
               TableScan: supplier projection=[s_suppkey, s_nationkey]
             Filter: #nation.n_name = Utf8("GERMANY")
               TableScan: nation projection=[n_nationkey, n_name], 
partial_filters=[#nation.n_name = Utf8("GERMANY")]
-        Projection: CAST(#SUM(partsupp.ps_supplycost * partsupp.ps_availqty) 
AS Decimal128(38, 17)) * Decimal128(Some(10000000000000),38,17) AS __value, 
alias=__sq_1
+        Projection: CAST(#SUM(partsupp.ps_supplycost * partsupp.ps_availqty) 
AS Decimal128(38, 17)) * CAST(Float64(0.0001) AS Decimal128(38, 17)) AS 
__value, alias=__sq_1
           Aggregate: groupBy=[[]], aggr=[[SUM(CAST(#partsupp.ps_supplycost AS 
Decimal128(26, 2)) * CAST(#partsupp.ps_availqty AS Decimal128(26, 2)))]]
             Inner Join: #supplier.s_nationkey = #nation.n_nationkey
               Inner Join: #partsupp.ps_suppkey = #supplier.s_suppkey
diff --git a/datafusion/expr/src/logical_plan/plan.rs 
b/datafusion/expr/src/logical_plan/plan.rs
index 049e6158c..a803f569c 100644
--- a/datafusion/expr/src/logical_plan/plan.rs
+++ b/datafusion/expr/src/logical_plan/plan.rs
@@ -1410,6 +1410,11 @@ pub struct Subquery {
 }
 
 impl Subquery {
+    pub fn new(plan: LogicalPlan) -> Self {
+        Subquery {
+            subquery: Arc::new(plan),
+        }
+    }
     pub fn try_from_expr(plan: &Expr) -> datafusion_common::Result<&Subquery> {
         match plan {
             Expr::ScalarSubquery(it) => Ok(it),
diff --git a/datafusion/optimizer/src/type_coercion.rs 
b/datafusion/optimizer/src/type_coercion.rs
index bf99d61d9..372d09326 100644
--- a/datafusion/optimizer/src/type_coercion.rs
+++ b/datafusion/optimizer/src/type_coercion.rs
@@ -22,6 +22,7 @@ use arrow::datatypes::DataType;
 use datafusion_common::{DFSchema, DFSchemaRef, DataFusionError, Result};
 use datafusion_expr::binary_rule::{coerce_types, comparison_coercion};
 use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, 
RewriteRecursion};
+use datafusion_expr::logical_plan::Subquery;
 use datafusion_expr::type_coercion::data_types;
 use datafusion_expr::utils::from_plan;
 use datafusion_expr::{
@@ -50,56 +51,70 @@ impl OptimizerRule for TypeCoercion {
         plan: &LogicalPlan,
         optimizer_config: &mut OptimizerConfig,
     ) -> Result<LogicalPlan> {
-        // optimize child plans first
-        let new_inputs = plan
-            .inputs()
-            .iter()
-            .map(|p| self.optimize(p, optimizer_config))
-            .collect::<Result<Vec<_>>>()?;
-
-        // get schema representing all available input fields. This is used 
for data type
-        // resolution only, so order does not matter here
-        let schema = new_inputs.iter().map(|input| input.schema()).fold(
-            DFSchema::empty(),
-            |mut lhs, rhs| {
-                lhs.merge(rhs);
-                lhs
-            },
-        );
+        optimize_internal(&DFSchema::empty(), plan, optimizer_config)
+    }
+}
 
-        let mut expr_rewrite = TypeCoercionRewriter {
-            schema: Arc::new(schema),
-        };
+fn optimize_internal(
+    // use the external schema to handle the correlated subqueries case
+    external_schema: &DFSchema,
+    plan: &LogicalPlan,
+    optimizer_config: &mut OptimizerConfig,
+) -> Result<LogicalPlan> {
+    // optimize child plans first
+    let new_inputs = plan
+        .inputs()
+        .iter()
+        .map(|p| optimize_internal(external_schema, p, optimizer_config))
+        .collect::<Result<Vec<_>>>()?;
+
+    // get schema representing all available input fields. This is used for 
data type
+    // resolution only, so order does not matter here
+    let mut schema = new_inputs.iter().map(|input| input.schema()).fold(
+        DFSchema::empty(),
+        |mut lhs, rhs| {
+            lhs.merge(rhs);
+            lhs
+        },
+    );
+
+    // merge the outer schema for correlated subqueries
+    // like case:
+    // select t2.c2 from t1 where t1.c1 in (select t2.c1 from t2 where 
t2.c2=t1.c3)
+    schema.merge(external_schema);
+
+    let mut expr_rewrite = TypeCoercionRewriter {
+        schema: Arc::new(schema),
+    };
 
-        let original_expr_names: Vec<Option<String>> = plan
-            .expressions()
-            .iter()
-            .map(|expr| expr.name().ok())
-            .collect();
-
-        let new_expr = plan
-            .expressions()
-            .into_iter()
-            .zip(original_expr_names)
-            .map(|(expr, original_name)| {
-                let expr = expr.rewrite(&mut expr_rewrite)?;
-
-                // ensure aggregate names don't change:
-                // https://github.com/apache/arrow-datafusion/issues/3555
-                if matches!(expr, Expr::AggregateFunction { .. }) {
-                    if let Some((alias, name)) = 
original_name.zip(expr.name().ok()) {
-                        if alias != name {
-                            return Ok(expr.alias(&alias));
-                        }
+    let original_expr_names: Vec<Option<String>> = plan
+        .expressions()
+        .iter()
+        .map(|expr| expr.name().ok())
+        .collect();
+
+    let new_expr = plan
+        .expressions()
+        .into_iter()
+        .zip(original_expr_names)
+        .map(|(expr, original_name)| {
+            let expr = expr.rewrite(&mut expr_rewrite)?;
+
+            // ensure aggregate names don't change:
+            // https://github.com/apache/arrow-datafusion/issues/3555
+            if matches!(expr, Expr::AggregateFunction { .. }) {
+                if let Some((alias, name)) = 
original_name.zip(expr.name().ok()) {
+                    if alias != name {
+                        return Ok(expr.alias(&alias));
                     }
                 }
+            }
 
-                Ok(expr)
-            })
-            .collect::<Result<Vec<_>>>()?;
+            Ok(expr)
+        })
+        .collect::<Result<Vec<_>>>()?;
 
-        from_plan(plan, &new_expr, &new_inputs)
-    }
+    from_plan(plan, &new_expr, &new_inputs)
 }
 
 pub(crate) struct TypeCoercionRewriter {
@@ -119,6 +134,41 @@ impl ExprRewriter for TypeCoercionRewriter {
 
     fn mutate(&mut self, expr: Expr) -> Result<Expr> {
         match expr {
+            Expr::ScalarSubquery(Subquery { subquery }) => {
+                let mut optimizer_config = OptimizerConfig::new();
+                let new_plan =
+                    optimize_internal(&self.schema, &subquery, &mut 
optimizer_config)?;
+                Ok(Expr::ScalarSubquery(Subquery::new(new_plan)))
+            }
+            Expr::Exists { subquery, negated } => {
+                let mut optimizer_config = OptimizerConfig::new();
+                let new_plan = optimize_internal(
+                    &self.schema,
+                    &subquery.subquery,
+                    &mut optimizer_config,
+                )?;
+                Ok(Expr::Exists {
+                    subquery: Subquery::new(new_plan),
+                    negated,
+                })
+            }
+            Expr::InSubquery {
+                expr,
+                subquery,
+                negated,
+            } => {
+                let mut optimizer_config = OptimizerConfig::new();
+                let new_plan = optimize_internal(
+                    &self.schema,
+                    &subquery.subquery,
+                    &mut optimizer_config,
+                )?;
+                Ok(Expr::InSubquery {
+                    expr,
+                    subquery: Subquery::new(new_plan),
+                    negated,
+                })
+            }
             Expr::IsTrue(expr) => {
                 let expr = is_true(get_casted_expr_for_bool_op(&expr, 
&self.schema)?);
                 Ok(expr)
@@ -368,11 +418,12 @@ fn coerce_arguments_for_signature(
 
 #[cfg(test)]
 mod test {
-    use crate::type_coercion::TypeCoercion;
+    use crate::type_coercion::{TypeCoercion, TypeCoercionRewriter};
     use crate::{OptimizerConfig, OptimizerRule};
     use arrow::datatypes::DataType;
     use datafusion_common::{DFField, DFSchema, Result, ScalarValue};
-    use datafusion_expr::{col, ColumnarValue};
+    use datafusion_expr::expr_rewriter::ExprRewritable;
+    use datafusion_expr::{cast, col, is_true, ColumnarValue};
     use datafusion_expr::{
         lit,
         logical_plan::{EmptyRelation, Projection},
@@ -735,4 +786,25 @@ mod test {
             ),
         }))
     }
+
+    #[test]
+    fn test_type_coercion_rewrite() -> Result<()> {
+        let schema = Arc::new(
+            DFSchema::new_with_metadata(
+                vec![DFField::new(None, "a", DataType::Int64, true)],
+                std::collections::HashMap::new(),
+            )
+            .unwrap(),
+        );
+        let mut rewriter = TypeCoercionRewriter::new(schema);
+        let expr = is_true(lit(12i32).eq(lit(13i64)));
+        let expected = is_true(
+            cast(lit(ScalarValue::Int32(Some(12))), DataType::Int64)
+                .eq(lit(ScalarValue::Int64(Some(13)))),
+        );
+        let result = expr.rewrite(&mut rewriter)?;
+        assert_eq!(expected, result);
+        Ok(())
+        // TODO add more test for this
+    }
 }
diff --git a/datafusion/optimizer/tests/integration-test.rs 
b/datafusion/optimizer/tests/integration-test.rs
index 554e3cceb..5f2760316 100644
--- a/datafusion/optimizer/tests/integration-test.rs
+++ b/datafusion/optimizer/tests/integration-test.rs
@@ -109,10 +109,9 @@ fn test_sql(sql: &str) -> Result<LogicalPlan> {
     // TODO should make align with rules in the context
     // https://github.com/apache/arrow-datafusion/issues/3524
     let rules: Vec<Arc<dyn OptimizerRule + Sync + Send>> = vec![
-        // Simplify expressions first to maximize the chance
-        // of applying other optimizations
-        Arc::new(SimplifyExpressions::new()),
         Arc::new(PreCastLitInComparisonExpressions::new()),
+        Arc::new(TypeCoercion::new()),
+        Arc::new(SimplifyExpressions::new()),
         Arc::new(DecorrelateWhereExists::new()),
         Arc::new(DecorrelateWhereIn::new()),
         Arc::new(ScalarSubqueryToJoin::new()),
@@ -125,9 +124,6 @@ fn test_sql(sql: &str) -> Result<LogicalPlan> {
         Arc::new(RewriteDisjunctivePredicate::new()),
         Arc::new(FilterNullJoinKeys::default()),
         Arc::new(ReduceOuterJoin::new()),
-        Arc::new(TypeCoercion::new()),
-        // after the type coercion, can do simplify expression again
-        Arc::new(SimplifyExpressions::new()),
         Arc::new(FilterPushDown::new()),
         Arc::new(LimitPushDown::new()),
         Arc::new(SingleDistinctToGroupBy::new()),

Reply via email to