morrySnow commented on code in PR #63690:
URL: https://github.com/apache/doris/pull/63690#discussion_r3394347891


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/EagerAggRewriter.java:
##########
@@ -534,21 +656,382 @@ public Plan visitLogicalRelation(LogicalRelation 
relation, PushDownAggContext co
     }
 
     private Plan genAggregate(Plan child, PushDownAggContext context) {
-        if (context.isValid() && checkStats(child, context)) {
+        if (isPushDisabledByVariable(context)) {
+            context.getBilateralState().registerNoCountSlot(child);
+            return child;
+        }
+        if (checkStats(child, context) || isPushEnabledByVariable(context)) {
             List<NamedExpression> aggOutputExpressions = new ArrayList<>();
             for (AggregateFunction func : context.getAggFunctions()) {
                 aggOutputExpressions.add(context.getAliasMap().get(func));
             }
+            Alias countStarAlias = null;
+            boolean countStarAlreadyProjected = false;
+            Count countStar = new Count();
+            if (context.getAliasMap().containsKey(countStar)) {
+                countStarAlias = context.getAliasMap().get(countStar);
+                countStarAlreadyProjected = true;
+            } else {
+                countStarAlias = new Alias(countStar,
+                        "cnt" + 
context.getCascadesContext().getStatementContext().generateColumnName());
+            }
             aggOutputExpressions.addAll(context.getGroupKeys());
+            if (countStarAlias != null && !countStarAlreadyProjected) {
+                aggOutputExpressions.add(countStarAlias);
+            }
             LogicalAggregate genAgg = new 
LogicalAggregate(context.getGroupKeys(), aggOutputExpressions, child);
             NormalizeAggregate normalizeAggregate = new NormalizeAggregate();
-            return normalizeAggregate.normalizeAgg(genAgg, Optional.empty(),
+            Plan normalized = normalizeAggregate.normalizeAgg(genAgg, 
Optional.empty(),
                     context.getCascadesContext());
+
+            for (AggregateFunction func : context.getAggFunctions()) {
+                Alias a = context.getAliasMap().get(func);
+                
context.getBilateralState().registerPushedAggFuncSlot(a.getExprId(), 
a.toSlot());
+            }
+
+            if (countStarAlias != null) {
+                context.getBilateralState().registerCountSlot(normalized, 
countStarAlias.toSlot());
+            } else {
+                context.getBilateralState().registerNoCountSlot(normalized);
+            }
+            return normalized;
         } else {
+            context.getBilateralState().registerNoCountSlot(child);
             return child;
         }
     }
 
+    // Build the canonical project above a rewritten join after 
eager-aggregation pushdown.
+    // Responsibilities:
+    // 1. Restore the outputs expected by the parent rollup. If a join side 
has a childContext, materialize
+    //    that side's aggregate current values and group keys; otherwise 
forward the original join outputs.
+    // 2. For inner joins, recover join multiplicity by multiplying 
non-MIN/MAX aggregate current values by
+    //    the opposite side's count slot when that side contributes rows to 
the parent aggregate.
+    // 3. Append and register a synthetic join-count slot `cnt` (logical jcnt) 
for upper-level rollup.
+    //
+    // The examples below are schematic. The real project may keep extra 
forwarded slots such as join keys.
+    //
+    // Inner join + sum, single-side rewrite:
+    //   Before:
+    //     agg(sum(t1.a), sum(t2.a), gby t2.k)
+    //       -> inner join(k = k)
+    //            -> scan(t1)
+    //            -> scan(t2)
+    //   After:
+    //     agg(sum(s1), sum(s2), gby t2.k)
+    //       -> project(s1, t2.a * cnt1 as s2, t2.k, cnt1 as cnt)
+    //            -> inner join(k = k)
+    //                 -> agg(sum(t1.a) as s1, count(*) as cnt1, gby k)
+    //                      -> scan(t1)
+    //                 -> scan(t2)
+    //
+    // Inner join + sum, bilateral rewrite:
+    //   Before:
+    //     agg(sum(t1.a), sum(t2.a), gby t2.k)
+    //       -> inner join(k = k)
+    //            -> scan(t1)
+    //            -> scan(t2)
+    //   After:
+    //     agg(sum(s1'), sum(s2'), gby t2.k)
+    //       -> project(s1 * cnt2 as s1', s2 * cnt1 as s2', t2.k, cnt1 * cnt2 
as cnt)
+    //            -> inner join(k = k)
+    //                 -> agg(sum(t1.a) as s1, count(*) as cnt1, gby k)
+    //                      -> scan(t1)
+    //                 -> agg(sum(t2.a) as s2, count(*) as cnt2, gby k)
+    //                      -> scan(t2)
+    //
+    // Inner join + count(col), single-side rewrite:
+    //   Before:
+    //     agg(count(t1.a), count(t2.a), gby t2.k)
+    //       -> inner join(k = k)
+    //            -> scan(t1)
+    //            -> scan(t2)
+    //   After:
+    //     agg(sum0(c1), sum0(c2), gby t2.k)
+    //       -> project(c1, if(t2.a is null, 0, 1) * cnt1 as c2, t2.k, cnt1 as 
cnt)
+    //            -> inner join(k = k)
+    //                 -> agg(count(t1.a) as c1, count(*) as cnt1, gby k)
+    //                      -> scan(t1)
+    //                 -> scan(t2)
+    //
+    // Inner join + count(col), bilateral rewrite:
+    //   Before:
+    //     agg(count(t1.a), count(t2.a), gby t2.k)
+    //       -> inner join(k = k)
+    //            -> scan(t1)
+    //            -> scan(t2)
+    //   After:
+    //     agg(sum0(c1'), sum0(c2'), gby t2.k)
+    //       -> project(c1 * cnt2 as c1', c2 * cnt1 as c2', t2.k, cnt1 * cnt2 
as cnt)
+    //            -> inner join(k = k)
+    //                 -> agg(count(t1.a) as c1, count(*) as cnt1, gby k)
+    //                      -> scan(t1)
+    //                 -> agg(count(t2.a) as c2, count(*) as cnt2, gby k)
+    //                      -> scan(t2)
+    //   For count(*), the current row value is 1 instead of if(col is null, 
0, 1).
+    //
+    // Semi/anti join:
+    //   The project does not multiply by the opposite-side count
+    //
+    // Outer join:
+    //   Aggregate outputs are not multiplied by the opposite-side count 
either; only `cnt` changes:
+    //     left outer join with left push  -> project(s1, t2.k, cnt1 as cnt)
+    //     right outer join with left push -> project(s1, t2.k, nvl(cnt1, 1) 
as cnt)
+    private Plan buildCanonicalJoinProject(LogicalJoin<? extends Plan, ? 
extends Plan> join, PushDownAggContext context,
+            Optional<PushDownAggContext> leftChildContext, 
Optional<PushDownAggContext> rightChildContext,
+            Optional<Slot> leftCountSlot, Optional<Slot> rightCountSlot) {
+        List<NamedExpression> projections = new ArrayList<>();
+        Set<ExprId> outputIds = new HashSet<>();
+        boolean remainLeft = join.getJoinType().isRemainLeftJoin();
+        boolean remainRight = join.getJoinType().isRemainRightJoin();
+        boolean shouldAdjustLeft = 
shouldUseJoinOppositeCntAdjustAggOutput(join, leftChildContext, rightCountSlot);
+        boolean shouldAdjustRight = 
shouldUseJoinOppositeCntAdjustAggOutput(join, rightChildContext, leftCountSlot);
+
+        if (remainLeft) {
+            appendJoinSideOutputs(projections, outputIds, join.left(), 
leftChildContext, context,
+                    rightCountSlot, shouldAdjustLeft);
+        }
+        if (remainRight) {
+            appendJoinSideOutputs(projections, outputIds, join.right(), 
rightChildContext, context,
+                    leftCountSlot, shouldAdjustRight);
+        }
+
+        Optional<Expression> joinCount = computeJoinCount(join, 
leftChildContext, rightChildContext,
+                leftCountSlot, rightCountSlot);
+        Optional<Slot> projectedCountSlot = Optional.empty();
+        if (joinCount.isPresent()) {
+            Alias countAlias = new Alias(joinCount.get(),
+                    "joinCnt" + 
context.getCascadesContext().getStatementContext().generateColumnName());
+            projections.add(countAlias);
+            projectedCountSlot = Optional.of(countAlias.toSlot());
+        }
+        LogicalProject<Plan> project = new LogicalProject<>(projections, join);
+        if (projectedCountSlot.isPresent()) {
+            context.getBilateralState().registerCountSlot(project,
+                    (Slot) project.getOutput().get(project.getOutput().size() 
- 1));
+        } else {
+            context.getBilateralState().registerNoCountSlot(project);
+        }
+        return project;
+    }
+
+    private void appendJoinSideOutputs(List<NamedExpression> projections, 
Set<ExprId> outputIds, Plan originalSide,
+            Optional<PushDownAggContext> childContext, PushDownAggContext 
parentContext,
+            Optional<Slot> oppositeCountSlot, boolean shouldAdjustOutput) {
+        if (childContext.isPresent()) {
+            for (AggregateFunction aggFunc : 
childContext.get().getAggFunctions()) {
+                NamedExpression aggOutput = shouldAdjustOutput
+                        ? adjustAggOutputUseOppositeCountOnJoin(aggFunc, 
parentContext, oppositeCountSlot)
+                        : buildAggOutputWithoutJoinAdjustment(aggFunc, 
parentContext);
+                appendProjectionIfAbsent(projections, outputIds, aggOutput);
+            }
+            for (SlotReference groupKey : childContext.get().getGroupKeys()) {
+                appendProjectionIfAbsent(projections, outputIds, groupKey);
+            }
+        } else {
+            for (Slot slot : originalSide.getOutput()) {
+                appendProjectionIfAbsent(projections, outputIds, slot);
+            }
+        }
+    }
+
+    private void appendProjectionIfAbsent(List<NamedExpression> projections, 
Set<ExprId> outputIds,
+            NamedExpression expression) {
+        if (outputIds.add(expression.getExprId())) {
+            projections.add(expression);
+        }
+    }
+
+    private boolean shouldUseJoinOppositeCntAdjustAggOutput(LogicalJoin<? 
extends Plan, ? extends Plan> join,
+            Optional<PushDownAggContext> childContext, Optional<Slot> 
oppositeCountSlot) {
+        return join.getJoinType().isInnerOrCrossJoin() && 
childContext.isPresent() && oppositeCountSlot.isPresent();
+    }
+
+    private Optional<Expression> computeJoinCount(LogicalJoin<? extends Plan, 
? extends Plan> join,
+            Optional<PushDownAggContext> leftChildContext, 
Optional<PushDownAggContext> rightChildContext,
+            Optional<Slot> leftCountSlot, Optional<Slot> rightCountSlot) {
+        JoinType joinType = join.getJoinType();
+        if (joinType.isInnerJoin()) {
+            if (leftCountSlot.isPresent() && rightCountSlot.isPresent()) {
+                return Optional.of(ExpressionUtils.rebuildSignature(
+                        new Multiply(leftCountSlot.get(), 
rightCountSlot.get())));
+            } else if (leftCountSlot.isPresent()) {
+                return Optional.of(leftCountSlot.get());
+            } else if (rightCountSlot.isPresent()) {
+                return Optional.of(rightCountSlot.get());
+            }
+            return Optional.empty();
+        }
+        if (joinType.isLeftOuterJoin()) {
+            if (leftChildContext.isPresent()) {
+                return leftCountSlot.map(cnt -> (Expression) cnt);
+            }
+            if (rightChildContext.isPresent()) {
+                return rightCountSlot.map(cnt -> (Expression) 
ExpressionUtils.rebuildSignature(
+                        new Nvl(cnt, BigIntLiteral.of(1))));
+            }
+            return Optional.empty();
+        }
+        if (joinType.isRightOuterJoin()) {
+            if (leftChildContext.isPresent()) {
+                return leftCountSlot.map(cnt -> (Expression) 
ExpressionUtils.rebuildSignature(
+                        new Nvl(cnt, BigIntLiteral.of(1))));
+            }
+            if (rightChildContext.isPresent()) {
+                return rightCountSlot.map(cnt -> (Expression) cnt);
+            }
+            return Optional.empty();
+        }
+        if (joinType.isLeftSemiOrAntiJoin()) {
+            return leftCountSlot.map(cnt -> (Expression) cnt);
+        }
+        if (joinType.isRightSemiOrAntiJoin()) {
+            return rightCountSlot.map(cnt -> (Expression) cnt);
+        }
+        return Optional.empty();
+    }
+
+    private Plan buildCanonicalProject(Plan child, PushDownAggContext context, 
Slot countSlot) {
+        List<NamedExpression> projections = new ArrayList<>();
+        Set<ExprId> outputIds = new HashSet<>();
+        for (AggregateFunction aggFunc : context.getAggFunctions()) {
+            ExprId exprId = context.getAliasMap().get(aggFunc).getExprId();
+            NamedExpression aggOutput = 
context.getBilateralState().getPushedAggFuncSlot(exprId);
+            projections.add(aggOutput);
+            outputIds.add(aggOutput.getExprId());
+        }
+        for (SlotReference groupKey : context.getGroupKeys()) {
+            if (outputIds.add(groupKey.getExprId())) {
+                projections.add(groupKey);
+            }
+        }
+        projections.add(countSlot);
+        if (projections.equals(child.getOutput())) {
+            return child;
+        } else {
+            LogicalProject<Plan> project = new LogicalProject<>(projections, 
child);
+            context.getBilateralState().registerCountSlot(project, countSlot);
+            return project;
+        }
+    }
+
+    private NamedExpression 
buildAggOutputWithoutJoinAdjustment(AggregateFunction aggFunc, 
PushDownAggContext context) {
+        Alias alias = context.getAliasMap().get(aggFunc);
+        ExprId exprId = alias.getExprId();
+        BilateralState state = context.getBilateralState();
+        NamedExpression output;
+        if (state.hasAggFuncOutput(exprId)) {
+            output = state.getPushedAggFuncSlot(exprId);
+        } else {
+            Expression currentValue;
+            if (aggFunc instanceof Count) {
+                if (aggFunc.arity() == 0) {
+                    currentValue = BigIntLiteral.of(1);
+                } else {
+                    currentValue = new If(new IsNull(aggFunc.child(0)), 
BigIntLiteral.of(0), BigIntLiteral.of(1));
+                }
+            } else {
+                currentValue = aggFunc.child(0);
+            }
+            output = (Alias) alias.withChildren(currentValue);
+            state.registerAggFuncOutput(exprId, output.toSlot(), 
state.isAggFuncActuallyPushed(exprId));
+        }
+        return output;
+    }
+
+    private NamedExpression 
adjustAggOutputUseOppositeCountOnJoin(AggregateFunction aggFunc, 
PushDownAggContext context,
+            Optional<Slot> countSlot) {
+        Alias alias = context.getAliasMap().get(aggFunc);
+        ExprId exprId = alias.getExprId();
+        BilateralState state = context.getBilateralState();
+        Expression currentValue = getCurrentAggValue(aggFunc, exprId, state);
+        Optional<Expression> multiplier = Optional.empty();
+        if (!(aggFunc instanceof Max) && !(aggFunc instanceof Min)) {
+            multiplier = countSlot.map(cnt -> (Expression) cnt);
+        }
+        Expression outputExpr = multiplier.map(expression -> (Expression) new 
Multiply(currentValue, expression))
+                .orElse(currentValue);
+        outputExpr = ExpressionUtils.rebuildSignature(outputExpr);
+        NamedExpression output = new Alias(outputExpr);
+        state.registerAggFuncOutput(exprId, output.toSlot(), 
state.isAggFuncActuallyPushed(exprId));

Review Comment:
   For Max/Min where multiplier is empty, new Alias(outputExpr) wraps an 
already-existing SlotReference with a new Alias (new ExprId). This adds 
unnecessary indirection. When currentValue is already a slot (from 
getPushedAggFuncSlot), consider returning it directly instead of wrapping it in 
a new Alias. Also, new Alias(outputExpr) auto-generates name from 
outputExpr.toSql() which produces confusing names like (s1 * cnt2) for Multiply 
expressions.



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/EagerAggRewriter.java:
##########
@@ -534,21 +656,382 @@ public Plan visitLogicalRelation(LogicalRelation 
relation, PushDownAggContext co
     }
 
     private Plan genAggregate(Plan child, PushDownAggContext context) {
-        if (context.isValid() && checkStats(child, context)) {
+        if (isPushDisabledByVariable(context)) {
+            context.getBilateralState().registerNoCountSlot(child);
+            return child;
+        }
+        if (checkStats(child, context) || isPushEnabledByVariable(context)) {
             List<NamedExpression> aggOutputExpressions = new ArrayList<>();
             for (AggregateFunction func : context.getAggFunctions()) {
                 aggOutputExpressions.add(context.getAliasMap().get(func));
             }
+            Alias countStarAlias = null;
+            boolean countStarAlreadyProjected = false;
+            Count countStar = new Count();
+            if (context.getAliasMap().containsKey(countStar)) {
+                countStarAlias = context.getAliasMap().get(countStar);
+                countStarAlreadyProjected = true;
+            } else {
+                countStarAlias = new Alias(countStar,
+                        "cnt" + 
context.getCascadesContext().getStatementContext().generateColumnName());
+            }
             aggOutputExpressions.addAll(context.getGroupKeys());
+            if (countStarAlias != null && !countStarAlreadyProjected) {
+                aggOutputExpressions.add(countStarAlias);
+            }
             LogicalAggregate genAgg = new 
LogicalAggregate(context.getGroupKeys(), aggOutputExpressions, child);
             NormalizeAggregate normalizeAggregate = new NormalizeAggregate();
-            return normalizeAggregate.normalizeAgg(genAgg, Optional.empty(),
+            Plan normalized = normalizeAggregate.normalizeAgg(genAgg, 
Optional.empty(),
                     context.getCascadesContext());
+
+            for (AggregateFunction func : context.getAggFunctions()) {
+                Alias a = context.getAliasMap().get(func);
+                
context.getBilateralState().registerPushedAggFuncSlot(a.getExprId(), 
a.toSlot());
+            }

Review Comment:
   Type safety bug: registerPushedAggFuncSlot uses pre-normalization slot type. 
The slot from a.toSlot() has the type from BEFORE NormalizeAggregate processes 
the agg at line 1288-1289. After normalization, the actual output slots may 
have different types. Consider using normalized.getOutput() to find the 
matching post-normalization slot and register that instead: 
normalized.getOutput().stream().filter(s -> 
s.getExprId().equals(id)).findFirst().orElse(a.toSlot()).



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/eageraggregation/EagerAggRewriter.java:
##########
@@ -534,21 +656,382 @@ public Plan visitLogicalRelation(LogicalRelation 
relation, PushDownAggContext co
     }
 
     private Plan genAggregate(Plan child, PushDownAggContext context) {
-        if (context.isValid() && checkStats(child, context)) {
+        if (isPushDisabledByVariable(context)) {
+            context.getBilateralState().registerNoCountSlot(child);
+            return child;
+        }
+        if (checkStats(child, context) || isPushEnabledByVariable(context)) {
             List<NamedExpression> aggOutputExpressions = new ArrayList<>();
             for (AggregateFunction func : context.getAggFunctions()) {
                 aggOutputExpressions.add(context.getAliasMap().get(func));
             }
+            Alias countStarAlias = null;
+            boolean countStarAlreadyProjected = false;
+            Count countStar = new Count();
+            if (context.getAliasMap().containsKey(countStar)) {
+                countStarAlias = context.getAliasMap().get(countStar);
+                countStarAlreadyProjected = true;
+            } else {
+                countStarAlias = new Alias(countStar,
+                        "cnt" + 
context.getCascadesContext().getStatementContext().generateColumnName());
+            }
             aggOutputExpressions.addAll(context.getGroupKeys());
+            if (countStarAlias != null && !countStarAlreadyProjected) {
+                aggOutputExpressions.add(countStarAlias);
+            }
             LogicalAggregate genAgg = new 
LogicalAggregate(context.getGroupKeys(), aggOutputExpressions, child);
             NormalizeAggregate normalizeAggregate = new NormalizeAggregate();
-            return normalizeAggregate.normalizeAgg(genAgg, Optional.empty(),
+            Plan normalized = normalizeAggregate.normalizeAgg(genAgg, 
Optional.empty(),
                     context.getCascadesContext());
+
+            for (AggregateFunction func : context.getAggFunctions()) {
+                Alias a = context.getAliasMap().get(func);
+                
context.getBilateralState().registerPushedAggFuncSlot(a.getExprId(), 
a.toSlot());
+            }
+
+            if (countStarAlias != null) {
+                context.getBilateralState().registerCountSlot(normalized, 
countStarAlias.toSlot());
+            } else {
+                context.getBilateralState().registerNoCountSlot(normalized);
+            }
+            return normalized;
         } else {
+            context.getBilateralState().registerNoCountSlot(child);
             return child;
         }
     }
 
+    // Build the canonical project above a rewritten join after 
eager-aggregation pushdown.
+    // Responsibilities:
+    // 1. Restore the outputs expected by the parent rollup. If a join side 
has a childContext, materialize
+    //    that side's aggregate current values and group keys; otherwise 
forward the original join outputs.
+    // 2. For inner joins, recover join multiplicity by multiplying 
non-MIN/MAX aggregate current values by
+    //    the opposite side's count slot when that side contributes rows to 
the parent aggregate.
+    // 3. Append and register a synthetic join-count slot `cnt` (logical jcnt) 
for upper-level rollup.
+    //
+    // The examples below are schematic. The real project may keep extra 
forwarded slots such as join keys.
+    //
+    // Inner join + sum, single-side rewrite:
+    //   Before:
+    //     agg(sum(t1.a), sum(t2.a), gby t2.k)
+    //       -> inner join(k = k)
+    //            -> scan(t1)
+    //            -> scan(t2)
+    //   After:
+    //     agg(sum(s1), sum(s2), gby t2.k)
+    //       -> project(s1, t2.a * cnt1 as s2, t2.k, cnt1 as cnt)
+    //            -> inner join(k = k)
+    //                 -> agg(sum(t1.a) as s1, count(*) as cnt1, gby k)
+    //                      -> scan(t1)
+    //                 -> scan(t2)
+    //
+    // Inner join + sum, bilateral rewrite:
+    //   Before:
+    //     agg(sum(t1.a), sum(t2.a), gby t2.k)
+    //       -> inner join(k = k)
+    //            -> scan(t1)
+    //            -> scan(t2)
+    //   After:
+    //     agg(sum(s1'), sum(s2'), gby t2.k)
+    //       -> project(s1 * cnt2 as s1', s2 * cnt1 as s2', t2.k, cnt1 * cnt2 
as cnt)
+    //            -> inner join(k = k)
+    //                 -> agg(sum(t1.a) as s1, count(*) as cnt1, gby k)
+    //                      -> scan(t1)
+    //                 -> agg(sum(t2.a) as s2, count(*) as cnt2, gby k)
+    //                      -> scan(t2)
+    //
+    // Inner join + count(col), single-side rewrite:
+    //   Before:
+    //     agg(count(t1.a), count(t2.a), gby t2.k)
+    //       -> inner join(k = k)
+    //            -> scan(t1)
+    //            -> scan(t2)
+    //   After:
+    //     agg(sum0(c1), sum0(c2), gby t2.k)
+    //       -> project(c1, if(t2.a is null, 0, 1) * cnt1 as c2, t2.k, cnt1 as 
cnt)
+    //            -> inner join(k = k)
+    //                 -> agg(count(t1.a) as c1, count(*) as cnt1, gby k)
+    //                      -> scan(t1)
+    //                 -> scan(t2)
+    //
+    // Inner join + count(col), bilateral rewrite:
+    //   Before:
+    //     agg(count(t1.a), count(t2.a), gby t2.k)
+    //       -> inner join(k = k)
+    //            -> scan(t1)
+    //            -> scan(t2)
+    //   After:
+    //     agg(sum0(c1'), sum0(c2'), gby t2.k)
+    //       -> project(c1 * cnt2 as c1', c2 * cnt1 as c2', t2.k, cnt1 * cnt2 
as cnt)
+    //            -> inner join(k = k)
+    //                 -> agg(count(t1.a) as c1, count(*) as cnt1, gby k)
+    //                      -> scan(t1)
+    //                 -> agg(count(t2.a) as c2, count(*) as cnt2, gby k)
+    //                      -> scan(t2)
+    //   For count(*), the current row value is 1 instead of if(col is null, 
0, 1).
+    //
+    // Semi/anti join:
+    //   The project does not multiply by the opposite-side count
+    //
+    // Outer join:
+    //   Aggregate outputs are not multiplied by the opposite-side count 
either; only `cnt` changes:
+    //     left outer join with left push  -> project(s1, t2.k, cnt1 as cnt)
+    //     right outer join with left push -> project(s1, t2.k, nvl(cnt1, 1) 
as cnt)
+    private Plan buildCanonicalJoinProject(LogicalJoin<? extends Plan, ? 
extends Plan> join, PushDownAggContext context,
+            Optional<PushDownAggContext> leftChildContext, 
Optional<PushDownAggContext> rightChildContext,
+            Optional<Slot> leftCountSlot, Optional<Slot> rightCountSlot) {
+        List<NamedExpression> projections = new ArrayList<>();
+        Set<ExprId> outputIds = new HashSet<>();
+        boolean remainLeft = join.getJoinType().isRemainLeftJoin();
+        boolean remainRight = join.getJoinType().isRemainRightJoin();
+        boolean shouldAdjustLeft = 
shouldUseJoinOppositeCntAdjustAggOutput(join, leftChildContext, rightCountSlot);
+        boolean shouldAdjustRight = 
shouldUseJoinOppositeCntAdjustAggOutput(join, rightChildContext, leftCountSlot);
+
+        if (remainLeft) {
+            appendJoinSideOutputs(projections, outputIds, join.left(), 
leftChildContext, context,
+                    rightCountSlot, shouldAdjustLeft);
+        }
+        if (remainRight) {
+            appendJoinSideOutputs(projections, outputIds, join.right(), 
rightChildContext, context,
+                    leftCountSlot, shouldAdjustRight);
+        }
+
+        Optional<Expression> joinCount = computeJoinCount(join, 
leftChildContext, rightChildContext,
+                leftCountSlot, rightCountSlot);
+        Optional<Slot> projectedCountSlot = Optional.empty();
+        if (joinCount.isPresent()) {
+            Alias countAlias = new Alias(joinCount.get(),
+                    "joinCnt" + 
context.getCascadesContext().getStatementContext().generateColumnName());
+            projections.add(countAlias);
+            projectedCountSlot = Optional.of(countAlias.toSlot());
+        }
+        LogicalProject<Plan> project = new LogicalProject<>(projections, join);
+        if (projectedCountSlot.isPresent()) {
+            context.getBilateralState().registerCountSlot(project,
+                    (Slot) project.getOutput().get(project.getOutput().size() 
- 1));
+        } else {
+            context.getBilateralState().registerNoCountSlot(project);
+        }
+        return project;
+    }
+
+    private void appendJoinSideOutputs(List<NamedExpression> projections, 
Set<ExprId> outputIds, Plan originalSide,
+            Optional<PushDownAggContext> childContext, PushDownAggContext 
parentContext,
+            Optional<Slot> oppositeCountSlot, boolean shouldAdjustOutput) {
+        if (childContext.isPresent()) {
+            for (AggregateFunction aggFunc : 
childContext.get().getAggFunctions()) {
+                NamedExpression aggOutput = shouldAdjustOutput
+                        ? adjustAggOutputUseOppositeCountOnJoin(aggFunc, 
parentContext, oppositeCountSlot)
+                        : buildAggOutputWithoutJoinAdjustment(aggFunc, 
parentContext);
+                appendProjectionIfAbsent(projections, outputIds, aggOutput);
+            }
+            for (SlotReference groupKey : childContext.get().getGroupKeys()) {
+                appendProjectionIfAbsent(projections, outputIds, groupKey);
+            }
+        } else {
+            for (Slot slot : originalSide.getOutput()) {
+                appendProjectionIfAbsent(projections, outputIds, slot);
+            }
+        }
+    }
+
+    private void appendProjectionIfAbsent(List<NamedExpression> projections, 
Set<ExprId> outputIds,
+            NamedExpression expression) {
+        if (outputIds.add(expression.getExprId())) {
+            projections.add(expression);
+        }
+    }
+
+    private boolean shouldUseJoinOppositeCntAdjustAggOutput(LogicalJoin<? 
extends Plan, ? extends Plan> join,
+            Optional<PushDownAggContext> childContext, Optional<Slot> 
oppositeCountSlot) {
+        return join.getJoinType().isInnerOrCrossJoin() && 
childContext.isPresent() && oppositeCountSlot.isPresent();
+    }
+
+    private Optional<Expression> computeJoinCount(LogicalJoin<? extends Plan, 
? extends Plan> join,
+            Optional<PushDownAggContext> leftChildContext, 
Optional<PushDownAggContext> rightChildContext,
+            Optional<Slot> leftCountSlot, Optional<Slot> rightCountSlot) {
+        JoinType joinType = join.getJoinType();
+        if (joinType.isInnerJoin()) {
+            if (leftCountSlot.isPresent() && rightCountSlot.isPresent()) {
+                return Optional.of(ExpressionUtils.rebuildSignature(
+                        new Multiply(leftCountSlot.get(), 
rightCountSlot.get())));
+            } else if (leftCountSlot.isPresent()) {
+                return Optional.of(leftCountSlot.get());
+            } else if (rightCountSlot.isPresent()) {
+                return Optional.of(rightCountSlot.get());
+            }
+            return Optional.empty();
+        }
+        if (joinType.isLeftOuterJoin()) {
+            if (leftChildContext.isPresent()) {
+                return leftCountSlot.map(cnt -> (Expression) cnt);
+            }
+            if (rightChildContext.isPresent()) {
+                return rightCountSlot.map(cnt -> (Expression) 
ExpressionUtils.rebuildSignature(
+                        new Nvl(cnt, BigIntLiteral.of(1))));
+            }

Review Comment:
   computeJoinCount Nvl type analysis: For left outer join with right push, cnt 
(BIGINT, non-nullable from count(*)) may become NULL after null-extension. 
Nvl(cnt, 1) correctly ensures join count stays non-nullable. Nvl is handled by 
rebuildSignature's BoundFunction branch. cnt (BIGINT) and BigIntLiteral(1) 
(BIGINT) have same type so type coercion is identity. But if count ever returns 
a different type (e.g. BIGINT UNSIGNED), BigIntLiteral.of(1) type might not 
match. Consider using Cast(BigIntLiteral.of(1), cnt.getDataType()) for explicit 
type alignment.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to