This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch branch-2.0
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-2.0 by this push:
     new 3877a7a4adc [fix](Nereids) simplify decimal comparison wrong when cast 
to smaller scale (#41151) (#42871)
3877a7a4adc is described below

commit 3877a7a4adcf593eb29364f79b2d558bbdd149b1
Author: morrySnow <101034200+morrys...@users.noreply.github.com>
AuthorDate: Thu Oct 31 14:08:01 2024 +0800

    [fix](Nereids) simplify decimal comparison wrong when cast to smaller scale 
(#41151) (#42871)
    
    pick from master #41151
---
 .../rules/expression/rules/SimplifyCastRule.java   |  11 +-
 .../rules/SimplifyComparisonPredicate.java         |  26 +++--
 .../rules/SimplifyDecimalV3Comparison.java         |  26 +++--
 .../expressions/literal/DecimalV3Literal.java      |   8 +-
 .../org/apache/doris/nereids/types/DataType.java   |   2 +-
 .../expression/rules/SimplifyCastRuleTest.java     |  51 ++++----
 .../rules/SimplifyComparisonPredicateTest.java     | 128 +++++++++++++++++++++
 .../rules/SimplifyDecimalV3ComparisonTest.java     |  47 +++++---
 .../test_simplify_decimal_comparison.groovy        |  28 +++++
 9 files changed, 253 insertions(+), 74 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyCastRule.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyCastRule.java
index 34143043a07..2f23412c3fc 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyCastRule.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyCastRule.java
@@ -38,6 +38,7 @@ import org.apache.doris.nereids.types.StringType;
 import org.apache.doris.nereids.types.VarcharType;
 
 import java.math.BigDecimal;
+import java.math.RoundingMode;
 
 /**
  * Rewrite rule of simplify CAST expression.
@@ -107,8 +108,14 @@ public class SimplifyCastRule extends 
AbstractExpressionRewriteRule {
                         return new DecimalV3Literal(decimalV3Type,
                                 new BigDecimal(((BigIntLiteral) 
child).getValue()));
                     } else if (child instanceof DecimalV3Literal) {
-                        return new DecimalV3Literal(decimalV3Type,
-                                ((DecimalV3Literal) child).getValue());
+                        DecimalV3Type childType = (DecimalV3Type) 
child.getDataType();
+                        if (childType.getRange() <= decimalV3Type.getRange()) {
+                            return new DecimalV3Literal(decimalV3Type,
+                                    ((DecimalV3Literal) child).getValue()
+                                            
.setScale(decimalV3Type.getScale(), RoundingMode.HALF_UP));
+                        } else {
+                            return cast;
+                        }
                     }
                 }
             } catch (Throwable t) {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java
index 9f719c73772..488f7bddfc6 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java
@@ -221,9 +221,10 @@ public class SimplifyComparisonPredicate extends 
AbstractExpressionRewriteRule {
                     int toScale = ((DecimalV3Type) 
left.getDataType()).getScale();
                     if (comparisonPredicate instanceof EqualTo) {
                         try {
-                            return comparisonPredicate.withChildren(left,
-                                    new DecimalV3Literal((DecimalV3Type) 
left.getDataType(),
-                                            
literal.getValue().setScale(toScale)));
+                            Expression decimal = new 
DecimalV3Literal((DecimalV3Type) left.getDataType(),
+                                    literal.getValue().setScale(toScale, 
RoundingMode.UNNECESSARY));
+                            return 
TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate)
+                                    comparisonPredicate.withChildren(left, 
decimal), left, decimal);
                         } catch (ArithmeticException e) {
                             if (left.nullable()) {
                                 // TODO: the ideal way is to return an If expr 
like:
@@ -240,24 +241,27 @@ public class SimplifyComparisonPredicate extends 
AbstractExpressionRewriteRule {
                         }
                     } else if (comparisonPredicate instanceof NullSafeEqual) {
                         try {
-                            return comparisonPredicate.withChildren(left,
-                                    new DecimalV3Literal((DecimalV3Type) 
left.getDataType(),
-                                            
literal.getValue().setScale(toScale)));
+                            Expression decimal = new 
DecimalV3Literal((DecimalV3Type) left.getDataType(),
+                                    literal.getValue().setScale(toScale, 
RoundingMode.UNNECESSARY));
+                            return 
TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate)
+                                    comparisonPredicate.withChildren(left, 
decimal), left, decimal);
                         } catch (ArithmeticException e) {
                             return BooleanLiteral.of(false);
                         }
                     } else if (comparisonPredicate instanceof GreaterThan
                             || comparisonPredicate instanceof LessThanEqual) {
-                        return comparisonPredicate.withChildren(left, 
literal.roundFloor(toScale));
+                        literal = literal.roundFloor(toScale);
+                        return 
TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate)
+                                comparisonPredicate.withChildren(left, 
literal), left, literal);
                     } else if (comparisonPredicate instanceof LessThan
                             || comparisonPredicate instanceof 
GreaterThanEqual) {
-                        return comparisonPredicate.withChildren(left,
-                                literal.roundCeiling(toScale));
+                        literal = literal.roundCeiling(toScale);
+                        return 
TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate)
+                                comparisonPredicate.withChildren(left, 
literal), left, literal);
                     }
                 }
             } else if (left.getDataType().isIntegerLikeType()) {
-                return 
processIntegerDecimalLiteralComparison(comparisonPredicate, left,
-                        literal.getValue());
+                return 
processIntegerDecimalLiteralComparison(comparisonPredicate, left, 
literal.getValue());
             }
         }
 
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3Comparison.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3Comparison.java
index b821d7a4d19..98a6a9112f8 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3Comparison.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3Comparison.java
@@ -17,6 +17,7 @@
 
 package org.apache.doris.nereids.rules.expression.rules;
 
+import org.apache.doris.nereids.exceptions.AnalysisException;
 import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
 import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
 import org.apache.doris.nereids.trees.expressions.Cast;
@@ -25,8 +26,6 @@ import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
 import org.apache.doris.nereids.types.DecimalV3Type;
 
-import com.google.common.base.Preconditions;
-
 import java.math.BigDecimal;
 import java.math.RoundingMode;
 
@@ -50,15 +49,17 @@ public class SimplifyDecimalV3Comparison extends 
AbstractExpressionRewriteRule {
         if (left.getDataType() instanceof DecimalV3Type
                 && left instanceof Cast
                 && ((Cast) left).child().getDataType() instanceof DecimalV3Type
+                && ((DecimalV3Type) left.getDataType()).getScale()
+                >= ((DecimalV3Type) ((Cast) 
left).child().getDataType()).getScale()
                 && right instanceof DecimalV3Literal) {
-            return doProcess(cp, (Cast) left, (DecimalV3Literal) right);
+            try {
+                return doProcess(cp, (Cast) left, (DecimalV3Literal) right);
+            } catch (ArithmeticException e) {
+                return cp;
+            }
         }
 
-        if (left != cp.left() || right != cp.right()) {
-            return cp.withChildren(left, right);
-        } else {
-            return cp;
-        }
+        return cp;
     }
 
     private Expression doProcess(ComparisonPredicate cp, Cast left, 
DecimalV3Literal right) {
@@ -72,13 +73,16 @@ public class SimplifyDecimalV3Comparison extends 
AbstractExpressionRewriteRule {
         }
 
         Expression castChild = left.child();
-        Preconditions.checkState(castChild.getDataType() instanceof 
DecimalV3Type);
+        if (!(castChild.getDataType() instanceof DecimalV3Type)) {
+            throw new AnalysisException("cast child's type should be 
DecimalV3Type, but its type is "
+                    + castChild.getDataType().toSql());
+        }
         DecimalV3Type leftType = (DecimalV3Type) castChild.getDataType();
-        if (scale <= leftType.getScale() && precision - scale <= 
leftType.getPrecision() - leftType.getScale()) {
+        if (scale <= leftType.getScale() && precision - scale <= 
leftType.getRange()) {
             // precision and scale of literal all smaller than left, we don't 
need the cast
             DecimalV3Literal newRight = new DecimalV3Literal(
                     DecimalV3Type.createDecimalV3Type(leftType.getPrecision(), 
leftType.getScale()),
-                    trailingZerosValue);
+                    trailingZerosValue.setScale(leftType.getScale(), 
RoundingMode.UNNECESSARY));
             return cp.withChildren(castChild, newRight);
         } else {
             return cp;
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalV3Literal.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalV3Literal.java
index c797e93cb6d..d80dd7a4cc3 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalV3Literal.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalV3Literal.java
@@ -73,15 +73,11 @@ public class DecimalV3Literal extends FractionalLiteral {
     }
 
     public DecimalV3Literal roundCeiling(int newScale) {
-        return new DecimalV3Literal(DecimalV3Type
-                .createDecimalV3Type(((DecimalV3Type) 
dataType).getPrecision(), newScale),
-                value.setScale(newScale, RoundingMode.CEILING));
+        return new DecimalV3Literal(value.setScale(newScale, 
RoundingMode.CEILING));
     }
 
     public DecimalV3Literal roundFloor(int newScale) {
-        return new DecimalV3Literal(DecimalV3Type
-                .createDecimalV3Type(((DecimalV3Type) 
dataType).getPrecision(), newScale),
-                value.setScale(newScale, RoundingMode.FLOOR));
+        return new DecimalV3Literal(value.setScale(newScale, 
RoundingMode.FLOOR));
     }
 
     /**
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java
index ba5d2b70eba..0c15e39bc44 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java
@@ -174,7 +174,7 @@ public abstract class DataType implements AbstractDataType {
             case "decimalv3":
                 switch (types.size()) {
                     case 1:
-                        return DecimalV3Type.CATALOG_DEFAULT;
+                        return DecimalV3Type.createDecimalV3Type(38, 9);
                     case 2:
                         return 
DecimalV3Type.createDecimalV3Type(Integer.parseInt(types.get(1)));
                     case 3:
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyCastRuleTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyCastRuleTest.java
index 658775cedad..4799f70fbcc 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyCastRuleTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyCastRuleTest.java
@@ -19,42 +19,45 @@ package org.apache.doris.nereids.rules.expression.rules;
 
 import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper;
 import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor;
+import org.apache.doris.nereids.trees.expressions.Cast;
 import org.apache.doris.nereids.trees.expressions.Expression;
-import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
+import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
 import org.apache.doris.nereids.types.DecimalV3Type;
 import org.apache.doris.nereids.types.StringType;
 import org.apache.doris.nereids.types.VarcharType;
 
 import com.google.common.collect.ImmutableList;
-import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
+import java.math.BigDecimal;
+
 class SimplifyCastRuleTest extends ExpressionRewriteTestHelper {
 
     @Test
     public void testSimplify() {
         executor = new 
ExpressionRuleExecutor(ImmutableList.of(SimplifyCastRule.INSTANCE));
-        assertRewriteAfterSimplify("CAST('1' AS STRING)", "'1'",
-                StringType.INSTANCE);
-        assertRewriteAfterSimplify("CAST('1' AS VARCHAR)", "'1'",
-                VarcharType.createVarcharType(-1));
-        assertRewriteAfterSimplify("CAST(1 AS DECIMAL)", "1.000000000",
-                DecimalV3Type.createDecimalV3Type(38, 9));
-        assertRewriteAfterSimplify("CAST(1000 AS DECIMAL)", "1000.000000000",
-                DecimalV3Type.createDecimalV3Type(38, 9));
-        assertRewriteAfterSimplify("CAST(1 AS DECIMALV3)", "1",
-                DecimalV3Type.createDecimalV3Type(9, 0));
-        assertRewriteAfterSimplify("CAST(1000 AS DECIMALV3)", "1000",
-                DecimalV3Type.createDecimalV3Type(9, 0));
+        assertRewrite(new Cast(new VarcharLiteral("1"), StringType.INSTANCE),
+                new StringLiteral("1"));
+        assertRewrite(new Cast(new VarcharLiteral("1"), 
VarcharType.SYSTEM_DEFAULT),
+                new VarcharLiteral("1", -1));
+        assertRewrite(new Cast(new TinyIntLiteral((byte) 1), 
DecimalV3Type.SYSTEM_DEFAULT),
+                new DecimalV3Literal(DecimalV3Type.SYSTEM_DEFAULT, new 
BigDecimal("1.000000000")));
+        assertRewrite(new Cast(new SmallIntLiteral((short) 1000), 
DecimalV3Type.SYSTEM_DEFAULT),
+                new DecimalV3Literal(DecimalV3Type.SYSTEM_DEFAULT, new 
BigDecimal("1000.000000000")));
+        assertRewrite(new Cast(new VarcharLiteral("1"), 
VarcharType.SYSTEM_DEFAULT), new VarcharLiteral("1", -1));
+        assertRewrite(new Cast(new VarcharLiteral("1"), 
VarcharType.SYSTEM_DEFAULT), new VarcharLiteral("1", -1));
+
+        Expression decimalV3Literal = new 
DecimalV3Literal(DecimalV3Type.createDecimalV3Type(5, 3),
+                new BigDecimal("12.000"));
+        assertRewrite(new Cast(decimalV3Literal, 
DecimalV3Type.createDecimalV3Type(7, 3)),
+                new DecimalV3Literal(DecimalV3Type.createDecimalV3Type(7, 3),
+                        new BigDecimal("12.000")));
+        assertRewrite(new Cast(decimalV3Literal, 
DecimalV3Type.createDecimalV3Type(3, 1)),
+                new DecimalV3Literal(DecimalV3Type.createDecimalV3Type(3, 1),
+                        new BigDecimal("12.0")));
     }
-
-    private void assertRewriteAfterSimplify(String expr, String expected, 
DataType expectedType) {
-        Expression needRewriteExpression = PARSER.parseExpression(expr);
-        Expression rewritten = 
SimplifyCastRule.INSTANCE.rewrite(needRewriteExpression, context);
-        Expression expectedExpression = PARSER.parseExpression(expected);
-        Assertions.assertEquals(expectedExpression.toSql(), rewritten.toSql());
-        Assertions.assertEquals(expectedType, rewritten.getDataType());
-
-    }
-
 }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java
index 224fa652386..122e0b444e7 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java
@@ -19,25 +19,37 @@ package org.apache.doris.nereids.rules.expression.rules;
 
 import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper;
 import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor;
+import org.apache.doris.nereids.trees.expressions.And;
 import org.apache.doris.nereids.trees.expressions.Cast;
 import org.apache.doris.nereids.trees.expressions.EqualTo;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.GreaterThan;
 import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
+import org.apache.doris.nereids.trees.expressions.IsNull;
 import org.apache.doris.nereids.trees.expressions.LessThan;
+import org.apache.doris.nereids.trees.expressions.LessThanEqual;
+import org.apache.doris.nereids.trees.expressions.NullSafeEqual;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.DateTimeV2Literal;
 import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal;
+import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
 import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
+import org.apache.doris.nereids.types.BooleanType;
 import org.apache.doris.nereids.types.DateTimeV2Type;
+import org.apache.doris.nereids.types.DecimalV3Type;
 import org.apache.doris.nereids.types.DoubleType;
 
 import com.google.common.collect.ImmutableList;
 import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
+import java.math.BigDecimal;
+
 class SimplifyComparisonPredicateTest extends ExpressionRewriteTestHelper {
     @Test
     void testSimplifyComparisonPredicateRule() {
@@ -137,4 +149,120 @@ class SimplifyComparisonPredicateTest extends 
ExpressionRewriteTestHelper {
         Assertions.assertEquals(left.child(0).getDataType(), 
rewrittenExpression.child(1).getDataType());
         Assertions.assertEquals(rewrittenExpression.child(0).getDataType(), 
rewrittenExpression.child(1).getDataType());
     }
+
+    @Test
+    void testDecimalV3Literal() {
+        executor = new 
ExpressionRuleExecutor(ImmutableList.of(SimplifyComparisonPredicate.INSTANCE));
+
+        // should not simplify
+        Expression leftChild = new DecimalV3Literal(new BigDecimal("1.24"));
+        Expression left = new Cast(leftChild, 
DecimalV3Type.createDecimalV3Type(2, 1));
+        Expression right = new DecimalV3Literal(new BigDecimal("1.2"));
+        Expression expression = new EqualTo(left, right);
+        Expression rewrittenExpression = executor.rewrite(expression, context);
+        Assertions.assertInstanceOf(Cast.class, rewrittenExpression.child(0));
+        Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(2, 1),
+                rewrittenExpression.child(0).getDataType());
+
+        // = round UNNECESSARY
+        leftChild = new DecimalV3Literal(new BigDecimal("11.24"));
+        left = new Cast(leftChild, DecimalV3Type.createDecimalV3Type(5, 3));
+        right = new DecimalV3Literal(new BigDecimal("12.340"));
+        expression = new EqualTo(left, right);
+        rewrittenExpression = executor.rewrite(expression, context);
+        Assertions.assertInstanceOf(DecimalV3Literal.class, 
rewrittenExpression.child(0));
+        Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(4, 2),
+                rewrittenExpression.child(0).getDataType());
+        Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(4, 2),
+                rewrittenExpression.child(1).getDataType());
+        Assertions.assertInstanceOf(DecimalV3Literal.class, 
rewrittenExpression.child(1));
+        Assertions.assertEquals(new BigDecimal("12.34"), ((DecimalV3Literal) 
rewrittenExpression.child(1)).getValue());
+
+        // = always not equals not null
+        leftChild = new DecimalV3Literal(new BigDecimal("11.24"));
+        left = new Cast(leftChild, DecimalV3Type.createDecimalV3Type(5, 3));
+        right = new DecimalV3Literal(new BigDecimal("12.345"));
+        expression = new EqualTo(left, right);
+        rewrittenExpression = executor.rewrite(expression, context);
+        Assertions.assertEquals(BooleanLiteral.FALSE, rewrittenExpression);
+
+        // = always not equals nullable
+        leftChild = new SlotReference("slot", 
DecimalV3Type.createDecimalV3Type(4, 2), true);
+        left = new Cast(leftChild, DecimalV3Type.createDecimalV3Type(5, 3));
+        right = new DecimalV3Literal(new BigDecimal("12.345"));
+        expression = new EqualTo(left, right);
+        rewrittenExpression = executor.rewrite(expression, context);
+        Assertions.assertEquals(new And(new IsNull(leftChild), new 
NullLiteral(BooleanType.INSTANCE)),
+                rewrittenExpression);
+
+        // <=> round UNNECESSARY
+        leftChild = new DecimalV3Literal(new BigDecimal("11.24"));
+        left = new Cast(leftChild, DecimalV3Type.createDecimalV3Type(5, 3));
+        right = new DecimalV3Literal(new BigDecimal("12.340"));
+        expression = new NullSafeEqual(left, right);
+        rewrittenExpression = executor.rewrite(expression, context);
+        Assertions.assertInstanceOf(DecimalV3Literal.class, 
rewrittenExpression.child(0));
+        Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(4, 2),
+                rewrittenExpression.child(0).getDataType());
+        Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(4, 2),
+                rewrittenExpression.child(1).getDataType());
+        Assertions.assertInstanceOf(DecimalV3Literal.class, 
rewrittenExpression.child(1));
+        Assertions.assertEquals(new BigDecimal("12.34"), ((DecimalV3Literal) 
rewrittenExpression.child(1)).getValue());
+
+        // <=> always not equals
+        leftChild = new DecimalV3Literal(new BigDecimal("11.24"));
+        left = new Cast(leftChild, DecimalV3Type.createDecimalV3Type(5, 3));
+        right = new DecimalV3Literal(new BigDecimal("12.345"));
+        expression = new NullSafeEqual(left, right);
+        rewrittenExpression = executor.rewrite(expression, context);
+        Assertions.assertEquals(BooleanLiteral.FALSE, rewrittenExpression);
+
+        // > right literal should round floor
+        leftChild = new DecimalV3Literal(new BigDecimal("1.24"));
+        left = new Cast(leftChild, DecimalV3Type.createDecimalV3Type(5, 3));
+        right = new DecimalV3Literal(new BigDecimal("12.345"));
+        expression = new GreaterThan(left, right);
+        rewrittenExpression = executor.rewrite(expression, context);
+        Assertions.assertInstanceOf(Cast.class, rewrittenExpression.child(0));
+        Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(4, 2),
+                rewrittenExpression.child(0).getDataType());
+        Assertions.assertInstanceOf(DecimalV3Literal.class, 
rewrittenExpression.child(1));
+        Assertions.assertEquals(new BigDecimal("12.34"), ((DecimalV3Literal) 
rewrittenExpression.child(1)).getValue());
+
+        // <= right literal should round floor
+        leftChild = new DecimalV3Literal(new BigDecimal("1.24"));
+        left = new Cast(leftChild, DecimalV3Type.createDecimalV3Type(5, 3));
+        right = new DecimalV3Literal(new BigDecimal("12.345"));
+        expression = new LessThanEqual(left, right);
+        rewrittenExpression = executor.rewrite(expression, context);
+        Assertions.assertInstanceOf(Cast.class, rewrittenExpression.child(0));
+        Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(4, 2),
+                rewrittenExpression.child(0).getDataType());
+        Assertions.assertInstanceOf(DecimalV3Literal.class, 
rewrittenExpression.child(1));
+        Assertions.assertEquals(new BigDecimal("12.34"), ((DecimalV3Literal) 
rewrittenExpression.child(1)).getValue());
+
+        // >= right literal should round ceiling
+        leftChild = new DecimalV3Literal(new BigDecimal("1.24"));
+        left = new Cast(leftChild, DecimalV3Type.createDecimalV3Type(5, 3));
+        right = new DecimalV3Literal(new BigDecimal("12.345"));
+        expression = new GreaterThanEqual(left, right);
+        rewrittenExpression = executor.rewrite(expression, context);
+        Assertions.assertInstanceOf(Cast.class, rewrittenExpression.child(0));
+        Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(4, 2),
+                rewrittenExpression.child(0).getDataType());
+        Assertions.assertInstanceOf(DecimalV3Literal.class, 
rewrittenExpression.child(1));
+        Assertions.assertEquals(new BigDecimal("12.35"), ((DecimalV3Literal) 
rewrittenExpression.child(1)).getValue());
+
+        // < right literal should round ceiling
+        leftChild = new DecimalV3Literal(new BigDecimal("1.24"));
+        left = new Cast(leftChild, DecimalV3Type.createDecimalV3Type(5, 3));
+        right = new DecimalV3Literal(new BigDecimal("12.345"));
+        expression = new LessThan(left, right);
+        rewrittenExpression = executor.rewrite(expression, context);
+        Assertions.assertInstanceOf(Cast.class, rewrittenExpression.child(0));
+        Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(4, 2),
+                rewrittenExpression.child(0).getDataType());
+        Assertions.assertInstanceOf(DecimalV3Literal.class, 
rewrittenExpression.child(1));
+        Assertions.assertEquals(new BigDecimal("12.35"), ((DecimalV3Literal) 
rewrittenExpression.child(1)).getValue());
+    }
 }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3ComparisonTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3ComparisonTest.java
index ff424e49711..edbbd872bac 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3ComparisonTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3ComparisonTest.java
@@ -17,40 +17,49 @@
 
 package org.apache.doris.nereids.rules.expression.rules;
 
-import org.apache.doris.common.Config;
 import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper;
 import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor;
+import org.apache.doris.nereids.trees.expressions.Cast;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
 import org.apache.doris.nereids.trees.expressions.Expression;
-import org.apache.doris.nereids.trees.expressions.Slot;
-import org.apache.doris.nereids.trees.expressions.SlotReference;
+import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
 import org.apache.doris.nereids.types.DecimalV3Type;
 
 import com.google.common.collect.ImmutableList;
 import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
-import java.util.HashMap;
-import java.util.Map;
+import java.math.BigDecimal;
 
 class SimplifyDecimalV3ComparisonTest extends ExpressionRewriteTestHelper {
 
     @Test
-    public void testSimplifyDecimalV3Comparison() {
-        Config.enable_decimal_conversion = false;
-        Map<String, Slot> nameToSlot = new HashMap<>();
-        nameToSlot.put("col1", new SlotReference("col1", 
DecimalV3Type.createDecimalV3Type(15, 2)));
+    void testChildScaleLargerThanCast() {
         executor = new 
ExpressionRuleExecutor(ImmutableList.of(SimplifyDecimalV3Comparison.INSTANCE));
-        assertRewriteAfterSimplify("cast(col1 as decimalv3(27, 9)) > 0.6", 
"cast(col1 as decimalv3(27, 9)) > 0.6", nameToSlot);
+        Expression leftChild = new DecimalV3Literal(new BigDecimal("1.23456"));
+        Expression left = new Cast(leftChild, 
DecimalV3Type.createDecimalV3Type(3, 2));
+        Expression right = new DecimalV3Literal(new BigDecimal("1.20"));
+        Expression expression = new EqualTo(left, right);
+        Expression rewrittenExpression = executor.rewrite(expression, context);
+        Assertions.assertInstanceOf(Cast.class, rewrittenExpression.child(0));
+        Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(3, 2),
+                rewrittenExpression.child(0).getDataType());
     }
 
-    private void assertRewriteAfterSimplify(String expr, String expected, 
Map<String, Slot> slotNameToSlot) {
-        Expression needRewriteExpression = PARSER.parseExpression(expr);
-        if (slotNameToSlot != null) {
-            needRewriteExpression = replaceUnboundSlot(needRewriteExpression, 
slotNameToSlot);
-        }
-        Expression rewritten = 
SimplifyDecimalV3Comparison.INSTANCE.rewrite(needRewriteExpression, context);
-        Expression expectedExpression = PARSER.parseExpression(expected);
-        Assertions.assertEquals(expectedExpression.toSql(), rewritten.toSql());
-    }
+    @Test
+    void testChildScaleSmallerThanCast() {
+        executor = new 
ExpressionRuleExecutor(ImmutableList.of(SimplifyDecimalV3Comparison.INSTANCE));
 
+        Expression leftChild = new DecimalV3Literal(new BigDecimal("1.23456"));
+        Expression left = new Cast(leftChild, 
DecimalV3Type.createDecimalV3Type(10, 9));
+        Expression right = new DecimalV3Literal(new BigDecimal("1.200000000"));
+        Expression expression = new EqualTo(left, right);
+        Expression rewrittenExpression = executor.rewrite(expression, context);
+        Assertions.assertInstanceOf(DecimalV3Literal.class, 
rewrittenExpression.child(0));
+        Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(6, 5),
+                rewrittenExpression.child(0).getDataType());
+        Assertions.assertInstanceOf(DecimalV3Literal.class, 
rewrittenExpression.child(1));
+        Assertions.assertEquals(new BigDecimal("1.20000"),
+                ((DecimalV3Literal) rewrittenExpression.child(1)).getValue());
+    }
 }
diff --git 
a/regression-test/suites/nereids_rules_p0/expression/test_simplify_decimal_comparison.groovy
 
b/regression-test/suites/nereids_rules_p0/expression/test_simplify_decimal_comparison.groovy
new file mode 100644
index 00000000000..103a66836c6
--- /dev/null
+++ 
b/regression-test/suites/nereids_rules_p0/expression/test_simplify_decimal_comparison.groovy
@@ -0,0 +1,28 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+suite("test_simplify_decimal_comparison") {
+    test {
+        sql """SELECT 1 FROM DUAL WHERE CAST(2.2222 AS DECIMAL(26, 2)) = 
2.22"""
+        result ([[1]])
+    }
+
+    test {
+        sql """ SELECT 1 FROM DUAL WHERE CAST(2.2222 AS DECIMAL(26, 2)) != 
2.22 """
+        result ([])
+    }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to