This is an automated email from the ASF dual-hosted git repository.
morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new fd8adb492d [fix](nereids) fix bugs in nereids window function (#17284)
fd8adb492d is described below
commit fd8adb492da916fe75a79b80fd720ec1a08e0d36
Author: minghong <[email protected]>
AuthorDate: Tue Mar 7 16:35:37 2023 +0800
[fix](nereids) fix bugs in nereids window function (#17284)
fix two problems:
1. push agg-fun in windowExpression down to AggregateNode
for example, sql:
select sum(sum(a)) over (order by b)
Plan:
windowExpression( sum(y) over (order by b))
+--- Agg(sum(a) as y, b)
2. push other expr to upper proj
for example, sql:
select sum(a+1) over ()
Plan:
windowExpression(sum(y) over ())
+--- Project(a + 1 as y,...)
+--- Agg(a,...)
---
.../glue/translator/ExpressionTranslator.java | 5 +-
.../nereids/rules/analysis/CheckAfterRewrite.java | 4 +-
.../ExtractAndNormalizeWindowExpression.java | 41 +++++-
.../rules/rewrite/logical/NormalizeAggregate.java | 138 ++++++++++++++++++---
.../rules/rewrite/logical/NormalizeToSlot.java | 19 +++
.../org/apache/doris/nereids/trees/TreeNode.java | 58 +++++++++
.../trees/expressions/WindowExpression.java | 7 ++
.../functions/agg/AggregateFunction.java | 13 +-
.../apache/doris/nereids/util/ExpressionUtils.java | 7 ++
.../org/apache/doris/planner/AnalyticEvalNode.java | 6 +-
.../data/nereids_syntax_p0/window_function.out | 50 ++++++++
.../nereids_syntax_p0/window_function.groovy | 32 ++++-
12 files changed, 351 insertions(+), 29 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
index 31867e82d4..c379e49b8f 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/ExpressionTranslator.java
@@ -107,9 +107,10 @@ public class ExpressionTranslator extends
DefaultExpressionVisitor<Expr, PlanTra
Expr staleExpr = expression.accept(INSTANCE, context);
try {
staleExpr.finalizeForNereids();
- } catch (org.apache.doris.common.AnalysisException e) {
+ } catch (Exception e) {
throw new AnalysisException(
- "Translate Nereids expression to stale expression failed.
" + e.getMessage(), e);
+ "Translate Nereids expression `" + expression.toSql()
+ + "` to stale expression failed. " +
e.getMessage(), e);
}
return staleExpr;
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAfterRewrite.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAfterRewrite.java
index faae491082..3e8a94e5cf 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAfterRewrite.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAfterRewrite.java
@@ -64,10 +64,10 @@ public class CheckAfterRewrite extends
OneAnalysisRuleFactory {
.collect(Collectors.toSet());
notFromChildren = removeValidSlotsNotFromChildren(notFromChildren,
childrenOutput);
if (!notFromChildren.isEmpty()) {
- throw new AnalysisException(String.format("Input slot(s) not in
child's output: %s",
+ throw new AnalysisException(String.format("Input slot(s) not in
child's output: %s in plan: %s",
StringUtils.join(notFromChildren.stream()
.map(ExpressionTrait::toSql)
- .collect(Collectors.toSet()), ", ")));
+ .collect(Collectors.toSet()), ", "), plan));
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractAndNormalizeWindowExpression.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractAndNormalizeWindowExpression.java
index 54121ad5b6..75dba0467a 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractAndNormalizeWindowExpression.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ExtractAndNormalizeWindowExpression.java
@@ -23,6 +23,7 @@ import
org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.plans.Plan;
@@ -38,6 +39,8 @@ import com.google.common.collect.Sets;
import java.util.List;
import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
/**
* extract window expressions from LogicalProject.projects and Normalize
LogicalWindow
@@ -94,14 +97,44 @@ public class ExtractAndNormalizeWindowExpression extends
OneRewriteRuleFactory i
// bottomProjects includes:
// 1. expressions from function and WindowSpec's partitionKeys and
orderKeys
// 2. other slots of outputExpressions
+ /*
+ avg(c) / sum(a+1) over (order by avg(b)) group by a
+ win(x/sum(z) over y)
+ prj(x, y, a+1 as z)
+ agg(avg(c) x, avg(b) y, a)
+ proj(a b c)
+ toBePushDown = {avg(c), a+1, avg(b)}
+ */
return expressions.stream()
.flatMap(expression -> {
if (expression.anyMatch(WindowExpression.class::isInstance)) {
+ Set<Slot> inputSlots =
expression.getInputSlots().stream().collect(Collectors.toSet());
Set<WindowExpression> collects =
expression.collect(WindowExpression.class::isInstance);
- return collects.stream().flatMap(windowExpression ->
- windowExpression.getExpressionsInWindowSpec().stream()
- // constant arguments may in WindowFunctions(e.g.
Lead, Lag), which shouldn't be pushed down
- .filter(expr -> !expr.isConstant())
+ Set<Slot> windowInputSlots = collects.stream().flatMap(
+ win -> win.getInputSlots().stream()
+ ).collect(Collectors.toSet());
+ /*
+ substr(
+ ref_1.cp_type,
+ max(
+ cast(ref_1.`cp_catalog_page_number` as int)) over
(...)
+ ),
+ 1)
+
+ in above case, ref_1.cp_type should be pushed down.
ref_1.cp_type is in
+ substr.inputSlots, but not in windowExpression.inputSlots
+
+ inputSlots= {ref_1.cp_type}
+ */
+ inputSlots.removeAll(windowInputSlots);
+ return Stream.concat(
+ collects.stream().flatMap(windowExpression ->
+
windowExpression.getExpressionsInWindowSpec().stream()
+ // constant arguments may in
WindowFunctions(e.g. Lead, Lag)
+ // which shouldn't be pushed down
+ .filter(expr -> !expr.isConstant())
+ ),
+ inputSlots.stream()
);
}
return ImmutableList.of(expression).stream();
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
index 07ddffeced..cf802245ca 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java
@@ -34,10 +34,13 @@ import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
+import java.util.stream.Stream;
/**
* normalize aggregate's group keys and AggregateFunction's child to
SlotReference
@@ -57,6 +60,31 @@ import java.util.stream.Collectors;
* +-- Aggregate(keys:[k1#1, SR#9], outputs:[k1#1, SR#9, Alias(SUM(v1#3))#10,
Alias(SUM(v1#3 + 1))#11])
* +-- Project(k1#1, Alias(K2#2 + 1)#9, v1#3)
* <p>
+ *
+ * Note: window function will be moved to upper project
+ * all agg functions except the top agg should be pushed to Aggregate node.
+ * example 1:
+ * select min(x), sum(x) over () ...
+ * the 'sum(x)' is top agg of window function, it should be moved to upper
project
+ * plan:
+ * project(sum(x) over())
+ * Aggregate(min(x), x)
+ *
+ * example 2:
+ * select min(x), avg(sum(x)) over() ...
+ * the 'sum(x)' should be moved to Aggregate
+ * plan:
+ * project(avg(y) over())
+ * Aggregate(min(x), sum(x) as y)
+ * example 3:
+ * select sum(x+1), x+1, sum(x+1) over() ...
+ * window function should use x instead of x+1
+ * plan:
+ * project(sum(x+1) over())
+ * Agg(sum(y), x)
+ * project(x+1 as y)
+ *
+ *
* More example could get from UT {NormalizeAggregateTest}
*/
public class NormalizeAggregate extends OneRewriteRuleFactory implements
NormalizeToSlot {
@@ -64,8 +92,37 @@ public class NormalizeAggregate extends
OneRewriteRuleFactory implements Normali
public Rule build() {
return
logicalAggregate().whenNot(LogicalAggregate::isNormalized).then(aggregate -> {
// push expression to bottom project
- Set<Alias> existsAliases = ExpressionUtils.collect(
+ Set<Alias> existsAliases = ExpressionUtils.mutableCollect(
aggregate.getOutputExpressions(), Alias.class::isInstance);
+ Set<AggregateFunction> aggregateFunctionsInWindow =
collectAggregateFunctionsInWindow(
+ aggregate.getOutputExpressions());
+ Set<Expression> existsAggAlias = existsAliases.stream().map(alias
-> alias.child())
+ .filter(AggregateFunction.class::isInstance)
+ .collect(Collectors.toSet());
+
+ /*
+ * agg-functions inside window function is regarded as an output
of aggregate.
+ * select sum(avg(c)) over ...
+ * is regarded as
+ * select avg(c), sum(avg(c)) over ...
+ *
+ * the plan:
+ * project(sum(y) over)
+ * Aggregate(avg(c) as y)
+ *
+ * after Aggregate, the 'y' is removed by upper project.
+ *
+ * aliasOfAggFunInWindowUsedAsAggOutput = {alias(avg(c))}
+ */
+ List<Alias> aliasOfAggFunInWindowUsedAsAggOutput =
Lists.newArrayList();
+
+ for (AggregateFunction aggFun : aggregateFunctionsInWindow) {
+ if (!existsAggAlias.contains(aggFun)) {
+ Alias alias = new Alias(aggFun, aggFun.toSql());
+ existsAliases.add(alias);
+ aliasOfAggFunInWindowUsedAsAggOutput.add(alias);
+ }
+ }
Set<Expression> needToSlots =
collectGroupByAndArgumentsOfAggregateFunctions(aggregate);
NormalizeToSlotContext groupByAndArgumentToSlotContext =
NormalizeToSlotContext.buildContext(existsAliases,
needToSlots);
@@ -80,13 +137,22 @@ public class NormalizeAggregate extends
OneRewriteRuleFactory implements Normali
// replace groupBy and arguments of aggregate function to slot,
may be this output contains
// some expression on the aggregate functions, e.g. `sum(value) +
1`, we should replace
// the sum(value) to slot and move the `slot + 1` to the upper
project later.
- List<NamedExpression> normalizeOutputPhase1 =
aggregate.getOutputExpressions().stream()
- .map(expr -> {
- if (expr.anyMatch(WindowExpression.class::isInstance))
{
- return expr;
- }
- return
groupByAndArgumentToSlotContext.normalizeToUseSlotRef(expr);
- }).collect(Collectors.toList());
+ List<NamedExpression> normalizeOutputPhase1 = Stream.concat(
+ aggregate.getOutputExpressions().stream(),
+ aliasOfAggFunInWindowUsedAsAggOutput.stream())
+ .map(expr -> groupByAndArgumentToSlotContext
+ .normalizeToUseSlotRefUp(expr,
WindowExpression.class::isInstance))
+ .collect(Collectors.toList());
+
+ Set<Slot> windowInputSlots =
collectWindowInputSlots(aggregate.getOutputExpressions());
+ Set<Expression> itemsInWindow = Sets.newHashSet(windowInputSlots);
+ itemsInWindow.addAll(aggregateFunctionsInWindow);
+ NormalizeToSlotContext windowToSlotContext =
+ NormalizeToSlotContext.buildContext(existsAliases,
itemsInWindow);
+ normalizeOutputPhase1 = normalizeOutputPhase1.stream()
+ .map(expr -> windowToSlotContext
+ .normalizeToUseSlotRefDown(expr,
WindowExpression.class::isInstance, true))
+ .collect(Collectors.toList());
Set<AggregateFunction> normalizedAggregateFunctions =
collectNonWindowedAggregateFunctions(normalizeOutputPhase1);
@@ -115,14 +181,10 @@ public class NormalizeAggregate extends
OneRewriteRuleFactory implements Normali
LogicalAggregate<Plan> normalizedAggregate =
aggregate.withNormalized(
(List) normalizedGroupBy, normalizedAggregateOutput,
normalizedChild);
+
normalizeOutputPhase1.removeAll(aliasOfAggFunInWindowUsedAsAggOutput);
// exclude same-name functions in WindowExpression
List<NamedExpression> upperProjects =
normalizeOutputPhase1.stream()
- .map(expr -> {
- if (expr.anyMatch(WindowExpression.class::isInstance))
{
- return expr;
- }
- return
aggregateFunctionToSlotContext.normalizeToUseSlotRef(expr);
- }).collect(Collectors.toList());
+
.map(aggregateFunctionToSlotContext::normalizeToUseSlotRef).collect(Collectors.toList());
return new LogicalProject<>(upperProjects, normalizedAggregate);
}).toRule(RuleType.NORMALIZE_AGGREGATE);
}
@@ -146,21 +208,61 @@ public class NormalizeAggregate extends
OneRewriteRuleFactory implements Normali
}))
.collect(ImmutableSet.toImmutableSet());
+ Set<Expression> windowFunctionKeys =
collectWindowFunctionKeys(aggregate.getOutputExpressions());
+
ImmutableSet<Expression> needPushDown =
ImmutableSet.<Expression>builder()
// group by should be pushed down, e.g. group by (k + 1),
// we should push down the `k + 1` to the bottom plan
.addAll(groupingByExpr)
// e.g. sum(k + 1), we should push down the `k + 1` to the
bottom plan
.addAll(argumentsOfAggregateFunction)
+ .addAll(windowFunctionKeys)
.build();
return needPushDown;
}
+ private Set<Expression> collectWindowFunctionKeys(List<NamedExpression>
aggOutput) {
+ Set<Expression> windowInputs = Sets.newHashSet();
+ for (Expression expr : aggOutput) {
+ Set<WindowExpression> windows =
expr.collect(WindowExpression.class::isInstance);
+ for (WindowExpression win : windows) {
+ windowInputs.addAll(win.getPartitionKeys().stream().flatMap(pk
-> pk.getInputSlots().stream()).collect(
+ Collectors.toList()));
+ windowInputs.addAll(win.getOrderKeys().stream().flatMap(ok ->
ok.getInputSlots().stream()).collect(
+ Collectors.toList()));
+ }
+ }
+ return windowInputs;
+ }
+
+ /**
+ * select sum(c2), avg(min(c2)) over (partition by max(c1) order by
count(c1)) from T ...
+ * extract {sum, min, max, count}. avg is not extracted.
+ */
private Set<AggregateFunction>
collectNonWindowedAggregateFunctions(List<NamedExpression> aggOutput) {
- List<Expression> expressionsWithoutWindow = aggOutput.stream()
- .filter(expr ->
!expr.anyMatch(WindowExpression.class::isInstance))
- .collect(Collectors.toList());
+ return ExpressionUtils.collect(aggOutput, expr -> {
+ if (expr instanceof AggregateFunction) {
+ return !((AggregateFunction) expr).isWindowFunction();
+ }
+ return false;
+ });
+ }
+
+ private Set<AggregateFunction>
collectAggregateFunctionsInWindow(List<NamedExpression> aggOutput) {
+
+ List<WindowExpression> windows = Lists.newArrayList(
+ ExpressionUtils.collect(aggOutput,
WindowExpression.class::isInstance));
+ return ExpressionUtils.collect(windows, expr -> {
+ if (expr instanceof AggregateFunction) {
+ return !((AggregateFunction) expr).isWindowFunction();
+ }
+ return false;
+ });
+ }
- return ExpressionUtils.collect(expressionsWithoutWindow,
AggregateFunction.class::isInstance);
+ private Set<Slot> collectWindowInputSlots(List<NamedExpression> aggOutput)
{
+ List<WindowExpression> windows = Lists.newArrayList(
+ ExpressionUtils.collect(aggOutput,
WindowExpression.class::isInstance));
+ return windows.stream().flatMap(win ->
win.getInputSlots().stream()).collect(Collectors.toSet());
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeToSlot.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeToSlot.java
index aee929c8be..8ef966496e 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeToSlot.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeToSlot.java
@@ -31,6 +31,7 @@ import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiFunction;
+import java.util.function.Predicate;
import javax.annotation.Nullable;
/** NormalizeToSlot */
@@ -88,6 +89,24 @@ public interface NormalizeToSlot {
})).collect(ImmutableList.toImmutableList());
}
+ public <E extends Expression> E normalizeToUseSlotRefUp(E expression,
Predicate skip) {
+ return (E) expression.rewriteDownShortCircuitUp(child -> {
+ NormalizeToSlotTriplet normalizeToSlotTriplet =
normalizeToSlotMap.get(child);
+ return normalizeToSlotTriplet == null ? child :
normalizeToSlotTriplet.remainExpr;
+ }, skip);
+ }
+
+ /**
+ * rewrite subtrees whose root matches predicate border
+ * when we traverse to the node satisfies border predicate,
aboveBorder becomes false
+ */
+ public <E extends Expression> E normalizeToUseSlotRefDown(E
expression, Predicate border, boolean aboveBorder) {
+ return (E) expression.rewriteDownShortCircuitDown(child -> {
+ NormalizeToSlotTriplet normalizeToSlotTriplet =
normalizeToSlotMap.get(child);
+ return normalizeToSlotTriplet == null ? child :
normalizeToSlotTriplet.remainExpr;
+ }, border, aboveBorder);
+ }
+
/**
* generate bottom projections with groupByExpressions.
* eg:
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java
index d8f10a362d..92b99ec68e 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/TreeNode.java
@@ -96,6 +96,64 @@ public interface TreeNode<NODE_TYPE extends
TreeNode<NODE_TYPE>> {
return currentNode;
}
+ /**
+ * same as rewriteDownShortCircuit,
+ * except that subtrees, whose root satisfies predicate is satisfied, are
not rewritten
+ */
+ default NODE_TYPE rewriteDownShortCircuitUp(Function<NODE_TYPE, NODE_TYPE>
rewriteFunction, Predicate skip) {
+ NODE_TYPE currentNode = rewriteFunction.apply((NODE_TYPE) this);
+ if (skip.test(currentNode)) {
+ return currentNode;
+ }
+ if (currentNode == this) {
+ Builder<NODE_TYPE> newChildren =
ImmutableList.builderWithExpectedSize(arity());
+ boolean changed = false;
+ for (NODE_TYPE child : children()) {
+ NODE_TYPE newChild =
child.rewriteDownShortCircuitUp(rewriteFunction, skip);
+ if (child != newChild) {
+ changed = true;
+ }
+ newChildren.add(newChild);
+ }
+
+ if (changed) {
+ currentNode = currentNode.withChildren(newChildren.build());
+ }
+ }
+ return currentNode;
+ }
+
+ /**
+ * similar to rewriteDownShortCircuit, except that only subtrees, whose
root satisfies
+ * border predicate are rewritten.
+ */
+ default NODE_TYPE rewriteDownShortCircuitDown(Function<NODE_TYPE,
NODE_TYPE> rewriteFunction,
+ Predicate border, boolean aboveBorder) {
+ NODE_TYPE currentNode = (NODE_TYPE) this;
+ if (border.test(this)) {
+ aboveBorder = false;
+ }
+ if (!aboveBorder) {
+ currentNode = rewriteFunction.apply((NODE_TYPE) this);
+ }
+ if (currentNode == this) {
+ Builder<NODE_TYPE> newChildren =
ImmutableList.builderWithExpectedSize(arity());
+ boolean changed = false;
+ for (NODE_TYPE child : children()) {
+ NODE_TYPE newChild =
child.rewriteDownShortCircuitDown(rewriteFunction, border, aboveBorder);
+ if (child != newChild) {
+ changed = true;
+ }
+ newChildren.add(newChild);
+ }
+
+ if (changed) {
+ currentNode = currentNode.withChildren(newChildren.build());
+ }
+ }
+ return currentNode;
+ }
+
/**
* bottom-up rewrite.
* @param rewriteFunction rewrite function.
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java
index c9e96da1ec..3da1610d9b 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/WindowExpression.java
@@ -18,6 +18,7 @@
package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.nereids.exceptions.UnboundException;
+import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
@@ -53,6 +54,9 @@ public class WindowExpression extends Expression {
.addAll(orderKeys)
.build().toArray(new Expression[0]));
this.function = function;
+ if (function instanceof AggregateFunction) {
+ ((AggregateFunction) function).setWindowFunction(true);
+ }
this.partitionKeys = ImmutableList.copyOf(partitionKeys);
this.orderKeys = ImmutableList.copyOf(orderKeys);
this.windowFrame = Optional.empty();
@@ -68,6 +72,9 @@ public class WindowExpression extends Expression {
.add(windowFrame)
.build().toArray(new Expression[0]));
this.function = function;
+ if (function instanceof AggregateFunction) {
+ ((AggregateFunction) function).setWindowFunction(true);
+ }
this.partitionKeys = ImmutableList.copyOf(partitionKeys);
this.orderKeys = ImmutableList.copyOf(orderKeys);
this.windowFrame = Optional.of(Objects.requireNonNull(windowFrame));
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java
index 7d4dd262a0..a170ae0dd5 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/AggregateFunction.java
@@ -38,6 +38,7 @@ import java.util.stream.Collectors;
public abstract class AggregateFunction extends BoundFunction implements
ExpectsInputTypes {
protected final boolean distinct;
+ protected boolean isWindowFunction = false;
public AggregateFunction(String name, Expression... arguments) {
this(name, false, arguments);
@@ -77,6 +78,14 @@ public abstract class AggregateFunction extends
BoundFunction implements Expects
return distinct;
}
+ public boolean isWindowFunction() {
+ return isWindowFunction;
+ }
+
+ public void setWindowFunction(boolean windowFunction) {
+ isWindowFunction = windowFunction;
+ }
+
@Override
public boolean equals(Object o) {
if (this == o) {
@@ -86,7 +95,8 @@ public abstract class AggregateFunction extends BoundFunction
implements Expects
return false;
}
AggregateFunction that = (AggregateFunction) o;
- return Objects.equals(distinct, that.distinct)
+ return isWindowFunction == that.isWindowFunction
+ && Objects.equals(distinct, that.distinct)
&& Objects.equals(getName(), that.getName())
&& Objects.equals(children, that.children);
}
@@ -123,4 +133,5 @@ public abstract class AggregateFunction extends
BoundFunction implements Expects
.collect(Collectors.joining(", "));
return getName() + "(" + (distinct ? "DISTINCT " : "") + args + ")";
}
+
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
index f9cbd57a89..90e7c73c8e 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
@@ -411,6 +411,13 @@ public class ExpressionUtils {
.collect(ImmutableSet.toImmutableSet());
}
+ public static <E> Set<E> mutableCollect(List<? extends Expression>
expressions,
+ Predicate<TreeNode<Expression>> predicate) {
+ return expressions.stream()
+ .flatMap(expr -> expr.<Set<E>>collect(predicate).stream())
+ .collect(Collectors.toSet());
+ }
+
public static <E> List<E> collectAll(List<? extends Expression>
expressions,
Predicate<TreeNode<Expression>> predicate) {
return expressions.stream()
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/planner/AnalyticEvalNode.java
b/fe/fe-core/src/main/java/org/apache/doris/planner/AnalyticEvalNode.java
index fb5a63b10f..ff218ce7e3 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/planner/AnalyticEvalNode.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/planner/AnalyticEvalNode.java
@@ -106,7 +106,11 @@ public class AnalyticEvalNode extends PlanNode {
AnalyticWindow analyticWindow, TupleDescriptor
intermediateTupleDesc,
TupleDescriptor outputTupleDesc, Expr partitionByEq, Expr
orderByEq,
TupleDescriptor bufferedTupleDesc) {
- super(id, input.getTupleIds(), "ANALYTIC",
StatisticalType.ANALYTIC_EVAL_NODE);
+ super(id,
+ (input.getOutputTupleDesc() != null
+ ?
Lists.newArrayList(input.getOutputTupleDesc().getId()) :
+ input.getTupleIds()),
+ "ANALYTIC", StatisticalType.ANALYTIC_EVAL_NODE);
Preconditions.checkState(!tupleIds.contains(outputTupleDesc.getId()));
// we're materializing the input row augmented with the analytic
output tuple
tupleIds.add(outputTupleDesc.getId());
diff --git a/regression-test/data/nereids_syntax_p0/window_function.out
b/regression-test/data/nereids_syntax_p0/window_function.out
index 921703d421..8434d65a73 100644
--- a/regression-test/data/nereids_syntax_p0/window_function.out
+++ b/regression-test/data/nereids_syntax_p0/window_function.out
@@ -389,3 +389,53 @@
-- !window_use_agg --
20
+-- !winExpr_not_agg_expr --
+2 5
+3 5
+3 5
+4 5
+4 5
+6 5
+6 5
+6 5
+
+-- !on_notgroupbycolumn --
+1.0
+2.0
+3.0
+3.0
+3.0
+6.0
+6.0
+6.0
+
+-- !orderby --
+1 2 2
+1 4 2
+1 4 2
+1 6 2
+2 3 5
+2 3 5
+2 6 5
+2 6 5
+
+-- !winExpr_with_others --
+0.4
+0.4
+0.5
+0.8
+0.8
+1.0
+1.0
+1.5
+
+-- !winExpr_with_others2 --
+0.4
+0.4
+0.5
+0.8
+0.8
+1.0
+1.0
+1.5
+
diff --git a/regression-test/suites/nereids_syntax_p0/window_function.groovy
b/regression-test/suites/nereids_syntax_p0/window_function.groovy
index 76318169b1..4c0509bbe6 100644
--- a/regression-test/suites/nereids_syntax_p0/window_function.groovy
+++ b/regression-test/suites/nereids_syntax_p0/window_function.groovy
@@ -15,7 +15,7 @@
// specific language governing permissions and limitations
// under the License.
-suite("test_window_function") {
+suite("window_function") {
sql "SET enable_nereids_planner=true"
sql "DROP TABLE IF EXISTS window_test"
@@ -118,4 +118,34 @@ suite("test_window_function") {
SELECT sum(sum(c1)) over(partition by avg(c2))
FROM window_test
"""
+
+ order_qt_winExpr_not_agg_expr """
+ select sum(c1+1), sum(c1+1) over (partition by avg(c2))
+ from window_test
+ group by c1, c2
+ """
+
+ order_qt_on_notgroupbycolumn """
+ select sum(sum(c3)) over (partition by avg(c2) order by c1)
+ from window_test
+ group by c1, c2
+ """
+
+ order_qt_orderby """
+ select c1, sum(c1+1), sum(c1+1) over (partition by avg(c2) order by c1)
+ from window_test
+ group by c1, c2
+ """
+
+ order_qt_winExpr_with_others """
+ select sum(c1)/sum(c1+1) over (partition by c2 order by c1)
+ from window_test
+ group by c1, c2
+ """
+
+ order_qt_winExpr_with_others2"""
+ select sum(c1)/sum(c1+1) over (partition by c2 order by c1)
+ from window_test
+ group by c1, c2
+ """
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]