liukun4515 commented on code in PR #3185:
URL: https://github.com/apache/arrow-datafusion/pull/3185#discussion_r952099943


##########
datafusion/optimizer/src/pre_cast_lit_in_binary_comparison.rs:
##########
@@ -0,0 +1,295 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Pre-cast literal binary comparison rule can be only used to the binary 
comparison expr.
+//! It can reduce adding the `Expr::Cast` to the expr instead of adding the 
`Expr::Cast` to literal expr.
+use crate::{OptimizerConfig, OptimizerRule};
+use arrow::datatypes::DataType;
+use datafusion_common::{DFSchemaRef, Result, ScalarValue};
+use datafusion_expr::utils::from_plan;
+use datafusion_expr::{binary_expr, lit, Expr, ExprSchemable, LogicalPlan, 
Operator};
+
+/// The rule can be only used to the numeric binary comparison with literal 
expr, like below pattern:
+/// `left_expr comparison_op literal_expr` or `literal_expr comparison_op 
right_expr`.
+/// The data type of two sides must be signed numeric type now, and will 
support more data type later.
+///
+/// If the binary comparison expr match above rules, the optimizer will check 
if the value of `literal`
+/// is in within range(min,max) which is the range(min,max) of the data type 
for `left_expr` or `right_expr`.
+///
+/// If this true, the literal expr will be casted to the data type of expr on 
the other side, and the result of
+/// binary comparison will be `left_expr comparison_op cast(literal_expr, 
left_data_type)` or
+/// `cast(literal_expr, right_data_type) comparison_op right_expr`. For better 
optimization,
+/// the expr of `cast(literal_expr, target_type)` will be precomputed and 
converted to the new expr `new_literal_expr`
+/// which data type is `target_type`.
+/// If this false, do nothing.
+///
+/// This is inspired by the optimizer rule `UnwrapCastInBinaryComparison` of 
Spark.
+/// # Example
+///
+/// `Filter: c1 > INT64(10)` will be optimized to `Filter: c1 > CAST(INT64(10) 
AS INT32),
+/// and continue to be converted to `Filter: c1 > INT32(10)`, if the DataType 
of c1 is INT32.
+///
+#[derive(Default)]
+pub struct PreCastLitInBinaryComparisonExpressions {}
+
+impl PreCastLitInBinaryComparisonExpressions {
+    pub fn new() -> Self {
+        Self::default()
+    }
+}
+
+impl OptimizerRule for PreCastLitInBinaryComparisonExpressions {
+    fn optimize(
+        &self,
+        plan: &LogicalPlan,
+        _optimizer_config: &mut OptimizerConfig,
+    ) -> Result<LogicalPlan> {
+        optimize(plan)
+    }
+
+    fn name(&self) -> &str {
+        "pre_cast_lit_in_binary_comparison"
+    }
+}
+
+fn optimize(plan: &LogicalPlan) -> Result<LogicalPlan> {
+    let new_inputs = plan
+        .inputs()
+        .iter()
+        .map(|input| optimize(input))
+        .collect::<Result<Vec<_>>>()?;
+
+    let schema = plan.schema();
+    let new_exprs = plan
+        .expressions()
+        .into_iter()
+        .map(|expr| visit_expr(expr, schema))
+        .collect::<Vec<_>>();
+
+    from_plan(plan, new_exprs.as_slice(), new_inputs.as_slice())
+}
+
+// Visit all type of expr, if the current has child expr, the child expr 
needed to visit first.
+fn visit_expr(expr: Expr, schema: &DFSchemaRef) -> Expr {
+    // traverse the expr by dfs
+    match &expr {
+        Expr::BinaryExpr { left, op, right } => {
+            // dfs visit the left and right expr
+            let left = visit_expr(*left.clone(), schema);
+            let right = visit_expr(*right.clone(), schema);
+            let left_type = left.get_type(schema);
+            let right_type = right.get_type(schema);
+            // can't get the data type, just return the expr
+            if left_type.is_err() || right_type.is_err() {
+                return expr.clone();
+            }
+            let left_type = left_type.unwrap();
+            let right_type = right_type.unwrap();
+            if !left_type.eq(&right_type)
+                && is_support_data_type(&left_type)
+                && is_support_data_type(&right_type)
+                && is_comparison_op(op)
+            {
+                match (&left, &right) {
+                    (Expr::Literal(_), Expr::Literal(_)) => {
+                        // do nothing
+                    }
+                    (Expr::Literal(left_lit_value), _)
+                        if can_integer_literal_cast_to_type(
+                            left_lit_value,
+                            &right_type,
+                        ) =>
+                    {
+                        // cast the left literal to the right type
+                        return binary_expr(
+                            cast_to_other_scalar_expr(left_lit_value, 
&right_type),
+                            *op,
+                            right,
+                        );
+                    }
+                    (_, Expr::Literal(right_lit_value))
+                        if can_integer_literal_cast_to_type(
+                            right_lit_value,
+                            &left_type,
+                        ) =>
+                    {
+                        // cast the right literal to the left type
+                        return binary_expr(
+                            left,
+                            *op,
+                            cast_to_other_scalar_expr(right_lit_value, 
&left_type),
+                        );
+                    }
+                    (_, _) => {
+                        // do nothing
+                    }
+                };
+            }
+            // return the new binary op
+            binary_expr(left, *op, right)
+        }
+        // TODO: optimize in list
+        // Expr::InList { .. } => {}
+        // TODO: handle other expr type and dfs visit them
+        _ => expr,
+    }
+}
+
+fn cast_to_other_scalar_expr(origin_value: &ScalarValue, target_type: 
&DataType) -> Expr {
+    // null case
+    if origin_value.is_null() {
+        // if the origin value is null, just convert to another type of null 
value
+        // The target type must be satisfied `is_support_data_type` method, we 
can unwrap safely
+        return lit(ScalarValue::try_from(target_type).unwrap());
+    }
+    // no null case
+    let value: i64 = match origin_value {
+        ScalarValue::Int8(Some(v)) => *v as i64,
+        ScalarValue::Int16(Some(v)) => *v as i64,
+        ScalarValue::Int32(Some(v)) => *v as i64,
+        ScalarValue::Int64(Some(v)) => *v as i64,
+        other_type => {
+            panic!("Invalid type and value {:?}", other_type);

Review Comment:
   I check the data type before this method. 
   The data type can hit thie panic which is guaranteed by 
`is_support_data_type`
   
   ```
   fn is_support_data_type(data_type: &DataType) -> bool {
       // TODO support decimal with other data type
       matches!(
           data_type,
           DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64
       )
   }
   ```
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@arrow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to