This is an automated email from the ASF dual-hosted git repository.

morrySnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new d449371b20a [fix](agg)Adjust agg strategy when table satisfy distinct 
key distribution (#61248)
d449371b20a is described below

commit d449371b20a5451f76ebffe4ec315e8f4a6804b5
Author: feiniaofeiafei <[email protected]>
AuthorDate: Wed May 20 17:13:13 2026 +0800

    [fix](agg)Adjust agg strategy when table satisfy distinct key distribution 
(#61248)
---
 .../LogicalOlapScanToPhysicalOlapScan.java         |  12 +-
 .../rules/rewrite/DistinctAggregateRewriter.java   | 115 +++++++++++
 .../java/org/apache/doris/nereids/util/Utils.java  |  16 ++
 .../rewrite/DistinctAggregateRewriterTest.java     | 224 ++++++++++++++++++++-
 .../agg_skew_rewrite/agg_skew_rewrite.out          |   6 +-
 .../nereids_rules_p0/agg_strategy/agg_strategy.out |  22 +-
 .../data/shape_check/clickbench/query10.out        |  11 +-
 .../data/shape_check/clickbench/query11.out        |   7 +-
 .../data/shape_check/clickbench/query12.out        |   7 +-
 .../data/shape_check/clickbench/query14.out        |   7 +-
 .../data/shape_check/clickbench/query23.out        |   7 +-
 .../data/shape_check/clickbench/query9.out         |   5 +-
 .../distinct_split/disitinct_split.groovy          |   1 +
 13 files changed, 395 insertions(+), 45 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalOlapScanToPhysicalOlapScan.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalOlapScanToPhysicalOlapScan.java
index 6a8b24c2c11..48ff1674709 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalOlapScanToPhysicalOlapScan.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/LogicalOlapScanToPhysicalOlapScan.java
@@ -17,13 +17,10 @@
 
 package org.apache.doris.nereids.rules.implementation;
 
-import org.apache.doris.catalog.ColocateTableIndex;
 import org.apache.doris.catalog.Column;
 import org.apache.doris.catalog.DistributionInfo;
-import org.apache.doris.catalog.Env;
 import org.apache.doris.catalog.HashDistributionInfo;
 import org.apache.doris.catalog.OlapTable;
-import org.apache.doris.catalog.PartitionType;
 import org.apache.doris.nereids.properties.DistributionSpec;
 import org.apache.doris.nereids.properties.DistributionSpecHash;
 import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType;
@@ -35,6 +32,7 @@ import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
 import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan;
+import org.apache.doris.nereids.util.Utils;
 
 import com.google.common.collect.Lists;
 import com.google.common.collect.Sets;
@@ -81,15 +79,11 @@ public class LogicalOlapScanToPhysicalOlapScan extends 
OneImplementationRuleFact
     public static DistributionSpec convertDistribution(LogicalOlapScan 
olapScan) {
         OlapTable olapTable = olapScan.getTable();
         DistributionInfo distributionInfo = 
olapTable.getDefaultDistributionInfo();
-        ColocateTableIndex colocateTableIndex = Env.getCurrentColocateIndex();
         // When there are multiple partitions, olapScan tasks of different 
buckets are dispatched in
         // rounded robin algorithm. Therefore, the hashDistributedSpec can be 
broken except they are in
         // the same stable colocateGroup(CG)
-        boolean isBelongStableCG = 
colocateTableIndex.isColocateTable(olapTable.getId())
-                && 
!colocateTableIndex.isGroupUnstable(colocateTableIndex.getGroup(olapTable.getId()))
-                && olapTable.getCatalogId() == 
Env.getCurrentInternalCatalog().getId();
-        boolean isSelectUnpartition = olapTable.getPartitionInfo().getType() 
== PartitionType.UNPARTITIONED
-                || olapScan.getSelectedPartitionIds().size() == 1;
+        boolean isBelongStableCG = Utils.isBelongStableCG(olapTable);
+        boolean isSelectUnpartition = Utils.isSelectUnpartition(olapTable, 
olapScan.getSelectedPartitionIds());
         // TODO: find a better way to handle both tablet num == 1 and colocate 
table together in future
         if (distributionInfo instanceof HashDistributionInfo && 
(isBelongStableCG || isSelectUnpartition)) {
             if (olapScan.getSelectedIndexId() != 
olapScan.getTable().getBaseIndexId()) {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctAggregateRewriter.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctAggregateRewriter.java
index da0aa21308b..832641add2d 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctAggregateRewriter.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctAggregateRewriter.java
@@ -17,6 +17,10 @@
 
 package org.apache.doris.nereids.rules.rewrite;
 
+import org.apache.doris.catalog.Column;
+import org.apache.doris.catalog.DistributionInfo;
+import org.apache.doris.catalog.HashDistributionInfo;
+import org.apache.doris.catalog.OlapTable;
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
 import org.apache.doris.nereids.rules.rewrite.StatsDerive.DeriveContext;
@@ -24,6 +28,8 @@ import org.apache.doris.nereids.stats.ExpressionEstimation;
 import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
 import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
 import org.apache.doris.nereids.trees.expressions.functions.agg.AnyValue;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
@@ -34,6 +40,9 @@ import 
org.apache.doris.nereids.trees.expressions.functions.agg.Sum0;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
 import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
+import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
 import org.apache.doris.nereids.util.AggregateUtils;
 import org.apache.doris.nereids.util.ExpressionUtils;
 import org.apache.doris.nereids.util.Utils;
@@ -110,6 +119,9 @@ public class DistinctAggregateRewriter implements 
RewriteRuleFactory {
         Statistics aggStats = aggregate.getStats();
         Statistics aggChildStats = aggregate.child().getStats();
         Set<Expression> dstArgs = aggregate.getDistinctArguments();
+        if (isDistinctKeySatisfyDistribution(aggregate)) {
+            return false;
+        }
         // has unknown statistics, split to bottom and top agg
         if 
(AggregateUtils.hasUnknownStatistics(aggregate.getGroupByExpressions(), 
aggChildStats)
                 || AggregateUtils.hasUnknownStatistics(dstArgs, 
aggChildStats)) {
@@ -130,6 +142,109 @@ public class DistinctAggregateRewriter implements 
RewriteRuleFactory {
                 && dstNdv > inputRows * 
AggregateUtils.HIGH_CARDINALITY_THRESHOLD;
     }
 
+    private boolean isDistinctKeySatisfyDistribution(LogicalAggregate<? 
extends Plan> aggregate) {
+        DistinctDistributionInfo info = 
resolveDistinctDistributionInfo(aggregate);
+        if (info == null) {
+            return false;
+        }
+        Set<String> distinctColumnNames = new HashSet<>();
+        for (SlotReference slot : info.distinctSlots) {
+            distinctColumnNames.add(slot.getName().toLowerCase());
+        }
+        DistributionInfo distributionInfo = info.distributionInfo;
+        if (!(distributionInfo instanceof HashDistributionInfo)) {
+            return false;
+        }
+        List<Column> distributionColumns = ((HashDistributionInfo) 
distributionInfo).getDistributionColumns();
+        if (distributionColumns.isEmpty()) {
+            return false;
+        }
+        for (Column column : distributionColumns) {
+            if (!distinctColumnNames.contains(column.getName().toLowerCase())) 
{
+                return false;
+            }
+        }
+        return true;
+    }
+
+    /** This function get the DistinctDistributionInfo from aggregate,
+     * and can handle such scenarios: aggregate's child is filter or project, 
or both of them:
+     * agg(count(distinct a), group by b)
+     *   ->project()
+     *     ->filter()
+     *       ->scan(distributed by hash(a))
+     * @return Table DistributionInfo and the slots of the base table
+     *         referenced by the distinct column of the aggregate function.
+     */
+    private DistinctDistributionInfo 
resolveDistinctDistributionInfo(LogicalAggregate<? extends Plan> aggregate) {
+        Set<Expression> distinctArgs = aggregate.getDistinctArguments();
+        if (distinctArgs.isEmpty()) {
+            return null;
+        }
+        Set<SlotReference> distinctSlots = new HashSet<>();
+        for (Expression expression : distinctArgs) {
+            if (!(expression instanceof SlotReference)) {
+                return null;
+            }
+            distinctSlots.add((SlotReference) expression);
+        }
+        Plan child = aggregate.child();
+        while (child instanceof LogicalProject || child instanceof 
LogicalFilter) {
+            if (child instanceof LogicalProject) {
+                LogicalProject<? extends Plan> project = (LogicalProject<? 
extends Plan>) child;
+                Map<Slot, Expression> projectExprMap = new HashMap<>();
+                for (NamedExpression namedExpression : project.getProjects()) {
+                    Expression projectExpr = namedExpression;
+                    if (namedExpression instanceof Alias) {
+                        projectExpr = ((Alias) namedExpression).child();
+                    }
+                    projectExprMap.put(namedExpression.toSlot(), projectExpr);
+                }
+                Set<SlotReference> replaced = new HashSet<>();
+                for (SlotReference slot : distinctSlots) {
+                    Expression projectExpr = projectExprMap.get(slot);
+                    if (!(projectExpr instanceof SlotReference)) {
+                        return null;
+                    }
+                    replaced.add((SlotReference) projectExpr);
+                }
+                distinctSlots = replaced;
+                child = project.child();
+                continue;
+            }
+            child = ((LogicalFilter<? extends Plan>) child).child();
+        }
+        if (!(child instanceof LogicalOlapScan)) {
+            return null;
+        }
+        LogicalOlapScan scan = (LogicalOlapScan) child;
+        OlapTable olapTable = scan.getTable();
+        if (olapTable == null) {
+            return null;
+        }
+        if (!Utils.isSelectUnpartition(olapTable, 
scan.getSelectedPartitionIds())
+                && !Utils.isBelongStableCG(olapTable)) {
+            return null;
+        }
+        for (SlotReference slot : distinctSlots) {
+            if (!slot.getOriginalTable().isPresent()
+                    || slot.getOriginalTable().get() != olapTable) {
+                return null;
+            }
+        }
+        return new 
DistinctDistributionInfo(olapTable.getDefaultDistributionInfo(), distinctSlots);
+    }
+
+    private static class DistinctDistributionInfo {
+        private final DistributionInfo distributionInfo;
+        private final Set<SlotReference> distinctSlots;
+
+        private DistinctDistributionInfo(DistributionInfo distributionInfo, 
Set<SlotReference> distinctSlots) {
+            this.distributionInfo = distributionInfo;
+            this.distinctSlots = distinctSlots;
+        }
+    }
+
     private Plan rewrite(LogicalAggregate<? extends Plan> aggregate, 
ConnectContext ctx) {
         if (aggregate.distinctFuncNum() == 0) {
             return null;
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/Utils.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/Utils.java
index 16e4eb82ef6..e01ae2baf60 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/Utils.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/Utils.java
@@ -17,6 +17,10 @@
 
 package org.apache.doris.nereids.util;
 
+import org.apache.doris.catalog.ColocateTableIndex;
+import org.apache.doris.catalog.Env;
+import org.apache.doris.catalog.OlapTable;
+import org.apache.doris.catalog.PartitionType;
 import org.apache.doris.catalog.info.TableNameInfo;
 import org.apache.doris.nereids.exceptions.AnalysisException;
 import org.apache.doris.nereids.glue.LogicalPlanAdapter;
@@ -175,6 +179,18 @@ public class Utils {
         return StringUtils.join(qualifierWithBackquote, ".");
     }
 
+    public static boolean isBelongStableCG(OlapTable olapTable) {
+        ColocateTableIndex colocateTableIndex = Env.getCurrentColocateIndex();
+        return colocateTableIndex.isColocateTable(olapTable.getId())
+                && 
!colocateTableIndex.isGroupUnstable(colocateTableIndex.getGroup(olapTable.getId()))
+                && olapTable.getCatalogId() == 
Env.getCurrentInternalCatalog().getId();
+    }
+
+    public static boolean isSelectUnpartition(OlapTable olapTable, 
Collection<Long> selectedPartitionIds) {
+        return olapTable.getPartitionInfo().getType() == 
PartitionType.UNPARTITIONED
+                || selectedPartitionIds.size() == 1;
+    }
+
     /**
      * Get sql string for plan.
      *
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/DistinctAggregateRewriterTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/DistinctAggregateRewriterTest.java
index ef2ae01f302..3bfb42918a1 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/DistinctAggregateRewriterTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/DistinctAggregateRewriterTest.java
@@ -17,17 +17,42 @@
 
 package org.apache.doris.nereids.rules.rewrite;
 
+import org.apache.doris.catalog.DistributionInfo;
+import org.apache.doris.catalog.HashDistributionInfo;
+import 
org.apache.doris.nereids.rules.analysis.LogicalSubQueryAliasToLogicalProject;
+import org.apache.doris.nereids.trees.expressions.SlotReference;
 import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
 import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
 import 
org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctCount;
 import 
org.apache.doris.nereids.trees.expressions.functions.agg.MultiDistinctGroupConcat;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum0;
 import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
+import org.apache.doris.nereids.trees.plans.AbstractPlan;
+import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
 import org.apache.doris.nereids.util.MemoPatternMatchSupported;
 import org.apache.doris.nereids.util.PlanChecker;
+import org.apache.doris.qe.SessionVariable;
+import org.apache.doris.statistics.ColumnStatistic;
+import org.apache.doris.statistics.ColumnStatisticBuilder;
+import org.apache.doris.statistics.Statistics;
 import org.apache.doris.utframe.TestWithFeService;
 
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
 import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.Assertions;
 import org.junit.jupiter.api.Test;
+import org.mockito.Mockito;
+
+import java.lang.reflect.Field;
+import java.lang.reflect.Method;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
 
 public class DistinctAggregateRewriterTest extends TestWithFeService 
implements MemoPatternMatchSupported {
     @Override
@@ -36,7 +61,32 @@ public class DistinctAggregateRewriterTest extends 
TestWithFeService implements
         createTable("create table test.distinct_agg_split_t(a int null, b int 
not null,"
                 + "c varchar(10) null, d date, dt datetime)\n"
                 + "distributed by hash(a) properties('replication_num' = 
'1');");
+        createTable("CREATE TABLE IF NOT EXISTS test.sales_records\n"
+                + "(\n"
+                + "    record_id BIGINT,\n"
+                + "    seller_id BIGINT,\n"
+                + "    sale_date DATE,\n"
+                + "    amount DECIMAL(18,2)\n"
+                + ")\n"
+                + "DUPLICATE KEY(record_id, seller_id)\n"
+                + "PARTITION BY RANGE(sale_date)\n"
+                + "(\n"
+                + "    PARTITION p202301 VALUES LESS THAN ('2023-02-01'),\n"
+                + "    PARTITION p202302 VALUES LESS THAN ('2023-03-01'),\n"
+                + "    PARTITION p202303 VALUES LESS THAN ('2023-04-01')\n"
+                + ")\n"
+                + "DISTRIBUTED BY HASH(record_id) BUCKETS 10\n"
+                + "PROPERTIES (\n"
+                + "    \"replication_num\" = \"1\"\n"
+                + ");");
+        createTable("create table test.distinct_agg_hash_ab(a int null, b int 
not null, c int null, d int null)\n"
+                + "distributed by hash(a, b) properties('replication_num' = 
'1');");
+        createTable("create table test.distinct_agg_hash_abcd(a int null, b 
int not null, c int null, d int null)\n"
+                + "distributed by hash(a, b, c, d) 
properties('replication_num' = '1');");
         connectContext.setDatabase("test");
+        SessionVariable spySessionVariable = 
Mockito.spy(connectContext.getSessionVariable());
+        
Mockito.doReturn(24).when(spySessionVariable).getParallelExecInstanceNum(Mockito.anyString());
+        connectContext.setSessionVariable(spySessionVariable);
         
connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION");
     }
 
@@ -60,7 +110,7 @@ public class DistinctAggregateRewriterTest extends 
TestWithFeService implements
                                 logicalAggregate().when(agg -> 
agg.getGroupByExpressions().size() == 1
                                         && 
agg.getGroupByExpressions().get(0).toSql().equals("b")
                                         && agg.getAggregateFunctions().stream()
-                                        .anyMatch(f -> f instanceof 
MultiDistinctCount))
+                                        .anyMatch(f -> f instanceof Count))
                         )
                 );
     }
@@ -77,9 +127,7 @@ public class DistinctAggregateRewriterTest extends 
TestWithFeService implements
                                 logicalAggregate().when(agg -> 
agg.getGroupByExpressions().size() == 1
                                         && 
agg.getGroupByExpressions().get(0).toSql().equals("b")
                                         && agg.getAggregateFunctions().stream()
-                                        .anyMatch(f -> f instanceof 
MultiDistinctCount)
-                                        && agg.getAggregateFunctions().stream()
-                                        .anyMatch(f -> f instanceof Count && 
!f.isDistinct()))
+                                        .allMatch(f -> f instanceof Count || f 
instanceof Sum0))
                         )
                 );
     }
@@ -96,8 +144,7 @@ public class DistinctAggregateRewriterTest extends 
TestWithFeService implements
                                 logicalAggregate().when(agg -> 
agg.getGroupByExpressions().size() == 1
                                         && 
agg.getGroupByExpressions().get(0).toSql().equals("b")
                                         && agg.getAggregateFunctions().stream()
-                                        .anyMatch(f -> f instanceof 
MultiDistinctCount)
-                                        && 
agg.getAggregateFunctions().stream().noneMatch(AggregateFunction::isDistinct)
+                                        .anyMatch(f -> f instanceof Count && 
!f.isDistinct())
                                 )));
     }
 
@@ -194,5 +241,170 @@ public class DistinctAggregateRewriterTest extends 
TestWithFeService implements
                                 && 
agg.getAggregateFunctions().stream().noneMatch(AggregateFunction::isDistinct)
                                 && 
agg.getAggregateFunctions().stream().anyMatch(f -> f instanceof 
MultiDistinctCount)
                         ));
+        connectContext.getSessionVariable().setAggPhase(0);
+    }
+
+    @Test
+    void testShouldUseMultiDistinctWithoutStatsSatisfyDistribution() throws 
Exception {
+        DistinctAggregateRewriter rewriter = 
DistinctAggregateRewriter.INSTANCE;
+        LogicalAggregate<? extends Plan> aggregate = getLogicalAggregate(
+                "select bb, count(distinct aa) from "
+                        + "(select a as aa, b as bb from 
test.distinct_agg_split_t where b > 1) t "
+                        + "group by bb"
+        );
+        Plan child = aggregate.child();
+        Map<org.apache.doris.nereids.trees.expressions.Expression, 
ColumnStatistic> colStats = new HashMap<>();
+        aggregate.getGroupByExpressions().forEach(expr ->
+                colStats.put(expr, unknownColumnStats()));
+        aggregate.getDistinctArguments().forEach(expr ->
+                colStats.put(expr, unknownColumnStats()));
+        ((AbstractPlan) child).setStatistics(new Statistics(10000, colStats));
+        aggregate.setStatistics(new Statistics(100, ImmutableMap.of()));
+
+        Assertions.assertFalse(rewriter.shouldUseMultiDistinct(aggregate));
+    }
+
+    @Test
+    void testShouldUseMultiDistinctWithStatsSelected() throws Exception {
+        DistinctAggregateRewriter rewriter = new DistinctAggregateRewriter();
+        LogicalAggregate<? extends Plan> aggregate = getLogicalAggregate(
+                "select b, count(distinct a) from test.distinct_agg_split_t 
group by b"
+        );
+        Plan child = aggregate.child();
+        Map<org.apache.doris.nereids.trees.expressions.Expression, 
ColumnStatistic> colStats = new HashMap<>();
+        aggregate.getGroupByExpressions().forEach(expr ->
+                colStats.put(expr, buildColumnStats(240, false)));
+        aggregate.getDistinctArguments().forEach(expr ->
+                colStats.put(expr, buildColumnStats(10000.0, false)));
+        ((AbstractPlan) child).setStatistics(new Statistics(100000, colStats));
+        aggregate.setStatistics(new Statistics(240, ImmutableMap.of()));
+
+        Assertions.assertFalse(rewriter.shouldUseMultiDistinct(aggregate));
+    }
+
+    @Test
+    void testShouldUseMultiDistinctWithPartitionTable() {
+        DistinctAggregateRewriter rewriter = 
DistinctAggregateRewriter.INSTANCE;
+        LogicalAggregate<? extends Plan> aggregate = getLogicalAggregate(
+                "select count(distinct record_id) from sales_records group by 
sale_date;"
+        );
+        Plan child = aggregate.child();
+        Map<org.apache.doris.nereids.trees.expressions.Expression, 
ColumnStatistic> colStats = new HashMap<>();
+        aggregate.getGroupByExpressions().forEach(expr ->
+                colStats.put(expr, unknownColumnStats()));
+        aggregate.getDistinctArguments().forEach(expr ->
+                colStats.put(expr, unknownColumnStats()));
+        ((AbstractPlan) child).setStatistics(new Statistics(10000, colStats));
+        aggregate.setStatistics(new Statistics(100, ImmutableMap.of()));
+
+        Assertions.assertTrue(rewriter.shouldUseMultiDistinct(aggregate));
+    }
+
+    @Test
+    void testResolveDistinctDistributionInfoWithProjectAndFilter() throws 
Exception {
+        LogicalAggregate<? extends Plan> aggregate = getLogicalAggregate(
+                "select bb, count(distinct aa) from "
+                        + "(select a as aa, b as bb from 
test.distinct_agg_hash_ab where c > 1) t "
+                        + "group by bb"
+        );
+
+        Object info = invokeResolveDistinctDistributionInfo(aggregate);
+        Assertions.assertNotNull(info);
+
+        List<String> distinctSlotNames = getDistinctSlots(info).stream()
+                .map(SlotReference::getName)
+                .collect(Collectors.toList());
+        Assertions.assertEquals(ImmutableList.of("a"), distinctSlotNames);
+
+        DistributionInfo distributionInfo = getDistributionInfo(info);
+        Assertions.assertTrue(distributionInfo instanceof 
HashDistributionInfo);
+        List<String> distributionColumnNames = ((HashDistributionInfo) 
distributionInfo).getDistributionColumns().stream()
+                .map(column -> column.getName().toLowerCase())
+                .collect(Collectors.toList());
+        Assertions.assertEquals(ImmutableList.of("a", "b"), 
distributionColumnNames);
+    }
+
+    @Test
+    void 
testIsDistinctKeySatisfyDistributionWhenDistinctContainsDistributionColumns() 
throws Exception {
+        LogicalAggregate<? extends Plan> aggregate = getLogicalAggregate(
+                "select d, count(distinct a, b, c) from 
test.distinct_agg_hash_ab group by d"
+        );
+
+        
Assertions.assertTrue(invokeIsDistinctKeySatisfyDistribution(aggregate));
+    }
+
+    @Test
+    void testIsDistinctKeySatisfyDistributionWhenDistributionHasExtraColumns() 
throws Exception {
+        LogicalAggregate<? extends Plan> aggregate = getLogicalAggregate(
+                "select d, count(distinct a, b, c) from 
test.distinct_agg_hash_abcd group by d"
+        );
+
+        
Assertions.assertFalse(invokeIsDistinctKeySatisfyDistribution(aggregate));
+    }
+
+    private LogicalAggregate<? extends Plan> getLogicalAggregate(String sql) {
+        Plan plan = PlanChecker.from(connectContext)
+                .analyze(sql)
+                .applyTopDown(new LogicalSubQueryAliasToLogicalProject())
+                .getPlan();
+        Optional<LogicalAggregate<? extends Plan>> aggregate = 
findAggregate(plan);
+        Assertions.assertTrue(aggregate.isPresent());
+        return aggregate.get();
+    }
+
+    private Optional<LogicalAggregate<? extends Plan>> findAggregate(Plan 
plan) {
+        if (plan instanceof LogicalAggregate) {
+            return Optional.of((LogicalAggregate<? extends Plan>) plan);
+        }
+        for (Plan child : plan.children()) {
+            Optional<LogicalAggregate<? extends Plan>> found = 
findAggregate(child);
+            if (found.isPresent()) {
+                return found;
+            }
+        }
+        return Optional.empty();
+    }
+
+    private Object invokeResolveDistinctDistributionInfo(LogicalAggregate<? 
extends Plan> aggregate) throws Exception {
+        Method method = DistinctAggregateRewriter.class.getDeclaredMethod(
+                "resolveDistinctDistributionInfo", LogicalAggregate.class);
+        method.setAccessible(true);
+        return method.invoke(DistinctAggregateRewriter.INSTANCE, aggregate);
+    }
+
+    private boolean invokeIsDistinctKeySatisfyDistribution(LogicalAggregate<? 
extends Plan> aggregate) throws Exception {
+        Method method = DistinctAggregateRewriter.class.getDeclaredMethod(
+                "isDistinctKeySatisfyDistribution", LogicalAggregate.class);
+        method.setAccessible(true);
+        return (boolean) method.invoke(DistinctAggregateRewriter.INSTANCE, 
aggregate);
+    }
+
+    @SuppressWarnings("unchecked")
+    private Set<SlotReference> getDistinctSlots(Object info) throws Exception {
+        Field field = info.getClass().getDeclaredField("distinctSlots");
+        field.setAccessible(true);
+        return (Set<SlotReference>) field.get(info);
+    }
+
+    private DistributionInfo getDistributionInfo(Object info) throws Exception 
{
+        Field field = info.getClass().getDeclaredField("distributionInfo");
+        field.setAccessible(true);
+        return (DistributionInfo) field.get(info);
+    }
+
+    private ColumnStatistic unknownColumnStats() {
+        return buildColumnStats(0.0, true);
+    }
+
+    private ColumnStatistic buildColumnStats(double ndv, boolean isUnknown) {
+        return new ColumnStatisticBuilder(1)
+                .setNdv(ndv)
+                .setAvgSizeByte(4)
+                .setNumNulls(0)
+                .setMinValue(0)
+                .setMaxValue(100)
+                .setIsUnknown(isUnknown)
+                .setUpdatedTime("")
+                .build();
     }
 }
diff --git 
a/regression-test/data/nereids_rules_p0/agg_skew_rewrite/agg_skew_rewrite.out 
b/regression-test/data/nereids_rules_p0/agg_skew_rewrite/agg_skew_rewrite.out
index 8953d96f3eb..21af6f5304a 100644
--- 
a/regression-test/data/nereids_rules_p0/agg_skew_rewrite/agg_skew_rewrite.out
+++ 
b/regression-test/data/nereids_rules_p0/agg_skew_rewrite/agg_skew_rewrite.out
@@ -564,10 +564,12 @@ PhysicalResultSink
 PhysicalResultSink
 --hashAgg[GLOBAL]
 ----hashAgg[LOCAL]
-------PhysicalOlapScan[test_skew_hint]
+------hashAgg[GLOBAL]
+--------PhysicalOlapScan[test_skew_hint]
 
 -- !shape_not_rewrite --
 PhysicalResultSink
 --hashAgg[GLOBAL]
-----PhysicalOlapScan[test_skew_hint]
+----hashAgg[GLOBAL]
+------PhysicalOlapScan[test_skew_hint]
 
diff --git 
a/regression-test/data/nereids_rules_p0/agg_strategy/agg_strategy.out 
b/regression-test/data/nereids_rules_p0/agg_strategy/agg_strategy.out
index 4338ac45a66..cda86c3ee03 100644
--- a/regression-test/data/nereids_rules_p0/agg_strategy/agg_strategy.out
+++ b/regression-test/data/nereids_rules_p0/agg_strategy/agg_strategy.out
@@ -153,7 +153,8 @@ PhysicalResultSink
 --------hashAgg[GLOBAL]
 ----------PhysicalDistribute[DistributionSpecHash]
 ------------hashAgg[LOCAL]
---------------PhysicalOlapScan[t_gbykey_10_dstkey_10_1000_id]
+--------------hashAgg[GLOBAL]
+----------------PhysicalOlapScan[t_gbykey_10_dstkey_10_1000_id]
 
 -- !agg_distinct_with_gby_key_with_other_func --
 PhysicalResultSink
@@ -178,10 +179,11 @@ PhysicalResultSink
 --PhysicalQuickSort[MERGE_SORT]
 ----PhysicalDistribute[DistributionSpecGather]
 ------PhysicalQuickSort[LOCAL_SORT]
---------hashAgg[GLOBAL]
+--------hashAgg[DISTINCT_GLOBAL]
 ----------PhysicalDistribute[DistributionSpecHash]
-------------hashAgg[LOCAL]
---------------PhysicalOlapScan[t_gbykey_10_dstkey_10_1000_id]
+------------hashAgg[DISTINCT_LOCAL]
+--------------hashAgg[GLOBAL]
+----------------PhysicalOlapScan[t_gbykey_10_dstkey_10_1000_id]
 
 -- !agg_distinct_without_gby_key --
 PhysicalResultSink
@@ -561,7 +563,8 @@ PhysicalResultSink
 --------hashAgg[GLOBAL]
 ----------PhysicalDistribute[DistributionSpecHash]
 ------------hashAgg[LOCAL]
---------------PhysicalOlapScan[t_gbykey_2_dstkey_10_30_id]
+--------------hashAgg[GLOBAL]
+----------------PhysicalOlapScan[t_gbykey_2_dstkey_10_30_id]
 
 -- !agg_distinct_with_gby_key_with_other_func_low_ndv --
 PhysicalResultSink
@@ -586,10 +589,11 @@ PhysicalResultSink
 --PhysicalQuickSort[MERGE_SORT]
 ----PhysicalDistribute[DistributionSpecGather]
 ------PhysicalQuickSort[LOCAL_SORT]
---------hashAgg[GLOBAL]
-----------PhysicalDistribute[DistributionSpecHash]
-------------hashAgg[LOCAL]
---------------PhysicalOlapScan[t_gbykey_2_dstkey_10_30_id]
+--------hashAgg[DISTINCT_GLOBAL]
+----------hashAgg[GLOBAL]
+------------PhysicalDistribute[DistributionSpecHash]
+--------------hashAgg[LOCAL]
+----------------PhysicalOlapScan[t_gbykey_2_dstkey_10_30_id]
 
 -- !agg_distinct_without_gby_key_low_ndv --
 PhysicalResultSink
diff --git a/regression-test/data/shape_check/clickbench/query10.out 
b/regression-test/data/shape_check/clickbench/query10.out
index c7840564369..36122c8cbf0 100644
--- a/regression-test/data/shape_check/clickbench/query10.out
+++ b/regression-test/data/shape_check/clickbench/query10.out
@@ -4,9 +4,10 @@ PhysicalResultSink
 --PhysicalTopN[MERGE_SORT]
 ----PhysicalDistribute[DistributionSpecGather]
 ------PhysicalTopN[LOCAL_SORT]
---------hashAgg[GLOBAL]
-----------PhysicalDistribute[DistributionSpecHash]
-------------hashAgg[LOCAL]
---------------PhysicalProject
-----------------PhysicalOlapScan[hits]
+--------hashAgg[DISTINCT_GLOBAL]
+----------hashAgg[GLOBAL]
+------------PhysicalDistribute[DistributionSpecHash]
+--------------hashAgg[LOCAL]
+----------------PhysicalProject
+------------------PhysicalOlapScan[hits]
 
diff --git a/regression-test/data/shape_check/clickbench/query11.out 
b/regression-test/data/shape_check/clickbench/query11.out
index 2ff48a34cad..0fac14167ad 100644
--- a/regression-test/data/shape_check/clickbench/query11.out
+++ b/regression-test/data/shape_check/clickbench/query11.out
@@ -7,7 +7,8 @@ PhysicalResultSink
 --------hashAgg[GLOBAL]
 ----------PhysicalDistribute[DistributionSpecHash]
 ------------hashAgg[LOCAL]
---------------PhysicalProject
-----------------filter(( not (length(MobilePhoneModel) = 0)))
-------------------PhysicalOlapScan[hits]
+--------------hashAgg[GLOBAL]
+----------------PhysicalProject
+------------------filter(( not (length(MobilePhoneModel) = 0)))
+--------------------PhysicalOlapScan[hits]
 
diff --git a/regression-test/data/shape_check/clickbench/query12.out 
b/regression-test/data/shape_check/clickbench/query12.out
index 2c2148174a1..69ad0325529 100644
--- a/regression-test/data/shape_check/clickbench/query12.out
+++ b/regression-test/data/shape_check/clickbench/query12.out
@@ -7,7 +7,8 @@ PhysicalResultSink
 --------hashAgg[GLOBAL]
 ----------PhysicalDistribute[DistributionSpecHash]
 ------------hashAgg[LOCAL]
---------------PhysicalProject
-----------------filter(( not (length(MobilePhoneModel) = 0)))
-------------------PhysicalOlapScan[hits]
+--------------hashAgg[GLOBAL]
+----------------PhysicalProject
+------------------filter(( not (length(MobilePhoneModel) = 0)))
+--------------------PhysicalOlapScan[hits]
 
diff --git a/regression-test/data/shape_check/clickbench/query14.out 
b/regression-test/data/shape_check/clickbench/query14.out
index 59f29de703e..a33268c31e4 100644
--- a/regression-test/data/shape_check/clickbench/query14.out
+++ b/regression-test/data/shape_check/clickbench/query14.out
@@ -7,7 +7,8 @@ PhysicalResultSink
 --------hashAgg[GLOBAL]
 ----------PhysicalDistribute[DistributionSpecHash]
 ------------hashAgg[LOCAL]
---------------PhysicalProject
-----------------filter(( not (length(SearchPhrase) = 0)))
-------------------PhysicalOlapScan[hits]
+--------------hashAgg[GLOBAL]
+----------------PhysicalProject
+------------------filter(( not (length(SearchPhrase) = 0)))
+--------------------PhysicalOlapScan[hits]
 
diff --git a/regression-test/data/shape_check/clickbench/query23.out 
b/regression-test/data/shape_check/clickbench/query23.out
index e38fca278f5..fec234f0fff 100644
--- a/regression-test/data/shape_check/clickbench/query23.out
+++ b/regression-test/data/shape_check/clickbench/query23.out
@@ -7,7 +7,8 @@ PhysicalResultSink
 --------hashAgg[GLOBAL]
 ----------PhysicalDistribute[DistributionSpecHash]
 ------------hashAgg[LOCAL]
---------------PhysicalProject
-----------------filter(( not (URL like '%.google.%')) and ( not 
(length(SearchPhrase) = 0)) and (Title like '%Google%'))
-------------------PhysicalOlapScan[hits]
+--------------hashAgg[GLOBAL]
+----------------PhysicalProject
+------------------filter(( not (URL like '%.google.%')) and ( not 
(length(SearchPhrase) = 0)) and (Title like '%Google%'))
+--------------------PhysicalOlapScan[hits]
 
diff --git a/regression-test/data/shape_check/clickbench/query9.out 
b/regression-test/data/shape_check/clickbench/query9.out
index dcece9f0ce7..b35cb2e2a80 100644
--- a/regression-test/data/shape_check/clickbench/query9.out
+++ b/regression-test/data/shape_check/clickbench/query9.out
@@ -7,6 +7,7 @@ PhysicalResultSink
 --------hashAgg[GLOBAL]
 ----------PhysicalDistribute[DistributionSpecHash]
 ------------hashAgg[LOCAL]
---------------PhysicalProject
-----------------PhysicalOlapScan[hits]
+--------------hashAgg[GLOBAL]
+----------------PhysicalProject
+------------------PhysicalOlapScan[hits]
 
diff --git 
a/regression-test/suites/nereids_rules_p0/distinct_split/disitinct_split.groovy 
b/regression-test/suites/nereids_rules_p0/distinct_split/disitinct_split.groovy
index f9da6636064..883f0529132 100644
--- 
a/regression-test/suites/nereids_rules_p0/distinct_split/disitinct_split.groovy
+++ 
b/regression-test/suites/nereids_rules_p0/distinct_split/disitinct_split.groovy
@@ -20,6 +20,7 @@ suite("distinct_split") {
     sql "set disable_join_reorder=true"
     sql "set global enable_auto_analyze=false;"
     sql "set be_number_for_test=1;"
+    sql "set parallel_pipeline_task_num=1;"
     sql "drop table if exists test_distinct_multi"
     sql "create table test_distinct_multi(a int, b int, c int, d varchar(10), 
e date) distributed by hash(a) properties('replication_num'='1');"
     sql "insert into test_distinct_multi 
values(1,2,3,'abc','2024-01-02'),(1,2,4,'abc','2024-01-03'),(2,2,4,'abcd','2024-01-02'),(1,2,3,'abcd','2024-01-04'),(1,2,4,'eee','2024-02-02'),(2,2,4,'abc','2024-01-02');"


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to