This is an automated email from the ASF dual-hosted git repository. kxiao pushed a commit to branch branch-2.0 in repository https://gitbox.apache.org/repos/asf/doris.git
commit 96588736bef0fa2ab37135b3fb10eabe5383f845 Author: 谢健 <[email protected]> AuthorDate: Wed Aug 16 17:19:23 2023 +0800 [enhancement](Nereids): count(1) to count(*) #22999 add a rule to transform count(1) to count(*) --- .../doris/nereids/jobs/executor/Rewriter.java | 2 + .../org/apache/doris/nereids/rules/RuleType.java | 3 +- .../rules/rewrite/CountLiteralToCountStar.java | 71 ++++++++++++++++++++++ .../sub_query_count_with_const.groovy | 8 +++ 4 files changed, 83 insertions(+), 1 deletion(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index 6902925809..3b080b5b46 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -46,6 +46,7 @@ import org.apache.doris.nereids.rules.rewrite.CollectProjectAboveConsumer; import org.apache.doris.nereids.rules.rewrite.ColumnPruning; import org.apache.doris.nereids.rules.rewrite.ConvertInnerOrCrossJoin; import org.apache.doris.nereids.rules.rewrite.CountDistinctRewrite; +import org.apache.doris.nereids.rules.rewrite.CountLiteralToCountStar; import org.apache.doris.nereids.rules.rewrite.DeferMaterializeTopNResult; import org.apache.doris.nereids.rules.rewrite.EliminateAggregate; import org.apache.doris.nereids.rules.rewrite.EliminateDedupJoinCondition; @@ -173,6 +174,7 @@ public class Rewriter extends AbstractBatchJobExecutor { topDown( new SimplifyAggGroupBy(), new NormalizeAggregate(), + new CountLiteralToCountStar(), new NormalizeSort() ), topic("Window analysis", diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index a124882378..4015a30ef6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -57,8 +57,9 @@ public enum RuleType { BINDING_SET_OPERATION_SLOT(RuleTypeClass.REWRITE), BINDING_GENERATE_FUNCTION(RuleTypeClass.REWRITE), BINDING_INSERT_TARGET_TABLE(RuleTypeClass.REWRITE), - BINDING_INSERT_FILE(RuleTypeClass.REWRITE), + BINDING_INSERT_FILE(RuleTypeClass.REWRITE), + COUNT_LITERAL_TO_COUNT_STAR(RuleTypeClass.REWRITE), REPLACE_SORT_EXPRESSION_BY_CHILD_OUTPUT(RuleTypeClass.REWRITE), FILL_UP_HAVING_AGGREGATE(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CountLiteralToCountStar.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CountLiteralToCountStar.java new file mode 100644 index 0000000000..d50d8b5a8e --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CountLiteralToCountStar.java @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; + +import com.google.common.collect.Lists; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * count(1) ==> count(*) + */ +public class CountLiteralToCountStar extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalAggregate().then( + agg -> { + List<NamedExpression> newExprs = Lists.newArrayListWithCapacity(agg.getOutputExpressions().size()); + if (rewriteCountLiteral(agg.getOutputExpressions(), newExprs)) { + return agg.withAggOutput(newExprs); + } + return agg; + } + ).toRule(RuleType.COUNT_LITERAL_TO_COUNT_STAR); + } + + private boolean rewriteCountLiteral(List<NamedExpression> oldExprs, List<NamedExpression> newExprs) { + boolean changed = false; + for (Expression expr : oldExprs) { + Map<Expression, Expression> replaced = new HashMap<>(); + Set<AggregateFunction> oldAggFuncSet = expr.collect(AggregateFunction.class::isInstance); + oldAggFuncSet.stream() + .filter(this::isCountLiteral) + .forEach(c -> replaced.put(c, new Count())); + expr = expr.rewriteUp(s -> replaced.getOrDefault(s, s)); + changed = !replaced.isEmpty(); + newExprs.add((NamedExpression) expr); + } + return changed; + } + + private boolean isCountLiteral(AggregateFunction aggFunc) { + return !aggFunc.isDistinct() + && aggFunc instanceof Count + && aggFunc.children().stream().allMatch(e -> e.isLiteral() && !e.isNullLiteral()); + } +} diff --git a/regression-test/suites/nereids_syntax_p0/sub_query_count_with_const.groovy b/regression-test/suites/nereids_syntax_p0/sub_query_count_with_const.groovy index 02d6453aef..c3168444bb 100644 --- a/regression-test/suites/nereids_syntax_p0/sub_query_count_with_const.groovy +++ b/regression-test/suites/nereids_syntax_p0/sub_query_count_with_const.groovy @@ -40,4 +40,12 @@ suite("sub_query_count_with_const") { select 2022 as dt ,sum(id) from sub_query_count_with_const ) tmp;""" + explain { + sql("""select count(1) as count + from ( + select 2022 as dt ,sum(id) + from sub_query_count_with_const + ) tmp;""") + contains "count(*)" + } } --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
