This is an automated email from the ASF dual-hosted git repository. yiguolei pushed a commit to branch branch-2.1 in repository https://gitbox.apache.org/repos/asf/doris.git
commit 4af3fd2a2e29696d0350d834b29a528c7a5a40ad Author: minghong <[email protected]> AuthorDate: Tue Jan 23 19:07:28 2024 +0800 [fix](Nereids) fix bug in case-when/if stats estimation (#30265) --- .../doris/nereids/stats/ExpressionEstimation.java | 68 ++++++++++++++++------ .../nereids/stats/ExpressionEstimationTest.java | 63 ++++++++++++++++++++ 2 files changed, 114 insertions(+), 17 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java index acf6bddfe57..874f47a50af 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/ExpressionEstimation.java @@ -39,6 +39,7 @@ import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.Subtract; import org.apache.doris.nereids.trees.expressions.TimestampArithmetic; import org.apache.doris.nereids.trees.expressions.VirtualSlotReference; +import org.apache.doris.nereids.trees.expressions.WhenClause; import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.Avg; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; @@ -137,21 +138,37 @@ public class ExpressionEstimation extends ExpressionVisitor<ColumnStatistic, Sta //TODO: case-when need to re-implemented @Override public ColumnStatistic visitCaseWhen(CaseWhen caseWhen, Statistics context) { + double ndv = caseWhen.getWhenClauses().size(); + if (caseWhen.getDefaultValue().isPresent()) { + ndv += 1; + } + for (WhenClause clause : caseWhen.getWhenClauses()) { + ColumnStatistic colStats = ExpressionEstimation.estimate(clause.getResult(), context); + ndv = Math.max(ndv, colStats.ndv); + } + if (caseWhen.getDefaultValue().isPresent()) { + ColumnStatistic colStats = ExpressionEstimation.estimate(caseWhen.getDefaultValue().get(), context); + ndv = Math.max(ndv, colStats.ndv); + } return new ColumnStatisticBuilder() - .setNdv(caseWhen.getWhenClauses().size() + 1) - .setMinValue(0) - .setMaxValue(Double.MAX_VALUE) + .setNdv(ndv) + .setMinValue(Double.NEGATIVE_INFINITY) + .setMaxValue(Double.POSITIVE_INFINITY) .setAvgSizeByte(8) .setNumNulls(0) .build(); } @Override - public ColumnStatistic visitIf(If function, Statistics context) { - // TODO: copy from visitCaseWhen, polish them. + public ColumnStatistic visitIf(If ifClause, Statistics context) { + double ndv = 2; + ColumnStatistic colStatsThen = ExpressionEstimation.estimate(ifClause.child(1), context); + ndv = Math.max(ndv, colStatsThen.ndv); + ColumnStatistic colStatsElse = ExpressionEstimation.estimate(ifClause.child(2), context); + ndv = Math.max(ndv, colStatsElse.ndv); return new ColumnStatisticBuilder() - .setNdv(2) - .setMinValue(0) + .setNdv(ndv) + .setMinValue(Double.NEGATIVE_INFINITY) .setMaxValue(Double.POSITIVE_INFINITY) .setAvgSizeByte(8) .setNumNulls(0) @@ -577,13 +594,22 @@ public class ExpressionEstimation extends ExpressionVisitor<ColumnStatistic, Sta if (childColumnStats.minOrMaxIsInf()) { return columnStatisticBuilder.build(); } - double minValue = getDatetimeFromLong((long) childColumnStats.minValue).toLocalDate() - .atStartOfDay(ZoneId.systemDefault()).toEpochSecond(); - double maxValue = getDatetimeFromLong((long) childColumnStats.maxValue).toLocalDate() - .atStartOfDay(ZoneId.systemDefault()).toEpochSecond(); + double minValue; + double maxValue; + try { + // min/max value is infinite, but they may be too large to convert to date + minValue = getDatetimeFromLong((long) childColumnStats.minValue).toLocalDate() + .atStartOfDay(ZoneId.systemDefault()).toEpochSecond(); + maxValue = getDatetimeFromLong((long) childColumnStats.maxValue).toLocalDate() + .atStartOfDay(ZoneId.systemDefault()).toEpochSecond(); + } catch (Exception e) { + // ignore DateTimeException + minValue = Double.NEGATIVE_INFINITY; + maxValue = Double.POSITIVE_INFINITY; + } return columnStatisticBuilder.setMaxValue(maxValue) - .setMinValue(minValue) - .build(); + .setMinValue(minValue).build(); + } private LocalDateTime getDatetimeFromLong(long dateTime) { @@ -599,10 +625,18 @@ public class ExpressionEstimation extends ExpressionVisitor<ColumnStatistic, Sta if (childColumnStats.minOrMaxIsInf()) { return columnStatisticBuilder.build(); } - double minValue = getDatetimeFromLong((long) childColumnStats.minValue).toLocalDate().toEpochDay() - + (double) DAYS_FROM_0_TO_1970; - double maxValue = getDatetimeFromLong((long) childColumnStats.maxValue).toLocalDate().toEpochDay() - + (double) DAYS_FROM_0_TO_1970; + double minValue; + double maxValue; + try { + minValue = getDatetimeFromLong((long) childColumnStats.minValue).toLocalDate().toEpochDay() + + (double) DAYS_FROM_0_TO_1970; + maxValue = getDatetimeFromLong((long) childColumnStats.maxValue).toLocalDate().toEpochDay() + + (double) DAYS_FROM_0_TO_1970; + } catch (Exception e) { + // ignore DateTimeException + minValue = Double.NEGATIVE_INFINITY; + maxValue = Double.POSITIVE_INFINITY; + } return columnStatisticBuilder.setMaxValue(maxValue) .setMinValue(minValue) .build(); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/ExpressionEstimationTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/ExpressionEstimationTest.java index 7368735365c..1748802e4dd 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/ExpressionEstimationTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/ExpressionEstimationTest.java @@ -20,14 +20,18 @@ package org.apache.doris.nereids.stats; import org.apache.doris.analysis.DateLiteral; import org.apache.doris.analysis.StringLiteral; import org.apache.doris.nereids.trees.expressions.Add; +import org.apache.doris.nereids.trees.expressions.CaseWhen; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Divide; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Multiply; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.Subtract; +import org.apache.doris.nereids.trees.expressions.WhenClause; import org.apache.doris.nereids.trees.expressions.functions.agg.Max; import org.apache.doris.nereids.trees.expressions.functions.agg.Min; +import org.apache.doris.nereids.trees.expressions.functions.scalar.If; +import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; import org.apache.doris.nereids.types.DateType; import org.apache.doris.nereids.types.DoubleType; import org.apache.doris.nereids.types.IntegerType; @@ -40,7 +44,9 @@ import org.apache.commons.math3.util.Precision; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; class ExpressionEstimationTest { @@ -321,4 +327,61 @@ class ExpressionEstimationTest { Assertions.assertNull(est.minExpr); Assertions.assertNull(est.maxExpr); } + + @Test + public void testCaseWhen() { + SlotReference a = new SlotReference("a", StringType.INSTANCE); + Map<Expression, ColumnStatistic> slotToColumnStat = new HashMap<>(); + ColumnStatisticBuilder builder = new ColumnStatisticBuilder() + .setNdv(100) + .setMinExpr(new StringLiteral("2020-01-01")) + .setMinValue(20200101000000.0) + .setMaxExpr(new StringLiteral("2021abcdefg")) + .setMaxValue(20210101000000.0); + slotToColumnStat.put(a, builder.build()); + SlotReference b = new SlotReference("b", StringType.INSTANCE); + builder = new ColumnStatisticBuilder() + .setNdv(10) + .setMinExpr(new StringLiteral("2020-01-01")) + .setMinValue(20200101000000.0) + .setMaxExpr(new StringLiteral("2021abcdefg")) + .setMaxValue(20210101000000.0); + slotToColumnStat.put(b, builder.build()); + Statistics stats = new Statistics(1000, slotToColumnStat); + + WhenClause when1 = new WhenClause(BooleanLiteral.TRUE, a); + WhenClause when2 = new WhenClause(BooleanLiteral.FALSE, b); + List<WhenClause> whens = new ArrayList<>(); + whens.add(when1); + whens.add(when2); + CaseWhen caseWhen = new CaseWhen(whens); + ColumnStatistic est = ExpressionEstimation.estimate(caseWhen, stats); + Assertions.assertEquals(est.ndv, 100); + } + + @Test + public void testIf() { + SlotReference a = new SlotReference("a", StringType.INSTANCE); + Map<Expression, ColumnStatistic> slotToColumnStat = new HashMap<>(); + ColumnStatisticBuilder builder = new ColumnStatisticBuilder() + .setNdv(100) + .setMinExpr(new StringLiteral("2020-01-01")) + .setMinValue(20200101000000.0) + .setMaxExpr(new StringLiteral("2021abcdefg")) + .setMaxValue(20210101000000.0); + slotToColumnStat.put(a, builder.build()); + SlotReference b = new SlotReference("b", StringType.INSTANCE); + builder = new ColumnStatisticBuilder() + .setNdv(10) + .setMinExpr(new StringLiteral("2020-01-01")) + .setMinValue(20200101000000.0) + .setMaxExpr(new StringLiteral("2021abcdefg")) + .setMaxValue(20210101000000.0); + slotToColumnStat.put(b, builder.build()); + Statistics stats = new Statistics(1000, slotToColumnStat); + + If ifClause = new If(BooleanLiteral.TRUE, a, b); + ColumnStatistic est = ExpressionEstimation.estimate(ifClause, stats); + Assertions.assertEquals(est.ndv, 100); + } } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
