This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new fd8adb492d [fix](nereids) fix bugs in nereids window function (#17284)
fd8adb492d is described below

commit fd8adb492da916fe75a79b80fd720ec1a08e0d36
Author: minghong <[email protected]>
AuthorDate: Tue Mar 7 16:35:37 2023 +0800

    [fix](nereids) fix bugs in nereids window function (#17284)
    
    fix two problems:
    
    1. push agg-fun in windowExpression down to AggregateNode
    for example, sql:
    select sum(sum(a)) over (order by b)
    Plan:
    windowExpression( sum(y) over (order by b))
    +--- Agg(sum(a) as y, b)
    
    2. push other expr to upper proj
    for example, sql:
    select sum(a+1) over ()
    Plan:
    windowExpression(sum(y) over ())
    +--- Project(a + 1 as y,...)
    +--- Agg(a,...)
---
 .../glue/translator/ExpressionTranslator.java      |   5 +-
 .../nereids/rules/analysis/CheckAfterRewrite.java  |   4 +-
 .../ExtractAndNormalizeWindowExpression.java       |  41 +++++-
 .../rules/rewrite/logical/NormalizeAggregate.java  | 138 ++++++++++++++++++---
 .../rules/rewrite/logical/NormalizeToSlot.java     |  19 +++
 .../org/apache/doris/nereids/trees/TreeNode.java   |  58 +++++++++
 .../trees/expressions/WindowExpression.java        |   7 ++
 .../functions/agg/AggregateFunction.java           |  13 +-
 .../apache/doris/nereids/util/ExpressionUtils.java |   7 ++
 .../org/apache/doris/planner/AnalyticEvalNode.java |   6 +-
 .../data/nereids_syntax_p0/window_function.out     |  50 ++++++++
 .../nereids_syntax_p0/window_function.groovy       |  32 ++++-
 12 files changed, 351 insertions(+), 29 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
index 31867e82d4..c379e49b8f 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
@@ -107,9 +107,10 @@ public class ExpressionTranslator extends 
DefaultExpressionVisitor<Expr, PlanTra
         Expr staleExpr = expression.accept(INSTANCE, context);
         try {
             staleExpr.finalizeForNereids();
-        } catch (org.apache.doris.common.AnalysisException e) {
+        } catch (Exception e) {
             throw new AnalysisException(
-                    "Translate Nereids expression to stale expression failed. 
" + e.getMessage(), e);
+                    "Translate Nereids expression `" + expression.toSql()
+                            + "` to stale expression failed. " + 
e.getMessage(), e);
         }
         return staleExpr;
     }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAfterRewrite.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAfterRewrite.java
index faae491082..3e8a94e5cf 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAfterRewrite.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAfterRewrite.java
@@ -64,10 +64,10 @@ public class CheckAfterRewrite extends 
OneAnalysisRuleFactory {
                 .collect(Collectors.toSet());
         notFromChildren = removeValidSlotsNotFromChildren(notFromChildren, 
childrenOutput);
         if (!notFromChildren.isEmpty()) {
-            throw new AnalysisException(String.format("Input slot(s) not in 
child's output: %s",
+            throw new AnalysisException(String.format("Input slot(s) not in 
child's output: %s in plan: %s",
                     StringUtils.join(notFromChildren.stream()
                             .map(ExpressionTrait::toSql)
-                            .collect(Collectors.toSet()), ", ")));
+                            .collect(Collectors.toSet()), ", "), plan));
         }
     }
 
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractAndNormalizeWindowExpression.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractAndNormalizeWindowExpression.java
index 54121ad5b6..75dba0467a 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractAndNormalizeWindowExpression.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractAndNormalizeWindowExpression.java
@@ -23,6 +23,7 @@ import 
org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
 import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.expressions.WindowExpression;
 import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
 import org.apache.doris.nereids.trees.plans.Plan;
