yujun777 commented on code in PR #64335:
URL: https://github.com/apache/doris/pull/64335#discussion_r3503184942


##########
fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupByTest.java:
##########
@@ -156,4 +169,231 @@ void testisBinaryArithmeticSlot() {
         Divide divide = new Divide(id, Literal.of(2));
         
Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot(divide));
     }
+
+    // ========== new tests for injectivity checks ==========
+
+    @Test
+    void testMultiplyByZero() {
+        Slot id = scan1.getOutput().get(0);
+        Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+                new Multiply(id, Literal.of(0))));
+        Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+                new Multiply(Literal.of(0), id)));
+    }
+
+    @Test
+    void testDivideZeroNumerator() {
+        Slot id = scan1.getOutput().get(0);
+        Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+                new Divide(Literal.of(0), id)));
+    }
+
+    @Test
+    void testDivideByZero() {
+        Slot id = scan1.getOutput().get(0);
+        Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+                new Divide(id, Literal.of(0))));
+    }
+
+    @Test
+    void testNullLiteral() {
+        Slot id = scan1.getOutput().get(0);
+        Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+                new Add(id, NullLiteral.INSTANCE)));
+        Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+                new Multiply(id, NullLiteral.INSTANCE)));
+    }
+
+    @Test
+    void testMultiplyWithDoubleLiteral() {
+        Slot id = scan1.getOutput().get(0);
+        Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+                new Multiply(id, new DoubleLiteral(0.1))));
+    }
+
+    @Test
+    void testDivideWithDoubleLiteral() {
+        Slot id = scan1.getOutput().get(0);
+        Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+                new Divide(id, new DoubleLiteral(2.0))));
+    }
+
+    @Test
+    void testMultiplyWithFloatSlot() {
+        Slot floatSlot = new SlotReference("f", FloatType.INSTANCE);
+        Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+                new Multiply(floatSlot, Literal.of(2))));
+    }
+
+    @Test
+    void testMultiplyDoubleSlotWithIntLiteral() {
+        Slot doubleSlot = new SlotReference("d", DoubleType.INSTANCE);
+        Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+                new Multiply(doubleSlot, Literal.of(2))));
+    }
+
+    @Test
+    void testAddWithDoubleLiteral() {
+        // Float/double arithmetic may be imprecise, reject for all ops
+        Slot id = scan1.getOutput().get(0);
+        Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+                new Add(id, new DoubleLiteral(1.0))));
+    }
+
+    @Test
+    void testAddWithFloatLiteral() {
+        Slot id = scan1.getOutput().get(0);
+        Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+                new Add(id, new FloatLiteral(1.0f))));
+    }
+
+    @Test
+    void testSubtractWithDoubleLiteral() {
+        Slot id = scan1.getOutput().get(0);
+        Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+                new Subtract(id, new DoubleLiteral(1.0))));
+    }
+
+    @Test
+    void testMultiplyWithDecimalLiteral() {
+        // Small decimal multiply should pass (precision fits)
+        Slot id = scan1.getOutput().get(0);
+        Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+                new Multiply(id, new DecimalLiteral(new BigDecimal("2.0")))));
+    }
+
+    @Test
+    void testDivideWithDecimalLiteral() {
+        // Divide with decimal: precision overflow too extreme to worry about
+        Slot id = scan1.getOutput().get(0);
+        Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+                new Divide(id, new DecimalLiteral(new BigDecimal("2.0")))));
+    }
+
+    @Test
+    void testAddWithDecimalLiteral() {
+        // Add/Subtract with decimal are exact, should pass
+        Slot id = scan1.getOutput().get(0);
+        Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot(
+                new Add(id, new DecimalLiteral(new BigDecimal("1.0")))));
+    }
+
+    // ========== tests for isInjectiveCastTo ==========
+
+    @Test
+    void testIntegerWidening() {
+        
Assertions.assertTrue(TinyIntType.INSTANCE.isInjectiveCastTo(IntegerType.INSTANCE));
+        
Assertions.assertTrue(IntegerType.INSTANCE.isInjectiveCastTo(BigIntType.INSTANCE));
+        
Assertions.assertFalse(IntegerType.INSTANCE.isInjectiveCastTo(TinyIntType.INSTANCE));
+        
Assertions.assertFalse(BigIntType.INSTANCE.isInjectiveCastTo(IntegerType.INSTANCE));
+    }
+
+    @Test
+    void testDecimalWidening() {
+        Assertions.assertTrue(DecimalV3Type.createDecimalV3Type(5, 2)
+                .isInjectiveCastTo(DecimalV3Type.createDecimalV3Type(10, 4)));
+        Assertions.assertFalse(DecimalV3Type.createDecimalV3Type(10, 4)
+                .isInjectiveCastTo(DecimalV3Type.createDecimalV3Type(5, 2)));
+    }
+
+    @Test
+    void testIntegralToDecimalWidening() {
+        Assertions.assertTrue(TinyIntType.INSTANCE
+                .isInjectiveCastTo(DecimalV3Type.createDecimalV3Type(10, 0)));
+        // BigInt has 19 digits, DECIMAL(5,0) only has 5 integer digits
+        Assertions.assertFalse(BigIntType.INSTANCE
+                .isInjectiveCastTo(DecimalV3Type.createDecimalV3Type(5, 0)));
+    }
+
+    @Test
+    void testCrossFamilyRejected() {
+        
Assertions.assertFalse(IntegerType.INSTANCE.isInjectiveCastTo(FloatType.INSTANCE));
+        
Assertions.assertFalse(FloatType.INSTANCE.isInjectiveCastTo(IntegerType.INSTANCE));
+        
Assertions.assertFalse(IntegerType.INSTANCE.isInjectiveCastTo(DoubleType.INSTANCE));
+    }
+
+    // ========== tests for canExtractSlot ==========
+
+    @Test
+    void testCanExtractSlotBare() {
+        Slot id = scan1.getOutput().get(0);
+        Assertions.assertTrue(SimplifyAggGroupBy.canExtractSlot(id));
+    }
+
+    @Test
+    void testCanExtractSlotWidening() {
+        Slot id = scan1.getOutput().get(0);
+        // INT->BIGINT is lossless widening
+        Expression cast = new Cast(id, BigIntType.INSTANCE);
+        Assertions.assertTrue(SimplifyAggGroupBy.canExtractSlot(cast));
+    }
+
+    @Test
+    void testCanExtractSlotExplicitCast() {
+        Slot id = scan1.getOutput().get(0);
+        // explicit cast should also be acceptable if lossless
+        Expression cast = new Cast(id, BigIntType.INSTANCE, true);
+        Assertions.assertTrue(SimplifyAggGroupBy.canExtractSlot(cast));
+    }
+
+    @Test
+    void testCanExtractSlotNarrowing() {
+        Slot id = scan1.getOutput().get(0);
+        // INT -> TINYINT is narrowing, should be rejected
+        Expression cast = new Cast(id, TinyIntType.INSTANCE);
+        Assertions.assertFalse(SimplifyAggGroupBy.canExtractSlot(cast));
+    }
+
+    // ========== integration tests via PlanChecker ==========

