This is an automated email from the ASF dual-hosted git repository.

liukun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new a4e74c0bc support inlist for pre cast literal expression (#3270)
a4e74c0bc is described below

commit a4e74c0bc9b6e46dd151d40e5c881b7961fecccc
Author: Kun Liu <[email protected]>
AuthorDate: Tue Aug 30 17:48:15 2022 +0800

    support inlist for pre cast literal expression (#3270)
    
    * support decimal for the PreCastLitInComparisonExpressions rule
    
    * address comments
    
    * support list
---
 .../optimizer/src/pre_cast_lit_in_comparison.rs    | 181 ++++++++++++++++++++-
 1 file changed, 178 insertions(+), 3 deletions(-)

diff --git a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs 
b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs
index 6e89afd60..793eca2f3 100644
--- a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs
+++ b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs
@@ -24,7 +24,9 @@ use arrow::datatypes::{
 use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue};
 use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, 
RewriteRecursion};
 use datafusion_expr::utils::from_plan;
-use datafusion_expr::{binary_expr, lit, Expr, ExprSchemable, LogicalPlan, 
Operator};
+use datafusion_expr::{
+    binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator,
+};
 
 /// The rule can be only used to the numeric binary comparison with literal 
expr, like below pattern:
 /// `left_expr comparison_op literal_expr` or `literal_expr comparison_op 
right_expr`.
@@ -144,8 +146,57 @@ impl ExprRewriter for PreCastLitExprRewriter {
                 // return the new binary op
                 Ok(binary_expr(left, *op, right))
             }
-            // TODO: optimize in list
-            // Expr::InList { .. } => {}
+            Expr::InList {
+                expr: left_expr,
+                list,
+                negated,
+            } => {
+                let left = left_expr.as_ref().clone();
+                let left_type = left.get_type(&self.schema);
+                if left_type.is_err() {
+                    // error data type
+                    return Ok(expr);
+                }
+                let left_type = left_type?;
+                if !is_support_data_type(&left_type) {
+                    // not supported data type
+                    return Ok(expr);
+                }
+                let right_exprs = list
+                    .iter()
+                    .map(|right| {
+                        let right_type = right.get_type(&self.schema)?;
+                        if !is_support_data_type(&right_type) {
+                            return Err(DataFusionError::Internal(format!(
+                                "The type of list expr {} not support",
+                                &right_type
+                            )));
+                        }
+                        match right {
+                            Expr::Literal(right_lit_value) => {
+                                let casted_scalar_value =
+                                    try_cast_literal_to_type(right_lit_value, 
&left_type)?;
+                                if let Some(value) = casted_scalar_value {
+                                    Ok(lit(value))
+                                } else {
+                                    Err(DataFusionError::Internal(format!(
+                                        "Can't cast the list expr {:?} to type 
{:?}",
+                                        right_lit_value, &left_type
+                                    )))
+                                }
+                            }
+                            other_expr => 
Err(DataFusionError::Internal(format!(
+                                "Only support literal expr to optimize, but 
the expr is {:?}",
+                                &other_expr
+                            ))),
+                        }
+                    })
+                    .collect::<Result<Vec<_>>>();
+                match right_exprs {
+                    Ok(right_exprs) => Ok(in_list(left, right_exprs, 
*negated)),
+                    Err(_) => Ok(expr),
+                }
+            }
             // TODO: handle other expr type and dfs visit them
             _ => Ok(expr),
         }
@@ -384,6 +435,129 @@ mod tests {
         assert_eq!(optimize_test(expr_lt, &schema), expected);
     }
 
+    #[test]
+    fn test_not_list_cast_lit_comparison() {
+        let schema = expr_test_schema();
+        // left type is not supported
+        // FLOAT32(C5) in ...
+        let expr_lt = col("c5").in_list(
+            vec![
+                lit(ScalarValue::Int64(Some(12))),
+                lit(ScalarValue::Int32(Some(12))),
+            ],
+            false,
+        );
+        assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
+
+        // INT32(C1) in (FLOAT32(1.23), INT32(12), INT64(12))
+        let expr_lt = col("c1").in_list(
+            vec![
+                lit(ScalarValue::Int32(Some(12))),
+                lit(ScalarValue::Int64(Some(12))),
+                lit(ScalarValue::Float32(Some(1.23))),
+            ],
+            false,
+        );
+        assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
+
+        // INT32(C1) in (INT64(99999999999), INT64(12))
+        let expr_lt = col("c1").in_list(
+            vec![
+                lit(ScalarValue::Int32(Some(12))),
+                lit(ScalarValue::Int64(Some(99999999999))),
+            ],
+            false,
+        );
+        assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
+
+        // DECIMAL(C3) in (INT64(12), INT32(12), DECIMAL(128,12,3))
+        let expr_lt = col("c3").in_list(
+            vec![
+                lit(ScalarValue::Int32(Some(12))),
+                lit(ScalarValue::Int64(Some(12))),
+                lit(ScalarValue::Decimal128(Some(128), 12, 3)),
+            ],
+            false,
+        );
+        assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
+    }
+
+    #[test]
+    fn test_pre_list_cast_lit_comparison() {
+        let schema = expr_test_schema();
+        // INT32(C1) IN (INT32(12),INT64(24)) -> INT32(C1) IN 
(INT32(12),INT32(24))
+        let expr_lt = col("c1").in_list(
+            vec![
+                lit(ScalarValue::Int32(Some(12))),
+                lit(ScalarValue::Int64(Some(24))),
+            ],
+            false,
+        );
+        let expected = col("c1").in_list(
+            vec![
+                lit(ScalarValue::Int32(Some(12))),
+                lit(ScalarValue::Int32(Some(24))),
+            ],
+            false,
+        );
+        assert_eq!(optimize_test(expr_lt, &schema), expected);
+        // INT32(C2) IN (INT64(NULL),INT64(24)) -> INT32(C1) IN 
(INT32(12),INT32(24))
+        let expr_lt = col("c2").in_list(
+            vec![
+                lit(ScalarValue::Int64(None)),
+                lit(ScalarValue::Int32(Some(14))),
+            ],
+            false,
+        );
+        let expected = col("c2").in_list(
+            vec![
+                lit(ScalarValue::Int64(None)),
+                lit(ScalarValue::Int64(Some(14))),
+            ],
+            false,
+        );
+
+        assert_eq!(optimize_test(expr_lt, &schema), expected);
+
+        // decimal test case
+        let expr_lt = col("c3").in_list(
+            vec![
+                lit(ScalarValue::Int32(Some(12))),
+                lit(ScalarValue::Int64(Some(24))),
+                lit(ScalarValue::Decimal128(Some(128), 10, 2)),
+                lit(ScalarValue::Decimal128(Some(1280), 10, 3)),
+            ],
+            false,
+        );
+        let expected = col("c3").in_list(
+            vec![
+                lit(ScalarValue::Decimal128(Some(1200), 18, 2)),
+                lit(ScalarValue::Decimal128(Some(2400), 18, 2)),
+                lit(ScalarValue::Decimal128(Some(128), 18, 2)),
+                lit(ScalarValue::Decimal128(Some(128), 18, 2)),
+            ],
+            false,
+        );
+        assert_eq!(optimize_test(expr_lt, &schema), expected);
+
+        // INT32(12) IN (.....)
+        let expr_lt = lit(ScalarValue::Int32(Some(12))).in_list(
+            vec![
+                lit(ScalarValue::Int32(Some(12))),
+                lit(ScalarValue::Int64(Some(12))),
+            ],
+            false,
+        );
+        let expected = lit(ScalarValue::Int32(Some(12))).in_list(
+            vec![
+                lit(ScalarValue::Int32(Some(12))),
+                lit(ScalarValue::Int32(Some(12))),
+            ],
+            false,
+        );
+        assert_eq!(optimize_test(expr_lt, &schema), expected);
+    }
+
     #[test]
     fn aliased() {
         let schema = expr_test_schema();
@@ -423,6 +597,7 @@ mod tests {
                     DFField::new(None, "c2", DataType::Int64, false),
                     DFField::new(None, "c3", DataType::Decimal128(18, 2), 
false),
                     DFField::new(None, "c4", DataType::Decimal128(38, 37), 
false),
+                    DFField::new(None, "c5", DataType::Float32, false),
                 ],
                 HashMap::new(),
             )

Reply via email to