@@ -38,6 +39,8 @@ import com.google.common.collect.Sets;
 
 import java.util.List;
 import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
 
 /**
  * extract window expressions from LogicalProject.projects and Normalize 
LogicalWindow
@@ -94,14 +97,44 @@ public class ExtractAndNormalizeWindowExpression extends 
OneRewriteRuleFactory i
         // bottomProjects includes:
         // 1. expressions from function and WindowSpec's partitionKeys and 
orderKeys
         // 2. other slots of outputExpressions
+        /*
+        avg(c) / sum(a+1) over (order by avg(b))  group by a
+        win(x/sum(z) over y)
+            prj(x, y, a+1 as z)
+                agg(avg(c) x, avg(b) y, a)
+                    proj(a b c)
+        toBePushDown = {avg(c), a+1, avg(b)}
+         */
         return expressions.stream()
             .flatMap(expression -> {
                 if (expression.anyMatch(WindowExpression.class::isInstance)) {
+                    Set<Slot> inputSlots = 
expression.getInputSlots().stream().collect(Collectors.toSet());
                     Set<WindowExpression> collects = 
expression.collect(WindowExpression.class::isInstance);
-                    return collects.stream().flatMap(windowExpression ->
-                        windowExpression.getExpressionsInWindowSpec().stream()
-                            // constant arguments may in WindowFunctions(e.g. 
Lead, Lag), which shouldn't be pushed down
-                            .filter(expr -> !expr.isConstant())
+                    Set<Slot> windowInputSlots = collects.stream().flatMap(
+                            win -> win.getInputSlots().stream()
+                    ).collect(Collectors.toSet());
+                    /*
+                    substr(
+                      ref_1.cp_type,
+                      max(
+                          cast(ref_1.`cp_catalog_page_number` as int)) over 
(...)
+                          ),
+                      1)
+
+                      in above case, ref_1.cp_type should be pushed down. 
ref_1.cp_type is in
+                      substr.inputSlots, but not in windowExpression.inputSlots
+
+                      inputSlots= {ref_1.cp_type}
+                     */
+                    inputSlots.removeAll(windowInputSlots);
+                    return Stream.concat(
+                            collects.stream().flatMap(windowExpression ->
+                                    
windowExpression.getExpressionsInWindowSpec().stream()
+                                    // constant arguments may in 
WindowFunctions(e.g. Lead, Lag)
+                                    // which shouldn't be pushed down
+                                    .filter(expr -> !expr.isConstant())
+                            ),
+                            inputSlots.stream()
                     );
                 }
                 return ImmutableList.of(expression).stream();
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
index 07ddffeced..cf802245ca 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
@@ -34,10 +34,13 @@ import org.apache.doris.nereids.util.ExpressionUtils;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
 
 import java.util.List;
 import java.util.Set;
 import java.util.stream.Collectors;
+import java.util.stream.Stream;
 
 /**
  * normalize aggregate's group keys and AggregateFunction's child to 
SlotReference
@@ -57,6 +60,31 @@ import java.util.stream.Collectors;
  * +-- Aggregate(keys:[k1#1, SR#9], outputs:[k1#1, SR#9, Alias(SUM(v1#3))#10, 
Alias(SUM(v1#3 + 1))#11])
  *   +-- Project(k1#1, Alias(K2#2 + 1)#9, v1#3)
  * <p>
+ *
+ * Note: window function will be moved to upper project
+ * all agg functions except the top agg should be pushed to Aggregate node.
+ * example 1:
+ *    select min(x), sum(x) over () ...
+ * the 'sum(x)' is top agg of window function, it should be moved to upper 
project
+ * plan:
+ *    project(sum(x) over())
+ *        Aggregate(min(x), x)
+ *
+ * example 2:
+ *    select min(x), avg(sum(x)) over() ...
+ * the 'sum(x)' should be moved to Aggregate
+ * plan:
+ *    project(avg(y) over())
+ *         Aggregate(min(x), sum(x) as y)
+ * example 3:
+ *    select sum(x+1), x+1, sum(x+1) over() ...
+ * window function should use x instead of x+1
+ * plan:
+ *    project(sum(x+1) over())
+ *        Agg(sum(y), x)
+ *            project(x+1 as y)
+ *
+ *
  * More example could get from UT {NormalizeAggregateTest}
  */
 public class NormalizeAggregate extends OneRewriteRuleFactory implements 
NormalizeToSlot {
@@ -64,8 +92,37 @@ public class NormalizeAggregate extends 
OneRewriteRuleFactory implements Normali
     public Rule build() {
         return 
logicalAggregate().whenNot(LogicalAggregate::isNormalized).then(aggregate -> {
             // push expression to bottom project
-            Set<Alias> existsAliases = ExpressionUtils.collect(
+            Set<Alias> existsAliases = ExpressionUtils.mutableCollect(
                     aggregate.getOutputExpressions(), Alias.class::isInstance);
+            Set<AggregateFunction> aggregateFunctionsInWindow = 
collectAggregateFunctionsInWindow(
+                    aggregate.getOutputExpressions());
+            Set<Expression> existsAggAlias = existsAliases.stream().map(alias 
-> alias.child())
+                    .filter(AggregateFunction.class::isInstance)
+                    .collect(Collectors.toSet());
+
+            /*
+             * agg-functions inside window function is regarded as an output 
of aggregate.
+             * select sum(avg(c)) over ...
+             * is regarded as
+             * select avg(c), sum(avg(c)) over ...
+             *
+             * the plan:
+             * project(sum(y) over)
+             *    Aggregate(avg(c) as y)
+             *
+             * after Aggregate, the 'y' is removed by upper project.
+             *
+             * aliasOfAggFunInWindowUsedAsAggOutput = {alias(avg(c))}
+             */
+            List<Alias> aliasOfAggFunInWindowUsedAsAggOutput = 
Lists.newArrayList();
+
+            for (AggregateFunction aggFun : aggregateFunctionsInWindow) {
+                if (!existsAggAlias.contains(aggFun)) {
+                    Alias alias = new Alias(aggFun, aggFun.toSql());
+                    existsAliases.add(alias);
+                    aliasOfAggFunInWindowUsedAsAggOutput.add(alias);
+                }
+            }
             Set<Expression> needToSlots = 
collectGroupByAndArgumentsOfAggregateFunctions(aggregate);
             NormalizeToSlotContext groupByAndArgumentToSlotContext =
                     NormalizeToSlotContext.buildContext(existsAliases, 
needToSlots);
@@ -80,13 +137,22 @@ public class NormalizeAggregate extends 
OneRewriteRuleFactory implements Normali
             // replace groupBy and arguments of aggregate function to slot, 
may be this output contains
             // some expression on the aggregate functions, e.g. `sum(value) + 
1`, we should replace
             // the sum(value) to slot and move the `slot + 1` to the upper 
project later.
-            List<NamedExpression> normalizeOutputPhase1 = 
aggregate.getOutputExpressions().stream()
-                    .map(expr -> {
-                        if (expr.anyMatch(WindowExpression.class::isInstance)) 
{
-                            return expr;
-                        }
-                        return 
groupByAndArgumentToSlotContext.normalizeToUseSlotRef(expr);
-                    }).collect(Collectors.toList());
+            List<NamedExpression> normalizeOutputPhase1 = Stream.concat(
+                    aggregate.getOutputExpressions().stream(),
+                    aliasOfAggFunInWindowUsedAsAggOutput.stream())
+                    .map(expr -> groupByAndArgumentToSlotContext
+                            .normalizeToUseSlotRefUp(expr, 
WindowExpression.class::isInstance))
+                    .collect(Collectors.toList());
+
+            Set<Slot> windowInputSlots = 
collectWindowInputSlots(aggregate.getOutputExpressions());
+            Set<Expression> itemsInWindow = Sets.newHashSet(windowInputSlots);
+            itemsInWindow.addAll(aggregateFunctionsInWindow);
+            NormalizeToSlotContext windowToSlotContext =
+                    NormalizeToSlotContext.buildContext(existsAliases, 
itemsInWindow);
+            normalizeOutputPhase1 = normalizeOutputPhase1.stream()
+                    .map(expr -> windowToSlotContext
+                            .normalizeToUseSlotRefDown(expr, 
WindowExpression.class::isInstance, true))
+                    .collect(Collectors.toList());
 
             Set<AggregateFunction> normalizedAggregateFunctions =
                     
collectNonWindowedAggregateFunctions(normalizeOutputPhase1);
@@ -115,14 +181,10 @@ public class NormalizeAggregate extends 
OneRewriteRuleFactory implements Normali
             LogicalAggregate<Plan> normalizedAggregate = 
aggregate.withNormalized(
                     (List) normalizedGroupBy, normalizedAggregateOutput, 
normalizedChild);
 
+            
normalizeOutputPhase1.removeAll(aliasOfAggFunInWindowUsedAsAggOutput);
             // exclude same-name functions in WindowExpression
             List<NamedExpression> upperProjects = 
normalizeOutputPhase1.stream()
-                    .map(expr -> {
-                        if (expr.anyMatch(WindowExpression.class::isInstance)) 
{
-                            return expr;
-                        }
-                        return 
aggregateFunctionToSlotContext.normalizeToUseSlotRef(expr);
-                    }).collect(Collectors.toList());
+                    
.map(aggregateFunctionToSlotContext::normalizeToUseSlotRef).collect(Collectors.toList());
             return new LogicalProject<>(upperProjects, normalizedAggregate);
         }).toRule(RuleType.NORMALIZE_AGGREGATE);
     }
@@ -146,21 +208,61 @@ public class NormalizeAggregate extends 
OneRewriteRuleFactory implements Normali
                 }))
                 .collect(ImmutableSet.toImmutableSet());
 
+        Set<Expression> windowFunctionKeys = 
collectWindowFunctionKeys(aggregate.getOutputExpressions());
+
         ImmutableSet<Expression> needPushDown = 
ImmutableSet.<Expression>builder()
                 // group by should be pushed down, e.g. group by (k + 1),
                 // we should push down the `k + 1` to the bottom plan
                 .addAll(groupingByExpr)
                 // e.g. sum(k + 1), we should push down the `k + 1` to the 
bottom plan
                 .addAll(argumentsOfAggregateFunction)
+                .addAll(windowFunctionKeys)
                 .build();
         return needPushDown;
     }
 
+    private Set<Expression> collectWindowFunctionKeys(List<NamedExpression> 
aggOutput) {
+        Set<Expression> windowInputs = Sets.newHashSet();
+        for (Expression expr : aggOutput) {
+            Set<WindowExpression> windows = 
expr.collect(WindowExpression.class::isInstance);
+            for (WindowExpression win : windows) {
+                windowInputs.addAll(win.getPartitionKeys().stream().flatMap(pk 
-> pk.getInputSlots().stream()).collect(
+                        Collectors.toList()));
+                windowInputs.addAll(win.getOrderKeys().stream().flatMap(ok -> 
ok.getInputSlots().stream()).collect(
+                        Collectors.toList()));
+            }
+        }
+        return windowInputs;
+    }
+
+    /**
+     * select sum(c2), avg(min(c2)) over (partition by max(c1) order by 
count(c1)) from T ...
+     * extract {sum, min, max, count}. avg is not extracted.
+     */
     private Set<AggregateFunction> 
collectNonWindowedAggregateFunctions(List<NamedExpression> aggOutput) {
-        List<Expression> expressionsWithoutWindow = aggOutput.stream()
-                .filter(expr -> 
!expr.anyMatch(WindowExpression.class::isInstance))
-                .collect(Collectors.toList());
+        return ExpressionUtils.collect(aggOutput, expr -> {
+            if (expr instanceof AggregateFunction) {
+                return !((AggregateFunction) expr).isWindowFunction();
+            }
+            return false;
+        });
+    }
+
+    private Set<AggregateFunction> 
collectAggregateFunctionsInWindow(List<NamedExpression> aggOutput) {
+
+        List<WindowExpression> windows = Lists.newArrayList(
+                ExpressionUtils.collect(aggOutput, 
WindowExpression.class::isInstance));
+        return ExpressionUtils.collect(windows, expr -> {
+            if (expr instanceof AggregateFunction) {
+                return !((AggregateFunction) expr).isWindowFunction();
+            }
+            return false;
+        });
+    }
 
-        return ExpressionUtils.collect(expressionsWithoutWindow, 
AggregateFunction.class::isInstance);
+    private Set<Slot> collectWindowInputSlots(List<NamedExpression> aggOutput) 
{
+        List<WindowExpression> windows = Lists.newArrayList(
+                ExpressionUtils.collect(aggOutput, 
WindowExpression.class::isInstance));
+        return windows.stream().flatMap(win -> 
win.getInputSlots().stream()).collect(Collectors.toSet());
     }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeToSlot.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeToSlot.java
index aee929c8be..8ef966496e 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeToSlot.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeToSlot.java
@@ -31,6 +31,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.function.BiFunction;
+import java.util.function.Predicate;
 import javax.annotation.Nullable;
 
 /** NormalizeToSlot */
@@ -88,6 +89,24 @@ public interface NormalizeToSlot {
                     })).collect(ImmutableList.toImmutableList());
         }
 
+        public <E extends Expression> E normalizeToUseSlotRefUp(E expression, 
Predicate skip) {
+            return (E) expression.rewriteDownShortCircuitUp(child -> {
+                NormalizeToSlotTriplet normalizeToSlotTriplet = 
normalizeToSlotMap.get(child);
+                return normalizeToSlotTriplet == null ? child : 
normalizeToSlotTriplet.remainExpr;
+            }, skip);
+        }
+
+        /**
+         * rewrite subtrees whose root matches predicate border
+         * when we traverse to the node satisfies border predicate, 
aboveBorder becomes false
+         */
+        public <E extends Expression> E normalizeToUseSlotRefDown(E 
expression, Predicate border, boolean aboveBorder) {
+            return (E) expression.rewriteDownShortCircuitDown(child -> {
+                NormalizeToSlotTriplet normalizeToSlotTriplet = 
normalizeToSlotMap.get(child);
+                return normalizeToSlotTriplet == null ? child : 
normalizeToSlotTriplet.remainExpr;
+            }, border, aboveBorder);
+        }
+
         /**
          * generate bottom projections with groupByExpressions.
          * eg:
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java
index d8f10a362d..92b99ec68e 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java
@@ -96,6 +96,64 @@ public interface TreeNode<NODE_TYPE extends 
TreeNode<NODE_TYPE>> {
         return currentNode;
     }
 
+    /**
+     * same as rewriteDownShortCircuit,
+     * except that subtrees, whose root satisfies predicate is satisfied, are 
not rewritten
+     */
+    default NODE_TYPE rewriteDownShortCircuitUp(Function<NODE_TYPE, NODE_TYPE> 
rewriteFunction, Predicate skip) {
+        NODE_TYPE currentNode = rewriteFunction.apply((NODE_TYPE) this);
+        if (skip.test(currentNode)) {
+            return currentNode;
+        }
+        if (currentNode == this) {
+            Builder<NODE_TYPE> newChildren = 
ImmutableList.builderWithExpectedSize(arity());
+            boolean changed = false;
+            for (NODE_TYPE child : children()) {
+                NODE_TYPE newChild = 
child.rewriteDownShortCircuitUp(rewriteFunction, skip);
+                if (child != newChild) {
+                    changed = true;
+                }
+                newChildren.add(newChild);
+            }
+
+            if (changed) {
+                currentNode = currentNode.withChildren(newChildren.build());
+            }
+        }
+        return currentNode;
+    }
+
+    /**
+     * similar to rewriteDownShortCircuit, except that only subtrees, whose 
root satisfies
+     * border predicate are rewritten.
+     */
+    default NODE_TYPE rewriteDownShortCircuitDown(Function<NODE_TYPE, 
NODE_TYPE> rewriteFunction,
+            Predicate border, boolean aboveBorder) {
+        NODE_TYPE currentNode = (NODE_TYPE) this;
+        if (border.test(this)) {
+            aboveBorder = false;
+        }
+        if (!aboveBorder) {
+            currentNode = rewriteFunction.apply((NODE_TYPE) this);
+        }
+        if (currentNode == this) {
+            Builder<NODE_TYPE> newChildren = 
ImmutableList.builderWithExpectedSize(arity());
+            boolean changed = false;
+            for (NODE_TYPE child : children()) {
+                NODE_TYPE newChild = 
child.rewriteDownShortCircuitDown(rewriteFunction, border, aboveBorder);
+                if (child != newChild) {
+                    changed = true;
+                }
+                newChildren.add(newChild);
+            }
+
+            if (changed) {
+                currentNode = currentNode.withChildren(newChildren.build());
+            }
+        }
+        return currentNode;
+    }
+
     /**
      * bottom-up rewrite.
      * @param rewriteFunction rewrite function.
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java
index c9e96da1ec..3da1610d9b 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java
@@ -18,6 +18,7 @@
 package org.apache.doris.nereids.trees.expressions;
 
 import org.apache.doris.nereids.exceptions.UnboundException;
+import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
 import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
 import org.apache.doris.nereids.types.DataType;
 
@@ -53,6 +54,9 @@ public class WindowExpression extends Expression {
                 .addAll(orderKeys)
                 .build().toArray(new Expression[0]));
         this.function = function;
+        if (function instanceof AggregateFunction) {
+            ((AggregateFunction) function).setWindowFunction(true);
+        }
         this.partitionKeys = ImmutableList.copyOf(partitionKeys);
         this.orderKeys = ImmutableList.copyOf(orderKeys);
         this.windowFrame = Optional.empty();
@@ -68,6 +72,9 @@ public class WindowExpression extends Expression {
                 .add(windowFrame)
                 .build().toArray(new Expression[0]));
         this.function = function;
+        if (function instanceof AggregateFunction) {
+            ((AggregateFunction) function).setWindowFunction(true);
+        }
         this.partitionKeys = ImmutableList.copyOf(partitionKeys);
         this.orderKeys = ImmutableList.copyOf(orderKeys);
         this.windowFrame = Optional.of(Objects.requireNonNull(windowFrame));
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java
index 7d4dd262a0..a170ae0dd5 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java
@@ -38,6 +38,7 @@ import java.util.stream.Collectors;
 public abstract class AggregateFunction extends BoundFunction implements 
ExpectsInputTypes {
 
     protected final boolean distinct;
+    protected boolean isWindowFunction = false;
 
     public AggregateFunction(String name, Expression... arguments) {
         this(name, false, arguments);
@@ -77,6 +78,14 @@ public abstract class AggregateFunction extends 
BoundFunction implements Expects
         return distinct;
     }
 
+    public boolean isWindowFunction() {
+        return isWindowFunction;
+    }
+
+    public void setWindowFunction(boolean windowFunction) {
+        isWindowFunction = windowFunction;
+    }
+
     @Override
     public boolean equals(Object o) {
         if (this == o) {
@@ -86,7 +95,8 @@ public abstract class AggregateFunction extends BoundFunction 
implements Expects
             return false;
         }
         AggregateFunction that = (AggregateFunction) o;
-        return Objects.equals(distinct, that.distinct)
+        return isWindowFunction == that.isWindowFunction
+                && Objects.equals(distinct, that.distinct)
                 && Objects.equals(getName(), that.getName())
                 && Objects.equals(children, that.children);
     }
@@ -123,4 +133,5 @@ public abstract class AggregateFunction extends 
BoundFunction implements Expects
                 .collect(Collectors.joining(", "));
         return getName() + "(" + (distinct ? "DISTINCT " : "") + args + ")";
     }
+
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
index f9cbd57a89..90e7c73c8e 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
@@ -411,6 +411,13 @@ public class ExpressionUtils {
                 .collect(ImmutableSet.toImmutableSet());
     }
 
+    public static <E> Set<E> mutableCollect(List<? extends Expression> 
expressions,
+            Predicate<TreeNode<Expression>> predicate) {
+        return expressions.stream()
+                .flatMap(expr -> expr.<Set<E>>collect(predicate).stream())
+                .collect(Collectors.toSet());
+    }
+
     public static <E> List<E> collectAll(List<? extends Expression> 
expressions,
             Predicate<TreeNode<Expression>> predicate) {
         return expressions.stream()
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/planner/AnalyticEvalNode.java 
b/fe/fe-core/src/main/java/org/apache/doris/planner/AnalyticEvalNode.java
index fb5a63b10f..ff218ce7e3 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/planner/AnalyticEvalNode.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/planner/AnalyticEvalNode.java
@@ -106,7 +106,11 @@ public class AnalyticEvalNode extends PlanNode {
             AnalyticWindow analyticWindow, TupleDescriptor 
intermediateTupleDesc,
             TupleDescriptor outputTupleDesc, Expr partitionByEq, Expr 
orderByEq,
             TupleDescriptor bufferedTupleDesc) {
-        super(id, input.getTupleIds(), "ANALYTIC", 
StatisticalType.ANALYTIC_EVAL_NODE);
+        super(id,
+                (input.getOutputTupleDesc() != null
+                        ? 
Lists.newArrayList(input.getOutputTupleDesc().getId()) :
+                        input.getTupleIds()),
+                "ANALYTIC", StatisticalType.ANALYTIC_EVAL_NODE);
         Preconditions.checkState(!tupleIds.contains(outputTupleDesc.getId()));
         // we're materializing the input row augmented with the analytic 
output tuple
         tupleIds.add(outputTupleDesc.getId());
diff --git a/regression-test/data/nereids_syntax_p0/window_function.out 
b/regression-test/data/nereids_syntax_p0/window_function.out
index 921703d421..8434d65a73 100644
--- a/regression-test/data/nereids_syntax_p0/window_function.out
+++ b/regression-test/data/nereids_syntax_p0/window_function.out
@@ -389,3 +389,53 @@
 -- !window_use_agg --
 20
 
+-- !winExpr_not_agg_expr --
+2      5
+3      5
+3      5
+4      5
+4      5
+6      5
+6      5
+6      5
+
+-- !on_notgroupbycolumn --
+1.0
+2.0
+3.0
+3.0
+3.0
+6.0
+6.0
+6.0
+
+-- !orderby --
+1      2       2
+1      4       2
+1      4       2
+1      6       2
+2      3       5
+2      3       5
+2      6       5
+2      6       5
+
+-- !winExpr_with_others --
+0.4
+0.4
+0.5
+0.8
+0.8
+1.0
+1.0
+1.5
+
+-- !winExpr_with_others2 --
+0.4
+0.4
+0.5
+0.8
+0.8
+1.0
+1.0
+1.5
+
diff --git a/regression-test/suites/nereids_syntax_p0/window_function.groovy 
b/regression-test/suites/nereids_syntax_p0/window_function.groovy
index 76318169b1..4c0509bbe6 100644
--- a/regression-test/suites/nereids_syntax_p0/window_function.groovy
+++ b/regression-test/suites/nereids_syntax_p0/window_function.groovy
@@ -15,7 +15,7 @@
 // specific language governing permissions and limitations
 // under the License.
 
-suite("test_window_function") {
+suite("window_function") {
     sql "SET enable_nereids_planner=true"
 
     sql "DROP TABLE IF EXISTS window_test"
@@ -118,4 +118,34 @@ suite("test_window_function") {
         SELECT sum(sum(c1)) over(partition by avg(c2))
         FROM window_test
     """
+
+    order_qt_winExpr_not_agg_expr """
+        select sum(c1+1), sum(c1+1) over (partition by avg(c2))
+        from window_test
+        group by c1, c2
+    """
+
+    order_qt_on_notgroupbycolumn """
+        select sum(sum(c3)) over (partition by avg(c2) order by c1)
+        from window_test
+        group by c1, c2
+    """
+
+    order_qt_orderby """
+        select c1, sum(c1+1), sum(c1+1) over (partition by avg(c2) order by c1)
+        from window_test
+        group by c1, c2
+    """
+
+    order_qt_winExpr_with_others """
+        select sum(c1)/sum(c1+1) over (partition by c2 order by c1)
+        from window_test
+        group by c1, c2
+    """
+
+    order_qt_winExpr_with_others2"""
+        select sum(c1)/sum(c1+1) over (partition by c2 order by c1)
+        from window_test
+        group by c1, c2
+    """
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to