This is an automated email from the ASF dual-hosted git repository.
morrysnow pushed a commit to branch 2.0.1-rc04-patch
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/2.0.1-rc04-patch by this push:
new c1aae5e984 [enhancement](nereids)remove useless cast for floatlike
type (#23621)
c1aae5e984 is described below
commit c1aae5e984990dba5494a5e79ff831122b229568
Author: starocean999 <[email protected]>
AuthorDate: Wed Aug 30 19:00:16 2023 +0800
[enhancement](nereids)remove useless cast for floatlike type (#23621)
convert cast(c1 AS double) > 2.0 to c1 >= 2 (c1 is integer like type)
---
.../rules/SimplifyComparisonPredicate.java | 155 ++++++++++---
.../test_simplify_comparison.groovy | 248 +++++++++++++++++++++
2 files changed, 376 insertions(+), 27 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java
index c66e27e8b2..19574f8f16 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java
@@ -31,13 +31,21 @@ import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
+import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeV2Literal;
import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal;
import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
+import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.FloatLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DateTimeV2Type;
@@ -46,9 +54,15 @@ import org.apache.doris.nereids.types.DateV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.coercion.DateLikeType;
+import com.google.common.base.Preconditions;
+
+import java.math.BigDecimal;
+import java.math.RoundingMode;
+
/**
* simplify comparison
* such as: cast(c1 as DateV2) >= DateV2Literal --> c1 >= DateLiteral
+ * cast(c1 AS double) > 2.0 --> c1 >= 2 (c1 is integer like type)
*/
public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule
{
@@ -65,6 +79,11 @@ public class SimplifyComparisonPredicate extends
AbstractExpressionRewriteRule {
Expression left = rewrite(cp.left(), context);
Expression right = rewrite(cp.right(), context);
+ // float like type: float, double
+ if (left.getDataType().isFloatLikeType() &&
right.getDataType().isFloatLikeType()) {
+ return processFloatLikeTypeCoercion(cp, left, right);
+ }
+
// decimalv3 type
if (left.getDataType() instanceof DecimalV3Type
&& right.getDataType() instanceof DecimalV3Type) {
@@ -194,6 +213,26 @@ public class SimplifyComparisonPredicate extends
AbstractExpressionRewriteRule {
}
}
+ private Expression processFloatLikeTypeCoercion(ComparisonPredicate
comparisonPredicate,
+ Expression left, Expression right) {
+ if (left instanceof Literal) {
+ comparisonPredicate = comparisonPredicate.commute();
+ Expression temp = left;
+ left = right;
+ right = temp;
+ }
+
+ if (left instanceof Cast &&
left.child(0).getDataType().isIntegerLikeType()
+ && (right instanceof DoubleLiteral || right instanceof
FloatLiteral)) {
+ Cast cast = (Cast) left;
+ left = cast.child();
+ BigDecimal literal = new BigDecimal(((Literal)
right).getStringValue());
+ return processIntegerDecimalLiteralComparison(comparisonPredicate,
left, literal);
+ } else {
+ return comparisonPredicate;
+ }
+ }
+
private Expression processDecimalV3TypeCoercion(ComparisonPredicate
comparisonPredicate,
Expression left, Expression right) {
if (left instanceof DecimalV3Literal) {
@@ -203,51 +242,113 @@ public class SimplifyComparisonPredicate extends
AbstractExpressionRewriteRule {
right = temp;
}
- if (left instanceof Cast &&
left.child(0).getDataType().isDecimalV3Type()
- && right instanceof DecimalV3Literal) {
+ if (left instanceof Cast && right instanceof DecimalV3Literal) {
Cast cast = (Cast) left;
left = cast.child();
DecimalV3Literal literal = (DecimalV3Literal) right;
- if (((DecimalV3Type) left.getDataType())
- .getScale() < ((DecimalV3Type)
literal.getDataType()).getScale()) {
- int toScale = ((DecimalV3Type) left.getDataType()).getScale();
- if (comparisonPredicate instanceof EqualTo) {
- try {
- return comparisonPredicate.withChildren(left, new
DecimalV3Literal(
- (DecimalV3Type) left.getDataType(),
literal.getValue().setScale(toScale)));
- } catch (ArithmeticException e) {
- if (left.nullable()) {
- // TODO: the ideal way is to return an If expr
like:
- // return new If(new IsNull(left), new
NullLiteral(BooleanType.INSTANCE),
- // BooleanLiteral.of(false));
- // but current fold constant rule can't handle
such complex expr with null literal
- // before supporting complex conjuncts with null
literal folding rules,
- // we use a trick way like this:
- return new And(new IsNull(left), new
NullLiteral(BooleanType.INSTANCE));
- } else {
+ if (left.getDataType().isDecimalV3Type()) {
+ if (((DecimalV3Type) left.getDataType())
+ .getScale() < ((DecimalV3Type)
literal.getDataType()).getScale()) {
+ int toScale = ((DecimalV3Type)
left.getDataType()).getScale();
+ if (comparisonPredicate instanceof EqualTo) {
+ try {
+ return comparisonPredicate.withChildren(left,
+ new DecimalV3Literal((DecimalV3Type)
left.getDataType(),
+
literal.getValue().setScale(toScale)));
+ } catch (ArithmeticException e) {
+ if (left.nullable()) {
+ // TODO: the ideal way is to return an If expr
like:
+ // return new If(new IsNull(left), new
NullLiteral(BooleanType.INSTANCE),
+ // BooleanLiteral.of(false));
+ // but current fold constant rule can't handle
such complex expr with null literal
+ // before supporting complex conjuncts with
null literal folding rules,
+ // we use a trick way like this:
+ return new And(new IsNull(left),
+ new NullLiteral(BooleanType.INSTANCE));
+ } else {
+ return BooleanLiteral.of(false);
+ }
+ }
+ } else if (comparisonPredicate instanceof NullSafeEqual) {
+ try {
+ return comparisonPredicate.withChildren(left,
+ new DecimalV3Literal((DecimalV3Type)
left.getDataType(),
+
literal.getValue().setScale(toScale)));
+ } catch (ArithmeticException e) {
return BooleanLiteral.of(false);
}
+ } else if (comparisonPredicate instanceof GreaterThan
+ || comparisonPredicate instanceof LessThanEqual) {
+ return comparisonPredicate.withChildren(left,
literal.roundFloor(toScale));
+ } else if (comparisonPredicate instanceof LessThan
+ || comparisonPredicate instanceof
GreaterThanEqual) {
+ return comparisonPredicate.withChildren(left,
+ literal.roundCeiling(toScale));
}
- } else if (comparisonPredicate instanceof NullSafeEqual) {
- try {
- return comparisonPredicate.withChildren(left, new
DecimalV3Literal(
- (DecimalV3Type) left.getDataType(),
literal.getValue().setScale(toScale)));
- } catch (ArithmeticException e) {
+ }
+ } else if (left.getDataType().isIntegerLikeType()) {
+ return
processIntegerDecimalLiteralComparison(comparisonPredicate, left,
+ literal.getValue());
+ }
+ }
+
+ return comparisonPredicate;
+ }
+
+ private Expression processIntegerDecimalLiteralComparison(
+ ComparisonPredicate comparisonPredicate, Expression left,
BigDecimal literal) {
+ // we only process isIntegerLikeType, which are tinyint, smallint,
int, bigint
+ if (literal.compareTo(new BigDecimal(Long.MAX_VALUE)) <= 0) {
+ if (literal.scale() > 0) {
+ if (comparisonPredicate instanceof EqualTo) {
+ if (left.nullable()) {
+ // TODO: the ideal way is to return an If expr like:
+ // return new If(new IsNull(left), new
NullLiteral(BooleanType.INSTANCE),
+ // BooleanLiteral.of(false));
+ // but current fold constant rule can't handle such
complex expr with null literal
+ // before supporting complex conjuncts with null
literal folding rules,
+ // we use a trick way like this:
+ return new And(new IsNull(left), new
NullLiteral(BooleanType.INSTANCE));
+ } else {
return BooleanLiteral.of(false);
}
+ } else if (comparisonPredicate instanceof NullSafeEqual) {
+ return BooleanLiteral.of(false);
} else if (comparisonPredicate instanceof GreaterThan
|| comparisonPredicate instanceof LessThanEqual) {
- return comparisonPredicate.withChildren(left,
literal.roundFloor(toScale));
+ return comparisonPredicate.withChildren(left,
+ convertDecimalToIntegerLikeLiteral(
+ literal.setScale(0, RoundingMode.FLOOR)));
} else if (comparisonPredicate instanceof LessThan
|| comparisonPredicate instanceof GreaterThanEqual) {
- return comparisonPredicate.withChildren(left,
literal.roundCeiling(toScale));
+ return comparisonPredicate.withChildren(left,
+ convertDecimalToIntegerLikeLiteral(
+ literal.setScale(0,
RoundingMode.CEILING)));
}
+ } else {
+ return comparisonPredicate.withChildren(left,
+ convertDecimalToIntegerLikeLiteral(literal));
}
}
-
return comparisonPredicate;
}
+ private IntegerLikeLiteral convertDecimalToIntegerLikeLiteral(BigDecimal
decimal) {
+ Preconditions.checkArgument(
+ decimal.scale() == 0 && decimal.compareTo(new
BigDecimal(Long.MAX_VALUE)) <= 0,
+ "decimal literal must have 0 scale and smaller than
Long.MAX_VALUE");
+ long val = decimal.longValue();
+ if (val <= Byte.MAX_VALUE) {
+ return new TinyIntLiteral((byte) val);
+ } else if (val <= Short.MAX_VALUE) {
+ return new SmallIntLiteral((short) val);
+ } else if (val <= Integer.MAX_VALUE) {
+ return new IntegerLiteral((int) val);
+ } else {
+ return new BigIntLiteral(val);
+ }
+ }
+
private Expression migrateCastToDateTime(Cast cast) {
//cast( cast(v as date) as datetime) if v is datetime, set left = v
if (cast.child() instanceof Cast
diff --git
a/regression-test/suites/nereids_syntax_p0/test_simplify_comparison.groovy
b/regression-test/suites/nereids_syntax_p0/test_simplify_comparison.groovy
index 53c0ff9a12..4b3cd3bdca 100644
--- a/regression-test/suites/nereids_syntax_p0/test_simplify_comparison.groovy
+++ b/regression-test/suites/nereids_syntax_p0/test_simplify_comparison.groovy
@@ -72,4 +72,252 @@ suite("test_simplify_comparison") {
}
sql "select cast('1234' as decimalv3(18,4)) > 2000;"
+
+ sql 'drop table if exists simple_test_table_t;'
+ sql """CREATE TABLE IF NOT EXISTS `simple_test_table_t` (
+ a tinyint,
+ b smallint,
+ c int,
+ d bigint,
+ e largeint
+ ) ENGINE=OLAP
+ UNIQUE KEY (`a`)
+ DISTRIBUTED BY HASH(`a`) BUCKETS 120
+ PROPERTIES (
+ "replication_num" = "1",
+ "in_memory" = "false",
+ "compression" = "LZ4"
+ );"""
+
+ explain {
+ sql "verbose select * from simple_test_table_t where a = cast(1.0 as
double) and b = cast(1.0 as double) and c = cast(1.0 as double) and d =
cast(1.0 as double);"
+ notContains "CAST"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where e = cast(1.0 as
double);"
+ contains "CAST"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where a > cast(1.0 as
double) and b > cast(1.0 as double) and c > cast(1.0 as double) and d >
cast(1.0 as double);"
+ notContains "CAST"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where e > cast(1.0 as
double);"
+ contains "CAST"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where a < cast(1.0 as
double) and b < cast(1.0 as double) and c < cast(1.0 as double) and d <
cast(1.0 as double);"
+ notContains "CAST"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where e < cast(1.0 as
double);"
+ contains "CAST"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where a >= cast(1.0 as
double) and b >= cast(1.0 as double) and c >= cast(1.0 as double) and d >=
cast(1.0 as double);"
+ notContains "CAST"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where e >= cast(1.0 as
double);"
+ contains "CAST"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where a <= cast(1.0 as
double) and b <= cast(1.0 as double) and c <= cast(1.0 as double) and d <=
cast(1.0 as double);"
+ notContains "CAST"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where e <= cast(1.0 as
double);"
+ contains "CAST"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where a = cast(1.1 as
double) and b = cast(1.1 as double) and c = cast(1.1 as double) and d =
cast(1.1 as double);"
+ contains "a[#0] IS NULL"
+ contains "b[#1] IS NULL"
+ contains "c[#2] IS NULL"
+ contains "d[#3] IS NULL"
+ contains "AND NULL"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where e = cast(1.1 as
double);"
+ contains "CAST(e[#4] AS DOUBLE) = 1.1"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where a > cast(1.1 as
double) and b > cast(1.1 as double) and c > cast(1.1 as double) and d >
cast(1.1 as double);"
+ contains "a[#0] > 1"
+ contains "b[#1] > 1"
+ contains "c[#2] > 1"
+ contains "d[#3] > 1"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where e > cast(1.1 as
double);"
+ contains "CAST(e[#4] AS DOUBLE) > 1.1"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where a < cast(1.1 as
double) and b < cast(1.1 as double) and c < cast(1.1 as double) and d <
cast(1.1 as double);"
+ contains "a[#0] < 2"
+ contains "b[#1] < 2"
+ contains "c[#2] < 2"
+ contains "d[#3] < 2"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where e < cast(1.1 as
double);"
+ contains "CAST(e[#4] AS DOUBLE) < 1.1"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where a >= cast(1.1 as
double) and b >= cast(1.1 as double) and c >= cast(1.1 as double) and d >=
cast(1.1 as double);"
+ contains "a[#0] >= 2"
+ contains "b[#1] >= 2"
+ contains "c[#2] >= 2"
+ contains "d[#3] >= 2"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where e >= cast(1.1 as
double);"
+ contains "CAST(e[#4] AS DOUBLE) >= 1.1"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where a <= cast(1.1 as
double) and b <= cast(1.1 as double) and c <= cast(1.1 as double) and d <=
cast(1.1 as double);"
+ contains "a[#0] <= 1"
+ contains "b[#1] <= 1"
+ contains "c[#2] <= 1"
+ contains "d[#3] <= 1"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where e <= cast(1.1 as
double);"
+ contains "CAST(e[#4] AS DOUBLE) <= 1.1"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where a = 1.0 and b =
1.0 and c = 1.0 and d = 1.0;"
+ notContains "CAST"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where e = 1.0;"
+ contains "CAST"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where a > 1.0 and b >
1.0 and c > 1.0 and d > 1.0;"
+ notContains "CAST"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where e > 1.0;"
+ contains "CAST"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where a < 1.0 and b <
1.0 and c < 1.0 and d < 1.0;"
+ notContains "CAST"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where e < 1.0;"
+ contains "CAST"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where a >= 1.0 and b >=
1.0 and c >= 1.0 and d >= 1.0;"
+ notContains "CAST"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where e >= 1.0;"
+ contains "CAST"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where a <= 1.0 and b <=
1.0 and c <= 1.0 and d <= 1.0;"
+ notContains "CAST"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where e <= 1.0;"
+ contains "CAST"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where a = 1.1 and b =
1.1 and c = 1.1 and d = 1.1;"
+ contains "a[#0] IS NULL"
+ contains "b[#1] IS NULL"
+ contains "c[#2] IS NULL"
+ contains "d[#3] IS NULL"
+ contains "AND NULL"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where e = 1.1;"
+ contains "CAST(e[#4] AS DOUBLE) = 1.1"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where a > 1.1 and b >
1.1 and c > 1.1 and d > 1.1;"
+ contains "a[#0] > 1"
+ contains "b[#1] > 1"
+ contains "c[#2] > 1"
+ contains "d[#3] > 1"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where e > 1.1;"
+ contains "CAST(e[#4] AS DOUBLE) > 1.1"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where a < 1.1 and b <
1.1 and c < 1.1 and d < 1.1;"
+ contains "a[#0] < 2"
+ contains "b[#1] < 2"
+ contains "c[#2] < 2"
+ contains "d[#3] < 2"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where e < 1.1;"
+ contains "CAST(e[#4] AS DOUBLE) < 1.1"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where a >= 1.1 and b >=
1.1 and c >= 1.1 and d >= 1.1;"
+ contains "a[#0] >= 2"
+ contains "b[#1] >= 2"
+ contains "c[#2] >= 2"
+ contains "d[#3] >= 2"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where e >= 1.1;"
+ contains "CAST(e[#4] AS DOUBLE) >= 1.1"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where a <= 1.1 and b <=
1.1 and c <= 1.1 and d <= 1.1;"
+ contains "a[#0] <= 1"
+ contains "b[#1] <= 1"
+ contains "c[#2] <= 1"
+ contains "d[#3] <= 1"
+ }
+
+ explain {
+ sql "verbose select * from simple_test_table_t where e <= 1.1;"
+ contains "CAST(e[#4] AS DOUBLE) <= 1.1"
+ }
}
\ No newline at end of file
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]