This is an automated email from the ASF dual-hosted git repository.

englefly pushed a commit to branch tpc_preview6
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/tpc_preview6 by this push:
     new 1ff08374101 make slot that ndv > LOW_NDV_THRESHOLD can be chosen as 
shuffle key in decomposeRepeat (#60610)
1ff08374101 is described below

commit 1ff083741016be9b9f8583b99f29483ac3e2e78f
Author: feiniaofeiafei <[email protected]>
AuthorDate: Mon Feb 9 15:00:19 2026 +0800

    make slot that ndv > LOW_NDV_THRESHOLD can be chosen as shuffle key in 
decomposeRepeat (#60610)
    
    ### What problem does this PR solve?
    
    Issue Number: close #xxx
    
    Related PR: #xxx
    
    Problem Summary:
    
    ### Release note
    
    None
    
    ### Check List (For Author)
    
    - Test <!-- At least one of them must be included. -->
        - [ ] Regression test
        - [ ] Unit Test
        - [ ] Manual test (add detailed scripts or steps below)
        - [ ] No need to test or manual test. Explain why:
    - [ ] This is a refactor/code format and no logic has been changed.
            - [ ] Previous test can cover this change.
            - [ ] No code files have been changed.
            - [ ] Other reason <!-- Add your reason?  -->
    
    - Behavior changed:
        - [ ] No.
        - [ ] Yes. <!-- Explain the behavior change -->
    
    - Does this need documentation?
        - [ ] No.
    - [ ] Yes. <!-- Add document PR link here. eg:
    https://github.com/apache/doris-website/pull/1214 -->
    
    ### Check List (For Reviewer who merge this PR)
    
    - [ ] Confirm the release note
    - [ ] Confirm test cases
    - [ ] Confirm document
    - [ ] Add branch pick label <!-- Add branch pick label that this PR
    should merge into -->
---
 .../java/org/apache/doris/nereids/PlanContext.java |   4 +
 .../properties/ChildrenPropertiesRegulator.java    |  24 ---
 .../nereids/properties/RequestPropertyDeriver.java |   8 +-
 .../rewrite/DecomposeRepeatWithPreAggregation.java |  64 +++-----
 .../java/org/apache/doris/qe/SessionVariable.java  |   9 ++
 .../doris/statistics/util/StatisticsUtil.java      |  20 +++
 .../properties/RequestPropertyDeriverTest.java     |  92 +++++++++++
 .../DecomposeRepeatWithPreAggregationTest.java     | 171 ++++++++++++++++++++-
 .../decompose_repeat/decompose_repeat.groovy       |  24 ++-
 9 files changed, 334 insertions(+), 82 deletions(-)

diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/PlanContext.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/PlanContext.java
index 09285a638de..e20d9bf4044 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/PlanContext.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/PlanContext.java
@@ -83,4 +83,8 @@ public class PlanContext {
     public StatementContext getStatementContext() {
         return connectContext.getStatementContext();
     }
+
+    public ConnectContext getConnectContext() {
+        return connectContext;
+    }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java
index b415b1f8007..8db30640797 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/ChildrenPropertiesRegulator.java
@@ -116,31 +116,7 @@ public class ChildrenPropertiesRegulator extends 
PlanVisitor<List<List<PhysicalP
         if (agg.getGroupByExpressions().isEmpty() && 
agg.getOutputExpressions().isEmpty()) {
             return ImmutableList.of();
         }
-        // If the origin attribute satisfies the group by key but does not 
meet the requirements, ban the plan.
-        // e.g. select count(distinct a) from t group by b;
-        // requiredChildProperty: a
-        // but the child is already distributed by b
-        // ban this plan
-        PhysicalProperties originChildProperty = 
originChildrenProperties.get(0);
         PhysicalProperties requiredChildProperty = requiredProperties.get(0);
-        PhysicalProperties hashSpec = 
PhysicalProperties.createHash(agg.getGroupByExpressions(), ShuffleType.REQUIRE);
-        GroupExpression child = children.get(0);
-        if (child.getPlan() instanceof PhysicalDistribute) {
-            PhysicalProperties properties = new PhysicalProperties(
-                    DistributionSpecAny.INSTANCE, 
originChildProperty.getOrderSpec());
-            Optional<Pair<Cost, GroupExpression>> pair = 
child.getOwnerGroup().getLowestCostPlan(properties);
-            // add null check
-            if (!pair.isPresent()) {
-                return ImmutableList.of();
-            }
-            GroupExpression distributeChild = pair.get().second;
-            PhysicalProperties distributeChildProperties = 
distributeChild.getOutputProperties(properties);
-            if (distributeChildProperties.satisfy(hashSpec)
-                    && 
!distributeChildProperties.satisfy(requiredChildProperty)) {
-                return ImmutableList.of();
-            }
-        }
-
         if (!agg.getAggregateParam().canBeBanned) {
             return visit(agg, context);
         }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java
index bee83f0cc80..5a7134fffe0 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/RequestPropertyDeriver.java
@@ -468,7 +468,7 @@ public class RequestPropertyDeriver extends 
PlanVisitor<Void, PlanContext> {
                 Set<ExprId> intersectId = Sets.intersection(new 
HashSet<>(parentHashExprIds),
                         new HashSet<>(groupByExprIds));
                 if (!intersectId.isEmpty() && intersectId.size() < 
groupByExprIds.size()) {
-                    if (shouldUseParent(parentHashExprIds, agg)) {
+                    if (shouldUseParent(parentHashExprIds, agg, context)) {
                         
addRequestPropertyToChildren(PhysicalProperties.createHash(
                                 Utils.fastToImmutableList(intersectId), 
ShuffleType.REQUIRE));
                     }
@@ -482,7 +482,11 @@ public class RequestPropertyDeriver extends 
PlanVisitor<Void, PlanContext> {
         return null;
     }
 
-    private boolean shouldUseParent(List<ExprId> parentHashExprIds, 
PhysicalHashAggregate<? extends Plan> agg) {
+    private boolean shouldUseParent(List<ExprId> parentHashExprIds, 
PhysicalHashAggregate<? extends Plan> agg,
+            PlanContext context) {
+        if 
(!context.getConnectContext().getSessionVariable().aggShuffleUseParentKey) {
+            return false;
+        }
         Optional<GroupExpression> groupExpression = agg.getGroupExpression();
         if (!groupExpression.isPresent()) {
             return true;
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregation.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregation.java
index 3939d79a510..fd40b635ffd 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregation.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregation.java
@@ -54,12 +54,12 @@ import org.apache.doris.nereids.util.ExpressionUtils;
 import org.apache.doris.qe.ConnectContext;
 import org.apache.doris.statistics.ColumnStatistic;
 import org.apache.doris.statistics.Statistics;
+import org.apache.doris.statistics.util.StatisticsUtil;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
 
 import java.util.ArrayList;
-import java.util.Comparator;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.List;
@@ -407,11 +407,7 @@ public class DecomposeRepeatWithPreAggregation extends 
DefaultPlanRewriter<Disti
         if (groupingSets.size() <= 
connectContext.getSessionVariable().decomposeRepeatThreshold) {
             return -1;
         }
-        int maxGroupIndex = findMaxGroupingSetIndex(groupingSets);
-        if (maxGroupIndex < 0) {
-            return -1;
-        }
-        return maxGroupIndex;
+        return findMaxGroupingSetIndex(groupingSets);
     }
 
     /**
@@ -436,6 +432,9 @@ public class DecomposeRepeatWithPreAggregation extends 
DefaultPlanRewriter<Disti
                 maxGroupIndex = i;
             }
         }
+        if (groupingSets.get(maxGroupIndex).isEmpty()) {
+            return -1;
+        }
         // Second pass: verify that the max-size grouping set contains all 
other grouping sets
         ImmutableSet<Expression> maxGroup = 
ImmutableSet.copyOf(groupingSets.get(maxGroupIndex));
         for (int i = 0; i < groupingSets.size(); ++i) {
@@ -520,14 +519,14 @@ public class DecomposeRepeatWithPreAggregation extends 
DefaultPlanRewriter<Disti
         switch (repeat.getRepeatType()) {
             case CUBE:
                 // Prefer larger NDV to improve balance
-                chosen = chooseByNdv(maxGroupByList, inputStats, 
totalInstanceNum);
+                chosen = chooseOneBalancedKey(maxGroupByList, inputStats, 
totalInstanceNum);
                 break;
             case GROUPING_SETS:
                 chosen = chooseByAppearanceThenNdv(repeat.getGroupingSets(), 
maxGroupIndex, maxGroupByList,
                         inputStats, totalInstanceNum);
                 break;
             case ROLLUP:
-                chosen = chooseByRollupPrefixThenNdv(maxGroupByList, 
inputStats, totalInstanceNum);
+                chosen = chooseOneBalancedKey(maxGroupByList, inputStats, 
totalInstanceNum);
                 break;
             default:
                 chosen = Optional.empty();
@@ -535,17 +534,21 @@ public class DecomposeRepeatWithPreAggregation extends 
DefaultPlanRewriter<Disti
         return chosen.map(ImmutableList::of);
     }
 
-    private Optional<Expression> chooseByNdv(List<Expression> candidates, 
Statistics inputStats, int totalInstanceNum) {
+    private Optional<Expression> chooseOneBalancedKey(List<Expression> 
candidates, Statistics inputStats,
+            int totalInstanceNum) {
         if (inputStats == null) {
             return Optional.empty();
         }
-        Comparator<Expression> cmp = Comparator.comparingDouble(e -> 
estimateNdv(e, inputStats));
-        Optional<Expression> choose = candidates.stream().max(cmp);
-        if (choose.isPresent() && estimateNdv(choose.get(), inputStats) > 
totalInstanceNum) {
-            return choose;
-        } else {
-            return Optional.empty();
+        for (Expression candidate : candidates) {
+            ColumnStatistic columnStatistic = 
inputStats.findColumnStatistics(candidate);
+            if (columnStatistic == null || columnStatistic.isUnKnown()) {
+                continue;
+            }
+            if (StatisticsUtil.isBalanced(columnStatistic, 
inputStats.getRowCount(), totalInstanceNum)) {
+                return Optional.of(candidate);
+            }
         }
+        return Optional.empty();
     }
 
     /**
@@ -568,42 +571,17 @@ public class DecomposeRepeatWithPreAggregation extends 
DefaultPlanRewriter<Disti
                 }
             }
         }
-        Map<Integer, List<Expression>> countToCandidate = new TreeMap<>();
+        TreeMap<Integer, List<Expression>> countToCandidate = new TreeMap<>();
         for (Map.Entry<Expression, Integer> entry : appearCount.entrySet()) {
             countToCandidate.computeIfAbsent(entry.getValue(), v -> new 
ArrayList<>()).add(entry.getKey());
         }
-        for (Map.Entry<Integer, List<Expression>> entry : 
countToCandidate.entrySet()) {
-            Optional<Expression> chosen = chooseByNdv(entry.getValue(), 
inputStats, totalInstanceNum);
+        for (Map.Entry<Integer, List<Expression>> entry : 
countToCandidate.descendingMap().entrySet()) {
+            Optional<Expression> chosen = 
chooseOneBalancedKey(entry.getValue(), inputStats, totalInstanceNum);
             if (chosen.isPresent()) {
                 return chosen;
             }
         }
         return Optional.empty();
-
-    }
-
-    /**
-     * ROLLUP: prefer earliest prefix key; if NDV is too low, fallback to next 
prefix.
-     */
-    private Optional<Expression> chooseByRollupPrefixThenNdv(List<Expression> 
candidates, Statistics inputStats,
-            int totalInstanceNum) {
-        for (Expression c : candidates) {
-            if (estimateNdv(c, inputStats) >= totalInstanceNum) {
-                return Optional.of(c);
-            }
-        }
-        return Optional.empty();
-    }
-
-    private double estimateNdv(Expression expr, Statistics stats) {
-        if (stats == null) {
-            return -1D;
-        }
-        ColumnStatistic col = stats.findColumnStatistics(expr);
-        if (col == null || col.isUnKnown()) {
-            return -1D;
-        }
-        return col.ndv;
     }
 
     /**
diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java 
b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
index 122065f6d35..02929b38868 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java
@@ -839,6 +839,8 @@ public class SessionVariable implements Serializable, 
Writable {
     public static final String SKEW_REWRITE_JOIN_SALT_EXPLODE_FACTOR = 
"skew_rewrite_join_salt_explode_factor";
 
     public static final String SKEW_REWRITE_AGG_BUCKET_NUM = 
"skew_rewrite_agg_bucket_num";
+    public static final String AGG_SHUFFLE_USE_PARENT_KEY = 
"agg_shuffle_use_parent_key";
+
     public static final String DECOMPOSE_REPEAT_THRESHOLD = 
"decompose_repeat_threshold";
     public static final String DECOMPOSE_REPEAT_SHUFFLE_INDEX_IN_MAX_GROUP
             = "decompose_repeat_shuffle_index_in_max_group";
@@ -850,6 +852,7 @@ public class SessionVariable implements Serializable, 
Writable {
                                 + "proportion as hot values, up to 
HOT_VALUE_COLLECT_COUNT."})
     public int hotValueCollectCount = 10; // Select the values that account 
for at least 10% of the column
 
+
     public void setHotValueCollectCount(int count) {
         this.hotValueCollectCount = count;
     }
@@ -2791,6 +2794,12 @@ public class SessionVariable implements Serializable, 
Writable {
             }, checker = "checkSkewRewriteAggBucketNum")
     public int skewRewriteAggBucketNum = 1024;
 
+    @VariableMgr.VarAttr(name = AGG_SHUFFLE_USE_PARENT_KEY, description = {
+            "在聚合算子进行 shuffle 时,是否使用父节点的分组键进行 shuffle",
+            "Whether to use the parent node's grouping key for shuffling 
during the aggregation operator"
+    }, needForward = false)
+    public boolean aggShuffleUseParentKey = true;
+
     @VariableMgr.VarAttr(name = ENABLE_PREFER_CACHED_ROWSET, needForward = 
false,
             description = {"是否启用 prefer cached rowset 功能",
                     "Whether to enable prefer cached rowset feature"})
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/statistics/util/StatisticsUtil.java 
b/fe/fe-core/src/main/java/org/apache/doris/statistics/util/StatisticsUtil.java
index 014c9ace70f..02d275795de 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/statistics/util/StatisticsUtil.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/statistics/util/StatisticsUtil.java
@@ -64,6 +64,7 @@ import 
org.apache.doris.nereids.trees.expressions.literal.Literal;
 import org.apache.doris.nereids.trees.expressions.literal.TimestampTzLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
 import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.util.AggregateUtils;
 import org.apache.doris.qe.AuditLogHelper;
 import org.apache.doris.qe.AutoCloseConnectContext;
 import org.apache.doris.qe.ConnectContext;
@@ -1322,4 +1323,23 @@ public class StatisticsUtil {
         }
         return null;
     }
+
+    public static boolean isBalanced(ColumnStatistic columnStatistic, double 
rowCount, int instanceNum) {
+        double ndv = columnStatistic.ndv;
+        double maxHotValueCntIncludeNull;
+        Map<Literal, Float> hotValues = columnStatistic.getHotValues();
+        // When hotValues not exist, or exist but unknown, treat nulls as the 
only hot value.
+        if (columnStatistic.getHotValues() == null || hotValues.isEmpty()) {
+            maxHotValueCntIncludeNull = columnStatistic.numNulls;
+        } else {
+            double rate = 
hotValues.values().stream().mapToDouble(Float::doubleValue).max().orElse(0);
+            maxHotValueCntIncludeNull = rate * rowCount > 
columnStatistic.numNulls
+                    ? rate * rowCount : columnStatistic.numNulls;
+        }
+        double rowsPerInstance = (rowCount - maxHotValueCntIncludeNull) / 
instanceNum;
+        double balanceFactor = maxHotValueCntIncludeNull == 0
+                ? Double.MAX_VALUE : rowsPerInstance / 
maxHotValueCntIncludeNull;
+        // The larger this factor is, the more balanced the data.
+        return balanceFactor > 2.0 && ndv > instanceNum * 3 && ndv > 
AggregateUtils.LOW_NDV_THRESHOLD;
+    }
 }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java
index 4524fc10e5e..32a6adcf212 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/RequestPropertyDeriverTest.java
@@ -368,4 +368,96 @@ class RequestPropertyDeriverTest {
         expected.add(Lists.newArrayList(PhysicalProperties.GATHER));
         Assertions.assertEquals(expected, actual);
     }
+
+    @Test
+    void testAggregateWithAggShuffleUseParentKeyDisabled() {
+        // Create ConnectContext with aggShuffleUseParentKey = false
+        ConnectContext testConnectContext = new ConnectContext();
+        testConnectContext.getSessionVariable().aggShuffleUseParentKey = false;
+
+        SlotReference key1 = new SlotReference(new ExprId(0), "col1", 
IntegerType.INSTANCE, true, ImmutableList.of());
+        SlotReference key2 = new SlotReference(new ExprId(1), "col2", 
IntegerType.INSTANCE, true, ImmutableList.of());
+        PhysicalHashAggregate<GroupPlan> aggregate = new 
PhysicalHashAggregate<>(
+                Lists.newArrayList(key1, key2),
+                Lists.newArrayList(key1, key2),
+                new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_RESULT),
+                true,
+                logicalProperties,
+                groupPlan
+        );
+        GroupExpression groupExpression = new GroupExpression(aggregate);
+        new Group(null, groupExpression, null);
+
+        // Create a parent hash distribution with key1 only
+        PhysicalProperties parentProperties = PhysicalProperties.createHash(
+                Lists.newArrayList(key1.getExprId()), ShuffleType.REQUIRE);
+
+        new Expectations() {
+            {
+                jobContext.getRequiredProperties();
+                result = parentProperties;
+            }
+        };
+
+        RequestPropertyDeriver requestPropertyDeriver = new 
RequestPropertyDeriver(testConnectContext, jobContext);
+        List<List<PhysicalProperties>> actual
+                = 
requestPropertyDeriver.getRequestChildrenPropertyList(groupExpression);
+
+        // When aggShuffleUseParentKey is false, should only use all 
groupByExpressions (key1, key2)
+        // and not use parent key (key1) separately
+        List<List<PhysicalProperties>> expected = Lists.newArrayList();
+        expected.add(Lists.newArrayList(PhysicalProperties.createHash(
+                Lists.newArrayList(key1.getExprId(), key2.getExprId()), 
ShuffleType.REQUIRE)));
+        Assertions.assertEquals(1, actual.size());
+        Assertions.assertEquals(expected, actual);
+    }
+
+    @Test
+    void testAggregateWithAggShuffleUseParentKeyEnabled() {
+        // Create ConnectContext with aggShuffleUseParentKey = true (default 
value)
+        ConnectContext testConnectContext = new ConnectContext();
+        testConnectContext.getSessionVariable().aggShuffleUseParentKey = true;
+
+        SlotReference key1 = new SlotReference(new ExprId(0), "col1", 
IntegerType.INSTANCE, true, ImmutableList.of());
+        SlotReference key2 = new SlotReference(new ExprId(1), "col2", 
IntegerType.INSTANCE, true, ImmutableList.of());
+        PhysicalHashAggregate<GroupPlan> aggregate = new 
PhysicalHashAggregate<>(
+                Lists.newArrayList(key1, key2),
+                Lists.newArrayList(key1, key2),
+                new AggregateParam(AggPhase.GLOBAL, AggMode.BUFFER_TO_RESULT),
+                true,
+                logicalProperties,
+                groupPlan
+        );
+        GroupExpression groupExpression = new GroupExpression(aggregate);
+        new Group(null, groupExpression, null);
+
+        // Create a parent hash distribution with key1 only
+        PhysicalProperties parentProperties = PhysicalProperties.createHash(
+                Lists.newArrayList(key1.getExprId()), ShuffleType.REQUIRE);
+
+        new Expectations() {
+            {
+                jobContext.getRequiredProperties();
+                result = parentProperties;
+            }
+        };
+        new MockUp<org.apache.doris.nereids.memo.GroupExpression>() {
+            @mockit.Mock
+            org.apache.doris.statistics.Statistics childStatistics(int idx) {
+                return null;
+            }
+        };
+        RequestPropertyDeriver requestPropertyDeriver = new 
RequestPropertyDeriver(testConnectContext, jobContext);
+        List<List<PhysicalProperties>> actual
+                = 
requestPropertyDeriver.getRequestChildrenPropertyList(groupExpression);
+
+        // When aggShuffleUseParentKey is true, shouldUseParent may return true
+        // If shouldUseParent returns true, it will add parent key (key1) 
first, then all groupByExpressions (key1, key2)
+        Assertions.assertEquals(2, actual.size(), "Should have at least one 
property request");
+        PhysicalProperties parentProp = PhysicalProperties.createHash(
+                Lists.newArrayList(key1.getExprId()), ShuffleType.REQUIRE);
+        PhysicalProperties aggProp = PhysicalProperties.createHash(
+                Lists.newArrayList(key1.getExprId(), key2.getExprId()), 
ShuffleType.REQUIRE);
+        Assertions.assertTrue(actual.contains(ImmutableList.of(aggProp)) && 
actual.contains(ImmutableList.of(parentProp)));
+    }
 }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregationTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregationTest.java
index d1187b0e7cc..70526298b73 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregationTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/DecomposeRepeatWithPreAggregationTest.java
@@ -40,6 +40,10 @@ import org.apache.doris.nereids.types.IntegerType;
 import org.apache.doris.nereids.util.MemoPatternMatchSupported;
 import org.apache.doris.nereids.util.MemoTestUtils;
 import org.apache.doris.nereids.util.PlanChecker;
+import org.apache.doris.qe.ConnectContext;
+import org.apache.doris.statistics.ColumnStatistic;
+import org.apache.doris.statistics.ColumnStatisticBuilder;
+import org.apache.doris.statistics.Statistics;
 import org.apache.doris.utframe.TestWithFeService;
 
 import com.google.common.collect.ImmutableList;
@@ -51,6 +55,7 @@ import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.Optional;
 import java.util.Set;
 
 /**
@@ -273,7 +278,7 @@ public class DecomposeRepeatWithPreAggregationTest extends 
TestWithFeService imp
 
     @Test
     public void testCanOptimize() throws Exception {
-        Method method = rule.getClass().getDeclaredMethod("canOptimize", 
LogicalAggregate.class);
+        Method method = rule.getClass().getDeclaredMethod("canOptimize", 
LogicalAggregate.class, ConnectContext.class);
         method.setAccessible(true);
 
         SlotReference a = new SlotReference("a", IntegerType.INSTANCE);
@@ -304,7 +309,7 @@ public class DecomposeRepeatWithPreAggregationTest extends 
TestWithFeService imp
                 ImmutableList.of(a, b, c, d, sumAlias),
                 repeat);
 
-        int result = (int) method.invoke(rule, aggregate);
+        int result = (int) method.invoke(rule, aggregate, connectContext);
         Assertions.assertEquals(0, result);
 
         // Test case 2: Child is not LogicalRepeat
@@ -312,7 +317,7 @@ public class DecomposeRepeatWithPreAggregationTest extends 
TestWithFeService imp
                 ImmutableList.of(a),
                 ImmutableList.of(a, sumAlias),
                 emptyRelation);
-        result = (int) method.invoke(rule, aggregateWithNonRepeat);
+        result = (int) method.invoke(rule, aggregateWithNonRepeat, 
connectContext);
         Assertions.assertEquals(-1, result);
 
         // Test case 3: Unsupported aggregate function (Avg)
@@ -323,7 +328,7 @@ public class DecomposeRepeatWithPreAggregationTest extends 
TestWithFeService imp
                 ImmutableList.of(a, b, c, d),
                 ImmutableList.of(a, b, c, d, avgAlias),
                 repeat);
-        result = (int) method.invoke(rule, aggregateWithCount);
+        result = (int) method.invoke(rule, aggregateWithCount, connectContext);
         Assertions.assertEquals(-1, result);
 
         // Test case 4: Grouping sets size <= 3
@@ -341,7 +346,7 @@ public class DecomposeRepeatWithPreAggregationTest extends 
TestWithFeService imp
                 ImmutableList.of(a, b),
                 ImmutableList.of(a, b, sumAlias),
                 smallRepeat);
-        result = (int) method.invoke(rule, aggregateWithSmallRepeat);
+        result = (int) method.invoke(rule, aggregateWithSmallRepeat, 
connectContext);
         Assertions.assertEquals(-1, result);
     }
 
@@ -393,7 +398,7 @@ public class DecomposeRepeatWithPreAggregationTest extends 
TestWithFeService imp
     @Test
     public void testConstructProducer() throws Exception {
         Method method = rule.getClass().getDeclaredMethod("constructProducer",
-                LogicalAggregate.class, int.class, 
DistinctSelectorContext.class, Map.class);
+                LogicalAggregate.class, int.class, 
DistinctSelectorContext.class, Map.class, ConnectContext.class);
         method.setAccessible(true);
 
         SlotReference a = new SlotReference("a", IntegerType.INSTANCE);
@@ -414,7 +419,7 @@ public class DecomposeRepeatWithPreAggregationTest extends 
TestWithFeService imp
         LogicalRepeat<Plan> repeat = new LogicalRepeat<>(
                 groupingSets,
                 (List) ImmutableList.of(a, b, c, d),
-                null,
+                RepeatType.GROUPING_SETS,
                 emptyRelation);
         Sum sumFunc = new Sum(d);
         Alias sumAlias = new Alias(sumFunc, "sum_d");
@@ -425,7 +430,7 @@ public class DecomposeRepeatWithPreAggregationTest extends 
TestWithFeService imp
 
         Map<Slot, Slot> preToCloneSlotMap = new HashMap<>();
         LogicalCTEProducer<LogicalAggregate<Plan>> result = 
(LogicalCTEProducer<LogicalAggregate<Plan>>)
-                method.invoke(rule, aggregate, 0, ctx, preToCloneSlotMap);
+                method.invoke(rule, aggregate, 0, ctx, preToCloneSlotMap, 
connectContext);
 
         Assertions.assertNotNull(result);
         Assertions.assertNotNull(result.child());
@@ -485,4 +490,154 @@ public class DecomposeRepeatWithPreAggregationTest 
extends TestWithFeService imp
         Assertions.assertEquals(2, result.getGroupingSets().size());
         Assertions.assertTrue(groupingFunctionSlots.isEmpty());
     }
+
+    @Test
+    public void testChoosePreAggShuffleKeyPartitionExprs() throws Exception {
+        Method method = 
rule.getClass().getDeclaredMethod("choosePreAggShuffleKeyPartitionExprs",
+                LogicalRepeat.class, int.class, List.class, 
org.apache.doris.qe.ConnectContext.class);
+        method.setAccessible(true);
+
+        SlotReference a = new SlotReference("a", IntegerType.INSTANCE);
+        SlotReference b = new SlotReference("b", IntegerType.INSTANCE);
+        SlotReference c = new SlotReference("c", IntegerType.INSTANCE);
+
+        List<Expression> maxGroupByList = ImmutableList.of(a, b, c);
+        LogicalEmptyRelation emptyRelation = new LogicalEmptyRelation(
+                
org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator.newRelationId(),
+                ImmutableList.of());
+        List<List<Expression>> groupingSets = ImmutableList.of(
+                ImmutableList.of(a, b, c),
+                ImmutableList.of(a, b),
+                ImmutableList.of(a)
+        );
+        LogicalRepeat<Plan> repeatRollup = new LogicalRepeat<>(
+                groupingSets,
+                (List) ImmutableList.of(a, b, c),
+                null,
+                RepeatType.ROLLUP,
+                emptyRelation);
+        LogicalRepeat<Plan> repeatGroupingSets = new LogicalRepeat<>(
+                groupingSets,
+                (List) ImmutableList.of(a, b, c),
+                new SlotReference("grouping_id", IntegerType.INSTANCE),
+                RepeatType.GROUPING_SETS,
+                emptyRelation);
+        LogicalRepeat<Plan> repeatCube = new LogicalRepeat<>(
+                groupingSets,
+                (List) ImmutableList.of(a, b, c),
+                new SlotReference("grouping_id", IntegerType.INSTANCE),
+                RepeatType.CUBE,
+                emptyRelation);
+
+        // Case 1: Session variable decomposeRepeatShuffleIndexInMaxGroup = 0, 
should return third expr
+        
connectContext.getSessionVariable().decomposeRepeatShuffleIndexInMaxGroup = 2;
+        @SuppressWarnings("unchecked")
+        Optional<List<Expression>> result2 = (Optional<List<Expression>>) 
method.invoke(
+                rule, repeatRollup, 0, maxGroupByList, connectContext);
+        Assertions.assertTrue(result2.isPresent());
+        Assertions.assertEquals(1, result2.get().size());
+        Assertions.assertEquals(c, result2.get().get(0));
+
+        // Case 2: Session variable = -1 (default), fall through to 
repeat-type logic (may be empty if no stats)
+        
connectContext.getSessionVariable().decomposeRepeatShuffleIndexInMaxGroup = -1;
+        @SuppressWarnings("unchecked")
+        Optional<List<Expression>> resultDefault = 
(Optional<List<Expression>>) method.invoke(
+                rule, repeatRollup, 0, maxGroupByList, connectContext);
+        // With no column stats, chooseByRollupPrefixThenNdv typically returns 
empty
+        Assertions.assertEquals(resultDefault, Optional.empty());
+
+        // Case 3: Session variable out of range (>= size), should not use 
index, fall through
+        
connectContext.getSessionVariable().decomposeRepeatShuffleIndexInMaxGroup = 10;
+        @SuppressWarnings("unchecked")
+        Optional<List<Expression>> resultOutOfRange = 
(Optional<List<Expression>>) method.invoke(
+                rule, repeatRollup, 0, maxGroupByList, connectContext);
+        Assertions.assertEquals(resultOutOfRange, Optional.empty());
+
+        // Case 4: RepeatType GROUPING_SETS and CUBE (smoke test, result 
depends on stats)
+        
connectContext.getSessionVariable().decomposeRepeatShuffleIndexInMaxGroup = -1;
+        @SuppressWarnings("unchecked")
+        Optional<List<Expression>> resultGs = (Optional<List<Expression>>) 
method.invoke(
+                rule, repeatGroupingSets, 0, maxGroupByList, connectContext);
+        Assertions.assertEquals(resultGs, Optional.empty());
+
+        // Case 5: RepeatType GROUPING_SETS and CUBE (smoke test, result 
depends on stats)
+        
connectContext.getSessionVariable().decomposeRepeatShuffleIndexInMaxGroup = -1;
+        @SuppressWarnings("unchecked")
+        Optional<List<Expression>> resultCb = (Optional<List<Expression>>) 
method.invoke(
+                rule, repeatCube, 0, maxGroupByList, connectContext);
+        Assertions.assertEquals(resultCb, Optional.empty());
+
+        // Restore default
+        
connectContext.getSessionVariable().decomposeRepeatShuffleIndexInMaxGroup = -1;
+    }
+
+    /** Helper: build Statistics with column ndv for given expressions. */
+    private static Statistics statsWithNdv(Map<Expression, Double> exprToNdv) {
+        Map<Expression, ColumnStatistic> map = new HashMap<>();
+        for (Map.Entry<Expression, Double> e : exprToNdv.entrySet()) {
+            ColumnStatistic col = new ColumnStatisticBuilder(1)
+                    .setNdv(e.getValue())
+                    .setAvgSizeByte(4)
+                    .setNumNulls(0)
+                    .setMinValue(0)
+                    .setMaxValue(100)
+                    .setIsUnknown(false)
+                    .setUpdatedTime("")
+                    .build();
+            map.put(e.getKey(), col);
+        }
+        return new Statistics(100, map);
+    }
+
+    @Test
+    public void testChooseByAppearanceThenNdv() throws Exception {
+        Method method = 
rule.getClass().getDeclaredMethod("chooseByAppearanceThenNdv",
+                List.class, int.class, List.class, Statistics.class, 
int.class);
+        method.setAccessible(true);
+
+        SlotReference a = new SlotReference("a", IntegerType.INSTANCE);
+        SlotReference b = new SlotReference("b", IntegerType.INSTANCE);
+        SlotReference c = new SlotReference("c", IntegerType.INSTANCE);
+        List<Expression> candidates = ImmutableList.of(a, b, c);
+
+        // grouping sets: index 0 = max (a,b,c), index 1 = (a,b), index 2 = (a)
+        // non-max: (a,b) and (a). a appears 2, b appears 1, c appears 1.
+        // countToCandidate: 1->[b,c], 2->[a]. TreeMap iterates 1 then 2.
+        // For count 1: chooseByNdv([b,c], stats, total). Need ndv > total to 
return. b:60, c:80, total=50 -> max ndv 80>50 -> return c.
+        List<List<Expression>> groupingSets = ImmutableList.of(
+                ImmutableList.of(a, b, c),
+                ImmutableList.of(a, c),
+                ImmutableList.of(c)
+        );
+
+        Map<Expression, Double> exprToNdv = new HashMap<>();
+        exprToNdv.put(a, 40.0);
+        exprToNdv.put(b, 60.0);
+        exprToNdv.put(c, 50.0);
+        Statistics stats = statsWithNdv(exprToNdv);
+
+        @SuppressWarnings("unchecked")
+        Optional<Expression> chosen = (Optional<Expression>) method.invoke(
+                rule, groupingSets, -1, candidates, stats, 15);
+        Assertions.assertTrue(chosen.isPresent());
+        Assertions.assertEquals(c, chosen.get());
+
+        // When no candidate has ndv > totalInstanceNum, return empty
+        @SuppressWarnings("unchecked")
+        Optional<Expression> empty = (Optional<Expression>) method.invoke(
+                rule, groupingSets, -1, candidates, stats, 1000);
+        Assertions.assertFalse(empty.isPresent());
+
+        @SuppressWarnings("unchecked")
+        Optional<Expression> chosen2 = (Optional<Expression>) method.invoke(
+                rule, groupingSets, -1, candidates, stats, 18);
+        Assertions.assertTrue(chosen2.isPresent());
+        Assertions.assertEquals(b, chosen2.get());
+
+        // inputStats null -> chooseByNdv returns empty for every group -> 
empty
+        @SuppressWarnings("unchecked")
+        Optional<Expression> emptyNullStats = (Optional<Expression>) 
method.invoke(
+                rule, groupingSets, -1, candidates, null, 18);
+        Assertions.assertFalse(emptyNullStats.isPresent());
+    }
 }
diff --git 
a/regression-test/suites/nereids_rules_p0/decompose_repeat/decompose_repeat.groovy
 
b/regression-test/suites/nereids_rules_p0/decompose_repeat/decompose_repeat.groovy
index 5d06776679b..bf1a2fd0407 100644
--- 
a/regression-test/suites/nereids_rules_p0/decompose_repeat/decompose_repeat.groovy
+++ 
b/regression-test/suites/nereids_rules_p0/decompose_repeat/decompose_repeat.groovy
@@ -23,11 +23,12 @@ suite("decompose_repeat") {
     order_qt_sum "select a,b,c,sum(d) from t1 group by rollup(a,b,c);"
     order_qt_agg_func_gby_key_same_col "select a,b,c,d,sum(d) from t1 group by 
rollup(a,b,c,d);"
     order_qt_multi_agg_func "select a,b,c,sum(d),sum(c),max(a) from t1 group 
by rollup(a,b,c,d);"
-    order_qt_nest_rewrite """
-    select a,b,c,c1 from (
-    select a,b,c,d,sum(d) c1 from t1 group by grouping 
sets((a,b,c),(a,b,c,d),(a),(a,b,c,c))
-    ) t group by rollup(a,b,c,c1);
-    """
+    // maybe this problem:DORIS-24075
+//    order_qt_nest_rewrite """
+//    select a,b,c,c1 from (
+//    select a,b,c,d,sum(d) c1 from t1 group by grouping 
sets((a,b,c),(a,b,c,d),(a),(a,b,c,c))
+//    ) t group by rollup(a,b,c,c1);
+//    """
     order_qt_upper_ref """
     select c1+10,a,b,c from (select a,b,c,sum(d) c1 from t1 group by 
rollup(a,b,c)) t group by c1+10,a,b,c;
     """
@@ -103,4 +104,17 @@ suite("decompose_repeat") {
     order_qt_grouping_max_not_first "select a,b,c,d, grouping_id(c,d) from t1 
group by grouping sets((a,b),(a,b,c),(a,b,c,d),());"
     // Test case: complex case with aggregation function and grouping function
     order_qt_grouping_with_agg "select a,b,c,d, sum(d), grouping_id(a,b,c) 
from t1 group by grouping sets((a,b,c,d),(a,b,c),(a),());"
+
+    // test empty grouping set
+    multi_sql """drop table if exists t_repeat_pick_shuffle_key;
+    create table t_repeat_pick_shuffle_key(a int, b int, c int, d int);
+    alter table t_repeat_pick_shuffle_key modify column a set stats 
('row_count'='300000', 'ndv'='10', 'num_nulls'='0', 'min_value'='1', 
'max_value'='300000', 'data_size'='2400000');
+    alter table t_repeat_pick_shuffle_key modify column b set stats 
('row_count'='300000', 'ndv'='100', 'num_nulls'='0', 'min_value'='1', 
'max_value'='300000', 'data_size'='2400000');
+    alter table t_repeat_pick_shuffle_key modify column c set stats 
('row_count'='300000', 'ndv'='1000', 'num_nulls'='0', 'min_value'='1', 
'max_value'='300000', 'data_size'='2400000');
+    alter table t_repeat_pick_shuffle_key modify column d set stats 
('row_count'='300000', 'ndv'='10000', 'num_nulls'='0', 'min_value'='1', 
'max_value'='300000', 'data_size'='2400000');"""
+    sql "select 2 from t_repeat_pick_shuffle_key group by grouping 
sets((),(),(),());"
+    sql "select a,b,c,d from t_repeat_pick_shuffle_key group by 
rollup(a,b,c,d);"
+    sql "select a,b,c,d from t_repeat_pick_shuffle_key group by cube(a,b,c,d);"
+    sql "select a,b,c,d from t_repeat_pick_shuffle_key group by grouping 
sets((a,b,c,d),(b,c,d),(c),(c,a));"
+
 }
\ No newline at end of file


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to