This is an automated email from the ASF dual-hosted git repository.

kxiao pushed a commit to branch branch-2.0
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 96588736bef0fa2ab37135b3fb10eabe5383f845
Author: 谢健 <[email protected]>
AuthorDate: Wed Aug 16 17:19:23 2023 +0800

    [enhancement](Nereids): count(1) to count(*) #22999
    
    add a rule to transform count(1) to count(*)
---
 .../doris/nereids/jobs/executor/Rewriter.java      |  2 +
 .../org/apache/doris/nereids/rules/RuleType.java   |  3 +-
 .../rules/rewrite/CountLiteralToCountStar.java     | 71 ++++++++++++++++++++++
 .../sub_query_count_with_const.groovy              |  8 +++
 4 files changed, 83 insertions(+), 1 deletion(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
index 6902925809..3b080b5b46 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java
@@ -46,6 +46,7 @@ import 
org.apache.doris.nereids.rules.rewrite.CollectProjectAboveConsumer;
 import org.apache.doris.nereids.rules.rewrite.ColumnPruning;
 import org.apache.doris.nereids.rules.rewrite.ConvertInnerOrCrossJoin;
 import org.apache.doris.nereids.rules.rewrite.CountDistinctRewrite;
+import org.apache.doris.nereids.rules.rewrite.CountLiteralToCountStar;
 import org.apache.doris.nereids.rules.rewrite.DeferMaterializeTopNResult;
 import org.apache.doris.nereids.rules.rewrite.EliminateAggregate;
 import org.apache.doris.nereids.rules.rewrite.EliminateDedupJoinCondition;
@@ -173,6 +174,7 @@ public class Rewriter extends AbstractBatchJobExecutor {
             topDown(
                     new SimplifyAggGroupBy(),
                     new NormalizeAggregate(),
+                    new CountLiteralToCountStar(),
                     new NormalizeSort()
             ),
             topic("Window analysis",
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index a124882378..4015a30ef6 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -57,8 +57,9 @@ public enum RuleType {
     BINDING_SET_OPERATION_SLOT(RuleTypeClass.REWRITE),
     BINDING_GENERATE_FUNCTION(RuleTypeClass.REWRITE),
     BINDING_INSERT_TARGET_TABLE(RuleTypeClass.REWRITE),
-    BINDING_INSERT_FILE(RuleTypeClass.REWRITE),
 
+    BINDING_INSERT_FILE(RuleTypeClass.REWRITE),
+    COUNT_LITERAL_TO_COUNT_STAR(RuleTypeClass.REWRITE),
     REPLACE_SORT_EXPRESSION_BY_CHILD_OUTPUT(RuleTypeClass.REWRITE),
 
     FILL_UP_HAVING_AGGREGATE(RuleTypeClass.REWRITE),
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CountLiteralToCountStar.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CountLiteralToCountStar.java
new file mode 100644
index 0000000000..d50d8b5a8e
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CountLiteralToCountStar.java
@@ -0,0 +1,71 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite;
+
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+
+import com.google.common.collect.Lists;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * count(1) ==> count(*)
+ */
+public class CountLiteralToCountStar extends OneRewriteRuleFactory {
+    @Override
+    public Rule build() {
+        return logicalAggregate().then(
+                agg -> {
+                    List<NamedExpression> newExprs = 
Lists.newArrayListWithCapacity(agg.getOutputExpressions().size());
+                    if (rewriteCountLiteral(agg.getOutputExpressions(), 
newExprs)) {
+                        return agg.withAggOutput(newExprs);
+                    }
+                    return agg;
+                }
+        ).toRule(RuleType.COUNT_LITERAL_TO_COUNT_STAR);
+    }
+
+    private boolean rewriteCountLiteral(List<NamedExpression> oldExprs, 
List<NamedExpression> newExprs) {
+        boolean changed = false;
+        for (Expression expr : oldExprs) {
+            Map<Expression, Expression> replaced = new HashMap<>();
+            Set<AggregateFunction> oldAggFuncSet = 
expr.collect(AggregateFunction.class::isInstance);
+            oldAggFuncSet.stream()
+                    .filter(this::isCountLiteral)
+                    .forEach(c -> replaced.put(c, new Count()));
+            expr = expr.rewriteUp(s -> replaced.getOrDefault(s, s));
+            changed = !replaced.isEmpty();
+            newExprs.add((NamedExpression) expr);
+        }
+        return changed;
+    }
+
+    private boolean isCountLiteral(AggregateFunction aggFunc) {
+        return !aggFunc.isDistinct()
+                && aggFunc instanceof Count
+                && aggFunc.children().stream().allMatch(e -> e.isLiteral() && 
!e.isNullLiteral());
+    }
+}
diff --git 
a/regression-test/suites/nereids_syntax_p0/sub_query_count_with_const.groovy 
b/regression-test/suites/nereids_syntax_p0/sub_query_count_with_const.groovy
index 02d6453aef..c3168444bb 100644
--- a/regression-test/suites/nereids_syntax_p0/sub_query_count_with_const.groovy
+++ b/regression-test/suites/nereids_syntax_p0/sub_query_count_with_const.groovy
@@ -40,4 +40,12 @@ suite("sub_query_count_with_const") {
                        select 2022 as dt ,sum(id)
                        from sub_query_count_with_const
                  ) tmp;"""
+    explain {
+      sql("""select count(1) as count
+                 from (
+                       select 2022 as dt ,sum(id)
+                       from sub_query_count_with_const
+                 ) tmp;""")
+      contains "count(*)"
+    }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to