This is an automated email from the ASF dual-hosted git repository.
liukun pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new a4e74c0bc support inlist for pre cast literal expression (#3270)
a4e74c0bc is described below
commit a4e74c0bc9b6e46dd151d40e5c881b7961fecccc
Author: Kun Liu <[email protected]>
AuthorDate: Tue Aug 30 17:48:15 2022 +0800
support inlist for pre cast literal expression (#3270)
* support decimal for the PreCastLitInComparisonExpressions rule
* address comments
* support list
---
.../optimizer/src/pre_cast_lit_in_comparison.rs | 181 ++++++++++++++++++++-
1 file changed, 178 insertions(+), 3 deletions(-)
diff --git a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs
b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs
index 6e89afd60..793eca2f3 100644
--- a/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs
+++ b/datafusion/optimizer/src/pre_cast_lit_in_comparison.rs
@@ -24,7 +24,9 @@ use arrow::datatypes::{
use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue};
use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter,
RewriteRecursion};
use datafusion_expr::utils::from_plan;
-use datafusion_expr::{binary_expr, lit, Expr, ExprSchemable, LogicalPlan,
Operator};
+use datafusion_expr::{
+ binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator,
+};
/// The rule can be only used to the numeric binary comparison with literal
expr, like below pattern:
/// `left_expr comparison_op literal_expr` or `literal_expr comparison_op
right_expr`.
@@ -144,8 +146,57 @@ impl ExprRewriter for PreCastLitExprRewriter {
// return the new binary op
Ok(binary_expr(left, *op, right))
}
- // TODO: optimize in list
- // Expr::InList { .. } => {}
+ Expr::InList {
+ expr: left_expr,
+ list,
+ negated,
+ } => {
+ let left = left_expr.as_ref().clone();
+ let left_type = left.get_type(&self.schema);
+ if left_type.is_err() {
+ // error data type
+ return Ok(expr);
+ }
+ let left_type = left_type?;
+ if !is_support_data_type(&left_type) {
+ // not supported data type
+ return Ok(expr);
+ }
+ let right_exprs = list
+ .iter()
+ .map(|right| {
+ let right_type = right.get_type(&self.schema)?;
+ if !is_support_data_type(&right_type) {
+ return Err(DataFusionError::Internal(format!(
+ "The type of list expr {} not support",
+ &right_type
+ )));
+ }
+ match right {
+ Expr::Literal(right_lit_value) => {
+ let casted_scalar_value =
+ try_cast_literal_to_type(right_lit_value,
&left_type)?;
+ if let Some(value) = casted_scalar_value {
+ Ok(lit(value))
+ } else {
+ Err(DataFusionError::Internal(format!(
+ "Can't cast the list expr {:?} to type
{:?}",
+ right_lit_value, &left_type
+ )))
+ }
+ }
+ other_expr =>
Err(DataFusionError::Internal(format!(
+ "Only support literal expr to optimize, but
the expr is {:?}",
+ &other_expr
+ ))),
+ }
+ })
+ .collect::<Result<Vec<_>>>();
+ match right_exprs {
+ Ok(right_exprs) => Ok(in_list(left, right_exprs,
*negated)),
+ Err(_) => Ok(expr),
+ }
+ }
// TODO: handle other expr type and dfs visit them
_ => Ok(expr),
}
@@ -384,6 +435,129 @@ mod tests {
assert_eq!(optimize_test(expr_lt, &schema), expected);
}
+ #[test]
+ fn test_not_list_cast_lit_comparison() {
+ let schema = expr_test_schema();
+ // left type is not supported
+ // FLOAT32(C5) in ...
+ let expr_lt = col("c5").in_list(
+ vec![
+ lit(ScalarValue::Int64(Some(12))),
+ lit(ScalarValue::Int32(Some(12))),
+ ],
+ false,
+ );
+ assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
+
+ // INT32(C1) in (FLOAT32(1.23), INT32(12), INT64(12))
+ let expr_lt = col("c1").in_list(
+ vec![
+ lit(ScalarValue::Int32(Some(12))),
+ lit(ScalarValue::Int64(Some(12))),
+ lit(ScalarValue::Float32(Some(1.23))),
+ ],
+ false,
+ );
+ assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
+
+ // INT32(C1) in (INT64(99999999999), INT64(12))
+ let expr_lt = col("c1").in_list(
+ vec![
+ lit(ScalarValue::Int32(Some(12))),
+ lit(ScalarValue::Int64(Some(99999999999))),
+ ],
+ false,
+ );
+ assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
+
+ // DECIMAL(C3) in (INT64(12), INT32(12), DECIMAL(128,12,3))
+ let expr_lt = col("c3").in_list(
+ vec![
+ lit(ScalarValue::Int32(Some(12))),
+ lit(ScalarValue::Int64(Some(12))),
+ lit(ScalarValue::Decimal128(Some(128), 12, 3)),
+ ],
+ false,
+ );
+ assert_eq!(optimize_test(expr_lt.clone(), &schema), expr_lt);
+ }
+
+ #[test]
+ fn test_pre_list_cast_lit_comparison() {
+ let schema = expr_test_schema();
+ // INT32(C1) IN (INT32(12),INT64(24)) -> INT32(C1) IN
(INT32(12),INT32(24))
+ let expr_lt = col("c1").in_list(
+ vec![
+ lit(ScalarValue::Int32(Some(12))),
+ lit(ScalarValue::Int64(Some(24))),
+ ],
+ false,
+ );
+ let expected = col("c1").in_list(
+ vec![
+ lit(ScalarValue::Int32(Some(12))),
+ lit(ScalarValue::Int32(Some(24))),
+ ],
+ false,
+ );
+ assert_eq!(optimize_test(expr_lt, &schema), expected);
+ // INT32(C2) IN (INT64(NULL),INT64(24)) -> INT32(C1) IN
(INT32(12),INT32(24))
+ let expr_lt = col("c2").in_list(
+ vec![
+ lit(ScalarValue::Int64(None)),
+ lit(ScalarValue::Int32(Some(14))),
+ ],
+ false,
+ );
+ let expected = col("c2").in_list(
+ vec![
+ lit(ScalarValue::Int64(None)),
+ lit(ScalarValue::Int64(Some(14))),
+ ],
+ false,
+ );
+
+ assert_eq!(optimize_test(expr_lt, &schema), expected);
+
+ // decimal test case
+ let expr_lt = col("c3").in_list(
+ vec![
+ lit(ScalarValue::Int32(Some(12))),
+ lit(ScalarValue::Int64(Some(24))),
+ lit(ScalarValue::Decimal128(Some(128), 10, 2)),
+ lit(ScalarValue::Decimal128(Some(1280), 10, 3)),
+ ],
+ false,
+ );
+ let expected = col("c3").in_list(
+ vec![
+ lit(ScalarValue::Decimal128(Some(1200), 18, 2)),
+ lit(ScalarValue::Decimal128(Some(2400), 18, 2)),
+ lit(ScalarValue::Decimal128(Some(128), 18, 2)),
+ lit(ScalarValue::Decimal128(Some(128), 18, 2)),
+ ],
+ false,
+ );
+ assert_eq!(optimize_test(expr_lt, &schema), expected);
+
+ // INT32(12) IN (.....)
+ let expr_lt = lit(ScalarValue::Int32(Some(12))).in_list(
+ vec![
+ lit(ScalarValue::Int32(Some(12))),
+ lit(ScalarValue::Int64(Some(12))),
+ ],
+ false,
+ );
+ let expected = lit(ScalarValue::Int32(Some(12))).in_list(
+ vec![
+ lit(ScalarValue::Int32(Some(12))),
+ lit(ScalarValue::Int32(Some(12))),
+ ],
+ false,
+ );
+ assert_eq!(optimize_test(expr_lt, &schema), expected);
+ }
+
#[test]
fn aliased() {
let schema = expr_test_schema();
@@ -423,6 +597,7 @@ mod tests {
DFField::new(None, "c2", DataType::Int64, false),
DFField::new(None, "c3", DataType::Decimal128(18, 2),
false),
DFField::new(None, "c4", DataType::Decimal128(38, 37),
false),
+ DFField::new(None, "c5", DataType::Float32, false),
],
HashMap::new(),
)