This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 4af3fd2a2e29696d0350d834b29a528c7a5a40ad
Author: minghong <[email protected]>
AuthorDate: Tue Jan 23 19:07:28 2024 +0800

    [fix](Nereids) fix bug in case-when/if stats estimation (#30265)
---
 .../doris/nereids/stats/ExpressionEstimation.java  | 68 ++++++++++++++++------
 .../nereids/stats/ExpressionEstimationTest.java    | 63 ++++++++++++++++++++
 2 files changed, 114 insertions(+), 17 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java
index acf6bddfe57..874f47a50af 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java
@@ -39,6 +39,7 @@ import 
org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.Subtract;
 import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
 import org.apache.doris.nereids.trees.expressions.VirtualSlotReference;
+import org.apache.doris.nereids.trees.expressions.WhenClause;
 import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Avg;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
@@ -137,21 +138,37 @@ public class ExpressionEstimation extends 
ExpressionVisitor<ColumnStatistic, Sta
     //TODO: case-when need to re-implemented
     @Override
     public ColumnStatistic visitCaseWhen(CaseWhen caseWhen, Statistics 
context) {
+        double ndv = caseWhen.getWhenClauses().size();
+        if (caseWhen.getDefaultValue().isPresent()) {
+            ndv += 1;
+        }
+        for (WhenClause clause : caseWhen.getWhenClauses()) {
+            ColumnStatistic colStats = 
ExpressionEstimation.estimate(clause.getResult(), context);
+            ndv = Math.max(ndv, colStats.ndv);
+        }
+        if (caseWhen.getDefaultValue().isPresent()) {
+            ColumnStatistic colStats = 
ExpressionEstimation.estimate(caseWhen.getDefaultValue().get(), context);
+            ndv = Math.max(ndv, colStats.ndv);
+        }
         return new ColumnStatisticBuilder()
-                .setNdv(caseWhen.getWhenClauses().size() + 1)
-                .setMinValue(0)
-                .setMaxValue(Double.MAX_VALUE)
+                .setNdv(ndv)
+                .setMinValue(Double.NEGATIVE_INFINITY)
+                .setMaxValue(Double.POSITIVE_INFINITY)
                 .setAvgSizeByte(8)
                 .setNumNulls(0)
                 .build();
     }
 
     @Override
-    public ColumnStatistic visitIf(If function, Statistics context) {
-        // TODO: copy from visitCaseWhen, polish them.
+    public ColumnStatistic visitIf(If ifClause, Statistics context) {
+        double ndv = 2;
+        ColumnStatistic colStatsThen = 
ExpressionEstimation.estimate(ifClause.child(1), context);
+        ndv = Math.max(ndv, colStatsThen.ndv);
+        ColumnStatistic colStatsElse = 
ExpressionEstimation.estimate(ifClause.child(2), context);
+        ndv = Math.max(ndv, colStatsElse.ndv);
         return new ColumnStatisticBuilder()
-                .setNdv(2)
-                .setMinValue(0)
+                .setNdv(ndv)
+                .setMinValue(Double.NEGATIVE_INFINITY)
                 .setMaxValue(Double.POSITIVE_INFINITY)
                 .setAvgSizeByte(8)
                 .setNumNulls(0)
@@ -577,13 +594,22 @@ public class ExpressionEstimation extends 
ExpressionVisitor<ColumnStatistic, Sta
         if (childColumnStats.minOrMaxIsInf()) {
             return columnStatisticBuilder.build();
         }
-        double minValue = getDatetimeFromLong((long) 
childColumnStats.minValue).toLocalDate()
-                .atStartOfDay(ZoneId.systemDefault()).toEpochSecond();
-        double maxValue = getDatetimeFromLong((long) 
childColumnStats.maxValue).toLocalDate()
-                .atStartOfDay(ZoneId.systemDefault()).toEpochSecond();
+        double minValue;
+        double maxValue;
+        try {
+            // min/max value is infinite, but they may be too large to convert 
to date
+            minValue = getDatetimeFromLong((long) 
childColumnStats.minValue).toLocalDate()
+                    .atStartOfDay(ZoneId.systemDefault()).toEpochSecond();
+            maxValue = getDatetimeFromLong((long) 
childColumnStats.maxValue).toLocalDate()
+                    .atStartOfDay(ZoneId.systemDefault()).toEpochSecond();
+        } catch (Exception e) {
+            // ignore DateTimeException
+            minValue = Double.NEGATIVE_INFINITY;
+            maxValue = Double.POSITIVE_INFINITY;
+        }
         return columnStatisticBuilder.setMaxValue(maxValue)
-                .setMinValue(minValue)
-                .build();
+                .setMinValue(minValue).build();
+
     }
 
     private LocalDateTime getDatetimeFromLong(long dateTime) {
@@ -599,10 +625,18 @@ public class ExpressionEstimation extends 
ExpressionVisitor<ColumnStatistic, Sta
         if (childColumnStats.minOrMaxIsInf()) {
             return columnStatisticBuilder.build();
         }
-        double minValue = getDatetimeFromLong((long) 
childColumnStats.minValue).toLocalDate().toEpochDay()
-                + (double) DAYS_FROM_0_TO_1970;
-        double maxValue = getDatetimeFromLong((long) 
childColumnStats.maxValue).toLocalDate().toEpochDay()
-                + (double) DAYS_FROM_0_TO_1970;
+        double minValue;
+        double maxValue;
+        try {
+            minValue = getDatetimeFromLong((long) 
childColumnStats.minValue).toLocalDate().toEpochDay()
+                    + (double) DAYS_FROM_0_TO_1970;
+            maxValue = getDatetimeFromLong((long) 
childColumnStats.maxValue).toLocalDate().toEpochDay()
+                    + (double) DAYS_FROM_0_TO_1970;
+        } catch (Exception e) {
+            // ignore DateTimeException
+            minValue = Double.NEGATIVE_INFINITY;
+            maxValue = Double.POSITIVE_INFINITY;
+        }
         return columnStatisticBuilder.setMaxValue(maxValue)
                 .setMinValue(minValue)
                 .build();
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/ExpressionEstimationTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/ExpressionEstimationTest.java
index 7368735365c..1748802e4dd 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/ExpressionEstimationTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/ExpressionEstimationTest.java
@@ -20,14 +20,18 @@ package org.apache.doris.nereids.stats;
 import org.apache.doris.analysis.DateLiteral;
 import org.apache.doris.analysis.StringLiteral;
 import org.apache.doris.nereids.trees.expressions.Add;
+import org.apache.doris.nereids.trees.expressions.CaseWhen;
 import org.apache.doris.nereids.trees.expressions.Cast;
 import org.apache.doris.nereids.trees.expressions.Divide;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.Multiply;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.Subtract;
+import org.apache.doris.nereids.trees.expressions.WhenClause;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
+import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
 import org.apache.doris.nereids.types.DateType;
 import org.apache.doris.nereids.types.DoubleType;
 import org.apache.doris.nereids.types.IntegerType;
@@ -40,7 +44,9 @@ import org.apache.commons.math3.util.Precision;
 import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
 
+import java.util.ArrayList;
 import java.util.HashMap;
+import java.util.List;
 import java.util.Map;
 
 class ExpressionEstimationTest {
@@ -321,4 +327,61 @@ class ExpressionEstimationTest {
         Assertions.assertNull(est.minExpr);
         Assertions.assertNull(est.maxExpr);
     }
+
+    @Test
+    public void testCaseWhen() {
+        SlotReference a = new SlotReference("a", StringType.INSTANCE);
+        Map<Expression, ColumnStatistic> slotToColumnStat = new HashMap<>();
+        ColumnStatisticBuilder builder = new ColumnStatisticBuilder()
+                .setNdv(100)
+                .setMinExpr(new StringLiteral("2020-01-01"))
+                .setMinValue(20200101000000.0)
+                .setMaxExpr(new StringLiteral("2021abcdefg"))
+                .setMaxValue(20210101000000.0);
+        slotToColumnStat.put(a, builder.build());
+        SlotReference b = new SlotReference("b", StringType.INSTANCE);
+        builder = new ColumnStatisticBuilder()
+                .setNdv(10)
+                .setMinExpr(new StringLiteral("2020-01-01"))
+                .setMinValue(20200101000000.0)
+                .setMaxExpr(new StringLiteral("2021abcdefg"))
+                .setMaxValue(20210101000000.0);
+        slotToColumnStat.put(b, builder.build());
+        Statistics stats = new Statistics(1000, slotToColumnStat);
+
+        WhenClause when1 = new WhenClause(BooleanLiteral.TRUE, a);
+        WhenClause when2 = new WhenClause(BooleanLiteral.FALSE, b);
+        List<WhenClause> whens = new ArrayList<>();
+        whens.add(when1);
+        whens.add(when2);
+        CaseWhen caseWhen = new CaseWhen(whens);
+        ColumnStatistic est = ExpressionEstimation.estimate(caseWhen, stats);
+        Assertions.assertEquals(est.ndv, 100);
+    }
+
+    @Test
+    public void testIf() {
+        SlotReference a = new SlotReference("a", StringType.INSTANCE);
+        Map<Expression, ColumnStatistic> slotToColumnStat = new HashMap<>();
+        ColumnStatisticBuilder builder = new ColumnStatisticBuilder()
+                .setNdv(100)
+                .setMinExpr(new StringLiteral("2020-01-01"))
+                .setMinValue(20200101000000.0)
+                .setMaxExpr(new StringLiteral("2021abcdefg"))
+                .setMaxValue(20210101000000.0);
+        slotToColumnStat.put(a, builder.build());
+        SlotReference b = new SlotReference("b", StringType.INSTANCE);
+        builder = new ColumnStatisticBuilder()
+                .setNdv(10)
+                .setMinExpr(new StringLiteral("2020-01-01"))
+                .setMinValue(20200101000000.0)
+                .setMaxExpr(new StringLiteral("2021abcdefg"))
+                .setMaxValue(20210101000000.0);
+        slotToColumnStat.put(b, builder.build());
+        Statistics stats = new Statistics(1000, slotToColumnStat);
+
+        If ifClause = new If(BooleanLiteral.TRUE, a, b);
+        ColumnStatistic est = ExpressionEstimation.estimate(ifClause, stats);
+        Assertions.assertEquals(est.ndv, 100);
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to