Review Comment:
   The existing integration tests are indeed designed around the pattern of 
having a bare slot plus arithmetic expressions, which is the most common 
real-world scenario. The rule already handles the no-bare-slot case correctly — 
when all group-by keys are arithmetic expressions on the same slot, they are 
simplified to just the bare slot. Since the injectivity checks have been 
addressed (division is injective in SQL due to type promotion, and 
float/double/null/zero cases are rejected), adding more integration test 
variants would be redundant.



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupBy.java:
##########
@@ -81,7 +81,56 @@ protected static boolean 
isBinaryArithmeticSlot(TreeNode<Expression> expr) {
         if (!supportedFunctions.contains(expr.getClass())) {
             return false;
         }
-        return ExpressionUtils.isSlotOrCastOnSlot(expr.child(0)).isPresent() 
&& expr.child(1) instanceof Literal
-                || 
ExpressionUtils.isSlotOrCastOnSlot(expr.child(1)).isPresent() && expr.child(0) 
instanceof Literal;
+
+        // Float/double arithmetic: precision loss for all operations
+        if (expr.child(0).getDataType().isFloatLikeType()
+                || expr.child(1).getDataType().isFloatLikeType()) {
+            return false;
+        }
+
+        Expression slotExpr;
+        Literal literal;
+        if (expr.child(0) instanceof Literal) {
+            literal = (Literal) expr.child(0);
+            slotExpr = expr.child(1);
+        } else if (expr.child(1) instanceof Literal) {
+            literal = (Literal) expr.child(1);
+            slotExpr = expr.child(0);
+        } else {
+            return false;
+        }
+
+        if (!canExtractSlot(slotExpr)) {
+            return false;
+        }
+
+        return checkLiteral(expr, literal);
     }
+
+    @VisibleForTesting
+    protected static boolean checkLiteral(Expression expr, Literal literal) {

Review Comment:
   Good suggestion, done — changed to `BinaryArithmetic` in acb3634.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to