This is an automated email from the ASF dual-hosted git repository.
starocean999 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 00c30f075ff [fix](nereids)only push down subquery in non-window agg
functions (#26034)
00c30f075ff is described below
commit 00c30f075ffdde6c7232dec7019fcf12bbe724f8
Author: starocean999 <[email protected]>
AuthorDate: Mon Oct 30 11:32:10 2023 +0800
[fix](nereids)only push down subquery in non-window agg functions (#26034)
---
.../doris/nereids/rules/analysis/NormalizeAggregate.java | 14 +++++++-------
.../nereids_p0/subquery/test_subquery_in_project.out | 12 ++++++++++++
.../nereids_p0/subquery/test_subquery_in_project.groovy | 16 ++++++++++++++--
3 files changed, 33 insertions(+), 9 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java
index 0acd366f1c2..203a7f09690 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java
@@ -104,11 +104,13 @@ public class NormalizeAggregate extends
OneRewriteRuleFactory implements Normali
List<NamedExpression> aggregateOutput =
aggregate.getOutputExpressions();
Set<Alias> existsAlias =
ExpressionUtils.mutableCollect(aggregateOutput, Alias.class::isInstance);
- // we need push down subquery exprs in side non-distinct agg
functions
- Set<SubqueryExpr> subqueryExprs = ExpressionUtils.mutableCollect(
-
Lists.newArrayList(ExpressionUtils.mutableCollect(aggregateOutput,
- expr -> expr instanceof AggregateFunction
- && !((AggregateFunction)
expr).isDistinct())),
+
+ List<AggregateFunction> aggFuncs = Lists.newArrayList();
+ aggregateOutput.forEach(o ->
o.accept(CollectNonWindowedAggFuncs.INSTANCE, aggFuncs));
+
+ // we need push down subquery exprs inside non-window and
non-distinct agg functions
+ Set<SubqueryExpr> subqueryExprs =
ExpressionUtils.mutableCollect(aggFuncs.stream()
+ .filter(aggFunc ->
!aggFunc.isDistinct()).collect(Collectors.toList()),
SubqueryExpr.class::isInstance);
Set<Expression> groupingByExprs =
ImmutableSet.copyOf(aggregate.getGroupByExpressions());
NormalizeToSlotContext bottomSlotContext =
@@ -116,8 +118,6 @@ public class NormalizeAggregate extends
OneRewriteRuleFactory implements Normali
Set<NamedExpression> bottomOutputs =
bottomSlotContext.pushDownToNamedExpression(Sets.union(groupingByExprs,
subqueryExprs));
- List<AggregateFunction> aggFuncs = Lists.newArrayList();
- aggregateOutput.forEach(o ->
o.accept(CollectNonWindowedAggFuncs.INSTANCE, aggFuncs));
// use group by context to normalize agg functions to process
// sql like: select sum(a + 1) from t group by a + 1
//
diff --git
a/regression-test/data/nereids_p0/subquery/test_subquery_in_project.out
b/regression-test/data/nereids_p0/subquery/test_subquery_in_project.out
index 4d8bd4c7361..c0d289b51fb 100644
--- a/regression-test/data/nereids_p0/subquery/test_subquery_in_project.out
+++ b/regression-test/data/nereids_p0/subquery/test_subquery_in_project.out
@@ -54,3 +54,15 @@ true
-- !sql16 --
12
+-- !sql17 --
+12
+12
+
+-- !sql18 --
+12
+12
+
+-- !sql20 --
+5
+7
+
diff --git
a/regression-test/suites/nereids_p0/subquery/test_subquery_in_project.groovy
b/regression-test/suites/nereids_p0/subquery/test_subquery_in_project.groovy
index b9de14e530b..25920848c6d 100644
--- a/regression-test/suites/nereids_p0/subquery/test_subquery_in_project.groovy
+++ b/regression-test/suites/nereids_p0/subquery/test_subquery_in_project.groovy
@@ -117,11 +117,23 @@ suite("test_subquery_in_project") {
"""
qt_sql15 """
- select sum(age + (select sum(age) from test_sql)) from test_sql;
+ select sum(age + (select sum(age) from test_sql)) from test_sql order
by 1;
"""
qt_sql16 """
- select sum(distinct age + (select sum(age) from test_sql)) from
test_sql;
+ select sum(distinct age + (select sum(age) from test_sql)) from
test_sql order by 1;
+ """
+
+ qt_sql17 """
+ select sum(age + (select sum(age) from test_sql)) over() from test_sql
order by 1;
+ """
+
+ qt_sql18 """
+ select sum(age + (select sum(age) from test_sql)) over() from test_sql
group by dt, age order by 1;
+ """
+
+ qt_sql20 """
+ select sum(age + (select sum(age) from test_sql)) from test_sql group
by dt, age order by 1;
"""
sql """drop table if exists test_sql;"""
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]